Exemple #1
0
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        net_output = model(**sample['net_input'])
        loss, _ = self.compute_loss(model, net_output, sample, reduce=reduce)
        sample_size = sample['target'].size(
            0) if self.args.sentence_avg else sample['ntokens']
        # Only evalute bleu if not training
        if model.training:
            bleu_stats = utils.get_zero_bleu_stats()
        else:
            use_cuda = torch.cuda.is_available() and not self.args.cpu
            if self.args.sacrebleu:
                if hasattr(self.args, 'lowercase'):
                    scorer = bleu.SacrebleuScorer(
                        lowercase=self.args.lowercase)
                else:
                    scorer = bleu.SacrebleuScorer()
            else:
                scorer = bleu.Scorer(self.task.target_dictionary.pad(),
                                     self.task.target_dictionary.eos(),
                                     self.task.target_dictionary.unk())
            gen_timer = StopwatchMeter()
            wps_meter = TimeMeter()
            utils.run_inference_on_sample(sample=sample,
                                          use_cuda=use_cuda,
                                          args=self.args,
                                          gen_timer=gen_timer,
                                          task=self.task,
                                          generator=self.generator,
                                          model=model,
                                          tgt_dict=self.task.target_dictionary,
                                          align_dict=None,
                                          subset=None,
                                          src_dict=self.task.source_dictionary,
                                          scorer=scorer,
                                          wps_meter=wps_meter)
            result_string = scorer.result_string()
            bleu_stats = utils.BleuStatistics(correct=result_string.counts,
                                              total=result_string.totals,
                                              sys_len=result_string.sys_len,
                                              ref_len=result_string.ref_len)
        logging_output = {
            'loss': utils.item(loss.data) if reduce else loss.data,
            'ntokens': sample['ntokens'],
            'nsentences': sample['target'].size(0),
            'sample_size': sample_size,
            'bleu': bleu_stats,
        }
        return loss, sample_size, logging_output
    def __init__(self, args, task):
        super().__init__(task = task)
        self.args = args

        self.generator = SimpleSequenceGenerator(beam=args.scst_beam,
                                                 penalty=args.scst_penalty,
                                                 max_pos=args.max_target_positions,
                                                 eos_index=task.target_dictionary.eos_index)

        # Needed for decoding model output to string
        self.conf_tokenizer = encoders.build_tokenizer(args)
        self.conf_decoder = encoders.build_bpe(args)
        self.target_dict = task.target_dictionary

        # Tokenizer needed for computing CIDEr scores
        self.tokenizer = encoders.build_tokenizer(args)
        self.bpe = encoders.build_bpe(args)
 
        self.scorer = bleu.SacrebleuScorer()

        self.pad_idx = task.target_dictionary.pad()
Exemple #3
0
    def __init__(self, args, task, data_queue, message_queue):
        """
        Handle detokenize and belu score computation

        Args:
            args (Namespace): paramerter for model and generation
            task (fairseq.tasks.fairseq_task.Fairseq):
                use to load dict for detokenize
            data_queue (multiprocessing.Queue):
                queue store tensor data for detokenize
            message_queue (multiprocessing.Queue): queue store output
        """
        super(PostProcess, self).__init__()
        # Set dictionaries
        try:
            self.src_dict = getattr(task, 'source_dictionary', None)
        except NotImplementedError:
            self.src_dict = None
        self.tgt_dict = task.target_dictionary

        # Load alignment dictionary for unknown word replacement
        # (None if no unknown word replacement, empty if no path to align dictionary)
        self.align_dict = utils.load_align_dict(args.replace_unk)

        # Generate and compute BLEU score
        if args.sacrebleu:
            self.scorer = bleu.SacrebleuScorer()
        else:
            self.scorer = bleu.Scorer(self.tgt_dict.pad(), self.tgt_dict.eos(),
                                      self.tgt_dict.unk())

        self.args = args
        self.task = task
        self.data_queue = data_queue
        self.message_queue = message_queue
        if args.decode_hypothesis:
            self.tokenizer = encoders.build_tokenizer(args)
            self.bpe = encoders.build_bpe(args)
Exemple #4
0
    def __init__(self, args, task, message_queue):
        """
        Process to handle IO and compute metrics

        Args:
            args (Namespace): paramerter for model and generation
            task (fairseq.tasks.fairseq_task.Fairseq):
                use to load dict for detokenize
            message_queue (multiprocessing.Queue): queue store output
        """
        super(IOProcess, self).__init__()
        self.tgt_dict = task.target_dictionary

        # Generate and compute BLEU score
        if args.sacrebleu:
            self.scorer = bleu.SacrebleuScorer()
        else:
            self.scorer = bleu.Scorer(self.tgt_dict.pad(), self.tgt_dict.eos(),
                                      self.tgt_dict.unk())

        self.args = args
        self.message_queue = message_queue
        self.has_target = False
Exemple #5
0
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    #pickle.dump(src_dict, open("./translations/seethetensors/src_dict.pkl", "bw") )
    #pickle.dump(tgt_dict, open("./translations/seethetensors/tgt_dict.pkl", "bw") )

    #print("* args.remove_bpe : ", args.remove_bpe)
    #bpe_symbol = args.remove_bpe
    #pickle.dump(bpe_symbol, open("./translations/seethetensors/bpe_symbol.pkl", "bw") )

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(':'),
        arg_overrides=eval(args.model_overrides),
        task=task,
        bert_ratio=args.bert_ratio if args.change_ratio else None,
        encoder_ratio=args.encoder_ratio if args.change_ratio else None,
        geargs=args,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)
    #pickle.dump(align_dict,  open("./translations/seethetensors/align_dict.pkl", "bw"))
    # Load dataset (possibly sharded)

    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    resdict = {}
    #hypo_strings = [] ##!!!
    #src_strings = [] ##!!!
    results_path = args.results_path
    stamp = str(time.time())
    resfp = results_path + "/" + args.gen_subset + "." + stamp + ".gen_sparql.json"
    #resfp_ = results_path+"/"+args.gen_subset+"."+stamp+".txt"
    #sampfp = results_path+"/"+args.gen_subset+"."+stamp+".sampleids.txt"

    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample['target'][:, :args.prefix_size]

            gen_timer.start()
            hypos = task.inference_step(generator, models, sample,
                                        prefix_tokens)
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            #open(sampfp, "w", encoding="UTF-8").writeline(str(sample['id'].tolist()))
            for i, sample_id in enumerate(sample['id'].tolist()):
                has_target = sample['target'] is not None

                # Remove padding
                src_tokens = utils.strip_pad(
                    sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
                target_tokens = None
                if has_target:
                    target_tokens = utils.strip_pad(
                        sample['target'][i, :], tgt_dict.pad()).int().cpu()

                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(
                        args.gen_subset).src.get_original_text(sample_id)
                    target_str = task.dataset(
                        args.gen_subset).tgt.get_original_text(sample_id)
                else:
                    if src_dict is not None:
                        src_str = src_dict.string(src_tokens, args.remove_bpe)
                    else:
                        src_str = ""
                    if has_target:
                        target_str = tgt_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)
                #src_strings.append(src_str) ##!!!
                if not args.quiet:
                    if src_dict is not None:
                        print('S-{}\t{}'.format(sample_id, src_str))
                    if has_target:
                        print('T-{}\t{}'.format(sample_id, target_str))

                # Process top predictions
                for i, hypo in enumerate(
                        hypos[i][:min(len(hypos), args.nbest)]):
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int().cpu(),
                        src_str=src_str,
                        alignment=hypo['alignment'].int().cpu()
                        if hypo['alignment'] is not None else None,
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )
                    #hypo_strings.append(hypo_str)  ##!!!
                    resdict[str(int(sample_id) + 1)] = {
                        "sparql": interprete(hypo_str),
                        "en": src_str
                    }  ##!!!
                    if not args.quiet:
                        print('H-{}\t{}\t{}'.format(sample_id, hypo['score'],
                                                    hypo_str))
                        print('P-{}\t{}'.format(
                            sample_id, ' '.join(
                                map(
                                    lambda x: '{:.4f}'.format(x),
                                    hypo['positional_scores'].tolist(),
                                ))))

                        if args.print_alignment:
                            print('A-{}\t{}'.format(
                                sample_id, ' '.join(
                                    map(lambda x: str(utils.item(x)),
                                        alignment))))

                    # Score only the top hypothesis
                    if has_target and i == 0:
                        if align_dict is not None or args.remove_bpe is not None:
                            # Convert back to tokens for evaluation with unk replacement and/or without BPE
                            target_tokens = tgt_dict.encode_line(
                                target_str, add_if_not_exist=True)
                        if hasattr(scorer, 'add_string'):
                            scorer.add_string(target_str, hypo_str)
                        else:
                            scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(num_generated_tokens)
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += sample['nsentences']

    print(
        '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset,
                                                      args.beam,
                                                      scorer.result_string()))
    """
    with open(resfp, "a", encoding="UTF-8") as restore :
        for gen_str in hypo_strings:
            restore.write(gen_str+" \n")
        restore.close() 
    with open(resfp_, "a", encoding="UTF-8") as res_tore :
        for src_str in src_strings:
            res_tore.write(src_str+" \n")
        res_tore.close()"""
    with open(resfp, "w", encoding="UTF-8") as restore:
        json.dump(resdict, restore, ensure_ascii=False, indent=4)
        restore.close()

    return scorer
def infer_onebyone(args, models, task, input):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu
    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    input = ' '.join([i for i in input])
    src_tokens = tgt_dict.encode_line(input).type(torch.LongTensor)
    input_sample = {
        'id': torch.Tensor([0]),
        'nsentences': 1,
        'ntokens': len(src_tokens),
        'net_input': {
            'src_tokens': src_tokens.unsqueeze(0),
            'src_lengths': torch.tensor([len(src_tokens)]),
            'prev_output_tokens': torch.tensor([[tgt_dict.eos()]])
        },
        'target': src_tokens.unsqueeze(0),
    }

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))

    # Optimize ensemble for generation

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True

    wps_meter = TimeMeter()

    sample = input_sample
    sample = utils.move_to_cuda(sample) if use_cuda else sample

    prefix_tokens = None
    if args.prefix_size > 0:
        prefix_tokens = sample['target'][:, :args.prefix_size]

    gen_timer.start()
    #pdb.set_trace()
    hypos = task.inference_step(generator, models, sample, prefix_tokens)
    num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
    gen_timer.stop(num_generated_tokens)

    for i, sample_id in enumerate(sample['id'].tolist()):
        has_target = sample['target'] is not None

        # Remove padding
        src_tokens = utils.strip_pad(sample['net_input']['src_tokens'][i, :],
                                     tgt_dict.pad())
        target_tokens = None
        if has_target:
            target_tokens = utils.strip_pad(sample['target'][i, :],
                                            tgt_dict.pad()).int().cpu()

        # Either retrieve the original sentences or regenerate them from tokens.
        if align_dict is not None:
            src_str = task.dataset(
                args.gen_subset).src.get_original_text(sample_id)
            target_str = task.dataset(
                args.gen_subset).tgt.get_original_text(sample_id)
        else:
            if src_dict is not None:
                src_str = src_dict.string(src_tokens, args.remove_bpe)
            else:
                src_str = ""
            if has_target:
                target_str = tgt_dict.string(target_tokens,
                                             args.remove_bpe,
                                             escape_unk=True)

        if not args.quiet:
            if src_dict is not None:
                print('S-{}\t{}'.format(sample_id, src_str))
            if has_target:
                print('T-{}\t{}'.format(sample_id, target_str))

        # Process top predictions
        for j, hypo in enumerate(hypos[i][:args.nbest]):
            hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                hypo_tokens=hypo['tokens'].int().cpu(),
                src_str=src_str,
                alignment=hypo['alignment'],
                align_dict=align_dict,
                tgt_dict=tgt_dict,
                remove_bpe=args.remove_bpe,
            )

            if not args.quiet:
                output = ''.join(hypo_str.split(' '))
                print('H-{}\t{}\t{}'.format(sample_id, hypo['score'],
                                            hypo_str))
                print('P-{}\t{}'.format(
                    sample_id, ' '.join(
                        map(
                            lambda x: '{:.4f}'.format(x),
                            hypo['positional_scores'].tolist(),
                        ))))

                if args.print_alignment:
                    print('A-{}\t{}'.format(
                        sample_id, ' '.join([
                            '{}-{}'.format(src_idx, tgt_idx)
                            for src_idx, tgt_idx in alignment
                        ])))

                if args.print_step:
                    print('I-{}\t{}'.format(sample_id, hypo['steps']))

            # Score only the top hypothesis
            if has_target and j == 0:
                if align_dict is not None or args.remove_bpe is not None:
                    # Convert back to tokens for evaluation with unk replacement and/or without BPE
                    target_tokens = tgt_dict.encode_line(target_str,
                                                         add_if_not_exist=True)
                if hasattr(scorer, 'add_string'):
                    scorer.add_string(target_str, hypo_str)
                else:
                    scorer.add(target_tokens, hypo_tokens)

    wps_meter.update(num_generated_tokens)
    num_sentences += sample['nsentences']

    print(
        '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset,
                                                      args.beam,
                                                      scorer.result_string()))

    return scorer, output
Exemple #7
0
def validate_translation(args, trainer, task, epoch_itr, generator):
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary
    models = [trainer.get_model()]
    bleu_dict = {key: None for key in task.eval_lang_pairs}

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer_dict = {
            key: bleu.SacrebleuScorer()
            for key in task.eval_lang_pairs
        }
    else:
        scorer_dict = {
            key: bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
            for key in task.eval_lang_pairs
        }

    itr = task.get_batch_iterator(
        dataset=task.dataset('valid'),
        max_tokens=args.max_tokens_valid,
        max_sentences=args.max_sentences_valid,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            trainer.get_model().max_positions(),
        ),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        seed=args.seed,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
        num_workers=args.num_workers,
        noskip=True,
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.build_progress_bar(args,
                                               itr,
                                               epoch_itr.epoch,
                                               prefix='translate subset',
                                               no_progress_bar='simple')

    num_sentences = 0
    has_target = True
    #with progress_bar.build_progress_bar(args, itr) as t:
    for samples in progress:
        if torch.cuda.is_available() and not args.cpu:
            samples = utils.move_to_cuda(samples)
        #if 'net_input' not in samples:
        #    continue

        prefix_tokens = None
        for key, sample in samples.items():
            hypos = task.inference_step(generator, models, sample,
                                        prefix_tokens)
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)

            for i, sample_id in enumerate(sample['id'].tolist()):
                has_target = sample['target'] is not None

                target_tokens = None
                if has_target:
                    target_tokens = utils.strip_pad(
                        sample['target'][i, :], tgt_dict.pad()).int().cpu()
                # Remove padding
                if args.sde:
                    src_tokens = target_tokens
                else:
                    src_tokens = utils.strip_pad(
                        sample['net_input']['src_tokens'][i, :],
                        tgt_dict.pad())

                # Either retrieve the original sentences or regenerate them from tokens.
                #if src_dict is not None:
                #    src_str = src_dict.string(src_tokens, args.remove_bpe)
                #else:
                #    src_str = ""
                if has_target:
                    target_str = tgt_dict.string(target_tokens,
                                                 args.remove_bpe,
                                                 escape_unk=True)

                #if not args.quiet:
                #    if src_dict is not None:
                #        print('S-{}\t{}'.format(sample_id, src_str))
                #    if has_target:
                #        print('T-{}\t{}'.format(sample_id, target_str))

                # Process top predictions
                for j, hypo in enumerate(hypos[i][:args.nbest]):
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int().cpu(),
                        src_str="",
                        alignment=hypo['alignment'].int().cpu()
                        if hypo['alignment'] is not None else None,
                        align_dict=None,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )

                #if not args.quiet:
                #    print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str))
                #    print('P-{}\t{}'.format(
                #        sample_id,
                #        ' '.join(map(
                #            lambda x: '{:.4f}'.format(x),
                #            hypo['positional_scores'].tolist(),
                #        ))
                #    ))

                #    if args.print_alignment:
                #        print('A-{}\t{}'.format(
                #            sample_id,
                #            ' '.join(map(lambda x: str(utils.item(x)), alignment))
                #        ))

                # Score only the top hypothesis
                if has_target and j == 0:
                    if args.remove_bpe is not None:
                        # Convert back to tokens for evaluation with unk replacement and/or without BPE
                        target_tokens = tgt_dict.encode_line(
                            target_str, add_if_not_exist=True)
                    if hasattr(scorer_dict[key], 'add_string'):
                        scorer_dict[key].add_string(target_str, hypo_str)
                    else:
                        scorer_dict[key].add(target_tokens, hypo_tokens)

            num_sentences += sample['nsentences']
    for key, scorer in scorer_dict.items():
        bleu_dict[key] = scorer.score()
    return bleu_dict
