コード例 #1
0
        base_datamgr = miniImageNet_few_shot.SetDataManager(
            image_size, n_query=n_query, **train_few_shot_params)
        base_loader = base_datamgr.get_data_loader(aug=params.train_aug)

        # use unlabeled data from these novel domains for adversarial domain adaptation
        # TODO: since the data is unlabeled, we need to modify data manager / data loader
        if params.dataset == "ChestX":
            target_datamgr = Chest_few_shot.SetDataManager(
                image_size, n_query=n_query, **train_few_shot_params)

        elif params.dataset == "EuroSAT":
            target_datamgr = EuroSAT_few_shot.SetDataManager(
                image_size, n_query=n_query, **train_few_shot_params)

        elif params.dataset == "ISIC2018":
            target_datamgr = ISIC_few_shot.SetDataManager(
                image_size, n_query=n_query, **train_few_shot_params)

        elif params.dataset == "CropDiseases":
            target_datamgr = CropDisease_few_show.SetDataManager(
                image_size, n_query=n_query, **train_few_shot_params)

        else:
            raise ValueError('Unknown dataset')

        target_loader = target_datamgr.get_data_loader(novel_file,
                                                       aug=params.train_aug)

        if params.adversarial or params.adaptFinetune:
            # TODO: check argv
            target_datamgr = SetDataManager(image_size,
                                            n_query=n_query,
コード例 #2
0
    iter_num = 600

    n_query = max(
        1, int(16 * params.test_n_way / params.train_n_way)
    )  #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small
    few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot)

    models_to_use = params.models_to_use
    finetune_each_model = params.fine_tune_all_models
    ##################################################################

    dataset_names = ["ISIC", "EuroSAT", "CropDisease", "Chest"]
    novel_loaders = []

    datamgr = ISIC_few_shot.SetDataManager(image_size,
                                           n_eposide=iter_num,
                                           n_query=15,
                                           **few_shot_params)
    novel_loader = datamgr.get_data_loader(aug=False)
    novel_loaders.append(novel_loader)

    datamgr = EuroSAT_few_shot.SetDataManager(image_size,
                                              n_eposide=iter_num,
                                              n_query=15,
                                              **few_shot_params)
    novel_loader = datamgr.get_data_loader(aug=False)
    novel_loaders.append(novel_loader)

    datamgr = CropDisease_few_shot.SetDataManager(image_size,
                                                  n_eposide=iter_num,
                                                  n_query=15,
                                                  **few_shot_params)
コード例 #3
0
ファイル: train_adapt.py プロジェクト: bigchou/ammai_hw2
                                                       n_query=n_query,
                                                       mode="train",
                                                       **train_few_shot_params)
        base_loader = datamgr.get_data_loader(aug=params.train_aug)
        #the above line waste a lot of time
        val_datamgr = miniImageNet_few_shot.SetDataManager(
            image_size, n_query=n_query, mode="val", **test_few_shot_params)
        val_loader = val_datamgr.get_data_loader(aug=False)

        #=========== ISIC or EuroSAT ========
        few_shot_params = dict(n_way=params.test_n_way,
                               n_support=params.n_shot)
        if params.method == 'mytpnadaptisic':
            print("init mytpnadaptisic")
            datamgr = ISIC_few_shot.SetDataManager(image_size,
                                                   n_eposide=600,
                                                   n_query=n_query,
                                                   **few_shot_params)
        elif params.method == 'mytpnadapteurosat':
            print("init mytpnadapteurosat")
            datamgr = EuroSAT_few_shot.SetDataManager(image_size,
                                                      n_eposide=600,
                                                      n_query=n_query,
                                                      **few_shot_params)
        else:
            raise ValueError('Unknown method')
        novel_loader = datamgr.get_data_loader(aug=False)
        #=====================================

        #======= load pretrained model and ckpt =======
        model = MyTPN_Adapt(model_dict[params.model], **train_few_shot_params)