network trains and tests but slow? -- investigate
This commit is contained in:
51
bridge.pyx
51
bridge.pyx
@@ -9,8 +9,9 @@ from libc.string cimport memcpy
|
||||
import library as nn
|
||||
|
||||
|
||||
X_train, y_train, X_test, y_test = nn.load_mnist()
|
||||
tokenizers = {}
|
||||
X_test = None
|
||||
y_test = None
|
||||
|
||||
|
||||
cdef extern from "numpy/arrayobject.h":
|
||||
@@ -74,20 +75,18 @@ cdef public void f_idx_list_to_print(float* f_idxs, size_t num):
|
||||
# return retval
|
||||
|
||||
|
||||
cdef public void c_onehot(float* y, float* idxs, size_t n_idx):
|
||||
oh = nn.onehot(np.asarray(<float[:n_idx]>idxs), nc=len(nn.vocab))
|
||||
ensure_contiguous(oh)
|
||||
memcpy(y, PyArray_DATA(oh), oh.size * sizeof(float))
|
||||
# eprint(np.argmax(oh, axis=1))
|
||||
|
||||
|
||||
cdef public void c_slices(float* X, float* idxs, size_t bs, size_t win):
|
||||
X_np = np.asarray(<float[:bs,:2*win]>X)
|
||||
cdef public void cbow_batch(
|
||||
float* X, float* y, float* idxs, size_t bs, size_t win
|
||||
):
|
||||
idxs_np = np.asarray(<float[:bs + 2*win]>idxs)
|
||||
# Deal with X
|
||||
X_np = np.asarray(<float[:bs,:2*win]>X)
|
||||
for r in range(bs):
|
||||
X_np[r, :win] = idxs_np[r:r+win]
|
||||
X_np[r, win:] = idxs_np[r+win+1:r+win+1+win]
|
||||
# eprint(X_np)
|
||||
|
||||
# Deal with y
|
||||
nn.onehot(np.asarray(<float[:bs, :len(nn.vocab)]>y), idxs_np[win:-win])
|
||||
|
||||
|
||||
cdef public void debug_print(object o):
|
||||
@@ -121,26 +120,6 @@ cdef public float eval_net(object net):
|
||||
return net.evaluate(X_test, y_test, verbose=False)
|
||||
|
||||
|
||||
cdef public void mnist_batch(float* X, float* y, size_t bs,
|
||||
int part, int total):
|
||||
if total == 0:
|
||||
X_pool, y_pool = X_train, y_train
|
||||
else:
|
||||
partsize = len(X_train) // total
|
||||
X_pool = X_train[part*partsize:(part+1)*partsize]
|
||||
y_pool = y_train[part*partsize:(part+1)*partsize]
|
||||
|
||||
idx = np.random.choice(len(X_pool), bs, replace=True)
|
||||
|
||||
X_r = X_pool[idx]
|
||||
y_r = y_pool[idx]
|
||||
|
||||
assert X_r.flags['C_CONTIGUOUS']
|
||||
assert y_r.flags['C_CONTIGUOUS']
|
||||
memcpy(X, PyArray_DATA(X_r), X_r.size * sizeof(float))
|
||||
memcpy(y, PyArray_DATA(y_r), y_r.size * sizeof(float))
|
||||
|
||||
|
||||
cdef public void init_weightlist_like(WeightList* wl, object net):
|
||||
weights = net.get_weights()
|
||||
wl.n_weights = len(weights)
|
||||
@@ -177,6 +156,10 @@ cdef public void combo_weights(
|
||||
wf += alpha * ww
|
||||
|
||||
|
||||
cdef public void create_test_dataset(size_t win):
|
||||
_create_test_dataset(win)
|
||||
|
||||
|
||||
cdef list wrap_weight_list(WeightList* wl):
|
||||
weights = []
|
||||
for i in range(wl.n_weights):
|
||||
@@ -221,3 +204,9 @@ def ensure_contiguous(a):
|
||||
|
||||
def eprint(*args, **kwargs):
|
||||
return print(*args, flush=True, **kwargs)
|
||||
|
||||
|
||||
def _create_test_dataset(win):
|
||||
global X_test, y_test
|
||||
if X_test is None or y_test is None:
|
||||
X_test, y_test = nn.create_test_dataset(win)
|
||||
|
||||
Reference in New Issue
Block a user