diff --git a/bridge.pyx b/bridge.pyx index 84087c0..d2025ae 100644 --- a/bridge.pyx +++ b/bridge.pyx @@ -75,9 +75,10 @@ cdef public void f_idx_list_to_print(float* f_idxs, size_t num): cdef public void c_onehot(float* y, float* idxs, size_t n_idx): - oh = nn.onehot(np.asarray(idxs)) + oh = nn.onehot(np.asarray(idxs), nc=len(nn.vocab)) ensure_contiguous(oh) memcpy(y, PyArray_DATA(oh), oh.size * sizeof(float)) + # eprint(np.argmax(oh, axis=1)) cdef public void c_slices(float* X, float* idxs, size_t bs, size_t win): @@ -85,7 +86,8 @@ cdef public void c_slices(float* X, float* idxs, size_t bs, size_t win): idxs_np = np.asarray(idxs) for r in range(bs): X_np[r, :win] = idxs_np[r:r+win] - X_np[r, win+1:] = idxs_np[r+win+1:r+2*win+1] + X_np[r, win:] = idxs_np[r+win+1:r+win+1+win] + # eprint(X_np) cdef public void debug_print(object o): @@ -93,7 +95,7 @@ cdef public void debug_print(object o): cdef public object create_network(int win, int embed): - return nn.create_cbow_network(win, len(nn.vocab), embed) + return nn.create_cbow_network(win, embed) cdef public void set_net_weights(object net, WeightList* wl): @@ -214,7 +216,7 @@ def inspect_array(a): def ensure_contiguous(a): - assert a.flats['C_CONTIGUOUS'] + assert a.flags['C_CONTIGUOUS'] def eprint(*args, **kwargs): diff --git a/library.py b/library.py index a573384..25561a8 100644 --- a/library.py +++ b/library.py @@ -34,15 +34,17 @@ def create_mnist_network(): return model -def create_cbow_network(win, vocsize, embed): +def create_cbow_network(win, embed): ctxt = tf.keras.layers.Input(shape=[win]) - ed = tf.keras.layers.Embedding(vocsize, embed, input_length=win)(ctxt) - avgd = tf.keras.layers.Lambda(lambda x: tf.reduce_mean(x, axis=1))(ed) - mod = tf.keras.Model(inputs=ctxt, outputs=avgd) + ed = tf.keras.layers.Embedding(len(vocab), embed, input_length=win)(ctxt) + cbow = tf.keras.layers.Lambda(lambda x: tf.reduce_mean(x, axis=1))(ed) + blowup = tf.keras.layers.Dense(len(vocab), activation='softmax')(cbow) + mod = tf.keras.Model(inputs=ctxt, outputs=blowup) mod.compile( optimizer='sgd', loss='categorical_crossentropy', ) + print(mod, flush=True) return mod diff --git a/main.c b/main.c index 34523ef..98e30bb 100644 --- a/main.c +++ b/main.c @@ -15,8 +15,8 @@ #define TAG_SWORD 7 #define TAG_IWORD 8 -#define COMM 100 -#define ITER 20 +#define COMM 1 +#define ITER 1 #define BS 10 #define EMB 20 #define WIN 2 @@ -67,7 +67,8 @@ size_t number_of(Role what) { - number_of(BATCHER) - number_of(MASTER); case MASTER: - return 1; + return 0; +#warning "set to real number of masters!" } } @@ -167,7 +168,7 @@ void filterer() { void batcher() { // Reads some data and converts it to a float array // INFO_PRINTF("Starting batcher %d\n", getpid()); - // int s = 0; + int s = 0; const size_t n_words = BS + WIN + WIN; float* f_widx = malloc(n_words * sizeof(float)); long l_wid = 0; @@ -187,11 +188,11 @@ void batcher() { // INFO_PRINTF("%5.0f ", f_widx[i]); // } // INFO_PRINTLN(""); - // MPI_Recv(&s, 1, MPI_INT, MPI_ANY_SOURCE, TAG_READY, MPI_COMM_WORLD, - // MPI_STATUS_IGNORE); - // if (s != -1) { - // MPI_Send(f_widx, n_words, MPI_FLOAT, s, TAG_BATCH, MPI_COMM_WORLD); - // } + MPI_Recv(&s, 1, MPI_INT, MPI_ANY_SOURCE, TAG_READY, MPI_COMM_WORLD, + MPI_STATUS_IGNORE); + if (s != -1) { + MPI_Send(f_widx, n_words, MPI_FLOAT, s, TAG_BATCH, MPI_COMM_WORLD); + } } free(f_widx); } @@ -253,10 +254,10 @@ void slave_node() { float* f_widx = malloc(n_words * sizeof(float)); for in_range(i, COMM) { - MPI_Send(&me, 1, MPI_INT, mpi_id_from_role_id(MASTER, 0), - TAG_READY, MPI_COMM_WORLD); - recv_weights(&wl, mpi_id_from_role_id(MASTER, 0), TAG_WEIGH); - set_net_weights(net, &wl); + // MPI_Send(&me, 1, MPI_INT, mpi_id_from_role_id(MASTER, 0), + // TAG_READY, MPI_COMM_WORLD); + // recv_weights(&wl, mpi_id_from_role_id(MASTER, 0), TAG_WEIGH); + // set_net_weights(net, &wl); for in_range(k, ITER) { MPI_Send(&me, 1, MPI_INT, mpi_id_from_role_id(BATCHER, 0), TAG_READY, MPI_COMM_WORLD); @@ -267,9 +268,9 @@ void slave_node() { c_onehot(y, f_widx + WIN, BS); step_net(net, X, y, BS); } - printf("%d net: %f\n", my_mpi_id(), eval_net(net)); + // printf("%d net: %f\n", my_mpi_id(), eval_net(net)); update_weightlist(&wl, net); - send_weights(&wl, mpi_id_from_role_id(MASTER, 0), TAG_WEIGH); + // send_weights(&wl, mpi_id_from_role_id(MASTER, 0), TAG_WEIGH); } Py_DECREF(net); free_weightlist(&wl); @@ -345,6 +346,9 @@ int main (int argc, const char **argv) { case BATCHER: batcher(); break; + case SLAVE: + slave_node(); + break; default: INFO_PRINTLN("DYING HORRIBLY!"); // case SLAVE: slave_node(); break;