Exemplo n.º 1
0
def main(args):
    assert args.path is not None, '--path required for evaluation!'

    args.tokens_per_sample = getattr(args, 'tokens_per_sample', 1024)
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task)

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        model.make_generation_fast_()
        if args.fp16:
            model.half()

    itr = data.EpochBatchIterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences or 4,
        max_positions=model.max_positions(),
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        ignore_invalid_inputs=True,
    ).next_epoch_itr(shuffle=False)

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(models, task.target_dictionary)
    if use_cuda:
        scorer.cuda()

    score_sum = 0.
    count = 0
    with progress_bar.build_progress_bar(args, itr) as t:
        results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
        wps_meter = TimeMeter()
        for _, src_tokens, __, hypos in results:
            for hypo in hypos:
                pos_scores = hypo['positional_scores']
                inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
                if inf_scores.any():
                    print('| Skipping tokens with inf scores:',
                          task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()]))
                    pos_scores = pos_scores[(~inf_scores).nonzero()]
                score_sum += pos_scores.sum()
                count += pos_scores.numel()
            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})

    avg_nll_loss = -score_sum / count
    print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss)))
Exemplo n.º 2
0
def decode_from_file(models, task, args, use_cuda, source_filename=None,
                     target_filename=None, output_filename=None):
    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # I/O files
    source_filename = source_filename if source_filename is not None else args.decode_source_file
    target_filename = target_filename if target_filename is not None else args.decode_target_file
    output_filename = output_filename if output_filename is not None else args.decode_output_file
    if output_filename is not None:
        base_filename = output_filename
    else:
        base_filename = source_filename
        if args.num_shards:
            base_filename += "%.2d" % args.shard_id
    decode_filename = _decode_filename(base_filename, args)
    outfile = open(decode_filename, "w")
    if args.decode_to_file:
        print("| [decode] writing decodes into {}".format(decode_filename))

    # Get sorted input (and reversed)
    sorted_inputs, sorted_keys, sorted_targets = _get_sorted_inputs(
        source_filename, args.num_shards, args.delimiter, target_filename, args.shard_id)

    # Build input iterator
    src_tokens = [
        tokenizer.Tokenizer.tokenize(src_str, src_dict, add_if_not_exist=False).long()
        for src_str in sorted_inputs]
    src_sizes = np.array([t.numel() for t in src_tokens])
    tgt_tokens = [
        tokenizer.Tokenizer.tokenize(tgt_str, tgt_dict, add_if_not_exist=False).long()
        for tgt_str in sorted_targets] if sorted_targets is not None else None
    tgt_sizes = np.array([t.numel() for t in tgt_tokens]) if tgt_tokens is not None else None
    print('| loading {} examples, {} tokens'.format(len(sorted_inputs), sum(src_sizes)))

    dataset = data.LanguagePairDataset(
        src_tokens, src_sizes, src_dict, tgt_tokens, tgt_sizes, tgt_dict, shuffle=False)
    itr = data.EpochBatchIterator(
        dataset=dataset,
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=models[0].max_positions(),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    if args.score_reference:
        translator = SequenceScorer(models, task.target_dictionary)
    else:
        translator = SequenceGenerator(
            models, task.target_dictionary, beam_size=args.beam,
            stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized),
            len_penalty=args.lenpen, unk_penalty=args.unkpen,
            sampling=args.sampling, sampling_topk=args.sampling_topk, minlen=args.min_len,
        )

    if use_cuda:
        translator.cuda()

    # Generate and compute BLEU score
    scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True

    if args.score_reference:
        translations = translator.score_batched_itr(itr, cuda=use_cuda, timer=gen_timer)
    else:
        translations = translator.generate_batched_itr(
            itr, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
            cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size,
        )

    decodes = dict()
    sids = []
    wps_meter = TimeMeter()
    start = time.perf_counter()
    for sample_id, src_tokens, target_tokens, hypos in translations:
        # Process input and ground truth
        has_target = target_tokens is not None
        target_tokens = target_tokens.int().cpu() if has_target else None

        # Either retrieve the original sentences or regenerate them from tokens.
        if align_dict is not None:
            src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
            target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id)
        else:
            src_str = src_dict.string(src_tokens, args.remove_bpe)
            if has_target:
                target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True)

        if not args.quiet:
            try:
                print('S-{}\t{}'.format(sample_id, src_str))
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str))
            except:
                print('S-{}\t{}'.format(sample_id, src_str.encode('utf-8')))
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str.encode('utf-8')))

        # Process top predictions
        for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
            hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                hypo_tokens=hypo['tokens'].int().cpu(),
                src_str=src_str,
                alignment=hypo['alignment'].int().cpu(),
                align_dict=align_dict,
                tgt_dict=tgt_dict,
                remove_bpe=args.remove_bpe,
            )
            if i == 0:
                decodes[sample_id.tolist()] = hypo_str
                # sids.append(sample_id.tolist())

            if not args.quiet:
                try:
                    print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str))
                except:
                    print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str.encode('utf-8')))
                print('P-{}\t{}'.format(
                    sample_id,
                    ' '.join(map(
                        lambda x: '{:.4f}'.format(x),
                        hypo['positional_scores'].tolist(),
                    ))
                ))
                print('A-{}\t{}'.format(
                    sample_id,
                    ' '.join(map(lambda x: str(utils.item(x)), alignment))
                ))

            # Score only the top hypothesis
            if has_target and i == 0:
                if align_dict is not None or args.remove_bpe is not None:
                    # Convert back to tokens for evaluation with unk replacement and/or without BPE
                    target_tokens = tokenizer.Tokenizer.tokenize(
                        target_str, tgt_dict, add_if_not_exist=True)
                scorer.add(target_tokens, hypo_tokens)

        wps_meter.update(src_tokens.size(0))

        num_sentences += 1
        if args.quiet and num_sentences % 100 == 0:
            print("| {} / {} sentences decoded ({})".format(num_sentences, len(sorted_inputs), len(decodes)))

    used_time = time.perf_counter() - start
    print("| Used time:" + repr(used_time))
    print("| Average time:" + repr(used_time / len(sorted_inputs)))

    if args.decode_to_file:
        print("| [decode] writing decodes into {}".format(decode_filename))
        # print(sids)
        for index in range(len(sorted_inputs)):
            try:
                outfile.write("{}{}".format(decodes[sorted_keys[index]], args.delimiter))
            except:
                outfile.write("{}{}".format(decodes[sorted_keys[index]].encode('utf-8'), args.delimiter))
        outfile.close()

    print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
        num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
