def single_test(params, results_logger):
    acc_all = []

    iter_num = 600

    few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot)

    if params.dataset in ['omniglot', 'cross_char']:
        assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation'
        params.model = 'Conv4S'

    if params.method == 'baseline':
        model = BaselineFinetune(model_dict[params.model], **few_shot_params)
    elif params.method == 'baseline++':
        model = BaselineFinetune(model_dict[params.model],
                                 loss_type='dist',
                                 **few_shot_params)
    elif params.method == 'protonet':
        model = ProtoNet(model_dict[params.model], **few_shot_params)
    elif params.method == 'DKT':
        model = DKT(model_dict[params.model], **few_shot_params)
    elif params.method == 'matchingnet':
        model = MatchingNet(model_dict[params.model], **few_shot_params)
    elif params.method in ['relationnet', 'relationnet_softmax']:
        if params.model == 'Conv4':
            feature_model = backbone.Conv4NP
        elif params.model == 'Conv6':
            feature_model = backbone.Conv6NP
        elif params.model == 'Conv4S':
            feature_model = backbone.Conv4SNP
        else:
            feature_model = lambda: model_dict[params.model](flatten=False)
        loss_type = 'mse' if params.method == 'relationnet' else 'softmax'
        model = RelationNet(feature_model,
                            loss_type=loss_type,
                            **few_shot_params)
    elif params.method in ['maml', 'maml_approx']:
        backbone.ConvBlock.maml = True
        backbone.SimpleBlock.maml = True
        backbone.BottleneckBlock.maml = True
        backbone.ResNet.maml = True
        model = MAML(model_dict[params.model],
                     approx=(params.method == 'maml_approx'),
                     **few_shot_params)
        if params.dataset in ['omniglot', 'cross_char'
                              ]:  # maml use different parameter in omniglot
            model.n_task = 32
            model.task_update_num = 1
            model.train_lr = 0.1
    else:
        raise ValueError('Unknown method')

    model = model.cuda()

    checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
        configs.save_dir, params.dataset, params.model, params.method)
    if params.train_aug:
        checkpoint_dir += '_aug'
    if not params.method in ['baseline', 'baseline++']:
        checkpoint_dir += '_%dway_%dshot' % (params.train_n_way, params.n_shot)

    # modelfile   = get_resume_file(checkpoint_dir)

    if not params.method in ['baseline', 'baseline++']:
        if params.save_iter != -1:
            modelfile = get_assigned_file(checkpoint_dir, params.save_iter)
        else:
            modelfile = get_best_file(checkpoint_dir)
        if modelfile is not None:
            tmp = torch.load(modelfile)
            model.load_state_dict(tmp['state'])
        else:
            print("[WARNING] Cannot find 'best_file.tar' in: " +
                  str(checkpoint_dir))

    split = params.split
    if params.save_iter != -1:
        split_str = split + "_" + str(params.save_iter)
    else:
        split_str = split
    if params.method in ['maml', 'maml_approx',
                         'DKT']:  # maml do not support testing with feature
        if 'Conv' in params.model:
            if params.dataset in ['omniglot', 'cross_char']:
                image_size = 28
            else:
                image_size = 84
        else:
            image_size = 224

        datamgr = SetDataManager(image_size,
                                 n_eposide=iter_num,
                                 n_query=15,
                                 **few_shot_params)

        if params.dataset == 'cross':
            if split == 'base':
                loadfile = configs.data_dir['miniImagenet'] + 'all.json'
            else:
                loadfile = configs.data_dir['CUB'] + split + '.json'
        elif params.dataset == 'cross_char':
            if split == 'base':
                loadfile = configs.data_dir['omniglot'] + 'noLatin.json'
            else:
                loadfile = configs.data_dir['emnist'] + split + '.json'
        else:
            loadfile = configs.data_dir[params.dataset] + split + '.json'

        novel_loader = datamgr.get_data_loader(loadfile, aug=False)
        if params.adaptation:
            model.task_update_num = 100  # We perform adaptation on MAML simply by updating more times.
        model.eval()
        acc_mean, acc_std = model.test_loop(novel_loader, return_std=True)

    else:
        novel_file = os.path.join(
            checkpoint_dir.replace("checkpoints",
                                   "features"), split_str + ".hdf5"
        )  # defaut split = novel, but you can also test base or val classes
        cl_data_file = feat_loader.init_loader(novel_file)

        for i in range(iter_num):
            acc = feature_evaluation(cl_data_file,
                                     model,
                                     n_query=15,
                                     adaptation=params.adaptation,
                                     **few_shot_params)
            acc_all.append(acc)

        acc_all = np.asarray(acc_all)
        acc_mean = np.mean(acc_all)
        acc_std = np.std(acc_all)
        print('%d Test Acc = %4.2f%% +- %4.2f%%' %
              (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))
    with open('record/results.txt', 'a') as f:
        timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime())
        aug_str = '-aug' if params.train_aug else ''
        aug_str += '-adapted' if params.adaptation else ''
        if params.method in ['baseline', 'baseline++']:
            exp_setting = '%s-%s-%s-%s%s %sshot %sway_test' % (
                params.dataset, split_str, params.model, params.method,
                aug_str, params.n_shot, params.test_n_way)
        else:
            exp_setting = '%s-%s-%s-%s%s %sshot %sway_train %sway_test' % (
                params.dataset, split_str, params.model, params.method,
                aug_str, params.n_shot, params.train_n_way, params.test_n_way)
        acc_str = '%d Test Acc = %4.2f%% +- %4.2f%%' % (
            iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num))
        f.write('Time: %s, Setting: %s, Acc: %s \n' %
                (timestamp, exp_setting, acc_str))
        results_logger.log("single_test_acc", acc_mean)
        results_logger.log("single_test_acc_std",
                           1.96 * acc_std / np.sqrt(iter_num))
        results_logger.log("time", timestamp)
        results_logger.log("exp_setting", exp_setting)
        results_logger.log("acc_str", acc_str)
    return acc_mean
