def main_train(params):
    _set_seed(params)

    results_logger = ResultsLogger(params)

    if params.dataset == 'cross':
        base_file = configs.data_dir['miniImagenet'] + 'all.json'
        val_file = configs.data_dir['CUB'] + 'val.json'
    elif params.dataset == 'cross_char':
        base_file = configs.data_dir['omniglot'] + 'noLatin.json'
        val_file = configs.data_dir['emnist'] + 'val.json'
    else:
        base_file = configs.data_dir[params.dataset] + 'base.json'
        val_file = configs.data_dir[params.dataset] + 'val.json'
    if 'Conv' in params.model:
        if params.dataset in ['omniglot', 'cross_char']:
            image_size = 28
        else:
            image_size = 84
    else:
        image_size = 224
    if params.dataset in ['omniglot', 'cross_char']:
        assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation'
        params.model = 'Conv4S'
    optimization = 'Adam'
    if params.stop_epoch == -1:
        if params.method in ['baseline', 'baseline++']:
            if params.dataset in ['omniglot', 'cross_char']:
                params.stop_epoch = 5
            elif params.dataset in ['CUB']:
                params.stop_epoch = 200  # This is different as stated in the open-review paper. However, using 400 epoch in baseline actually lead to over-fitting
            elif params.dataset in ['miniImagenet', 'cross']:
                params.stop_epoch = 400
            else:
                params.stop_epoch = 400  # default
        else:  # meta-learning methods
            if params.n_shot == 1:
                params.stop_epoch = 600
            elif params.n_shot == 5:
                params.stop_epoch = 400
            else:
                params.stop_epoch = 600  # default
    if params.method in ['baseline', 'baseline++']:
        base_datamgr = SimpleDataManager(image_size, batch_size=16)
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   aug=params.train_aug)
        val_datamgr = SimpleDataManager(image_size, batch_size=64)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)

        if params.dataset == 'omniglot':
            assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class'
        if params.dataset == 'cross_char':
            assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class'

        if params.method == 'baseline':
            model = BaselineTrain(model_dict[params.model], params.num_classes)
        elif params.method == 'baseline++':
            model = BaselineTrain(model_dict[params.model],
                                  params.num_classes,
                                  loss_type='dist')

    elif params.method in [
            'DKT', 'protonet', 'matchingnet', 'relationnet',
            'relationnet_softmax', 'maml', 'maml_approx'
    ]:
        n_query = max(
            1, int(16 * params.test_n_way / params.train_n_way)
        )  # if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small

        train_few_shot_params = dict(n_way=params.train_n_way,
                                     n_support=params.n_shot)
        base_datamgr = SetDataManager(image_size,
                                      n_query=n_query,
                                      **train_few_shot_params)  # n_eposide=100
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   aug=params.train_aug)

        test_few_shot_params = dict(n_way=params.test_n_way,
                                    n_support=params.n_shot)
        val_datamgr = SetDataManager(image_size,
                                     n_query=n_query,
                                     **test_few_shot_params)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)
        # a batch for SetDataManager: a [n_way, n_support + n_query, dim, w, h] tensor

        if (params.method == 'DKT'):
            model = DKT(model_dict[params.model], **train_few_shot_params)
            model.init_summary()
        elif params.method == 'protonet':
            model = ProtoNet(model_dict[params.model], **train_few_shot_params)
        elif params.method == 'matchingnet':
            model = MatchingNet(model_dict[params.model],
                                **train_few_shot_params)
        elif params.method in ['relationnet', 'relationnet_softmax']:
            if params.model == 'Conv4':
                feature_model = backbone.Conv4NP
            elif params.model == 'Conv6':
                feature_model = backbone.Conv6NP
            elif params.model == 'Conv4S':
                feature_model = backbone.Conv4SNP
            else:
                feature_model = lambda: model_dict[params.model](flatten=False)
            loss_type = 'mse' if params.method == 'relationnet' else 'softmax'

            model = RelationNet(feature_model,
                                loss_type=loss_type,
                                **train_few_shot_params)
        elif params.method in ['maml', 'maml_approx']:
            backbone.ConvBlock.maml = True
            backbone.SimpleBlock.maml = True
            backbone.BottleneckBlock.maml = True
            backbone.ResNet.maml = True
            model = MAML(model_dict[params.model],
                         approx=(params.method == 'maml_approx'),
                         **train_few_shot_params)
            if params.dataset in [
                    'omniglot', 'cross_char'
            ]:  # maml use different parameter in omniglot
                model.n_task = 32
                model.task_update_num = 1
                model.train_lr = 0.1
    else:
        raise ValueError('Unknown method')
    model = model.cuda()
    params.checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
        configs.save_dir, params.dataset, params.model, params.method)
    if params.train_aug:
        params.checkpoint_dir += '_aug'
    if not params.method in ['baseline', 'baseline++']:
        params.checkpoint_dir += '_%dway_%dshot' % (params.train_n_way,
                                                    params.n_shot)
    if not os.path.isdir(params.checkpoint_dir):
        os.makedirs(params.checkpoint_dir)
    start_epoch = params.start_epoch
    stop_epoch = params.stop_epoch
    if params.method == 'maml' or params.method == 'maml_approx':
        stop_epoch = params.stop_epoch * model.n_task  # maml use multiple tasks in one update
    if params.resume:
        resume_file = get_resume_file(params.checkpoint_dir)
        if resume_file is not None:
            tmp = torch.load(resume_file)
            start_epoch = tmp['epoch'] + 1
            model.load_state_dict(tmp['state'])
    elif params.warmup:  # We also support warmup from pretrained baseline feature, but we never used in our paper
        baseline_checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
            configs.save_dir, params.dataset, params.model, 'baseline')
        if params.train_aug:
            baseline_checkpoint_dir += '_aug'
        warmup_resume_file = get_resume_file(baseline_checkpoint_dir)
        tmp = torch.load(warmup_resume_file)
        if tmp is not None:
            state = tmp['state']
            state_keys = list(state.keys())
            for i, key in enumerate(state_keys):
                if "feature." in key:
                    newkey = key.replace(
                        "feature.", ""
                    )  # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx'
                    state[newkey] = state.pop(key)
                else:
                    state.pop(key)
            model.feature.load_state_dict(state)
        else:
            raise ValueError('No warm_up file')

    model = train(base_loader, val_loader, model, optimization, start_epoch,
                  stop_epoch, params, results_logger)
    results_logger.save()
