def test_wrong_q_query_arg(self):
        episodes_per_epoch = 1
        n_shot = 5
        k_way = 5

        # making n_shot + q_query larger than the maximum sample
        # count should immediately raise an error
        q_queries = self.background_class_count.max() - n_shot + 1

        sampler = NShotTaskSampler(self.background,
                                   episodes_per_epoch,
                                   n_shot, k_way, q_queries)

        def dummy_iter():
            for sample in sampler:
                continue

        self.assertRaises(ValueError, dummy_iter)

        # taking q_query larger than the smallest class should raise an error
        # as soon as that class is selected in an episode
        q_queries = self.background_class_count.min() - n_shot + 1
        class_min_index = self.background_class_count.argmin()

        sampler = NShotTaskSampler(self.background,
                                   episodes_per_epoch,
                                   n_shot, k_way, q_queries,
                                   fixed_tasks=[[class_min_index]])

        def dummy_iter():
            for sample in sampler:
                continue

        self.assertRaises(ValueError, dummy_iter)
Пример #2
0
    def _test_n_k_q_combination(self, n, k, q):
        n_shot_taskloader = DataLoader(self.dataset,
                                       batch_sampler=NShotTaskSampler(
                                           self.dataset, 100, n, k, q))

        # Load a single n-shot, k-way task
        for batch in n_shot_taskloader:
            x, y = batch
            break

        # Take just dummy label features and a little bit of noise
        # So distances are never 0
        support = x[:n * k, 1:]
        queries = x[n * k:, 1:]
        support += torch.rand_like(support)
        queries += torch.rand_like(queries)

        distances = pairwise_distances(queries, support, 'cosine')

        # Calculate "attention" as softmax over distances
        attention = (-distances).softmax(dim=1).cuda()

        y_pred = matching_net_predictions(attention, n, k, q)

        self.assertEqual(
            y_pred.shape, (q * k, k),
            'Matching Network predictions must have shape (q * k, k).')

        y_pred_sum = y_pred.sum(dim=1)
        self.assertTrue(
            torch.all(
                torch.isclose(y_pred_sum,
                              torch.ones_like(y_pred_sum).double())),
            'Matching Network predictions probabilities must sum to 1 for each '
            'query sample.')
    def test_evaluation_sampler(self):
        episodes_per_epoch = 10
        n_test = 1
        k_test = 5
        q_test = 1

        evaluation_sampler = NShotTaskSampler(self.evaluation,
                                              episodes_per_epoch,
                                              n_test, k_test, q_test)

        evaluation_loader = DataLoader(
            self.evaluation,
            batch_sampler=evaluation_sampler,
            num_workers=4
        )

        prepare_batch = prepare_nshot_task(n_test, k_test, q_test)

        for batch_index, batch in enumerate(evaluation_loader):
            x, y = prepare_batch(batch)

            loss, y_pred = dummy_fit_function(
                dummy_model,
                torch.nn.NLLLoss().to(device),
                x.to(device),
                y.to(device),
                n_shot=n_test,
                k_way=k_test,
                q_queries=q_test,
                train=False,
            )
    def test_background_sampler(self):
        episodes_per_epoch = 10
        n_train = 5
        k_train = 15
        q_train = self.background_class_count.min() - n_train

        background_sampler = NShotTaskSampler(self.background,
                                              episodes_per_epoch,
                                              n_train, k_train, q_train)

        background_loader = DataLoader(
            self.background,
            batch_sampler=background_sampler,
            num_workers=4
        )

        prepare_batch = prepare_nshot_task(n_train, k_train, q_train)

        for batch_index, batch in enumerate(background_loader):
            x, y = prepare_batch(batch)

            loss, y_pred = dummy_fit_function(
                dummy_model,
                torch.nn.NLLLoss().to(device),
                x.to(device),
                y.to(device),
                n_shot=n_train,
                k_way=k_train,
                q_queries=q_train,
                train=False,
            )
Пример #5
0
    def test_n_shot_sampler(self):
        n, k, q = 2, 4, 3
        # NOTE: i think if num_task > 1 in NShotTaskSampler then we get support as x[:n*k] and next support as
        # x[n*k+n*q:2(n*k+n*q)] etc
        n_shot_taskloader = DataLoader(self.dataset,
                                       batch_sampler=NShotTaskSampler(
                                           self.dataset, 100, n, k, q))

        # Load a single n-shot task and check it's properties
        for x, y in n_shot_taskloader:
            support = x[:n * k]
            queries = x[n * k:]
            support_labels = y[:n * k]
            query_labels = y[n * k:]

            # Check ordering of support labels is correct
            for i in range(0, n * k, n):
                support_set_labels_correct = torch.all(
                    support_labels[i:i + n] == support_labels[i])
                self.assertTrue(
                    support_set_labels_correct,
                    'Classes of support set samples should be arranged like: '
                    '[class_1]*n + [class_2]*n + ... + [class_k]*n')

            # Check ordering of query labels is correct
            for i in range(0, q * k, q):
                support_set_labels_correct = torch.all(
                    query_labels[i:i + q] == query_labels[i])
                self.assertTrue(
                    support_set_labels_correct,
                    'Classes of query set samples should be arranged like: '
                    '[class_1]*q + [class_2]*q + ... + [class_k]*q')

            # Check labels are consistent across query and support
            for i in range(k):
                self.assertEqual(
                    support_labels[i * n], query_labels[i * q],
                    'Classes of query and support set should be consistent.')

            # Check no overlap of IDs between support and query.
            # By construction the first feature in the DummyDataset is the
            # id of the sample in the dataset so we can use this to test
            # for overlap betwen query and suppport samples
            self.assertEqual(
                len(
                    set(support[:, 0].numpy()).intersection(
                        set(queries[:, 0].numpy()))), 0,
                'There should be no overlap between support and query set samples.'
            )

            break
    def test_wrong_n_shot_arg(self):
        episodes_per_epoch = 100
        # taking n_shot larger than the maximum sample count
        n_shot = self.background_class_count.max() + 1
        k_way = 5
        q_queries = 1

        sampler = NShotTaskSampler(self.background,
                                   episodes_per_epoch,
                                   n_shot, k_way, q_queries)

        def dummy_iter():
            for sample in sampler:
                continue

        self.assertRaises(ValueError, dummy_iter)
    def test_wrong_k_way_arg(self):
        episodes_per_epoch = 100
        n_shot = 5
        # make k-way larger than number of classes
        k_way = self.n_background_classes + 1
        q_queries = 1

        sampler = NShotTaskSampler(self.background,
                                   episodes_per_epoch,
                                   n_shot, k_way, q_queries)

        def dummy_iter():
            for sample in sampler:
                continue

        self.assertRaises(ValueError, dummy_iter)
