def test_send_recv(self): rank = dist.get_rank() tensor = _build_tensor(rank + 1) for dest in range(0, dist.get_world_size()): if dest == rank: continue dist.send(tensor, dest) for src in range(0, dist.get_world_size()): if src == rank: continue tensor = _build_tensor(src + 1, value=-1) expected_tensor = _build_tensor(src + 1) dist.recv(tensor, src) self.assertEqual(tensor, expected_tensor) self._barrier()
def test_send_recv_any_source(self): rank = dist.get_rank() tensor = _build_tensor(10, rank) for dest in range(0, dist.get_world_size()): if dest == rank: continue dist.send(tensor, dest) recv_ranks = set() for src in range(0, dist.get_world_size()): if src == rank: continue tensor = _build_tensor(10, value=-1) sender = dist.recv(tensor) self.assertTrue(tensor.eq(sender).all()) recv_ranks.add(sender) self.assertEqual(len(recv_ranks), dist.get_world_size() - 1) self._barrier()
def test_irecv(self): rank = dist.get_rank() world_size = dist.get_world_size() if rank == 0: expected_tensors = [_build_tensor(src, -1) for src in range(1, world_size)] requests = [ dist.irecv(expected_tensors[src - 1], src) for src in range(1, world_size) ] for src in range(1, world_size): requests[src - 1].wait() self.assertTrue(requests[src - 1].is_completed()) self.assertEqual(expected_tensors[src - 1], _build_tensor(src, 10)) else: tensor = _build_tensor(rank, 10) dist.send(tensor, 0) self._barrier()
def run(rank, model, train_data, test_data, queue, param_q, stop_flag): # Get the initial model from the server while True: if not param_q.empty(): param_dict = param_q.get() tmp = OrderedDict( map(lambda item: (item[0], torch.from_numpy(item[1])), param_dict.items())) model.load_state_dict(tmp) break print('Model recved successfully!') if args.model in ['MnistCNN', 'AlexNet', 'ResNet18OnCifar10']: optimizer = MySGD(model.parameters(), lr=0.1) else: optimizer = MySGD(model.parameters(), lr=0.01) if args.model in ['MnistCNN', 'AlexNet']: criterion = torch.nn.NLLLoss() else: criterion = torch.nn.CrossEntropyLoss() if args.model in ['AlexNet', 'ResNet18OnCifar10']: decay_period = 50 else: decay_period = 100 print('Begin!') time_logs = open("./record" + str(rank), 'w') for epoch in range(int(args.epochs)): model.train() # Decay the learning at the specific epoch #if args.model == 'AlexNet': if (epoch + 1) % decay_period == 0: for param_group in optimizer.param_groups: param_group['lr'] *= 0.1 print('LR Decreased! Now: {}'.format(param_group['lr'])) epoch_train_loss = 0 for batch_idx, (data, target) in enumerate(train_data): it_start = time.time() data, target = Variable(data), Variable(target) optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() delta_ws = optimizer.get_delta_w() it_comp_end = time.time() # noinspection PyBroadException try: if delta_ws: queue.put({ rank: [loss.data.numpy(), np.array(args.train_bsz), False] }) for delta in delta_ws: dist.send(tensor=delta, dst=0) for idx, param in enumerate(model.parameters()): tmp_tensor = torch.zeros_like(param.data) dist.recv(tensor=tmp_tensor, src=0) param.data = tmp_tensor #print('Rank {}, Epoch {}, Batch {}/{}, Loss:{}' # .format(rank, epoch, batch_idx, len(train_data), loss.data[0])) except Exception as e: print(str(e)) print('Should Stop: {}!'.format(stop_flag.value)) break it_comm_end = time.time() it_duration = it_comm_end - it_start it_comp_duration = it_comp_end - it_start it_comm_duration = it_comm_end - it_comp_end time_logs.write( str(it_duration) + "\t" + str(it_comp_duration) + "\t" + str(it_comm_duration) + "\n") time_logs.flush() # test the model print("test Model:", epoch) # test_model(rank, model, test_data, criterion=criterion) if stop_flag.value: break queue.put({rank: [[], [], True]}) time_logs.close() print("Worker {} has completed epoch {}!".format(args.this_rank, epoch))
def run(model, test_data, queue, param_q, stop_signal, train_pics): if args.model == 'MnistCNN': criterion = torch.nn.NLLLoss() else: criterion = torch.nn.NLLLoss() # Transform tensor to numpy tmp = map(lambda item: (item[0], item[1].numpy()), model.state_dict().items()) _tmp = OrderedDict(tmp) workers = [v+1 for v in range(args.workers_num)] for _ in workers: param_q.put(_tmp) print('Model Sent Finished!') print('Begin!') worker_gradient_list = [] epoch_train_loss = 0 iteration_in_epoch = 0 data_size_epoch = 0 # len(train_data), one epoch epoch_count = 0 staleness_sum_suqare_epoch = 0 staleness_sum_epoch = 0 staleness = 0 learner_staleness = {l: 0 for l in workers} s_time = time.time() epoch_time = s_time # In SSP, the fast workers have to wait the slowest worker a given duration # The fast worker exceeding the duration will be pushed into the queue to wait trainloss_file = './trainloss' + args.model + '.txt' staleness_file = './staleness' + args.model + ".txt" log_file = './log' + args.model + ".txt" if(os.path.isfile(trainloss_file)): os.remove(trainloss_file) if(os.path.isfile(staleness_file)): os.remove(staleness_file) if (os.path.isfile(log_file)): os.remove(log_file) f_trainloss = open(trainloss_file, 'a') f_staleness = open(staleness_file, 'a') f_log = open(log_file, 'a') it_start_time = time.time() global_clock = 0 while True: if not queue.empty(): batch_start_time = time.time() tmp_dict = queue.get() rank_src = list(tmp_dict.keys())[0] isWorkerEnd = tmp_dict[rank_src][2] if isWorkerEnd: print("Worker {} has completed all its data computation!".format(rank_src)) learner_staleness.pop(rank_src) if (len(learner_staleness) == 0): f_trainloss.close() f_staleness.close() stop_signal.put(1) print('Epoch is done: {}'.format(epoch_count)) break continue batch_receive_time = time.time() # Dictionary:k:index, v:delta_w tmp_gradient = [] for param in model.parameters(): tmp = torch.zeros_like(param.data) dist.recv(tensor = tmp, src=rank_src) tmp_gradient.append(tmp) # add the local to k-list cache worker_gradient_list.append(tmp_gradient) iteration_loss = tmp_dict[rank_src][0] batch_size = tmp_dict[rank_src][1] iteration_in_epoch += 1 epoch_train_loss += iteration_loss data_size_epoch += batch_size stale = int(staleness - learner_staleness[rank_src]) learner_staleness[rank_src] = staleness staleness_sum_epoch += stale # staleness_sum_suqare_epoch += stale ** 2 batch_preprocess_time = time.time() # return current model to worker for idx, param in enumerate(model.parameters()): dist.send(tensor=param.data, dst=rank_src) batch_send_time = time.time() # update current model while number of mini-batchs = args.stale_threshold if len(worker_gradient_list) >= args.stale_threshold: global_clock += 1 for idx, param in enumerate(model.parameters()): delta_ws = torch.zeros_like(param.data) for w_g_idx in range(len(worker_gradient_list)): delta_ws += worker_gradient_list[w_g_idx][idx] param.data -= delta_ws/len(worker_gradient_list) f_log.write(str(len(worker_gradient_list))+"\n") f_log.flush() worker_gradient_list = [] staleness += 1 # update system clock batch_end_time = time.time() batch_interval = batch_end_time - batch_start_time batch_receive_interval = batch_receive_time - batch_start_time batch_send_interval = batch_send_time - batch_preprocess_time it_iterval = batch_end_time - it_start_time it_start_time = time.time() #print('Done From Rank {}, Staleness {}!' # .format(rank_src, stale)) # epoch, rank, batch size, stale f_staleness.write(str(epoch_count) + "\t" + str(rank_src) + "\t" + str(batch_size) + "\t" + str(stale) + "\t" + str(batch_interval) + "\t" + str(batch_receive_interval) + "\t" + str(batch_send_interval) + "\t" + str(it_iterval) + "\t" + str(global_clock) + '\n') # once reach an epoch, count the average train loss if(data_size_epoch >= train_pics): e_epoch_time = time.time() #variance of stale # diversity_stale = (staleness_sum_suqare_epoch/iteration_in_epoch)\ # - (staleness_sum_epoch/iteration_in_epoch)**2 diversity_stale = 0.0 staleness_sum_suqare_epoch = 0 staleness_sum_epoch = 0 # test_loss, test_acc = test_model(dist.get_rank(), model, test_data, criterion=criterion) test_acc = 0 # rank, trainloss, variance of stalness, time in one epoch, time till now f_trainloss.write(str(args.this_rank) + "\t" + str(epoch_train_loss/float(iteration_in_epoch)) + "\t" + str(0) + "\t" + str(e_epoch_time - epoch_time) + "\t" + str(e_epoch_time - s_time) + "\t" + str(e_epoch_time - s_time) + "\t" + str(epoch_count) + "\t" + str(test_acc) + "\t" + str(diversity_stale) + "\t" + str(global_clock) + '\n') f_trainloss.flush() f_staleness.flush() iteration_in_epoch = 0 epoch_count += 1 epoch_train_loss = 0 data_size_epoch = 0 epoch_time = e_epoch_time # The training stop if(epoch_count >= args.epochs): f_trainloss.close() f_staleness.close() f_log.close() stop_signal.put(1) print('Epoch is done: {}'.format(epoch_count)) break e_time = time.time() if (e_time - s_time) >= float(args.timeout): f_trainloss.close() f_staleness.close() stop_signal.put(1) print('Time up: {}, Stop Now!'.format(e_time - s_time)) break
def run(model, test_data, queue, param_q, stop_signal, train_pics): if args.model == 'MnistCNN': criterion = torch.nn.NLLLoss() else: criterion = torch.nn.NLLLoss() # Transform tensor to numpy tmp = map(lambda item: (item[0], item[1].numpy()), model.state_dict().items()) _tmp = OrderedDict(tmp) workers = [v + 1 for v in range(args.workers_num)] for _ in workers: param_q.put(_tmp) print('Model Sent Finished!') print('Begin!') epoch_train_loss = 0 iteration_in_epoch = 0 data_size_epoch = 0 # len(train_data), one epoch epoch_count = 0 staleness_sum_suqare_epoch = 0 staleness_sum_epoch = 0 staleness = 0 # the global clock learner_staleness = {l: 0 for l in workers} s_time = time.time() epoch_time = s_time # In SSP, the fast workers have to wait the slowest worker a given duration # The fast worker exceeding the duration will be pushed into the queue to wait stale_stack = [] trainloss_file = './trainloss' + args.model + '.txt' staleness_file = './staleness' + args.model + ".txt" if (os.path.isfile(trainloss_file)): os.remove(trainloss_file) if (os.path.isfile(staleness_file)): os.remove(staleness_file) f_trainloss = open(trainloss_file, 'a') f_staleness = open(staleness_file, 'a') global_clock = 0 while True: if not queue.empty(): it_start_time = time.time() tmp_dict = queue.get() rank_src = list(tmp_dict.keys())[0] isWorkerEnd = tmp_dict[rank_src][2] if isWorkerEnd: print( "Worker {} has completed all its data computation!".format( rank_src)) learner_staleness.pop(rank_src) if (len(learner_staleness) == 0): f_trainloss.close() f_staleness.close() stop_signal.put(1) print('Epoch is done: {}'.format(epoch_count)) break continue stale = int(staleness - learner_staleness[rank_src]) staleness_sum_epoch += stale # staleness_sum_suqare_epoch += stale**2 staleness_sum_suqare_epoch += 0.0 staleness += 1 learner_staleness[rank_src] = staleness stale_stack.append(rank_src) # recv gradients of the worker and update current model for idx, param in enumerate(model.parameters()): tmp_tensor = torch.zeros_like(param.data) dist.recv(tensor=tmp_tensor, src=rank_src) param.data -= tmp_tensor / (stale + 1) global_clock += 1 iteration_loss = tmp_dict[rank_src][0] batch_size = tmp_dict[rank_src][1] iteration_in_epoch += 1 epoch_train_loss += iteration_loss data_size_epoch += batch_size # judge if the staleness exceed the staleness threshold in SSP outOfStale = False for stale_each_worker in learner_staleness: if (stale_each_worker not in stale_stack) & \ (staleness - learner_staleness[stale_each_worker] > args.stale_threshold): outOfStale = True break if not outOfStale: for i in range(len(stale_stack)): rank_wait = stale_stack.pop() # update the value of staleness in the corresponding learner learner_staleness[rank_wait] = staleness for idx, param in enumerate(model.parameters()): dist.send(tensor=param.data, dst=rank_wait) else: continue it_end_time = time.time() f_staleness.write( str(epoch_count) + "\t" + str(rank_src) + "\t" + str(batch_size) + "\t" + str(stale) + "\t" + str(it_end_time - it_start_time) + "\t" + str(global_clock) + '\n') f_staleness.flush() # once reach an epoch, count the average train loss if (data_size_epoch >= train_pics): e_epoch_time = time.time() #variance of stale diversity_stale = (staleness_sum_suqare_epoch/iteration_in_epoch)\ - (staleness_sum_epoch/iteration_in_epoch)**2 staleness_sum_suqare_epoch = 0 staleness_sum_epoch = 0 #test_loss, test_acc = test_model(dist.get_rank(), model, test_data, criterion=criterion) test_acc = 0.0 test_acc = 0.0 # rank, trainloss, variance of stalness, time in one epoch, time till now f_trainloss.write( str(args.this_rank) + "\t" + str(epoch_train_loss / float(iteration_in_epoch)) + "\t" + str(diversity_stale) + "\t" + str(e_epoch_time - epoch_time) + "\t" + str(e_epoch_time - s_time) + "\t" + str(epoch_count) + "\t" + str(test_acc) + "\t" + str(global_clock) + '\n') f_trainloss.flush() iteration_in_epoch = 0 epoch_count += 1 epoch_train_loss = 0 data_size_epoch = 0 epoch_time = e_epoch_time # The training stop if (epoch_count >= args.epochs): f_trainloss.close() f_staleness.close() stop_signal.put(1) print('Epoch is done: {}'.format(epoch_count)) break e_time = time.time() if (e_time - s_time) >= float(args.timeout): f_trainloss.close() f_staleness.close() stop_signal.put(1) print('Time up: {}, Stop Now!'.format(e_time - s_time)) break
def run(rank, model, train_data, test_data, queue, param_q, stop_flag): # 获取ps端传来的模型初始参数 while True: if not param_q.empty(): param_dict = param_q.get() tmp = OrderedDict( map(lambda item: (item[0], torch.from_numpy(item[1])), param_dict.items())) model.load_state_dict(tmp) break print('Model recved successfully!') if args.model in ['MnistCNN', 'AlexNet', 'ResNet18OnCifar10']: optimizer = MySGD(model.parameters(), lr=0.1) else: optimizer = MySGD(model.parameters(), lr=0.01) if args.model in ['MnistCNN', 'AlexNet']: criterion = torch.nn.NLLLoss() else: criterion = torch.nn.CrossEntropyLoss() if args.model in ['AlexNet', 'ResNet18OnCifar10']: decay_period = 50 else: decay_period = 100 print('Begin!') time_logs = open("./record" + str(rank), 'w') for epoch in range(int(args.epochs)): batch_interval = 0.0 batch_comp_interval = 0.0 batch_comm_interval = 0.0 batch_push_interval = 0.0 batch_pull_interval = 0.0 model.train() # AlexNet在指定epoch减少学习率LR #if args.model == 'AlexNet': if (epoch + 1) % decay_period == 0: for param_group in optimizer.param_groups: param_group['lr'] *= 0.1 print('LR Decreased! Now: {}'.format(param_group['lr'])) epoch_train_loss = 0 for batch_idx, (data, target) in enumerate(train_data): batch_start_time = time.time() data, target = Variable(data), Variable(target) optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() delta_ws = optimizer.get_delta_w() batch_comp_time = time.time() # noinspection PyBroadException try: # 捕获异常,异常来源于ps进程的停止 if delta_ws: queue.put({ rank: [loss.data.numpy(), np.array(args.train_bsz), False] }) for delta in delta_ws: dist.send(tensor=delta, dst=0) batch_push_time = time.time() for idx, param in enumerate(model.parameters()): tmp_tensor = torch.zeros_like(param.data) dist.recv(tensor=tmp_tensor, src=0) param.data = tmp_tensor batch_tmp_time = time.time() batch_pull_time = time.time() #print('Rank {}, Epoch {}, Batch {}/{}, Loss:{}' # .format(rank, epoch, batch_idx, len(train_data), loss.data[0])) except Exception as e: print(str(e)) print('Should Stop: {}!'.format(stop_flag.value)) break batch_interval += batch_pull_time - batch_start_time batch_comp_interval += batch_comp_time - batch_start_time batch_comm_interval += batch_pull_time - batch_comp_time batch_push_interval += batch_push_time - batch_comp_time batch_pull_interval += batch_pull_time - batch_push_time b_interval = batch_interval / (batch_idx + 1) b_comp_interval = batch_comp_interval / (batch_idx + 1) b_comm_interval = batch_comm_interval / (batch_idx + 1) b_push_interval = batch_push_interval / (batch_idx + 1) b_pull_interval = batch_pull_interval / (batch_idx + 1) logs = torch.tensor([ 0.0, b_interval, b_comp_interval, b_comm_interval, b_push_interval, b_pull_interval, batch_pull_time - batch_tmp_time ]) time_logs.write(str(logs) + '\n') time_logs.flush() batch_interval /= batch_idx batch_comp_interval /= batch_idx batch_comm_interval /= batch_idx batch_push_interval /= batch_idx batch_pull_interval /= batch_idx logs = torch.tensor([ 0.0, batch_interval, batch_comp_interval, batch_comm_interval, batch_push_interval, batch_pull_interval ]) time_logs.write(str(epoch) + '\t' + str(logs) + '\n') time_logs.flush() # 训练结束后进行test print("test Model:", epoch) # test_model(rank, model, test_data, criterion=criterion) if stop_flag.value: break queue.put({rank: [[], [], True]}) time_logs.close() print("Worker {} has completed epoch {}!".format(args.this_rank, epoch))