Esempio n. 1
0
def get_omniglot(ways, shots):
    omniglot = l2l.vision.datasets.FullOmniglot(root='~/data',
                                                transform=transforms.Compose([
                                                    transforms.Resize(
                                                        28,
                                                        interpolation=LANCZOS),
                                                    transforms.ToTensor(),
                                                    lambda x: 1.0 - x,
                                                ]),
                                                download=True)
    dataset = l2l.data.MetaDataset(omniglot)
    classes = list(range(1623))
    random.shuffle(classes)

    train_transforms = [
        FilterLabels(dataset, classes[:1100]),
        NWays(dataset, ways),
        KShots(dataset, 2 * shots),
        LoadData(dataset),
        RemapLabels(dataset),
        ConsecutiveLabels(dataset),
        RandomClassRotation(dataset, [0.0, 90.0, 180.0, 270.0])
    ]
    train_tasks = l2l.data.TaskDataset(dataset,
                                       task_transforms=train_transforms,
                                       num_tasks=20000)

    valid_transforms = [
        FilterLabels(dataset, classes[1100:1200]),
        NWays(dataset, ways),
        KShots(dataset, 2 * shots),
        LoadData(dataset),
        RemapLabels(dataset),
        ConsecutiveLabels(dataset),
        RandomClassRotation(dataset, [0.0, 90.0, 180.0, 270.0])
    ]
    valid_tasks = l2l.data.TaskDataset(dataset,
                                       task_transforms=valid_transforms,
                                       num_tasks=1024)

    test_transforms = [
        FilterLabels(dataset, classes[1200:]),
        NWays(dataset, ways),
        KShots(dataset, 2 * shots),
        LoadData(dataset),
        RemapLabels(dataset),
        ConsecutiveLabels(dataset),
        RandomClassRotation(dataset, [0.0, 90.0, 180.0, 270.0])
    ]
    test_tasks = l2l.data.TaskDataset(dataset,
                                      task_transforms=test_transforms,
                                      num_tasks=1024)

    return train_tasks, valid_tasks, test_tasks
Esempio n. 2
0
def fc100_tasksets(
    train_ways=5,
    train_samples=10,
    test_ways=5,
    test_samples=10,
    root='~/data',
    **kwargs,
):
    """Tasksets for FC100 benchmarks."""
    data_transform = tv.transforms.ToTensor()
    train_dataset = l2l.vision.datasets.FC100(root=root,
                                              transform=data_transform,
                                              mode='train',
                                              download=True)
    valid_dataset = l2l.vision.datasets.FC100(root=root,
                                              transform=data_transform,
                                              mode='validation',
                                              download=True)
    test_dataset = l2l.vision.datasets.FC100(root=root,
                                             transform=data_transform,
                                             mode='test',
                                             download=True)
    train_dataset = l2l.data.MetaDataset(train_dataset)
    valid_dataset = l2l.data.MetaDataset(valid_dataset)
    test_dataset = l2l.data.MetaDataset(test_dataset)

    train_transforms = [
        NWays(train_dataset, train_ways),
        KShots(train_dataset, train_samples),
        LoadData(train_dataset),
        RemapLabels(train_dataset),
        ConsecutiveLabels(train_dataset),
    ]
    valid_transforms = [
        NWays(valid_dataset, test_ways),
        KShots(valid_dataset, test_samples),
        LoadData(valid_dataset),
        ConsecutiveLabels(valid_dataset),
        RemapLabels(valid_dataset),
    ]
    test_transforms = [
        NWays(test_dataset, test_ways),
        KShots(test_dataset, test_samples),
        LoadData(test_dataset),
        RemapLabels(test_dataset),
        ConsecutiveLabels(test_dataset),
    ]

    _datasets = (train_dataset, valid_dataset, test_dataset)
    _transforms = (train_transforms, valid_transforms, test_transforms)
    return _datasets, _transforms
Esempio n. 3
0
def get_mini_imagenet(ways, shots):
    # Create Datasets
    train_dataset = l2l.vision.datasets.MiniImagenet(root='~/data',
                                                     mode='train',
                                                     download=True)
    valid_dataset = l2l.vision.datasets.MiniImagenet(root='~/data',
                                                     mode='validation',
                                                     download=True)
    test_dataset = l2l.vision.datasets.MiniImagenet(root='~/data',
                                                    mode='test',
                                                    download=True)
    train_dataset = l2l.data.MetaDataset(train_dataset)
    valid_dataset = l2l.data.MetaDataset(valid_dataset)
    test_dataset = l2l.data.MetaDataset(test_dataset)

    train_transforms = [
        NWays(train_dataset, ways),
        KShots(train_dataset, 2 * shots),
        LoadData(train_dataset),
        RemapLabels(train_dataset),
        ConsecutiveLabels(train_dataset),
    ]
    train_tasks = l2l.data.TaskDataset(train_dataset,
                                       task_transforms=train_transforms,
                                       num_tasks=20000)

    valid_transforms = [
        NWays(valid_dataset, ways),
        KShots(valid_dataset, 2 * shots),
        LoadData(valid_dataset),
        ConsecutiveLabels(valid_dataset),
        RemapLabels(valid_dataset),
    ]
    valid_tasks = l2l.data.TaskDataset(valid_dataset,
                                       task_transforms=valid_transforms,
                                       num_tasks=600)

    test_transforms = [
        NWays(test_dataset, ways),
        KShots(test_dataset, 2 * shots),
        LoadData(test_dataset),
        RemapLabels(test_dataset),
        ConsecutiveLabels(test_dataset),
    ]
    test_tasks = l2l.data.TaskDataset(test_dataset,
                                      task_transforms=test_transforms,
                                      num_tasks=600)

    return train_tasks, valid_tasks, test_tasks
