this is the baseline for evaluation
This commit is contained in:
58
main.c
58
main.c
@@ -16,14 +16,14 @@
|
||||
#define TAG_IWORD 8
|
||||
#define TAG_INSTR 9
|
||||
|
||||
#define COMM 500
|
||||
#define ITER 50
|
||||
#define BS 64
|
||||
#define EMB 20
|
||||
#define COMM 25
|
||||
#define ITER 690
|
||||
#define BS 32
|
||||
#define EMB 32
|
||||
#define WIN 2
|
||||
#define FLPC 1
|
||||
|
||||
#define in_range(i, x) (size_t (i) = 0; (i) < (x); (i)++)
|
||||
#define in_range(i, x) (size_t i = 0; i < (x); i++)
|
||||
// I am honestly VERY sorry for this but power corrupts even the best of us
|
||||
|
||||
#define INFO_PRINTF(fmt, ...) \
|
||||
@@ -33,7 +33,7 @@
|
||||
#define INFO_PRINT(what) \
|
||||
do { fprintf(stderr, "%s", what); } while(0)
|
||||
|
||||
int g_argc = 1;
|
||||
int g_argc; // sorry!
|
||||
|
||||
typedef enum{
|
||||
TOKENIZER,
|
||||
@@ -76,7 +76,7 @@ size_t number_of(Role what) {
|
||||
- number_of(DISPATCHER)
|
||||
- number_of(VISUALIZER);
|
||||
case VISUALIZER:
|
||||
return 1;
|
||||
return 0;
|
||||
case DISPATCHER:
|
||||
return 1;
|
||||
}
|
||||
@@ -109,6 +109,7 @@ Role map_node() {
|
||||
}
|
||||
INFO_PRINTF("Something went wrong for node %d\n", node);
|
||||
MPI_Abort(MPI_COMM_WORLD, 1); // this is bad
|
||||
return -1; // Not going to happen anyway (i hope)
|
||||
}
|
||||
|
||||
void announce_ready(int dest) {
|
||||
@@ -144,6 +145,12 @@ void send_word(Word* w, int dest) {
|
||||
MPI_Send(w->data, len + 1, MPI_CHAR, dest, TAG_SWORD, MPI_COMM_WORLD);
|
||||
}
|
||||
|
||||
void ssend_word(Word* w, int dest) {
|
||||
long len = strlen(w->data);
|
||||
MPI_Ssend(&len, 1, MPI_LONG, dest, TAG_STLEN, MPI_COMM_WORLD);
|
||||
MPI_Ssend(w->data, len + 1, MPI_CHAR, dest, TAG_SWORD, MPI_COMM_WORLD);
|
||||
}
|
||||
|
||||
int recv_word(Word* w, int src) {
|
||||
long len;
|
||||
MPI_Status stat;
|
||||
@@ -162,10 +169,16 @@ int recv_word(Word* w, int src) {
|
||||
void tokenizer(const char* source) {
|
||||
INFO_PRINTF("Starting tokenizer %d\n", getpid());
|
||||
WordList wl = {0, 0, NULL};
|
||||
size_t sync_ctr = 0;
|
||||
while (get_tokens(&wl, source)) {
|
||||
for in_range(i, wl.n_words) {
|
||||
// int tok = wait_for_ready();
|
||||
send_word(&wl.words[i], mpi_id_from_role_id(FILTERER, 0));
|
||||
if (sync_ctr == 1000) {
|
||||
ssend_word(&wl.words[i], mpi_id_from_role_id(FILTERER, 0));
|
||||
sync_ctr = 0;
|
||||
} else {
|
||||
send_word(&wl.words[i], mpi_id_from_role_id(FILTERER, 0));
|
||||
}
|
||||
sync_ctr++;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -184,8 +197,6 @@ void filterer() {
|
||||
while (1) {
|
||||
int stream_offs;
|
||||
while (have[src] != entry_size) {
|
||||
// src = rand() % num_streams;
|
||||
// announce_ready(role_id_from_mpi_id(TOKENIZER, src));
|
||||
src = recv_word(&w, MPI_ANY_SOURCE);
|
||||
src = role_id_from_mpi_id(TOKENIZER, src);
|
||||
stream_offs = src*entry_size;
|
||||
@@ -218,7 +229,6 @@ void batcher() {
|
||||
batch[r*entry_size + c] = (float)l_wid[c];
|
||||
}
|
||||
}
|
||||
printf(".");
|
||||
MPI_Recv(&s, 1, MPI_INT, MPI_ANY_SOURCE, TAG_READY, MPI_COMM_WORLD,
|
||||
MPI_STATUS_IGNORE);
|
||||
MPI_Send(batch, bufsize, MPI_FLOAT, s, TAG_BATCH, MPI_COMM_WORLD);
|
||||
@@ -304,13 +314,16 @@ void dispatcher() {
|
||||
update_weightlist(&wl, frank);
|
||||
|
||||
int lpr = number_of(LEARNER) * FLPC; // Learners per round
|
||||
|
||||
WeightList *wls = malloc(sizeof(WeightList) * lpr);
|
||||
int *round = malloc(sizeof(int) * lpr);
|
||||
|
||||
for in_range(i, lpr) {
|
||||
init_weightlist_like(wls + i, frank);
|
||||
}
|
||||
int *round = malloc(sizeof(int) * lpr);
|
||||
|
||||
float first_loss = eval_net(frank);
|
||||
float crt_loss = first_loss;
|
||||
float min_loss = crt_loss;
|
||||
time_t start = time(NULL);
|
||||
for in_range(i, COMM) {
|
||||
randidx(round, number_of(LEARNER), lpr);
|
||||
|
||||
@@ -324,8 +337,17 @@ void dispatcher() {
|
||||
}
|
||||
combo_weights(&wl, wls, lpr);
|
||||
set_net_weights(frank, &wl);
|
||||
INFO_PRINTF("Frank: %f\n", eval_net(frank));
|
||||
crt_loss = eval_net(frank);
|
||||
min_loss = crt_loss < min_loss ? crt_loss : min_loss;
|
||||
INFO_PRINTF("Round %ld, validation loss %f\n", i, crt_loss);
|
||||
}
|
||||
time_t finish = time(NULL);
|
||||
float delta_t = finish - start;
|
||||
float delta_l = first_loss - eval_net(frank);
|
||||
INFO_PRINTF(
|
||||
"Laptop MPI sgd consecutive_batch W%d E%d BS%d R%d bpe%d LPR%d,"
|
||||
"%f,%f,%f\n", WIN, EMB, BS, COMM, ITER, lpr,
|
||||
delta_l / COMM, delta_l / delta_t, min_loss);
|
||||
Py_DECREF(frank);
|
||||
free_weightlist(&wl);
|
||||
for in_range(i, lpr) free_weightlist(wls + i);
|
||||
@@ -338,10 +360,11 @@ void visualizer() {
|
||||
serve();
|
||||
}
|
||||
|
||||
|
||||
int main (int argc, const char **argv) {
|
||||
MPI_Init(NULL, NULL);
|
||||
|
||||
g_argc = argc;
|
||||
|
||||
// Cython Boilerplate
|
||||
PyImport_AppendInittab("bridge", PyInit_bridge);
|
||||
Py_Initialize();
|
||||
@@ -350,7 +373,6 @@ int main (int argc, const char **argv) {
|
||||
|
||||
// Actual Code
|
||||
int role_id;
|
||||
g_argc = argc;
|
||||
switch (map_node()) {
|
||||
case TOKENIZER:
|
||||
role_id = role_id_from_mpi_id(TOKENIZER, my_mpi_id());
|
||||
|
||||
Reference in New Issue
Block a user