Ejemplo n.º 1
0
def select_dataloader_for_train(params):
    """
    select dataloader to define the Data reading mode based on params
    """
    isAircraft = (params.dataset == 'aircrafts')

    base_file = os.path.join('filelists', params.dataset,
                             params.base + '.json')
    val_file = os.path.join('filelists', params.dataset, 'val.json')

    image_size = params.image_size

    if params.method in ['baseline', 'baseline++']:
        base_datamgr = SimpleDataManager(image_size,
                                         batch_size=params.bs,
                                         jigsaw=params.jigsaw,
                                         rotation=params.rotation,
                                         isAircraft=isAircraft)
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   aug=params.train_aug)
        val_datamgr = SimpleDataManager(image_size,
                                        batch_size=params.bs,
                                        jigsaw=params.jigsaw,
                                        rotation=params.rotation,
                                        isAircraft=isAircraft)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)

    elif params.method in [
            'protonet', 'matchingnet', 'relationnet', 'relationnet_softmax',
            'maml', 'maml_approx'
    ]:
        n_query = max(
            1, int(params.n_query * params.test_n_way / params.train_n_way)
        )  #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small

        train_few_shot_params    = dict(n_way = params.train_n_way, n_support = params.n_shot, \
                                        jigsaw=params.jigsaw, lbda=params.lbda, rotation=params.rotation)
        base_datamgr = SetDataManager(image_size,
                                      n_query=n_query,
                                      **train_few_shot_params,
                                      isAircraft=isAircraft)
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   aug=params.train_aug)

        test_few_shot_params     = dict(n_way = params.test_n_way, n_support = params.n_shot, \
                                        jigsaw=params.jigsaw, lbda=params.lbda, rotation=params.rotation)
        val_datamgr = SetDataManager(image_size,
                                     n_query=n_query,
                                     **test_few_shot_params,
                                     isAircraft=isAircraft)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)

    return base_loader, val_loader
Ejemplo n.º 2
0
def sib_init(params, split):
    model_name = params.model.lower()
    if params.dataset == "cross" or params.dataset == "miniImagenet":
        model_dir = os.path.join(configs.sib_dir, "miniImagenet",
                                 model_name + "_best.pth")
    model = SIBWRN(num_classes=64)
    model = model.cuda()
    model_dict = model.encoder.state_dict()
    checkpoint = torch.load(model_dir)
    model_dict.update(checkpoint)
    model.encoder.load_state_dict(model_dict)
    model.eval()

    if params.dataset == "cross":
        loadfile = configs.data_dir['CUB'] + split + '.json'
    elif params.dataset == "miniImagenet":
        loadfile = configs.data_dir['miniImagenet'] + split + '.json'
    image_size = 80
    datamgr = SimpleDataManager(image_size, batch_size=64)
    data_loader = datamgr.get_data_loader(loadfile, aug=False, num_workers=12)

    outfile = '%s/features/%s/%s/%s.hdf5' % (configs.sib_dir, params.dataset,
                                             params.model, split)
    dirname = os.path.dirname(outfile)
    if not os.path.isdir(dirname):
        os.makedirs(dirname)
    return model, data_loader, outfile, params
Ejemplo n.º 3
0
 def test_d_specific_classifiers(self, n_clf):
     in_dim = int(self.feat_dim / n_clf)
     out_dim = self.num_classes
     tiered_mini = False
     if self.dataset == "tiered":
         tiered_mini = True
         base_file = "base"
     else:
         base_file = configs.data_dir['miniImagenet'] + "base" + '.json'
     batch_size = 32
     base_datamgr = SimpleDataManager(self.image_size,
                                      batch_size=batch_size)
     base_loader = base_datamgr.get_data_loader(base_file,
                                                aug=False,
                                                num_workers=0,
                                                tiered_mini=tiered_mini)
     correct_counts = np.zeros(n_clf)
     total = 0
     for i, (x, y, _) in enumerate(base_loader):
         x = Variable(x.cuda())
         out = self.get_features(x)
         total += out.shape[0]
         for j in range(n_clf):
             start = in_dim * j
             stop = start + in_dim
             scores = self.clfs[j](out[:, start:stop])
             pred = scores.data.cpu().numpy().argmax(axis=1)
             y_np = y.cpu().numpy()
             correct_counts[j] += (pred == y_np).sum()
     correct_counts = correct_counts / total
Ejemplo n.º 4
0
def get_loader(params):
    if params.dataset == 'cross':
        base_file = configs.data_dir['miniImagenet'] + 'all.json'
        val_file = configs.data_dir['CUB'] + 'val.json'
    elif params.dataset == 'cross_char':
        base_file = configs.data_dir['omniglot'] + 'noLatin.json'
        val_file = configs.data_dir['emnist'] + 'val.json'
    else:
        base_file = configs.data_dir[params.dataset] + 'base.json'
        val_file = configs.data_dir[params.dataset] + 'val.json'

    if 'Conv' in params.model:
        if params.dataset in ['omniglot', 'cross_char']:
            image_size = 28
        else:
            image_size = 84
    else:
        image_size = 224

    if params.dataset in ['omniglot', 'cross_char']:
        assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation'
        params.model = 'Conv4S'

    if params.stop_epoch == -1:
        if params.method in ['baseline', 'baseline++']:
            if params.dataset in ['omniglot', 'cross_char']:
                params.stop_epoch = 5
            elif params.dataset in ['CUB']:
                params.stop_epoch = 200  # This is different as stated in the open-review paper. However, using 400 epoch in baseline actually lead to over-fitting
            elif params.dataset in ['miniImagenet', 'cross']:
                params.stop_epoch = 400
            else:
                params.stop_epoch = 400  # default
        else:  # meta-learning methods
            if params.n_shot == 1:
                params.stop_epoch = 600
            elif params.n_shot == 5:
                params.stop_epoch = 400
            else:
                params.stop_epoch = 600  # default

    base_datamgr = SimpleDataManager(image_size, batch_size=16)
    base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug)
    val_datamgr = SimpleDataManager(image_size, batch_size=64)
    val_loader = val_datamgr.get_data_loader(val_file, aug=False)

    return base_loader, val_loader
