this is the baseline for evaluation

This commit is contained in:
2019-12-11 10:31:16 -08:00
parent 5abe7bb413
commit 7043b65532
3 changed files with 52 additions and 45 deletions

View File

@@ -97,7 +97,9 @@ cdef public void randidx(int* idx, size_t l, size_t how_much):
cdef public object create_network(int win, int embed):
try:
return nn.create_cbow_network(win, embed)
net = nn.create_cbow_network(win, embed)
eprint(net)
return net
except Exception as e:
eprint(e)
@@ -169,7 +171,7 @@ cdef tuple cbow_batch(
):
win = net.input_shape[1] // 2
batch_np = np.asarray(<float[:bs,:2*win+1]>batch)
X_np = np.concatenate([batch_np[:, :win], batch_np[:, win+1:]], axis=1)
X_np = batch_np[:, [*range(win), *range(win+1, win+win+1)]]
y_np = nn.onehot(batch_np[:, win], nc=len(nn.vocab))
return X_np, y_np
@@ -192,7 +194,7 @@ cdef void words_into_wordlist(WordList* wl, list words):
wl.words = <Word*>realloc(wl.words, wl.mem * sizeof(Word))
for i in range(old, wl.mem):
wl.words[i].mem = 0
wl.words[i].data = <char*>0
wl.words[i].data = NULL
wl.n_words = len(words)
for i, w in enumerate(words):