Exemple #1
0
def validate(args, epoch, trainer, dataset, max_positions, subset):
    """Evaluate the model on the validation set and return the average loss."""

    itr = dataset.eval_dataloader(
        subset, max_tokens=args.max_tokens, max_sentences=args.max_sentences_valid,
        max_positions=max_positions,
        skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test,
        descending=True,  # largest batch first to warm the caching allocator
    )
    loss_meter = AverageMeter()
    nll_loss_meter = AverageMeter()
    extra_meters = collections.defaultdict(lambda: AverageMeter())

    prefix = 'valid on \'{}\' subset'.format(subset)
    with utils.build_progress_bar(args, itr, epoch, prefix) as t:
        for _, sample in data.skip_group_enumerator(t, args.num_gpus):
            loss_dict = trainer.valid_step(sample)
            ntokens = sum(s['ntokens'] for s in sample)
            loss = loss_dict['loss']
            del loss_dict['loss']  # don't include in extra_meters or extra_postfix

            if 'nll_loss' in loss_dict:
                nll_loss = loss_dict['nll_loss']
                nll_loss_meter.update(nll_loss, ntokens)

            loss_meter.update(loss, ntokens)

            extra_postfix = []
            for k, v in loss_dict.items():
                extra_meters[k].update(v)
                extra_postfix.append((k, extra_meters[k].avg))

            t.log(collections.OrderedDict([
                ('valid loss', round(loss_meter.avg, 2)),
            ] + extra_postfix))

        t.print(collections.OrderedDict([
            ('valid loss', round(loss_meter.avg, 2)),
            ('valid ppl', get_perplexity(nll_loss_meter.avg
                                         if nll_loss_meter.count > 0
                                         else loss_meter.avg)),
        ] + [
            (k, meter.avg)
            for k, meter in extra_meters.items()
        ]))

    # update and return the learning rate
    return loss_meter.avg
Exemple #2
0
def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus):
    """Evaluate the model on the validation set and return the average loss."""

    itr = dataset.eval_dataloader(
        subset, max_tokens=args.max_tokens, max_sentences=args.max_sentences,
        max_positions=max_positions,
        skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test,
        descending=True,  # largest batch first to warm the caching allocator
    )
    loss_meter = AverageMeter()
    extra_meters = collections.defaultdict(lambda: AverageMeter())

    prefix = 'valid on \'{}\' subset'.format(subset)
    with utils.build_progress_bar(args, itr, epoch, prefix) as t:
        for _, sample in data.skip_group_enumerator(t, ngpus):
            loss_dict = trainer.valid_step(sample)
            loss = loss_dict['loss']
            del loss_dict['loss']  # don't include in extra_meters or extra_postfix

            ntokens = sum(s['ntokens'] for s in sample)
            loss_meter.update(loss, ntokens)

            extra_postfix = []
            for k, v in loss_dict.items():
                extra_meters[k].update(v)
                extra_postfix.append((k, extra_meters[k].avg))

            t.log(collections.OrderedDict([
                ('valid loss', round(loss_meter.avg, 2)),
            ] + extra_postfix))

        t.print(collections.OrderedDict([
            ('valid loss', round(loss_meter.avg, 2)),
            ('valid ppl', get_perplexity(loss_meter.avg)),
        ] + [
            (k, meter.avg)
            for k, meter in extra_meters.items()
        ]))

    # update and return the learning rate
    return loss_meter.avg
def main():
    parser = options.get_parser('Generation')
    parser.add_argument('--path',
                        metavar='FILE',
                        required=True,
                        action='append',
                        help='path(s) to model file(s)')
    dataset_args = options.add_dataset_args(parser)
    dataset_args.add_argument('--batch-size',
                              default=32,
                              type=int,
                              metavar='N',
                              help='batch size')
    dataset_args.add_argument(
        '--gen-subset',
        default='test',
        metavar='SPLIT',
        help='data subset to generate (train, valid, test)')
    dataset_args.add_argument('--num-shards',
                              default=1,
                              type=int,
                              metavar='N',
                              help='shard generation over N shards')
    dataset_args.add_argument(
        '--shard-id',
        default=0,
        type=int,
        metavar='ID',
        help='id of the shard to generate (id < num_shards)')
    options.add_generation_args(parser)

    args = parser.parse_args()
    if args.no_progress_bar and args.log_format is None:
        args.log_format = 'none'
#    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu
    if hasattr(torch, 'set_grad_enabled'):
        torch.set_grad_enabled(False)

    # Load dataset
    if args.replace_unk is None:
        dataset = data.load_dataset(args.data, [args.gen_subset],
                                    args.source_lang, args.target_lang)
    else:
        dataset = data.load_raw_text_dataset(args.data, [args.gen_subset],
                                             args.source_lang,
                                             args.target_lang)
    if args.source_lang is None or args.target_lang is None:
        # record inferred languages in args
        args.source_lang, args.target_lang = dataset.src, dataset.dst

    # Load ensemble
#    print('| loading model(s) from {}'.format(', '.join(args.path)))
    models, _ = utils.load_ensemble_for_inference(args.path, dataset.src_dict,
                                                  dataset.dst_dict)

    #    print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
    #    print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
    #    print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset.splits[args.gen_subset])))

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam)

    # Initialize generator
    translator = SequenceGenerator(models,
                                   beam_size=args.beam,
                                   stop_early=(not args.no_early_stop),
                                   normalize_scores=(not args.unnormalized),
                                   len_penalty=args.lenpen,
                                   unk_penalty=args.unkpen)
    if use_cuda:
        translator.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Generate and compute BLEU score
    #scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk())
    max_positions = min(model.max_encoder_positions() for model in models)
    itr = dataset.eval_dataloader(args.gen_subset,
                                  max_sentences=args.batch_size,
                                  max_positions=max_positions,
                                  skip_invalid_size_inputs_valid_test=args.
                                  skip_invalid_size_inputs_valid_test)
    if args.num_shards > 1:
        if args.shard_id < 0 or args.shard_id >= args.num_shards:
            raise ValueError('--shard-id must be between 0 and num_shards')
        itr = data.sharded_iterator(itr, args.num_shards, args.shard_id)
    num_sentences = 0
    with utils.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        gen_timer = StopwatchMeter()
        translations = translator.generate_batched_itr(
            t,
            maxlen_a=args.max_len_a,
            maxlen_b=args.max_len_b,
            cuda_device=0 if use_cuda else None,
            timer=gen_timer)

        correct = 0
        total = 0
        for sample_id, src_tokens, target_tokens, hypos in translations:
            # Process input and ground truth
            target_tokens = target_tokens.int().cpu()
            # Either retrieve the original sentences or regenerate them from tokens.
            if align_dict is not None:
                src_str = dataset.splits[
                    args.gen_subset].src.get_original_text(sample_id)
                target_str = dataset.splits[
                    args.gen_subset].dst.get_original_text(sample_id)
            else:
                src_str = dataset.src_dict.string(src_tokens, args.remove_bpe)
                target_str = dataset.dst_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)

#            if not args.quiet:
#                print('S-{}\t{}'.format(sample_id, src_str))
#                print('T-{}\t{}'.format(sample_id, target_str))
            total += 1
            # Process top predictions
            for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'].int().cpu(),
                    align_dict=align_dict,
                    dst_dict=dataset.dst_dict,
                    remove_bpe=args.remove_bpe)
                #if src_str == 'walk around right thrice after jump opposite left twice':
                #    import pdb; pdb.set_trace()
                #                if not args.quiet:
                #                    print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str))
                #                    print('A-{}\t{}'.format(sample_id, ' '.join(map(str, alignment))))

                # Score only the top hypothesis
                if i == 0:
                    if align_dict is not None or args.remove_bpe is not None:
                        # Convert back to tokens for evaluation with unk replacement and/or without BPE
                        target_tokens = tokenizer.Tokenizer.tokenize(
                            target_str,
                            dataset.dst_dict,
                            add_if_not_exist=True)
                    #scorer.add(target_tokens, hypo_tokens)
                mat = ''
                for row in hypo['attention']:
                    for column in row:
                        mat += str(column) + '\t'
                    mat += '\n'
                tar = '/' + target_str
                tra = '=' + str(target_str == hypo_str)
                to_write.write(mat)
                to_write.write(src_str)
                to_write.write('\n')
                to_write.write(hypo_str)
                to_write.write('\n')
                to_write.write(tar)
                to_write.write('\n')
                to_write.write(tra)
                to_write.write('\n')
                to_write.write('-----------')
                to_write.write('\n')
                if hypo_str == target_str:
                    correct += 1
            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += 1

        print('| Correct : {} - Total: {}. Accuracy: {:.5f}'.format(
            correct, total, correct / total))
Exemple #4
0
def train(args, epoch, batch_offset, trainer, dataset, max_positions,
          num_gpus):
    """Train the model for one epoch."""

    seed = args.seed + epoch
    torch.manual_seed(seed)
    trainer.set_seed(seed)

    itr = dataset.train_dataloader(
        args.train_subset,
        num_workers=args.workers,
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
        seed=seed,
        epoch=epoch,
        sample_without_replacement=args.sample_without_replacement,
        sort_by_source_size=(epoch <= args.curriculum))
    loss_meter = AverageMeter()
    bsz_meter = AverageMeter()  # sentences per batch
    wpb_meter = AverageMeter()  # words per batch
    wps_meter = TimeMeter()  # words per second
    clip_meter = AverageMeter()  # % of updates clipped
    extra_meters = collections.defaultdict(lambda: AverageMeter())

    lr = trainer.get_lr()
    with utils.build_progress_bar(args, itr, epoch) as t:
        for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset):
            loss_dict = trainer.train_step(sample)
            loss = loss_dict['loss']
            del loss_dict[
                'loss']  # don't include in extra_meters or extra_postfix

            ntokens = sum(s['ntokens'] for s in sample)
            nsentences = sum(s['src_tokens'].size(0) for s in sample)
            loss_meter.update(loss,
                              nsentences if args.sentence_avg else ntokens)
            bsz_meter.update(nsentences)
            wpb_meter.update(ntokens)
            wps_meter.update(ntokens)
            clip_meter.update(1 if loss_dict['gnorm'] > args.clip_norm else 0)

            extra_postfix = []
            for k, v in loss_dict.items():
                extra_meters[k].update(v)
                extra_postfix.append((k, extra_meters[k].avg))

            t.log(
                collections.OrderedDict([
                    ('loss', loss_meter),
                    ('wps', round(wps_meter.avg)),
                    ('wpb', round(wpb_meter.avg)),
                    ('bsz', round(bsz_meter.avg)),
                    ('lr', lr),
                    ('clip', '{:.0%}'.format(clip_meter.avg)),
                ] + extra_postfix))

            if i == 0:
                # ignore the first mini-batch in words-per-second calculation
                wps_meter.reset()
            if args.save_interval > 0 and (i + 1) % args.save_interval == 0:
                save_checkpoint(trainer, args, epoch, i + 1)

        t.print(
            collections.OrderedDict([
                ('train loss', round(loss_meter.avg, 2)),
                ('train ppl', get_perplexity(loss_meter.avg)),
                ('s/checkpoint', round(wps_meter.elapsed_time)),
                ('words/s', round(wps_meter.avg)),
                ('words/batch', round(wpb_meter.avg)),
                ('bsz', round(bsz_meter.avg)),
                ('lr', lr),
                ('clip', '{:3.0f}%'.format(clip_meter.avg * 100)),
            ] + [(k, meter.avg) for k, meter in extra_meters.items()]))
Exemple #5
0
def main():
    parser = argparse.ArgumentParser(description='Batch translate')
    #parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar')
    parser.add_argument('--model', metavar='FILE', required=True, action='append',
                        help='path(s) to model file(s)')
    parser.add_argument('--dictdir', metavar='DIR', required=True, help='directory of dictionary files')
    parser.add_argument('--batch-size', default=32, type=int, metavar='N',
                        help='batch size')

    parser.add_argument('--beam', default=5, type=int, metavar='N',
                       help='beam size (default: 5)')
    #parser.add_argument('--nbest', default=1, type=int, metavar='N',
    #                   help='number of hypotheses to output')
    #parser.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None,
    #                   help='remove BPE tokens before scoring')
    parser.add_argument('--no-early-stop', action='store_true',
                       help=('continue searching even after finalizing k=beam '
                             'hypotheses; this is more correct, but increases '
                             'generation time by 50%%'))
    #parser.add_argument('--unnormalized', action='store_true',
    #                   help='compare unnormalized hypothesis scores')
    parser.add_argument('--cpu', action='store_true', help='generate on CPU')
    parser.add_argument('--no-beamable-mm', action='store_true',
                       help='don\'t use BeamableMM in attention layers')
    parser.add_argument('--lenpen', default=1, type=float,
                       help='length penalty: <1.0 favors shorter, >1.0 favors longer sentences')
    parser.add_argument('--unkpen', default=0, type=float,
                       help='unknown word penalty: <0 produces more unks, >0 produces fewer')
    #parser.add_argument('--replace-unk', nargs='?', const=True, default=None,
    #                   help='perform unknown replacement (optionally with alignment dictionary)')
    #parser.add_argument('--quiet', action='store_true',
    #                   help='Only print final scores')

    parser.add_argument('input', metavar='INPUT', help='Input file')

    args = parser.parse_args()
    
    # required by progress bar
    args.log_format = None
    
    USE_CUDA = not args.cpu and torch.cuda.is_available()

    print('Loading model...', file=sys.stderr)
    models, _ = utils.load_ensemble_for_inference(args.model, data_dir=args.dictdir)
    src_dic = models[0].src_dict
    dst_dic = models[0].dst_dict

    for model in models:
        model.make_generation_fast_(beamable_mm_beam_size=args.beam)

    translator = SequenceGenerator(
        models, beam_size=args.beam, stop_early=(not args.no_early_stop),
        len_penalty=args.lenpen, unk_penalty=args.unkpen)
    if USE_CUDA:
        translator.cuda()

    max_positions = min(model.max_encoder_positions() for model in models)

    print('Loading input data...', file=sys.stderr)

    raw_dataset = indexed_dataset.IndexedRawTextDataset(args.input, src_dic)
    dataset = fairseq.data.LanguageDatasets(SRC_LANG, TGT_LANG, src_dic, dst_dic)
    dataset.splits['test'] = fairseq.data.LanguagePairDataset(
        raw_dataset, raw_dataset, pad_idx=dataset.src_dict.pad(),
        eos_idx=dataset.src_dict.eos())

#    itr = dataset.eval_dataloader(
#        'test', max_sentences=args.batch_size, max_positions=max_positions)
    itr = dataset.eval_dataloader('test', max_sentences=args.batch_size)
    itr = utils.build_progress_bar(args, itr)

    #out = []

    for sample_id, src_tokens, _, hypos in translator.generate_batched_itr(
        itr, cuda_device=0 if USE_CUDA else None):
        src_str = dataset.src_dict.string(src_tokens, '@@ ')
        #print(src_str)
        hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
            hypo_tokens=hypos[0]['tokens'].int().cpu(),
            src_str=src_str,
            alignment=hypos[0]['alignment'].int().cpu(),
            align_dict=None,
            dst_dict=dataset.dst_dict,
            remove_bpe='@@ ')
        #out.append((sample_id, hypo_str))
        print('{}\t{}'.format(sample_id, hypo_str), flush=True)
Exemple #6
0
def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
    """Train the model for one epoch."""

    seed = args.seed + epoch
    torch.manual_seed(seed)
    trainer.set_seed(seed)

    itr = dataset.train_dataloader(
        args.train_subset, num_workers=args.workers,
        max_tokens=args.max_tokens, max_sentences=args.max_sentences,
        max_positions=max_positions, seed=seed, epoch=epoch,
        sample_without_replacement=args.sample_without_replacement,
        sort_by_source_size=(epoch <= args.curriculum))
    loss_meter = AverageMeter()
    bsz_meter = AverageMeter()    # sentences per batch
    wpb_meter = AverageMeter()    # words per batch
    wps_meter = TimeMeter()       # words per second
    clip_meter = AverageMeter()   # % of updates clipped
    extra_meters = collections.defaultdict(lambda: AverageMeter())

    lr = trainer.get_lr()
    with utils.build_progress_bar(args, itr, epoch) as t:
        for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset):
            loss_dict = trainer.train_step(sample)
            loss = loss_dict['loss']
            del loss_dict['loss']  # don't include in extra_meters or extra_postfix

            ntokens = sum(s['ntokens'] for s in sample)
            nsentences = sum(s['src_tokens'].size(0) for s in sample)
            loss_meter.update(loss, nsentences if args.sentence_avg else ntokens)
            bsz_meter.update(nsentences)
            wpb_meter.update(ntokens)
            wps_meter.update(ntokens)
            clip_meter.update(1 if loss_dict['gnorm'] > args.clip_norm else 0)

            extra_postfix = []
            for k, v in loss_dict.items():
                extra_meters[k].update(v)
                extra_postfix.append((k, extra_meters[k].avg))

            t.log(collections.OrderedDict([
                ('loss', loss_meter),
                ('wps', round(wps_meter.avg)),
                ('wpb', round(wpb_meter.avg)),
                ('bsz', round(bsz_meter.avg)),
                ('lr', lr),
                ('clip', '{:.0%}'.format(clip_meter.avg)),
            ] + extra_postfix))

            if i == 0:
                # ignore the first mini-batch in words-per-second calculation
                wps_meter.reset()
            if args.save_interval > 0 and (i + 1) % args.save_interval == 0:
                save_checkpoint(trainer, args, epoch, i + 1)

        t.print(collections.OrderedDict([
            ('train loss', round(loss_meter.avg, 2)),
            ('train ppl', get_perplexity(loss_meter.avg)),
            ('s/checkpoint', round(wps_meter.elapsed_time)),
            ('words/s', round(wps_meter.avg)),
            ('words/batch', round(wpb_meter.avg)),
            ('bsz', round(bsz_meter.avg)),
            ('lr', lr),
            ('clip', '{:3.0f}%'.format(clip_meter.avg * 100)),
        ] + [
            (k, meter.avg)
            for k, meter in extra_meters.items()
        ]))
