示例#1
0
def train_baseline(base_loader, base_loader_test, val_loader, model,
                   start_epoch, stop_epoch, params, tmp):
    if params.dct_status:
        channels = params.channels
    else:
        channels = 3

    val_acc_best = 0.0

    if not os.path.isdir(params.checkpoint_dir):
        os.makedirs(params.checkpoint_dir)

    if path.exists(params.checkpoint_dir + '/val_' + params.dataset + '.pt'):
        loader = torch.load(params.checkpoint_dir + '/val_' + params.dataset +
                            '.pt')
    else:
        loader = []
        for ii, (x, _) in enumerate(val_loader):
            loader.append(x)
            #print("head of train_dct: ", x.shape)
        torch.save(loader,
                   params.checkpoint_dir + '/val_' + params.dataset + '.pt')

    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = optim.Adam(model.parameters())
    print("stop_epoch", start_epoch, stop_epoch)
    for epoch in range(start_epoch, stop_epoch):
        print('\nEpoch: %d' % epoch)
        model.train()
        train_loss = 0
        reg_loss = 0
        correct = 0
        correct1 = 0.0
        total = 0

        for batch_idx, (input_var, target_var) in enumerate(base_loader):
            if use_gpu:
                input_var, target_var = input_var.cuda(), target_var.cuda()
            input_dct_var, target_var = Variable(input_var), Variable(
                target_var)
            f, outputs = model.forward(input_dct_var)
            loss = criterion(outputs, target_var)
            train_loss += loss.data.item()
            _, predicted = torch.max(outputs.data, 1)
            total += target_var.size(0)
            correct += predicted.eq(target_var.data).cpu().sum()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch_idx % 50 == 0:
                print(
                    '{0}/{1}'.format(batch_idx, len(base_loader)),
                    'Loss: %.3f | Acc: %.3f%%  ' %
                    (train_loss / (batch_idx + 1), 100. * correct / total))

        if not os.path.isdir(params.checkpoint_dir):
            os.makedirs(params.checkpoint_dir)

        if (epoch % params.save_freq == 0) or (epoch == stop_epoch - 1):
            outfile = os.path.join(params.checkpoint_dir,
                                   '{:d}.tar'.format(epoch))
            torch.save({'epoch': epoch, 'state': model.state_dict()}, outfile)

        model.eval()
        with torch.no_grad():
            test_loss = 0
            correct = 0
            total = 0
            for batch_idx, (inputs, targets) in enumerate(base_loader_test):
                if use_gpu:
                    inputs, targets = inputs.cuda(), targets.cuda()
                inputs, targets = Variable(inputs), Variable(targets)
                f, outputs = model.forward(inputs)
                loss = criterion(outputs, targets)
                test_loss += loss.data.item()
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += predicted.eq(targets.data).cpu().sum()

            print('Loss: %.3f | Acc: %.3f%%' %
                  (test_loss / (batch_idx + 1), 100. * correct / total))
        torch.cuda.empty_cache()

        valmodel = BaselineFinetune(model_dict[params.model],
                                    params.train_n_way,
                                    params.n_shot,
                                    loss_type='dist')
        valmodel.n_query = 15
        acc_all1, acc_all2, acc_all3 = [], [], []
        for i, x in enumerate(loader):
            # print("len of loader: ",len(loader))
            # print("shape of x: ",x.shape)
            if params.dct_status:
                x = x.view(-1, channels, image_size_dct, image_size_dct)
            else:
                x = x.view(-1, channels, image_size, image_size)

            if use_gpu:
                x = x.cuda()

            with torch.no_grad():
                f, scores = model(x)
            f = f.view(params.train_n_way, params.n_shot + valmodel.n_query,
                       -1)
            scores = valmodel.set_forward_adaptation(f.cpu())
            acc = []
            for each_score in scores:
                pred = each_score.data.cpu().numpy().argmax(axis=1)
                y = np.repeat(range(5), 15)
                acc.append(np.mean(pred == y) * 100)
            acc_all1.append(acc[0])
            acc_all2.append(acc[1])
            acc_all3.append(acc[2])

        print('Test Acc at 100= %4.2f%%' % (np.mean(acc_all1)))
        print('Test Acc at 200= %4.2f%%' % (np.mean(acc_all2)))
        print('Test Acc at 300= %4.2f%%' % (np.mean(acc_all3)))

        if np.mean(acc_all3) > val_acc_best:
            val_acc_best = np.mean(acc_all3)
            bestfile = os.path.join(params.checkpoint_dir, 'best.tar')
            torch.save({'epoch': epoch, 'state': model.state_dict()}, bestfile)

    return model
