示例#1
0
def test(model, label, args, few_shot_params):
    if args.debug:
        n_test = 10
        print_freq = 2
    else:
        n_test = 1000
        print_freq = 100
    test_file = args.dataset_dir + 'test.json'
    test_datamgr = SetDataManager(test_file, args.dataset_dir, args.image_size,
                                  mode = 'val',n_episode = n_test ,**few_shot_params)
    loader = test_datamgr.get_data_loader(aug=False)

    test_acc_record = np.zeros((n_test,))

    warmup_state = torch.load(osp.join(args.checkpoint_dir, 'max_acc' + '.pth'))['params']
    model.load_state_dict(warmup_state, strict=False)
    model.eval()

    ave_acc = Averager()
    with torch.no_grad():
        for i, batch in enumerate(loader, 1):
            data, index_label = batch[0].cuda(), batch[1].cuda()
            logits = model(data, 'test')
            acc = count_acc(logits, label)
            ave_acc.add(acc)
            test_acc_record[i - 1] = acc
            if i % print_freq == 0:
                print('batch {}: {:.2f}({:.2f})'.format(i, ave_acc.item() * 100, acc * 100))

    m, pm = compute_confidence_interval(test_acc_record)
    # print('Val Best Epoch {}, Acc {:.4f}, Test Acc {:.4f}'.format(trlog['max_acc_epoch'], trlog['max_acc'],
    #                                                               ave_acc.item()))
    print('Test Acc {:.4f} + {:.4f}'.format(m, pm))
    acc_str = '%4.2f' % (m * 100)
    with open(args.save_dir + '/result.txt', 'a') as f:
        f.write('%s %s\n' % (acc_str, args.name))
示例#2
0
    def meta_train(self,
                   config,
                   method,
                   descriptor_str,
                   debug=True,
                   use_test=False,
                   require_pretrain=False,
                   metric="acc"):
        config["meta_training"] = True
        params = self.params
        params.save_freq = 10
        params.n_query = max(1,
                             int(16 * params.test_n_way / params.train_n_way))
        params.dataset = config["dataset"]
        params.model = config["model"]
        params.method = config["method"]
        params.n_shot = config["n_shot"]
        train_episodes = config["train_episodes"]
        val_episodes = config["val_episodes"]
        end_epoch = config["end_epoch"]
        if "weight_decay" in config:
            weight_decay = config["weight_decay"]
        else:
            weight_decay = 0

        result_dir = "results/meta/%s" % (params.dataset)
        if not os.path.isdir(result_dir):
            os.makedirs(result_dir)
        result_file = os.path.join(
            result_dir,
            "%s_%s_%s.txt" % (params.method, params.model, descriptor_str))

        self.few_shot_params = dict(n_way=params.test_n_way,
                                    n_support=params.n_shot,
                                    n_query=self.params.n_query)
        params.checkpoint_dir = '%s/checkpoints/%s/%s/%s_%s_%s' % (
            configs.save_dir, params.dataset, params.method, params.model,
            descriptor_str, params.n_shot)
        params.stop_epoch = 100
        self.initialize(params, False)
        image_size = self.image_size
        pretrain = PretrainedModel(self.params)

        if use_test:
            file_name = "novel.json"
        else:
            file_name = "val.json"
        if params.dataset == 'cross':
            base_file = configs.data_dir['miniImagenet'] + 'base.json'
            val_file = configs.data_dir['CUB'] + file_name
        else:
            base_file = configs.data_dir[params.dataset] + 'base.json'
            val_file = configs.data_dir[params.dataset] + file_name

        n_query = max(1, int(16 * params.test_n_way / params.train_n_way))

        train_few_shot_params = dict(n_way=params.train_n_way,
                                     n_support=params.n_shot)
        base_datamgr = SetDataManager(image_size,
                                      n_query=n_query,
                                      **train_few_shot_params,
                                      n_eposide=train_episodes)
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   aug=params.train_aug,
                                                   debug=debug)

        test_few_shot_params = dict(n_way=params.test_n_way,
                                    n_support=params.n_shot)
        val_datamgr = SetDataManager(image_size,
                                     n_query=n_query,
                                     **test_few_shot_params,
                                     n_eposide=val_episodes)
        val_loader = val_datamgr.get_data_loader(val_file,
                                                 aug=False,
                                                 debug=debug)

        backbone = self.get_backbone()
        if "params" in config:
            model_params = config["params"]
        else:
            model_params = {}
        if require_pretrain:
            model_params["pretrain"] = pretrain

        model = method(backbone, **train_few_shot_params, **model_params)
        model = model.cuda()

        # Freeze backbone
        model.feature = None
        if not require_pretrain:
            model.pretrain = pretrain
        optimizer = torch.optim.Adam(model.parameters(),
                                     weight_decay=weight_decay)

        max_acc = 0
        for epoch in range(0, end_epoch):
            model.epoch = epoch
            model.train()
            model.train_loop(
                epoch, base_loader,
                optimizer)  # model are called by reference, no need to return
            model.eval()

            if not os.path.isdir(params.checkpoint_dir):
                os.makedirs(params.checkpoint_dir)

            acc = model.test_loop(val_loader, metric=metric)
            message = "Epoch: %d, Validation accuracy: %.3f, Best validation accuracy: %.3f" % (
                epoch, acc, max_acc)
            print(message)
            append_to_file(result_file, message)

            if acc > max_acc:
                print("best model! save...")
                max_acc = acc
                outfile = os.path.join(params.checkpoint_dir, 'best_model.tar')
                torch.save({
                    'epoch': epoch,
                    'state': model.state_dict()
                }, outfile)

            if (epoch % params.save_freq == 0) or (epoch
                                                   == params.stop_epoch - 1):
                outfile = os.path.join(params.checkpoint_dir,
                                       '{:d}.tar'.format(epoch))
                torch.save({
                    'epoch': epoch,
                    'state': model.state_dict()
                }, outfile)
        self.meta_test(config, method, descriptor_str, debug, require_pretrain)
示例#3
0
class Experiment():
    def __init__(self, params):
        np.random.seed(10)

        if params.train_dataset == 'cross':
            base_file = configs.data_dir['miniImagenet'] + 'all.json'
            val_file = configs.data_dir['CUB'] + 'val.json'
        elif params.train_dataset == 'cross_char':
            base_file = configs.data_dir['omniglot'] + 'noLatin.json'
            val_file = configs.data_dir['emnist'] + 'val.json'
        else:
            base_file = configs.data_dir[params.train_dataset] + 'base.json'
            val_file = configs.data_dir[params.train_dataset] + 'val.json'

        if 'Conv' in params.model:
            if params.train_dataset in ['omniglot', 'cross_char']:
                image_size = 28
            else:
                image_size = 84
        else:
            image_size = 224

        if params.train_dataset in ['omniglot', 'cross_char']:
            assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation'
            params.model = 'Conv4S'

        if params.train_dataset == 'omniglot':
            assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class'
        if params.train_dataset == 'cross_char':
            assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class'

        params.train_num_query = max(
            1,
            int(params.test_num_query * params.test_num_way /
                params.train_num_way))
        if params.episodic:
            train_few_shot_params = dict(n_way=params.train_num_way,
                                         n_support=params.train_num_shot,
                                         n_query=params.train_num_query)
            base_datamgr = SetDataManager(image_size, **train_few_shot_params)
            base_loader = base_datamgr.get_data_loader(base_file,
                                                       aug=params.train_aug)
        else:
            base_datamgr = SimpleDataManager(image_size, batch_size=32)
            base_loader = base_datamgr.get_data_loader(base_file,
                                                       aug=params.train_aug)

        if params.test_dataset == 'cross':
            novel_file = configs.data_dir['CUB'] + 'novel.json'
        elif params.test_dataset == 'cross_char':
            novel_file = configs.data_dir['emnist'] + 'novel.json'
        else:
            novel_file = configs.data_dir[params.test_dataset] + 'novel.json'

        val_datamgr = SimpleDataManager(image_size, batch_size=64)
        val_loader = val_datamgr.get_data_loader(novel_file, aug=False)

        novel_datamgr = SimpleDataManager(image_size, batch_size=64)
        novel_loader = novel_datamgr.get_data_loader(novel_file, aug=False)

        optimizer = params.optimizer

        if params.stop_epoch == -1:
            if params.train_dataset in ['omniglot', 'cross_char']:
                params.stop_epoch = 5
            elif params.train_dataset in ['CUB']:
                params.stop_epoch = 200  # This is different as stated in the open-review paper. However, using 400 epoch in baseline actually lead to over-fitting
            elif params.train_dataset in ['miniImagenet', 'cross']:
                params.stop_epoch = 300
            else:
                params.stop_epoch = 300

        shake_config = {
            'shake_forward': params.shake_forward,
            'shake_backward': params.shake_backward,
            'shake_picture': params.shake_picture
        }
        train_param = {
            'loss_type': params.train_loss_type,
            'temperature': params.train_temperature,
            'margin': params.train_margin,
            'lr': params.train_lr,
            'shake': params.shake,
            'shake_config': shake_config,
            'episodic': params.episodic,
            'num_way': params.train_num_way,
            'num_shot': params.train_num_shot,
            'num_query': params.train_num_query,
            'num_classes': params.num_classes
        }
        test_param = {
            'loss_type': params.test_loss_type,
            'temperature': params.test_temperature,
            'margin': params.test_margin,
            'lr': params.test_lr,
            'num_way': params.test_num_way,
            'num_shot': params.test_num_shot,
            'num_query': params.test_num_query
        }

        model = Baseline(model_dict[params.model], params.entropy, train_param,
                         test_param)

        model = model.cuda()

        key = params.tag
        writer = SummaryWriter(log_dir=os.path.join(params.vis_log, key))

        params.checkpoint_dir = '%s/checkpoints/%s/%s' % (
            configs.save_dir, params.train_dataset, params.checkpoint_dir)

        if not os.path.isdir(params.vis_log):
            os.makedirs(params.vis_log)

        outfile_template = os.path.join(
            params.checkpoint_dir.replace("checkpoints", "features"),
            "%s.hdf5")

        if params.mode == 'train' and not os.path.isdir(params.checkpoint_dir):
            os.makedirs(params.checkpoint_dir)

        if params.resume or params.mode == 'test':
            if params.mode == 'test':
                self.feature_model = model_dict[params.model]().cuda()
                resume_file = get_best_file(params.checkpoint_dir)
                tmp = torch.load(resume_file)
                state = tmp['state']
                state_keys = list(state.keys())
                for i, key in enumerate(state_keys):
                    if "feature." in key:
                        newkey = key.replace("feature.", "")
                        state[newkey] = state.pop(key)
                    else:
                        state.pop(key)
                self.feature_model.load_state_dict(state)
                self.feature_model.eval()
            else:
                resume_file = get_resume_file(params.checkpoint_dir)
                tmp = torch.load(resume_file)
                state = tmp['state']
                model.load_state_dict(state)
                params.start_epoch = tmp['epoch'] + 1

            print('Info: Model loaded!!!')

        self.params = params
        self.val_file = val_file
        self.base_file = base_file
        self.image_size = image_size
        self.optimizer = optimizer
        self.outfile_template = outfile_template
        self.novel_loader = novel_loader
        self.base_loader = base_loader
        self.val_loader = val_loader
        self.writer = writer
        self.model = model
        self.key = key

    def train(self):
        if self.optimizer == 'Adam':
            train_optimizer = torch.optim.Adam(self.model.parameters(),
                                               lr=self.params.train_lr)
            train_scheduler = StepLR(train_optimizer, step_size=75, gamma=0.1)
        elif self.optimizer == 'SGD':
            train_optimizer = torch.optim.SGD(self.model.parameters(),
                                              lr=self.params.train_lr,
                                              momentum=0.9,
                                              weight_decay=0.001)
            train_scheduler = StepLR(train_optimizer, step_size=75, gamma=0.1)
        else:
            raise ValueError('Unknown optimizer, please define by yourself')

        max_acc = 0
        start_epoch = self.params.start_epoch
        stop_epoch = self.params.stop_epoch
        test_start_epoch = int(stop_epoch * self.params.test_start_epoch)
        for epoch in range(start_epoch, stop_epoch):
            self.model.train()
            train_num_way = self.params.train_num_way
            train_num_query = self.params.train_num_query
            if self.params.curriculum:
                train_num_way = int(
                    self.params.train_num_way -
                    (self.params.train_num_way - self.params.test_num_way) *
                    (epoch - start_epoch) / (stop_epoch - start_epoch))
                train_num_query = max(
                    1,
                    int(self.params.test_num_query * self.params.test_num_way /
                        train_num_way))
                train_few_shot_params = dict(
                    n_way=train_num_way,
                    n_support=self.params.train_num_shot,
                    n_query=train_num_query)
                self.base_datamgr = SetDataManager(self.image_size,
                                                   **train_few_shot_params)
                self.base_loader = self.base_datamgr.get_data_loader(
                    self.base_file, aug=params.train_aug)
                self.writer.add_scalar('way/curriculum_way', train_num_way,
                                       epoch)
            self.model.train_loop(epoch, train_num_way, train_num_query,
                                  self.base_loader, train_optimizer,
                                  train_scheduler, self.writer)

            if epoch >= test_start_epoch and (epoch + 1) % 5 == 0:
                self.model.eval()

                acc = self.test('val', epoch)

                if acc > max_acc:  # for baseline and baseline++, we don't use validation here so we let acc = -1
                    print("best model! save...")
                    max_acc = acc
                    outfile = os.path.join(self.params.checkpoint_dir,
                                           'best_model.tar')
                    torch.save(
                        {
                            'epoch': epoch,
                            'state': self.model.state_dict()
                        }, outfile)

            if (epoch % self.params.save_freq == 0) or (epoch
                                                        == stop_epoch - 1):
                outfile = os.path.join(self.params.checkpoint_dir,
                                       '{:d}.tar'.format(epoch))
                torch.save({
                    'epoch': epoch,
                    'state': self.model.state_dict()
                }, outfile)

    def test(self, split='novel', epoch=0):
        self.outfile = self.outfile_template % split
        if split == 'novel':
            self.save_feature(self.novel_loader)
        else:
            self.save_feature(self.val_loader)
        cl_data_file = feat_loader.init_loader(self.outfile)

        acc_all = []
        for i in tqdm(range(self.params.test_epoch)):
            if self.params.fast_adapt:
                acc = self.model.fast_adapt(cl_data_file)
            else:
                acc = self.model.test_loop(cl_data_file)
            acc_all.append(acc)

        acc_all = np.asarray(acc_all)
        acc_mean = np.mean(acc_all)
        acc_std = np.std(acc_all)
        print('%d Test Acc = %4.2f%% +- %4.2f%%' %
              (self.params.test_epoch, acc_mean,
               1.96 * acc_std / np.sqrt(self.params.test_epoch)))
        if self.params.mode != 'test':
            self.writer.add_scalar('acc/%s_acc' % split, acc_mean, epoch)

        return acc_mean

    def save_feature(self, data_loader):
        print('Info: Saving feature...')
        dirname = os.path.dirname(self.outfile)
        if not os.path.isdir(dirname):
            os.makedirs(dirname)
        f = h5py.File(self.outfile, 'w')
        max_count = len(data_loader) * data_loader.batch_size
        all_labels = f.create_dataset('all_labels', (max_count, ), dtype='i')
        all_feats = None
        count = 0
        for i, (x, y) in enumerate(data_loader):
            x = Variable(x.cuda())
            if self.params.mode == 'train':
                feats = self.model.feature(x)
            else:
                feats = self.feature_model(x)
            if all_feats is None:
                all_feats = f.create_dataset('all_feats', [max_count] +
                                             list(feats.size()[1:]),
                                             dtype='f')
            all_feats[count:count + feats.size(0)] = feats.data.cpu().numpy()
            all_labels[count:count + feats.size(0)] = y.cpu().numpy()
            count = count + feats.size(0)

        count_var = f.create_dataset('count', (1, ), dtype='i')
        count_var[0] = count

        f.close()
        print('Info: Done saving feature!!!')

    def run(self):
        if self.params.mode == 'train':
            self.train()
        elif self.params.mode == 'test':
            self.test()
