visualizer kinda working, more work needs done

This commit is contained in:
2019-12-13 11:05:38 -08:00
parent 1338df9db2
commit 4a9e328884
5 changed files with 82 additions and 13 deletions

1
.gitignore vendored
View File

@@ -4,5 +4,6 @@ run
compile_commands.json compile_commands.json
cfg.json cfg.json
build/ build/
trained/
__pycache__/ __pycache__/
data_*/ data_*/

View File

@@ -7,7 +7,7 @@ from libc.stdlib cimport malloc, realloc
from libc.string cimport memcpy from libc.string cimport memcpy
import library as nn import library as nn
import flask import server as srv
tokenizers = {} tokenizers = {}
@@ -44,7 +44,12 @@ cdef public char *greeting():
cdef public void serve(): cdef public void serve():
nn.app.run(port=8448) srv.serve()
cdef public void bump_count():
eprint(f'bumping count from {srv.counter} to {srv.counter + 1}')
srv.counter += 1
cdef public size_t getwin(): cdef public size_t getwin():
@@ -143,6 +148,14 @@ cdef public float eval_net(object net):
return nn.eval_network(net) return nn.eval_network(net)
cdef public void ckpt_net(object net):
nn.ckpt_network(net)
cdef public void save_emb(object net):
nn.save_embeddings(nn.get_embeddings(net))
cdef public void init_weightlist_like(WeightList* wl, object net): cdef public void init_weightlist_like(WeightList* wl, object net):
weights = net.get_weights() weights = net.get_weights()
wl.n_weights = len(weights) wl.n_weights = len(weights)

View File

@@ -18,6 +18,7 @@ def read_cfg():
CFG = read_cfg() CFG = read_cfg()
DATA = os.path.join(HERE, CFG['data']) DATA = os.path.join(HERE, CFG['data'])
RESULTS = os.path.join(HERE, 'trained')
CORPUS = os.path.join(DATA, 'corpus.txt') CORPUS = os.path.join(DATA, 'corpus.txt')
VOCAB = os.path.join(DATA, 'vocab.txt') VOCAB = os.path.join(DATA, 'vocab.txt')
TEST = os.path.join(DATA, 'test.txt') TEST = os.path.join(DATA, 'test.txt')
@@ -94,3 +95,16 @@ def token_generator(filename):
tok = word_tokenize(l) tok = word_tokenize(l)
if tok: if tok:
yield tok yield tok
def get_embeddings(net):
return net.get_weights()[0]
def save_embeddings(emb):
import numpy as np
np.savetxt(os.path.join(RESULTS, f'embeddings_{CFG["data"]}.csv'), emb)
def ckpt_network(net):
net.save_weights(os.path.join(RESULTS, f'model_ckpt_{CFG["data"]}.h5'))

37
main.c
View File

