def test_model(test_data, model, gen, args): ''' Run model on test data, and return loss, accuracy. ''' if args.cuda: model = model.cuda() gen = gen.cuda() test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, drop_last=False) test_stats = metrics.init_metrics_dictionary(modes=['test']) mode = 'Test' train_model = False key_prefix = mode.lower() print("-------------\nTest") epoch_details, _, losses, preds, golds, rationales, probas = run_epoch( data_loader=test_loader, train_model=train_model, model=model, gen=gen, optimizer=None, step=None, args=args) test_stats, log_statement = metrics.collate_epoch_stat( test_stats, epoch_details, 'test', args) test_stats['losses'] = losses test_stats['preds'] = preds test_stats['probas'] = probas test_stats['golds'] = golds test_stats['rationales'] = rationales print(log_statement) return test_stats
def train_model(train_data, dev_data, model, gen, args): ''' Train model and tune on dev set. If model doesn't improve dev performance within args.patience epochs, then halve the learning rate, restore the model to best and continue training. At the end of training, the function will restore the model to best dev version. returns epoch_stats: a dictionary of epoch level metrics for train and test returns model : best model from this call to train ''' if args.cuda: model = model.cuda() gen = gen.cuda() args.lr = args.init_lr optimizer = learn.get_optimizer([model, gen], args) num_epoch_sans_improvement = 0 epoch_stats = metrics.init_metrics_dictionary(modes=['train', 'dev']) step = 0 tuning_key = "dev_{}".format(args.tuning_metric) best_epoch_func = min if tuning_key == 'loss' else max train_loader = learn.get_train_loader(train_data, args) dev_loader = learn.get_dev_loader(dev_data, args) for epoch in range(1, args.epochs + 1): print("-------------\nEpoch {}:\n".format(epoch)) for mode, dataset, loader in [('Train', train_data, train_loader), ('Dev', dev_data, dev_loader)]: train_model = mode == 'Train' print('{}'.format(mode)) key_prefix = mode.lower() epoch_details, step, _, _, _, _, _ = run_epoch( data_loader=loader, train_model=train_model, model=model, gen=gen, optimizer=optimizer, step=step, args=args) epoch_stats, log_statement = metrics.collate_epoch_stat( epoch_stats, epoch_details, key_prefix, args) # Log performance print(log_statement) # Save model if beats best dev best_func = min if args.tuning_metric == 'loss' else max if best_func(epoch_stats[tuning_key]) == epoch_stats[tuning_key][-1]: num_epoch_sans_improvement = 0 if not os.path.isdir(args.save_dir): os.makedirs(args.save_dir) # Subtract one because epoch is 1-indexed and arr is 0-indexed epoch_stats['best_epoch'] = epoch - 1 torch.save(model, args.model_path) torch.save(gen, learn.get_gen_path(args.model_path)) else: num_epoch_sans_improvement += 1 if not train_model: print('---- Best Dev {} is {:.4f} at epoch {}'.format( args.tuning_metric, epoch_stats[tuning_key][epoch_stats['best_epoch']], epoch_stats['best_epoch'] + 1)) if num_epoch_sans_improvement >= args.patience: print("Reducing learning rate") num_epoch_sans_improvement = 0 model.cpu() gen.cpu() model = torch.load(args.model_path) gen = torch.load(learn.get_gen_path(args.model_path)) if args.cuda: model = model.cuda() gen = gen.cuda() args.lr *= .5 optimizer = learn.get_optimizer([model, gen], args) # Restore model to best dev performance if os.path.exists(args.model_path): model.cpu() model = torch.load(args.model_path) gen.cpu() gen = torch.load(learn.get_gen_path(args.model_path)) return epoch_stats, model, gen
def train_model(train_data, dev_data, model, gen, args): ''' Train model and tune on dev set. If model doesn't improve dev performance within args.patience epochs, then halve the learning rate, restore the model to best and continue training. At the end of training, the function will restore the model to best dev version. returns epoch_stats: a dictionary of epoch level metrics for train and test returns model : best model from this call to train ''' if args.cuda: model = model.cuda() gen = gen.cuda() args.lr = args.init_lr optimizer = utils.get_optimizer([model, gen], args) num_epoch_sans_improvement = 0 epoch_stats = metrics.init_metrics_dictionary(modes=['train', 'dev']) step = 0 if args.class_balance: sampler = torch.utils.data.sampler.WeightedRandomSampler( weights=train_data.weights, num_samples=len(train_data), replacement=True) train_loader = torch.utils.data.DataLoader( train_data, num_workers= args.num_workers, sampler=sampler, batch_size=args.batch_size) else: train_loader = torch.utils.data.DataLoader( train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True) dev_loader = torch.utils.data.DataLoader( dev_data, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, drop_last=False) for epoch in range(1, args.epochs + 1): print("-------------\nEpoch {}:\n".format(epoch)) for mode, dataset, loader in [('Train', train_data, train_loader), ('Dev', dev_data, dev_loader)]: train_model = mode == 'Train' print('{}'.format(mode)) key_prefix = mode.lower() epoch_details, step, _, _, _ = run_epoch( data_loader=loader, train_model=train_model, model=model, gen=gen, optimizer=optimizer, step=step, args=args) epoch_stats, log_statement = metrics.collate_epoch_stat(epoch_stats, epoch_details, key_prefix, args) # Log performance print(log_statement) if not train_model: print('---- Best Dev Loss is {:.4f}'.format( min(epoch_stats['dev_loss']))) # Save model if beats best dev if min(epoch_stats['dev_loss']) == epoch_stats['dev_loss'][-1]: num_epoch_sans_improvement = 0 if not os.path.isdir(args.save_dir): os.makedirs(args.save_dir) # Subtract one because epoch is 1-indexed and arr is 0-indexed epoch_stats['best_epoch'] = epoch - 1 torch.save(model, args.model_path) torch.save(gen, utils.get_gen_path(args.model_path)) else: num_epoch_sans_improvement += 1 if num_epoch_sans_improvement >= args.patience: print("Reducing learning rate") num_epoch_sans_improvement = 0 model.cpu() gen.cpu() model = torch.load(args.model_path) gen = torch.load(utils.get_gen_path(args.model_path)) if args.cuda: model = model.cuda() gen = gen.cuda() args.lr *= .5 optimizer = utils.get_optimizer([model, gen], args) # Restore model to best dev performance if os.path.exists(args.model_path): model.cpu() model = torch.load(args.model_path) gen.cpu() gen = torch.load(utils.get_gen_path(args.model_path)) return epoch_stats, model, gen