Exemplo n.º 1
0
def load_presaved_model_for_train(model, params):
    """
    
    """
    start_epoch = params.start_epoch
    stop_epoch = params.stop_epoch
    if params.method == 'maml' or params.method == 'maml_approx':
        stop_epoch = params.stop_epoch * model.n_task  #maml use multiple tasks in one update

    if params.resume:
        resume_file = get_resume_file(params.checkpoint_dir)
        if resume_file is not None:
            tmp = torch.load(resume_file)
            start_epoch = tmp['epoch'] + 1
            model.load_state_dict(tmp['state'])
            del tmp
    elif params.warmup:  #We also support warmup from pretrained baseline feature, but we never used in our paper
        baseline_checkpoint_dir = 'checkpoints/%s/%s_%s' % (
            params.dataset, params.model, 'baseline')
        if params.train_aug:
            baseline_checkpoint_dir += '_aug'
        warmup_resume_file = get_resume_file(baseline_checkpoint_dir)
        tmp = torch.load(warmup_resume_file)
        if tmp is not None:
            state = tmp['state']
            state_keys = list(state.keys())
            for i, key in enumerate(state_keys):
                if "feature." in key:
                    newkey = key.replace(
                        "feature.", ""
                    )  # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx'
                    state[newkey] = state.pop(key)
                else:
                    state.pop(key)
            model.feature.load_state_dict(state)
        else:
            raise ValueError('No warm_up file')

    if params.loadfile != '':
        print('Loading model from: ' + params.loadfile)
        checkpoint = torch.load(params.loadfile)
        ## remove last layer for baseline
        pretrained_dict = {
            k: v
            for k, v in checkpoint['state'].items()
            if 'classifier' not in k and 'loss_fn' not in k
        }
        print('Load model from:', params.loadfile)
        model.load_state_dict(pretrained_dict, strict=False)
    return model, start_epoch, stop_epoch
Exemplo n.º 2
0
 def load_states(self, checkpoint_dir):
     resume_file = get_resume_file(checkpoint_dir)
     tmp = torch.load(resume_file)
     for key in tmp:
         if key in ['epoch', 'state']:
             continue
         state = tmp[key]
         self.losses_engines[key].load_state_dict(state)
Exemplo n.º 3
0
def load_weight_file_for_test(model, params):
    """
    choose the weight file for test process
    """
    if params.loadfile != '':
        modelfile = params.loadfile
        checkpoint_dir = params.loadfile
    else:
        checkpoint_dir = params.checkpoint_dir  # checkpoint path
        if params.save_iter != -1:
            modelfile = get_assigned_file(checkpoint_dir, params.save_iter)
        elif params.method in ['baseline', 'baseline++']:
            modelfile = get_resume_file(checkpoint_dir)
        else:
            modelfile = get_best_file(
                checkpoint_dir)  # return the best.tar file

    assert modelfile, "can not find model weight file in {}".format(
        checkpoint_dir)
    print("use model weight file: ", modelfile)
    if params.method in ['maml', 'maml_approx']:
        if modelfile is not None:
            tmp = torch.load(modelfile)
            state = tmp['state']
            state_keys = list(state.keys())
            for i, key in enumerate(state_keys):
                if "feature." in key:
                    newkey = key.replace(
                        "feature.", ""
                    )  # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx'
                    state[newkey] = state.pop(key)
                else:
                    state.pop(key)
            model.feature.load_state_dict(tmp['state'])

    else:  ## eg: for Protonet and others
        tmp = torch.load(modelfile)
        state = tmp['state']
        state_keys = list(state.keys())
        for i, key in enumerate(state_keys):
            if "feature." in key:
                newkey = key.replace(
                    "feature.", ""
                )  # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx'
                state[newkey] = state.pop(key)
            else:
                state.pop(key)
        ## for protonets

        model.feature.load_state_dict(state)
        model.eval()
        model = model.cuda()
        model.eval()
    return model
