def run(size, rank, epoch, batchsize): #print('run') if MODEL == 'CNN' and DATA_SET == 'KWS': model = CNNKws() if MODEL == 'CNN' and DATA_SET == 'Cifar10': model = CNNCifar() if MODEL == 'ResNet18' and DATA_SET == 'Cifar10': model = ResNet18() model = model.cuda() optimizer = torch.optim.SGD(model.parameters(), lr=LR, weight_decay=1e-3) loss_func = torch.nn.CrossEntropyLoss() train_loader = get_local_data(size, rank, batchsize) if rank == 0: test_loader = get_testset(rank) #fo = open("file_multi"+str(rank)+".txt", 'w') group_list = [i for i in range(size)] group = dist.new_group(group_list) model, round = load_model(model, group, rank) while round < MAX_ROUND: sys.stdout.flush() if rank == 0: accuracy = 0 positive_test_number = 0 total_test_number = 0 for step, (test_x, test_y) in enumerate(test_loader): test_x = test_x.cuda() test_y = test_y.cuda() test_output = model(test_x) pred_y = torch.max(test_output, 1)[1].data.cpu().numpy() positive_test_number += ( pred_y == test_y.data.cpu().numpy()).astype(int).sum() # print(positive_test_number) total_test_number += float(test_y.size(0)) accuracy = positive_test_number / total_test_number print('Round: ', round, ' Rank: ', rank, '| test accuracy: %.4f' % accuracy) #fo.write(str(round) + " " + str(rank) + " " + str(accuracy) + "\n") for epoch_cnt in range(epoch): for step, (b_x, b_y) in enumerate(train_loader): b_x = b_x.cuda() b_y = b_y.cuda() optimizer.zero_grad() output = model(b_x) loss = loss_func(output, b_y) loss.backward() optimizer.step() model = all_reduce(model, size, group) #if (round+1) % ROUND_NUMBER_FOR_REDUCE == 0: #model = all_reduce(model, size, group) if (round + 1) % ROUND_NUMBER_FOR_SAVE == 0: save_model(model, round + 1, rank) round += 1
def _start_reduction_threads(self): num_buckets = len(self.bucket_sizes) self._reduction_queues = [queue.Queue() for _ in range(num_buckets)] self._reduction_threads = [] self._reduction_streams = [[] for _ in range(num_buckets)] self._nccl_streams = [] self._default_streams = [] for dev_id in self.device_ids: with torch.cuda.device(dev_id): # TODO: don't assume we're on a default stream self._default_streams.append(torch.cuda.current_stream()) self._nccl_streams.append(torch.cuda.Stream()) for reduction_queue, reduction_streams in zip(self._reduction_queues, self._reduction_streams): for dev_id in self.device_ids: with torch.cuda.device(dev_id): reduction_streams.append(torch.cuda.Stream()) # We only use the first device for distributed reductions dist._register_stream(reduction_streams[0]) group_id = dist.new_group() self._reduction_threads.append(threading.Thread( target=self._reduction_thread_fn, args=(reduction_queue, group_id, self.device_ids, reduction_streams, self._nccl_streams))) self._reduction_threads[-1].daemon = True self._reduction_threads[-1].start()
def _init_group_test(self): group = [1, 2] group_id = dist.new_group(group) rank = dist.get_rank() if rank not in group: return ([], None, rank) return (group, group_id, rank)
def init_processes(master_address, world_size, rank, epoch_per_round, batch_size, run): # change 'tcp' to 'nccl' if running on GPU worker dist.init_process_group(backend='tcp', init_method=master_address, world_size=world_size, rank=rank) group = dist.new_group([i for i in range(world_size)]) run(world_size, rank, group, epoch_per_round, batch_size)
def run(): modell = model.CNN() # modell = model.AlexNet() size = dist.get_world_size() rank = dist.get_rank() group_list = [] for i in range(size): group_list.append(i) group = dist.new_group(group_list) while (1): for param in modell.parameters(): # for dst in range(1, size): # dist.send(param.data, dst=dst) dist.broadcast(param.data, src=0, group=group) for param in modell.parameters(): tensor_temp = torch.zeros_like(param.data) dist.reduce(tensor_temp, dst=0, op=dist.reduce_op.SUM, group=group) param.data = tensor_temp / (size - 1)
def _register_nccl_grad_hook(self): """ This function registers the callback all-reduction function for the NCCL backend. All gradients will be all reduced in one single step. The NCCL reduction will directly be enqueued into the default CUDA stream. Therefore, no synchronization is needed. """ # Creating a new group self.nccl_reduction_group_id = dist.new_group() def reduction_fn_nccl(): # This function only needs to be called once if not self.need_reduction: return self.need_reduction = False all_grads = [[] for _ in range(len(self._module_copies))] all_grads_buckets_iters = [] # Bucketing all the gradients for dev_idx, module in enumerate(self._module_copies): for param in module.parameters(): if not param.requires_grad or param.grad is None: continue if param.grad.requires_grad: raise RuntimeError("DistributedDataParallel only works " "with gradients that don't require " "grad") # Adding the gradients for reduction all_grads[dev_idx].append(param.grad.data) # Now bucketing the parameters dev_grads_buckets = _take_tensors(all_grads[dev_idx], self.nccl_reduce_bucket_size) all_grads_buckets_iters.append(dev_grads_buckets) # Now reduce each bucket one after another for grads_batch in zip(*all_grads_buckets_iters): grads_batch_coalesced = [] # Coalesce each bucket for dev_idx, dev_grads_batch in enumerate(grads_batch): dev_id = self.device_ids[dev_idx] with torch.cuda.device(dev_id): dev_grads_batch_coalesced = _flatten_dense_tensors(dev_grads_batch) grads_batch_coalesced.append(dev_grads_batch_coalesced) # We will only use device 0's results, but this single op should be # faster than doing the following two operation sequentially: # (1) intra-node reduce to lead GPU, followed by # (2) inter-node allreduce for all the first lead GPUs in all nodes dist.all_reduce_multigpu(grads_batch_coalesced, group=self.nccl_reduction_group_id) # Now only work on the first device of self.device_ids, uncoalesce # the gradients for each bucket grads_batch_coalesced[0] /= dist.get_world_size() grads_batch_reduced = _unflatten_dense_tensors(grads_batch_coalesced[0], grads_batch[0]) for grad, reduced in zip(grads_batch[0], grads_batch_reduced): grad.copy_(reduced) # clear the gradients and save memory for replicas for module in self._module_copies[1:]: for param in module.parameters(): if param.requires_grad: param.grad = None param.data.set_() # Now register the reduction hook on the parameters for p in self.module.parameters(): if not p.requires_grad: continue @torch.utils.hooks.unserializable_hook def allreduce_hook(*unused): Variable._execution_engine.queue_callback(reduction_fn_nccl) p.register_hook(allreduce_hook)
def run(size, rank): modell = model.CNN() #modell = model.AlexNet() optimizer = torch.optim.Adam(modell.parameters(), lr=LR) loss_func = torch.nn.CrossEntropyLoss() if(IID == True): train_loader = Mnist().get_train_data() test_data = Mnist().get_test_data() else: if(rank > 0): if(rank == 1): train_loader = Mnist_noniid().get_train_data1() test_data = Mnist_noniid().get_test_data1() if(rank == 2): train_loader = Mnist_noniid().get_train_data2() test_data = Mnist_noniid().get_test_data2() if(rank == 3): train_loader = Mnist_noniid().get_train_data3() test_data = Mnist_noniid().get_test_data3() if(rank == 4): train_loader = Mnist_noniid().get_train_data4() test_data = Mnist_noniid().get_test_data4() if(rank == 5): train_loader = Mnist_noniid().get_train_data5() test_data = Mnist_noniid().get_test_data5() #size = dist.get_world_size() #rank = dist.get_rank() #train_loader = Mnist().get_train_data() #test_data = Mnist().get_test_data() for step, (b_x, b_y) in enumerate(test_data): test_x = b_x test_y = b_y group_list = [] for i in range(size): group_list.append(i) group = dist.new_group(group_list) for epoch in range(MAX_EPOCH): modell = get_new_model(modell, group) #current_model = copy.deepcopy(modell) test_output, last_layer = modell(test_x) pred_y = torch.max(test_output, 1)[1].data.numpy() accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0)) for step, (b_x, b_y) in enumerate(train_loader): #modell = get_new_model(modell) #current_model = copy.deepcopy(modell) output = modell(b_x)[0] loss = loss_func(output, b_y) optimizer.zero_grad() loss.backward() optimizer.step() for param in modell.parameters(): dist.reduce(param.data, dst=0, op=dist.reduce_op.SUM, group=group) f = open('./test.txt', 'a') print('Epoch: ', epoch, ' Rank: ', rank, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % accuracy, file=f) print('Epoch: ', epoch, ' Rank: ', rank, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % accuracy) f.close()
def run(rank, workers, model, save_path, train_data, test_data): # Get the initial model from the server _group = [w for w in workers].append(0) group = dist.new_group(_group) for p in model.parameters(): tmp_p = torch.zeros_like(p) dist.scatter(tensor=tmp_p, src=0, group=group) p.data = tmp_p print('Model recved successfully!') if args.model in ['MnistCNN', 'AlexNet', 'ResNet18OnCifar10']: optimizer = MySGD(model.parameters(), lr=0.1) else: optimizer = MySGD(model.parameters(), lr=0.01) if args.model in ['MnistCNN', 'AlexNet']: criterion = torch.nn.NLLLoss() else: criterion = torch.nn.CrossEntropyLoss() if args.model in ['AlexNet', 'ResNet18OnCifar10']: decay_period = 50 else: decay_period = 100 print('Begin!') time_logs = open("./record" + str(rank), 'w') for epoch in range(args.epochs): batch_interval = 0.0 batch_comp_interval = 0.0 batch_comm_interval = 0.0 s_time = time.time() model.train() # Reduce the learning rate LR in some specific epochs #if args.model == 'AlexNet': if (epoch + 1) % decay_period == 0: for param_group in optimizer.param_groups: param_group['lr'] *= 0.1 print('LR Decreased! Now: {}'.format(param_group['lr'])) epoch_train_loss = 0 for batch_idx, (data, target) in enumerate(train_data): batch_start_time = time.time() data, target = Variable(data), Variable(target) optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() delta_ws = optimizer.get_delta_w() batch_comp_time = time.time() # Synchronization # send epoch train loss firstly dist.gather(loss.data, dst=0, group=group) for idx, param in enumerate(model.parameters()): dist.gather(tensor=delta_ws[idx], dst=0, group=group) recv = torch.zeros_like(delta_ws[idx]) dist.scatter(tensor=recv, src=0, group=group) param.data = recv epoch_train_loss += loss.data.item() batch_end_time = time.time() batch_interval += batch_end_time - batch_start_time batch_comp_interval += batch_comp_time - batch_start_time batch_comm_interval += batch_end_time - batch_comp_time logs = torch.tensor([ 0.0, batch_interval / (batch_idx + 1), batch_comp_interval / (batch_idx + 1), batch_comm_interval / (batch_idx + 1) ]) time_logs.write(str(logs) + '\n') time_logs.flush() print('Rank {}, Epoch {}, Loss:{}'.format(rank, epoch, loss.data.item())) e_time = time.time() #epoch_train_loss /= len(train_data) #epoch_train_loss = format(epoch_train_loss, '.4f') # test the model #test_loss, acc = test_model(rank, model, test_data, criterion=criterion) acc = 0.0 batch_interval /= batch_idx batch_comp_interval /= batch_idx batch_comm_interval /= batch_idx logs = torch.tensor( [acc, batch_interval, batch_comp_interval, batch_comm_interval]) time_logs.write(str(logs) + '\n') time_logs.flush() #dist.gather(tensor=logs, dst = 0, group = group) time_logs.close()
def run(rank, model, train_pics, train_bsz): workers = [v + 1 for v in range(args.workers_num)] _group = [w for w in workers].append(rank) group = dist.new_group(_group) for p in model.parameters(): scatter_p_list = [p.data for _ in range(len(workers) + 1)] dist.scatter(tensor=p.data, scatter_list=scatter_p_list, group=group) print('Model Sent Finished!') print('Begin!') trainloss_file = './trainloss' + args.model + '.txt' if (os.path.isfile(trainloss_file)): os.remove(trainloss_file) f_trainloss = open(trainloss_file, 'a') tmp = [ (0, 0) for _ in range(int(math.ceil(train_pics / (len(workers) * train_bsz)))) ] g_list = [[torch.zeros_like(param.data) for param in model.parameters()] for _ in range(len(workers) + 1)] s_time = time.time() epoch_time = s_time global_clock = 0 epoch_train_loss = 0.0 sparsification_ratio = 0.0 for epoch in range(args.epochs): for batch_idx, (_, _) in enumerate(tmp): # receive the list of train loss from workers info_list = [torch.tensor([0.0]) for _ in range(len(workers) + 1)] dist.gather(tensor=torch.tensor([0.0]), gather_list=info_list, group=group) batch_loss = sum(info_list).item() / len(workers) epoch_train_loss += batch_loss sparsification_ratio_list = [ torch.tensor([0.0]) for _ in range(len(workers) + 1) ] dist.gather(tensor=torch.tensor([0.0]), gather_list=sparsification_ratio_list, group=group) batch_ratio = sum(sparsification_ratio_list).item() / len(workers) sparsification_ratio += batch_ratio param_idx = 0 g_quare_sum = torch.tensor([0.]) for layer_idx, param in enumerate(model.parameters()): tensor = torch.zeros_like(param.data) # FIXME FIXED:gather_list中的每个Tensor都必须是新的对象,否则会出问题 gather_list = [ torch.zeros_like(param.data) for _ in range(len(workers) + 1) ] dist.gather(tensor=tensor, gather_list=gather_list, group=group) tensor = sum(gather_list) / len(workers) param.data -= tensor scatter_list = [param.data for _ in range(len(workers) + 1)] dist.scatter(tensor=tensor, scatter_list=scatter_list, group=group) param_idx += 1 global_clock += 1 f_trainloss.write( str(args.this_rank) + "\t" + str(epoch_train_loss / float(batch_idx + 1)) + "\t" + str(batch_loss) + "\t" + str(0) + "\t" + str(0) + "\t" + str(epoch) + "\t" + str(0) + "\t" + str(0) + "\t" + str(0) + "\t" + str(batch_ratio) + "\t" + str(sparsification_ratio / float(batch_idx + 1)) + "\t" + str(global_clock) + '\n') # f_trainloss.flush() #print('Done {}/{}!'.format(batch_idx, len(tmp))) print('Done Epoch {}/{}!'.format(epoch + 1, args.epochs)) e_epoch_time = time.time() # test_acc, batch_interval, batch_comp_interval, batch_comm_interval logs = torch.tensor([0.0, 0.0, 0.0, 0.0]) logs_list = [torch.zeros_like(logs) for _ in range(len(workers) + 1)] #dist.gather(tensor = logs, gather_list = logs_list, group = group) test_acc, batch_interval, batch_comp_interval, batch_comm_interval = zip( *logs_list) test_acc = sum(test_acc) / len(workers) batch_interval = sum(batch_interval) / len(workers) batch_comp_interval = sum(batch_comp_interval) / len(workers) batch_comm_interval = sum(batch_comm_interval) / len(workers) # f_trainloss.write(str(args.this_rank) + # "\t" + str(epoch_train_loss / float(batch_idx+1)) + # "\t" + str(0) + # "\t" + str(e_epoch_time - epoch_time) + # "\t" + str(e_epoch_time - s_time) + # "\t" + str(epoch) + # "\t" + str(test_acc.item()) + # "\t" + str(batch_interval.item()) + # "\t" + str(batch_comp_interval.item()) + # "\t" + str(batch_comm_interval.item()) + # "\t" + str(sparsification_ratio / float(batch_idx+1)) + # "\t" + str(global_clock) + # '\n') f_trainloss.flush() epoch_time = e_epoch_time epoch_train_loss = 0.0 sparsification_ratio = 0.0 if (epoch + 1) % 2 == 0: if not os.path.exists('model_state'): os.makedirs('model_state') torch.save( model.state_dict(), 'model_state' + '/' + args.model + '_' + str(epoch + 1) + '.pkl') f_trainloss.close()
def run(rank, workers, model, save_path, train_data, test_data): # 获取ps端传来的模型初始参数 _group = [w for w in workers].append(0) group = dist.new_group(_group) param_num = 0 for p in model.parameters(): tmp_p = torch.zeros_like(p) param_num += torch.numel(tmp_p) dist.scatter(tensor=tmp_p, src=0, group=group) p.data = tmp_p print('Model recved successfully!') compression_num = int(param_num * args.ratio) compression_num = compression_num if compression_num > 0 else 1 dist.gather(torch.tensor([compression_num / param_num]), dst=0, group=group) if args.model in ['MnistCNN', 'AlexNet', 'ResNet18OnCifar10']: learning_rate = 0.1 else: learning_rate = args.lr optimizer = MySGD(model.parameters(), lr=learning_rate) if args.model in ['MnistCNN', 'AlexNet']: criterion = torch.nn.NLLLoss() elif args.model in ['Abalone', 'Bodyfat', 'Housing']: criterion = torch.nn.MSELoss() else: criterion = torch.nn.CrossEntropyLoss() if args.model in ['AlexNet', 'ResNet18OnCifar10']: decay_period = 50 elif args.model in [ 'LROnMnist', 'LROnCifar10', 'LROnCifar100', 'Abalone', 'Bodyfat', 'Housing' ]: decay_period = 1000000 # learning rate is constant for LR (convex) models else: decay_period = 100 print('Begin!') global_clock = 0 g_remain = [torch.zeros_like(param.data) for param in model.parameters()] time_logs = open("./record" + str(rank), 'w') for epoch in range(args.epochs): batch_interval = 0.0 batch_comp_interval = 0.0 batch_comm_interval = 0.0 s_time = time.time() model.train() # AlexNet在指定epoch减少学习率LR #if args.model == 'AlexNet': if (epoch + 1) % decay_period == 0: for param_group in optimizer.param_groups: param_group['lr'] *= 0.1 print('LR Decreased! Now: {}'.format(param_group['lr'])) epoch_train_loss = 0 for batch_idx, (data, target) in enumerate(train_data): batch_start_time = time.time() data, target = Variable(data), Variable(target) optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() delta_ws = optimizer.get_delta_w() g_remain, g_large_change = get_upload(g_remain, delta_ws, args.ratio, args.isCompensate) batch_comp_time = time.time() # 同步操作 # send epoch train loss firstly dist.gather(loss.data, dst=0, group=group) for idx, param in enumerate(model.parameters()): dist.gather(tensor=g_large_change[idx], dst=0, group=group) recv = torch.zeros_like(delta_ws[idx]) dist.scatter(tensor=recv, src=0, group=group) param.data = recv epoch_train_loss += loss.data.item() batch_end_time = time.time() batch_interval += batch_end_time - batch_start_time batch_comp_interval += batch_comp_time - batch_start_time batch_comm_interval += batch_end_time - batch_comp_time logs = torch.tensor([ 0.0, batch_interval / (batch_idx + 1), batch_comp_interval / (batch_idx + 1), batch_comm_interval / (batch_idx + 1) ]) time_logs.write(str(logs) + '\n') time_logs.flush() print('Rank {}, Epoch {}, Loss:{}'.format(rank, epoch, loss.data.item())) e_time = time.time() #epoch_train_loss /= len(train_data) #epoch_train_loss = format(epoch_train_loss, '.4f') # 训练结束后进行test #test_loss, acc = test_model(rank, model, test_data, criterion=criterion) acc = 0.0 batch_interval /= batch_idx + 1 batch_comp_interval /= batch_idx + 1 batch_comm_interval /= batch_idx + 1 logs = torch.tensor( [acc, batch_interval, batch_comp_interval, batch_comm_interval]) time_logs.write(str(logs) + '\n') time_logs.flush() #dist.gather(tensor=logs, dst = 0, group = group) time_logs.close()
def run(rank, workers, model, save_path, train_data, test_data, global_lr): # Get the initial model from the server print(workers) _group = [w for w in workers].append(0) group = dist.new_group(_group) for p in model.parameters(): tmp_p = torch.zeros_like(p) dist.scatter(tensor=tmp_p, src=0, group=group) p.data = tmp_p print('Model recved successfully!') temp_lr = global_lr.get() if args.model in ['MnistCNN', 'AlexNet', 'ResNet18OnCifar10']: optimizer = MySGD(model.parameters(), lr=temp_lr) else: optimizer = MySGD(model.parameters(), lr=temp_lr) if args.model in ['MnistCNN', 'AlexNet']: criterion = torch.nn.NLLLoss() else: criterion = torch.nn.CrossEntropyLoss() print('Begin!') # the parameters that will be transferred to the thread model_cache = [p.data + 0.0 for p in model.parameters()] global_update = [torch.zeros_like(p) for p in model.parameters()] local_update = [torch.zeros_like(p) for p in model.parameters()] it_count = Value(c_float, 0.) # count update times in an iteration by local worker data_lock = Lock() update_lock = Queue() update_lock.put(1) loss_t = torch.tensor(0.0) receive_end = Value(c_bool, False) batch_communication_interval = Value(c_float, 0.0) stale_in_iteration = Value(c_float, 0.) sender_td = Thread(target=sender, args=( model_cache, global_update, local_update, it_count, loss_t, update_lock, data_lock, group, receive_end, batch_communication_interval, stale_in_iteration, ), daemon=True) sender_td.start() time_logs = open("./record" + str(rank), 'w') osp_logs = open("./log" + str(rank), 'w') Stale_Threshold = args.stale_threshold for epoch in range(args.epochs): batch_interval = 0.0 batch_comp_interval = 0.0 s_time = time.time() model.train() # Decay the learning at the specific epoch # learning rate should be decreased on server due to unmatched updating speed between local worker and server if not global_lr.empty(): g_lr = global_lr.get() if args.model == 'AlexNet': for param_group in optimizer.param_groups: param_group['lr'] = g_lr print('LR Decreased! Now: {}'.format(param_group['lr'])) for batch_idx, (data, target) in enumerate(train_data): batch_start_time = time.time() data, target = Variable(data), Variable(target) optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() delta_ws = optimizer.get_delta_w() optimizer.step() # Aggregate local update data_lock.acquire() # aggregate loss loss_t.data += loss.data it_count.value += 1 for g_idx, update in enumerate(local_update): update.data += delta_ws[g_idx].data data_lock.release() batch_computation_time = time.time() # Open the lock once the local update has at least one gradient if it_count.value == 1: update_lock.put(1) while it_count.value >= Stale_Threshold: pass if receive_end.value: receive_end.value = False for idx, param in enumerate(model.parameters()): param.data = model_cache[idx] # without local update # param.data = model_cache[idx] - global_update[idx] # with local update batch_end_time = time.time() batch_interval += batch_end_time - batch_start_time batch_comp_interval += batch_computation_time - batch_start_time osp_logs.write( str(batch_end_time - batch_start_time) + "\t" + str(batch_computation_time - batch_start_time) + "\n") osp_logs.flush() print('Rank {}, Epoch {}, Loss:{}'.format(rank, epoch, loss.data.item())) e_time = time.time() # 训练结束后进行test #test_loss, acc = test_model(rank, model, test_data, criterion=criterion) acc = 0.0 batch_interval /= batch_idx batch_comp_interval /= batch_idx logs = torch.tensor([ acc, batch_interval, batch_comp_interval, batch_communication_interval.value, stale_in_iteration.value ]) time_logs.write(str(logs) + '\n') time_logs.flush() # dist.gather(tensor=logs, dst = 0, group = group) time_logs.close() sender_td.join()
def run(size, rank): modell = model.CNN() # modell = model.AlexNet() optimizer = torch.optim.Adam(modell.parameters(), lr=LR) loss_func = torch.nn.CrossEntropyLoss() # size = dist.get_world_size() # rank = dist.get_rank() if (IID == True): train_loader = Mnist().get_train_data() test_data = Mnist().get_test_data() test_x = torch.unsqueeze(test_data.test_data, dim=1).type( torch.FloatTensor ) / 255. # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1) test_y = test_data.test_labels else: if (rank > 0): if (rank == 1): train_loader = Mnist_noniid().get_train_data1() test_data = Mnist_noniid().get_test_data1() test_x = torch.unsqueeze(test_data.test_data, dim=1).type( torch.FloatTensor ) / 255. # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1) test_y = test_data.test_labels if (rank == 2): train_loader = Mnist_noniid().get_train_data2() test_data = Mnist_noniid().get_test_data2() test_x = torch.unsqueeze(test_data.test_data, dim=1).type( torch.FloatTensor ) / 255. # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1) test_y = test_data.test_labels if (rank == 3): train_loader = Mnist_noniid().get_train_data3() test_data = Mnist_noniid().get_test_data3() test_x = torch.unsqueeze(test_data.test_data, dim=1).type( torch.FloatTensor ) / 255. # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1) test_y = test_data.test_labels if (rank == 4): train_loader = Mnist_noniid().get_train_data4() test_data = Mnist_noniid().get_test_data4() test_x = torch.unsqueeze(test_data.test_data, dim=1).type( torch.FloatTensor ) / 255. # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1) test_y = test_data.test_labels if (rank == 5): train_loader = Mnist_noniid().get_train_data5() test_data = Mnist_noniid().get_test_data5() test_x = torch.unsqueeze(test_data.test_data, dim=1).type( torch.FloatTensor ) / 255. # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1) test_y = test_data.test_labels # test_x = torch.unsqueeze(test_data.test_data, dim=1).type( # torch.FloatTensor) / 255. # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1) # test_y = test_data.test_labels group_list = [] for i in range(size): group_list.append(i) group = dist.new_group(group_list) for epoch in range(MAX_EPOCH): modell = get_new_model(modell) # current_model = copy.deepcopy(modell) for step, (b_x, b_y) in enumerate(train_loader): # modell = get_new_model(modell) # current_model = copy.deepcopy(modell) output = modell(b_x)[0] loss = loss_func(output, b_y) optimizer.zero_grad() loss.backward() optimizer.step() # new_model = copy.deepcopy(modell) # for param1, param2 in zip( current_model.parameters(), new_model.parameters() ): # dist.reduce(param2.data-param1.data, dst=0, op=dist.reduce_op.SUM, group=group) for param in modell.parameters(): dist.reduce(param, dst=0, op=dist.reduce_op.SUM, group=group) test_output, last_layer = modell(test_x) pred_y = torch.max(test_output, 1)[1].data.numpy() accuracy = float( (pred_y == test_y.data.numpy()).astype(int).sum()) / float( test_y.size(0)) print('Epoch: ', epoch, ' Rank: ', rank, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % accuracy)
def run(rank, model, train_pics, train_bsz, g_lr): workers = [v + 1 for v in range(args.workers_num)] print(workers) _group = [w for w in workers].append(rank) group = dist.new_group(_group) for p in model.parameters(): scatter_p_list = [p.data for _ in range(len(workers) + 1)] dist.scatter(tensor=p.data, scatter_list=scatter_p_list, group=group) # initialize learning rate if args.model in ['MnistCNN', 'AlexNet', 'ResNet18OnCifar10']: global_lr = 0.1 else: global_lr = 0.01 for w in workers: g_lr.put(global_lr) print('Model Sent Finished!') print('Begin!') epoch_train_loss = 0.0 trainloss_file = './trainloss' + args.model + '.txt' log_file = './log' + args.model + '.txt' if (os.path.isfile(trainloss_file)): os.remove(trainloss_file) if (os.path.isfile(log_file)): os.remove(log_file) f_trainloss = open(trainloss_file, 'a') f_log = open(log_file, 'a') tmp = [(0, 0) for _ in range( int( math.ceil( int(train_pics * args.data_ratio) / (len(workers) * train_bsz))))] s_time = time.time() epoch_time = s_time total_iteration_time = 0 iteration_times_count_epoch = 0 iteration_times_epoch = len(tmp) * len(workers) if args.model in ['AlexNet', 'ResNet18OnCifar10']: decay_period = 50 else: decay_period = 100 global_clock = 0 epoch_clock = 0 real_epoch = 0 for epoch in range(args.epochs): for batch_idx, (_, _) in enumerate(tmp): batch_start_time = time.time() # receive the list of train loss and local iteration count from workers loss_it_list = [ torch.tensor([0.0, 0.0]) for _ in range(len(workers) + 1) ] dist.gather(tensor=torch.tensor([0.0, 1.0]), gather_list=loss_it_list, group=group) loss_avg = [loss_it[0] / loss_it[1] for loss_it in loss_it_list] epoch_train_loss += sum(loss_avg).item() / len(workers) iteration_times = sum(loss_it_list)[1] total_iteration_time += iteration_times iteration_times_count_epoch += iteration_times # receive global update from each worker for update_idx, param in enumerate(model.parameters()): tensor = torch.zeros_like(param.data) gather_list = [ torch.zeros_like(param.data) for _ in range(len(workers) + 1) ] dist.gather(tensor=tensor, gather_list=gather_list, group=group) # here we only use average temperally for simplicity tensor = sum(gather_list) / len(workers) param.data -= tensor # send updated model back to workers for param_idx, param in enumerate(model.parameters()): scatter_list = [param.data for _ in range(len(workers) + 1)] dist.scatter(tensor=torch.zeros_like(param.data), scatter_list=scatter_list, group=group) global_clock += 1 epoch_clock += 1 # Decay the learning at the specific epoch # not update intra the epoch temp_epoch = int(total_iteration_time / iteration_times_epoch) + 1 if temp_epoch > real_epoch: real_epoch = temp_epoch if real_epoch % decay_period == 0: global_lr *= 0.1 for w in workers: g_lr.put(global_lr) print('LR Decreased! Now: {}'.format(global_lr)) # evaluate the time of each iteration batch_end_time = time.time() batch_time_interval = batch_end_time - batch_start_time f_log.write( str(batch_time_interval) + "\t" + str(iteration_times) + "\n") f_log.flush() if iteration_times_count_epoch >= iteration_times_epoch: e_epoch_time = time.time() iteration_times_count_epoch -= iteration_times_epoch # test_acc, batch_interval, batch_comp_interval, batch_comm_interval logs = torch.tensor([0.0, 0.0, 0.0, 0.0]) logs_list = [ torch.zeros_like(logs) for _ in range(len(workers) + 1) ] # dist.gather(tensor=logs, gather_list=logs_list, group=group) test_acc, batch_interval, batch_comp_interval, batch_comm_interval = zip( *logs_list) test_acc = sum(test_acc) / len(workers) batch_interval = sum(batch_interval) / len(workers) batch_comp_interval = sum(batch_comp_interval) / len(workers) batch_comm_interval = sum(batch_comm_interval) / len(workers) f_trainloss.write( str(args.this_rank) + "\t" + str(epoch_train_loss / float(epoch_clock)) + "\t" + str(0) + "\t" + str(e_epoch_time - epoch_time) + "\t" + str(e_epoch_time - s_time) + "\t" + str(int(total_iteration_time / iteration_times_epoch)) + "\t" + str(test_acc.item()) + "\t" + str(batch_interval.item()) + "\t" + str(batch_comp_interval.item()) + "\t" + str(batch_comm_interval.item()) + "\t" + str(global_clock) + '\n') print("total_iteration_time:{}, iteration_times:{}".format( total_iteration_time, iteration_times_epoch)) f_trainloss.flush() epoch_time = e_epoch_time epoch_train_loss = 0.0 epoch_clock = 0 print('Done Epoch {}/{}!'.format(epoch + 1, args.epochs)) f_trainloss.close()