Пример #1
0
def get_model(params):
    if params.method in ['baseline', 'baseline++']:
        if params.dataset == 'omniglot':
            assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class'
        if params.dataset == 'cross_char':
            assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class'

        if params.method == 'baseline':
            model = BaselineTrain(model_dict[params.model], params.num_classes)
        elif params.method == 'baseline++':
            model = BaselineTrain(model_dict[params.model],
                                  params.num_classes,
                                  loss_type='dist')
    else:
        raise ValueError('Unknown method')

    if params.resume:
        resume_file = get_resume_file(params.checkpoint_dir)
        if resume_file is not None:
            tmp = torch.load(resume_file)
            params.start_epoch = tmp['epoch'] + 1
            model.load_state_dict(tmp['state'])
    elif params.warmup:  #We also support warmup from pretrained baseline feature, but we never used in our paper
        baseline_checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
            configs.save_dir, params.dataset, params.model, 'baseline')
        if params.train_aug:
            baseline_checkpoint_dir += '_aug'
        warmup_resume_file = get_resume_file(baseline_checkpoint_dir)
        tmp = torch.load(warmup_resume_file)
        if tmp is not None:
            state = tmp['state']
            state_keys = list(state.keys())
            for i, key in enumerate(state_keys):
                if "feature." in key:
                    newkey = key.replace(
                        "feature.", ""
                    )  # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx'
                    state[newkey] = state.pop(key)
                else:
                    state.pop(key)
            model.feature.load_state_dict(state)
        else:
            raise ValueError('No warm_up file')

    return model
Пример #2
0
def load_mask_detector():
    ############### Change path to path of model ###################
    # path_model='/content/drive/MyDrive/frinks/models/fastai_resnet101'
    path_model='/content/drive/MyDrive/frinks/fewShots/CloserLookFewShot/checkpoints_masks_Conv4_baseline_aug/20.tar'
    path_data='/content/drive/MyDrive/frinks/Faces/data'
    if flag is 'torch':
      if not flag_fewShots:
        model = torch.load(path_model)
      if flag_fewShots:
        # import pdb; pdb.set_trace()
        model = BaselineTrain(backbone.Conv4, 4)
        model_dict = torch.load(path_model)
        model.load_state_dict(model_dict['state'])
        model=model.cuda()
    elif flag is 'fastai':
      data = ImageDataBunch.from_folder(path_data, valid_pct=0.2, size = 120)
      model = cnn_learner(data, models.resnet101, metrics=error_rate)
      model.load(path_model)
    else:
      model = LoadModel(path_model)
    return model
Пример #3
0
                                                    params.n_shot)

    if not os.path.isdir(params.checkpoint_dir):
        os.makedirs(params.checkpoint_dir)

    start_epoch = params.start_epoch
    stop_epoch = params.stop_epoch
    if params.method == 'maml' or params.method == 'maml_approx':
        stop_epoch = params.stop_epoch * model.n_task  #maml use multiple tasks in one update

    if params.resume:
        resume_file = get_resume_file(params.checkpoint_dir)
        if resume_file is not None:
            tmp = torch.load(resume_file)
            start_epoch = tmp['epoch'] + 1
            model.load_state_dict(tmp['state'])
    elif params.warmup:  #We also support warmup from pretrained baseline feature, but we never used in our paper
        baseline_checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
            configs.save_dir, params.dataset, params.model, 'baseline')
        if params.train_aug:
            baseline_checkpoint_dir += '_aug'
        warmup_resume_file = get_resume_file(baseline_checkpoint_dir)
        tmp = torch.load(warmup_resume_file)
        if tmp is not None:
            state = tmp['state']
            state_keys = list(state.keys())
            for i, key in enumerate(state_keys):
                if "feature." in key:
                    newkey = key.replace(
                        "feature.", ""
                    )  # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx'