Exemplo n.º 3
0
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset,
                                       len(task.dataset(args.gen_subset))))

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task)

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam)
        if args.fp16:
            model.half()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    ignoredIndices = []
    if args.outindices:
        f = open(args.outindices, 'r')
        for line in f.readlines():
            ignoredIndices.append(int(line.strip()))
    print("{} indices to be ignored from validation set.".format(
        len(ignoredIndices)))

    # Load dataset (possibly sharded)
    itr = data.EpochBatchIterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=models[0].max_positions(),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        savedir=os.path.join(args.decode_dir, "valid_"),
        ignoredIndices=ignoredIndices,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    if args.score_reference:
        translator = SequenceScorer(models, task.target_dictionary)
    elif args.sepahypo:
        translator = SequenceGeneratorWCSSepahypo(
            models,
            task.target_dictionary,
            beam_size=args.beam,
            stop_early=(not args.no_early_stop),
            normalize_scores=(not args.unnormalized),
            len_penalty=args.lenpen,
            unk_penalty=args.unkpen,
            sampling=args.sampling,
            sampling_topk=args.sampling_topk,
            minlen=args.min_len,
            maxlen=None,
            context=args.context,
            ngram=args.ngram,
            naive=args.naive,
            num_topics=args.num_topics,
            flatenc=args.flatenc,
            flatten_source=args.flatten_source,
            cov_penalty=args.covpen,
            keystop=args.keystop,
        )
    elif args.flatdec:
        translator = SequenceGenerator(
            models,
            task.target_dictionary,
            beam_size=args.beam,
            stop_early=(not args.no_early_stop),
            normalize_scores=(not args.unnormalized),
            len_penalty=args.lenpen,
            unk_penalty=args.unkpen,
            sampling=args.sampling,
            sampling_topk=args.sampling_topk,
            minlen=args.min_len,
            flatdec=True,
        )
    else:
        translator = SequenceGeneratorWCS(
            models,
            task.target_dictionary,
            beam_size=args.beam,
            stop_early=(not args.no_early_stop),
            normalize_scores=(not args.unnormalized),
            len_penalty=args.lenpen,
            unk_penalty=args.unkpen,
            sampling=args.sampling,
            sampling_topk=args.sampling_topk,
            minlen=args.min_len,
            maxlen=None,
            context=args.context,
            ngram=args.ngram,
            num_topics=args.num_topics,
            flatenc=args.flatenc,
            dechatt=args.dechatt,
            flatten_source=args.flatten_source,
        )

    if use_cuda:
        translator.cuda()

    # Generate and compute BLEU score
    scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    outlog = open(args.decode_dir + '/out.log', 'w', encoding='utf8')
    print(
        "* Generating target texts of max length proportional to b: {} (ax+b)".
        format(args.max_len_b))
    with progress_bar.build_progress_bar(args, itr) as t:
        if args.score_reference:
            translations = translator.score_batched_itr(t,
                                                        cuda=use_cuda,
                                                        timer=gen_timer)
        else:
            translations = translator.generate_batched_itr(
                t,
                maxlen_a=args.max_len_a,
                maxlen_b=args.max_len_b,
                cuda=use_cuda,
                timer=gen_timer,
                prefix_size=args.prefix_size,
            )

        wps_meter = TimeMeter()
        for sample_id, src_tokens, target_tokens, hypos in translations:  # for each batch
            # Process input and ground truth
            has_target = target_tokens is not None
            target_tokens = target_tokens.int().cpu() if has_target else None

            # Either retrieve the original sentences or regenerate them from tokens.
            target_str = None
            if align_dict is not None and args.raw_text:
                src_str = task.dataset(
                    args.gen_subset).src.get_original_text(sample_id)
                target_str = task.dataset(
                    args.gen_subset).tgt.get_original_text(sample_id)
            else:
                src_str = src_dict.string(src_tokens, args.remove_bpe)
                if has_target and args.target_raw_text:
                    target_str_tok = tgt_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)
                    target_str = task.dataset(
                        args.gen_subset).get_target_original_text(sample_id)

            # Process top predictions
            if args.flatdec:
                processFlatHypo(sample_id, src_tokens, target_tokens, hypos,
                                src_str, align_dict, tgt_dict, args.remove_bpe,
                                has_target, target_str)
            else:
                for j in range(min(len(hypos), args.nbest)):  # for each beam
                    doc_hypo_tokens = []
                    doc_hypo_str = []
                    doc_target_str = []

                    for i in range(
                            len(hypos[j]
                                ['beam'])):  # for each sentence of the beam
                        hypo = hypos[j]['beam'][i]
                        hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                            hypo_tokens=hypo['tokens'].int().cpu(),
                            src_str=src_str,
                            alignment=hypo['alignment'].int().cpu(),
                            align_dict=align_dict,
                            tgt_dict=tgt_dict,
                            remove_bpe=args.remove_bpe,
                        )

                        if not args.quiet:
                            print('H({})-{}\t{}\t{}'.format(
                                j, sample_id, hypo['score'], hypo_str))
                            print('P({})-{}\t{}'.format(
                                j, sample_id, ' '.join(
                                    map(
                                        lambda x: '{:.4f}'.format(x),
                                        hypo['positional_scores'].tolist(),
                                    ))))
                            print('A({})-{}\t{}'.format(
                                j, sample_id, ' '.join(
                                    map(lambda x: str(utils.item(x)),
                                        alignment))))

                        subhypo = False
                        tokens_curhypo = set(hypo_str.split())
                        for hyp in doc_hypo_str:
                            tokens_hyp = set(hyp.split())

                            # if its contained in previous sentence hypothesis
                            if hypo_str.strip()[0:-1] in hyp:
                                subhypo = True
                                break

                            shorter = len(tokens_curhypo)

                            # if it overlaps on more than 80% of its tokens
                            shorter = round(shorter * 0.8)
                            if len(tokens_curhypo.intersection(
                                    tokens_hyp)) >= shorter:
                                subhypo = True

                        if not (hypo_str in doc_hypo_str or subhypo):
                            doc_hypo_str.append(hypo_str)
                        else:
                            print("repeated on {} / {}".format(sample_id, i))
                            print(hypo_str)

                        if has_target and i == 0:
                            doc_hypo_tokens.append(hypo_tokens)

                #write files for ROUGE
                with open(
                        os.path.join(args.decode_dir,
                                     "{}.dec".format(sample_id)), 'w') as f:
                    f.write(
                        make_html_safe(" ".join(doc_hypo_str).replace(
                            tgt_dict.eod_word, "").strip()))
                    f.close()

                #TODO: call scorer for BLEU

                if target_str:
                    doc_target_str.append(target_str)
                    with open(
                            os.path.join(args.reference_dir,
                                         "{}.ref".format(sample_id)),
                            'w') as f:
                        f.write(make_html_safe(" ".join(doc_target_str)))
                        f.close()
                    with open(
                            os.path.join(args.reference_dir + "_fromdict",
                                         "{}.ref".format(sample_id)),
                            'w') as f:
                        f.write(make_html_safe(target_str_tok))
                        f.close()
                outlog.write("[{}] ".format(sample_id) +
                             " ".join(doc_hypo_str).replace(
                                 tgt_dict.eod_word, "").strip() + "\n")

            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += 1

    outlog.close()

    print(
        '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
Exemplo n.º 4
0
class FairseqAgent(TorchAgent):
    """Generic wrapper around fairseq for use in ParlAI"""

    DEFAULT_OPTIONS = {
        "adam_betas": "(0.9,0.98)",
        "optimizer": "adam",
        "clip_norm": 0.1,
    }

    metrics = {}

    @classmethod
    def add_cmdline_args(cls, argparser):
        """Add command-line arguments specifically for this agent."""
        # first we need to add the general torch agent operations
        TorchAgent.add_cmdline_args(argparser)

        agent = argparser.add_argument_group('Fairseq Arguments')
        agent.add_argument('--fp16',
                           default=False,
                           type=bool,
                           help='Use fp16 training')
        agent.add_argument('--seed',
                           default=1,
                           type=int,
                           metavar='N',
                           help='pseudo random number generator seed')
        agent.add_argument(
            '--skip-generation',
            default=False,
            type=bool,
            metavar='BOOL',
            help=
            'Skips test time beam search. Much faster if you only need PPL',
        )

        # Dictionary construction stuff. Using the subclass in case we end up
        # needing any fairseq specific things
        cls.dictionary_class().add_cmdline_args(argparser)

        # Check subargs for generation, optimizers, criterions, archs, etc
        options.add_generation_args(argparser)
        options.add_optimization_args(argparser)

        # make sure we set defaults according to the model before parsing
        argparser.set_defaults(**cls.DEFAULT_OPTIONS)
        known_args = argparser.parse_known_args(nohelp=True)[0]

        if hasattr(known_args, "optimizer"):
            optimizer = known_args.optimizer
            opt_group = argparser.add_argument_group(
                '{} optimizer arguments'.format(optimizer))
            optim.OPTIMIZER_REGISTRY[optimizer].add_args(opt_group)
        if hasattr(known_args, "lr_scheduler"):
            lr_scheduler = known_args.lr_scheduler
            lr_group = argparser.add_argument_group(
                '{} scheduler arguments'.format(lr_scheduler))
            optim.lr_scheduler.LR_SCHEDULER_REGISTRY[lr_scheduler].add_args(
                lr_group)
        # We need to find out the fairseq model-specific options, so grab the
        # architecture stuff and look up its options
        arch_group = options.add_model_args(argparser)
        # Fairseq marks the arch flag as required, but it may be specified
        # by a saved model cache, so we do some weird stuff to undo that
        for a in arch_group._actions:
            if a.dest == "arch":
                a.required = False
                a.default = None
                break

        # make sure we set defaults according to parlai model before parsing
        argparser.set_defaults(**cls.DEFAULT_OPTIONS)
        known_args = argparser.parse_known_args(nohelp=True)[0]

        if hasattr(known_args, "arch") and known_args.arch is not None:
            arch = known_args.arch
            arch_group = argparser.add_argument_group(
                "{} architecture arguments".format(arch))
            models.ARCH_MODEL_REGISTRY[arch].add_args(arch_group)

        if hasattr(known_args, "criterion"):
            crit_group = argparser.add_argument_group(
                '{} criterion arguments'.format(known_args.criterion))
            criterions.CRITERION_REGISTRY[known_args.criterion].add_args(
                crit_group)

        # As one final check, let's make sure we set defaults correctly
        argparser.set_defaults(**cls.DEFAULT_OPTIONS)

    @staticmethod
    def dictionary_class():
        # Force use of the Fairseq Dictionary
        return _FairseqDictionary

    def __init__(self, opt, shared=None):
        # In general use a basic TorchAgent wherever possible
        super().__init__(opt, shared)
        if not shared:
            # this is not a shared instance of this class, so do full initialization

            # check early if we're going to be loading the model from a checkpoint
            model_file_exists = (self.opt.get('model_file')
                                 and os.path.isfile(self.opt['model_file']))

            # fairseq expects options to be in argparse format, instead of a dict
            # We also need to do some argument postprocessing and whatnot
            # We'll skip pretrained embeddings if we're going to override them with
            # a model checkpoint anyway
            self.args, self.opt = _fairseq_opt_wrapper(opt, model_file_exists)

            # seed the RNG
            torch.manual_seed(self.args.seed)

            # Just some identifying info
            self.id = "fairseq:{}".format(self.args.arch)

            # We need a placeholder task for fairseq
            self.task = _ParlaiTask(self.dict)

            # actually construct the model and generator
            self.model = self.build_model()

            # Construct the generator and scorer
            self.generator = SequenceGenerator(
                [self.model],
                tgt_dict=self.dict,
                beam_size=self.args.beam,
                stop_early=(not self.args.no_early_stop),
                normalize_scores=(not self.args.unnormalized),
                len_penalty=self.args.lenpen,
                unk_penalty=self.args.unkpen,
                sampling=self.args.sampling,
                sampling_topk=self.args.sampling_topk,
                sampling_temperature=self.args.sampling_temperature,
            )
            self.scorer = SequenceScorer([self.model], self.dict)

            # set up the grader and the trainer
            self.criterion = criterions.build_criterion(self.args, self.task)

            if getattr(self.args, 'fp16', None):
                self.trainer = fp16_trainer.FP16Trainer(
                    self.args, self.task, self.model, self.criterion)
            else:
                # TODO: we might choose to add a --no-fp16 opt in the future to
                # explicitly disable fp16 instead
                if torch.cuda.get_device_capability(0)[0] >= 7:
                    print("Heads up: using --fp16 could be a lot faster!")
                self.trainer = trainer.Trainer(self.args, self.task,
                                               self.model, self.criterion)

            # if the model already existed, let's preload it and the trainer
            if model_file_exists:
                print('Loading existing model params from ' +
                      self.opt['model_file'])
                self.load(self.opt.get('model_file'))

            # move things to the GPU if possible
            if self.use_cuda:
                self.model = self.model.cuda()
                self.generator = self.generator.cuda()
        else:
            self.model = shared['model']
            self.trainer = shared['trainer']
            self.generator = shared['generator']
            self.dict = shared['dict']
            self.args = shared['args']

        # Start things off clean
        self.reset()

    def _check_opts_unchanged(self, saved_opts, current_opts):
        """Verify that critical options do not differ in command line vs saved model"""
        for k in NON_OVERRIDABLE_ARGS:
            if k not in saved_opts or k not in current_opts:
                # if it's not an option needed by this fairseq model, don't stress
                continue
            if saved_opts[k] != current_opts[k]:
                raise ValueError(
                    '{} cannot be overridden when --model-file is specified'.
                    format(k))

    def build_model(self):
        """
        Construct the actual Fairseq model. Default implementation is to use
        Fairseq's arch builder, but this method may be overridden to build custom
        models.
        """
        model_class = models.ARCH_MODEL_REGISTRY[self.args.arch]
        return model_class.build_model(self.args, self.task)

    def share(self):
        shared = super().share()
        shared['model'] = self.model
        shared['trainer'] = self.trainer
        shared['generator'] = self.generator
        shared['dict'] = self.dict
        shared['args'] = self.args
        return shared

    def save(self, path):
        """Save using fairseq's checkpointing."""
        if not path:
            return
        self.trainer.save_checkpoint(path, {'opt': self.opt, 'epoch': 0})
        # Parlai expects options to also be saved
        with open(path + ".opt", 'wb') as handle:
            # overridden options shouldn't be stored, only the main ones
            if 'override' in self.opt:
                del self.opt['override']
            pickle.dump(self.opt, handle, protocol=pickle.HIGHEST_PROTOCOL)

    def load(self, path):
        """Load using fairseq's checkpointing."""
        old_options = self.trainer.load_checkpoint(path)
        self._check_opts_unchanged(old_options, self.opt)

    def shutdown(self):
        if not hasattr(self, 'trainer'):
            # looks like this is a "fake" model that isn't actually used for batch_act.
            # we don't need to save this one.
            return
        super().shutdown()

    def reset(self):
        """Reset observation and episode_done."""
        super().reset()
        self.reset_metrics()

    def batchify(self, obs_batch):
        """
        Override parent batchify to set requirements for fairseq.

        Fairseq depends on sorted batch inputs for a call to rnn.pad_packed_sequence.
        Fairseq models cannot handle zero length sentences
        """
        return super().batchify(obs_batch,
                                sort=True,
                                is_valid=_is_nonempty_observation)

    def train_step(self, batch):
        """Process batch of inputs and targets and train on them.

        :param batch: parlai.core.torch_agent.Batch, contains tensorized
                      version of observations.
        """
        if batch.text_vec is None:
            return
        self.is_training = True
        samples = self._make_sample(batch.text_vec, batch.label_vec)
        self.model.train()
        self.trainer.train_step(samples)

    def eval_step(self, batch):
        """Process batch of inputs.

        If the batch includes labels, calculate validation metrics as well.
        If --skip-generation is not set, return a prediction for each input.

        :param batch: parlai.core.torch_agent.Batch, contains tensorized
                      version of observations.
        """
        if batch.text_vec is None:
            return
        self.is_training = False
        samples = self._make_sample(batch.text_vec, batch.label_vec)
        self.model.eval()
        if batch.label_vec is not None:
            # Interactive mode won't have a gold label
            self.trainer.valid_step(samples)

        # Output placeholders
        reranked_cands = None
        generated_output = None

        # Grade each of the candidate sequences
        if batch.candidate_vecs is not None:
            bsz = len(batch.text_vec)
            reranked_cands = []
            # score the candidates for each item in the batch separately, so that
            # we can support variable number of candidates
            for i in range(bsz):
                cands = batch.candidate_vecs[i]
                if not cands:
                    reranked_cands.append(None)
                    continue
                ncand = len(cands)
                # repeat the input many times
                xs = batch.text_vec[i].unsqueeze(0).expand(ncand, -1)
                # some models crash if there's leading padding on every example
                xs = xs[:, :batch.text_lengths[i]]
                # and appropriately pack the outputs
                ys, _ = padded_tensor(cands, self.NULL_IDX, self.use_cuda)
                s = self._make_sample(xs, ys)
                # perform the actual grading, extract the scores
                scored = list(
                    self.scorer.score_batched_itr([s], cuda=self.use_cuda))
                scores = [s[3][0]['score'].item() for s in scored]
                # intentional hanging comma here; argsort returns a list
                ranked, = argsort(scores, batch.candidates[i], descending=True)
                reranked_cands.append(ranked)

        # Next generate freely to create our response
        if not self.args.skip_generation:
            generated_output = self._generate(samples)
        elif reranked_cands:
            # we're skiping generation, but we're also grading candidates
            # so output the highest ranked candidate
            # In the case of zero candidates, we don't have something to rank,
            # so we may need to pass on that None
            generated_output = [
                ranked and ranked[0] or None for ranked in reranked_cands
            ]
        else:
            # no output at all
            pass

        return Output(generated_output, reranked_cands)

    def _generate(self, samples):
        src_tokens = samples["net_input"]["src_tokens"]
        src_lengths = samples["net_input"]["src_lengths"]
        gens = self.generator.generate(src_tokens, src_lengths, maxlen=64)
        responses = []
        for i in range(len(src_tokens)):
            beams = gens[i]
            selected = max(beams, key=lambda x: x["score"])
            tokens = selected["tokens"]
            start = 0
            end = -1
            for i, t in enumerate(tokens):
                t = t.item()
                if t == self.dict.bos_index:
                    # don't include <s> token
                    start = i + 1
                    continue
                if t == self.dict.eos_index:
                    # stop (and don't include) </s> token
                    end = i
                    break
            responses.append(self.dict.vec2txt(tokens[start:end]))
        return responses

    def report(self):
        """Return metrics calculated by the model."""
        # if we haven't initialized yet, just return a dummy object
        if not hasattr(self, "trainer"):
            return {}

        # These are the metrics we'll pass up the way, and their new names
        train_metrics = {"train_loss", "ups", "wps", "gnorm", "clip"}
        valid_metrics = {"valid_loss"}

        metrics = train_metrics if self.is_training else valid_metrics

        m = {k: self.trainer.meters[k].avg for k in metrics}

        # additionally output perplexity. note that fairseq models use base 2
        # in cross_entropy:
        # github.com/pytorch/fairseq/blob/master/fairseq/criterions/cross_entropy.py#L55
        if "train_loss" in m:
            m["train_ppl"] = np.exp2(m["train_loss"])
        if "valid_loss" in m:
            m["ppl"] = np.exp2(m["valid_loss"])

        for k, v in m.items():
            # clean up: rounds to sigfigs and converts tensors to floats
            m[k] = round_sigfigs(v, 4)

        return m

    def reset_metrics(self):
        """Reset metrics calculated by the model back to zero."""
        if not hasattr(self, "trainer"):
            # We haven't set up the trainer yet, so we don't have any metrics
            return
        # We need to reset everything
        for k in self.trainer.meters:
            self.trainer.meters[k].reset()

    def receive_metrics(self, metrics_dict):
        """Update lr scheduler with validation loss."""
        self.trainer.lr_step(-1, metrics_dict["valid_loss"])

    # Helper functions
    def _seq_length(self, xs):
        """Compute length of the sequence (non-padded size)."""
        return xs.ne(self.dict.pad_index).long().sum(dim=-1)

    def _right_shifted_ys(self, ys):
        """Replace first token with EOS and shift remaining tokens right 1."""
        result = torch.LongTensor(ys.size())
        result[:, 0] = self.dict.eos_index
        result[:, 1:] = ys[:, :-1]
        return result

    def _make_sample(self, xs, ys):
        """Generate a sample object that Fairseq expects."""
        # add extra info to samples
        # TODO: should the right/left padding thing be in torch agent?
        sample = {}
        sample["id"] = torch.arange(len(xs) - 1)
        sample["net_input"] = {
            "src_tokens": xs,
            "src_lengths": self._seq_length(xs),
        }
        if ys is not None:
            sample["target"] = ys
            sample["ntokens"] = sum(self._seq_length(ys)).item()
            sample["net_input"]["prev_output_tokens"] = self._right_shifted_ys(
                ys)
        return sample
Exemplo n.º 5
0
class FairseqAgent(TorchAgent):
    """Generic wrapper around fairseq for use in ParlAI"""

    metrics = {}

    @classmethod
    def add_cmdline_args(cls, argparser):
        """Add command-line arguments specifically for this agent."""
        # first we need to add the general torch agent operations
        super(FairseqAgent, cls).add_cmdline_args(argparser)

        # let's store any defaults that were overridden
        old_defaults = argparser._defaults
        if 'clip_norm' not in old_defaults:
            # fairseq has a few awful defaults
            old_defaults['clip_norm'] = 1.0
        if 'optimizer' not in old_defaults:
            old_defaults['optimizer'] = 'adam'
            old_defaults['adam_betas'] = '(0.9,0.98)'

        agent = argparser.add_argument_group('Fairseq Arguments')
        agent.add_argument('--fp16',
                           default=False,
                           type='bool',
                           help='Use fp16 training')
        agent.add_argument(
            '--fp16-init-scale',
            default=2**7,
            type=int,
            help='default FP16 loss scale',
        )
        agent.add_argument(
            '--seed',
            default=1,
            type=int,
            metavar='N',
            help='pseudo random number generator seed',
        )
        agent.add_argument(
            '--skip-generation',
            default=False,
            type='bool',
            metavar='BOOL',
            help=
            'Skips test time beam search. Much faster if you only need PPL',
        )

        # Check subargs for generation, optimizers, criterions, archs, etc
        options.add_generation_args(argparser)
        options.add_optimization_args(argparser)
        options.add_checkpoint_args(argparser)

        # restore any user set defaults that fairseq possibly overrode
        argparser.set_defaults(**old_defaults)
        known_args = argparser.parse_known_args(nohelp=True)[0]

        if hasattr(known_args, "optimizer"):
            optimizer = known_args.optimizer
            opt_group = argparser.add_argument_group(
                '{} optimizer arguments'.format(optimizer))
            optim.OPTIMIZER_REGISTRY[optimizer].add_args(opt_group)
        if hasattr(known_args, "lr_scheduler"):
            lr_scheduler = known_args.lr_scheduler
            lr_group = argparser.add_argument_group(
                '{} scheduler arguments'.format(lr_scheduler))
            optim.lr_scheduler.LR_SCHEDULER_REGISTRY[lr_scheduler].add_args(
                lr_group)
        # We need to find out the fairseq model-specific options, so grab the
        # architecture stuff and look up its options
        arch_group = options.add_model_args(argparser)
        # Fairseq marks the arch flag as required, but it may be specified
        # by a saved model cache, so we do some weird stuff to undo that
        for a in arch_group._actions:
            if a.dest == "arch":
                a.required = False
                a.default = None
                break

        # once again restore any user-set defaults
        argparser.set_defaults(**old_defaults)
        known_args = argparser.parse_known_args(nohelp=True)[0]

        if hasattr(known_args, "arch") and known_args.arch is not None:
            arch = known_args.arch
            arch_group = argparser.add_argument_group(
                "{} architecture arguments".format(arch))
            models.ARCH_MODEL_REGISTRY[arch].add_args(arch_group)

        if hasattr(known_args, "criterion"):
            crit_group = argparser.add_argument_group(
                '{} criterion arguments'.format(known_args.criterion))
            criterions.CRITERION_REGISTRY[known_args.criterion].add_args(
                crit_group)

        # one last time, restore any user set defaults
        argparser.set_defaults(**old_defaults)

    @staticmethod
    def dictionary_class():
        # Force use of the Fairseq Dictionary
        return _FairseqDictionary

    def __init__(self, opt, shared=None):
        # In general use a basic TorchAgent wherever possible
        super().__init__(opt, shared)
        if not shared:
            # this is not a shared instance of this class, so do full initialization

            # check early if we're going to be loading the model from a checkpoint
            model_file_exists = self.opt.get('model_file') and os.path.isfile(
                self.opt['model_file'])

            # fairseq expects options to be in argparse format, instead of a dict
            # We also need to do some argument postprocessing and whatnot
            # We'll skip pretrained embeddings if we're going to override them with
            # a model checkpoint anyway
            self.args, self.opt = _fairseq_opt_wrapper(opt, model_file_exists)

            # seed the RNG
            torch.manual_seed(self.args.seed)

            # Just some identifying info
            self.id = "fairseq:{}".format(self.args.arch)

            # We need a placeholder task for fairseq
            self.task = _ParlaiTask(self.dict)

            # meters for keeping track of loss, ppl, etc.
            self.meters = defaultdict(AverageMeter)

            # actually construct the model and generator
            self.model = self.build_model()

            # Construct the generator and scorer
            self.generator = SequenceGenerator(
                [self.model],
                tgt_dict=self.dict,
                beam_size=self.args.beam,
                stop_early=(not self.args.no_early_stop),
                normalize_scores=(not self.args.unnormalized),
                len_penalty=self.args.lenpen,
                unk_penalty=self.args.unkpen,
                sampling=self.args.sampling,
                sampling_topk=self.args.sampling_topk,
                sampling_temperature=self.args.sampling_temperature,
            )
            self.scorer = SequenceScorer([self.model], self.dict)

            # set up the grader and the trainer
            self.criterion = criterions.build_criterion(self.args, self.task)

            # TODO: we might choose to add a --no-fp16 opt in the future to
            # explicitly disable fp16 instead
            if not self.args.fp16 and torch.cuda.get_device_capability(
                    0)[0] >= 7:
                print("Heads up: using --fp16 could be a lot faster!")
            if self.use_cuda:
                self.trainer = trainer.Trainer(self.args, self.task,
                                               self.model, self.criterion,
                                               None)
                self.trainer._build_optimizer()
            else:
                self.trainer = None

            # if the model already existed, let's preload it and the trainer
            if model_file_exists:
                print('Loading existing model params from ' +
                      self.opt['model_file'])
                self.load(self.opt.get('model_file'))

            # move things to the GPU if possible
            if self.use_cuda:
                self.model = self.model.cuda()
                self.generator = self.generator.cuda()
        else:
            self.model = shared['model']
            self.trainer = shared['trainer']
            self.generator = shared['generator']
            self.dict = shared['dict']
            self.args = shared['args']
            self.meters = shared['meters']

        # Start things off clean
        self.reset()

    def _check_opts_unchanged(self, saved_opts, current_opts):
        """Verify that critical options do not differ in command line vs saved model"""
        for k in NON_OVERRIDABLE_ARGS:
            if k not in saved_opts or k not in current_opts:
                # if it's not an option needed by this fairseq model, don't stress
                continue
            if saved_opts[k] != current_opts[k]:
                raise ValueError(
                    '{} cannot be overridden when --model-file is specified'.
                    format(k))

    def build_model(self):
        """
        Construct the actual Fairseq model. Default implementation is to use
        Fairseq's arch builder, but this method may be overridden to build custom
        models.
        """
        model_class = models.ARCH_MODEL_REGISTRY[self.args.arch]
        model = model_class.build_model(self.args, self.task)
        if self.args.embedding_type != 'random':
            self._copy_embeddings(model.encoder.embed_tokens.weight,
                                  self.args.embedding_type)
        return model

    def share(self):
        shared = super().share()
        shared['model'] = self.model
        shared['trainer'] = self.trainer
        shared['generator'] = self.generator
        shared['dict'] = self.dict
        shared['args'] = self.args
        shared['meters'] = self.meters
        return shared

    def save(self, path):
        """Save using fairseq's checkpointing."""
        if not path:
            return
        self.trainer.save_checkpoint(path, {'opt': self.opt, 'epoch': 0})
        # Parlai expects options to also be saved
        with open(path + '.opt', 'w') as handle:
            # overridden options shouldn't be stored, only the main ones
            if 'override' in self.opt:
                del self.opt['override']
            json.dump(self.opt, handle)

        # force save the dict
        self.dict.save(path + '.dict', sort=False)

    def load(self, path):
        """Load using fairseq's checkpointing."""
        if self.trainer:
            old_options = self.trainer.load_checkpoint(
                path, self.args.reset_optimizer)
            self._check_opts_unchanged(old_options, self.opt)
        else:
            load_model_state(path, self.model)

    def shutdown(self):
        if not hasattr(self, 'trainer'):
            # looks like this is a "fake" model that isn't actually used for batch_act.
            # we don't need to save this one.
            return
        super().shutdown()

    def reset(self):
        """Reset observation and episode_done."""
        super().reset()
        self.reset_metrics()

    def is_valid(self, obs):
        """Override from TorchAgent.
        Check if an observation has no tokens in it."""
        return len(obs.get('text_vec', [])) > 0

    def batchify(self, obs_batch):
        """
        Override parent batchify to set requirements for fairseq.

        Fairseq depends on sorted batch inputs for a call to rnn.pad_packed_sequence.
        Fairseq models cannot handle zero length sentences
        """
        return super().batchify(obs_batch, sort=True)

    def _update_metrics(self, metrics, sample):
        if metrics is None:
            # probably got an overflow in fp16 mode. don't count this sample
            return

        bsz = len(sample['target'])
        ntok = sample['ntokens']
        ssize = metrics['sample_size']

        for k, v in metrics.items():
            if k in {'ntokens', 'nsentences', 'sample_size'}:
                # don't need these
                continue
            elif k == "nll_loss":
                # nll loss is always normalized by ntokens
                self.meters[k].update(v, ntok)
            elif k == "loss":
                # loss is explicitly normalized by passed up sample size
                self.meters[k].update(v, ssize)
            else:
                # assume everything else it's averaged over bsz
                self.meters[k].update(v, bsz)

    def train_step(self, batch):
        """Process batch of inputs and targets and train on them.

        :param batch: parlai.core.torch_agent.Batch, contains tensorized
                      version of observations.
        """
        if batch.text_vec is None:
            return
        self.is_training = True
        sample = self._make_sample(batch)
        self.model.train()
        metrics = self.trainer.train_step([sample])
        self._update_metrics(metrics, sample)

    def eval_step(self, batch):
        """Process batch of inputs.

        If the batch includes labels, calculate validation metrics as well.
        If --skip-generation is not set, return a prediction for each input.

        :param batch: parlai.core.torch_agent.Batch, contains tensorized
                      version of observations.
        """
        if batch.text_vec is None:
            return
        self.is_training = False
        samples = self._make_sample(batch)
        self.model.eval()
        if batch.label_vec is not None and self.trainer is not None:
            # Interactive mode won't have a gold label
            metrics = self.trainer.valid_step(samples)
            self._update_metrics(metrics, samples)

        # Output placeholders
        reranked_cands = None
        generated_output = None

        # Grade each of the candidate sequences
        if batch.candidate_vecs is not None:
            bsz = len(batch.text_vec)
            reranked_cands = []
            # score the candidates for each item in the batch separately, so that
            # we can support variable number of candidates
            for i in range(bsz):
                cands = batch.candidate_vecs[i]
                if not cands:
                    reranked_cands.append(None)
                    continue
                ncand = len(cands)
                # repeat the input many times
                xs = batch.text_vec[i].unsqueeze(0).expand(ncand, -1)
                # some models crash if there's leading padding on every example
                xs = xs[:, :batch.text_lengths[i]]
                # and appropriately pack the outputs
                ys, _ = padded_tensor(cands, self.NULL_IDX, self.use_cuda)
                s = self._make_sample(xs=xs, ys=ys)
                # perform the actual grading, extract the scores
                scored = list(
                    self.scorer.score_batched_itr([s], cuda=self.use_cuda))
                scores = [s[3][0]['score'].item() for s in scored]
                # intentional hanging comma here; argsort returns a list
                ranked, = argsort(scores, batch.candidates[i], descending=True)
                reranked_cands.append(ranked)

        # Next generate freely to create our response
        if not self.args.skip_generation:
            generated_output = self._generate(samples)
        elif reranked_cands:
            # we're skiping generation, but we're also grading candidates
            # so output the highest ranked candidate
            # In the case of zero candidates, we don't have something to rank,
            # so we may need to pass on that None
            generated_output = [
                ranked and ranked[0] or None for ranked in reranked_cands
            ]
        else:
            # no output at all
            pass

        return Output(generated_output, reranked_cands)

    def _generate(self, samples):
        no_prev_token = {
            k: v
            for k, v in samples['net_input'].items()
            if k != 'prev_output_tokens'
        }
        gens = self.generator.generate(no_prev_token, maxlen=64)
        bsz = samples['net_input']['src_tokens'].size(0)
        responses = []
        for i in range(bsz):
            beams = gens[i]
            selected = max(beams, key=lambda x: x["score"])
            tokens = selected["tokens"]
            start = 0
            end = -1
            for i, t in enumerate(tokens):
                t = t.item()
                if t == self.dict.bos_index:
                    # don't include <s> token
                    start = i + 1
                    continue
                if t == self.dict.eos_index:
                    # stop (and don't include) </s> token
                    end = i
                    break
            responses.append(self.dict.vec2txt(tokens[start:end]))
        return responses

    def report(self):
        """Return metrics calculated by the model."""
        # if we haven't initialized yet, just return a dummy object
        if not hasattr(self, "trainer"):
            return {}

        output = {k: v.avg for k, v in self.meters.items()}

        if "nll_loss" in self.meters:
            # special case, we used sentence averaging so ppl comes from nll_loss
            output["ppl"] = np.exp2(self.meters["nll_loss"].avg)
        else:
            # normal case, just use loss
            output["ppl"] = np.exp2(self.meters["loss"].avg)

        # Fairseq trainer metrics we'll pass up the way
        trainer_metrics = {"ups", "wps", "gnorm", "clip"}
        if self.is_training:
            for k in trainer_metrics:
                output[k] = self.trainer.meters[k].avg

        # for display purposes
        output = {k: round_sigfigs(v, 4) for k, v in output.items()}
        return output

    def reset_metrics(self):
        """Reset metrics calculated by the model back to zero."""
        if not hasattr(self, "trainer"):
            # We haven't set up the trainer yet, so we don't have any metrics
            return
        # We need to reset everything
        self.meters.clear()
        if self.trainer:
            for k in self.trainer.meters:
                self.trainer.meters[k].reset()

    def receive_metrics(self, metrics_dict):
        """Update lr scheduler with validation loss."""
        # TODO: this should be smarter
        self.trainer.lr_step(-1, metrics_dict["loss"])

    # Helper functions
    def _seq_length(self, xs):
        """Compute length of the sequence (non-padded size)."""
        return xs.ne(self.dict.pad_index).long().sum(dim=-1)

    def _right_shifted_ys(self, ys):
        """Replace first token with EOS and shift remaining tokens right 1."""
        result = torch.LongTensor(ys.size())
        result[:, 0] = self.dict.eos_index
        result[:, 1:] = ys[:, :-1]
        return result

    def _make_sample(self, batch=None, xs=None, ys=None):
        """Generate a sample object that Fairseq expects."""
        # add extra info to samples
        if batch is None and xs is None:
            raise ValueError("Must supply either batch or xs")
        if batch is None and ys is None:
            raise ValueError("Must supply either batch or ys")
        if xs is None:
            xs = batch.text_vec
        if ys is None:
            ys = batch.label_vec
        repadded = convert_padding_direction(xs,
                                             self.dict.pad(),
                                             right_to_left=True)
        sample = {}
        sample["id"] = torch.arange(len(xs) - 1)
        sample["net_input"] = {
            "src_tokens": repadded,
            "src_lengths": self._seq_length(xs),
        }
        if ys is not None:
            sample["target"] = ys
            sample["ntokens"] = sum(self._seq_length(ys)).item()
            sample["net_input"]["prev_output_tokens"] = self._right_shifted_ys(
                ys)
        return sample
Exemplo n.º 6
0
def main(parsed_args):
    assert parsed_args.path is not None, '--path required for evaluation!'

    print(parsed_args)

    use_cuda = torch.cuda.is_available() and not parsed_args.cpu

    task = tasks.setup_task(parsed_args)

    # Load ensemble
    print('| loading model(s) from {}'.format(parsed_args.path))
    models, args = utils.load_ensemble_for_inference(parsed_args.path.split(':'), task)

    args.__dict__.update(parsed_args.__dict__)
    print(args)

    task.args = args

    # Load dataset splits
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        model.make_generation_fast_()
        if args.fp16:
            model.half()

    assert len(models) > 0

    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens or 36000,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(*[
            model.max_positions() for model in models
        ]),
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        ignore_invalid_inputs=True,
    ).next_epoch_itr(shuffle=False)

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(models, task.target_dictionary)
    if use_cuda:
        scorer.cuda()

    score_sum = 0.
    count = 0

    if args.remove_bpe is not None:
        bpe_cont = args.remove_bpe.rstrip()
        bpe_toks = set(i for i in range(len(task.dictionary)) if task.dictionary[i].endswith(bpe_cont))
        bpe_len = len(bpe_cont)
    else:
        bpe_toks = None
        bpe_len = 0

    word_stats = dict()

    with progress_bar.build_progress_bar(args, itr) as t:
        results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
        wps_meter = TimeMeter()
        for _, src_tokens, __, hypos in results:
            for hypo in hypos:
                pos_scores = hypo['positional_scores']

                skipped_toks = 0
                if bpe_toks is not None:
                    for i in range(len(hypo['tokens']) - 1):
                        if hypo['tokens'][i].item() in bpe_toks:
                            skipped_toks += 1
                            pos_scores[i + 1] += pos_scores[i]
                            pos_scores[i] = 0

                inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
                if inf_scores.any():
                    print('| Skipping tokens with inf scores:',
                          task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()]))
                    pos_scores = pos_scores[(~inf_scores).nonzero()]
                score_sum += utils.item(pos_scores.sum())
                count += pos_scores.numel() - skipped_toks

                if args.output_word_probs or args.output_word_stats:
                    w = ''
                    word_prob = []
                    is_bpe = False
                    for i in range(len(hypo['tokens'])):
                        w_ind = hypo['tokens'][i].item()
                        w += task.dictionary[w_ind]
                        if bpe_toks is not None and w_ind in bpe_toks:
                            w = w[:-bpe_len]
                            is_bpe = True
                        else:
                            word_prob.append((w, pos_scores[i].item()))
                            word_stats.setdefault(w, WordStat(w, is_bpe)).add(pos_scores[i].item())
                            is_bpe = False
                            w = ''
                    if args.output_word_probs:
                        print('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))

            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})

    avg_nll_loss = -score_sum / count
    print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss)))

    if args.output_word_stats:
        for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True):
            print(ws)