Exemple #8
0
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _model_args = utils.load_ensemble_for_inference(
        args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides),
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]
        ),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    if args.score_reference:
        translator = SequenceScorer(models, task.target_dictionary)
    else:
        translator = SequenceGenerator(
            models, task.target_dictionary, beam_size=args.beam, minlen=args.min_len,
            stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized),
            len_penalty=args.lenpen, unk_penalty=args.unkpen,
            sampling=args.sampling, sampling_topk=args.sampling_topk, sampling_temperature=args.sampling_temperature,
            diverse_beam_groups=args.diverse_beam_groups, diverse_beam_strength=args.diverse_beam_strength,
            match_source_len=args.match_source_len, no_repeat_ngram_size=args.no_repeat_ngram_size,
        )

    if use_cuda:
        translator.cuda()

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    with progress_bar.build_progress_bar(args, itr) as t:
        if args.score_reference:
            translations = translator.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
        else:
            translations = translator.generate_batched_itr(
                t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
                cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size,
            )

        wps_meter = TimeMeter()
        for sample_id, src_tokens, target_tokens, hypos in translations:
            # Process input and ground truth
            has_target = target_tokens is not None
            target_tokens = target_tokens.int().cpu() if has_target else None

            # Either retrieve the original sentences or regenerate them from tokens.
            if align_dict is not None:
                src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
                target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id)
            else:
                src_str = src_dict.string(src_tokens, args.remove_bpe)
                if has_target:
                    target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True)

            if not args.quiet:
                print('S-{}\t{}'.format(sample_id, src_str))
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str))

            # Process top predictions
            for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe,
                )

                if not args.quiet:
                    print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str))
                    print('P-{}\t{}'.format(
                        sample_id,
                        ' '.join(map(
                            lambda x: '{:.4f}'.format(x),
                            hypo['positional_scores'].tolist(),
                        ))
                    ))

                    if args.print_alignment:
                        print('A-{}\t{}'.format(
                            sample_id,
                            ' '.join(map(lambda x: str(utils.item(x)), alignment))
                        ))

                # Score only the top hypothesis
                if has_target and i == 0:
                    if align_dict is not None or args.remove_bpe is not None:
                        # Convert back to tokens for evaluation with unk replacement and/or without BPE
                        target_tokens = tokenizer.Tokenizer.tokenize(
                            target_str, tgt_dict, add_if_not_exist=True)
                    if hasattr(scorer, 'add_string'):
                        scorer.add_string(target_str, hypo_str)
                    else:
                        scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += 1

    print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
        num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
Exemple #9
0
def _generate_score(models, args, task, dataset, optimize=True):
    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load ensemble
    if not args.quiet:
        print("| loading model(s) from {}".format(", ".join(
            args.path.split(":"))))

    # Optimize ensemble for generation
    if optimize:
        for model in models:
            model.make_generation_fast_(
                beamable_mm_beam_size=None
                if args.no_beamable_mm else args.beam,
                need_attn=True,
            )

    translator = build_sequence_generator(args, task, models)
    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Keep track of translations
    # Initialize with empty translations
    # and zero probs scores
    translated_sentences = [""] * len(dataset)
    translated_scores = [0.0] * len(dataset)

    collect_output_hypos = getattr(args, "output_hypos_binary_path", False)
    if collect_output_hypos:
        output_hypos_token_arrays = [None] * len(dataset)

    # Generate and compute BLEU score
    dst_dict = task.target_dictionary
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(dst_dict.pad(), dst_dict.eos(), dst_dict.unk())
    itr = get_eval_itr(args, models, task, dataset)

    oracle_scorer = None
    if args.report_oracle_bleu:
        oracle_scorer = bleu.Scorer(dst_dict.pad(), dst_dict.eos(),
                                    dst_dict.unk())

    rescoring_model = rescoring.setup_rescoring(args)
    rescoring_scorer = None
    if rescoring_model:
        rescoring_scorer = bleu.Scorer(dst_dict.pad(), dst_dict.eos(),
                                       dst_dict.unk())

    num_sentences = 0
    translation_samples = []
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        gen_timer = StopwatchMeter()
        translations = translator.generate_batched_itr(
            t,
            maxlen_a=args.max_len_a,
            maxlen_b=args.max_len_b,
            cuda=use_cuda,
            timer=gen_timer,
            prefix_size=1
            if pytorch_translate_data.is_multilingual(args) else 0,
        )

        for trans_info in _iter_translations(args, task, dataset, translations,
                                             align_dict, rescoring_model):
            scorer.add(trans_info.target_tokens, trans_info.hypo_tokens)
            if oracle_scorer is not None:
                oracle_scorer.add(trans_info.target_tokens,
                                  trans_info.best_hypo_tokens)
            if rescoring_scorer is not None:
                rescoring_scorer.add(trans_info.target_tokens,
                                     trans_info.hypo_tokens_after_rescoring)

            translated_sentences[trans_info.sample_id] = trans_info.hypo_str
            translated_scores[trans_info.sample_id] = trans_info.hypo_score
            if collect_output_hypos:
                output_hypos_token_arrays[
                    trans_info.sample_id] = trans_info.best_hypo_tokens
            translation_samples.append(
                collections.OrderedDict({
                    "sample_id":
                    trans_info.sample_id.item(),
                    "src_str":
                    trans_info.src_str,
                    "target_str":
                    trans_info.target_str,
                    "hypo_str":
                    trans_info.hypo_str,
                }))
            wps_meter.update(trans_info.src_tokens.size(0))
            t.log({"wps": round(wps_meter.avg)})
            num_sentences += 1

    # If applicable, save collected hypothesis tokens to binary output file
    if collect_output_hypos:
        output_dataset = pytorch_translate_data.InMemoryNumpyDataset()
        output_dataset.load_from_sequences(output_hypos_token_arrays)
        output_dataset.save(args.output_hypos_binary_path)

    # If applicable, save the translations to the output file
    # For eg. external evaluation
    if getattr(args, "translation_output_file", False):
        with open(args.translation_output_file, "w") as out_file:
            for hypo_str in translated_sentences:
                print(hypo_str, file=out_file)

    if getattr(args, "translation_probs_file", False):
        with open(args.translation_probs_file, "w") as out_file:
            for hypo_score in translated_scores:
                print(np.exp(hypo_score), file=out_file)

    if oracle_scorer is not None:
        print(
            f"| Oracle BLEU (best hypo in beam): {oracle_scorer.result_string()}"
        )

    if rescoring_scorer is not None:
        print(
            f"| Rescoring BLEU (top hypo in beam after rescoring):{rescoring_scorer.result_string()}"
        )

    return scorer, num_sentences, gen_timer, translation_samples
def main(args):
    assert args.path is not None, "--path required for generation!"
    assert (not args.sampling or args.nbest
            == args.beam), "--sampling requires --nbest to be equal to --beam"
    assert (args.replace_unk is None or args.raw_text
            ), "--replace-unk requires a raw text dataset (--raw-text)"

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionaries
    try:
        src_dict = getattr(task, "source_dictionary", None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    print("| loading model(s) from {}".format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(":"),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if "net_input" not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample["target"][:, :args.prefix_size]

            gen_timer.start()
            hypos = task.inference_step(generator, models, sample,
                                        prefix_tokens)
            num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample["id"].tolist()):
                has_target = sample["target"] is not None

                # Remove padding
                src_tokens = utils.strip_pad(
                    sample["net_input"]["src_tokens"][i, :], tgt_dict.pad())
                target_tokens = None
                if has_target:
                    target_tokens = (utils.strip_pad(
                        sample["target"][i, :], tgt_dict.pad()).int().cpu())

                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(
                        args.gen_subset).src.get_original_text(sample_id)
                    target_str = task.dataset(
                        args.gen_subset).tgt.get_original_text(sample_id)
                else:
                    if src_dict is not None:
                        src_str = src_dict.string(src_tokens, args.remove_bpe)
                    else:
                        src_str = ""
                    if has_target:
                        target_str = tgt_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)

                if not args.quiet:
                    if src_dict is not None:
                        print("S-{}\t{}".format(sample_id, src_str))
                    if has_target:
                        print("T-{}\t{}".format(sample_id, target_str))

                # Process top predictions
                for j, hypo in enumerate(hypos[i][:args.nbest]):
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo["tokens"].int().cpu(),
                        src_str=src_str,
                        alignment=hypo["alignment"],
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )

                    if not args.quiet:
                        print("H-{}\t{}\t{}".format(sample_id, hypo["score"],
                                                    hypo_str))
                        print("P-{}\t{}".format(
                            sample_id,
                            " ".join(
                                map(
                                    lambda x: "{:.4f}".format(x),
                                    hypo["positional_scores"].tolist(),
                                )),
                        ))

                        if args.print_alignment:
                            print("A-{}\t{}".format(
                                sample_id,
                                " ".join([
                                    "{}-{}".format(src_idx, tgt_idx)
                                    for src_idx, tgt_idx in alignment
                                ]),
                            ))

                        if args.print_step:
                            print("I-{}\t{}".format(sample_id, hypo["steps"]))

                        if getattr(args, "retain_iter_history", False):
                            print("\n".join([
                                "E-{}_{}\t{}".format(
                                    sample_id,
                                    step,
                                    utils.post_process_prediction(
                                        h["tokens"].int().cpu(),
                                        src_str,
                                        None,
                                        None,
                                        tgt_dict,
                                        None,
                                    )[1],
                                ) for step, h in enumerate(hypo["history"])
                            ]))

                    # Score only the top hypothesis
                    if has_target and j == 0:
                        if align_dict is not None or args.remove_bpe is not None:
                            # Convert back to tokens for evaluation with unk replacement and/or without BPE
                            target_tokens = tgt_dict.encode_line(
                                target_str, add_if_not_exist=True)
                        if hasattr(scorer, "add_string"):
                            scorer.add_string(target_str, hypo_str)
                        else:
                            scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(num_generated_tokens)
            t.log({"wps": round(wps_meter.avg)})
            num_sentences += sample["nsentences"]

    print(
        "| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)"
        .format(
            num_sentences,
            gen_timer.n,
            gen_timer.sum,
            num_sentences / gen_timer.sum,
            1.0 / gen_timer.avg,
        ))
    if has_target:
        print("| Generate {} with beam={}: {}".format(args.gen_subset,
                                                      args.beam,
                                                      scorer.result_string()))

    return scorer
Exemple #11
0
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.dataset_impl == 'raw', \
        '--replace-unk requires a raw text dataset (--dataset-impl=raw)'

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))

    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(':'),
        arg_overrides=args.model_overrides,
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()

        if args.decoding_path is not None:
            src_sents = [[] for _ in range(5000000)]
            tgt_sents = [[] for _ in range(5000000)]
            hyp_sents = [[] for _ in range(5000000)]

        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample['target'][:, :args.prefix_size]

            gen_timer.start()
            hypos = task.inference_step(generator, models, sample,
                                        prefix_tokens)
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample['id'].tolist()):
                has_target = sample['target'] is not None

                # Remove padding
                src_tokens = utils.strip_pad(
                    sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
                target_tokens = None
                if has_target:
                    target_tokens = utils.strip_pad(
                        sample['target'][i, :], tgt_dict.pad()).int().cpu()

                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(
                        args.gen_subset).src.get_original_text(sample_id)
                    target_str = task.dataset(
                        args.gen_subset).tgt.get_original_text(sample_id)
                else:
                    if src_dict is not None:
                        src_str = src_dict.string(src_tokens, args.remove_bpe)
                    else:
                        src_str = ""
                    if has_target:
                        target_str = tgt_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)

                if not args.quiet:
                    if src_dict is not None:
                        print('S-{}\t{}'.format(sample_id, src_str))
                    if has_target:
                        print('T-{}\t{}'.format(sample_id, target_str))

                # Process top predictions
                for j, hypo in enumerate(hypos[i][:args.nbest]):
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int().cpu(),
                        src_str=src_str,
                        alignment=hypo['alignment'],
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )

                    if not args.quiet:
                        print('H-{}\t{}\t{}'.format(sample_id, hypo['score'],
                                                    hypo_str))
                        print('P-{}\t{}'.format(
                            sample_id, ' '.join(
                                map(
                                    lambda x: '{:.4f}'.format(x),
                                    hypo['positional_scores'].tolist(),
                                ))))

                        if args.print_alignment:
                            print('A-{}\t{}'.format(
                                sample_id, ' '.join([
                                    '{}-{}'.format(src_idx, tgt_idx)
                                    for src_idx, tgt_idx in alignment
                                ])))

                        if args.print_step:
                            print('I-{}\t{}'.format(sample_id, hypo['steps']))

                        if getattr(args, 'retain_iter_history', False):
                            print("\n".join([
                                'E-{}_{}\t{}'.format(
                                    sample_id, step,
                                    utils.post_process_prediction(
                                        h['tokens'].int().cpu(), src_str, None,
                                        None, tgt_dict, None)[1])
                                for step, h in enumerate(hypo['history'])
                            ]))

                    if args.decoding_path is not None:
                        src_sents[int(sample_id)].append(src_str)
                        tgt_sents[int(sample_id)].append(target_str)
                        hyp_sents[int(sample_id)].append(hypo_str)

                    # Score only the top hypothesis
                    if has_target and j == 0:
                        if align_dict is not None or args.remove_bpe is not None:
                            # Convert back to tokens for evaluation with unk replacement and/or without BPE
                            target_tokens = tgt_dict.encode_line(
                                target_str, add_if_not_exist=True)
                        if hasattr(scorer, 'add_string'):
                            scorer.add_string(target_str, hypo_str)
                        else:
                            scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(num_generated_tokens)
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += sample['nsentences']

    print(
        '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset,
                                                      args.beam,
                                                      scorer.result_string()))

    if args.decoding_path is not None:
        with open(os.path.join(args.decoding_path, 'source.txt'),
                  'w',
                  encoding='utf-8') as f:
            for sents in src_sents:
                if len(sents) == 0:
                    continue
                for sent in sents:
                    f.write(sent + '\n')

        with open(os.path.join(args.decoding_path, 'target.txt'),
                  'w',
                  encoding='utf-8') as f:
            for sents in tgt_sents:
                if len(sents) == 0:
                    continue
                for sent in sents:
                    f.write(sent + '\n')

        with open(os.path.join(args.decoding_path, 'decoding.txt'),
                  'w',
                  encoding='utf-8') as f:
            for sents in hyp_sents:
                if len(sents) == 0:
                    continue
                for sent in sents:
                    f.write(sent + '\n')

    if len(list(args.num_ref.values())) == 1:
        num_ref = int(list(args.num_ref.values())[0])
    else:
        raise NotImplementedError

    ref_path = []

    if num_ref == 1:
        ref_path.append(
            os.path.join(args.valid_decoding_path,
                         args.gen_subset + '.tok.' + args.target_lang))
    else:
        for i in range(num_ref):
            ref_path.append(
                os.path.join(
                    args.valid_decoding_path,
                    args.gen_subset + '.tok.' + args.target_lang + str(i)))

    decoding_path = os.path.join(args.decoding_path, 'decoding.txt')

    #with open(decoding_path) as out_file:
    #    out_file.seek(0)
    #    subprocess.call(
    #        'perl %s/multi-bleu.perl %s' % (args.multi_bleu_path, ' '.join(ref_path)),
    #        stdin=out_file, shell=True)

    return scorer