def mini_imagenet_tasksets(
    train_ways=5,
    train_samples=10,
    test_ways=5,
    test_samples=10,
    root='~/data',
    **kwargs,
):
    """Tasksets for mini-ImageNet benchmarks."""
    train_dataset = l2l.vision.datasets.MiniImagenet(root=root,
                                                     mode='train',
                                                     download=True)
    valid_dataset = l2l.vision.datasets.MiniImagenet(root=root,
                                                     mode='validation',
                                                     download=True)
    test_dataset = l2l.vision.datasets.MiniImagenet(root=root,
                                                    mode='test',
                                                    download=True)
    train_dataset = l2l.data.MetaDataset(train_dataset)
    valid_dataset = l2l.data.MetaDataset(valid_dataset)
    test_dataset = l2l.data.MetaDataset(test_dataset)

    train_transforms = [
        NWays(train_dataset, train_ways),
        KShots(train_dataset, train_samples),
        LoadData(train_dataset),
        RemapLabels(train_dataset),
        ConsecutiveLabels(train_dataset),
    ]
    valid_transforms = [
        NWays(valid_dataset, test_ways),
        KShots(valid_dataset, test_samples),
        LoadData(valid_dataset),
        ConsecutiveLabels(valid_dataset),
        RemapLabels(valid_dataset),
    ]
    test_transforms = [
        NWays(test_dataset, test_ways),
        KShots(test_dataset, test_samples),
        LoadData(test_dataset),
        RemapLabels(test_dataset),
        ConsecutiveLabels(test_dataset),
    ]

    _datasets = (train_dataset, valid_dataset, test_dataset)
    _transforms = (train_transforms, valid_transforms, test_transforms)
    return _datasets, _transforms
Esempio n. 5
0
def create_task_pool(dataset=None, num_tasks=100, ways=5, shot=1):
    dataset = l2l.data.MetaDataset(dataset)
    transforms = [
        NWays(dataset, ways),
        KShots(dataset, shot),
        LoadData(dataset),
    ]
    task_pool = l2l.data.TaskDataset(dataset, task_transforms=transforms, num_tasks=num_tasks)
    return task_pool
Esempio n. 6
0
 def test_k_shots(self):
     data = torch.randn(NUM_DATA, X_SHAPE)
     labels = torch.randint(0, Y_SHAPE, (NUM_DATA, ))
     dataset = MetaDataset(TensorDataset(data, labels))
     for replacement in [False, True]:
         for shots in range(1, 10):
             task_dataset = TaskDataset(dataset,
                                        task_transforms=[
                                            KShots(dataset,
                                                   k=shots,
                                                   replacement=replacement),
                                            LoadData(dataset)
                                        ],
                                        num_tasks=NUM_TASKS)
             for task in task_dataset:
                 bins = task[1].bincount()
                 correct = (bins == shots).sum()
                 self.assertEqual(correct, Y_SHAPE)
Esempio n. 7
0
def tiered_imagenet_tasksets(
    train_ways=5,
    train_samples=10,
    test_ways=5,
    test_samples=10,
    root='~/data',
    data_augmentation=None,
    device=None,
    **kwargs,
):
    """Tasksets for tiered-ImageNet benchmarks."""
    if data_augmentation is None:
        to_tensor = ToTensor() if device is None else lambda x: x
        train_data_transforms = Compose([
            to_tensor,
        ])
        test_data_transforms = train_data_transforms
    elif data_augmentation == 'normalize':
        to_tensor = ToTensor() if device is None else lambda x: x
        train_data_transforms = Compose([
            to_tensor,
        ])
        test_data_transforms = train_data_transforms
    elif data_augmentation == 'lee2019':
        normalize = Normalize(
            mean=[120.39586422/255.0, 115.59361427/255.0, 104.54012653/255.0],
            std=[70.68188272/255.0, 68.27635443/255.0, 72.54505529/255.0],
        )
        to_pil = ToPILImage() if device is not None else lambda x: x
        to_tensor = ToTensor() if device is None else lambda x: x
        train_data_transforms = Compose([
            to_pil,
            RandomCrop(84, padding=8),
            ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ])
        test_data_transforms = Compose([
            to_tensor,
            normalize,
        ])
    else:
        raise('Invalid data_augmentation argument.')

    train_dataset = l2l.vision.datasets.TieredImagenet(
        root=root,
        mode='train',
        transform=ToTensor(),
        download=True,
    )
    valid_dataset = l2l.vision.datasets.TieredImagenet(
        root=root,
        mode='validation',
        transform=ToTensor(),
        download=True,
    )
    test_dataset = l2l.vision.datasets.TieredImagenet(
        root=root,
        mode='test',
        transform=ToTensor(),
        download=True,
    )
    if device is None:
        train_dataset.transform = train_data_transforms
        valid_dataset.transform = test_data_transforms
        test_dataset.transform = test_data_transforms
    else:
        train_dataset = l2l.data.OnDeviceDataset(
            dataset=train_dataset,
            transform=train_data_transforms,
            device=device,
        )
        valid_dataset = l2l.data.OnDeviceDataset(
            dataset=valid_dataset,
            transform=test_data_transforms,
            device=device,
        )
        test_dataset = l2l.data.OnDeviceDataset(
            dataset=test_dataset,
            transform=test_data_transforms,
            device=device,
        )
    train_dataset = l2l.data.MetaDataset(train_dataset)
    valid_dataset = l2l.data.MetaDataset(valid_dataset)
    test_dataset = l2l.data.MetaDataset(test_dataset)

    train_transforms = [
        NWays(train_dataset, train_ways),
        KShots(train_dataset, train_samples),
        LoadData(train_dataset),
        RemapLabels(train_dataset),
        ConsecutiveLabels(train_dataset),
    ]
    valid_transforms = [
        NWays(valid_dataset, test_ways),
        KShots(valid_dataset, test_samples),
        LoadData(valid_dataset),
        ConsecutiveLabels(valid_dataset),
        RemapLabels(valid_dataset),
    ]
    test_transforms = [
        NWays(test_dataset, test_ways),
        KShots(test_dataset, test_samples),
        LoadData(test_dataset),
        RemapLabels(test_dataset),
        ConsecutiveLabels(test_dataset),
    ]

    _datasets = (train_dataset, valid_dataset, test_dataset)
    _transforms = (train_transforms, valid_transforms, test_transforms)
    return _datasets, _transforms
    meta_curetsr_lvl0 = l2l.data.MetaDataset(curetsr_lvl0)
    meta_curetsr_lvl5 = l2l.data.MetaDataset(curetsr_lvl5)

    train_dataset = meta_curetsr_lvl0
    valid_dataset = meta_curetsr_lvl0
    test_dataset = meta_curetsr_lvl5

    classes = list(range(14))  # 14 classes of stop signs
    random.shuffle(classes)
    # Changes, end!

    train_dataset = l2l.data.MetaDataset(train_dataset)
    train_transforms = [
        FilterLabels(train_dataset, classes[:8]),
        NWays(train_dataset, args.train_way),
        KShots(train_dataset, args.train_query + args.shot),
        LoadData(train_dataset),
        RemapLabels(train_dataset),
    ]
    train_tasks = l2l.data.TaskDataset(train_dataset,
                                       task_transforms=train_transforms)
    train_loader = DataLoader(train_tasks, pin_memory=True, shuffle=True)

    valid_dataset = l2l.data.MetaDataset(valid_dataset)
    valid_transforms = [
        FilterLabels(valid_dataset, classes[8:14]),
        NWays(valid_dataset, args.test_way),
        KShots(valid_dataset, args.test_query + args.test_shot),
        LoadData(valid_dataset),
        RemapLabels(valid_dataset),
    ]