Exemplo n.º 7
0
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _model_args = utils.load_ensemble_for_inference(
        args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides),
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]
        ),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    if args.score_reference:
        translator = SequenceScorer(models, task.target_dictionary)
    else:
        translator = SequenceGenerator(
            models, task.target_dictionary, beam_size=args.beam, minlen=args.min_len,
            stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized),
            len_penalty=args.lenpen, unk_penalty=args.unkpen,
            sampling=args.sampling, sampling_topk=args.sampling_topk, sampling_temperature=args.sampling_temperature,
            diverse_beam_groups=args.diverse_beam_groups, diverse_beam_strength=args.diverse_beam_strength,
            match_source_len=args.match_source_len, no_repeat_ngram_size=args.no_repeat_ngram_size,
        )

    if use_cuda:
        translator.cuda()

    # Generate and compute BLEU score
    scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    # output the result
    result=['']*21678
    with progress_bar.build_progress_bar(args, itr) as t:
        if args.score_reference:
            translations = translator.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
        else:
            translations = translator.generate_batched_itr(
                t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
                cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size,
            )

        wps_meter = TimeMeter()
        for sample_id, src_tokens, target_tokens, hypos in translations:
            # Process input and ground truth
            has_target = target_tokens is not None
            target_tokens = target_tokens.int().cpu() if has_target else None

            # Either retrieve the original sentences or regenerate them from tokens.
            if align_dict is not None:
                src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
                target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id)
            else:
                src_str = src_dict.string(src_tokens, args.remove_bpe)
                if has_target:
                    target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True)

            if not args.quiet:
                print('S-{}\t{}'.format(sample_id, src_str))
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str))

            # Process top predictions
            for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe,
                )

                result[sample_id]=hypo_str
                if not args.quiet:
                    print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str))
                    print('P-{}\t{}'.format(
                        sample_id,
                        ' '.join(map(
                            lambda x: '{:.4f}'.format(x),
                            hypo['positional_scores'].tolist(),
                        ))
                    ))

                    if args.print_alignment:
                        print('A-{}\t{}'.format(
                            sample_id,
                            ' '.join(map(lambda x: str(utils.item(x)), alignment))
                        ))

                # Score only the top hypothesis
                if has_target and i == 0:
                    if align_dict is not None or args.remove_bpe is not None:
                        # Convert back to tokens for evaluation with unk replacement and/or without BPE
                        target_tokens = tokenizer.Tokenizer.tokenize(
                            target_str, tgt_dict, add_if_not_exist=True)
                    scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += 1

    print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
        num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
    #output the result
    return result