Exemple #12
0
def main(args):
    utils.import_user_module(args)

    if args.buffer_size < 1:
        args.buffer_size = 1
    if args.max_tokens is None and args.max_sentences is None:
        args.max_sentences = 1

    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert not args.max_sentences or args.max_sentences <= args.buffer_size, \
        '--max-sentences/--batch-size cannot be larger than --buffer-size'

    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu
    use_ctc_loss = True if args.criterion == 'ctc_loss' else False

    # Setup task, e.g., image captioning
    task = tasks.setup_task(args)
    # Load dataset split
    task.load_dataset(args.gen_subset, combine=True, epoch=0)

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    model_paths = args.path.split(':')
    models, _model_args = checkpoint_utils.load_model_ensemble(
        model_paths,
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    # Set dictionaries
    tgt_dict = task.target_dictionary

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]
        ),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())

    stats = collections.OrderedDict()
    num_sentences = 0
    num_correct = 0
    has_target = True

    with progress_bar.build_progress_bar(
        args, itr,
        prefix='inference on \'{}\' subset'.format(args.gen_subset),
        no_progress_bar='simple',
    ) as progress:
        wps_meter = TimeMeter()
        for sample in progress:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            gen_timer.start()
            hypos = task.inference_step(generator, models, sample)
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample['id'].tolist()):
                has_target = sample['target'] is not None
                target_tokens = None
                if has_target:
                    if use_ctc_loss:
                        target_tokens = sample['target'][i]
                        target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True)
                    else:
                        # Remove padding
                        target_tokens = utils.strip_pad(sample['target'][i, :], tgt_dict.pad()).int().cpu()
                        # Regenerate original sentences from tokens.
                        target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True)

                if not args.quiet:
                    if has_target:
                        print('\nT-{}\t{}'.format(sample_id, target_str))

                # Process top predictions
                hypo = hypos[i][0]
                hypo_tokens = hypo['tokens'] if use_ctc_loss else hypo['tokens'].int().cpu()
                hypo_str = tgt_dict.string(hypo_tokens, args.remove_bpe, escape_unk=True)
                alignment = hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None

                if hypo_str == target_str:
                    num_correct += 1

                if not args.quiet:
                    print('H-{}\t{}\t{}'.format(sample_id, hypo_str, hypo['score']))
                    print('P-{}\t{}'.format(
                        sample_id,
                        ' '.join(map(
                            lambda x: '{:.4f}'.format(x),
                            hypo['positional_scores'].tolist(),
                        )) if not use_ctc_loss else None
                    ))

                    if args.print_alignment:
                        print('A-{}\t{}'.format(
                            sample_id,
                            ' '.join(map(lambda x: str(utils.item(x)), alignment))
                        ))

                # Score only the top hypothesis
                if has_target:
                    if hasattr(scorer, 'add_string'):
                        scorer.add_string(target_str, hypo_str)
                    else:
                        scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(num_generated_tokens)
            num_sentences += sample['nsentences']
            stats['wps'] = round(wps_meter.avg)
            stats['acc'] = num_correct / num_sentences
            progress.log(stats, tag='accuracy')
        progress.print(stats, tag='accuracy')

    print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
        num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
    return scorer
Exemple #13
0
def main(args):
    logging.basicConfig(
        format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        level=os.environ.get('LOGLEVEL', 'INFO').upper(),
        stream=sys.stdout,
    )
    logger = logging.getLogger('fairseq_cli.generate')

    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    #print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    logger.info('loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(':'),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    mose_en = MosesDetokenizer(lang='en')
    mose_de = MosesDetokenizer(lang='de')
    if args.bert_model_path :
        bert_tokenizer = BertTokenizer.from_pretrained(args.bert_model_path)
        bert_model = BertModel.from_pretrained(args.bert_model_path)
        bert_model.cuda()
        bert_model.eval()

        bert_de_tokenizer = BertTokenizer.from_pretrained(args.bert_german_path)
        bert_de_model = BertModel.from_pretrained(args.bert_german_path)
        bert_de_model.cuda()
        bert_de_model.eval()
    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        model.eval()
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]
        ),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)
    batch_len = len(task.dataset(args.gen_subset))

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    #Bert model
    #bert_model = bert_as_lm.Bert_score(torch.cuda.current_device()) 
    #mose_de = MosesDetokenizer(lang='en')
    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True

    total_pearson = []
    total_bert_pearson = []
    random_list = []
    bert_bleu_equal = 0
    sents_num = 0
    if args.gen_subset == 'train':
        random_list = [i for i in range(0, batch_len)]
        random.shuffle(random_list)
        random_list = random_list[:1000]

    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:
            selected = []
            if random_list:
                for i in sample['id']:
                    selected.append(True) if i in random_list else selected.append(False)
                selected = torch.nonzero(torch.tensor(selected).ne(0)).squeeze(-1)
                if len(selected) == 0:
                    continue
                for item in sample.keys():
                    if item == 'nsentences' or item== 'ntokens':
                        continue
                    elif item == 'net_input':
                        for input in sample[item].keys():
                            sample[item][input] = sample[item][input].index_select(0,selected)
                    else: sample[item] = sample[item].index_select(0,selected)
                sample['nsentences'] = len(selected)
                sample['ntokens'] = torch.LongTensor([s.ne(1).long().sum() for s in sample['target']]).sum().item()

            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample['target'][:, :args.prefix_size]

            gen_timer.start()
            hypos = task.inference_step(generator, models, sample, prefix_tokens)
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample['id'].tolist()):
                sents_num += 1
                if random_list and sample_id not in random_list:
                    continue
                has_target = sample['target'] is not None

                # Remove padding
                src_tokens = utils.strip_pad(sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
                target_tokens = None
                if has_target:
                    target_tokens = utils.strip_pad(sample['target'][i, :], tgt_dict.pad()).int().cpu()

                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
                    target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id)
                else:
                    if src_dict is not None:
                        src_str = src_dict.string(src_tokens, args.remove_bpe)
                        src_splittokens = src_str.split(' ')
                        src_str_for_bert = mose_de.detokenize(src_splittokens)
                        src_ids, src_bert = \
                                get_bert_out(src_str_for_bert, bert_de_tokenizer, bert_de_model)
                    else:
                        src_str = ""
                    if has_target:
                        target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True)

                if not args.quiet:
                    print ("---------------{}--------------".format(sents_num))
                    if src_dict is not None:
                        print('S-{}\t{}'.format(sample_id, src_str), file=sys.stdout)
                    if has_target:
                        print('T-{}\t{}'.format(sample_id, target_str), file=sys.stdout)

                # Process top predictions
                probs = []
                bleu_score = []
                cands = []
                sents_bert_score = []
                detoken_cands =[]
                temp_cand_tokens = []

                for j, hypo in enumerate(hypos[i][:args.nbest]):
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int().cpu(),
                        src_str=src_str,
                        alignment=hypo['alignment'],
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )

                    if not args.quiet:
                        score = hypo['score'] / math.log(2)  # convert to base 2
                    probs.append(score)
                    single_score = sacrebleu.corpus_bleu([hypo_str],[[target_str]], use_effective_order=True, tokenize="none")
                    bleu_score.append(single_score.score)
                    cands.append(hypo_str)
                    temp_cand_tokens.append(hypo['tokens'].cpu())

                    hypo_splittokens = hypo_str.split(' ')
                    detoken_hypo = mose_de.detokenize(hypo_splittokens)
                    detoken_cands.append(detoken_hypo)

                #print (cands)
                cand_ids, cand_bert = get_bert_out(detoken_cands, bert_tokenizer, bert_model)
                src_hidden = models[0].decoder.through_ffnet(src_bert[:,0,:],'src')
                cand_hidden = models[0].decoder.through_ffnet(cand_bert[:,0,:],'cand')
                cand_net_vocab = models[0].decoder.through_ffnet(cand_bert)
                cand_lprob = \
                        models[0].get_normalized_probs([cand_net_vocab],log_probs=True).view(-1,cand_net_vocab.size(-1))
                #compute similarity
                cos_sim = F.cosine_similarity(
                    src_hidden.repeat(args.nbest,1), cand_hidden, dim=1, eps=1e-6) #[beam]
                cos_sim = torch.log((cos_sim + 1)/2)
                cand_probs = cand_lprob.gather(dim=-1,index=cand_ids.view(-1,1))
                cand_loss = cand_probs.view(-1,cand_ids.size(-1)).sum(dim=1)
                cand_loss = cand_loss/cand_ids.ne(0).sum(dim=1).float()
                #print (cand_loss)
                total_score = cos_sim + 0.1*cand_loss
                #print (total_score)
                net_pos = torch.argmax(total_score).item()

                pearson = 0
                #if args.nbest > 20:
                np_prob = np.array(probs)
                np_bleu = np.array(bleu_score)
                pearson = np.corrcoef(np_prob, np_bleu)[0][1]

                if not np.isnan(pearson):
                    total_pearson.append(pearson)
                #else:
                #    print ("cands:", cands, file=sys.stdout)
                #    print ("probs:", np_prob, file=sys.stdout)
                #    print ("bleus:", np_bleu, file=sys.stdout)

                bleu_pos = bleu_score.index(max(bleu_score))
                print ("-----bleu choice: {} bleu:{:.3f}  pos: {}".format(
                            cands[bleu_pos], bleu_score[bleu_pos], bleu_pos+1),file=sys.stdout)
                pos = probs.index(max(probs))
                print ("-----prob choice: {} bleu:{:.3f} pos: {}".format(
                        cands[pos], bleu_score[pos],pos+1),file=sys.stdout)
                print ("-----net choice: {} bleu:{:.3f} pos: {} score:{:.3f}".format(
                        cands[net_pos], bleu_score[net_pos],net_pos+1, total_score[net_pos]),file=sys.stdout)

                '''
                np_bert = np.array(sents_bert_score)
                bert_bleu_pearson = np.corrcoef(np_bert, np_bleu)[0][1]
                if not np.isnan(bert_bleu_pearson):
                   total_bert_pearson.append(bert_bleu_pearson) 

                bert_pos = sents_bert_score.index(min(sents_bert_score))
                print('*****{} bert choice: {}\tprob:{:.3f}\tbleu:{:.3f}\tbertscore:{:.3f}\tposition:{}\tprob_bleu_pearson:{:.3f} bert_bleu_p: {:.3f} '. \
                format(sample_id, cands[bert_pos], probs[bert_pos], bleu_score[bert_pos],
                    sents_bert_score[bert_pos], bert_pos+1, pearson, bert_bleu_pearson), file=sys.stdout)
                '''
                if args.usebleu:
                    final_hypo = cands[bleu_pos]
                elif args.usebert:
                    final_hypo = cands[net_pos]
                else: final_hypo = cands[pos]
                scorer.add_string(target_str, final_hypo)

                print ('H choice use bleu: {} usebert: {}'.format(args.usebleu, args.usebert))
                if has_target and sents_num % 800 == 0:
                    print('Generate {} with beam={}: {}\t{}'.format(args.gen_subset, args.beam,
                                                                    scorer.result_string(),
                                                                    sents_num,file=sys.stdout))
            wps_meter.update(num_generated_tokens)
            #t.log({'wps': round(wps_meter.avg)})
            num_sentences += sample['nsentences']

    logger.info('Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
        num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('Generate {} with beam={}: {}\n --- prob&bleu pearson: {:.4f} ---'.format(
            args.gen_subset, args.beam, scorer.result_string(), 
            sum(total_pearson)/len(total_pearson)), file=sys.stdout)

    return scorer
Exemple #14
0
def main(args, override_args=None):
    utils.import_user_module(args)

    use_cuda = torch.cuda.is_available() and not args.cpu
    eval_task = args.eval_task
    model, task, criterion, generator, model_args = set_up_model(
        args, args.vqvae_path, override_args)
    if eval_task == 'sampling':
        assert args.prior_path is not None
        prior_model, prior_task, prior_criterion, prior_generator, prior_args = set_up_model(
            args, args.prior_path, None)

    dictionary = task.dictionary
    if eval_task == 'code_extract':
        fopt = io.open(os.path.join(args.results_path,
                                    args.gen_subset + ".codes"),
                       "w",
                       encoding='utf-8')
    elif eval_task == 'reconstruct':
        if args.sampling:
            prefix = ".sample"
        else:
            prefix = ".bs"
        if args.code_extract_strategy is not None:
            prefix = prefix + '.' + args.code_extract_strategy
        else:
            prefix += '.orig.sampling'
        if args.prefix_num > 0:
            prefix = prefix + ".prefix_{}".format(args.prefix_num)
        fopt = io.open(os.path.join(
            args.results_path, args.gen_subset + prefix + ".reconstruction"),
                       "w",
                       encoding='utf-8')
        # Generate and compute BLEU score
        if args.sacrebleu:
            scorer = bleu.SacrebleuScorer()
        else:
            scorer = bleu.Scorer(dictionary.pad(), dictionary.eos(),
                                 dictionary.unk())
    elif eval_task == 'sampling':
        fopt = io.open(os.path.join(args.results_path,
                                    args.gen_subset + ".samples"),
                       "w",
                       encoding='utf-8')
    else:
        raise ValueError

    # Initialize generator
    gen_timer = StopwatchMeter()
    num_sentences = 0
    generate_id = 0
    if eval_task != 'sampling':
        # Load valid dataset (we load training data below, based on the latest checkpoint)
        for subset in args.gen_subset.split(','):
            try:
                task.load_dataset(subset, combine=False, epoch=args.shard_id)
                dataset = task.dataset(subset)
            except KeyError:
                raise Exception('Cannot find dataset: ' + subset)

            # Initialize data iterator
            itr = task.get_batch_iterator(
                dataset=dataset,
                max_tokens=args.max_tokens,
                max_sentences=args.max_sentences,
                max_positions=utils.resolve_max_positions(
                    task.max_positions(),
                    model.max_positions(),
                ),
                ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
                required_batch_size_multiple=args.required_batch_size_multiple,
                seed=args.seed,
                num_workers=args.num_workers,
            ).next_epoch_itr(shuffle=False)
            progress = progress_bar.build_progress_bar(
                args,
                itr,
                prefix='valid on \'{}\' subset'.format(subset),
                no_progress_bar='simple')

            log_outputs = []
            wps_meter = TimeMeter()
            all_codes = set()
            all_stats = []
            for jj, sample in enumerate(progress):
                sample = utils.move_to_cuda(sample) if use_cuda else sample
                log_output = {'sample_size': sample['target'].size(0)}
                log_outputs.append(log_output)
                num_sentences += sample['nsentences']

                if eval_task == 'code_extract':
                    codes = task.extract_codes(sample, model)
                    if args.gen_subset == 'valid':
                        all_codes.update(torch.unique(codes).tolist())
                    codes = codes.cpu().numpy()
                elif eval_task == 'reconstruct':
                    prefix_tokens = None if args.prefix_num == 0 else sample[
                        'target'][:, :args.prefix_num]
                    gen_timer.start()
                    hypos, codes, stats = task.reconstruct(
                        sample,
                        model,
                        generator,
                        prefix_tokens,
                        extract_mode=args.code_extract_strategy)
                    if isinstance(codes, list):
                        codes, topk = codes
                    else:
                        topk = None
                    if 'avg_topp' in stats:
                        all_stats.append(stats)
                    all_codes.update(torch.unique(codes).tolist())
                    num_generated_tokens = sum(
                        len(h[0]['tokens']) for h in hypos)
                    gen_timer.stop(num_generated_tokens)
                    wps_meter.update(num_generated_tokens)
                    progress.log({'wps': round(wps_meter.avg)})
                else:
                    raise NotImplementedError

                progress.log(log_output, step=jj)

                for i, sample_id in enumerate(sample['id'].tolist()):
                    tokens = utils.strip_pad(sample['target'][i, :],
                                             dictionary.pad())
                    origin_string = dictionary.string(
                        tokens, bpe_symbol=args.remove_bpe, escape_unk=True)
                    if len(tokens) <= 1:
                        continue
                    bpe_string = dictionary.string(tokens,
                                                   bpe_symbol=None,
                                                   escape_unk=True)
                    fopt.write('T-ori-{}\t{}\n'.format(sample_id,
                                                       origin_string))
                    fopt.write('T-bpe-{}\t{}\n'.format(sample_id, bpe_string))

                    if eval_task == 'code_extract':
                        code = codes[i]
                        fopt.write('C-{}\t{}\n'.format(
                            sample_id, ' '.join([
                                "c-{}".format(x) for x in code.tolist()
                                if x != -1
                            ])))
                    elif eval_task == 'reconstruct':
                        for j, hypo in enumerate(hypos[i][:args.nbest]):
                            code = codes[i]
                            hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                                hypo_tokens=hypo['tokens'].int().cpu()
                                [:len(tokens) - 1],
                                src_str="",
                                alignment=hypo['alignment'],
                                align_dict=None,
                                tgt_dict=dictionary,
                                remove_bpe=args.remove_bpe,
                            )
                            _, hypo_str_bpe, _ = utils.post_process_prediction(
                                hypo_tokens=hypo['tokens'].int().cpu()
                                [:len(tokens) - 1],
                                src_str="",
                                alignment=hypo['alignment'],
                                align_dict=None,
                                tgt_dict=dictionary,
                                remove_bpe=None,
                            )

                            fopt.write('H-bpe-{}\t{}\t{}\n'.format(
                                sample_id, hypo['score'], hypo_str_bpe))
                            fopt.write('H-{}\t{}\t{}\n'.format(
                                sample_id, hypo['score'], hypo_str))
                            code_str = ""
                            if topk is not None:
                                prob_str = ""
                                sent_topk = topk[i] / topk[i].sum(1).unsqueeze(
                                    1)
                            for ii, token_code in enumerate(code):
                                code_str += " ".join([
                                    "c{}-{}".format(ii, kk)
                                    for kk in token_code if kk != -1
                                ]) + ' '
                                if topk is not None:
                                    prob_str += " ".join([
                                        "p{}-{:.2f}".format(ii, kk.item())
                                        for kk, mm in zip(
                                            sent_topk[ii], token_code)
                                        if mm != -1
                                    ]) + ' '
                            fopt.write('C-{}\t{}\n'.format(
                                sample_id, code_str))
                            if topk is not None:
                                fopt.write('K-{}\n'.format(prob_str))
                            if hypo['attention'] is not None:
                                hypo_attn = hypo['attention'].cpu().numpy(
                                )  # src_len x tgt_len
                                entropy, max_idx = compute_attn(hypo_attn)
                                baseline_entropy = hypo_attn.shape[
                                    0] * math.log(1. / hypo_attn.shape[0]
                                                  ) * 1.0 / hypo_attn.shape[0]
                                fopt.write(
                                    'A-entropy-baseline-{:.2f}\n'.format(
                                        baseline_entropy))
                                fopt.write('A-entropy-{}\n'.format(" ".join(
                                    ["%.2f" % e for e in entropy])))
                                fopt.write('A-max-attn-pos-{}\n'.format(
                                    " ".join([str(kk) for kk in max_idx])))
                            fopt.write('\n')
                            if args.remove_bpe is not None:
                                # Convert back to tokens for evaluation with unk replacement and/or without BPE
                                tokens = dictionary.encode_line(
                                    origin_string, add_if_not_exist=True)
                            if hasattr(scorer, 'add_string'):
                                scorer.add_string(origin_string, hypo_str)
                            else:
                                scorer.add(tokens, hypo_tokens)
                    else:
                        raise NotImplementedError
                generate_id += len(sample['id'])
                if generate_id % 1000 == 0:
                    print("Processed {} lines!".format(i))
                progress.print(log_outputs[0], tag=subset, step=i)
            print("Total unique active codes = {}".format(len(all_codes)))
            if len(all_stats) > 0:
                avg_topp = sum([stat['avg_topp']
                                for stat in all_stats]) * 1. / len(all_stats)
                max_topp = sum([stat['max_topp']
                                for stat in all_stats]) * 1. / len(all_stats)
                min_topp = sum([stat['min_topp']
                                for stat in all_stats]) * 1. / len(all_stats)
                print(
                    'Max #codes in top 0.9 = {}, min #codes in top 0.9 = {}, avg #codes in top 0.9 = {}'
                    .format(max_topp, min_topp, avg_topp))
        if eval_task == 'reconstruct':
            print(
                '| Reconstructed {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
                .format(num_sentences, gen_timer.n, gen_timer.sum,
                        num_sentences / gen_timer.sum, 1. / gen_timer.avg))
            print('| Reconstruct {} with beam={}: {}'.format(
                args.gen_subset, args.beam, scorer.result_string()))
    else:
        batch_size = 3072 // args.max_len_b
        gen_epochs = args.num_samples // batch_size
        latent_dictionary = prior_task.dictionary
        latent_dictionary_size = len(latent_dictionary)
        for ii in range(gen_epochs):
            dummy_tokens = torch.ones(batch_size, args.max_len_b).long().cuda()
            dummy_lengths = (torch.ones(args.max_len_b) *
                             args.max_len_b).long().cuda()
            dummy_samples = {
                'net_input': {
                    'src_tokens': dummy_tokens,
                    'src_lengths': dummy_lengths,
                },
                'target': dummy_tokens
            }
            prefix_tokens = None
            code_hypos = prior_task.inference_step(prior_generator,
                                                   [prior_model],
                                                   dummy_samples,
                                                   prefix_tokens)
            list_predictions = []
            for jj in range(batch_size):
                code_hypo = code_hypos[jj][0]  # best output
                latent_hypo_tokens, latent_hypo_str, _ = utils.post_process_prediction(
                    hypo_tokens=code_hypo['tokens'].int().cpu(),
                    src_str="",
                    alignment=code_hypo['alignment'],
                    align_dict=None,
                    tgt_dict=latent_dictionary,
                    remove_bpe=None,
                )
                # should have no pad and eos
                list_predictions.append(
                    torch.LongTensor([
                        int(ss) for ss in latent_hypo_str.strip().split()
                    ]).cuda())
            merged_codes = data_utils.collate_tokens(
                list_predictions,
                latent_dictionary_size,
                left_pad=False,
            )
            code_masks = merged_codes.eq(latent_dictionary_size)
            merged_codes = merged_codes.masked_fill_(code_masks, 0)
            hypos, _, _ = task.sampling(dummy_samples, merged_codes,
                                        code_masks, model, generator)
            for tt in range(len(hypos)):
                hypo = hypos[tt][0]
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str="",
                    alignment=hypo['alignment'],
                    align_dict=None,
                    tgt_dict=dictionary,
                    remove_bpe=args.remove_bpe,
                )
                fopt.write('C-{}\t{}\n'.format(
                    generate_id, " ".join([
                        "c-%d" % kk for kk in list_predictions[tt] if kk != -1
                    ])))
                fopt.write('H-{}\t{}\t{}\n'.format(generate_id, hypo['score'],
                                                   hypo_str))
                generate_id += 1

            if generate_id % 1000 == 0:
                print("Sampled {} sentences!".format(generate_id))
    fopt.close()
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(':'),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()
        model.decoder.alignment_layer = args.alignment_layer

    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())

    if args.print_vanilla_alignment:
        import string
        punc = string.punctuation
        src_punc_tokens = [
            w for w in range(len(src_dict)) if src_dict[w] in punc
        ]
        tgt_punc_tokens = [
            w for w in range(len(tgt_dict)) if tgt_dict[w] in punc
        ]
    else:
        src_punc_tokens = None

    import time
    print('start time is :', time.strftime("%Y-%m-%d %X"))
    with progress_bar.build_progress_bar(args, itr) as t:
        if args.decoding_path is not None:
            align_sents = [[] for _ in range(4000000)]

        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                continue
            if args.print_vanilla_alignment:
                if args.set_shift:
                    alignments = utils.extract_soft_alignment(
                        sample,
                        models[0],
                        src_punc_tokens,
                        tgt_punc_tokens,
                        alignment_task=args.alignment_task)
                else:
                    alignments = utils.extract_soft_alignment_noshift(
                        sample,
                        models[0],
                        src_punc_tokens,
                        tgt_punc_tokens,
                        alignment_task=args.alignment_task)
            else:
                alignments = None

            for sample_id in sample['id'].tolist():
                if args.print_vanilla_alignment and args.decoding_path is not None:
                    align_sents[int(sample_id)].append(
                        alignments[int(sample_id)])

    print('end time is :', time.strftime("%Y-%m-%d %X"))
    if args.decoding_path is not None and args.print_vanilla_alignment:
        with open(
                os.path.join(
                    args.decoding_path,
                    f'{args.gen_subset}.{args.source_lang}2{args.target_lang}.align'
                ), 'w') as f:
            for sents in align_sents:
                if len(sents) == 0:
                    continue
                for sent in sents:
                    f.write(str(sent) + '\n')
        print("finished ...")
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    # print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(':'),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    trained_epoch = checkpoint_utils.get_checkpoint_epoch(args.path.split(':'))

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    # itr = task.get_batch_iterator(
    #     dataset=task.dataset(args.gen_subset),
    #     max_tokens=args.max_tokens,
    #     max_sentences=args.max_sentences,
    #     max_positions=utils.resolve_max_positions(
    #         task.max_positions(),
    #         *[model.max_positions() for model in models]
    #     ),
    #     ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
    #     required_batch_size_multiple=args.required_batch_size_multiple,
    #     num_shards=args.num_shards,
    #     shard_id=args.shard_id,
    #     num_workers=args.num_workers,
    # ).next_epoch_itr(shuffle=False)
    # we modify to use the max_positions only from the task and not the model.
    # the reason is that we keep a low max positions while training transformer
    # to handle large batches, but we need to disable this while testing to get
    # metrics evaluated on full dev/test set.
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=task.max_positions(),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    # Generate and compute BLEU score
    # em_scorer = bleu.EmScorer()
    all_metrics = bleu.Metric()
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    all_preds = []
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample['target'][:, :args.prefix_size]

            gen_timer.start()
            hypos = task.inference_step(generator, models, sample,
                                        prefix_tokens)
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample['id'].tolist()):
                has_target = sample['target'] is not None

                # Remove padding
                src_tokens = utils.strip_pad(
                    sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
                target_tokens = None
                if has_target:
                    target_tokens = utils.strip_pad(
                        sample['target'][i, :], tgt_dict.pad()).int().cpu()

                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(
                        args.gen_subset).src.get_original_text(sample_id)
                    target_str = task.dataset(
                        args.gen_subset).tgt.get_original_text(sample_id)
                else:
                    if src_dict is not None:
                        src_str = src_dict.string(src_tokens, args.remove_bpe)
                    else:
                        src_str = ""
                    if has_target:
                        target_str = tgt_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)

                all_preds.append({
                    'id': sample_id,
                    'tgt_str': target_str,
                    'src_str': src_str,
                    'url': task.urls[sample_id]
                })

                if not args.quiet:
                    if src_dict is not None:
                        print('S-{}\t{}'.format(sample_id, src_str))
                    if has_target:
                        print('T-{}\t{}'.format(sample_id, target_str))

                # Process top predictions
                for i, hypo in enumerate(
                        hypos[i][:min(len(hypos), args.nbest)]):
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int().cpu(),
                        src_str=src_str,
                        alignment=hypo['alignment'].int().cpu()
                        if hypo['alignment'] is not None else None,
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )
                    # print('=========')
                    # print(hypo_tokens)
                    # print(hypo_str)
                    # print(align_dict)
                    if i == 0:
                        # get best hypothesis
                        all_preds[-1]['hypo_str'] = hypo_str

                    if not args.quiet:
                        print('H-{}\t{}\t{}'.format(sample_id, hypo['score'],
                                                    hypo_str))
                        print('P-{}\t{}'.format(
                            sample_id, ' '.join(
                                map(
                                    lambda x: '{:.4f}'.format(x),
                                    hypo['positional_scores'].tolist(),
                                ))))

                        if args.print_alignment:
                            print('A-{}\t{}'.format(
                                sample_id, ' '.join(
                                    map(lambda x: str(utils.item(x)),
                                        alignment))))

                    # Score only the top hypothesis
                    if has_target and i == 0:
                        if align_dict is not None or args.remove_bpe is not None:
                            # Convert back to tokens for evaluation with unk replacement and/or without BPE
                            target_tokens = tgt_dict.encode_line(
                                target_str, add_if_not_exist=True)
                        if hasattr(scorer, 'add_string'):
                            scorer.add_string(target_str, hypo_str)
                        else:
                            scorer.add(target_tokens, hypo_tokens)
                        all_metrics.add_string(target_str, hypo_str)

            wps_meter.update(num_generated_tokens)
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += sample['nsentences']

    # we can dump the preds out in notebook format
    preds_dir = dirname(args.path) + '/preds'
    # sort them in order of index in dev/test set.
    all_preds.sort(key=lambda x: int(x['id']))
    log_preds_to_notebook(preds=all_preds, outdir=preds_dir)

    print(
        '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset,
                                                      args.beam,
                                                      scorer.result_string()))
        print(args.path)

        # compute, store, and print the metrics
        all_metrics.compute_metrics(trained_epoch)
        all_metrics.save(dirname(args.path) + '/metrics.json')
        print('All metrics:')
        print(all_metrics.result_string())

    return all_metrics.get_metric('corpus_bleu'), all_metrics.get_metric('em')
