def main(): # Load config yaml file parser = argparse.ArgumentParser() parser.add_argument('--path_config', default='config/default.yaml', type=str, help='path to a yaml config file') args = parser.parse_args() if args.path_config is not None: with open(args.path_config, 'r') as handle: config = yaml.load(handle) cudnn.benchmark = True # Generate dataset and loader print("Loading samples to predict from %s" % os.path.join( config['annotations']['dir'], config['prediction']['split'] + '.json')) # Load annotations path_annotations = os.path.join(config['annotations']['dir'], config['prediction']['split'] + '.json') input_annotations = json.load(open(path_annotations, 'r')) # Data loader and dataset input_loader = vqa_dataset.get_loader(config, split=config['prediction']['split']) # Load model weights print("Loading Model from %s" % config['prediction']['model_path']) log = torch.load(config['prediction']['model_path']) # Num tokens seen during training num_tokens = len(log['vocabs']['question']) + 1 # Use the same configuration used during training train_config = log['config'] model = models.Model(train_config, num_tokens).to(dev) dict_weights = log['weights'] model.load_state_dict(dict_weights) predicted, samples_ids = predict_answers( model, input_loader, split=config['prediction']['split']) submission = create_submission(input_annotations, predicted, samples_ids, input_loader.dataset.vocabs) with open(config['prediction']['submission_file'], 'w') as fd: json.dump(submission, fd) print("Submission file saved in %s" % config['prediction']['submission_file'])
def main(): # Load config yaml file parser = argparse.ArgumentParser() parser.add_argument('--path_config', default='config/default.yaml', type=str, help='path to a yaml config file') args = parser.parse_args() if args.path_config is not None: with open(args.path_config, 'r') as handle: config = yaml.load(handle) # generate log directory dir_name = datetime.now().strftime("%Y-%m-%d_%H:%M:%S") path_log_dir = os.path.join(config['logs']['dir_logs'], dir_name) if not os.path.exists(path_log_dir): os.makedirs(path_log_dir) print('Model logs will be saved in {}'.format(path_log_dir)) cudnn.benchmark = True # Generate datasets and loaders train_loader = vqa_dataset.get_loader(config, split='train') val_loader = vqa_dataset.get_loader(config, split='val') model = models.Model(config, train_loader.dataset.num_tokens).to(dev) optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), config['training']['lr']) # Load model weights if necessary if config['model']['pretrained_model'] is not None: print("Loading Model from %s" % config['model']['pretrained_model']) log = torch.load(config['model']['pretrained_model']) dict_weights = log['weights'] model.load_state_dict(dict_weights) tracker = utils.Tracker() min_loss = 10 max_accuracy = 0 path_best_accuracy = os.path.join(path_log_dir, 'best_accuracy_log.pth') path_best_loss = os.path.join(path_log_dir, 'best_loss_log.pth') for i in range(config['training']['epochs']): train(model, train_loader, optimizer, tracker, epoch=i, split=config['training']['train_split']) # If we are training on the train split (and not on train+val) we can evaluate on val if config['training']['train_split'] == 'train': eval_results = evaluate(model, val_loader, tracker, epoch=i, split='val') # save all the information in the log file log_data = { 'epoch': i, 'tracker': tracker.to_dict(), 'config': config, 'weights': model.state_dict(), 'eval_results': eval_results, 'vocabs': train_loader.dataset.vocabs, } # save logs for min validation loss and max validation accuracy if eval_results['avg_loss'] < min_loss: # torch.save(log_data, path_best_loss) # save model min_loss = eval_results['avg_loss'] # update min loss value if eval_results['avg_accuracy'] > max_accuracy: # torch.save(log_data, path_best_accuracy) # save model max_accuracy = eval_results[ 'avg_accuracy'] # update max accuracy value # Save final model log_data = { 'tracker': tracker.to_dict(), 'config': config, 'weights': model.state_dict(), 'vocabs': train_loader.dataset.vocabs, } path_final_log = os.path.join(path_log_dir, 'final_log.pth') torch.save(log_data, path_final_log)