Exemplo n.º 1
0
def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus):
    """Train the model for one epoch."""

    itr = dataset.dataloader(
        args.train_subset,
        num_workers=args.workers,
        max_tokens=args.max_tokens,
        seed=args.seed,
        epoch=epoch,
        max_positions=args.max_positions,
        sample_without_replacement=args.sample_without_replacement,
        skip_invalid_size_inputs_valid_test=args.
        skip_invalid_size_inputs_valid_test)
    loss_meter = AverageMeter()
    bsz_meter = AverageMeter()  # sentences per batch
    wpb_meter = AverageMeter()  # words per batch
    wps_meter = TimeMeter()  # words per second
    clip_meter = AverageMeter()  # % of updates clipped
    gnorm_meter = AverageMeter()  # gradient norm

    desc = '| epoch {:03d}'.format(epoch)
    lr = trainer.get_lr()
    with progress_bar(itr, desc, leave=False) as t:
        for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset):
            loss, grad_norm = trainer.train_step(sample, criterion)

            ntokens = sum(s['ntokens'] for s in sample)
            src_size = sum(s['src_tokens'].size(0) for s in sample)
            loss_meter.update(loss, ntokens)
            bsz_meter.update(src_size)
            wpb_meter.update(ntokens)
            wps_meter.update(ntokens)
            clip_meter.update(1 if grad_norm > args.clip_norm else 0)
            gnorm_meter.update(grad_norm)

            t.set_postfix(collections.OrderedDict([
                ('loss', '{:.2f} ({:.2f})'.format(loss, loss_meter.avg)),
                ('wps', '{:5d}'.format(round(wps_meter.avg))),
                ('wpb', '{:5d}'.format(round(wpb_meter.avg))),
                ('bsz', '{:5d}'.format(round(bsz_meter.avg))),
                ('lr', lr),
                ('clip', '{:3.0f}%'.format(clip_meter.avg * 100)),
                ('gnorm', '{:.4f}'.format(gnorm_meter.avg)),
            ]),
                          refresh=False)

            if i == 0:
                # ignore the first mini-batch in words-per-second calculation
                wps_meter.reset()
            if args.save_interval > 0 and (i + 1) % args.save_interval == 0:
                trainer.save_checkpoint(args, epoch, i + 1)

        fmt = desc + ' | train loss {:2.2f} | train ppl {:3.2f}'
        fmt += ' | s/checkpoint {:7d} | words/s {:6d} | words/batch {:6d}'
        fmt += ' | bsz {:5d} | lr {:0.6f} | clip {:3.0f}% | gnorm {:.4f}'
        t.write(
            fmt.format(loss_meter.avg, math.pow(2, loss_meter.avg),
                       round(wps_meter.elapsed_time), round(wps_meter.avg),
                       round(wpb_meter.avg), round(bsz_meter.avg), lr,
                       clip_meter.avg * 100, gnorm_meter.avg))