Exemple #17
0
def run_generation(ckpt, results, ents):
    gen_parser = options.get_generation_parser()
    args = options.parse_args_and_arch(gen_parser,
                                       input_args=[
                                           data_set, '--gen-subset', 'valid',
                                           '--path', ckpt, '--beam', '10',
                                           '--max-tokens', '4000',
                                           '--sacrebleu', '--remove-bpe',
                                           '--log-format', 'none'
                                       ])

    use_cuda = torch.cuda.is_available() and not args.cpu
    # if use_cuda:
    #     lock.acquire()
    #     torch.cuda.set_device(device_id)
    #     lock.release()

    utils.import_user_module(args)
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(':'),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    generator = task.build_generator(args)

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    entropies = []
    token_counts = []
    with progress_bar.build_progress_bar(args, itr) as t:
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample['target'][:, :args.prefix_size]

            hypos = task.inference_step(generator, models, sample,
                                        prefix_tokens)
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)

            if 'avg_ent' in sample:
                entropies.append(sample['avg_ent'][0])
                token_counts.append(sample['avg_ent'][1])

            for i, sample_id in enumerate(sample['id'].tolist()):
                has_target = sample['target'] is not None

                # Remove padding
                src_tokens = utils.strip_pad(
                    sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
                target_tokens = None
                if has_target:
                    target_tokens = utils.strip_pad(
                        sample['target'][i, :], tgt_dict.pad()).int().cpu()

                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(
                        args.gen_subset).src.get_original_text(sample_id)
                    target_str = task.dataset(
                        args.gen_subset).tgt.get_original_text(sample_id)
                else:
                    if src_dict is not None:
                        src_str = src_dict.string(src_tokens, args.remove_bpe)
                    else:
                        src_str = ""
                    if has_target:
                        target_str = tgt_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)
                # Process top predictions
                for j, hypo in enumerate(hypos[i][:args.nbest]):
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int().cpu(),
                        src_str=src_str,
                        alignment=hypo['alignment'],
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )

                    # Score only the top hypothesis
                    if has_target and j == 0:
                        if align_dict is not None or args.remove_bpe is not None:
                            # Convert back to tokens for evaluation with unk replacement and/or without BPE
                            target_tokens = tgt_dict.encode_line(
                                target_str, add_if_not_exist=True)
                        if hasattr(scorer, 'add_string'):
                            scorer.add_string(target_str, hypo_str)
                        else:
                            scorer.add(target_tokens, hypo_tokens)

            num_sentences += sample['nsentences']

    results[ckpt] = scorer.score()
    ents[ckpt] = sum(entropies) / sum(token_counts)
Exemple #18
0
 def __init__(self, lowercase=False):
     self.scorer = bleu.SacrebleuScorer(lowercase=lowercase)
     self.target_sum = []
     self.hypo_sum = []
     self.count = 0