Esempio n. 9
0
def main(
    ways=5,
    train_shots=15,
    test_shots=5,
    meta_lr=1.0,
    meta_mom=0.0,
    meta_bsz=5,
    fast_lr=0.001,
    train_bsz=10,
    test_bsz=15,
    train_adapt_steps=8,
    test_adapt_steps=50,
    num_iterations=100000,
    test_interval=100,
    adam=0,  # Use adam or sgd for fast-adapt
    meta_decay=1,  # Linearly decay the meta-lr or not
    cuda=1,
    seed=42,
):

    cuda = bool(cuda)
    use_adam = bool(adam)
    meta_decay = bool(meta_decay)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device('cpu')
    if cuda and torch.cuda.device_count():
        torch.cuda.manual_seed(seed)
        device = torch.device('cuda')

    # Create Datasets
    train_dataset = l2l.vision.datasets.MiniImagenet(root='~/data',
                                                     mode='train')
    valid_dataset = l2l.vision.datasets.MiniImagenet(root='~/data',
                                                     mode='validation')
    test_dataset = l2l.vision.datasets.MiniImagenet(root='~/data', mode='test')
    train_dataset = l2l.data.MetaDataset(train_dataset)
    valid_dataset = l2l.data.MetaDataset(valid_dataset)
    test_dataset = l2l.data.MetaDataset(test_dataset)

    train_transforms = [
        NWays(train_dataset, ways),
        KShots(train_dataset, 2 * train_shots),
        LoadData(train_dataset),
        RemapLabels(train_dataset),
        ConsecutiveLabels(train_dataset),
    ]
    train_tasks = l2l.data.TaskDataset(train_dataset,
                                       task_transforms=train_transforms,
                                       num_tasks=20000)

    valid_transforms = [
        NWays(valid_dataset, ways),
        KShots(valid_dataset, 2 * test_shots),
        LoadData(valid_dataset),
        ConsecutiveLabels(valid_dataset),
        RemapLabels(valid_dataset),
    ]
    valid_tasks = l2l.data.TaskDataset(valid_dataset,
                                       task_transforms=valid_transforms,
                                       num_tasks=600)

    test_transforms = [
        NWays(test_dataset, ways),
        KShots(test_dataset, 2 * test_shots),
        LoadData(test_dataset),
        RemapLabels(test_dataset),
        ConsecutiveLabels(test_dataset),
    ]
    test_tasks = l2l.data.TaskDataset(test_dataset,
                                      task_transforms=test_transforms,
                                      num_tasks=600)

    # Create model
    model = l2l.vision.models.MiniImagenetCNN(ways)
    model.to(device)
    if use_adam:
        opt = optim.Adam(model.parameters(), meta_lr, betas=(meta_mom, 0.999))
    else:
        opt = optim.SGD(model.parameters(), lr=meta_lr, momentum=meta_mom)
    adapt_opt = optim.Adam(model.parameters(), lr=fast_lr, betas=(0, 0.999))
    adapt_opt_state = adapt_opt.state_dict()
    loss = nn.CrossEntropyLoss(reduction='mean')

    for iteration in range(num_iterations):
        # anneal meta-lr
        if meta_decay:
            frac_done = float(iteration) / num_iterations
            new_lr = frac_done * meta_lr + (1 - frac_done) * meta_lr
            for pg in opt.param_groups:
                pg['lr'] = new_lr

        # zero-grad the parameters
        for p in model.parameters():
            p.grad = torch.zeros_like(p.data)

        meta_train_error = 0.0
        meta_train_accuracy = 0.0
        meta_valid_error = 0.0
        meta_valid_accuracy = 0.0
        meta_test_error = 0.0
        meta_test_accuracy = 0.0
        for task in range(meta_bsz):
            # Compute meta-training loss
            learner = deepcopy(model)
            adapt_opt = optim.Adam(learner.parameters(),
                                   lr=fast_lr,
                                   betas=(0, 0.999))
            adapt_opt.load_state_dict(adapt_opt_state)
            batch = train_tasks.sample()
            evaluation_error, evaluation_accuracy = fast_adapt(
                batch,
                learner,
                fast_lr,
                loss,
                adapt_steps=train_adapt_steps,
                batch_size=train_bsz,
                opt=adapt_opt,
                shots=train_shots,
                ways=ways,
                device=device)
            adapt_opt_state = adapt_opt.state_dict()
            for p, l in zip(model.parameters(), learner.parameters()):
                p.grad.data.add_(-1.0, l.data)

            meta_train_error += evaluation_error.item()
            meta_train_accuracy += evaluation_accuracy.item()

            if iteration % test_interval == 0:
                # Compute meta-validation loss
                learner = deepcopy(model)
                adapt_opt = optim.Adam(learner.parameters(),
                                       lr=fast_lr,
                                       betas=(0, 0.999))
                adapt_opt.load_state_dict(adapt_opt_state)
                batch = valid_tasks.sample()
                evaluation_error, evaluation_accuracy = fast_adapt(
                    batch,
                    learner,
                    fast_lr,
                    loss,
                    adapt_steps=test_adapt_steps,
                    batch_size=test_bsz,
                    opt=adapt_opt,
                    shots=test_shots,
                    ways=ways,
                    device=device)
                meta_valid_error += evaluation_error.item()
                meta_valid_accuracy += evaluation_accuracy.item()

                # Compute meta-testing loss
                learner = deepcopy(model)
                adapt_opt = optim.Adam(learner.parameters(),
                                       lr=fast_lr,
                                       betas=(0, 0.999))
                adapt_opt.load_state_dict(adapt_opt_state)
                batch = test_tasks.sample()
                evaluation_error, evaluation_accuracy = fast_adapt(
                    batch,
                    learner,
                    fast_lr,
                    loss,
                    adapt_steps=test_adapt_steps,
                    batch_size=test_bsz,
                    opt=adapt_opt,
                    shots=test_shots,
                    ways=ways,
                    device=device)
                meta_test_error += evaluation_error.item()
                meta_test_accuracy += evaluation_accuracy.item()

        # Print some metrics
        print('\n')
        print('Iteration', iteration)
        print('Meta Train Error', meta_train_error / meta_bsz)
        print('Meta Train Accuracy', meta_train_accuracy / meta_bsz)
        if iteration % test_interval == 0:
            print('Meta Valid Error', meta_valid_error / meta_bsz)
            print('Meta Valid Accuracy', meta_valid_accuracy / meta_bsz)
            print('Meta Test Error', meta_test_error / meta_bsz)
            print('Meta Test Accuracy', meta_test_accuracy / meta_bsz)

        # Average the accumulated gradients and optimize
        for p in model.parameters():
            p.grad.data.mul_(1.0 / meta_bsz).add_(p.data)
        opt.step()
