def get_valid_stats(args, trainer): stats = metrics.get_smoothed_values('valid') if 'nll_loss' in stats and 'ppl' not in stats: stats['ppl'] = utils.get_perplexity(stats['nll_loss']) stats['num_updates'] = trainer.get_num_updates() if hasattr(checkpoint_utils.save_checkpoint, 'best'): key = 'best_{0}'.format(args.best_checkpoint_metric) best_function = max if args.maximize_best_checkpoint_metric else min stats[key] = best_function( checkpoint_utils.save_checkpoint.best, stats[args.best_checkpoint_metric], ) return stats
def get_training_stats(stats_key): stats = metrics.get_smoothed_values(stats_key) if 'nll_loss' in stats and 'ppl' not in stats: stats['ppl'] = utils.get_perplexity(stats['nll_loss']) stats['wall'] = round(metrics.get_meter('default', 'wall').elapsed_time, 0) return stats