Exemplo n.º 4
0
                model = wrn_mixup_model.wrn28_10(
                    num_classes=params.num_classes,
                    dct_status=params.dct_status)
            elif params.model == 'ResNet18':
                model = res_mixup_model.resnet18(
                    num_classes=params.num_classes)

        if params.method == 'baseline++':
            if use_gpu:
                if torch.cuda.device_count() > 1:
                    model = torch.nn.DataParallel(
                        model, device_ids=range(torch.cuda.device_count()))
                model.cuda()

            if params.resume:
                resume_file = get_resume_file(params.checkpoint_dir)
                tmp = torch.load(resume_file)
                start_epoch = tmp['epoch'] + 1
                state = tmp['state']
                model.load_state_dict(state)
            model = torch.nn.DataParallel(model).cuda()
            cudnn.benchmark = True
            optimization = 'Adam'
            model = train_baseline(base_loader, base_loader_test, val_loader,
                                   model, start_epoch,
                                   start_epoch + stop_epoch, params, {})

        elif params.method == 'S2M2_R':
            if use_gpu:
                if torch.cuda.device_count() > 1:
                    model = torch.nn.DataParallel(
Exemplo n.º 5
0
def test_loop(novel_loader,
              return_std=False,
              loss_type="softmax",
              n_query=15,
              models_to_use=[],
              finetune_each_model=False,
              n_way=5,
              n_support=5):  #overwrite parrent function
    correct = 0
    count = 0

    iter_num = len(novel_loader)

    acc_all = []
    for _, (x, y) in enumerate(novel_loader):

        ###############################################################################################
        pretrained_models = []
        for _ in range(len(models_to_use)):
            pretrained_models.append(model_dict[params.model]())

        ###############################################################################################
        for idx, dataset_name in enumerate(models_to_use):

            checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
                configs.save_dir, models_to_use[idx], params.model,
                params.method)
            if params.train_aug:
                checkpoint_dir += '_aug'

            params.save_iter = -1
            if params.save_iter != -1:
                modelfile = get_assigned_file(checkpoint_dir, params.save_iter)
            elif params.method in ['baseline', 'baseline++']:
                modelfile = get_resume_file(checkpoint_dir)
            else:
                modelfile = get_best_file(checkpoint_dir)

            tmp = torch.load(modelfile)
            state = tmp['state']

            state_keys = list(state.keys())
            for _, key in enumerate(state_keys):
                if "feature." in key:
                    newkey = key.replace(
                        "feature.", ""
                    )  # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx'
                    state[newkey] = state.pop(key)
                else:
                    state.pop(key)

            pretrained_models[idx].load_state_dict(state)

        ###############################################################################################
        n_query = x.size(1) - n_support
        x = x.cuda()
        x_var = Variable(x)

        batch_size = 4
        support_size = n_way * n_support

        ##################################################################################
        if finetune_each_model:

            for idx, model_name in enumerate(pretrained_models):
                pretrained_models[idx].cuda()
                pretrained_models[idx].train()

                x_a_i = x_var[:, :n_support, :, :, :].contiguous().view(
                    n_way * n_support,
                    *x.size()[2:])  # (25, 3, 224, 224)

                loss_fn = nn.CrossEntropyLoss().cuda()
                cnet = Classifier(pretrained_models[idx].final_feat_dim,
                                  n_way).cuda()

                classifier_opt = torch.optim.SGD(cnet.parameters(),
                                                 lr=0.01,
                                                 momentum=0.9,
                                                 dampening=0.9,
                                                 weight_decay=0.001)
                feature_opt = torch.optim.SGD(
                    pretrained_models[idx].parameters(),
                    lr=0.01,
                    momentum=0.9,
                    dampening=0.9,
                    weight_decay=0.001)

                x_a_i = Variable(x_a_i).cuda()
                y_a_i = Variable(
                    torch.from_numpy(np.repeat(range(n_way),
                                               n_support))).cuda()  # (25,)

                train_size = support_size
                batch_size = 4
                for epoch in range(100):
                    rand_id = np.random.permutation(train_size)

                    for j in range(0, train_size, batch_size):
                        classifier_opt.zero_grad()
                        feature_opt.zero_grad()

                        #####################################
                        selected_id = torch.from_numpy(
                            rand_id[j:min(j + batch_size, train_size)]).cuda()
                        z_batch = x_a_i[selected_id]

                        y_batch = y_a_i[selected_id]
                        #####################################
                        outputs = pretrained_models[idx](z_batch)
                        outputs = cnet(outputs)
                        #####################################

                        loss = loss_fn(outputs, y_batch)
                        loss.backward()

                        for k, param in enumerate(
                                pretrained_models[idx].parameters()):
                            param.grad[torch.lt(
                                torch.abs(param.grad),
                                torch.abs(param.grad).median())] = 0.0

                        classifier_opt.step()
                        feature_opt.step()

        ###############################################################################################
        for idx, model_name in enumerate(pretrained_models):
            pretrained_models[idx].cuda()
            pretrained_models[idx].eval()

        ###############################################################################################

        all_embeddings_train = []

        for idx, model_name in enumerate(pretrained_models):
            model_embeddings = []
            x_a_i = x_var[:, :n_support, :, :, :].contiguous().view(
                n_way * n_support,
                *x.size()[2:])  # (25, 3, 224, 224)
            for idx, module in enumerate(pretrained_models[idx].trunk):
                x_a_i = module(x_a_i)
                if len(list(x_a_i.size())) == 4:
                    embedding = F.adaptive_avg_pool2d(x_a_i, (1, 1)).squeeze()
                    model_embeddings.append(embedding.detach())

            if params.model == "ResNet10" or params.model == "ResNet18":
                model_embeddings = model_embeddings[4:-1]

            elif params.model == "Conv4":
                model_embeddings = model_embeddings

            all_embeddings_train.append(model_embeddings)

        ##########################################################

        y_a_i = np.repeat(range(n_way), n_support)
        embeddings_idx_of_each, embeddings_idx_model, embeddings_train, embeddings_best_of_each = train_selection(
            all_embeddings_train,
            y_a_i,
            support_size,
            n_support,
            n_way,
            with_replacement=True)

        ##########################################################

        all_embeddings_test = []

        for idx, model_name in enumerate(pretrained_models):
            model_embeddings = []

            x_b_i = x_var[:, n_support:, :, :, :].contiguous().view(
                n_way * n_query,
                *x.size()[2:])
            for idx, module in enumerate(pretrained_models[idx].trunk):
                x_b_i = module(x_b_i)
                if len(list(x_b_i.size())) == 4:
                    embedding = F.adaptive_avg_pool2d(x_b_i, (1, 1)).squeeze()
                    model_embeddings.append(embedding.detach())

            if params.model == "ResNet10" or params.model == "ResNet18":
                model_embeddings = model_embeddings[4:-1]

            elif params.model == "Conv4":
                model_embeddings = model_embeddings

            all_embeddings_test.append(model_embeddings)

        ############################################################################################
        embeddings_test = []

        for index in embeddings_idx_model:
            embeddings_test.append(
                all_embeddings_test[index][embeddings_idx_of_each[index]])

        embeddings_test = torch.cat(embeddings_test, 1)
        ############################################################################################

        y_a_i = Variable(torch.from_numpy(np.repeat(
            range(n_way), n_support))).cuda()  # (25,)

        net = Classifier(embeddings_test.size()[1], n_way).cuda()

        loss_fn = nn.CrossEntropyLoss().cuda()

        classifier_opt = torch.optim.SGD(net.parameters(),
                                         lr=0.01,
                                         momentum=0.9,
                                         dampening=0.9,
                                         weight_decay=0.001)

        total_epoch = 100
        embeddings_train = Variable(embeddings_train.cuda())

        net.train()
        for epoch in range(total_epoch):
            rand_id = np.random.permutation(support_size)

            for j in range(0, support_size, batch_size):
                classifier_opt.zero_grad()

                #####################################
                selected_id = torch.from_numpy(
                    rand_id[j:min(j + batch_size, support_size)]).cuda()
                z_batch = embeddings_train[selected_id]

                y_batch = y_a_i[selected_id]
                #####################################
                outputs = net(z_batch)
                #####################################

                loss = loss_fn(outputs, y_batch)

                loss.backward()
                classifier_opt.step()

        embeddings_test = Variable(embeddings_test.cuda())

        scores = net(embeddings_test)

        y_query = np.repeat(range(n_way), n_query)
        topk_scores, topk_labels = scores.data.topk(1, 1, True, True)
        topk_ind = topk_labels.cpu().numpy()

        top1_correct = np.sum(topk_ind[:, 0] == y_query)
        correct_this, count_this = float(top1_correct), len(y_query)
        print(correct_this / count_this * 100)
        acc_all.append((correct_this / count_this * 100))

        ###############################################################################################

    acc_all = np.asarray(acc_all)
    acc_mean = np.mean(acc_all)
    acc_std = np.std(acc_all)
    print('%d Test Acc = %4.2f%% +- %4.2f%%' %
          (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))
    def __init__(self, module):
        super(WrappedModel, self).__init__()
        self.module = module # that I actually define.
    def forward(self, x):
        return self.module(x)



model = backbone.WideResNet28_10( flatten = True, beta_value = 50.)
checkpoint_dir = './checkpoints/%s/%s_%s_%s' %('cifar', 'WideResNet28_10', 'art' ,  'cifar')
model = WrappedModel(model)


    
print("resuming" , checkpoint_dir)
resume_file = get_resume_file(checkpoint_dir)
if resume_file is not None:
    print("resume_file" , resume_file)
    tmp = torch.load(resume_file)
    model.load_state_dict(tmp['state'])