Exemple #7
0
def main():
    parser = options.get_parser('Generation')
    parser.add_argument('--path', metavar='FILE', required=True, action='append',
                        help='path(s) to model file(s)')
    dataset_args = options.add_dataset_args(parser)
    dataset_args.add_argument('--batch-size', default=32, type=int, metavar='N',
                              help='batch size')
    dataset_args.add_argument('--gen-subset', default='test', metavar='SPLIT',
                              help='data subset to generate (train, valid, test)')
    options.add_generation_args(parser)

    args = parser.parse_args()
    if args.no_progress_bar and args.log_format is None:
        args.log_format = 'none'
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset
    if args.replace_unk is None:
        dataset = data.load_dataset(args.data, [args.gen_subset], args.source_lang, args.target_lang)
    else:
        dataset = data.load_raw_text_dataset(args.data, [args.gen_subset], args.source_lang, args.target_lang)
    if args.source_lang is None or args.target_lang is None:
        # record inferred languages in args
        args.source_lang, args.target_lang = dataset.src, dataset.dst

    # Load ensemble
    print('| loading model(s) from {}'.format(', '.join(args.path)))
    models, _ = utils.load_ensemble_for_inference(args.path, dataset.src_dict, dataset.dst_dict)

    print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
    print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset.splits[args.gen_subset])))

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam)

    # Initialize generator
    translator = SequenceGenerator(
        models, beam_size=args.beam, stop_early=(not args.no_early_stop),
        normalize_scores=(not args.unnormalized), len_penalty=args.lenpen,
        unk_penalty=args.unkpen)
    if use_cuda:
        translator.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Generate and compute BLEU score
    scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk())
    max_positions = min(model.max_encoder_positions() for model in models)
    itr = dataset.eval_dataloader(
        args.gen_subset, max_sentences=args.batch_size, max_positions=max_positions,
        skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
    num_sentences = 0
    with utils.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        gen_timer = StopwatchMeter()
        translations = translator.generate_batched_itr(
            t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
            cuda_device=0 if use_cuda else None, timer=gen_timer)
        for sample_id, src_tokens, target_tokens, hypos in translations:
            # Process input and ground truth
            target_tokens = target_tokens.int().cpu()
            # Either retrieve the original sentences or regenerate them from tokens.
            if align_dict is not None:
                src_str = dataset.splits[args.gen_subset].src.get_original_text(sample_id)
                target_str = dataset.splits[args.gen_subset].dst.get_original_text(sample_id)
            else:
                src_str = dataset.src_dict.string(src_tokens, args.remove_bpe)
                target_str = dataset.dst_dict.string(target_tokens, args.remove_bpe, escape_unk=True)

            if not args.quiet:
                print('S-{}\t{}'.format(sample_id, src_str))
                print('T-{}\t{}'.format(sample_id, target_str))

            # Process top predictions
            for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'].int().cpu(),
                    align_dict=align_dict,
                    dst_dict=dataset.dst_dict,
                    remove_bpe=args.remove_bpe)

                if not args.quiet:
                    print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str))
                    print('A-{}\t{}'.format(sample_id, ' '.join(map(str, alignment))))

                # Score only the top hypothesis
                if i == 0:
                    if align_dict is not None or args.remove_bpe is not None:
                        # Convert back to tokens for evaluation with unk replacement and/or without BPE
                        target_tokens = tokenizer.Tokenizer.tokenize(target_str,
                                                                     dataset.dst_dict,
                                                                     add_if_not_exist=True)
                    scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += 1

    print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} tokens/s)'.format(
        num_sentences, gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
