Exemple #1
0
                                                                 aug=False)
            test_few_shot_params = dict(n_way=params.train_n_way,
                                        n_support=params.n_shot)
            val_datamgr = SetDataManager(image_size,
                                         n_query=15,
                                         **test_few_shot_params)
            val_loader = val_datamgr.get_data_loader(val_file, aug=False)

        if params.method == 'baseline++':
            model = BaselineTrain(model_dict[params.model],
                                  params.num_classes,
                                  loss_type='dist')

        elif params.method == 'manifold_mixup':
            if params.model == 'WideResNet28_10':
                model = wrn_mixup_model.wrn28_10(params.num_classes)
            elif params.model == 'ResNet18':
                model = res_mixup_model.resnet18(
                    num_classes=params.num_classes)

        elif params.method == 'S2M2_R' or 'rotation':
            if params.model == 'WideResNet28_10':
                model = wrn_mixup_model.wrn28_10(
                    num_classes=params.num_classes,
                    dct_status=params.dct_status)
            elif params.model == 'ResNet18':
                model = res_mixup_model.resnet18(
                    num_classes=params.num_classes)

        if params.method == 'baseline++':
            if use_gpu:
Exemple #2
0
    image_size = 32

    base_file = configs.data_dir[params.dataset] + 'base.json'
    params.checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
        configs.save_dir, params.dataset, params.model, params.method)
    start_epoch = params.start_epoch
    stop_epoch = params.stop_epoch

    base_datamgr = SimpleDataManager(image_size, batch_size=params.batch_size)
    base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug)
    val_datamgr = SimpleDataManager(image_size,
                                    batch_size=params.test_batch_size)
    val_loader = base_datamgr.get_data_loader(base_file, aug=False)

    if params.method == 'manifold_mixup':
        model = wrn_mixup_model.wrn28_10(64)
    elif params.method == 'S2M2_R':
        model = wrn_mixup_model.wrn28_10(200)
    elif params.method == 'rotation':
        model = BaselineTrain(model_dict[params.model], 64, loss_type='dist')

    if params.method == 'S2M2_R':

        if use_gpu:
            if torch.cuda.device_count() > 1:
                model = torch.nn.DataParallel(model,
                                              device_ids=range(
                                                  torch.cuda.device_count()))
            model.cuda()

        if params.resume:
Exemple #3
0
if __name__ == '__main__':
    params = parse_args('test')

    loadfile = configs.data_dir[params.dataset] + 'novel.json'
    print('loading data file from: ', loadfile)

    datamgr = SimpleDataManager(84, batch_size=256)
    novel_loader = datamgr.get_data_loader(loadfile, aug=False)

    checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
        configs.save_dir, params.dataset, params.model, params.method)
    modelfile = get_resume_file(checkpoint_dir)
    print('loading model file from: ', checkpoint_dir)

    if params.model == 'WideResNet28_10':
        model = wrn_mixup_model.wrn28_10(num_classes=params.num_classes)
    elif params.model == 'ResNet18':
        model = res_mixup_model.resnet18(num_classes=params.num_classes)

    model = model.cuda()
    cudnn.benchmark = True

    checkpoint = torch.load(modelfile)
    state = checkpoint['state']
    state_keys = list(state.keys())

    callwrap = False
    if 'module' in state_keys[0]:
        callwrap = True
    if callwrap:
        model = WrappedModel(model)
Exemple #4
0
        modelfile = get_resume_file(checkpoint_dir)

    if params.save_iter != -1:
        outfile = os.path.join(
            checkpoint_dir.replace("checkpoints", "features"),
            split + "_" + str(params.save_iter) + ".hdf5")
    else:
        outfile = os.path.join(
            checkpoint_dir.replace("checkpoints", "features"), split + ".hdf5")

    datamgr = SimpleDataManager(image_size, batch_size=3)
    data_loader = datamgr.get_data_loader(loadfile, aug=False)

    if params.method == 'manifold_mixup':
        if params.dataset == 'cifar':
            model = wrn_mixup_model.wrn28_10(64)
        else:
            model = wrn_mixup_model.wrn28_10(200)
    elif params.method == 'S2M2_R':
        if params.dataset == 'cifar':
            model = wrn_mixup_model.wrn28_10(64, loss_type='softmax')
        else:
            model = wrn_mixup_model.wrn28_10(200)
    else:
        model = model_dict[params.model]()

    print(checkpoint_dir, modelfile)
    if params.method == 'manifold_mixup' or params.method == 'S2M2_R':

        if modelfile is not None:
            if torch.cuda.is_available():