else:
    print("error no file found")
    exit()

    
model = model.cuda()
model.eval()

def normalize(x):
    mean = torch.tensor([0.4914, 0.4822, 0.4465])
    std = torch.tensor([0.2023, 0.1994, 0.2010])
Exemplo n.º 7
0
                                                   aug=params.train_aug)
        val_datamgr = SimpleDataManager(image_size, batch_size=64)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)
        model = SSL_Train(model_dict[params.model], params.num_classes)
    else:
        raise ValueError('Unknown method')

    model = model.cuda()
    #Prepare checkpoint_dir
    params.checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
        configs.save_dir, params.dataset, params.model, params.method)
    if params.train_aug:
        params.checkpoint_dir += '_aug'
    if not os.path.isdir(params.checkpoint_dir):
        os.makedirs(params.checkpoint_dir)

    print('checkpoint_dir', params.checkpoint_dir)

    if params.resume:
        resume_file = get_resume_file(params.checkpoint_dir)
        if resume_file is not None:
            tmp = torch.load(resume_file)
            start_epoch = tmp['epoch'] + 1
            model.load_state_dict(tmp['state'])

    start_epoch = params.start_epoch
    stop_epoch = params.stop_epoch

    model = train(base_loader, val_loader, model, optimization, start_epoch,
                  stop_epoch, params)
