small refactoring and kicked the visualizer again

This commit is contained in:
2019-12-15 09:51:22 -08:00
parent 06d0b2d565
commit 05480606b0
4 changed files with 82 additions and 30 deletions

View File

@@ -47,6 +47,12 @@ cdef public void serve():
srv.serve()
cdef public void server_update(float *emb):
embeddings = np.asarray(<float[:getvocsize(),:getemb()]>emb)
low_dim = nn.calc_TSNE(embeddings)
srv.emb_map = dict(zip(nn.inv_vocab, low_dim))
cdef public size_t getwin():
return nn.WIN
@@ -67,6 +73,10 @@ cdef public float gettarget():
return nn.CFG['target']
cdef public size_t getvocsize():
return len(nn.vocab)
cdef public int get_tokens(WordList* wl, const char *filename):
fnu = filename.decode('utf-8')
if fnu not in tokenizers:
@@ -101,8 +111,8 @@ cdef public void _dbg_print(object o):
eprint(o)
cdef public void _dbg_print_cbow_batch(float* batch, size_t bs):
X_np, y_np = cbow_batch(batch, bs)
cdef public void _dbg_print_cbow_batch(float* batch):
X_np, y_np = cbow_batch(batch)
eprint(X_np)
eprint(y_np)
@@ -124,10 +134,8 @@ cdef public void set_net_weights(object net, WeightList* wl):
net.set_weights(wrap_weight_list(wl))
cdef public void step_net(
object net, float* batch, size_t bs
):
X_train, y_train = cbow_batch(batch, bs)
cdef public void step_net(object net, float* batch):
X_train, y_train = cbow_batch(batch)
net.train_on_batch(X_train, y_train)
@@ -183,8 +191,9 @@ cdef public void combo_weights(
wf += alpha * ww
cdef tuple cbow_batch(float* batch, size_t bs):
win = nn.WIN
cdef tuple cbow_batch(float* batch):
win = getwin()
bs = getbs()
batch_np = np.asarray(<float[:bs,:2*win+1]>batch)
X_np = batch_np[:, [*range(win), *range(win+1, win+win+1)]]
y_np = nn.onehot(batch_np[:, win], nc=len(nn.vocab))