def run_save(params):
    print('Run Save features ... ')
    if 'maml' in params.method:
        print('Continuing since maml doesnt support save_feature')
        return

    image_size = get_image_size(params)

    if params.dataset in ['omniglot', 'cross_char']:
        assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation'
        params.model = 'Conv4S'

    split = params.split
    if params.dataset == 'cross':
        if split == 'base':
            loadfile = configs.data_dir['miniImagenet'] + 'all.json'
        else:
            loadfile = configs.data_dir['CUB'] + split + '.json'
    elif params.dataset == 'cross_char':
        if split == 'base':
            loadfile = configs.data_dir['omniglot'] + 'noLatin.json'
        else:
            loadfile = configs.data_dir['emnist'] + split + '.json'
    else:
        loadfile = configs.data_dir[params.dataset] + split + '.json'

    if hasattr(params, 'checkpoint_dir'):
        checkpoint_dir = params.checkpoint_dir
    else:
        checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
            configs.save_dir, params.dataset, params.model, params.method)
        if params.train_aug:
            checkpoint_dir += '_aug'
        if not params.method in ['baseline', 'baseline++']:
            checkpoint_dir += '_%dway_%dshot' % (params.train_n_way,
                                                 params.n_shot)

    print(f'Checkpoint dir: {checkpoint_dir}')
    if params.save_iter != -1:
        modelfile = get_assigned_file(checkpoint_dir, params.save_iter)


#    elif params.method in ['baseline', 'baseline++'] :
#        modelfile   = get_resume_file(checkpoint_dir) #comment in 2019/08/03 updates as the validation of baseline/baseline++ is added
    else:
        modelfile = get_best_file(checkpoint_dir)
    print(f'Model file {modelfile}')
    if params.save_iter != -1:
        outfile = os.path.join(
            checkpoint_dir.replace("checkpoints", "features"),
            split + "_" + str(params.save_iter) + ".hdf5")
    else:
        outfile = os.path.join(
            checkpoint_dir.replace("checkpoints", "features"), split + ".hdf5")

    datamgr = SimpleDataManager(image_size, batch_size=64)
    data_loader = datamgr.get_data_loader(loadfile, aug=False)

    if params.method in ['relationnet', 'relationnet_softmax']:
        if params.model == 'Conv4':
            model = backbone.Conv4NP()
        elif params.model == 'Conv6':
            model = backbone.Conv6NP()
        elif params.model == 'Conv4S':
            model = backbone.Conv4SNP()
        else:
            model = model_dict[params.model](flatten=False)
    elif params.method in ['maml', 'maml_approx']:
        raise ValueError('MAML do not support save feature')
    else:
        model = model_dict[params.model]()

    model = model.cuda()
    #     print('Model keys')
    #     print(model.state_dict().keys())
    tmp = torch.load(modelfile)
    state = tmp['state']
    #     print('Loaded keys')
    #     print(state.keys())
    state_keys = list(state.keys())
    for i, key in enumerate(state_keys):
        if "feature." in key:
            newkey = key.replace(
                "feature.", ""
            )  # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx'
            state[newkey] = state.pop(key)
        else:
            state.pop(key)

    model.load_state_dict(state)
    model.eval()

    dirname = os.path.dirname(outfile)
    if not os.path.isdir(dirname):
        os.makedirs(dirname)
    save_features(model, data_loader, outfile)
Esempio n. 2
0
    else:
        modelfile   = get_best_file(checkpoint_dir)

    if params.save_iter != -1:
        outfile = os.path.join( checkpoint_dir.replace("checkpoints","features"), split + "_" + str(params.save_iter)+ ".hdf5") 
    else:
        outfile = os.path.join( checkpoint_dir.replace("checkpoints","features"), split + ".hdf5") 

    datamgr         = SimpleDataManager(image_size, batch_size = 64)
    data_loader      = datamgr.get_data_loader(loadfile, aug = False, is_train=False)

    if params.method in ['relationnet', 'relationnet_softmax']:
        if params.model == 'Conv4': 
            model = backbone.Conv4NP()
        elif params.model == 'Conv6': 
            model = backbone.Conv6NP()
        elif params.model == 'Conv4S': 
            model = backbone.Conv4SNP()
        else:
            model = model_dict[params.model]( flatten = False )
    elif params.method in ['maml' , 'maml_approx']: 
       raise ValueError('MAML do not support save feature')
    else:
        model = model_dict[params.model]()

    model = model.cuda()
    print(modelfile)
    tmp = torch.load(modelfile)
    state = tmp['state']
    state_keys = list(state.keys())
    for i, key in enumerate(state_keys):
