parallelizing doesn't bring anything

This commit is contained in:
2019-12-11 22:23:11 -08:00
parent 2e5042f0e3
commit 598e59ca2e

27
main.c
View File

@@ -17,8 +17,8 @@
#define TAG_INSTR 9 #define TAG_INSTR 9
#define TAG_TERMT 10 #define TAG_TERMT 10
#define COMM 25 #define COMM 5
#define ITER 690 #define ITER 500
#define BS 32 #define BS 32
#define EMB 32 #define EMB 32
#define WIN 2 #define WIN 2
@@ -319,11 +319,14 @@ void recv_weights(WeightList* wl, int src) {
void learner() { void learner() {
INFO_PRINTF("Starting learner %d\n", getpid()); INFO_PRINTF("Starting learner %d\n", getpid());
int me = my_mpi_id(); int me = my_mpi_id();
int batcher = mpi_id_from_role_id(BATCHER, 0); int rid = role_id_from_mpi_id(LEARNER, me);
int my_batcher_rid = rid % number_of(BATCHER);
int batcher = mpi_id_from_role_id(BATCHER, my_batcher_rid);
int dispatcher = mpi_id_from_role_id(DISPATCHER, 0); int dispatcher = mpi_id_from_role_id(DISPATCHER, 0);
INFO_PRINTF("%d is Learner %d assigned to batcher %d\n", getpid(),
rid, my_batcher_rid);
PyObject* net = create_network(WIN, EMB); PyObject* net = create_network(WIN, EMB);
create_test_dataset(WIN);
WeightList wl; WeightList wl;
init_weightlist_like(&wl, net); init_weightlist_like(&wl, net);
@@ -408,10 +411,11 @@ void dispatcher() {
time_t finish = time(NULL); time_t finish = time(NULL);
float delta_t = finish - start; float delta_t = finish - start;
float delta_l = first_loss - eval_net(frank); float delta_l = first_loss - crt_loss;
INFO_PRINTF( INFO_PRINTF(
"Laptop MPI sgd consecutive_batch W%d E%d BS%d R%d bpe%d LPR%d," "Laptop MPI sgd consecutive_batch W%d E%d "
"%f,%f,%f\n", WIN, EMB, BS, COMM, ITER, lpr, "BS%d R%d bpe%d LPR%d pp%d,"
"%f,%f,%f\n", WIN, EMB, BS, COMM, ITER, lpr, g_argc - 1,
delta_l / COMM, delta_l / delta_t, min_loss); delta_l / COMM, delta_l / delta_t, min_loss);
Py_DECREF(frank); Py_DECREF(frank);
free_weightlist(&wl); free_weightlist(&wl);
@@ -434,6 +438,15 @@ int main (int argc, const char **argv) {
INFO_PRINTLN("NOT ENOUGH INPUTS!"); INFO_PRINTLN("NOT ENOUGH INPUTS!");
MPI_Abort(MPI_COMM_WORLD, 1); MPI_Abort(MPI_COMM_WORLD, 1);
} }
int pipelines = argc - 1;
int min_nodes = 3 * pipelines + 2;
if (world_size() < min_nodes) {
INFO_PRINTF("You requested %d pipelines "
"but only provided %d procs "
"(%d required)\n",
pipelines, world_size(), min_nodes);
MPI_Abort(MPI_COMM_WORLD, 1);
}
} }
g_argc = argc; g_argc = argc;