take a more windowed approach

This commit is contained in:
2019-12-01 15:04:59 -08:00
parent 4bf66bf85e
commit a2ef842ef8

83
main.c
View File

@@ -15,9 +15,9 @@
#define TAG_SWORD 7 #define TAG_SWORD 7
#define TAG_IWORD 8 #define TAG_IWORD 8
#define COMM 5 #define COMM 1
#define ITER 200 #define ITER 1000
#define BS 32 #define BS 10
#define EMB 20 #define EMB 20
#define WIN 2 #define WIN 2
#define FSPC 1 #define FSPC 1
@@ -147,54 +147,56 @@ void tokenizer(const char* source) {
void filterer() { void filterer() {
Word w = {0, NULL}; Word w = {0, NULL};
long idx; const size_t bufsize = 2 * WIN + 1;
long* idx = malloc(bufsize * sizeof(long));
size_t have = 0;
while (1) { while (1) {
recv_word(&w, role_id_from_mpi_id(TOKENIZER, 0)); while (have < bufsize) {
if (!strlen(w.data)) { recv_word(&w, role_id_from_mpi_id(TOKENIZER, 0));
break; if (!strlen(w.data)) break;
} idx[have] = vocab_idx_of(&w);
idx = vocab_idx_of(&w); if (idx[have] != -1) have++;
if (idx != -1) {
MPI_Send(&idx, 1, MPI_LONG, mpi_id_from_role_id(BATCHER, 0),
TAG_IWORD, MPI_COMM_WORLD);
} }
if (!strlen(w.data)) break;
have = 0;
MPI_Send(idx, bufsize, MPI_LONG, mpi_id_from_role_id(BATCHER, 0),
TAG_IWORD, MPI_COMM_WORLD);
} }
idx = -1; idx[0] = -1;
MPI_Send(&idx, 1, MPI_LONG, mpi_id_from_role_id(BATCHER, 0), MPI_Send(idx, bufsize, MPI_LONG, mpi_id_from_role_id(BATCHER, 0),
TAG_IWORD, MPI_COMM_WORLD); TAG_IWORD, MPI_COMM_WORLD);
free_word(&w); free_word(&w);
free(idx);
} }
void batcher() { void batcher() {
// Reads some data and converts it to a float array // Reads some data and converts it to a float array
// INFO_PRINTF("Starting batcher %d\n", getpid()); // INFO_PRINTF("Starting batcher %d\n", getpid());
int s = 0; int s = 0;
const size_t n_words = BS + WIN + WIN; const size_t entry_size = 2 * WIN + 1;
float* f_widx = malloc(n_words * sizeof(float)); const size_t bufsize = BS * entry_size;
long l_wid = 0; float* batch = malloc(bufsize * sizeof(float));
long* l_wid = malloc(entry_size * sizeof(long));
while (l_wid != -1) { while (1) {
for in_range(r, BS) {
for in_range(i, n_words) { MPI_Recv(l_wid, entry_size, MPI_LONG,
MPI_Recv(&l_wid, 1, MPI_LONG, mpi_id_from_role_id(FILTERER, 0), mpi_id_from_role_id(FILTERER, 0),
TAG_IWORD, MPI_COMM_WORLD, MPI_STATUS_IGNORE); TAG_IWORD, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
if (l_wid == -1) break; if (l_wid[0] == -1) break;
f_widx[i] = (float)l_wid; for in_range(c, entry_size) {
batch[r*entry_size + c] = (float)l_wid[c];
}
} }
if (l_wid == -1) break; if (l_wid[0] == -1) break;
cbow_batch(batch, BS, WIN);
// f_idx_list_to_c_string(f_widx, n_words); // MPI_Recv(&s, 1, MPI_INT, MPI_ANY_SOURCE, TAG_READY, MPI_COMM_WORLD,
// for in_range(i, n_words) { // MPI_STATUS_IGNORE);
// INFO_PRINTF("%5.0f ", f_widx[i]); // MPI_Send(batch, bufsize, MPI_FLOAT, s, TAG_BATCH, MPI_COMM_WORLD);
// }
// INFO_PRINTLN("");
MPI_Recv(&s, 1, MPI_INT, MPI_ANY_SOURCE, TAG_READY, MPI_COMM_WORLD,
MPI_STATUS_IGNORE);
if (s != -1) {
MPI_Send(f_widx, n_words, MPI_FLOAT, s, TAG_BATCH, MPI_COMM_WORLD);
}
} }
free(f_widx); free(l_wid);
free(batch);
} }
void free_weightlist(WeightList* wl) { void free_weightlist(WeightList* wl) {
@@ -265,8 +267,9 @@ void slave_node() {
MPI_Recv(f_widx, n_words, MPI_FLOAT, MPI_Recv(f_widx, n_words, MPI_FLOAT,
mpi_id_from_role_id(BATCHER, 0), TAG_BATCH, MPI_COMM_WORLD, mpi_id_from_role_id(BATCHER, 0), TAG_BATCH, MPI_COMM_WORLD,
MPI_STATUS_IGNORE); MPI_STATUS_IGNORE);
cbow_batch(X, y, f_widx, BS, WIN); // cbow_batch(X, y, f_widx, BS, WIN);
step_net(net, X, y, BS); step_net(net, X, BS);
#warning "fix this"
INFO_PRINTLN("."); INFO_PRINTLN(".");
} }
printf("%d net: %f\n", my_mpi_id(), eval_net(net)); printf("%d net: %f\n", my_mpi_id(), eval_net(net));
@@ -347,9 +350,9 @@ int main (int argc, const char **argv) {
case BATCHER: case BATCHER:
batcher(); batcher();
break; break;
case SLAVE: // case SLAVE:
slave_node(); // slave_node();
break; // break;
default: default:
INFO_PRINTLN("DYING HORRIBLY!"); INFO_PRINTLN("DYING HORRIBLY!");
// case SLAVE: slave_node(); break; // case SLAVE: slave_node(); break;