Exemplo n.º 2
0
def train_baseline(base_loader, base_loader_test, val_loader, model,
                   start_epoch, stop_epoch, params, tmp):
    if params.dct_status:
        channels = params.channels
    else:
        channels = 3

    val_acc_best = 0.0

    if not os.path.isdir(params.checkpoint_dir):
        os.makedirs(params.checkpoint_dir)

    if path.exists(params.checkpoint_dir + '/val_' + params.dataset + '.pt'):
        loader = torch.load(params.checkpoint_dir + '/val_' + params.dataset +
                            '.pt')
    else:
        loader = []
        for ii, (x, _) in enumerate(val_loader):
            loader.append(x)
            #print("head of train_dct: ", x.shape)
        torch.save(loader,
                   params.checkpoint_dir + '/val_' + params.dataset + '.pt')

    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = optim.Adam(model.parameters())
    print("stop_epoch", start_epoch, stop_epoch)
    for epoch in range(start_epoch, stop_epoch):
        print('\nEpoch: %d' % epoch)
        model.train()
        train_loss = 0
        reg_loss = 0
        correct = 0
        correct1 = 0.0
        total = 0

        for batch_idx, (input_var, target_var) in enumerate(base_loader):
            if use_gpu:
                input_var, target_var = input_var.cuda(), target_var.cuda()
            input_dct_var, target_var = Variable(input_var), Variable(
                target_var)
            f, outputs = model.forward(input_dct_var)
            loss = criterion(outputs, target_var)
            train_loss += loss.data.item()
            _, predicted = torch.max(outputs.data, 1)
            total += target_var.size(0)
            correct += predicted.eq(target_var.data).cpu().sum()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch_idx % 50 == 0:
                print(
                    '{0}/{1}'.format(batch_idx, len(base_loader)),
                    'Loss: %.3f | Acc: %.3f%%  ' %
                    (train_loss / (batch_idx + 1), 100. * correct / total))

        if not os.path.isdir(params.checkpoint_dir):
            os.makedirs(params.checkpoint_dir)

        if (epoch % params.save_freq == 0) or (epoch == stop_epoch - 1):
            outfile = os.path.join(params.checkpoint_dir,
                                   '{:d}.tar'.format(epoch))
            torch.save({'epoch': epoch, 'state': model.state_dict()}, outfile)

        model.eval()
        with torch.no_grad():
            test_loss = 0
            correct = 0
            total = 0
            for batch_idx, (inputs, targets) in enumerate(base_loader_test):
                if use_gpu:
                    inputs, targets = inputs.cuda(), targets.cuda()
                inputs, targets = Variable(inputs), Variable(targets)
                f, outputs = model.forward(inputs)
                loss = criterion(outputs, targets)
                test_loss += loss.data.item()
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += predicted.eq(targets.data).cpu().sum()

            print('Loss: %.3f | Acc: %.3f%%' %
                  (test_loss / (batch_idx + 1), 100. * correct / total))
        torch.cuda.empty_cache()

        valmodel = BaselineFinetune(model_dict[params.model],
                                    params.train_n_way,
                                    params.n_shot,
                                    loss_type='dist')
        valmodel.n_query = 15
        acc_all1, acc_all2, acc_all3 = [], [], []
        for i, x in enumerate(loader):
            # print("len of loader: ",len(loader))
            # print("shape of x: ",x.shape)
            if params.dct_status:
                x = x.view(-1, channels, image_size_dct, image_size_dct)
            else:
                x = x.view(-1, channels, image_size, image_size)

            if use_gpu:
                x = x.cuda()

            with torch.no_grad():
                f, scores = model(x)
            f = f.view(params.train_n_way, params.n_shot + valmodel.n_query,
                       -1)
            scores = valmodel.set_forward_adaptation(f.cpu())
            acc = []
            for each_score in scores:
                pred = each_score.data.cpu().numpy().argmax(axis=1)
                y = np.repeat(range(5), 15)
                acc.append(np.mean(pred == y) * 100)
            acc_all1.append(acc[0])
            acc_all2.append(acc[1])
            acc_all3.append(acc[2])

        print('Test Acc at 100= %4.2f%%' % (np.mean(acc_all1)))
        print('Test Acc at 200= %4.2f%%' % (np.mean(acc_all2)))
        print('Test Acc at 300= %4.2f%%' % (np.mean(acc_all3)))

        if np.mean(acc_all3) > val_acc_best:
            val_acc_best = np.mean(acc_all3)
            bestfile = os.path.join(params.checkpoint_dir, 'best.tar')
            torch.save({'epoch': epoch, 'state': model.state_dict()}, bestfile)

    return model
