Example #1
0
    def generate_task(self, extracted_features_queries_dic, shot, number_tasks, extracted_features_shots_dic):
        """
        inputs:
            extracted_features_dic :
            shot : Number of support shot per class
            number_tasks : Number of tasks to generate

        returns :
            merged_task : { z_support : torch.tensor of shape [number_tasks, n_ways * shot, feature_dim]
                            z_query : torch.tensor of shape [number_tasks, n_ways * query_shot, feature_dim]
                            y_support : torch.tensor of shape [number_tasks, n_ways * shot]
                            y_query : torch.tensor of shape [number_tasks, n_ways * query_shot] }
        """
        print(f" ==> Generating  {number_tasks} task ...")
        tasks_dics = []
        for _ in warp_tqdm(range(number_tasks), False):
            task_dic = self.get_task(shot=shot, extracted_features_queries_dic=extracted_features_queries_dic, extracted_features_shots_dic=extracted_features_shots_dic)
            tasks_dics.append(task_dic)

        # Now merging all tasks into 1 single dictionnary
        merged_tasks = {}
        n_tasks = len(tasks_dics)
        for key in tasks_dics[0].keys():
            n_samples = tasks_dics[0][key].size(0)
            merged_tasks[key] = torch.cat([tasks_dics[i][key] for i in range(n_tasks)], dim=0).view(n_tasks, n_samples, -1)
        return merged_tasks
Example #2
0
    def meta_val(self, model, meta_val_way, meta_val_shot, disable_tqdm,
                 callback, epoch):
        top1 = AverageMeter()
        model.eval()

        with torch.no_grad():
            tqdm_test_loader = warp_tqdm(self.val_loader, disable_tqdm)
            for i, (inputs, target, _) in enumerate(tqdm_test_loader):
                inputs, target = inputs.to(self.device), target.to(
                    self.device, non_blocking=True)
                output = model(inputs, feature=True)[0].cuda(0)
                train_out = output[:meta_val_way * meta_val_shot]
                train_label = target[:meta_val_way * meta_val_shot]
                test_out = output[meta_val_way * meta_val_shot:]
                test_label = target[meta_val_way * meta_val_shot:]
                train_out = train_out.reshape(meta_val_way, meta_val_shot,
                                              -1).mean(1)
                train_label = train_label[::meta_val_shot]
                prediction = self.metric_prediction(train_out, test_out,
                                                    train_label)
                acc = (prediction == test_label).float().mean()
                top1.update(acc.item())
                if not disable_tqdm:
                    tqdm_test_loader.set_description('Acc {:.2f}'.format(
                        top1.avg * 100))

        if callback is not None:
            callback.scalar('val_acc', epoch + 1, top1.avg, title='Val acc')
        return top1.avg
Example #3
0
    def extract_features(self, model, model_path, model_tag, used_set,
                         loaders_dic):
        """
        inputs:
            model : The loaded model containing the feature extractor
            loaders_dic : Dictionnary containing training and testing loaders
            model_path : Where was the model loaded from
            model_tag : Which model ('final' or 'best') to load
            used_set : Set used between 'test' and 'val'
            n_ways : Number of ways for the task

        returns :
            extracted_features_dic : Dictionnary containing all extracted features and labels
        """

        # Load features from memory if previously saved ...
        save_dir = os.path.join(model_path, model_tag, used_set)
        filepath = os.path.join(save_dir, 'output.plk')
        if os.path.isfile(filepath):
            extracted_features_dic = load_pickle(filepath)
            print(" ==> Features loaded from {}".format(filepath))
            return extracted_features_dic

        # ... otherwise just extract them
        else:
            print(" ==> Beginning feature extraction")
            if not os.path.isdir(save_dir):
                os.makedirs(save_dir)

        model.eval()
        with torch.no_grad():

            all_features = []
            all_labels = []
            for i, (inputs, labels,
                    _) in enumerate(warp_tqdm(loaders_dic['test'], False)):
                inputs = inputs.to(self.device)
                outputs, _ = model(inputs, True)
                all_features.append(outputs.cpu())
                all_labels.append(labels)
            all_features = torch.cat(all_features, 0)
            all_labels = torch.cat(all_labels, 0)
            extracted_features_dic = {
                'concat_features': all_features,
                'concat_labels': all_labels
            }
        print(" ==> Saving features to {}".format(filepath))
        save_pickle(filepath, extracted_features_dic)
        return extracted_features_dic