Esempio n. 10
0
def main(
    ways=5,
    shots=5,
    meta_lr=0.002,
    fast_lr=0.1,  # original 0.1
    reg_lambda=0,
    adapt_steps=5,  # original: 5
    meta_bsz=32,
    iters=1000,  # orginal: 1000
    cuda=1,
    seed=42,
):

    cuda = bool(cuda)

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device('cpu')
    if cuda and torch.cuda.device_count():
        torch.cuda.manual_seed(seed)
        device = torch.device('cuda')

    train_dataset = l2l.vision.datasets.MiniImagenet(root='~/data',
                                                     mode='train')
    valid_dataset = l2l.vision.datasets.MiniImagenet(root='~/data',
                                                     mode='validation')
    test_dataset = l2l.vision.datasets.MiniImagenet(root='~/data', mode='test')
    train_dataset = l2l.data.MetaDataset(train_dataset)
    valid_dataset = l2l.data.MetaDataset(valid_dataset)
    test_dataset = l2l.data.MetaDataset(test_dataset)

    train_transforms = [
        NWays(train_dataset, ways),
        KShots(train_dataset, 2 * shots),
        LoadData(train_dataset),
        RemapLabels(train_dataset),
        ConsecutiveLabels(train_dataset),
    ]
    train_tasks = l2l.data.TaskDataset(train_dataset,
                                       task_transforms=train_transforms,
                                       num_tasks=20000)

    valid_transforms = [
        NWays(valid_dataset, ways),
        KShots(valid_dataset, 2 * shots),
        LoadData(valid_dataset),
        ConsecutiveLabels(valid_dataset),
        RemapLabels(valid_dataset),
    ]
    valid_tasks = l2l.data.TaskDataset(valid_dataset,
                                       task_transforms=valid_transforms,
                                       num_tasks=600)

    test_transforms = [
        NWays(test_dataset, ways),
        KShots(test_dataset, 2 * shots),
        LoadData(test_dataset),
        RemapLabels(test_dataset),
        ConsecutiveLabels(test_dataset),
    ]
    test_tasks = l2l.data.TaskDataset(test_dataset,
                                      task_transforms=test_transforms,
                                      num_tasks=600)

    # Create model
    # features = l2l.vision.models.MiniImagenetCNN(ways)
    features = l2l.vision.models.ConvBase(output_size=32,
                                          channels=3,
                                          max_pool=True)
    # for p in  features.parameters():
    #     print(p.shape)
    features = torch.nn.Sequential(features,
                                   Lambda(lambda x: x.view(-1, 1600)))
    features.to(device)
    head = torch.nn.Linear(1600, ways)
    head = l2l.algorithms.MAML(head, lr=fast_lr)
    head.to(device)

    # Setup optimization
    all_parameters = list(features.parameters())

    # optimizer = torch.optim.Adam(all_parameters, lr=meta_lr)

    ## use different learning rates for w and theta
    optimizer = torch.optim.Adam(all_parameters, lr=meta_lr)

    loss = nn.CrossEntropyLoss(reduction='mean')

    training_accuracy = torch.ones(iters)
    test_accuracy = torch.ones(iters)
    running_time = np.ones(iters)
    import time
    start_time = time.time()

    for iteration in range(iters):
        optimizer.zero_grad()
        meta_train_error = 0.0
        meta_train_accuracy = 0.0
        meta_valid_error = 0.0
        meta_valid_accuracy = 0.0
        meta_test_error = 0.0
        meta_test_accuracy = 0.0

        for task in range(meta_bsz):
            # Compute meta-training loss
            learner = head.clone()
            batch = train_tasks.sample()
            evaluation_error, evaluation_accuracy = fast_adapt(
                batch, learner, features, loss, reg_lambda, adapt_steps, shots,
                ways, device)
            evaluation_error.backward()
            meta_train_error += evaluation_error.item()
            meta_train_accuracy += evaluation_accuracy.item()

            # Compute meta-validation loss
            learner = head.clone()
            batch = valid_tasks.sample()
            evaluation_error, evaluation_accuracy = fast_adapt(
                batch, learner, features, loss, reg_lambda, adapt_steps, shots,
                ways, device)
            meta_valid_error += evaluation_error.item()
            meta_valid_accuracy += evaluation_accuracy.item()

            # Compute meta-testing loss
            learner = head.clone()
            batch = test_tasks.sample()
            evaluation_error, evaluation_accuracy = fast_adapt(
                batch, learner, features, loss, reg_lambda, adapt_steps, shots,
                ways, device)
            meta_test_error += evaluation_error.item()
            meta_test_accuracy += evaluation_accuracy.item()

        training_accuracy[iteration] = meta_train_accuracy / meta_bsz
        test_accuracy[iteration] = meta_test_accuracy / meta_bsz

        # Print some metrics
        print('\n')
        print('Iteration', iteration)
        print('Meta Train Error', meta_train_error / meta_bsz)
        print('Meta Train Accuracy', meta_train_accuracy / meta_bsz)
        print('Meta Valid Error', meta_valid_error / meta_bsz)
        print('Meta Valid Accuracy', meta_valid_accuracy / meta_bsz)
        print('Meta Test Error', meta_test_error / meta_bsz)
        print('Meta Test Accuracy', meta_test_accuracy / meta_bsz)

        # Average the accumulated gradients and optimize
        for p in all_parameters:
            p.grad.data.mul_(1.0 / meta_bsz)

        # print('head')
        # for p in list(head.parameters()):
        #     print(torch.max(torch.abs(p.grad.data)))

        # print('feature')
        # for p in list(features.parameters()):
        #     print(torch.max(torch.abs(p.grad.data)))

        optimizer.step()
        end_time = time.time()
        running_time[iteration] = end_time - start_time
        print('total running time', end_time - start_time)

    return training_accuracy.numpy(), test_accuracy.numpy(), running_time