Ejemplo n.º 5
0
    def save_pretrain_dataset(self, split):
        params = self.params
        tiered_mini = False
        if params.dataset == 'cross':
            # base_file = configs.data_dir['miniImagenet'] + 'all.json'  # Original code
            if split == "base":
                base_file = configs.data_dir['miniImagenet'] + split + '.json'
            elif split == "novel":
                base_file = configs.data_dir['CUB'] + split + '.json'
        elif params.dataset == 'cross_char':
            base_file = configs.data_dir['omniglot'] + 'noLatin.json'
        elif params.dataset == "tiered":
            base_file = split
            tiered_mini = True
        else:
            base_file = configs.data_dir[params.dataset] + split + '.json'

        batch_size = self.batch_size

        base_datamgr = SimpleDataManager(self.image_size,
                                         batch_size=batch_size)
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   aug=False,
                                                   num_workers=12,
                                                   tiered_mini=tiered_mini)

        features = []
        labels = []
        print("Saving pretrain dataset...")
        for epoch in range(0, 1):
            for i, (x, y, _) in enumerate(base_loader):
                x = Variable(x.cuda())
                out = self.get_features(x)
                current_batch_size = out.shape[0]
                for j in range(current_batch_size):
                    np_out = out.data.cpu().numpy()[j]
                    np_y = y.numpy()[j]
                    features.append(np_out)
                    labels.append(np_y)
                print_with_carriage_return(
                    "Epoch %d: %d/%d processed" %
                    (epoch, i, len(base_loader.dataset.meta["image_labels"]) /
                     batch_size))
            end_carriage_return_print()
        dataset = self.params.dataset
        features_dir = "pretrain/features_%s_%s_%s_%s.npy" % (
            dataset, self.params.method, self.model_name, split)
        labels_dir = "pretrain/labels_%s_%s_%s_%s.npy" % (
            dataset, self.params.method, self.model_name, split)
        np.save(features_dir, np.asarray(features))
        np.save(labels_dir, np.asarray(labels))
        return np.asarray(features), np.asarray(labels)
Ejemplo n.º 6
0
    def _calc_pretrained_class_mean(self, normalize=False):
        params = self.params
        print(params)
        aug = False
        tiered_mini = False
        if params.dataset == 'cross':
            # base_file = configs.data_dir['miniImagenet'] + 'all.json'  # Original code
            base_file = configs.data_dir['miniImagenet'] + 'base.json'
        elif params.dataset == 'cross_char':
            base_file = configs.data_dir['omniglot'] + 'noLatin.json'
        elif params.dataset == "tiered":
            base_file = "base"
            tiered_mini = True
        else:
            base_file = configs.data_dir[params.dataset] + 'base.json'

        batch_size = self.batch_size
        base_datamgr = SimpleDataManager(self.image_size,
                                         batch_size=batch_size)
        # No aug for every method
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   aug=aug,
                                                   num_workers=12,
                                                   tiered_mini=tiered_mini)

        features = np.zeros((self.num_classes, self.feat_dim))
        counts = np.zeros(self.num_classes)
        print("saving pretrained mean")
        for epoch in range(0, 1):
            for i, (x, y, path) in enumerate(base_loader):
                x = Variable(x.cuda())
                out = self.get_features(x)
                if normalize:
                    out = self.normalize(out)
                for j in range(out.shape[0]):
                    np_out = out.data.cpu().numpy()[j]
                    np_y = y.numpy()[j]
                    features[np_y] += np_out
                    counts[np_y] += 1
                print_with_carriage_return("Epoch %d: %d/%d processed" %
                                           (epoch, i, len(base_loader)))
            end_carriage_return_print()
            # print(np.max(counts[64:]))
            print(np.max(features))
        for i in range(0, len(counts)):
            if counts[i] != 0:
                features[i] = features[i] / counts[i]
        return features
Ejemplo n.º 7
0
 def train_d_specific_classifiers(self, n_clf):
     in_dim = int(self.feat_dim / n_clf)
     out_dim = self.num_classes
     self.clfs = nn.ModuleList(
         [nn.Linear(in_dim, out_dim) for i in range(n_clf)])
     self.clfs = self.clfs.cuda()
     tiered_mini = False
     if self.dataset == "tiered":
         tiered_mini = True
         base_file = "base"
     else:
         base_file = configs.data_dir['miniImagenet'] + "base" + '.json'
     batch_size = 128
     base_datamgr = SimpleDataManager(self.image_size,
                                      batch_size=batch_size)
     base_loader = base_datamgr.get_data_loader(base_file,
                                                aug=True,
                                                num_workers=12,
                                                tiered_mini=tiered_mini)
     loss_fn = nn.CrossEntropyLoss()
     params = self.clfs.parameters()
     optimizer = torch.optim.Adam(params)
     for epoch in range(0, 40):
         for i, (x, y, _) in enumerate(base_loader):
             optimizer.zero_grad()
             x = Variable(x.cuda())
             out = self.get_features(x)
             y = y.cuda()
             avg_loss = 0
             for j in range(n_clf):
                 start = in_dim * j
                 stop = start + in_dim
                 scores = self.clfs[j](out[:, start:stop])
                 loss = loss_fn(scores, y)
                 loss.backward(retain_graph=True)
                 avg_loss += loss.item()
             optimizer.step()
             if i % 10 == 0:
                 print("Epoch: %d, Batch %d/%d, Loss=%.3f" %
                       (epoch, i, len(base_loader), avg_loss / n_clf))
         # save model
         out_dir = "pretrain/clfs/%s_%s_%s_%d" % (
             self.method, self.model_name, self.base_dataset, n_clf)
         if not os.path.isdir(out_dir):
             os.makedirs(out_dir)
         outfile = os.path.join(out_dir, "%d.tar" % (epoch))
         torch.save(self.clfs.state_dict(), outfile)
Ejemplo n.º 8
0
def simple_shot_init(params, split):
    model_name = params.model.lower()
    if params.dataset == "cross" or params.dataset == "miniImagenet":
        model_dir = os.path.join(configs.simple_shot_dir, "miniImagenet",
                                 model_name, "model_best.pth.tar")
    elif params.dataset == "tiered":
        model_dir = os.path.join(configs.simple_shot_dir, "tiered", model_name,
                                 "model_best.pth.tar")
    num_classes = 64
    if params.dataset == "tiered":
        num_classes = 351
    model = simple_shot_models.__dict__[model_name](num_classes=num_classes,
                                                    remove_linear=True)
    model = model.cuda()
    # model = torch.nn.DataParallel(model).cuda()
    checkpoint = torch.load(model_dir)
    model_dict = model.state_dict()
    model_params = checkpoint['state_dict']
    model_params = {
        remove_module_from_param_name(k): v
        for k, v in model_params.items()
    }
    model_params = {k: v for k, v in model_params.items() if k in model_dict}
    model_dict.update(model_params)
    model.load_state_dict(model_dict)

    tiered_mini = False
    if params.dataset == "cross":
        loadfile = configs.data_dir['CUB'] + split + '.json'
    elif params.dataset == "miniImagenet":
        loadfile = configs.data_dir['miniImagenet'] + split + '.json'
    elif params.dataset == "tiered":
        loadfile = split
        tiered_mini = True
    image_size = 84
    datamgr = SimpleDataManager(image_size, batch_size=64)
    data_loader = datamgr.get_data_loader(loadfile,
                                          aug=False,
                                          num_workers=12,
                                          tiered_mini=tiered_mini)

    outfile = '%s/features/%s/%s/%s.hdf5' % (
        configs.simple_shot_dir, params.dataset, params.model, split)
    dirname = os.path.dirname(outfile)
    if not os.path.isdir(dirname):
        os.makedirs(dirname)
    return model, data_loader, outfile, params
