stop refactoring and get some stuff huggin' done
This commit is contained in:
54
library.py
54
library.py
@@ -1,19 +1,38 @@
|
||||
import os
|
||||
import json
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
|
||||
from mynet import onehot
|
||||
|
||||
|
||||
WIN = 2
|
||||
EMB = 32
|
||||
|
||||
HERE = os.path.abspath(os.path.dirname(__file__))
|
||||
DATA = os.path.join(HERE, 'data')
|
||||
|
||||
|
||||
def read_cfg():
|
||||
with open(os.path.join(HERE, 'cfg.json')) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
CFG = read_cfg()
|
||||
DATA = os.path.join(HERE, CFG['data'])
|
||||
CORPUS = os.path.join(DATA, 'corpus.txt')
|
||||
VOCAB = os.path.join(DATA, 'vocab.txt')
|
||||
TEST = os.path.join(DATA, 'test.txt')
|
||||
|
||||
vocab = {
|
||||
w: i for i, w in enumerate(open(VOCAB).read().splitlines(keepends=False))
|
||||
}
|
||||
inv_vocab = sorted(vocab, key=vocab.get)
|
||||
|
||||
def read_vocab_list():
|
||||
with open(VOCAB) as f:
|
||||
return f.read().split()
|
||||
|
||||
|
||||
inv_vocab = read_vocab_list()
|
||||
vocab = {w: i for i, w in enumerate(inv_vocab)}
|
||||
|
||||
X_test = None
|
||||
y_test = None
|
||||
|
||||
|
||||
def word_tokenize(s: str):
|
||||
@@ -21,13 +40,14 @@ def word_tokenize(s: str):
|
||||
return l.split()
|
||||
|
||||
|
||||
def create_test_dataset(win):
|
||||
def create_test_dataset():
|
||||
import numpy as np
|
||||
test_dataset = np.vectorize(vocab.get)(np.genfromtxt(TEST, dtype=str))
|
||||
assert test_dataset.shape[1] == 2*win + 1
|
||||
X_test = test_dataset[:, [*range(0, win), *range(win+1, win+win+1)]]
|
||||
y_test = onehot(test_dataset[:, win], nc=len(vocab))
|
||||
return X_test, y_test
|
||||
assert test_dataset.shape[1] == 2*WIN + 1
|
||||
|
||||
global X_test, y_test
|
||||
X_test = test_dataset[:, [*range(0, WIN), *range(WIN+1, WIN+WIN+1)]]
|
||||
y_test = onehot(test_dataset[:, WIN], nc=len(vocab))
|
||||
|
||||
|
||||
def create_mnist_network():
|
||||
@@ -44,13 +64,13 @@ def create_mnist_network():
|
||||
return model
|
||||
|
||||
|
||||
def create_cbow_network(win, embed):
|
||||
def create_cbow_network():
|
||||
import tensorflow as tf
|
||||
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) # STFU!
|
||||
tf.random.set_random_seed(42)
|
||||
|
||||
ctxt = tf.keras.layers.Input(shape=[2*win])
|
||||
ed = tf.keras.layers.Embedding(len(vocab), embed, input_length=2*win)(ctxt)
|
||||
ctxt = tf.keras.layers.Input(shape=[2*WIN])
|
||||
ed = tf.keras.layers.Embedding(len(vocab), EMB, input_length=2*WIN)(ctxt)
|
||||
cbow = tf.keras.layers.Lambda(lambda x: tf.reduce_mean(x, axis=1))(ed)
|
||||
blowup = tf.keras.layers.Dense(len(vocab), activation='softmax')(cbow)
|
||||
mod = tf.keras.Model(inputs=ctxt, outputs=blowup)
|
||||
@@ -61,9 +81,15 @@ def create_cbow_network(win, embed):
|
||||
return mod
|
||||
|
||||
|
||||
def eval_network(net):
|
||||
if X_test is None or y_test is None:
|
||||
create_test_dataset()
|
||||
return net.evaluate(X_test, y_test, verbose=False)
|
||||
|
||||
|
||||
def token_generator(filename):
|
||||
with open(filename) as f:
|
||||
for i, l in enumerate(f.readlines()):
|
||||
for l in f:
|
||||
if not l.isspace():
|
||||
tok = word_tokenize(l)
|
||||
if tok:
|
||||
|
||||
Reference in New Issue
Block a user