Пример #8
0
    def test_n_shot_sampler_num_tasks_not_1(self):  # TODO
        n, k, q = 2, 4, 3
        # NOTE: i think if num_task > 1 in NShotTaskSampler then we get support as x[:n*k] and next support as
        # x[n*k+n*q:2(n*k+n*q)] etc
        n_shot_taskloader = DataLoader(self.dataset,
                                       batch_sampler=NShotTaskSampler(
                                           self.dataset,
                                           100,
                                           n,
                                           k,
                                           q,
                                           num_tasks=3))

        # Load a single n-shot task and check it's properties
        for x, y in n_shot_taskloader:
            support = x[:n * k]
            queries = x[n * k:]
            support_labels = y[:n * k]
            query_labels = y[n * k:]
Пример #9
0
    def setUpClass(cls):
        cls.n = 1
        cls.k = 5
        cls.q = 1

        cls.meta_batch_size = 1

        cls.dummy = DummyDataset()
        cls.dummy_tasks = DataLoader(
            cls.dummy,
            batch_sampler=NShotTaskSampler(cls.dummy,
                                           cls.meta_batch_size,
                                           n=cls.n,
                                           k=cls.k,
                                           q=cls.q,
                                           num_tasks=1),
        )

        cls.model = DummyModel(cls.k).double()
        cls.opt = torch.optim.Adam(cls.model.parameters(), lr=0.001)
Пример #10
0
  def _test_n_k_q_combination(self, n, k, q):
    n_shot_taskloader = DataLoader(self.dataset,
                                   batch_sampler=NShotTaskSampler(
                                       self.dataset, 100, n, k, q))

    # Load a single n-shot, k-way task
    for batch in n_shot_taskloader:
      x, y = batch
      break

    support = x[:n * k]
    support_labels = y[:n * k]
    prototypes = compute_prototypes(support, k, n)

    # By construction the second feature of samples from the
    # DummyDataset is equal to the label.
    # As class prototypes are constructed from the means of the support
    # set items of a particular class the value of the second feature
    # of the class prototypes should be equal to the label of that class.
    for i in range(k):
      self.assertEqual(support_labels[i * n], prototypes[i, 1],
                       'Prototypes computed incorrectly!')
Пример #11
0
def few_shot_training(datadir=DATA_PATH,
                      dataset='fashion',
                      num_input_channels=3,
                      drop_lr_every=20,
                      validation_episodes=200,
                      evaluation_episodes=1000,
                      episodes_per_epoch=100,
                      n_epochs=80,
                      small_dataset=False,
                      n_train=1,
                      n_test=1,
                      k_train=30,
                      k_test=5,
                      q_train=5,
                      q_test=1,
                      distance='l2',
                      pretrained=False,
                      monitor_validation=False,
                      n_val_classes=10,
                      architecture='resnet18',
                      gpu=None):
    setup_dirs()

    if dataset == 'fashion':
        dataset_class = FashionProductImagesSmall if small_dataset \
            else FashionProductImages
    else:
        raise (ValueError, 'Unsupported dataset')

    param_str = f'{dataset}_nt={n_train}_kt={k_train}_qt={q_train}_' \
                f'nv={n_test}_kv={k_test}_qv={q_test}_small={small_dataset}_' \
                f'pretrained={pretrained}_validate={monitor_validation}'

    print(param_str)

    ###################
    # Create datasets #
    ###################

    # ADAPTED: data transforms including augmentation
    resize = (80, 60) if small_dataset else (400, 300)

    background_transform = transforms.Compose([
        transforms.RandomResizedCrop(resize, scale=(0.8, 1.0)),
        # transforms.RandomGrayscale(),
        transforms.RandomPerspective(),
        transforms.RandomHorizontalFlip(),
        # transforms.Resize(resize),
        transforms.ToTensor(),
        # transforms.Normalize(mean=[0.485, 0.456, 0.406],
        #                     std=[0.229, 0.224, 0.225])
    ])

    evaluation_transform = transforms.Compose([
        transforms.Resize(resize),
        # transforms.CenterCrop(224),
        transforms.ToTensor(),
        # transforms.Normalize(mean=[0.485, 0.456, 0.406],
        #                     std=[0.229, 0.224, 0.225])
    ])

    if monitor_validation:
        if not n_val_classes >= k_test:
            n_val_classes = k_test
            print("Warning: `n_val_classes` < `k_test`. Take a larger number"
                  " of validation classes next time. Increased to `k_test`"
                  " classes")

        # class structure for background (training), validation (validation),
        # evaluation (test): take a random subset of background classes
        validation_classes = list(
            np.random.choice(dataset_class.background_classes, n_val_classes))
        background_classes = list(
            set(dataset_class.background_classes).difference(
                set(validation_classes)))

        # use keyword for evaluation classes
        evaluation_classes = 'evaluation'

        # Meta-validation set
        validation = dataset_class(datadir,
                                   split='all',
                                   classes=validation_classes,
                                   transform=evaluation_transform)
        # ADAPTED: in the original code, `episodes_per_epoch` was provided to
        # `NShotTaskSampler` instead of `validation_episodes`.
        validation_sampler = NShotTaskSampler(validation, validation_episodes,
                                              n_test, k_test, q_test)
        validation_taskloader = DataLoader(validation,
                                           batch_sampler=validation_sampler,
                                           num_workers=4)
    else:
        # use keyword for both background and evaluation classes
        background_classes = 'background'
        evaluation_classes = 'evaluation'

    # Meta-training set
    background = dataset_class(datadir,
                               split='all',
                               classes=background_classes,
                               transform=background_transform)
    background_sampler = NShotTaskSampler(background, episodes_per_epoch,
                                          n_train, k_train, q_train)
    background_taskloader = DataLoader(background,
                                       batch_sampler=background_sampler,
                                       num_workers=4)

    # Meta-test set
    evaluation = dataset_class(datadir,
                               split='all',
                               classes=evaluation_classes,
                               transform=evaluation_transform)
    # ADAPTED: in the original code, `episodes_per_epoch` was provided to
    # `NShotTaskSampler` instead of `evaluation_episodes`.
    evaluation_sampler = NShotTaskSampler(evaluation, evaluation_episodes,
                                          n_test, k_test, q_test)
    evaluation_taskloader = DataLoader(evaluation,
                                       batch_sampler=evaluation_sampler,
                                       num_workers=4)

    #########
    # Model #
    #########

    if torch.cuda.is_available():
        if gpu is not None:
            device = torch.device('cuda', gpu)
        else:
            device = torch.device('cuda')
        torch.backends.cudnn.benchmark = True
    else:
        device = torch.device('cpu')

    if not pretrained:
        model = get_few_shot_encoder(num_input_channels)
        # ADAPTED
        model.to(device)
        # BEFORE
        # model.to(device, dtype=torch.double)
    else:
        assert torch.cuda.is_available()
        model = models.__dict__[architecture](pretrained=True)
        model.fc = Identity()
        if gpu is not None:
            model = model.cuda(gpu)
        else:
            model = model.cuda()
        # TODO this is too risky: I'm not sure that this can work, since in
        #  the few-shot github repo the batch axis is actually split into
        #  support and query samples
        # model = torch.nn.DataParallel(model).cuda()

    def lr_schedule(epoch, lr):
        # Drop lr every 2000 episodes
        if epoch % drop_lr_every == 0:
            return lr / 2
        else:
            return lr

    ############
    # Training #
    ############
    print(f'Training Prototypical network on {dataset}...')
    optimiser = Adam(model.parameters(), lr=1e-3)
    loss_fn = torch.nn.NLLLoss().to(device)

    callbacks = [
        # ADAPTED: this is the test monitoring now - and is only done at the
        # end of training.
        EvaluateFewShot(
            eval_fn=proto_net_episode,
            num_tasks=evaluation_episodes,  # THIS IS NOT USED
            n_shot=n_test,
            k_way=k_test,
            q_queries=q_test,
            taskloader=evaluation_taskloader,
            prepare_batch=prepare_nshot_task(n_test,
                                             k_test,
                                             q_test,
                                             device=device),
            distance=distance,
            on_epoch_end=False,
            on_train_end=True,
            prefix='test_')
    ]
    if monitor_validation:
        callbacks.append(
            # ADAPTED: this is the validation monitoring now - computed
            # after every epoch.
            EvaluateFewShot(
                eval_fn=proto_net_episode,
                num_tasks=evaluation_episodes,  # THIS IS NOT USED
                n_shot=n_test,
                k_way=k_test,
                q_queries=q_test,
                # BEFORE taskloader=evaluation_taskloader,
                taskloader=validation_taskloader,  # ADAPTED
                prepare_batch=prepare_nshot_task(n_test,
                                                 k_test,
                                                 q_test,
                                                 device=device),
                distance=distance,
                on_epoch_end=True,  # ADAPTED
                on_train_end=False,  # ADAPTED
                prefix='val_'))
    callbacks.extend([
        ModelCheckpoint(
            filepath=PATH + f'/models/proto_nets/{param_str}.pth',
            monitor=f'val_{n_test}-shot_{k_test}-way_acc',
            verbose=1,  # ADAPTED
            save_best_only=monitor_validation  # ADAPTED
        ),
        LearningRateScheduler(schedule=lr_schedule),
        CSVLogger(PATH + f'/logs/proto_nets/{param_str}.csv'),
    ])

    fit(
        model,
        optimiser,
        loss_fn,
        epochs=n_epochs,
        dataloader=background_taskloader,
        prepare_batch=prepare_nshot_task(n_train,
                                         k_train,
                                         q_train,
                                         device=device),
        callbacks=callbacks,
        metrics=['categorical_accuracy'],
        fit_function=proto_net_episode,
        fit_function_kwargs={
            'n_shot': n_train,
            'k_way': k_train,
            'q_queries': q_train,
            'train': True,
            'distance': distance
        },
    )