Ejemplo n.º 9
0
def cosine_init(params, split):
    model_name = params.model.lower()
    if params.dataset == "cross" or params.dataset == "miniImagenet":
        num_classes = 64
        model_dir = os.path.join(configs.cosine_dir, "miniImagenet",
                                 model_name, "max_acc.pth")
    elif params.dataset == "tiered":
        num_classes = 351
        model_dir = os.path.join(configs.cosine_dir, "tiered", model_name,
                                 "max_acc.pth")
    if model_name == "resnet10":
        feat_dim = 512
    elif model_name == "wrn":
        feat_dim = 640
    model = CosinePretrain(model_name, num_classes, feat_dim)
    model = model.cuda()
    model_dict = model.encoder.state_dict()
    ckpt = torch.load(model_dir)["params"]
    model_dict.update(ckpt)
    model.encoder.load_state_dict(model_dict)
    model.eval()

    if params.dataset == "cross":
        loadfile = configs.data_dir['CUB'] + split + '.json'
    elif params.dataset == "miniImagenet":
        loadfile = configs.data_dir['miniImagenet'] + split + '.json'
    image_size = 84
    datamgr = SimpleDataManager(image_size, batch_size=64)
    data_loader = datamgr.get_data_loader(loadfile, aug=False, num_workers=12)

    outfile = '%s/features/%s/%s/%s.hdf5' % (
        configs.cosine_dir, params.dataset, params.model, split)
    dirname = os.path.dirname(outfile)
    if not os.path.isdir(dirname):
        os.makedirs(dirname)
    return model, data_loader, outfile, params
Ejemplo n.º 10
0
def main_train(params):
    _set_seed(params)

    results_logger = ResultsLogger(params)

    if params.dataset == 'cross':
        base_file = configs.data_dir['miniImagenet'] + 'all.json'
        val_file = configs.data_dir['CUB'] + 'val.json'
    elif params.dataset == 'cross_char':
        base_file = configs.data_dir['omniglot'] + 'noLatin.json'
        val_file = configs.data_dir['emnist'] + 'val.json'
    else:
        base_file = configs.data_dir[params.dataset] + 'base.json'
        val_file = configs.data_dir[params.dataset] + 'val.json'
    if 'Conv' in params.model:
        if params.dataset in ['omniglot', 'cross_char']:
            image_size = 28
        else:
            image_size = 84
    else:
        image_size = 224
    if params.dataset in ['omniglot', 'cross_char']:
        assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation'
        params.model = 'Conv4S'
    optimization = 'Adam'
    if params.stop_epoch == -1:
        if params.method in ['baseline', 'baseline++']:
            if params.dataset in ['omniglot', 'cross_char']:
                params.stop_epoch = 5
            elif params.dataset in ['CUB']:
                params.stop_epoch = 200  # This is different as stated in the open-review paper. However, using 400 epoch in baseline actually lead to over-fitting
            elif params.dataset in ['miniImagenet', 'cross']:
                params.stop_epoch = 400
            else:
                params.stop_epoch = 400  # default
        else:  # meta-learning methods
            if params.n_shot == 1:
                params.stop_epoch = 600
            elif params.n_shot == 5:
                params.stop_epoch = 400
            else:
                params.stop_epoch = 600  # default
    if params.method in ['baseline', 'baseline++']:
        base_datamgr = SimpleDataManager(image_size, batch_size=16)
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   aug=params.train_aug)
        val_datamgr = SimpleDataManager(image_size, batch_size=64)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)

        if params.dataset == 'omniglot':
            assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class'
        if params.dataset == 'cross_char':
            assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class'

        if params.method == 'baseline':
            model = BaselineTrain(model_dict[params.model], params.num_classes)
        elif params.method == 'baseline++':
            model = BaselineTrain(model_dict[params.model],
                                  params.num_classes,
                                  loss_type='dist')

    elif params.method in [
            'DKT', 'protonet', 'matchingnet', 'relationnet',
            'relationnet_softmax', 'maml', 'maml_approx'
    ]:
        n_query = max(
            1, int(16 * params.test_n_way / params.train_n_way)
        )  # if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small

        train_few_shot_params = dict(n_way=params.train_n_way,
                                     n_support=params.n_shot)
        base_datamgr = SetDataManager(image_size,
                                      n_query=n_query,
                                      **train_few_shot_params)  # n_eposide=100
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   aug=params.train_aug)

        test_few_shot_params = dict(n_way=params.test_n_way,
                                    n_support=params.n_shot)
        val_datamgr = SetDataManager(image_size,
                                     n_query=n_query,
                                     **test_few_shot_params)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)
        # a batch for SetDataManager: a [n_way, n_support + n_query, dim, w, h] tensor

        if (params.method == 'DKT'):
            model = DKT(model_dict[params.model], **train_few_shot_params)
            model.init_summary()
        elif params.method == 'protonet':
            model = ProtoNet(model_dict[params.model], **train_few_shot_params)
        elif params.method == 'matchingnet':
            model = MatchingNet(model_dict[params.model],
                                **train_few_shot_params)
        elif params.method in ['relationnet', 'relationnet_softmax']:
            if params.model == 'Conv4':
                feature_model = backbone.Conv4NP
            elif params.model == 'Conv6':
                feature_model = backbone.Conv6NP
            elif params.model == 'Conv4S':
                feature_model = backbone.Conv4SNP
            else:
                feature_model = lambda: model_dict[params.model](flatten=False)
            loss_type = 'mse' if params.method == 'relationnet' else 'softmax'

            model = RelationNet(feature_model,
                                loss_type=loss_type,
                                **train_few_shot_params)
        elif params.method in ['maml', 'maml_approx']:
            backbone.ConvBlock.maml = True
            backbone.SimpleBlock.maml = True
            backbone.BottleneckBlock.maml = True
            backbone.ResNet.maml = True
            model = MAML(model_dict[params.model],
                         approx=(params.method == 'maml_approx'),
                         **train_few_shot_params)
            if params.dataset in [
                    'omniglot', 'cross_char'
            ]:  # maml use different parameter in omniglot
                model.n_task = 32
                model.task_update_num = 1
                model.train_lr = 0.1
    else:
        raise ValueError('Unknown method')
    model = model.cuda()
    params.checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
        configs.save_dir, params.dataset, params.model, params.method)
    if params.train_aug:
        params.checkpoint_dir += '_aug'
    if not params.method in ['baseline', 'baseline++']:
        params.checkpoint_dir += '_%dway_%dshot' % (params.train_n_way,
                                                    params.n_shot)
    if not os.path.isdir(params.checkpoint_dir):
        os.makedirs(params.checkpoint_dir)
    start_epoch = params.start_epoch
    stop_epoch = params.stop_epoch
    if params.method == 'maml' or params.method == 'maml_approx':
        stop_epoch = params.stop_epoch * model.n_task  # maml use multiple tasks in one update
    if params.resume:
        resume_file = get_resume_file(params.checkpoint_dir)
        if resume_file is not None:
            tmp = torch.load(resume_file)
            start_epoch = tmp['epoch'] + 1
            model.load_state_dict(tmp['state'])
    elif params.warmup:  # We also support warmup from pretrained baseline feature, but we never used in our paper
        baseline_checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
            configs.save_dir, params.dataset, params.model, 'baseline')
        if params.train_aug:
            baseline_checkpoint_dir += '_aug'
        warmup_resume_file = get_resume_file(baseline_checkpoint_dir)
        tmp = torch.load(warmup_resume_file)
        if tmp is not None:
            state = tmp['state']
            state_keys = list(state.keys())
            for i, key in enumerate(state_keys):
                if "feature." in key:
                    newkey = key.replace(
                        "feature.", ""
                    )  # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx'
                    state[newkey] = state.pop(key)
                else:
                    state.pop(key)
            model.feature.load_state_dict(state)
        else:
            raise ValueError('No warm_up file')

    model = train(base_loader, val_loader, model, optimization, start_epoch,
                  stop_epoch, params, results_logger)
    results_logger.save()
