commit working termination code for two nodes

before embarking on yet another side-quest
This commit is contained in:
2019-12-11 11:12:32 -08:00
parent 7043b65532
commit 32742059c7

53
main.c
View File

@@ -15,8 +15,9 @@
#define TAG_SWORD 7 #define TAG_SWORD 7
#define TAG_IWORD 8 #define TAG_IWORD 8
#define TAG_INSTR 9 #define TAG_INSTR 9
#define TAG_TERMT 10
#define COMM 25 #define COMM 1
#define ITER 690 #define ITER 690
#define BS 32 #define BS 32
#define EMB 32 #define EMB 32
@@ -59,15 +60,11 @@ int my_mpi_id() {
size_t number_of(Role what) { size_t number_of(Role what) {
switch (what) { switch (what) {
case TOKENIZER: case TOKENIZER:
if (g_argc < 2) {
INFO_PRINTLN("NOT ENOUGH INPUTS!");
MPI_Abort(MPI_COMM_WORLD, 1);
}
return g_argc - 1; return g_argc - 1;
case FILTERER: case FILTERER:
return 1; return number_of(TOKENIZER);
case BATCHER: case BATCHER:
return 1; return number_of(TOKENIZER);
case LEARNER: case LEARNER:
return world_size() return world_size()
- number_of(TOKENIZER) - number_of(TOKENIZER)
@@ -168,19 +165,33 @@ int recv_word(Word* w, int src) {
void tokenizer(const char* source) { void tokenizer(const char* source) {
INFO_PRINTF("Starting tokenizer %d\n", getpid()); INFO_PRINTF("Starting tokenizer %d\n", getpid());
int rid = role_id_from_mpi_id(TOKENIZER, my_mpi_id());
WordList wl = {0, 0, NULL}; WordList wl = {0, 0, NULL};
size_t sync_ctr = 0; size_t sync_ctr = 0;
while (get_tokens(&wl, source)) {
Word terminator = {1, ""};
MPI_Request stop_req;
int stop;
MPI_Irecv(&stop, 1, MPI_INT, MPI_ANY_SOURCE, TAG_TERMT, MPI_COMM_WORLD,
&stop_req);
MPI_Test(&stop_req, &stop, MPI_STATUS_IGNORE);
while (!stop && get_tokens(&wl, source)) {
for in_range(i, wl.n_words) { for in_range(i, wl.n_words) {
if (sync_ctr == 1000) { if (sync_ctr == 10000) {
ssend_word(&wl.words[i], mpi_id_from_role_id(FILTERER, 0)); ssend_word(&wl.words[i], mpi_id_from_role_id(FILTERER, rid));
sync_ctr = 0; sync_ctr = 0;
} else { } else {
send_word(&wl.words[i], mpi_id_from_role_id(FILTERER, 0)); send_word(&wl.words[i], mpi_id_from_role_id(FILTERER, rid));
} }
sync_ctr++; sync_ctr++;
} }
MPI_Test(&stop_req, &stop, MPI_STATUS_IGNORE);
} }
free_wordlist(&wl);
send_word(&terminator, mpi_id_from_role_id(FILTERER, rid));
INFO_PRINTF("Finishing tokenizer %d\n", getpid());
} }
void filterer() { void filterer() {
@@ -196,13 +207,19 @@ void filterer() {
int src = 0; // WLOG int src = 0; // WLOG
while (1) { while (1) {
int stream_offs; int stream_offs;
while (have[src] != entry_size) { while (have[src] != entry_size) { // TODO FLATTEN PIPELINE
src = recv_word(&w, MPI_ANY_SOURCE); src = recv_word(&w, MPI_ANY_SOURCE);
if (!strlen(w.data)) break;
src = role_id_from_mpi_id(TOKENIZER, src); src = role_id_from_mpi_id(TOKENIZER, src);
stream_offs = src*entry_size; stream_offs = src*entry_size;
buffer[stream_offs + have[src]] = vocab_idx_of(&w); buffer[stream_offs + have[src]] = vocab_idx_of(&w);
if (buffer[stream_offs + have[src]] != -1) have[src]++; if (buffer[stream_offs + have[src]] != -1) have[src]++;
} }
if (!strlen(w.data)) break;
have[src] = 0; have[src] = 0;
MPI_Send(buffer + stream_offs, entry_size, MPI_LONG, MPI_Send(buffer + stream_offs, entry_size, MPI_LONG,
mpi_id_from_role_id(BATCHER, 0), mpi_id_from_role_id(BATCHER, 0),
@@ -211,6 +228,7 @@ void filterer() {
free_word(&w); free_word(&w);
free(buffer); free(buffer);
free(have); free(have);
INFO_PRINTF("Finishing filterer %d\n", getpid());
} }
void batcher() { void batcher() {
@@ -341,6 +359,11 @@ void dispatcher() {
min_loss = crt_loss < min_loss ? crt_loss : min_loss; min_loss = crt_loss < min_loss ? crt_loss : min_loss;
INFO_PRINTF("Round %ld, validation loss %f\n", i, crt_loss); INFO_PRINTF("Round %ld, validation loss %f\n", i, crt_loss);
} }
int stop = 1;
MPI_Send(&stop, 1, MPI_INT, mpi_id_from_role_id(TOKENIZER, 0), TAG_TERMT,
MPI_COMM_WORLD);
time_t finish = time(NULL); time_t finish = time(NULL);
float delta_t = finish - start; float delta_t = finish - start;
float delta_l = first_loss - eval_net(frank); float delta_l = first_loss - eval_net(frank);
@@ -363,6 +386,12 @@ void visualizer() {
int main (int argc, const char **argv) { int main (int argc, const char **argv) {
MPI_Init(NULL, NULL); MPI_Init(NULL, NULL);
if (my_mpi_id() == 0) {
if (argc < 2) {
INFO_PRINTLN("NOT ENOUGH INPUTS!");
MPI_Abort(MPI_COMM_WORLD, 1);
}
}
g_argc = argc; g_argc = argc;
// Cython Boilerplate // Cython Boilerplate