Пример #12
0
def run():
    episodes_per_epoch = 600

    if args.dataset == 'miniImageNet':
        n_epochs = 500
        dataset_class = MiniImageNet
        num_input_channels = 3
        lstm_input_size = 1600
    else:
        raise(ValueError('need to make other datasets module'))

    param_str = f'{args.dataset}_n={args.n_train}_k={args.k_train}_q={args.q_train}_' \
                f'nv={args.n_test}_kv={args.k_test}_qv={args.q_test}_' \
                f'dist={args.distance}_fce={args.fce}_sampling_method={args.sampling_method}_' \
                f'is_diversity={args.is_diversity}_epi_candidate={args.num_s_candidates}'


    #########
    # Model #
    #########
    from few_shot.models import MatchingNetwork
    model = MatchingNetwork(args.n_train, args.k_train, args.q_train, args.fce, num_input_channels,
                            lstm_layers=args.lstm_layers,
                            lstm_input_size=lstm_input_size,
                            unrolling_steps=args.unrolling_steps,
                            device=device)
    model.to(device, dtype=torch.double)


    ###################
    # Create datasets #
    ###################
    train_dataset = dataset_class('train')
    eval_dataset = dataset_class('eval')

    # Original_sampling
    if not args.sampling_method:
        train_dataset_taskloader = DataLoader(
            train_dataset,
            batch_sampler=NShotTaskSampler(train_dataset, episodes_per_epoch, args.n_train, args.k_train, args.q_train),
            num_workers=4
        )
        eval_dataset_taskloader = DataLoader(
            eval_dataset,
            batch_sampler=NShotTaskSampler(eval_dataset, episodes_per_epoch, args.n_test, args.k_test, args.q_test),
            num_workers=4
        )
    # Importance sampling
    else:
        train_dataset_taskloader = DataLoader(
            train_dataset,
            batch_sampler=ImportanceSampler(train_dataset, model,
            episodes_per_epoch, n_epochs, args.n_train, args.k_train, args.q_train,
            args.num_s_candidates, args.init_temperature, args.is_diversity),
            num_workers=4
        )
        eval_dataset_taskloader = DataLoader(
            eval_dataset,
            batch_sampler=NShotTaskSampler(eval_dataset, episodes_per_epoch, args.n_test, args.k_test, args.q_test),
            num_workers=4
        )

    ############
    # Training #
    ############
    print(f'Training Matching Network on {args.dataset}...')
    optimiser = Adam(model.parameters(), lr=1e-3)
    loss_fn = torch.nn.NLLLoss().cuda()


    callbacks = [
        EvaluateFewShot(
            eval_fn=matching_net_episode,
            n_shot=args.n_test,
            k_way=args.k_test,
            q_queries=args.q_test,
            taskloader=eval_dataset_taskloader,
            prepare_batch=prepare_nshot_task(args.n_test, args.k_test, args.q_test),
            fce=args.fce,
            distance=args.distance
        ),
        ModelCheckpoint(
            filepath=PATH + f'/models/matching_nets/{param_str}.pth',
            monitor=f'val_{args.n_test}-shot_{args.k_test}-way_acc',
            save_best_only=True,
        ),
        ReduceLROnPlateau(patience=20, factor=0.5, monitor=f'val_{args.n_test}-shot_{args.k_test}-way_acc'),
        CSVLogger(PATH + f'/logs/matching_nets/{param_str}.csv'),
    ]

    fit(
        model,
        optimiser,
        loss_fn,
        epochs=n_epochs,
        dataloader=train_dataset_taskloader,
        prepare_batch=prepare_nshot_task(args.n_train, args.k_train, args.q_train),
        callbacks=callbacks,
        metrics=['categorical_accuracy'],
        fit_function=matching_net_episode,
        fit_function_kwargs={'n_shot': args.n_train, 'k_way': args.k_train, 'q_queries': args.q_train, 'train': True,
                            'fce': args.fce, 'distance': args.distance}
    )
