From 253c833b5b1928093f2ca3c730c645252e7ea2c9 Mon Sep 17 00:00:00 2001 From: Pavel Lutskov Date: Mon, 25 Nov 2019 20:54:27 -0800 Subject: [PATCH] commit small improvements before going craZy --- library.pyx | 10 +++++++--- main.c | 11 ++++------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/library.pyx b/library.pyx index e8320ba..6a1ea16 100644 --- a/library.pyx +++ b/library.pyx @@ -7,6 +7,7 @@ from libc.stdlib cimport malloc ctr = [] X_train, y_train, X_test, y_test = mn.load_mnist() +opt = mn.SGDOptimizer(lr=0.1) cdef extern from "numpy/arrayobject.h": @@ -60,10 +61,13 @@ cdef public object make_like(object neta, object netb): cdef public void step_net( object net, - np.ndarray[np.float32_t, ndim=2, mode='c'] batch + float* batch_data, + Py_ssize_t batch_size ): - opt = mn.SGDOptimizer(lr=0.1) - net.step(batch[:, :784], batch[:, 784:], opt) + cdef Py_ssize_t in_dim = net.geometry[0] + cdef Py_ssize_t out_dim = net.geometry[-1] + batch = np.asarray(batch_data) + net.step(batch[:, :in_dim], batch[:, in_dim:], opt) cdef public float eval_net( diff --git a/main.c b/main.c index 02a743b..6196c0a 100644 --- a/main.c +++ b/main.c @@ -92,14 +92,11 @@ void slave_node() { MPI_Recv(shape, 2, MPI_LONG, P_READER, MPI_ANY_TAG, MPI_COMM_WORLD, MPI_STATUS_IGNORE); long size = shape[0] * shape[1]; - float* data = malloc(shape[0] * shape[1] * sizeof(float)); - MPI_Recv(data, size, MPI_FLOAT, P_READER, MPI_ANY_TAG, + float* batch = malloc(shape[0] * shape[1] * sizeof(float)); + MPI_Recv(batch, size, MPI_FLOAT, P_READER, MPI_ANY_TAG, MPI_COMM_WORLD, MPI_STATUS_IGNORE); - PyArrayObject* batch = PyArray_SimpleNewFromData( - 2, shape, NPY_FLOAT32, data); - step_net(net, batch); - Py_DECREF(batch); - free(data); + step_net(net, batch, BS); + free(batch); } Network c_net; cify_network(net, &c_net);