Exemplo n.º 8
0
        lsl=args.lsl,
        language_model=lang_model,
        lang_supervision=args.lang_supervision,
        l3=args.l3,
        l3_model=l3_model,
        l3_n_infer=args.l3_n_infer)

    model = model.cuda()

    os.makedirs(args.checkpoint_dir, exist_ok=True)

    start_epoch = args.start_epoch
    stop_epoch = args.stop_epoch

    if args.resume:
        resume_file = get_resume_file(args.checkpoint_dir)
        if resume_file is not None:
            tmp = torch.load(resume_file)
            start_epoch = tmp["epoch"] + 1
            model.load_state_dict(tmp["state"])

    metrics_fname = "metrics_{}.json".format(args.n)

    train(
        base_loader,
        val_loader,
        model,
        start_epoch,
        stop_epoch,
        args,
        metrics_fname=metrics_fname,
Exemplo n.º 9
0
def finetune(novel_loader, n_query = 15, pretrained_dataset='miniImageNet', freeze_backbone = False, n_way = 5, n_support = 5): 
    correct = 0
    count = 0

    iter_num = len(novel_loader) 

    acc_all = []

    for _, (x, y) in enumerate(novel_loader):

        ###############################################################################################
        # load pretrained model on miniImageNet
        pretrained_model = model_dict[params.model]()

        checkpoint_dir = '%s/checkpoints/%s/%s_%s' %(configs.save_dir, pretrained_dataset, params.model, params.method)
        if params.train_aug:
            checkpoint_dir += '_aug'

        params.save_iter = -1
        if params.save_iter != -1:
            modelfile   = get_assigned_file(checkpoint_dir, params.save_iter)
        elif params.method in ['baseline', 'baseline++'] :
            modelfile   = get_resume_file(checkpoint_dir)
        else:
            modelfile   = get_best_file(checkpoint_dir)


        tmp = torch.load(modelfile)
        state = tmp['state']

        state_keys = list(state.keys())
        for _, key in enumerate(state_keys):
            if "feature." in key:
                newkey = key.replace("feature.","")  # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx'  
                state[newkey] = state.pop(key)
            else:
                state.pop(key)


        pretrained_model.load_state_dict(state)
        ###############################################################################################

        classifier = Classifier(pretrained_model.final_feat_dim, n_way)

        ###############################################################################################
        n_query = x.size(1) - n_support
        x = x.cuda()
        x_var = Variable(x)

    
        batch_size = 4
        support_size = n_way * n_support 
       
        y_a_i = Variable( torch.from_numpy( np.repeat(range( n_way ), n_support ) )).cuda() # (25,)

        x_b_i = x_var[:, n_support:,:,:,:].contiguous().view( n_way* n_query,   *x.size()[2:]) 
        x_a_i = x_var[:,:n_support,:,:,:].contiguous().view( n_way* n_support, *x.size()[2:]) # (25, 3, 224, 224)

         ###############################################################################################
        loss_fn = nn.CrossEntropyLoss().cuda()
        classifier_opt = torch.optim.SGD(classifier.parameters(), lr = 0.01, momentum=0.9, dampening=0.9, weight_decay=0.001)
        

        if freeze_backbone is False:
            delta_opt = torch.optim.SGD(filter(lambda p: p.requires_grad, pretrained_model.parameters()), lr = 0.01)


        pretrained_model.cuda()
        classifier.cuda()
        ###############################################################################################
        total_epoch = 100

        if freeze_backbone is False:
            pretrained_model.train()
        else:
            pretrained_model.eval()
        
        classifier.train()

        for epoch in range(total_epoch):
            rand_id = np.random.permutation(support_size)

            for j in range(0, support_size, batch_size):
                classifier_opt.zero_grad()
                if freeze_backbone is False:
                    delta_opt.zero_grad()

                #####################################
                selected_id = torch.from_numpy( rand_id[j: min(j+batch_size, support_size)]).cuda()
               
                z_batch = x_a_i[selected_id]
                y_batch = y_a_i[selected_id] 
                #####################################

                output = pretrained_model(z_batch)
                output = classifier(output)
                loss = loss_fn(output, y_batch)

                #####################################
                loss.backward()

                classifier_opt.step()
                
                if freeze_backbone is False:
                    delta_opt.step()

        pretrained_model.eval()
        classifier.eval()

        output = pretrained_model(x_b_i.cuda())
        scores = classifier(output)
       
        y_query = np.repeat(range( n_way ), n_query )
        topk_scores, topk_labels = scores.data.topk(1, 1, True, True)
        topk_ind = topk_labels.cpu().numpy()
        
        top1_correct = np.sum(topk_ind[:,0] == y_query)
        correct_this, count_this = float(top1_correct), len(y_query)
        print (correct_this/ count_this *100)
        acc_all.append((correct_this/ count_this *100))
        
        ###############################################################################################

    acc_all  = np.asarray(acc_all)
    acc_mean = np.mean(acc_all)
    acc_std  = np.std(acc_all)
    print('%d Test Acc = %4.2f%% +- %4.2f%%' %(iter_num,  acc_mean, 1.96* acc_std/np.sqrt(iter_num)))
Exemplo n.º 10
0
def run(params):
    if params.dataset == 'cross':
        base_file = configs.data_dir['miniImagenet'] + 'all.json'
        val_file = configs.data_dir['CUB'] + 'val.json'
    elif params.dataset == 'cross_char':
        base_file = configs.data_dir['omniglot'] + 'noLatin.json'
        val_file = configs.data_dir['emnist'] + 'val.json'
    else:
        if params.base_json:
            print(f'Using base classes from {params.base_json}')
            base_file = params.base_json
        else:
            base_file = configs.data_dir[params.dataset] + 'base.json'
        val_file = configs.data_dir[params.dataset] + 'val.json'

    image_size = get_image_size(params)

    if params.dataset in ['omniglot', 'cross_char']:
        assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation'
        params.model = 'Conv4S'

    optimization = 'Adam'

    if params.stop_epoch == -1:
        if params.method in ['baseline', 'baseline++']:
            if params.dataset in ['omniglot', 'cross_char']:
                params.stop_epoch = 5
            elif params.dataset in ['CUB']:
                params.stop_epoch = 200  # This is different as stated in the open-review paper. However, using 400 epoch in baseline actually lead to over-fitting
            elif params.dataset in ['miniImagenet', 'cross']:
                params.stop_epoch = 400
            else:
                params.stop_epoch = 400  #default
        else:  #meta-learning methods
            if params.n_shot == 1:
                params.stop_epoch = 600


#             elif params.n_shot == 5:
#                 params.stop_epoch = 400
            else:
                params.stop_epoch = 600  #default for 5-shot

    if params.method in ['baseline', 'baseline++']:
        base_datamgr = SimpleDataManager(image_size, batch_size=16)
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   aug=params.train_aug)
        val_datamgr = SimpleDataManager(image_size, batch_size=64)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)

        if params.dataset == 'omniglot':
            assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class'
        if params.dataset == 'cross_char':
            assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class'

        if params.method == 'baseline':
            model = BaselineTrain(model_dict[params.model], params.num_classes)
        elif params.method == 'baseline++':
            model = BaselineTrain(model_dict[params.model],
                                  params.num_classes,
                                  loss_type='dist')

    elif params.method in [
            'protonet', 'matchingnet', 'relationnet', 'relationnet_softmax',
            'maml', 'maml_approx'
    ]:
        n_query = max(
            1, int(16 * params.test_n_way / params.train_n_way)
        )  #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small
        if 'n_episode' not in params:
            if params.stop_epoch >= 500:
                params.n_episode = 1000
                params.stop_epoch = int(params.stop_epoch / 10)
            else:
                params.n_episode = 100
        print(f'| Using {params.n_episode} n_episode for trainloader...')
        print(f'| Using Stop epoch {params.stop_epoch}')
        train_few_shot_params = dict(n_way=params.train_n_way,
                                     n_support=params.n_shot)
        base_datamgr = SetDataManager(image_size,
                                      n_query=n_query,
                                      n_episode=params.n_episode,
                                      **train_few_shot_params)
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   aug=params.train_aug)

        test_few_shot_params = dict(n_way=params.test_n_way,
                                    n_support=params.n_shot)
        val_datamgr = SetDataManager(image_size,
                                     n_query=n_query,
                                     **test_few_shot_params)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)
        #a batch for SetDataManager: a [n_way, n_support + n_query, dim, w, h] tensor
        # n_way: 5, n_support: 5, n_query: 16

        if params.method == 'protonet':
            model = ProtoNet(model_dict[params.model], **train_few_shot_params)
        elif params.method == 'matchingnet':
            model = MatchingNet(model_dict[params.model],
                                **train_few_shot_params)
        elif params.method in ['relationnet', 'relationnet_softmax']:
            if params.model == 'Conv4':
                feature_model = backbone.Conv4NP
            elif params.model == 'Conv6':
                feature_model = backbone.Conv6NP
            elif params.model == 'Conv4S':
                feature_model = backbone.Conv4SNP
            else:
                feature_model = lambda: model_dict[params.model](flatten=False)
            loss_type = 'mse' if params.method == 'relationnet' else 'softmax'

            model = RelationNet(feature_model,
                                loss_type=loss_type,
                                **train_few_shot_params)
        elif params.method in ['maml', 'maml_approx']:
            backbone.ConvBlock.maml = True
            backbone.SimpleBlock.maml = True
            backbone.BottleneckBlock.maml = True
            backbone.ResNet.maml = True
            model = MAML(model_dict[params.model],
                         approx=(params.method == 'maml_approx'),
                         **train_few_shot_params)
            if params.dataset in ['omniglot', 'cross_char'
                                  ]:  #maml use different parameter in omniglot
                model.n_task = 32
                model.task_update_num = 1
                model.train_lr = 0.1
    else:
        raise ValueError('Unknown method')

    model = model.cuda()
    print(model)

    if hasattr(params, 'logdir'):
        params.checkpoint_dir = params.logdir
    else:
        params.checkpoint_dir = '%s/ckpts/%s/%s_%s_%s' % (
            configs.save_dir, params.dataset, params.model, params.method,
            params.base_json)
        if params.train_aug:
            params.checkpoint_dir += '_aug'
        if not params.method in ['baseline', 'baseline++']:
            params.checkpoint_dir += '_%dway_%dshot' % (params.train_n_way,
                                                        params.n_shot)

        if not os.path.isdir(params.checkpoint_dir):
            os.makedirs(params.checkpoint_dir)

    start_epoch = params.start_epoch
    stop_epoch = params.stop_epoch
    if params.method == 'maml' or params.method == 'maml_approx':
        stop_epoch = params.stop_epoch * model.n_task  #maml use multiple tasks in one update

    if params.resume:
        resume_file = get_resume_file(params.checkpoint_dir)
        if resume_file is not None:
            tmp = torch.load(resume_file)
            start_epoch = tmp['epoch'] + 1
            model.load_state_dict(tmp['state'])
    elif params.warmup:  #We also support warmup from pretrained baseline feature, but we never used in our paper
        baseline_checkpoint_dir = '%s/ckpts/%s/%s_%s' % (
            configs.save_dir, params.dataset, params.model, 'baseline')
        if params.train_aug:
            baseline_checkpoint_dir += '_aug'
        warmup_resume_file = get_resume_file(baseline_checkpoint_dir)
        tmp = torch.load(warmup_resume_file)
        if tmp is not None:
            state = tmp['state']
            state_keys = list(state.keys())
            for i, key in enumerate(state_keys):
                if "feature." in key:
                    newkey = key.replace(
                        "feature.", ""
                    )  # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx'
                    state[newkey] = state.pop(key)
                else:
                    state.pop(key)
            model.feature.load_state_dict(state)
        else:
            raise ValueError('No warm_up file')

    model = train(base_loader, val_loader, model, optimization, start_epoch,
                  stop_epoch, params)
