From 598e59ca2e20db02e78cf63d898803fb5ac7bfb7 Mon Sep 17 00:00:00 2001 From: Pavel Lutskov Date: Wed, 11 Dec 2019 22:23:11 -0800 Subject: [PATCH] parallelizing doesn't bring anything --- main.c | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/main.c b/main.c index 13b7b29..0277fab 100644 --- a/main.c +++ b/main.c @@ -17,8 +17,8 @@ #define TAG_INSTR 9 #define TAG_TERMT 10 -#define COMM 25 -#define ITER 690 +#define COMM 5 +#define ITER 500 #define BS 32 #define EMB 32 #define WIN 2 @@ -319,11 +319,14 @@ void recv_weights(WeightList* wl, int src) { void learner() { INFO_PRINTF("Starting learner %d\n", getpid()); int me = my_mpi_id(); - int batcher = mpi_id_from_role_id(BATCHER, 0); + int rid = role_id_from_mpi_id(LEARNER, me); + int my_batcher_rid = rid % number_of(BATCHER); + int batcher = mpi_id_from_role_id(BATCHER, my_batcher_rid); int dispatcher = mpi_id_from_role_id(DISPATCHER, 0); + INFO_PRINTF("%d is Learner %d assigned to batcher %d\n", getpid(), + rid, my_batcher_rid); PyObject* net = create_network(WIN, EMB); - create_test_dataset(WIN); WeightList wl; init_weightlist_like(&wl, net); @@ -408,10 +411,11 @@ void dispatcher() { time_t finish = time(NULL); float delta_t = finish - start; - float delta_l = first_loss - eval_net(frank); + float delta_l = first_loss - crt_loss; INFO_PRINTF( - "Laptop MPI sgd consecutive_batch W%d E%d BS%d R%d bpe%d LPR%d," - "%f,%f,%f\n", WIN, EMB, BS, COMM, ITER, lpr, + "Laptop MPI sgd consecutive_batch W%d E%d " + "BS%d R%d bpe%d LPR%d pp%d," + "%f,%f,%f\n", WIN, EMB, BS, COMM, ITER, lpr, g_argc - 1, delta_l / COMM, delta_l / delta_t, min_loss); Py_DECREF(frank); free_weightlist(&wl); @@ -434,6 +438,15 @@ int main (int argc, const char **argv) { INFO_PRINTLN("NOT ENOUGH INPUTS!"); MPI_Abort(MPI_COMM_WORLD, 1); } + int pipelines = argc - 1; + int min_nodes = 3 * pipelines + 2; + if (world_size() < min_nodes) { + INFO_PRINTF("You requested %d pipelines " + "but only provided %d procs " + "(%d required)\n", + pipelines, world_size(), min_nodes); + MPI_Abort(MPI_COMM_WORLD, 1); + } } g_argc = argc;