Ejemplo n.º 1
0
def validate(args, epoch, trainer, criterion, dataset, subset, ngpus):
    """Evaluate the model on the validation set and return the average loss."""

    itr = dataset.dataloader(subset,
                             batch_size=None,
                             max_tokens=args.max_tokens,
                             max_positions=args.max_positions,
                             skip_invalid_size_inputs_valid_test=args.
                             skip_invalid_size_inputs_valid_test)
    loss_meter = AverageMeter()

    desc = '| epoch {:03d} | valid on \'{}\' subset'.format(epoch, subset)
    with progress_bar(itr, desc, leave=False) as t:
        for _, sample in data.skip_group_enumerator(t, ngpus):
            ntokens = sum(s['ntokens'] for s in sample)
            loss = trainer.valid_step(sample, criterion)
            loss_meter.update(loss, ntokens)
            t.set_postfix(loss='{:.2f}'.format(loss_meter.avg), refresh=False)

        val_loss = loss_meter.avg
        t.write(desc + ' | valid loss {:2.2f} | valid ppl {:3.2f}'.format(
            val_loss, math.pow(2, val_loss)))

    # update and return the learning rate
    return val_loss
Ejemplo n.º 2
0
def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus):
    """Train the model for one epoch."""

    itr = dataset.dataloader(
        args.train_subset,
        num_workers=args.workers,
        max_tokens=args.max_tokens,
        seed=args.seed,
        epoch=epoch,
        max_positions=args.max_positions,
        sample_without_replacement=args.sample_without_replacement,
        skip_invalid_size_inputs_valid_test=args.
        skip_invalid_size_inputs_valid_test)
    loss_meter = AverageMeter()
    bsz_meter = AverageMeter()  # sentences per batch
    wpb_meter = AverageMeter()  # words per batch
    wps_meter = TimeMeter()  # words per second
    clip_meter = AverageMeter()  # % of updates clipped
    gnorm_meter = AverageMeter()  # gradient norm

    desc = '| epoch {:03d}'.format(epoch)
    lr = trainer.get_lr()
    with progress_bar(itr, desc, leave=False) as t:
        for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset):
            loss, grad_norm = trainer.train_step(sample, criterion)

            ntokens = sum(s['ntokens'] for s in sample)
            src_size = sum(s['src_tokens'].size(0) for s in sample)
            loss_meter.update(loss, ntokens)
            bsz_meter.update(src_size)
            wpb_meter.update(ntokens)
            wps_meter.update(ntokens)
            clip_meter.update(1 if grad_norm > args.clip_norm else 0)
            gnorm_meter.update(grad_norm)

            t.set_postfix(collections.OrderedDict([
                ('loss', '{:.2f} ({:.2f})'.format(loss, loss_meter.avg)),
                ('wps', '{:5d}'.format(round(wps_meter.avg))),
                ('wpb', '{:5d}'.format(round(wpb_meter.avg))),
                ('bsz', '{:5d}'.format(round(bsz_meter.avg))),
                ('lr', lr),
                ('clip', '{:3.0f}%'.format(clip_meter.avg * 100)),
                ('gnorm', '{:.4f}'.format(gnorm_meter.avg)),
            ]),
                          refresh=False)

            if i == 0:
                # ignore the first mini-batch in words-per-second calculation
                wps_meter.reset()
            if args.save_interval > 0 and (i + 1) % args.save_interval == 0:
                trainer.save_checkpoint(args, epoch, i + 1)

        fmt = desc + ' | train loss {:2.2f} | train ppl {:3.2f}'
        fmt += ' | s/checkpoint {:7d} | words/s {:6d} | words/batch {:6d}'
        fmt += ' | bsz {:5d} | lr {:0.6f} | clip {:3.0f}% | gnorm {:.4f}'
        t.write(
            fmt.format(loss_meter.avg, math.pow(2, loss_meter.avg),
                       round(wps_meter.elapsed_time), round(wps_meter.avg),
                       round(wpb_meter.avg), round(bsz_meter.avg), lr,
                       clip_meter.avg * 100, gnorm_meter.avg))
Ejemplo n.º 3
0
 def build_progress_bar(
     self,
     epoch: Optional[int] = None,
     prefix: Optional[str] = None,
     default_log_format: str = "tqdm",
 ) -> BaseProgressBar:
     return progress_bar.progress_bar(
         iterator=self.get_dataset_itr(),
         log_format=self.cfg.common.log_format,
         log_interval=self.cfg.common.log_interval,
         epoch=epoch,
         prefix=prefix,
         tensorboard_logdir=self.cfg.common.tensorboard_logdir,
         default_log_format=default_log_format,
     )
Ejemplo n.º 4
0
def validate(args, epoch, trainer, dataset, subset, ngpus):
    """Evaluate the model on the validation set and return the average loss."""

    itr = dataset.dataloader(subset,
                             batch_size=None,
                             max_tokens=args.max_tokens,
                             max_positions=args.max_positions,
                             skip_invalid_size_inputs_valid_test=args.
                             skip_invalid_size_inputs_valid_test)
    loss_meter = AverageMeter()
    extra_meters = collections.defaultdict(lambda: AverageMeter())

    desc = '| epoch {:03d} | valid on \'{}\' subset'.format(epoch, subset)
    with progress_bar(itr, desc, leave=False) as t:
        for _, sample in data.skip_group_enumerator(t, ngpus):
            loss_dict = trainer.valid_step(sample)
            loss = loss_dict['loss']
            del loss_dict[
                'loss']  # don't include in extra_meters or extra_postfix

            ntokens = sum(s['ntokens'] for s in sample)
            loss_meter.update(loss, ntokens)

            extra_postfix = []
            for k, v in loss_dict.items():
                extra_meters[k].update(v)
                extra_postfix.append((k, '{:.4f}'.format(extra_meters[k].avg)))

            t.set_postfix(collections.OrderedDict([
                ('loss', '{:.2f}'.format(loss_meter.avg)),
            ] + extra_postfix),
                          refresh=False)

        val_loss = loss_meter.avg
        fmt = desc + ' | valid loss {:2.2f} | valid ppl {:3.2f}'.format(
            val_loss, get_perplexity(val_loss))
        fmt += ''.join(' | {} {:.4f}'.format(k, meter.avg)
                       for k, meter in extra_meters.items())
        t.write(fmt)

    # update and return the learning rate
    return val_loss
Ejemplo n.º 5
0
def validate(args, epoch, trainer, criterion, dataset, subset, ngpus):
    """Evaluate the model on the validation set and return the average loss."""

    itr = dataset.dataloader(subset,
                             batch_size=None,
                             max_tokens=args.max_tokens,
                             max_positions=args.max_positions)
    loss_meter = AverageMeter()
    rouge_greedy_meter = AverageMeter()
    rouge_sampled_meter = AverageMeter()

    desc = '| epoch {:03d} | valid on \'{}\' subset'.format(epoch, subset)
    with progress_bar(itr, desc, leave=False) as t:
        for _, sample in data.skip_group_enumerator(t, ngpus):
            ntokens = sum(s['ntokens'] for s in sample)
            loss, mean_rouge_greedy, mean_rouge_sampled = trainer.valid_step(
                sample, criterion)
            loss_meter.update(loss, ntokens)
            rouge_greedy_meter.update(mean_rouge_greedy, 1)
            rouge_sampled_meter.update(mean_rouge_sampled, 1)
            t.set_postfix(
                collections.OrderedDict([
                    ('loss', '{:.2f}'.format(loss_meter.avg)),
                    ('ROUGE-L/f (greedy)',
                     '{:.4f}'.format(rouge_greedy_meter.avg)),
                    ('ROUGE-L/f (sampled)',
                     '{:.4f}'.format(rouge_sampled_meter.avg))
                ]))

        val_loss = loss_meter.avg
        t.write(
            desc +
            ' | valid loss {:2.2f} | valid ppl {:3.2f} | ROUGE-L (greedy): {:.4f} | ROUGE-L (sampled): {:.4f}'
            .format(val_loss, math.pow(2, val_loss), rouge_greedy_meter.avg,
                    rouge_sampled_meter.avg))

    # update and return the learning rate
    return val_loss
