コード例 #1
0
def run(size, rank, epoch, batchsize):
    #print('run')
    if MODEL == 'CNN' and DATA_SET == 'KWS':
        model = CNNKws()
    if MODEL == 'CNN' and DATA_SET == 'Cifar10':
        model = CNNCifar()
    if MODEL == 'ResNet18' and DATA_SET == 'Cifar10':
        model = ResNet18()

    model = model.cuda()

    optimizer = torch.optim.SGD(model.parameters(), lr=LR, weight_decay=1e-3)
    loss_func = torch.nn.CrossEntropyLoss()

    train_loader = get_local_data(size, rank, batchsize)
    if rank == 0:
        test_loader = get_testset(rank)
        #fo = open("file_multi"+str(rank)+".txt", 'w')

    group_list = [i for i in range(size)]
    group = dist.new_group(group_list)

    model, round = load_model(model, group, rank)
    while round < MAX_ROUND:
        sys.stdout.flush()
        if rank == 0:
            accuracy = 0
            positive_test_number = 0
            total_test_number = 0
            for step, (test_x, test_y) in enumerate(test_loader):
                test_x = test_x.cuda()
                test_y = test_y.cuda()
                test_output = model(test_x)
                pred_y = torch.max(test_output, 1)[1].data.cpu().numpy()
                positive_test_number += (
                    pred_y == test_y.data.cpu().numpy()).astype(int).sum()
                # print(positive_test_number)
                total_test_number += float(test_y.size(0))
            accuracy = positive_test_number / total_test_number
            print('Round: ', round, ' Rank: ', rank,
                  '| test accuracy: %.4f' % accuracy)
            #fo.write(str(round) + "    " + str(rank) + "    " + str(accuracy) + "\n")

        for epoch_cnt in range(epoch):
            for step, (b_x, b_y) in enumerate(train_loader):
                b_x = b_x.cuda()
                b_y = b_y.cuda()
                optimizer.zero_grad()
                output = model(b_x)
                loss = loss_func(output, b_y)
                loss.backward()
                optimizer.step()

        model = all_reduce(model, size, group)
        #if (round+1) % ROUND_NUMBER_FOR_REDUCE == 0:
        #model = all_reduce(model, size, group)

        if (round + 1) % ROUND_NUMBER_FOR_SAVE == 0:
            save_model(model, round + 1, rank)
        round += 1
コード例 #2
0
def load_model(group, rank):
    if MODEL == 'CNN' and DATA_SET == 'Mnist':
        model = CNNMnist()
    if MODEL == 'CNN' and DATA_SET == 'Cifar10':
        model = CNNCifar()
    if MODEL == 'ResNet18' and DATA_SET == 'Cifar10':
        model = ResNet18()
    if SAVE and os.path.exists('autoencoder'+str(rank)+'.t7'):
        logging('===> Try resume from checkpoint')
        checkpoint = torch.load('autoencoder'+str(rank)+'.t7')
        model.load_state_dict(checkpoint['state'])
        round = checkpoint['round']
        print('===> Load last checkpoint data')
    else:
        round = 0
        init_param(model, 0, group)
    return model, round
コード例 #3
0
def load_model(group, rank):
    if MODEL == 'CNN' and DATASET == 'Mnist':
        model = CNNMnist()
    if MODEL == 'CNN' and DATASET == 'Cifar10':
        model = CNNCifar()
    if MODEL == 'ResNet18' and DATASET == 'Cifar10':
        model = ResNet18()
    if CUDA:
        model.cuda()
    if False and SAVE and os.path.exists('autoencoder' + str(rank) + '.t7'):
        logging('===> Try resume from checkpoint')
        checkpoint = torch.load('autoencoder' + str(rank) + '.t7')
        model.load_state_dict(checkpoint['state'])
        logging('model loaded')
    else:
        init_param(model, 0, group)
        logging('model created')
    return model