Exemplo n.º 2
0
def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
    """Train the model for one epoch."""

    itr = dataset.dataloader(
        args.train_subset,
        num_workers=args.workers,
        max_tokens=args.max_tokens,
        seed=args.seed,
        epoch=epoch,
        max_positions=args.max_positions,
        sample_without_replacement=args.sample_without_replacement,
        skip_invalid_size_inputs_valid_test=args.
        skip_invalid_size_inputs_valid_test)
    loss_meter = AverageMeter()
    bsz_meter = AverageMeter()  # sentences per batch
    wpb_meter = AverageMeter()  # words per batch
    wps_meter = TimeMeter()  # words per second
    clip_meter = AverageMeter()  # % of updates clipped
    extra_meters = collections.defaultdict(lambda: AverageMeter())

    desc = '| epoch {:03d}'.format(epoch)
    trainer.set_seed(args.seed + epoch)
    lr = trainer.get_lr()
    with progress_bar(itr, desc, leave=False) as t:
        for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset):
            loss_dict = trainer.train_step(sample)
            loss = loss_dict['loss']
            del loss_dict[
                'loss']  # don't include in extra_meters or extra_postfix

            ntokens = sum(s['ntokens'] for s in sample)
            src_size = sum(s['src_tokens'].size(0) for s in sample)
            loss_meter.update(loss, ntokens)
            bsz_meter.update(src_size)
            wpb_meter.update(ntokens)
            wps_meter.update(ntokens)
            clip_meter.update(1 if loss_dict['gnorm'] > args.clip_norm else 0)

            extra_postfix = []
            for k, v in loss_dict.items():
                extra_meters[k].update(v)
                extra_postfix.append((k, '{:.4f}'.format(extra_meters[k].avg)))

            t.set_postfix(collections.OrderedDict([
                ('loss', '{:.2f} ({:.2f})'.format(loss, loss_meter.avg)),
                ('wps', '{:5d}'.format(round(wps_meter.avg))),
                ('wpb', '{:5d}'.format(round(wpb_meter.avg))),
                ('bsz', '{:5d}'.format(round(bsz_meter.avg))),
                ('lr', lr),
                ('clip', '{:3.0f}%'.format(clip_meter.avg * 100)),
            ] + extra_postfix),
                          refresh=False)

            if i == 0:
                # ignore the first mini-batch in words-per-second calculation
                wps_meter.reset()
            if args.save_interval > 0 and (i + 1) % args.save_interval == 0:
                save_checkpoint(trainer, args, epoch, i + 1)

        fmt = desc + ' | train loss {:2.2f} | train ppl {:3.2f}'.format(
            loss_meter.avg, get_perplexity(loss_meter.avg))
        fmt += ' | s/checkpoint {:7d} | words/s {:6d} | words/batch {:6d}'.format(
            round(wps_meter.elapsed_time), round(wps_meter.avg),
            round(wpb_meter.avg))
        fmt += ' | bsz {:5d} | lr {:0.6f} | clip {:3.0f}%'.format(
            round(bsz_meter.avg), lr, clip_meter.avg * 100)
        fmt += ''.join(' | {} {:.4f}'.format(k, meter.avg)
                       for k, meter in extra_meters.items())
        t.write(fmt)
Exemplo n.º 3
0
def train(args, epoch, batch_offset, trainer, dataset, max_positions,
          num_gpus):
    """Train the model for one epoch."""

    seed = args.seed + epoch
    torch.manual_seed(seed)
    trainer.set_seed(seed)

    itr = dataset.train_dataloader(
        args.train_subset,
        num_workers=args.workers,
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
        seed=seed,
        epoch=epoch,
        sample_without_replacement=args.sample_without_replacement,
        sort_by_source_size=(epoch <= args.curriculum))
    loss_meter = AverageMeter()
    bsz_meter = AverageMeter()  # sentences per batch
    wpb_meter = AverageMeter()  # words per batch
    wps_meter = TimeMeter()  # words per second
    clip_meter = AverageMeter()  # % of updates clipped
    extra_meters = collections.defaultdict(lambda: AverageMeter())

    lr = trainer.get_lr()
    with utils.build_progress_bar(args, itr, epoch) as t:
        for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset):
            loss_dict = trainer.train_step(sample)
            loss = loss_dict['loss']
            del loss_dict[
                'loss']  # don't include in extra_meters or extra_postfix

            ntokens = sum(s['ntokens'] for s in sample)
            nsentences = sum(s['src_tokens'].size(0) for s in sample)
            loss_meter.update(loss,
                              nsentences if args.sentence_avg else ntokens)
            bsz_meter.update(nsentences)
            wpb_meter.update(ntokens)
            wps_meter.update(ntokens)
            clip_meter.update(1 if loss_dict['gnorm'] > args.clip_norm else 0)

            extra_postfix = []
            for k, v in loss_dict.items():
                extra_meters[k].update(v)
                extra_postfix.append((k, extra_meters[k].avg))

            t.log(
                collections.OrderedDict([
                    ('loss', loss_meter),
                    ('wps', round(wps_meter.avg)),
                    ('wpb', round(wpb_meter.avg)),
                    ('bsz', round(bsz_meter.avg)),
                    ('lr', lr),
                    ('clip', '{:.0%}'.format(clip_meter.avg)),
                ] + extra_postfix))

            if i == 0:
                # ignore the first mini-batch in words-per-second calculation
                wps_meter.reset()
            if args.save_interval > 0 and (i + 1) % args.save_interval == 0:
                save_checkpoint(trainer, args, epoch, i + 1)

        t.print(
            collections.OrderedDict([
                ('train loss', round(loss_meter.avg, 2)),
                ('train ppl', get_perplexity(loss_meter.avg)),
                ('s/checkpoint', round(wps_meter.elapsed_time)),
                ('words/s', round(wps_meter.avg)),
                ('words/batch', round(wpb_meter.avg)),
                ('bsz', round(bsz_meter.avg)),
                ('lr', lr),
                ('clip', '{:3.0f}%'.format(clip_meter.avg * 100)),
            ] + [(k, meter.avg) for k, meter in extra_meters.items()]))