Exemple #19
0
def _main(args, output_file):
    logging.basicConfig(
        format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        level=logging.INFO,
        stream=output_file,
    )
    logger = logging.getLogger('fairseq_cli.generate')

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    logger.info('loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(os.pathsep),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    args.vocab_size = len(tgt_dict)
    for arg in vars(_model_args).keys():
        if arg in {'decoder_embed_dim', 'vocab_size'}:
            setattr(args, arg, getattr(_model_args, arg))

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    if args.knnlm and args.save_knnlm_dstore:
        raise ValueError(
            "Cannot use knnlm while trying to build the datastore!")
    if args.knnlm:
        knn_dstore = KNN_Dstore(args)

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())

    if args.save_knnlm_dstore:
        print('keytype being saved:', args.knn_keytype)
        if args.dstore_fp16:
            print('Saving fp16')
            dstore_keys = np.memmap(args.dstore_mmap + '_keys.npy',
                                    dtype=np.float16,
                                    mode='w+',
                                    shape=(args.dstore_size,
                                           args.decoder_embed_dim))
            dstore_vals = np.memmap(args.dstore_mmap + '_vals.npy',
                                    dtype=np.int16,
                                    mode='w+',
                                    shape=(args.dstore_size, 1))
        else:
            print('Saving fp32')
            dstore_keys = np.memmap(args.dstore_mmap + '_keys.npy',
                                    dtype=np.float32,
                                    mode='w+',
                                    shape=(args.dstore_size,
                                           args.decoder_embed_dim))
            dstore_vals = np.memmap(args.dstore_mmap + '_vals.npy',
                                    dtype=np.int,
                                    mode='w+',
                                    shape=(args.dstore_size, 1))
        dstore_idx = 0
    if args.save_knnlm_dstore or args.knnlm:
        # source_tokens_file = open(args.output_tokens_file_prefix + '.src' , 'w')
        target_tokens_file = open(args.output_tokens_file_prefix + '.tgt', 'w')

        # This is only for MT right now, use interactive.py for language modeling
        assert task != 'language_modeling'

    num_sentences = 0
    has_target = True
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample['target'][:, :args.prefix_size]

            gen_timer.start()
            if args.knnlm:
                hypos = task.inference_step(generator,
                                            models,
                                            sample,
                                            prefix_tokens,
                                            knn_dstore=knn_dstore)
            else:
                hypos = task.inference_step(generator, models, sample,
                                            prefix_tokens)
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            if args.save_knnlm_dstore:
                for i, hypos_i in enumerate(hypos):
                    hypo = hypos_i[0]
                    shape = hypo['dstore_keys'].shape
                    if dstore_idx + shape[0] > args.dstore_size:
                        shape = [args.dstore_size - dstore_idx]
                        hypo['dstore_keys'] = hypo['dstore_keys'][:shape[0]]
                    # import pdb; pdb.set_trace()
                    # print(hypo)
                    if args.dstore_fp16:
                        dstore_keys[dstore_idx:shape[0] +
                                    dstore_idx] = hypo['dstore_keys'].view(
                                        -1, args.decoder_embed_dim).cpu(
                                        ).numpy().astype(np.float16)
                        dstore_vals[dstore_idx:shape[0] +
                                    dstore_idx] = hypo['tokens'].view(
                                        -1, 1).cpu().numpy().astype(np.int16)
                    else:
                        dstore_keys[dstore_idx:shape[0] +
                                    dstore_idx] = hypo['dstore_keys'].view(
                                        -1, args.decoder_embed_dim).cpu(
                                        ).numpy().astype(np.float32)
                        dstore_vals[dstore_idx:shape[0] +
                                    dstore_idx] = hypo['tokens'].view(
                                        -1, 1).cpu().numpy().astype(np.int)
                    dstore_idx += shape[0]

            if args.save_knnlm_dstore or args.knnlm:
                for i, hypos_i in enumerate(hypos):
                    hypo = hypos_i[0]

                    # dump the tokens to a file, used for analysis and interactive printing
                    # source_tokens = [task.source_dictionary[token] for token in hypo['source_tokens']]
                    # source_tokens_file.write('\n'.join(source_tokens) + '\n')

                    target_tokens = [
                        task.target_dictionary[token]
                        for token in hypo['tokens']
                    ]
                    target_tokens_file.write('\n'.join(target_tokens) + '\n')

            for i, sample_id in enumerate(sample['id'].tolist()):
                has_target = sample['target'] is not None

                # Remove padding
                src_tokens = utils.strip_pad(
                    sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
                target_tokens = None
                if has_target:
                    target_tokens = utils.strip_pad(
                        sample['target'][i, :], tgt_dict.pad()).int().cpu()

                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(
                        args.gen_subset).src.get_original_text(sample_id)
                    target_str = task.dataset(
                        args.gen_subset).tgt.get_original_text(sample_id)
                else:
                    if src_dict is not None:
                        src_str = src_dict.string(src_tokens, args.remove_bpe)
                    else:
                        src_str = ""
                    if has_target:
                        target_str = tgt_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)

                if not args.quiet:
                    if src_dict is not None:
                        print('S-{}\t{}'.format(sample_id, src_str),
                              file=output_file)
                    if has_target:
                        print('T-{}\t{}'.format(sample_id, target_str),
                              file=output_file)

                # Process top predictions
                for j, hypo in enumerate(hypos[i][:args.nbest]):
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int().cpu(),
                        src_str=src_str,
                        alignment=hypo['alignment'],
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )

                    if not args.quiet:
                        score = hypo['score'] / math.log(
                            2)  # convert to base 2
                        print('H-{}\t{}\t{}'.format(sample_id, score,
                                                    hypo_str),
                              file=output_file)
                        print(
                            'P-{}\t{}'.format(
                                sample_id,
                                ' '.join(
                                    map(
                                        lambda x: '{:.4f}'.format(x),
                                        # convert from base e to base 2
                                        hypo['positional_scores'].div_(
                                            math.log(2)).tolist(),
                                    ))),
                            file=output_file)

                        if args.print_alignment:
                            print('A-{}\t{}'.format(
                                sample_id, ' '.join([
                                    '{}-{}'.format(src_idx, tgt_idx)
                                    for src_idx, tgt_idx in alignment
                                ])),
                                  file=output_file)

                        if args.print_step:
                            print('I-{}\t{}'.format(sample_id, hypo['steps']),
                                  file=output_file)

                        if getattr(args, 'retain_iter_history', False):
                            for step, h in enumerate(hypo['history']):
                                _, h_str, _ = utils.post_process_prediction(
                                    hypo_tokens=h['tokens'].int().cpu(),
                                    src_str=src_str,
                                    alignment=None,
                                    align_dict=None,
                                    tgt_dict=tgt_dict,
                                    remove_bpe=None,
                                )
                                print('E-{}_{}\t{}'.format(
                                    sample_id, step, h_str),
                                      file=output_file)
                    # Score only the top hypothesis
                    if has_target and j == 0:
                        if align_dict is not None or args.remove_bpe is not None:
                            # Convert back to tokens for evaluation with unk replacement and/or without BPE
                            target_tokens = tgt_dict.encode_line(
                                target_str, add_if_not_exist=True)
                        if hasattr(scorer, 'add_string'):
                            scorer.add_string(target_str, hypo_str)
                        else:
                            scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(num_generated_tokens)
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += sample['nsentences']

    logger.info('NOTE: hypothesis and token scores are output in base 2')
    logger.info(
        'Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        logger.info('Generate {} with beam={}: {}'.format(
            args.gen_subset, args.beam, scorer.result_string()))

    if args.save_knnlm_dstore:
        print("dstore_idx", dstore_idx, "final shape", shape)
        print("Keys", dstore_keys.shape, dstore_keys.dtype)
        print("Vals", dstore_vals.shape, dstore_vals.dtype)
        target_tokens_file.seek(0)
        num_lines = len(target_tokens_file.readlines())
        if dstore_idx != num_lines:
            print(
                'Warning: size of KNN datastore is {}, does not match number of lines in train tokens file which is {}'
                .format(dstore_idx, num_lines))

    if args.save_knnlm_dstore or args.knnlm:
        # source_tokens_file.close()
        target_tokens_file.close()

    return scorer
Exemple #20
0
 def __init__(self):
     self.correct, self.total, self.sys_len, self.ref_len = utils.get_zero_bleu_stats(
     )
     # TODO handle lowercase
     self.scorer = bleu.SacrebleuScorer(lowercase=False)
Exemple #21
0
def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile,
                      score_outfile, lm_outfile, fw_outfile, bw_outfile,
                      write_hypos, normalize):

    print("lenpen", lenpen, "weight1", a, "weight2", b, "weight3", c,
          "target_outfile", target_outfile, "hypo_outfile", hypo_outfile,
          "lm_outfile", lm_outfile)
    gen_output_lst, bitext1_lst, bitext2_lst, lm_res_lst = load_score_files(
        args)
    dict = dictionary.Dictionary()
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk())

    ordered_hypos = {}
    ordered_targets = {}
    ordered_scores = {}
    ordered_lm = {}
    ordered_fw = {}
    ordered_bw = {}

    for shard_id in range(len(bitext1_lst)):
        bitext1 = bitext1_lst[shard_id]
        bitext2 = bitext2_lst[shard_id]
        gen_output = gen_output_lst[shard_id]
        lm_res = lm_res_lst[shard_id]

        total = len(bitext1.rescore_source.keys())
        source_lst = []
        hypo_lst = []
        score_lst = []
        lm_lst = []
        fw_lst = []
        bw_lst = []
        reference_lst = []
        j = 1
        best_score = -math.inf

        for i in range(total):
            # length is measured in terms of words, not bpe tokens, since models may not share the same bpe
            target_len = len(bitext1.rescore_hypo[i].split())

            if lm_res is not None and i in lm_res.score:
                lm_score = lm_res.score[i]
            else:
                lm_score = 0

            if bitext2 is not None:
                bitext2_score = bitext2.rescore_score[i]
                bitext2_backwards = bitext2.backwards
            else:
                bitext2_score = None
                bitext2_backwards = None

            score = rerank_utils.get_score(a,
                                           b,
                                           c,
                                           target_len,
                                           bitext1.rescore_score[i],
                                           bitext2_score,
                                           lm_score=lm_score,
                                           lenpen=lenpen,
                                           src_len=bitext1.source_lengths[i],
                                           tgt_len=bitext1.target_lengths[i],
                                           bitext1_backwards=bitext1.backwards,
                                           bitext2_backwards=bitext2_backwards,
                                           normalize=normalize)

            if score > best_score:
                best_score = score
                best_hypo = bitext1.rescore_hypo[i]
                best_lm = lm_score
                best_fw = bitext1.rescore_score[i]
                best_bw = bitext2_score

            if j == gen_output.num_hypos[i] or j == args.num_rescore:
                j = 1

                hypo_lst.append(best_hypo)
                score_lst.append(best_score)
                lm_lst.append(best_lm)
                fw_lst.append(best_fw)
                bw_lst.append(best_bw)
                source_lst.append(bitext1.rescore_source[i])
                reference_lst.append(bitext1.rescore_target[i])

                best_score = -math.inf
                best_hypo = ""
                best_lm = -math.inf
            else:
                j += 1

        gen_keys = list(sorted(gen_output.no_bpe_target.keys()))

        for key in range(len(gen_keys)):
            if args.prefix_len is None:
                assert hypo_lst[key] in gen_output.no_bpe_hypo[
                    gen_keys[key]], ("pred and rescore hypo mismatch: i: " +
                                     str(key) + ", " + str(hypo_lst[key]) +
                                     str(gen_keys[key]) +
                                     str(gen_output.no_bpe_hypo[key]))
                sys_tok = dict.encode_line(hypo_lst[key])
                ref_tok = dict.encode_line(
                    gen_output.no_bpe_target[gen_keys[key]])
                if args.sacrebleu:
                    scorer.add_string(gen_output.no_bpe_target[gen_keys[key]],
                                      hypo_lst[key])
                else:
                    scorer.add(ref_tok, sys_tok)

            else:
                full_hypo = rerank_utils.get_full_from_prefix(
                    hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]])
                sys_tok = dict.encode_line(full_hypo)
                ref_tok = dict.encode_line(
                    gen_output.no_bpe_target[gen_keys[key]])
                if args.sacrebleu:
                    scorer.add_string(gen_output.no_bpe_target[gen_keys[key]],
                                      hypo_lst[key])
                else:
                    scorer.add(ref_tok, sys_tok)

        # if only one set of hyper parameters is provided, write the predictions to a file
        if write_hypos:
            # recover the orinal ids from n best list generation
            for key in range(len(gen_output.no_bpe_target)):
                if args.prefix_len is None:
                    assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], \
                        "pred and rescore hypo mismatch:"+"i:"+str(key)+str(hypo_lst[key]) + str(gen_output.no_bpe_hypo[key])
                    ordered_hypos[gen_keys[key]] = hypo_lst[key]
                    ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[
                        gen_keys[key]]
                    ordered_scores[gen_keys[key]] = score_lst[key]
                    ordered_lm[gen_keys[key]] = lm_lst[key]
                    ordered_fw[gen_keys[key]] = fw_lst[key]
                    ordered_bw[gen_keys[key]] = bw_lst[key]

                else:
                    full_hypo = rerank_utils.get_full_from_prefix(
                        hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]])
                    ordered_hypos[gen_keys[key]] = full_hypo
                    ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[
                        gen_keys[key]]
                    ordered_scores[gen_keys[key]] = score_lst[key]
                    ordered_lm[gen_keys[key]] = lm_lst[key]
                    ordered_fw[gen_keys[key]] = fw_lst[key]
                    ordered_bw[gen_keys[key]] = bw_lst[key]

                # print("Target = " + ordered_targets[gen_keys[key]] + " Hypothesis = " + ordered_hypos[gen_keys[key]])

    # write the hypos in the original order from nbest list generation

    # print(ordered_hypos)
    # print(len(ordered_hypos))
    # print(ordered_scores)
    # print(len(ordered_scores))

    if args.num_shards == (len(bitext1_lst)):
        with open(target_outfile, 'a') as t:
            with open(hypo_outfile, 'a') as h:
                with open(score_outfile, 'a') as s:
                    with open(lm_outfile, 'a') as l:
                        with open(fw_outfile, 'a') as f:
                            with open(bw_outfile, 'a') as b:
                                for key in range(len(ordered_hypos)):
                                    t.write(ordered_targets[key])
                                    h.write(ordered_hypos[key])
                                    s.write(str(ordered_scores[key]) + "\n")
                                    l.write(str(ordered_lm[key]) + "\n")
                                    f.write(str(ordered_fw[key]) + "\n")
                                    b.write(str(ordered_bw[key]) + "\n")

    print(scorer)
    res = scorer.result_string(4)
    if write_hypos:
        print(res)

    if args.sacrebleu:
        score = res.score
    else:
        score = rerank_utils.parse_bleu_scoring(res)
    return score
def main(args):
    utils.import_user_module(args)

    if args.buffer_size < 1:
        args.buffer_size = 1
    if args.max_tokens is None and args.max_sentences is None:
        args.max_sentences = 1

    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert not args.max_sentences or args.max_sentences <= args.buffer_size, \
        '--max-sentences/--batch-size cannot be larger than --buffer-size'

    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Setup task, e.g., translation
    task = tasks.setup_task(args)

    # Load ensemble
    logger.info('loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(os.pathsep),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    # Handle tokenization and BPE
    tokenizer = encoders.build_tokenizer(args)
    bpe = encoders.build_bpe(args)

    def encode_fn(x):
        if tokenizer is not None:
            x = tokenizer.encode(x)
        if bpe is not None:
            x = bpe.encode(x)
        return x

    def decode_fn(x):
        if bpe is not None:
            x = bpe.decode(x)
        if tokenizer is not None:
            x = tokenizer.decode(x)
        return x

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    max_positions = utils.resolve_max_positions(
        task.max_positions(), *[model.max_positions() for model in models])

    num_sentences = 0
    if args.buffer_size > 1:
        logger.info('Sentence buffer size: %s', args.buffer_size)
    logger.info('NOTE: hypothesis and token scores are output in base 2')
    logger.info('Type the input sentence and press return:')
    start_id = 0
    for line in sys.stdin:
        inputs = [line.strip()]
        results = []
        for batch in make_batches(inputs, args, task, max_positions,
                                  encode_fn):
            src_tokens = batch.src_tokens
            src_lengths = batch.src_lengths
            tgt_tokens = batch.tgt_tokens
            num_sentences += src_tokens[0].size(0)
            if use_cuda:
                if isinstance(src_tokens, list):
                    src_tokens = [tokens.cuda() for tokens in src_tokens]
                    src_lengths = [lengths.cuda() for lengths in src_lengths]
                else:
                    src_tokens = src_tokens.cuda()
                    src_lengths = src_lengths.cuda()

            sample = {
                'net_input': {
                    'src_tokens': src_tokens,
                    'src_lengths': src_lengths,
                },
                'target': tgt_tokens,
            }

            gen_timer.start()
            translations = task.inference_step(generator, models, sample)
            num_generated_tokens = sum(
                len(h[0]['tokens']) for h in translations)
            gen_timer.stop(num_generated_tokens)

            for i, (id,
                    hypos) in enumerate(zip(batch.ids.tolist(), translations)):
                src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad())
                tgt_tokens_i = None
                if tgt_tokens is not None:
                    tgt_tokens_i = utils.strip_pad(tgt_tokens[i, :],
                                                   tgt_dict.pad()).int().cpu()
                results.append(
                    (start_id + id, src_tokens_i, hypos, tgt_tokens_i))

        # sort output to match input order
        for id, src_tokens, hypos, tgt_tokens in sorted(results,
                                                        key=lambda x: x[0]):
            if src_dict is not None:
                src_str = src_dict.string(src_tokens, args.remove_bpe)
                print('S-{}\t{}'.format(id, src_str))

            if tgt_tokens is not None:
                tgt_str = tgt_dict.string(tgt_tokens,
                                          args.remove_bpe,
                                          escape_unk=True)
                print('T-{}\t{}'.format(id, tgt_str))

            # Process top predictions
            for j, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'],
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe,
                )
                hypo_str = decode_fn(hypo_str)
                score = hypo['score'] / math.log(2)  # convert to base 2
                print('H-{}\t{}\t{}'.format(id, score, hypo_str))
                print('P-{}\t{}'.format(
                    id,
                    ' '.join(
                        map(
                            lambda x: '{:.4f}'.format(x),
                            # convert from base e to base 2
                            hypo['positional_scores'].div_(math.log(2)
                                                           ).tolist(),
                        ))))
                if args.print_alignment:
                    alignment_str = " ".join(
                        ["{}-{}".format(src, tgt) for src, tgt in alignment])
                    print('A-{}\t{}'.format(id, alignment_str))
                if args.print_step:
                    print('I-{}\t{}'.format(id, hypo['steps']))
                    print('O-{}\t{}'.format(id, hypo['num_ops']))

                if getattr(args, 'retain_iter_history', False):
                    for step, h in enumerate(hypo['history']):
                        _, h_str, _ = utils.post_process_prediction(
                            hypo_tokens=h['tokens'].int().cpu(),
                            src_str=src_str,
                            alignment=None,
                            align_dict=None,
                            tgt_dict=tgt_dict,
                            remove_bpe=None,
                        )
                        print('E-{}_{}\t{}'.format(id, step, h_str))

                # Score only the top hypothesis
                if tgt_tokens is not None and j == 0:
                    if align_dict is not None or args.remove_bpe is not None:
                        # Convert back to tokens for evaluation with unk replacement and/or without BPE
                        tgt_tokens = tgt_dict.encode_line(
                            tgt_str, add_if_not_exist=True)
                    if hasattr(scorer, 'add_string'):
                        scorer.add_string(tgt_str, hypo_str)
                    else:
                        scorer.add(tgt_tokens, hypo_tokens)

        sys.stdout.flush()
        # update running id counter
        start_id += len(inputs)

    logger.info(
        'Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if args.has_target:
        logger.info('Generate with beam={}: {}'.format(args.beam,
                                                       scorer.result_string()))
def _main(args, output_file):
    logging.basicConfig(
        format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        level=logging.INFO,
        stream=output_file,
    )
    logger = logging.getLogger('fairseq_cli.generate')

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    logger.info('loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(os.pathsep),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    from fairseq.sequence_scorer import SequenceScorer
    generator = task.build_generator(args)

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    avg_ranks = AverageMeter()
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        all_ents = []
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample['target'][:, :args.prefix_size]

            gen_timer.start()
            hypos = task.inference_step(generator, models, sample,
                                        prefix_tokens)
            if 'ents' in sample:
                all_ents.extend(sample['ents'])
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample['id'].tolist()):
                has_target = sample['target'] is not None

                # Remove padding
                src_tokens = utils.strip_pad(
                    sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
                target_tokens = None
                if has_target:
                    target_tokens = utils.strip_pad(
                        sample['target'][i, :], tgt_dict.pad()).int().cpu()

                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(
                        args.gen_subset).src.get_original_text(sample_id)
                    target_str = task.dataset(
                        args.gen_subset).tgt.get_original_text(sample_id)
                else:
                    if src_dict is not None:
                        src_str = src_dict.string(src_tokens, args.remove_bpe)
                    else:
                        src_str = ""
                    if has_target:
                        target_str = tgt_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)

                if not args.quiet:
                    if src_dict is not None:
                        print('S-{}\t{}'.format(sample_id, src_str),
                              file=output_file)
                    if has_target:
                        print('T-{}\t{}'.format(sample_id, target_str),
                              file=output_file)

                # Process top predictions
                for j, hypo in enumerate(hypos[i][:args.nbest]):
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int().cpu(),
                        src_str=src_str,
                        alignment=hypo['alignment'],
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )

                    if not args.quiet:
                        score = hypo['score'] / math.log(
                            2)  # convert to base 2
                        print('H-{}\t{}\t{}'.format(sample_id, score,
                                                    hypo_str),
                              file=output_file)
                        print(
                            'P-{}\t{}'.format(
                                sample_id,
                                ' '.join(
                                    map(
                                        lambda x: '{:.4f}'.format(x),
                                        # convert from base e to base 2
                                        hypo['positional_scores'].div_(
                                            math.log(2)).tolist(),
                                    ))),
                            file=output_file)

                        if args.print_alignment:
                            print('A-{}\t{}'.format(
                                sample_id, ' '.join([
                                    '{}-{}'.format(src_idx, tgt_idx)
                                    for src_idx, tgt_idx in alignment
                                ])),
                                  file=output_file)

                        if args.print_step:
                            print('I-{}\t{}'.format(sample_id, hypo['steps']),
                                  file=output_file)

                        if getattr(args, 'retain_iter_history', False):
                            for step, h in enumerate(hypo['history']):
                                _, h_str, _ = utils.post_process_prediction(
                                    hypo_tokens=h['tokens'].int().cpu(),
                                    src_str=src_str,
                                    alignment=None,
                                    align_dict=None,
                                    tgt_dict=tgt_dict,
                                    remove_bpe=None,
                                )
                                print('E-{}_{}\t{}'.format(
                                    sample_id, step, h_str),
                                      file=output_file)

                        if getattr(args, 'score_reference', False):
                            print('R-{}\t{}'.format(
                                sample_id, '{:.4f}'.format(hypo['avg_ranks'])),
                                  file=output_file)

                    # Score only the top hypothesis
                    if getattr(args, 'score_reference', False):
                        avg_ranks.update(hypo['avg_ranks'])
                    if has_target and j == 0:
                        if align_dict is not None or args.remove_bpe is not None:
                            # Convert back to tokens for evaluation with unk replacement and/or without BPE
                            target_tokens = tgt_dict.encode_line(
                                target_str, add_if_not_exist=True)
                        if hasattr(scorer, 'add_string'):
                            scorer.add_string(target_str, hypo_str)
                        else:
                            scorer.add(target_tokens, hypo_tokens)
            wps_meter.update(num_generated_tokens)
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += sample['nsentences']

    logger.info('NOTE: hypothesis and token scores are output in base 2')
    logger.info(
        'Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        logger.info('Generate {} with beam={}: {}'.format(
            args.gen_subset, args.beam, scorer.result_string()))
    if getattr(args, 'score_reference', False):
        logger.info('Average rank of reference={:.4f}, Entropy={:.4f}'.format(
            avg_ranks.avg,
            torch.cat(all_ents, dim=0).mean()))

    return scorer
Exemple #24
0
def eval(args, task, models, subset):
    use_cuda = torch.cuda.is_available() and not args.cpu

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True

    results = []
    hits = 0.
    hit_vector = []

    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample['target'][:, :args.prefix_size]

            gen_timer.start()
            hypos = task.inference_step(generator, models, sample,
                                        prefix_tokens)
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample['id'].tolist()):
                has_target = sample['target'] is not None

                # Remove padding
                src_tokens = utils.strip_pad(
                    sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
                target_tokens = None
                if has_target:
                    target_tokens = utils.strip_pad(
                        sample['target'][i, :], tgt_dict.pad()).int().cpu()

                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(
                        args.gen_subset).src.get_original_text(sample_id)
                    target_str = task.dataset(
                        args.gen_subset).tgt.get_original_text(sample_id)
                else:
                    if src_dict is not None:
                        src_str = src_dict.string(src_tokens, args.remove_bpe)
                    else:
                        src_str = ""
                    if has_target:
                        target_str = tgt_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)

                out_dict = {}
                if not args.quiet:
                    if src_dict is not None:
                        # print('S-{}\t{}'.format(sample_id, src_str))
                        out_dict['source'] = src_str
                    if has_target:
                        # print('T-{}\t{}'.format(sample_id, target_str))
                        out_dict['gold_target'] = target_str

                # Process top predictions
                for j, hypo in enumerate(hypos[i][:args.nbest]):
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int().cpu(),
                        src_str=src_str,
                        alignment=hypo['alignment'],
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )

                    if not args.quiet:
                        out_dict['pred_target'] = hypo_str
                        # print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str))
                        # print('P-{}\t{}'.format(
                        #     sample_id,
                        #     ' '.join(map(
                        #         lambda x: '{:.4f}'.format(x),
                        #         hypo['positional_scores'].tolist(),
                        #     ))
                        # ))

                        # if args.print_alignment:
                        # print('A-{}\t{}'.format(
                        #     sample_id,
                        #     ' '.join(['{}-{}'.format(src_idx, tgt_idx) for src_idx, tgt_idx in alignment])
                        # ))

                        # if args.print_step:
                        # print('I-{}\t{}'.format(sample_id, hypo['steps']))

                        # if getattr(args, 'retain_iter_history', False):
                        # print("\n".join([
                        #     'E-{}_{}\t{}'.format(
                        #         sample_id, step,
                        #         utils.post_process_prediction(
                        #             h['tokens'].int().cpu(),
                        #             src_str, None, None, tgt_dict, None)[1])
                        #     for step, h in enumerate(hypo['history'])]))

                        # Score only the top hypothesis
                        if has_target and j == 0:
                            if align_dict is not None or args.remove_bpe is not None:
                                # Convert back to tokens for evaluation with unk replacement and/or without BPE
                                target_tokens = tgt_dict.encode_line(
                                    target_str, add_if_not_exist=True)
                            if hasattr(scorer, 'add_string'):
                                scorer.add_string(target_str, hypo_str)
                            else:
                                scorer.add(target_tokens, hypo_tokens)

                    results.append(out_dict)
                    if out_dict['gold_target'] == out_dict['pred_target']:
                        hit_vector.append(1)
                        hits += 1
                    else:
                        hit_vector.append(0)

            wps_meter.update(num_generated_tokens)
            # t.log({'wps': round(wps_meter.avg)})
            num_sentences += sample['nsentences']

        print(
            '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
            .format(num_sentences, gen_timer.n, gen_timer.sum,
                    num_sentences / gen_timer.sum, 1. / gen_timer.avg))
        if has_target:
            print('| Generate {} with beam={}: {}'.format(
                args.gen_subset, args.beam, scorer.result_string()))

        acc = hits / len(results)
        print("Hit {}/{}. ACC: {}".format(hits, len(results), acc))

        return acc, hit_vector