@@ -3,6 +3,7 @@
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
#include <unistd.h>
#include <mpi.h> #include <mpi.h>
#define TAG_IDGAF 0 #define TAG_IDGAF 0
@@ -67,13 +68,18 @@ size_t number_of(Role what) {
- number_of(DISPATCHER) - number_of(DISPATCHER)
- number_of(VISUALIZER); - number_of(VISUALIZER);
case VISUALIZER: case VISUALIZER:
return 0; return 1;
case DISPATCHER: case DISPATCHER:
return 1; return 1;
} }
} }
int mpi_id_from_role_id(Role role, int rid) { int mpi_id_from_role_id(Role role, int rid) {
if (rid >= number_of(role) || rid < 0) {
INFO_PRINTF("There aren't %d of %d (but %lu)\n",
rid, role, number_of(role));
MPI_Abort(MPI_COMM_WORLD, 1);
}
int base = 0; int base = 0;
for (Role r = TOKENIZER; r < role; r++) { for (Role r = TOKENIZER; r < role; r++) {
base += number_of(r); base += number_of(r);
@@ -146,20 +152,16 @@ void ssend_word(Word* w, int dest) {
MPI_Ssend(w->data, len + 1, MPI_CHAR, dest, TAG_SWORD, MPI_COMM_WORLD); MPI_Ssend(w->data, len + 1, MPI_CHAR, dest, TAG_SWORD, MPI_COMM_WORLD);
} }
int recv_word(Word* w, int src) { void recv_word(Word* w, int src) {
// WAT is going on here I have no idea
long len; long len;
MPI_Status stat;
MPI_Recv(&len, 1, MPI_LONG, src, TAG_STLEN, MPI_COMM_WORLD, MPI_Recv(&len, 1, MPI_LONG, src, TAG_STLEN, MPI_COMM_WORLD,
&stat); MPI_STATUS_IGNORE);
int the_src = stat.MPI_SOURCE;
if (w->mem < len + 1) { if (w->mem < len + 1) {
w->mem = len + 1; w->mem = len + 1;
w->data = realloc(w->data, sizeof(char) * w->mem); w->data = realloc(w->data, sizeof(char) * w->mem);
} }
MPI_Recv(w->data, len + 1, MPI_CHAR, the_src, TAG_SWORD, MPI_COMM_WORLD, MPI_Recv(w->data, len + 1, MPI_CHAR, src, TAG_SWORD, MPI_COMM_WORLD,
MPI_STATUS_IGNORE); MPI_STATUS_IGNORE);
return the_src;
} }
void send_window(long* window, size_t winsize, int dest) { void send_window(long* window, size_t winsize, int dest) {
@@ -192,7 +194,11 @@ void tokenizer(const char* source) {
ssend_word(&wl.words[i], next); ssend_word(&wl.words[i], next);
sync_ctr = 0; sync_ctr = 0;
} else { } else {
send_word(&wl.words[i], next); if (rand() % 100) {
// drop a word here and there
// probably would make sense if there was less data
send_word(&wl.words[i], next);
}
} }
sync_ctr++; sync_ctr++;
} }
@@ -398,6 +404,9 @@ void dispatcher() {
crt_loss = eval_net(frank); crt_loss = eval_net(frank);
min_loss = crt_loss < min_loss ? crt_loss : min_loss; min_loss = crt_loss < min_loss ? crt_loss : min_loss;
INFO_PRINTF("Round %ld, validation loss %f\n", rounds, crt_loss); INFO_PRINTF("Round %ld, validation loss %f\n", rounds, crt_loss);
ckpt_net(frank);
rounds++; rounds++;
} }
time_t finish = time(NULL); time_t finish = time(NULL);
@@ -412,10 +421,12 @@ void dispatcher() {
TAG_INSTR, MPI_COMM_WORLD); TAG_INSTR, MPI_COMM_WORLD);
} }
save_emb(frank);
float delta_t = finish - start; float delta_t = finish - start;
float delta_l = first_loss - crt_loss; float delta_l = first_loss - crt_loss;
INFO_PRINTF( INFO_PRINTF(
"WIKI MPI adam consecutive_batch " "Moby MPI adam consecutive_batch "
"W%lu E%lu BS%lu bpe%lu LPR%d pp%lu," "W%lu E%lu BS%lu bpe%lu LPR%d pp%lu,"
"%f,%f,%f,%f," "%f,%f,%f,%f,"
"%lu,%.0f,%lu\n", "%lu,%.0f,%lu\n",
@@ -434,6 +445,10 @@ void dispatcher() {
void visualizer() { void visualizer() {
INFO_PRINTF("Starting visualizer %d\n", getpid()); INFO_PRINTF("Starting visualizer %d\n", getpid());
serve(); serve();
while (1) {
sleep(1);
bump_count();
}
} }
int main (int argc, const char **argv) { int main (int argc, const char **argv) {
@@ -445,7 +460,7 @@ int main (int argc, const char **argv) {
MPI_Abort(MPI_COMM_WORLD, 1); MPI_Abort(MPI_COMM_WORLD, 1);
} }
int pipelines = argc - 1; int pipelines = argc - 1;
int min_nodes = 4 * pipelines + 1; int min_nodes = 4 * pipelines + 2;
if (world_size() < min_nodes) { if (world_size() < min_nodes) {
INFO_PRINTF("You requested %d pipeline(s) " INFO_PRINTF("You requested %d pipeline(s) "
"but only provided %d procs " "but only provided %d procs "

26
server.py Normal file
View File

@@ -0,0 +1,26 @@
from threading import Thread
from sys import stderr
import flask
t = None
app = flask.Flask(__name__)
counter = 0
@app.route('/')
def main():
return f'Hello {counter}'
def serve():
global t
if t is None:
t = Thread(target=app.run, kwargs={'port': 8448})
t.start()
print('So I kinda started', flush=True, file=stderr)
if __name__ == '__main__':
serve()