Exemplo n.º 3
0
def train_s2m2(base_loader, base_loader_test, val_loader, model, start_epoch,
               stop_epoch, params, tmp):

    if params.dct_status:
        channels = params.channels
    else:
        channels = 3

    val_acc_best = 0.0

    if not os.path.isdir(params.checkpoint_dir):
        os.makedirs(params.checkpoint_dir)

    if path.exists(params.checkpoint_dir + '/val_' + params.dataset + '.pt'):
        loader = torch.load(params.checkpoint_dir + '/val_' + params.dataset +
                            '.pt')
    else:
        loader = []
        for _, (x, _) in enumerate(val_loader):
            loader.append(x)
        torch.save(loader,
                   params.checkpoint_dir + '/val_' + params.dataset + '.pt')

    def mixup_criterion(criterion, pred, y_a, y_b, lam):
        return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

    criterion = nn.CrossEntropyLoss()

    if params.model == 'WideResNet28_10':
        rotate_classifier = nn.Sequential(nn.Linear(640, 4))
    elif params.model == 'ResNet18':
        rotate_classifier = nn.Sequential(nn.Linear(512, 4))

    rotate_classifier.cuda()

    if 'rotate' in tmp:
        print("loading rotate model")
        rotate_classifier.load_state_dict(tmp['rotate'])

    optimizer = torch.optim.Adam([{
        'params': model.parameters()
    }, {
        'params': rotate_classifier.parameters()
    }])

    print("stop_epoch", start_epoch, stop_epoch)

    for epoch in range(start_epoch, stop_epoch):
        print('\nEpoch: %d' % epoch)

        model.train()
        train_loss = 0
        rotate_loss = 0
        correct = 0
        total = 0
        torch.cuda.empty_cache()
        print("inside base_loader: ", len(base_loader))
        for batch_idx, (inputs, targets) in enumerate(base_loader):
            if use_gpu:
                inputs, targets = inputs.cuda(), targets.cuda()
            #print("shape of input: ", inputs.shape)
            lam = np.random.beta(params.alpha, params.alpha)
            f, outputs, target_a, target_b = model(inputs,
                                                   targets,
                                                   mixup_hidden=True,
                                                   mixup_alpha=params.alpha,
                                                   lam=lam)
            loss = mixup_criterion(criterion, outputs, target_a, target_b, lam)
            train_loss += loss.data.item()
            optimizer.zero_grad()
            loss.backward()

            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += (
                lam * predicted.eq(target_a.data).cpu().sum().float() +
                (1 - lam) * predicted.eq(target_b.data).cpu().sum().float())

            bs = inputs.size(0)
            inputs_ = []
            targets_ = []
            a_ = []
            indices = np.arange(bs)
            np.random.shuffle(indices)

            split_size = int(bs / 4)
            for j in indices[0:split_size]:
                x90 = inputs[j].transpose(2, 1).flip(1)
                x180 = x90.transpose(2, 1).flip(1)
                x270 = x180.transpose(2, 1).flip(1)
                inputs_ += [inputs[j], x90, x180, x270]
                targets_ += [targets[j] for _ in range(4)]
                a_ += [
                    torch.tensor(0),
                    torch.tensor(1),
                    torch.tensor(2),
                    torch.tensor(3)
                ]

            inputs = Variable(torch.stack(inputs_, 0))
            targets = Variable(torch.stack(targets_, 0))
            a_ = Variable(torch.stack(a_, 0))

            if use_gpu:
                inputs = inputs.cuda()
                targets = targets.cuda()
                a_ = a_.cuda()

            rf, outputs = model(inputs)
            rotate_outputs = rotate_classifier(rf)
            rloss = criterion(rotate_outputs, a_)
            closs = criterion(outputs, targets)
            loss = (rloss + closs) / 2.0

            rotate_loss += rloss.data.item()

            loss.backward()

            optimizer.step()

            if batch_idx % 50 == 0:
                print(
                    '{0}/{1}'.format(batch_idx, len(base_loader)),
                    'Loss: %.3f | Acc: %.3f%% | RotLoss: %.3f  ' %
                    (train_loss /
                     (batch_idx + 1), 100. * correct / total, rotate_loss /
                     (batch_idx + 1)))

        if not os.path.isdir(params.checkpoint_dir):
            os.makedirs(params.checkpoint_dir)

        if (epoch % params.save_freq == 0) or (epoch == stop_epoch - 1):
            outfile = os.path.join(params.checkpoint_dir,
                                   '{:d}.tar'.format(epoch))
            torch.save({'epoch': epoch, 'state': model.state_dict()}, outfile)

        model.eval()
        with torch.no_grad():
            test_loss = 0
            correct = 0
            total = 0
            for batch_idx, (inputs, targets) in enumerate(base_loader_test):
                if use_gpu:
                    inputs, targets = inputs.cuda(), targets.cuda()
                inputs, targets = Variable(inputs), Variable(targets)
                f, outputs = model.forward(inputs)
                loss = criterion(outputs, targets)
                test_loss += loss.data.item()
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += predicted.eq(targets.data).cpu().sum()

            print('Loss: %.3f | Acc: %.3f%%' %
                  (test_loss / (batch_idx + 1), 100. * correct / total))

        if params.dct_status:

            valmodel = BaselineFinetune(model_dict[params.model + '_dct'],
                                        params.train_n_way,
                                        params.n_shot,
                                        loss_type='dist')
        else:
            valmodel = BaselineFinetune(model_dict[params.model],
                                        params.train_n_way,
                                        params.n_shot,
                                        loss_type='dist')
        valmodel.n_query = 15
        acc_all1, acc_all2, acc_all3 = [], [], []
        for i, x in enumerate(loader):
            if params.dct_status:
                x = x.view(-1, channels, image_size_dct, image_size_dct)
            else:
                x = x.view(-1, channels, image_size, image_size)

            if use_gpu:
                x = x.cuda()

            with torch.no_grad():
                f, scores = model(x)
            f = f.view(params.train_n_way, params.n_shot + valmodel.n_query,
                       -1)
            scores = valmodel.set_forward_adaptation(f.cpu())
            acc = []
            for each_score in scores:
                pred = each_score.data.cpu().numpy().argmax(axis=1)
                y = np.repeat(range(5), 15)
                acc.append(np.mean(pred == y) * 100)
            acc_all1.append(acc[0])
            acc_all2.append(acc[1])
            acc_all3.append(acc[2])

        print('Test Acc at 100= %4.2f%%' % (np.mean(acc_all1)))
        print('Test Acc at 200= %4.2f%%' % (np.mean(acc_all2)))
        print('Test Acc at 300= %4.2f%%' % (np.mean(acc_all3)))

        if np.mean(acc_all3) > val_acc_best:
            val_acc_best = np.mean(acc_all3)
            bestfile = os.path.join(params.checkpoint_dir, 'best.tar')
            torch.save(
                {
                    'epoch': epoch,
                    'state': model.state_dict(),
                    'rotate': rotate_classifier.state_dict()
                }, bestfile)

    return model