Пример #13
0


param_str = f'{args.exp_name}_exp_name_{args.dataset}_nt={args.n_train}_kt={args.k_train}_qt={args.q_train}_' \
            f'nv={args.n_test}_kv={args.k_test}_qv={args.q_test}'

print(param_str)

###################
# Create datasets #
###################
background = dataset_class('background')
# no batch size for proto nets
background_taskloader = DataLoader(
    background,
    batch_sampler=NShotTaskSampler(background, episodes_per_epoch,
                                   args.n_train, args.k_train, args.q_train),
    num_workers=4)
evaluation = dataset_class('evaluation')
evaluation_taskloader = DataLoader(
    evaluation,
    batch_sampler=NShotTaskSampler(
        evaluation, episodes_per_epoch, args.n_test, args.k_test, args.q_test
    ),  # why is qtest needed for protonet i think its not rquired for protonet check it
    num_workers=4)

#########
# Model #
#########
model = get_few_shot_encoder(num_input_channels)
model.to(device, dtype=torch.double)
Пример #14
0
def run():
    episodes_per_epoch = 600
    '''
    ###### LearningRateScheduler ######
    drop_lr_every = 20
    def lr_schedule(epoch, lr):
        # Drop lr every 2000 episodes
        if epoch % drop_lr_every == 0:
            return lr / 2
        else:
            return lr
    # callbacks add: LearningRateScheduler(schedule=lr_schedule)
    '''

    if args.dataset == 'miniImageNet':
        n_epochs = 500
        dataset_class = MiniImageNet
        num_input_channels = 3
    else:
        raise (ValueError('need to make other datasets module'))


    param_str = f'{args.dataset}_nt={args.n_train}_kt={args.k_train}_qt={args.q_train}_' \
                f'nv={args.n_test}_kv={args.k_test}_qv={args.q_test}_' \
                f'dist={args.distance}_sampling_method={args.sampling_method}_is_diverisity={args.is_diversity}'

    print(param_str)

    #########
    # Model #
    #########
    model = get_few_shot_encoder(num_input_channels)
    model.to(device, dtype=torch.double)

    ###################
    # Create datasets #
    ###################
    train_dataset = dataset_class('train')
    eval_dataset = dataset_class('eval')

    # Original sampling
    if not args.sampling_method:
        train_dataset_taskloader = DataLoader(
            train_dataset,
            batch_sampler=NShotTaskSampler(train_dataset, episodes_per_epoch,
                                           args.n_train, args.k_train,
                                           args.q_train),
            num_workers=4)
        eval_dataset_taskloader = DataLoader(
            eval_dataset,
            batch_sampler=NShotTaskSampler(eval_dataset, episodes_per_epoch,
                                           args.n_test, args.k_test,
                                           args.q_test),
            num_workers=4)
    # Importance sampling
    else:
        # ImportanceSampler: Latent space of model
        train_dataset_taskloader = DataLoader(
            train_dataset,
            batch_sampler=ImportanceSampler(
                train_dataset, model, episodes_per_epoch, n_epochs,
                args.n_train, args.k_train, args.q_train,
                args.num_s_candidates, args.init_temperature,
                args.is_diversity),
            num_workers=4)
        eval_dataset_taskloader = DataLoader(
            eval_dataset,
            batch_sampler=NShotTaskSampler(eval_dataset, episodes_per_epoch,
                                           args.n_test, args.k_test,
                                           args.q_test),
            num_workers=4)

    ############
    # Training #
    ############
    print(f'Training Prototypical network on {args.dataset}...')
    optimiser = Adam(model.parameters(), lr=1e-3)
    loss_fn = torch.nn.NLLLoss().cuda()

    callbacks = [
        EvaluateFewShot(eval_fn=proto_net_episode,
                        n_shot=args.n_test,
                        k_way=args.k_test,
                        q_queries=args.q_test,
                        taskloader=eval_dataset_taskloader,
                        prepare_batch=prepare_nshot_task(
                            args.n_test, args.k_test, args.q_test),
                        distance=args.distance),
        ModelCheckpoint(
            filepath=PATH + f'/models/proto_nets/{param_str}.pth',
            monitor=f'val_{args.n_test}-shot_{args.k_test}-way_acc',
            save_best_only=True,
        ),
        ReduceLROnPlateau(
            patience=40,
            factor=0.5,
            monitor=f'val_{args.n_test}-shot_{args.k_test}-way_acc'),
        CSVLogger(PATH + f'/logs/proto_nets/{param_str}.csv'),
    ]

    fit(
        model,
        optimiser,
        loss_fn,
        epochs=n_epochs,
        dataloader=train_dataset_taskloader,
        prepare_batch=prepare_nshot_task(args.n_train, args.k_train,
                                         args.q_train),
        callbacks=callbacks,
        metrics=['categorical_accuracy'],
        fit_function=proto_net_episode,
        fit_function_kwargs={
            'n_shot': args.n_train,
            'k_way': args.k_train,
            'q_queries': args.q_train,
            'train': True,
            'distance': args.distance
        },
    )
Пример #15
0
        y = y.reshape(meta_batch_size, n + q, 1)
        y = y.float().to(device)
        #print(x.shape)
        #print(y.reshape(meta_batch_size, n*k + q*k, 1).shape)
        # Create label
        #y = create_nshot_task_label(k, q).cuda().repeat(meta_batch_size)
        return x, y

    return prepare_meta_batch_