示例#4
0
                val_file, aug=False, filter_size=params.filter_size)
        else:
            base_datamgr = SimpleDataManager(image_size,
                                             batch_size=params.batch_size)
            base_loader = base_datamgr.get_data_loader(base_file,
                                                       aug=params.train_aug)
            base_datamgr_test = SimpleDataManager(
                image_size, batch_size=params.test_batch_size)
            base_loader_test = base_datamgr_test.get_data_loader(base_file,
                                                                 aug=False)
            test_few_shot_params = dict(n_way=params.train_n_way,
                                        n_support=params.n_shot)
            val_datamgr = SetDataManager(image_size,
                                         n_query=15,
                                         **test_few_shot_params)
            val_loader = val_datamgr.get_data_loader(val_file, aug=False)

        if params.method == 'baseline++':
            model = BaselineTrain(model_dict[params.model],
                                  params.num_classes,
                                  loss_type='dist')

        elif params.method == 'manifold_mixup':
            if params.model == 'WideResNet28_10':
                model = wrn_mixup_model.wrn28_10(params.num_classes)
            elif params.model == 'ResNet18':
                model = res_mixup_model.resnet18(
                    num_classes=params.num_classes)

        elif params.method == 'S2M2_R' or 'rotation':
            if params.model == 'WideResNet28_10':
示例#5
0
    # dataloader
    print('\n--- Prepare dataloader ---')
    print('\ttrain with seen domain {}'.format(params.dataset))
    print('\tval with seen domain {}'.format(params.testset))
    base_file = os.path.join(params.data_dir, params.dataset, 'base.json')
    val_file = os.path.join(params.data_dir, params.testset, 'val.json')

    # model
    image_size = 224
    n_query = max(1, int(16 * params.test_n_way / params.train_n_way))
    base_datamgr = SetDataManager(image_size,
                                  n_query=n_query,
                                  n_way=params.train_n_way,
                                  n_support=params.n_shot)
    base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug)
    val_datamgr = SetDataManager(image_size,
                                 n_query=n_query,
                                 n_way=params.test_n_way,
                                 n_support=params.n_shot)
    val_loader = val_datamgr.get_data_loader(val_file, aug=False)

    if params.method == 'MatchingNet':
        model = MatchingNet(model_dict[params.model],
                            n_way=params.train_n_way,
                            n_support=params.n_shot).cuda()
    elif params.method == 'RelationNet':
        model = RelationNet(model_dict[params.model],
                            n_way=params.train_n_way,
                            n_support=params.n_shot).cuda()
    elif params.method == 'RelationNetLRP':
        else:
            image_size = 84
    else:
        image_size = 224

    optimization = 'Adam'

    n_query = max(
        1, int(16 * params.test_n_way / params.train_n_way)
    )  #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small
    train_few_shot_params = dict(n_way=params.train_n_way,
                                 n_support=params.n_shot)
    base_datamgr = SetDataManager(image_size,
                                  n_query=n_query,
                                  **train_few_shot_params)
    base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug)
    if params.n_shot_test == -1:  # modify val loader support
        params.n_shot_test = params.n_shot
    else:  # modify target loader support
        train_few_shot_params['n_support'] = params.n_shot_test
    target_datamgr = SetDataManager(image_size,
                                    n_query=n_query,
                                    **train_few_shot_params)
    target_loader = target_datamgr.get_data_loader(novel_file, aug=False)
    base_loader = [base_loader, target_loader]

    test_few_shot_params = dict(n_way=params.test_n_way,
                                n_support=params.n_shot_test)
    val_datamgr = SetDataManager(image_size,
                                 n_query=n_query,
                                 **test_few_shot_params)
示例#7
0
    n_query = max(1, int(16 * args.test_n_way / args.train_n_way))

    train_few_shot_args = dict(n_way=args.train_n_way, n_support=args.n_shot)
    base_datamgr = SetDataManager("CUB",
                                  84,
                                  n_query=n_query,
                                  **train_few_shot_args,
                                  args=args)
    print("Loading train data")

    base_loader = base_datamgr.get_data_loader(
        base_file,
        aug=True,
        lang_dir=constants.LANG_DIR,
        normalize=True,
        vocab=vocab,
        # Maximum training data restrictions only apply at train time
        max_class=args.max_class,
        max_img_per_class=args.max_img_per_class,
        max_lang_per_class=args.max_lang_per_class,
    )

    val_datamgr = SetDataManager(
        "CUB",
        84,
        n_query=n_query,
        n_way=args.test_n_way,
        n_support=args.n_shot,
        args=args,
    )
    print("Loading val data\n")
示例#8
0
        if params.dataset == 'cross':
            if split == 'base':
                loadfile = configs.data_dir['miniImagenet'] + 'all.json'
            else:
                loadfile = configs.data_dir['CUB'] + split + '.json'
        elif params.dataset == 'cross_char':
            if split == 'base':
                loadfile = configs.data_dir['omniglot'] + 'noLatin.json'
            else:
                loadfile = configs.data_dir['emnist'] + split + '.json'
        else:
            loadfile = configs.data_dir[params.dataset] + split + '.json'

        novel_loader = datamgr.get_data_loader(loadfile,
                                               aug=False,
                                               is_train=False)
        if params.adaptation:
            model.task_update_num = 100  #We perform adaptation on MAML simply by updating more times.
        model.eval()
        acc_mean, acc_std = model.test_loop(novel_loader, return_std=True)
    elif params.method == 'comet':
        if 'Conv' in params.model:
            if params.dataset in ['omniglot', 'cross_char']:
                image_size = 28
            else:
                image_size = 84
        else:
            image_size = 224
        loadfile = configs.data_dir[params.dataset] + split + '.json'
        datamgr = SetDataManager(image_size,
示例#9
0
    y = np.repeat(range(n_way), n_query)
    acc = np.mean(pred == y) * 100
    return acc


if __name__ == '__main__':
    params = parse_args('test')

    acc_all = []

    iter_num = 600

    few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot)

    datamgr = SetDataManager(n_eposide=iter_num, n_query=15, **few_shot_params)
    novel_loader = datamgr.get_data_loader(root='./filelists/tabula_muris',
                                           mode='test')

    x_dim = novel_loader.dataset.get_dim()
    go_mask = novel_loader.dataset.go_mask

    if params.method == 'baseline':
        model = BaselineFinetune(backbone.FCNet(x_dim), **few_shot_params)
    elif params.method == 'baseline++':
        model = BaselineFinetune(backbone.FCNet(x_dim),
                                 loss_type='dist',
                                 **few_shot_params)
    elif params.method == 'protonet':
        model = ProtoNet(backbone.FCNet(x_dim), **few_shot_params)
    elif params.method == 'comet':
        model = COMET(backbone.EnFCNet(x_dim, go_mask), **few_shot_params)
    elif params.method == 'matchingnet':
