commit small improvements before going craZy
This commit is contained in:
10
library.pyx
10
library.pyx
@@ -7,6 +7,7 @@ from libc.stdlib cimport malloc
|
|||||||
|
|
||||||
ctr = []
|
ctr = []
|
||||||
X_train, y_train, X_test, y_test = mn.load_mnist()
|
X_train, y_train, X_test, y_test = mn.load_mnist()
|
||||||
|
opt = mn.SGDOptimizer(lr=0.1)
|
||||||
|
|
||||||
|
|
||||||
cdef extern from "numpy/arrayobject.h":
|
cdef extern from "numpy/arrayobject.h":
|
||||||
@@ -60,10 +61,13 @@ cdef public object make_like(object neta, object netb):
|
|||||||
|
|
||||||
cdef public void step_net(
|
cdef public void step_net(
|
||||||
object net,
|
object net,
|
||||||
np.ndarray[np.float32_t, ndim=2, mode='c'] batch
|
float* batch_data,
|
||||||
|
Py_ssize_t batch_size
|
||||||
):
|
):
|
||||||
opt = mn.SGDOptimizer(lr=0.1)
|
cdef Py_ssize_t in_dim = net.geometry[0]
|
||||||
net.step(batch[:, :784], batch[:, 784:], opt)
|
cdef Py_ssize_t out_dim = net.geometry[-1]
|
||||||
|
batch = np.asarray(<float[:batch_size,:in_dim+out_dim]>batch_data)
|
||||||
|
net.step(batch[:, :in_dim], batch[:, in_dim:], opt)
|
||||||
|
|
||||||
|
|
||||||
cdef public float eval_net(
|
cdef public float eval_net(
|
||||||
|
|||||||
11
main.c
11
main.c
@@ -92,14 +92,11 @@ void slave_node() {
|
|||||||
MPI_Recv(shape, 2, MPI_LONG, P_READER, MPI_ANY_TAG, MPI_COMM_WORLD,
|
MPI_Recv(shape, 2, MPI_LONG, P_READER, MPI_ANY_TAG, MPI_COMM_WORLD,
|
||||||
MPI_STATUS_IGNORE);
|
MPI_STATUS_IGNORE);
|
||||||
long size = shape[0] * shape[1];
|
long size = shape[0] * shape[1];
|
||||||
float* data = malloc(shape[0] * shape[1] * sizeof(float));
|
float* batch = malloc(shape[0] * shape[1] * sizeof(float));
|
||||||
MPI_Recv(data, size, MPI_FLOAT, P_READER, MPI_ANY_TAG,
|
MPI_Recv(batch, size, MPI_FLOAT, P_READER, MPI_ANY_TAG,
|
||||||
MPI_COMM_WORLD, MPI_STATUS_IGNORE);
|
MPI_COMM_WORLD, MPI_STATUS_IGNORE);
|
||||||
PyArrayObject* batch = PyArray_SimpleNewFromData(
|
step_net(net, batch, BS);
|
||||||
2, shape, NPY_FLOAT32, data);
|
free(batch);
|
||||||
step_net(net, batch);
|
|
||||||
Py_DECREF(batch);
|
|
||||||
free(data);
|
|
||||||
}
|
}
|
||||||
Network c_net;
|
Network c_net;
|
||||||
cify_network(net, &c_net);
|
cify_network(net, &c_net);
|
||||||
|
|||||||
Reference in New Issue
Block a user