Exemplo n.º 8
0
def main(args):
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset
    if args.replace_unk is None:
        dataset = data.load_dataset(
            args.data,
            [args.gen_subset],
            args.source_lang,
            args.target_lang,
        )
    else:
        dataset = data.load_raw_text_dataset(
            args.data,
            [args.gen_subset],
            args.source_lang,
            args.target_lang,
            args.doctopics,
            args.encoder_embed_dim,
        )
    if args.source_lang is None or args.target_lang is None:
        # record inferred languages in args
        args.source_lang, args.target_lang = dataset.src, dataset.dst

    # Load ensemble
    print('| loading model(s) from {}'.format(', '.join(args.path)))
    models, _ = utils.load_ensemble_for_inference(args.path, dataset.src_dict,
                                                  dataset.dst_dict)

    print('| [{}] dictionary: {} types'.format(dataset.src,
                                               len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst,
                                               len(dataset.dst_dict)))
    print('| {} {} {} examples'.format(args.data, args.gen_subset,
                                       len(dataset.splits[args.gen_subset])))

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, )

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    max_positions = min(model.max_encoder_positions() for model in models)
    itr = dataset.eval_dataloader(
        args.gen_subset,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
        skip_invalid_size_inputs_valid_test=args.
        skip_invalid_size_inputs_valid_test,
    )
    if args.num_shards > 1:
        if args.shard_id < 0 or args.shard_id >= args.num_shards:
            raise ValueError('--shard-id must be between 0 and num_shards')
        itr = data.sharded_iterator(itr, args.num_shards, args.shard_id)

    print("SHASHI: I AM HERE")

    # Initialize generator
    gen_timer = StopwatchMeter()
    if args.score_reference:
        translator = SequenceScorer(models)
    else:
        translator = SequenceGenerator(
            models,
            beam_size=args.beam,
            stop_early=(not args.no_early_stop),
            normalize_scores=(not args.unnormalized),
            len_penalty=args.lenpen,
            unk_penalty=args.unkpen)
    if use_cuda:
        translator.cuda()

    # Generate and compute BLEU score
    scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(),
                         dataset.dst_dict.unk())
    num_sentences = 0
    has_target = True
    with progress_bar.build_progress_bar(args, itr) as t:
        if args.score_reference:
            translations = translator.score_batched_itr(t,
                                                        cuda=use_cuda,
                                                        timer=gen_timer)
        else:
            translations = translator.generate_batched_itr(
                t,
                maxlen_a=args.max_len_a,
                maxlen_b=args.max_len_b,
                cuda=use_cuda,
                timer=gen_timer,
                prefix_size=args.prefix_size)
        wps_meter = TimeMeter()
        for sample_id, src_tokens, target_tokens, hypos in translations:
            # Process input and ground truth
            has_target = target_tokens is not None
            target_tokens = target_tokens.int().cpu() if has_target else None
            # Either retrieve the original sentences or regenerate them from tokens.
            if align_dict is not None:
                src_str = dataset.splits[
                    args.gen_subset].src.get_original_text(sample_id)
                target_str = dataset.splits[
                    args.gen_subset].dst.get_original_text(sample_id)
            else:
                src_str = dataset.src_dict.string(src_tokens, args.remove_bpe)
                target_str = dataset.dst_dict.string(
                    target_tokens, args.remove_bpe,
                    escape_unk=True) if has_target else ''

            if not args.quiet:
                print('S-{}\t{}'.format(sample_id, src_str))
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str))

            # Process top predictions
            for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'].int().cpu(),
                    align_dict=align_dict,
                    dst_dict=dataset.dst_dict,
                    remove_bpe=args.remove_bpe,
                )

                if not args.quiet:
                    print('H-{}\t{}\t{}'.format(sample_id, hypo['score'],
                                                hypo_str))
                    print('P-{}\t{}'.format(
                        sample_id, ' '.join(
                            map(
                                lambda x: '{:.4f}'.format(x),
                                hypo['positional_scores'].tolist(),
                            ))))
                    print('A-{}\t{}'.format(
                        sample_id,
                        ' '.join(map(lambda x: str(utils.item(x)),
                                     alignment))))

                # Score only the top hypothesis
                if has_target and i == 0:
                    if align_dict is not None or args.remove_bpe is not None:
                        # Convert back to tokens for evaluation with unk replacement and/or without BPE
                        target_tokens = tokenizer.Tokenizer.tokenize(
                            target_str,
                            dataset.dst_dict,
                            add_if_not_exist=True)
                    scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += 1

    print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} tokens/s)'.
          format(num_sentences, gen_timer.n, gen_timer.sum,
                 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset,
                                                      args.beam,
                                                      scorer.result_string()))
