tokenizer and batcher are sane and work

i wonder how much pain it's gonna cause me to change
something in the code
This commit is contained in:
2019-11-30 22:16:42 -08:00
parent 04c35ed9b6
commit 101248965c
6 changed files with 321 additions and 156 deletions

4
.gitignore vendored
View File

@@ -1,8 +1,8 @@
.*.sw? .*.sw?
DS_Store DS_Store
library.c
library.h
run run
compile_commands.json compile_commands.json
*.txt
build/ build/
cythoned/
__pycache__/ __pycache__/

View File

@@ -3,13 +3,14 @@ import numpy as np
from sys import stderr from sys import stderr
from libc.stdlib cimport malloc from libc.stdlib cimport malloc, realloc
from libc.string cimport memcpy from libc.string cimport memcpy
import nn import library as nn
X_train, y_train, X_test, y_test = nn.load_mnist() X_train, y_train, X_test, y_test = nn.load_mnist()
tokenizers = {}
cdef extern from "numpy/arrayobject.h": cdef extern from "numpy/arrayobject.h":
@@ -27,16 +28,62 @@ ctypedef public struct WeightList:
Weight* weights; Weight* weights;
ctypedef public struct Word:
size_t mem
char* data
ctypedef public struct WordList:
size_t mem
size_t n_words
Word* words
cdef public char *greeting(): cdef public char *greeting():
return f'The value is {3**3**3}'.encode('utf-8') return f'The value is {3**3**3}'.encode('utf-8')
cdef public int get_tokens(WordList* wl, const char *filename):
fnu = filename.decode('utf-8')
if fnu not in tokenizers:
tokenizers[fnu] = nn.token_generator(fnu)
g = tokenizers[fnu]
try:
words = next(g)
except StopIteration:
return 0
words_into_wordlist(wl, words)
return 1
cdef public long vocab_idx_of(Word* w):
word = w.data.decode('utf-8')
if word.lower() in nn.vocab:
return nn.vocab.index(word.lower())
else:
return -1
cdef public void c_onehot(float* y, float* idxs, size_t n_idx):
oh = nn.onehot(np.asarray(<float[:n_idx]>idxs))
ensure_contiguous(oh)
memcpy(y, PyArray_DATA(oh), oh.size * sizeof(float))
cdef public void c_slices(float* X, float* idxs, size_t bs, size_t win):
X_np = np.asarray(<float[:bs,:2*win]>X)
idxs_np = np.asarray(<float[:bs + 2*win]>idxs)
for r in range(bs):
X_np[r, :win] = idxs_np[r:r+win]
X_np[r, win+1:] = idxs_np[r+win+1:r+2*win+1]
cdef public void debug_print(object o): cdef public void debug_print(object o):
print(o) eprint(o)
cdef public object create_network(): cdef public object create_network(int win, int embed):
return nn.create_mnist_network() return nn.create_cbow_network(win, len(nn.vocab), embed)
cdef public void set_net_weights(object net, WeightList* wl): cdef public void set_net_weights(object net, WeightList* wl):
@@ -46,16 +93,20 @@ cdef public void set_net_weights(object net, WeightList* wl):
cdef public void step_net( cdef public void step_net(
object net, float* X, float* y, size_t batch_size object net, float* X, float* y, size_t batch_size
): ):
in_shape = (batch_size,) + net.layers[0].input_shape[1:] in_shape = (batch_size,) + net.input_shape[1:]
out_shape = (batch_size,) + net.layers[-1].output_shape[1:] out_shape = (batch_size,) + net.output_shape[1:]
X_train = np.asarray(<float[:np.prod(in_shape)]>X).reshape(in_shape) X_train = np.asarray(<float[:np.prod(in_shape)]>X).reshape(in_shape)
y_train = np.asarray(<float[:np.prod(out_shape)]>y).reshape(out_shape) y_train = np.asarray(<float[:np.prod(out_shape)]>y).reshape(out_shape),
net.train_on_batch(X_train, y_train) net.train_on_batch(X_train, y_train)
cdef public size_t out_size(object net):
return np.prod(net.output_shape[1:])
cdef public float eval_net(object net): cdef public float eval_net(object net):
return net.evaluate(X_test, y_test, verbose=False)[1] return net.evaluate(X_test, y_test, verbose=False)
cdef public void mnist_batch(float* X, float* y, size_t bs, cdef public void mnist_batch(float* X, float* y, size_t bs,
@@ -74,8 +125,8 @@ cdef public void mnist_batch(float* X, float* y, size_t bs,
assert X_r.flags['C_CONTIGUOUS'] assert X_r.flags['C_CONTIGUOUS']
assert y_r.flags['C_CONTIGUOUS'] assert y_r.flags['C_CONTIGUOUS']
memcpy(X, <float*>PyArray_DATA(X_r), X_r.size * sizeof(float)) memcpy(X, PyArray_DATA(X_r), X_r.size * sizeof(float))
memcpy(y, <float*>PyArray_DATA(y_r), y_r.size * sizeof(float)) memcpy(y, PyArray_DATA(y_r), y_r.size * sizeof(float))
cdef public void init_weightlist_like(WeightList* wl, object net): cdef public void init_weightlist_like(WeightList* wl, object net):
@@ -89,8 +140,7 @@ cdef public void init_weightlist_like(WeightList* wl, object net):
wl.weights[i].W = <float*>malloc(sizeof(float) * w.size) wl.weights[i].W = <float*>malloc(sizeof(float) * w.size)
assert sh.flags['C_CONTIGUOUS'] assert sh.flags['C_CONTIGUOUS']
memcpy(wl.weights[i].shape, <long*>PyArray_DATA(sh), memcpy(wl.weights[i].shape, PyArray_DATA(sh), sh.size * sizeof(long))
sh.size * sizeof(long))
cdef public void update_weightlist(WeightList* wl, object net): cdef public void update_weightlist(WeightList* wl, object net):
@@ -99,8 +149,7 @@ cdef public void update_weightlist(WeightList* wl, object net):
w = w.astype(np.float32) w = w.astype(np.float32)
assert w.flags['C_CONTIGUOUS'] assert w.flags['C_CONTIGUOUS']
memcpy(wl.weights[i].W, <float*>PyArray_DATA(w), memcpy(wl.weights[i].W, PyArray_DATA(w), w.size * sizeof(float))
w.size * sizeof(float))
cdef public void combo_weights( cdef public void combo_weights(
@@ -127,7 +176,36 @@ cdef list wrap_weight_list(WeightList* wl):
return weights return weights
cdef void words_into_wordlist(WordList* wl, list words):
if wl.mem < len(words):
old = wl.mem
wl.mem = len(words)
wl.words = <Word*>realloc(wl.words, wl.mem * sizeof(Word))
for i in range(old, wl.mem):
wl.words[i].mem = 0
wl.words[i].data = <char*>0
wl.n_words = len(words)
for i, w in enumerate(words):
wenc = w.encode('utf-8')
if wl.words[i].mem < len(wenc) + 1:
wl.words[i].mem = len(wenc) + 1
wl.words[i].data = <char*>realloc(
wl.words[i].data, wl.words[i].mem * sizeof(char)
)
memcpy(wl.words[i].data, <char*>wenc, len(wenc) * sizeof(char))
wl.words[i].data[len(wenc)] = 0
def inspect_array(a): def inspect_array(a):
print(a.flags, flush=True) print(a.flags, flush=True)
print(a.dtype, flush=True) print(a.dtype, flush=True)
print(a.sum(), flush=True) print(a.sum(), flush=True)
def ensure_contiguous(a):
assert a.flats['C_CONTIGUOUS']
def eprint(*args, **kwargs):
return print(*args, flush=True, **kwargs)

54
library.py Normal file
View File

@@ -0,0 +1,54 @@
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) # STFU!
# from nltk.corpus import stopwords
# from nltk.tokenize import word_tokenize
from mynet import load_mnist, onehot
def word_tokenize(s: str):
l = ''.join(c if c.isalpha() else ' ' for c in s)
return l.split()
HERE = os.path.abspath(os.path.dirname(__file__))
CORPUS = os.path.join(HERE, 'melville-moby_dick.txt')
# sw = set(stopwords.words('english'))
sw = ['the']
vocab = list(set(
w.lower() for w in word_tokenize(open(CORPUS).read())
if w.isalpha() and not w.lower() in sw
))
def create_mnist_network():
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(30, input_shape=(784,), activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(loss='categorical_crossentropy', optimizer='sgd',
metrics=['accuracy'])
return model
def create_cbow_network(win, vocab, embed):
ctxt = tf.keras.layers.Input(shape=[win])
ed = tf.keras.layers.Embedding(vocab, embed, input_length=win)(ctxt)
avgd = tf.keras.layers.Lambda(lambda x: tf.reduce_mean(x, axis=1))(ed)
mod = tf.keras.Model(inputs=ctxt, outputs=avgd)
mod.compile(
optimizer='sgd',
loss='categorical_crossentropy',
)
return mod
def token_generator(filename):
with open(filename) as f:
for l in f.readlines(500):
if not l.isspace():
tok = word_tokenize(l)
if tok:
yield tok

293
main.c
View File

@@ -1,7 +1,8 @@
#include "cythoned/library.h" #include "cythoned/bridge.h"
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <string.h>
#include <mpi.h> #include <mpi.h>
#define TAG_IDGAF 0 #define TAG_IDGAF 0
@@ -10,10 +11,15 @@
#define TAG_WEIGH 3 #define TAG_WEIGH 3
#define TAG_READY 4 #define TAG_READY 4
#define TAG_BREAK 5 #define TAG_BREAK 5
#define TAG_STLEN 6
#define TAG_SWORD 7
#define TAG_IWORD 8
#define COMM 100 #define COMM 100
#define ITER 20 #define ITER 20
#define BS 50 #define BS 50
#define EMB 20
#define WIN 2
#define FSPC 1 #define FSPC 1
#define in_range(i, x) (size_t (i) = 0; (i) < (x); (i)++) #define in_range(i, x) (size_t (i) = 0; (i) < (x); (i)++)
@@ -24,109 +30,75 @@
#define INFO_PRINTLN(what) \ #define INFO_PRINTLN(what) \
do { fprintf(stderr, "%s\n", what); } while(0) do { fprintf(stderr, "%s\n", what); } while(0)
// char_stream -> tokenize -> word_strem -> filter + batch -> slave network
typedef enum{ typedef enum{
DATA, TOKENIZER,
FILTERER,
BATCHER,
SLAVE, SLAVE,
MASTER MASTER
} Role; } Role;
typedef struct IntQueue IntQueue; int world_size() {
struct IntQueue {
int head;
int tail;
size_t size;
int* data;
};
void queue_from_size(IntQueue* q, size_t s) {
q->data = malloc(s * sizeof(int));
q->size = s+1;
q->head = 0;
q->tail = 0;
}
void push_queue(IntQueue *q, int d) {
// Assuming queue is not full
q->data[q->tail] = d;
q->tail = (q->tail + 1) % q->size;
}
int pop_queue(IntQueue *q) {
int d = q->data[q->head];
q->head = (q->head + 1) % q->size;
return d;
}
int queue_empty(IntQueue *q) {
return q->head == q->tail;
}
int queue_full(IntQueue *q) {
return ((q->tail + 1) % q->size) == q->head;
}
int number_of_nodes() {
int n; int n;
MPI_Comm_size(MPI_COMM_WORLD, &n); MPI_Comm_size(MPI_COMM_WORLD, &n);
return n; return n;
} }
int number_of_masters() { int my_mpi_id() {
return 1;
}
int number_of_readers() {
return 1;
}
int number_of_slaves() {
return number_of_nodes() - number_of_masters() - number_of_readers();
}
int my_id() {
int i; int i;
MPI_Comm_rank(MPI_COMM_WORLD, &i); MPI_Comm_rank(MPI_COMM_WORLD, &i);
return i; return i;
} }
int master_id(int m) { size_t number_of(Role what) {
return m; switch (what) {
case TOKENIZER:
return 1;
case FILTERER:
return 1;
case BATCHER:
return 1;
case SLAVE:
return world_size()
- number_of(TOKENIZER)
- number_of(FILTERER)
- number_of(BATCHER)
- number_of(MASTER);
case MASTER:
return 1;
}
} }
int reader_id(int r) { int mpi_id_from_role_id(Role role, int rid) {
return r + number_of_masters(); int base = 0;
for (Role r = TOKENIZER; r < role; r++) {
base += number_of(r);
}
return rid + base;
} }
int slave_id(int s) { int role_id_from_mpi_id(Role role, int mid) {
return s + number_of_masters() + number_of_readers(); int z = mpi_id_from_role_id(role, 0);
int rid = mid - z;
if (rid >= number_of(role) || rid < 0) {
INFO_PRINTF("%d is not a %d\n", mid, role);
exit(1);
}
return rid;
} }
Role map_node() { Role map_node() {
int node; int node = my_mpi_id();
MPI_Comm_rank(MPI_COMM_WORLD, &node); size_t base = 0;
if (node >= reader_id(0) && node <= reader_id(number_of_readers()-1)) { for (Role r = TOKENIZER; r <= MASTER; r++) {
return DATA; if (node < number_of(r) + base) return r;
} base += number_of(r);
if (node >= master_id(0) && node <= master_id(number_of_masters()-1)) {
return MASTER;
}
if (node >= slave_id(0) && node <= slave_id(number_of_slaves()-1)) {
return SLAVE;
} }
exit(1); // this is bad exit(1); // this is bad
} }
int rid(int id, Role what) {
int z;
switch (what) {
case DATA: z = reader_id(0); break;
case SLAVE: z = slave_id(0); break;
case MASTER: z = master_id(0); break;
}
return id - z;
}
void free_weightlist(WeightList* wl) { void free_weightlist(WeightList* wl) {
for in_range(i, wl->n_weights) { for in_range(i, wl->n_weights) {
free(wl->weights[i].shape); free(wl->weights[i].shape);
@@ -135,27 +107,92 @@ void free_weightlist(WeightList* wl) {
free(wl->weights); free(wl->weights);
} }
void data_reader() { void free_word(Word* w) {
// Reads some data and converts it to a float array free(w->data);
INFO_PRINTF("Starting reader %d\n", getpid()); w->data = NULL;
w->mem = 0;
}
size_t X_numel = 784 * BS; void free_wordlist(WordList* wl) {
size_t y_numel = 10 * BS; for in_range(i, wl->mem) {
float* X = malloc(X_numel * sizeof(float)); free_word(wl->words + i);
float* y = malloc(y_numel * sizeof(float)); }
free(wl->words);
wl->words = NULL;
wl->n_words = 0;
}
void send_word(Word* w, int dest) {
long len = strlen(w->data);
MPI_Send(&len, 1, MPI_LONG, dest, TAG_STLEN, MPI_COMM_WORLD);
MPI_Send(w->data, len + 1, MPI_CHAR, dest, TAG_SWORD, MPI_COMM_WORLD);
}
void recv_word(Word* w, int src) {
long len;
MPI_Recv(&len, 1, MPI_LONG, src, TAG_STLEN, MPI_COMM_WORLD,
MPI_STATUS_IGNORE);
if (w->mem < len + 1) {
w->mem = len + 1;
w->data = realloc(w->data, sizeof(char) * w->mem);
}
MPI_Recv(w->data, len + 1, MPI_CHAR, src, TAG_SWORD, MPI_COMM_WORLD,
MPI_STATUS_IGNORE);
}
void tokenizer(const char* source) {
WordList wl = {0, 0, NULL};
while (get_tokens(&wl, source)) {
for in_range(i, wl.n_words) {
send_word(&wl.words[i], mpi_id_from_role_id(FILTERER, 0));
// printf("OI %s\n", wl.words[i].data);
}
// INFO_PRINTLN("");
}
Word terminator = {0, ""};
send_word(&terminator, mpi_id_from_role_id(FILTERER, 0));
free_wordlist(&wl);
}
void filterer() {
Word w = {0, NULL};
while (1) {
recv_word(&w, role_id_from_mpi_id(TOKENIZER, 0));
if (!strlen(w.data)) {
break;
}
INFO_PRINTF("%s: ", w.data);
long idx = vocab_idx_of(&w);
INFO_PRINTF("%ld\n", idx);
// if (idx != -1) {
// MPI_Send(&idx, 1, MPI_LONG, mpi_id_from_role_id(BATCHER, 0),
// TAG_IWORD, MPI_COMM_WORLD);
// }
}
free_word(&w);
}
void batcher() {
// Reads some data and converts it to a float array
INFO_PRINTF("Starting batcher %d\n", getpid());
int s = 0; int s = 0;
const size_t n_words = BS + WIN + WIN;
float* f_widx = malloc(n_words * sizeof(float));
while (s != -1) { while (s != -1) {
for in_range(i, n_words) {
long l_wid;
MPI_Recv(&l_wid, 1, MPI_LONG, role_id_from_mpi_id(FILTERER, 0),
TAG_IWORD, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
f_widx[i] = (float)l_wid;
}
MPI_Recv(&s, 1, MPI_INT, MPI_ANY_SOURCE, TAG_READY, MPI_COMM_WORLD, MPI_Recv(&s, 1, MPI_INT, MPI_ANY_SOURCE, TAG_READY, MPI_COMM_WORLD,
MPI_STATUS_IGNORE); MPI_STATUS_IGNORE);
if (s != -1) { if (s != -1) {
mnist_batch(X, y, BS, rid(s, SLAVE), number_of_slaves()); MPI_Send(f_widx, n_words, MPI_FLOAT, s, TAG_BATCH, MPI_COMM_WORLD);
MPI_Send(X, X_numel, MPI_FLOAT, s, TAG_BATCH, MPI_COMM_WORLD);
MPI_Send(y, y_numel, MPI_FLOAT, s, TAG_BATCH, MPI_COMM_WORLD);
} }
} }
free(X); free(f_widx);
free(y);
} }
void send_weights(const WeightList* wl, int dest, int tag) { void send_weights(const WeightList* wl, int dest, int tag) {
@@ -191,34 +228,39 @@ void slave_node() {
// 3. Do computations // 3. Do computations
// 4. Send weights back to master // 4. Send weights back to master
INFO_PRINTF("Starting slave %d\n", getpid()); INFO_PRINTF("Starting slave %d\n", getpid());
int me = my_mpi_id();
int me; PyObject* net = create_network(WIN, EMB);
MPI_Comm_rank(MPI_COMM_WORLD, &me);
size_t X_numel = 784 * BS;
size_t y_numel = 10 * BS;
float* X = malloc(X_numel * sizeof(float));
float* y = malloc(y_numel * sizeof(float));
PyObject* net = create_network();
WeightList wl; WeightList wl;
init_weightlist_like(&wl, net); init_weightlist_like(&wl, net);
size_t vocab = out_size(net);
size_t n_words = (BS + WIN + WIN);
size_t X_numel = BS * (WIN + WIN);
size_t y_numel = BS * vocab;
float* X = malloc(X_numel * sizeof(float));
float* y = malloc(y_numel * sizeof(float));
float* f_widx = malloc(n_words * sizeof(float));
for in_range(i, COMM) { for in_range(i, COMM) {
MPI_Send(&me, 1, MPI_INT, master_id(0), TAG_READY, MPI_COMM_WORLD); MPI_Send(&me, 1, MPI_INT, mpi_id_from_role_id(MASTER, 0),
recv_weights(&wl, master_id(0), TAG_WEIGH); TAG_READY, MPI_COMM_WORLD);
recv_weights(&wl, mpi_id_from_role_id(MASTER, 0), TAG_WEIGH);
set_net_weights(net, &wl); set_net_weights(net, &wl);
for in_range(k, ITER) { for in_range(k, ITER) {
MPI_Send(&me, 1, MPI_INT, reader_id(0), TAG_READY, MPI_COMM_WORLD); MPI_Send(&me, 1, MPI_INT, mpi_id_from_role_id(BATCHER, 0),
MPI_Recv(X, X_numel, MPI_FLOAT, reader_id(0), TAG_BATCH, TAG_READY, MPI_COMM_WORLD);
MPI_COMM_WORLD, MPI_STATUS_IGNORE); MPI_Recv(f_widx, n_words, MPI_FLOAT,
MPI_Recv(y, y_numel, MPI_FLOAT, reader_id(0), TAG_BATCH, mpi_id_from_role_id(BATCHER, 0), TAG_BATCH, MPI_COMM_WORLD,
MPI_COMM_WORLD, MPI_STATUS_IGNORE); MPI_STATUS_IGNORE);
c_slices(X, f_widx, BS, WIN);
c_onehot(y, f_widx + WIN, BS);
step_net(net, X, y, BS); step_net(net, X, y, BS);
} }
printf("%d net: %f\n", my_id(), eval_net(net)); printf("%d net: %f\n", my_mpi_id(), eval_net(net));
update_weightlist(&wl, net); update_weightlist(&wl, net);
send_weights(&wl, master_id(0), TAG_WEIGH); send_weights(&wl, mpi_id_from_role_id(MASTER, 0), TAG_WEIGH);
} }
Py_DECREF(net); Py_DECREF(net);
free_weightlist(&wl); free_weightlist(&wl);
@@ -232,12 +274,12 @@ void master_node() {
// 3. Average the weights // 3. Average the weights
PyObject* frank = create_network(); PyObject* frank = create_network(WIN, EMB);
WeightList wl; WeightList wl;
init_weightlist_like(&wl, frank); init_weightlist_like(&wl, frank);
update_weightlist(&wl, frank); update_weightlist(&wl, frank);
int spr = number_of_slaves() * FSPC; // Slaves per round int spr = number_of(SLAVE) * FSPC; // Slaves per round
int s; int s;
WeightList *wls = malloc(sizeof(WeightList) * spr); WeightList *wls = malloc(sizeof(WeightList) * spr);
@@ -265,33 +307,40 @@ void master_node() {
free_weightlist(&wl); free_weightlist(&wl);
for in_range(i, spr) free_weightlist(wls + i); for in_range(i, spr) free_weightlist(wls + i);
free(wls); free(wls);
if (rid(my_id(), MASTER) == 0) { // if (role_id_from_mpi_id(my_mpi_id(), MASTER) == 0) {
for in_range(r, number_of_readers()) { // for in_range(r, number_of(BATCHER)) {
int stop = -1; // int stop = -1;
MPI_Send(&stop, 1, MPI_INT, reader_id(r), TAG_READY, // MPI_Send(&stop, 1, MPI_INT, reader_id(r), TAG_READY,
MPI_COMM_WORLD); // MPI_COMM_WORLD);
} // }
} // }
} }
int main (int argc, const char **argv) { int main (int argc, const char **argv) {
MPI_Init(NULL, NULL); MPI_Init(NULL, NULL);
// Cython Boilerplate // Cython Boilerplate
PyImport_AppendInittab("library", PyInit_library); PyImport_AppendInittab("bridge", PyInit_bridge);
Py_Initialize(); Py_Initialize();
PyRun_SimpleString("import sys\nsys.path.insert(0,'')"); PyRun_SimpleString("import sys\nsys.path.insert(0,'')");
PyObject* library_module = PyImport_ImportModule("library"); PyObject* bridge_module = PyImport_ImportModule("bridge");
// Actual Code // Actual Code
switch (map_node()) { switch (map_node()) {
case DATA: data_reader(); break; case TOKENIZER:
case SLAVE: slave_node(); break; tokenizer(argv[1]);
case MASTER: master_node(); break; break;
case FILTERER:
filterer();
break;
default:
INFO_PRINTLN("DYING HORRIBLY!");
// case SLAVE: slave_node(); break;
// case MASTER: master_node(); break;
} }
// Finalizing Boilerplate // Finalizing Boilerplate
Py_DECREF(library_module); Py_DECREF(bridge_module);
Py_Finalize(); Py_Finalize();
MPI_Finalize(); MPI_Finalize();
} }

View File

@@ -12,7 +12,7 @@ numpy_header = include_directories(run_command(
).stdout().strip()) ).stdout().strip())
executable( executable(
'fedavg_mpi', 'main.c', 'cythoned/library.c', 'fedavg_mpi', 'main.c', 'cythoned/bridge.c',
dependencies: [mpi, python], dependencies: [mpi, python],
include_directories: numpy_header, include_directories: numpy_header,
link_args: '-Wl,-w' link_args: '-Wl,-w'

16
nn.py
View File

@@ -1,16 +0,0 @@
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) # STFU!
from mynet import load_mnist
def create_mnist_network():
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(30, input_shape=(784,), activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(loss='categorical_crossentropy', optimizer='sgd',
metrics=['accuracy'])
return model