visualizer kinda working, more work needs done
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -4,5 +4,6 @@ run
|
|||||||
compile_commands.json
|
compile_commands.json
|
||||||
cfg.json
|
cfg.json
|
||||||
build/
|
build/
|
||||||
|
trained/
|
||||||
__pycache__/
|
__pycache__/
|
||||||
data_*/
|
data_*/
|
||||||
|
|||||||
17
bridge.pyx
17
bridge.pyx
@@ -7,7 +7,7 @@ from libc.stdlib cimport malloc, realloc
|
|||||||
from libc.string cimport memcpy
|
from libc.string cimport memcpy
|
||||||
|
|
||||||
import library as nn
|
import library as nn
|
||||||
import flask
|
import server as srv
|
||||||
|
|
||||||
|
|
||||||
tokenizers = {}
|
tokenizers = {}
|
||||||
@@ -44,7 +44,12 @@ cdef public char *greeting():
|
|||||||
|
|
||||||
|
|
||||||
cdef public void serve():
|
cdef public void serve():
|
||||||
nn.app.run(port=8448)
|
srv.serve()
|
||||||
|
|
||||||
|
|
||||||
|
cdef public void bump_count():
|
||||||
|
eprint(f'bumping count from {srv.counter} to {srv.counter + 1}')
|
||||||
|
srv.counter += 1
|
||||||
|
|
||||||
|
|
||||||
cdef public size_t getwin():
|
cdef public size_t getwin():
|
||||||
@@ -143,6 +148,14 @@ cdef public float eval_net(object net):
|
|||||||
return nn.eval_network(net)
|
return nn.eval_network(net)
|
||||||
|
|
||||||
|
|
||||||
|
cdef public void ckpt_net(object net):
|
||||||
|
nn.ckpt_network(net)
|
||||||
|
|
||||||
|
|
||||||
|
cdef public void save_emb(object net):
|
||||||
|
nn.save_embeddings(nn.get_embeddings(net))
|
||||||
|
|
||||||
|
|
||||||
cdef public void init_weightlist_like(WeightList* wl, object net):
|
cdef public void init_weightlist_like(WeightList* wl, object net):
|
||||||
weights = net.get_weights()
|
weights = net.get_weights()
|
||||||
wl.n_weights = len(weights)
|
wl.n_weights = len(weights)
|
||||||
|
|||||||
14
library.py
14
library.py
@@ -18,6 +18,7 @@ def read_cfg():
|
|||||||
|
|
||||||
CFG = read_cfg()
|
CFG = read_cfg()
|
||||||
DATA = os.path.join(HERE, CFG['data'])
|
DATA = os.path.join(HERE, CFG['data'])
|
||||||
|
RESULTS = os.path.join(HERE, 'trained')
|
||||||
CORPUS = os.path.join(DATA, 'corpus.txt')
|
CORPUS = os.path.join(DATA, 'corpus.txt')
|
||||||
VOCAB = os.path.join(DATA, 'vocab.txt')
|
VOCAB = os.path.join(DATA, 'vocab.txt')
|
||||||
TEST = os.path.join(DATA, 'test.txt')
|
TEST = os.path.join(DATA, 'test.txt')
|
||||||
@@ -94,3 +95,16 @@ def token_generator(filename):
|
|||||||
tok = word_tokenize(l)
|
tok = word_tokenize(l)
|
||||||
if tok:
|
if tok:
|
||||||
yield tok
|
yield tok
|
||||||
|
|
||||||
|
|
||||||
|
def get_embeddings(net):
|
||||||
|
return net.get_weights()[0]
|
||||||
|
|
||||||
|
|
||||||
|
def save_embeddings(emb):
|
||||||
|
import numpy as np
|
||||||
|
np.savetxt(os.path.join(RESULTS, f'embeddings_{CFG["data"]}.csv'), emb)
|
||||||
|
|
||||||
|
|
||||||
|
def ckpt_network(net):
|
||||||
|
net.save_weights(os.path.join(RESULTS, f'model_ckpt_{CFG["data"]}.h5'))
|
||||||
|
|||||||
37
main.c
37
main.c
@@ -3,6 +3,7 @@
|
|||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
|
#include <unistd.h>
|
||||||
#include <mpi.h>
|
#include <mpi.h>
|
||||||
|
|
||||||
#define TAG_IDGAF 0
|
#define TAG_IDGAF 0
|
||||||
@@ -67,13 +68,18 @@ size_t number_of(Role what) {
|
|||||||
- number_of(DISPATCHER)
|
- number_of(DISPATCHER)
|
||||||
- number_of(VISUALIZER);
|
- number_of(VISUALIZER);
|
||||||
case VISUALIZER:
|
case VISUALIZER:
|
||||||
return 0;
|
return 1;
|
||||||
case DISPATCHER:
|
case DISPATCHER:
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int mpi_id_from_role_id(Role role, int rid) {
|
int mpi_id_from_role_id(Role role, int rid) {
|
||||||
|
if (rid >= number_of(role) || rid < 0) {
|
||||||
|
INFO_PRINTF("There aren't %d of %d (but %lu)\n",
|
||||||
|
rid, role, number_of(role));
|
||||||
|
MPI_Abort(MPI_COMM_WORLD, 1);
|
||||||
|
}
|
||||||
int base = 0;
|
int base = 0;
|
||||||
for (Role r = TOKENIZER; r < role; r++) {
|
for (Role r = TOKENIZER; r < role; r++) {
|
||||||
base += number_of(r);
|
base += number_of(r);
|
||||||
@@ -146,20 +152,16 @@ void ssend_word(Word* w, int dest) {
|
|||||||
MPI_Ssend(w->data, len + 1, MPI_CHAR, dest, TAG_SWORD, MPI_COMM_WORLD);
|
MPI_Ssend(w->data, len + 1, MPI_CHAR, dest, TAG_SWORD, MPI_COMM_WORLD);
|
||||||
}
|
}
|
||||||
|
|
||||||
int recv_word(Word* w, int src) {
|
void recv_word(Word* w, int src) {
|
||||||
// WAT is going on here I have no idea
|
|
||||||
long len;
|
long len;
|
||||||
MPI_Status stat;
|
|
||||||
MPI_Recv(&len, 1, MPI_LONG, src, TAG_STLEN, MPI_COMM_WORLD,
|
MPI_Recv(&len, 1, MPI_LONG, src, TAG_STLEN, MPI_COMM_WORLD,
|
||||||
&stat);
|
MPI_STATUS_IGNORE);
|
||||||
int the_src = stat.MPI_SOURCE;
|
|
||||||
if (w->mem < len + 1) {
|
if (w->mem < len + 1) {
|
||||||
w->mem = len + 1;
|
w->mem = len + 1;
|
||||||
w->data = realloc(w->data, sizeof(char) * w->mem);
|
w->data = realloc(w->data, sizeof(char) * w->mem);
|
||||||
}
|
}
|
||||||
MPI_Recv(w->data, len + 1, MPI_CHAR, the_src, TAG_SWORD, MPI_COMM_WORLD,
|
MPI_Recv(w->data, len + 1, MPI_CHAR, src, TAG_SWORD, MPI_COMM_WORLD,
|
||||||
MPI_STATUS_IGNORE);
|
MPI_STATUS_IGNORE);
|
||||||
return the_src;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void send_window(long* window, size_t winsize, int dest) {
|
void send_window(long* window, size_t winsize, int dest) {
|
||||||
@@ -192,7 +194,11 @@ void tokenizer(const char* source) {
|
|||||||
ssend_word(&wl.words[i], next);
|
ssend_word(&wl.words[i], next);
|
||||||
sync_ctr = 0;
|
sync_ctr = 0;
|
||||||
} else {
|
} else {
|
||||||
send_word(&wl.words[i], next);
|
if (rand() % 100) {
|
||||||
|
// drop a word here and there
|
||||||
|
// probably would make sense if there was less data
|
||||||
|
send_word(&wl.words[i], next);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
sync_ctr++;
|
sync_ctr++;
|
||||||
}
|
}
|
||||||
@@ -398,6 +404,9 @@ void dispatcher() {
|
|||||||
crt_loss = eval_net(frank);
|
crt_loss = eval_net(frank);
|
||||||
min_loss = crt_loss < min_loss ? crt_loss : min_loss;
|
min_loss = crt_loss < min_loss ? crt_loss : min_loss;
|
||||||
INFO_PRINTF("Round %ld, validation loss %f\n", rounds, crt_loss);
|
INFO_PRINTF("Round %ld, validation loss %f\n", rounds, crt_loss);
|
||||||
|
|
||||||
|
ckpt_net(frank);
|
||||||
|
|
||||||
rounds++;
|
rounds++;
|
||||||
}
|
}
|
||||||
time_t finish = time(NULL);
|
time_t finish = time(NULL);
|
||||||
@@ -412,10 +421,12 @@ void dispatcher() {
|
|||||||
TAG_INSTR, MPI_COMM_WORLD);
|
TAG_INSTR, MPI_COMM_WORLD);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
save_emb(frank);
|
||||||
|
|
||||||
float delta_t = finish - start;
|
float delta_t = finish - start;
|
||||||
float delta_l = first_loss - crt_loss;
|
float delta_l = first_loss - crt_loss;
|
||||||
INFO_PRINTF(
|
INFO_PRINTF(
|
||||||
"WIKI MPI adam consecutive_batch "
|
"Moby MPI adam consecutive_batch "
|
||||||
"W%lu E%lu BS%lu bpe%lu LPR%d pp%lu,"
|
"W%lu E%lu BS%lu bpe%lu LPR%d pp%lu,"
|
||||||
"%f,%f,%f,%f,"
|
"%f,%f,%f,%f,"
|
||||||
"%lu,%.0f,%lu\n",
|
"%lu,%.0f,%lu\n",
|
||||||
@@ -434,6 +445,10 @@ void dispatcher() {
|
|||||||
void visualizer() {
|
void visualizer() {
|
||||||
INFO_PRINTF("Starting visualizer %d\n", getpid());
|
INFO_PRINTF("Starting visualizer %d\n", getpid());
|
||||||
serve();
|
serve();
|
||||||
|
while (1) {
|
||||||
|
sleep(1);
|
||||||
|
bump_count();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int main (int argc, const char **argv) {
|
int main (int argc, const char **argv) {
|
||||||
@@ -445,7 +460,7 @@ int main (int argc, const char **argv) {
|
|||||||
MPI_Abort(MPI_COMM_WORLD, 1);
|
MPI_Abort(MPI_COMM_WORLD, 1);
|
||||||
}
|
}
|
||||||
int pipelines = argc - 1;
|
int pipelines = argc - 1;
|
||||||
int min_nodes = 4 * pipelines + 1;
|
int min_nodes = 4 * pipelines + 2;
|
||||||
if (world_size() < min_nodes) {
|
if (world_size() < min_nodes) {
|
||||||
INFO_PRINTF("You requested %d pipeline(s) "
|
INFO_PRINTF("You requested %d pipeline(s) "
|
||||||
"but only provided %d procs "
|
"but only provided %d procs "
|
||||||
|
|||||||
26
server.py
Normal file
26
server.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
from threading import Thread
|
||||||
|
from sys import stderr
|
||||||
|
|
||||||
|
import flask
|
||||||
|
|
||||||
|
|
||||||
|
t = None
|
||||||
|
app = flask.Flask(__name__)
|
||||||
|
counter = 0
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/')
|
||||||
|
def main():
|
||||||
|
return f'Hello {counter}'
|
||||||
|
|
||||||
|
|
||||||
|
def serve():
|
||||||
|
global t
|
||||||
|
if t is None:
|
||||||
|
t = Thread(target=app.run, kwargs={'port': 8448})
|
||||||
|
t.start()
|
||||||
|
print('So I kinda started', flush=True, file=stderr)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
serve()
|
||||||
Reference in New Issue
Block a user