stop refactoring and get some stuff huggin' done

This commit is contained in:
2019-12-12 21:10:41 -08:00
parent 966bbc904c
commit 2bbb1d243c
4 changed files with 97 additions and 68 deletions

View File

@@ -11,8 +11,6 @@ import flask
tokenizers = {}
X_test = None
y_test = None
cdef extern from "numpy/arrayobject.h":
@@ -49,6 +47,30 @@ cdef public void serve():
nn.app.run(port=8448)
cdef public size_t getwin():
return nn.WIN
cdef public size_t getemb():
return nn.EMB
cdef public size_t getbs():
return nn.CFG['bs']
cdef public size_t getbpe():
return nn.CFG['bpe']
cdef public float gettarget():
return nn.CFG['target']
cdef public float getflpc():
return nn.CFG['flpc']
cdef public int get_tokens(WordList* wl, const char *filename):
fnu = filename.decode('utf-8')
if fnu not in tokenizers:
@@ -82,10 +104,8 @@ cdef public void _dbg_print(object o):
eprint(o)
cdef public void _dbg_print_cbow_batch(
object net, float* batch, size_t bs
):
X_np, y_np = cbow_batch(net, batch, bs)
cdef public void _dbg_print_cbow_batch(float* batch, size_t bs):
X_np, y_np = cbow_batch(batch, bs)
eprint(X_np)
eprint(y_np)
@@ -95,9 +115,9 @@ cdef public void randidx(int* idx, size_t l, size_t how_much):
memcpy(idx, PyArray_DATA(i_np), how_much * sizeof(int))
cdef public object create_network(int win, int embed):
cdef public object create_network():
try:
net = nn.create_cbow_network(win, embed)
net = nn.create_cbow_network()
eprint(net)
return net
except Exception as e:
@@ -111,7 +131,7 @@ cdef public void set_net_weights(object net, WeightList* wl):
cdef public void step_net(
object net, float* batch, size_t bs
):
X_train, y_train = cbow_batch(net, batch, bs)
X_train, y_train = cbow_batch(batch, bs)
net.train_on_batch(X_train, y_train)
@@ -120,10 +140,7 @@ cdef public size_t out_size(object net):
cdef public float eval_net(object net):
try:
return net.evaluate(X_test, y_test, verbose=False)
except Exception as e:
eprint(e)
return nn.eval_network(net)
cdef public void init_weightlist_like(WeightList* wl, object net):
@@ -162,14 +179,8 @@ cdef public void combo_weights(
wf += alpha * ww
cdef public void create_test_dataset(size_t win):
_create_test_dataset(win)
cdef tuple cbow_batch(
object net, float* batch, size_t bs
):
win = net.input_shape[1] // 2
cdef tuple cbow_batch(float* batch, size_t bs):
win = nn.WIN
batch_np = np.asarray(<float[:bs,:2*win+1]>batch)
X_np = batch_np[:, [*range(win), *range(win+1, win+win+1)]]
y_np = nn.onehot(batch_np[:, win], nc=len(nn.vocab))
@@ -177,6 +188,7 @@ cdef tuple cbow_batch(
cdef list wrap_weight_list(WeightList* wl):
"""Thinly wraps a WeightList struct into a NumPy array."""
weights = []
for i in range(wl.n_weights):
w_shape = <long[:wl.weights[i].dims]>wl.weights[i].shape
@@ -220,9 +232,3 @@ def ensure_contiguous(a):
def eprint(*args, **kwargs):
return print(*args, flush=True, **kwargs)
def _create_test_dataset(win):
global X_test, y_test
if X_test is None or y_test is None:
X_test, y_test = nn.create_test_dataset(win)