commit small improvements before going craZy

This commit is contained in:
2019-11-25 20:54:27 -08:00
parent 04e0b9829c
commit 253c833b5b
2 changed files with 11 additions and 10 deletions

View File

@@ -7,6 +7,7 @@ from libc.stdlib cimport malloc
ctr = [] ctr = []
X_train, y_train, X_test, y_test = mn.load_mnist() X_train, y_train, X_test, y_test = mn.load_mnist()
opt = mn.SGDOptimizer(lr=0.1)
cdef extern from "numpy/arrayobject.h": cdef extern from "numpy/arrayobject.h":
@@ -60,10 +61,13 @@ cdef public object make_like(object neta, object netb):
cdef public void step_net( cdef public void step_net(
object net, object net,
np.ndarray[np.float32_t, ndim=2, mode='c'] batch float* batch_data,
Py_ssize_t batch_size
): ):
opt = mn.SGDOptimizer(lr=0.1) cdef Py_ssize_t in_dim = net.geometry[0]
net.step(batch[:, :784], batch[:, 784:], opt) cdef Py_ssize_t out_dim = net.geometry[-1]
batch = np.asarray(<float[:batch_size,:in_dim+out_dim]>batch_data)
net.step(batch[:, :in_dim], batch[:, in_dim:], opt)
cdef public float eval_net( cdef public float eval_net(

11
main.c
View File

@@ -92,14 +92,11 @@ void slave_node() {
MPI_Recv(shape, 2, MPI_LONG, P_READER, MPI_ANY_TAG, MPI_COMM_WORLD, MPI_Recv(shape, 2, MPI_LONG, P_READER, MPI_ANY_TAG, MPI_COMM_WORLD,
MPI_STATUS_IGNORE); MPI_STATUS_IGNORE);
long size = shape[0] * shape[1]; long size = shape[0] * shape[1];
float* data = malloc(shape[0] * shape[1] * sizeof(float)); float* batch = malloc(shape[0] * shape[1] * sizeof(float));
MPI_Recv(data, size, MPI_FLOAT, P_READER, MPI_ANY_TAG, MPI_Recv(batch, size, MPI_FLOAT, P_READER, MPI_ANY_TAG,
MPI_COMM_WORLD, MPI_STATUS_IGNORE); MPI_COMM_WORLD, MPI_STATUS_IGNORE);
PyArrayObject* batch = PyArray_SimpleNewFromData( step_net(net, batch, BS);
2, shape, NPY_FLOAT32, data); free(batch);
step_net(net, batch);
Py_DECREF(batch);
free(data);
} }
Network c_net; Network c_net;
cify_network(net, &c_net); cify_network(net, &c_net);