diff --git a/bridge.pyx b/bridge.pyx index 9c9cd59..e8c21be 100644 --- a/bridge.pyx +++ b/bridge.pyx @@ -76,17 +76,15 @@ cdef public void f_idx_list_to_print(float* f_idxs, size_t num): cdef public void cbow_batch( - float* X, float* y, float* idxs, size_t bs, size_t win + float* batch, size_t bs, size_t win ): - idxs_np = np.asarray(idxs) + batch_np = np.asarray(batch) # Deal with X - X_np = np.asarray(X) - for r in range(bs): - X_np[r, :win] = idxs_np[r:r+win] - X_np[r, win:] = idxs_np[r+win+1:r+win+1+win] - - # Deal with y - nn.onehot(np.asarray(y), idxs_np[win:-win]) + X_np = np.concatenate([batch_np[:, :win], batch_np[:, win+1:]], axis=1) + y_np = nn.onehot(batch_np[:, win], nc=len(nn.vocab)) + eprint(batch_np) + eprint(X_np) + eprint(np.argmax(y_np, axis=1)) cdef public void debug_print(object o): @@ -102,13 +100,11 @@ cdef public void set_net_weights(object net, WeightList* wl): cdef public void step_net( - object net, float* X, float* y, size_t batch_size + object net, float* batch, size_t bs ): - in_shape = (batch_size,) + net.input_shape[1:] - out_shape = (batch_size,) + net.output_shape[1:] - X_train = np.asarray(X).reshape(in_shape) - y_train = np.asarray(y).reshape(out_shape), - + # X_train, y_train = cbow_batch(net, batch, bs) + X_train = None + y_train = None net.train_on_batch(X_train, y_train) diff --git a/library.py b/library.py index 35f349d..ba71a73 100644 --- a/library.py +++ b/library.py @@ -5,6 +5,8 @@ import numpy as np import tensorflow as tf tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) # STFU! +from mynet import onehot + HERE = os.path.abspath(os.path.dirname(__file__)) CORPUS = os.path.join(HERE, 'melville-moby_dick.txt') @@ -16,11 +18,6 @@ vocab = { inv_vocab = sorted(vocab, key=vocab.get) -def onehot(oh_store, idx): - oh_store[:] = 0 - oh_store[np.arange(len(idx)), idx.astype(np.int)] = 1 - - def word_tokenize(s: str): l = ''.join(c.lower() if c.isalpha() else ' ' for c in s) return l.split() @@ -70,7 +67,7 @@ def create_cbow_network(win, embed): def token_generator(filename): with open(filename) as f: - for i, l in enumerate(f.readlines()): + for i, l in enumerate(f.readlines(1000)): if not l.isspace(): tok = word_tokenize(l) if tok: