예제 #1
0
def load_mask_detector():
    ############### Change path to path of model ###################
    # path_model='/content/drive/MyDrive/frinks/models/fastai_resnet101'
    path_model='/content/drive/MyDrive/frinks/fewShots/CloserLookFewShot/checkpoints_masks_Conv4_baseline_aug/20.tar'
    path_data='/content/drive/MyDrive/frinks/Faces/data'
    if flag is 'torch':
      if not flag_fewShots:
        model = torch.load(path_model)
      if flag_fewShots:
        # import pdb; pdb.set_trace()
        model = BaselineTrain(backbone.Conv4, 4)
        model_dict = torch.load(path_model)
        model.load_state_dict(model_dict['state'])
        model=model.cuda()
    elif flag is 'fastai':
      data = ImageDataBunch.from_folder(path_data, valid_pct=0.2, size = 120)
      model = cnn_learner(data, models.resnet101, metrics=error_rate)
      model.load(path_model)
    else:
      model = LoadModel(path_model)
    return model
예제 #2
0
def get_model(params):
    if params.method in ['baseline', 'baseline++']:
        if params.dataset == 'omniglot':
            assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class'
        if params.dataset == 'cross_char':
            assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class'

        if params.method == 'baseline':
            model = BaselineTrain(model_dict[params.model], params.num_classes)
        elif params.method == 'baseline++':
            model = BaselineTrain(model_dict[params.model],
                                  params.num_classes,
                                  loss_type='dist')
    else:
        raise ValueError('Unknown method')

    if params.resume:
        resume_file = get_resume_file(params.checkpoint_dir)
        if resume_file is not None:
            tmp = torch.load(resume_file)
            params.start_epoch = tmp['epoch'] + 1
            model.load_state_dict(tmp['state'])
    elif params.warmup:  #We also support warmup from pretrained baseline feature, but we never used in our paper
        baseline_checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
            configs.save_dir, params.dataset, params.model, 'baseline')
        if params.train_aug:
            baseline_checkpoint_dir += '_aug'
        warmup_resume_file = get_resume_file(baseline_checkpoint_dir)
        tmp = torch.load(warmup_resume_file)
        if tmp is not None:
            state = tmp['state']
            state_keys = list(state.keys())
            for i, key in enumerate(state_keys):
                if "feature." in key:
                    newkey = key.replace(
                        "feature.", ""
                    )  # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx'
                    state[newkey] = state.pop(key)
                else:
                    state.pop(key)
            model.feature.load_state_dict(state)
        else:
            raise ValueError('No warm_up file')

    return model