Exemplo n.º 4
0
        datamgr = SetDataManager(image_size,
                                 n_eposide=iter_num,
                                 n_query=15,
                                 **few_shot_params,
                                 isAircraft=isAircraft,
                                 grey=params.grey)
        loadfile = os.path.join('filelists', params.dataset, 'novel.json')
        novel_loader = datamgr.get_data_loader(loadfile, aug=False)
        if params.adaptation:
            model.task_update_num = 100  #We perform adaptation on MAML simply by updating more times.
        model.eval()
        acc_mean, acc_std = model.test_loop(novel_loader, return_std=True)
    else:

        if params.method == 'baseline':
            model = BaselineFinetune(model_dict[params.model],
                                     **few_shot_params)
        elif params.method == 'baseline++':
            model = BaselineFinetune(model_dict[params.model],
                                     loss_type='dist',
                                     **few_shot_params)

        if params.save_iter != -1:
            outfile = os.path.join(
                checkpoint_dir.replace("checkpoints", "features"),
                split + "_" + str(params.save_iter) + ".hdf5")
        else:
            outfile = os.path.join(
                checkpoint_dir.replace("checkpoints", "features"),
                split + ".hdf5")

        datamgr = SimpleDataManager(image_size,
Exemplo n.º 5
0
if __name__ == '__main__':
    params = parse_args('test')
    set_cuda(params.cuda)

    acc_all = []

    iter_num = 600

    few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot)

    if params.dataset in ['omniglot', 'cross_char']:
        assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation'
        params.model = 'Conv4S'

    if params.method == 'baseline':
        model = BaselineFinetune(model_dict[params.model], **few_shot_params)
    elif params.method == 'baseline++':
        model = BaselineFinetune(model_dict[params.model],
                                 loss_type='dist',
                                 **few_shot_params)
    elif params.method == 'protonet':
        model = ProtoNet(model_dict[params.model], **few_shot_params)
    elif params.method == 'matchingnet':
        model = MatchingNet(model_dict[params.model], **few_shot_params)
    elif params.method in ['relationnet', 'relationnet_softmax']:
        if params.model == 'Conv4':
            feature_model = backbone.Conv4NP
        elif params.model == 'Conv6':
            feature_model = backbone.Conv6NP
        elif params.model == 'Conv4S':
            feature_model = backbone.Conv4SNP
