small refactoring and kicked the visualizer again

This commit is contained in:
2019-12-15 09:51:22 -08:00
parent 06d0b2d565
commit 05480606b0
4 changed files with 82 additions and 30 deletions

View File

@@ -47,6 +47,12 @@ cdef public void serve():
srv.serve() srv.serve()
cdef public void server_update(float *emb):
embeddings = np.asarray(<float[:getvocsize(),:getemb()]>emb)
low_dim = nn.calc_TSNE(embeddings)
srv.emb_map = dict(zip(nn.inv_vocab, low_dim))
cdef public size_t getwin(): cdef public size_t getwin():
return nn.WIN return nn.WIN
@@ -67,6 +73,10 @@ cdef public float gettarget():
return nn.CFG['target'] return nn.CFG['target']
cdef public size_t getvocsize():
return len(nn.vocab)
cdef public int get_tokens(WordList* wl, const char *filename): cdef public int get_tokens(WordList* wl, const char *filename):
fnu = filename.decode('utf-8') fnu = filename.decode('utf-8')
if fnu not in tokenizers: if fnu not in tokenizers:
@@ -101,8 +111,8 @@ cdef public void _dbg_print(object o):
eprint(o) eprint(o)
cdef public void _dbg_print_cbow_batch(float* batch, size_t bs): cdef public void _dbg_print_cbow_batch(float* batch):
X_np, y_np = cbow_batch(batch, bs) X_np, y_np = cbow_batch(batch)
eprint(X_np) eprint(X_np)
eprint(y_np) eprint(y_np)
@@ -124,10 +134,8 @@ cdef public void set_net_weights(object net, WeightList* wl):
net.set_weights(wrap_weight_list(wl)) net.set_weights(wrap_weight_list(wl))
cdef public void step_net( cdef public void step_net(object net, float* batch):
object net, float* batch, size_t bs X_train, y_train = cbow_batch(batch)
):
X_train, y_train = cbow_batch(batch, bs)
net.train_on_batch(X_train, y_train) net.train_on_batch(X_train, y_train)
@@ -183,8 +191,9 @@ cdef public void combo_weights(
wf += alpha * ww wf += alpha * ww
cdef tuple cbow_batch(float* batch, size_t bs): cdef tuple cbow_batch(float* batch):
win = nn.WIN win = getwin()
bs = getbs()
batch_np = np.asarray(<float[:bs,:2*win+1]>batch) batch_np = np.asarray(<float[:bs,:2*win+1]>batch)
X_np = batch_np[:, [*range(win), *range(win+1, win+win+1)]] X_np = batch_np[:, [*range(win), *range(win+1, win+win+1)]]
y_np = nn.onehot(batch_np[:, win], nc=len(nn.vocab)) y_np = nn.onehot(batch_np[:, win], nc=len(nn.vocab))

View File

@@ -113,6 +113,13 @@ def get_embeddings(net):
return net.get_weights()[0] return net.get_weights()[0]
def calc_TSNE(emb):
# import umap
# return umap.UMAP().fit_transform(emb)
return emb
def save_embeddings(emb): def save_embeddings(emb):
import numpy as np import numpy as np
np.savetxt(os.path.join(RESULTS, f'embeddings_{CFG["data_name"]}.csv'), np.savetxt(os.path.join(RESULTS, f'embeddings_{CFG["data_name"]}.csv'),

64
main.c
View File

@@ -17,6 +17,7 @@
#define TAG_IWIND 8 #define TAG_IWIND 8
#define TAG_INSTR 9 #define TAG_INSTR 9
#define TAG_TERMT 10 #define TAG_TERMT 10
#define TAG_EMBED 11
#define in_range(i, x) (size_t i = 0; i < (x); i++) #define in_range(i, x) (size_t i = 0; i < (x); i++)
// I am honestly VERY sorry for this // I am honestly VERY sorry for this
@@ -33,7 +34,7 @@ int g_argc; // sorry!
typedef enum { typedef enum {
TOKENIZER, TOKENIZER,
FILTERER, FILTER,
BATCHER, BATCHER,
LEARNER, LEARNER,
VISUALIZER, VISUALIZER,
@@ -56,19 +57,19 @@ size_t number_of(Role what) {
switch (what) { switch (what) {
case TOKENIZER: case TOKENIZER:
return g_argc - 1; return g_argc - 1;
case FILTERER: case FILTER:
return number_of(TOKENIZER); return number_of(TOKENIZER);
case BATCHER: case BATCHER:
return number_of(TOKENIZER); return number_of(TOKENIZER);
case LEARNER: case LEARNER:
return world_size() return world_size()
- number_of(TOKENIZER) - number_of(TOKENIZER)
- number_of(FILTERER) - number_of(FILTER)
- number_of(BATCHER) - number_of(BATCHER)
- number_of(DISPATCHER) - number_of(DISPATCHER)
- number_of(VISUALIZER); - number_of(VISUALIZER);
case VISUALIZER: case VISUALIZER:
return 1; return 0;
case DISPATCHER: case DISPATCHER:
return 1; return 1;
} }
@@ -77,7 +78,7 @@ size_t number_of(Role what) {
int mpi_id_from_role_id(Role role, int rid) { int mpi_id_from_role_id(Role role, int rid) {
if (rid >= number_of(role) || rid < 0) { if (rid >= number_of(role) || rid < 0) {
INFO_PRINTF("There aren't %d of %d (but %lu)\n", INFO_PRINTF("There aren't %d of %d (but %lu)\n",
rid, role, number_of(role)); rid+1, role, number_of(role));
MPI_Abort(MPI_COMM_WORLD, 1); MPI_Abort(MPI_COMM_WORLD, 1);
} }
int base = 0; int base = 0;
@@ -176,7 +177,7 @@ void recv_window(long* window, size_t winsize, int src) {
void tokenizer(const char* source) { void tokenizer(const char* source) {
INFO_PRINTF("Starting tokenizer %d\n", getpid()); INFO_PRINTF("Starting tokenizer %d\n", getpid());
int rid = my_role_id(TOKENIZER); int rid = my_role_id(TOKENIZER);
int next = mpi_id_from_role_id(FILTERER, rid); int next = mpi_id_from_role_id(FILTER, rid);
WordList wl = {0, 0, NULL}; WordList wl = {0, 0, NULL};
size_t sync_ctr = 0; size_t sync_ctr = 0;
@@ -209,9 +210,9 @@ void tokenizer(const char* source) {
INFO_PRINTF("Finishing tokenizer %d\n", getpid()); INFO_PRINTF("Finishing tokenizer %d\n", getpid());
} }
void filterer() { void filter() {
INFO_PRINTF("Starting filterer %d\n", getpid()); INFO_PRINTF("Starting filter %d\n", getpid());
int rid = my_role_id(FILTERER); int rid = my_role_id(FILTER);
int tokenizer = mpi_id_from_role_id(TOKENIZER, rid); int tokenizer = mpi_id_from_role_id(TOKENIZER, rid);
int batcher = mpi_id_from_role_id(BATCHER, rid); int batcher = mpi_id_from_role_id(BATCHER, rid);
@@ -239,13 +240,13 @@ void filterer() {
send_window(window, window_size, batcher); send_window(window, window_size, batcher);
free_word(&w); free_word(&w);
free(window); free(window);
INFO_PRINTF("Finishing filterer %d\n", getpid()); INFO_PRINTF("Finishing filter %d\n", getpid());
} }
void batcher() { void batcher() {
INFO_PRINTF("Starting batcher %d\n", getpid()); INFO_PRINTF("Starting batcher %d\n", getpid());
int rid = my_role_id(BATCHER); int rid = my_role_id(BATCHER);
int tokenizer = mpi_id_from_role_id(FILTERER, rid); int tokenizer = mpi_id_from_role_id(FILTER, rid);
int bs = getbs(); int bs = getbs();
int learner_mpi_id = 0; int learner_mpi_id = 0;
@@ -326,15 +327,13 @@ void learner() {
int dispatcher = mpi_id_from_role_id(DISPATCHER, 0); int dispatcher = mpi_id_from_role_id(DISPATCHER, 0);
INFO_PRINTF("Learner %d (pid %d) is assigned to pipeline %d\n", rid, INFO_PRINTF("Learner %d (pid %d) is assigned to pipeline %d\n", rid,
getpid(), my_batcher_rid); getpid(), my_batcher_rid);
size_t bs = getbs();
size_t bpe = getbpe();
PyObject* net = create_network(); PyObject* net = create_network();
WeightList wl; WeightList wl;
init_weightlist_like(&wl, net); init_weightlist_like(&wl, net);
size_t window_size = 2 * getwin() + 1; size_t window_size = 2 * getwin() + 1;
size_t bufsize = bs * window_size; size_t bufsize = getbs() * window_size;
float* batch = malloc(bufsize * sizeof(float)); float* batch = malloc(bufsize * sizeof(float));
int go; int go;
@@ -344,11 +343,11 @@ void learner() {
while (go != -1) { while (go != -1) {
recv_weights(&wl, dispatcher); recv_weights(&wl, dispatcher);
set_net_weights(net, &wl); set_net_weights(net, &wl);
for in_range(k, bpe) { for in_range(k, getbpe()) {
MPI_Send(&me, 1, MPI_INT, batcher, TAG_READY, MPI_COMM_WORLD); MPI_Send(&me, 1, MPI_INT, batcher, TAG_READY, MPI_COMM_WORLD);
MPI_Recv(batch, bufsize, MPI_FLOAT, batcher, TAG_BATCH, MPI_Recv(batch, bufsize, MPI_FLOAT, batcher, TAG_BATCH,
MPI_COMM_WORLD, MPI_STATUS_IGNORE); MPI_COMM_WORLD, MPI_STATUS_IGNORE);
step_net(net, batch, bs); step_net(net, batch);
} }
update_weightlist(&wl, net); update_weightlist(&wl, net);
send_weights(&wl, dispatcher); send_weights(&wl, dispatcher);
@@ -365,9 +364,11 @@ void learner() {
void dispatcher() { void dispatcher() {
INFO_PRINTF("Starting dispatcher %d\n", getpid()); INFO_PRINTF("Starting dispatcher %d\n", getpid());
int go = 1; int go = 1;
// int visualizer = mpi_id_from_role_id(VISUALIZER, 0);
size_t bs = getbs(); size_t bs = getbs();
size_t bpe = getbpe(); size_t bpe = getbpe();
float target = gettarget(); float target = gettarget();
// size_t emb_mat_size = getemb() * getvocsize();
PyObject* frank = create_network(); PyObject* frank = create_network();
WeightList wl; WeightList wl;
@@ -403,6 +404,9 @@ void dispatcher() {
crt_loss = eval_net(frank); crt_loss = eval_net(frank);
min_loss = crt_loss < min_loss ? crt_loss : min_loss; min_loss = crt_loss < min_loss ? crt_loss : min_loss;
INFO_PRINTF("Round %ld, validation loss %f\n", rounds, crt_loss); INFO_PRINTF("Round %ld, validation loss %f\n", rounds, crt_loss);
// MPI_Send(&go, 1, MPI_INT, visualizer, TAG_INSTR, MPI_COMM_WORLD);
// MPI_Send(wl.weights[0].W, emb_mat_size, MPI_FLOAT,
// visualizer, TAG_EMBED, MPI_COMM_WORLD);
ckpt_net(frank); ckpt_net(frank);
@@ -419,6 +423,8 @@ void dispatcher() {
MPI_Send(&go, 1, MPI_INT, mpi_id_from_role_id(LEARNER, l), MPI_Send(&go, 1, MPI_INT, mpi_id_from_role_id(LEARNER, l),
TAG_INSTR, MPI_COMM_WORLD); TAG_INSTR, MPI_COMM_WORLD);
} }
// MPI_Send(&go, 1, MPI_INT, mpi_id_from_role_id(VISUALIZER, 0),
// TAG_INSTR, MPI_COMM_WORLD);
save_emb(frank); save_emb(frank);
@@ -439,16 +445,38 @@ void dispatcher() {
free(wls); free(wls);
free(round); free(round);
INFO_PRINTF("Finishing dispatcher %d\n", getpid()); INFO_PRINTF("Finishing dispatcher %d\n", getpid());
// sleep(4);
// INFO_PRINTLN("Visualization server is still running on port 8448\n"
// "To terminate, press Ctrl-C");
} }
void visualizer() { void visualizer() {
INFO_PRINTF("Starting visualizer %d\n", getpid()); INFO_PRINTF("Starting visualizer %d\n", getpid());
serve(); serve();
int dispatcher = mpi_id_from_role_id(DISPATCHER, 0);
int go_on = 1;
size_t emb_mat_size = getvocsize() * getemb();
float* embeddings = malloc(emb_mat_size * sizeof(float));
MPI_Recv(&go_on, 1, MPI_INT, dispatcher, TAG_INSTR, MPI_COMM_WORLD,
MPI_STATUS_IGNORE);
while(go_on != -1) {
MPI_Recv(embeddings, emb_mat_size, MPI_FLOAT, dispatcher, TAG_EMBED,
MPI_COMM_WORLD, MPI_STATUS_IGNORE);
server_update(embeddings);
MPI_Recv(&go_on, 1, MPI_INT, dispatcher, TAG_INSTR, MPI_COMM_WORLD,
MPI_STATUS_IGNORE);
}
INFO_PRINTF("Exiting visualizer node %d\n", getpid());
} }
int main (int argc, const char **argv) { int main (int argc, const char **argv) {
MPI_Init(NULL, NULL); MPI_Init(NULL, NULL);
// Some sanity checks on the input
if (my_mpi_id() == 0) { if (my_mpi_id() == 0) {
if (argc < 2) { if (argc < 2) {
INFO_PRINTLN("NOT ENOUGH INPUTS!"); INFO_PRINTLN("NOT ENOUGH INPUTS!");
@@ -479,8 +507,8 @@ int main (int argc, const char **argv) {
role_id = role_id_from_mpi_id(TOKENIZER, my_mpi_id()); role_id = role_id_from_mpi_id(TOKENIZER, my_mpi_id());
tokenizer(argv[role_id + 1]); tokenizer(argv[role_id + 1]);
break; break;
case FILTERER: case FILTER:
filterer(); filter();
break; break;
case BATCHER: case BATCHER:
batcher(); batcher();

View File

@@ -1,17 +1,26 @@
from threading import Thread from threading import Thread
from sys import stderr
import flask import flask
t = None t = None
app = flask.Flask(__name__) app = flask.Flask(__name__)
counter = 0 emb_map = None
import logging
log = logging.getLogger('werkzeug')
log.setLevel(logging.ERROR)
app.logger.setLevel(logging.ERROR)
@app.route('/') @app.route('/')
def main(): def main():
return f'Hello {counter}' if emb_map is None:
return 'Hello World!'
else:
return '\n'.join(f'{w}: {vec}' for w, vec in emb_map.items())
def serve(): def serve():
@@ -19,7 +28,6 @@ def serve():
if t is None: if t is None:
t = Thread(target=app.run, kwargs={'port': 8448}) t = Thread(target=app.run, kwargs={'port': 8448})
t.start() t.start()
print('So I kinda started', flush=True, file=stderr)
if __name__ == '__main__': if __name__ == '__main__':