Exemplo n.º 11
0
def meta_test(novel_loader,
              n_query=15,
              pretrained_dataset='miniImageNet',
              freeze_backbone=False,
              n_way=5,
              n_support=5):
    #novel_loader has 600 dataloaders
    #n_query=15
    #pretrained_dataset=miniImageNet
    #freeze_backbone=True
    #n_way=5
    #n_support = 5
    correct = 0
    count = 0

    iter_num = len(novel_loader)  #600

    acc_all = []

    for ti, (x, y) in enumerate(novel_loader):

        ###############################################################################################
        # load pretrained model on miniImageNet
        pretrained_model = model_dict[params.model]()

        checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
            configs.save_dir, pretrained_dataset, params.model, params.method)
        if params.train_aug:
            checkpoint_dir += '_aug'

        params.save_iter = -1
        if params.save_iter != -1:
            modelfile = get_assigned_file(checkpoint_dir, params.save_iter)
        elif params.method in ['baseline', 'baseline++']:
            modelfile = get_resume_file(checkpoint_dir)
        else:
            modelfile = get_best_file(checkpoint_dir)
        print(
            "load from %s" % (modelfile)
        )  #"./logs/checkpoints/miniImagenet/ResNet10_baseline_aug/399.pth"

        tmp = torch.load(modelfile)
        state = tmp['state']
        state_keys = list(state.keys())
        for _, key in enumerate(state_keys):
            if "feature." in key:
                newkey = key.replace("feature.", "")
                state[newkey] = state.pop(key)  #replace key name
            else:
                state.pop(key)  #remove classifier
        pretrained_model.load_state_dict(state)  #load checkpoints

        # train a new linear classifier
        classifier = Classifier(
            pretrained_model.final_feat_dim,
            n_way)  #initializ only classifier with shape (512,5) for each task

        ###############################################################################################
        # split data into support set(5) and query set(15)
        n_query = x.size(1) - n_support
        #print(x.size())#torch.Size([5, 20, 3, 224, 224])
        #print(n_support)#5
        #print("n_query:%d"%(n_query))#15
        x = x.cuda()
        x_var = Variable(x)
        #print(x_var.data.shape)#torch.Size([5, 20, 3, 224, 224])
        # number of dataloaders is 5 and the real input is (20,3,224,224)
        #print(y)#however, y is useless and its shape is (5,20) => batch=5 and label=20

        batch_size = 4
        support_size = n_way * n_support  #5*5=25  (maybe 5-way and each way contains 5 samples)

        y_a_i = Variable(torch.from_numpy(np.repeat(range(n_way),
                                                    n_support))).cuda()
        #np.repeat(range( n_way ), n_support )=[0,0,0,0,0,1,1,1,1,1,2,2,2,2,2,3,3,3,3,3,4,4,4,4,4]
        #print(y_a_i.data.shape)#torch.Size([25])

        #n_way=5 and n_query=15, view(75,3,224,224)
        #x_var[:, n_support:,:,:,:].shape=(5,15,3,224,224) => sample 5 loaders, where each contains a batch of images with shape (15,3,224,224)
        x_b_i = x_var[:, n_support:, :, :, :].contiguous().view(
            n_way * n_query,
            *x.size()[2:])  # query set
        #print(x_b_i.shape)#(75,3,224,224)  # 5 class loaders in total. Thus, batch size = 15*5 =75
        #x_b_i.shape=75,3,224,224
        #n_way * n_query ... (maybe 5-way and each way contains 15 samples)

        #n_way=5 and n_support=5, view(25,3,224,224)
        #x_var[:, :n_support,:,:,:].shape=(5,5,3,224,224)
        x_a_i = x_var[:, :n_support, :, :, :].contiguous().view(
            n_way * n_support,
            *x.size()[2:])  # support set
        #x_a_u.shape=25,3,224,224

        ################################################################################################
        # loss function and optimizer setting
        loss_fn = nn.CrossEntropyLoss().cuda()
        classifier_opt = torch.optim.SGD(classifier.parameters(),
                                         lr=0.01,
                                         momentum=0.9,
                                         dampening=0.9,
                                         weight_decay=0.001)

        if freeze_backbone is False:  #for finetune use
            delta_opt = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                               pretrained_model.parameters()),
                                        lr=0.01)

        pretrained_model.cuda(
        )  #pretrained on "mini-ImageNet" instead of "ImageNet"
        classifier.cuda()
        ###############################################################################################
        # fine-tuning
        #In the fine-tuning or meta-testing stage for all methods, we average the results over 600 experiments.
        #In each experiment, we randomly sample 5 classes from novel classes, and in each class, we also
        #pick k instances for the support set and 16 for the query set.
        #For Baseline and Baseline++, we use the entire support set to train a new classifier for 100 iterations with a batch size of 4.
        #For meta-learning methods, we obtain the classification model conditioned on the support set
        total_epoch = 100

        if freeze_backbone is False:  #for finetune use
            pretrained_model.train()
        else:  # if you don't want finetune
            pretrained_model.eval()

        classifier.train(
        )  #classifier should be dependent on task. Thus, we should update the classifier weights

        for epoch in range(total_epoch):  #train classifier 100 epoch
            rand_id = np.random.permutation(support_size)  #rand_id.shape=25
            #support_size=25
            #batch_size=4
            # using "support set" to train the classifier (and fine-tune the backbone).
            for j in range(0, support_size,
                           batch_size):  #support_size=25, batch_size=4
                classifier_opt.zero_grad()  #clear classifier optimizer
                if freeze_backbone is False:  #for finetune use
                    delta_opt.zero_grad()  #update feature extractor

                selected_id = torch.from_numpy(
                    rand_id[j:min(j + batch_size, support_size)]).cuda(
                    )  #fetch only 4 elements

                #x_a_i.shape=25,3,224,224
                #y_a_i.shape=25
                z_batch = x_a_i[
                    selected_id]  #sample 4 inputs randomly from support set data
                #z_batch.shape=4,3,224,224

                #y_a_i=[0,0,0,0,0,1,1,1,1,1,2,2,2,2,2,3,3,3,3,3,4,4,4,4,4]
                y_batch = y_a_i[
                    selected_id]  #sample 4 labels randomly from support set label
                #y_batch.shape=4

                output = pretrained_model(z_batch)  #feature
                output = classifier(output)  #predictions

                loss = loss_fn(output, y_batch)
                loss.backward()

                classifier_opt.step()  #update classifier optimizer

                if freeze_backbone is False:  #for finetune use
                    delta_opt.step()  #update extractor

        ##############################################################################################
        # inference
        pretrained_model.eval()
        classifier.eval()

        output = pretrained_model(x_b_i.cuda())  #features
        scores = classifier(output)  #predictions

        y_query = np.repeat(range(n_way), n_query)  #shape=(75)
        #y_query=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        #         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        #         2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        #         3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        #         4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]
        topk_scores, topk_labels = scores.data.topk(1, 1, True, True)
        #the 1st argument means return top-1
        #the 2nd argument dim=1 means return the value row-wisely
        #the 3rd arguemtn is largest=True
        #the 4th argument is sorted=True

        #topk_labels=[[1],[1], ..., [0],[0]] with shape (75,1)    cuz batch=75
        topk_ind = topk_labels.cpu().numpy()

        top1_correct = np.sum(topk_ind[:, 0] == y_query)
        correct_this, count_this = float(top1_correct), len(y_query)
        acc_all.append((correct_this / count_this * 100))
        print("Task %d : %4.2f%%  Now avg: %4.2f%%" %
              (ti, correct_this / count_this * 100, np.mean(acc_all)))
        ###############################################################################################

    acc_all = np.asarray(acc_all)
    acc_mean = np.mean(acc_all)
    acc_std = np.std(acc_all)
    print('%d Test Acc = %4.2f%% +- %4.2f%%' %
          (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))