Exemplo n.º 6
0
def get_model(params, mode):
    '''
    Args:
        params: argparse params
        mode: (str), 'train', 'test'
    '''
    print('get_model() start...')
#     few_shot_params_d = get_few_shot_params(params, None)
#     few_shot_params = few_shot_params_d[mode]
    few_shot_params = get_few_shot_params(params, mode)
    
    if 'omniglot' in params.dataset or 'cross_char' in params.dataset:
#     if params.dataset in ['omniglot', 'cross_char', 'cross_char_half', 'cross_char_quarter', ...]:
#         assert params.model == 'Conv4' and not params.train_aug ,'omniglot only support Conv4 without augmentation'
        assert 'Conv4' in params.model and not params.train_aug ,'omniglot/cross_char only support Conv4 without augmentation'
        params.model = params.model.replace('Conv4', 'Conv4S') # because Conv4Drop should also be Conv4SDrop
        if params.recons_decoder is not None:
            if 'ConvS' not in params.recons_decoder:
                raise ValueError('omniglot / cross_char should use ConvS/HiddenConvS decoder.')
    
#     if mode == 'train':
#         params.num_classes = n_base_class_map[params.dataset]
    if params.method in ['baseline', 'baseline++'] and mode=='train':
        assert params.num_classes >= n_base_classes[params.dataset]
#         if params.dataset == 'omniglot': # 4112/688/1692
#             assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class'
#         if params.dataset == 'cross_char': # 1597/31/31
#             assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class'
#         if params.dataset == 'cross_char_half': # 758/31/31
#             assert params.num_classes >= 758, 'class number need to be larger than max label id in base class'
#         if params.dataset in ['cross_char_quarter', 'cross_char_quarter_10shot']: # 350/31/31
#             assert params.num_classes >= 350, 'class number need to be larger than max label id in base class'
#         if params.dataset == 'cross_char_base3lang': # 69/31/31
#             assert params.num_classes >= 69, 'class number need to be larger than max label id in base class'
#         if params.dataset == 'miniImagenet': # 64/16/20
#             assert params.num_classes >= 64, 'class number need to be larger than max label id in base class'
#         if params.dataset == 'CUB': # 100/50/50
#             assert params.num_classes >= 100, 'class number need to be larger than max label id in base class'
#         if params.dataset == 'cross': # 64+16+20/50/50
#             assert params.num_classes >= 100, 'class number need to be larger than max label id in base class'
#         if params.dataset == 'cross_base80cl': # 80/50/50
#             assert params.num_classes >= 100, 'class number need to be larger than max label id in base class'
        
    
    if params.recons_decoder == None:
        print('params.recons_decoder == None')
        recons_decoder = None
    else:
        recons_decoder = decoder_dict[params.recons_decoder]
        print('recons_decoder:\n',recons_decoder)

    backbone_func = get_backbone_func(params)
    
    if 'baseline' in params.method:
        loss_types = {
            'baseline':'softmax', 
            'baseline++':'dist', 
        }
        loss_type = loss_types[params.method]
        
        if recons_decoder is None and params.min_gram is None: # default baseline/baseline++
            if mode == 'train':
                model = BaselineTrain(
                    model_func = backbone_func, loss_type = loss_type, 
                    num_class = params.num_classes, **few_shot_params)
            elif mode == 'test':
                model = BaselineFinetune(
                    model_func = backbone_func, loss_type = loss_type, 
                    **few_shot_params, finetune_dropout_p = params.finetune_dropout_p)
        else: # other settings for baseline
            if params.min_gram is not None:
                min_gram_params = {
                    'min_gram':params.min_gram, 
                    'lambda_gram':params.lambda_gram, 
                }
                if mode == 'train':
                    model = BaselineTrainMinGram(
                        model_func = backbone_func, loss_type = loss_type, 
                        num_class = params.num_classes, **few_shot_params, **min_gram_params)
                elif mode == 'test':
                    model = BaselineFinetune(
                        model_func = backbone_func, loss_type = loss_type, 
                        **few_shot_params, finetune_dropout_p = params.finetune_dropout_p)