Exemplo n.º 9
0
def main(args):
    ####1 ARGS SETTING
    assert args.path is not None, '--path required for generation!'
    assert args.root_directory is not None, '--root directory required for logging!'
    args.cpu = True
    args.score_reference = True
    args.print_alignment = True
    args.max_sentences = 1  #--batch-size 1
    args.beam = 1  #--beam 1
    args.no_progress_bar = True
    args.replace_unk = True
    ####2 LOAD DATASET IN THE RIGHT FORMAT
    task = tasks.setup_task(args)
    #print(args.gen_subset)
    #sys.exit(1)
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset,
                                       len(task.dataset(args.gen_subset))))
    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    #print("Attention {}".format(args.print_alignment))
    #print("Score reference {}".format(args.score_reference))
    ####3 LOAD MODEL
    print('| loading model(s) from {}'.format(args.path))
    models, _ = utils.load_ensemble_for_inference(args.path.split(':'),
                                                  task,
                                                  model_arg_overrides=eval(
                                                      args.model_overrides))
    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()

    ####4 CREATE AN ITERATOR
    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
    ).next_epoch_itr(shuffle=False)

    ####5 INIT GENERATOR
    gen_timer = StopwatchMeter()
    translator = SequenceScorer(models, task.target_dictionary)

    ####6 SCORING
    check_root(args.root_directory)
    if args.gold_source:
        src_gold = read_gold(args.gold_source)
    if args.gold_target:
        trg_gold = read_gold(args.gold_target)
    with progress_bar.build_progress_bar(
            args, itr) as t:  #creates a progress bar, life goes on
        translations = translator.score_batched_itr(t,
                                                    cuda=False,
                                                    timer=gen_timer)
        '''translations: <class 'generator'>
        just defines the generator,does not calculate stuff right here'''
        wps_meter = TimeMeter()
        for sample_id, src_tokens, target_tokens, _, acc_attentions in translations:
            src_str = src_dict.string(src_tokens)
            target_str = tgt_dict.string(
                target_tokens)  #, args.remove_bpe, escape_unk=True)

            s_id = sample_id.item()
            assert_unk(src_str, src_gold[s_id] if args.gold_source else None,
                       "source")
            assert_unk(target_str,
                       trg_gold[s_id] if args.gold_target else None, "target")

            write_attention_textfiles(s_id, src_str, target_str,
                                      acc_attentions, args.root_directory)
            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})
Exemplo n.º 10
0
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset,
                                       len(task.dataset(args.gen_subset))))

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary
    src_dict_sen_piece = task.source_sen_piece_dictionary
    tgt_dict_sen_piece = task.target_sen_piece_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    prefix_path = os.path.split(args.path.split(':')[0])[0]
    models, _ = utils.load_ensemble_for_inference(args.path.split(':'),
                                                  task,
                                                  model_arg_overrides=eval(
                                                      args.model_overrides))

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,  # default need_attn=False
        )
        if args.fp16:
            model.half()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    if args.score_reference:
        translator = SequenceScorer(models, task.target_dictionary)
    else:
        translator = SequenceGenerator(
            models,
            task.target_dictionary,
            task.target_sen_piece_dictionary,
            beam_size=args.beam,
            minlen=args.min_len,
            stop_early=(not args.no_early_stop),
            normalize_scores=(not args.unnormalized),
            len_penalty=args.lenpen,
            unk_penalty=args.unkpen,
            sampling=args.sampling,
            sampling_topk=args.sampling_topk,
            sampling_temperature=args.sampling_temperature,
            diverse_beam_groups=args.diverse_beam_groups,
            diverse_beam_strength=args.diverse_beam_strength,
        )

    if use_cuda:
        translator.cuda()

    # Generate and compute BLEU score
    scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    sp = spm.SentencePieceProcessor()
    # prefix = '/home/v-lijuwu'
    sp.Load(args.senpiece_model)
    with progress_bar.build_progress_bar(args, itr) as t:
        if args.score_reference:
            translations = translator.score_batched_itr(t,
                                                        cuda=use_cuda,
                                                        timer=gen_timer)
        else:
            translations = translator.generate_batched_itr(
                t,
                maxlen_a=args.max_len_a,
                maxlen_b=args.max_len_b,
                cuda=use_cuda,
                timer=gen_timer,
                prefix_size=args.prefix_size,
            )

        ftgt = open(prefix_path + '/ref_tgt.txt', 'w', encoding='utf-8')
        fbpe_src = open(prefix_path + '/bpe_src.tok', 'w', encoding='utf-8')
        fbpe_hyp = open(prefix_path + '/bpe_trans.tok', 'w', encoding='utf-8')
        fsp_src = open(prefix_path + '/sp_src.detok', 'w', encoding='utf-8')
        fsp_hyp = open(prefix_path + '/trans.txt', 'w', encoding='utf-8')
        fhyp_tok = open(prefix_path + '/hyp_trans.txt', 'w', encoding='utf-8')
        fhyp_tok_ids = open(prefix_path + '/hyp_ids.txt',
                            'w',
                            encoding='utf-8')
        wps_meter = TimeMeter()
        id = 0
        for sample_id, src_tokens, target_tokens, src_sen_piece_tokens, target_sen_piece_tokens, hypos, hypos_sen_piece in translations:
            # Process input and ground truth
            has_target = target_tokens is not None
            target_tokens = target_tokens.int().cpu() if has_target else None
            target_sen_piece_tokens = target_sen_piece_tokens.int().cpu(
            ) if has_target else None

            # Either retrieve the original sentences or regenerate them from tokens.
            if align_dict is not None:
                src_str = task.dataset(
                    args.gen_subset).src.get_original_text(sample_id)
                target_str = task.dataset(
                    args.gen_subset).tgt.get_original_text(sample_id)
                src_str_sen_piece = task.dataset(
                    args.gen_subset).src_sen_piece.get_original_text(sample_id)
                tgt_str_sen_piece = task.dataset(
                    args.gen_subset).tgt_sen_piece.get_original_text(sample_id)
            else:
                src_str = src_dict.string(src_tokens, args.remove_bpe)
                fbpe_src.write(src_str + '\n')  # write  bpe_token data
                if has_target:
                    target_str = tgt_dict.string(target_tokens,
                                                 args.remove_bpe,
                                                 escape_unk=True)

                src_str_sen_piece = src_dict_sen_piece.string(
                    src_sen_piece_tokens)  # return list, not string
                src_str_sen_piece_list = src_dict_sen_piece.to_list(
                    src_sen_piece_tokens)
                src_str_out = sp.DecodePieces(src_str_sen_piece_list)
                fsp_src.write(src_str_out + '\n')  # write sp_detok data
                if has_target:
                    tgt_str_sen_piece_list = tgt_dict_sen_piece.to_list(
                        target_sen_piece_tokens, escape_unk=True)
                    tgt_str_sen_piece = tgt_dict_sen_piece.string(
                        target_sen_piece_tokens, escape_unk=True)
                    tgt_str_out = sp.DecodePieces(tgt_str_sen_piece_list)
                    ftgt.write(tgt_str_out + '\n')

            if not args.quiet:
                print('S-{}\t{}'.format(sample_id, src_str))
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str))
                print('SS-{}\t{}'.format(sample_id, src_str_sen_piece))
                if has_target:
                    print('TS-{}\t{}'.format(sample_id, tgt_str_sen_piece))

            score1 = 0.
            hypo_str1 = ""
            # Process top predictions
            for i, hypo in enumerate(
                    hypos[:min(len(hypos), args.nbest)]):  # args.nbest=1
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'].int().cpu()
                    if hypo['alignment'] is not None else None,
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe,
                )

                if not args.quiet:
                    print('H-{}\t{}\t{}'.format(sample_id, hypo['score'],
                                                hypo_str))
                    print('P-{}\t{}'.format(
                        sample_id, ' '.join(
                            map(
                                lambda x: '{:.4f}'.format(x),
                                hypo['positional_scores'].tolist(),
                            ))))

                    if args.print_alignment:
                        print('A-{}\t{}'.format(
                            sample_id, ' '.join(
                                map(lambda x: str(utils.item(x)), alignment))))

                # Score only the top hypothesis
                if has_target and i == 0:
                    score1 = hypo['score']
                    hypo_str1 = hypo_str
                    if align_dict is not None or args.remove_bpe is not None:
                        # Convert back to tokens for evaluation with unk replacement and/or without BPE
                        target_tokens = tokenizer.Tokenizer.tokenize(
                            target_str, tgt_dict, add_if_not_exist=True)
                    scorer.add(target_tokens, hypo_tokens)
                # write bpe_trans to file
                fbpe_hyp.write(hypo_str + '\n')

            score2 = 0.
            # process sen_piece and save translations to file
            for i, hypo in enumerate(
                    hypos_sen_piece[:min(len(hypos_sen_piece), args.nbest)]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str_sen_piece,
                    alignment=hypo['alignment'].int().cpu()
                    if hypo['alignment'] is not None else None,
                    align_dict=align_dict,
                    tgt_dict=tgt_dict_sen_piece,
                    remove_bpe=None,
                    to_list=True,
                )
                if not args.quiet:
                    print('HS-{}\t{}\t{}'.format(sample_id, hypo['score'],
                                                 hypo_str))
                hypo_str_out = sp.DecodePieces(hypo_str)
                fsp_hyp.write(hypo_str_out + '\n')  # detokenized data

                # Score only the top hypothesis
                if has_target and i == 0:
                    score2 = hypo['score']
            if score1 > score2:
                fhyp_tok.write(hypo_str1 + '\n')
                fhyp_tok_ids.write(str(id) + '\n')
            id += 1
            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += 1
    ftgt.close()
    fbpe_src.close()
    fbpe_hyp.close()
    fsp_src.close()
    fsp_hyp.close()
    fhyp_tok.close()
    fhyp_tok_ids.close()
    print(
        '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset,
                                                      args.beam,
                                                      scorer.result_string()))
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Set up functions for multiturn
    def train_path(lang):
        return "{}{}".format(args.trainpref, ("." + lang) if lang else "")

    def file_name(prefix, lang):
        fname = prefix
        if lang is not None:
            fname += ".{lang}".format(lang=lang)
        return fname

    def dest_path(prefix, lang):
        return os.path.join(args.destdir, file_name(prefix, lang))

    def dict_path(lang):
        return dest_path("dict", lang) + ".txt"

    def build_dictionary(filenames, src=False, tgt=False):
        assert src ^ tgt
        return task.build_dictionary(
            filenames,
            workers=args.workers,
            threshold=args.thresholdsrc if src else args.thresholdtgt,
            nwords=args.nwordssrc if src else args.nwordstgt,
            padding_factor=args.padding_factor,
        )

    def make_binary_dataset(input_prefix, output_prefix, lang, num_workers):
        dict = task.load_dictionary(dict_path(lang))
        print("| [{}] Dictionary: {} types".format(lang, len(dict) - 1))
        n_seq_tok = [0, 0]
        replaced = Counter()

        def merge_result(worker_result):
            replaced.update(worker_result["replaced"])
            n_seq_tok[0] += worker_result["nseq"]
            n_seq_tok[1] += worker_result["ntok"]

        input_file = "{}{}".format(input_prefix,
                                   ("." + lang) if lang is not None else "")
        offsets = Tokenizer.find_offsets(input_file, num_workers)
        pool = None
        if num_workers > 1:
            pool = Pool(processes=num_workers - 1)
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                pool.apply_async(
                    binarize,
                    (
                        args,
                        input_file,
                        dict,
                        prefix,
                        lang,
                        offsets[worker_id],
                        offsets[worker_id + 1],
                    ),
                    callback=merge_result,
                )
            pool.close()

        ds = indexed_dataset.IndexedDatasetBuilder(
            dataset_dest_file(args, output_prefix, lang, "bin"))
        merge_result(
            Tokenizer.binarize(input_file,
                               dict,
                               lambda t: ds.add_item(t),
                               offset=0,
                               end=offsets[1]))
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, lang)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

        ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))

        print("| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format(
            lang,
            input_file,
            n_seq_tok[0],
            n_seq_tok[1],
            100 * sum(replaced.values()) / n_seq_tok[1],
            dict.unk_word,
        ))

    def make_dataset(input_prefix, output_prefix, lang, num_workers=1):
        if args.output_format == "binary":
            make_binary_dataset(input_prefix, output_prefix, lang, num_workers)
        elif args.output_format == "raw":
            # Copy original text file to destination folder
            output_text_file = dest_path(
                output_prefix +
                ".{}-{}".format(args.source_lang, args.target_lang),
                lang,
            )
            shutil.copyfile(file_name(input_prefix, lang), output_text_file)

    def make_all(lang):
        if args.multiturnpref:
            make_dataset(args.multiturnpref,
                         "test",
                         lang,
                         num_workers=args.workers)

    # Load dataset splits
    task = tasks.setup_task(args)

    # Multiturn tracking: prompt in test set, turn in debate
    turn = 0
    prompt = 1
    first_pass = True
    while first_pass or args.multiturn:
        if args.multiturn:
            # Set up first turn
            if turn == 0:
                multiturn_file = "{}{}".format(args.multiturnpref,
                                               ("." + args.source_lang))
                test_file = "{}{}".format(args.testpref,
                                          ("." + args.source_lang))
                if args.interactive:
                    line = input('What subject would you like to debate?')
                else:
                    with open(test_file, 'r', encoding='utf-8') as f:
                        for i in range(prompt):
                            line = f.readline()
                with open(multiturn_file, 'w', encoding='utf-8') as f:
                    f.write(line)
                prompt += 1

            target = not args.only_source
            assert (args.multiturnpref), "--multiturnpref must be set"
            if args.joined_dictionary:
                assert (
                    not args.srcdict or not args.tgtdict
                ), "cannot use both --srcdict and --tgtdict with --joined-dictionary"

                if args.srcdict:
                    src_dict = task.load_dictionary(args.srcdict)
                elif args.tgtdict:
                    src_dict = task.load_dictionary(args.tgtdict)
                else:
                    assert (
                        args.trainpref
                    ), "--trainpref must be set if --srcdict is not specified"
                    src_dict = build_dictionary(
                        {
                            train_path(lang)
                            for lang in [args.source_lang, args.target_lang]
                        },
                        src=True)
                tgt_dict = src_dict
            else:
                if args.srcdict:
                    src_dict = task.load_dictionary(args.srcdict)
                else:
                    assert (
                        args.trainpref
                    ), "--trainpref must be set if --srcdict is not specified"
                    src_dict = build_dictionary([train_path(args.source_lang)],
                                                src=True)
            if target:
                if args.tgtdict:
                    tgt_dict = task.load_dictionary(args.tgtdict)
                else:
                    assert (
                        args.trainpref
                    ), "--trainpref must be set if --tgtdict is not specified"
                    tgt_dict = build_dictionary([train_path(args.target_lang)],
                                                tgt=True)
            else:
                tgt_dict = None

            src_dict.save(dict_path(args.source_lang))
            if target and tgt_dict is not None:
                tgt_dict.save(dict_path(args.target_lang))

            make_all(args.source_lang)
            if target:
                make_all(args.target_lang)
            if first_pass:
                print("| Wrote preprocessed data to {}".format(args.destdir))
                print('| Generating multiturn debate')
            task.load_dataset('test')
        else:
            task.load_dataset(args.gen_subset)
            print('| {} {} {} examples'.format(
                args.data, args.gen_subset,
                len(task.dataset(args.gen_subset))))

        if first_pass:
            # Set dictionaries
            src_dict = task.source_dictionary
            tgt_dict = task.target_dictionary

            # Load ensemble
            print('| loading model(s) from {}'.format(args.path))
            models, _model_args = utils.load_ensemble_for_inference(
                args.path.split(':'),
                task,
                model_arg_overrides=eval(args.model_overrides),
            )

            # Optimize ensemble for generation
            for model in models:
                model.make_generation_fast_(
                    beamable_mm_beam_size=None
                    if args.no_beamable_mm else args.beam,
                    need_attn=args.print_alignment,
                )
                if args.fp16:
                    model.half()

        # Load alignment dictionary for unknown word replacement
        # (None if no unknown word replacement, empty if no path to align dictionary)
        align_dict = utils.load_align_dict(args.replace_unk)

        # Load dataset (possibly sharded)
        itr = task.get_batch_iterator(
            dataset=task.dataset(args.gen_subset),
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                *[model.max_positions() for model in models]),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=8,
            num_shards=args.num_shards,
            shard_id=args.shard_id,
            num_workers=args.num_workers,
        ).next_epoch_itr(shuffle=False)

        # Initialize generator
        gen_timer = StopwatchMeter()
        if args.score_reference:
            translator = SequenceScorer(models, task.target_dictionary)
        else:
            translator = SequenceGenerator(
                models,
                task.target_dictionary,
                beam_size=args.beam,
                minlen=args.min_len,
                stop_early=(not args.no_early_stop),
                normalize_scores=(not args.unnormalized),
                len_penalty=args.lenpen,
                unk_penalty=args.unkpen,
                sampling=args.sampling,
                sampling_topk=args.sampling_topk,
                sampling_temperature=args.sampling_temperature,
                diverse_beam_groups=args.diverse_beam_groups,
                diverse_beam_strength=args.diverse_beam_strength,
                match_source_len=args.match_source_len,
                no_repeat_ngram_size=args.no_repeat_ngram_size,
            )

        if use_cuda:
            translator.cuda()

        # Generate and compute BLEU score
        if args.sacrebleu:
            scorer = bleu.SacrebleuScorer()
        else:
            scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(),
                                 tgt_dict.unk())
        num_sentences = 0
        has_target = True
        with progress_bar.build_progress_bar(args, itr) as t:
            if args.score_reference:
                translations = translator.score_batched_itr(t,
                                                            cuda=use_cuda,
                                                            timer=gen_timer)
            else:
                translations = translator.generate_batched_itr(
                    t,
                    maxlen_a=args.max_len_a,
                    maxlen_b=args.max_len_b,
                    cuda=use_cuda,
                    timer=gen_timer,
                    prefix_size=args.prefix_size,
                )

            wps_meter = TimeMeter()
            for sample_id, src_tokens, target_tokens, hypos in translations:

                # Process input and ground truth
                has_target = target_tokens is not None
                target_tokens = target_tokens.int().cpu(
                ) if has_target else None

                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(
                        args.gen_subset).src.get_original_text(sample_id)
                    target_str = task.dataset(
                        args.gen_subset).tgt.get_original_text(sample_id)
                else:
                    src_str = src_dict.string(src_tokens, args.remove_bpe)
                    if has_target:
                        target_str = tgt_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)

                if not args.quiet:
                    print('S-{}\t{}'.format(sample_id, src_str))
                    if has_target:
                        print('T-{}\t{}'.format(sample_id, target_str))

                # Process top predictions
                for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int().cpu(),
                        src_str=src_str,
                        alignment=hypo['alignment'].int().cpu()
                        if hypo['alignment'] is not None else None,
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )

                    if not args.quiet:
                        print('H-{}\t{}\t{}'.format(sample_id, hypo['score'],
                                                    hypo_str))
                        print('P-{}\t{}'.format(
                            sample_id, ' '.join(
                                map(
                                    lambda x: '{:.4f}'.format(x),
                                    hypo['positional_scores'].tolist(),
                                ))))

                        if args.print_alignment:
                            print('A-{}\t{}'.format(
                                sample_id, ' '.join(
                                    map(lambda x: str(utils.item(x)),
                                        alignment))))

                    if args.multiturn:
                        multiturn_file = "{}{}".format(
                            args.multiturnpref, ("." + args.source_lang))
                        output_file = "{}{}".format(args.outputpref,
                                                    ("." + args.target_lang))
                        with open(multiturn_file, 'r', encoding='utf-8') as f:
                            line = f.readline()
                            if args.interactive:
                                interactive_response = input('Please respond:')
                                line += f' <EOA> {interactive_response}'
                        if turn < MAX_TURNS - 1:
                            with open(multiturn_file, 'w',
                                      encoding='utf-8') as f:
                                f.write(f'{line[:-1]} <EOA> {hypo_str}')
                            turn += 1
                        elif turn == MAX_TURNS - 1:
                            with open(output_file, 'a', encoding='utf-8') as f:
                                f.write(f'{line[:-1]} <EOA> {hypo_str}\n')
                            turn = 0

                    # Score only the top hypothesis
                    if has_target and i == 0:
                        if align_dict is not None or args.remove_bpe is not None:
                            # Convert back to tokens for evaluation with unk replacement and/or without BPE
                            target_tokens = tokenizer.Tokenizer.tokenize(
                                target_str, tgt_dict, add_if_not_exist=True)
                        if hasattr(scorer, 'add_string'):
                            scorer.add_string(target_str, hypo_str)
                        else:
                            scorer.add(target_tokens, hypo_tokens)

                wps_meter.update(src_tokens.size(0))
                t.log({'wps': round(wps_meter.avg)})
                num_sentences += 1

        print(
            '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
            .format(num_sentences, gen_timer.n, gen_timer.sum,
                    num_sentences / gen_timer.sum, 1. / gen_timer.avg))
        if has_target:
            print('| Generate {} with beam={}: {}'.format(
                args.gen_subset, args.beam, scorer.result_string()))

        first_pass = False
