network trains and tests but slow? -- investigate

This commit is contained in:
2019-12-01 11:26:54 -08:00
parent b115077c3b
commit 4bf66bf85e
3 changed files with 59 additions and 49 deletions

View File

@@ -9,8 +9,9 @@ from libc.string cimport memcpy
import library as nn
X_train, y_train, X_test, y_test = nn.load_mnist()
tokenizers = {}
X_test = None
y_test = None
cdef extern from "numpy/arrayobject.h":
@@ -74,20 +75,18 @@ cdef public void f_idx_list_to_print(float* f_idxs, size_t num):
# return retval
cdef public void c_onehot(float* y, float* idxs, size_t n_idx):
oh = nn.onehot(np.asarray(<float[:n_idx]>idxs), nc=len(nn.vocab))
ensure_contiguous(oh)
memcpy(y, PyArray_DATA(oh), oh.size * sizeof(float))
# eprint(np.argmax(oh, axis=1))
cdef public void c_slices(float* X, float* idxs, size_t bs, size_t win):
X_np = np.asarray(<float[:bs,:2*win]>X)
cdef public void cbow_batch(
float* X, float* y, float* idxs, size_t bs, size_t win
):
idxs_np = np.asarray(<float[:bs + 2*win]>idxs)
# Deal with X
X_np = np.asarray(<float[:bs,:2*win]>X)
for r in range(bs):
X_np[r, :win] = idxs_np[r:r+win]
X_np[r, win:] = idxs_np[r+win+1:r+win+1+win]
# eprint(X_np)
# Deal with y
nn.onehot(np.asarray(<float[:bs, :len(nn.vocab)]>y), idxs_np[win:-win])
cdef public void debug_print(object o):
@@ -121,26 +120,6 @@ cdef public float eval_net(object net):
return net.evaluate(X_test, y_test, verbose=False)
cdef public void mnist_batch(float* X, float* y, size_t bs,
int part, int total):
if total == 0:
X_pool, y_pool = X_train, y_train
else:
partsize = len(X_train) // total
X_pool = X_train[part*partsize:(part+1)*partsize]
y_pool = y_train[part*partsize:(part+1)*partsize]
idx = np.random.choice(len(X_pool), bs, replace=True)
X_r = X_pool[idx]
y_r = y_pool[idx]
assert X_r.flags['C_CONTIGUOUS']
assert y_r.flags['C_CONTIGUOUS']
memcpy(X, PyArray_DATA(X_r), X_r.size * sizeof(float))
memcpy(y, PyArray_DATA(y_r), y_r.size * sizeof(float))
cdef public void init_weightlist_like(WeightList* wl, object net):
weights = net.get_weights()
wl.n_weights = len(weights)
@@ -177,6 +156,10 @@ cdef public void combo_weights(
wf += alpha * ww
cdef public void create_test_dataset(size_t win):
_create_test_dataset(win)
cdef list wrap_weight_list(WeightList* wl):
weights = []
for i in range(wl.n_weights):
@@ -221,3 +204,9 @@ def ensure_contiguous(a):
def eprint(*args, **kwargs):
return print(*args, flush=True, **kwargs)
def _create_test_dataset(win):
global X_test, y_test
if X_test is None or y_test is None:
X_test, y_test = nn.create_test_dataset(win)