data = RoPUF('test_untouched', challenge_size, test_board, False)
data_taskloader = DataLoader(data,
                             batch_sampler=NShotTaskSampler(data,
                                                            eval_batches,
                                                            n=n,
                                                            k=k,
                                                            q=q,
                                                            num_tasks=1),
                             num_workers=8)

model = FewShotClassifierPUF(in_features=challenge_size)
model.load_state_dict(torch.load(PATH + f'/models/maml2/{param_str}.pth'))
model.eval()
model.to(device)

meta_optimiser = torch.optim.Adam(model.parameters(), lr=0.001)
# TODO change to nn.BCELoss().to(device)  ??
loss_fn = nn.BCELoss().to(device)
sum = 0
for batch_index, batch in enumerate(data_taskloader):
    x, y = prepare_meta_batch(n, k, q, 1)(batch)
Пример #16
0
                        globals.Q_TRAIN,
                        globals.FCE,
                        num_input_channels,
                        lstm_layers=globals.LSTM_LAYERS,
                        lstm_input_size=lstm_input_size,
                        unrolling_steps=globals.UNROLLING_STEPS,
                        device=device)
model.to(device, dtype=torch.double)

###################
# Create datasets #
###################
background = dataset_class('background')
background_taskloader = DataLoader(background,
                                   batch_sampler=NShotTaskSampler(
                                       background, episodes_per_epoch,
                                       globals.N_TRAIN, globals.K_TRAIN,
                                       globals.Q_TRAIN),
                                   num_workers=4)
evaluation = dataset_class('evaluation')
evaluation_taskloader = DataLoader(evaluation,
                                   batch_sampler=NShotTaskSampler(
                                       evaluation, episodes_per_epoch,
                                       globals.N_TEST, globals.K_TEST,
                                       globals.Q_TEST),
                                   num_workers=4)

############
# Training #
############
print(f'Training Matching Network on {globals.DATASET}...')
optimiser = Adam(model.parameters(), lr=1e-3)
Пример #17
0
loaded_model = torch.load(model_path)
model.load_state_dict(loaded_model)


model.to(device)
model.double()
# print("###########################")
# for param in model.parameters():
#     print(param.data)
# print("###########################")


evaluation = dataset_class('evaluation')
dataloader = DataLoader(
    evaluation,
    batch_sampler=NShotTaskSampler(evaluation, episodes_per_epoch, n=globals.N_TEST, k=globals.K_TEST, 
                        q=globals.Q_TEST),
                        num_workers=4
)
prepare_batch = prepare_nshot_task(globals.N_TEST, globals.K_TEST, globals.Q_TEST)

for batch_index, batch in enumerate(dataloader):
    batch_logs = dict(batch=batch_index, size=(dataloader.batch_size or 1))
    x, y = prepare_batch(batch)
    # print(type(x))
    # time.sleep(55)
    loss, y_pred = matching_net_episode(model, None, torch.nn.NLLLoss().cuda(), x, y, globals.N_TEST, globals.K_TEST, 
                    globals.Q_TEST, 'l2', False, False )

    _, predicted = torch.max(y_pred.data, 1)
    # print(predicted)
    # print(y_pred.argmax(dim=-1))
Пример #18
0
def run_one(names):
    class args:
        meta_lr = 1e-3
        dataset = "miniImageNet"
        epoch_len = 800
        n = 5
        k = 5
        q = 5
        meta_batch_size = 2
        n_models = len(names)
        eval_batches = 80
        pred_mode = 'mean'
        order = 2
        epochs = 1
        inner_train_steps = 5
        inner_val_steps = 10
        inner_lr = 0.01

    background = dataset_class('background')
    background_taskloader = DataLoader(background,
                                       batch_sampler=NShotTaskSampler(
                                           background,
                                           args.epoch_len,
                                           n=args.n,
                                           k=args.k,
                                           q=args.q,
                                           num_tasks=args.meta_batch_size),
                                       num_workers=8)
    evaluation = dataset_class('evaluation')
    evaluation_taskloader = DataLoader(evaluation,
                                       batch_sampler=NShotTaskSampler(
                                           evaluation,
                                           args.eval_batches,
                                           n=args.n,
                                           k=args.k,
                                           q=args.q,
                                           num_tasks=args.meta_batch_size),
                                       num_workers=8)

    ############
    # Training #
    ############
    print(f'Training MAML on {args.dataset}...')

    model_params = [num_input_channels, args.k, fc_layer_size]
    meta_models = [
        FewShotClassifier(num_input_channels, args.k,
                          fc_layer_size).to(device, dtype=torch.double)
        for _ in range(args.n_models)
    ]
    meta_optimisers = [
        torch.optim.Adam(meta_model.parameters(), lr=args.meta_lr)
        for meta_model in meta_models
    ]

    for i, (model, name) in enumerate(zip(meta_models, names)):
        model.load_state_dict(torch.load("../models/maml_ens/" + name))

    loss_fn = F.nll_loss if args.order > 0 else F.cross_entropy

    if args.order == 2:
        fit_fn = meta_gradient_ens_step_mgpu_2order
    elif args.order == 1:
        fit_fn = meta_gradient_ens_step_mgpu_1order
    else:
        fit_fn = meta_gradient_ens_step_mgpu_meanloss

    def prepare_meta_batch(n, k, q, meta_batch_size):
        def prepare_meta_batch_(batch):
            x, y = batch
            # Reshape to `meta_batch_size` number of tasks. Each task contains
            # n*k support samples to train the fast model on and q*k query samples to
            # evaluate the fast model on and generate meta-gradients
            x = x.reshape(meta_batch_size, n * k + q * k, num_input_channels,
                          x.shape[-2], x.shape[-1])
            # Move to device
            x = x.double().to(device)
            # Create label
            y = create_nshot_task_label(k, q).cuda().repeat(meta_batch_size)
            return x, y

        return prepare_meta_batch_

    callbacks = [
        SaveFewShot(
            eval_fn=fit_fn,
            num_tasks=args.eval_batches,
            n_shot=args.n,
            k_way=args.k,
            q_queries=args.q,
            taskloader=evaluation_taskloader,
            prepare_batch=prepare_meta_batch(args.n, args.k, args.q,
                                             args.meta_batch_size),
            # MAML kwargs
            inner_train_steps=args.inner_val_steps,
            inner_lr=args.inner_lr,
            device=device,
            order=args.order,
            pred_mode=args.pred_mode,
            model_params=model_params)
    ]
    print(names[0][:-7])
    save_res(
        meta_models,
        meta_optimisers,
        loss_fn,
        epochs=args.epochs,
        dataloader=background_taskloader,
        prepare_batch=prepare_meta_batch(args.n, args.k, args.q,
                                         args.meta_batch_size),
        callbacks=callbacks,
        metrics=['categorical_accuracy'],
        fit_function=fit_fn,
        name=names[0][:-7],
        fit_function_kwargs={
            'n_shot': args.n,
            'k_way': args.k,
            'q_queries': args.q,
            'train': False,
            'pred_mode': args.pred_mode,
            'order': args.order,
            'device': device,
            'inner_train_steps': args.inner_train_steps,
            'inner_lr': args.inner_lr,
            'model_params': model_params
        },
    )