示例#10
0
文件: train.py 项目: killsking/tcmaml
        elif params.n_shot == 5:
            params.stop_epoch = 400
        else:
            params.stop_epoch = 600

    if params.method in ['tcmaml', 'tcmaml_approx']:
        n_query = max(
            1, int(16 * params.test_n_way / params.train_n_way)
        )  # if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small

        train_few_shot_params = dict(n_way=params.train_n_way,
                                     n_support=params.n_shot)
        train_datamgr = SetDataManager(
            image_size, n_query=n_query, **train_few_shot_params
        )  # default number of episodes (tasks) is 100 per epoch
        train_loader = train_datamgr.get_data_loader(train_file,
                                                     aug=params.train_aug)

        test_few_shot_params = dict(n_way=params.test_n_way,
                                    n_support=params.n_shot)
        val_datamgr = SetDataManager(image_size,
                                     n_query=n_query,
                                     **test_few_shot_params)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)

        backbone.ConvBlock.maml = True
        backbone.SimpleBlock.maml = True
        backbone.ResNet.maml = True

        model = TCMAML(model_dict[params.model],
                       approx=(params.method == 'tcmaml_approx'),
                       **train_few_shot_params)
    test_file = configs.data_dir[params.dataset] + 'novel.json'

    record_dir = './record'
    if not os.path.isdir(record_dir):
        os.makedirs(record_dir)

    image_size = 224
    if params.method in ['jigsaw', 'imprint_jigsaw']:
        extra_data = 15  # extra_unlabeled data
        test_datamgr = SetDataManager(image_size,
                                      n_way=params.test_n_way,
                                      n_support=params.n_shot,
                                      n_query=params.n_query + extra_data,
                                      n_eposide=1)
        test_loader = test_datamgr.get_data_loader(test_file, aug=False)
        if params.dataset == "miniImagenet":
            num_class = 64
        elif params.dataset == "tieredImagenet":
            num_class = 351
        elif params.dataset == "caltech256":
            num_class = 257
        elif params.dataset == "CUB":
            num_class = 200  # set to 200 since the label range 0~199 even though there are only 100 classes to be trained
        else:
            raise ValueError('Unknown dataset')

        if params.method == "jigsaw":
            model = Jigsaw(num_class=num_class)
        else:
            model = ImprintJigsaw(num_class=num_class)
示例#12
0
        datamgr = SetDataManager(image_size,
                                 n_episode=iter_num,
                                 n_query=15,
                                 **few_shot_args,
                                 args=args)

        if args.dataset == 'cross':
            if split == 'base':
                loadfile = configs.data_dir['miniImagenet'] + 'all.json'
            else:
                loadfile = configs.data_dir['CUB'] + split + '.json'
        else:
            loadfile = configs.data_dir[args.dataset] + split + '.json'

        novel_loader = datamgr.get_data_loader(loadfile, aug=False)
        if args.adaptation:
            model.task_update_num = 100  #We perform adaptation on MAML simply by updating more times.
        model.eval()
        acc_mean, acc_std = model.test_loop(novel_loader, return_std=True)
    elif split == 'attributes':
        if 'Conv' in args.model:
            image_size = 84
        else:
            image_size = 224

        datamgr = AttrDataManager(image_size,
                                  num_workers=args.n_workers,
                                  pin_memory=args.pin_memory)

        if args.attr_dataset is None:
示例#13
0
def exp_test(params, n_episodes, should_del_features=False):#, show_data=False):
    start_time = datetime.datetime.now()
    print('exp_test() started at',start_time)
    
    set_random_seed(0) # successfully reproduce "normal" testing. 
    
    if params.gpu_id:
        set_gpu_id(params.gpu_id)
    
#     acc_all = []

    model = get_model(params, 'test')
    
    ########## get settings ##########
    n_shot = params.test_n_shot if params.test_n_shot is not None else params.n_shot
    few_shot_params = dict(n_way = params.test_n_way , n_support = n_shot)
    if params.gpu_id:
        model = model.cuda()
    else:
        model = to_device(model)
    checkpoint_dir = get_checkpoint_dir(params)
    print('loading from:',checkpoint_dir)
    if params.save_iter != -1:
        modelfile   = get_assigned_file(checkpoint_dir, params.save_iter)
    else:
        modelfile   = get_best_file(checkpoint_dir)
    
    ########## load model ##########
    if modelfile is not None:
        if params.gpu_id is None:
            tmp = torch.load(modelfile)
        else: # TODO: figure out WTF is going on here
            print('params.gpu_id =', params.gpu_id)
            map_location = 'cuda:0'
#             gpu_str = 'cuda:' + '0'#str(params.gpu_id)
#             map_location = {'cuda:1':gpu_str, 'cuda:0':gpu_str} # see here: https://hackmd.io/koKAo6kURn2YBqjoXXDhaw#RuntimeError-CUDA-error-invalid-device-ordinal
            tmp = torch.load(modelfile, map_location=map_location)
#                 tmp = torch.load(modelfile)
        if not params.method in ['baseline', 'baseline++'] : 
            # if 'baseline' or 'baseline++' then NO NEED to load model !!!
            model.load_state_dict(tmp['state'])
            print('Model successfully loaded.')
        else:
            print('No need to load model for baseline/baseline++ when testing.')
        load_epoch = int(tmp['epoch'])
    
    ########## testing ##########
    if params.method in ['maml', 'maml_approx']: #maml do not support testing with feature
        image_size = get_img_size(params)
        load_file = get_loadfile_path(params, params.split)

        datamgr         = SetDataManager(image_size, n_episode = n_episodes, n_query = 15 , **few_shot_params)
        
        novel_loader     = datamgr.get_data_loader( loadfile, aug = False)
        if params.adaptation:
            model.task_update_num = 100 #We perform adaptation on MAML simply by updating more times.
        model.eval()
        acc_mean, acc_std = model.test_loop( novel_loader, return_std = True)
        
        ########## last record and post-process ##########
        torch.cuda.empty_cache()
        timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime())
        # TODO afterward: compute this
        acc_str = '%4.2f%% +- %4.2f%%' % (acc_mean, 1.96* acc_std/np.sqrt(n_episodes))
        # writing settings into csv
        acc_mean_str = '%4.2f' % (acc_mean)
        acc_std_str = '%4.2f' %(acc_std)
        # record beyond params
        extra_record = {'time':timestamp, 'acc_mean':acc_mean_str, 'acc_std':acc_std_str, 'epoch':load_epoch}
        if should_del_features:
            del_features(params)
        end_time = datetime.datetime.now()
        print('exp_test() start at', start_time, ', end at', end_time, '.\n')
        print('exp_test() totally took:', end_time-start_time)
        return extra_record, task_datas

    else: # not MAML
        acc_all = []
#         # draw_task: initialize task acc(actually can replace acc_all), img_path, img_is_correct, etc.
#         task_datas = [None]*n_episodes # list of dict
        # directly use extracted features
        all_feature_files = get_all_feature_files(params)
        
        if params.n_test_candidates is None: # common setting (no candidate)
            # draw_task: initialize task acc(actually can replace acc_all), img_path, img_is_correct, etc.
            task_datas = [None]*n_episodes # list of dict
            
            feature_file = all_feature_files[0]
            cl_feature, cl_filepath = feat_loader.init_loader(feature_file, return_path=True)
            cl_feature_single = [cl_feature]
            
            for i in tqdm(range(n_episodes)):
                # TODO afterward: fix data list? can only fix class list?
                task_data = feature_evaluation(
                    cl_feature_single, model, params=params, n_query=15, **few_shot_params, 
                    cl_filepath=cl_filepath,
                )
                acc = task_data['acc']
                acc_all.append(acc)
                task_datas[i] = task_data
            
            acc_all  = np.asarray(acc_all)
            acc_mean = np.mean(acc_all)
            acc_std  = np.std(acc_all)
            print('loaded from %d epoch model.' %(load_epoch))
            print('%d episodes, Test Acc = %4.2f%% +- %4.2f%%' %(n_episodes, acc_mean, 1.96* acc_std/np.sqrt(n_episodes)))
            
            ########## last record and post-process ##########
            torch.cuda.empty_cache()
            timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime())
            # TODO afterward: compute this
            acc_str = '%4.2f%% +- %4.2f%%' % (acc_mean, 1.96* acc_std/np.sqrt(n_episodes))
            # writing settings into csv
            acc_mean_str = '%4.2f' % (acc_mean)
            acc_std_str = '%4.2f' %(acc_std)
            # record beyond params
            extra_record = {'time':timestamp, 'acc_mean':acc_mean_str, 'acc_std':acc_std_str, 'epoch':load_epoch}
            if should_del_features:
                del_features(params)
            end_time = datetime.datetime.now()
            print('exp_test() start at', start_time, ', end at', end_time, '.\n')
            print('exp_test() totally took:', end_time-start_time)
            return extra_record, task_datas
        else: # n_test_candidates settings
                
            candidate_cl_feature = [] # features of each class of each candidates
            print('Loading features of %s candidates into dictionaries...' %(params.n_test_candidates))
            for n in tqdm(range(params.n_test_candidates)):
                nth_feature_file = all_feature_files[n]
                cl_feature, cl_filepath = feat_loader.init_loader(nth_feature_file, return_path=True)
                candidate_cl_feature.append(cl_feature)

            print('Evaluating...')
            
            # TODO: frac_acc_all
            is_single_exp = not isinstance(params.frac_ensemble, list)
            if is_single_exp:
                # draw_task: initialize task acc(actually can replace acc_all), img_path, img_is_correct, etc.
                task_datas = [None]*n_episodes # list of dict
                ########## test and record acc ##########
                for i in tqdm(range(n_episodes)):
                    # TODO afterward: fix data list? can only fix class list?

                    task_data = feature_evaluation(
                        candidate_cl_feature, model, params=params, n_query=15, **few_shot_params, 
                        cl_filepath=cl_filepath,
                    )
                    acc = task_data['acc']
                    acc_all.append(acc)
                    task_datas[i] = task_data
                    
                    collected = gc.collect()
#                     print("Garbage collector: collected %d objects." % (collected))

                acc_all  = np.asarray(acc_all)
                acc_mean = np.mean(acc_all)
                acc_std  = np.std(acc_all)
                print('loaded from %d epoch model.' %(load_epoch))
                print('%d episodes, Test Acc = %4.2f%% +- %4.2f%%' %(n_episodes, acc_mean, 1.96* acc_std/np.sqrt(n_episodes)))
                collected = gc.collect()
                print("garbage collector: collected %d objects." % (collected))
                
                ########## last record and post-process ##########
                torch.cuda.empty_cache()
                timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime())
                # TODO afterward: compute this
                acc_str = '%4.2f%% +- %4.2f%%' % (acc_mean, 1.96* acc_std/np.sqrt(n_episodes))
                # writing settings into csv
                acc_mean_str = '%4.2f' % (acc_mean)
                acc_std_str = '%4.2f' %(acc_std)
                # record beyond params
                extra_record = {'time':timestamp, 'acc_mean':acc_mean_str, 'acc_std':acc_std_str, 'epoch':load_epoch}
                if should_del_features:
                    del_features(params)
                end_time = datetime.datetime.now()
                print('exp_test() start at', start_time, ', end at', end_time, '.\n')
                print('exp_test() totally took:', end_time-start_time)
                return extra_record, task_datas
            else: ########## multi-frac_ensemble exps ##########
                
                ########## (haven't modified) test and record acc ##########
                n_fracs = len(params.frac_ensemble)
                
                ##### initialize frac_data #####
                frac_acc_alls = [[0]*n_episodes for _ in range(n_fracs)]
                frac_acc_means = [None]*n_fracs
                frac_acc_stds = [None]*n_fracs
                # draw_task: initialize task acc(actually can replace acc_all), img_path, img_is_correct, etc.
                ep_task_data_each_frac = [[None]*n_episodes for _ in range(n_fracs)] # list of list of dict

                for ep_id in tqdm(range(n_episodes)):
                    # TODO afterward: fix data list? can only fix class list?
                    
                    # TODO my_utils.py: feature_eval return frac_task_data
                    frac_task_data = feature_evaluation(
                        candidate_cl_feature, model, params=params, n_query=15, **few_shot_params, 
                        cl_filepath=cl_filepath,
                    )
                    for frac_id in range(n_fracs):
                        task_data = frac_task_data[frac_id]
                        # TODO: i think here's the problem???
                        acc = task_data['acc']
                        frac_acc_alls[frac_id][ep_id] = acc
                        ep_task_data_each_frac[frac_id][ep_id] = task_data
                        
                        collected = gc.collect()
#                         print("Garbage collector: collected %d objects." % (collected))
                    collected = gc.collect()
#                     print("Garbage collector: collected %d objects." % (collected))
                ### debug