Esempio n. 11
0
def main(
    ways=5,
    shots=5,
    meta_lr=0.003,
    fast_lr=0.5,
    meta_batch_size=32,
    adaptation_steps=1,
    num_iterations=60000,
    cuda=True,
    seed=42,
):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device('cpu')
    if cuda and torch.cuda.device_count():
        torch.cuda.manual_seed(seed)
        device = torch.device('cuda')

    # Create Datasets
    train_dataset = l2l.vision.datasets.MiniImagenet(root='~/data',
                                                     mode='train')
    valid_dataset = l2l.vision.datasets.MiniImagenet(root='~/data',
                                                     mode='validation')
    test_dataset = l2l.vision.datasets.MiniImagenet(root='~/data', mode='test')
    train_dataset = l2l.data.MetaDataset(train_dataset)
    valid_dataset = l2l.data.MetaDataset(valid_dataset)
    test_dataset = l2l.data.MetaDataset(test_dataset)

    train_transforms = [
        NWays(train_dataset, ways),
        KShots(train_dataset, 2 * shots),
        LoadData(train_dataset),
        RemapLabels(train_dataset),
        ConsecutiveLabels(train_dataset),
    ]
    train_tasks = l2l.data.TaskDataset(train_dataset,
                                       task_transforms=train_transforms,
                                       num_tasks=20000)

    valid_transforms = [
        NWays(valid_dataset, ways),
        KShots(valid_dataset, 2 * shots),
        LoadData(valid_dataset),
        ConsecutiveLabels(valid_dataset),
        RemapLabels(valid_dataset),
    ]
    valid_tasks = l2l.data.TaskDataset(valid_dataset,
                                       task_transforms=valid_transforms,
                                       num_tasks=600)

    test_transforms = [
        NWays(test_dataset, ways),
        KShots(test_dataset, 2 * shots),
        LoadData(test_dataset),
        RemapLabels(test_dataset),
        ConsecutiveLabels(test_dataset),
    ]
    test_tasks = l2l.data.TaskDataset(test_dataset,
                                      task_transforms=test_transforms,
                                      num_tasks=600)

    # Create model
    model = l2l.vision.models.MiniImagenetCNN(ways)
    model.to(device)
    maml = l2l.algorithms.MAML(model, lr=fast_lr, first_order=False)
    opt = optim.Adam(maml.parameters(), meta_lr)
    loss = nn.CrossEntropyLoss(reduction='mean')

    training_accuracy = torch.ones(num_iterations)
    test_accuracy = torch.ones(num_iterations)
    running_time = np.ones(num_iterations)
    import time
    start_time = time.time()

    for iteration in range(num_iterations):
        opt.zero_grad()
        meta_train_error = 0.0
        meta_train_accuracy = 0.0
        meta_valid_error = 0.0
        meta_valid_accuracy = 0.0
        meta_test_error = 0.0
        meta_test_accuracy = 0.0
        for task in range(meta_batch_size):
            # Compute meta-training loss
            learner = maml.clone()
            batch = train_tasks.sample()
            evaluation_error, evaluation_accuracy = fast_adapt(
                batch, learner, loss, adaptation_steps, shots, ways, device)
            evaluation_error.backward()
            meta_train_error += evaluation_error.item()
            meta_train_accuracy += evaluation_accuracy.item()

            # Compute meta-validation loss
            learner = maml.clone()
            batch = valid_tasks.sample()
            evaluation_error, evaluation_accuracy = fast_adapt(
                batch, learner, loss, adaptation_steps, shots, ways, device)
            meta_valid_error += evaluation_error.item()
            meta_valid_accuracy += evaluation_accuracy.item()

            # Compute meta-test loss
            learner = maml.clone()
            batch = test_tasks.sample()
            evaluation_error, evaluation_accuracy = fast_adapt(
                batch, learner, loss, adaptation_steps, shots, ways, device)
            meta_test_error += evaluation_error.item()
            meta_test_accuracy += evaluation_accuracy.item()

        training_accuracy[iteration] = meta_train_accuracy / meta_batch_size
        test_accuracy[iteration] = meta_test_accuracy / meta_batch_size

        # Print some metrics
        print('\n')
        print('Iteration', iteration)
        print('Meta Train Error', meta_train_error / meta_batch_size)
        print('Meta Train Accuracy', meta_train_accuracy / meta_batch_size)
        print('Meta Valid Error', meta_valid_error / meta_batch_size)
        print('Meta Valid Accuracy', meta_valid_accuracy / meta_batch_size)
        print('Meta Test Error', meta_test_error / meta_batch_size)
        print('Meta Test Accuracy', meta_test_accuracy / meta_batch_size)

        # Average the accumulated gradients and optimize
        for p in maml.parameters():
            p.grad.data.mul_(1.0 / meta_batch_size)
        opt.step()

        end_time = time.time()
        running_time[iteration] = end_time - start_time
        print('total running time', end_time - start_time)

    # meta_test_error = 0.0
    # meta_test_accuracy = 0.0
    # for task in range(meta_batch_size):
    #     # Compute meta-testing loss
    #     learner = maml.clone()
    #     batch = test_tasks.sample()
    #     evaluation_error, evaluation_accuracy = fast_adapt(batch,
    #                                                       learner,
    #                                                       loss,
    #                                                       adaptation_steps,
    #                                                       shots,
    #                                                       ways,
    #                                                       device)
    #     meta_test_error += evaluation_error.item()
    #     meta_test_accuracy += evaluation_accuracy.item()
    # print('Meta Test Error', meta_test_error / meta_batch_size)
    # print('Meta Test Accuracy', meta_test_accuracy / meta_batch_size)

    return training_accuracy.numpy(), test_accuracy.numpy(), running_time