Exemple #5
0
    image_size = 32

    base_file = configs.data_dir[params.dataset] + 'base.json'
    params.checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
        configs.save_dir, params.dataset, params.model, params.method)
    start_epoch = params.start_epoch
    stop_epoch = params.stop_epoch

    base_datamgr = SimpleDataManager(image_size, batch_size=params.batch_size)
    base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug)
    val_datamgr = SimpleDataManager(image_size,
                                    batch_size=params.test_batch_size)
    val_loader = base_datamgr.get_data_loader(base_file, aug=False)

    if params.method == 'manifold_mixup':
        model = wrn_mixup_model.wrn28_10(64, 0.9)
    elif params.method == 'S2M2_R':
        model = wrn_mixup_model.wrn28_10(64, 0.9)
    elif params.method == 'rotation':
        model = BaselineTrain(model_dict[params.model],
                              64,
                              dropRate=0.9,
                              loss_type='dist')

    if params.method == 'S2M2_R':

        if use_gpu:
            if torch.cuda.device_count() > 1:
                model = torch.nn.DataParallel(model,
                                              device_ids=range(
                                                  torch.cuda.device_count()))

    if params.save_iter != -1:
        outfile = os.path.join( checkpoint_dir_plain.replace("checkpoints","features"), split + "_" + str(params.save_iter)+ "_both.hdf5") 
    else:
        outfile = os.path.join( checkpoint_dir_plain.replace("checkpoints","features"), split + "_both.hdf5") 

    datamgr         = SimpleDataManager_both(image_size, batch_size = 1)

    print("out file: ", outfile)
    data_loader_dct      = datamgr.get_data_loader_dct(loadfile, aug = False)    
    data_loader_plain      = datamgr.get_data_loader(loadfile, aug = False)

    if params.method == 'manifold_mixup':
        if params.dataset == 'cifar':
            model = wrn_mixup_model.wrn28_10(64)
        else:
            model = wrn_mixup_model.wrn28_10(200)
    elif params.method == 'S2M2_R':
        if params.dataset == 'cifar':
            model_plain = wrn_mixup_model.wrn28_10(64 ,  dct_status = False)
            model_dct = wrn_mixup_model.wrn28_10(64, dct_status = True)
        else:
            model_plain = wrn_mixup_model.wrn28_10(200)
            model_dct = wrn_mixup_model.wrn28_10(200, dct_status = True)
    else:
        model_plain = model_dict[params.model]()
        model_dct = model_dict[params.model + 'dct']()

   # print(checkpoint_dir , modelfile)
    if params.method == 'manifold_mixup' or params.method == 'S2M2_R' :
Exemple #7
0
    novel_loader_dct = datamgr.get_data_loader_dct(loadfile, aug=False)
    novel_loader_plain = datamgr.get_data_loader(loadfile, aug=False)

    checkpoint_dir_plain = '%s/checkpoints/%s/%s_%s_%sway_%sshot_aug' % (
        configs.save_dir, params.dataset, params.model, params.method,
        params.test_n_way, params.n_shot)
    checkpoint_dir_dct = '%s/checkpoints/%s/%s_%s_%sway_%sshot_aug_dct' % (
        configs.save_dir, params.dataset, params.model, params.method,
        params.test_n_way, params.n_shot)

    modelfile_plain = get_best_file(checkpoint_dir_plain)
    modelfile_dct = get_best_file(checkpoint_dir_dct)
    print(checkpoint_dir_plain, checkpoint_dir_dct)

    if params.model == 'WideResNet28_10':
        model_plain = wrn_mixup_model.wrn28_10(num_classes=params.num_classes,
                                               dct_status=False)
        model_dct = wrn_mixup_model.wrn28_10(num_classes=params.num_classes,
                                             dct_status=True)
    elif params.model == 'ResNet18':
        model = res_mixup_model.resnet18(num_classes=params.num_classes)

    model_plain = model_plain.cuda()
    model_dct = model_dct.cuda()
    cudnn.benchmark = True

    checkpoint_plain = torch.load(modelfile_plain)
    checkpoint_dct = torch.load(modelfile_dct)
    state_plain = checkpoint_plain['state']
    state_dct = checkpoint_dct['state']
    state_keys_plain = list(state_plain.keys())
    state_keys_dct = list(state_dct.keys())
Exemple #8
0
    image_size = 32

    base_file = configs.data_dir[params.dataset] + 'base.json'
    params.checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
        configs.save_dir, params.dataset, params.model, params.method)
    start_epoch = params.start_epoch
    stop_epoch = params.stop_epoch

    base_datamgr = SimpleDataManager(image_size, batch_size=params.batch_size)
    base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug)
    base_datamgr_test = SimpleDataManager(image_size,
                                          batch_size=params.test_batch_size)
    base_loader_test = base_datamgr_test.get_data_loader(base_file, aug=False)

    if params.model == 'WideResNet28_10':
        model = wrn_mixup_model.wrn28_10(num_classes=64)
    elif params.model == 'ResNet18':
        model = res_mixup_model.resnet18(num_classes=64)

    if params.method == 'S2M2_R':

        if use_gpu:
            if torch.cuda.device_count() > 1:
                model = torch.nn.DataParallel(model,
                                              device_ids=range(
                                                  torch.cuda.device_count()))
            model.cuda()

        if params.resume:
            resume_file = get_resume_file(params.checkpoint_dir)
            print("resume_file", resume_file)