def train_sweep():

    from torch.optim import Adam
    from torch.utils.data import DataLoader
    import argparse

    from few_shot.datasets import OmniglotDataset, MiniImageNet, ClinicDataset, SNIPSDataset, CustomDataset
    from few_shot.models import XLNetForEmbedding
    from few_shot.core import NShotTaskSampler, EvaluateFewShot, prepare_nshot_task
    from few_shot.proto import proto_net_episode
    from few_shot.train_with_prints import fit
    from few_shot.callbacks import CallbackList, Callback, DefaultCallback, ProgressBarLogger, CSVLogger, EvaluateMetrics, ReduceLROnPlateau, ModelCheckpoint, LearningRateScheduler
    from few_shot.utils import setup_dirs
    from few_shot.utils import get_gpu_info
    from config import PATH
    import wandb
    from transformers import AdamW

    import torch

    gpu_dict = get_gpu_info()
    print('Total GPU Mem: {} , Used GPU Mem: {}, Used Percent: {}'.format(
        gpu_dict['mem_total'], gpu_dict['mem_used'],
        gpu_dict['mem_used_percent']))

    setup_dirs()
    assert torch.cuda.is_available()
    device = torch.device('cuda')
    torch.backends.cudnn.benchmark = True

    ##############
    # Parameters #
    ##############
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', default='Custom')
    parser.add_argument('--distance', default='l2')
    parser.add_argument('--n-train', default=2, type=int)
    parser.add_argument('--n-test', default=2, type=int)
    parser.add_argument('--k-train', default=2, type=int)
    parser.add_argument('--k-test', default=2, type=int)
    parser.add_argument('--q-train', default=2, type=int)
    parser.add_argument('--q-test', default=2, type=int)
    args = parser.parse_args()

    evaluation_episodes = 100
    episodes_per_epoch = 10

    if args.dataset == 'omniglot':
        n_epochs = 40
        dataset_class = OmniglotDataset
        num_input_channels = 1
        drop_lr_every = 20
    elif args.dataset == 'miniImageNet':
        n_epochs = 80
        dataset_class = MiniImageNet
        num_input_channels = 3
        drop_lr_every = 40
    elif args.dataset == 'clinic150':
        n_epochs = 5
        dataset_class = ClinicDataset
        num_input_channels = 150
        drop_lr_every = 2
    elif args.dataset == 'SNIPS':
        n_epochs = 5
        dataset_class = SNIPSDataset
        num_input_channels = 150
        drop_lr_every = 2
    elif args.dataset == 'Custom':
        n_epochs = 20
        dataset_class = CustomDataset
        num_input_channels = 150
        drop_lr_every = 5
    else:
        raise (ValueError, 'Unsupported dataset')

    param_str = f'{args.dataset}_nt={args.n_train}_kt={args.k_train}_qt={args.q_train}_' \
                f'nv={args.n_test}_kv={args.k_test}_qv={args.q_test}'

    print(param_str)

    from sklearn.model_selection import train_test_split

    ###################
    # Create datasets #
    ###################

    train_df = dataset_class('train')

    train_taskloader = DataLoader(train_df,
                                  batch_sampler=NShotTaskSampler(
                                      train_df, episodes_per_epoch,
                                      args.n_train, args.k_train,
                                      args.q_train))

    val_df = dataset_class('val')

    evaluation_taskloader = DataLoader(
        val_df,
        batch_sampler=NShotTaskSampler(val_df, episodes_per_epoch, args.n_test,
                                       args.k_test, args.q_test))

    #train_iter = iter(train_taskloader)
    #train_taskloader = next(train_iter)

    #val_iter = iter(evaluation_taskloader)
    #evaluation_taskloader = next(val_iter)

    #########
    # Wandb #
    #########

    config_defaults = {
        'lr': 0.00001,
        'optimiser': 'adam',
        'batch_size': 16,
    }

    wandb.init(config=config_defaults)

    #########
    # Model #
    #########

    torch.cuda.empty_cache()

    try:
        print('Before Model Move')
        gpu_dict = get_gpu_info()
        print('Total GPU Mem: {} , Used GPU Mem: {}, Used Percent: {}'.format(
            gpu_dict['mem_total'], gpu_dict['mem_used'],
            gpu_dict['mem_used_percent']))
    except:
        pass

    #from transformers import XLNetForSequenceClassification, AdamW

    #model = XLNetForSequenceClassification.from_pretrained('xlnet-base-cased', num_labels=150)
    #model.cuda()

    try:
        del model
    except:
        print("Cannot delete model. No model with name 'model' exists")

    model = XLNetForEmbedding(num_input_channels)
    model.to(device, dtype=torch.double)

    #param_optimizer = list(model.named_parameters())
    #no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    #optimizer_grouped_parameters = [
    #                                {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    #                                {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay':0.0}
    #]

    try:
        print('After Model Move')
        gpu_dict = get_gpu_info()
        print('Total GPU Mem: {} , Used GPU Mem: {}, Used Percent: {}'.format(
            gpu_dict['mem_total'], gpu_dict['mem_used'],
            gpu_dict['mem_used_percent']))
    except:
        pass

    wandb.watch(model)

    ############
    # Training #
    ############

    from transformers import AdamW

    print(f'Training Prototypical network on {args.dataset}...')
    if wandb.config.optimiser == 'adam':
        optimiser = Adam(model.parameters(), lr=wandb.config.lr)
    else:
        optimiser = AdamW(model.parameters(), lr=wandb.config.lr)

    #optimiser = AdamW(optimizer_grouped_parameters, lr=3e-5)
    #loss_fn = torch.nn.NLLLoss().cuda()

    #loss_fn = torch.nn.CrossEntropyLoss()

    #max_grad_norm = 1.0

    loss_fn = torch.nn.NLLLoss()

    def lr_schedule(epoch, lr):
        # Drop lr every 2000 episodes
        if epoch % drop_lr_every == 0:
            return lr / 2
        else:
            return lr

    callbacks = [
        EvaluateFewShot(eval_fn=proto_net_episode,
                        num_tasks=evaluation_episodes,
                        n_shot=args.n_test,
                        k_way=args.k_test,
                        q_queries=args.q_test,
                        taskloader=evaluation_taskloader,
                        prepare_batch=prepare_nshot_task(
                            args.n_test, args.k_test, args.q_test),
                        distance=args.distance),
        ModelCheckpoint(
            filepath=PATH + f'/models/proto_nets/{param_str}.pth',
            monitor=f'val_{args.n_test}-shot_{args.k_test}-way_acc'),
        LearningRateScheduler(schedule=lr_schedule),
        CSVLogger(PATH + f'/logs/proto_nets/{param_str}.csv'),
    ]

    try:
        print('Before Fit')
        print('optimiser :', optimiser)
        print('Learning Rate: ', wandb.config.lr)
        gpu_dict = get_gpu_info()
        print('Total GPU Mem: {} , Used GPU Mem: {}, Used Percent: {}'.format(
            gpu_dict['mem_total'], gpu_dict['mem_used'],
            gpu_dict['mem_used_percent']))
    except:
        pass

    fit(
        model,
        optimiser,
        loss_fn,
        epochs=n_epochs,
        dataloader=train_taskloader,
        prepare_batch=prepare_nshot_task(args.n_train, args.k_train,
                                         args.q_train),
        callbacks=callbacks,
        metrics=['categorical_accuracy'],
        fit_function=proto_net_episode,
        fit_function_kwargs={
            'n_shot': args.n_train,
            'k_way': args.k_train,
            'q_queries': args.q_train,
            'train': True,
            'distance': args.distance
        },
    )
    drop_lr_every = 2