예제 #3
0
        val_datamgr = SimpleDataManager(image_size)
        val_loader = val_datamgr.get_data_loader(val_file,
                                                 batch_size=64,
                                                 aug=False)

        if params.dataset == "omniglot":
            assert (
                params.num_classes >= 4112
            ), "class number need to be larger than max label id in base class"
        if params.dataset == "cross_char":
            assert (
                params.num_classes >= 1597
            ), "class number need to be larger than max label id in base class"

        if params.method == "baseline":
            model = BaselineTrain(model_dict[params.model], params.num_classes)
        elif params.method == "baseline++":
            model = BaselineTrain(model_dict[params.model],
                                  params.num_classes,
                                  loss_type="dist")

    elif params.method in [
            "protonet",
            "matchingnet",
            "relationnet",
            "relationnet_softmax",
            "maml",
            "maml_approx",
    ]:
        n_query = max(
            1, int(16 * params.test_n_way / params.train_n_way)
예제 #4
0
            base_datamgr = caltech256_few_shot.SimpleDataManager(image_size,
                                                                 batch_size=16)
            base_loader = base_datamgr.get_data_loader(aug=False)
            params.num_classes = 257

        elif params.dataset == "DTD":
            base_datamgr = DTD_few_shot.SimpleDataManager(image_size,
                                                          batch_size=16)
            base_loader = base_datamgr.get_data_loader(aug=True)

        else:
            raise ValueError('Unknown dataset')

        #device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        #print(device)
        model = BaselineTrain(model_dict[params.model], params.num_classes)

    elif params.method in [
            'dampnet_full_class', 'dampnet_full_sparse', 'protonet_damp',
            'maml', 'relationnet', 'dampnet_full', 'dampnet', 'protonet',
            'gnnnet', 'gnnnet_maml', 'metaoptnet', 'gnnnet_normalized',
            'gnnnet_neg_margin'
    ]:
        n_query = max(
            1, int(16 * params.test_n_way / params.train_n_way)
        )  #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small
        train_few_shot_params = dict(n_way=params.train_n_way,
                                     n_support=params.n_shot)
        test_few_shot_params = dict(n_way=params.test_n_way,
                                    n_support=params.n_shot)
예제 #5
0
    image_size = 224
    optimization = 'Adam'

    if params.method in ['baseline']:

        if params.dataset == "miniImageNet":

            base_file = configs.data_dir['miniImagenet'] + 'base.json'
            base_datamgr = SimpleDataManager(image_size, batch_size=16)
            base_loader = base_datamgr.get_data_loader(base_file,
                                                       aug=params.train_aug)
        else:
            raise ValueError('Unknown dataset')

        model = BaselineTrain(model_dict[params.model], params.num_classes)

    elif params.method in ['protonet']:
        n_query = max(
            1, int(16 * params.test_n_way / params.train_n_way)
        )  #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small
        train_few_shot_params = dict(n_way=params.train_n_way,
                                     n_support=params.n_shot)
        test_few_shot_params = dict(n_way=params.test_n_way,
                                    n_support=params.n_shot)

        if params.dataset == "miniImageNet":

            base_file = configs.data_dir['miniImagenet'] + 'base.json'
            base_datamgr = SetDataManager(image_size,
                                          n_query=n_query,
예제 #6
0
    image_size = 224
    optimization = 'Adam'

    if params.method in ['baseline', 'myModel']:

        if params.dataset == "miniImageNet":
            datamgr = miniImageNet_few_shot.SimpleDataManager(image_size,
                                                              batch_size=180)
            base_loader = datamgr.get_data_loader(aug=params.train_aug)
            val_loader = None
        else:
            raise ValueError('Unknown dataset')

        if params.method == 'baseline':
            model = BaselineTrain(model_dict[params.model], params.num_classes)
        else:
            model = MyModelTrain(model_dict[params.model], params.num_classes,
                                 params.margin, params.embed_dim,
                                 params.logit_scale)

    elif params.method in ['protonet', 'myprotonet']:
        n_query = max(
            1, int(16 * params.test_n_way / params.train_n_way)
        )  #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small
        train_few_shot_params = dict(n_way=params.train_n_way,
                                     n_support=params.n_shot)
        test_few_shot_params = dict(n_way=params.test_n_way,
                                    n_support=params.n_shot)

        if params.dataset == "miniImageNet":
def main_train(params):
    _set_seed(params)

    results_logger = ResultsLogger(params)

    if params.dataset == 'cross':
        base_file = configs.data_dir['miniImagenet'] + 'all.json'
        val_file = configs.data_dir['CUB'] + 'val.json'
    elif params.dataset == 'cross_char':
        base_file = configs.data_dir['omniglot'] + 'noLatin.json'
        val_file = configs.data_dir['emnist'] + 'val.json'
    else:
        base_file = configs.data_dir[params.dataset] + 'base.json'
        val_file = configs.data_dir[params.dataset] + 'val.json'
    if 'Conv' in params.model:
        if params.dataset in ['omniglot', 'cross_char']:
            image_size = 28
        else:
            image_size = 84
    else:
        image_size = 224
    if params.dataset in ['omniglot', 'cross_char']:
        assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation'
        params.model = 'Conv4S'
    optimization = 'Adam'
    if params.stop_epoch == -1:
        if params.method in ['baseline', 'baseline++']:
            if params.dataset in ['omniglot', 'cross_char']:
                params.stop_epoch = 5
            elif params.dataset in ['CUB']:
                params.stop_epoch = 200  # This is different as stated in the open-review paper. However, using 400 epoch in baseline actually lead to over-fitting
            elif params.dataset in ['miniImagenet', 'cross']:
                params.stop_epoch = 400
            else:
                params.stop_epoch = 400  # default
        else:  # meta-learning methods
            if params.n_shot == 1:
                params.stop_epoch = 600
            elif params.n_shot == 5:
                params.stop_epoch = 400
            else:
                params.stop_epoch = 600  # default
    if params.method in ['baseline', 'baseline++']:
        base_datamgr = SimpleDataManager(image_size, batch_size=16)
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   aug=params.train_aug)
        val_datamgr = SimpleDataManager(image_size, batch_size=64)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)

        if params.dataset == 'omniglot':
            assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class'
        if params.dataset == 'cross_char':
            assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class'

        if params.method == 'baseline':
            model = BaselineTrain(model_dict[params.model], params.num_classes)
        elif params.method == 'baseline++':
            model = BaselineTrain(model_dict[params.model],
                                  params.num_classes,
                                  loss_type='dist')

    elif params.method in [
            'DKT', 'protonet', 'matchingnet', 'relationnet',
            'relationnet_softmax', 'maml', 'maml_approx'
    ]:
        n_query = max(
            1, int(16 * params.test_n_way / params.train_n_way)
        )  # if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small

        train_few_shot_params = dict(n_way=params.train_n_way,
                                     n_support=params.n_shot)
        base_datamgr = SetDataManager(image_size,
                                      n_query=n_query,
                                      **train_few_shot_params)  # n_eposide=100
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   aug=params.train_aug)

        test_few_shot_params = dict(n_way=params.test_n_way,
                                    n_support=params.n_shot)
        val_datamgr = SetDataManager(image_size,
                                     n_query=n_query,
                                     **test_few_shot_params)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)
        # a batch for SetDataManager: a [n_way, n_support + n_query, dim, w, h] tensor

        if (params.method == 'DKT'):
            model = DKT(model_dict[params.model], **train_few_shot_params)
            model.init_summary()
        elif params.method == 'protonet':
            model = ProtoNet(model_dict[params.model], **train_few_shot_params)
        elif params.method == 'matchingnet':
            model = MatchingNet(model_dict[params.model],
                                **train_few_shot_params)
        elif params.method in ['relationnet', 'relationnet_softmax']:
            if params.model == 'Conv4':
                feature_model = backbone.Conv4NP
            elif params.model == 'Conv6':
                feature_model = backbone.Conv6NP
            elif params.model == 'Conv4S':
                feature_model = backbone.Conv4SNP
            else:
                feature_model = lambda: model_dict[params.model](flatten=False)
            loss_type = 'mse' if params.method == 'relationnet' else 'softmax'

            model = RelationNet(feature_model,
                                loss_type=loss_type,
                                **train_few_shot_params)
        elif params.method in ['maml', 'maml_approx']:
            backbone.ConvBlock.maml = True
            backbone.SimpleBlock.maml = True
            backbone.BottleneckBlock.maml = True
            backbone.ResNet.maml = True
            model = MAML(model_dict[params.model],
                         approx=(params.method == 'maml_approx'),
                         **train_few_shot_params)
            if params.dataset in [
                    'omniglot', 'cross_char'
            ]:  # maml use different parameter in omniglot
                model.n_task = 32
                model.task_update_num = 1
                model.train_lr = 0.1
    else:
        raise ValueError('Unknown method')
    model = model.cuda()
    params.checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
        configs.save_dir, params.dataset, params.model, params.method)
    if params.train_aug:
        params.checkpoint_dir += '_aug'
    if not params.method in ['baseline', 'baseline++']:
        params.checkpoint_dir += '_%dway_%dshot' % (params.train_n_way,
                                                    params.n_shot)
    if not os.path.isdir(params.checkpoint_dir):
        os.makedirs(params.checkpoint_dir)
    start_epoch = params.start_epoch
    stop_epoch = params.stop_epoch
    if params.method == 'maml' or params.method == 'maml_approx':
        stop_epoch = params.stop_epoch * model.n_task  # maml use multiple tasks in one update
    if params.resume:
        resume_file = get_resume_file(params.checkpoint_dir)
        if resume_file is not None:
            tmp = torch.load(resume_file)
            start_epoch = tmp['epoch'] + 1
            model.load_state_dict(tmp['state'])
    elif params.warmup:  # We also support warmup from pretrained baseline feature, but we never used in our paper
        baseline_checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
            configs.save_dir, params.dataset, params.model, 'baseline')
        if params.train_aug:
            baseline_checkpoint_dir += '_aug'
        warmup_resume_file = get_resume_file(baseline_checkpoint_dir)
        tmp = torch.load(warmup_resume_file)
        if tmp is not None:
            state = tmp['state']
            state_keys = list(state.keys())
            for i, key in enumerate(state_keys):
                if "feature." in key:
                    newkey = key.replace(
                        "feature.", ""
                    )  # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx'
                    state[newkey] = state.pop(key)
                else:
                    state.pop(key)
            model.feature.load_state_dict(state)
        else:
            raise ValueError('No warm_up file')

    model = train(base_loader, val_loader, model, optimization, start_epoch,
                  stop_epoch, params, results_logger)
    results_logger.save()
