def run(rank, model, train_data, test_data, queue, param_q, stop_flag): # Get the initial model from the server while True: if not param_q.empty(): param_dict = param_q.get() tmp = OrderedDict( map(lambda item: (item[0], torch.from_numpy(item[1])), param_dict.items())) model.load_state_dict(tmp) break print('Model recved successfully!') if args.model in ['MnistCNN', 'AlexNet', 'ResNet18OnCifar10']: optimizer = MySGD(model.parameters(), lr=0.1) else: optimizer = MySGD(model.parameters(), lr=0.01) if args.model in ['MnistCNN', 'AlexNet']: criterion = torch.nn.NLLLoss() else: criterion = torch.nn.CrossEntropyLoss() if args.model in ['AlexNet', 'ResNet18OnCifar10']: decay_period = 50 else: decay_period = 100 print('Begin!') time_logs = open("./record" + str(rank), 'w') for epoch in range(int(args.epochs)): model.train() # Decay the learning at the specific epoch #if args.model == 'AlexNet': if (epoch + 1) % decay_period == 0: for param_group in optimizer.param_groups: param_group['lr'] *= 0.1 print('LR Decreased! Now: {}'.format(param_group['lr'])) epoch_train_loss = 0 for batch_idx, (data, target) in enumerate(train_data): it_start = time.time() data, target = Variable(data), Variable(target) optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() delta_ws = optimizer.get_delta_w() it_comp_end = time.time() # noinspection PyBroadException try: if delta_ws: queue.put({ rank: [loss.data.numpy(), np.array(args.train_bsz), False] }) for delta in delta_ws: dist.send(tensor=delta, dst=0) for idx, param in enumerate(model.parameters()): tmp_tensor = torch.zeros_like(param.data) dist.recv(tensor=tmp_tensor, src=0) param.data = tmp_tensor #print('Rank {}, Epoch {}, Batch {}/{}, Loss:{}' # .format(rank, epoch, batch_idx, len(train_data), loss.data[0])) except Exception as e: print(str(e)) print('Should Stop: {}!'.format(stop_flag.value)) break it_comm_end = time.time() it_duration = it_comm_end - it_start it_comp_duration = it_comp_end - it_start it_comm_duration = it_comm_end - it_comp_end time_logs.write( str(it_duration) + "\t" + str(it_comp_duration) + "\t" + str(it_comm_duration) + "\n") time_logs.flush() # test the model print("test Model:", epoch) # test_model(rank, model, test_data, criterion=criterion) if stop_flag.value: break queue.put({rank: [[], [], True]}) time_logs.close() print("Worker {} has completed epoch {}!".format(args.this_rank, epoch))
def run(rank, workers, model, save_path, train_data, test_data): # Get the initial model from the server _group = [w for w in workers].append(0) group = dist.new_group(_group) for p in model.parameters(): tmp_p = torch.zeros_like(p) dist.scatter(tensor=tmp_p, src=0, group=group) p.data = tmp_p print('Model recved successfully!') if args.model in ['MnistCNN', 'AlexNet', 'ResNet18OnCifar10']: optimizer = MySGD(model.parameters(), lr=0.1) else: optimizer = MySGD(model.parameters(), lr=0.01) if args.model in ['MnistCNN', 'AlexNet']: criterion = torch.nn.NLLLoss() else: criterion = torch.nn.CrossEntropyLoss() if args.model in ['AlexNet', 'ResNet18OnCifar10']: decay_period = 50 else: decay_period = 100 print('Begin!') time_logs = open("./record" + str(rank), 'w') for epoch in range(args.epochs): batch_interval = 0.0 batch_comp_interval = 0.0 batch_comm_interval = 0.0 s_time = time.time() model.train() # Reduce the learning rate LR in some specific epochs #if args.model == 'AlexNet': if (epoch + 1) % decay_period == 0: for param_group in optimizer.param_groups: param_group['lr'] *= 0.1 print('LR Decreased! Now: {}'.format(param_group['lr'])) epoch_train_loss = 0 for batch_idx, (data, target) in enumerate(train_data): batch_start_time = time.time() data, target = Variable(data), Variable(target) optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() delta_ws = optimizer.get_delta_w() batch_comp_time = time.time() # Synchronization # send epoch train loss firstly dist.gather(loss.data, dst=0, group=group) for idx, param in enumerate(model.parameters()): dist.gather(tensor=delta_ws[idx], dst=0, group=group) recv = torch.zeros_like(delta_ws[idx]) dist.scatter(tensor=recv, src=0, group=group) param.data = recv epoch_train_loss += loss.data.item() batch_end_time = time.time() batch_interval += batch_end_time - batch_start_time batch_comp_interval += batch_comp_time - batch_start_time batch_comm_interval += batch_end_time - batch_comp_time logs = torch.tensor([ 0.0, batch_interval / (batch_idx + 1), batch_comp_interval / (batch_idx + 1), batch_comm_interval / (batch_idx + 1) ]) time_logs.write(str(logs) + '\n') time_logs.flush() print('Rank {}, Epoch {}, Loss:{}'.format(rank, epoch, loss.data.item())) e_time = time.time() #epoch_train_loss /= len(train_data) #epoch_train_loss = format(epoch_train_loss, '.4f') # test the model #test_loss, acc = test_model(rank, model, test_data, criterion=criterion) acc = 0.0 batch_interval /= batch_idx batch_comp_interval /= batch_idx batch_comm_interval /= batch_idx logs = torch.tensor( [acc, batch_interval, batch_comp_interval, batch_comm_interval]) time_logs.write(str(logs) + '\n') time_logs.flush() #dist.gather(tensor=logs, dst = 0, group = group) time_logs.close()
def run(rank, workers, model, save_path, train_data, test_data, global_lr): # Get the initial model from the server print(workers) _group = [w for w in workers].append(0) group = dist.new_group(_group) for p in model.parameters(): tmp_p = torch.zeros_like(p) dist.scatter(tensor=tmp_p, src=0, group=group) p.data = tmp_p print('Model recved successfully!') temp_lr = global_lr.get() if args.model in ['MnistCNN', 'AlexNet', 'ResNet18OnCifar10']: optimizer = MySGD(model.parameters(), lr=temp_lr) else: optimizer = MySGD(model.parameters(), lr=temp_lr) if args.model in ['MnistCNN', 'AlexNet']: criterion = torch.nn.NLLLoss() else: criterion = torch.nn.CrossEntropyLoss() print('Begin!') # the parameters that will be transferred to the thread model_cache = [p.data + 0.0 for p in model.parameters()] global_update = [torch.zeros_like(p) for p in model.parameters()] local_update = [torch.zeros_like(p) for p in model.parameters()] it_count = Value(c_float, 0.) # count update times in an iteration by local worker data_lock = Lock() update_lock = Queue() update_lock.put(1) loss_t = torch.tensor(0.0) receive_end = Value(c_bool, False) batch_communication_interval = Value(c_float, 0.0) stale_in_iteration = Value(c_float, 0.) sender_td = Thread(target=sender, args=( model_cache, global_update, local_update, it_count, loss_t, update_lock, data_lock, group, receive_end, batch_communication_interval, stale_in_iteration, ), daemon=True) sender_td.start() time_logs = open("./record" + str(rank), 'w') osp_logs = open("./log" + str(rank), 'w') Stale_Threshold = args.stale_threshold for epoch in range(args.epochs): batch_interval = 0.0 batch_comp_interval = 0.0 s_time = time.time() model.train() # Decay the learning at the specific epoch # learning rate should be decreased on server due to unmatched updating speed between local worker and server if not global_lr.empty(): g_lr = global_lr.get() if args.model == 'AlexNet': for param_group in optimizer.param_groups: param_group['lr'] = g_lr print('LR Decreased! Now: {}'.format(param_group['lr'])) for batch_idx, (data, target) in enumerate(train_data): batch_start_time = time.time() data, target = Variable(data), Variable(target) optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() delta_ws = optimizer.get_delta_w() optimizer.step() # Aggregate local update data_lock.acquire() # aggregate loss loss_t.data += loss.data it_count.value += 1 for g_idx, update in enumerate(local_update): update.data += delta_ws[g_idx].data data_lock.release() batch_computation_time = time.time() # Open the lock once the local update has at least one gradient if it_count.value == 1: update_lock.put(1) while it_count.value >= Stale_Threshold: pass if receive_end.value: receive_end.value = False for idx, param in enumerate(model.parameters()): param.data = model_cache[idx] # without local update # param.data = model_cache[idx] - global_update[idx] # with local update batch_end_time = time.time() batch_interval += batch_end_time - batch_start_time batch_comp_interval += batch_computation_time - batch_start_time osp_logs.write( str(batch_end_time - batch_start_time) + "\t" + str(batch_computation_time - batch_start_time) + "\n") osp_logs.flush() print('Rank {}, Epoch {}, Loss:{}'.format(rank, epoch, loss.data.item())) e_time = time.time() # 训练结束后进行test #test_loss, acc = test_model(rank, model, test_data, criterion=criterion) acc = 0.0 batch_interval /= batch_idx batch_comp_interval /= batch_idx logs = torch.tensor([ acc, batch_interval, batch_comp_interval, batch_communication_interval.value, stale_in_iteration.value ]) time_logs.write(str(logs) + '\n') time_logs.flush() # dist.gather(tensor=logs, dst = 0, group = group) time_logs.close() sender_td.join()
def run(rank, workers, model, save_path, train_data, test_data): # 获取ps端传来的模型初始参数 _group = [w for w in workers].append(0) group = dist.new_group(_group) param_num = 0 for p in model.parameters(): tmp_p = torch.zeros_like(p) param_num += torch.numel(tmp_p) dist.scatter(tensor=tmp_p, src=0, group=group) p.data = tmp_p print('Model recved successfully!') compression_num = int(param_num * args.ratio) compression_num = compression_num if compression_num > 0 else 1 dist.gather(torch.tensor([compression_num / param_num]), dst=0, group=group) if args.model in ['MnistCNN', 'AlexNet', 'ResNet18OnCifar10']: learning_rate = 0.1 else: learning_rate = args.lr optimizer = MySGD(model.parameters(), lr=learning_rate) if args.model in ['MnistCNN', 'AlexNet']: criterion = torch.nn.NLLLoss() elif args.model in ['Abalone', 'Bodyfat', 'Housing']: criterion = torch.nn.MSELoss() else: criterion = torch.nn.CrossEntropyLoss() if args.model in ['AlexNet', 'ResNet18OnCifar10']: decay_period = 50 elif args.model in [ 'LROnMnist', 'LROnCifar10', 'LROnCifar100', 'Abalone', 'Bodyfat', 'Housing' ]: decay_period = 1000000 # learning rate is constant for LR (convex) models else: decay_period = 100 print('Begin!') global_clock = 0 g_remain = [torch.zeros_like(param.data) for param in model.parameters()] time_logs = open("./record" + str(rank), 'w') for epoch in range(args.epochs): batch_interval = 0.0 batch_comp_interval = 0.0 batch_comm_interval = 0.0 s_time = time.time() model.train() # AlexNet在指定epoch减少学习率LR #if args.model == 'AlexNet': if (epoch + 1) % decay_period == 0: for param_group in optimizer.param_groups: param_group['lr'] *= 0.1 print('LR Decreased! Now: {}'.format(param_group['lr'])) epoch_train_loss = 0 for batch_idx, (data, target) in enumerate(train_data): batch_start_time = time.time() data, target = Variable(data), Variable(target) optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() delta_ws = optimizer.get_delta_w() g_remain, g_large_change = get_upload(g_remain, delta_ws, args.ratio, args.isCompensate) batch_comp_time = time.time() # 同步操作 # send epoch train loss firstly dist.gather(loss.data, dst=0, group=group) for idx, param in enumerate(model.parameters()): dist.gather(tensor=g_large_change[idx], dst=0, group=group) recv = torch.zeros_like(delta_ws[idx]) dist.scatter(tensor=recv, src=0, group=group) param.data = recv epoch_train_loss += loss.data.item() batch_end_time = time.time() batch_interval += batch_end_time - batch_start_time batch_comp_interval += batch_comp_time - batch_start_time batch_comm_interval += batch_end_time - batch_comp_time logs = torch.tensor([ 0.0, batch_interval / (batch_idx + 1), batch_comp_interval / (batch_idx + 1), batch_comm_interval / (batch_idx + 1) ]) time_logs.write(str(logs) + '\n') time_logs.flush() print('Rank {}, Epoch {}, Loss:{}'.format(rank, epoch, loss.data.item())) e_time = time.time() #epoch_train_loss /= len(train_data) #epoch_train_loss = format(epoch_train_loss, '.4f') # 训练结束后进行test #test_loss, acc = test_model(rank, model, test_data, criterion=criterion) acc = 0.0 batch_interval /= batch_idx + 1 batch_comp_interval /= batch_idx + 1 batch_comm_interval /= batch_idx + 1 logs = torch.tensor( [acc, batch_interval, batch_comp_interval, batch_comm_interval]) time_logs.write(str(logs) + '\n') time_logs.flush() #dist.gather(tensor=logs, dst = 0, group = group) time_logs.close()
def run(rank, model, train_data, test_data, queue, param_q, stop_flag): # 获取ps端传来的模型初始参数 while True: if not param_q.empty(): param_dict = param_q.get() tmp = OrderedDict( map(lambda item: (item[0], torch.from_numpy(item[1])), param_dict.items())) model.load_state_dict(tmp) break print('Model recved successfully!') if args.model in ['MnistCNN', 'AlexNet', 'ResNet18OnCifar10']: optimizer = MySGD(model.parameters(), lr=0.1) else: optimizer = MySGD(model.parameters(), lr=0.01) if args.model in ['MnistCNN', 'AlexNet']: criterion = torch.nn.NLLLoss() else: criterion = torch.nn.CrossEntropyLoss() if args.model in ['AlexNet', 'ResNet18OnCifar10']: decay_period = 50 else: decay_period = 100 print('Begin!') time_logs = open("./record" + str(rank), 'w') for epoch in range(int(args.epochs)): batch_interval = 0.0 batch_comp_interval = 0.0 batch_comm_interval = 0.0 batch_push_interval = 0.0 batch_pull_interval = 0.0 model.train() # AlexNet在指定epoch减少学习率LR #if args.model == 'AlexNet': if (epoch + 1) % decay_period == 0: for param_group in optimizer.param_groups: param_group['lr'] *= 0.1 print('LR Decreased! Now: {}'.format(param_group['lr'])) epoch_train_loss = 0 for batch_idx, (data, target) in enumerate(train_data): batch_start_time = time.time() data, target = Variable(data), Variable(target) optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() delta_ws = optimizer.get_delta_w() batch_comp_time = time.time() # noinspection PyBroadException try: # 捕获异常,异常来源于ps进程的停止 if delta_ws: queue.put({ rank: [loss.data.numpy(), np.array(args.train_bsz), False] }) for delta in delta_ws: dist.send(tensor=delta, dst=0) batch_push_time = time.time() for idx, param in enumerate(model.parameters()): tmp_tensor = torch.zeros_like(param.data) dist.recv(tensor=tmp_tensor, src=0) param.data = tmp_tensor batch_tmp_time = time.time() batch_pull_time = time.time() #print('Rank {}, Epoch {}, Batch {}/{}, Loss:{}' # .format(rank, epoch, batch_idx, len(train_data), loss.data[0])) except Exception as e: print(str(e)) print('Should Stop: {}!'.format(stop_flag.value)) break batch_interval += batch_pull_time - batch_start_time batch_comp_interval += batch_comp_time - batch_start_time batch_comm_interval += batch_pull_time - batch_comp_time batch_push_interval += batch_push_time - batch_comp_time batch_pull_interval += batch_pull_time - batch_push_time b_interval = batch_interval / (batch_idx + 1) b_comp_interval = batch_comp_interval / (batch_idx + 1) b_comm_interval = batch_comm_interval / (batch_idx + 1) b_push_interval = batch_push_interval / (batch_idx + 1) b_pull_interval = batch_pull_interval / (batch_idx + 1) logs = torch.tensor([ 0.0, b_interval, b_comp_interval, b_comm_interval, b_push_interval, b_pull_interval, batch_pull_time - batch_tmp_time ]) time_logs.write(str(logs) + '\n') time_logs.flush() batch_interval /= batch_idx batch_comp_interval /= batch_idx batch_comm_interval /= batch_idx batch_push_interval /= batch_idx batch_pull_interval /= batch_idx logs = torch.tensor([ 0.0, batch_interval, batch_comp_interval, batch_comm_interval, batch_push_interval, batch_pull_interval ]) time_logs.write(str(epoch) + '\t' + str(logs) + '\n') time_logs.flush() # 训练结束后进行test print("test Model:", epoch) # test_model(rank, model, test_data, criterion=criterion) if stop_flag.value: break queue.put({rank: [[], [], True]}) time_logs.close() print("Worker {} has completed epoch {}!".format(args.this_rank, epoch))