From 4a9e328884dccbe9e9ad8abd58a62b8b86695e38 Mon Sep 17 00:00:00 2001 From: Pavel Lutskov Date: Fri, 13 Dec 2019 11:05:38 -0800 Subject: [PATCH] visualizer kinda working, more work needs done --- .gitignore | 1 + bridge.pyx | 17 +++++++++++++++-- library.py | 14 ++++++++++++++ main.c | 37 ++++++++++++++++++++++++++----------- server.py | 26 ++++++++++++++++++++++++++ 5 files changed, 82 insertions(+), 13 deletions(-) create mode 100644 server.py diff --git a/.gitignore b/.gitignore index 5772e8c..5d8f7d5 100644 --- a/.gitignore +++ b/.gitignore @@ -4,5 +4,6 @@ run compile_commands.json cfg.json build/ +trained/ __pycache__/ data_*/ diff --git a/bridge.pyx b/bridge.pyx index fae5af1..b01e7a8 100644 --- a/bridge.pyx +++ b/bridge.pyx @@ -7,7 +7,7 @@ from libc.stdlib cimport malloc, realloc from libc.string cimport memcpy import library as nn -import flask +import server as srv tokenizers = {} @@ -44,7 +44,12 @@ cdef public char *greeting(): cdef public void serve(): - nn.app.run(port=8448) + srv.serve() + + +cdef public void bump_count(): + eprint(f'bumping count from {srv.counter} to {srv.counter + 1}') + srv.counter += 1 cdef public size_t getwin(): @@ -143,6 +148,14 @@ cdef public float eval_net(object net): return nn.eval_network(net) +cdef public void ckpt_net(object net): + nn.ckpt_network(net) + + +cdef public void save_emb(object net): + nn.save_embeddings(nn.get_embeddings(net)) + + cdef public void init_weightlist_like(WeightList* wl, object net): weights = net.get_weights() wl.n_weights = len(weights) diff --git a/library.py b/library.py index 0c6033d..56b6a8f 100644 --- a/library.py +++ b/library.py @@ -18,6 +18,7 @@ def read_cfg(): CFG = read_cfg() DATA = os.path.join(HERE, CFG['data']) +RESULTS = os.path.join(HERE, 'trained') CORPUS = os.path.join(DATA, 'corpus.txt') VOCAB = os.path.join(DATA, 'vocab.txt') TEST = os.path.join(DATA, 'test.txt') @@ -94,3 +95,16 @@ def token_generator(filename): tok = word_tokenize(l) if tok: yield tok + + +def get_embeddings(net): + return net.get_weights()[0] + + +def save_embeddings(emb): + import numpy as np + np.savetxt(os.path.join(RESULTS, f'embeddings_{CFG["data"]}.csv'), emb) + + +def ckpt_network(net): + net.save_weights(os.path.join(RESULTS, f'model_ckpt_{CFG["data"]}.h5')) diff --git a/main.c b/main.c index b5460bc..0c50d31 100644 --- a/main.c +++ b/main.c @@ -3,6 +3,7 @@ #include #include #include +#include #include #define TAG_IDGAF 0 @@ -67,13 +68,18 @@ size_t number_of(Role what) { - number_of(DISPATCHER) - number_of(VISUALIZER); case VISUALIZER: - return 0; + return 1; case DISPATCHER: return 1; } } int mpi_id_from_role_id(Role role, int rid) { + if (rid >= number_of(role) || rid < 0) { + INFO_PRINTF("There aren't %d of %d (but %lu)\n", + rid, role, number_of(role)); + MPI_Abort(MPI_COMM_WORLD, 1); + } int base = 0; for (Role r = TOKENIZER; r < role; r++) { base += number_of(r); @@ -146,20 +152,16 @@ void ssend_word(Word* w, int dest) { MPI_Ssend(w->data, len + 1, MPI_CHAR, dest, TAG_SWORD, MPI_COMM_WORLD); } -int recv_word(Word* w, int src) { - // WAT is going on here I have no idea +void recv_word(Word* w, int src) { long len; - MPI_Status stat; MPI_Recv(&len, 1, MPI_LONG, src, TAG_STLEN, MPI_COMM_WORLD, - &stat); - int the_src = stat.MPI_SOURCE; + MPI_STATUS_IGNORE); if (w->mem < len + 1) { w->mem = len + 1; w->data = realloc(w->data, sizeof(char) * w->mem); } - MPI_Recv(w->data, len + 1, MPI_CHAR, the_src, TAG_SWORD, MPI_COMM_WORLD, + MPI_Recv(w->data, len + 1, MPI_CHAR, src, TAG_SWORD, MPI_COMM_WORLD, MPI_STATUS_IGNORE); - return the_src; } void send_window(long* window, size_t winsize, int dest) { @@ -192,7 +194,11 @@ void tokenizer(const char* source) { ssend_word(&wl.words[i], next); sync_ctr = 0; } else { - send_word(&wl.words[i], next); + if (rand() % 100) { + // drop a word here and there + // probably would make sense if there was less data + send_word(&wl.words[i], next); + } } sync_ctr++; } @@ -398,6 +404,9 @@ void dispatcher() { crt_loss = eval_net(frank); min_loss = crt_loss < min_loss ? crt_loss : min_loss; INFO_PRINTF("Round %ld, validation loss %f\n", rounds, crt_loss); + + ckpt_net(frank); + rounds++; } time_t finish = time(NULL); @@ -412,10 +421,12 @@ void dispatcher() { TAG_INSTR, MPI_COMM_WORLD); } + save_emb(frank); + float delta_t = finish - start; float delta_l = first_loss - crt_loss; INFO_PRINTF( - "WIKI MPI adam consecutive_batch " + "Moby MPI adam consecutive_batch " "W%lu E%lu BS%lu bpe%lu LPR%d pp%lu," "%f,%f,%f,%f," "%lu,%.0f,%lu\n", @@ -434,6 +445,10 @@ void dispatcher() { void visualizer() { INFO_PRINTF("Starting visualizer %d\n", getpid()); serve(); + while (1) { + sleep(1); + bump_count(); + } } int main (int argc, const char **argv) { @@ -445,7 +460,7 @@ int main (int argc, const char **argv) { MPI_Abort(MPI_COMM_WORLD, 1); } int pipelines = argc - 1; - int min_nodes = 4 * pipelines + 1; + int min_nodes = 4 * pipelines + 2; if (world_size() < min_nodes) { INFO_PRINTF("You requested %d pipeline(s) " "but only provided %d procs " diff --git a/server.py b/server.py new file mode 100644 index 0000000..7575957 --- /dev/null +++ b/server.py @@ -0,0 +1,26 @@ +from threading import Thread +from sys import stderr + +import flask + + +t = None +app = flask.Flask(__name__) +counter = 0 + + +@app.route('/') +def main(): + return f'Hello {counter}' + + +def serve(): + global t + if t is None: + t = Thread(target=app.run, kwargs={'port': 8448}) + t.start() + print('So I kinda started', flush=True, file=stderr) + + +if __name__ == '__main__': + serve()