From 32742059c79fadea4894f3e95d9c66aebad8f18a Mon Sep 17 00:00:00 2001 From: Pavel Lutskov Date: Wed, 11 Dec 2019 11:12:32 -0800 Subject: [PATCH] commit working termination code for two nodes before embarking on yet another side-quest --- main.c | 53 +++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 41 insertions(+), 12 deletions(-) diff --git a/main.c b/main.c index 54a3472..5502291 100644 --- a/main.c +++ b/main.c @@ -15,8 +15,9 @@ #define TAG_SWORD 7 #define TAG_IWORD 8 #define TAG_INSTR 9 +#define TAG_TERMT 10 -#define COMM 25 +#define COMM 1 #define ITER 690 #define BS 32 #define EMB 32 @@ -59,15 +60,11 @@ int my_mpi_id() { size_t number_of(Role what) { switch (what) { case TOKENIZER: - if (g_argc < 2) { - INFO_PRINTLN("NOT ENOUGH INPUTS!"); - MPI_Abort(MPI_COMM_WORLD, 1); - } return g_argc - 1; case FILTERER: - return 1; + return number_of(TOKENIZER); case BATCHER: - return 1; + return number_of(TOKENIZER); case LEARNER: return world_size() - number_of(TOKENIZER) @@ -168,19 +165,33 @@ int recv_word(Word* w, int src) { void tokenizer(const char* source) { INFO_PRINTF("Starting tokenizer %d\n", getpid()); + int rid = role_id_from_mpi_id(TOKENIZER, my_mpi_id()); + WordList wl = {0, 0, NULL}; size_t sync_ctr = 0; - while (get_tokens(&wl, source)) { + + Word terminator = {1, ""}; + MPI_Request stop_req; + int stop; + MPI_Irecv(&stop, 1, MPI_INT, MPI_ANY_SOURCE, TAG_TERMT, MPI_COMM_WORLD, + &stop_req); + MPI_Test(&stop_req, &stop, MPI_STATUS_IGNORE); + + while (!stop && get_tokens(&wl, source)) { for in_range(i, wl.n_words) { - if (sync_ctr == 1000) { - ssend_word(&wl.words[i], mpi_id_from_role_id(FILTERER, 0)); + if (sync_ctr == 10000) { + ssend_word(&wl.words[i], mpi_id_from_role_id(FILTERER, rid)); sync_ctr = 0; } else { - send_word(&wl.words[i], mpi_id_from_role_id(FILTERER, 0)); + send_word(&wl.words[i], mpi_id_from_role_id(FILTERER, rid)); } sync_ctr++; } + MPI_Test(&stop_req, &stop, MPI_STATUS_IGNORE); } + free_wordlist(&wl); + send_word(&terminator, mpi_id_from_role_id(FILTERER, rid)); + INFO_PRINTF("Finishing tokenizer %d\n", getpid()); } void filterer() { @@ -196,13 +207,19 @@ void filterer() { int src = 0; // WLOG while (1) { int stream_offs; - while (have[src] != entry_size) { + while (have[src] != entry_size) { // TODO FLATTEN PIPELINE src = recv_word(&w, MPI_ANY_SOURCE); + + if (!strlen(w.data)) break; + src = role_id_from_mpi_id(TOKENIZER, src); stream_offs = src*entry_size; buffer[stream_offs + have[src]] = vocab_idx_of(&w); if (buffer[stream_offs + have[src]] != -1) have[src]++; } + + if (!strlen(w.data)) break; + have[src] = 0; MPI_Send(buffer + stream_offs, entry_size, MPI_LONG, mpi_id_from_role_id(BATCHER, 0), @@ -211,6 +228,7 @@ void filterer() { free_word(&w); free(buffer); free(have); + INFO_PRINTF("Finishing filterer %d\n", getpid()); } void batcher() { @@ -341,6 +359,11 @@ void dispatcher() { min_loss = crt_loss < min_loss ? crt_loss : min_loss; INFO_PRINTF("Round %ld, validation loss %f\n", i, crt_loss); } + + int stop = 1; + MPI_Send(&stop, 1, MPI_INT, mpi_id_from_role_id(TOKENIZER, 0), TAG_TERMT, + MPI_COMM_WORLD); + time_t finish = time(NULL); float delta_t = finish - start; float delta_l = first_loss - eval_net(frank); @@ -363,6 +386,12 @@ void visualizer() { int main (int argc, const char **argv) { MPI_Init(NULL, NULL); + if (my_mpi_id() == 0) { + if (argc < 2) { + INFO_PRINTLN("NOT ENOUGH INPUTS!"); + MPI_Abort(MPI_COMM_WORLD, 1); + } + } g_argc = argc; // Cython Boilerplate