Ejemplo n.º 6
0
def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
    """Train the model for one epoch."""

    itr = dataset.dataloader(
        args.train_subset,
        num_workers=args.workers,
        max_tokens=args.max_tokens,
        seed=args.seed,
        epoch=epoch,
        max_positions=args.max_positions,
        sample_without_replacement=args.sample_without_replacement,
        skip_invalid_size_inputs_valid_test=args.
        skip_invalid_size_inputs_valid_test)
    loss_meter = AverageMeter()
    bsz_meter = AverageMeter()  # sentences per batch
    wpb_meter = AverageMeter()  # words per batch
    wps_meter = TimeMeter()  # words per second
    clip_meter = AverageMeter()  # % of updates clipped
    extra_meters = collections.defaultdict(lambda: AverageMeter())

    desc = '| epoch {:03d}'.format(epoch)
    trainer.set_seed(args.seed + epoch)
    lr = trainer.get_lr()
    with progress_bar(itr, desc, leave=False) as t:
        for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset):
            loss_dict = trainer.train_step(sample)
            loss = loss_dict['loss']
            del loss_dict[
                'loss']  # don't include in extra_meters or extra_postfix

            ntokens = sum(s['ntokens'] for s in sample)
            src_size = sum(s['src_tokens'].size(0) for s in sample)
            loss_meter.update(loss, ntokens)
            bsz_meter.update(src_size)
            wpb_meter.update(ntokens)
            wps_meter.update(ntokens)
            clip_meter.update(1 if loss_dict['gnorm'] > args.clip_norm else 0)

            extra_postfix = []
            for k, v in loss_dict.items():
                extra_meters[k].update(v)
                extra_postfix.append((k, '{:.4f}'.format(extra_meters[k].avg)))

            t.set_postfix(collections.OrderedDict([
                ('loss', '{:.2f} ({:.2f})'.format(loss, loss_meter.avg)),
                ('wps', '{:5d}'.format(round(wps_meter.avg))),
                ('wpb', '{:5d}'.format(round(wpb_meter.avg))),
                ('bsz', '{:5d}'.format(round(bsz_meter.avg))),
                ('lr', lr),
                ('clip', '{:3.0f}%'.format(clip_meter.avg * 100)),
            ] + extra_postfix),
                          refresh=False)

            if i == 0:
                # ignore the first mini-batch in words-per-second calculation
                wps_meter.reset()
            if args.save_interval > 0 and (i + 1) % args.save_interval == 0:
                save_checkpoint(trainer, args, epoch, i + 1)

        fmt = desc + ' | train loss {:2.2f} | train ppl {:3.2f}'.format(
            loss_meter.avg, get_perplexity(loss_meter.avg))
        fmt += ' | s/checkpoint {:7d} | words/s {:6d} | words/batch {:6d}'.format(
            round(wps_meter.elapsed_time), round(wps_meter.avg),
            round(wpb_meter.avg))
        fmt += ' | bsz {:5d} | lr {:0.6f} | clip {:3.0f}%'.format(
            round(bsz_meter.avg), lr, clip_meter.avg * 100)
        fmt += ''.join(' | {} {:.4f}'.format(k, meter.avg)
                       for k, meter in extra_meters.items())
        t.write(fmt)