Esempio n. 12
0
def cifarfs_tasksets(
    train_ways=5,
    train_samples=10,
    test_ways=5,
    test_samples=10,
    root='~/data',
    device=None,
    **kwargs,
):
    """Tasksets for CIFAR-FS benchmarks."""
    data_transform = tv.transforms.ToTensor()
    train_dataset = l2l.vision.datasets.CIFARFS(root=root,
                                                transform=data_transform,
                                                mode='train',
                                                download=True)
    valid_dataset = l2l.vision.datasets.CIFARFS(root=root,
                                                transform=data_transform,
                                                mode='validation',
                                                download=True)
    test_dataset = l2l.vision.datasets.CIFARFS(root=root,
                                               transform=data_transform,
                                               mode='test',
                                               download=True)
    if device is not None:
        train_dataset = l2l.data.OnDeviceDataset(
            dataset=train_dataset,
            device=device,
        )
        valid_dataset = l2l.data.OnDeviceDataset(
            dataset=valid_dataset,
            device=device,
        )
        test_dataset = l2l.data.OnDeviceDataset(
            dataset=test_dataset,
            device=device,
        )
    train_dataset = l2l.data.MetaDataset(train_dataset)
    valid_dataset = l2l.data.MetaDataset(valid_dataset)
    test_dataset = l2l.data.MetaDataset(test_dataset)

    train_transforms = [
        NWays(train_dataset, train_ways),
        KShots(train_dataset, train_samples),
        LoadData(train_dataset),
        RemapLabels(train_dataset),
        ConsecutiveLabels(train_dataset),
    ]
    valid_transforms = [
        NWays(valid_dataset, test_ways),
        KShots(valid_dataset, test_samples),
        LoadData(valid_dataset),
        ConsecutiveLabels(valid_dataset),
        RemapLabels(valid_dataset),
    ]
    test_transforms = [
        NWays(test_dataset, test_ways),
        KShots(test_dataset, test_samples),
        LoadData(test_dataset),
        RemapLabels(test_dataset),
        ConsecutiveLabels(test_dataset),
    ]

    _datasets = (train_dataset, valid_dataset, test_dataset)
    _transforms = (train_transforms, valid_transforms, test_transforms)
    return _datasets, _transforms
Esempio n. 13
0
def get_few_shot_tasksets(
    root='data',
    dataset='cifar10-fc100',
    train_ways=5,
    train_samples=10,
    test_ways=5,
    test_samples=10,
    n_train_tasks=2000,
    n_test_tasks=1000,
):
    """
    Fetch the train, valid, test meta tasks of given dataset.
    :param root: data directory.
    :param dataset: name of dataset.
    :param train_ways: number of ways of few-shot training.
    :param train_samples: number of each-way samples for a training task.
    :param test_ways: number of ways of few-shot evaluation and testing.
    :param test_samples: number of each-way samples for a valid or test task.
    :param n_train_tasks: total number of train tasks.
    :param n_test_tasks: total number of valid and test tasks.
    :return:
    """

    train_dataset, valid_dataset, test_dataset = get_normal_tasksets(
        root=root, dataset=dataset)
    train_dataset = l2l.data.MetaDataset(train_dataset)
    valid_dataset = l2l.data.MetaDataset(valid_dataset)
    test_dataset = l2l.data.MetaDataset(test_dataset)

    train_transforms = [
        NWays(train_dataset, train_ways),
        KShots(train_dataset, train_samples),
        LoadData(train_dataset),
        RemapLabels(train_dataset),
        ConsecutiveLabels(train_dataset),
    ]

    test_transforms = [
        NWays(test_dataset, test_ways),
        KShots(test_dataset, test_samples),
        LoadData(test_dataset),
        RemapLabels(test_dataset),
        ConsecutiveLabels(test_dataset),
    ]

    # Instantiate the tasksets
    train_tasks = l2l.data.TaskDataset(
        dataset=train_dataset,
        task_transforms=train_transforms,
        num_tasks=n_train_tasks,
    )

    valid_tasks = l2l.data.TaskDataset(
        dataset=valid_dataset,
        task_transforms=test_transforms,
        num_tasks=n_test_tasks,
    )

    test_tasks = l2l.data.TaskDataset(
        dataset=test_dataset,
        task_transforms=test_transforms,
        num_tasks=n_test_tasks,
    )

    return BenchmarkTasksets(train_tasks, valid_tasks, test_tasks)
def test(args):

    device = torch.device('cpu')
    if args.gpu and torch.cuda.device_count():
        print("Using gpu")
        torch.cuda.manual_seed(43)
        device = torch.device('cuda')

    if args.backbone2d:
        model = Convnet2D(st_attention=args.st_attention)
    else:
        model = Convnet(st_attention=args.st_attention)
    model.to(device)
    relation_head = RelationHead(ways=args.test_way, dynamic=args.dynhead)
    relation_head.to(device)
    if args.temporal_align:
        temporal_align = TemporalAlignMoudle(10, shot=args.test_shot)
        temporal_align.to(device)

    models = [model, relation_head]
    if args.temporal_align:
        models.append(temporal_align)

    num_class, args.train_list, args.val_list, args.root_path, prefix, anno_prefix = dataset_config.return_dataset(
        args.dataset, 'RGB')
    args.test_list = args.train_list.replace('train', 'test')
    path_data = args.root_path
    num_segments = 20

    test_dataset = TSNDataSet(path_data,
                              args.test_list,
                              num_segments=num_segments,
                              new_length=1,
                              modality='RGB',
                              image_tmpl=prefix,
                              transform=Compose([
                                  GroupScale([128, 128]),
                                  StackBatch(roll=False),
                                  To3DTorchFormatTensor(div=True)
                              ]),
                              dense_sample=False,
                              test_mode=True)

    test_dataset = l2l.data.MetaDataset(
        test_dataset, indices_to_labels=test_dataset.indices_to_labels)
    test_transforms = [
        NWays(test_dataset, args.test_way),
        KShots(test_dataset, args.test_query + args.test_shot),
        LoadData(test_dataset),
        RemapLabels(test_dataset),
    ]
    test_tasks = l2l.data.TaskDataset(test_dataset,
                                      task_transforms=test_transforms,
                                      num_tasks=2000)
    test_loader = DataLoader(test_tasks,
                             num_workers=1,
                             pin_memory=True,
                             shuffle=True)
    # test_prefetcher=DataPrefetcher(test_tasks)

    loss_ctr = 0
    n_acc = 0
    model.load_state_dict(
        torch.load('%s/%s/ckpt.pth' %
                   (args.root_model, args.store_name))['state_dict'])
    relation_head.load_state_dict(
        torch.load('%s/%s/ckpt.pth' %
                   (args.root_model, args.store_name))['head_dict'])
    if args.temporal_align:
        temporal_align.load_state_dict(
            torch.load('%s/%s/ckpt.pth' %
                       (args.root_model, args.store_name))['align_dict'])

    for m in models:
        m.eval()

    for i, batch in enumerate(test_loader, 1):
        loss, acc = fast_adapt(models,
                               batch,
                               args.test_way,
                               args.test_shot,
                               args.test_query,
                               metric=pairwise_distances_logits,
                               device=device)
        loss_ctr += 1
        n_acc += acc.item()
        sys.stdout.write('\rbatch {}: {:.2f}({:.2f})  \033[K'.format(
            i, n_acc / loss_ctr * 100, acc * 100))
    print()
