diff --git a/main.c b/main.c index 69cda07..0583c2e 100644 --- a/main.c +++ b/main.c @@ -15,9 +15,9 @@ #define TAG_SWORD 7 #define TAG_IWORD 8 -#define COMM 5 -#define ITER 200 -#define BS 32 +#define COMM 1 +#define ITER 1000 +#define BS 10 #define EMB 20 #define WIN 2 #define FSPC 1 @@ -147,54 +147,56 @@ void tokenizer(const char* source) { void filterer() { Word w = {0, NULL}; - long idx; + const size_t bufsize = 2 * WIN + 1; + long* idx = malloc(bufsize * sizeof(long)); + size_t have = 0; while (1) { - recv_word(&w, role_id_from_mpi_id(TOKENIZER, 0)); - if (!strlen(w.data)) { - break; - } - idx = vocab_idx_of(&w); - if (idx != -1) { - MPI_Send(&idx, 1, MPI_LONG, mpi_id_from_role_id(BATCHER, 0), - TAG_IWORD, MPI_COMM_WORLD); + while (have < bufsize) { + recv_word(&w, role_id_from_mpi_id(TOKENIZER, 0)); + if (!strlen(w.data)) break; + idx[have] = vocab_idx_of(&w); + if (idx[have] != -1) have++; } + if (!strlen(w.data)) break; + have = 0; + MPI_Send(idx, bufsize, MPI_LONG, mpi_id_from_role_id(BATCHER, 0), + TAG_IWORD, MPI_COMM_WORLD); } - idx = -1; - MPI_Send(&idx, 1, MPI_LONG, mpi_id_from_role_id(BATCHER, 0), + idx[0] = -1; + MPI_Send(idx, bufsize, MPI_LONG, mpi_id_from_role_id(BATCHER, 0), TAG_IWORD, MPI_COMM_WORLD); free_word(&w); + free(idx); } void batcher() { // Reads some data and converts it to a float array // INFO_PRINTF("Starting batcher %d\n", getpid()); int s = 0; - const size_t n_words = BS + WIN + WIN; - float* f_widx = malloc(n_words * sizeof(float)); - long l_wid = 0; + const size_t entry_size = 2 * WIN + 1; + const size_t bufsize = BS * entry_size; + float* batch = malloc(bufsize * sizeof(float)); + long* l_wid = malloc(entry_size * sizeof(long)); - while (l_wid != -1) { - - for in_range(i, n_words) { - MPI_Recv(&l_wid, 1, MPI_LONG, mpi_id_from_role_id(FILTERER, 0), + while (1) { + for in_range(r, BS) { + MPI_Recv(l_wid, entry_size, MPI_LONG, + mpi_id_from_role_id(FILTERER, 0), TAG_IWORD, MPI_COMM_WORLD, MPI_STATUS_IGNORE); - if (l_wid == -1) break; - f_widx[i] = (float)l_wid; + if (l_wid[0] == -1) break; + for in_range(c, entry_size) { + batch[r*entry_size + c] = (float)l_wid[c]; + } } - if (l_wid == -1) break; + if (l_wid[0] == -1) break; + cbow_batch(batch, BS, WIN); - // f_idx_list_to_c_string(f_widx, n_words); - // for in_range(i, n_words) { - // INFO_PRINTF("%5.0f ", f_widx[i]); - // } - // INFO_PRINTLN(""); - MPI_Recv(&s, 1, MPI_INT, MPI_ANY_SOURCE, TAG_READY, MPI_COMM_WORLD, - MPI_STATUS_IGNORE); - if (s != -1) { - MPI_Send(f_widx, n_words, MPI_FLOAT, s, TAG_BATCH, MPI_COMM_WORLD); - } + // MPI_Recv(&s, 1, MPI_INT, MPI_ANY_SOURCE, TAG_READY, MPI_COMM_WORLD, + // MPI_STATUS_IGNORE); + // MPI_Send(batch, bufsize, MPI_FLOAT, s, TAG_BATCH, MPI_COMM_WORLD); } - free(f_widx); + free(l_wid); + free(batch); } void free_weightlist(WeightList* wl) { @@ -265,8 +267,9 @@ void slave_node() { MPI_Recv(f_widx, n_words, MPI_FLOAT, mpi_id_from_role_id(BATCHER, 0), TAG_BATCH, MPI_COMM_WORLD, MPI_STATUS_IGNORE); - cbow_batch(X, y, f_widx, BS, WIN); - step_net(net, X, y, BS); + // cbow_batch(X, y, f_widx, BS, WIN); + step_net(net, X, BS); +#warning "fix this" INFO_PRINTLN("."); } printf("%d net: %f\n", my_mpi_id(), eval_net(net)); @@ -347,9 +350,9 @@ int main (int argc, const char **argv) { case BATCHER: batcher(); break; - case SLAVE: - slave_node(); - break; + // case SLAVE: + // slave_node(); + // break; default: INFO_PRINTLN("DYING HORRIBLY!"); // case SLAVE: slave_node(); break;