Ejemplo n.º 7
0
def main():
    parser = options.get_parser('Generation')
    parser.add_argument('--path', metavar='FILE', required=True, action='append',
                        help='path(s) to model file(s)')
    dataset_args = options.add_dataset_args(parser)
    dataset_args.add_argument('-i', '--interactive', action='store_true',
                              help='generate translations in interactive mode')
    dataset_args.add_argument('--batch-size', default=32, type=int, metavar='N',
                              help='batch size')
    dataset_args.add_argument('--gen-subset', default='test', metavar='SPLIT',
                              help='data subset to generate (train, valid, test)')
    options.add_generation_args(parser)

    args = parser.parse_args()
    print(args)

    if args.no_progress_bar:
        progress_bar.enabled = False
    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load model and dataset
    print('| loading model(s) from {}'.format(', '.join(args.path)))
    models, dataset = utils.load_ensemble_for_inference(args.path, args.data)

    print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
    if not args.interactive:
        print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset.splits[args.gen_subset])))

    # Optimize model for generation
    for model in models:
        model.make_generation_fast_(not args.no_beamable_mm)

    # Initialize generator
    translator = SequenceGenerator(models, dataset.dst_dict, beam_size=args.beam,
                                   stop_early=(not args.no_early_stop),
                                   normalize_scores=(not args.unnormalized),
                                   len_penalty=args.lenpen)
    align_dict = {}
    if args.unk_replace_dict != '':
        assert args.interactive, "Unkown words replacing requires access to original source and is only" \
                                 "supported in interactive mode"
        with open(args.unk_replace_dict, 'r') as f:
            for line in f:
                l = line.split()
                align_dict[l[0]] = l[1]

    def replace_unk(hypo_str, align_str, src, unk):
        hypo_tokens = hypo_str.split()
        src_tokens = tokenizer.tokenize_line(src)
        align_idx = [int(i) for i in align_str.split()]
        for i, ht in enumerate(hypo_tokens):
            if ht == unk:
                src_token = src_tokens[align_idx[i]]
                if src_token in align_dict:
                    hypo_tokens[i] = align_dict[src_token]
                else:
                    hypo_tokens[i] = src_token
        return ' '.join(hypo_tokens)

    if use_cuda:
        translator.cuda()

    bpe_symbol = '@@ ' if args.remove_bpe else None
    def display_hypotheses(id, src, orig, ref, hypos):
        id_str = '' if id is None else '-{}'.format(id)
        src_str = to_sentence(dataset.src_dict, src, bpe_symbol)
        print('S{}\t{}'.format(id_str, src_str))
        if orig is not None:
            print('O{}\t{}'.format(id_str, orig.strip()))
        if ref is not None:
            print('T{}\t{}'.format(id_str, to_sentence(dataset.dst_dict, ref, bpe_symbol, ref_unk=True)))
        for hypo in hypos:
            hypo_str = to_sentence(dataset.dst_dict, hypo['tokens'], bpe_symbol)
            align_str = ' '.join(map(str, hypo['alignment']))
            if args.unk_replace_dict != '':
                hypo_str = replace_unk(hypo_str, align_str, orig, unk_symbol(dataset.dst_dict))
            print('H{}\t{}\t{}'.format(
                id_str, hypo['score'], hypo_str))
            print('A{}\t{}'.format(id_str, align_str))

    if args.interactive:
        for line in sys.stdin:
            tokens = tokenizer.Tokenizer.tokenize(line, dataset.src_dict, add_if_not_exist=False).long()
            start = dataset.src_dict.pad() + 1
            positions = torch.arange(start, start + len(tokens)).type_as(tokens)
            if use_cuda:
                positions = positions.cuda()
                tokens = tokens.cuda()
            translations = translator.generate(Variable(tokens.view(1, -1)), Variable(positions.view(1, -1)))
            hypos = translations[0]
            display_hypotheses(None, tokens, line, None, hypos[:min(len(hypos), args.nbest)])

    else:
        def maybe_remove_bpe(tokens):
            """Helper for removing BPE symbols from a hypothesis."""
            if not args.remove_bpe:
                return tokens
            assert (tokens == dataset.dst_dict.pad()).sum() == 0
            hypo_minus_bpe = to_sentence(dataset.dst_dict, tokens, bpe_symbol)
            return tokenizer.Tokenizer.tokenize(hypo_minus_bpe, dataset.dst_dict, add_if_not_exist=True)

        # Generate and compute BLEU score
        scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk())
        itr = dataset.dataloader(args.gen_subset, batch_size=args.batch_size, max_positions=args.max_positions)
        num_sentences = 0
        with progress_bar(itr, smoothing=0, leave=False) as t:
            wps_meter = TimeMeter()
            gen_timer = StopwatchMeter()
            translations = translator.generate_batched_itr(
                t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
                cuda_device=0 if use_cuda else None, timer=gen_timer)
            for id, src, ref, hypos in translations:
                ref = ref.int().cpu()
                top_hypo = hypos[0]['tokens'].int().cpu()
                scorer.add(maybe_remove_bpe(ref), maybe_remove_bpe(top_hypo))
                display_hypotheses(id, src, None, ref, hypos[:min(len(hypos), args.nbest)])

                wps_meter.update(src.size(0))
                t.set_postfix(wps='{:5d}'.format(round(wps_meter.avg)))
                num_sentences += 1

        print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} tokens/s)'.format(
            num_sentences, gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
        print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