#                 print('frac_acc_alls:', frac_acc_alls)
#                 yee
                for frac_id in range(n_fracs):
                    frac_acc_alls[frac_id]  = np.asarray(frac_acc_alls[frac_id])
                    acc_all = frac_acc_alls[frac_id]
                    acc_mean = np.mean(acc_all)
                    acc_std = np.std(acc_all)
                    frac_acc_means[frac_id] = acc_mean
                    frac_acc_stds[frac_id]  = acc_std
                    print('loaded from %d epoch model, frac_ensemble:'%(load_epoch), params.frac_ensemble[frac_id])
                    print('%d episodes, Test Acc = %4.2f%% +- %4.2f%%' %(n_episodes, acc_mean, 1.96* acc_std/np.sqrt(n_episodes)))
                
                ########## (haven't modified) last record and post-process ##########
                torch.cuda.empty_cache()
                timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime())
                # TODO afterward: compute this
#                 acc_str = '%4.2f%% +- %4.2f%%' % (acc_mean, 1.96* acc_std/np.sqrt(n_episodes))
                frac_extra_records = []
                for frac_id in range(n_fracs):
                    # writing settings into csv
                    acc_mean = frac_acc_means[frac_id]
                    acc_std = frac_acc_stds[frac_id]
                    acc_mean_str = '%4.2f' % (acc_mean)
                    acc_std_str = '%4.2f' %(acc_std)
                    # record beyond params
                    extra_record = {'time':timestamp, 'acc_mean':acc_mean_str, 'acc_std':acc_std_str, 'epoch':load_epoch}
                    frac_extra_records.append(extra_record)
                
                if should_del_features:
                    del_features(params)
                end_time = datetime.datetime.now()
                print('exp_test() start at', start_time, ', end at', end_time, '.\n')
                print('exp_test() totally took:', end_time-start_time)
                
                return frac_extra_records, ep_task_data_each_frac
        modelfile = get_best_file(params.checkpoint_dir)
    print("  load model: %s" % modelfile)

    # start evaluate
    print('\n--- start the testing ---')
    n_exp = params.n_exp
    n_iter = params.n_iter
    tf_path = '%s/log_test/%s_iter_%s_%s' % (params.save_dir, params.name,
                                             params.n_iter, params.opt)
    tf_writer = SummaryWriter(log_dir=tf_path)

    # statics
    print('\n--- get statics ---')
    for i in range(n_exp):
        acc_all = np.empty((n_task, 2))
        base_data_loader = datamgr.get_data_loader(base_loadfile, aug=False)
        test_data_loader = datamgr.get_data_loader(test_loadfile, aug=False)

        base_data_generator = iter(base_data_loader)
        test_data_generator = iter(test_data_loader)

        task_pbar = tqdm(range(n_task))
        for j in task_pbar:
            base_task = next(base_data_generator)[0]
            test_task = next(test_data_generator)[0]
            n_sub_query = params.n_sub_query
            _ = model.resume(modelfile)
            if False:
                base_acc, test_acc = test_uni(base_task, test_task, model,
                                              n_iter, n_sub_query, params)
            else:
示例#15
0
    def __init__(self, params):
        np.random.seed(10)

        if params.train_dataset == 'cross':
            base_file = configs.data_dir['miniImagenet'] + 'all.json'
            val_file = configs.data_dir['CUB'] + 'val.json'
        elif params.train_dataset == 'cross_char':
            base_file = configs.data_dir['omniglot'] + 'noLatin.json'
            val_file = configs.data_dir['emnist'] + 'val.json'
        else:
            base_file = configs.data_dir[params.train_dataset] + 'base.json'
            val_file = configs.data_dir[params.train_dataset] + 'val.json'

        if 'Conv' in params.model:
            if params.train_dataset in ['omniglot', 'cross_char']:
                image_size = 28
            else:
                image_size = 84
        else:
            image_size = 224

        if params.train_dataset in ['omniglot', 'cross_char']:
            assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation'
            params.model = 'Conv4S'

        if params.train_dataset == 'omniglot':
            assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class'
        if params.train_dataset == 'cross_char':
            assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class'

        params.train_num_query = max(
            1,
            int(params.test_num_query * params.test_num_way /
                params.train_num_way))
        if params.episodic:
            train_few_shot_params = dict(n_way=params.train_num_way,
                                         n_support=params.train_num_shot,
                                         n_query=params.train_num_query)
            base_datamgr = SetDataManager(image_size, **train_few_shot_params)
            base_loader = base_datamgr.get_data_loader(base_file,
                                                       aug=params.train_aug)
        else:
            base_datamgr = SimpleDataManager(image_size, batch_size=32)
            base_loader = base_datamgr.get_data_loader(base_file,
                                                       aug=params.train_aug)

        if params.test_dataset == 'cross':
            novel_file = configs.data_dir['CUB'] + 'novel.json'
        elif params.test_dataset == 'cross_char':
            novel_file = configs.data_dir['emnist'] + 'novel.json'
        else:
            novel_file = configs.data_dir[params.test_dataset] + 'novel.json'

        val_datamgr = SimpleDataManager(image_size, batch_size=64)
        val_loader = val_datamgr.get_data_loader(novel_file, aug=False)

        novel_datamgr = SimpleDataManager(image_size, batch_size=64)
        novel_loader = novel_datamgr.get_data_loader(novel_file, aug=False)

        optimizer = params.optimizer

        if params.stop_epoch == -1:
            if params.train_dataset in ['omniglot', 'cross_char']:
                params.stop_epoch = 5
            elif params.train_dataset in ['CUB']:
                params.stop_epoch = 200  # This is different as stated in the open-review paper. However, using 400 epoch in baseline actually lead to over-fitting
            elif params.train_dataset in ['miniImagenet', 'cross']:
                params.stop_epoch = 300
            else:
                params.stop_epoch = 300

        shake_config = {
            'shake_forward': params.shake_forward,
            'shake_backward': params.shake_backward,
            'shake_picture': params.shake_picture
        }
        train_param = {
            'loss_type': params.train_loss_type,
            'temperature': params.train_temperature,
            'margin': params.train_margin,
            'lr': params.train_lr,
            'shake': params.shake,
            'shake_config': shake_config,
            'episodic': params.episodic,
            'num_way': params.train_num_way,
            'num_shot': params.train_num_shot,
            'num_query': params.train_num_query,
            'num_classes': params.num_classes
        }
        test_param = {
            'loss_type': params.test_loss_type,
            'temperature': params.test_temperature,
            'margin': params.test_margin,
            'lr': params.test_lr,
            'num_way': params.test_num_way,
            'num_shot': params.test_num_shot,
            'num_query': params.test_num_query
        }

        model = Baseline(model_dict[params.model], params.entropy, train_param,
                         test_param)

        model = model.cuda()

        key = params.tag
        writer = SummaryWriter(log_dir=os.path.join(params.vis_log, key))

        params.checkpoint_dir = '%s/checkpoints/%s/%s' % (
            configs.save_dir, params.train_dataset, params.checkpoint_dir)

        if not os.path.isdir(params.vis_log):
            os.makedirs(params.vis_log)

        outfile_template = os.path.join(
            params.checkpoint_dir.replace("checkpoints", "features"),
            "%s.hdf5")

        if params.mode == 'train' and not os.path.isdir(params.checkpoint_dir):
            os.makedirs(params.checkpoint_dir)

        if params.resume or params.mode == 'test':
            if params.mode == 'test':
                self.feature_model = model_dict[params.model]().cuda()
                resume_file = get_best_file(params.checkpoint_dir)
                tmp = torch.load(resume_file)
                state = tmp['state']
                state_keys = list(state.keys())
                for i, key in enumerate(state_keys):
                    if "feature." in key:
                        newkey = key.replace("feature.", "")
                        state[newkey] = state.pop(key)
                    else:
                        state.pop(key)
                self.feature_model.load_state_dict(state)
                self.feature_model.eval()
            else:
                resume_file = get_resume_file(params.checkpoint_dir)
                tmp = torch.load(resume_file)
                state = tmp['state']
                model.load_state_dict(state)
                params.start_epoch = tmp['epoch'] + 1

            print('Info: Model loaded!!!')

        self.params = params
        self.val_file = val_file
        self.base_file = base_file
        self.image_size = image_size
        self.optimizer = optimizer
        self.outfile_template = outfile_template
        self.novel_loader = novel_loader
        self.base_loader = base_loader
        self.val_loader = val_loader
        self.writer = writer
        self.model = model
        self.key = key
示例#16
0
    def meta_test(self,
                  config,
                  method,
                  descriptor_str,
                  debug=True,
                  require_pretrain=False,
                  metric="acc"):
        config["meta_training"] = True
        params = self.params
        params.save_freq = 50
        params.n_query = max(1,
                             int(16 * params.test_n_way / params.train_n_way))
        params.dataset = config["dataset"]
        params.model = config["model"]
        params.method = config["method"]
        params.n_shot = config["n_shot"]

        self.few_shot_params = dict(n_way=params.test_n_way,
                                    n_support=params.n_shot,
                                    n_query=self.params.n_query)
        params.checkpoint_dir = '%s/checkpoints/%s/%s/%s_%s_%s' % (
            configs.save_dir, params.dataset, params.method, params.model,
            descriptor_str, params.n_shot)
        params.stop_epoch = 2000
        self.initialize(params, False)
        image_size = self.image_size
        pretrain = PretrainedModel(self.params)

        if params.dataset == 'cross':
            test_file = configs.data_dir['CUB'] + 'novel.json'
        else:
            test_file = configs.data_dir[params.dataset] + 'novel.json'

        n_query = 15

        few_shot_params = dict(n_way=params.test_n_way,
                               n_support=params.n_shot)
        datamgr = SetDataManager(image_size,
                                 n_query=n_query,
                                 **few_shot_params,
                                 n_eposide=params.stop_epoch)
        loader = datamgr.get_data_loader(test_file, aug=False, debug=debug)

        if "params" in config:
            model_params = config["params"]
        else:
            model_params = {}
        if require_pretrain:
            model_params["pretrain"] = pretrain

        backbone = self.get_backbone()
        model = method(backbone, **few_shot_params, **model_params)
        model = model.cuda()
        model_file = os.path.join(params.checkpoint_dir, "best_model.tar")
        # Load model
        state_dict = model.state_dict()
        saved_states = torch.load(model_file)["state"]
        state_dict.update(saved_states)
        model.load_state_dict(state_dict)
        # Freeze backbone
        model.feature = None
        model.pretrain = pretrain

        model.eval()
        acc = model.test_loop(loader, metric=metric)
        print(acc)
示例#17
0
    print('Testing! {} shots on {} dataset with {} epochs of {}({})'.format(
        params.n_shot, params.testset, params.save_epoch, name, params.method))
    # dataset
    print('  build dataset')
    if 'Conv' in params.model:
        image_size = 84
    else:
        image_size = 224
    split = params.split
    loadfile = os.path.join(params.data_dir, params.testset, split + '.json')
    test_few_shot_params = dict(n_way=params.test_n_way,
                                n_support=params.n_shot)
    val_datamgr = SetDataManager(image_size,
                                 n_query=params.n_query,
                                 **test_few_shot_params)
    val_loader = val_datamgr.get_data_loader(loadfile, aug=False)

    datasets = params.dataset
    datasets.remove(params.testset)

    base_loaders = [
        val_datamgr.get_data_loader(os.path.join(params.data_dir, dataset,
                                                 'base.json'),
                                    aug=False) for dataset in datasets
    ]

    print('  build feature encoder')
    # feature encoder
    checkpoint_dir = '%s/checkmodels/%s' % (params.save_dir, name)
    if params.save_epoch != -1:
        modelfile = get_assigned_file(checkpoint_dir, params.save_epoch)