#                     model = BaselineFinetuneMinGram(backbone_func, loss_type = loss_type, **few_shot_params, **min_gram_params)
            
    
    elif params.method == 'protonet':
        # default ProtoNet
        if recons_decoder is None and params.min_gram is None:
            model = ProtoNet( backbone_func, **few_shot_params )
        else: # other settings
            if params.min_gram is not None:
                min_gram_params = {
                    'min_gram':params.min_gram, 
                    'lambda_gram':params.lambda_gram, 
                }
                model = ProtoNetMinGram(backbone_func, **few_shot_params, **min_gram_params)

            if params.recons_decoder is not None:
                if 'Hidden' in params.recons_decoder:
                    if params.recons_decoder == 'HiddenConv': # 'HiddenConv', 'HiddenConvS'
                        model = ProtoNetAE2(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, extract_layer = 2)
                    elif params.recons_decoder == 'HiddenConvS': # 'HiddenConv', 'HiddenConvS'
                        model = ProtoNetAE2(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, extract_layer = 2, is_color=False)
                    elif params.recons_decoder == 'HiddenRes10':
                        model = ProtoNetAE2(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, extract_layer = 6)
                    elif params.recons_decoder == 'HiddenRes18':
                        model = ProtoNetAE2(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, extract_layer = 8)
                else:
                    if 'ConvS' in params.recons_decoder:
                        model = ProtoNetAE(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, is_color=False)
                    else:
                        model = ProtoNetAE(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, is_color=True)
    elif params.method == 'matchingnet':
        model           = MatchingNet( backbone_func, **few_shot_params )
    elif params.method in ['relationnet', 'relationnet_softmax']:
#         if params.model == 'Conv4': 
#             feature_model = backbone.Conv4NP
#         elif params.model == 'Conv6': 
#             feature_model = backbone.Conv6NP
#         elif params.model == 'Conv4S': 
#             feature_model = backbone.Conv4SNP
#         else:
#             feature_model = lambda: model_dict[params.model]( flatten = False )
        loss_type = 'mse' if params.method == 'relationnet' else 'softmax'

        model           = RelationNet( backbone_func, loss_type = loss_type , **few_shot_params )
    elif params.method in ['maml' , 'maml_approx']:
        backbone.ConvBlock.maml = True
        backbone.SimpleBlock.maml = True
        backbone.BottleneckBlock.maml = True
        backbone.ResNet.maml = True
        model           = MAML(  backbone_func, approx = (params.method == 'maml_approx') , **few_shot_params )
        if 'omniglot' in params.dataset or 'cross_char' in params.dataset:
#         if params.dataset in ['omniglot', 'cross_char', 'cross_char_half']: #maml use different parameter in omniglot
            model.n_task     = 32
            model.task_update_num = 1
            model.train_lr = 0.1
    else:
        raise ValueError('Unexpected params.method: %s'%(params.method))
    
    print('get_model() finished.')
    return model
Exemplo n.º 7
0
    return acc


if __name__ == '__main__':
    params = parse_args('test')

    acc_all = []

    if params.dataset == 'CUB':
        iter_num = 600
    else:
        iter_num = 10000

    few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot)

    model = BaselineFinetune(model_dict[params.model], **few_shot_params)

    if torch.cuda.is_available():
        model = model.cuda()
    if params.dct_status == False:
        params.channels = 3

    checkpoint_dir = '%s/checkpoints/%s/%s_%s_%sway_%sshot' % (
        configs.save_dir, params.dataset, params.model, params.method,
        params.test_n_way, params.n_shot)
    if params.train_aug:
        checkpoint_dir += '_aug'

    if params.dct_status:
        checkpoint_dir += '_dct'