Esempio n. 3
0
def baseline_s2m2_init(params):
    assert params.method != 'maml' and params.method != 'maml_approx', 'maml do not support save_feature and run'

    if 'Conv' in params.model:
        if params.dataset in ['omniglot', 'cross_char']:
            image_size = 28
        else:
            image_size = 84
    else:
        if params.method == "S2M2_R":
            image_size = 80
        else:
            image_size = 224

    if params.dataset in ['omniglot', 'cross_char']:
        assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation'
        params.model = 'Conv4S'

    split = params.split
    if params.dataset == 'cross':
        if split == 'base':
            loadfile = configs.data_dir['miniImagenet'] + 'all.json'
        else:
            loadfile = configs.data_dir['CUB'] + split + '.json'
    elif params.dataset == 'cross_char':
        if split == 'base':
            loadfile = configs.data_dir['omniglot'] + 'noLatin.json'
        else:
            loadfile = configs.data_dir['emnist'] + split + '.json'
    else:
        loadfile = configs.data_dir[params.dataset] + split + '.json'

    ###### Temp !!!!!!!!!!!!!!!!!
    if params.dataset == "cross":
        dataset = "miniImagenet"
    else:
        dataset = params.dataset
    ###### Temp !!!!!!!!!!!!!!!!!

    checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (configs.save_dir, dataset,
                                                  params.model, params.method)
    if params.train_aug:
        checkpoint_dir += '_aug'
    if not params.method in ['baseline', 'baseline++', 'S2M2_R']:
        checkpoint_dir += '_%dway_%dshot' % (params.train_n_way, params.n_shot)

    if params.save_iter != -1:
        modelfile = get_assigned_file(checkpoint_dir, params.save_iter)
    #    elif params.method in ['baseline', 'baseline++'] :
    #        modelfile   = get_resume_file(checkpoint_dir) #comment in 2019/08/03 updates as the validation of baseline/baseline++ is added
    else:
        modelfile = get_best_file(checkpoint_dir)

    if params.save_iter != -1:
        outfile = os.path.join(
            checkpoint_dir.replace("checkpoints", "features"),
            split + "_" + str(params.save_iter) + ".hdf5")
    else:
        outfile = os.path.join(
            checkpoint_dir.replace("checkpoints", "features"), split + ".hdf5")
        if params.dataset == "cross":
            outfile = outfile.replace("miniImagenet", "cross")

    ###### Temp !!!!!!!!!!!!!!!!!
    # outfile = outfile.replace("miniImagenet", "cross")
    ###### Temp !!!!!!!!!!!!!!!!!

    datamgr = SimpleDataManager(image_size, batch_size=64)
    data_loader = datamgr.get_data_loader(loadfile, aug=False, num_workers=12)

    if params.method in ['relationnet', 'relationnet_softmax']:
        if params.model == 'Conv4':
            model = backbone.Conv4NP()
        elif params.model == 'Conv6':
            model = backbone.Conv6NP()
        elif params.model == 'Conv4S':
            model = backbone.Conv4SNP()
        else:
            model = model_dict[params.model](flatten=False)
    elif params.method in ['maml', 'maml_approx']:
        raise ValueError('MAML do not support save feature')
    elif params.method == "S2M2_R":
        model = wrn28_10(200)
    else:
        model = model_dict[params.model]()

    print("Using %s" % modelfile)

    model = model.cuda()
    tmp = torch.load(modelfile)
    state = tmp['state']
    state_keys = list(state.keys())
    if params.method == "S2M2_R":
        callwrap = False
        if 'module' in state_keys[0]:
            callwrap = True

        if callwrap:
            model = WrappedModel(model)

        model_dict_load = model.state_dict()
        model_dict_load.update(state)
        model.load_state_dict(model_dict_load)
    else:
        for i, key in enumerate(state_keys):
            if "feature." in key:
                newkey = key.replace(
                    "feature.", ""
                )  # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx'
                state[newkey] = state.pop(key)
            else:
                state.pop(key)

        model.load_state_dict(state)
    model.eval()

    dirname = os.path.dirname(outfile)
    if not os.path.isdir(dirname):
        os.makedirs(dirname)
    return model, data_loader, outfile, params