Exemplo n.º 12
0
    def test_sequence_scorer(self):
        # construct dummy dictionary
        d = test_utils.dummy_dictionary(vocab_size=2)
        self.assertEqual(d.pad(), 1)
        self.assertEqual(d.eos(), 2)
        self.assertEqual(d.unk(), 3)
        eos = d.eos()
        w1 = 4
        w2 = 5

        # construct dataloader
        data = [
            {
                'source': torch.LongTensor([w1, w2, eos]),
                'target': torch.LongTensor([w1, w2, w1, eos]),
            },
            {
                'source': torch.LongTensor([w2, eos]),
                'target': torch.LongTensor([w2, w1, eos]),
            },
            {
                'source': torch.LongTensor([w2, eos]),
                'target': torch.LongTensor([w2, eos]),
            },
        ]
        data_itr = test_utils.dummy_dataloader(data)

        # specify expected output probabilities
        args = argparse.Namespace()
        unk = 0.
        args.beam_probs = [
            # step 0:
            torch.FloatTensor([
                # eos      w1   w2
                [0.0, unk, 0.6, 0.4],  # sentence 1
                [0.0, unk, 0.4, 0.6],  # sentence 2
                [0.0, unk, 0.7, 0.3],  # sentence 3
            ]),
            # step 1:
            torch.FloatTensor([
                # eos      w1   w2
                [0.0, unk, 0.2, 0.7],  # sentence 1
                [0.0, unk, 0.8, 0.2],  # sentence 2
                [0.7, unk, 0.1, 0.2],  # sentence 3
            ]),
            # step 2:
            torch.FloatTensor([
                # eos       w1    w2
                [0.10, unk, 0.50, 0.4],  # sentence 1
                [0.15, unk, 0.15, 0.7],  # sentence 2
                [0.00, unk, 0.00, 0.0],  # sentence 3
            ]),
            # step 3:
            torch.FloatTensor([
                # eos      w1    w2
                [0.9, unk, 0.05, 0.05],  # sentence 1
                [0.0, unk, 0.00, 0.0],  # sentence 2
                [0.0, unk, 0.00, 0.0],  # sentence 3
            ]),
        ]
        expected_scores = [
            [0.6, 0.7, 0.5, 0.9],  # sentence 1
            [0.6, 0.8, 0.15],  # sentence 2
            [0.3, 0.7],  # sentence 3
        ]

        model = test_utils.TestModel.build_model(args, d, d)
        scorer = SequenceScorer([model])
        for id, _src, _ref, hypos in scorer.score_batched_itr(data_itr):
            self.assertHypoTokens(hypos[0], data[id]['target'])
            self.assertHypoScore(hypos[0], expected_scores[id])