def main_train(params):
    _set_seed(params)

    results_logger = ResultsLogger(params)

    if params.dataset == 'cross':
        base_file = configs.data_dir['miniImagenet'] + 'all.json'
        val_file = configs.data_dir['CUB'] + 'val.json'
    elif params.dataset == 'cross_char':
        base_file = configs.data_dir['omniglot'] + 'noLatin.json'
        val_file = configs.data_dir['emnist'] + 'val.json'
    else:
        base_file = configs.data_dir[params.dataset] + 'base.json'
        val_file = configs.data_dir[params.dataset] + 'val.json'
    if 'Conv' in params.model:
        if params.dataset in ['omniglot', 'cross_char']:
            image_size = 28
        else:
            image_size = 84
    else:
        image_size = 224
    if params.dataset in ['omniglot', 'cross_char']:
        assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation'
        params.model = 'Conv4S'
    optimization = 'Adam'
    if params.stop_epoch == -1:
        if params.method in ['baseline', 'baseline++']:
            if params.dataset in ['omniglot', 'cross_char']:
                params.stop_epoch = 5
            elif params.dataset in ['CUB']:
                params.stop_epoch = 200  # This is different as stated in the open-review paper. However, using 400 epoch in baseline actually lead to over-fitting
            elif params.dataset in ['miniImagenet', 'cross']:
                params.stop_epoch = 400
            else:
                params.stop_epoch = 400  # default
        else:  # meta-learning methods
            if params.n_shot == 1:
                params.stop_epoch = 600
            elif params.n_shot == 5:
                params.stop_epoch = 400
            else:
                params.stop_epoch = 600  # default
    if params.method in ['baseline', 'baseline++']:
        base_datamgr = SimpleDataManager(image_size, batch_size=16)
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   aug=params.train_aug)
        val_datamgr = SimpleDataManager(image_size, batch_size=64)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)

        if params.dataset == 'omniglot':
            assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class'
        if params.dataset == 'cross_char':
            assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class'

        if params.method == 'baseline':
            model = BaselineTrain(model_dict[params.model], params.num_classes)
        elif params.method == 'baseline++':
            model = BaselineTrain(model_dict[params.model],
                                  params.num_classes,
                                  loss_type='dist')

    elif params.method in [
            'DKT', 'protonet', 'matchingnet', 'relationnet',
            'relationnet_softmax', 'maml', 'maml_approx'
    ]:
        n_query = max(
            1, int(16 * params.test_n_way / params.train_n_way)
        )  # if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small

        train_few_shot_params = dict(n_way=params.train_n_way,
                                     n_support=params.n_shot)
        base_datamgr = SetDataManager(image_size,
                                      n_query=n_query,
                                      **train_few_shot_params)  # n_eposide=100
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   aug=params.train_aug)

        test_few_shot_params = dict(n_way=params.test_n_way,
                                    n_support=params.n_shot)
        val_datamgr = SetDataManager(image_size,
                                     n_query=n_query,
                                     **test_few_shot_params)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)
        # a batch for SetDataManager: a [n_way, n_support + n_query, dim, w, h] tensor

        if (params.method == 'DKT'):
            model = DKT(model_dict[params.model], **train_few_shot_params)
            model.init_summary()
        elif params.method == 'protonet':
            model = ProtoNet(model_dict[params.model], **train_few_shot_params)
        elif params.method == 'matchingnet':
            model = MatchingNet(model_dict[params.model],
                                **train_few_shot_params)
        elif params.method in ['relationnet', 'relationnet_softmax']:
            if params.model == 'Conv4':
                feature_model = backbone.Conv4NP
            elif params.model == 'Conv6':
                feature_model = backbone.Conv6NP
            elif params.model == 'Conv4S':
                feature_model = backbone.Conv4SNP
            else:
                feature_model = lambda: model_dict[params.model](flatten=False)
            loss_type = 'mse' if params.method == 'relationnet' else 'softmax'

            model = RelationNet(feature_model,
                                loss_type=loss_type,
                                **train_few_shot_params)
        elif params.method in ['maml', 'maml_approx']:
            backbone.ConvBlock.maml = True
            backbone.SimpleBlock.maml = True
            backbone.BottleneckBlock.maml = True
            backbone.ResNet.maml = True
            model = MAML(model_dict[params.model],
                         approx=(params.method == 'maml_approx'),
                         **train_few_shot_params)
            if params.dataset in [
                    'omniglot', 'cross_char'
            ]:  # maml use different parameter in omniglot
                model.n_task = 32
                model.task_update_num = 1
                model.train_lr = 0.1
    else:
        raise ValueError('Unknown method')
    model = model.cuda()
    params.checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
        configs.save_dir, params.dataset, params.model, params.method)
    if params.train_aug:
        params.checkpoint_dir += '_aug'
    if not params.method in ['baseline', 'baseline++']:
        params.checkpoint_dir += '_%dway_%dshot' % (params.train_n_way,
                                                    params.n_shot)
    if not os.path.isdir(params.checkpoint_dir):
        os.makedirs(params.checkpoint_dir)
    start_epoch = params.start_epoch
    stop_epoch = params.stop_epoch
    if params.method == 'maml' or params.method == 'maml_approx':
        stop_epoch = params.stop_epoch * model.n_task  # maml use multiple tasks in one update
    if params.resume:
        resume_file = get_resume_file(params.checkpoint_dir)
        if resume_file is not None:
            tmp = torch.load(resume_file)
            start_epoch = tmp['epoch'] + 1
            model.load_state_dict(tmp['state'])
    elif params.warmup:  # We also support warmup from pretrained baseline feature, but we never used in our paper
        baseline_checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
            configs.save_dir, params.dataset, params.model, 'baseline')
        if params.train_aug:
            baseline_checkpoint_dir += '_aug'
        warmup_resume_file = get_resume_file(baseline_checkpoint_dir)
        tmp = torch.load(warmup_resume_file)
        if tmp is not None:
            state = tmp['state']
            state_keys = list(state.keys())
            for i, key in enumerate(state_keys):
                if "feature." in key:
                    newkey = key.replace(
                        "feature.", ""
                    )  # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx'
                    state[newkey] = state.pop(key)
                else:
                    state.pop(key)
            model.feature.load_state_dict(state)
        else:
            raise ValueError('No warm_up file')

    model = train(base_loader, val_loader, model, optimization, start_epoch,
                  stop_epoch, params, results_logger)
    results_logger.save()
