Пример #1
0
def run(size, rank, epoch, batchsize):
    if MODEL == 'CNN' and DATA_SET == 'Mnist':
        model = CNNMnist()
    if MODEL == 'CNN' and DATA_SET == 'Cifar10':
        model = CNNCifar()
    if MODEL == 'ResNet18' and DATA_SET == 'Cifar10':
        model = ResNet18()
    optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-5)
    loss_func = torch.nn.CrossEntropyLoss()

    train_loader = get_local_data(size, rank, batchsize)
    if rank == 0:
        test_x, test_y = get_testset()
        #fo = open("file_multi"+str(rank)+".txt", 'w')

    group_list = [i for i in range(size)]
    group = dist.new_group(group_list)

    model, round = load_model(model, group, rank)
    while round < MAX_ROUND:
        sys.stdout.flush()
        if rank == 0:
            test_output = model(test_x)
            pred_y = torch.max(test_output, 1)[1].data.numpy()
            accuracy = float(
                (pred_y == test_y.data.numpy()).astype(int).sum()) / float(
                    test_y.size(0))
            print('Round: ', round, ' Rank: ', rank,
                  '| test accuracy: %.2f' % accuracy)
            #fo.write(str(round) + "    " + str(rank) + "    " + str(accuracy) + "\n")

        for epoch_cnt in range(epoch):
            for step, (b_x, b_y) in enumerate(train_loader):
                optimizer.zero_grad()
                output = model(b_x)
                loss = loss_func(output, b_y)
                loss.backward()
                optimizer.step()

        # model = exchange(model, size, rank)
        model = all_reduce(model, size, group)

        # if (round+1) % ROUND_NUMBER_FOR_REDUCE == 0:
        #     model = all_reduce(model, size, group)

        if (round + 1) % ROUND_NUMBER_FOR_SAVE == 0:
            save_model(model, round + 1, rank)
        round += 1
Пример #2
0
def load_model(group, rank):
    if MODEL == 'CNN' and DATA_SET == 'Mnist':
        model = CNNMnist()
    if MODEL == 'CNN' and DATA_SET == 'Cifar10':
        model = CNNCifar()
    if MODEL == 'ResNet18' and DATA_SET == 'Cifar10':
        model = ResNet18()
    if SAVE and os.path.exists('autoencoder'+str(rank)+'.t7'):
        logging('===> Try resume from checkpoint')
        checkpoint = torch.load('autoencoder'+str(rank)+'.t7')
        model.load_state_dict(checkpoint['state'])
        round = checkpoint['round']
        print('===> Load last checkpoint data')
    else:
        round = 0
        init_param(model, 0, group)
    return model, round
Пример #3
0
def initialize_model(num_classes=10, num_channels=1):

    if model_name == "resnet":
        model_ft = models.resnet18(pretrained=False)
        num_ftrs = model_ft.fc.in_features
        model_ft.conv1 = nn.Conv2d(1,
                                   64,
                                   kernel_size=7,
                                   stride=2,
                                   padding=3,
                                   bias=False)

        model_ft.fc = nn.Linear(num_ftrs, num_classes)
    else:
        model_ft = CNNMnist(num_channels=num_channels, num_classes=num_classes)

    return model_ft
def load_model(group, rank):
    if MODEL == 'CNN' and DATASET == 'Mnist':
        model = CNNMnist()
    if MODEL == 'CNN' and DATASET == 'Cifar10':
        model = CNNCifar()
    if MODEL == 'ResNet18' and DATASET == 'Cifar10':
        model = ResNet18()
    if CUDA:
        model.cuda()
    if False and SAVE and os.path.exists('autoencoder' + str(rank) + '.t7'):
        logging('===> Try resume from checkpoint')
        checkpoint = torch.load('autoencoder' + str(rank) + '.t7')
        model.load_state_dict(checkpoint['state'])
        logging('model loaded')
    else:
        init_param(model, 0, group)
        logging('model created')
    return model