def get_logits_targets(params):
    acc_all = []
    iter_num = 600
    few_shot_params = dict(n_way = params.test_n_way , n_support = params.n_shot) 

    if params.dataset in ['omniglot', 'cross_char']:
        assert params.model == 'Conv4' and not params.train_aug ,'omniglot only support Conv4 without augmentation'
        params.model = 'Conv4S'

    if params.method == 'baseline':
        model           = BaselineFinetune( model_dict[params.model], **few_shot_params )
    elif params.method == 'baseline++':
        model           = BaselineFinetune( model_dict[params.model], loss_type = 'dist', **few_shot_params )
    elif params.method == 'protonet':
        model           = ProtoNet( model_dict[params.model], **few_shot_params )
    elif params.method == 'DKT':
        model           = DKT(model_dict[params.model], **few_shot_params)
    elif params.method == 'matchingnet':
        model           = MatchingNet( model_dict[params.model], **few_shot_params )
    elif params.method in ['relationnet', 'relationnet_softmax']:
        if params.model == 'Conv4': 
            feature_model = backbone.Conv4NP
        elif params.model == 'Conv6': 
            feature_model = backbone.Conv6NP
        elif params.model == 'Conv4S': 
            feature_model = backbone.Conv4SNP
        else:
            feature_model = lambda: model_dict[params.model]( flatten = False )
        loss_type = 'mse' if params.method == 'relationnet' else 'softmax'
        model           = RelationNet( feature_model, loss_type = loss_type , **few_shot_params )
    elif params.method in ['maml' , 'maml_approx']:
        backbone.ConvBlock.maml = True
        backbone.SimpleBlock.maml = True
        backbone.BottleneckBlock.maml = True
        backbone.ResNet.maml = True
        model = MAML(  model_dict[params.model], approx = (params.method == 'maml_approx') , **few_shot_params )
        if params.dataset in ['omniglot', 'cross_char']: #maml use different parameter in omniglot
            model.n_task     = 32
            model.task_update_num = 1
            model.train_lr = 0.1
    else:
       raise ValueError('Unknown method')

    model = model.cuda()

    checkpoint_dir = '%s/checkpoints/%s/%s_%s' %(configs.save_dir, params.dataset, params.model, params.method)
    if params.train_aug:
        checkpoint_dir += '_aug'
    if not params.method in ['baseline', 'baseline++'] :
        checkpoint_dir += '_%dway_%dshot' %( params.train_n_way, params.n_shot)

    #modelfile   = get_resume_file(checkpoint_dir)

    if not params.method in ['baseline', 'baseline++'] : 
        if params.save_iter != -1:
            modelfile   = get_assigned_file(checkpoint_dir,params.save_iter)
        else:
            modelfile   = get_best_file(checkpoint_dir)
        if modelfile is not None:
            tmp = torch.load(modelfile)
            model.load_state_dict(tmp['state'])
        else:
            print("[WARNING] Cannot find 'best_file.tar' in: " + str(checkpoint_dir))

    split = params.split
    if params.save_iter != -1:
        split_str = split + "_" +str(params.save_iter)
    else:
        split_str = split
    if params.method in ['maml', 'maml_approx', 'DKT']: #maml do not support testing with feature
        if 'Conv' in params.model:
            if params.dataset in ['omniglot', 'cross_char']:
                image_size = 28
            else:
                image_size = 84 
        else:
            image_size = 224

        datamgr         = SetDataManager(image_size, n_eposide = iter_num, n_query = 15 , **few_shot_params)
        
        if params.dataset == 'cross':
            if split == 'base':
                loadfile = configs.data_dir['miniImagenet'] + 'all.json' 
            else:
                loadfile   = configs.data_dir['CUB'] + split +'.json'
        elif params.dataset == 'cross_char':
            if split == 'base':
                loadfile = configs.data_dir['omniglot'] + 'noLatin.json' 
            else:
                loadfile  = configs.data_dir['emnist'] + split +'.json' 
        else: 
            loadfile    = configs.data_dir[params.dataset] + split + '.json'

        novel_loader     = datamgr.get_data_loader( loadfile, aug = False)
        if params.adaptation:
            model.task_update_num = 100 #We perform adaptation on MAML simply by updating more times.
        model.eval()

        logits_list = list()
        targets_list = list()    
        for i, (x,_) in enumerate(novel_loader):
            logits = model.get_logits(x).detach()
            targets = torch.tensor(np.repeat(range(params.test_n_way), model.n_query)).cuda()
            logits_list.append(logits) #.cpu().detach().numpy())
            targets_list.append(targets) #.cpu().detach().numpy())
    else:
        novel_file = os.path.join( checkpoint_dir.replace("checkpoints","features"), split_str +".hdf5")
        cl_data_file = feat_loader.init_loader(novel_file)
        logits_list = list()
        targets_list = list()
        n_query = 15
        n_way = few_shot_params['n_way']
        n_support = few_shot_params['n_support']
        class_list = cl_data_file.keys()
        for i in range(iter_num):
            #----------------------
            select_class = random.sample(class_list,n_way)
            z_all  = []
            for cl in select_class:
                img_feat = cl_data_file[cl]
                perm_ids = np.random.permutation(len(img_feat)).tolist()
                z_all.append( [ np.squeeze( img_feat[perm_ids[i]]) for i in range(n_support+n_query) ] )     # stack each batch
            z_all = torch.from_numpy(np.array(z_all))
            model.n_query = n_query
            logits  = model.set_forward(z_all, is_feature = True).detach()
            targets = torch.tensor(np.repeat(range(n_way), n_query)).cuda()
            logits_list.append(logits)
            targets_list.append(targets)
            #----------------------
    return torch.cat(logits_list, 0), torch.cat(targets_list, 0)
