commit small improvements before going craZy
This commit is contained in:
10
library.pyx
10
library.pyx
@@ -7,6 +7,7 @@ from libc.stdlib cimport malloc
|
||||
|
||||
ctr = []
|
||||
X_train, y_train, X_test, y_test = mn.load_mnist()
|
||||
opt = mn.SGDOptimizer(lr=0.1)
|
||||
|
||||
|
||||
cdef extern from "numpy/arrayobject.h":
|
||||
@@ -60,10 +61,13 @@ cdef public object make_like(object neta, object netb):
|
||||
|
||||
cdef public void step_net(
|
||||
object net,
|
||||
np.ndarray[np.float32_t, ndim=2, mode='c'] batch
|
||||
float* batch_data,
|
||||
Py_ssize_t batch_size
|
||||
):
|
||||
opt = mn.SGDOptimizer(lr=0.1)
|
||||
net.step(batch[:, :784], batch[:, 784:], opt)
|
||||
cdef Py_ssize_t in_dim = net.geometry[0]
|
||||
cdef Py_ssize_t out_dim = net.geometry[-1]
|
||||
batch = np.asarray(<float[:batch_size,:in_dim+out_dim]>batch_data)
|
||||
net.step(batch[:, :in_dim], batch[:, in_dim:], opt)
|
||||
|
||||
|
||||
cdef public float eval_net(
|
||||
|
||||
Reference in New Issue
Block a user