Exemplo n.º 4
0
def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
    """Train the model for one epoch."""

    seed = args.seed + epoch
    torch.manual_seed(seed)
    trainer.set_seed(seed)

    itr = dataset.train_dataloader(
        args.train_subset, num_workers=args.workers,
        max_tokens=args.max_tokens, max_sentences=args.max_sentences,
        max_positions=max_positions, seed=seed, epoch=epoch,
        sample_without_replacement=args.sample_without_replacement,
        sort_by_source_size=(epoch <= args.curriculum))
    loss_meter = AverageMeter()
    bsz_meter = AverageMeter()    # sentences per batch
    wpb_meter = AverageMeter()    # words per batch
    wps_meter = TimeMeter()       # words per second
    clip_meter = AverageMeter()   # % of updates clipped
    extra_meters = collections.defaultdict(lambda: AverageMeter())

    lr = trainer.get_lr()
    with utils.build_progress_bar(args, itr, epoch) as t:
        for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset):
            loss_dict = trainer.train_step(sample)
            loss = loss_dict['loss']
            del loss_dict['loss']  # don't include in extra_meters or extra_postfix

            ntokens = sum(s['ntokens'] for s in sample)
            nsentences = sum(s['src_tokens'].size(0) for s in sample)
            loss_meter.update(loss, nsentences if args.sentence_avg else ntokens)
            bsz_meter.update(nsentences)
            wpb_meter.update(ntokens)
            wps_meter.update(ntokens)
            clip_meter.update(1 if loss_dict['gnorm'] > args.clip_norm else 0)

            extra_postfix = []
            for k, v in loss_dict.items():
                extra_meters[k].update(v)
                extra_postfix.append((k, extra_meters[k].avg))

            t.log(collections.OrderedDict([
                ('loss', loss_meter),
                ('wps', round(wps_meter.avg)),
                ('wpb', round(wpb_meter.avg)),
                ('bsz', round(bsz_meter.avg)),
                ('lr', lr),
                ('clip', '{:.0%}'.format(clip_meter.avg)),
            ] + extra_postfix))

            if i == 0:
                # ignore the first mini-batch in words-per-second calculation
                wps_meter.reset()
            if args.save_interval > 0 and (i + 1) % args.save_interval == 0:
                save_checkpoint(trainer, args, epoch, i + 1)

        t.print(collections.OrderedDict([
            ('train loss', round(loss_meter.avg, 2)),
            ('train ppl', get_perplexity(loss_meter.avg)),
            ('s/checkpoint', round(wps_meter.elapsed_time)),
            ('words/s', round(wps_meter.avg)),
            ('words/batch', round(wpb_meter.avg)),
            ('bsz', round(bsz_meter.avg)),
            ('lr', lr),
            ('clip', '{:3.0f}%'.format(clip_meter.avg * 100)),
        ] + [
            (k, meter.avg)
            for k, meter in extra_meters.items()
        ]))
