stop refactoring and get some stuff huggin' done
This commit is contained in:
60
bridge.pyx
60
bridge.pyx
@@ -11,8 +11,6 @@ import flask
|
||||
|
||||
|
||||
tokenizers = {}
|
||||
X_test = None
|
||||
y_test = None
|
||||
|
||||
|
||||
cdef extern from "numpy/arrayobject.h":
|
||||
@@ -49,6 +47,30 @@ cdef public void serve():
|
||||
nn.app.run(port=8448)
|
||||
|
||||
|
||||
cdef public size_t getwin():
|
||||
return nn.WIN
|
||||
|
||||
|
||||
cdef public size_t getemb():
|
||||
return nn.EMB
|
||||
|
||||
|
||||
cdef public size_t getbs():
|
||||
return nn.CFG['bs']
|
||||
|
||||
|
||||
cdef public size_t getbpe():
|
||||
return nn.CFG['bpe']
|
||||
|
||||
|
||||
cdef public float gettarget():
|
||||
return nn.CFG['target']
|
||||
|
||||
|
||||
cdef public float getflpc():
|
||||
return nn.CFG['flpc']
|
||||
|
||||
|
||||
cdef public int get_tokens(WordList* wl, const char *filename):
|
||||
fnu = filename.decode('utf-8')
|
||||
if fnu not in tokenizers:
|
||||
@@ -82,10 +104,8 @@ cdef public void _dbg_print(object o):
|
||||
eprint(o)
|
||||
|
||||
|
||||
cdef public void _dbg_print_cbow_batch(
|
||||
object net, float* batch, size_t bs
|
||||
):
|
||||
X_np, y_np = cbow_batch(net, batch, bs)
|
||||
cdef public void _dbg_print_cbow_batch(float* batch, size_t bs):
|
||||
X_np, y_np = cbow_batch(batch, bs)
|
||||
eprint(X_np)
|
||||
eprint(y_np)
|
||||
|
||||
@@ -95,9 +115,9 @@ cdef public void randidx(int* idx, size_t l, size_t how_much):
|
||||
memcpy(idx, PyArray_DATA(i_np), how_much * sizeof(int))
|
||||
|
||||
|
||||
cdef public object create_network(int win, int embed):
|
||||
cdef public object create_network():
|
||||
try:
|
||||
net = nn.create_cbow_network(win, embed)
|
||||
net = nn.create_cbow_network()
|
||||
eprint(net)
|
||||
return net
|
||||
except Exception as e:
|
||||
@@ -111,7 +131,7 @@ cdef public void set_net_weights(object net, WeightList* wl):
|
||||
cdef public void step_net(
|
||||
object net, float* batch, size_t bs
|
||||
):
|
||||
X_train, y_train = cbow_batch(net, batch, bs)
|
||||
X_train, y_train = cbow_batch(batch, bs)
|
||||
net.train_on_batch(X_train, y_train)
|
||||
|
||||
|
||||
@@ -120,10 +140,7 @@ cdef public size_t out_size(object net):
|
||||
|
||||
|
||||
cdef public float eval_net(object net):
|
||||
try:
|
||||
return net.evaluate(X_test, y_test, verbose=False)
|
||||
except Exception as e:
|
||||
eprint(e)
|
||||
return nn.eval_network(net)
|
||||
|
||||
|
||||
cdef public void init_weightlist_like(WeightList* wl, object net):
|
||||
@@ -162,14 +179,8 @@ cdef public void combo_weights(
|
||||
wf += alpha * ww
|
||||
|
||||
|
||||
cdef public void create_test_dataset(size_t win):
|
||||
_create_test_dataset(win)
|
||||
|
||||
|
||||
cdef tuple cbow_batch(
|
||||
object net, float* batch, size_t bs
|
||||
):
|
||||
win = net.input_shape[1] // 2
|
||||
cdef tuple cbow_batch(float* batch, size_t bs):
|
||||
win = nn.WIN
|
||||
batch_np = np.asarray(<float[:bs,:2*win+1]>batch)
|
||||
X_np = batch_np[:, [*range(win), *range(win+1, win+win+1)]]
|
||||
y_np = nn.onehot(batch_np[:, win], nc=len(nn.vocab))
|
||||
@@ -177,6 +188,7 @@ cdef tuple cbow_batch(
|
||||
|
||||
|
||||
cdef list wrap_weight_list(WeightList* wl):
|
||||
"""Thinly wraps a WeightList struct into a NumPy array."""
|
||||
weights = []
|
||||
for i in range(wl.n_weights):
|
||||
w_shape = <long[:wl.weights[i].dims]>wl.weights[i].shape
|
||||
@@ -220,9 +232,3 @@ def ensure_contiguous(a):
|
||||
|
||||
def eprint(*args, **kwargs):
|
||||
return print(*args, flush=True, **kwargs)
|
||||
|
||||
|
||||
def _create_test_dataset(win):
|
||||
global X_test, y_test
|
||||
if X_test is None or y_test is None:
|
||||
X_test, y_test = nn.create_test_dataset(win)
|
||||
|
||||
Reference in New Issue
Block a user