Exemplo n.º 12
0
                                                 batch_size=params.bs,
                                                 shuffle=False,
                                                 num_workers=0)
        config = {
            'epsilon': 8.0 / 255,
            'num_steps': 5,
            'step_size': 2.0 / 255,
            'random_start': True,
            'loss_func': 'xent',
        }

    teacher = model.Model(net='_'.join(params.teacher.split('_')[:-1]),
                          num_classes=params.num_classes)
    teacher_dir = '%s/%s/teacher/%s' % (SAVE_DIR, params.dataset,
                                        params.teacher)
    teacher_file = get_resume_file(teacher_dir)
    print('Teacher file:', teacher_file)
    tmp = torch.load(teacher_file)
    teacher.feature.load_state_dict(tmp['feature'])
    teacher.classifier.load_state_dict(tmp['classifier'])
    teacher.eval()

    model = model.Model(net=params.model, num_classes=params.num_classes)

    optimization = 'Adam'
    params.checkpoint_dir = '%s/%s/student2/%s_%s_%s' % (
        SAVE_DIR, params.dataset, params.model, params.method, params.teacher)
    if params.exp != 'gbp':
        params.checkpoint_dir += '_%s' % (params.exp)
    if params.e != 8.0:
        params.checkpoint_dir += '_eps{}'.format(params.e)
Exemplo n.º 13
0
def train_s2m2(base_loader, base_loader_test, model, params, tmp):
    def mixup_criterion(criterion, pred, y_a, y_b, lam):
        return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

    criterion = nn.CrossEntropyLoss()

    rotate_classifier = nn.Sequential(nn.Linear(640, 4))
    rotate_classifier.to(device)
    model.to(device)

    if 'rotate' in tmp:
        print("loading rotate model")
        rotate_classifier.load_state_dict(tmp['rotate'])

    optimizer = torch.optim.Adam([{
        'params': model.parameters()
    }, {
        'params': rotate_classifier.parameters()
    }])

    start_epoch, stop_epoch = params.start_epoch, params.start_epoch + params.stop_epoch

    if params.resume:
        checkpoint = get_resume_file(params.checkpoint_dir)
        print('resumefile: {}'.format(checkpoint))
        checkpoint = torch.load(checkpoint, map_location=device)
        model.load_state_dict(checkpoint['state'])
        start_epoch = checkpoint['epoch']
        print('Model loaded')

    print("stop_epoch", start_epoch, stop_epoch)

    for epoch in range(start_epoch, stop_epoch):
        print('\nEpoch: %d' % epoch)

        model.train()
        train_loss, rotate_loss = 0, 0
        correct, total = 0, 0
        torch.cuda.empty_cache()

        for batch_idx, (inputs, targets) in enumerate(base_loader):

            inputs, targets = inputs.to(device), targets.to(device)
            lam = np.random.beta(params.alpha, params.alpha)
            f, outputs, target_a, target_b = model(inputs,
                                                   targets,
                                                   mixup_hidden=True,
                                                   mixup_alpha=params.alpha,
                                                   lam=lam)
            loss = mixup_criterion(criterion, outputs, target_a, target_b, lam)
            train_loss += loss.data.item()
            optimizer.zero_grad()
            loss.backward()

            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += (
                lam * predicted.eq(target_a.data).cpu().sum().float() +
                (1 - lam) * predicted.eq(target_b.data).cpu().sum().float())

            bs = inputs.size(0)
            inputs_, targets_, a_ = [], [], []
            indices = np.arange(bs)
            np.random.shuffle(indices)

            split_size = int(bs / 4)
            for j in indices[0:split_size]:
                x90 = inputs[j].transpose(2, 1).flip(1)
                x180 = x90.transpose(2, 1).flip(1)
                x270 = x180.transpose(2, 1).flip(1)
                inputs_ += [inputs[j], x90, x180, x270]
                targets_ += [targets[j] for _ in range(4)]
                a_ += [
                    torch.tensor(0),
                    torch.tensor(1),
                    torch.tensor(2),
                    torch.tensor(3)
                ]

            inputs = Variable(torch.stack(inputs_, 0))
            targets = Variable(torch.stack(targets_, 0))
            a_ = Variable(torch.stack(a_, 0))

            inputs, targets, a_ = inputs.to(device), targets.to(device), a_.to(
                device)

            rf, outputs = model(inputs)
            rotate_outputs = rotate_classifier(rf)
            rloss = criterion(rotate_outputs, a_)
            closs = criterion(outputs, targets)

            loss = (rloss + closs) / 2.0
            rotate_loss += rloss.data.item()
            loss.backward()
            optimizer.step()

            if (batch_idx + 1) % 50 == 0:
                print(
                    '{0}/{1}'.format(batch_idx, len(base_loader)),
                    'Loss: %.3f | Acc: %.3f%% | RotLoss: %.3f  ' %
                    (train_loss /
                     (batch_idx + 1), 100. * correct / total, rotate_loss /
                     (batch_idx + 1)))

        if (epoch % params.save_freq == 0) or (epoch == stop_epoch - 1):
            if not os.path.isdir(params.checkpoint_dir):
                os.makedirs(params.checkpoint_dir)

            outfile = os.path.join(params.checkpoint_dir,
                                   '{:d}.tar'.format(epoch))
            torch.save({'epoch': epoch, 'state': model.state_dict()}, outfile)
            print('Model saved')

        test_s2m2(base_loader_test, model, criterion)

    return model
