def generate_task(self, extracted_features_queries_dic, shot, number_tasks, extracted_features_shots_dic): """ inputs: extracted_features_dic : shot : Number of support shot per class number_tasks : Number of tasks to generate returns : merged_task : { z_support : torch.tensor of shape [number_tasks, n_ways * shot, feature_dim] z_query : torch.tensor of shape [number_tasks, n_ways * query_shot, feature_dim] y_support : torch.tensor of shape [number_tasks, n_ways * shot] y_query : torch.tensor of shape [number_tasks, n_ways * query_shot] } """ print(f" ==> Generating {number_tasks} task ...") tasks_dics = [] for _ in warp_tqdm(range(number_tasks), False): task_dic = self.get_task(shot=shot, extracted_features_queries_dic=extracted_features_queries_dic, extracted_features_shots_dic=extracted_features_shots_dic) tasks_dics.append(task_dic) # Now merging all tasks into 1 single dictionnary merged_tasks = {} n_tasks = len(tasks_dics) for key in tasks_dics[0].keys(): n_samples = tasks_dics[0][key].size(0) merged_tasks[key] = torch.cat([tasks_dics[i][key] for i in range(n_tasks)], dim=0).view(n_tasks, n_samples, -1) return merged_tasks
def meta_val(self, model, meta_val_way, meta_val_shot, disable_tqdm, callback, epoch): top1 = AverageMeter() model.eval() with torch.no_grad(): tqdm_test_loader = warp_tqdm(self.val_loader, disable_tqdm) for i, (inputs, target, _) in enumerate(tqdm_test_loader): inputs, target = inputs.to(self.device), target.to( self.device, non_blocking=True) output = model(inputs, feature=True)[0].cuda(0) train_out = output[:meta_val_way * meta_val_shot] train_label = target[:meta_val_way * meta_val_shot] test_out = output[meta_val_way * meta_val_shot:] test_label = target[meta_val_way * meta_val_shot:] train_out = train_out.reshape(meta_val_way, meta_val_shot, -1).mean(1) train_label = train_label[::meta_val_shot] prediction = self.metric_prediction(train_out, test_out, train_label) acc = (prediction == test_label).float().mean() top1.update(acc.item()) if not disable_tqdm: tqdm_test_loader.set_description('Acc {:.2f}'.format( top1.avg * 100)) if callback is not None: callback.scalar('val_acc', epoch + 1, top1.avg, title='Val acc') return top1.avg
def extract_features(self, model, model_path, model_tag, used_set, loaders_dic): """ inputs: model : The loaded model containing the feature extractor loaders_dic : Dictionnary containing training and testing loaders model_path : Where was the model loaded from model_tag : Which model ('final' or 'best') to load used_set : Set used between 'test' and 'val' n_ways : Number of ways for the task returns : extracted_features_dic : Dictionnary containing all extracted features and labels """ # Load features from memory if previously saved ... save_dir = os.path.join(model_path, model_tag, used_set) filepath = os.path.join(save_dir, 'output.plk') if os.path.isfile(filepath): extracted_features_dic = load_pickle(filepath) print(" ==> Features loaded from {}".format(filepath)) return extracted_features_dic # ... otherwise just extract them else: print(" ==> Beginning feature extraction") if not os.path.isdir(save_dir): os.makedirs(save_dir) model.eval() with torch.no_grad(): all_features = [] all_labels = [] for i, (inputs, labels, _) in enumerate(warp_tqdm(loaders_dic['test'], False)): inputs = inputs.to(self.device) outputs, _ = model(inputs, True) all_features.append(outputs.cpu()) all_labels.append(labels) all_features = torch.cat(all_features, 0) all_labels = torch.cat(all_labels, 0) extracted_features_dic = { 'concat_features': all_features, 'concat_labels': all_labels } print(" ==> Saving features to {}".format(filepath)) save_pickle(filepath, extracted_features_dic) return extracted_features_dic
def main(seed, pretrain, resume, evaluate, print_runtime, epochs, disable_tqdm, visdom_port, ckpt_path, make_plot, cuda): device = torch.device("cuda" if cuda else "cpu") callback = None if visdom_port is None else VisdomLogger(port=visdom_port) if seed is not None: random.seed(seed) torch.manual_seed(seed) cudnn.deterministic = True torch.cuda.set_device(0) # create model print("=> Creating model '{}'".format( ex.current_run.config['model']['arch'])) model = torch.nn.DataParallel(get_model()).cuda() print('Number of model parameters: {}'.format( sum([p.data.nelement() for p in model.parameters()]))) optimizer = get_optimizer(model) if pretrain: pretrain = os.path.join(pretrain, 'checkpoint.pth.tar') if os.path.isfile(pretrain): print("=> loading pretrained weight '{}'".format(pretrain)) checkpoint = torch.load(pretrain) model_dict = model.state_dict() params = checkpoint['state_dict'] params = {k: v for k, v in params.items() if k in model_dict} model_dict.update(params) model.load_state_dict(model_dict) else: print( '[Warning]: Did not find pretrained model {}'.format(pretrain)) if resume: resume_path = ckpt_path + '/checkpoint.pth.tar' if os.path.isfile(resume_path): print("=> loading checkpoint '{}'".format(resume_path)) checkpoint = torch.load(resume_path) start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] # scheduler.load_state_dict(checkpoint['scheduler']) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})".format( resume_path, checkpoint['epoch'])) else: print('[Warning]: Did not find checkpoint {}'.format(resume_path)) else: start_epoch = 0 best_prec1 = -1 cudnn.benchmark = True # Data loading code evaluator = Evaluator(device=device, ex=ex) if evaluate: print("Evaluating") results = evaluator.run_full_evaluation(model=model, model_path=ckpt_path, callback=callback) #MYMOD #,model_tag='best', #shots=[5], #method="tim-gd") return results # If this line is reached, then training the model trainer = Trainer(device=device, ex=ex) scheduler = get_scheduler(optimizer=optimizer, num_batches=len(trainer.train_loader), epochs=epochs) tqdm_loop = warp_tqdm(list(range(start_epoch, epochs)), disable_tqdm=disable_tqdm) for epoch in tqdm_loop: # Do one epoch trainer.do_epoch(model=model, optimizer=optimizer, epoch=epoch, scheduler=scheduler, disable_tqdm=disable_tqdm, callback=callback) # Evaluation on validation set prec1 = trainer.meta_val(model=model, disable_tqdm=disable_tqdm, epoch=epoch, callback=callback) print('Meta Val {}: {}'.format(epoch, prec1)) is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) if not disable_tqdm: tqdm_loop.set_description('Best Acc {:.2f}'.format(best_prec1 * 100.)) # Save checkpoint save_checkpoint(state={ 'epoch': epoch + 1, 'arch': ex.current_run.config['model']['arch'], 'state_dict': model.state_dict(), 'best_prec1': best_prec1, 'optimizer': optimizer.state_dict() }, is_best=is_best, folder=ckpt_path) if scheduler is not None: scheduler.step() # Final evaluation on test set results = evaluator.run_full_evaluation(model=model, model_path=ckpt_path) return results
def do_epoch(self, epoch, scheduler, print_freq, disable_tqdm, callback, model, alpha, optimizer): batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() # switch to train mode model.train() steps_per_epoch = len(self.train_loader) end = time.time() tqdm_train_loader = warp_tqdm(self.train_loader, disable_tqdm) for i, (input, target, _) in enumerate(tqdm_train_loader): input, target = input.to(self.device), target.to(self.device, non_blocking=True) smoothed_targets = self.smooth_one_hot(target) assert (smoothed_targets.argmax(1) == target).float().mean() == 1.0 # Forward pass if alpha > 0: # Mixup augmentation # generate mixed sample and targets lam = np.random.beta(alpha, alpha) rand_index = torch.randperm(input.size()[0]).cuda() target_a = smoothed_targets target_b = smoothed_targets[rand_index] mixed_input = lam * input + (1 - lam) * input[rand_index] output = model(mixed_input) loss = self.cross_entropy(output, target_a) * lam + self.cross_entropy( output, target_b) * (1. - lam) else: output = model(input) loss = self.cross_entropy(output, smoothed_targets) # Backward pass optimizer.zero_grad() loss.backward() optimizer.step() prec1 = (output.argmax(1) == target).float().mean() top1.update(prec1.item(), input.size(0)) if not disable_tqdm: tqdm_train_loader.set_description('Acc {:.2f}'.format( top1.avg)) # Measure accuracy and record loss losses.update(loss.item(), input.size(0)) batch_time.update(time.time() - end) end = time.time() if i % print_freq == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'.format( epoch, i, len(self.train_loader), batch_time=batch_time, loss=losses, top1=top1)) if callback is not None: callback.scalar('train_loss', i / steps_per_epoch + epoch, losses.avg, title='Train loss') callback.scalar('@1', i / steps_per_epoch + epoch, top1.avg, title='Train Accuracy') for param_group in optimizer.param_groups: current_lr = param_group['lr'] if callback is not None: callback.scalar('lr', epoch, current_lr, title='Learning rate')