Ejemplo n.º 11
0
    if params.save_iter != -1:
        modelfile = get_assigned_file(checkpoint_dir, params.save_iter)
    elif params.method in ['baseline', 'baseline++']:
        modelfile = get_resume_file(checkpoint_dir)
    else:
        modelfile = get_best_file(checkpoint_dir)

    if params.save_iter != -1:
        outfile = os.path.join(checkpoint_dir.replace("checkpoints", "features"),
                               split + "_" + str(params.save_iter) + ".hdf5")
    else:
        outfile = os.path.join(checkpoint_dir.replace("checkpoints", "features"), split + ".hdf5")

    datamgr = SimpleDataManager(image_size, batch_size=64)
    data_loader = datamgr.get_data_loader(loadfile, aug=False)

    if params.method in ['relationnet', 'relationnet_softmax']:
        if params.model == 'Conv4':
            model = backbone.Conv4NP()
        elif params.model == 'Conv6':
            model = backbone.Conv6NP()
        elif params.model == 'Conv4S':
            model = backbone.Conv4SNP()
        else:
            model = model_dict[params.model](flatten=False)
    elif params.method in ['maml', 'maml_approx']:
        raise ValueError('MAML do not support save feature')
    else:
        model = model_dict[params.model]()
Ejemplo n.º 12
0
            if params.n_shot == 1:
                params.stop_epoch = 600
            elif params.n_shot == 5:
                params.stop_epoch = 400
            else:
                params.stop_epoch = 400  #default

    if params.method in ['baseline', 'baseline++']:
        limit_n_images = params.limit_n_images
        if params.limit_n_images == 0:
            limit_n_images = 1

        base_datamgr = SimpleDataManager(image_size, batch_size=16)
        base_loader = base_datamgr.get_data_loader(
            base_file,
            aug=params.train_aug,
            n_images=limit_n_images,
            n_classes=params.limit_n_classes,
            seed=0)
        val_datamgr = SimpleDataManager(image_size, batch_size=64)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)

        if params.dataset == 'omniglot':
            assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class'
        if params.dataset == 'cross_char':
            assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class'

        if params.method == 'baseline':
            model = BaselineTrain(model_dict[params.model], params.num_classes)
        elif params.method == 'baseline++':
            model = BaselineTrain(model_dict[params.model],
                                  params.num_classes,
Ejemplo n.º 13
0
            elif params.dataset in ['miniImagenet', 'cross']:
                params.stop_epoch = 400
            else:
                params.stop_epoch = 400  #default
        else:  #meta-learning methods
            if params.n_shot == 1:
                params.stop_epoch = 600
            elif params.n_shot == 5:
                params.stop_epoch = 400
            else:
                params.stop_epoch = 600  #default

    if params.method in ['baseline', 'baseline++']:
        base_datamgr = SimpleDataManager(image_size, batch_size=16)
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   aug=params.train_aug,
                                                   is_train=True)
        val_datamgr = SimpleDataManager(image_size, batch_size=64)
        val_loader = val_datamgr.get_data_loader(val_file,
                                                 aug=False,
                                                 is_train=False)

        if params.dataset == 'omniglot':
            assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class'
        if params.dataset == 'cross_char':
            assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class'

        if params.method == 'baseline':
            model = BaselineTrain(model_dict[params.model], params.num_classes)
        elif params.method == 'baseline++':
            model = BaselineTrain(model_dict[params.model],
Ejemplo n.º 14
0
        checkpoint_dir += '_%dway_%dshot' %( params.train_n_way, params.n_shot)

    if params.save_iter != -1:
        modelfile   = get_assigned_file(checkpoint_dir,params.save_iter)
#    elif params.method in ['baseline', 'baseline++'] :
#        modelfile   = get_resume_file(checkpoint_dir) #comment in 2019/08/03 updates as the validation of baseline/baseline++ is added
    else:
        modelfile   = get_best_file(checkpoint_dir)

    if params.save_iter != -1:
        outfile = os.path.join( checkpoint_dir.replace("checkpoints","features"), split + "_" + str(params.save_iter)+ ".hdf5") 
    else:
        outfile = os.path.join( checkpoint_dir.replace("checkpoints","features"), split + ".hdf5") 

    datamgr         = SimpleDataManager(image_size, batch_size = 64)
    data_loader      = datamgr.get_data_loader(loadfile, aug = False, is_train=False)

    if params.method in ['relationnet', 'relationnet_softmax']:
        if params.model == 'Conv4': 
            model = backbone.Conv4NP()
        elif params.model == 'Conv6': 
            model = backbone.Conv6NP()
        elif params.model == 'Conv4S': 
            model = backbone.Conv4SNP()
        else:
            model = model_dict[params.model]( flatten = False )
    elif params.method in ['maml' , 'maml_approx']: 
       raise ValueError('MAML do not support save feature')
    else:
        model = model_dict[params.model]()
Ejemplo n.º 15
0
    test_ffs_params     = dict(n_way = 1, n_support = params.n_shot, n_episodes = 20)
    # for attr_split, example_split in tqdm(itertools.product(['train', 'val', 'test'],['base']), desc='Loading Val Loaders'):
    for attr_split, example_split in itertools.product(['train', 'val', 'test'],['base','val','novel']):
        val_loader = FFSDataManager(params.x_type, image_size, attr_split=attr_split, attr_split_file=params.attr_split_file, n_query = 15, **test_ffs_params).get_data_loader( configs.data_dir[params.dataset] + f'{example_split}.json' , aug = False) 
        eval_loaders_dic[f"FFS,attr={attr_split},example={example_split}"] = val_loader
        
    # FS
    # test_few_shot_params     = dict(n_way = params.test_n_way, n_support = params.n_shot, n_episodes = 20) 
    # for  example_split in ['base','val','novel']:
    #     val_loader = SetDataManager(image_size, n_query = 15, **test_few_shot_params).get_data_loader( configs.data_dir[params.dataset] + f'{example_split}.json' , aug = False) 
    #     eval_loaders_dic[f"FS,example={example_split}"] = val_loader
     

    if params.method in ['baseline', 'baseline++'] :
        base_datamgr    = SimpleDataManager(image_size, batch_size = 16)
        base_loader     = base_datamgr.get_data_loader( base_file , aug = params.train_aug )

        model  = BaselineTrainTest( model_dict[params.model], params.num_classes, params.test_n_way, params.n_shot, loss_type = 'softmax' if params.method == 'baseline' else 'dist')

    elif params.method == 'attr':
        base_datamgr    = AttrDataManager(image_size, params.train_attr_split, attr_split_file=params.attr_split_file, batch_size = 16)
        base_loader     = base_datamgr.get_data_loader( base_file , aug = params.train_aug )
        

        model  = BaselineTrainTest( model_dict[params.model],params.num_classes,  params.test_n_way, params.n_shot, loss_type= 'bce')

    elif params.method in ['protonet','matchingnet','relationnet', 'relationnet_softmax', 'maml', 'maml_approx']:
        n_query = max(1, int(16* params.test_n_way/params.train_n_way)) #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small
        
        if params.train_ffs:
            train_ffs_params    = dict(n_way = 1, n_support = params.n_shot) 
Ejemplo n.º 16
0
def get_train_val_loader(params, source_val):
    # to prevent circular import
    from data.datamgr import SimpleDataManager, SetDataManager, AugSetDataManager, VAESetDataManager

    image_size = get_img_size(params)
    base_file, val_file = get_train_val_filename(params)
    if source_val:
        source_val_file = get_source_val_filename(params)

    if params.method in ['baseline', 'baseline++']:
        base_datamgr = SimpleDataManager(image_size, batch_size=16)
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   aug=params.train_aug)
        #         val_datamgr     = SimpleDataManager(image_size, batch_size = 64)
        #         val_loader      = val_datamgr.get_data_loader( val_file, aug = False)

        # to do fine-tune when validation
        n_query = max(
            1, int(16 * params.test_n_way / params.train_n_way)
        )  #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small
        val_few_shot_params = get_few_shot_params(params, 'val')
        val_datamgr = SetDataManager(image_size,
                                     n_query=n_query,
                                     **val_few_shot_params)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)
        if source_val:
            source_val_datamgr = SetDataManager(image_size,
                                                n_query=n_query,
                                                **val_few_shot_params)
            source_val_loader = val_datamgr.get_data_loader(source_val_file,
                                                            aug=False)

    elif params.method in [
            'protonet', 'matchingnet', 'relationnet', 'relationnet_softmax',
            'maml', 'maml_approx'
    ]:
        n_query = max(
            1, int(16 * params.test_n_way / params.train_n_way)
        )  #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small

        #         train_few_shot_params    = dict(n_way = params.train_n_way, n_support = params.n_shot)
        #         val_few_shot_params     = dict(n_way = params.test_n_way, n_support = params.n_shot)
        train_few_shot_params = get_few_shot_params(params, 'train')
        val_few_shot_params = get_few_shot_params(params, 'val')
        if params.vaegan_exp is not None:
            # TODO
            is_training = False
            vaegan = restore_vaegan(params.dataset,
                                    params.vaegan_exp,
                                    params.vaegan_step,
                                    is_training=is_training)

            base_datamgr = VAESetDataManager(
                image_size,
                n_query=n_query,
                vaegan_exp=params.vaegan_exp,
                vaegan_step=params.vaegan_step,
                vaegan_is_train=params.vaegan_is_train,
                lambda_zlogvar=params.zvar_lambda,
                fake_prob=params.fake_prob,
                **train_few_shot_params)
            # train_val or val???
            val_datamgr = SetDataManager(image_size,
                                         n_query=n_query,
                                         **val_few_shot_params)

        elif params.aug_target is None:  # Common Case
            assert params.aug_type is None

            base_datamgr = SetDataManager(image_size,
                                          n_query=n_query,
                                          **train_few_shot_params)
            val_datamgr = SetDataManager(image_size,
                                         n_query=n_query,
                                         **val_few_shot_params)
            if source_val:
                source_val_datamgr = SetDataManager(image_size,
                                                    n_query=n_query,
                                                    **val_few_shot_params)
        else:
            aug_type = params.aug_type
            assert aug_type is not None
            base_datamgr = AugSetDataManager(image_size,
                                             n_query=n_query,
                                             aug_type=aug_type,
                                             aug_target=params.aug_target,
                                             **train_few_shot_params)
            val_datamgr = AugSetDataManager(image_size,
                                            n_query=n_query,
                                            aug_type=aug_type,
                                            aug_target='test-sample',
                                            **val_few_shot_params)
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   aug=params.train_aug)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)
        if source_val:
            source_val_loader = val_datamgr.get_data_loader(source_val_file,
                                                            aug=False)
        #a batch for SetDataManager: a [n_way, n_support + n_query, n_channel, w, h] tensor

    else:
        raise ValueError('Unknown method')

    if source_val:
        return base_loader, val_loader, source_val_loader
    else:
        return base_loader, val_loader