else:
    raise (ValueError, 'Unsupported dataset')

param_str = f'{args.dataset}_nt={args.n_train}_kt={args.k_train}_qt={args.q_train}_' \
            f'nv={args.n_test}_kv={args.k_test}_qv={args.q_test}'

print(param_str)

###################
# Create datasets #
###################
train_df = dataset_class('train')
train_taskloader = DataLoader(train_df,
                              batch_sampler=NShotTaskSampler(
                                  train_df, episodes_per_epoch, args.n_train,
                                  args.k_train, args.q_train))
val_df = dataset_class('val')
evaluation_taskloader = DataLoader(val_df,
                                   batch_sampler=NShotTaskSampler(
                                       val_df, episodes_per_epoch, args.n_test,
                                       args.k_test, args.q_test))

#train_iter = iter(train_taskloader)
#train_taskloader = next(train_iter)

#val_iter = iter(evaluation_taskloader)
#evaluation_taskloader = next(val_iter)

#########
# Wandb #
Пример #21
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_path",
        type=str,
        default=
        "./models/proto_nets/miniImageNet_nt=5_kt=5_qt=10_nv=5_kv=5_qv=10_dist=l2_sampling_method=True_is_diverisity=True.pth",
        help="model path")
    parser.add_argument(
        "--result_path",
        type=str,
        default="./results/proto_nets/5shot_training_5shot_diverisity.csv",
        help="Directory for evaluation report result (for experiments)")
    parser.add_argument('--dataset', type=str, required=True)
    parser.add_argument('--distance', default='cosine')
    parser.add_argument('--n_train', default=1, type=int)
    parser.add_argument('--n_test', default=1, type=int)
    parser.add_argument('--k_train', default=5, type=int)
    parser.add_argument('--k_test', default=5, type=int)
    parser.add_argument('--q_train', default=15, type=int)
    parser.add_argument('--q_test', default=15, type=int)
    parser.add_argument(
        "--debug",
        action="store_true",
        help="set logging level DEBUG",
    )
    args = parser.parse_args()

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.DEBUG if args.debug else logging.INFO,
    )

    ###################
    # Create datasets #
    ###################
    episodes_per_epoch = 600

    if args.dataset == 'miniImageNet':
        n_epochs = 5
        dataset_class = MiniImageNet
        num_input_channels = 3
    else:
        raise (ValueError('need to make other datasets module'))

    test_dataset = dataset_class('test')
    test_dataset_taskloader = DataLoader(
        test_dataset,
        batch_sampler=NShotTaskSampler(test_dataset, episodes_per_epoch,
                                       args.n_test, args.k_test, args.q_test),
        num_workers=4)

    #########
    # Model #
    #########
    model = get_few_shot_encoder(num_input_channels).to(device,
                                                        dtype=torch.double)

    model.load_state_dict(torch.load(args.model_path), strict=False)
    model.eval()

    #############
    # Inference #
    #############
    logger.info("***** Epochs = %d *****", n_epochs)
    logger.info("***** Num episodes per epoch = %d *****", episodes_per_epoch)

    result_writer = ResultWriter(args.result_path)

    # just argument (function: proto_net_episode)
    prepare_batch = prepare_nshot_task(args.n_test, args.k_test, args.q_test)
    optimiser = Adam(model.parameters(), lr=1e-3)
    loss_fn = torch.nn.NLLLoss().cuda()

    train_iterator = trange(
        0,
        int(n_epochs),
        desc="Epoch",
    )
    for i_epoch in train_iterator:
        epoch_iterator = tqdm(
            test_dataset_taskloader,
            desc="Iteration",
        )
        seen = 0
        metric_name = f'test_{args.n_test}-shot_{args.k_test}-way_acc'
        metric = {metric_name: 0.0}
        for _, batch in enumerate(epoch_iterator):
            x, y = prepare_batch(batch)

            loss, y_pred = proto_net_episode(model,
                                             optimiser,
                                             loss_fn,
                                             x,
                                             y,
                                             n_shot=args.n_test,
                                             k_way=args.k_test,
                                             q_queries=args.q_test,
                                             train=False,
                                             distance=args.distance)

            seen += y_pred.shape[0]
            metric[metric_name] += categorical_accuracy(
                y, y_pred) * y_pred.shape[0]

        metric[metric_name] = metric[metric_name] / seen

        logger.info("epoch: {},     categorical_accuracy: {}".format(
            i_epoch, metric[metric_name]))
        result_writer.update(**metric)
Пример #22
0
    raise (ValueError, 'Unsupported dataset')