Exemplo n.º 13
0
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset, aligned=False)
    print('| {} {} {} examples'.format(args.data, args.gen_subset,
                                       len(task.dataset(args.gen_subset))))

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _ = utils.load_ensemble_for_inference(args.path.split(':'),
                                                  task,
                                                  model_arg_overrides=eval(
                                                      args.model_overrides))
    first_model = models[0]

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Initialize generator
    gen_timer = StopwatchMeter()
    if args.score_reference:
        translator = SequenceScorer(models, task.target_dictionary)
    else:
        translator = SequenceGenerator(
            models,
            task.target_dictionary,
            beam_size=args.beam,
            stop_early=(not args.no_early_stop),
            normalize_scores=(not args.unnormalized),
            len_penalty=args.lenpen,
            unk_penalty=args.unkpen,
            sampling=args.sampling,
            sampling_topk=args.sampling_topk,
            minlen=args.min_len,
        )

    if use_cuda:
        translator.cuda()

    for data_idx in [0, 1]:

        # Load dataset (possibly sharded)
        itr = data.EpochBatchIterator(
            dataset=task.dataset(args.gen_subset)[data_idx],
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences,
            max_positions=models[0].max_positions(),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=8,
            num_shards=args.num_shards,
            shard_id=args.shard_id,
        ).next_epoch_itr(shuffle=False)

        # Generate and compute BLEU score
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
        num_sentences = 0
        has_target = True
        res = []
        out_obj = []
        with progress_bar.build_progress_bar(args, itr) as t:
            if args.score_reference:
                translations = translator.score_batched_itr(t,
                                                            cuda=use_cuda,
                                                            timer=gen_timer)
            else:
                translations = translator.generate_batched_itr(
                    t,
                    maxlen_a=args.max_len_a,
                    maxlen_b=args.max_len_b,
                    cuda=use_cuda,
                    timer=gen_timer,
                    prefix_size=args.prefix_size,
                    to_trg=(data_idx == 0),
                )

            wps_meter = TimeMeter()
            for sample_id, src_tokens, target_tokens, hypos in translations:

                # sample out dict
                sample_out_dict = {}

                # Process input and ground truth
                has_target = target_tokens is not None
                target_tokens = target_tokens.int().cpu(
                ) if has_target else None

                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(
                        args.gen_subset).src.get_original_text(sample_id)
                    target_str = task.dataset(
                        args.gen_subset).tgt.get_original_text(sample_id)
                else:
                    src_str = src_dict.string(src_tokens, args.remove_bpe)
                    if has_target:
                        target_str = tgt_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)

                sample_out_dict['source'] = src_str
                if has_target:
                    sample_out_dict['target'] = target_str

                if not args.quiet:
                    print('S-{}\t{}'.format(sample_id, src_str))
                    if has_target:
                        print('T-{}\t{}'.format(sample_id, target_str))

                # Process top predictions
                preds = []

                sample_out_dict['translations'] = []
                sample_out_dict['gen_scores'] = []
                sample_out_dict['class_scores'] = []
                sample_out_dict['oracle_scores'] = []

                for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int().cpu(),
                        src_str=src_str,
                        alignment=hypo['alignment'].int().cpu()
                        if hypo['alignment'] is not None else None,
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )
                    sample_out_dict['translations'].append(hypo_str)
                    sample_out_dict['gen_scores'].append(hypo['score'])

                    # res.append((sample_id.item(), hypo_str, hypo['score']))
                    preds.append([hypo['score'], hypo_str, sample_id.item()])

                    # oracle_score
                    # oracle_score = sentence_bleu([target_str.split()], hypo_str.split())
                    # sample_out_dict['oracle_scores'].append(oracle_score)
                    # if args.oracle_score:
                    #     if has_target: # score the prediction
                    #         # replace the hypo score with the testing one
                    #         preds[-1][0] = oracle_score
                    #     else:
                    #         print("# WARNING: Not target to compute oracle")

                    # disc_score
                    padded_hypo_tokens = collate_tokens(
                        [hypo['tokens']],
                        pad_idx=first_model.src_dict.pad(),
                        eos_idx=first_model.src_dict.eos(),
                        left_pad=False,
                        min_size=5,
                    )
                    # print("padded_hypo_tokens.size", padded_hypo_tokens.size())
                    # print(models[0].discriminator.pred(padded_hypo_tokens)[0].size())
                    disc_score = models[0].discriminator.pred(
                        padded_hypo_tokens)[0][0][1 - data_idx].item()
                    sample_out_dict['class_scores'].append(disc_score)
                    if args.disc_score:
                        if hasattr(first_model, 'discriminator'):

                            preds[-1][0] = -float(
                                "inf") if disc_score < 0.5 else preds[-1][0]
                            # print("{}:{}".format(hypo_str, preds[-1][0]))
                        else:
                            print("# WARNING: No discriminator to score")

                    if not args.quiet:
                        print('H-{}\t{}\t{}'.format(sample_id, hypo['score'],
                                                    hypo_str))
                        print('P-{}\t{}'.format(
                            sample_id, ' '.join(
                                map(
                                    lambda x: '{:.4f}'.format(x),
                                    hypo['positional_scores'].tolist(),
                                ))))

                        if args.print_alignment:
                            print('A-{}\t{}'.format(
                                sample_id, ' '.join(
                                    map(lambda x: str(utils.item(x)),
                                        alignment))))

                    # Score only the top hypothesis
                    if has_target and i == 0:
                        if align_dict is not None or args.remove_bpe is not None:
                            # Convert back to tokens for evaluation with unk replacement and/or without BPE
                            target_tokens = tokenizer.Tokenizer.tokenize(
                                target_str, tgt_dict, add_if_not_exist=True)
                        scorer.add(target_tokens, hypo_tokens)

                preds = sorted(preds, reverse=True)
                res.append((preds[0][2], preds[0][1], preds[0][0]))

                wps_meter.update(src_tokens.size(0))
                t.log({'wps': round(wps_meter.avg)})
                num_sentences += 1

                out_obj.append(sample_out_dict)

        if args.output_path is not None:
            if data_idx == 0:
                output_suffix = '.' + args.source_lang + '-' + args.target_lang
            else:
                output_suffix = '.' + args.target_lang + '-' + args.source_lang
            out = open(args.output_path + output_suffix, 'w')
            res = sorted(res)
            for r in res:
                if args.score_reference:
                    out.write("{} ||| {:.4f}\n".format(r[1], r[2]))
                else:
                    out.write(r[1] + '\n')

            with open(args.output_path + output_suffix + '.json',
                      'w') as f_out:
                f_out.write(
                    json.dumps(out_obj,
                               ensure_ascii=False,
                               sort_keys=False,
                               indent=4))

    print(
        '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset,
                                                      args.beam,
                                                      scorer.result_string()))
Exemplo n.º 14
0
    def test_sequence_scorer(self):
        # construct dummy dictionary
        d = test_utils.dummy_dictionary(vocab_size=2)
        self.assertEqual(d.pad(), 1)
        self.assertEqual(d.eos(), 2)
        self.assertEqual(d.unk(), 3)
        eos = d.eos()
        w1 = 4
        w2 = 5

        # construct dataloader
        data = [
            {
                'source': torch.LongTensor([w1, w2, eos]),
                'target': torch.LongTensor([w1, w2, w1, eos]),
            },
            {
                'source': torch.LongTensor([w2, eos]),
                'target': torch.LongTensor([w2, w1, eos]),
            },
            {
                'source': torch.LongTensor([w2, eos]),
                'target': torch.LongTensor([w2, eos]),
            },
        ]
        data_itr = test_utils.dummy_dataloader(data)

        # specify expected output probabilities
        args = argparse.Namespace()
        unk = 0.
        args.beam_probs = [
            # step 0:
            torch.FloatTensor([
                # eos      w1   w2
                [0.0, unk, 0.6, 0.4],  # sentence 1
                [0.0, unk, 0.4, 0.6],  # sentence 2
                [0.0, unk, 0.7, 0.3],  # sentence 3
            ]),
            # step 1:
            torch.FloatTensor([
                # eos      w1   w2
                [0.0, unk, 0.2, 0.7],  # sentence 1
                [0.0, unk, 0.8, 0.2],  # sentence 2
                [0.7, unk, 0.1, 0.2],  # sentence 3
            ]),
            # step 2:
            torch.FloatTensor([
                # eos       w1    w2
                [0.10, unk, 0.50, 0.4],  # sentence 1
                [0.15, unk, 0.15, 0.7],  # sentence 2
                [0.00, unk, 0.00, 0.0],  # sentence 3
            ]),
            # step 3:
            torch.FloatTensor([
                # eos      w1    w2
                [0.9, unk, 0.05, 0.05],  # sentence 1
                [0.0, unk, 0.00, 0.0],  # sentence 2
                [0.0, unk, 0.00, 0.0],  # sentence 3
            ]),
        ]
        expected_scores = [
            [0.6, 0.7, 0.5, 0.9],  # sentence 1
            [0.6, 0.8, 0.15],  # sentence 2
            [0.3, 0.7],  # sentence 3
        ]

        task = test_utils.TestTranslationTask.setup_task(args, d, d)
        model = task.build_model(args)
        scorer = SequenceScorer([model], task.target_dictionary)
        for id, _src, _ref, hypos in scorer.score_batched_itr(data_itr):
            self.assertHypoTokens(hypos[0], data[id]['target'])
            self.assertHypoScore(hypos[0], expected_scores[id])
Exemplo n.º 15
0
def decode_from_dataset(models, task, args, use_cuda, output_filename=None):
    # Load dataset splits
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    output_filename = output_filename if output_filename is not None else args.decode_output_file
    if output_filename is not None:
        base_filename = output_filename
    else:
        base_filename = args.gen_subset
        if args.num_shards:
            base_filename += "%.2d" % args.shard_id
    decode_filename = _decode_filename(base_filename, args)
    outfile = open(decode_filename, "w")

    # Load dataset (possibly sharded)
    itr = data.EpochBatchIterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=models[0].max_positions(),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    if args.score_reference:
        translator = SequenceScorer(models, task.target_dictionary)
    else:
        translator = SequenceGenerator(
            models, task.target_dictionary, beam_size=args.beam,
            stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized),
            len_penalty=args.lenpen, unk_penalty=args.unkpen,
            sampling=args.sampling, sampling_topk=args.sampling_topk, minlen=args.min_len,
        )

    if use_cuda:
        translator.cuda()

    # Generate and compute BLEU score
    scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True

    if args.score_reference:
        translations = translator.score_batched_itr(itr, cuda=use_cuda, timer=gen_timer)
    else:
        translations = translator.generate_batched_itr(
            itr, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
            cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size,
        )

    wps_meter = TimeMeter()
    for sample_id, src_tokens, target_tokens, hypos in translations:
        # Process input and ground truth
        has_target = target_tokens is not None
        target_tokens = target_tokens.int().cpu() if has_target else None

        # Either retrieve the original sentences or regenerate them from tokens.
        if align_dict is not None:
            src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
            target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id)
        else:
            src_str = src_dict.string(src_tokens, args.remove_bpe)
            if has_target:
                target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True)

        if not args.quiet:
            try:
                print('S-{}\t{}'.format(sample_id, src_str))
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str))
            except:
                print('S-{}\t{}'.format(sample_id, src_str.encode('utf-8')))
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str.encode('utf-8')))

        # Process top predictions
        for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
            hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                hypo_tokens=hypo['tokens'].int().cpu(),
                src_str=src_str,
                alignment=hypo['alignment'].int().cpu(),
                align_dict=align_dict,
                tgt_dict=tgt_dict,
                remove_bpe=args.remove_bpe,
            )

            if not args.quiet:
                try:
                    print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str))
                except:
                    print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str.encode('utf-8')))
                print('P-{}\t{}'.format(
                    sample_id,
                    ' '.join(map(
                        lambda x: '{:.4f}'.format(x),
                        hypo['positional_scores'].tolist(),
                    ))
                ))
                print('A-{}\t{}'.format(
                    sample_id,
                    ' '.join(map(lambda x: str(utils.item(x)), alignment))
                ))

            # Score only the top hypothesis
            if has_target and i == 0:
                if align_dict is not None or args.remove_bpe is not None:
                    # Convert back to tokens for evaluation with unk replacement and/or without BPE
                    target_tokens = tokenizer.Tokenizer.tokenize(
                        target_str, tgt_dict, add_if_not_exist=True)
                scorer.add(target_tokens, hypo_tokens)

        wps_meter.update(src_tokens.size(0))

        num_sentences += 1

    print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
        num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
Exemplo n.º 16
0
class FluencyScorer(object):
    """Evaluate sentences for fluency.
    The FluencyScorer class uses an embedded language model to score candidate
    sentences for according to how likely they would be used by a native
    speaker.
    """

    def __init__(self, path, data, use_cpu=True):
        # Create the language modeling task.
        self.args = FluencyArgs(path, data)
        self.task = tasks.setup_task(self.args)
        self.use_cuda = torch.cuda.is_available and not use_cpu

        # Load language model ensemble.
        models, model_args = utils.load_ensemble_for_inference(self.args.path.split(':'), self.task)
        self.models = models
        self.model_args = model_args

        # Optimize ensemble for generation.
        for model in self.models:
            model.make_generation_fast_()
            if self.use_cuda and self.model_args.fp16:
                model.half()

        # Create the sequence scorer.
        self.scorer = SequenceScorer(self.models, self.task.target_dictionary)
        if self.use_cuda:
            self.scorer.cuda()

    def score_sentence(self, line):
        # Tokenize the input sentence into a batch of size one.
        tokens = tokenizer.Tokenizer.tokenize(line, self.task.dictionary, add_if_not_exist=False).long()
        lengths = np.array([tokens.numel()])
        ds = data.TokenBlockDataset(tokens, lengths, self.args.tokens_per_sample, pad=self.task.dictionary.pad(),
                                    eos=self.task.dictionary.eos(), break_mode=self.args.sample_break_mode,
                                    include_targets=True)

        # Create a batch iterator to wrap the data.
        add_eos_for_other_targets = self.args.sample_break_mode is not None and self.args.sample_break_mode != 'none'
        itr = self.task.get_batch_iterator(
            dataset=data.MonolingualDataset(ds, ds.sizes, self.task.dictionary, self.task.target_dictionary,
                                            add_eos_for_other_targets=add_eos_for_other_targets, shuffle=False,
                                            targets=self.task.targets),
            max_tokens=self.args.max_tokens or 3000,
            max_sentences=self.args.max_sentences,
            max_positions=utils.resolve_max_positions(*[
                model.max_positions() for model in self.models
            ]),
            num_shards=self.args.num_shards,
            shard_id=self.args.shard_id,
            ignore_invalid_inputs=True,
        ).next_epoch_itr(shuffle=False)

        # Evaluate the sentence and return the fluency score.
        results = self.scorer.score_batched_itr(itr, cuda=self.use_cuda)
        for _, _, _, hypos in results:
            for hypo in hypos:
                # Ignore words with infinite probability. This can happen when
                # running low-precision inference on the GPU.
                pos_scores = hypo['positional_scores']
                word_prob = [score for score in pos_scores if score != float('-inf') and score != float('inf')]
                return self._fluency_score(word_prob)
        return 0.0

    def _fluency_score(self, word_prob):
        """Calculate fluency score.
        Given the list of log-probabilities for each token, calculate the
        fluency score of the sentence.
        """

        # If there were no tokens because they were all filtered out for
        # having infinite probabilites, then give a minimum fluency score.
        if len(word_prob) == 0:
            return 0.0

        H = 0.0
        for x in word_prob:
            H -= x
        H = H / len(word_prob)
        score = 1.0 / (1.0 + H)
        return score
