for some reason I dockerized it and it works
This commit is contained in:
38
library.py
38
library.py
@@ -1,9 +1,8 @@
|
||||
import os
|
||||
import json
|
||||
from sys import stderr
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
|
||||
from mynet import onehot
|
||||
|
||||
|
||||
WIN = 2
|
||||
EMB = 32
|
||||
@@ -11,21 +10,27 @@ EMB = 32
|
||||
HERE = os.path.abspath(os.path.dirname(__file__))
|
||||
|
||||
|
||||
def read_cfg():
|
||||
with open(os.path.join(HERE, 'cfg.json')) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
CFG = read_cfg()
|
||||
DATA = os.path.join(HERE, CFG['data'])
|
||||
DATA = os.path.join(HERE, 'data')
|
||||
RESULTS = os.path.join(HERE, 'trained')
|
||||
CORPUS = os.path.join(DATA, 'corpus.txt')
|
||||
VOCAB = os.path.join(DATA, 'vocab.txt')
|
||||
TEST = os.path.join(DATA, 'test.txt')
|
||||
|
||||
|
||||
if not os.path.exists(RESULTS):
|
||||
os.mkdir(RESULTS)
|
||||
|
||||
|
||||
def read_cfg():
|
||||
with open(os.path.join(DATA, 'cfg.json'), encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
CFG = read_cfg()
|
||||
|
||||
|
||||
def read_vocab_list():
|
||||
with open(VOCAB) as f:
|
||||
with open(VOCAB, encoding='utf-8') as f:
|
||||
return f.read().split()
|
||||
|
||||
|
||||
@@ -41,6 +46,13 @@ def word_tokenize(s: str):
|
||||
return l.split()
|
||||
|
||||
|
||||
def onehot(a, nc=10):
|
||||
import numpy as np
|
||||
oh = np.zeros((len(a), nc), dtype=np.float32)
|
||||
oh[np.arange(len(a)), a.flatten().astype(np.int)] = 1
|
||||
return oh
|
||||
|
||||
|
||||
def create_test_dataset():
|
||||
import numpy as np
|
||||
test_dataset = np.vectorize(vocab.get)(np.genfromtxt(TEST, dtype=str))
|
||||
@@ -89,7 +101,7 @@ def eval_network(net):
|
||||
|
||||
|
||||
def token_generator(filename):
|
||||
with open(filename) as f:
|
||||
with open(filename, encoding='utf-8') as f:
|
||||
for l in f:
|
||||
if not l.isspace():
|
||||
tok = word_tokenize(l)
|
||||
@@ -103,8 +115,8 @@ def get_embeddings(net):
|
||||
|
||||
def save_embeddings(emb):
|
||||
import numpy as np
|
||||
np.savetxt(os.path.join(RESULTS, f'embeddings_{CFG["data"]}.csv'), emb)
|
||||
np.savetxt(os.path.join(RESULTS, f'embeddings_{CFG["name"]}.csv'), emb)
|
||||
|
||||
|
||||
def ckpt_network(net):
|
||||
net.save_weights(os.path.join(RESULTS, f'model_ckpt_{CFG["data"]}.h5'))
|
||||
net.save_weights(os.path.join(RESULTS, f'model_ckpt_{CFG["name"]}.h5'))
|
||||
|
||||
Reference in New Issue
Block a user