for some two other files weren't added to commit

This commit is contained in:
2019-12-01 15:05:48 -08:00
parent a2ef842ef8
commit 5d14171631
2 changed files with 14 additions and 21 deletions

View File

@@ -76,17 +76,15 @@ cdef public void f_idx_list_to_print(float* f_idxs, size_t num):
cdef public void cbow_batch(
float* X, float* y, float* idxs, size_t bs, size_t win
float* batch, size_t bs, size_t win
):
idxs_np = np.asarray(<float[:bs + 2*win]>idxs)
batch_np = np.asarray(<float[:bs,:2*win+1]>batch)
# Deal with X
X_np = np.asarray(<float[:bs,:2*win]>X)
for r in range(bs):
X_np[r, :win] = idxs_np[r:r+win]
X_np[r, win:] = idxs_np[r+win+1:r+win+1+win]
# Deal with y
nn.onehot(np.asarray(<float[:bs, :len(nn.vocab)]>y), idxs_np[win:-win])
X_np = np.concatenate([batch_np[:, :win], batch_np[:, win+1:]], axis=1)
y_np = nn.onehot(batch_np[:, win], nc=len(nn.vocab))
eprint(batch_np)
eprint(X_np)
eprint(np.argmax(y_np, axis=1))
cdef public void debug_print(object o):
@@ -102,13 +100,11 @@ cdef public void set_net_weights(object net, WeightList* wl):
cdef public void step_net(
object net, float* X, float* y, size_t batch_size
object net, float* batch, size_t bs
):
in_shape = (batch_size,) + net.input_shape[1:]
out_shape = (batch_size,) + net.output_shape[1:]
X_train = np.asarray(<float[:np.prod(in_shape)]>X).reshape(in_shape)
y_train = np.asarray(<float[:np.prod(out_shape)]>y).reshape(out_shape),
# X_train, y_train = cbow_batch(net, batch, bs)
X_train = None
y_train = None
net.train_on_batch(X_train, y_train)