예제 #8
0
파일: train_moco.py 프로젝트: hyoje42/S2M2
            else:
                print(f'{key} will be removed')
                del state[key]
        msg = model_moco.load_state_dict(state, strict=False)
        assert len(msg.missing_keys) == 0 and len(msg.unexpected_keys) == 0, "loading model is wrong"
        # get bottom of ResNet
        encoder = moco.ResNetBottom(model_moco.encoder_q)

    if params.method in ['baseline', 'baseline++'] :
        base_datamgr    = SimpleDataManager(image_size, batch_size = 64)
        base_loader     = base_datamgr.get_data_loader( base_file , aug = params.train_aug )
        val_datamgr     = SimpleDataManager(image_size, batch_size = 256)
        val_loader      = val_datamgr.get_data_loader( val_file, aug = False)
        
        if params.method == 'baseline':
            model           = BaselineTrain( model_dict[params.model], params.num_classes)
        elif params.method == 'baseline++':
            # model           = BaselineTrain( model_dict[params.model], params.num_classes, loss_type = 'dist')
            model           = BaselineTrain( encoder, params.num_classes, loss_type = 'dist')

    elif params.method in ['protonet','matchingnet','relationnet', 'relationnet_softmax', 'maml', 'maml_approx']:
        n_query = max(1, int(16* params.test_n_way/params.train_n_way)) #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small
 
        train_few_shot_params    = dict(n_way = params.train_n_way, n_support = params.n_shot) 
        base_datamgr            = SetDataManager(image_size, n_query = n_query,  **train_few_shot_params)
        base_loader             = base_datamgr.get_data_loader( base_file , aug = params.train_aug )
         
        test_few_shot_params     = dict(n_way = params.test_n_way, n_support = params.n_shot) 
        val_datamgr             = SetDataManager(image_size, n_query = n_query, **test_few_shot_params)
        val_loader              = val_datamgr.get_data_loader( val_file, aug = False) 
        #a batch for SetDataManager: a [n_way, n_support + n_query, dim, w, h] tensor        