Exemple #25
0
def _main(args, output_file):
    logging.basicConfig(
        format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        level=logging.INFO,
        stream=output_file,
    )
    logger = logging.getLogger('fairseq_cli.generate')

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    logger.info('loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        utils.split_paths(args.path),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        default_log_format=('tqdm' if not args.no_progress_bar else 'none'),
    )

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(models, args)

    # Handle tokenization and BPE
    tokenizer = encoders.build_tokenizer(args)
    bpe = encoders.build_bpe(args)

    def decode_fn(x):
        if bpe is not None:
            x = bpe.decode(x)
        if tokenizer is not None:
            x = tokenizer.decode(x)
        return x

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    wps_meter = TimeMeter()
    for sample in progress:
        sample = utils.move_to_cuda(sample) if use_cuda else sample
        if 'net_input' not in sample:
            continue

        prefix_tokens = None
        if args.prefix_size > 0:
            prefix_tokens = sample['target'][:, :args.prefix_size]

        gen_timer.start()
        hypos = task.inference_step(generator, models, sample, prefix_tokens)
        num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
        gen_timer.stop(num_generated_tokens)

        for i, sample_id in enumerate(sample['id'].tolist()):
            has_target = sample['target'] is not None

            # Remove padding
            src_tokens = utils.strip_pad(
                sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
            target_tokens = None
            if has_target:
                target_tokens = utils.strip_pad(sample['target'][i, :],
                                                tgt_dict.pad()).int().cpu()

            # Either retrieve the original sentences or regenerate them from tokens.
            if align_dict is not None:
                src_str = task.dataset(
                    args.gen_subset).src.get_original_text(sample_id)
                target_str = task.dataset(
                    args.gen_subset).tgt.get_original_text(sample_id)
            else:
                if src_dict is not None:
                    src_str = src_dict.string(src_tokens, args.remove_bpe)
                else:
                    src_str = ""
                if has_target:
                    target_str = tgt_dict.string(target_tokens,
                                                 args.remove_bpe,
                                                 escape_unk=True,
                                                 extra_symbols_to_ignore={
                                                     generator.eos,
                                                 })

            src_str = decode_fn(src_str)
            if has_target:
                target_str = decode_fn(target_str)

            if not args.quiet:
                if src_dict is not None:
                    print('S-{}\t{}'.format(sample_id, src_str),
                          file=output_file)
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str),
                          file=output_file)

            # Process top predictions
            for j, hypo in enumerate(hypos[i][:args.nbest]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'],
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe,
                    extra_symbols_to_ignore={
                        generator.eos,
                    })
                detok_hypo_str = decode_fn(hypo_str)
                if not args.quiet:
                    score = hypo['score'] / math.log(2)  # convert to base 2
                    # original hypothesis (after tokenization and BPE)
                    print('H-{}\t{}\t{}'.format(sample_id, score, hypo_str),
                          file=output_file)
                    # detokenized hypothesis
                    print('D-{}\t{}\t{}'.format(sample_id, score,
                                                detok_hypo_str),
                          file=output_file)
                    print(
                        'P-{}\t{}'.format(
                            sample_id,
                            ' '.join(
                                map(
                                    lambda x: '{:.4f}'.format(x),
                                    # convert from base e to base 2
                                    hypo['positional_scores'].div_(math.log(2)
                                                                   ).tolist(),
                                ))),
                        file=output_file)

                    if args.print_alignment:
                        print('A-{}\t{}'.format(
                            sample_id, ' '.join([
                                '{}-{}'.format(src_idx, tgt_idx)
                                for src_idx, tgt_idx in alignment
                            ])),
                              file=output_file)

                    if args.print_step:
                        print('I-{}\t{}'.format(sample_id, hypo['steps']),
                              file=output_file)

                    if 'enc_selection' in hypo:
                        print('Menc-{}\t{}'.format(sample_id,
                                                   hypo['enc_selection']),
                              file=output_file)
                    if 'dec_selection' in hypo:
                        print('Mdec-{}\t{}'.format(sample_id,
                                                   hypo['dec_selection']),
                              file=output_file)
                    if args.print_attn_confidence:
                        print('C-{}\t{}'.format(sample_id,
                                                hypo['enc_self_attn_conf']),
                              file=output_file)

                    if getattr(args, 'retain_iter_history', False):
                        for step, h in enumerate(hypo['history']):
                            _, h_str, _ = utils.post_process_prediction(
                                hypo_tokens=h['tokens'].int().cpu(),
                                src_str=src_str,
                                alignment=None,
                                align_dict=None,
                                tgt_dict=tgt_dict,
                                remove_bpe=None,
                            )
                            print('E-{}_{}\t{}'.format(sample_id, step, h_str),
                                  file=output_file)

                # Score only the top hypothesis
                if has_target and j == 0:
                    if align_dict is not None or args.remove_bpe is not None:
                        # Convert back to tokens for evaluation with unk replacement and/or without BPE
                        target_tokens = tgt_dict.encode_line(
                            target_str, add_if_not_exist=True)
                        hypo_tokens = tgt_dict.encode_line(
                            detok_hypo_str, add_if_not_exist=True)
                    if hasattr(scorer, 'add_string'):
                        scorer.add_string(target_str, detok_hypo_str)
                    else:
                        scorer.add(target_tokens, hypo_tokens)

        wps_meter.update(num_generated_tokens)
        progress.log({'wps': round(wps_meter.avg)})
        num_sentences += sample['nsentences']

    logger.info('NOTE: hypothesis and token scores are output in base 2')
    logger.info(
        'Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        if args.bpe and not args.sacrebleu:
            if args.remove_bpe:
                logger.warning(
                    "BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization"
                )
            else:
                logger.warning(
                    "If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words.  Use --sacrebleu for standard 13a BLEU tokenization"
                )
        logger.info('Generate {} with beam={}: {}'.format(
            args.gen_subset, args.beam, scorer.result_string()))

    return scorer
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Set up functions for multiturn
    def train_path(lang):
        return "{}{}".format(args.trainpref, ("." + lang) if lang else "")

    def file_name(prefix, lang):
        fname = prefix
        if lang is not None:
            fname += ".{lang}".format(lang=lang)
        return fname

    def dest_path(prefix, lang):
        return os.path.join(args.destdir, file_name(prefix, lang))

    def dict_path(lang):
        return dest_path("dict", lang) + ".txt"

    def build_dictionary(filenames, src=False, tgt=False):
        assert src ^ tgt
        return task.build_dictionary(
            filenames,
            workers=args.workers,
            threshold=args.thresholdsrc if src else args.thresholdtgt,
            nwords=args.nwordssrc if src else args.nwordstgt,
            padding_factor=args.padding_factor,
        )

    def make_binary_dataset(input_prefix, output_prefix, lang, num_workers):
        dict = task.load_dictionary(dict_path(lang))
        print("| [{}] Dictionary: {} types".format(lang, len(dict) - 1))
        n_seq_tok = [0, 0]
        replaced = Counter()

        def merge_result(worker_result):
            replaced.update(worker_result["replaced"])
            n_seq_tok[0] += worker_result["nseq"]
            n_seq_tok[1] += worker_result["ntok"]

        input_file = "{}{}".format(input_prefix,
                                   ("." + lang) if lang is not None else "")
        offsets = Tokenizer.find_offsets(input_file, num_workers)
        pool = None
        if num_workers > 1:
            pool = Pool(processes=num_workers - 1)
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                pool.apply_async(
                    binarize,
                    (
                        args,
                        input_file,
                        dict,
                        prefix,
                        lang,
                        offsets[worker_id],
                        offsets[worker_id + 1],
                    ),
                    callback=merge_result,
                )
            pool.close()

        ds = indexed_dataset.IndexedDatasetBuilder(
            dataset_dest_file(args, output_prefix, lang, "bin"))
        merge_result(
            Tokenizer.binarize(input_file,
                               dict,
                               lambda t: ds.add_item(t),
                               offset=0,
                               end=offsets[1]))
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, lang)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

        ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))

        print("| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format(
            lang,
            input_file,
            n_seq_tok[0],
            n_seq_tok[1],
            100 * sum(replaced.values()) / n_seq_tok[1],
            dict.unk_word,
        ))

    def make_dataset(input_prefix, output_prefix, lang, num_workers=1):
        if args.output_format == "binary":
            make_binary_dataset(input_prefix, output_prefix, lang, num_workers)
        elif args.output_format == "raw":
            # Copy original text file to destination folder
            output_text_file = dest_path(
                output_prefix +
                ".{}-{}".format(args.source_lang, args.target_lang),
                lang,
            )
            shutil.copyfile(file_name(input_prefix, lang), output_text_file)

    def make_all(lang):
        if args.multiturnpref:
            make_dataset(args.multiturnpref,
                         "test",
                         lang,
                         num_workers=args.workers)

    # Load dataset splits
    task = tasks.setup_task(args)

    # Multiturn tracking: prompt in test set, turn in debate
    turn = 0
    prompt = 1
    first_pass = True
    while first_pass or args.multiturn:
        if args.multiturn:
            # Set up first turn
            if turn == 0:
                multiturn_file = "{}{}".format(args.multiturnpref,
                                               ("." + args.source_lang))
                test_file = "{}{}".format(args.testpref,
                                          ("." + args.source_lang))
                if args.interactive:
                    line = input('What subject would you like to debate?')
                else:
                    with open(test_file, 'r', encoding='utf-8') as f:
                        for i in range(prompt):
                            line = f.readline()
                with open(multiturn_file, 'w', encoding='utf-8') as f:
                    f.write(line)
                prompt += 1

            target = not args.only_source
            assert (args.multiturnpref), "--multiturnpref must be set"
            if args.joined_dictionary:
                assert (
                    not args.srcdict or not args.tgtdict
                ), "cannot use both --srcdict and --tgtdict with --joined-dictionary"

                if args.srcdict:
                    src_dict = task.load_dictionary(args.srcdict)
                elif args.tgtdict:
                    src_dict = task.load_dictionary(args.tgtdict)
                else:
                    assert (
                        args.trainpref
                    ), "--trainpref must be set if --srcdict is not specified"
                    src_dict = build_dictionary(
                        {
                            train_path(lang)
                            for lang in [args.source_lang, args.target_lang]
                        },
                        src=True)
                tgt_dict = src_dict
            else:
                if args.srcdict:
                    src_dict = task.load_dictionary(args.srcdict)
                else:
                    assert (
                        args.trainpref
                    ), "--trainpref must be set if --srcdict is not specified"
                    src_dict = build_dictionary([train_path(args.source_lang)],
                                                src=True)
            if target:
                if args.tgtdict:
                    tgt_dict = task.load_dictionary(args.tgtdict)
                else:
                    assert (
                        args.trainpref
                    ), "--trainpref must be set if --tgtdict is not specified"
                    tgt_dict = build_dictionary([train_path(args.target_lang)],
                                                tgt=True)
            else:
                tgt_dict = None

            src_dict.save(dict_path(args.source_lang))
            if target and tgt_dict is not None:
                tgt_dict.save(dict_path(args.target_lang))

            make_all(args.source_lang)
            if target:
                make_all(args.target_lang)
            if first_pass:
                print("| Wrote preprocessed data to {}".format(args.destdir))
                print('| Generating multiturn debate')
            task.load_dataset('test')
        else:
            task.load_dataset(args.gen_subset)
            print('| {} {} {} examples'.format(
                args.data, args.gen_subset,
                len(task.dataset(args.gen_subset))))

        if first_pass:
            # Set dictionaries
            src_dict = task.source_dictionary
            tgt_dict = task.target_dictionary

            # Load ensemble
            print('| loading model(s) from {}'.format(args.path))
            models, _model_args = utils.load_ensemble_for_inference(
                args.path.split(':'),
                task,
                model_arg_overrides=eval(args.model_overrides),
            )

            # Optimize ensemble for generation
            for model in models:
                model.make_generation_fast_(
                    beamable_mm_beam_size=None
                    if args.no_beamable_mm else args.beam,
                    need_attn=args.print_alignment,
                )
                if args.fp16:
                    model.half()

        # Load alignment dictionary for unknown word replacement
        # (None if no unknown word replacement, empty if no path to align dictionary)
        align_dict = utils.load_align_dict(args.replace_unk)

        # Load dataset (possibly sharded)
        itr = task.get_batch_iterator(
            dataset=task.dataset(args.gen_subset),
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                *[model.max_positions() for model in models]),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=8,
            num_shards=args.num_shards,
            shard_id=args.shard_id,
            num_workers=args.num_workers,
        ).next_epoch_itr(shuffle=False)

        # Initialize generator
        gen_timer = StopwatchMeter()
        if args.score_reference:
            translator = SequenceScorer(models, task.target_dictionary)
        else:
            translator = SequenceGenerator(
                models,
                task.target_dictionary,
                beam_size=args.beam,
                minlen=args.min_len,
                stop_early=(not args.no_early_stop),
                normalize_scores=(not args.unnormalized),
                len_penalty=args.lenpen,
                unk_penalty=args.unkpen,
                sampling=args.sampling,
                sampling_topk=args.sampling_topk,
                sampling_temperature=args.sampling_temperature,
                diverse_beam_groups=args.diverse_beam_groups,
                diverse_beam_strength=args.diverse_beam_strength,
                match_source_len=args.match_source_len,
                no_repeat_ngram_size=args.no_repeat_ngram_size,
            )

        if use_cuda:
            translator.cuda()

        # Generate and compute BLEU score
        if args.sacrebleu:
            scorer = bleu.SacrebleuScorer()
        else:
            scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(),
                                 tgt_dict.unk())
        num_sentences = 0
        has_target = True
        with progress_bar.build_progress_bar(args, itr) as t:
            if args.score_reference:
                translations = translator.score_batched_itr(t,
                                                            cuda=use_cuda,
                                                            timer=gen_timer)
            else:
                translations = translator.generate_batched_itr(
                    t,
                    maxlen_a=args.max_len_a,
                    maxlen_b=args.max_len_b,
                    cuda=use_cuda,
                    timer=gen_timer,
                    prefix_size=args.prefix_size,
                )

            wps_meter = TimeMeter()
            for sample_id, src_tokens, target_tokens, hypos in translations:

                # Process input and ground truth
                has_target = target_tokens is not None
                target_tokens = target_tokens.int().cpu(
                ) if has_target else None

                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(
                        args.gen_subset).src.get_original_text(sample_id)
                    target_str = task.dataset(
                        args.gen_subset).tgt.get_original_text(sample_id)
                else:
                    src_str = src_dict.string(src_tokens, args.remove_bpe)
                    if has_target:
                        target_str = tgt_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)

                if not args.quiet:
                    print('S-{}\t{}'.format(sample_id, src_str))
                    if has_target:
                        print('T-{}\t{}'.format(sample_id, target_str))

                # Process top predictions
                for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int().cpu(),
                        src_str=src_str,
                        alignment=hypo['alignment'].int().cpu()
                        if hypo['alignment'] is not None else None,
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )

                    if not args.quiet:
                        print('H-{}\t{}\t{}'.format(sample_id, hypo['score'],
                                                    hypo_str))
                        print('P-{}\t{}'.format(
                            sample_id, ' '.join(
                                map(
                                    lambda x: '{:.4f}'.format(x),
                                    hypo['positional_scores'].tolist(),
                                ))))

                        if args.print_alignment:
                            print('A-{}\t{}'.format(
                                sample_id, ' '.join(
                                    map(lambda x: str(utils.item(x)),
                                        alignment))))

                    if args.multiturn:
                        multiturn_file = "{}{}".format(
                            args.multiturnpref, ("." + args.source_lang))
                        output_file = "{}{}".format(args.outputpref,
                                                    ("." + args.target_lang))
                        with open(multiturn_file, 'r', encoding='utf-8') as f:
                            line = f.readline()
                            if args.interactive:
                                interactive_response = input('Please respond:')
                                line += f' <EOA> {interactive_response}'
                        if turn < MAX_TURNS - 1:
                            with open(multiturn_file, 'w',
                                      encoding='utf-8') as f:
                                f.write(f'{line[:-1]} <EOA> {hypo_str}')
                            turn += 1
                        elif turn == MAX_TURNS - 1:
                            with open(output_file, 'a', encoding='utf-8') as f:
                                f.write(f'{line[:-1]} <EOA> {hypo_str}\n')
                            turn = 0

                    # Score only the top hypothesis
                    if has_target and i == 0:
                        if align_dict is not None or args.remove_bpe is not None:
                            # Convert back to tokens for evaluation with unk replacement and/or without BPE
                            target_tokens = tokenizer.Tokenizer.tokenize(
                                target_str, tgt_dict, add_if_not_exist=True)
                        if hasattr(scorer, 'add_string'):
                            scorer.add_string(target_str, hypo_str)
                        else:
                            scorer.add(target_tokens, hypo_tokens)

                wps_meter.update(src_tokens.size(0))
                t.log({'wps': round(wps_meter.avg)})
                num_sentences += 1

        print(
            '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
            .format(num_sentences, gen_timer.n, gen_timer.sum,
                    num_sentences / gen_timer.sum, 1. / gen_timer.avg))
        if has_target:
            print('| Generate {} with beam={}: {}'.format(
                args.gen_subset, args.beam, scorer.result_string()))

        first_pass = False