示例#18
0
def explain_relationnet():
    # print(sys.path)
    params = options.parse_args('test')
    feature_model = backbone.model_dict['ResNet10']
    params.method = 'relationnet'
    params.dataset = 'miniImagenet'  # name relationnet --testset miniImagenet
    params.name = 'relationnet'
    params.testset = 'miniImagenet'
    params.data_dir = '/home/sunjiamei/work/fewshotlearning/dataset/'
    params.save_dir = '/home/sunjiamei/work/fewshotlearning/CrossDomainFewShot-master/output'

    if 'Conv' in params.model:
        image_size = 84
    else:
        image_size = 224
    split = params.split
    n_query = 1
    loadfile = os.path.join(params.data_dir, params.testset, split + '.json')
    few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot)
    data_datamgr = SetDataManager(image_size,
                                  n_query=n_query,
                                  **few_shot_params)
    data_loader = data_datamgr.get_data_loader(loadfile, aug=False)

    acc_all = []
    iter_num = 1000

    # model
    print('  build metric-based model')
    if params.method == 'protonet':
        model = ProtoNet(backbone.model_dict[params.model], **few_shot_params)
    elif params.method == 'matchingnet':
        model = MatchingNet(backbone.model_dict[params.model],
                            **few_shot_params)
    elif params.method == 'gnnnet':
        model = GnnNet(backbone.model_dict[params.model], **few_shot_params)
    elif params.method in ['relationnet', 'relationnet_softmax']:
        if params.model == 'Conv4':
            feature_model = backbone.Conv4NP
        elif params.model == 'Conv6':
            feature_model = backbone.Conv6NP
        else:
            feature_model = backbone.model_dict[params.model]
        loss_type = 'LRPmse'
        model = RelationNet(feature_model,
                            loss_type=loss_type,
                            **few_shot_params)
    else:
        raise ValueError('Unknown method')

    checkpoint_dir = '%s/checkpoints/%s' % (params.save_dir, params.name)
    # print(checkpoint_dir)
    if params.save_epoch != -1:
        modelfile = get_assigned_file(checkpoint_dir, params.save_epoch)
    else:
        modelfile = get_best_file(checkpoint_dir)
        # print(modelfile)
    if modelfile is not None:
        tmp = torch.load(modelfile)
        try:
            model.load_state_dict(tmp['state'])
        except RuntimeError:
            print('warning! RuntimeError when load_state_dict()!')
            model.load_state_dict(tmp['state'], strict=False)
        except KeyError:
            for k in tmp['model_state']:  ##### revise latter
                if 'running' in k:
                    tmp['model_state'][k] = tmp['model_state'][k].squeeze()
            model.load_state_dict(tmp['model_state'], strict=False)
        except:
            raise

    model = model.cuda()
    model.eval()
    model.n_query = n_query
    # ---test the accuracy on the test set to verify the model is loaded----
    acc = 0
    count = 0
    # for i, (x, y) in enumerate(data_loader):
    #   scores = model.set_forward(x)
    #   pred = scores.data.cpu().numpy().argmax(axis=1)
    #   y = np.repeat(range(model.n_way), n_query)
    #   acc += np.sum(pred == y)
    #   count += len(y)
    #   # print(1.0*acc/count)
    # print(1.0*acc/count)
    preset = lrp_presets.SequentialPresetA()

    feature_model = copy.deepcopy(model.feature)
    lrp_wrapper.add_lrp(feature_model, preset=preset)
    relation_model = copy.deepcopy(model.relation_module)
    # print(relation_model)
    lrp_wrapper.add_lrp(relation_model, preset=preset)
    with open(
            '/home/sunjiamei/work/fewshotlearning/dataset/miniImagenet/class_to_readablelabel.json',
            'r') as f:
        class_to_readable = json.load(f)
    explanation_save_dir = os.path.join(params.save_dir, 'explanations',
                                        params.name)
    if not os.path.isdir(explanation_save_dir):
        os.makedirs(explanation_save_dir)
    for i, (x, y, p) in enumerate(data_loader):
        '''x is the images with shape as n_way, n_support + n_querry, 3, img_size, img_size
       y is the global labels of the images with shape as (n_way, n_support + n_query)
       p is the image path as a list of tuples, length is n_query+n_support,  each tuple element is with length n_way'''
        if i >= 3:
            break
        label_to_readableclass, query_img_path, query_gt_class = LRPutil.get_class_label(
            p, class_to_readable, model.n_query)
        z_support, z_query = model.parse_feature(x, is_feature=False)
        z_support = z_support.contiguous()
        z_proto = z_support.view(model.n_way, model.n_support,
                                 *model.feat_dim).mean(1)
        # print(z_proto.shape)
        z_query = z_query.contiguous().view(model.n_way * model.n_query,
                                            *model.feat_dim)
        # print(z_query.shape)
        # get relations with metric function
        z_proto_ext = z_proto.unsqueeze(0).repeat(model.n_query * model.n_way,
                                                  1, 1, 1, 1)
        # print(z_proto_ext.shape)
        z_query_ext = z_query.unsqueeze(0).repeat(model.n_way, 1, 1, 1, 1)

        z_query_ext = torch.transpose(z_query_ext, 0, 1)
        # print(z_query_ext.shape)
        extend_final_feat_dim = model.feat_dim.copy()
        extend_final_feat_dim[0] *= 2
        relation_pairs = torch.cat((z_proto_ext, z_query_ext),
                                   2).view(-1, *extend_final_feat_dim)
        # print(relation_pairs.shape)
        relations = relation_model(relation_pairs)
        # print(relations)
        scores = relations.view(-1, model.n_way)
        preds = scores.data.cpu().numpy().argmax(axis=1)
        # print(preds.shape)
        relations = relations.view(-1, model.n_way)
        # print(relations)
        relations_sf = torch.softmax(relations, dim=-1)
        # print(relations_sf)
        relations_logits = torch.log(LRPutil.LOGIT_BETA * relations_sf /
                                     (1 - relations_sf))
        # print(relations_logits)
        # print(preds)
        relations_logits = relations_logits.view(-1, 1)
        relevance_relations = relation_model.compute_lrp(
            relation_pairs, target=relations_logits)
        # print(relevance_relations.shape)
        # print(model.feat_dim)
        relevance_z_query = torch.narrow(relevance_relations, 1,
                                         model.feat_dim[0], model.feat_dim[0])
        # print(relevance_z_query.shape)
        relevance_z_query = relevance_z_query.view(
            model.n_query * model.n_way, model.n_way,
            *relevance_z_query.size()[1:])
        # print(relevance_z_query.shape)
        query_img = x.narrow(1, model.n_support,
                             model.n_query).view(model.n_way * model.n_query,
                                                 *x.size()[2:])
        # query_img_copy = query_img.view(model.n_way, model.n_query, *x.size()[2:])
        # print(query_img.shape)
        for k in range(model.n_way):
            relevance_querry_cls = torch.narrow(relevance_z_query, 1, k,
                                                1).squeeze(1)
            # print(relevance_querry_cls.shape)
            relevance_querry_img = feature_model.compute_lrp(
                query_img.cuda(), target=relevance_querry_cls)
            # print(relevance_querry_img.max(), relevance_querry_img.min())
            # print(relevance_querry_img.shape)
            for j in range(model.n_query * model.n_way):
                predict_class = label_to_readableclass[preds[j]]
                true_class = query_gt_class[int(j % model.n_way)][int(
                    j // model.n_way)]
                explain_class = label_to_readableclass[k]
                img_name = query_img_path[int(j % model.n_way)][int(
                    j // model.n_way)].split('/')[-1]
                if not os.path.isdir(
                        os.path.join(explanation_save_dir, 'episode' + str(i),
                                     img_name.strip('.jpg'))):
                    os.makedirs(
                        os.path.join(explanation_save_dir, 'episode' + str(i),
                                     img_name.strip('.jpg')))
                save_path = os.path.join(explanation_save_dir,
                                         'episode' + str(i),
                                         img_name.strip('.jpg'))
                if not os.path.exists(
                        os.path.join(
                            save_path,
                            true_class + '_' + predict_class + img_name)):
                    original_img = Image.fromarray(
                        np.uint8(
                            project(query_img[j].permute(1, 2,
                                                         0).cpu().numpy())))
                    original_img.save(
                        os.path.join(
                            save_path,
                            true_class + '_' + predict_class + img_name))

                img_relevance = relevance_querry_img.narrow(0, j, 1)
                print(predict_class, true_class, explain_class)
                # assert relevance_querry_cls[j].sum() != 0
                # assert img_relevance.sum()!=0
                hm = img_relevance.permute(0, 2, 3, 1).cpu().detach().numpy()
                hm = LRPutil.gamma(hm)
                hm = LRPutil.heatmap(hm)[0]
                hm = project(hm)
                hp_img = Image.fromarray(np.uint8(hm))
                hp_img.save(
                    os.path.join(
                        save_path,
                        true_class + '_' + explain_class + '_lrp_hm.jpg'))
示例#19
0
def get_train_val_loader(params, source_val):
    # to prevent circular import
    from data.datamgr import SimpleDataManager, SetDataManager, AugSetDataManager, VAESetDataManager

    image_size = get_img_size(params)
    base_file, val_file = get_train_val_filename(params)
    if source_val:
        source_val_file = get_source_val_filename(params)

    if params.method in ['baseline', 'baseline++']:
        base_datamgr = SimpleDataManager(image_size, batch_size=16)
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   aug=params.train_aug)
        #         val_datamgr     = SimpleDataManager(image_size, batch_size = 64)
        #         val_loader      = val_datamgr.get_data_loader( val_file, aug = False)

        # to do fine-tune when validation
        n_query = max(
            1, int(16 * params.test_n_way / params.train_n_way)
        )  #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small
        val_few_shot_params = get_few_shot_params(params, 'val')
        val_datamgr = SetDataManager(image_size,
                                     n_query=n_query,
                                     **val_few_shot_params)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)
        if source_val:
            source_val_datamgr = SetDataManager(image_size,
                                                n_query=n_query,
                                                **val_few_shot_params)
            source_val_loader = val_datamgr.get_data_loader(source_val_file,
                                                            aug=False)

    elif params.method in [
            'protonet', 'matchingnet', 'relationnet', 'relationnet_softmax',
            'maml', 'maml_approx'
    ]:
        n_query = max(
            1, int(16 * params.test_n_way / params.train_n_way)
        )  #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small

        #         train_few_shot_params    = dict(n_way = params.train_n_way, n_support = params.n_shot)
        #         val_few_shot_params     = dict(n_way = params.test_n_way, n_support = params.n_shot)
        train_few_shot_params = get_few_shot_params(params, 'train')
        val_few_shot_params = get_few_shot_params(params, 'val')
        if params.vaegan_exp is not None:
            # TODO
            is_training = False
            vaegan = restore_vaegan(params.dataset,
                                    params.vaegan_exp,
                                    params.vaegan_step,
                                    is_training=is_training)

            base_datamgr = VAESetDataManager(
                image_size,
                n_query=n_query,
                vaegan_exp=params.vaegan_exp,
                vaegan_step=params.vaegan_step,
                vaegan_is_train=params.vaegan_is_train,
                lambda_zlogvar=params.zvar_lambda,
                fake_prob=params.fake_prob,
                **train_few_shot_params)
            # train_val or val???
            val_datamgr = SetDataManager(image_size,
                                         n_query=n_query,
                                         **val_few_shot_params)

        elif params.aug_target is None:  # Common Case
            assert params.aug_type is None

            base_datamgr = SetDataManager(image_size,
                                          n_query=n_query,
                                          **train_few_shot_params)
            val_datamgr = SetDataManager(image_size,
                                         n_query=n_query,
                                         **val_few_shot_params)
            if source_val:
                source_val_datamgr = SetDataManager(image_size,
                                                    n_query=n_query,
                                                    **val_few_shot_params)
        else:
            aug_type = params.aug_type
            assert aug_type is not None
            base_datamgr = AugSetDataManager(image_size,
                                             n_query=n_query,
                                             aug_type=aug_type,
                                             aug_target=params.aug_target,
                                             **train_few_shot_params)
            val_datamgr = AugSetDataManager(image_size,
                                            n_query=n_query,
                                            aug_type=aug_type,
                                            aug_target='test-sample',
                                            **val_few_shot_params)
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   aug=params.train_aug)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)
        if source_val:
            source_val_loader = val_datamgr.get_data_loader(source_val_file,
                                                            aug=False)
        #a batch for SetDataManager: a [n_way, n_support + n_query, n_channel, w, h] tensor

    else:
        raise ValueError('Unknown method')

    if source_val:
        return base_loader, val_loader, source_val_loader
    else:
        return base_loader, val_loader