Пример #2
0
            loss_type = 'mse' if params.method == 'relationnet' else 'softmax'

            model = RelationNet(feature_model,
                                loss_type=loss_type,
                                **train_few_shot_params)
        elif params.method in ['maml', 'maml_approx']:
            backbone.ConvBlock.maml = True
            backbone.SimpleBlock.maml = True
            backbone.BottleneckBlock.maml = True
            backbone.ResNet.maml = True
            model = MAML(model_dict[params.model],
                         approx=(params.method == 'maml_approx'),
                         **train_few_shot_params)
            if params.dataset in ['omniglot', 'cross_char'
                                  ]:  #maml use different parameter in omniglot
                model.n_task = 32
                model.task_update_num = 1
                model.train_lr = 0.1
    else:
        raise ValueError('Unknown method')

    model = to_cuda(model)

    params.checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
        configs.save_dir, params.dataset, params.model, params.method)
    if params.train_aug:
        params.checkpoint_dir += '_aug'
    if not params.method in ['baseline', 'baseline++']:
        params.checkpoint_dir += '_%dway_%dshot' % (params.train_n_way,
                                                    params.n_shot)
Пример #3
0
def get_model(params, mode):
    '''
    Args:
        params: argparse params
        mode: (str), 'train', 'test'
    '''
    print('get_model() start...')
#     few_shot_params_d = get_few_shot_params(params, None)
#     few_shot_params = few_shot_params_d[mode]
    few_shot_params = get_few_shot_params(params, mode)
    
    if 'omniglot' in params.dataset or 'cross_char' in params.dataset:
#     if params.dataset in ['omniglot', 'cross_char', 'cross_char_half', 'cross_char_quarter', ...]:
#         assert params.model == 'Conv4' and not params.train_aug ,'omniglot only support Conv4 without augmentation'
        assert 'Conv4' in params.model and not params.train_aug ,'omniglot/cross_char only support Conv4 without augmentation'
        params.model = params.model.replace('Conv4', 'Conv4S') # because Conv4Drop should also be Conv4SDrop
        if params.recons_decoder is not None:
            if 'ConvS' not in params.recons_decoder:
                raise ValueError('omniglot / cross_char should use ConvS/HiddenConvS decoder.')
    
#     if mode == 'train':
#         params.num_classes = n_base_class_map[params.dataset]
    if params.method in ['baseline', 'baseline++'] and mode=='train':
        assert params.num_classes >= n_base_classes[params.dataset]