Exemplo n.º 17
0
def main(args):
    assert args.path is not None, '--path required for evaluation!'

    args.tokens_per_sample = getattr(args, 'tokens_per_sample', 1024)
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset,
                                       len(task.dataset(args.gen_subset))))

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task)

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        model.make_generation_fast_()
        if args.fp16:
            model.half()

    assert len(models) > 0

    itr = data.EpochBatchIterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens or 36000,
        max_sentences=args.max_sentences,
        max_positions=models[0].max_positions(),
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        ignore_invalid_inputs=True,
    ).next_epoch_itr(shuffle=False)

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(models, task.target_dictionary)
    if use_cuda:
        scorer.cuda()

    score_sum = 0.
    count = 0

    if args.remove_bpe is not None:
        bpe_cont = args.remove_bpe.rstrip()
        bpe_toks = set(i for i in range(len(task.dictionary))
                       if task.dictionary[i].endswith(bpe_cont))
        bpe_len = len(bpe_cont)
    else:
        bpe_toks = None
        bpe_len = 0

    with progress_bar.build_progress_bar(args, itr) as t:
        results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
        wps_meter = TimeMeter()
        for _, src_tokens, __, hypos in results:
            for hypo in hypos:
                pos_scores = hypo['positional_scores']

                skipped_toks = 0
                if bpe_toks is not None:
                    for i in range(len(hypo['tokens']) - 1):
                        if hypo['tokens'][i].item() in bpe_toks:
                            skipped_toks += 1
                            pos_scores[i + 1] += pos_scores[i]
                            pos_scores[i] = 0

                inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(
                    float('-inf'))
                if inf_scores.any():
                    print(
                        '| Skipping tokens with inf scores:',
                        task.target_dictionary.string(
                            hypo['tokens'][inf_scores.nonzero()]))
                    pos_scores = pos_scores[(~inf_scores).nonzero()]
                score_sum += pos_scores.sum()
                count += pos_scores.numel() - skipped_toks

                if args.output_word_probs:
                    w = ''
                    word_prob = []
                    for i in range(len(hypo['tokens'])):
                        w_ind = hypo['tokens'][i].item()
                        w += task.dictionary[w_ind]
                        if bpe_toks is not None and w_ind in bpe_toks:
                            w = w[:-bpe_len]
                        else:
                            word_prob.append((w, pos_scores[i].item()))
                            w = ''
                    print('\t'.join('{} [{:2f}]'.format(x[0], x[1])
                                    for x in word_prob))

            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})

    avg_nll_loss = -score_sum / count
    print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(
        gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss,
                                                      np.exp(avg_nll_loss)))
Exemplo n.º 18
0
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides))

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]
        ),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    if args.score_reference:
        translator = SequenceScorer(models, task.target_dictionary)
    else:
        translator = SequenceGenerator(
            models, task.target_dictionary, beam_size=args.beam, minlen=args.min_len,
            stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized),
            len_penalty=args.lenpen, unk_penalty=args.unkpen,
            sampling=args.sampling, sampling_topk=args.sampling_topk, sampling_temperature=args.sampling_temperature,
            diverse_beam_groups=args.diverse_beam_groups, diverse_beam_strength=args.diverse_beam_strength,
        )

    if use_cuda:
        translator.cuda()

    # Generate and compute BLEU score
    scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    with progress_bar.build_progress_bar(args, itr) as t:
        if args.score_reference:
            translations = translator.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
        else:
            translations = translator.generate_batched_itr(
                t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
                cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size,
            )

        wps_meter = TimeMeter()
        for sample_id, src_tokens, target_tokens, hypos in translations:
            # Process input and ground truth
            has_target = target_tokens is not None
            target_tokens = target_tokens.int().cpu() if has_target else None

            # Either retrieve the original sentences or regenerate them from tokens.
            if align_dict is not None:
                src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
                target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id)
            else:
                src_str = src_dict.string(src_tokens, args.remove_bpe)
                if has_target:
                    target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True)

            if not args.quiet:
                print('S-{}\t{}'.format(sample_id, src_str))
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str))

            # Process top predictions
            for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe,
                )

                if not args.quiet:
                    print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str))
                    print('P-{}\t{}'.format(
                        sample_id,
                        ' '.join(map(
                            lambda x: '{:.4f}'.format(x),
                            hypo['positional_scores'].tolist(),
                        ))
                    ))

                    if args.print_alignment:
                        print('A-{}\t{}'.format(
                            sample_id,
                            ' '.join(map(lambda x: str(utils.item(x)), alignment))
                        ))

                # Score only the top hypothesis
                if has_target and i == 0:
                    if align_dict is not None or args.remove_bpe is not None:
                        # Convert back to tokens for evaluation with unk replacement and/or without BPE
                        target_tokens = tokenizer.Tokenizer.tokenize(
                            target_str, tgt_dict, add_if_not_exist=True)
                    scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += 1

    print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
        num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
Exemplo n.º 19
0
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(
        args.data, args.gen_subset, len(task.dataset(args.gen_subset))))

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _ = utils.load_ensemble_for_inference(args.path.split(
        ':'), task, model_arg_overrides=eval(args.model_overrides))

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]
        ),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    if args.score_reference:
        translator = SequenceScorer(models, task.target_dictionary)
    else:
        translator = SequenceGenerator(
            models, task.target_dictionary, beam_size=args.beam, minlen=args.min_len,
            stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized),
            len_penalty=args.lenpen, unk_penalty=args.unkpen,
            sampling=args.sampling, sampling_topk=args.sampling_topk, sampling_temperature=args.sampling_temperature,
            diverse_beam_groups=args.diverse_beam_groups, diverse_beam_strength=args.diverse_beam_strength,
        )

    if use_cuda:
        translator.cuda()

    # Initialize fluency scorer (and language model)
    fluency_scorer = FluencyScorer(
        args.lang_model_path, args.lang_model_data, use_cpu=False)

    en_filename = os.path.join(args.out_dir, 'errorgen.en')
    gec_filename = os.path.join(args.out_dir, 'errorgen.gec')
    has_target = True
    with progress_bar.build_progress_bar(args, itr) as t, open(en_filename, 'w') as en_file, open(gec_filename, 'w') as gec_file:
        if args.score_reference:
            translations = translator.score_batched_itr(
                t, cuda=use_cuda, timer=gen_timer)
        else:
            translations = translator.generate_batched_itr(
                t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
                cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size,
            )

        for sample_id, src_tokens, target_tokens, hypos in translations:
            # Process input and ground truth
            has_target = target_tokens is not None
            target_tokens = target_tokens.int().cpu() if has_target else None

            # Either retrieve the original sentences or regenerate them from tokens.
            if align_dict is not None:
                src_str = task.dataset(
                    args.gen_subset).src.get_original_text(sample_id)
                target_str = task.dataset(
                    args.gen_subset).tgt.get_original_text(sample_id)
            else:
                src_str = src_dict.string(src_tokens, args.remove_bpe)
                if has_target:
                    target_str = tgt_dict.string(
                        target_tokens, args.remove_bpe, escape_unk=True)

            # Only consider sentences with at least four words.
            if len(src_tokens) < 5:
                continue

            # Calculate the fluency score for the source sentence
            source_fluency = fluency_scorer.score_sentence(src_str)

            # Process top predictions
            for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'].int().cpu(
                    ) if hypo['alignment'] is not None else None,
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe,
                )

                # Skip if this is the original sentence.
                if hypo_str == target_str:
                    continue

                # Score the hypothesis.
                hypo_fluency = fluency_scorer.score_sentence(hypo_str)

                # Save the hypothesis if it is sufficiently disfluent.
                if (source_fluency / hypo_fluency) > 1.05:
                    en_file.write('{}\n'.format(hypo_str))
                    gec_file.write('{}\n'.format(src_str))
Exemplo n.º 20
0
def main(parsed_args):
    assert parsed_args.path is not None, '--path required for evaluation!'

    print(parsed_args)

    use_cuda = torch.cuda.is_available() and not parsed_args.cpu

    task = tasks.setup_task(parsed_args)

    # Load ensemble
    print('| loading model(s) from {}'.format(parsed_args.path))
    models, args = utils.load_ensemble_for_inference(
        parsed_args.path.split(':'),
        task,
        model_arg_overrides=eval(parsed_args.model_overrides))

    for arg in vars(parsed_args).keys():
        if arg not in {
                'self_target', 'future_target', 'past_target',
                'tokens_per_sample', 'output_size_dictionary'
        }:
            setattr(args, arg, getattr(parsed_args, arg))
    task = tasks.setup_task(args)

    # Load dataset splits
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset,
                                       len(task.dataset(args.gen_subset))))

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        model.make_generation_fast_()
        if args.fp16:
            model.half()

    assert len(models) > 0

    print('num. model params: {}'.format(
        sum(p.numel() for p in models[0].parameters())))

    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens or 36000,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            *[model.max_positions() for model in models]),
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        ignore_invalid_inputs=True,
    ).next_epoch_itr(shuffle=False)

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(models, task.target_dictionary)
    if use_cuda:
        scorer.cuda()

    score_sum = 0.
    count = 0

    if args.remove_bpe is not None:
        bpe_cont = args.remove_bpe.rstrip()
        bpe_toks = set(i for i in range(len(task.dictionary))
                       if task.dictionary[i].endswith(bpe_cont))
        bpe_len = len(bpe_cont)
    else:
        bpe_toks = None
        bpe_len = 0

    word_stats = dict()

    with progress_bar.build_progress_bar(args, itr) as t:
        results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
        wps_meter = TimeMeter()
        for _, src_tokens, __, hypos in results:
            for hypo in hypos:
                pos_scores = hypo['positional_scores']

                skipped_toks = 0
                if bpe_toks is not None:
                    for i in range(len(hypo['tokens']) - 1):
                        if hypo['tokens'][i].item() in bpe_toks:
                            skipped_toks += 1
                            pos_scores[i + 1] += pos_scores[i]
                            pos_scores[i] = 0

                inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(
                    float('-inf'))
                if inf_scores.any():
                    print(
                        '| Skipping tokens with inf scores:',
                        task.target_dictionary.string(
                            hypo['tokens'][inf_scores.nonzero()]))
                    pos_scores = pos_scores[(~inf_scores).nonzero()]
                score_sum += pos_scores.sum().cpu()
                count += pos_scores.numel() - skipped_toks

                if args.output_word_probs or args.output_word_stats:
                    w = ''
                    word_prob = []
                    is_bpe = False
                    for i in range(len(hypo['tokens'])):
                        w_ind = hypo['tokens'][i].item()
                        w += task.dictionary[w_ind]
                        if bpe_toks is not None and w_ind in bpe_toks:
                            w = w[:-bpe_len]
                            is_bpe = True
                        else:
                            word_prob.append((w, pos_scores[i].item()))

                            next_prob = None
                            ind = i + 1
                            while ind < len(hypo['tokens']):
                                if pos_scores[ind].item() != 0:
                                    next_prob = pos_scores[ind]
                                    break
                                ind += 1

                            word_stats.setdefault(w, WordStat(w, is_bpe)).add(
                                pos_scores[i].item(), next_prob)
                            is_bpe = False
                            w = ''
                    if args.output_word_probs:
                        print('\t'.join('{} [{:2f}]'.format(x[0], x[1])
                                        for x in word_prob))

            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})

    avg_nll_loss = -score_sum / count
    print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(
        gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss,
                                                      np.exp(avg_nll_loss)))

    if args.output_word_stats:
        for ws in sorted(word_stats.values(),
                         key=lambda x: x.count,
                         reverse=True):
            print(ws)
Exemplo n.º 21
0
def eval_from_file(models, task, args, use_cuda, source_filename=None,
                   target_filename=None, score_filename=None):
    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # I/O files
    source_filename = source_filename if source_filename is not None else args.source_file
    target_filename = target_filename if target_filename is not None else args.target_file
    score_filename = score_filename if score_filename is not None else args.score_file
    if score_filename is None:
        score_filename = target_filename + ".eval.score"
    outfile = open(score_filename, "w")

    # Get sorted input (and reversed)
    sorted_inputs, sorted_keys, sorted_targets = _get_sorted_inputs(
        source_filename, args.num_shards, args.delimiter, target_filename, args.shard_id,
        args.dup_src, args.dup_tgt)

    # Build input iterator
    src_tokens = [
        tokenizer.Tokenizer.tokenize(src_str, src_dict, add_if_not_exist=False).long()
        for src_str in sorted_inputs]
    tgt_tokens = [
        tokenizer.Tokenizer.tokenize(tgt_str, tgt_dict, add_if_not_exist=False).long()
        for tgt_str in sorted_targets] if sorted_targets is not None else None
    src_sizes = np.array([t.numel() for t in src_tokens])
    tgt_sizes = np.array([t.numel() for t in tgt_tokens])
    print('| loading {} examples, {} tokens'.format(len(sorted_inputs), sum(src_sizes)))

    dataset = data.LanguagePairDataset(
        src_tokens, src_sizes, src_dict, tgt_tokens, tgt_sizes, tgt_dict, shuffle=False)
    itr = data.EpochBatchIterator(
        dataset=dataset,
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=models[0].max_positions(),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
    ).next_epoch_itr(shuffle=False)

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(models, task.target_dictionary)
    if use_cuda:
        scorer.cuda()

    all_scores = dict()
    score_sum = 0.
    count, sen_count = 0, 0
    results = scorer.score_batched_itr(itr, cuda=use_cuda, timer=gen_timer)
    wps_meter = TimeMeter()
    for sample_id, src_tokens, target_tokens, hypos in results:
        for i, hypo in enumerate(hypos):
            pos_scores = hypo['positional_scores']
            inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
            if inf_scores.any():
                print('| Skipping tokens with inf scores:',
                      task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()]))
                pos_scores = pos_scores[(~inf_scores).nonzero()]
            score_sum += pos_scores.sum()
            count += pos_scores.numel()
            sentence_score = hypo['score']
            if i == 0:
                all_scores[sample_id.tolist()] = sentence_score
        sen_count += 1
        wps_meter.update(src_tokens.size(0))

    print("| [eval] writing scores into {}".format(score_filename))
    # print(sids)
    for index in range(len(sorted_inputs)):
        outfile.write("{}{}".format(all_scores[sorted_keys[index]], args.delimiter))
    outfile.close()

    avg_nll_loss = -score_sum / count
    print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss)))