示例#1
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.epoch >= args.curriculum),
    )
    update_freq = (args.update_freq[epoch_itr.epoch - 1]
                   if epoch_itr.epoch <= len(args.update_freq) else
                   args.update_freq[-1])
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch_itr.epoch,
        no_progress_bar='simple',
    )

    # task specific setup per epoch
    task.begin_epoch(epoch_itr.epoch, trainer.get_model())

    valid_subsets = args.valid_subset.split(',')
    max_update = args.max_update or math.inf
    if hasattr(trainer.criterion, 'set_epoch'):
        trainer.criterion.set_epoch(epoch_itr.epoch)
    for samples in progress:
        if hasattr(trainer.criterion, 'set_num_updates'):
            trainer.criterion.set_num_updates(trainer.get_num_updates())

        log_output = trainer.train_step(samples)
        num_updates = trainer.get_num_updates()
        if log_output is None:
            continue

        # log mid-epoch stats
        stats = get_training_stats(metrics.get_smoothed_values('train'))
        progress.log(stats, tag='train', step=num_updates)

        if (not args.disable_validation and args.save_interval_updates > 0
                and num_updates % args.save_interval_updates == 0
                and num_updates > 0):
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        if num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(metrics.get_smoothed_values('train'))
    progress.print(stats, tag='train', step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters('train')
示例#2
0
def get_valid_stats(args, trainer):
    stats = metrics.get_smoothed_values('valid')
    if 'valid_nll_loss' in stats and 'ppl' not in stats:
        stats['valid_ppl'] = utils.get_perplexity(stats['nll_loss'])
    stats['num_updates'] = trainer.get_num_updates()
    if hasattr(checkpoint_utils.save_checkpoint, 'best'):
        key = 'best_{0}'.format(args.best_checkpoint_metric)
        best_function = max if args.maximize_best_checkpoint_metric else min
        stats[key] = best_function(
            checkpoint_utils.save_checkpoint.best,
            stats[args.best_checkpoint_metric],
        )
    return stats
    def test_nested_duplicate_names(self):
        name = str(uuid.uuid4())
        metrics.reset_meters(name)

        with metrics.aggregate(name):
            metrics.log_scalar('loss', 1)
            with metrics.aggregate() as other:
                with metrics.aggregate(name):
                    metrics.log_scalar('loss', 2)
            metrics.log_scalar('loss', 6)

        self.assertEqual(metrics.get_smoothed_values(name)['loss'], 3)
        self.assertEqual(other.get_smoothed_values()['loss'], 2)
    def test_named(self):
        name = str(uuid.uuid4())
        metrics.reset_meters(name)

        with metrics.aggregate(name):
            metrics.log_scalar('loss', 1)

        metrics.log_scalar('loss', 3)

        with metrics.aggregate(name):
            metrics.log_scalar('loss', 2)

        self.assertEqual(metrics.get_smoothed_values(name)['loss'], 1.5)
示例#5
0
def get_training_stats(stats_key):
    stats = metrics.get_smoothed_values(stats_key)
    if 'nll_loss' in stats and 'ppl' not in stats:
        stats['ppl'] = utils.get_perplexity(stats['nll_loss'])
    stats['wall'] = round(metrics.get_meter('default', 'wall').elapsed_time, 0)
    return stats
示例#6
0
文件: train.py 项目: phlrain/example
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.epoch >= args.curriculum),
    )
    update_freq = (args.update_freq[epoch_itr.epoch - 1]
                   if epoch_itr.epoch <= len(args.update_freq) else
                   args.update_freq[-1])
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch_itr.epoch,
        no_progress_bar='simple',
    )

    # task specific setup per epoch
    task.begin_epoch(epoch_itr.epoch, trainer.get_model())

    valid_subsets = args.valid_subset.split(',')
    max_update = args.max_update or math.inf

    start_time = time.time()
    step = 0
    for samples in progress:
        log_output = trainer.train_step(samples)
        num_updates = trainer.get_num_updates()
        step += 1
        """
        if step % 10 == 0:
            print(step)

        if step >= 200:
            pr.disable()
            #pr.dump_stats( "torch_profile")

            sys.exit()

        step += 1
        """

        if log_output is None:
            continue

        # log mid-epoch stats
        stats = get_training_stats(metrics.get_smoothed_values('train'))
        progress.log(stats, tag='train', step=num_updates)

        if (not args.disable_validation and args.save_interval_updates > 0
                and num_updates % args.save_interval_updates == 0
                and num_updates > 0):
            print("validate and save_checkpoint")
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        if num_updates >= max_update:
            break

    train_epoch_cost = time.time() - start_time

    # log end-of-epoch stats
    stats = get_training_stats(metrics.get_smoothed_values('train'))
    progress.print(stats, tag='train', step=num_updates)
    print("epoch_cost: %.5f s, avg_speed: %.5f steps/s" %
          (train_epoch_cost, float(step) / train_epoch_cost))

    # reset epoch-level meters
    metrics.reset_meters('train')