示例#20
0
def explain_gnnnet():
    params = options.parse_args('test')
    feature_model = backbone.model_dict['ResNet10']
    params.method = 'gnnnet'
    params.dataset = 'miniImagenet'  # name relationnet --testset miniImagenet
    params.name = 'gnn'
    params.testset = 'miniImagenet'
    params.data_dir = '/home/sunjiamei/work/fewshotlearning/dataset/'
    params.save_dir = '/home/sunjiamei/work/fewshotlearning/CrossDomainFewShot-master/output'

    if 'Conv' in params.model:
        image_size = 84
    else:
        image_size = 224
    split = params.split
    n_query = 1
    loadfile = os.path.join(params.data_dir, params.testset, split + '.json')
    few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot)
    data_datamgr = SetDataManager(image_size,
                                  n_query=n_query,
                                  **few_shot_params)
    data_loader = data_datamgr.get_data_loader(loadfile, aug=False)

    # model
    print('  build metric-based model')
    if params.method == 'protonet':
        model = ProtoNet(backbone.model_dict[params.model], **few_shot_params)
    elif params.method == 'matchingnet':
        model = MatchingNet(backbone.model_dict[params.model],
                            **few_shot_params)
    elif params.method == 'gnnnet':
        model = GnnNet(backbone.model_dict[params.model], **few_shot_params)
    elif params.method in ['relationnet', 'relationnet_softmax']:
        if params.model == 'Conv4':
            feature_model = backbone.Conv4NP
        elif params.model == 'Conv6':
            feature_model = backbone.Conv6NP
        else:
            feature_model = backbone.model_dict[params.model]
        loss_type = 'LRP'
        model = RelationNet(feature_model,
                            loss_type=loss_type,
                            **few_shot_params)
    else:
        raise ValueError('Unknown method')

    checkpoint_dir = '%s/checkpoints/%s' % (params.save_dir, params.name)
    # print(checkpoint_dir)
    if params.save_epoch != -1:
        modelfile = get_assigned_file(checkpoint_dir, params.save_epoch)
    else:
        modelfile = get_best_file(checkpoint_dir)
        # print(modelfile)
    if modelfile is not None:
        tmp = torch.load(modelfile)
        try:
            model.load_state_dict(tmp['state'])
            print('loaded pretrained model')
        except RuntimeError:
            print('warning! RuntimeError when load_state_dict()!')
            model.load_state_dict(tmp['state'], strict=False)
        except KeyError:
            for k in tmp['model_state']:  ##### revise latter
                if 'running' in k:
                    tmp['model_state'][k] = tmp['model_state'][k].squeeze()
            model.load_state_dict(tmp['model_state'], strict=False)
        except:
            raise

    model = model.cuda()
    model.eval()
    model.n_query = n_query
    # for module in model.modules():
    #   print(type(module))
    lrp_preset = lrp_presets.SequentialPresetA()
    feature_model = model.feature
    fc_encoder = model.fc
    gnn_net = model.gnn
    lrp_wrapper.add_lrp(fc_encoder, lrp_preset)
    # lrp_wrapper.add_lrp(feature_model, lrp_preset)
    # lrp_wrapper.add_lrp(fc_encoder,lrp_preset)
    # lrp_wrapper.add_lrp(feature_model, lrp_preset)

    # acc = 0
    # count = 0
    # tested the forward pass is correct by observing the accuracy
    # for i, (x, _, _) in enumerate(data_loader):
    #   x = x.cuda()
    #   support_label = torch.from_numpy(np.repeat(range(model.n_way), model.n_support)).unsqueeze(1)
    #   support_label = torch.zeros(model.n_way*model.n_support, model.n_way).scatter(1, support_label, 1).view(model.n_way, model.n_support, model.n_way)
    #   support_label = torch.cat([support_label, torch.zeros(model.n_way, 1, model.n_way)], dim=1)
    #   support_label = support_label.view(1, -1, model.n_way)
    #   support_label = support_label.cuda()
    #   x = x.view(-1, *x.size()[2:])
    #
    #   x_feature = feature_model(x)
    #   x_fc_encoded = fc_encoder(x_feature)
    #   z = x_fc_encoded.view(model.n_way, -1, x_fc_encoded.size(1))
    #   gnn_feature = [
    #     torch.cat([z[:, :model.n_support], z[:, model.n_support + i:model.n_support + i + 1]], dim=1).view(1, -1, z.size(2))
    #     for i in range(model.n_query)]
    #   gnn_nodes = torch.cat([torch.cat([z, support_label], dim=2) for z in gnn_feature], dim=0)
    #   scores = gnn_net(gnn_nodes)
    #   scores = scores.view(model.n_query, model.n_way, model.n_support + 1, model.n_way)[:, :, -1].permute(1, 0,
    #                                                                                                    2).contiguous().view(
    #     -1, model.n_way)
    #   pred = scores.data.cpu().numpy().argmax(axis=1)
    #   y = np.repeat(range(model.n_way), n_query)
    #   acc += np.sum(pred == y)
    #   count += len(y)
    #   # print(1.0*acc/count)
    # print(1.0*acc/count)
    with open(
            '/home/sunjiamei/work/fewshotlearning/dataset/miniImagenet/class_to_readablelabel.json',
            'r') as f:
        class_to_readable = json.load(f)
    explanation_save_dir = os.path.join(params.save_dir, 'explanations',
                                        params.name)
    if not os.path.isdir(explanation_save_dir):
        os.makedirs(explanation_save_dir)
    for batch_idx, (x, y, p) in enumerate(data_loader):
        print(p)
        label_to_readableclass, query_img_path, query_gt_class = LRPutil.get_class_label(
            p, class_to_readable, model.n_query)
        x = x.cuda()
        support_label = torch.from_numpy(
            np.repeat(range(model.n_way),
                      model.n_support)).unsqueeze(1)  #torch.Size([25, 1])
        support_label = torch.zeros(model.n_way * model.n_support,
                                    model.n_way).scatter(1, support_label,
                                                         1).view(
                                                             model.n_way,
                                                             model.n_support,
                                                             model.n_way)
        support_label = torch.cat(
            [support_label,
             torch.zeros(model.n_way, 1, model.n_way)], dim=1)
        support_label = support_label.view(1, -1, model.n_way)
        support_label = support_label.cuda()  #torch.Size([1, 30, 5])
        x = x.contiguous()
        x = x.view(-1, *x.size()[2:])  #torch.Size([30, 3, 224, 224])
        x_feature = feature_model(x)  #torch.Size([30, 512])
        x_fc_encoded = fc_encoder(x_feature)  #torch.Size([30, 128])
        z = x_fc_encoded.view(model.n_way, -1,
                              x_fc_encoded.size(1))  # (5,6,128)
        gnn_feature = [
            torch.cat([
                z[:, :model.n_support],
                z[:, model.n_support + i:model.n_support + i + 1]
            ],
                      dim=1).view(1, -1, z.size(2))
            for i in range(model.n_query)
        ]  # model.n_query is the number of query images for each class
        # gnn_feature is grouped into n_query groups. each group contains the support image features concatenated with one query image features.
        # print(len(gnn_feature), gnn_feature[0].shape)
        gnn_nodes = torch.cat(
            [torch.cat([z, support_label], dim=2) for z in gnn_feature], dim=0
        )  # the features are concatenated with the one hot label. for the unknow image the one hot label is all zero

        #  perform gnn_net step by step
        #  the first iteration
        print('x', gnn_nodes.shape)
        W_init = torch.eye(
            gnn_nodes.size(1), device=gnn_nodes.device
        ).unsqueeze(0).repeat(gnn_nodes.size(0), 1, 1).unsqueeze(
            3
        )  # (n_querry, n_way*(num_support + 1), n_way*(num_support + 1), 1)
        # print(W_init.shape)

        W1 = gnn_net._modules['layer_w{}'.format(0)](
            gnn_nodes, W_init
        )  # (n_querry, n_way*(num_support + 1), n_way*(num_support + 1), 2)
        # print(Wi.shape)
        x_new1 = F.leaky_relu(gnn_net._modules['layer_l{}'.format(0)](
            [W1, gnn_nodes])[1])  # (num_querry, num_support + 1, num_outputs)
        # print(x_new1.shape)  #torch.Size([1, 30, 48])
        gnn_nodes_1 = torch.cat([gnn_nodes, x_new1],
                                2)  # (concat more features)
        # print('gn1',gnn_nodes_1.shape) #torch.Size([1, 30, 181])

        #  the second iteration
        W2 = gnn_net._modules['layer_w{}'.format(1)](
            gnn_nodes_1, W_init
        )  # (n_querry, n_way*(num_support + 1), n_way*(num_support + 1), 2)
        x_new2 = F.leaky_relu(gnn_net._modules['layer_l{}'.format(1)](
            [W2,
             gnn_nodes_1])[1])  # (num_querry, num_support + 1, num_outputs)
        # print(x_new2.shape)
        gnn_nodes_2 = torch.cat([gnn_nodes_1, x_new2],
                                2)  # (concat more features)
        # print('gn2', gnn_nodes_2.shape)  #torch.Size([1, 30, 229])

        Wl = gnn_net.w_comp_last(gnn_nodes_2, W_init)
        # print(Wl.shape)  #torch.Size([1, 30, 30, 2])
        scores = gnn_net.layer_last(
            [Wl, gnn_nodes_2])[1]  # (num_querry, num_support + 1, num_way)
        print(scores.shape)

        scores_sf = torch.softmax(scores, dim=-1)
        # print(scores_sf)

        gnn_logits = torch.log(LRPutil.LOGIT_BETA * scores_sf /
                               (1 - scores_sf))
        gnn_logits = gnn_logits.view(-1, model.n_way,
                                     model.n_support + n_query, model.n_way)
        # print(gnn_logits)
        query_scores = scores.view(
            model.n_query, model.n_way, model.n_support + 1,
            model.n_way)[:, :,
                         -1].permute(1, 0,
                                     2).contiguous().view(-1, model.n_way)
        preds = query_scores.data.cpu().numpy().argmax(axis=-1)
        # print(preds.shape)
        for k in range(model.n_way):
            mask = torch.zeros(5).cuda()
            mask[k] = 1
            gnn_logits_cls = gnn_logits.clone()
            gnn_logits_cls[:, :, -1] = gnn_logits_cls[:, :, -1] * mask
            # print(gnn_logits_cls)
            # print(gnn_logits_cls.shape)
            gnn_logits_cls = gnn_logits_cls.view(-1, model.n_way)
            relevance_gnn_nodes_2 = explain_Gconv(gnn_logits_cls,
                                                  gnn_net.layer_last, Wl,
                                                  gnn_nodes_2)
            relevance_x_new2 = relevance_gnn_nodes_2.narrow(-1, 181, 48)
            # relevance_gnn_nodes = relevance_gnn_nodes_2
            relevance_gnn_nodes_1 = explain_Gconv(
                relevance_x_new2, gnn_net._modules['layer_l{}'.format(1)], W2,
                gnn_nodes_1)
            relevance_x_new1 = relevance_gnn_nodes_1.narrow(-1, 133, 48)
            relevance_gnn_nodes = explain_Gconv(
                relevance_x_new1, gnn_net._modules['layer_l{}'.format(0)], W1,
                gnn_nodes)
            relevance_gnn_features = relevance_gnn_nodes.narrow(-1, 0, 128)
            print(relevance_gnn_features.shape)
            relevance_gnn_features += relevance_gnn_nodes_1.narrow(-1, 0, 128)
            relevance_gnn_features += relevance_gnn_nodes_2.narrow(
                -1, 0, 128)  #[2, 30, 128]
            relevance_gnn_features = relevance_gnn_features.view(
                n_query, model.n_way, model.n_support + 1, 128)
            for i in range(n_query):
                query_i = relevance_gnn_features[i][:, model.
                                                    n_support:model.n_support +
                                                    1]
                if i == 0:
                    relevance_z = query_i
                else:
                    relevance_z = torch.cat((relevance_z, query_i), 1)
            relevance_z = relevance_z.view(-1, 128)
            query_feature = x_feature.view(model.n_way, -1,
                                           512)[:, model.n_support:]
            # print(query_feature.shape)
            query_feature = query_feature.contiguous()
            query_feature = query_feature.view(n_query * model.n_way, 512)
            # print(query_feature.shape)
            relevance_query_features = fc_encoder.compute_lrp(
                query_feature, target=relevance_z)
            # print(relevance_query_features.shape)
            # print(relevance_gnn_features.shape)
            # explain the fc layer and the image encoder
            query_images = x.view(model.n_way, -1,
                                  *x.size()[1:])[:, model.n_support:]
            query_images = query_images.contiguous()

            query_images = query_images.view(-1, *x.size()[1:]).detach()
            # print(query_images.shape)
            lrp_wrapper.add_lrp(feature_model, lrp_preset)
            relevance_query_images = feature_model.compute_lrp(
                query_images, target=relevance_query_features)
            print(relevance_query_images.shape)

            for j in range(n_query * model.n_way):
                predict_class = label_to_readableclass[preds[j]]
                true_class = query_gt_class[int(j % model.n_way)][int(
                    j // model.n_way)]
                explain_class = label_to_readableclass[k]
                img_name = query_img_path[int(j % model.n_way)][int(
                    j // model.n_way)].split('/')[-1]
                if not os.path.isdir(
                        os.path.join(explanation_save_dir, 'episode' +
                                     str(batch_idx), img_name.strip('.jpg'))):
                    os.makedirs(
                        os.path.join(explanation_save_dir,
                                     'episode' + str(batch_idx),
                                     img_name.strip('.jpg')))
                save_path = os.path.join(explanation_save_dir,
                                         'episode' + str(batch_idx),
                                         img_name.strip('.jpg'))
                if not os.path.exists(
                        os.path.join(
                            save_path,
                            true_class + '_' + predict_class + img_name)):
                    original_img = Image.fromarray(
                        np.uint8(
                            project(query_images[j].permute(
                                1, 2, 0).detach().cpu().numpy())))
                    original_img.save(
                        os.path.join(
                            save_path,
                            true_class + '_' + predict_class + img_name))

                img_relevance = relevance_query_images.narrow(0, j, 1)
                print(predict_class, true_class, explain_class)
                # assert relevance_querry_cls[j].sum() != 0
                # assert img_relevance.sum()!=0
                hm = img_relevance.permute(0, 2, 3, 1).cpu().detach().numpy()
                hm = LRPutil.gamma(hm)
                hm = LRPutil.heatmap(hm)[0]
                hm = project(hm)
                hp_img = Image.fromarray(np.uint8(hm))
                hp_img.save(
                    os.path.join(
                        save_path,
                        true_class + '_' + explain_class + '_lrp_hm.jpg'))

        break
示例#21
0
    ITER_NUM = 600
    N_QUERY = 15

    test_datamgr = SetDataManager(
        "CUB",
        84,
        n_query=N_QUERY,
        n_way=args.test_n_way,
        n_support=args.n_shot,
        n_episode=ITER_NUM,
        args=args,
    )
    test_loader = test_datamgr.get_data_loader(
        os.path.join(constants.DATA_DIR, f"{args.split}.json"),
        aug=False,
        lang_dir=constants.LANG_DIR,
        normalize=False,
        vocab=vocab,
    )
    normalizer = TransformLoader(84).get_normalize()

    model.eval()

    acc_all = model.test_loop(
        test_loader,
        normalizer=normalizer,
        verbose=True,
        return_all=True,
        # Debug on first loop only
        debug=args.debug,
        debug_dir=os.path.split(args.checkpoint_dir)[0],
示例#22
0
        elif params.dataset == "CropDiseases":
            target_datamgr = CropDisease_few_show.SetDataManager(
                image_size, n_query=n_query, **train_few_shot_params)

        else:
            raise ValueError('Unknown dataset')

        target_loader = target_datamgr.get_data_loader(novel_file,
                                                       aug=params.train_aug)

        if params.adversarial or params.adaptFinetune:
            # TODO: check argv
            target_datamgr = SetDataManager(image_size,
                                            n_query=n_query,
                                            **train_few_shot_params)
            target_loader = target_datamgr.get_data_loader(novel_file,
                                                           aug=True)
            base_loader = [base_loader, target_loader]

        if params.method == 'protonet':
            # TODO: check argv
            if params.adversarial:
                model = ProtoNet(model_dict[params.model],
                                 params.test_n_way,
                                 params.n_shot,
                                 discriminator=backbone.Disc_model(
                                     params.train_n_way),
                                 cosine=params.cosine)
            elif params.adaptFinetune:
                assert (params.adversarial == False)
                model = ProtoNet(
                    model_dict[params.model],
示例#23
0
        train_few_shot_params = dict(n_way=params.train_n_way,
                                     n_support=params.n_shot)
        base_datamgr = SetDataManager(image_size,
                                      n_query=n_query,
                                      **train_few_shot_params)
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   aug=params.train_aug)
        if params.n_shot_test == -1:  # modify val loader support
            params.n_shot_test = params.n_shot
        else:  # modify target loader support
            train_few_shot_params['n_support'] = params.n_shot_test
        if params.adversarial or params.adaptFinetune:
            target_datamgr = SetDataManager(image_size,
                                            n_query=n_query,
                                            **train_few_shot_params)
            target_loader = target_datamgr.get_data_loader(novel_file,
                                                           aug=False)
            # ipdb.set_trace()
            # bl, tl = iter(base_loader), iter(target_loader)
            # bx, _ = next(bl)
            # tx, _ = next(tl)
            base_loader = [base_loader, target_loader]

        test_few_shot_params = dict(n_way=params.test_n_way,
                                    n_support=params.n_shot_test)
        val_datamgr = SetDataManager(image_size,
                                     n_query=n_query,
                                     **test_few_shot_params)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)
        #a batch for SetDataManager: a [n_way, n_support + n_query, dim, w, h] tensor

        if params.method == 'protonet':
示例#24
0
def main():
    timer = Timer()
    args, writer = init()

    train_file = args.dataset_dir + 'train.json'
    val_file = args.dataset_dir + 'val.json'

    few_shot_params = dict(n_way=args.n_way, n_support=args.n_shot, n_query=args.n_query)
    n_episode = 10 if args.debug else 100
    if args.method_type is Method_type.baseline:
        train_datamgr = SimpleDataManager(train_file, args.dataset_dir, args.image_size, batch_size=64)
        train_loader = train_datamgr.get_data_loader(aug = True)
    else:
        train_datamgr = SetDataManager(train_file, args.dataset_dir, args.image_size,
                                       n_episode=n_episode, mode='train', **few_shot_params)
        train_loader = train_datamgr.get_data_loader(aug=True)

    val_datamgr = SetDataManager(val_file, args.dataset_dir, args.image_size,
                                     n_episode=n_episode, mode='val', **few_shot_params)
    val_loader = val_datamgr.get_data_loader(aug=False)

    if args.model_type is Model_type.ConvNet:
        pass
    elif args.model_type is Model_type.ResNet12:
        from methods.backbone import ResNet12
        encoder = ResNet12()
    else:
        raise ValueError('')

    if args.method_type is Method_type.baseline:
        from methods.baselinetrain import BaselineTrain
        model = BaselineTrain(encoder, args)
    elif args.method_type is Method_type.protonet:
        from methods.protonet import ProtoNet
        model = ProtoNet(encoder, args)
    else:
        raise ValueError('')

    from torch.optim import SGD,lr_scheduler
    if args.method_type is Method_type.baseline:
        optimizer = SGD(model.encoder.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_epoch, eta_min=0, last_epoch=-1)
    else:
        optimizer = torch.optim.SGD(model.encoder.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4,
                                    nesterov=True)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.5)

    args.ngpu = torch.cuda.device_count()
    torch.backends.cudnn.benchmark = True
    model = model.cuda()

    label = torch.from_numpy(np.repeat(range(args.n_way), args.n_query))
    label = label.cuda()

    if args.test:
        test(model, label, args, few_shot_params)
        return

    if args.resume:
        resume_OK =  resume_model(model, optimizer, args, scheduler)
    else:
        resume_OK = False
    if (not resume_OK) and  (args.warmup is not None):
        load_pretrained_weights(model, args)

    if args.debug:
        args.max_epoch = args.start_epoch + 1

    for epoch in range(args.start_epoch, args.max_epoch):
        train_one_epoch(model, optimizer, args, train_loader, label, writer, epoch)
        scheduler.step()

        vl, va = val(model, args, val_loader, label)
        if writer is not None:
            writer.add_scalar('data/val_acc', float(va), epoch)
        print('epoch {}, val, loss={:.4f} acc={:.4f}'.format(epoch, vl, va))

        if va >= args.max_acc:
            args.max_acc = va
            print('saving the best model! acc={:.4f}'.format(va))
            save_model(model, optimizer, args, epoch, args.max_acc, 'max_acc', scheduler)
        save_model(model, optimizer, args, epoch, args.max_acc, 'epoch-last', scheduler)
        if epoch != 0:
            print('ETA:{}/{}'.format(timer.measure(), timer.measure(epoch / args.max_epoch)))
    if writer is not None:
        writer.close()
    test(model, label, args, few_shot_params)
示例#25
0
                                 **few_shot_params)

        if params.dataset == 'cross':
            if split == 'base':
                loadfile = configs.data_dir['miniImagenet'] + 'all.json'
            else:
                loadfile = configs.data_dir['CUB'] + split + '.json'
        elif params.dataset == 'cross_char':
            if split == 'base':
                loadfile = configs.data_dir['omniglot'] + 'noLatin.json'
            else:
                loadfile = configs.data_dir['emnist'] + split + '.json'
        else:
            loadfile = configs.data_dir[params.dataset] + split + '.json'

        novel_loader = datamgr.get_data_loader(loadfile, aug=False)
        if params.adaptation:
            model.task_update_num = 100  #We perform adaptation on MAML simply by updating more times.
        model.eval()
        acc_mean, acc_std = model.test_loop(novel_loader, return_std=True)

    else:
        novel_file = os.path.join(
            checkpoint_dir.replace("checkpoints",
                                   "features"), split_str + ".hdf5"
        )  #defaut split = novel, but you can also test base or val classes
        cl_data_file = feat_loader.init_loader(novel_file)

        for i in range(iter_num):
            acc = feature_evaluation(cl_data_file,
                                     model,
def get_logits_targets(params):
    acc_all = []
    iter_num = 600
    few_shot_params = dict(n_way = params.test_n_way , n_support = params.n_shot) 

    if params.dataset in ['omniglot', 'cross_char']:
        assert params.model == 'Conv4' and not params.train_aug ,'omniglot only support Conv4 without augmentation'
        params.model = 'Conv4S'

    if params.method == 'baseline':
        model           = BaselineFinetune( model_dict[params.model], **few_shot_params )
    elif params.method == 'baseline++':
        model           = BaselineFinetune( model_dict[params.model], loss_type = 'dist', **few_shot_params )
    elif params.method == 'protonet':
        model           = ProtoNet( model_dict[params.model], **few_shot_params )
    elif params.method == 'DKT':
        model           = DKT(model_dict[params.model], **few_shot_params)
    elif params.method == 'matchingnet':
        model           = MatchingNet( model_dict[params.model], **few_shot_params )
    elif params.method in ['relationnet', 'relationnet_softmax']:
        if params.model == 'Conv4': 
            feature_model = backbone.Conv4NP
        elif params.model == 'Conv6': 
            feature_model = backbone.Conv6NP
        elif params.model == 'Conv4S': 
            feature_model = backbone.Conv4SNP
        else:
            feature_model = lambda: model_dict[params.model]( flatten = False )
        loss_type = 'mse' if params.method == 'relationnet' else 'softmax'
        model           = RelationNet( feature_model, loss_type = loss_type , **few_shot_params )
    elif params.method in ['maml' , 'maml_approx']:
        backbone.ConvBlock.maml = True
        backbone.SimpleBlock.maml = True
        backbone.BottleneckBlock.maml = True
        backbone.ResNet.maml = True
        model = MAML(  model_dict[params.model], approx = (params.method == 'maml_approx') , **few_shot_params )
        if params.dataset in ['omniglot', 'cross_char']: #maml use different parameter in omniglot
            model.n_task     = 32
            model.task_update_num = 1
            model.train_lr = 0.1
    else:
       raise ValueError('Unknown method')

    model = model.cuda()

    checkpoint_dir = '%s/checkpoints/%s/%s_%s' %(configs.save_dir, params.dataset, params.model, params.method)
    if params.train_aug:
        checkpoint_dir += '_aug'
    if not params.method in ['baseline', 'baseline++'] :
        checkpoint_dir += '_%dway_%dshot' %( params.train_n_way, params.n_shot)

    #modelfile   = get_resume_file(checkpoint_dir)

    if not params.method in ['baseline', 'baseline++'] : 
        if params.save_iter != -1:
            modelfile   = get_assigned_file(checkpoint_dir,params.save_iter)
        else:
            modelfile   = get_best_file(checkpoint_dir)
        if modelfile is not None:
            tmp = torch.load(modelfile)
            model.load_state_dict(tmp['state'])
        else:
            print("[WARNING] Cannot find 'best_file.tar' in: " + str(checkpoint_dir))

    split = params.split
    if params.save_iter != -1:
        split_str = split + "_" +str(params.save_iter)
    else:
        split_str = split
    if params.method in ['maml', 'maml_approx', 'DKT']: #maml do not support testing with feature
        if 'Conv' in params.model:
            if params.dataset in ['omniglot', 'cross_char']:
                image_size = 28
            else:
                image_size = 84 
        else:
            image_size = 224

        datamgr         = SetDataManager(image_size, n_eposide = iter_num, n_query = 15 , **few_shot_params)
        
        if params.dataset == 'cross':
            if split == 'base':
                loadfile = configs.data_dir['miniImagenet'] + 'all.json' 
            else:
                loadfile   = configs.data_dir['CUB'] + split +'.json'
        elif params.dataset == 'cross_char':
            if split == 'base':
                loadfile = configs.data_dir['omniglot'] + 'noLatin.json' 
            else:
                loadfile  = configs.data_dir['emnist'] + split +'.json' 
        else: 
            loadfile    = configs.data_dir[params.dataset] + split + '.json'

        novel_loader     = datamgr.get_data_loader( loadfile, aug = False)
        if params.adaptation:
            model.task_update_num = 100 #We perform adaptation on MAML simply by updating more times.
        model.eval()

        logits_list = list()
        targets_list = list()    
        for i, (x,_) in enumerate(novel_loader):
            logits = model.get_logits(x).detach()
            targets = torch.tensor(np.repeat(range(params.test_n_way), model.n_query)).cuda()
            logits_list.append(logits) #.cpu().detach().numpy())
            targets_list.append(targets) #.cpu().detach().numpy())
    else:
        novel_file = os.path.join( checkpoint_dir.replace("checkpoints","features"), split_str +".hdf5")
        cl_data_file = feat_loader.init_loader(novel_file)
        logits_list = list()
        targets_list = list()
        n_query = 15
        n_way = few_shot_params['n_way']
        n_support = few_shot_params['n_support']
        class_list = cl_data_file.keys()
        for i in range(iter_num):
            #----------------------
            select_class = random.sample(class_list,n_way)
            z_all  = []
            for cl in select_class:
                img_feat = cl_data_file[cl]
                perm_ids = np.random.permutation(len(img_feat)).tolist()
                z_all.append( [ np.squeeze( img_feat[perm_ids[i]]) for i in range(n_support+n_query) ] )     # stack each batch
            z_all = torch.from_numpy(np.array(z_all))
            model.n_query = n_query
            logits  = model.set_forward(z_all, is_feature = True).detach()
            targets = torch.tensor(np.repeat(range(n_way), n_query)).cuda()
            logits_list.append(logits)
            targets_list.append(targets)
            #----------------------
    return torch.cat(logits_list, 0), torch.cat(targets_list, 0)
def single_test(params, results_logger):
    acc_all = []

    iter_num = 600

    few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot)

    if params.dataset in ['omniglot', 'cross_char']:
        assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation'
        params.model = 'Conv4S'

    if params.method == 'baseline':
        model = BaselineFinetune(model_dict[params.model], **few_shot_params)
    elif params.method == 'baseline++':
        model = BaselineFinetune(model_dict[params.model],
                                 loss_type='dist',
                                 **few_shot_params)
    elif params.method == 'protonet':
        model = ProtoNet(model_dict[params.model], **few_shot_params)
    elif params.method == 'DKT':
        model = DKT(model_dict[params.model], **few_shot_params)
    elif params.method == 'matchingnet':
        model = MatchingNet(model_dict[params.model], **few_shot_params)
    elif params.method in ['relationnet', 'relationnet_softmax']:
        if params.model == 'Conv4':
            feature_model = backbone.Conv4NP
        elif params.model == 'Conv6':
            feature_model = backbone.Conv6NP
        elif params.model == 'Conv4S':
            feature_model = backbone.Conv4SNP
        else:
            feature_model = lambda: model_dict[params.model](flatten=False)
        loss_type = 'mse' if params.method == 'relationnet' else 'softmax'
        model = RelationNet(feature_model,
                            loss_type=loss_type,
                            **few_shot_params)
    elif params.method in ['maml', 'maml_approx']:
        backbone.ConvBlock.maml = True
        backbone.SimpleBlock.maml = True
        backbone.BottleneckBlock.maml = True
        backbone.ResNet.maml = True
        model = MAML(model_dict[params.model],
                     approx=(params.method == 'maml_approx'),
                     **few_shot_params)
        if params.dataset in ['omniglot', 'cross_char'
                              ]:  # maml use different parameter in omniglot
            model.n_task = 32
            model.task_update_num = 1
            model.train_lr = 0.1
    else:
        raise ValueError('Unknown method')

    model = model.cuda()

    checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
        configs.save_dir, params.dataset, params.model, params.method)
    if params.train_aug:
        checkpoint_dir += '_aug'
    if not params.method in ['baseline', 'baseline++']:
        checkpoint_dir += '_%dway_%dshot' % (params.train_n_way, params.n_shot)

    # modelfile   = get_resume_file(checkpoint_dir)

    if not params.method in ['baseline', 'baseline++']:
        if params.save_iter != -1:
            modelfile = get_assigned_file(checkpoint_dir, params.save_iter)
        else:
            modelfile = get_best_file(checkpoint_dir)
        if modelfile is not None:
            tmp = torch.load(modelfile)
            model.load_state_dict(tmp['state'])
        else:
            print("[WARNING] Cannot find 'best_file.tar' in: " +
                  str(checkpoint_dir))

    split = params.split
    if params.save_iter != -1:
        split_str = split + "_" + str(params.save_iter)
    else:
        split_str = split
    if params.method in ['maml', 'maml_approx',
                         'DKT']:  # maml do not support testing with feature
        if 'Conv' in params.model:
            if params.dataset in ['omniglot', 'cross_char']:
                image_size = 28
            else:
                image_size = 84
        else:
            image_size = 224

        datamgr = SetDataManager(image_size,
                                 n_eposide=iter_num,
                                 n_query=15,
                                 **few_shot_params)

        if params.dataset == 'cross':
            if split == 'base':
                loadfile = configs.data_dir['miniImagenet'] + 'all.json'
            else:
                loadfile = configs.data_dir['CUB'] + split + '.json'
        elif params.dataset == 'cross_char':
            if split == 'base':
                loadfile = configs.data_dir['omniglot'] + 'noLatin.json'
            else:
                loadfile = configs.data_dir['emnist'] + split + '.json'
        else:
            loadfile = configs.data_dir[params.dataset] + split + '.json'

        novel_loader = datamgr.get_data_loader(loadfile, aug=False)
        if params.adaptation:
            model.task_update_num = 100  # We perform adaptation on MAML simply by updating more times.
        model.eval()
        acc_mean, acc_std = model.test_loop(novel_loader, return_std=True)

    else:
        novel_file = os.path.join(
            checkpoint_dir.replace("checkpoints",
                                   "features"), split_str + ".hdf5"
        )  # defaut split = novel, but you can also test base or val classes
        cl_data_file = feat_loader.init_loader(novel_file)

        for i in range(iter_num):
            acc = feature_evaluation(cl_data_file,
                                     model,
                                     n_query=15,
                                     adaptation=params.adaptation,
                                     **few_shot_params)
            acc_all.append(acc)

        acc_all = np.asarray(acc_all)
        acc_mean = np.mean(acc_all)
        acc_std = np.std(acc_all)
        print('%d Test Acc = %4.2f%% +- %4.2f%%' %
              (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))
    with open('record/results.txt', 'a') as f:
        timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime())
        aug_str = '-aug' if params.train_aug else ''
        aug_str += '-adapted' if params.adaptation else ''
        if params.method in ['baseline', 'baseline++']:
            exp_setting = '%s-%s-%s-%s%s %sshot %sway_test' % (
                params.dataset, split_str, params.model, params.method,
                aug_str, params.n_shot, params.test_n_way)
        else:
            exp_setting = '%s-%s-%s-%s%s %sshot %sway_train %sway_test' % (
                params.dataset, split_str, params.model, params.method,
                aug_str, params.n_shot, params.train_n_way, params.test_n_way)
        acc_str = '%d Test Acc = %4.2f%% +- %4.2f%%' % (
            iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num))
        f.write('Time: %s, Setting: %s, Acc: %s \n' %
                (timestamp, exp_setting, acc_str))
        results_logger.log("single_test_acc", acc_mean)
        results_logger.log("single_test_acc_std",
                           1.96 * acc_std / np.sqrt(iter_num))
        results_logger.log("time", timestamp)
        results_logger.log("exp_setting", exp_setting)
        results_logger.log("acc_str", acc_str)
    return acc_mean