def train(args):
    logger = SummaryWriter(comment=args.comment)

    device = torch.device('cpu')
    if args.gpu and torch.cuda.device_count():
        print("Using gpu")
        torch.cuda.manual_seed(43)
        device = torch.device('cuda')
    if args.backbone2d:
        model = Convnet2D(st_attention=args.st_attention)
    else:
        model = Convnet(st_attention=args.st_attention)
    model.to(device)
    relation_head = RelationHead(ways=args.train_way, dynamic=args.dynhead)
    relation_head.to(device)
    if args.temporal_align:
        temporal_align = TemporalAlignMoudle(10, shot=args.shot)
        # d=torch.load('./ckpt/pretrain_ta/ckpt.pth','cpu')
        # temporal_align.load_state_dict(d['align_dict'])
        temporal_align.to(device)
        # del d

    # train_augmentation = get_augmentation(flip=False if 'something' in args.dataset or 'jester' in args.dataset else True)
    train_augmentation = Compose(
        [GroupScale([128, 128]),
         GroupRandomHorizontalFlip(is_flow=False)])
    normalize = IdentityTransform()

    num_class, args.train_list, args.val_list, args.root_path, prefix, anno_prefix = dataset_config.return_dataset(
        args.dataset, 'RGB')
    args.test_list = args.train_list.replace('train', 'test')
    path_data = args.root_path
    num_segments = 20

    train_dataset = TSNDataSet(path_data,
                               args.train_list,
                               num_segments=num_segments,
                               new_length=1,
                               modality='RGB',
                               image_tmpl=prefix,
                               transform=Compose([
                                   train_augmentation,
                                   StackBatch(roll=False),
                                   To3DTorchFormatTensor(div=True)
                               ]),
                               dense_sample=False)
    train_dataset = l2l.data.MetaDataset(
        train_dataset, indices_to_labels=train_dataset.indices_to_labels)
    train_transforms = [
        NWays(train_dataset, args.train_way),
        KShots(train_dataset, args.train_query + args.shot),
        LoadData(train_dataset),
        RemapLabels(train_dataset),
    ]
    train_tasks = l2l.data.TaskDataset(train_dataset,
                                       task_transforms=train_transforms)
    # train_loader = DataLoader(train_tasks,num_workers=1, pin_memory=True, shuffle=True)
    train_prefetcher = DataPrefetcher(train_tasks)

    valid_dataset = TSNDataSet(path_data,
                               args.val_list,
                               num_segments=num_segments,
                               new_length=1,
                               modality='RGB',
                               image_tmpl=prefix,
                               transform=Compose([
                                   GroupScale([128, 128]),
                                   StackBatch(roll=False),
                                   To3DTorchFormatTensor(div=True)
                               ]),
                               dense_sample=False)
    valid_dataset = l2l.data.MetaDataset(
        valid_dataset, indices_to_labels=valid_dataset.indices_to_labels)
    valid_transforms = [
        NWays(valid_dataset, args.test_way),
        KShots(valid_dataset, args.test_query + args.test_shot),
        LoadData(valid_dataset),
        RemapLabels(valid_dataset),
    ]
    valid_tasks = l2l.data.TaskDataset(valid_dataset,
                                       task_transforms=valid_transforms,
                                       num_tasks=100)
    valid_loader = DataLoader(valid_tasks,
                              num_workers=1,
                              pin_memory=True,
                              shuffle=True)
    # valid_prefetcher=DataPrefetcher(valid_tasks)

    models = [model, relation_head]
    param_groups = [{'params': m.parameters()} for m in models]
    if args.temporal_align:
        models.append(temporal_align)
        param_groups.append({
            'params': temporal_align.parameters(),
            'lr': 1e-3
        })

    optimizer = torch.optim.Adam(param_groups, lr=1e-3)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=25,
                                                   gamma=0.5)
    print('Start training')
    best_metric = 0
    for epoch in range(1, args.max_epoch + 1):
        for m in models:
            m.train()

        loss_ctr = 0
        n_loss = 0
        n_acc = 0
        batch = train_prefetcher.next()
        for i in range(100):
            # batch = next(iter(train_loader))

            loss, acc = fast_adapt(models,
                                   batch,
                                   args.train_way,
                                   args.shot,
                                   args.train_query,
                                   metric=pairwise_distances_logits,
                                   device=device,
                                   epoch=epoch)

            loss_ctr += 1
            n_loss += loss.item()
            n_acc += acc.item()

            batch = train_prefetcher.next()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if (i + 1) % 10 == 0:
                sys.stdout.write('\rWorking.... i=%d \033[K' % (i + 1))

        lr_scheduler.step()

        print('epoch {}, train, loss={:.4f} acc={:.4f}'.format(
            epoch, n_loss / loss_ctr, n_acc / loss_ctr))
        logger.add_scalar('train_acc', n_acc / loss_ctr, global_step=epoch)
        ckpt = {
            'epoch': epoch + 1,
            'arch': 'raw_relation',
            'state_dict': model.state_dict(),
            'head_dict': relation_head.state_dict(),
            'optimizer': optimizer.state_dict(),
            'best_acc': -1,
        }
        if args.temporal_align:
            ckpt['align_dict'] = temporal_align.state_dict()
        save_checkpoint(ckpt, False)

        for m in models:
            m.eval()

        loss_ctr = 0
        n_loss = 0
        n_acc = 0
        # batch=valid_prefetcher.next()
        with torch.no_grad():
            for i, batch in enumerate(valid_loader):
                # while batch is not None and batch[0] is not None:
                loss, acc = fast_adapt(models,
                                       batch,
                                       args.test_way,
                                       args.test_shot,
                                       args.test_query,
                                       metric=pairwise_distances_logits,
                                       device=device,
                                       epoch=epoch)

                loss_ctr += 1
                n_loss += loss.item()
                n_acc += acc.item()
                # batch=valid_prefetcher.next()

                if (i + 1) % 10 == 0:
                    sys.stdout.write('\rEvaling.... i=%d \033[K' % (i + 1))

        print('epoch {}, val, loss={:.4f} acc={:.4f}'.format(
            epoch, n_loss / loss_ctr, n_acc / loss_ctr))
        logger.add_scalar('val_acc', n_acc / loss_ctr, global_step=epoch)
        metric_for_best = n_acc / loss_ctr
        is_best = metric_for_best > best_metric
        best_metric = max(metric_for_best, best_metric)
        ckpt = {
            'epoch': epoch + 1,
            'arch': 'raw_relation',
            'state_dict': model.state_dict(),
            'head_dict': relation_head.state_dict(),
            'optimizer': optimizer.state_dict(),
            'best_acc': best_metric,
        }
        if args.temporal_align:
            ckpt['align_dict'] = temporal_align.state_dict()
        save_checkpoint(ckpt, is_best)

    torch.cuda.empty_cache()