Exemplo n.º 5
0
def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus):
    """Train the model for one epoch."""

    itr = dataset.dataloader(
        args.train_subset,
        num_workers=args.workers,
        max_tokens=args.max_tokens,
        seed=args.seed,
        epoch=epoch,
        max_positions=args.max_positions,
        sample_without_replacement=args.sample_without_replacement)
    ###print("itr:"+str(itr))
    loss_meter = AverageMeter()
    bsz_meter = AverageMeter()  # sentences per batch
    wpb_meter = AverageMeter()  # words per batch
    wps_meter = TimeMeter()  # words per second
    clip_meter = AverageMeter()  # % of updates clipped
    gnorm_meter = AverageMeter()  # gradient norm

    desc = '| epoch {:03d}'.format(epoch)
    lr = trainer.get_lr()
    with progress_bar(itr, desc, leave=False) as t:
        for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset):
            ###print("i:"+str(i)+" sample:"+str(sample)) ###id,src_tokens,input_tokens,input_positions,target,src_positions,ntokens
            ###print("i:"+str(i)+" sample len:"+str(len(sample))+" sample id:"+str(sample[0]['id'])+" sample src_tokens:"+str(sample[0]['src_tokens'][0]))
            aggregate_res = trainer.train_step(sample, criterion)
            mixed_loss = aggregate_res.loss
            ml_loss = aggregate_res.ml_loss
            grad_norm = aggregate_res.grad_norm
            mixed_loss = aggregate_res.loss
            rl_loss = aggregate_res.rl_loss
            mean_rouge_greedy = aggregate_res.mean_rouge_greedy
            mean_rouge_sampled = aggregate_res.mean_rouge_sampled
            mean_sum_log_prob = aggregate_res.mean_sum_log_prob

            ntokens = sum(s['ntokens'] for s in sample)
            src_size = sum(s['src_tokens'].size(0) for s in sample)
            loss_meter.update(ml_loss, ntokens)
            bsz_meter.update(src_size)
            wpb_meter.update(ntokens)
            wps_meter.update(ntokens)
            clip_meter.update(1 if grad_norm > args.clip_norm else 0)
            gnorm_meter.update(grad_norm)

            t.set_postfix(
                collections.OrderedDict([
                    ('loss', '{:.2f} ({:.2f})'.format(ml_loss,
                                                      loss_meter.avg)),
                    ('wps', '{:5d}'.format(round(wps_meter.avg))),
                    ('wpb', '{:5d}'.format(round(wpb_meter.avg))),
                    ('bsz', '{:5d}'.format(round(bsz_meter.avg))),
                    ('lr', lr),
                    ('clip', '{:3.0f}%'.format(clip_meter.avg * 100)),
                    ('gnorm', '{:.4f}'.format(gnorm_meter.avg)),
                ]))

            if args.enable_rl:
                fmt_other = 'mixed_loss: {:^10.4f} | ml_loss: {:^10.4f}'
                fmt_other += '| rl_loss: {:^10.4f} | mean_rouge_greedy: {:^10.4f}'
                fmt_other += '| mean_rouge_sampled: {:^10.4f} | mean_sum_log_prob: {:^10.4f}'
                print(
                    fmt_other.format(mixed_loss, ml_loss, rl_loss,
                                     mean_rouge_greedy, mean_rouge_sampled,
                                     mean_sum_log_prob))

            if i == 0:
                # ignore the first mini-batch in words-per-second calculation
                wps_meter.reset()
            if args.save_interval > 0 and (i + 1) % args.save_interval == 0:
                trainer.save_checkpoint(args, epoch, i + 1)

        fmt = desc + ' | train loss {:2.2f} | train ppl {:3.2f}'
        fmt += ' | s/checkpoint {:7d} | words/s {:6d} | words/batch {:6d}'
        fmt += ' | bsz {:5d} | lr {:0.6f} | clip {:3.0f}% | gnorm {:.4f}'
        t.write(
            fmt.format(loss_meter.avg, math.pow(2, loss_meter.avg),
                       round(wps_meter.elapsed_time), round(wps_meter.avg),
                       round(wpb_meter.avg), round(bsz_meter.avg), lr,
                       clip_meter.avg * 100, gnorm_meter.avg))