예제 #9
0
            params.stop_epoch = 40  #default
        else:  #meta-learning methods
            params.stop_epoch = 60  #default

    if params.method in ['baseline', 'baseline++']:
        base_datamgr = SimpleDataManager(batch_size=16)
        base_loader = base_datamgr.get_data_loader(
            root='./filelists/tabula_muris', mode='train')
        val_datamgr = SimpleDataManager(batch_size=64)
        val_loader = val_datamgr.get_data_loader(
            root='./filelists/tabula_muris', mode='val')

        x_dim = base_loader.dataset.get_dim()

        if params.method == 'baseline':
            model = BaselineTrain(backbone.FCNet(x_dim), params.num_classes)
        elif params.method == 'baseline++':
            model = BaselineTrain(backbone.FCNet(x_dim),
                                  params.num_classes,
                                  loss_type='dist')

    elif params.method in [
            'protonet', 'comet', 'matchingnet', 'relationnet',
            'relationnet_softmax', 'maml', 'maml_approx'
    ]:
        n_query = max(
            1, int(16 * params.test_n_way / params.train_n_way)
        )  #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small
        train_few_shot_params = dict(n_way=params.train_n_way,
                                     n_support=params.n_shot)
        base_datamgr = SetDataManager(n_query=n_query, **train_few_shot_params)
예제 #10
0
    base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug)

    test_few_shot_params = dict(n_way=params.test_n_way,
                                n_support=params.n_shot)
    val_datamgr = SetDataManager(image_size,
                                 n_query=n_query,
                                 **test_few_shot_params)
    val_loader = val_datamgr.get_data_loader(val_file, aug=False)

    if params.method == 'manifold_mixup':
        model = wrn_mixup_model.wrn28_10(64)
    elif params.method == 'S2M2_R':
        model = ProtoNet(model_dict[params.model], params.train_n_way,
                         params.n_shot)
    elif params.method == 'rotation':
        model = BaselineTrain(model_dict[params.model], 64, loss_type='dist')

    if params.method == 'S2M2_R':

        if use_gpu:
            if torch.cuda.device_count() > 1:
                model = torch.nn.DataParallel(model,
                                              device_ids=range(
                                                  torch.cuda.device_count()))
            model.cuda()

        if params.resume:
            resume_file = get_resume_file(params.checkpoint_dir)
            print("resume_file", resume_file)
            tmp = torch.load(resume_file)
            start_epoch = tmp['epoch'] + 1
예제 #11
0
파일: train.py 프로젝트: bigchou/ammai_hw2
            datamgr = miniImageNet_few_shot.SimpleDataManager(image_size,
                                                              batch_size=16)
            print(
                "datamgr = miniImageNet_few_shot.SimpleDataManager(image_size, batch_size = 16)"
            )
            base_loader = datamgr.get_data_loader(
                aug=params.train_aug)  #waste lots of time
            print(
                "base_loader = datamgr.get_data_loader(aug = params.train_aug )"
            )
            val_loader = None
            print("load miniIMageNet [END]")
        else:
            raise ValueError('Unknown dataset')
        print("load model [START]")
        model = BaselineTrain(model_dict[params.model], params.num_classes)
        print("load model [END]")

    elif params.method in ['protonet']:
        n_query = max(
            1, int(16 * params.test_n_way / params.train_n_way)
        )  #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small
        train_few_shot_params = dict(n_way=params.train_n_way,
                                     n_support=params.n_shot)
        test_few_shot_params = dict(n_way=params.test_n_way,
                                    n_support=params.n_shot)

        if params.dataset == "miniImageNet":

            datamgr = miniImageNet_few_shot.SetDataManager(
                image_size,
예제 #12
0
    start_epoch = params.start_epoch
    stop_epoch = params.stop_epoch

    base_datamgr = SimpleDataManager(image_size, batch_size=params.batch_size)
    base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug)
    val_datamgr = SimpleDataManager(image_size,
                                    batch_size=params.test_batch_size)
    val_loader = base_datamgr.get_data_loader(base_file, aug=False)

    if params.method == 'manifold_mixup':
        model = wrn_mixup_model.wrn28_10(64, 0.9)
    elif params.method == 'S2M2_R':
        model = wrn_mixup_model.wrn28_10(64, 0.9)
    elif params.method == 'rotation':
        model = BaselineTrain(model_dict[params.model],
                              64,
                              dropRate=0.9,
                              loss_type='dist')

    if params.method == 'S2M2_R':

        if use_gpu:
            if torch.cuda.device_count() > 1:
                model = torch.nn.DataParallel(model,
                                              device_ids=range(
                                                  torch.cuda.device_count()))
            model.cuda()

        if params.resume:
            resume_file = get_resume_file(params.checkpoint_dir)
            print("resume_file", resume_file)
            tmp = torch.load(resume_file)