Exemplo n.º 9
0
    iter_num = 600

    few_shot_args = dict(n_way=args.test_n_way, n_support=args.n_shot)

    l3_model = None

    # Always load vocab for compatibility
    if args.dataset == 'CUB':
        vocab = lang_utils.load_vocab(configs.lang_dir)
    else:
        vocab, *_ = lang_utils.load_scenes_vocab(
            './filelists/scenes/AbstractScenes_v1.1/Sentences_1002.txt')

    if args.method == 'baseline':
        model = BaselineFinetune(model_dict[args.model], **few_shot_args)
    elif args.method == 'baseline++':
        model = BaselineFinetune(model_dict[args.model],
                                 loss_type='dist',
                                 **few_shot_args)
    elif args.method == 'protonet':
        lang_model = None
        if args.ml2 or args.l3:
            # Add sos/eos tokens to vocabulary
            embedding_model = nn.Embedding(len(vocab), args.lang_emb_size)
            lang_input_size = 1600 if args.model.startswith('Conv') else 512
            if args.language_task == 'decode':
                # FIXME: This won't work for all models
                lang_model_func = lambda inpsize: TextProposal(
                    embedding_model,
                    input_size=inpsize,
Exemplo n.º 10
0
        if params.dataset == 'omniglot':
            assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class'
        if params.dataset == 'cross_char':
            assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class'

        if params.method == 'baseline':
            train_model = BaselineTrain(model_dict[params.model],
                                        params.num_classes,
                                        loss_type=params.loss_type,
                                        margin=params.margin,
                                        centered=params.centered,
                                        temperature=params.temperature)
            test_model = BaselineFinetune(model_dict[params.model],
                                          params.num_classes,
                                          params.n_shot,
                                          loss_type=params.loss_type,
                                          margin=params.margin,
                                          centered=params.centered,
                                          temperature=params.temperature)

    elif params.method in [
            'simplenet', 'protonet', 'matchingnet', 'relationnet',
            'relationnet_softmax', 'maml', 'maml_approx'
    ]:
        n_query = max(
            1, int(16 * params.test_n_way / params.train_n_way)
        )  # if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small

        train_few_shot_params = dict(n_way=params.train_n_way,
                                     n_support=params.n_shot)
        base_datamgr = SetDataManager(image_size,
    pred = scores.data.cpu().numpy().argmax(axis=1)
    y = np.repeat(range(n_way), n_query)
    acc = np.mean(pred == y) * 100
    return acc


if __name__ == '__main__':
    params = parse_args('test')

    acc_all = []

    iter_num = 600
    few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot)

    if params.method == 'baseline':
        model = BaselineFinetune(model_dict[params.model], **few_shot_params)
    elif params.method == 'protonet':
        model = ProtoNet(model_dict[params.model], **few_shot_params)
    else:
        raise ValueError('Unknown method')

    model = model.cuda()
    checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
        configs.save_dir, 'miniImageNet', params.model, params.method)
    if params.train_aug:
        checkpoint_dir += '_aug'

    if not params.method in ['baseline']:
        checkpoint_dir += '_%dway_%dshot' % (params.train_n_way, params.n_shot)

    if not params.method in ['baseline']:
Exemplo n.º 12
0
    iter_num = 600

    few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot)

    params.reminder = "%s_TempR-%s_%s_Margin-%s_%s" % (
        params.model, params.ratio, params.loss_type, params.large_margin,
        time.time())

    if params.dataset in ['omniglot', 'cross_char']:
        assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation'
        params.model = 'Conv4S'

    if params.method == 'baseline':
        model = BaselineFinetune(model_dict[params.model],
                                 ratio=params.ratio,
                                 loss_type=params.loss_type,
                                 margin=params.large_margin,
                                 centered=params.centered,
                                 **few_shot_params)
    elif params.method == 'simplenet':
        model = SimpleNetFinetune(model_dict[params.model], **few_shot_params)
    elif params.method == 'protonet':
        model = ProtoNet(model_dict[params.model], **few_shot_params)
    elif params.method == 'matchingnet':
        model = MatchingNet(model_dict[params.model], **few_shot_params)
    elif params.method in ['relationnet', 'relationnet_softmax']:
        if params.model == 'Conv4':
            feature_model = backbone.Conv4NP
        elif params.model == 'Conv6':
            feature_model = backbone.Conv6NP
        elif params.model == 'Conv4S':
            feature_model = backbone.Conv4SNP