def _test_scatter_helper(self, group, group_id, rank): for dest in group: tensor = _build_tensor(dest + 1, -1) expected_tensor = _build_tensor(dest + 1, rank) tensors = ( [_build_tensor(dest + 1, i) for i in group] if rank == dest else [] ) dist.scatter(tensor, src=dest, scatter_list=tensors, group=group_id) self.assertEqual(tensor, expected_tensor) self._barrier()
def sender(model_cache, global_update, local_update, it_count, loss_t, update_lock, data_lock, group, receive_end, batch_communication_interval, stale_in_iteration): comm_count = 0 while True: # this is a queue that is controled by computer module # At lease one gradient has been generated, the sender has gradient to send update_lock.get() # Note: A lock should be here before accessing to local_update # copy local update to global update, and then set local update to be 0 data_lock.acquire() loss = loss_t.data loss_t.data = torch.tensor(0.) it_times = it_count.value it_count.value = 0. for idx, update in enumerate(global_update): update.data = local_update[idx].data local_update[idx].data = torch.zeros_like(update.data) data_lock.release() comm_s = time.time() loss_it = torch.tensor([float(loss), it_times]) dist.gather(tensor=loss_it, dst=0, group=group) for idx, update in enumerate(global_update): dist.gather(tensor=update.data, dst=0, group=group) for idx, param in enumerate(model_cache): dist.scatter(tensor=param.data, src=0, group=group) receive_end.value = True comm_e = time.time() comm_count += 1 # compute average communication time of an iteration batch_communication_interval.value = ( batch_communication_interval.value * (comm_count - 1) + (comm_e - comm_s)) / comm_count stale_in_iteration.value = (stale_in_iteration.value * (comm_count - 1) + it_times) / comm_count return
def run(rank, workers, model, save_path, train_data, test_data): # Get the initial model from the server _group = [w for w in workers].append(0) group = dist.new_group(_group) for p in model.parameters(): tmp_p = torch.zeros_like(p) dist.scatter(tensor=tmp_p, src=0, group=group) p.data = tmp_p print('Model recved successfully!') if args.model in ['MnistCNN', 'AlexNet', 'ResNet18OnCifar10']: optimizer = MySGD(model.parameters(), lr=0.1) else: optimizer = MySGD(model.parameters(), lr=0.01) if args.model in ['MnistCNN', 'AlexNet']: criterion = torch.nn.NLLLoss() else: criterion = torch.nn.CrossEntropyLoss() if args.model in ['AlexNet', 'ResNet18OnCifar10']: decay_period = 50 else: decay_period = 100 print('Begin!') time_logs = open("./record" + str(rank), 'w') for epoch in range(args.epochs): batch_interval = 0.0 batch_comp_interval = 0.0 batch_comm_interval = 0.0 s_time = time.time() model.train() # Reduce the learning rate LR in some specific epochs #if args.model == 'AlexNet': if (epoch + 1) % decay_period == 0: for param_group in optimizer.param_groups: param_group['lr'] *= 0.1 print('LR Decreased! Now: {}'.format(param_group['lr'])) epoch_train_loss = 0 for batch_idx, (data, target) in enumerate(train_data): batch_start_time = time.time() data, target = Variable(data), Variable(target) optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() delta_ws = optimizer.get_delta_w() batch_comp_time = time.time() # Synchronization # send epoch train loss firstly dist.gather(loss.data, dst=0, group=group) for idx, param in enumerate(model.parameters()): dist.gather(tensor=delta_ws[idx], dst=0, group=group) recv = torch.zeros_like(delta_ws[idx]) dist.scatter(tensor=recv, src=0, group=group) param.data = recv epoch_train_loss += loss.data.item() batch_end_time = time.time() batch_interval += batch_end_time - batch_start_time batch_comp_interval += batch_comp_time - batch_start_time batch_comm_interval += batch_end_time - batch_comp_time logs = torch.tensor([ 0.0, batch_interval / (batch_idx + 1), batch_comp_interval / (batch_idx + 1), batch_comm_interval / (batch_idx + 1) ]) time_logs.write(str(logs) + '\n') time_logs.flush() print('Rank {}, Epoch {}, Loss:{}'.format(rank, epoch, loss.data.item())) e_time = time.time() #epoch_train_loss /= len(train_data) #epoch_train_loss = format(epoch_train_loss, '.4f') # test the model #test_loss, acc = test_model(rank, model, test_data, criterion=criterion) acc = 0.0 batch_interval /= batch_idx batch_comp_interval /= batch_idx batch_comm_interval /= batch_idx logs = torch.tensor( [acc, batch_interval, batch_comp_interval, batch_comm_interval]) time_logs.write(str(logs) + '\n') time_logs.flush() #dist.gather(tensor=logs, dst = 0, group = group) time_logs.close()
def run(rank, model, train_pics, train_bsz): workers = [v + 1 for v in range(args.workers_num)] _group = [w for w in workers].append(rank) group = dist.new_group(_group) for p in model.parameters(): scatter_p_list = [p.data for _ in range(len(workers) + 1)] dist.scatter(tensor=p.data, scatter_list=scatter_p_list, group=group) print('Model Sent Finished!') print('Begin!') trainloss_file = './trainloss' + args.model + '.txt' if (os.path.isfile(trainloss_file)): os.remove(trainloss_file) f_trainloss = open(trainloss_file, 'a') tmp = [ (0, 0) for _ in range(int(math.ceil(train_pics / (len(workers) * train_bsz)))) ] g_list = [[torch.zeros_like(param.data) for param in model.parameters()] for _ in range(len(workers) + 1)] s_time = time.time() epoch_time = s_time global_clock = 0 epoch_train_loss = 0.0 sparsification_ratio = 0.0 for epoch in range(args.epochs): for batch_idx, (_, _) in enumerate(tmp): # receive the list of train loss from workers info_list = [torch.tensor([0.0]) for _ in range(len(workers) + 1)] dist.gather(tensor=torch.tensor([0.0]), gather_list=info_list, group=group) batch_loss = sum(info_list).item() / len(workers) epoch_train_loss += batch_loss sparsification_ratio_list = [ torch.tensor([0.0]) for _ in range(len(workers) + 1) ] dist.gather(tensor=torch.tensor([0.0]), gather_list=sparsification_ratio_list, group=group) batch_ratio = sum(sparsification_ratio_list).item() / len(workers) sparsification_ratio += batch_ratio param_idx = 0 g_quare_sum = torch.tensor([0.]) for layer_idx, param in enumerate(model.parameters()): tensor = torch.zeros_like(param.data) # FIXME FIXED:gather_list中的每个Tensor都必须是新的对象,否则会出问题 gather_list = [ torch.zeros_like(param.data) for _ in range(len(workers) + 1) ] dist.gather(tensor=tensor, gather_list=gather_list, group=group) tensor = sum(gather_list) / len(workers) param.data -= tensor scatter_list = [param.data for _ in range(len(workers) + 1)] dist.scatter(tensor=tensor, scatter_list=scatter_list, group=group) param_idx += 1 global_clock += 1 f_trainloss.write( str(args.this_rank) + "\t" + str(epoch_train_loss / float(batch_idx + 1)) + "\t" + str(batch_loss) + "\t" + str(0) + "\t" + str(0) + "\t" + str(epoch) + "\t" + str(0) + "\t" + str(0) + "\t" + str(0) + "\t" + str(batch_ratio) + "\t" + str(sparsification_ratio / float(batch_idx + 1)) + "\t" + str(global_clock) + '\n') # f_trainloss.flush() #print('Done {}/{}!'.format(batch_idx, len(tmp))) print('Done Epoch {}/{}!'.format(epoch + 1, args.epochs)) e_epoch_time = time.time() # test_acc, batch_interval, batch_comp_interval, batch_comm_interval logs = torch.tensor([0.0, 0.0, 0.0, 0.0]) logs_list = [torch.zeros_like(logs) for _ in range(len(workers) + 1)] #dist.gather(tensor = logs, gather_list = logs_list, group = group) test_acc, batch_interval, batch_comp_interval, batch_comm_interval = zip( *logs_list) test_acc = sum(test_acc) / len(workers) batch_interval = sum(batch_interval) / len(workers) batch_comp_interval = sum(batch_comp_interval) / len(workers) batch_comm_interval = sum(batch_comm_interval) / len(workers) # f_trainloss.write(str(args.this_rank) + # "\t" + str(epoch_train_loss / float(batch_idx+1)) + # "\t" + str(0) + # "\t" + str(e_epoch_time - epoch_time) + # "\t" + str(e_epoch_time - s_time) + # "\t" + str(epoch) + # "\t" + str(test_acc.item()) + # "\t" + str(batch_interval.item()) + # "\t" + str(batch_comp_interval.item()) + # "\t" + str(batch_comm_interval.item()) + # "\t" + str(sparsification_ratio / float(batch_idx+1)) + # "\t" + str(global_clock) + # '\n') f_trainloss.flush() epoch_time = e_epoch_time epoch_train_loss = 0.0 sparsification_ratio = 0.0 if (epoch + 1) % 2 == 0: if not os.path.exists('model_state'): os.makedirs('model_state') torch.save( model.state_dict(), 'model_state' + '/' + args.model + '_' + str(epoch + 1) + '.pkl') f_trainloss.close()
def run(rank, workers, model, save_path, train_data, test_data): # 获取ps端传来的模型初始参数 _group = [w for w in workers].append(0) group = dist.new_group(_group) param_num = 0 for p in model.parameters(): tmp_p = torch.zeros_like(p) param_num += torch.numel(tmp_p) dist.scatter(tensor=tmp_p, src=0, group=group) p.data = tmp_p print('Model recved successfully!') compression_num = int(param_num * args.ratio) compression_num = compression_num if compression_num > 0 else 1 dist.gather(torch.tensor([compression_num / param_num]), dst=0, group=group) if args.model in ['MnistCNN', 'AlexNet', 'ResNet18OnCifar10']: learning_rate = 0.1 else: learning_rate = args.lr optimizer = MySGD(model.parameters(), lr=learning_rate) if args.model in ['MnistCNN', 'AlexNet']: criterion = torch.nn.NLLLoss() elif args.model in ['Abalone', 'Bodyfat', 'Housing']: criterion = torch.nn.MSELoss() else: criterion = torch.nn.CrossEntropyLoss() if args.model in ['AlexNet', 'ResNet18OnCifar10']: decay_period = 50 elif args.model in [ 'LROnMnist', 'LROnCifar10', 'LROnCifar100', 'Abalone', 'Bodyfat', 'Housing' ]: decay_period = 1000000 # learning rate is constant for LR (convex) models else: decay_period = 100 print('Begin!') global_clock = 0 g_remain = [torch.zeros_like(param.data) for param in model.parameters()] time_logs = open("./record" + str(rank), 'w') for epoch in range(args.epochs): batch_interval = 0.0 batch_comp_interval = 0.0 batch_comm_interval = 0.0 s_time = time.time() model.train() # AlexNet在指定epoch减少学习率LR #if args.model == 'AlexNet': if (epoch + 1) % decay_period == 0: for param_group in optimizer.param_groups: param_group['lr'] *= 0.1 print('LR Decreased! Now: {}'.format(param_group['lr'])) epoch_train_loss = 0 for batch_idx, (data, target) in enumerate(train_data): batch_start_time = time.time() data, target = Variable(data), Variable(target) optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() delta_ws = optimizer.get_delta_w() g_remain, g_large_change = get_upload(g_remain, delta_ws, args.ratio, args.isCompensate) batch_comp_time = time.time() # 同步操作 # send epoch train loss firstly dist.gather(loss.data, dst=0, group=group) for idx, param in enumerate(model.parameters()): dist.gather(tensor=g_large_change[idx], dst=0, group=group) recv = torch.zeros_like(delta_ws[idx]) dist.scatter(tensor=recv, src=0, group=group) param.data = recv epoch_train_loss += loss.data.item() batch_end_time = time.time() batch_interval += batch_end_time - batch_start_time batch_comp_interval += batch_comp_time - batch_start_time batch_comm_interval += batch_end_time - batch_comp_time logs = torch.tensor([ 0.0, batch_interval / (batch_idx + 1), batch_comp_interval / (batch_idx + 1), batch_comm_interval / (batch_idx + 1) ]) time_logs.write(str(logs) + '\n') time_logs.flush() print('Rank {}, Epoch {}, Loss:{}'.format(rank, epoch, loss.data.item())) e_time = time.time() #epoch_train_loss /= len(train_data) #epoch_train_loss = format(epoch_train_loss, '.4f') # 训练结束后进行test #test_loss, acc = test_model(rank, model, test_data, criterion=criterion) acc = 0.0 batch_interval /= batch_idx + 1 batch_comp_interval /= batch_idx + 1 batch_comm_interval /= batch_idx + 1 logs = torch.tensor( [acc, batch_interval, batch_comp_interval, batch_comm_interval]) time_logs.write(str(logs) + '\n') time_logs.flush() #dist.gather(tensor=logs, dst = 0, group = group) time_logs.close()
def run(rank, workers, model, save_path, train_data, test_data, global_lr): # Get the initial model from the server print(workers) _group = [w for w in workers].append(0) group = dist.new_group(_group) for p in model.parameters(): tmp_p = torch.zeros_like(p) dist.scatter(tensor=tmp_p, src=0, group=group) p.data = tmp_p print('Model recved successfully!') temp_lr = global_lr.get() if args.model in ['MnistCNN', 'AlexNet', 'ResNet18OnCifar10']: optimizer = MySGD(model.parameters(), lr=temp_lr) else: optimizer = MySGD(model.parameters(), lr=temp_lr) if args.model in ['MnistCNN', 'AlexNet']: criterion = torch.nn.NLLLoss() else: criterion = torch.nn.CrossEntropyLoss() print('Begin!') # the parameters that will be transferred to the thread model_cache = [p.data + 0.0 for p in model.parameters()] global_update = [torch.zeros_like(p) for p in model.parameters()] local_update = [torch.zeros_like(p) for p in model.parameters()] it_count = Value(c_float, 0.) # count update times in an iteration by local worker data_lock = Lock() update_lock = Queue() update_lock.put(1) loss_t = torch.tensor(0.0) receive_end = Value(c_bool, False) batch_communication_interval = Value(c_float, 0.0) stale_in_iteration = Value(c_float, 0.) sender_td = Thread(target=sender, args=( model_cache, global_update, local_update, it_count, loss_t, update_lock, data_lock, group, receive_end, batch_communication_interval, stale_in_iteration, ), daemon=True) sender_td.start() time_logs = open("./record" + str(rank), 'w') osp_logs = open("./log" + str(rank), 'w') Stale_Threshold = args.stale_threshold for epoch in range(args.epochs): batch_interval = 0.0 batch_comp_interval = 0.0 s_time = time.time() model.train() # Decay the learning at the specific epoch # learning rate should be decreased on server due to unmatched updating speed between local worker and server if not global_lr.empty(): g_lr = global_lr.get() if args.model == 'AlexNet': for param_group in optimizer.param_groups: param_group['lr'] = g_lr print('LR Decreased! Now: {}'.format(param_group['lr'])) for batch_idx, (data, target) in enumerate(train_data): batch_start_time = time.time() data, target = Variable(data), Variable(target) optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() delta_ws = optimizer.get_delta_w() optimizer.step() # Aggregate local update data_lock.acquire() # aggregate loss loss_t.data += loss.data it_count.value += 1 for g_idx, update in enumerate(local_update): update.data += delta_ws[g_idx].data data_lock.release() batch_computation_time = time.time() # Open the lock once the local update has at least one gradient if it_count.value == 1: update_lock.put(1) while it_count.value >= Stale_Threshold: pass if receive_end.value: receive_end.value = False for idx, param in enumerate(model.parameters()): param.data = model_cache[idx] # without local update # param.data = model_cache[idx] - global_update[idx] # with local update batch_end_time = time.time() batch_interval += batch_end_time - batch_start_time batch_comp_interval += batch_computation_time - batch_start_time osp_logs.write( str(batch_end_time - batch_start_time) + "\t" + str(batch_computation_time - batch_start_time) + "\n") osp_logs.flush() print('Rank {}, Epoch {}, Loss:{}'.format(rank, epoch, loss.data.item())) e_time = time.time() # 训练结束后进行test #test_loss, acc = test_model(rank, model, test_data, criterion=criterion) acc = 0.0 batch_interval /= batch_idx batch_comp_interval /= batch_idx logs = torch.tensor([ acc, batch_interval, batch_comp_interval, batch_communication_interval.value, stale_in_iteration.value ]) time_logs.write(str(logs) + '\n') time_logs.flush() # dist.gather(tensor=logs, dst = 0, group = group) time_logs.close() sender_td.join()
def run(rank, model, train_pics, train_bsz, g_lr): workers = [v + 1 for v in range(args.workers_num)] print(workers) _group = [w for w in workers].append(rank) group = dist.new_group(_group) for p in model.parameters(): scatter_p_list = [p.data for _ in range(len(workers) + 1)] dist.scatter(tensor=p.data, scatter_list=scatter_p_list, group=group) # initialize learning rate if args.model in ['MnistCNN', 'AlexNet', 'ResNet18OnCifar10']: global_lr = 0.1 else: global_lr = 0.01 for w in workers: g_lr.put(global_lr) print('Model Sent Finished!') print('Begin!') epoch_train_loss = 0.0 trainloss_file = './trainloss' + args.model + '.txt' log_file = './log' + args.model + '.txt' if (os.path.isfile(trainloss_file)): os.remove(trainloss_file) if (os.path.isfile(log_file)): os.remove(log_file) f_trainloss = open(trainloss_file, 'a') f_log = open(log_file, 'a') tmp = [(0, 0) for _ in range( int( math.ceil( int(train_pics * args.data_ratio) / (len(workers) * train_bsz))))] s_time = time.time() epoch_time = s_time total_iteration_time = 0 iteration_times_count_epoch = 0 iteration_times_epoch = len(tmp) * len(workers) if args.model in ['AlexNet', 'ResNet18OnCifar10']: decay_period = 50 else: decay_period = 100 global_clock = 0 epoch_clock = 0 real_epoch = 0 for epoch in range(args.epochs): for batch_idx, (_, _) in enumerate(tmp): batch_start_time = time.time() # receive the list of train loss and local iteration count from workers loss_it_list = [ torch.tensor([0.0, 0.0]) for _ in range(len(workers) + 1) ] dist.gather(tensor=torch.tensor([0.0, 1.0]), gather_list=loss_it_list, group=group) loss_avg = [loss_it[0] / loss_it[1] for loss_it in loss_it_list] epoch_train_loss += sum(loss_avg).item() / len(workers) iteration_times = sum(loss_it_list)[1] total_iteration_time += iteration_times iteration_times_count_epoch += iteration_times # receive global update from each worker for update_idx, param in enumerate(model.parameters()): tensor = torch.zeros_like(param.data) gather_list = [ torch.zeros_like(param.data) for _ in range(len(workers) + 1) ] dist.gather(tensor=tensor, gather_list=gather_list, group=group) # here we only use average temperally for simplicity tensor = sum(gather_list) / len(workers) param.data -= tensor # send updated model back to workers for param_idx, param in enumerate(model.parameters()): scatter_list = [param.data for _ in range(len(workers) + 1)] dist.scatter(tensor=torch.zeros_like(param.data), scatter_list=scatter_list, group=group) global_clock += 1 epoch_clock += 1 # Decay the learning at the specific epoch # not update intra the epoch temp_epoch = int(total_iteration_time / iteration_times_epoch) + 1 if temp_epoch > real_epoch: real_epoch = temp_epoch if real_epoch % decay_period == 0: global_lr *= 0.1 for w in workers: g_lr.put(global_lr) print('LR Decreased! Now: {}'.format(global_lr)) # evaluate the time of each iteration batch_end_time = time.time() batch_time_interval = batch_end_time - batch_start_time f_log.write( str(batch_time_interval) + "\t" + str(iteration_times) + "\n") f_log.flush() if iteration_times_count_epoch >= iteration_times_epoch: e_epoch_time = time.time() iteration_times_count_epoch -= iteration_times_epoch # test_acc, batch_interval, batch_comp_interval, batch_comm_interval logs = torch.tensor([0.0, 0.0, 0.0, 0.0]) logs_list = [ torch.zeros_like(logs) for _ in range(len(workers) + 1) ] # dist.gather(tensor=logs, gather_list=logs_list, group=group) test_acc, batch_interval, batch_comp_interval, batch_comm_interval = zip( *logs_list) test_acc = sum(test_acc) / len(workers) batch_interval = sum(batch_interval) / len(workers) batch_comp_interval = sum(batch_comp_interval) / len(workers) batch_comm_interval = sum(batch_comm_interval) / len(workers) f_trainloss.write( str(args.this_rank) + "\t" + str(epoch_train_loss / float(epoch_clock)) + "\t" + str(0) + "\t" + str(e_epoch_time - epoch_time) + "\t" + str(e_epoch_time - s_time) + "\t" + str(int(total_iteration_time / iteration_times_epoch)) + "\t" + str(test_acc.item()) + "\t" + str(batch_interval.item()) + "\t" + str(batch_comp_interval.item()) + "\t" + str(batch_comm_interval.item()) + "\t" + str(global_clock) + '\n') print("total_iteration_time:{}, iteration_times:{}".format( total_iteration_time, iteration_times_epoch)) f_trainloss.flush() epoch_time = e_epoch_time epoch_train_loss = 0.0 epoch_clock = 0 print('Done Epoch {}/{}!'.format(epoch + 1, args.epochs)) f_trainloss.close()