예제 #13
0
def select_model(params):
    """
    select which model to use based on params
    """
    if params.method in ['baseline', 'baseline++']:
        if params.dataset == 'CUB':
            params.num_classes = 200
        elif params.dataset == 'cars':
            params.num_classes = 196
        elif params.dataset == 'aircrafts':
            params.num_classes = 100
        elif params.dataset == 'dogs':
            params.num_classes = 120
        elif params.dataset == 'flowers':
            params.num_classes = 102
        elif params.dataset == 'miniImagenet':
            params.num_classes = 100
        elif params.dataset == 'tieredImagenet':
            params.num_classes = 608

        if params.method == 'baseline':
            model           = BaselineTrain( model_dict[params.model], params.num_classes, \
                                            jigsaw=params.jigsaw, lbda=params.lbda, rotation=params.rotation, tracking=params.tracking)
        elif params.method == 'baseline++':
            model           = BaselineTrain( model_dict[params.model], params.num_classes, \
                                            loss_type = 'dist', jigsaw=params.jigsaw, lbda=params.lbda, rotation=params.rotation, tracking=params.tracking)

    elif params.method in [
            'protonet', 'matchingnet', 'relationnet', 'relationnet_softmax',
            'maml', 'maml_approx'
    ]:
        train_few_shot_params    = dict(n_way = params.train_n_way, n_support = params.n_shot, \
                                        jigsaw=params.jigsaw, lbda=params.lbda, rotation=params.rotation)
        if params.method == 'protonet':
            model = ProtoNet(model_dict[params.model],
                             **train_few_shot_params,
                             use_bn=(not params.no_bn),
                             pretrain=params.pretrain,
                             tracking=params.tracking)
        elif params.method == 'matchingnet':
            model = MatchingNet(model_dict[params.model],
                                **train_few_shot_params)
        elif params.method in ['relationnet', 'relationnet_softmax']:
            feature_model = lambda: model_dict[params.model](flatten=False)
            loss_type = 'mse' if params.method == 'relationnet' else 'softmax'

            model = RelationNet(feature_model,
                                loss_type=loss_type,
                                **train_few_shot_params)
        elif params.method in ['maml', 'maml_approx']:
            backbone.ConvBlock.maml = True
            backbone.SimpleBlock.maml = True
            backbone.BottleneckBlock.maml = True
            backbone.ResNet.maml = True

            BasicBlock.maml = True
            Bottleneck.maml = True
            ResNet.maml = True

            model = MAML(model_dict[params.model],
                         approx=(params.method == 'maml_approx'),
                         **train_few_shot_params)

    else:
        raise ValueError('Unknown method')
    return model
예제 #14
0
    if 'Conv' in params.model:
        image_size = 84
    else:
        image_size = 224

    if params.method in ['baseline', 'baseline++']:
        print('  pre-training the feature encoder {} using method {}'.format(
            params.model, params.method))
        base_datamgr = SimpleDataManager(image_size, batch_size=16)
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   aug=params.train_aug)
        val_datamgr = SimpleDataManager(image_size, batch_size=64)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)
        if params.method == 'baseline':
            model = BaselineTrain(model_dict[params.model],
                                  params.num_classes,
                                  tf_path=params.tf_dir)
        elif params.method == 'baseline++':
            model = BaselineTrain(model_dict[params.model],
                                  params.num_classes,
                                  loss_type='dist',
                                  tf_path=params.tf_dir)

    elif params.method in [
            'protonet', 'matchingnet', 'relationnet', 'relationnet_softmax',
            'gnnnet'
    ]:
        print(
            '  baseline training the model {} with feature encoder {}'.format(
                params.method, params.model))
예제 #15
0
    print('Running up to {} epochs'.format(params.stop_epoch))
    device = torch.device("cuda:0" if torch.cuda.device_count() > 0 else "cpu")
        
    if params.method in ['baseline', 'baseline++'] :
        base_datamgr    = SimpleDataManager(image_size, batch_size = 16)
        base_loader     = base_datamgr.get_data_loader( base_file , aug = params.train_aug )
        val_datamgr     = SimpleDataManager(image_size, batch_size = 64)
        val_loader      = val_datamgr.get_data_loader( val_file, aug = False)
        
        if params.dataset == 'omniglot':
            assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class'
        if params.dataset == 'cross_char':
            assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class'

        if params.method == 'baseline':
            model           = BaselineTrain( model_dict[params.model], params.num_classes, device=device)
        elif params.method == 'baseline++':
            model           = BaselineTrain( model_dict[params.model], params.num_classes, 
                loss_type = 'dist', init_orthogonal = params.ortho, ortho_reg = params.ortho_reg, device=device)

    elif params.method in ['protonet','matchingnet','relationnet', 'relationnet_softmax', 'maml', 'maml_approx']:
        n_query = max(1, int(16* params.test_n_way/params.train_n_way)) #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small
 
        train_few_shot_params    = dict(n_way = params.train_n_way, n_support = params.n_shot) 
        base_datamgr            = SetDataManager(image_size, n_query = n_query,  **train_few_shot_params)
        base_loader             = base_datamgr.get_data_loader( base_file , aug = params.train_aug )
         
        test_few_shot_params     = dict(n_way = params.test_n_way, n_support = params.n_shot) 
        val_datamgr             = SetDataManager(image_size, n_query = n_query, **test_few_shot_params)
        val_loader              = val_datamgr.get_data_loader( val_file, aug = False) 
        #a batch for SetDataManager: a [n_way, n_support + n_query, dim, w, h] tensor        
