'dampnet_full_class', 'dampnet_full_sparse', 'protonet_damp',
            'maml', 'relationnet', 'dampnet_full', 'dampnet', 'protonet',
            'gnnnet', 'gnnnet_maml', 'metaoptnet', 'gnnnet_normalized',
            'gnnnet_neg_margin'
    ]:
        n_query = max(
            1, int(16 * params.test_n_way / params.train_n_way)
        )  #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small
        train_few_shot_params = dict(n_way=params.train_n_way,
                                     n_support=params.n_shot)
        test_few_shot_params = dict(n_way=params.test_n_way,
                                    n_support=params.n_shot)

        if params.dataset == "miniImageNet":
            print("loading")
            datamgr = miniImageNet_few_shot.SetDataManager(
                image_size, n_query=n_query, **train_few_shot_params)
            base_loader = datamgr.get_data_loader(aug=params.train_aug)
            #datamgr         = miniImageNet_few_shot.SimpleDataManager(image_size, batch_size = 64)
            #data_loader     = datamgr.get_data_loader(aug = False )

            print("BYE")

        else:
            raise ValueError('Unknown dataset')

        if params.method == 'protonet':
            model = ProtoNet(model_dict[params.model], **train_few_shot_params)
        elif params.method == 'protonet_damp':
            model = protonet_damp.ProtoNet(model_dict[params.model],
                                           **train_few_shot_params)
        elif params.method == 'relationnet':
Exemple #2
0
    n_pseudo = 100

    ##################################################################
    # loading dataset
    pretrained_dataset = "miniImageNet"
    dataset_names = ["EuroSAT", "ISIC"]

    novel_loaders = []
    if task == 'fsl':
        freeze_backbone = True

        dataset_names = ["miniImageNet"]
        print("Loading mini-ImageNet")
        datamgr = miniImageNet_few_shot.SetDataManager(image_size,
                                                       n_eposide=iter_num,
                                                       n_query=15,
                                                       mode="test",
                                                       **few_shot_params)
        novel_loader = datamgr.get_data_loader(aug=False)
        novel_loaders.append(novel_loader)
    else:
        freeze_backbone = params.freeze_backbone

        dataset_names = ["EuroSAT", "ISIC"]

        print("Loading EuroSAT")
        datamgr = EuroSAT_few_shot.SetDataManager(image_size,
                                                  n_eposide=iter_num,
                                                  n_query=15,
                                                  **few_shot_params)
        novel_loader = datamgr.get_data_loader(aug=False)
Exemple #3
0
                                 params.logit_scale)

    elif params.method in ['protonet', 'myprotonet']:
        n_query = max(
            1, int(16 * params.test_n_way / params.train_n_way)
        )  #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small
        train_few_shot_params = dict(n_way=params.train_n_way,
                                     n_support=params.n_shot)
        test_few_shot_params = dict(n_way=params.test_n_way,
                                    n_support=params.n_shot)

        if params.dataset == "miniImageNet":

            datamgr = miniImageNet_few_shot.SetDataManager(
                image_size,
                n_query=n_query,
                mode="train",
                **train_few_shot_params)
            base_loader = datamgr.get_data_loader(aug=params.train_aug)
            val_datamgr = miniImageNet_few_shot.SetDataManager(
                image_size,
                n_query=n_query,
                mode="val",
                **test_few_shot_params)
            val_loader = val_datamgr.get_data_loader(aug=False)

        else:
            raise ValueError('Unknown dataset')

        if params.method == 'protonet':
            model = ProtoNet(model_dict[params.model], **train_few_shot_params)