diff --git a/bridge.pyx b/bridge.pyx index 1e801d3..f7eed7c 100644 --- a/bridge.pyx +++ b/bridge.pyx @@ -58,9 +58,9 @@ cdef public int get_tokens(WordList* wl, const char *filename): cdef public long vocab_idx_of(Word* w): word = w.data.decode('utf-8') - if word.lower() in nn.vocab: - return nn.vocab.index(word.lower()) - else: + try: + return nn.vocab.index(word) + except ValueError: return -1 diff --git a/library.py b/library.py index b7cbda7..7a92159 100644 --- a/library.py +++ b/library.py @@ -9,7 +9,7 @@ from mynet import load_mnist, onehot def word_tokenize(s: str): - l = ''.join(c if c.isalpha() else ' ' for c in s) + l = ''.join(c.lower() if c.isalpha() else ' ' for c in s) return l.split() @@ -47,7 +47,7 @@ def create_cbow_network(win, vocab, embed): def token_generator(filename): with open(filename) as f: - for l in f.readlines(500): + for l in f.readlines(): if not l.isspace(): tok = word_tokenize(l) if tok: diff --git a/main.c b/main.c index 174866d..8b85c25 100644 --- a/main.c +++ b/main.c @@ -17,7 +17,7 @@ #define COMM 100 #define ITER 20 -#define BS 50 +#define BS 20 #define EMB 20 #define WIN 2 #define FSPC 1 @@ -156,41 +156,54 @@ void tokenizer(const char* source) { void filterer() { Word w = {0, NULL}; + long idx; while (1) { recv_word(&w, role_id_from_mpi_id(TOKENIZER, 0)); if (!strlen(w.data)) { break; } - INFO_PRINTF("%s: ", w.data); - long idx = vocab_idx_of(&w); - INFO_PRINTF("%ld\n", idx); - // if (idx != -1) { - // MPI_Send(&idx, 1, MPI_LONG, mpi_id_from_role_id(BATCHER, 0), - // TAG_IWORD, MPI_COMM_WORLD); - // } + // INFO_PRINTF("%s: ", w.data); + idx = vocab_idx_of(&w); + // INFO_PRINTF("%ld\n", idx); + if (idx != -1) { + MPI_Send(&idx, 1, MPI_LONG, mpi_id_from_role_id(BATCHER, 0), + TAG_IWORD, MPI_COMM_WORLD); + } } + idx = -1; + MPI_Send(&idx, 1, MPI_LONG, mpi_id_from_role_id(BATCHER, 0), + TAG_IWORD, MPI_COMM_WORLD); free_word(&w); } void batcher() { // Reads some data and converts it to a float array - INFO_PRINTF("Starting batcher %d\n", getpid()); - int s = 0; + // INFO_PRINTF("Starting batcher %d\n", getpid()); + // int s = 0; const size_t n_words = BS + WIN + WIN; float* f_widx = malloc(n_words * sizeof(float)); + long l_wid = 0; + + while (l_wid != -1) { - while (s != -1) { for in_range(i, n_words) { - long l_wid; - MPI_Recv(&l_wid, 1, MPI_LONG, role_id_from_mpi_id(FILTERER, 0), + MPI_Recv(&l_wid, 1, MPI_LONG, mpi_id_from_role_id(FILTERER, 0), TAG_IWORD, MPI_COMM_WORLD, MPI_STATUS_IGNORE); + if (l_wid == -1) break; f_widx[i] = (float)l_wid; } - MPI_Recv(&s, 1, MPI_INT, MPI_ANY_SOURCE, TAG_READY, MPI_COMM_WORLD, - MPI_STATUS_IGNORE); - if (s != -1) { - MPI_Send(f_widx, n_words, MPI_FLOAT, s, TAG_BATCH, MPI_COMM_WORLD); + if (l_wid == -1) break; + + for in_range(i, n_words) { + INFO_PRINTF("%5.0f ", f_widx[i]); } + + INFO_PRINTLN(""); + // MPI_Recv(&s, 1, MPI_INT, MPI_ANY_SOURCE, TAG_READY, MPI_COMM_WORLD, + // MPI_STATUS_IGNORE); + // if (s != -1) { + // MPI_Send(f_widx, n_words, MPI_FLOAT, s, TAG_BATCH, MPI_COMM_WORLD); + // } } free(f_widx); } @@ -333,6 +346,9 @@ int main (int argc, const char **argv) { case FILTERER: filterer(); break; + case BATCHER: + batcher(); + break; default: INFO_PRINTLN("DYING HORRIBLY!"); // case SLAVE: slave_node(); break;