예제 #16
0
            params.num_classes = 200
        elif params.dataset == 'cars_original':
            params.num_classes = 196
        elif params.dataset == 'aircrafts_original':
            params.num_classes = 100
        elif params.dataset == 'dogs_original':
            params.num_classes = 120
        elif params.dataset == 'flowers_original':
            params.num_classes = 102
        elif params.dataset == 'miniImagenet':
            params.num_classes = 100
        elif params.dataset == 'tieredImagenet':
            params.num_classes = 608

        if params.method == 'baseline':
            model           = BaselineTrain( model_dict[params.model], params.num_classes, \
                                            jigsaw=params.jigsaw, lbda=params.lbda, rotation=params.rotation, tracking=params.tracking)
        elif params.method == 'baseline++':
            model           = BaselineTrain( model_dict[params.model], params.num_classes, \
                                            loss_type = 'dist', jigsaw=params.jigsaw, lbda=params.lbda, rotation=params.rotation, tracking=params.tracking)

    elif params.method in [
            'protonet', 'matchingnet', 'relationnet', 'relationnet_softmax',
            'maml', 'maml_approx'
    ]:
        n_query = max(
            1, int(params.n_query * params.test_n_way / params.train_n_way)
        )  #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small
        print('n_query:', n_query)

        base_datamgr_u = SimpleDataManager(image_size,
                                           batch_size=params.bs,
예제 #17
0
def main():
    timer = Timer()
    args, writer = init()

    train_file = args.dataset_dir + 'train.json'
    val_file = args.dataset_dir + 'val.json'

    few_shot_params = dict(n_way=args.n_way, n_support=args.n_shot, n_query=args.n_query)
    n_episode = 10 if args.debug else 100
    if args.method_type is Method_type.baseline:
        train_datamgr = SimpleDataManager(train_file, args.dataset_dir, args.image_size, batch_size=64)
        train_loader = train_datamgr.get_data_loader(aug = True)
    else:
        train_datamgr = SetDataManager(train_file, args.dataset_dir, args.image_size,
                                       n_episode=n_episode, mode='train', **few_shot_params)
        train_loader = train_datamgr.get_data_loader(aug=True)

    val_datamgr = SetDataManager(val_file, args.dataset_dir, args.image_size,
                                     n_episode=n_episode, mode='val', **few_shot_params)
    val_loader = val_datamgr.get_data_loader(aug=False)

    if args.model_type is Model_type.ConvNet:
        pass
    elif args.model_type is Model_type.ResNet12:
        from methods.backbone import ResNet12
        encoder = ResNet12()
    else:
        raise ValueError('')

    if args.method_type is Method_type.baseline:
        from methods.baselinetrain import BaselineTrain
        model = BaselineTrain(encoder, args)
    elif args.method_type is Method_type.protonet:
        from methods.protonet import ProtoNet
        model = ProtoNet(encoder, args)
    else:
        raise ValueError('')

    from torch.optim import SGD,lr_scheduler
    if args.method_type is Method_type.baseline:
        optimizer = SGD(model.encoder.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_epoch, eta_min=0, last_epoch=-1)
    else:
        optimizer = torch.optim.SGD(model.encoder.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4,
                                    nesterov=True)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.5)

    args.ngpu = torch.cuda.device_count()
    torch.backends.cudnn.benchmark = True
    model = model.cuda()

    label = torch.from_numpy(np.repeat(range(args.n_way), args.n_query))
    label = label.cuda()

    if args.test:
        test(model, label, args, few_shot_params)
        return

    if args.resume:
        resume_OK =  resume_model(model, optimizer, args, scheduler)
    else:
        resume_OK = False
    if (not resume_OK) and  (args.warmup is not None):
        load_pretrained_weights(model, args)

    if args.debug:
        args.max_epoch = args.start_epoch + 1

    for epoch in range(args.start_epoch, args.max_epoch):
        train_one_epoch(model, optimizer, args, train_loader, label, writer, epoch)
        scheduler.step()

        vl, va = val(model, args, val_loader, label)
        if writer is not None:
            writer.add_scalar('data/val_acc', float(va), epoch)
        print('epoch {}, val, loss={:.4f} acc={:.4f}'.format(epoch, vl, va))

        if va >= args.max_acc:
            args.max_acc = va
            print('saving the best model! acc={:.4f}'.format(va))
            save_model(model, optimizer, args, epoch, args.max_acc, 'max_acc', scheduler)
        save_model(model, optimizer, args, epoch, args.max_acc, 'epoch-last', scheduler)
        if epoch != 0:
            print('ETA:{}/{}'.format(timer.measure(), timer.measure(epoch / args.max_epoch)))
    if writer is not None:
        writer.close()
    test(model, label, args, few_shot_params)
