def run(size, rank, epoch, batchsize): #print('run') if MODEL == 'CNN' and DATA_SET == 'Mnist': model = CNNMnist() if MODEL == 'CNN' and DATA_SET == 'Cifar10': model = CNNCifar() if MODEL == 'ResNet18' and DATA_SET == 'Cifar10': model = ResNet18() model.cuda() optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-5) loss_func = torch.nn.CrossEntropyLoss() train_loader = get_local_data(size, rank, batchsize) if rank == 0: test_loader = get_testset(rank) #fo = open("file_multi"+str(rank)+".txt", 'w') group_list = [i for i in range(size)] group = dist.new_group(group_list) model, round = load_model(model, group, rank) while round < MAX_ROUND: if rank == 0: accuracy = 0 for step, (test_x, test_y) in enumerate(test_loader): test_x = test_x.cuda() test_y = test_y.cuda() test_output = model(test_x) pred_y = torch.max(test_output, 1)[1].data.cpu().numpy() accuracy += float( (pred_y == test_y.data.cpu().numpy()).astype(int).sum()) / float( test_y.size(0)) accuracy /= 100 print('Round: ', round, 'Rank: ', rank, '| test accuracy: %.4f' % accuracy) #fo.write(str(round) + " " + str(rank) + " " + str(accuracy) + "\n") for epoch_cnt in range(epoch): for step, (b_x, b_y) in enumerate(train_loader): b_x = b_x.cuda() b_y = b_y.cuda() optimizer.zero_grad() output = model(b_x) loss = loss_func(output, b_y) loss.backward() optimizer.step() model = all_reduce(model, size, group) #if (round+1) % ROUND_NUMBER_FOR_REDUCE == 0: #model = all_reduce(model, size, group) if (round + 1) % ROUND_NUMBER_FOR_SAVE == 0: save_model(model, round + 1, rank) round += 1
def load_model(group, rank): if MODEL == 'CNN' and DATASET == 'Mnist': model = CNNMnist() if MODEL == 'CNN' and DATASET == 'Cifar10': model = CNNCifar() if MODEL == 'ResNet18' and DATASET == 'Cifar10': model = ResNet18() if CUDA: model.cuda() if False and SAVE and os.path.exists('autoencoder' + str(rank) + '.t7'): logging('===> Try resume from checkpoint') checkpoint = torch.load('autoencoder' + str(rank) + '.t7') model.load_state_dict(checkpoint['state']) logging('model loaded') else: init_param(model, 0, group) logging('model created') return model