def main():
    parser = options.get_parser('Generation')
    parser.add_argument('--path',
                        metavar='FILE',
                        required=True,
                        action='append',
                        help='path(s) to model file(s)')
    dataset_args = options.add_dataset_args(parser)
    dataset_args.add_argument('--batch-size',
                              default=32,
                              type=int,
                              metavar='N',
                              help='batch size')
    dataset_args.add_argument(
        '--gen-subset',
        default='test',
        metavar='SPLIT',
        help='data subset to generate (train, valid, test)')
    options.add_generation_args(parser)

    args = parser.parse_args()
    if args.no_progress_bar and args.log_format is None:
        args.log_format = 'none'
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset
    if args.replace_unk is None:
        dataset = data.load_dataset(args.data, [args.gen_subset],
                                    args.source_lang, args.target_lang)
    else:
        dataset = data.load_raw_text_dataset(args.data, [args.gen_subset],
                                             args.source_lang,
                                             args.target_lang)
    if args.source_lang is None or args.target_lang is None:
        # record inferred languages in args
        args.source_lang, args.target_lang = dataset.src, dataset.dst

    # Load ensemble
    print('| loading model(s) from {}'.format(', '.join(args.path)))
    models, _ = utils.load_ensemble_for_inference(args.path, dataset.src_dict,
                                                  dataset.dst_dict)

    print('| [{}] dictionary: {} types'.format(dataset.src,
                                               len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst,
                                               len(dataset.dst_dict)))
    print('| {} {} {} examples'.format(args.data, args.gen_subset,
                                       len(dataset.splits[args.gen_subset])))

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam)

    # Initialize generator
    translator = SequenceGenerator(models,
                                   beam_size=args.beam,
                                   stop_early=(not args.no_early_stop),
                                   normalize_scores=(not args.unnormalized),
                                   len_penalty=args.lenpen,
                                   unk_penalty=args.unkpen)
    if use_cuda:
        translator.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Generate and compute BLEU score
    scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(),
                         dataset.dst_dict.unk())
    max_positions = min(model.max_encoder_positions() for model in models)
    itr = dataset.eval_dataloader(args.gen_subset,
                                  max_sentences=args.batch_size,
                                  max_positions=max_positions,
                                  skip_invalid_size_inputs_valid_test=args.
                                  skip_invalid_size_inputs_valid_test)
    num_sentences = 0
    with utils.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        gen_timer = StopwatchMeter()
        translations = translator.generate_batched_itr(
            t,
            maxlen_a=args.max_len_a,
            maxlen_b=args.max_len_b,
            cuda_device=0 if use_cuda else None,
            timer=gen_timer)
        for sample_id, src_tokens, target_tokens, hypos in translations:
            # Process input and ground truth
            target_tokens = target_tokens.int().cpu()
            # Either retrieve the original sentences or regenerate them from tokens.
            if align_dict is not None:
                src_str = dataset.splits[
                    args.gen_subset].src.get_original_text(sample_id)
                target_str = dataset.splits[
                    args.gen_subset].dst.get_original_text(sample_id)
            else:
                src_str = dataset.src_dict.string(src_tokens, args.remove_bpe)
                target_str = dataset.dst_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)

            if not args.quiet:
                print('S-{}\t{}'.format(sample_id, src_str))
                print('T-{}\t{}'.format(sample_id, target_str))

            # Process top predictions
            for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'].int().cpu(),
                    align_dict=align_dict,
                    dst_dict=dataset.dst_dict,
                    remove_bpe=args.remove_bpe)

                if not args.quiet:
                    print('H-{}\t{}\t{}'.format(sample_id, hypo['score'],
                                                hypo_str))
                    print('A-{}\t{}'.format(sample_id,
                                            ' '.join(map(str, alignment))))

                # Score only the top hypothesis
                if i == 0:
                    if align_dict is not None or args.remove_bpe is not None:
                        # Convert back to tokens for evaluation with unk replacement and/or without BPE
                        target_tokens = tokenizer.Tokenizer.tokenize(
                            target_str,
                            dataset.dst_dict,
                            add_if_not_exist=True)
                    scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += 1

    print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} tokens/s)'.
          format(num_sentences, gen_timer.n, gen_timer.sum,
                 1. / gen_timer.avg))
    print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam,
                                                  scorer.result_string()))