now it kinda learns again and code is kinda clean

This commit is contained in:
2019-12-01 15:44:34 -08:00
parent 5d14171631
commit bc6d34e253
3 changed files with 39 additions and 53 deletions

View File

@@ -75,24 +75,15 @@ cdef public void f_idx_list_to_print(float* f_idxs, size_t num):
# return retval
cdef public void cbow_batch(
float* batch, size_t bs, size_t win
):
batch_np = np.asarray(<float[:bs,:2*win+1]>batch)
# Deal with X
X_np = np.concatenate([batch_np[:, :win], batch_np[:, win+1:]], axis=1)
y_np = nn.onehot(batch_np[:, win], nc=len(nn.vocab))
eprint(batch_np)
eprint(X_np)
eprint(np.argmax(y_np, axis=1))
cdef public void debug_print(object o):
eprint(o)
cdef public object create_network(int win, int embed):
return nn.create_cbow_network(win, embed)
try:
return nn.create_cbow_network(win, embed)
except Exception as e:
eprint(e)
cdef public void set_net_weights(object net, WeightList* wl):
@@ -102,9 +93,7 @@ cdef public void set_net_weights(object net, WeightList* wl):
cdef public void step_net(
object net, float* batch, size_t bs
):
# X_train, y_train = cbow_batch(net, batch, bs)
X_train = None
y_train = None
X_train, y_train = cbow_batch(net, batch, bs)
net.train_on_batch(X_train, y_train)
@@ -113,7 +102,10 @@ cdef public size_t out_size(object net):
cdef public float eval_net(object net):
return net.evaluate(X_test, y_test, verbose=False)
try:
return net.evaluate(X_test, y_test, verbose=False)
except Exception as e:
eprint(e)
cdef public void init_weightlist_like(WeightList* wl, object net):
@@ -156,6 +148,16 @@ cdef public void create_test_dataset(size_t win):
_create_test_dataset(win)
cdef tuple cbow_batch(
object net, float* batch, size_t bs
):
win = net.input_shape[1] // 2
batch_np = np.asarray(<float[:bs,:2*win+1]>batch)
X_np = np.concatenate([batch_np[:, :win], batch_np[:, win+1:]], axis=1)
y_np = nn.onehot(batch_np[:, win], nc=len(nn.vocab))
return X_np, y_np
cdef list wrap_weight_list(WeightList* wl):
weights = []
for i in range(wl.n_weights):