for some two other files weren't added to commit
This commit is contained in:
26
bridge.pyx
26
bridge.pyx
@@ -76,17 +76,15 @@ cdef public void f_idx_list_to_print(float* f_idxs, size_t num):
|
|||||||
|
|
||||||
|
|
||||||
cdef public void cbow_batch(
|
cdef public void cbow_batch(
|
||||||
float* X, float* y, float* idxs, size_t bs, size_t win
|
float* batch, size_t bs, size_t win
|
||||||
):
|
):
|
||||||
idxs_np = np.asarray(<float[:bs + 2*win]>idxs)
|
batch_np = np.asarray(<float[:bs,:2*win+1]>batch)
|
||||||
# Deal with X
|
# Deal with X
|
||||||
X_np = np.asarray(<float[:bs,:2*win]>X)
|
X_np = np.concatenate([batch_np[:, :win], batch_np[:, win+1:]], axis=1)
|
||||||
for r in range(bs):
|
y_np = nn.onehot(batch_np[:, win], nc=len(nn.vocab))
|
||||||
X_np[r, :win] = idxs_np[r:r+win]
|
eprint(batch_np)
|
||||||
X_np[r, win:] = idxs_np[r+win+1:r+win+1+win]
|
eprint(X_np)
|
||||||
|
eprint(np.argmax(y_np, axis=1))
|
||||||
# Deal with y
|
|
||||||
nn.onehot(np.asarray(<float[:bs, :len(nn.vocab)]>y), idxs_np[win:-win])
|
|
||||||
|
|
||||||
|
|
||||||
cdef public void debug_print(object o):
|
cdef public void debug_print(object o):
|
||||||
@@ -102,13 +100,11 @@ cdef public void set_net_weights(object net, WeightList* wl):
|
|||||||
|
|
||||||
|
|
||||||
cdef public void step_net(
|
cdef public void step_net(
|
||||||
object net, float* X, float* y, size_t batch_size
|
object net, float* batch, size_t bs
|
||||||
):
|
):
|
||||||
in_shape = (batch_size,) + net.input_shape[1:]
|
# X_train, y_train = cbow_batch(net, batch, bs)
|
||||||
out_shape = (batch_size,) + net.output_shape[1:]
|
X_train = None
|
||||||
X_train = np.asarray(<float[:np.prod(in_shape)]>X).reshape(in_shape)
|
y_train = None
|
||||||
y_train = np.asarray(<float[:np.prod(out_shape)]>y).reshape(out_shape),
|
|
||||||
|
|
||||||
net.train_on_batch(X_train, y_train)
|
net.train_on_batch(X_train, y_train)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import numpy as np
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) # STFU!
|
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) # STFU!
|
||||||
|
|
||||||
|
from mynet import onehot
|
||||||
|
|
||||||
|
|
||||||
HERE = os.path.abspath(os.path.dirname(__file__))
|
HERE = os.path.abspath(os.path.dirname(__file__))
|
||||||
CORPUS = os.path.join(HERE, 'melville-moby_dick.txt')
|
CORPUS = os.path.join(HERE, 'melville-moby_dick.txt')
|
||||||
@@ -16,11 +18,6 @@ vocab = {
|
|||||||
inv_vocab = sorted(vocab, key=vocab.get)
|
inv_vocab = sorted(vocab, key=vocab.get)
|
||||||
|
|
||||||
|
|
||||||
def onehot(oh_store, idx):
|
|
||||||
oh_store[:] = 0
|
|
||||||
oh_store[np.arange(len(idx)), idx.astype(np.int)] = 1
|
|
||||||
|
|
||||||
|
|
||||||
def word_tokenize(s: str):
|
def word_tokenize(s: str):
|
||||||
l = ''.join(c.lower() if c.isalpha() else ' ' for c in s)
|
l = ''.join(c.lower() if c.isalpha() else ' ' for c in s)
|
||||||
return l.split()
|
return l.split()
|
||||||
@@ -70,7 +67,7 @@ def create_cbow_network(win, embed):
|
|||||||
|
|
||||||
def token_generator(filename):
|
def token_generator(filename):
|
||||||
with open(filename) as f:
|
with open(filename) as f:
|
||||||
for i, l in enumerate(f.readlines()):
|
for i, l in enumerate(f.readlines(1000)):
|
||||||
if not l.isspace():
|
if not l.isspace():
|
||||||
tok = word_tokenize(l)
|
tok = word_tokenize(l)
|
||||||
if tok:
|
if tok:
|
||||||
|
|||||||
Reference in New Issue
Block a user