diff --git a/bridge.pyx b/bridge.pyx index eb50114..1a0d6b9 100644 --- a/bridge.pyx +++ b/bridge.pyx @@ -47,6 +47,12 @@ cdef public void serve(): srv.serve() +cdef public void server_update(float *emb): + embeddings = np.asarray(emb) + low_dim = nn.calc_TSNE(embeddings) + srv.emb_map = dict(zip(nn.inv_vocab, low_dim)) + + cdef public size_t getwin(): return nn.WIN @@ -67,6 +73,10 @@ cdef public float gettarget(): return nn.CFG['target'] +cdef public size_t getvocsize(): + return len(nn.vocab) + + cdef public int get_tokens(WordList* wl, const char *filename): fnu = filename.decode('utf-8') if fnu not in tokenizers: @@ -101,8 +111,8 @@ cdef public void _dbg_print(object o): eprint(o) -cdef public void _dbg_print_cbow_batch(float* batch, size_t bs): - X_np, y_np = cbow_batch(batch, bs) +cdef public void _dbg_print_cbow_batch(float* batch): + X_np, y_np = cbow_batch(batch) eprint(X_np) eprint(y_np) @@ -124,10 +134,8 @@ cdef public void set_net_weights(object net, WeightList* wl): net.set_weights(wrap_weight_list(wl)) -cdef public void step_net( - object net, float* batch, size_t bs -): - X_train, y_train = cbow_batch(batch, bs) +cdef public void step_net(object net, float* batch): + X_train, y_train = cbow_batch(batch) net.train_on_batch(X_train, y_train) @@ -183,8 +191,9 @@ cdef public void combo_weights( wf += alpha * ww -cdef tuple cbow_batch(float* batch, size_t bs): - win = nn.WIN +cdef tuple cbow_batch(float* batch): + win = getwin() + bs = getbs() batch_np = np.asarray(batch) X_np = batch_np[:, [*range(win), *range(win+1, win+win+1)]] y_np = nn.onehot(batch_np[:, win], nc=len(nn.vocab)) diff --git a/library.py b/library.py index 4376e84..b0f0b47 100644 --- a/library.py +++ b/library.py @@ -113,6 +113,13 @@ def get_embeddings(net): return net.get_weights()[0] +def calc_TSNE(emb): + # import umap + # return umap.UMAP().fit_transform(emb) + return emb + + + def save_embeddings(emb): import numpy as np np.savetxt(os.path.join(RESULTS, f'embeddings_{CFG["data_name"]}.csv'), diff --git a/main.c b/main.c index e2f7556..d983ac6 100644 --- a/main.c +++ b/main.c @@ -17,6 +17,7 @@ #define TAG_IWIND 8 #define TAG_INSTR 9 #define TAG_TERMT 10 +#define TAG_EMBED 11 #define in_range(i, x) (size_t i = 0; i < (x); i++) // I am honestly VERY sorry for this @@ -33,7 +34,7 @@ int g_argc; // sorry! typedef enum { TOKENIZER, - FILTERER, + FILTER, BATCHER, LEARNER, VISUALIZER, @@ -56,19 +57,19 @@ size_t number_of(Role what) { switch (what) { case TOKENIZER: return g_argc - 1; - case FILTERER: + case FILTER: return number_of(TOKENIZER); case BATCHER: return number_of(TOKENIZER); case LEARNER: return world_size() - number_of(TOKENIZER) - - number_of(FILTERER) + - number_of(FILTER) - number_of(BATCHER) - number_of(DISPATCHER) - number_of(VISUALIZER); case VISUALIZER: - return 1; + return 0; case DISPATCHER: return 1; } @@ -77,7 +78,7 @@ size_t number_of(Role what) { int mpi_id_from_role_id(Role role, int rid) { if (rid >= number_of(role) || rid < 0) { INFO_PRINTF("There aren't %d of %d (but %lu)\n", - rid, role, number_of(role)); + rid+1, role, number_of(role)); MPI_Abort(MPI_COMM_WORLD, 1); } int base = 0; @@ -176,7 +177,7 @@ void recv_window(long* window, size_t winsize, int src) { void tokenizer(const char* source) { INFO_PRINTF("Starting tokenizer %d\n", getpid()); int rid = my_role_id(TOKENIZER); - int next = mpi_id_from_role_id(FILTERER, rid); + int next = mpi_id_from_role_id(FILTER, rid); WordList wl = {0, 0, NULL}; size_t sync_ctr = 0; @@ -209,9 +210,9 @@ void tokenizer(const char* source) { INFO_PRINTF("Finishing tokenizer %d\n", getpid()); } -void filterer() { - INFO_PRINTF("Starting filterer %d\n", getpid()); - int rid = my_role_id(FILTERER); +void filter() { + INFO_PRINTF("Starting filter %d\n", getpid()); + int rid = my_role_id(FILTER); int tokenizer = mpi_id_from_role_id(TOKENIZER, rid); int batcher = mpi_id_from_role_id(BATCHER, rid); @@ -239,13 +240,13 @@ void filterer() { send_window(window, window_size, batcher); free_word(&w); free(window); - INFO_PRINTF("Finishing filterer %d\n", getpid()); + INFO_PRINTF("Finishing filter %d\n", getpid()); } void batcher() { INFO_PRINTF("Starting batcher %d\n", getpid()); int rid = my_role_id(BATCHER); - int tokenizer = mpi_id_from_role_id(FILTERER, rid); + int tokenizer = mpi_id_from_role_id(FILTER, rid); int bs = getbs(); int learner_mpi_id = 0; @@ -326,15 +327,13 @@ void learner() { int dispatcher = mpi_id_from_role_id(DISPATCHER, 0); INFO_PRINTF("Learner %d (pid %d) is assigned to pipeline %d\n", rid, getpid(), my_batcher_rid); - size_t bs = getbs(); - size_t bpe = getbpe(); PyObject* net = create_network(); WeightList wl; init_weightlist_like(&wl, net); size_t window_size = 2 * getwin() + 1; - size_t bufsize = bs * window_size; + size_t bufsize = getbs() * window_size; float* batch = malloc(bufsize * sizeof(float)); int go; @@ -344,11 +343,11 @@ void learner() { while (go != -1) { recv_weights(&wl, dispatcher); set_net_weights(net, &wl); - for in_range(k, bpe) { + for in_range(k, getbpe()) { MPI_Send(&me, 1, MPI_INT, batcher, TAG_READY, MPI_COMM_WORLD); MPI_Recv(batch, bufsize, MPI_FLOAT, batcher, TAG_BATCH, MPI_COMM_WORLD, MPI_STATUS_IGNORE); - step_net(net, batch, bs); + step_net(net, batch); } update_weightlist(&wl, net); send_weights(&wl, dispatcher); @@ -365,9 +364,11 @@ void learner() { void dispatcher() { INFO_PRINTF("Starting dispatcher %d\n", getpid()); int go = 1; + // int visualizer = mpi_id_from_role_id(VISUALIZER, 0); size_t bs = getbs(); size_t bpe = getbpe(); float target = gettarget(); + // size_t emb_mat_size = getemb() * getvocsize(); PyObject* frank = create_network(); WeightList wl; @@ -403,6 +404,9 @@ void dispatcher() { crt_loss = eval_net(frank); min_loss = crt_loss < min_loss ? crt_loss : min_loss; INFO_PRINTF("Round %ld, validation loss %f\n", rounds, crt_loss); + // MPI_Send(&go, 1, MPI_INT, visualizer, TAG_INSTR, MPI_COMM_WORLD); + // MPI_Send(wl.weights[0].W, emb_mat_size, MPI_FLOAT, + // visualizer, TAG_EMBED, MPI_COMM_WORLD); ckpt_net(frank); @@ -419,6 +423,8 @@ void dispatcher() { MPI_Send(&go, 1, MPI_INT, mpi_id_from_role_id(LEARNER, l), TAG_INSTR, MPI_COMM_WORLD); } + // MPI_Send(&go, 1, MPI_INT, mpi_id_from_role_id(VISUALIZER, 0), + // TAG_INSTR, MPI_COMM_WORLD); save_emb(frank); @@ -439,16 +445,38 @@ void dispatcher() { free(wls); free(round); INFO_PRINTF("Finishing dispatcher %d\n", getpid()); + // sleep(4); + // INFO_PRINTLN("Visualization server is still running on port 8448\n" + // "To terminate, press Ctrl-C"); } void visualizer() { INFO_PRINTF("Starting visualizer %d\n", getpid()); serve(); + + int dispatcher = mpi_id_from_role_id(DISPATCHER, 0); + int go_on = 1; + + size_t emb_mat_size = getvocsize() * getemb(); + float* embeddings = malloc(emb_mat_size * sizeof(float)); + + MPI_Recv(&go_on, 1, MPI_INT, dispatcher, TAG_INSTR, MPI_COMM_WORLD, + MPI_STATUS_IGNORE); + + while(go_on != -1) { + MPI_Recv(embeddings, emb_mat_size, MPI_FLOAT, dispatcher, TAG_EMBED, + MPI_COMM_WORLD, MPI_STATUS_IGNORE); + server_update(embeddings); + MPI_Recv(&go_on, 1, MPI_INT, dispatcher, TAG_INSTR, MPI_COMM_WORLD, + MPI_STATUS_IGNORE); + } + INFO_PRINTF("Exiting visualizer node %d\n", getpid()); } int main (int argc, const char **argv) { MPI_Init(NULL, NULL); + // Some sanity checks on the input if (my_mpi_id() == 0) { if (argc < 2) { INFO_PRINTLN("NOT ENOUGH INPUTS!"); @@ -479,8 +507,8 @@ int main (int argc, const char **argv) { role_id = role_id_from_mpi_id(TOKENIZER, my_mpi_id()); tokenizer(argv[role_id + 1]); break; - case FILTERER: - filterer(); + case FILTER: + filter(); break; case BATCHER: batcher(); diff --git a/server.py b/server.py index 7575957..90c8502 100644 --- a/server.py +++ b/server.py @@ -1,17 +1,26 @@ from threading import Thread -from sys import stderr import flask t = None app = flask.Flask(__name__) -counter = 0 +emb_map = None + + + +import logging +log = logging.getLogger('werkzeug') +log.setLevel(logging.ERROR) +app.logger.setLevel(logging.ERROR) @app.route('/') def main(): - return f'Hello {counter}' + if emb_map is None: + return 'Hello World!' + else: + return '\n'.join(f'{w}: {vec}' for w, vec in emb_map.items()) def serve(): @@ -19,7 +28,6 @@ def serve(): if t is None: t = Thread(target=app.run, kwargs={'port': 8448}) t.start() - print('So I kinda started', flush=True, file=stderr) if __name__ == '__main__':