def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.epoch >= args.curriculum), ) update_freq = (args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ) # task specific setup per epoch task.begin_epoch(epoch_itr.epoch, trainer.get_model()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf if hasattr(trainer.criterion, 'set_epoch'): trainer.criterion.set_epoch(epoch_itr.epoch) for samples in progress: if hasattr(trainer.criterion, 'set_num_updates'): trainer.criterion.set_num_updates(trainer.get_num_updates()) log_output = trainer.train_step(samples) num_updates = trainer.get_num_updates() if log_output is None: continue # log mid-epoch stats stats = get_training_stats(metrics.get_smoothed_values('train')) progress.log(stats, tag='train', step=num_updates) if (not args.disable_validation and args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0): valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values('train')) progress.print(stats, tag='train', step=num_updates) # reset epoch-level meters metrics.reset_meters('train')
def get_valid_stats(args, trainer): stats = metrics.get_smoothed_values('valid') if 'valid_nll_loss' in stats and 'ppl' not in stats: stats['valid_ppl'] = utils.get_perplexity(stats['nll_loss']) stats['num_updates'] = trainer.get_num_updates() if hasattr(checkpoint_utils.save_checkpoint, 'best'): key = 'best_{0}'.format(args.best_checkpoint_metric) best_function = max if args.maximize_best_checkpoint_metric else min stats[key] = best_function( checkpoint_utils.save_checkpoint.best, stats[args.best_checkpoint_metric], ) return stats
def test_nested_duplicate_names(self): name = str(uuid.uuid4()) metrics.reset_meters(name) with metrics.aggregate(name): metrics.log_scalar('loss', 1) with metrics.aggregate() as other: with metrics.aggregate(name): metrics.log_scalar('loss', 2) metrics.log_scalar('loss', 6) self.assertEqual(metrics.get_smoothed_values(name)['loss'], 3) self.assertEqual(other.get_smoothed_values()['loss'], 2)
def test_named(self): name = str(uuid.uuid4()) metrics.reset_meters(name) with metrics.aggregate(name): metrics.log_scalar('loss', 1) metrics.log_scalar('loss', 3) with metrics.aggregate(name): metrics.log_scalar('loss', 2) self.assertEqual(metrics.get_smoothed_values(name)['loss'], 1.5)
def get_training_stats(stats_key): stats = metrics.get_smoothed_values(stats_key) if 'nll_loss' in stats and 'ppl' not in stats: stats['ppl'] = utils.get_perplexity(stats['nll_loss']) stats['wall'] = round(metrics.get_meter('default', 'wall').elapsed_time, 0) return stats
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.epoch >= args.curriculum), ) update_freq = (args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ) # task specific setup per epoch task.begin_epoch(epoch_itr.epoch, trainer.get_model()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf start_time = time.time() step = 0 for samples in progress: log_output = trainer.train_step(samples) num_updates = trainer.get_num_updates() step += 1 """ if step % 10 == 0: print(step) if step >= 200: pr.disable() #pr.dump_stats( "torch_profile") sys.exit() step += 1 """ if log_output is None: continue # log mid-epoch stats stats = get_training_stats(metrics.get_smoothed_values('train')) progress.log(stats, tag='train', step=num_updates) if (not args.disable_validation and args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0): print("validate and save_checkpoint") valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break train_epoch_cost = time.time() - start_time # log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values('train')) progress.print(stats, tag='train', step=num_updates) print("epoch_cost: %.5f s, avg_speed: %.5f steps/s" % (train_epoch_cost, float(step) / train_epoch_cost)) # reset epoch-level meters metrics.reset_meters('train')