diff --git a/main.c b/main.c index 9cf7f3c..9b56ad3 100644 --- a/main.c +++ b/main.c @@ -11,11 +11,13 @@ #define TAG_IDGAF 0 #define TAG_BATCH 1 #define TAG_NETWK 2 -#define TAG_WEIGH 2 +#define TAG_WEIGH 3 +#define TAG_READY 4 #define COMM 500 #define ITER 40 #define BS 50 +#define FSPC 0.2 #define sid(s) s + P_SLAVE @@ -29,20 +31,54 @@ typedef enum{ MASTER } Role; + +typedef struct IntQueue IntQueue; +struct IntQueue { + int head; + int tail; + size_t size; + int* data; +}; + +void queue_from_size(IntQueue* q, size_t s) { + q->data = malloc(s * sizeof(int)); + q->size = s+1; + q->head = 0; + q->tail = 0; +} + +void push_queue(IntQueue *q, int d) { + // Assuming queue is not full + q->data[q->tail] = d; + q->tail = (q->tail + 1) % q->size; +} + +int pop_queue(IntQueue *q) { + int d = q->data[q->head]; + q->head = (q->head + 1) % q->size; + return d; +} + +int queue_empty(IntQueue *q) { + return q->head == q->tail; +} + +int queue_full(IntQueue *q) { + return ((q->tail + 1) % q->size) == q->head; +} + void data_reader() { // Reads some data and converts it to a float array printf("Start reader\n"); size_t batch_numel = (784 + 10) * BS; float* batch = malloc(batch_numel * sizeof(float)); int s = 0; - int num_slaves; - MPI_Comm_size(MPI_COMM_WORLD, &num_slaves); - num_slaves -= P_SLAVE; + while (1) { + MPI_Recv(&s, 1, MPI_INT, MPI_ANY_SOURCE, TAG_READY, MPI_COMM_WORLD, + MPI_STATUS_IGNORE); mnist_batch(batch, BS); - MPI_Send(batch, batch_numel, MPI_FLOAT, sid(s), - TAG_BATCH, MPI_COMM_WORLD); - s = (s + 1) % num_slaves; + MPI_Send(batch, batch_numel, MPI_FLOAT, s, TAG_BATCH, MPI_COMM_WORLD); } free(batch); } @@ -73,8 +109,6 @@ void recv_weights(const Network* c_net, int src, int tag) { } } - - void send_network(const Network* c_net, int dest, int tag) { // Send a network to the expecting destination // It's best to receive with `recv_network` @@ -130,14 +164,19 @@ void free_network_contents(Network* c_net) { void slave_node() { printf("Start slave\n"); + int me; + MPI_Comm_rank(MPI_COMM_WORLD, &me); + size_t batch_numel = (784 + 10) * BS; float* batch = malloc(batch_numel * sizeof(float)); Network net; create_c_network(&net); for i_in_range(COMM) { + MPI_Send(&me, 1, MPI_INT, P_MASTER, TAG_READY, MPI_COMM_WORLD); recv_weights(&net, P_MASTER, TAG_NETWK); for (int k = 0; k < ITER; k++) { + MPI_Send(&me, 1, MPI_INT, P_READER, TAG_READY, MPI_COMM_WORLD); MPI_Recv(batch, batch_numel, MPI_FLOAT, P_READER, TAG_BATCH, MPI_COMM_WORLD, MPI_STATUS_IGNORE); step_net(&net, batch, BS); @@ -153,9 +192,11 @@ void master_node() { // Stores most up-to-date model, sends it to slaves for training // First do it synchronously // Need a "slave registry" + printf("Start master\n"); + int world_size; MPI_Comm_size(MPI_COMM_WORLD, &world_size); - printf("Start master\n"); + Network frank; create_c_network(&frank); @@ -165,6 +206,9 @@ void master_node() { Network* nets = malloc(sizeof(Network) * world_size); for s_in_slaves(world_size) create_c_network(nets + s); + IntQueue slave_queue; + queue_from_size(&slave_queue, world_size - P_SLAVE); + for i_in_range(COMM) { for s_in_slaves(world_size) { send_weights(&frank, sid(s), TAG_WEIGH);