示例#2
0
def train_s2m2(base_loader, base_loader_test, val_loader, model, start_epoch,
               stop_epoch, params, tmp):

    if params.dct_status:
        channels = params.channels
    else:
        channels = 3

    val_acc_best = 0.0

    if not os.path.isdir(params.checkpoint_dir):
        os.makedirs(params.checkpoint_dir)

    if path.exists(params.checkpoint_dir + '/val_' + params.dataset + '.pt'):
        loader = torch.load(params.checkpoint_dir + '/val_' + params.dataset +
                            '.pt')
    else:
        loader = []
        for _, (x, _) in enumerate(val_loader):
            loader.append(x)
        torch.save(loader,
                   params.checkpoint_dir + '/val_' + params.dataset + '.pt')

    def mixup_criterion(criterion, pred, y_a, y_b, lam):
        return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

    criterion = nn.CrossEntropyLoss()

    if params.model == 'WideResNet28_10':
        rotate_classifier = nn.Sequential(nn.Linear(640, 4))
    elif params.model == 'ResNet18':
        rotate_classifier = nn.Sequential(nn.Linear(512, 4))

    rotate_classifier.cuda()

    if 'rotate' in tmp:
        print("loading rotate model")
        rotate_classifier.load_state_dict(tmp['rotate'])

    optimizer = torch.optim.Adam([{
        'params': model.parameters()
    }, {
        'params': rotate_classifier.parameters()
    }])

    print("stop_epoch", start_epoch, stop_epoch)

    for epoch in range(start_epoch, stop_epoch):
        print('\nEpoch: %d' % epoch)

        model.train()
        train_loss = 0
        rotate_loss = 0
        correct = 0
        total = 0
        torch.cuda.empty_cache()
        print("inside base_loader: ", len(base_loader))
        for batch_idx, (inputs, targets) in enumerate(base_loader):
            if use_gpu:
                inputs, targets = inputs.cuda(), targets.cuda()
            #print("shape of input: ", inputs.shape)
            lam = np.random.beta(params.alpha, params.alpha)
            f, outputs, target_a, target_b = model(inputs,
                                                   targets,
                                                   mixup_hidden=True,
                                                   mixup_alpha=params.alpha,
                                                   lam=lam)
            loss = mixup_criterion(criterion, outputs, target_a, target_b, lam)
            train_loss += loss.data.item()
            optimizer.zero_grad()
            loss.backward()

            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += (
                lam * predicted.eq(target_a.data).cpu().sum().float() +
                (1 - lam) * predicted.eq(target_b.data).cpu().sum().float())

            bs = inputs.size(0)
            inputs_ = []
            targets_ = []
            a_ = []
            indices = np.arange(bs)
            np.random.shuffle(indices)

            split_size = int(bs / 4)
            for j in indices[0:split_size]:
                x90 = inputs[j].transpose(2, 1).flip(1)
                x180 = x90.transpose(2, 1).flip(1)
                x270 = x180.transpose(2, 1).flip(1)
                inputs_ += [inputs[j], x90, x180, x270]
                targets_ += [targets[j] for _ in range(4)]
                a_ += [
                    torch.tensor(0),
                    torch.tensor(1),
                    torch.tensor(2),
                    torch.tensor(3)
                ]

            inputs = Variable(torch.stack(inputs_, 0))
            targets = Variable(torch.stack(targets_, 0))
            a_ = Variable(torch.stack(a_, 0))

            if use_gpu:
                inputs = inputs.cuda()
                targets = targets.cuda()
                a_ = a_.cuda()

            rf, outputs = model(inputs)
            rotate_outputs = rotate_classifier(rf)
            rloss = criterion(rotate_outputs, a_)
            closs = criterion(outputs, targets)
            loss = (rloss + closs) / 2.0

            rotate_loss += rloss.data.item()

            loss.backward()

            optimizer.step()

            if batch_idx % 50 == 0:
                print(
                    '{0}/{1}'.format(batch_idx, len(base_loader)),
                    'Loss: %.3f | Acc: %.3f%% | RotLoss: %.3f  ' %
                    (train_loss /
                     (batch_idx + 1), 100. * correct / total, rotate_loss /
                     (batch_idx + 1)))

        if not os.path.isdir(params.checkpoint_dir):
            os.makedirs(params.checkpoint_dir)

        if (epoch % params.save_freq == 0) or (epoch == stop_epoch - 1):
            outfile = os.path.join(params.checkpoint_dir,
                                   '{:d}.tar'.format(epoch))
            torch.save({'epoch': epoch, 'state': model.state_dict()}, outfile)

        model.eval()
        with torch.no_grad():
            test_loss = 0
            correct = 0
            total = 0
            for batch_idx, (inputs, targets) in enumerate(base_loader_test):
                if use_gpu:
                    inputs, targets = inputs.cuda(), targets.cuda()
                inputs, targets = Variable(inputs), Variable(targets)
                f, outputs = model.forward(inputs)
                loss = criterion(outputs, targets)
                test_loss += loss.data.item()
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += predicted.eq(targets.data).cpu().sum()

            print('Loss: %.3f | Acc: %.3f%%' %
                  (test_loss / (batch_idx + 1), 100. * correct / total))

        if params.dct_status:

            valmodel = BaselineFinetune(model_dict[params.model + '_dct'],
                                        params.train_n_way,
                                        params.n_shot,
                                        loss_type='dist')
        else:
            valmodel = BaselineFinetune(model_dict[params.model],
                                        params.train_n_way,
                                        params.n_shot,
                                        loss_type='dist')
        valmodel.n_query = 15
        acc_all1, acc_all2, acc_all3 = [], [], []
        for i, x in enumerate(loader):
            if params.dct_status:
                x = x.view(-1, channels, image_size_dct, image_size_dct)
            else:
                x = x.view(-1, channels, image_size, image_size)

            if use_gpu:
                x = x.cuda()

            with torch.no_grad():
                f, scores = model(x)
            f = f.view(params.train_n_way, params.n_shot + valmodel.n_query,
                       -1)
            scores = valmodel.set_forward_adaptation(f.cpu())
            acc = []
            for each_score in scores:
                pred = each_score.data.cpu().numpy().argmax(axis=1)
                y = np.repeat(range(5), 15)
                acc.append(np.mean(pred == y) * 100)
            acc_all1.append(acc[0])
            acc_all2.append(acc[1])
            acc_all3.append(acc[2])

        print('Test Acc at 100= %4.2f%%' % (np.mean(acc_all1)))
        print('Test Acc at 200= %4.2f%%' % (np.mean(acc_all2)))
        print('Test Acc at 300= %4.2f%%' % (np.mean(acc_all3)))

        if np.mean(acc_all3) > val_acc_best:
            val_acc_best = np.mean(acc_all3)
            bestfile = os.path.join(params.checkpoint_dir, 'best.tar')
            torch.save(
                {
                    'epoch': epoch,
                    'state': model.state_dict(),
                    'rotate': rotate_classifier.state_dict()
                }, bestfile)

    return model