#         if params.dataset == 'omniglot': # 4112/688/1692
#             assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class'
#         if params.dataset == 'cross_char': # 1597/31/31
#             assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class'
#         if params.dataset == 'cross_char_half': # 758/31/31
#             assert params.num_classes >= 758, 'class number need to be larger than max label id in base class'
#         if params.dataset in ['cross_char_quarter', 'cross_char_quarter_10shot']: # 350/31/31
#             assert params.num_classes >= 350, 'class number need to be larger than max label id in base class'
#         if params.dataset == 'cross_char_base3lang': # 69/31/31
#             assert params.num_classes >= 69, 'class number need to be larger than max label id in base class'
#         if params.dataset == 'miniImagenet': # 64/16/20
#             assert params.num_classes >= 64, 'class number need to be larger than max label id in base class'
#         if params.dataset == 'CUB': # 100/50/50
#             assert params.num_classes >= 100, 'class number need to be larger than max label id in base class'
#         if params.dataset == 'cross': # 64+16+20/50/50
#             assert params.num_classes >= 100, 'class number need to be larger than max label id in base class'
#         if params.dataset == 'cross_base80cl': # 80/50/50
#             assert params.num_classes >= 100, 'class number need to be larger than max label id in base class'
        
    
    if params.recons_decoder == None:
        print('params.recons_decoder == None')
        recons_decoder = None
    else:
        recons_decoder = decoder_dict[params.recons_decoder]
        print('recons_decoder:\n',recons_decoder)

    backbone_func = get_backbone_func(params)
    
    if 'baseline' in params.method:
        loss_types = {
            'baseline':'softmax', 
            'baseline++':'dist', 
        }
        loss_type = loss_types[params.method]
        
        if recons_decoder is None and params.min_gram is None: # default baseline/baseline++
            if mode == 'train':
                model = BaselineTrain(
                    model_func = backbone_func, loss_type = loss_type, 
                    num_class = params.num_classes, **few_shot_params)
            elif mode == 'test':
                model = BaselineFinetune(
                    model_func = backbone_func, loss_type = loss_type, 
                    **few_shot_params, finetune_dropout_p = params.finetune_dropout_p)
        else: # other settings for baseline
            if params.min_gram is not None:
                min_gram_params = {
                    'min_gram':params.min_gram, 
                    'lambda_gram':params.lambda_gram, 
                }
                if mode == 'train':
                    model = BaselineTrainMinGram(
                        model_func = backbone_func, loss_type = loss_type, 
                        num_class = params.num_classes, **few_shot_params, **min_gram_params)
                elif mode == 'test':
                    model = BaselineFinetune(
                        model_func = backbone_func, loss_type = loss_type, 
                        **few_shot_params, finetune_dropout_p = params.finetune_dropout_p)
#                     model = BaselineFinetuneMinGram(backbone_func, loss_type = loss_type, **few_shot_params, **min_gram_params)
            
    
    elif params.method == 'protonet':
        # default ProtoNet
        if recons_decoder is None and params.min_gram is None:
            model = ProtoNet( backbone_func, **few_shot_params )
        else: # other settings
            if params.min_gram is not None:
                min_gram_params = {
                    'min_gram':params.min_gram, 
                    'lambda_gram':params.lambda_gram, 
                }
                model = ProtoNetMinGram(backbone_func, **few_shot_params, **min_gram_params)

            if params.recons_decoder is not None:
                if 'Hidden' in params.recons_decoder:
                    if params.recons_decoder == 'HiddenConv': # 'HiddenConv', 'HiddenConvS'
                        model = ProtoNetAE2(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, extract_layer = 2)
                    elif params.recons_decoder == 'HiddenConvS': # 'HiddenConv', 'HiddenConvS'
                        model = ProtoNetAE2(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, extract_layer = 2, is_color=False)
                    elif params.recons_decoder == 'HiddenRes10':
                        model = ProtoNetAE2(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, extract_layer = 6)
                    elif params.recons_decoder == 'HiddenRes18':
                        model = ProtoNetAE2(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, extract_layer = 8)
                else:
                    if 'ConvS' in params.recons_decoder:
                        model = ProtoNetAE(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, is_color=False)
                    else:
                        model = ProtoNetAE(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, is_color=True)
    elif params.method == 'matchingnet':
        model           = MatchingNet( backbone_func, **few_shot_params )
    elif params.method in ['relationnet', 'relationnet_softmax']:
#         if params.model == 'Conv4': 
#             feature_model = backbone.Conv4NP
#         elif params.model == 'Conv6': 
#             feature_model = backbone.Conv6NP
#         elif params.model == 'Conv4S': 
#             feature_model = backbone.Conv4SNP
#         else:
#             feature_model = lambda: model_dict[params.model]( flatten = False )
        loss_type = 'mse' if params.method == 'relationnet' else 'softmax'

        model           = RelationNet( backbone_func, loss_type = loss_type , **few_shot_params )
    elif params.method in ['maml' , 'maml_approx']:
        backbone.ConvBlock.maml = True
        backbone.SimpleBlock.maml = True
        backbone.BottleneckBlock.maml = True
        backbone.ResNet.maml = True
        model           = MAML(  backbone_func, approx = (params.method == 'maml_approx') , **few_shot_params )
        if 'omniglot' in params.dataset or 'cross_char' in params.dataset:
#         if params.dataset in ['omniglot', 'cross_char', 'cross_char_half']: #maml use different parameter in omniglot
            model.n_task     = 32
            model.task_update_num = 1
            model.train_lr = 0.1
    else:
        raise ValueError('Unexpected params.method: %s'%(params.method))
    
    print('get_model() finished.')
    return model