Esempio n. 16
0
def main(
    ways=4,
    shots=5,
    meta_lr=0.01,
    fast_lr=0.5,
    meta_batch_size=32,
    adaptation_steps=1,
    num_iterations=10000,
    cuda=True,
    seed=42,
):

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device('cpu')
    if cuda:
        torch.cuda.manual_seed(seed)
        device = torch.device('cuda')

    # Create dataset

    train_transforms = [
        NWays(train_dataset, ways),
        KShots(train_dataset, shots * 2),
        LoadData(train_dataset),
        RemapLabels(train_dataset),
        ConsecutiveLabels(train_dataset)
    ]

    valid_transforms = [
        NWays(valid_dataset, ways),
        KShots(valid_dataset, shots * 2),
        LoadData(valid_dataset),
        RemapLabels(valid_dataset),
        ConsecutiveLabels(valid_dataset)
    ]

    train_tasks = l2l.data.TaskDataset(train_dataset,
                                       task_transforms=train_transforms,
                                       num_tasks=1000)

    valid_tasks = l2l.data.TaskDataset(valid_dataset,
                                       task_transforms=valid_transforms,
                                       num_tasks=200)

    # Create model
    model = Discriminator(ways=ways)

    model.to(device)

    maml = l2l.algorithms.MAML(model, lr=fast_lr, first_order=False)
    opt = optim.Adam(maml.parameters(), meta_lr)

    loss = nn.CrossEntropyLoss(reduction='mean')

    best_acc = 0.0

    writer = SummaryWriter("/home/sever2users/Desktop/MetaGAN/maml/viz")

    for iteration in range(num_iterations):
        opt.zero_grad()
        meta_train_error = 0.0
        meta_train_accuracy = 0.0
        meta_valid_error = 0.0
        meta_valid_accuracy = 0.0

        for task in range(meta_batch_size):
            # Compute meta-training loss
            learner = maml.clone()
            batch = train_tasks.sample()
            evaluation_error, evaluation_accuracy = fast_adapt(
                batch, learner, loss, adaptation_steps, shots, ways, device)
            evaluation_error.backward()
            meta_train_error += evaluation_error.item()
            meta_train_accuracy += evaluation_accuracy.item()

            # Compute meta-validation loss
            learner = maml.clone()
            batch = valid_tasks.sample()
            evaluation_error, evaluation_accuracy = fast_adapt(
                batch, learner, loss, adaptation_steps, shots, ways, device)
            meta_valid_error += evaluation_error.item()
            meta_valid_accuracy += evaluation_accuracy.item()

        # Print some metrics
        if iteration % 500 == 0:
            print('\n')
            print('Iteration', iteration)
            print('Meta Train Error', meta_train_error / meta_batch_size)
            print('Meta Train Accuracy', meta_train_accuracy / meta_batch_size)
            print('Meta Valid Error', meta_valid_error / meta_batch_size)
            print('Meta Valid Accuracy', meta_valid_accuracy / meta_batch_size)

        writer.add_scalar(
            "{}way_{}shot_{}train_as/accuracy/valid".format(
                ways, shots, adaptation_steps),
            (meta_valid_accuracy / meta_batch_size), iteration)
        writer.add_scalar(
            "{}way_{}shot_{}train_as/accuracy/train".format(
                ways, shots, adaptation_steps),
            (meta_train_accuracy / meta_batch_size), iteration)

        writer.add_scalar(
            "way_{}shot_{}train_as/loss/valid".format(ways, shots,
                                                      adaptation_steps),
            (meta_valid_error / meta_batch_size), iteration)
        writer.add_scalar(
            "way_{}shot_{}train_as/loss/train".format(ways, shots,
                                                      adaptation_steps),
            (meta_train_error / meta_batch_size), iteration)

        if ((meta_valid_accuracy / meta_batch_size) > best_acc):

            maml_test = maml.clone()

            best_acc = (meta_valid_accuracy / meta_batch_size)

        # Average the accumulated gradients and optimize
        for p in maml.parameters():
            p.grad.data.mul_(1.0 / meta_batch_size)
        opt.step()

    print("\n")
    print(ways, "WAYS", shots, "SHOTS", adaptation_steps,
          "train_adaptation_steps")

    x = []

    for a_s in [1, 3, 10, 20]:
        for fast_lr in [0.5, 0.1, 0.05, 0.01]:

            l, a = test(fast_lr, loss, a_s, shots, ways, device, maml_test,
                        valid_tasks)

            x.append([a_s, fast_lr, a])

            print("test_a_s=", a_s, 'fast_lr', fast_lr, "acc=", a)

    np.savetxt(
        '/home/sever2users/Desktop/MetaGAN/maml/test_results/{}way_{}shot_{}train_adp_steps.txt'
        .format(ways, shots, adaptation_steps), np.array(x))

    torch.save(
        maml_test.module.state_dict(),
        '/home/sever2users/Desktop/MetaGAN/maml/saved_models/state_dict_{}way_{}shot_{}train_adp_steps.pth'
        .format(ways, shots, adaptation_steps))
Esempio n. 17
0
transformations = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307, ), (0.3081, )),
    lambda x: x.view(1, 28, 28),
])
# transformations = transforms.Compose([transforms.ToTensor()])

# d=CustomDatasetFromCsvLocation("./mnist-demo.csv")
d = CustomDatasetFromCsvData('./mnist-demo.csv', 28, 28, transformations)
# import torch.dataset
t = l2l.data.MetaDataset(d)
train_tasks = l2l.data.TaskDataset(t,
                                   task_transforms=[
                                       NWays(t, n=3),
                                       KShots(t, k=2),
                                       LoadData(t),
                                       RemapLabels(t),
                                       ConsecutiveLabels(t)
                                   ],
                                   num_tasks=1000)
model = Net(3)

# model = Net(ways)
maml_lr = 0.01
lr = 0.005
iterations = 1000
tps = 32
fas = 5
shots = 1
ways = 3