Ejemplo n.º 17
0
    if 'Conv' in params.model:
        image_size = 84
    else:
        image_size = 224

    n_query = max(1, int(16 * params.test_n_way / params.train_n_way))
    train_few_shot_params = dict(n_way=params.train_n_way,
                                 n_support=params.n_shot)
    base_datamgr = SetDataManager(image_size,
                                  n_query=n_query,
                                  **train_few_shot_params)
    aux_datamgr = SimpleDataManager(image_size, batch_size=16)
    aux_iter = iter(
        cycle(
            aux_datamgr.get_data_loader(os.path.join(params.data_dir,
                                                     'miniImagenet',
                                                     'base.json'),
                                        aug=params.train_aug)))
    test_few_shot_params = dict(n_way=params.test_n_way,
                                n_support=params.n_shot)
    val_datamgr = SetDataManager(image_size,
                                 n_query=n_query,
                                 **test_few_shot_params)
    val_loader = val_datamgr.get_data_loader(val_file, aug=False)

    model = LFTNet(params, tf_path=params.tf_dir)
    model.cuda()

    # resume training
    start_epoch = params.start_epoch
    stop_epoch = params.stop_epoch
    if params.resume != '':
Ejemplo n.º 18
0
#    elif params.method in ['baseline', 'baseline++'] :
#        modelfile   = get_resume_file(checkpoint_dir) #comment in 2019/08/03 updates as the validation of baseline/baseline++ is added
    else:
        modelfile = get_best_file(checkpoint_dir)

    if params.save_iter != -1:
        outfile = os.path.join(
            checkpoint_dir.replace("checkpoints", "features"),
            'test' + "_" + str(params.save_iter) + ".hdf5")
    else:
        outfile = os.path.join(
            checkpoint_dir.replace("checkpoints", "features"),
            'test' + ".hdf5")

    datamgr = SimpleDataManager(batch_size=64)
    base_loader = datamgr.get_data_loader(root='./filelists/tabula_muris',
                                          mode='test')

    if params.method in ['relationnet', 'relationnet_softmax']:
        if params.model == 'Conv4':
            model = backbone.Conv4NP()
        elif params.model == 'Conv6':
            model = backbone.Conv6NP()
        elif params.model == 'Conv4S':
            model = backbone.Conv4SNP()
        else:
            model = model_dict[params.model](flatten=False)
    elif params.method in ['maml', 'maml_approx']:
        raise ValueError('MAML do not support save feature')
    else:
        model = backbone.FCNet(x_dim=base_loader.dataset.get_dim())