Exemple #27
0
def _generate_score(models, args, task, dataset, modify_target_dict):
    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load ensemble
    if not args.quiet:
        print("| loading model(s) from {}".format(", ".join(args.path.split(":"))))

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=True,
        )

    translator = build_sequence_generator(args, task, models)
    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    print("seed number is" + str(args.max_examples_to_evaluate_seed))
    if args.max_examples_to_evaluate > 0:
        pytorch_translate_data.subsample_pair_dataset(
            dataset, args.max_examples_to_evaluate, args.max_examples_to_evaluate_seed
        )

    # Keep track of translations
    # Initialize with empty translations
    # and zero probs scores
    translated_sentences = [""] * len(dataset)
    translated_scores = [0.0] * len(dataset)
    hypos_list = []

    collect_output_hypos = getattr(args, "output_hypos_binary_path", False)
    if collect_output_hypos:
        output_hypos_token_arrays = [None] * len(dataset)

    # Generate and compute BLEU score
    dst_dict = task.target_dictionary
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(dst_dict.pad(), dst_dict.eos(), dst_dict.unk())

    itr = task.get_batch_iterator(
        dataset=dataset,
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(), *[model.max_positions() for model in models]
        ),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    oracle_scorer = None
    if args.report_oracle_bleu:
        oracle_scorer = bleu.Scorer(dst_dict.pad(), dst_dict.eos(), dst_dict.unk())

    rescorer = None
    num_sentences = 0
    translation_samples = []
    translation_info_list = []
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        gen_timer = StopwatchMeter()
        translations = translator.generate_batched_itr(
            t,
            maxlen_a=args.max_len_a,
            maxlen_b=args.max_len_b,
            cuda=use_cuda,
            timer=gen_timer,
            prefix_size=1
            if pytorch_translate_data.is_multilingual_many_to_one(args)
            else 0,
        )

        for trans_info in _iter_translations(
            args, task, dataset, translations, align_dict, rescorer, modify_target_dict
        ):
            if hasattr(scorer, "add_string"):
                scorer.add_string(trans_info.target_str, trans_info.hypo_str)
            else:
                scorer.add(trans_info.target_tokens, trans_info.hypo_tokens)
            if oracle_scorer is not None:
                oracle_scorer.add(trans_info.target_tokens, trans_info.best_hypo_tokens)

            if getattr(args, "translation_output_file", False):
                translated_sentences[trans_info.sample_id] = trans_info.hypo_str
            if getattr(args, "translation_probs_file", False):
                translated_scores[trans_info.sample_id] = trans_info.hypo_score
            if getattr(args, "hypotheses_export_path", False):
                hypos_list.append(trans_info.hypos)
            if collect_output_hypos:
                output_hypos_token_arrays[
                    trans_info.sample_id
                ] = trans_info.best_hypo_tokens
            if args.translation_info_export_path is not None:
                # Strip expensive data from hypotheses before saving
                hypos = [
                    {k: v for k, v in hypo.items() if k in ["tokens", "score"]}
                    for hypo in trans_info.hypos
                ]
                # Make sure everything is on cpu before exporting
                hypos = [
                    {"score": hypo["score"], "tokens": hypo["tokens"].cpu()}
                    for hypo in hypos
                ]
                translation_info_list.append(
                    {
                        "src_tokens": trans_info.src_tokens.cpu(),
                        "target_tokens": trans_info.target_tokens,
                        "hypos": hypos,
                    }
                )
            translation_samples.append(
                collections.OrderedDict(
                    {
                        "sample_id": trans_info.sample_id.item(),
                        "src_str": trans_info.src_str,
                        "target_str": trans_info.target_str,
                        "hypo_str": trans_info.hypo_str,
                    }
                )
            )
            wps_meter.update(trans_info.src_tokens.size(0))
            t.log({"wps": round(wps_meter.avg)})
            num_sentences += 1

    # If applicable, save collected hypothesis tokens to binary output file
    if collect_output_hypos:
        output_dataset = pytorch_translate_data.InMemoryIndexedDataset()
        output_dataset.load_from_sequences(output_hypos_token_arrays)
        output_dataset.save(args.output_hypos_binary_path)
    if args.output_source_binary_path:
        dataset.src.save(args.output_source_binary_path)
    if args.translation_info_export_path is not None:
        f = open(args.translation_info_export_path, "wb")
        pickle.dump(translation_info_list, f)
        f.close()

    # If applicable, save the translations and scores to the output files
    # These two ouputs are used in dual learning for weighted backtranslation
    if getattr(args, "translation_output_file", False) and getattr(
        args, "translation_probs_file", False
    ):
        with open(args.translation_output_file, "w") as translation_file, open(
            args.translation_probs_file, "w"
        ) as score_file:
            for hypo_str, hypo_score in zip(translated_sentences, translated_scores):
                if len(hypo_str.strip()) > 0:
                    print(hypo_str, file=translation_file)
                    print(np.exp(hypo_score), file=score_file)

    # For eg. external evaluation
    if getattr(args, "hypotheses_export_path", False):
        with open(args.hypotheses_export_path, "w") as out_file:
            for hypos in hypos_list:
                for hypo in hypos:
                    print(
                        task.tgt_dict.string(
                            hypo["tokens"], bpe_symbol=args.remove_bpe
                        ),
                        file=out_file,
                    )

    if oracle_scorer is not None:
        print(f"| Oracle BLEU (best hypo in beam): {oracle_scorer.result_string()}")

    return scorer, num_sentences, gen_timer, translation_samples
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    import_user_module(args)
    """
    MODIFIED: The GEC task uses token-labeled raw text datasets, which 
    require raw text to be used.
    """
    assert args.raw_text, \
        f"--raw-text option is required for copy-based generation."

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset,
                                       len(task.dataset(args.gen_subset))))

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _model_args = utils.load_ensemble_for_inference(
        args.path.split(':'),
        task,
        model_arg_overrides=eval(args.model_overrides),
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    has_copy_scores = True
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample['target'][:, :args.prefix_size]

            gen_timer.start()
            """
            MODIFIED: Use copy scores to replace <unk>'s with raw source words.
            
            use_copy_scores may be False with non-copy-based transformers that
            only use edit labels (e.g., transformer_aux_el and transformer_el).
            """
            hypos = task.inference_step(generator, models, sample,
                                        prefix_tokens)
            use_copy_scores = hypos[0][0].get('copy_scores', None) is not None
            if has_copy_scores and not use_copy_scores:
                print("| generate_or_copy.py | INFO | "
                      "Model does not include copy scores. "
                      "Generating hypotheses without replacing UNKs.")
                has_copy_scores = False
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample['id'].tolist()):
                has_target = sample['target'] is not None

                # Remove padding
                src_tokens = utils.strip_pad(
                    sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
                target_tokens = None
                if has_target:
                    target_tokens = utils.strip_pad(
                        sample['target'][i, :], tgt_dict.pad()).int().cpu()
                """
                MODIFIED: Replace <unk>s with raw source tokens. 
                This is analogous to the case where align_dict is provided
                in the original generate.py.
                """
                rawtext_dataset = task.dataset(args.gen_subset)
                src_str = rawtext_dataset.src.get_original_text(sample_id)
                tokenized_src_str = rawtext_dataset.src_dict.string(
                    src_tokens, bpe_symbol=args.remove_bpe)
                target_str = rawtext_dataset.tgt.get_original_text(sample_id)

                if not args.quiet:
                    if src_dict is not None:
                        # Raw source text
                        print('S-{}\t{}'.format(sample_id, src_str))
                        # Tokenized source text
                        print('K-{}\t{}'.format(sample_id, tokenized_src_str))
                    if has_target:
                        print('T-{}\t{}'.format(sample_id, target_str))

                # Process top predictions
                for k, hypo in enumerate(
                        hypos[i][:min(len(hypos), args.nbest)]):
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int().cpu(),
                        src_str=src_str,
                        alignment=hypo['alignment'].int().cpu()
                        if hypo['alignment'] is not None else None,
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )
                    """
                    MODIFIED: Replace predicted <unk>s with the source token
                    that received the highest score.
                    """
                    raw_src_tokens = src_str.split()
                    final_hypo_tokens_str = []
                    for tgt_position, hypo_token in enumerate(hypo_tokens):
                        if use_copy_scores and hypo_token == tgt_dict.unk():
                            # See sequence_copygenerator.py#L292 for details.
                            copy_scores = hypo[
                                'copy_scores'][:, tgt_position].cpu()
                            assert len(copy_scores) - 1 == len(raw_src_tokens), \
                                f"length of copy scores do not match input source tokens " \
                                f"(copy_scores: {copy_scores}, raw_src_tokens: {raw_src_tokens})"
                            src_position = torch.argmax(copy_scores).item()
                            # Don't copy if attending to an EOS (not ideal).
                            if src_position == len(raw_src_tokens):
                                print("WARNING: copy score highest at EOS.")
                            else:
                                final_hypo_tokens_str.append(
                                    raw_src_tokens[src_position])
                            print('U-{}\t{}\t{}'.format(
                                sample_id,
                                tgt_position,
                                ' '.join(
                                    map(
                                        lambda x: '{:.4f}'.format(x),
                                        copy_scores.tolist(),
                                    )),
                            ))
                        else:
                            final_hypo_tokens_str.append(tgt_dict[hypo_token])

                    # Note: raw input tokens could be included here.
                    final_hypo_str = ' '.join([
                        token for token in final_hypo_tokens_str
                        if token != tgt_dict.eos_word
                    ])

                    if not args.quiet:
                        print('H-{}\t{}\t{}'.format(sample_id, hypo['score'],
                                                    final_hypo_str))
                        print('P-{}\t{}'.format(
                            sample_id, ' '.join(
                                map(
                                    lambda x: '{:.4f}'.format(x),
                                    hypo['positional_scores'].tolist(),
                                ))))

                        if args.print_alignment:
                            print('A-{}\t{}'.format(
                                sample_id, ' '.join(
                                    map(lambda x: str(utils.item(x)),
                                        alignment))))

                    # Score only the top hypothesis
                    if has_target and k == 0:
                        if align_dict is not None or args.remove_bpe is not None:
                            # Convert back to tokens for evaluation with unk replacement and/or without BPE
                            target_tokens = tgt_dict.encode_line(
                                target_str, add_if_not_exist=True)
                        if hasattr(scorer, 'add_string'):
                            scorer.add_string(target_str, hypo_str)
                        else:
                            scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(num_generated_tokens)
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += sample['nsentences']

    print(
        '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset,
                                                      args.beam,
                                                      scorer.result_string()))
    return scorer
Exemple #29
0
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    tgt_file = None
    hypo_file = None
    if args.output_dir:
        os.makedirs(args.output_dir, exist_ok=True)
        tgt_fn = os.path.join(args.output_dir, 'gold')
        hypo_fn = os.path.join(args.output_dir, 'candidate')
        tgt_file = open(tgt_fn, 'w', encoding='utf-8')
        hypo_file = open(hypo_fn, 'w', encoding='utf-8')

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset,
                                       len(task.dataset(args.gen_subset))))

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _model_args = utils.load_ensemble_for_inference(
        args.path.split(':'),
        task,
        model_arg_overrides=eval(args.model_overrides),
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample['target'][:, :args.prefix_size]

            gen_timer.start()
            hypos = task.inference_step(generator, models, sample,
                                        prefix_tokens)
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample['id'].tolist()):
                has_target = sample['target'] is not None

                # Remove padding
                src_tokens = utils.strip_pad(
                    sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
                target_tokens = None
                if has_target:
                    target_tokens = utils.strip_pad(
                        sample['target'][i, :], tgt_dict.pad()).int().cpu()

                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(
                        args.gen_subset).src.get_original_text(sample_id)
                    target_str = task.dataset(
                        args.gen_subset).tgt.get_original_text(sample_id)
                else:
                    if src_dict is not None:
                        src_str = src_dict.string(src_tokens, args.remove_bpe)
                    else:
                        src_str = ""
                    if has_target:
                        target_str = tgt_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)

                if not args.quiet:
                    if src_dict is not None:
                        print('S-{}\t{}'.format(
                            sample_id, src_str.encode(encoding='utf-8')))
                    if has_target:
                        print('T-{}\t{}'.format(
                            sample_id, target_str.encode(encoding='utf-8')))

                # Process top predictions
                for j, hypo in enumerate(
                        hypos[i][:min(len(hypos[i]), args.nbest)]):
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int().cpu(),
                        src_str=src_str,
                        alignment=hypo['alignment'].int().cpu()
                        if hypo['alignment'] is not None else None,
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )

                    if not args.quiet:
                        print('H-{}\t{}\t{}'.format(
                            sample_id, hypo['score'],
                            hypo_str.encode(encoding='utf-8')))
                        print('P-{}\t{}'.format(
                            sample_id, ' '.join(
                                map(
                                    lambda x: '{:.4f}'.format(x),
                                    hypo['positional_scores'].tolist(),
                                ))))

                        if args.print_alignment:
                            print('A-{}\t{}'.format(
                                sample_id, ' '.join(
                                    map(lambda x: str(utils.item(x)),
                                        alignment))))

                    # Score only the top hypothesis
                    if has_target and j == 0:
                        if align_dict is not None or args.remove_bpe is not None:
                            # Convert back to tokens for evaluation with unk replacement and/or without BPE
                            target_tokens = tgt_dict.encode_line(
                                target_str, add_if_not_exist=True)
                        if hasattr(scorer, 'add_string'):
                            scorer.add_string(target_str, hypo_str)
                        else:
                            scorer.add(target_tokens, hypo_tokens)

                        if args.output_dir:
                            tgt_file.writelines(target_str + '\n')
                            hypo_file.writelines(hypo_str + '\n')

            wps_meter.update(num_generated_tokens)
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += sample['nsentences']

    tgt_file.close()
    hypo_file.close()

    print(
        '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset,
                                                      args.beam,
                                                      scorer.result_string()))
    return scorer
Exemple #30
0
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset, args=args)
    print('| {} {} {} examples'.format(args.data, args.gen_subset,
                                       len(task.dataset(args.gen_subset))))

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    args.unk_idx = task.src_dict.indices['<unk>']
    args.dict_len = task.src_dict.indices.__len__()
    if '[APPEND]' in task.src_dict.indices.keys():
        args.APPEND_ID = task.src_dict.indices['[APPEND]']
        print("[APPEND] ID: {}".format(args.APPEND_ID))
    else:
        args.APPEND_ID = -1
    if '[SRC]' in task.src_dict.indices.keys():
        args.SRC_ID = task.src_dict.indices['[SRC]']
        print("[SRC] ID: {}".format(args.SRC_ID))
    else:
        args.SRC_ID = -1
    if '[TGT]' in task.src_dict.indices.keys():
        args.TGT_ID = task.src_dict.indices['[TGT]']
        print("[TGT] ID: {}".format(args.TGT_ID))
    else:
        args.TGT_ID = -1
    if '[SEP]' in task.src_dict.indices.keys():
        args.SEP_ID = task.src_dict.indices['[SEP]']
        print("[SEP] ID: {}".format(args.SEP_ID))
    else:
        args.SEP_ID = -1
    if '</s>' in task.src_dict.indices.keys():
        args.EOS_ID = task.src_dict.indices['</s>']
    else:
        args.EOD_ID = -1
    if '<pad>' in task.src_dict.indices.keys():
        args.PAD_ID = task.src_dict.indices['<pad>']
    else:
        args.PAD_ID = -1

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _model_args = utils.load_ensemble_for_inference(
        args.path.split(':'),
        task,
        model_arg_overrides=eval(args.model_overrides),
    )
    _model_args.avgpen = args.avgpen
    task.datasets['test'].args = _model_args

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    select_retrieve_tokens = []
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        trans_results = []
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample['target'][:, :args.prefix_size]

            gen_timer.start()
            hypos, encoder_outs = task.inference_step(generator, models,
                                                      sample, prefix_tokens)
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample['id'].tolist()):
                has_target = sample['target'] is not None

                # Remove padding
                src_tokens, retrieve_source_tokens, retrieve_target_tokens = sample[
                    'net_input']['src_tokens']
                retrieve_tokens = list(
                    itertools.chain.from_iterable(
                        zip(retrieve_source_tokens, retrieve_target_tokens)))
                retrieve_tokens = torch.cat(retrieve_tokens, dim=1)
                all_tokens = torch.cat([src_tokens, retrieve_tokens], dim=1)
                src_tokens = utils.strip_pad(all_tokens[i, :], tgt_dict.pad())
                target_tokens = None
                if has_target:
                    target_tokens = utils.strip_pad(
                        sample['target'][i, :], tgt_dict.pad()).int().cpu()

                #

                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(
                        args.gen_subset).src.get_original_text(sample_id)
                    target_str = task.dataset(
                        args.gen_subset).tgt.get_original_text(sample_id)
                else:
                    if src_dict is not None:
                        src_str = src_dict.string(src_tokens, args.remove_bpe)
                    else:
                        src_str = ""
                    if has_target:
                        target_str = tgt_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)

                if not args.quiet:
                    if src_dict is not None:
                        print('S-{}\t{}'.format(sample_id, src_str))
                    if has_target:
                        print('T-{}\t{}'.format(sample_id, target_str))

                # add select tokens
                select_retrieve_tokens.append([
                    sample_id, src_str, target_str,
                    sample['predict_ground_truth'][i, :],
                    retrieve_tokens[i, :],
                    encoder_outs[0]['new_retrieve_tokens'][i, :],
                    utils.strip_pad(retrieve_tokens[i, :],
                                    src_dict.pad()).tolist(),
                    utils.strip_pad(
                        encoder_outs[0]['new_retrieve_tokens'][i, :],
                        src_dict.pad()).tolist()
                ])
                # Process top predictions
                for i, hypo in enumerate(
                        hypos[i][:min(len(hypos), args.nbest)]):
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int().cpu(),
                        src_str=src_str,
                        alignment=hypo['alignment'].int().cpu()
                        if hypo['alignment'] is not None else None,
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )

                    trans_results.append((sample_id, hypo_str))
                    if not args.quiet:
                        print('H-{}\t{}\t{}'.format(sample_id, hypo['score'],
                                                    hypo_str))
                        print('P-{}\t{}'.format(
                            sample_id, ' '.join(
                                map(
                                    lambda x: '{:.4f}'.format(x),
                                    hypo['positional_scores'].tolist(),
                                ))))
                        if args.print_alignment:
                            print('A-{}\t{}'.format(
                                sample_id, ' '.join(
                                    map(lambda x: str(utils.item(x)),
                                        alignment))))

                    # Score only the top hypothesis
                    if has_target and i == 0:
                        if align_dict is not None or args.remove_bpe is not None:
                            # Convert back to tokens for evaluation with unk replacement and/or without BPE
                            target_tokens = tgt_dict.encode_line(
                                target_str, add_if_not_exist=True)
                        if hasattr(scorer, 'add_string'):
                            scorer.add_string(target_str, hypo_str)
                        else:
                            scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(num_generated_tokens)
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += sample['nsentences']

    print(
        '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset,
                                                      args.beam,
                                                      scorer.result_string()))

    trans_results.sort(key=lambda key: key[0])
    print("saving translation result to {}...".format(args.output))
    with open(args.output, "w", encoding="utf-8") as w:
        for item in trans_results:
            w.write("{}\n".format(item[1].replace("<<unk>>", "")))
    select_retrieve_tokens.sort(key=lambda key: key[0])
    orig_retrieve_tokens_length = 0
    select_retrieve_tokens_length = 0
    correct_tokens = 0
    with open(args.output + ".select", "w", encoding="utf-8") as w_select:
        for item in select_retrieve_tokens:
            sample_id, src_str, target_str, sample_predict_ground_truth, sample_orig_id, sample_select_retrieve_id, sample_orig_retrieve_tokens, sample_select_retrieve_tokens = item
            retrieve_str = src_dict.string(sample_orig_retrieve_tokens,
                                           args.remove_bpe)
            select_str = src_dict.string(sample_select_retrieve_tokens,
                                         args.remove_bpe)
            w_select.write("{}\n{}\n{}\n{}\n\n".format(src_str, target_str,
                                                       retrieve_str,
                                                       select_str))
            orig_retrieve_tokens_length += len(sample_orig_retrieve_tokens)
            select_retrieve_tokens_length += len(sample_select_retrieve_tokens)
            #calculate accuracy
            correct_tokens += (
                (sample_select_retrieve_id != _model_args.PAD_ID
                 ).long() == sample_predict_ground_truth).masked_fill(
                     (sample_orig_id == _model_args.PAD_ID).byte(), 0).sum()

    ratio = select_retrieve_tokens_length / float(orig_retrieve_tokens_length)
    accuracy = correct_tokens.tolist() / float(orig_retrieve_tokens_length)
    print("Selective Tokens: {}".format(ratio))
    print("Correct Tokens: {}".format(accuracy))

    with open("{}.RetrieveNMT.BLEU".format(args.output), "a",
              encoding="utf-8") as w:
        w.write(
            '{}->{}: Generate {} with beam={} and lenpen={}: {};\tSelection Ratio: {};\tAccuracy:{}\n'
            .format(args.source_lang, args.target_lang,
                    args.gen_subset, args.beam, args.lenpen,
                    scorer.result_string(), ratio, accuracy))

    return scorer