Example #4
0
def main(seed, pretrain, resume, evaluate, print_runtime, epochs, disable_tqdm,
         visdom_port, ckpt_path, make_plot, cuda):
    device = torch.device("cuda" if cuda else "cpu")
    callback = None if visdom_port is None else VisdomLogger(port=visdom_port)
    if seed is not None:
        random.seed(seed)
        torch.manual_seed(seed)
        cudnn.deterministic = True
    torch.cuda.set_device(0)
    # create model
    print("=> Creating model '{}'".format(
        ex.current_run.config['model']['arch']))
    model = torch.nn.DataParallel(get_model()).cuda()

    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))
    optimizer = get_optimizer(model)

    if pretrain:
        pretrain = os.path.join(pretrain, 'checkpoint.pth.tar')
        if os.path.isfile(pretrain):
            print("=> loading pretrained weight '{}'".format(pretrain))
            checkpoint = torch.load(pretrain)
            model_dict = model.state_dict()
            params = checkpoint['state_dict']
            params = {k: v for k, v in params.items() if k in model_dict}
            model_dict.update(params)
            model.load_state_dict(model_dict)
        else:
            print(
                '[Warning]: Did not find pretrained model {}'.format(pretrain))

    if resume:
        resume_path = ckpt_path + '/checkpoint.pth.tar'
        if os.path.isfile(resume_path):
            print("=> loading checkpoint '{}'".format(resume_path))
            checkpoint = torch.load(resume_path)
            start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            # scheduler.load_state_dict(checkpoint['scheduler'])
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                resume_path, checkpoint['epoch']))
        else:
            print('[Warning]: Did not find checkpoint {}'.format(resume_path))
    else:
        start_epoch = 0
        best_prec1 = -1

    cudnn.benchmark = True

    # Data loading code
    evaluator = Evaluator(device=device, ex=ex)
    if evaluate:
        print("Evaluating")
        results = evaluator.run_full_evaluation(model=model,
                                                model_path=ckpt_path,
                                                callback=callback)
        #MYMOD
        #,model_tag='best',
        #shots=[5],
        #method="tim-gd")
        return results

    # If this line is reached, then training the model
    trainer = Trainer(device=device, ex=ex)
    scheduler = get_scheduler(optimizer=optimizer,
                              num_batches=len(trainer.train_loader),
                              epochs=epochs)
    tqdm_loop = warp_tqdm(list(range(start_epoch, epochs)),
                          disable_tqdm=disable_tqdm)
    for epoch in tqdm_loop:
        # Do one epoch
        trainer.do_epoch(model=model,
                         optimizer=optimizer,
                         epoch=epoch,
                         scheduler=scheduler,
                         disable_tqdm=disable_tqdm,
                         callback=callback)

        # Evaluation on validation set
        prec1 = trainer.meta_val(model=model,
                                 disable_tqdm=disable_tqdm,
                                 epoch=epoch,
                                 callback=callback)
        print('Meta Val {}: {}'.format(epoch, prec1))
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        if not disable_tqdm:
            tqdm_loop.set_description('Best Acc {:.2f}'.format(best_prec1 *
                                                               100.))

        # Save checkpoint
        save_checkpoint(state={
            'epoch': epoch + 1,
            'arch': ex.current_run.config['model']['arch'],
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            'optimizer': optimizer.state_dict()
        },
                        is_best=is_best,
                        folder=ckpt_path)
        if scheduler is not None:
            scheduler.step()

    # Final evaluation on test set
    results = evaluator.run_full_evaluation(model=model, model_path=ckpt_path)
    return results
Example #5
0
    def do_epoch(self, epoch, scheduler, print_freq, disable_tqdm, callback,
                 model, alpha, optimizer):
        batch_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()

        # switch to train mode
        model.train()
        steps_per_epoch = len(self.train_loader)
        end = time.time()
        tqdm_train_loader = warp_tqdm(self.train_loader, disable_tqdm)
        for i, (input, target, _) in enumerate(tqdm_train_loader):

            input, target = input.to(self.device), target.to(self.device,
                                                             non_blocking=True)

            smoothed_targets = self.smooth_one_hot(target)
            assert (smoothed_targets.argmax(1) == target).float().mean() == 1.0
            # Forward pass
            if alpha > 0:  # Mixup augmentation
                # generate mixed sample and targets
                lam = np.random.beta(alpha, alpha)
                rand_index = torch.randperm(input.size()[0]).cuda()
                target_a = smoothed_targets
                target_b = smoothed_targets[rand_index]
                mixed_input = lam * input + (1 - lam) * input[rand_index]

                output = model(mixed_input)
                loss = self.cross_entropy(output,
                                          target_a) * lam + self.cross_entropy(
                                              output, target_b) * (1. - lam)
            else:
                output = model(input)
                loss = self.cross_entropy(output, smoothed_targets)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            prec1 = (output.argmax(1) == target).float().mean()
            top1.update(prec1.item(), input.size(0))
            if not disable_tqdm:
                tqdm_train_loader.set_description('Acc {:.2f}'.format(
                    top1.avg))

            # Measure accuracy and record loss
            losses.update(loss.item(), input.size(0))
            batch_time.update(time.time() - end)
            end = time.time()

            if i % print_freq == 0:
                print('Epoch: [{0}][{1}/{2}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'.format(
                          epoch,
                          i,
                          len(self.train_loader),
                          batch_time=batch_time,
                          loss=losses,
                          top1=top1))
                if callback is not None:
                    callback.scalar('train_loss',
                                    i / steps_per_epoch + epoch,
                                    losses.avg,
                                    title='Train loss')
                    callback.scalar('@1',
                                    i / steps_per_epoch + epoch,
                                    top1.avg,
                                    title='Train Accuracy')
        for param_group in optimizer.param_groups:
            current_lr = param_group['lr']
        if callback is not None:
            callback.scalar('lr', epoch, current_lr, title='Learning rate')