commiting halfways because i'm feeling stuck
This commit is contained in:
64
main.c
64
main.c
@@ -11,11 +11,13 @@
|
||||
#define TAG_IDGAF 0
|
||||
#define TAG_BATCH 1
|
||||
#define TAG_NETWK 2
|
||||
#define TAG_WEIGH 2
|
||||
#define TAG_WEIGH 3
|
||||
#define TAG_READY 4
|
||||
|
||||
#define COMM 500
|
||||
#define ITER 40
|
||||
#define BS 50
|
||||
#define FSPC 0.2
|
||||
|
||||
#define sid(s) s + P_SLAVE
|
||||
|
||||
@@ -29,20 +31,54 @@ typedef enum{
|
||||
MASTER
|
||||
} Role;
|
||||
|
||||
|
||||
typedef struct IntQueue IntQueue;
|
||||
struct IntQueue {
|
||||
int head;
|
||||
int tail;
|
||||
size_t size;
|
||||
int* data;
|
||||
};
|
||||
|
||||
void queue_from_size(IntQueue* q, size_t s) {
|
||||
q->data = malloc(s * sizeof(int));
|
||||
q->size = s+1;
|
||||
q->head = 0;
|
||||
q->tail = 0;
|
||||
}
|
||||
|
||||
void push_queue(IntQueue *q, int d) {
|
||||
// Assuming queue is not full
|
||||
q->data[q->tail] = d;
|
||||
q->tail = (q->tail + 1) % q->size;
|
||||
}
|
||||
|
||||
int pop_queue(IntQueue *q) {
|
||||
int d = q->data[q->head];
|
||||
q->head = (q->head + 1) % q->size;
|
||||
return d;
|
||||
}
|
||||
|
||||
int queue_empty(IntQueue *q) {
|
||||
return q->head == q->tail;
|
||||
}
|
||||
|
||||
int queue_full(IntQueue *q) {
|
||||
return ((q->tail + 1) % q->size) == q->head;
|
||||
}
|
||||
|
||||
void data_reader() {
|
||||
// Reads some data and converts it to a float array
|
||||
printf("Start reader\n");
|
||||
size_t batch_numel = (784 + 10) * BS;
|
||||
float* batch = malloc(batch_numel * sizeof(float));
|
||||
int s = 0;
|
||||
int num_slaves;
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &num_slaves);
|
||||
num_slaves -= P_SLAVE;
|
||||
|
||||
while (1) {
|
||||
MPI_Recv(&s, 1, MPI_INT, MPI_ANY_SOURCE, TAG_READY, MPI_COMM_WORLD,
|
||||
MPI_STATUS_IGNORE);
|
||||
mnist_batch(batch, BS);
|
||||
MPI_Send(batch, batch_numel, MPI_FLOAT, sid(s),
|
||||
TAG_BATCH, MPI_COMM_WORLD);
|
||||
s = (s + 1) % num_slaves;
|
||||
MPI_Send(batch, batch_numel, MPI_FLOAT, s, TAG_BATCH, MPI_COMM_WORLD);
|
||||
}
|
||||
free(batch);
|
||||
}
|
||||
@@ -73,8 +109,6 @@ void recv_weights(const Network* c_net, int src, int tag) {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void send_network(const Network* c_net, int dest, int tag) {
|
||||
// Send a network to the expecting destination
|
||||
// It's best to receive with `recv_network`
|
||||
@@ -130,14 +164,19 @@ void free_network_contents(Network* c_net) {
|
||||
void slave_node() {
|
||||
printf("Start slave\n");
|
||||
|
||||
int me;
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &me);
|
||||
|
||||
size_t batch_numel = (784 + 10) * BS;
|
||||
float* batch = malloc(batch_numel * sizeof(float));
|
||||
Network net;
|
||||
create_c_network(&net);
|
||||
|
||||
for i_in_range(COMM) {
|
||||
MPI_Send(&me, 1, MPI_INT, P_MASTER, TAG_READY, MPI_COMM_WORLD);
|
||||
recv_weights(&net, P_MASTER, TAG_NETWK);
|
||||
for (int k = 0; k < ITER; k++) {
|
||||
MPI_Send(&me, 1, MPI_INT, P_READER, TAG_READY, MPI_COMM_WORLD);
|
||||
MPI_Recv(batch, batch_numel, MPI_FLOAT, P_READER, TAG_BATCH,
|
||||
MPI_COMM_WORLD, MPI_STATUS_IGNORE);
|
||||
step_net(&net, batch, BS);
|
||||
@@ -153,9 +192,11 @@ void master_node() {
|
||||
// Stores most up-to-date model, sends it to slaves for training
|
||||
// First do it synchronously
|
||||
// Need a "slave registry"
|
||||
printf("Start master\n");
|
||||
|
||||
int world_size;
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &world_size);
|
||||
printf("Start master\n");
|
||||
|
||||
Network frank;
|
||||
create_c_network(&frank);
|
||||
|
||||
@@ -165,6 +206,9 @@ void master_node() {
|
||||
Network* nets = malloc(sizeof(Network) * world_size);
|
||||
for s_in_slaves(world_size) create_c_network(nets + s);
|
||||
|
||||
IntQueue slave_queue;
|
||||
queue_from_size(&slave_queue, world_size - P_SLAVE);
|
||||
|
||||
for i_in_range(COMM) {
|
||||
for s_in_slaves(world_size) {
|
||||
send_weights(&frank, sid(s), TAG_WEIGH);
|
||||
|
||||
Reference in New Issue
Block a user