Exemplo n.º 14
0
    if params.train_aug:
        params.checkpoint_dir += '_aug'
    if not params.method in ['baseline', 'baseline++']:
        params.checkpoint_dir += '_%dway_%dshot' % (params.train_n_way,
                                                    params.n_shot)

    if not os.path.isdir(params.checkpoint_dir):
        os.makedirs(params.checkpoint_dir)

    start_epoch = params.start_epoch
    stop_epoch = params.stop_epoch
    if params.method == 'maml' or params.method == 'maml_approx':
        stop_epoch = params.stop_epoch * model.n_task  #maml use multiple tasks in one update

    if params.resume:
        resume_file = get_resume_file(params.checkpoint_dir)
        if resume_file is not None:
            tmp = torch.load(resume_file)
            start_epoch = tmp['epoch'] + 1
            model.load_state_dict(tmp['state'])
    elif params.warmup:  #We also support warmup from pretrained baseline feature, but we never used in our paper
        baseline_checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
            configs.save_dir, params.dataset, params.model, 'baseline')
        if params.train_aug:
            baseline_checkpoint_dir += '_aug'
        warmup_resume_file = get_resume_file(baseline_checkpoint_dir)
        tmp = torch.load(warmup_resume_file)
        if tmp is not None:
            state = tmp['state']
            state_keys = list(state.keys())
            for i, key in enumerate(state_keys):
Exemplo n.º 15
0
def finetune(novel_loader,
             n_query=15,
             freeze_backbone=False,
             n_way=5,
             n_support=5,
             loadpath='',
             adaptation=False,
             pretrained_dataset='miniImagenet',
             proto_init=False):
    correct = 0
    count = 0

    iter_num = len(novel_loader)

    acc_all = []

    with tqdm(enumerate(novel_loader), total=len(novel_loader)) as pbar:

        for _, (x, y) in pbar:  #, position=1,
            #leave=False):

            ###############################################################################################
            # load pretrained model on miniImageNet
            pretrained_model = model_dict[params.model]()
            checkpoint_dir = '%s/checkpoints/%s/%s_%s_%s%s_%s%s' % (
                configs.save_dir, params.dataset, params.model, params.method,
                params.n_support, "s" if params.no_aug_support else "s_aug",
                params.n_query, "q" if params.no_aug_query else "q_aug")
            checkpoint_dir += "_bs{}".format(params.batch_size)

            if params.save_iter != -1:
                modelfile = get_assigned_file(checkpoint_dir, params.save_iter)
            elif params.method in ['baseline', 'baseline++']:
                modelfile = get_resume_file(checkpoint_dir)
            else:
                modelfile = get_best_file(checkpoint_dir)

            tmp = torch.load(modelfile)
            state = tmp['state']

            state_keys = list(state.keys())
            for _, key in enumerate(state_keys):
                if "feature." in key:
                    newkey = key.replace(
                        "feature.", ""
                    )  # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx'
                    state[newkey] = state.pop(key)
                else:
                    state.pop(key)

            pretrained_model.load_state_dict(state)
            pretrained_model.cuda()
            pretrained_model.train()
            ###############################################################################################

            if adaptation:
                classifier = Classifier(pretrained_model.final_feat_dim, n_way)
                classifier.cuda()
                classifier.train()
            else:
                classifier = ProtoClassifier(n_way, n_support, n_query)

            ###############################################################################################
            n_query = x.size(1) - n_support
            x = x.cuda()
            x_var = Variable(x)

            batch_size = n_way
            support_size = n_way * n_support

            y_a_i = Variable(
                torch.from_numpy(np.repeat(range(n_way),
                                           n_support))).cuda()  # (25,)

            x_b_i = x_var[:, n_support:, :, :, :].contiguous().view(
                n_way * n_query,
                *x.size()[2:])
            x_a_i = x_var[:, :n_support, :, :, :].contiguous().view(
                n_way * n_support,
                *x.size()[2:])  # (25, 3, 224, 224)
            pretrained_model.eval()
            z_a_i = pretrained_model(x_a_i.cuda())
            pretrained_model.train()

            ###############################################################################################
            loss_fn = nn.CrossEntropyLoss().cuda()
            if adaptation:
                inner_lr = params.lr_rate
                if proto_init:  # Initialise as distance classifer (distance to prototypes)
                    classifier.init_params_from_prototypes(
                        z_a_i, n_way, n_support)
                #classifier_opt = torch.optim.SGD(classifier.parameters(), lr = inner_lr, momentum=0.9, dampening=0.9, weight_decay=0.001)
                classifier_opt = torch.optim.Adam(classifier.parameters(),
                                                  lr=inner_lr)

                if freeze_backbone is False:
                    delta_opt = torch.optim.Adam(filter(
                        lambda p: p.requires_grad,
                        pretrained_model.parameters()),
                                                 lr=inner_lr)

                total_epoch = params.ft_steps

                if freeze_backbone is False:
                    pretrained_model.train()
                else:
                    pretrained_model.eval()

                classifier.train()

                #for epoch in range(total_epoch):
                for epoch in tqdm(range(total_epoch),
                                  total=total_epoch,
                                  leave=False):
                    rand_id = np.random.permutation(support_size)

                    for j in range(0, support_size, batch_size):
                        classifier_opt.zero_grad()
                        if freeze_backbone is False:
                            delta_opt.zero_grad()

                        #####################################
                        selected_id = torch.from_numpy(
                            rand_id[j:min(j +
                                          batch_size, support_size)]).cuda()

                        z_batch = x_a_i[selected_id]
                        y_batch = y_a_i[selected_id]
                        #####################################

                        output = pretrained_model(z_batch)
                        output = classifier(output)
                        loss = loss_fn(output, y_batch)

                        #####################################
                        loss.backward()

                        classifier_opt.step()

                        if freeze_backbone is False:
                            delta_opt.step()

                classifier.eval()

            pretrained_model.eval()

            output = pretrained_model(x_b_i.cuda())
            if adaptation:
                scores = classifier(output)
            else:
                scores = classifier(z_a_i, y_a_i, output)

            y_query = np.repeat(range(n_way), n_query)
            topk_scores, topk_labels = scores.data.topk(1, 1, True, True)
            topk_ind = topk_labels.cpu().numpy()

            top1_correct = np.sum(topk_ind[:, 0] == y_query)
            correct_this, count_this = float(top1_correct), len(y_query)
            #print (correct_this/ count_this *100)
            acc_all.append((correct_this / count_this * 100))

            ###############################################################################################

            pbar.set_postfix(avg_acc=np.mean(np.asarray(acc_all)))

        acc_all = np.asarray(acc_all)
        acc_mean = np.mean(acc_all)
        acc_std = np.std(acc_all)
        print('%d Test Acc = %4.2f%% +- %4.2f%%' %
              (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))