예제 #18
0
            base_loader = base_datamgr.get_data_loader(base_file,
                                                       aug=params.train_aug)
            base_datamgr_test = SimpleDataManager(
                image_size, batch_size=params.test_batch_size)
            base_loader_test = base_datamgr_test.get_data_loader(base_file,
                                                                 aug=False)
            test_few_shot_params = dict(n_way=params.train_n_way,
                                        n_support=params.n_shot)
            val_datamgr = SetDataManager(image_size,
                                         n_query=15,
                                         **test_few_shot_params)
            val_loader = val_datamgr.get_data_loader(val_file, aug=False)

        if params.method == 'baseline++':
            model = BaselineTrain(model_dict[params.model],
                                  params.num_classes,
                                  loss_type='dist')

        elif params.method == 'manifold_mixup':
            if params.model == 'WideResNet28_10':
                model = wrn_mixup_model.wrn28_10(params.num_classes)
            elif params.model == 'ResNet18':
                model = res_mixup_model.resnet18(
                    num_classes=params.num_classes)

        elif params.method == 'S2M2_R' or 'rotation':
            if params.model == 'WideResNet28_10':
                model = wrn_mixup_model.wrn28_10(
                    num_classes=params.num_classes,
                    dct_status=params.dct_status)
            elif params.model == 'ResNet18':
예제 #19
0
def get_model(params, mode):
    '''
    Args:
        params: argparse params
        mode: (str), 'train', 'test'
    '''
    print('get_model() start...')
#     few_shot_params_d = get_few_shot_params(params, None)
#     few_shot_params = few_shot_params_d[mode]
    few_shot_params = get_few_shot_params(params, mode)
    
    if 'omniglot' in params.dataset or 'cross_char' in params.dataset:
#     if params.dataset in ['omniglot', 'cross_char', 'cross_char_half', 'cross_char_quarter', ...]:
#         assert params.model == 'Conv4' and not params.train_aug ,'omniglot only support Conv4 without augmentation'
        assert 'Conv4' in params.model and not params.train_aug ,'omniglot/cross_char only support Conv4 without augmentation'
        params.model = params.model.replace('Conv4', 'Conv4S') # because Conv4Drop should also be Conv4SDrop
        if params.recons_decoder is not None:
            if 'ConvS' not in params.recons_decoder:
                raise ValueError('omniglot / cross_char should use ConvS/HiddenConvS decoder.')
    
#     if mode == 'train':
#         params.num_classes = n_base_class_map[params.dataset]
    if params.method in ['baseline', 'baseline++'] and mode=='train':
        assert params.num_classes >= n_base_classes[params.dataset]
#         if params.dataset == 'omniglot': # 4112/688/1692
#             assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class'
#         if params.dataset == 'cross_char': # 1597/31/31
#             assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class'
#         if params.dataset == 'cross_char_half': # 758/31/31
#             assert params.num_classes >= 758, 'class number need to be larger than max label id in base class'
#         if params.dataset in ['cross_char_quarter', 'cross_char_quarter_10shot']: # 350/31/31
#             assert params.num_classes >= 350, 'class number need to be larger than max label id in base class'
#         if params.dataset == 'cross_char_base3lang': # 69/31/31
#             assert params.num_classes >= 69, 'class number need to be larger than max label id in base class'
#         if params.dataset == 'miniImagenet': # 64/16/20
#             assert params.num_classes >= 64, 'class number need to be larger than max label id in base class'
#         if params.dataset == 'CUB': # 100/50/50
#             assert params.num_classes >= 100, 'class number need to be larger than max label id in base class'
#         if params.dataset == 'cross': # 64+16+20/50/50
#             assert params.num_classes >= 100, 'class number need to be larger than max label id in base class'
#         if params.dataset == 'cross_base80cl': # 80/50/50
#             assert params.num_classes >= 100, 'class number need to be larger than max label id in base class'
        
    
    if params.recons_decoder == None:
        print('params.recons_decoder == None')
        recons_decoder = None
    else:
        recons_decoder = decoder_dict[params.recons_decoder]
        print('recons_decoder:\n',recons_decoder)

    backbone_func = get_backbone_func(params)
    
    if 'baseline' in params.method:
        loss_types = {
            'baseline':'softmax', 
            'baseline++':'dist', 
        }
        loss_type = loss_types[params.method]
        
        if recons_decoder is None and params.min_gram is None: # default baseline/baseline++
            if mode == 'train':
                model = BaselineTrain(
                    model_func = backbone_func, loss_type = loss_type, 
                    num_class = params.num_classes, **few_shot_params)
            elif mode == 'test':
                model = BaselineFinetune(
                    model_func = backbone_func, loss_type = loss_type, 
                    **few_shot_params, finetune_dropout_p = params.finetune_dropout_p)
        else: # other settings for baseline
            if params.min_gram is not None:
                min_gram_params = {
                    'min_gram':params.min_gram, 
                    'lambda_gram':params.lambda_gram, 
                }
                if mode == 'train':
                    model = BaselineTrainMinGram(
                        model_func = backbone_func, loss_type = loss_type, 
                        num_class = params.num_classes, **few_shot_params, **min_gram_params)
                elif mode == 'test':
                    model = BaselineFinetune(
                        model_func = backbone_func, loss_type = loss_type, 
                        **few_shot_params, finetune_dropout_p = params.finetune_dropout_p)