param_str = f'{args.dataset}_nt={args.n_train}_kt={args.k_train}_qt={args.q_train}_' \
            f'nv={args.n_test}_kv={args.k_test}_qv={args.q_test}'

print(param_str)

###################
# Create datasets #
###################
background = dataset_class('background')
background_taskloader = DataLoader(background,
                                   batch_sampler=NShotTaskSampler(
                                       background,
                                       episodes_per_epoch,
                                       args.n_train,
                                       args.k_train,
                                       args.q_train,
                                       num_tasks=args.meta_batch_size),
                                   num_workers=8)
evaluation = dataset_class('evaluation')
evaluation_taskloader = DataLoader(evaluation,
                                   batch_sampler=NShotTaskSampler(
                                       evaluation,
                                       episodes_per_epoch,
                                       args.n_test,
                                       args.k_test,
                                       args.q_test,
                                       num_tasks=args.meta_batch_size),
                                   num_workers=8)
Пример #23
0
    raise (ValueError('Unsupported dataset'))

param_str = f'{args.dataset}_order={args.order}_n={args.n}_k={args.k}_metabatch={args.meta_batch_size}_' \
            f'train_steps={args.inner_train_steps}_val_steps={args.inner_val_steps}_n_models={args.n_models}_train_pred_mode={args.train_pred_mode}_' \
            f'test_pred_mode={args.test_pred_mode}'
print(param_str)

###################
# Create datasets #
###################
background = dataset_class('background')
background_taskloader = DataLoader(background,
                                   batch_sampler=NShotTaskSampler(
                                       background,
                                       args.epoch_len,
                                       n=args.n,
                                       k=args.k,
                                       q=args.q,
                                       num_tasks=args.meta_batch_size),
                                   num_workers=8)
evaluation = dataset_class('evaluation')
evaluation_taskloader = DataLoader(evaluation,
                                   batch_sampler=NShotTaskSampler(
                                       evaluation,
                                       args.eval_batches,
                                       n=args.n,
                                       k=args.k,
                                       q=args.q,
                                       num_tasks=args.meta_batch_size),
                                   num_workers=8)
Пример #24
0
if args.network == 'proto':
    optimiser = Adam(model.parameters(), lr=1e-3)
    loss_fn = torch.nn.NLLLoss().cuda()

    n_epochs = 80
    dataset_class = FashionDataset
    num_input_channels = 3
    drop_lr_every = 40
    model = get_few_shot_encoder(num_input_channels)
    eval_fn = proto_net_episode
    evaluation_taskloader = DataLoader(
        evaluation,
        batch_sampler=NShotTaskSampler(
            evaluation,
            episodes_per_epoch,
            args.n_test,
            args.k_test,
            args.q_test,
            eval_classes=None
        ),  # why is qtest needed for protonet i think its not rquired for protonet check it
        num_workers=4)
    callbacks = [
        EvaluateFewShot(
            eval_fn=eval_fn,
            num_tasks=evaluation_episodes,
            n_shot=args.n_test,
            k_way=args.k_test,
            q_queries=args.q_test,
            taskloader=evaluation_taskloader,
            prepare_batch=prepare_nshot_task(
                args.n_test, args.k_test, args.q_test
            ),  # n shot task is a simple function that maps classes to [0-k]
Пример #25
0
else:
    raise(ValueError, 'Unsupported dataset')

param_str = f'{args.dataset}_nt={args.n_train}_kt={args.k_train}_qt={args.q_train}_' \
            f'nv={args.n_test}_kv={args.k_test}_qv={args.q_test}_experiment={args.experiment_name}'

print(param_str)

###################
# Create datasets #
###################
# background = dataset_class('images_background')
background = dataset_class('datasets/in32_all')
background_taskloader = DataLoader(
    background,
    batch_sampler=NShotTaskSampler(background, episodes_per_epoch, args.n_train, args.k_train, args.q_train),
    num_workers=4
)
# evaluation = dataset_class('images_evaluation')
evaluation = dataset_class('datasets/test')
evaluation_taskloader = DataLoader(
    evaluation,
    batch_sampler=NShotTaskSampler(evaluation, episodes_per_epoch, args.n_test, args.k_test, args.q_test),
    num_workers=4
)


#########
# Model #
#########
model = get_few_shot_encoder(num_input_channels)
Пример #26
0
def evaluate_few_shot(state_dict,
                      n_shot,
                      k_way,
                      q_queries,
                      device,
                      architecture='resnet18',
                      pretrained=False,
                      small_dataset=False,
                      metric_name=None,
                      evaluation_episodes=1000,
                      num_input_channels=3,
                      distance='l2'):
    if not pretrained:
        model = get_few_shot_encoder(num_input_channels)
        model.load_state_dict(state_dict)
    else:
        # assert torch.cuda.is_available()
        model = models.__dict__[architecture](pretrained=True)
        model.fc = Identity()
        model.load_state_dict(state_dict)

    dataset_class = FashionProductImagesSmall if small_dataset \
        else FashionProductImages

    # Meta-test set
    resize = (80, 60) if small_dataset else (400, 300)
    evaluation_transform = transforms.Compose([
        transforms.Resize(resize),
        transforms.ToTensor(),
    ])

    evaluation = FashionProductImagesSmall(DATA_PATH,
                                           split='all',
                                           classes='evaluation',
                                           transform=evaluation_transform)
    sampler = NShotTaskSampler(evaluation, evaluation_episodes, n_shot, k_way,
                               q_queries)
    taskloader = DataLoader(evaluation, batch_sampler=sampler, num_workers=4)
    prepare_batch = prepare_nshot_task(n_shot, k_way, q_queries)

    if metric_name is None:
        metric_name = f'test_{n_shot}-shot_{k_way}-way_acc'
    seen = 0
    totals = {'loss': 0, metric_name: 0}

    optimiser = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = torch.nn.NLLLoss().to(device)

    for batch_index, batch in enumerate(taskloader):
        x, y = prepare_batch(batch)

        loss, y_pred = proto_net_episode(model,
                                         optimiser,
                                         loss_fn,
                                         x,
                                         y,
                                         n_shot=n_shot,
                                         k_way=k_way,
                                         q_queries=q_queries,
                                         train=False,
                                         distance=distance)

        seen += y_pred.shape[0]

        totals['loss'] += loss.item() * y_pred.shape[0]
        totals[metric_name] += categorical_accuracy(y, y_pred) * \
                               y_pred.shape[0]

    totals['loss'] = totals['loss'] / seen
    totals[metric_name] = totals[metric_name] / seen

    return totals