Пример #5
0
                model = res_mixup_model.resnet18(
                    num_classes=params.num_classes)

        if params.method == 'baseline++':
            if use_gpu:
                if torch.cuda.device_count() > 1:
                    model = torch.nn.DataParallel(
                        model, device_ids=range(torch.cuda.device_count()))
                model.cuda()

            if params.resume:
                resume_file = get_resume_file(params.checkpoint_dir)
                tmp = torch.load(resume_file)
                start_epoch = tmp['epoch'] + 1
                state = tmp['state']
                model.load_state_dict(state)
            model = torch.nn.DataParallel(model).cuda()
            cudnn.benchmark = True
            optimization = 'Adam'
            model = train_baseline(base_loader, base_loader_test, val_loader,
                                   model, start_epoch,
                                   start_epoch + stop_epoch, params, {})

        elif params.method == 'S2M2_R':
            if use_gpu:
                if torch.cuda.device_count() > 1:
                    model = torch.nn.DataParallel(
                        model, device_ids=range(torch.cuda.device_count()))
                model.cuda()

            if params.resume:
Пример #6
0
    stop_epoch = params.stop_epoch
    if params.method == 'maml' or params.method == 'maml_approx':
        stop_epoch = params.stop_epoch * model.n_task  #maml use multiple tasks in one update

    if params.resume:
        resume_file = get_resume_file(params.checkpoint_dir)
        if resume_file is not None:
            tmp = torch.load(resume_file)
            start_epoch = tmp['epoch'] + 1
            model_state_dict = model.state_dict()
            pretr_dict = {
                k: v
                for k, v in tmp.items() if k in model_state_dict
            }
            model_state_dict.update(pretr_dict)
            model.load_state_dict(model_state_dict)
            # model.load_state_dict(tmp['state'])
    elif params.warmup:  #We also support warmup from pretrained baseline feature, but we never used in our paper
        baseline_checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
            configs.save_dir, params.dataset, params.model, 'baseline')
        if params.train_aug:
            baseline_checkpoint_dir += '_aug'
        warmup_resume_file = get_resume_file(baseline_checkpoint_dir)
        tmp = torch.load(warmup_resume_file)
        if tmp is not None:
            state = tmp['state']
            state_keys = list(state.keys())
            for i, key in enumerate(state_keys):
                if "feature." in key:
                    newkey = key.replace(
                        "feature.", ""
Пример #7
0
    start_epoch = params.start_epoch
    stop_epoch = params.stop_epoch
    if params.method == 'maml' or params.method == 'maml_approx':
        stop_epoch = params.stop_epoch * model.n_task  #maml use multiple tasks in one update

    ## Use Google paper setting
    if params.method == 'baseline' and 'original' in params.dataset:
        stop_epoch = int(20000 / len(base_loader))
        print('train 20000 iters which is ' + str(stop_epoch) + ' epoch')

    if params.resume:
        resume_file = get_resume_file(params.checkpoint_dir)
        if resume_file is not None:
            tmp = torch.load(resume_file)
            start_epoch = tmp['epoch'] + 1
            model.load_state_dict(tmp['state'])
            del tmp
    elif params.warmup:  #We also support warmup from pretrained baseline feature, but we never used in our paper
        baseline_checkpoint_dir = 'checkpoints/%s/%s_%s' % (
            params.dataset, params.model, 'baseline')
        if params.train_aug:
            baseline_checkpoint_dir += '_aug'
        warmup_resume_file = get_resume_file(baseline_checkpoint_dir)
        tmp = torch.load(warmup_resume_file)
        if tmp is not None:
            state = tmp['state']
            state_keys = list(state.keys())
            for i, key in enumerate(state_keys):
                if "feature." in key:
                    newkey = key.replace(
                        "feature.", ""
Пример #8
0
    if not os.path.isdir(params.checkpoint_dir):
        os.makedirs(params.checkpoint_dir)

    start_epoch = params.start_epoch
    stop_epoch = params.stop_epoch
    if params.method == "maml" or params.method == "maml_approx":
        stop_epoch = (params.stop_epoch * model.n_task
                      )  # maml use multiple tasks in one update

    if params.resume:
        resume_file = get_resume_file(params.checkpoint_dir)
        if resume_file is not None:
            tmp = torch.load(resume_file)
            start_epoch = tmp["epoch"] + 1
            model.load_state_dict(tmp["state"])
    elif (
            params.warmup
    ):  # We also support warmup from pretrained baseline feature, but we never used in our paper
        baseline_checkpoint_dir = "%s/checkpoints/%s/%s_%s" % (
            configs.save_dir,
            params.dataset,
            params.model,
            "baseline",
        )
        if params.train_aug:
            baseline_checkpoint_dir += "_aug"
        warmup_resume_file = get_resume_file(baseline_checkpoint_dir)
        tmp = torch.load(warmup_resume_file)
        if tmp is not None:
            state = tmp["state"]
Пример #9
0
                               **train_few_shot_params)

    else:
        raise ValueError('Unknown method')

    model = model.cuda()
    save_dir = configs.save_dir

    params.checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
        save_dir, params.dataset, params.model, params.method)
    if params.train_aug:
        params.checkpoint_dir += '_aug'

    if not params.method in ['baseline', 'myModel']:
        params.checkpoint_dir += '_%dway_%dshot' % (params.train_n_way,
                                                    params.n_shot)

    if not os.path.isdir(params.checkpoint_dir):
        os.makedirs(params.checkpoint_dir)

    start_epoch = params.start_epoch
    stop_epoch = params.stop_epoch
    print('stop_epoch', stop_epoch)

    if configs.model_path != None and os.path.exists(configs.model_path):
        checkpoint = torch.load(configs.model_path)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state'])
    model = train(base_loader, val_loader, model, optimization, start_epoch,
                  stop_epoch, params)
                                 tf_path=params.tf_dir,
                                 **train_few_shot_params)
 else:
     raise ValueError('Unknown method')
 model = model.cuda()
 # print(model)
 # load model
 start_epoch = params.start_epoch
 stop_epoch = params.stop_epoch
 if params.resume != '':
     if params.resume == 'best':
         resume_file = get_best_file('%s/checkpoints/%s' %
                                     (params.save_dir, params.name))
         tmp = torch.load(resume_file)
         try:
             model.load_state_dict(tmp['state'])
         except RuntimeError:
             print('warning! RuntimeError when load_state_dict()!')
             model.load_state_dict(tmp['state'], strict=False)
         except KeyError:
             for k in tmp['model_state']:  ##### revise latter
                 if 'running' in k:
                     tmp['model_state'][k] = tmp['model_state'][k].squeeze()
             model.load_state_dict(tmp['model_state'], strict=False)
         except:
             raise
         print('  resume the training with at {} epoch (model file {})'.
               format(start_epoch, params.resume))
     else:
         resume_file = get_resume_file(
             '%s/checkpoints/%s' % (params.save_dir, params.resume),