Ejemplo n.º 8
0
def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus):
    """Train the model for one epoch."""

    itr = dataset.dataloader(
        args.train_subset,
        num_workers=args.workers,
        max_tokens=args.max_tokens,
        seed=args.seed,
        epoch=epoch,
        max_positions=args.max_positions,
        sample_without_replacement=args.sample_without_replacement)
    ###print("itr:"+str(itr))
    loss_meter = AverageMeter()
    bsz_meter = AverageMeter()  # sentences per batch
    wpb_meter = AverageMeter()  # words per batch
    wps_meter = TimeMeter()  # words per second
    clip_meter = AverageMeter()  # % of updates clipped
    gnorm_meter = AverageMeter()  # gradient norm

    desc = '| epoch {:03d}'.format(epoch)
    lr = trainer.get_lr()
    with progress_bar(itr, desc, leave=False) as t:
        for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset):
            ###print("i:"+str(i)+" sample:"+str(sample)) ###id,src_tokens,input_tokens,input_positions,target,src_positions,ntokens
            ###print("i:"+str(i)+" sample len:"+str(len(sample))+" sample id:"+str(sample[0]['id'])+" sample src_tokens:"+str(sample[0]['src_tokens'][0]))
            aggregate_res = trainer.train_step(sample, criterion)
            mixed_loss = aggregate_res.loss
            ml_loss = aggregate_res.ml_loss
            grad_norm = aggregate_res.grad_norm
            mixed_loss = aggregate_res.loss
            rl_loss = aggregate_res.rl_loss
            mean_rouge_greedy = aggregate_res.mean_rouge_greedy
            mean_rouge_sampled = aggregate_res.mean_rouge_sampled
            mean_sum_log_prob = aggregate_res.mean_sum_log_prob

            ntokens = sum(s['ntokens'] for s in sample)
            src_size = sum(s['src_tokens'].size(0) for s in sample)
            loss_meter.update(ml_loss, ntokens)
            bsz_meter.update(src_size)
            wpb_meter.update(ntokens)
            wps_meter.update(ntokens)
            clip_meter.update(1 if grad_norm > args.clip_norm else 0)
            gnorm_meter.update(grad_norm)

            t.set_postfix(
                collections.OrderedDict([
                    ('loss', '{:.2f} ({:.2f})'.format(ml_loss,
                                                      loss_meter.avg)),
                    ('wps', '{:5d}'.format(round(wps_meter.avg))),
                    ('wpb', '{:5d}'.format(round(wpb_meter.avg))),
                    ('bsz', '{:5d}'.format(round(bsz_meter.avg))),
                    ('lr', lr),
                    ('clip', '{:3.0f}%'.format(clip_meter.avg * 100)),
                    ('gnorm', '{:.4f}'.format(gnorm_meter.avg)),
                ]))

            if args.enable_rl:
                fmt_other = 'mixed_loss: {:^10.4f} | ml_loss: {:^10.4f}'
                fmt_other += '| rl_loss: {:^10.4f} | mean_rouge_greedy: {:^10.4f}'
                fmt_other += '| mean_rouge_sampled: {:^10.4f} | mean_sum_log_prob: {:^10.4f}'
                print(
                    fmt_other.format(mixed_loss, ml_loss, rl_loss,
                                     mean_rouge_greedy, mean_rouge_sampled,
                                     mean_sum_log_prob))

            if i == 0:
                # ignore the first mini-batch in words-per-second calculation
                wps_meter.reset()
            if args.save_interval > 0 and (i + 1) % args.save_interval == 0:
                trainer.save_checkpoint(args, epoch, i + 1)

        fmt = desc + ' | train loss {:2.2f} | train ppl {:3.2f}'
        fmt += ' | s/checkpoint {:7d} | words/s {:6d} | words/batch {:6d}'
        fmt += ' | bsz {:5d} | lr {:0.6f} | clip {:3.0f}% | gnorm {:.4f}'
        t.write(
            fmt.format(loss_meter.avg, math.pow(2, loss_meter.avg),
                       round(wps_meter.elapsed_time), round(wps_meter.avg),
                       round(wpb_meter.avg), round(bsz_meter.avg), lr,
                       clip_meter.avg * 100, gnorm_meter.avg))