示例#28
0
            params.stop_epoch = 600
        else:
            params.stop_epoch = 400

    #n_query = max(1, int(16* params.test_n_way/params.train_n_way)) #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small
    n_query = 16

    train_few_shot_params    = dict(n_way=params.train_n_way, n_query = n_query, max_n_way=params.train_max_way, \
            min_n_way=params.train_min_way, max_shot=params.max_shot, min_shot=params.min_shot, fixed_way=params.fixed_way)

    base_datamgr = SetDataManager(image_size,
                                  n_support=params.train_n_shot,
                                  n_eposide=100,
                                  **train_few_shot_params)
    base_loader = base_datamgr.get_data_loader(base_file,
                                               [base_file_unk, base_file_sil],
                                               aug=params.train_aug)

    val_few_shot_params     = dict(n_way=-1, n_query = n_query, max_n_way=params.test_max_way, min_n_way=params.test_min_way, \
            max_shot=params.max_shot, min_shot=params.min_shot, fixed_way=params.fixed_way, n_eposide=1000)
    val_datamgr = SetDataManager(image_size,
                                 n_support=-1,
                                 **val_few_shot_params)
    val_loader = val_datamgr.get_data_loader(val_file,
                                             [val_file_unk, val_file_sil],
                                             aug=False)

    if params.method == 'protonet':
        model = ProtoNet(model_dict[params.model], **train_few_shot_params)
    else:
        raise ValueError('Unknown method')