#                     model = BaselineFinetuneMinGram(backbone_func, loss_type = loss_type, **few_shot_params, **min_gram_params)
            
    
    elif params.method == 'protonet':
        # default ProtoNet
        if recons_decoder is None and params.min_gram is None:
            model = ProtoNet( backbone_func, **few_shot_params )
        else: # other settings
            if params.min_gram is not None:
                min_gram_params = {
                    'min_gram':params.min_gram, 
                    'lambda_gram':params.lambda_gram, 
                }
                model = ProtoNetMinGram(backbone_func, **few_shot_params, **min_gram_params)

            if params.recons_decoder is not None:
                if 'Hidden' in params.recons_decoder:
                    if params.recons_decoder == 'HiddenConv': # 'HiddenConv', 'HiddenConvS'
                        model = ProtoNetAE2(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, extract_layer = 2)
                    elif params.recons_decoder == 'HiddenConvS': # 'HiddenConv', 'HiddenConvS'
                        model = ProtoNetAE2(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, extract_layer = 2, is_color=False)
                    elif params.recons_decoder == 'HiddenRes10':
                        model = ProtoNetAE2(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, extract_layer = 6)
                    elif params.recons_decoder == 'HiddenRes18':
                        model = ProtoNetAE2(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, extract_layer = 8)
                else:
                    if 'ConvS' in params.recons_decoder:
                        model = ProtoNetAE(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, is_color=False)
                    else:
                        model = ProtoNetAE(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, is_color=True)
    elif params.method == 'matchingnet':
        model           = MatchingNet( backbone_func, **few_shot_params )
    elif params.method in ['relationnet', 'relationnet_softmax']:
#         if params.model == 'Conv4': 
#             feature_model = backbone.Conv4NP
#         elif params.model == 'Conv6': 
#             feature_model = backbone.Conv6NP
#         elif params.model == 'Conv4S': 
#             feature_model = backbone.Conv4SNP
#         else:
#             feature_model = lambda: model_dict[params.model]( flatten = False )
        loss_type = 'mse' if params.method == 'relationnet' else 'softmax'

        model           = RelationNet( backbone_func, loss_type = loss_type , **few_shot_params )
    elif params.method in ['maml' , 'maml_approx']:
        backbone.ConvBlock.maml = True
        backbone.SimpleBlock.maml = True
        backbone.BottleneckBlock.maml = True
        backbone.ResNet.maml = True
        model           = MAML(  backbone_func, approx = (params.method == 'maml_approx') , **few_shot_params )
        if 'omniglot' in params.dataset or 'cross_char' in params.dataset:
#         if params.dataset in ['omniglot', 'cross_char', 'cross_char_half']: #maml use different parameter in omniglot
            model.n_task     = 32
            model.task_update_num = 1
            model.train_lr = 0.1
    else:
        raise ValueError('Unexpected params.method: %s'%(params.method))
    
    print('get_model() finished.')
    return model
예제 #20
0
                params.stop_epoch = 600  #default

    if params.method in ['baseline', 'baseline++']:
        base_datamgr = SimpleDataManager(image_size, batch_size=16)
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   aug=params.train_aug)
        val_datamgr = SimpleDataManager(image_size, batch_size=64)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)

        if params.dataset == 'omniglot':
            assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class'
        if params.dataset == 'cross_char':
            assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class'

        if params.method == 'baseline':
            model = BaselineTrain(model_dict[params.model], params.num_classes)
        elif params.method == 'baseline++':
            model = BaselineTrain(model_dict[params.model],
                                  params.num_classes,
                                  loss_type='dist')

    elif params.method in [
            'protonet', 'matchingnet', 'relationnet', 'relationnet_softmax',
            'maml', 'maml_approx'
    ]:
        n_query = max(
            1, int(16 * params.test_n_way / params.train_n_way)
        )  #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small

        train_few_shot_params = dict(n_way=params.train_n_way,
                                     n_support=params.n_shot)
예제 #21
0
            datamgr = miniImageNet_few_shot.SimpleDataManager(image_size,
                                                              batch_size=16)
            print(
                "datamgr = miniImageNet_few_shot.SimpleDataManager(image_size, batch_size = 16)"
            )
            base_loader = datamgr.get_data_loader(
                aug=params.train_aug)  #waste lots of time
            print(
                "base_loader = datamgr.get_data_loader(aug = params.train_aug )"
            )
            val_loader = None
            print("load miniIMageNet [END]")
        else:
            raise ValueError('Unknown dataset')
        print("load model [START]")
        model = BaselineTrain(model_dict[params.model], params.num_classes)
        print("load model [END]")

    elif params.method in ['protonet']:
        n_query = max(
            1, int(16 * params.test_n_way / params.train_n_way)
        )  #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small
        train_few_shot_params = dict(n_way=params.train_n_way,
                                     n_support=params.n_shot)
        test_few_shot_params = dict(n_way=params.test_n_way,
                                    n_support=params.n_shot)

        if params.dataset == "miniImageNet":

            datamgr = miniImageNet_few_shot.SetDataManager(
                image_size,