def get_logits_targets(params):
    acc_all = []
    iter_num = 600
    few_shot_params = dict(n_way = params.test_n_way , n_support = params.n_shot) 

    if params.dataset in ['omniglot', 'cross_char']:
        assert params.model == 'Conv4' and not params.train_aug ,'omniglot only support Conv4 without augmentation'
        params.model = 'Conv4S'

    if params.method == 'baseline':
        model           = BaselineFinetune( model_dict[params.model], **few_shot_params )
    elif params.method == 'baseline++':
        model           = BaselineFinetune( model_dict[params.model], loss_type = 'dist', **few_shot_params )
    elif params.method == 'protonet':
        model           = ProtoNet( model_dict[params.model], **few_shot_params )
    elif params.method == 'DKT':
        model           = DKT(model_dict[params.model], **few_shot_params)
    elif params.method == 'matchingnet':
        model           = MatchingNet( model_dict[params.model], **few_shot_params )
    elif params.method in ['relationnet', 'relationnet_softmax']:
        if params.model == 'Conv4': 
            feature_model = backbone.Conv4NP
        elif params.model == 'Conv6': 
            feature_model = backbone.Conv6NP
        elif params.model == 'Conv4S': 
            feature_model = backbone.Conv4SNP
        else:
            feature_model = lambda: model_dict[params.model]( flatten = False )
        loss_type = 'mse' if params.method == 'relationnet' else 'softmax'
        model           = RelationNet( feature_model, loss_type = loss_type , **few_shot_params )
    elif params.method in ['maml' , 'maml_approx']:
        backbone.ConvBlock.maml = True
        backbone.SimpleBlock.maml = True
        backbone.BottleneckBlock.maml = True
        backbone.ResNet.maml = True
        model = MAML(  model_dict[params.model], approx = (params.method == 'maml_approx') , **few_shot_params )
        if params.dataset in ['omniglot', 'cross_char']: #maml use different parameter in omniglot
            model.n_task     = 32
            model.task_update_num = 1
            model.train_lr = 0.1
    else:
       raise ValueError('Unknown method')

    model = model.cuda()

    checkpoint_dir = '%s/checkpoints/%s/%s_%s' %(configs.save_dir, params.dataset, params.model, params.method)
    if params.train_aug:
        checkpoint_dir += '_aug'
    if not params.method in ['baseline', 'baseline++'] :
        checkpoint_dir += '_%dway_%dshot' %( params.train_n_way, params.n_shot)

    #modelfile   = get_resume_file(checkpoint_dir)

    if not params.method in ['baseline', 'baseline++'] : 
        if params.save_iter != -1:
            modelfile   = get_assigned_file(checkpoint_dir,params.save_iter)
        else:
            modelfile   = get_best_file(checkpoint_dir)
        if modelfile is not None:
            tmp = torch.load(modelfile)
            model.load_state_dict(tmp['state'])
        else:
            print("[WARNING] Cannot find 'best_file.tar' in: " + str(checkpoint_dir))

    split = params.split
    if params.save_iter != -1:
        split_str = split + "_" +str(params.save_iter)
    else:
        split_str = split
    if params.method in ['maml', 'maml_approx', 'DKT']: #maml do not support testing with feature
        if 'Conv' in params.model:
            if params.dataset in ['omniglot', 'cross_char']:
                image_size = 28
            else:
                image_size = 84 
        else:
            image_size = 224

        datamgr         = SetDataManager(image_size, n_eposide = iter_num, n_query = 15 , **few_shot_params)
        
        if params.dataset == 'cross':
            if split == 'base':
                loadfile = configs.data_dir['miniImagenet'] + 'all.json' 
            else:
                loadfile   = configs.data_dir['CUB'] + split +'.json'
        elif params.dataset == 'cross_char':
            if split == 'base':
                loadfile = configs.data_dir['omniglot'] + 'noLatin.json' 
            else:
                loadfile  = configs.data_dir['emnist'] + split +'.json' 
        else: 
            loadfile    = configs.data_dir[params.dataset] + split + '.json'

        novel_loader     = datamgr.get_data_loader( loadfile, aug = False)
        if params.adaptation:
            model.task_update_num = 100 #We perform adaptation on MAML simply by updating more times.
        model.eval()

        logits_list = list()
        targets_list = list()    
        for i, (x,_) in enumerate(novel_loader):
            logits = model.get_logits(x).detach()
            targets = torch.tensor(np.repeat(range(params.test_n_way), model.n_query)).cuda()
            logits_list.append(logits) #.cpu().detach().numpy())
            targets_list.append(targets) #.cpu().detach().numpy())
    else:
        novel_file = os.path.join( checkpoint_dir.replace("checkpoints","features"), split_str +".hdf5")
        cl_data_file = feat_loader.init_loader(novel_file)
        logits_list = list()
        targets_list = list()
        n_query = 15
        n_way = few_shot_params['n_way']
        n_support = few_shot_params['n_support']
        class_list = cl_data_file.keys()
        for i in range(iter_num):
            #----------------------
            select_class = random.sample(class_list,n_way)
            z_all  = []
            for cl in select_class:
                img_feat = cl_data_file[cl]
                perm_ids = np.random.permutation(len(img_feat)).tolist()
                z_all.append( [ np.squeeze( img_feat[perm_ids[i]]) for i in range(n_support+n_query) ] )     # stack each batch
            z_all = torch.from_numpy(np.array(z_all))
            model.n_query = n_query
            logits  = model.set_forward(z_all, is_feature = True).detach()
            targets = torch.tensor(np.repeat(range(n_way), n_query)).cuda()
            logits_list.append(logits)
            targets_list.append(targets)
            #----------------------
    return torch.cat(logits_list, 0), torch.cat(targets_list, 0)