コード例 #1
0
ファイル: multi_slave.py プロジェクト: chenc10/pytorch-fl
def run(size, rank, epoch, batchsize):
    #print('run')
    if MODEL == 'CNN' and DATA_SET == 'Mnist':
        model = CNNMnist()
    if MODEL == 'CNN' and DATA_SET == 'Cifar10':
        model = CNNCifar()
    if MODEL == 'ResNet18' and DATA_SET == 'Cifar10':
        model = ResNet18()
    model.cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-5)
    loss_func = torch.nn.CrossEntropyLoss()

    train_loader = get_local_data(size, rank, batchsize)
    if rank == 0:
        test_loader = get_testset(rank)
        #fo = open("file_multi"+str(rank)+".txt", 'w')

    group_list = [i for i in range(size)]
    group = dist.new_group(group_list)

    model, round = load_model(model, group, rank)
    while round < MAX_ROUND:
        if rank == 0:
            accuracy = 0
            for step, (test_x, test_y) in enumerate(test_loader):
                test_x = test_x.cuda()
                test_y = test_y.cuda()
                test_output = model(test_x)
                pred_y = torch.max(test_output, 1)[1].data.cpu().numpy()
                accuracy += float(
                    (pred_y
                     == test_y.data.cpu().numpy()).astype(int).sum()) / float(
                         test_y.size(0))
            accuracy /= 100
            print('Round: ', round, 'Rank: ', rank,
                  '| test accuracy: %.4f' % accuracy)
            #fo.write(str(round) + "    " + str(rank) + "    " + str(accuracy) + "\n")

        for epoch_cnt in range(epoch):
            for step, (b_x, b_y) in enumerate(train_loader):
                b_x = b_x.cuda()
                b_y = b_y.cuda()
                optimizer.zero_grad()
                output = model(b_x)
                loss = loss_func(output, b_y)
                loss.backward()
                optimizer.step()

        model = all_reduce(model, size, group)
        #if (round+1) % ROUND_NUMBER_FOR_REDUCE == 0:
        #model = all_reduce(model, size, group)

        if (round + 1) % ROUND_NUMBER_FOR_SAVE == 0:
            save_model(model, round + 1, rank)
        round += 1
コード例 #2
0
def load_model(group, rank):
    if MODEL == 'CNN' and DATASET == 'Mnist':
        model = CNNMnist()
    if MODEL == 'CNN' and DATASET == 'Cifar10':
        model = CNNCifar()
    if MODEL == 'ResNet18' and DATASET == 'Cifar10':
        model = ResNet18()
    if CUDA:
        model.cuda()
    if False and SAVE and os.path.exists('autoencoder' + str(rank) + '.t7'):
        logging('===> Try resume from checkpoint')
        checkpoint = torch.load('autoencoder' + str(rank) + '.t7')
        model.load_state_dict(checkpoint['state'])
        logging('model loaded')
    else:
        init_param(model, 0, group)
        logging('model created')
    return model