Exemplo n.º 16
0
                            loss_type=loss_type,
                            **few_shot_params)

    elif params.method in ["dampnet_full_class"]:
        model = dampnet_full_class.DampNet(model_dict[params.model],
                                           **few_shot_params)
    elif params.method == "baseline":
        checkpoint_dir_b = '%s/checkpoints/%s/%s_%s' % (
            configs.save_dir, pretrained_dataset, params.model, "baseline")
        if params.train_aug:
            checkpoint_dir_b += '_aug'

        if params.save_iter != -1:
            modelfile_b = get_assigned_file(checkpoint_dir_b, 400)
        elif params.method in ['baseline', 'baseline++']:
            modelfile_b = get_resume_file(checkpoint_dir_b)
        else:
            modelfile_b = get_best_file(checkpoint_dir_b)

        tmp_b = torch.load(modelfile_b)
        state_b = tmp_b['state']

    elif params.method == "all":
        #model           = ProtoNet( model_dict[params.model], **few_shot_params )
        checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
            configs.save_dir, 'miniImageNet', params.model, "protonet")
        model_2 = GnnNet(model_dict[params.model], **few_shot_params)
        checkpoint_dir2 = '%s/checkpoints/%s/%s_%s' % (
            configs.save_dir, 'miniImageNet', params.model, "gnnnet")
        #model_3           = dampnet_full_class.DampNet( model_dict[params.model], **few_shot_params )
        checkpoint_dir3 = '%s/checkpoints/%s/%s_%s' % (
Exemplo n.º 17
0
    def __init__(self, params):
        np.random.seed(10)

        if params.train_dataset == 'cross':
            base_file = configs.data_dir['miniImagenet'] + 'all.json'
            val_file = configs.data_dir['CUB'] + 'val.json'
        elif params.train_dataset == 'cross_char':
            base_file = configs.data_dir['omniglot'] + 'noLatin.json'
            val_file = configs.data_dir['emnist'] + 'val.json'
        else:
            base_file = configs.data_dir[params.train_dataset] + 'base.json'
            val_file = configs.data_dir[params.train_dataset] + 'val.json'

        if 'Conv' in params.model:
            if params.train_dataset in ['omniglot', 'cross_char']:
                image_size = 28
            else:
                image_size = 84
        else:
            image_size = 224

        if params.train_dataset in ['omniglot', 'cross_char']:
            assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation'
            params.model = 'Conv4S'

        if params.train_dataset == 'omniglot':
            assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class'
        if params.train_dataset == 'cross_char':
            assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class'

        params.train_num_query = max(
            1,
            int(params.test_num_query * params.test_num_way /
                params.train_num_way))
        if params.episodic:
            train_few_shot_params = dict(n_way=params.train_num_way,
                                         n_support=params.train_num_shot,
                                         n_query=params.train_num_query)
            base_datamgr = SetDataManager(image_size, **train_few_shot_params)
            base_loader = base_datamgr.get_data_loader(base_file,
                                                       aug=params.train_aug)
        else:
            base_datamgr = SimpleDataManager(image_size, batch_size=32)
            base_loader = base_datamgr.get_data_loader(base_file,
                                                       aug=params.train_aug)

        if params.test_dataset == 'cross':
            novel_file = configs.data_dir['CUB'] + 'novel.json'
        elif params.test_dataset == 'cross_char':
            novel_file = configs.data_dir['emnist'] + 'novel.json'
        else:
            novel_file = configs.data_dir[params.test_dataset] + 'novel.json'

        val_datamgr = SimpleDataManager(image_size, batch_size=64)
        val_loader = val_datamgr.get_data_loader(novel_file, aug=False)

        novel_datamgr = SimpleDataManager(image_size, batch_size=64)
        novel_loader = novel_datamgr.get_data_loader(novel_file, aug=False)

        optimizer = params.optimizer

        if params.stop_epoch == -1:
            if params.train_dataset in ['omniglot', 'cross_char']:
                params.stop_epoch = 5
            elif params.train_dataset in ['CUB']:
                params.stop_epoch = 200  # This is different as stated in the open-review paper. However, using 400 epoch in baseline actually lead to over-fitting
            elif params.train_dataset in ['miniImagenet', 'cross']:
                params.stop_epoch = 300
            else:
                params.stop_epoch = 300

        shake_config = {
            'shake_forward': params.shake_forward,
            'shake_backward': params.shake_backward,
            'shake_picture': params.shake_picture
        }
        train_param = {
            'loss_type': params.train_loss_type,
            'temperature': params.train_temperature,
            'margin': params.train_margin,
            'lr': params.train_lr,
            'shake': params.shake,
            'shake_config': shake_config,
            'episodic': params.episodic,
            'num_way': params.train_num_way,
            'num_shot': params.train_num_shot,
            'num_query': params.train_num_query,
            'num_classes': params.num_classes
        }
        test_param = {
            'loss_type': params.test_loss_type,
            'temperature': params.test_temperature,
            'margin': params.test_margin,
            'lr': params.test_lr,
            'num_way': params.test_num_way,
            'num_shot': params.test_num_shot,
            'num_query': params.test_num_query
        }

        model = Baseline(model_dict[params.model], params.entropy, train_param,
                         test_param)

        model = model.cuda()

        key = params.tag
        writer = SummaryWriter(log_dir=os.path.join(params.vis_log, key))

        params.checkpoint_dir = '%s/checkpoints/%s/%s' % (
            configs.save_dir, params.train_dataset, params.checkpoint_dir)

        if not os.path.isdir(params.vis_log):
            os.makedirs(params.vis_log)

        outfile_template = os.path.join(
            params.checkpoint_dir.replace("checkpoints", "features"),
            "%s.hdf5")

        if params.mode == 'train' and not os.path.isdir(params.checkpoint_dir):
            os.makedirs(params.checkpoint_dir)

        if params.resume or params.mode == 'test':
            if params.mode == 'test':
                self.feature_model = model_dict[params.model]().cuda()
                resume_file = get_best_file(params.checkpoint_dir)
                tmp = torch.load(resume_file)
                state = tmp['state']
                state_keys = list(state.keys())
                for i, key in enumerate(state_keys):
                    if "feature." in key:
                        newkey = key.replace("feature.", "")
                        state[newkey] = state.pop(key)
                    else:
                        state.pop(key)
                self.feature_model.load_state_dict(state)
                self.feature_model.eval()
            else:
                resume_file = get_resume_file(params.checkpoint_dir)
                tmp = torch.load(resume_file)
                state = tmp['state']
                model.load_state_dict(state)
                params.start_epoch = tmp['epoch'] + 1

            print('Info: Model loaded!!!')

        self.params = params
        self.val_file = val_file
        self.base_file = base_file
        self.image_size = image_size
        self.optimizer = optimizer
        self.outfile_template = outfile_template
        self.novel_loader = novel_loader
        self.base_loader = base_loader
        self.val_loader = val_loader
        self.writer = writer
        self.model = model
        self.key = key