Ejemplo n.º 19
0
            for out, label in zip(outputs, labels):
                output_dict[label.item()].append(out)

        all_info = output_dict
        save_pickle(save_dir + '/%s_features.plk' % set, all_info)
        return all_info


if __name__ == '__main__':
    params = parse_args('test')
    loadfile_base = configs.data_dir[params.dataset] + 'base.json'
    loadfile_novel = configs.data_dir[params.dataset] + 'novel.json'

    datamgr = SimpleDataManager(84, batch_size=256)
    base_loader = datamgr.get_data_loader(loadfile_base, aug=False)
    novel_loader = datamgr.get_data_loader(loadfile_novel, aug=False)

    checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
        configs.save_dir, params.dataset, params.model, params.method)
    modelfile = get_resume_file(checkpoint_dir)

    model = wrn28_10(num_classes=params.num_classes)
    model = model.to(device)

    cudnn.benchmark = True

    checkpoint = torch.load(modelfile)
    state = checkpoint['state']
    state_keys = list(state.keys())
Ejemplo n.º 20
0
#        modelfile   = get_resume_file(checkpoint_dir) #comment in 2019/08/03 updates as the validation of baseline/baseline++ is added
    else:
        modelfile = get_best_file(checkpoint_dir, params.test_n_way)

    if params.save_iter != -1:
        outfile = os.path.join(
            checkpoint_dir.replace("checkpoints", "features"),
            split + "_" + str(params.save_iter) + ".hdf5")
    else:
        outfile = os.path.join(
            checkpoint_dir.replace("checkpoints", "features"), split + ".hdf5")
        #outfile = os.path.join( checkpoint_dir.replace("checkpoints","features"), split + "_test_random-way.hdf5")

    datamgr = SimpleDataManager(image_size, batch_size=64)
    data_loader = datamgr.get_data_loader(loadfile,
                                          [loadfile_unk, loadfile_sil],
                                          aug=False)

    if params.method in ['relationnet', 'relationnet_softmax']:
        if params.model == 'Conv4':
            model = backbone.Conv4NP()
        elif params.model == 'Conv6':
            model = backbone.Conv6NP()
        elif params.model == 'Conv4S':
            model = backbone.Conv4SNP()
        else:
            model = model_dict[params.model](flatten=False)
    elif params.method in ['maml', 'maml_approx']:
        raise ValueError('MAML do not support save feature')
    else:
        model = model_dict[params.model]()
Ejemplo n.º 21
0
    optimization = 'Adam'

    if params.method in ['baseline']:

        if params.dataset == "miniImageNet":
            #print('hi')
            datamgr = miniImageNet_few_shot.SimpleDataManager(image_size,
                                                              batch_size=16)
            #print("bye")
            base_loader = datamgr.get_data_loader(aug=params.train_aug)
            #print("loaded")
        elif params.dataset == "CUB":

            base_file = configs.data_dir['CUB'] + 'base.json'
            base_datamgr = SimpleDataManager(image_size, batch_size=16)
            base_loader = base_datamgr.get_data_loader(base_file,
                                                       aug=params.train_aug)

        elif params.dataset == "cifar100":
            base_datamgr = cifar_few_shot.SimpleDataManager("CIFAR100",
                                                            image_size,
                                                            batch_size=16)
            base_loader = base_datamgr.get_data_loader("base", aug=True)

            params.num_classes = 100

        elif params.dataset == 'caltech256':
            base_datamgr = caltech256_few_shot.SimpleDataManager(image_size,
                                                                 batch_size=16)
            base_loader = base_datamgr.get_data_loader(aug=False)
            params.num_classes = 257
Ejemplo n.º 22
0
def main():
    timer = Timer()
    args, writer = init()

    train_file = args.dataset_dir + 'train.json'
    val_file = args.dataset_dir + 'val.json'

    few_shot_params = dict(n_way=args.n_way, n_support=args.n_shot, n_query=args.n_query)
    n_episode = 10 if args.debug else 100
    if args.method_type is Method_type.baseline:
        train_datamgr = SimpleDataManager(train_file, args.dataset_dir, args.image_size, batch_size=64)
        train_loader = train_datamgr.get_data_loader(aug = True)
    else:
        train_datamgr = SetDataManager(train_file, args.dataset_dir, args.image_size,
                                       n_episode=n_episode, mode='train', **few_shot_params)
        train_loader = train_datamgr.get_data_loader(aug=True)

    val_datamgr = SetDataManager(val_file, args.dataset_dir, args.image_size,
                                     n_episode=n_episode, mode='val', **few_shot_params)
    val_loader = val_datamgr.get_data_loader(aug=False)

    if args.model_type is Model_type.ConvNet:
        pass
    elif args.model_type is Model_type.ResNet12:
        from methods.backbone import ResNet12
        encoder = ResNet12()
    else:
        raise ValueError('')

    if args.method_type is Method_type.baseline:
        from methods.baselinetrain import BaselineTrain
        model = BaselineTrain(encoder, args)
    elif args.method_type is Method_type.protonet:
        from methods.protonet import ProtoNet
        model = ProtoNet(encoder, args)
    else:
        raise ValueError('')

    from torch.optim import SGD,lr_scheduler
    if args.method_type is Method_type.baseline:
        optimizer = SGD(model.encoder.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_epoch, eta_min=0, last_epoch=-1)
    else:
        optimizer = torch.optim.SGD(model.encoder.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4,
                                    nesterov=True)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.5)

    args.ngpu = torch.cuda.device_count()
    torch.backends.cudnn.benchmark = True
    model = model.cuda()

    label = torch.from_numpy(np.repeat(range(args.n_way), args.n_query))
    label = label.cuda()

    if args.test:
        test(model, label, args, few_shot_params)
        return

    if args.resume:
        resume_OK =  resume_model(model, optimizer, args, scheduler)
    else:
        resume_OK = False
    if (not resume_OK) and  (args.warmup is not None):
        load_pretrained_weights(model, args)

    if args.debug:
        args.max_epoch = args.start_epoch + 1

    for epoch in range(args.start_epoch, args.max_epoch):
        train_one_epoch(model, optimizer, args, train_loader, label, writer, epoch)
        scheduler.step()

        vl, va = val(model, args, val_loader, label)
        if writer is not None:
            writer.add_scalar('data/val_acc', float(va), epoch)
        print('epoch {}, val, loss={:.4f} acc={:.4f}'.format(epoch, vl, va))

        if va >= args.max_acc:
            args.max_acc = va
            print('saving the best model! acc={:.4f}'.format(va))
            save_model(model, optimizer, args, epoch, args.max_acc, 'max_acc', scheduler)
        save_model(model, optimizer, args, epoch, args.max_acc, 'epoch-last', scheduler)
        if epoch != 0:
            print('ETA:{}/{}'.format(timer.measure(), timer.measure(epoch / args.max_epoch)))
    if writer is not None:
        writer.close()
    test(model, label, args, few_shot_params)
def calculate_dist(n_way, n_shot, task_num, episode, batch_size=64):

    model = ProtoNet(model_dict["ResNet18"], n_way, n_shot)

    device = "cuda:0"
    with torch.cuda.device(device):
        model = model.cuda()

    resume_file = "/home/takumi/research/CloserLookFewShot/instance_selection/feature_space/{0}.tar".format(
        episode)
    tmp = torch.load(resume_file)
    model.load_state_dict(tmp['state'])

    base_file = "/home/takumi/research/CloserLookFewShot/filelists/full_Imagenet_except_testclass.json"

    candidate_data_manager = SimpleDataManager(224, batch_size)

    candidate_data_loader = candidate_data_manager.get_data_loader(
        base_file, aug=False, shuffle=False)

    task_file = "/home/takumi/research/CloserLookFewShot/instance_selection/task/few_shot_task{0}.json".format(
        task_num)

    few_image_list = task_train_reader(task_file, n_shot, n_way)

    few_image_feature_list = []
    for images in few_image_list:
        images = images.to(device)
        features2 = model.feature(images)
        features = torch.mean(features2, dim=0)
        few_image_feature_list.append(features)
        #print(euclidean_dist(features2, features2))

    dim = few_image_feature_list[0].size()[0]
    features_ave = torch.zeros(0, dim).to(device)
    for x in few_image_feature_list:
        features_ave = torch.cat([features_ave, x.unsqueeze(0)], dim=0)

    with open(base_file) as f:
        base = json.load(f)

    image_dists = []

    image_id = 0
    for x, labels in tqdm.tqdm(candidate_data_loader):
        x = x.to(device)
        y = model.feature(x)
        #print(euclidean_dist(y, y))
        #print(euclidean_dist(features_ave, features_ave))
        dist = euclidean_dist(y, features_ave)
        #print(dist)
        dist_min, _ = torch.min(dist, dim=1)

        dist_min = dist_min.tolist()
        labels = labels.tolist()

        for i in range(len(dist_min)):
            image_dists.append(
                [dist_min[i], labels[i], base["image_names"][image_id]])
            image_id += 1

    image_dists.sort()

    task_image_dist = dict()

    task_image_dist["label_names"] = copy.deepcopy(base["label_names"])
    task_image_dist["image_names"] = []
    task_image_dist["image_labels"] = []
    task_image_dist["distance"] = []

    for i in range(len(image_dists)):
        task_image_dist["image_names"].append(image_dists[i][2])
        task_image_dist["image_labels"].append(image_dists[i][1])
        task_image_dist["distance"].append(image_dists[i][0])

    with open(
            "/home/takumi/research/CloserLookFewShot/instance_selection/task/task{0}_dataset_dist_{1}.json"
            .format(task_num, episode), "w") as f:
        json.dump(task_image_dist, f)

    print(len(task_image_dist["image_names"]),
          len(task_image_dist["image_labels"]),
          len(task_image_dist["distance"]))
Ejemplo n.º 24
0
def run_save(params):
    print('Run Save features ... ')
    if 'maml' in params.method:
        print('Continuing since maml doesnt support save_feature')
        return

    image_size = get_image_size(params)

    if params.dataset in ['omniglot', 'cross_char']:
        assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation'
        params.model = 'Conv4S'

    split = params.split
    if params.dataset == 'cross':
        if split == 'base':
            loadfile = configs.data_dir['miniImagenet'] + 'all.json'
        else:
            loadfile = configs.data_dir['CUB'] + split + '.json'
    elif params.dataset == 'cross_char':
        if split == 'base':
            loadfile = configs.data_dir['omniglot'] + 'noLatin.json'
        else:
            loadfile = configs.data_dir['emnist'] + split + '.json'
    else:
        loadfile = configs.data_dir[params.dataset] + split + '.json'

    if hasattr(params, 'checkpoint_dir'):
        checkpoint_dir = params.checkpoint_dir
    else:
        checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
            configs.save_dir, params.dataset, params.model, params.method)
        if params.train_aug:
            checkpoint_dir += '_aug'
        if not params.method in ['baseline', 'baseline++']:
            checkpoint_dir += '_%dway_%dshot' % (params.train_n_way,
                                                 params.n_shot)

    print(f'Checkpoint dir: {checkpoint_dir}')
    if params.save_iter != -1:
        modelfile = get_assigned_file(checkpoint_dir, params.save_iter)


#    elif params.method in ['baseline', 'baseline++'] :
#        modelfile   = get_resume_file(checkpoint_dir) #comment in 2019/08/03 updates as the validation of baseline/baseline++ is added
    else:
        modelfile = get_best_file(checkpoint_dir)
    print(f'Model file {modelfile}')
    if params.save_iter != -1:
        outfile = os.path.join(
            checkpoint_dir.replace("checkpoints", "features"),
            split + "_" + str(params.save_iter) + ".hdf5")
    else:
        outfile = os.path.join(
            checkpoint_dir.replace("checkpoints", "features"), split + ".hdf5")

    datamgr = SimpleDataManager(image_size, batch_size=64)
    data_loader = datamgr.get_data_loader(loadfile, aug=False)

    if params.method in ['relationnet', 'relationnet_softmax']:
        if params.model == 'Conv4':
            model = backbone.Conv4NP()
        elif params.model == 'Conv6':
            model = backbone.Conv6NP()
        elif params.model == 'Conv4S':
            model = backbone.Conv4SNP()
        else:
            model = model_dict[params.model](flatten=False)
    elif params.method in ['maml', 'maml_approx']:
        raise ValueError('MAML do not support save feature')
    else:
        model = model_dict[params.model]()

    model = model.cuda()
    #     print('Model keys')
    #     print(model.state_dict().keys())
    tmp = torch.load(modelfile)
    state = tmp['state']
    #     print('Loaded keys')
    #     print(state.keys())
    state_keys = list(state.keys())
    for i, key in enumerate(state_keys):
        if "feature." in key:
            newkey = key.replace(
                "feature.", ""
            )  # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx'
            state[newkey] = state.pop(key)
        else:
            state.pop(key)

    model.load_state_dict(state)
    model.eval()

    dirname = os.path.dirname(outfile)
    if not os.path.isdir(dirname):
        os.makedirs(dirname)
    save_features(model, data_loader, outfile)
Ejemplo n.º 25
0
    if 'Conv' in params.model:
        if params.dataset in ['omniglot', 'cross_char']:
            image_size = 28
        else:
            image_size = 255
    else:
        image_size = params.image_size

    if params.method in ['baseline', 'baseline++']:
        base_datamgr_l = SimpleDataManager(image_size,
                                           batch_size=params.bs,
                                           jigsaw=False,
                                           rotation=False,
                                           isAircraft=isAircraft,
                                           grey=params.grey)
        base_loader_l = base_datamgr_l.get_data_loader(base_file,
                                                       aug=params.train_aug)

        base_datamgr_u = SimpleDataManager(image_size,
                                           batch_size=params.bs,
                                           jigsaw=params.jigsaw,
                                           rotation=params.rotation,
                                           isAircraft=isAircraft,
                                           grey=params.grey)
        if params.dataset_unlabel is not None:
            base_loader_u = base_datamgr_u.get_data_loader(
                base_file_unlabel, aug=params.train_aug)
        else:
            base_loader_u = base_datamgr_u.get_data_loader(
                base_file, aug=params.train_aug)
        val_datamgr = SimpleDataManager(image_size,
                                        batch_size=params.bs,
Ejemplo n.º 26
0
            batch_size = 128
        elif params.dataset in ['caltech256', 'CUB']:
            params.stop_epoch = 100
            batch_size = 64
        elif params.dataset in ['tieredImagenet']:
            params.stop_epoch = 100
            batch_size = 256
        else:
            params.stop_epoch = 400  # default
            batch_size = 128
    else:  # meta-learning methods
        raise ValueError('Unknown methods')

    image_size = 224
    base_datamgr = SimpleDataManager(image_size, batch_size=batch_size)
    base_loader = base_datamgr.get_data_loader(base_file, aug=True)

    base_jigsaw_datamgr = JigsawDataManger(
        image_size,
        batch_size=batch_size,
        max_replace_block_num=params.jig_replace_num_train)
    base_jigsaw_loader = base_jigsaw_datamgr.get_data_loader(base_file,
                                                             aug=False)

    extra_data = 15  # extra_unlabeled data
    val_datamgr = SetDataManager(image_size,
                                 n_way=params.test_n_way,
                                 n_support=params.n_shot,
                                 n_query=params.n_query + extra_data,
                                 n_eposide=50)
    val_loader = val_datamgr.get_data_loader(val_file, aug=False)
Ejemplo n.º 27
0
                filter_size=params.filter_size)
            base_datamgr_test = SimpleDataManager(
                image_size_dct, batch_size=params.test_batch_size)
            base_loader_test = base_datamgr_test.get_data_loader_dct(
                base_file, aug=False, filter_size=params.filter_size)
            test_few_shot_params = dict(n_way=params.train_n_way,
                                        n_support=params.n_shot)
            val_datamgr = SetDataManager(image_size_dct,
                                         n_query=15,
                                         **test_few_shot_params)
            val_loader = val_datamgr.get_data_loader_dct(
                val_file, aug=False, filter_size=params.filter_size)
        else:
            base_datamgr = SimpleDataManager(image_size,
                                             batch_size=params.batch_size)
            base_loader = base_datamgr.get_data_loader(base_file,
                                                       aug=params.train_aug)
            base_datamgr_test = SimpleDataManager(
                image_size, batch_size=params.test_batch_size)
            base_loader_test = base_datamgr_test.get_data_loader(base_file,
                                                                 aug=False)
            test_few_shot_params = dict(n_way=params.train_n_way,
                                        n_support=params.n_shot)
            val_datamgr = SetDataManager(image_size,
                                         n_query=15,
                                         **test_few_shot_params)
            val_loader = val_datamgr.get_data_loader(val_file, aug=False)

        if params.method == 'baseline++':
            model = BaselineTrain(model_dict[params.model],
                                  params.num_classes,
                                  loss_type='dist')
Ejemplo n.º 28
0
def get_visualize_data(datasets='multi',
                       save_epoch=399,
                       name='tmp',
                       method='baseline',
                       model='ResNet10',
                       split='novel',
                       data_dir='./filelists',
                       save_dir='./output'):
    print('Visualizing! {} datasets with {} epochs of {}({})'.format(
        datasets, save_epoch, name, method))

    print('\nStage 1: saving features')
    # dataset
    print('  build dataset')
    if 'Conv' in model:
        image_size = 84
    else:
        image_size = 224
    split = split

    data_loaders = []
    for dataset in datasets.split():
        loadfile = os.path.join(data_dir, dataset, split + '.json')
        datamgr = SimpleDataManager(image_size, batch_size=64)
        data_loader = datamgr.get_data_loader(loadfile, aug=False)
        data_loaders.append(data_loader)

    print('  build feature encoder')
    # feature encoder
    checkpoint_dir = '%s/checkpoints/%s' % (save_dir, name)
    if save_epoch != -1:
        modelfile = get_assigned_file(checkpoint_dir, save_epoch)
    else:
        modelfile = get_best_file(checkpoint_dir)
    if method in ['relationnet', 'relationnet_softmax']:
        if model == 'Conv4':
            model = backbone.Conv4NP()
        elif model == 'Conv6':
            model = backbone.Conv6NP()
        else:
            model = model_dict[model](flatten=False)
    else:
        model = model_dict[model]()
    model = model.cuda()
    tmp = torch.load(modelfile)
    try:
        state = tmp['state']
    except KeyError:
        state = tmp['model_state']
    except:
        raise
    state_keys = list(state.keys())
    for i, key in enumerate(state_keys):
        if "feature." in key and not 'gamma' in key and not 'beta' in key:
            newkey = key.replace("feature.", "")
            state[newkey] = state.pop(key)
        else:
            state.pop(key)

    model.load_state_dict(state)
    model.eval()

    return tsne(model, data_loaders)
Ejemplo n.º 29
0
                params.stop_epoch = 200  # This is different as stated in the open-review paper. However, using 400 epoch in baseline actually lead to over-fitting
            elif params.dataset in ['miniImagenet', 'cross']:
                params.stop_epoch = 400
            else:
                params.stop_epoch = 400  #default
        else:  #meta-learning methods
            if params.n_shot == 1:
                params.stop_epoch = 600
            elif params.n_shot == 5:
                params.stop_epoch = 400
            else:
                params.stop_epoch = 600  #default

    if params.method in ['baseline', 'baseline++']:
        base_datamgr = SimpleDataManager(image_size, batch_size=16)
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   aug=params.train_aug)
        val_datamgr = SimpleDataManager(image_size, batch_size=64)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)

        if params.dataset == 'omniglot':
            assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class'
        if params.dataset == 'cross_char':
            assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class'

        if params.method == 'baseline':
            model = BaselineTrain(model_dict[params.model], params.num_classes)
        elif params.method == 'baseline++':
            model = BaselineTrain(model_dict[params.model],
                                  params.num_classes,
                                  loss_type='dist')
Ejemplo n.º 30
0
    seed = 1339

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    cudnn.enabled = True
    cudnn.benchmark = True

    checkpoint_dir = './checkpoint/'
    if not os.path.isdir(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    train_datamgr = SimpleDataManager(image_size=image_size, batch_size=batch_size, num_workers=num_workers)
    train_loader = train_datamgr.get_data_loader(data_path='PATH/TO/DATASET/CAER-S', load_set='train', aug=True)

    test_datamgr = SimpleDataManager(image_size=image_size, batch_size=batch_size, num_workers=num_workers)
    test_loader = test_datamgr.get_data_loader(data_path='PATH/TO/DATASET/CAER-S', load_set='test', aug=False)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = OURnet()
    
    model = model.to(device)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    CE_loss = nn.CrossEntropyLoss()

    
    optimizer = optim.Adam([{'params': model.face_features.parameters(), 'lr': base_lr},