Ejemplo n.º 1
0
 def score_batched_itr(self, data_itr, cuda=False, timer=None):
     """Iterate over a batched dataset and yield scored translations."""
     for sample in data_itr:
         s = utils.move_to_cuda(sample) if cuda else sample
         if timer is not None:
             timer.start()
         pos_scores, attn = self.score(s)
         for i, id in enumerate(s['id'].data):
             # remove padding from ref
             src = utils.strip_pad(s['net_input']['src_tokens'].data[i, :], self.pad)
             ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None
             tgt_len = ref.numel()
             pos_scores_i = pos_scores[i][:tgt_len]
             score_i = pos_scores_i.sum() / tgt_len
             if attn is not None:
                 attn_i = attn[i]
                 _, alignment = attn_i.max(dim=0)
             else:
                 attn_i = alignment = None
             hypos = [{
                 'tokens': ref,
                 'score': score_i,
                 'attention': attn_i,
                 'alignment': alignment,
                 'positional_scores': pos_scores_i,
             }]
             if timer is not None:
                 timer.stop(s['ntokens'])
             # return results in the same format as SequenceGenerator
             yield id, src, ref, hypos
Ejemplo n.º 2
0
    def generate_batched_itr(
        self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=None,
        cuda=False, timer=None, prefix_size=0,
    ):
        """Iterate over a batched dataset and yield individual translations.
        Args:
            maxlen_a/b: generate sequences of maximum length ax + b,
                where x is the source sentence length.
            cuda: use GPU for generation
            timer: StopwatchMeter for timing generations.
        """
        if maxlen_b is None:
            maxlen_b = self.maxlen

        for sample in data_itr:
            s = utils.move_to_cuda(sample) if cuda else sample
            if 'net_input' not in s:
                continue
            input = s['net_input']
            srclen = input['src_tokens'].size(1)
            if timer is not None:
                timer.start()
            with torch.no_grad():
                hypos = self.generate(
                    input['src_tokens'],
                    input['src_lengths'],
                    beam_size=beam_size,
                    maxlen=int(maxlen_a*srclen + maxlen_b),
                    prefix_tokens=s['target'][:, :prefix_size] if prefix_size > 0 else None,
                )
            if timer is not None:
                timer.stop(sum(len(h[0]['tokens']) for h in hypos))
            for i, id in enumerate(s['id'].data):
                # remove padding
                src = utils.strip_pad(input['src_tokens'].data[i, :], self.pad)
                ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None
                yield id, src, ref, hypos[i]
Ejemplo n.º 3
0
    def valid_step(self, sample, raise_oom=False, domain=None):
        """Do forward pass in evaluation mode."""
        with torch.no_grad():
            self.model.eval()
            self.criterion.eval()
            sample = self._prepare_sample(sample)
            if sample is None:
                sample = self._prepare_sample(self._dummy_batch)
                ignore_results = True
            else:
                ignore_results = False

            try:
                _loss, sample_size, logging_output = self.task.valid_step(
                    sample,
                    self.model,
                    self.criterion,
                    domain=domain,
                    num_updates=self._num_updates)
                src_target_hypo_str = []
                prefix_tokens = None
                if self.args.prefix_size > 0:
                    prefix_tokens = sample['target'][:, :self.args.prefix_size]

                hypos = self.task.inference_step(self._generator, [self.model],
                                                 sample,
                                                 prefix_tokens,
                                                 domain=domain)
                for i, sample_id in enumerate(sample['id'].tolist()):
                    assert self.args.nbest == 1, 'bleu valid needs nbest == 1'

                    has_target = sample['target'] is not None

                    # Remove padding
                    src_tokens = utils.strip_pad(
                        sample['net_input']['src_tokens'][i, :],
                        self._tgt_dict.pad())
                    target_tokens = None
                    if has_target:
                        target_tokens = utils.strip_pad(
                            sample['target'][i, :],
                            self._tgt_dict.pad()).int().cpu()

                    # Either retrieve the original sentences or regenerate them from tokens.

                    src_str = self._src_dict.string(src_tokens,
                                                    self.args.remove_bpe)

                    if has_target:
                        target_str = self._tgt_dict.string(
                            target_tokens,
                            self.args.remove_bpe,
                            escape_unk=True)

                    # Process top predictions
                    for i, hypo in enumerate(
                            hypos[i][:min(len(hypos), self.args.nbest)]):

                        hypo_str = self._tgt_dict.string(
                            hypo['tokens'].int().cpu(), self.args.remove_bpe)

                        src_target_hypo_str.append(
                            (sample_id, src_str, target_str, hypo_str))

            except RuntimeError as e:
                if 'out of memory' in str(e) and not raise_oom:
                    print('| WARNING: ran out of memory, retrying batch')
                    for p in self.model.parameters():
                        if p.grad is not None:
                            del p.grad  # free some memory
                    if self.cuda:
                        torch.cuda.empty_cache()
                    return self.valid_step(sample, raise_oom=True)
                else:
                    raise e

            if ignore_results:
                logging_output, sample_size = {}, 0
                src_target_hypo_str = []

        # gather logging outputs from all replicas
        if self.args.distributed_world_size > 1:
            logging_output, sample_size, src_target_hypo_str = zip(
                *distributed_utils.all_gather_list_big(
                    [logging_output, sample_size, src_target_hypo_str], ))

            logging_output = list(logging_output)
            sample_size = list(sample_size)
            src_target_hypo_str = src_target_hypo_str

        else:
            logging_output = [logging_output]
            sample_size = [sample_size]
            src_target_hypo_str = [src_target_hypo_str]

        # aggregate logging outputs and sample sizes
        logging_output = self.task.aggregate_logging_outputs(
            logging_output, self.criterion)
        sample_size = self.task.grad_denom(sample_size, self.criterion)

        # update meters for validation
        ntokens = logging_output.get('ntokens', 0)

        self.meters['valid_loss_' + domain].update(
            logging_output.get('loss', 0), sample_size)

        if 'nll_loss' in logging_output:
            self.meters['valid_nll_loss_' + domain].update(
                logging_output.get('nll_loss', 0), ntokens)

        return logging_output, src_target_hypo_str
Ejemplo n.º 4
0
    def _inference_with_bleu(self, generator, sample, model):
        import sacrebleu

        def decode(toks, escape_unk=False, dicts=self.tgt_dict):
            s = dicts.string(
                toks.int().cpu(),
                self.args.eval_bleu_remove_bpe,
                # The default unknown string in fairseq is `<unk>`, but
                # this is tokenized by sacrebleu as `< unk >`, inflating
                # BLEU scores. Instead, we use a somewhat more verbose
                # alternative that is unlikely to appear in the real
                # reference, but doesn't get split into multiple tokens.
                unk_string=("UNKNOWNTOKENINREF"
                            if escape_unk else "UNKNOWNTOKENINHYP"),
            )
            if self.tokenizer:
                s = self.tokenizer.decode(s)
            return s

        prefix_tokens = None
        # if self.args.prefix_lang > 0:
        #     prefix_tokens = sample["target"][:, :1]
        gen_out = self.inference_step(generator, [model],
                                      sample,
                                      prefix_tokens=prefix_tokens)

        hyps, refs = [], []
        srcs = []
        for i in range(len(gen_out)):
            hyps_item = decode(gen_out[i][0]["tokens"])
            hyps_item_bpe = self.remove_sentencepiece(
                hyps_item, (self.args.bpe == 'sentencepiece'))
            if hyps_item_bpe[0] == "[":  # rm lang-tok id.
                hyps_item_bpe_split = " ".join(hyps_item_bpe.split(" ")[1:])
            else:
                hyps_item_bpe_split = hyps_item_bpe
            hyps.append(hyps_item_bpe_split)

            lang_tokens = ""
            ref_item = decode(
                utils.strip_pad(sample["target"][i], self.tgt_dict.pad()),
                escape_unk=True,  # don't count <unk> as matches to the hypo
            )
            ref_item_bpe = self.remove_sentencepiece(
                ref_item, (self.args.bpe == 'sentencepiece'))

            if ref_item_bpe[0] == "[":  # rm lang-tok id.
                ref_item_bpe = ref_item_bpe.split(" ")
                ref_item_bpe_split = " ".join(ref_item_bpe[1:])
                lang_tokens = ref_item_bpe[0]
            else:
                ref_item_bpe_split = ref_item_bpe
            refs.append(ref_item_bpe_split)

            src_item = decode(
                utils.strip_pad(sample["net_input"]["src_tokens"][i],
                                self.tgt_dict.pad()),
                escape_unk=True,  # don't count <unk> as matches to the hypo
                dicts=self.src_dict,
            )
            src_item_bpe = self.remove_sentencepiece(
                src_item, (self.args.bpe == 'sentencepiece'))
            srcs.append(lang_tokens + "| |" + src_item_bpe)

        if self.args.eval_bleu_print_samples:
            logger.info("example sources: " + srcs[0])
            logger.info("example hypothesis: " + hyps[0])
            logger.info("example reference: " + refs[0])

        if self.args.eval_tokenized_bleu:
            return sacrebleu.corpus_bleu(hyps, [refs], tokenize="none")
        else:
            if hasattr(self.args,
                       "target_lang") and self.args.target_lang == 'zh':
                return sacrebleu.corpus_bleu(hyps, [refs], tokenize="zh")
            return sacrebleu.corpus_bleu(hyps, [refs])
Ejemplo n.º 5
0
def main(args, task=None, model_state=None):
    check_args(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 4000000
    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    if task is None:
        # Load dataset splits
        task = tasks.setup_task(args)
        task.load_dataset(args.gen_subset)
        logger.info("| {} {} {} examples".format(
            args.data, args.gen_subset, len(task.dataset(args.gen_subset))))

    # Set dictionary
    tgt_dict = task.target_dictionary

    logger.info("| decoding with criterion {}".format(args.criterion))

    # Load ensemble
    print(args)
    if args.load_emissions:
        models, criterions = [], []
    else:
        logger.info("| loading model(s) from {}".format(args.path))
        models, criterions, _ = load_models_and_criterions(
            args.path,
            data_path=args.data,
            arg_overrides=eval(args.model_overrides),  # noqa
            task=task,
            model_state=model_state,
        )
        optimize_models(args, use_cuda, models)

    # hack to pass transitions to W2lDecoder
    if args.criterion == "asg_loss":
        trans = criterions[0].asg.trans.data
        args.asg_transitions = torch.flatten(trans).tolist()

    # Load dataset (possibly sharded)
    itr = get_dataset_itr(args, task, models)

    # Initialize generator
    gen_timer = StopwatchMeter()

    def build_generator(args):
        w2l_decoder = getattr(args, "w2l_decoder", None)
        if w2l_decoder == "viterbi":
            from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder

            return W2lViterbiDecoder(args, task.target_dictionary)
        elif w2l_decoder == "kenlm":
            from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
            print(task.target_dictionary.symbols)
            return W2lKenLMDecoder(args, task.target_dictionary)
        elif w2l_decoder == "fairseqlm":
            from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder
            print(task.target_dictionary.symbols)
            return W2lFairseqLMDecoder(args, task.target_dictionary)
        else:
            return super().build_generator(args)

    generator = build_generator(args)

    if args.load_emissions:
        generator = ExistingEmissionsDecoder(
            generator, np.load(args.load_emissions, allow_pickle=True))
        logger.info("loaded emissions from " + args.load_emissions)

    num_sentences = 0

    if args.results_path is not None and not os.path.exists(args.results_path):
        os.makedirs(args.results_path)

    max_source_pos = (utils.resolve_max_positions(
        task.max_positions(), *[model.max_positions() for model in models]), )

    if max_source_pos is not None:
        max_source_pos = max_source_pos[0]
        if max_source_pos is not None:
            max_source_pos = max_source_pos[0] - 1

    if args.dump_emissions:
        emissions = {}
    if args.dump_features:
        features = {}
        models[0].bert.proj = None
    else:
        res_files = prepare_result_files(args)
    errs_t = 0
    lengths_t = 0
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if "net_input" not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample["target"][:, :args.prefix_size]

            gen_timer.start()
            if args.dump_emissions:
                with torch.no_grad():
                    encoder_out = models[0](**sample["net_input"])
                    emm = models[0].get_normalized_probs(encoder_out,
                                                         log_probs=True)
                    emm = emm.transpose(0, 1).cpu().numpy()
                    for i, id in enumerate(sample["id"]):
                        emissions[id.item()] = emm[i]
                    continue
            elif args.dump_features:
                with torch.no_grad():
                    encoder_out = models[0](**sample["net_input"])
                    feat = encoder_out["encoder_out"].transpose(
                        0, 1).cpu().numpy()
                    for i, id in enumerate(sample["id"]):
                        padding = encoder_out["encoder_padding_mask"][i].cpu(
                        ).numpy() if encoder_out[
                            "encoder_padding_mask"] is not None else None
                        features[id.item()] = (feat[i], padding)
                    continue
            hypos = task.inference_step(generator, models, sample,
                                        prefix_tokens)
            num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample["id"].tolist()):
                speaker = None
                # id = task.dataset(args.gen_subset).ids[int(sample_id)]
                id = sample_id
                #print(sample)
                toks = sample["target"][
                    i, :] if 'target_label' not in sample else sample[
                        "target_label"][i, :]
                target_tokens = (utils.strip_pad(toks,
                                                 tgt_dict.pad()).int().cpu())
                # Process top predictions
                errs, length = process_predictions(args, hypos[i], None,
                                                   tgt_dict, target_tokens,
                                                   res_files, speaker, id)
                errs_t += errs
                lengths_t += length

            wps_meter.update(num_generated_tokens)
            t.log({"wps": round(wps_meter.avg)})
            num_sentences += sample[
                "nsentences"] if "nsentences" in sample else sample[
                    "id"].numel()

    wer = None
    if args.dump_emissions:
        emm_arr = []
        for i in range(len(emissions)):
            emm_arr.append(emissions[i])
        np.save(args.dump_emissions, emm_arr)
        logger.info(
            f"saved {len(emissions)} emissions to {args.dump_emissions}")
    elif args.dump_features:
        feat_arr = []
        for i in range(len(features)):
            feat_arr.append(features[i])
        np.save(args.dump_features, feat_arr)
        logger.info(f"saved {len(features)} emissions to {args.dump_features}")
    else:
        if lengths_t > 0:
            wer = errs_t * 100.0 / lengths_t
            logger.info(f"WER: {wer}")
            print(f"WER: {wer}")

        logger.info("| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}"
                    "sentences/s, {:.2f} tokens/s)".format(
                        num_sentences,
                        gen_timer.n,
                        gen_timer.sum,
                        num_sentences / gen_timer.sum,
                        1.0 / gen_timer.avg,
                    ))
        logger.info("| Generate {} with beam={}".format(
            args.gen_subset, args.beam))
    return task, wer
Ejemplo n.º 6
0
def main(cfg: FairseqConfig):
    if isinstance(cfg, Namespace):
        cfg = convert_namespace_to_omegaconf(cfg)

    start_time = time.time()
    total_translate_time = 0

    utils.import_user_module(cfg.common)

    if cfg.interactive.buffer_size < 1:
        cfg.interactive.buffer_size = 1
    if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None:
        cfg.dataset.batch_size = 1

    assert (not cfg.generation.sampling
            or cfg.generation.nbest == cfg.generation.beam
            ), "--sampling requires --nbest to be equal to --beam"
    assert (not cfg.dataset.batch_size
            or cfg.dataset.batch_size <= cfg.interactive.buffer_size
            ), "--batch-size cannot be larger than --buffer-size"

    logger.info(cfg)

    # Fix seed for stochastic decoding
    if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
        np.random.seed(cfg.common.seed)
        utils.set_torch_seed(cfg.common.seed)

    use_cuda = torch.cuda.is_available() and not cfg.common.cpu

    # Setup task, e.g., translation
    task = tasks.setup_task(cfg.task)

    # Load ensemble
    overrides = ast.literal_eval(cfg.common_eval.model_overrides)
    logger.info("loading model(s) from {}".format(cfg.common_eval.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        utils.split_paths(cfg.common_eval.path),
        arg_overrides=overrides,
        task=task,
        suffix=cfg.checkpoint.checkpoint_suffix,
        strict=(cfg.checkpoint.checkpoint_shard_count == 1),
        num_shards=cfg.checkpoint.checkpoint_shard_count,
    )

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # Optimize ensemble for generation
    for model in models:
        if model is None:
            continue
        if cfg.common.fp16:
            model.half()
        if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
            model.cuda()
        model.prepare_for_inference_(cfg)

    # Initialize generator
    generator = task.build_generator(models, cfg.generation)

    # Handle tokenization and BPE
    tokenizer = task.build_tokenizer(cfg.tokenizer)
    bpe = task.build_bpe(cfg.bpe)

    def encode_fn(x):
        if tokenizer is not None:
            x = tokenizer.encode(x)
        if bpe is not None:
            x = bpe.encode(x)
        return x

    def decode_fn(x):
        if bpe is not None:
            x = bpe.decode(x)
        if tokenizer is not None:
            x = tokenizer.decode(x)
        return x

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(cfg.generation.replace_unk)

    max_positions = utils.resolve_max_positions(
        task.max_positions(), *[model.max_positions() for model in models])

    if cfg.generation.constraints:
        logger.warning(
            "NOTE: Constrained decoding currently assumes a shared subword vocabulary."
        )

    if cfg.interactive.buffer_size > 1:
        logger.info("Sentence buffer size: %s", cfg.interactive.buffer_size)
    logger.info("NOTE: hypothesis and token scores are output in base 2")
    logger.info("Type the input sentence and press return:")
    start_id = 0
    for inputs in buffered_read(cfg.interactive.input,
                                cfg.interactive.buffer_size):
        results = []
        for batch in make_batches(inputs, cfg, task, max_positions, encode_fn):
            bsz = batch.src_tokens.size(0)
            src_tokens = batch.src_tokens
            src_lengths = batch.src_lengths
            constraints = batch.constraints
            if use_cuda:
                src_tokens = src_tokens.cuda()
                src_lengths = src_lengths.cuda()
                if constraints is not None:
                    constraints = constraints.cuda()

            sample = {
                "net_input": {
                    "src_tokens": src_tokens,
                    "src_lengths": src_lengths,
                },
            }
            translate_start_time = time.time()
            translations = task.inference_step(generator,
                                               models,
                                               sample,
                                               constraints=constraints)
            translate_time = time.time() - translate_start_time
            total_translate_time += translate_time
            list_constraints = [[] for _ in range(bsz)]
            if cfg.generation.constraints:
                list_constraints = [unpack_constraints(c) for c in constraints]
            for i, (id,
                    hypos) in enumerate(zip(batch.ids.tolist(), translations)):
                src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad())
                constraints = list_constraints[i]
                results.append((
                    start_id + id,
                    src_tokens_i,
                    hypos,
                    {
                        "constraints": constraints,
                        "time": translate_time / len(translations),
                    },
                ))

        # sort output to match input order
        for id_, src_tokens, hypos, info in sorted(results,
                                                   key=lambda x: x[0]):
            src_str = ''
            if src_dict is not None:
                src_str = src_dict.string(src_tokens,
                                          cfg.common_eval.post_process)
                print("S-{}\t{}".format(id_, src_str))
                print("W-{}\t{:.3f}\tseconds".format(id_, info["time"]))
                for constraint in info["constraints"]:
                    print("C-{}\t{}".format(
                        id_,
                        tgt_dict.string(constraint,
                                        cfg.common_eval.post_process)))

            # Process top predictions
            for hypo in hypos[:min(len(hypos), cfg.generation.nbest)]:
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo["tokens"].int().cpu(),
                    src_str=src_str,
                    alignment=hypo["alignment"],
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=cfg.common_eval.post_process,
                    extra_symbols_to_ignore=get_symbols_to_strip_from_output(
                        generator),
                )
                detok_hypo_str = decode_fn(hypo_str)
                score = hypo["score"] / math.log(2)  # convert to base 2
                # original hypothesis (after tokenization and BPE)
                print("H-{}\t{}\t{}".format(id_, score, hypo_str))
                # detokenized hypothesis
                print("D-{}\t{}\t{}".format(id_, score, detok_hypo_str))
                print("P-{}\t{}".format(
                    id_,
                    " ".join(
                        map(
                            lambda x: "{:.4f}".format(x),
                            # convert from base e to base 2
                            hypo["positional_scores"].div_(math.log(2)
                                                           ).tolist(),
                        )),
                ))
                if cfg.generation.print_alignment:
                    alignment_str = " ".join(
                        ["{}-{}".format(src, tgt) for src, tgt in alignment])
                    print("A-{}\t{}".format(id_, alignment_str))

        # update running id_ counter
        start_id += len(inputs)

    logger.info("Total time: {:.3f} seconds; translation time: {:.3f}".format(
        time.time() - start_time, total_translate_time))
Ejemplo n.º 7
0
    def infer(self, lines_of_text):
        context = self.context

        bpe = context['bpe']
        tokenizer = context['tokenizer']
        cfg = context['cfg']
        task = context['task']
        max_positions = context['max_positions']
        use_cuda = context['use_cuda']
        generator = context['generator']
        models = context['models']
        src_dict = context['src_dict']
        tgt_dict = context['tgt_dict']
        align_dict = context['align_dict']

        def encode_fn(x):
            if tokenizer is not None:
                x = tokenizer.encode(x)
            if bpe is not None:
                x = bpe.encode(x)
            return x

        def decode_fn(x):
            if bpe is not None:
                x = bpe.decode(x)
            if tokenizer is not None:
                x = tokenizer.decode(x)
            return x

        start_id = 0
        for inputs in [lines_of_text]:
            results = []
            for batch in make_batches(inputs, cfg, task, max_positions,
                                      encode_fn):
                src_tokens = batch.src_tokens
                src_lengths = batch.src_lengths
                if use_cuda:
                    src_tokens = src_tokens.cuda()
                    src_lengths = src_lengths.cuda()
                sample = {
                    "net_input": {
                        "src_tokens": src_tokens,
                        "src_lengths": src_lengths,
                    },
                }
                translations = task.inference_step(generator, models, sample)
                for i, (id, hypos) in enumerate(
                        zip(batch.ids.tolist(), translations)):
                    src_tokens_i = utils.strip_pad(src_tokens[i],
                                                   tgt_dict.pad())
                    results.append((
                        start_id + id,
                        src_tokens_i,
                        hypos,
                    ))

            # sort output to match input order
            for id_, src_tokens, hypos in sorted(results, key=lambda x: x[0]):
                if src_dict is not None:
                    src_str = src_dict.string(src_tokens,
                                              cfg.common_eval.post_process)

                # Process top predictions
                for hypo in hypos[:min(len(hypos), cfg.generation.nbest)]:
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo["tokens"].int().cpu(),
                        src_str=src_str,
                        alignment=hypo["alignment"],
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=cfg.common_eval.post_process,
                        extra_symbols_to_ignore=
                        get_symbols_to_strip_from_output(generator),
                    )
                    detok_hypo_str = decode_fn(hypo_str)
                    score = hypo["score"] / math.log(2)  # convert to base 2
                    #print(detok_hypo_str, hypo_str, hypo_tokens)
                    yield (detok_hypo_str, hypo_str, hypo_tokens)

            # update running id_ counter
            start_id += len(inputs)
Ejemplo n.º 8
0
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.
        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        lm_logits, output_metadata = model(**sample["net_input"])

        # reshape lm_logits from (N,T,C) to (N*T,C)
        lm_logits = lm_logits.view(-1, lm_logits.size(-1))
        lm_targets = sample['lm_target'].view(-1)
        lm_loss = compute_cross_entropy_loss(
            lm_logits, lm_targets, self.padding_idx)

        # compute the number of tokens for which loss is computed. This is used
        # to normalize the loss
        ntokens = utils.strip_pad(lm_targets, self.padding_idx).numel()
        loss = lm_loss / ntokens
        nsentences = sample['nsentences']
        # nsentences = 0

        # Compute sentence loss if masked_lm_only is False
        sentence_loss = None
        if not self.args.masked_lm_only:
            sentence_logits = output_metadata['sentence_logits']
            sentence_targets = sample['sentence_target'].view(-1)
            # This needs to be recomputed due to some differences between
            # TokenBlock and BlockPair dataset. This can be resolved with a
            # refactor of BERTModel which we will do in the future.
            # TODO: Remove this after refactor of BERTModel
            nsentences = sentence_targets.size(0)

            # Check for logits being none which can happen when remove_heads
            # is set to true in the BERT model. Ideally we should set
            # masked_lm_only to true in this case, but that requires some
            # refactor in the BERT model.
            if sentence_logits is not None:
                sentence_loss = compute_cross_entropy_loss(
                    sentence_logits, sentence_targets)

                loss += self.args.nsp_loss_weight * (sentence_loss / nsentences)

        # NOTE: as we are summing up per token mlm loss and per sentence nsp loss
        # we don't need to use sample_size as denominator for the gradient
        # here sample_size is just used for logging
        sample_size = 1
        logging_output = {
            'loss': utils.item(loss.data) if reduce else loss.data,
            'lm_loss': utils.item(lm_loss.data) if reduce else lm_loss.data,
            # sentence loss is not always computed
            'sentence_loss': (
                (
                    utils.item(sentence_loss.data) if reduce
                    else sentence_loss.data
                ) if sentence_loss is not None else 0.0
            ),
            'ntokens': ntokens,
            'nsentences': nsentences,
            'sample_size': sample_size,
        }
        return loss, sample_size, logging_output
Ejemplo n.º 9
0
def main(args):
    import_user_module(args)

    if args.buffer_size < 1:
        args.buffer_size = 1
    if args.max_tokens is None and args.max_sentences is None:
        args.max_sentences = 1

    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert not args.max_sentences or args.max_sentences <= args.buffer_size, \
        '--max-sentences/--batch-size cannot be larger than --buffer-size'

    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Setup task, e.g., translation
    task = tasks.setup_task(args)

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _model_args = utils.load_ensemble_for_inference(
        args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides),
    )

    # Set dictionaries
    src_dict = task.source_dictionary
    ctx_dict = task.context_dictionary  # [CONTEXT]
    tgt_dict = task.target_dictionary

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Initialize generator
    generator = task.build_generator(args)

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    max_positions = utils.resolve_max_positions(
        task.max_positions(),
        *[model.max_positions() for model in models]
    )

    if args.buffer_size > 1:
        print('| Sentence buffer size:', args.buffer_size)
    print('| Type the input sentence and press return:')
    start_id = 0
    for inputs in buffered_read(args.input, args.buffer_size):
        results = []
        for batch in make_batches(inputs, args, task, max_positions):
            src_tokens = batch.src_tokens
            src_lengths = batch.src_lengths
            # [CONTEXT]
            ctx_tokens = batch.ctx_tokens
            ctx_lengths = batch.ctx_lengths
            if use_cuda:
                src_tokens = src_tokens.cuda()
                src_lengths = src_lengths.cuda()
                # [CONTEXT]
                ctx_tokens = ctx_tokens.cuda()
                ctx_lengths = ctx_lengths.cuda()

            sample = {
                'net_input': {
                    'src_tokens': src_tokens,
                    'src_lengths': src_lengths,
                    # [CONTEXT]
                    'ctx_tokens': ctx_tokens,
                    'ctx_lengths': ctx_lengths,
                },
            }
            translations = task.inference_step(generator, models, sample)
            for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
                ctx_tokens_i = utils.strip_pad(ctx_tokens[i], tgt_dict.pad())  # [CONTEXT]
                src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad())
                # [CONTEXT]
                # results.append((start_id + id, src_tokens_i, hypos))
                results.append((start_id + id, ctx_tokens_i, src_tokens_i, hypos))

        # sort output to match input order
        # [CONTEXT]
        # for id, src_tokens, hypos in sorted(results, key=lambda x: x[0]):
        for id, ctx_tokens, src_tokens, hypos in sorted(results, key=lambda x: x[0]):
            # [CONTEXT]
            if ctx_dict is not None:
                ctx_str = ctx_dict.string(ctx_tokens, args.remove_bpe)
                print('C-{}\t{}'.format(id, ctx_str))

            if src_dict is not None:
                src_str = src_dict.string(src_tokens, args.remove_bpe)
                print('S-{}\t{}'.format(id, src_str))

            # Process top predictions
            for hypo in hypos[:min(len(hypos), args.nbest)]:
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe,
                )
                print('H-{}\t{}\t{}'.format(id, hypo['score'], hypo_str))
                print('P-{}\t{}'.format(
                    id,
                    ' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist()))
                ))
                if args.print_alignment:
                    print('A-{}\t{}'.format(
                        id,
                        ' '.join(map(lambda x: str(utils.item(x)), alignment))
                    ))

        # update running id counter
        start_id += len(results)
Ejemplo n.º 10
0
    def generate(self, models, sample, **kwargs):
        """Score a batch of translations."""
        net_input = sample['net_input']

        def batch_for_softmax(dec_out, target):
            # assumes decoder_out[0] is the only thing needed (may not be correct for future models!)
            first, rest = dec_out[0], dec_out[1:]
            bsz, tsz, dim = first.shape
            if bsz * tsz < self.softmax_batch:
                yield dec_out, target, True
            else:
                flat = first.contiguous().view(1, -1, dim)
                flat_tgt = target.contiguous().view(flat.shape[:-1])
                s = 0
                while s < flat.size(1):
                    e = s + self.softmax_batch
                    yield (flat[:, s:e], ) + rest, flat_tgt[:, s:e], False
                    s = e

        def gather_target_probs(probs, target):
            probs = probs.gather(
                dim=2,
                index=target.unsqueeze(-1),
            )
            return probs

        def combine_knn_and_vocab_probs(knn_p, vocab_p, coeff):
            combine_probs = torch.stack([vocab_p, knn_p], dim=0)
            coeffs = torch.ones_like(combine_probs)
            coeffs[0] = np.log(1 - coeff)
            coeffs[1] = np.log(coeff)
            curr_prob = torch.logsumexp(combine_probs + coeffs, dim=0)

            return curr_prob

        orig_target = sample['target']

        # compute scores for each model in the ensemble
        avg_probs = None
        avg_attn = None
        extra = None
        for i_model, model in enumerate(models):
            assert extra is None
            model.eval()
            decoder_out = model(**net_input)
            attn = decoder_out[1]
            if type(attn) is dict:
                attn = attn.get('attn', None)

            batched = batch_for_softmax(decoder_out, orig_target)
            probs, idx = None, 0
            for i, (bd, tgt, is_single) in enumerate(batched):
                sample['target'] = tgt
                curr_prob = model.get_normalized_probs(
                    bd, log_probs=len(models) == 1, sample=sample).data

                if is_single:
                    probs = gather_target_probs(curr_prob, orig_target)
                else:
                    if probs is None:
                        probs = curr_prob.new(orig_target.numel())
                    step = curr_prob.size(0) * curr_prob.size(1)
                    end = step + idx
                    tgt_probs = gather_target_probs(
                        curr_prob.view(tgt.shape + (curr_prob.size(-1), )),
                        tgt)
                    probs[idx:end] = tgt_probs.view(-1)
                    idx = end
                sample['target'] = orig_target

            probs = probs.view(sample['target'].shape)

            extra = {}
            extra['probs'] = probs[orig_target != self.pad].clone()
            extra['src_tokens'] = sample['net_input']['src_tokens'][
                orig_target != self.pad].clone()
            extra['target'] = orig_target[orig_target != self.pad]
            extra['keys'] = decoder_out[1][self.args.knn_keytype].permute(
                1, 0, 2)[orig_target != self.pad]
            #d0, d1 = orig_target.shape
            #extra['src_id'] = sample['id'].view(d0, 1).expand(d0, d1)[orig_target != self.pad].clone()

            if 'knn_dstore' in kwargs:
                dstore = kwargs['knn_dstore']
                # TxBxC
                queries = bd[1][self.args.knn_keytype]
                if len(models) != 1:
                    raise ValueError('Only knn *log* probs are supported.')

                yhat_knn_prob, _extra = dstore.get_knn_log_prob(
                    queries, orig_target.permute(1, 0), pad_idx=self.pad)
                for k, v in _extra.items():
                    extra[k] = v
                yhat_knn_prob = yhat_knn_prob.permute(1, 0, 2).squeeze(-1)
                if self.args.fp16:
                    yhat_knn_prob = yhat_knn_prob.half()
                    probs = probs.half()

                probs = combine_knn_and_vocab_probs(yhat_knn_prob, probs,
                                                    self.args.lmbda)

            if avg_probs is None:
                avg_probs = probs
            else:
                avg_probs.add_(probs)
            if attn is not None and torch.is_tensor(attn):
                attn = attn.data
                if avg_attn is None:
                    avg_attn = attn
                else:
                    avg_attn.add_(attn)
        if len(models) > 1:
            avg_probs.div_(len(models))
            avg_probs.log_()
            if avg_attn is not None:
                avg_attn.div_(len(models))

        # save_extra = self.collate(save_extra)
        # TODO: This needs to be written to file.

        bsz = avg_probs.size(0)
        hypos = []
        start_idxs = sample[
            'start_indices'] if 'start_indices' in sample else [0] * bsz
        for i in range(bsz):
            # remove padding from ref
            ref = utils.strip_pad(sample['target'][i, start_idxs[i]:], self.pad) \
                if sample['target'] is not None else None
            tgt_len = ref.numel()
            avg_probs_i = avg_probs[i][start_idxs[i]:start_idxs[i] + tgt_len]
            score_i = avg_probs_i.sum() / tgt_len
            if avg_attn is not None:
                avg_attn_i = avg_attn[i]
                if self.compute_alignment:
                    alignment = utils.extract_hard_alignment(
                        avg_attn_i,
                        sample['net_input']['src_tokens'][i],
                        sample['target'][i],
                        self.pad,
                        self.eos,
                    )
                else:
                    alignment = None
            else:
                avg_attn_i = alignment = None
            hypos.append([{
                'tokens':
                ref,
                'score':
                score_i,
                'attention':
                avg_attn_i,
                'alignment':
                alignment,
                'positional_scores':
                avg_probs_i,
                'dstore_keys':
                decoder_out[1][self.args.knn_keytype][start_idxs[i]:, i, :]
                if self.args.save_knnlm_dstore else None,
            }])
        return hypos, extra
Ejemplo n.º 11
0
def main(args):
    utils.import_user_module(args)

    if args.buffer_size < 1:
        args.buffer_size = 1
    if args.max_tokens is None and args.batch_size is None:
        args.batch_size = 1

    assert (not args.batch_size or args.batch_size <= args.buffer_size
            ), "--batch-size cannot be larger than --buffer-size"

    logger.info(args)

    # Fix seed for stochastic decoding
    if args.seed is not None and not args.no_seed_provided:
        np.random.seed(args.seed)
        utils.set_torch_seed(args.seed)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Setup task, e.g., translation
    task = tasks.setup_task(args)

    # Load ensemble
    logger.info("loading model(s) from {}".format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(os.pathsep),
        arg_overrides=eval(args.model_overrides),
        task=task,
        suffix=getattr(args, "checkpoint_suffix", ""),
        strict=(args.checkpoint_shard_count == 1),
        num_shards=args.checkpoint_shard_count,
    )

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # Optimize ensemble for generation
    for model in models:
        if args.fp16:
            model.half()
        if use_cuda and not args.pipeline_model_parallel:
            model.cuda()
        model.prepare_for_inference_(args)

    # Initialize generator
    generator = task.build_generator(models, args)

    def encode_fn(x):
        x = task.spm.encode(x, out_type=str)
        x = segmenter_utils.to_character_sequence("".join(x))
        return x

    def decode_fn(chars, tags):
        x = ""
        for c, t in zip(chars.replace(" ", ""), tags.replace(" ", "")):
            if t == "B":
                x += " "
            x += c
        return x[1:]

    max_positions = utils.resolve_max_positions(
        task.max_positions(), *[model.max_positions() for model in models])

    if args.buffer_size > 1:
        logger.info("Sentence buffer size: %s", args.buffer_size)
    logger.info("NOTE: hypothesis and token scores are output in base 2")
    logger.info("Type the input sentence and press return:")
    start_id = 0
    for inputs in buffered_read(args.input, args.buffer_size):
        results = []
        for batch in make_batches(inputs, args, task, max_positions,
                                  encode_fn):
            bsz = batch.src_tokens.size(0)
            src_tokens = batch.src_tokens
            src_lengths = batch.src_lengths
            if use_cuda:
                src_tokens = src_tokens.cuda()
                src_lengths = src_lengths.cuda()

            sample = {
                "net_input": {
                    "src_tokens": src_tokens,
                    "src_lengths": src_lengths,
                },
                "src_chars": batch.src_chars,
            }
            segmentations = task.inference_step(generator, models, sample)
            for i, (id,
                    hypos) in enumerate(zip(batch.ids.tolist(),
                                            segmentations)):
                src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad())
                results.append((
                    start_id + id,
                    src_tokens_i,
                    batch.src_chars[i],
                    hypos,
                ))

        # sort output to match input order
        for id_, src_tokens, src_str, hypos in sorted(results,
                                                      key=lambda x: x[0]):
            print("S-{}\t{}".format(id_, src_str))

            # Process top predictions
            for hypo in hypos[:min(len(hypos), args.nbest)]:
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo["tokens"].int().cpu(),
                    src_str=src_str,
                    alignment=None,
                    align_dict=None,
                    tgt_dict=tgt_dict,
                )
                detok_hypo_str = decode_fn(src_str, hypo_str)
                score = hypo["score"] / math.log(2)  # convert to base 2
                # original hypothesis (after tokenization and BPE)
                print("H-{}\t{}\t{}".format(id_, score, hypo_str))
                # detokenized hypothesis
                print("D-{}\t{}\t{}".format(id_, score, detok_hypo_str))
                print("P-{}\t{}".format(
                    id_,
                    " ".join(
                        map(
                            lambda x: "{:.4f}".format(x),
                            # convert from base e to base 2
                            hypo["positional_scores"].div_(math.log(2)
                                                           ).tolist(),
                        )),
                ))

        # update running id_ counter
        start_id += len(inputs)
Ejemplo n.º 12
0
def _main(args, output_file):
    logging.basicConfig(
        format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        level=logging.INFO,
        stream=output_file,
    )
    logger = logging.getLogger('fairseq_cli.predict')

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.pred_subset)

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    logger.info('loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        utils.split_paths(args.path),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    # Optimize ensemble for generation TODO
    for model in models:
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.pred_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]
        ),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        default_log_format=('tqdm' if not args.no_progress_bar else 'none'),
    )


    # Handle tokenization and BPE
    tokenizer = encoders.build_tokenizer(args)
    bpe = encoders.build_bpe(args)

    def decode_fn(x):
        if bpe is not None:
            x = bpe.decode(x)
        if tokenizer is not None:
            x = tokenizer.decode(x)
        return x


    model = models[0] ## Use first model only, TODO ensembles
    y_true = []
    y_pred = []

    for sample in progress:
        sample = utils.move_to_cuda(sample) if use_cuda else sample
        if 'net_input' not in sample:
            continue

        for i, sample_id in enumerate(sample['id'].tolist()):

            # Remove padding
            src_tokens = utils.strip_pad(sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
            target_tokens = utils.strip_pad(sample['target'][i, :], tgt_dict.pad()).int().cpu()

            ### Predict
            _, _, y_true_, y_pred_ = task._predict_with_seqeval(sample, model) # get labels and predictions
            y_true_ = y_true_[0]
            y_pred_ = y_pred_[0]

            y_true.append(y_true_) # append to all labels and predictions
            y_pred.append(y_pred_)
        
            if src_dict is not None:
                src_str = src_dict.string(src_tokens, args.remove_bpe)
            else:
                src_str = ""

            target_str = tgt_dict.string(
                target_tokens,
                args.remove_bpe,
                escape_unk=True,
            )

            src_str = decode_fn(src_str)
            target_str = decode_fn(target_str)
            pred_str = ' '.join(y_pred_)

            if not args.quiet:
                print('S-{}\t{}'.format(sample_id, src_str), file=output_file)
                print('P-{}\t{}'.format(sample_id, pred_str), file=output_file)

        clf_report = classification_report(y_true, y_pred)
        logger.info("Eval results on {} subset: \n\n{} ".format(args.pred_subset, clf_report))
Ejemplo n.º 13
0
    def forward(self, model, sample, reduce=True):
        # sample mode
        #print('!!!RL loss.')
        model.eval()
        # src_dict = self.task.source_dictionary
        tgt_dict = self.task.target_dictionary
        eos_idx = self.task.target_dictionary.eos()
        sample_beam = self.args.sample_beam
        translator = SequenceGenerator(
            [model],
            tgt_dict=tgt_dict,
            sampling=self.args.multinomial_sample_train,
            beam_size=sample_beam,
            minlen=1)
        translator.cuda()
        ct = 0
        translations = []

        s = utils.move_to_cuda(sample)
        input = s['net_input']
        max_len = 200
        with torch.no_grad():
            hypos = translator.generate(
                input['src_tokens'],
                input['src_lengths'],
                beam_size=sample_beam,
                maxlen=max_len,
            )
        for i, id in enumerate(s['id'].data):
            src = input['src_tokens'].data[i, :]
            # remove padding from ref
            ref = utils.strip_pad(
                s['target'].data[i, :],
                tgt_dict.pad()) if s['target'] is not None else None
            translations.append((id, src, ref, hypos[i]))
            ct += 1
        # print("sample batch size:", ct)

        model.train()

        # MLE loss
        mle_net_output = model(**sample['net_input'])
        mle_lprobs = model.get_normalized_probs(mle_net_output, log_probs=True)
        mle_lprobs = mle_lprobs.view(-1, mle_lprobs.size(-1))
        mle_target = model.get_targets(sample, mle_net_output).view(-1)
        mle_loss = F.nll_loss(mle_lprobs,
                              mle_target,
                              size_average=False,
                              ignore_index=self.padding_idx,
                              reduce=reduce)
        mle_tokens = sample['ntokens']
        avg_mle_loss = mle_loss / mle_tokens
        print('avg_mle_loss:', avg_mle_loss)
        # RL loss
        batch_rl_loss = 0
        batch_tokens = 0
        sample_ind = 0
        for sample_id, src_tokens, tgt_tokens, hypos in translations:
            # calculate bleu
            sample_ind += 1
            rewards = torch.Tensor(sample_beam).float().cuda()
            logprobs = torch.Tensor(sample_beam).float().cuda()
            for i in range(sample_beam):
                hypo = hypos[i]
                trans_tokens = hypo['tokens']
                rewards[i] = self.compute_gleu(tgt_tokens.cpu(),
                                               trans_tokens.cpu(),
                                               max_order=self.args.max_order,
                                               gram=self.args.gram).cuda()
                # one_sample loss calculation
                tgt_input_tokens = trans_tokens.new(
                    trans_tokens.shape).fill_(0)
                assert trans_tokens[-1] == eos_idx
                tgt_input_tokens[0] = eos_idx
                tgt_input_tokens[1:] = trans_tokens[:-1]
                train_sample = {
                    'net_input': {
                        'src_tokens':
                        src_tokens.view(1, -1),
                        'src_lengths':
                        torch.LongTensor(src_tokens.numel()).view(1, -1),
                        'prev_output_tokens':
                        tgt_input_tokens.view(1, -1),
                    },
                    'target': trans_tokens.view(1, -1)
                }
                train_sample = utils.move_to_cuda(train_sample)
                net_output = model(**train_sample['net_input'])
                lprobs = model.get_normalized_probs(net_output, log_probs=True)
                lprobs = lprobs.view(-1, lprobs.size(-1))
                target = model.get_targets(train_sample,
                                           net_output).view(-1, 1)
                non_pad_mask = target.ne(tgt_dict.pad())
                lprob = -lprobs.gather(dim=-1, index=target)[non_pad_mask]
                logprobs[i] = torch.sum(lprob)
                ntokens = len(train_sample['target'])
                batch_tokens += ntokens
            rl_loss = torch.sum(logprobs *
                                (rewards - rewards.mean()))  # one sample loss
            batch_rl_loss += rl_loss

        avg_rl_loss = batch_rl_loss / batch_tokens
        print('avg_rl_loss:', avg_rl_loss)
        if self.args.mle_weight:
            assert self.args.rl_weight
            total_loss = self.args.mle_weight * avg_mle_loss + self.args.rl_weight * avg_rl_loss
            total_tokens = batch_tokens + mle_tokens
        else:
            total_loss = avg_rl_loss
            total_tokens = batch_tokens
        logging_output = {
            'loss': utils.item(total_loss.data),
            'ntokens': total_tokens,
            'sample_size': total_tokens,
        }
        print('total: ', total_loss)
        return total_loss, total_tokens, logging_output
Ejemplo n.º 14
0
    def generate(self, models, sample, **kwargs):
        """Score a batch of translations."""
        net_input = sample['net_input']

        def batch_for_softmax(dec_out, target):
            # assumes decoder_out[0] is the only thing needed (may not be correct for future models!)
            first, rest = dec_out[0], dec_out[1:]
            bsz, tsz, dim = first.shape
            if bsz * tsz < self.softmax_batch:
                yield dec_out, target, True
            else:
                flat = first.contiguous().view(1, -1, dim)
                flat_tgt = target.contiguous().view(flat.shape[:-1])
                s = 0
                while s < flat.size(1):
                    e = s + self.softmax_batch
                    yield (flat[:, s:e],) + rest, flat_tgt[:, s:e], False
                    s = e

        def gather_target_probs(probs, target):
            probs = probs.gather(
                dim=2,
                index=target.unsqueeze(-1),
            )
            return probs

        orig_target = sample['target']

        # compute scores for each model in the ensemble
        avg_probs = None
        avg_attn = None
        for model in models:
            model.eval()
            decoder_out = model.forward(**net_input)
            attn = decoder_out[1]
            if type(attn) is dict:
                attn = attn.get('attn', None)

            batched = batch_for_softmax(decoder_out, orig_target)
            probs, idx = None, 0
            for bd, tgt, is_single in batched:
                sample['target'] = tgt

                # divide the logits by temperature prior to softmax
                # for example, see https://github.com/pytorch/fairseq/blob/master/fairseq/sequence_generator.py:
                #   decoder_out[0][:, -1:, :].div_(temperature)
                bd[0].div_(self.temperature)

                curr_prob = model.get_normalized_probs(bd, log_probs=len(models) == 1, sample=sample).data
                if is_single:
                    probs = gather_target_probs(curr_prob, orig_target)
                else:
                    if probs is None:
                        probs = curr_prob.new(orig_target.numel())
                    step = curr_prob.size(0) * curr_prob.size(1)
                    end = step + idx
                    tgt_probs = gather_target_probs(curr_prob.view(tgt.shape + (curr_prob.size(-1),)), tgt)
                    probs[idx:end] = tgt_probs.view(-1)
                    idx = end
                sample['target'] = orig_target

            probs = probs.view(sample['target'].shape)

            if avg_probs is None:
                avg_probs = probs
            else:
                avg_probs.add_(probs)
            if attn is not None and torch.is_tensor(attn):
                attn = attn.data
                if avg_attn is None:
                    avg_attn = attn
                else:
                    avg_attn.add_(attn)
        if len(models) > 1:
            avg_probs.div_(len(models))
            avg_probs.log_()
            if avg_attn is not None:
                avg_attn.div_(len(models))

        bsz = avg_probs.size(0)
        hypos = []
        start_idxs = sample['start_indices'] if 'start_indices' in sample else [0] * bsz
        for i in range(bsz):
            # remove padding from ref
            ref = utils.strip_pad(sample['target'][i, start_idxs[i]:], self.pad) \
                if sample['target'] is not None else None
            tgt_len = ref.numel()
            avg_probs_i = avg_probs[i][start_idxs[i]:start_idxs[i] + tgt_len]
            score_i = avg_probs_i.sum() / tgt_len
            if avg_attn is not None:
                avg_attn_i = avg_attn[i]
                alignment = utils.extract_hard_alignment(avg_attn_i, sample['net_input']['src_tokens'][i],
                                                         sample['target'][i], self.pad, self.eos)
            else:
                avg_attn_i = alignment = None
            hypos.append([{
                'tokens': ref,
                'score': score_i,
                'attention': avg_attn_i,
                'alignment': alignment,
                'positional_scores': avg_probs_i,
            }])
        return hypos
Ejemplo n.º 15
0
    def _detokenize(self, sample, hypos):
        """ detokenize and compute BELU """
        message_list = []
        for i, sample_id in enumerate(sample['id'].tolist()):
            has_target = sample['target'] is not None

            # Remove padding
            src_tokens = utils.strip_pad(
                sample['net_input']['src_tokens'][i, :], self.tgt_dict.pad())
            target_tokens = None
            if has_target:
                target_tokens = utils.strip_pad(sample['target'][i, :],
                                                self.tgt_dict.pad()).int()

            # Either retrieve the original sentences or regenerate them from tokens.
            if self.align_dict is not None:
                src_str = self.task.dataset(
                    self.args.gen_subset).src.get_original_text(sample_id)
                target_str = self.task.dataset(
                    self.args.gen_subset).tgt.get_original_text(sample_id)
            else:
                if self.src_dict is not None:
                    src_str = self.src_dict.string(src_tokens,
                                                   self.args.remove_bpe)
                else:
                    src_str = ""
                if has_target:
                    target_str = self.tgt_dict.string(target_tokens,
                                                      self.args.remove_bpe,
                                                      escape_unk=True)

            if not self.args.quiet:
                if self.src_dict is not None:
                    if self.args.decode_hypothesis:
                        message_list.append('S-{}\t{}'.format(
                            sample_id, self._decode(src_str)))
                    else:
                        message_list.append('S-{}\t{}'.format(
                            sample_id, src_str))
                if has_target:
                    if self.args.decode_hypothesis:
                        message_list.append('T-{}\t{}'.format(
                            sample_id, self._decode(target_str)))
                    else:
                        message_list.append('T-{}\t{}'.format(
                            sample_id, target_str))

            # Process top predictions
            for j, hypo in enumerate(hypos[i][:self.args.nbest]):
                hypo_tokens, hypo_str, alignment = \
                    utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int(),
                        src_str=src_str,
                        alignment=hypo['alignment'],
                        align_dict=self.align_dict,
                        tgt_dict=self.tgt_dict,
                        remove_bpe=self.args.remove_bpe,
                    )

                if not self.args.quiet:
                    if self.args.decode_hypothesis:
                        detok_hypo_str = self._decode(hypo_str)
                        message_list.append('D-{}\t{}\t{}'.format(
                            sample_id, hypo['score'], detok_hypo_str))
                    else:
                        message_list.append('H-{}\t{}\t{}'.format(
                            sample_id, hypo['score'], hypo_str))
                    message_list.append('P-{}\t{}'.format(
                        sample_id, ' '.join(
                            map(
                                lambda x: '{:.4f}'.format(x),
                                hypo['positional_scores'].tolist(),
                            ))))

                    if self.args.print_alignment:
                        message_list.append('A-{}\t{}'.format(
                            sample_id, ' '.join([
                                '{}-{}'.format(src_idx, tgt_idx)
                                for src_idx, tgt_idx in alignment
                            ])))

                    if self.args.print_step:
                        message_list.append('I-{}\t{}'.format(
                            sample_id, hypo['steps']))

                    if getattr(self.args, 'retain_iter_history', False):
                        message_list.append("\n".join([
                            'E-{}_{}\t{}'.format(
                                sample_id, step,
                                utils.post_process_prediction(
                                    h['tokens'].int(), self.src_str, None,
                                    None, self.tgt_dict, None)[1])
                            for step, h in enumerate(hypo['history'])
                        ]))

                # Score only the top hypothesis
                if has_target and j == 0:
                    if (self.align_dict is not None
                            or self.args.remove_bpe is not None):
                        # Convert back to tokens for evaluation with unk
                        # replacement and/or without BPE
                        target_tokens = self.tgt_dict.encode_line(
                            target_str, add_if_not_exist=True)
                    if hasattr(self.scorer, 'add_string'):
                        self.message_queue.put((target_str, hypo_str))
                    else:
                        self.message_queue.put((target_tokens, hypo_tokens))

        self.message_queue.put('\n'.join(message_list))
Ejemplo n.º 16
0
def main(args):
    utils.import_user_module(args)

    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert not args.max_sentences or args.max_sentences <= args.buffer_size, \
        '--max-sentences/--batch-size cannot be larger than --buffer-size'

    logger.info(args)
    use_cuda = torch.cuda.is_available() and not args.cpu

    # Setup task, e.g., translation
    task = tasks.setup_task(args)

    # Load ensemble
    logger.info('loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(os.pathsep),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    for arg in vars(_model_args).keys():
        '''if arg not in {
            'self_target', 'future_target', 'past_target', 'tokens_per_sample',
            'output_size_dictionary', 'add_bos_token',
        }:'''
        if arg in {'decoder_embed_dim', 'vocab_size'}:
            setattr(args, arg, getattr(_model_args, arg))

    logger.info(args)

    if args.knnlm:
        knn_dstore = KNN_Dstore(args)
        print("Loading training tokens...")
        with open(args.input_tokens_file) as infile:
            train_tokens = infile.read().split()

        print("TODO, REMOVE ME\n\n\n !!!!Skipping first training tokens...")
        train_tokens = train_tokens[3072:]  # TODO, remove this

    if args.buffer_size < 1:
        args.buffer_size = 1
    if args.max_tokens is None and args.max_sentences is None:
        args.max_sentences = 1

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Initialize generator
    generator = task.build_generator(args)

    # Handle tokenization and BPE
    tokenizer = encoders.build_tokenizer(args)
    bpe = encoders.build_bpe(args)

    def encode_fn(x):
        if tokenizer is not None:
            x = tokenizer.encode(x)
        if bpe is not None:
            x = bpe.encode(x)
        return x

    def decode_fn(x):
        if bpe is not None:
            x = bpe.decode(x)
        if tokenizer is not None:
            x = tokenizer.decode(x)
        return x

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    max_positions = utils.resolve_max_positions(
        task.max_positions(), *[model.max_positions() for model in models])

    if args.buffer_size > 1:
        logger.info('Sentence buffer size: %s', args.buffer_size)
    logger.info('NOTE: hypothesis and token scores are output in base 2')
    logger.info('Type the input sentence and press return:')
    start_id = 0
    for inputs in buffered_read(args.input, args.buffer_size):
        results = []
        for batch in make_batches(inputs, args, task, max_positions,
                                  encode_fn):
            src_tokens = batch.src_tokens
            src_lengths = batch.src_lengths
            if use_cuda:
                src_tokens = src_tokens.cuda()
                src_lengths = src_lengths.cuda()

            sample = {
                'net_input': {
                    'src_tokens': src_tokens,
                    'src_lengths': src_lengths,
                },
            }
            if args.knnlm:
                translations = task.inference_step(generator,
                                                   models,
                                                   sample,
                                                   None,
                                                   knn_dstore=knn_dstore)
            else:
                translations = task.inference_step(generator, models, sample,
                                                   None)
            for i, (id,
                    hypos) in enumerate(zip(batch.ids.tolist(), translations)):
                src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad())
                results.append((start_id + id, src_tokens_i, hypos))

        # sort output to match input order
        for id, src_tokens, hypos in sorted(results, key=lambda x: x[0]):
            if src_dict is not None:
                src_str = src_dict.string(src_tokens, args.remove_bpe)
                print('S-{}\t{}'.format(id, src_str))

            # Process top predictions
            for hypo in hypos[:min(len(hypos), args.nbest)]:
                assert hypo['dists_full'] != None
                dists_full = hypo['dists_full'].float()
                knns_full = hypo['knns_full']
                word_tokens = [
                    task.target_dictionary[token] for token in hypo['tokens']
                ]
                #assert len(yhat_scores.tolist()) == len(word_tokens)

                # TODO, trim off padding when its batched

                context_size = 20
                num_neighbors = 10
                print("Example:", " ".join(word_tokens))
                print(dists_full)
                best_dist_indices = np.argsort(
                    dists_full)[-num_neighbors:][::-1]
                for j, neighbor_index in enumerate(best_dist_indices):
                    distance = dists_full[neighbor_index]
                    knn_index = knns_full[neighbor_index]
                    print(
                        "Best neighbor {} (distance {:.2f}):".format(
                            j, distance),
                        " ".join(train_tokens[knn_index -
                                              context_size:knn_index]), "[[",
                        train_tokens[knn_index], "]]")

                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'],
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe,
                )
                hypo_str = decode_fn(hypo_str)
                score = hypo['score'] / math.log(2)  # convert to base 2
                print('H-{}\t{}\t{}'.format(id, score, hypo_str))
                print('P-{}\t{}'.format(
                    id,
                    ' '.join(
                        map(
                            lambda x: '{:.4f}'.format(x),
                            # convert from base e to base 2
                            hypo['positional_scores'].div_(math.log(2)
                                                           ).tolist(),
                        ))))
                if args.print_alignment:
                    alignment_str = " ".join(
                        ["{}-{}".format(src, tgt) for src, tgt in alignment])
                    print('A-{}\t{}'.format(id, alignment_str))

        # update running id counter
        start_id += len(inputs)
Ejemplo n.º 17
0
    def get_costs(self, sample, scorer=None):
        """Get costs for hypotheses using the specified *scorer*."""
        if scorer is None:
            scorer = self.get_sequence_scorer(self.args.seq_scorer)

        bsz = len(sample['hypos'])
        nhypos = len(sample['hypos'][0])
        target = sample['target'].int()
        source = sample['net_input']['src_tokens'].int()

        pad_idx = self.target_dictionary.pad()
        assert pad_idx == self.source_dictionary.pad()

        costs = torch.zeros(bsz, nhypos).to(sample['target'].device)

        if self.args.seq_scorer == "simile":
            for i, hypos_i in enumerate(sample['hypos']):
                ref = utils.strip_pad(target[i, :], pad_idx).cpu()
                ref = scorer.preprocess_ref(ref)
                ref_len = len(ref.split())
                hypos = []
                hypo_lens = []

                for j, hypo in enumerate(hypos_i):
                    hyp = scorer.preprocess_hypo(hypo)
                    hypos.append(hyp)
                    hypo_lens.append(len(hyp.split()))

                _costs = scorer.get_costs(ref, hypos)

                for j, _ in enumerate(hypos_i):
                    lp = np.exp(1 - max(ref_len, hypo_lens[j]) /
                                float(min(ref_len, hypo_lens[j])))
                    costs[i, j] = 1 - lp**self.args.simile_lenpen * _costs[
                        j].item()
        elif self.args.seq_scorer == "cl-simile":
            for i, hypos_i in enumerate(sample['hypos']):
                ref = utils.strip_pad(target[i, :], pad_idx).cpu()
                ref = scorer.preprocess_ref(ref)
                src = utils.strip_pad(source[i, :], pad_idx).cpu()
                src = scorer.preprocess_src(src)

                ref_len = len(ref.split())
                hypos = []
                hypo_lens = []

                for j, hypo in enumerate(hypos_i):
                    hyp = scorer.preprocess_hypo(hypo)
                    hypos.append(hyp)
                    hypo_lens.append(len(hyp.split()))

                _costs = scorer.get_costs(ref, hypos, src)

                for j, _ in enumerate(hypos_i):
                    lp = np.exp(1 - max(ref_len, hypo_lens[j]) /
                                float(min(ref_len, hypo_lens[j])))
                    costs[i, j] = 1 - lp**self.args.simile_lenpen * _costs[
                        j].item()
        else:
            for i, hypos_i in enumerate(sample['hypos']):
                ref = utils.strip_pad(target[i, :], pad_idx).cpu()
                ref = scorer.preprocess_ref(ref)
                for j, hypo in enumerate(hypos_i):
                    costs[i, j] = scorer.get_cost(ref,
                                                  scorer.preprocess_hypo(hypo))
        return scorer.postprocess_costs(costs)
Ejemplo n.º 18
0
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    args.beam=1
    utils.import_user_module(args)

    if args.max_tokens is None:
        args.max_tokens = 12000
    print(args)
    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(':'),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
        ),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    num_sentences = 0
    source_sentences = []
    shard_id = 0
    all_avg_pool = None
    encoder_has_langtok = (
        hasattr(task.args, 'encoder_langtok')
        and task.args.encoder_langtok is not None
        and hasattr(task.args, 'lang_tok_replacing_bos_eos')
        and not task.args.lang_tok_replacing_bos_eos
    )
    with progress_bar.build_progress_bar(args, itr) as t:
        for sample in t:
            if sample is None:
                print("Skipping None")
                continue
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample['target'][:, :args.prefix_size]

            with torch.no_grad():
                avg_pool = get_avg_pool(
                        models, sample, prefix_tokens, src_dict,
                        args.remove_bpe,
                        has_langtok=encoder_has_langtok)
                if all_avg_pool is not None:
                    all_avg_pool = np.concatenate((all_avg_pool, avg_pool))
                else:
                    all_avg_pool = avg_pool

            if not isinstance(sample['id'], list):
                sample_ids = sample['id'].tolist()
            else:
                sample_ids = sample['id']
            for i, sample_id in enumerate(sample_ids):
                # Remove padding
                src_tokens = utils.strip_pad(sample['net_input']['src_tokens'][i, :], tgt_dict.pad())

                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
                else:
                    if src_dict is not None:
                        src_str = src_dict.string(src_tokens, args.remove_bpe)
                    else:
                        src_str = ""

                if not args.quiet:
                    if src_dict is not None:
                        print('S-{}\t{}'.format(sample_id, src_str))

                source_sentences.append(f"{sample_id}\t{src_str}")

            num_sentences += sample['nsentences']
            if all_avg_pool.shape[0] >= 1000000:
                with open(f'{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}',
                        'w') as avg_pool_file:
                    all_avg_pool.tofile(avg_pool_file)
                with open(f'{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}', 'w') as sentence_file:
                    sentence_file.writelines(f'{line}\n' for line in source_sentences)
                all_avg_pool = None
                source_sentences = []
                shard_id += 1

    if all_avg_pool is not None:
        with open(f'{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}',
                'w') as avg_pool_file:
            all_avg_pool.tofile(avg_pool_file)
        with open(f'{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}', 'w') as sentence_file:
            sentence_file.writelines(f'{line}\n' for line in source_sentences)
    return None
Ejemplo n.º 19
0
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample; periodically print out
        randomly sampled predictions if model is in training mode, otherwise
        aggregate word error stats for validation.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        dictionary = self.scorer.dictionary
        if model.training:
            net_output = model(**sample['net_input'], epoch=self.epoch)
            lprobs = model.get_normalized_probs(net_output, log_probs=True)
            target = model.get_targets(sample, net_output)
            if (
                self.num_updates // self.args.print_interval >
                (self.num_updates - 1) // self.args.print_interval
            ):  # print a randomly sampled result every print_interval updates
                pred = lprobs.argmax(-1).cpu()  # bsz x len
                assert pred.size() == target.size()
                with data_utils.numpy_seed(self.num_updates):
                    i = np.random.randint(0, len(sample['id']))
                ref_tokens = sample['target_raw_text'][i]
                length = utils.strip_pad(target.data[i], self.padding_idx).size(0)
                ref_one = dictionary.tokens_to_sentence(
                    ref_tokens, use_unk_sym=False, bpe_symbol=self.args.remove_bpe,
                )
                pred_one = dictionary.tokens_to_sentence(
                    dictionary.string(pred.data[i][:length]), use_unk_sym=True,
                    bpe_symbol=self.args.remove_bpe,
                )
                print('| sample REF: ' + ref_one)
                print('| sample PRD: ' + pred_one)
        else:
            tokens, lprobs, _ = self.decoder_for_validation.decode([model], sample)
            pred = tokens[:, 1:].data.cpu()  # bsz x len
            target = sample['target']
            # compute word error stats
            assert pred.size(0) == target.size(0)
            self.scorer.reset()
            for i in range(target.size(0)):
                utt_id = sample['utt_id'][i]
                ref_tokens = sample['target_raw_text'][i]
                pred_tokens = dictionary.string(pred.data[i])
                self.scorer.add_evaluation(
                    utt_id, ref_tokens, pred_tokens, bpe_symbol=self.args.remove_bpe,
                )

        lprobs = lprobs.view(-1, lprobs.size(-1))
        loss = F.nll_loss(
            lprobs,
            target.view(-1),
            ignore_index=self.padding_idx,
            reduction='sum' if reduce else 'none',
        )
        sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
        logging_output = {
            'loss': utils.item(loss.data) if reduce else loss.data,
            'ntokens': sample['ntokens'],
            'nsentences': sample['target'].size(0),
            'sample_size': sample_size,
        }
        if not model.training:  # do not compute word error in training mode
            logging_output['word_error'] = self.scorer.tot_word_error()
            logging_output['word_count'] = self.scorer.tot_word_count()
            logging_output['char_error'] = self.scorer.tot_char_error()
            logging_output['char_count'] = self.scorer.tot_char_count()
        return loss, sample_size, logging_output
Ejemplo n.º 20
0
def eval_single_token_prediction(model,
                                 itr,
                                 dictionary,
                                 singletoken_topp=0.0,
                                 singletoken_topk=1):
    predicted_tokens = []
    target_tokens = []

    mle_loss_sum = 0
    num_samples_sum = 0
    wrong_mass_sum = 0

    logging_outputs = []

    for n, sample in tqdm(enumerate(itr)):
        sample = utils.move_to_cuda(sample)
        net_output = model(**sample['net_input'])
        logits = net_output[0][0]
        logits[:, dictionary.pad()] = -1e19
        predicted_tokens.append(logits.argmax(1).tolist())
        target = sample['target'].view(-1)
        target_tokens.append(target.tolist())

        # -- mle loss
        lprobs = model.get_normalized_probs(net_output, log_probs=True)
        lprobs = lprobs.view(-1, lprobs.size(-1))
        true_token_lprobs = F.nll_loss(
            lprobs,
            target,
            ignore_index=dictionary.pad_index,
            reduction='none',
        )
        true_token_logits = -F.nll_loss(
            logits,
            target,
            ignore_index=dictionary.pad_index,
            reduction='none',
        )
        mle_loss = true_token_lprobs.sum()
        orig = utils.strip_pad(target, dictionary.pad_index)
        ntokens = orig.numel()

        mle_loss_sum += mle_loss.item()
        num_samples_sum += ntokens

        logging_output = TrainingMetrics.ranking_metrics(logits,
                                                         true_token_logits,
                                                         sample,
                                                         ntokens,
                                                         target,
                                                         topk=singletoken_topk,
                                                         topp=singletoken_topp)

        negative_targets = (logits > true_token_logits[:, None]).float()
        wrong_mass_sum += (negative_targets * (F.softmax(logits, dim=1))).sum()

        logging_outputs.append(logging_output)

    ppl = math.pow(2, mle_loss_sum / num_samples_sum / math.log(2))
    custom_metrics = TrainingMetrics.aggregate_and_normalize(logging_outputs)
    custom_metrics['ppl'] = ppl
    avg_wrong_mass = wrong_mass_sum / num_samples_sum
    custom_metrics['avg_wrong_mass'] = avg_wrong_mass.item()
    return predicted_tokens, target_tokens, custom_metrics
Ejemplo n.º 21
0
def main(args):
    utils.import_user_module(args)

    if args.buffer_size < 1:
        args.buffer_size = 1
    if args.max_tokens is None and args.max_sentences is None:
        args.max_sentences = 1

    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert not args.max_sentences or args.max_sentences <= args.buffer_size, \
        '--max-sentences/--batch-size cannot be larger than --buffer-size'

    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Setup task, e.g., translation
    task = tasks.setup_task(args)

    # Load ensemble
    logger.info('loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(os.pathsep),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Initialize generator
    generator = task.build_generator(models, args)

    # Handle tokenization and BPE
    tokenizer = encoders.build_tokenizer(args)
    bpe = encoders.build_bpe(args)

    def encode_fn(x):
        if tokenizer is not None:
            x = tokenizer.encode(x)
        if bpe is not None:
            x = bpe.encode(x)
        return x

    def decode_fn(x):
        if bpe is not None:
            x = bpe.decode(x)
        if tokenizer is not None:
            x = tokenizer.decode(x)
        return x

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    max_positions = utils.resolve_max_positions(
        task.max_positions(),
        *[model.max_positions() for model in models]
    )

    if args.buffer_size > 1:
        logger.info('Sentence buffer size: %s', args.buffer_size)
    logger.info('NOTE: hypothesis and token scores are output in base 2')
    logger.info('Type the input sentence and press return:')
    start_id = 0
    for inputs in buffered_read(args.input, args.buffer_size):
        results = []
        for batch in make_batches(inputs, args, task, max_positions, encode_fn):
            src_tokens = batch.src_tokens
            src_lengths = batch.src_lengths
            if use_cuda:
                src_tokens = src_tokens.cuda()
                src_lengths = src_lengths.cuda()

            sample = {
                'net_input': {
                    'src_tokens': src_tokens,
                    'src_lengths': src_lengths,
                },
            }
            translations = task.inference_step(generator, models, sample)
            for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
                src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad())
                results.append((start_id + id, src_tokens_i, hypos))

        # sort output to match input order
        for id, src_tokens, hypos in sorted(results, key=lambda x: x[0]):
            if src_dict is not None:
                src_str = src_dict.string(src_tokens, args.remove_bpe)
                print('S-{}\t{}'.format(id, src_str))

            # Process top predictions
            for hypo in hypos[:min(len(hypos), args.nbest)]:
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'],
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe,
                )
                detok_hypo_str = decode_fn(hypo_str)
                score = hypo['score'] / math.log(2)  # convert to base 2
                # original hypothesis (after tokenization and BPE)
                print('H-{}\t{}\t{}'.format(id, score, hypo_str))
                # detokenized hypothesis
                print('D-{}\t{}\t{}'.format(id, score, detok_hypo_str))
                print('P-{}\t{}'.format(
                    id,
                    ' '.join(map(
                        lambda x: '{:.4f}'.format(x),
                        # convert from base e to base 2
                        hypo['positional_scores'].div_(math.log(2)).tolist(),
                    ))
                ))
                if args.print_alignment:
                    alignment_str = " ".join(["{}-{}".format(src, tgt) for src, tgt in alignment])
                    print('A-{}\t{}'.format(
                        id,
                        alignment_str
                    ))

        # update running id counter
        start_id += len(inputs)
Ejemplo n.º 22
0
    def generate_cand_translations(self, sample):
        def decode_fn(x):
            return (x + ' ').replace('@@ ', '').rstrip()

        #print ("id: ",sample['id'][0].item())
        #hypos hold all the candidates for one batch, [batch_size,beamsize]
        hypos = self.task.inference_step(self.generator, [self.model], sample)
        #print ("---before{}----".format(sample['cand']))

        pad_idx = 1
        eos_idx = 2
        left_pad_target = False

        tgt_dict = self.task.target_dictionary
        src_dict = self.task.source_dictionary

        src_str_for_bert = []
        cand_str_for_bert = []
        cand_tokens = []

        for i, sample_id in enumerate(sample['id'].tolist()):
            random_list = [i for i in range(self.args.beam)]
            random.shuffle(random_list)

            src_tokens = utils.strip_pad(
                sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
            src_str = src_dict.string(src_tokens, None)
            src_str = decode_fn(src_str)
            src_splittokens = src_str.split(' ')
            src_str_for_bert.append(self.mose_de.detokenize(src_splittokens))

            # Remove paddinsample_idg
            has_target = sample['target'] is not None
            target_tokens = None
            if has_target:
                target_tokens = utils.strip_pad(sample['target'][i, :],
                                                tgt_dict.pad()).int().cpu()

                target_str = tgt_dict.string(
                    target_tokens,
                    None,
                    escape_unk=True,
                    #extra_symbols_to_ignore={tgt_dict.eos()},
                )

                target_str = decode_fn(target_str)

            # Process random {cands_into_model} predictions
            one_sent_bleu = []
            detoken_cands = []  #this is for bert
            candidates = []
            for j in range(self.args.cands_into_model):
                hypo = hypos[i][random_list[j]]
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'],
                    align_dict=None,
                    tgt_dict=tgt_dict,
                    remove_bpe=None,
                    #extra_symbols_to_ignore={tgt_dict.eos()},
                )
                hypo_str = decode_fn(hypo_str)
                candidates.append(hypo['tokens'].cpu())
                #print (hypo_str)

                #input for bert
                hypo_splittokens = hypo_str.split(' ')
                detoken_hypo = self.mose_en.detokenize(hypo_splittokens)
                detoken_cands.append(detoken_hypo)
                #detoken_cands.insert(i+j*(i+1),detoken_hypo)

                smooth_bleuscore = sacrebleu.corpus_bleu(
                    [hypo_str], [[target_str]],
                    use_effective_order=True,
                    tokenize="none")
                #cands_bleu.append(smooth_bleuscore.score)
                one_sent_bleu.append(smooth_bleuscore.score)

            max_cand = detoken_cands[one_sent_bleu.index(max(one_sent_bleu))]
            cand_str_for_bert.append(max_cand)
            #max_cand_token = candidates[one_sent_bleu.index(max(one_sent_bleu))]
            #cand_tokens.append(max_cand_token)

        #temp_tgt.extend(candidates)

        #padding
        with torch.no_grad():
            '''
            cands = data_utils.collate_tokens(
                    cand_tokens,
                    pad_idx,
                    eos_idx,
                    left_pad = left_pad_target,
                    move_eos_to_beginning=False,)
            sample ['cands'] = cands.cuda()
            '''
            src_bert_ids, src_bert = get_bert_out(src_str_for_bert,
                                                  self.bert_de_token,
                                                  self.bert_de_model)
            cand_bert_ids, cand_bert = get_bert_out(cand_str_for_bert,
                                                    self.bert_tokenizer,
                                                    self.bert_model)
            sample['cand_bert_encoded'] = cand_bert
            sample['src_bert_encoded'] = src_bert
            sample['cands'] = cand_bert_ids
            #print ("cands shape",sample['cands'].shape)
        del hypos, one_sent_bleu, candidates, random_list, cand_bert, src_bert, cand_str_for_bert, \
                src_str_for_bert
        #print (sample)
        return sample
Ejemplo n.º 23
0
def generate(cfg: UnsupGenerateConfig, models, saved_cfg, use_cuda):
    task = tasks.setup_task(cfg.fairseq.task)
    saved_cfg.task.labels = cfg.fairseq.task.labels
    task.load_dataset(cfg.fairseq.dataset.gen_subset, task_cfg=saved_cfg.task)
    # Set dictionary
    tgt_dict = task.target_dictionary
    logger.info(
        "| {} {} {} examples".format(
            cfg.fairseq.task.data,
            cfg.fairseq.dataset.gen_subset,
            len(task.dataset(cfg.fairseq.dataset.gen_subset)),
        )
    )
    # Load dataset (possibly sharded)
    itr = get_dataset_itr(cfg, task)
    # Initialize generator
    gen_timer = StopwatchMeter()

    def build_generator(cfg: UnsupGenerateConfig):
        w2l_decoder = cfg.w2l_decoder
        if w2l_decoder == DecoderType.VITERBI:
            from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder

            return W2lViterbiDecoder(cfg, task.target_dictionary)
        elif w2l_decoder == DecoderType.KENLM:
            from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder

            return W2lKenLMDecoder(cfg, task.target_dictionary)
        elif w2l_decoder == DecoderType.FAIRSEQ:
            from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder

            return W2lFairseqLMDecoder(cfg, task.target_dictionary)
        elif w2l_decoder == DecoderType.KALDI:
            from examples.speech_recognition.kaldi.kaldi_decoder import KaldiDecoder

            assert cfg.kaldi_decoder_config is not None

            return KaldiDecoder(
                cfg.kaldi_decoder_config,
                cfg.beam,
            )
        else:
            raise NotImplementedError(
                "only wav2letter decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment but found "
                + str(w2l_decoder)
            )

    generator = build_generator(cfg)

    kenlm = None
    fairseq_lm = None
    if cfg.lm_model is not None:
        import kenlm

        kenlm = kenlm.Model(cfg.lm_model)

    num_sentences = 0
    if cfg.results_path is not None and not os.path.exists(cfg.results_path):
        os.makedirs(cfg.results_path)

    res_files = prepare_result_files(cfg)
    errs_t = 0
    lengths_hyp_t = 0
    lengths_hyp_unit_t = 0
    lengths_t = 0
    count = 0
    num_feats = 0
    all_hyp_pieces = []
    all_hyp_words = []

    num_symbols = (
        len([s for s in tgt_dict.symbols if not s.startswith("madeup")])
        - tgt_dict.nspecial
    )
    targets = None
    if cfg.targets is not None:
        tgt_path = os.path.join(
            cfg.fairseq.task.data, cfg.fairseq.dataset.gen_subset + "." + cfg.targets
        )
        if os.path.exists(tgt_path):
            with open(tgt_path, "r") as f:
                targets = f.read().splitlines()
    viterbi_transcript = None
    if cfg.viterbi_transcript is not None and len(cfg.viterbi_transcript) > 0:
        logger.info(f"loading viterbi transcript from {cfg.viterbi_transcript}")
        with open(cfg.viterbi_transcript, "r") as vf:
            viterbi_transcript = vf.readlines()
            viterbi_transcript = [v.rstrip().split() for v in viterbi_transcript]

    gen_timer.start()

    start = 0
    end = len(itr)

    hypo_futures = None
    if cfg.w2l_decoder == DecoderType.KALDI:
        logger.info("Extracting features")
        hypo_futures = []
        samples = []
        with progress_bar.build_progress_bar(cfg.fairseq.common, itr) as t:
            for i, sample in enumerate(t):
                if "net_input" not in sample or i < start or i >= end:
                    continue
                if "padding_mask" not in sample["net_input"]:
                    sample["net_input"]["padding_mask"] = None

                hypos, num_feats = gen_hypos(
                    generator, models, num_feats, sample, task, use_cuda
                )
                hypo_futures.append(hypos)
                samples.append(sample)
        itr = list(zip(hypo_futures, samples))
        start = 0
        end = len(itr)
        logger.info("Finished extracting features")

    with progress_bar.build_progress_bar(cfg.fairseq.common, itr) as t:
        for i, sample in enumerate(t):
            if i < start or i >= end:
                continue

            if hypo_futures is not None:
                hypos, sample = sample
                hypos = [h.result() for h in hypos]
            else:
                if "net_input" not in sample:
                    continue

                hypos, num_feats = gen_hypos(
                    generator, models, num_feats, sample, task, use_cuda
                )

            for i, sample_id in enumerate(sample["id"].tolist()):
                if targets is not None:
                    target_tokens = targets[sample_id]
                elif "target" in sample or "target_label" in sample:
                    toks = (
                        sample["target"][i, :]
                        if "target_label" not in sample
                        else sample["target_label"][i, :]
                    )

                    target_tokens = utils.strip_pad(toks, tgt_dict.pad()).int().cpu()
                else:
                    target_tokens = None

                # Process top predictions
                (
                    errs,
                    length_hyp,
                    length,
                    hyp_pieces,
                    hyp_words,
                ) = process_predictions(
                    cfg,
                    hypos[i],
                    tgt_dict,
                    target_tokens,
                    res_files,
                )
                errs_t += errs
                lengths_hyp_t += length_hyp
                lengths_hyp_unit_t += (
                    len(hyp_pieces) if len(hyp_pieces) > 0 else len(hyp_words)
                )
                lengths_t += length
                count += 1
                all_hyp_pieces.append(hyp_pieces)
                all_hyp_words.append(hyp_words)

            num_sentences += (
                sample["nsentences"] if "nsentences" in sample else sample["id"].numel()
            )

    lm_score_sum = 0
    if kenlm is not None:

        if cfg.unit_lm:
            lm_score_sum = sum(kenlm.score(w) for w in all_hyp_pieces)
        else:
            lm_score_sum = sum(kenlm.score(w) for w in all_hyp_words)
    elif fairseq_lm is not None:
        lm_score_sum = sum(fairseq_lm.score([h.split() for h in all_hyp_words])[0])

    vt_err_t = 0
    vt_length_t = 0
    if viterbi_transcript is not None:
        unit_hyps = []
        if cfg.targets is not None and cfg.lexicon is not None:
            lex = {}
            with open(cfg.lexicon, "r") as lf:
                for line in lf:
                    items = line.rstrip().split()
                    lex[items[0]] = items[1:]
            for h in all_hyp_pieces:
                hyp_ws = []
                for w in h.split():
                    assert w in lex, w
                    hyp_ws.extend(lex[w])
                unit_hyps.append(hyp_ws)

        else:
            unit_hyps.extend([h.split() for h in all_hyp_words])

        vt_err_t = sum(
            editdistance.eval(vt, h) for vt, h in zip(viterbi_transcript, unit_hyps)
        )

        vt_length_t = sum(len(h) for h in viterbi_transcript)

    if res_files is not None:
        for r in res_files.values():
            r.close()

    gen_timer.stop(lengths_hyp_t)

    return GenResult(
        count,
        errs_t,
        gen_timer,
        lengths_hyp_unit_t,
        lengths_hyp_t,
        lengths_t,
        lm_score_sum,
        num_feats,
        num_sentences,
        num_symbols,
        vt_err_t,
        vt_length_t,
    )
Ejemplo n.º 24
0
def main(args, task=None, model_state=None):
    check_args(args)

    if args.max_tokens is None and args.batch_size is None:
        args.max_tokens = 4000000
    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    logger.info("| decoding with criterion {}".format(args.criterion))

    #import cloudpickle as pickle
    #pickle.dump(args, open('ARGS.bin', 'wb'))
    #args = pickle.load(open('ARGS.bin', 'rb'))
    #model_state = None

    task = tasks.setup_task(args)

    # Load ensemble
    if args.load_emissions:
        models, criterions = [], []
        task.load_dataset(args.gen_subset)
    else:
        logger.info("| loading model(s) from {}".format(args.path))
        #models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
        models, saved_cfg = checkpoint_utils.load_model_ensemble(
            utils.split_paths(args.path),
            arg_overrides=ast.literal_eval(args.model_overrides),
            task=task,
            suffix=args.checkpoint_suffix,
            strict=(args.checkpoint_shard_count == 1),
            num_shards=args.checkpoint_shard_count,
            state=model_state,
        )
        optimize_models(args, use_cuda, models)
        task.load_dataset(args.gen_subset, task_cfg=saved_cfg.task)

    # Set dictionary
    tgt_dict = task.target_dictionary

    logger.info("| {} {} {} examples".format(
        args.data, args.gen_subset, len(task.dataset(args.gen_subset))))

    # hack to pass transitions to W2lDecoder
    if args.criterion == "asg_loss":
        raise NotImplementedError("asg_loss is currently not supported")
        # trans = criterions[0].asg.trans.data
        # args.asg_transitions = torch.flatten(trans).tolist()

    # Load dataset (possibly sharded)
    itr = get_dataset_itr(args, task, models)

    # Initialize generator
    gen_timer = StopwatchMeter()

    def build_generator(args):
        w2l_decoder = getattr(args, "w2l_decoder", None)
        if w2l_decoder == "viterbi":
            from w2l_decoder import W2lViterbiDecoder

            return W2lViterbiDecoder(args, task.target_dictionary)
        elif w2l_decoder == "kenlm":
            from w2l_decoder import W2lKenLMDecoder

            return W2lKenLMDecoder(args, task.target_dictionary)
        elif w2l_decoder == "fairseqlm":
            from w2l_decoder import W2lFairseqLMDecoder

            return W2lFairseqLMDecoder(args, task.target_dictionary)
        else:
            print(
                "only flashlight decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment"
            )

    # please do not touch this unless you test both generate.py and infer.py with audio_pretraining task
    generator = build_generator(args)

    if args.load_emissions:
        generator = ExistingEmissionsDecoder(
            generator, np.load(args.load_emissions, allow_pickle=True))
        logger.info("loaded emissions from " + args.load_emissions)

    num_sentences = 0

    if args.results_path is not None and not os.path.exists(args.results_path):
        os.makedirs(args.results_path)

    max_source_pos = (utils.resolve_max_positions(
        task.max_positions(), *[model.max_positions() for model in models]), )

    if max_source_pos is not None:
        max_source_pos = max_source_pos[0]
        if max_source_pos is not None:
            max_source_pos = max_source_pos[0] - 1

    if args.dump_emissions:
        emissions = {}
    if args.dump_features:
        features = {}
        models[0].bert.proj = None
    else:
        res_files = prepare_result_files(args)
    errs_t = 0
    lengths_t = 0
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if "net_input" not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample["target"][:, :args.prefix_size]

            gen_timer.start()
            if args.dump_emissions:
                with torch.no_grad():
                    encoder_out = models[0](**sample["net_input"])
                    emm = models[0].get_normalized_probs(encoder_out,
                                                         log_probs=True)
                    emm = emm.transpose(0, 1).cpu().numpy()
                    for i, id in enumerate(sample["id"]):
                        emissions[id.item()] = emm[i]
                    continue
            elif args.dump_features:
                with torch.no_grad():
                    encoder_out = models[0](**sample["net_input"])
                    feat = encoder_out["encoder_out"].transpose(
                        0, 1).cpu().numpy()
                    for i, id in enumerate(sample["id"]):
                        padding = (encoder_out["encoder_padding_mask"][i].cpu(
                        ).numpy() if encoder_out["encoder_padding_mask"]
                                   is not None else None)
                        features[id.item()] = (feat[i], padding)
                    continue
            hypos = task.inference_step(generator, models, sample,
                                        prefix_tokens)
            num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample["id"].tolist()):
                speaker = None
                # id = task.dataset(args.gen_subset).ids[int(sample_id)]
                id = sample_id
                toks = (sample["target"][i, :] if "target_label" not in sample
                        else sample["target_label"][i, :])
                target_tokens = utils.strip_pad(toks,
                                                tgt_dict.pad()).int().cpu()
                # Process top predictions
                errs, length = process_predictions(
                    args,
                    hypos[i],
                    None,
                    tgt_dict,
                    target_tokens,
                    res_files,
                    speaker,
                    id,
                )
                errs_t += errs
                lengths_t += length

            wps_meter.update(num_generated_tokens)
            t.log({"wps": round(wps_meter.avg)})
            num_sentences += (sample["nsentences"] if "nsentences" in sample
                              else sample["id"].numel())

    wer = None
    if args.dump_emissions:
        emm_arr = []
        for i in range(len(emissions)):
            emm_arr.append(emissions[i])
        np.save(args.dump_emissions, emm_arr)
        logger.info(
            f"saved {len(emissions)} emissions to {args.dump_emissions}")
    elif args.dump_features:
        feat_arr = []
        for i in range(len(features)):
            feat_arr.append(features[i])
        np.save(args.dump_features, feat_arr)
        logger.info(f"saved {len(features)} emissions to {args.dump_features}")
    else:
        if lengths_t > 0:
            wer = errs_t * 100.0 / lengths_t
            logger.info(f"WER: {wer}")

        logger.info("| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}"
                    "sentences/s, {:.2f} tokens/s)".format(
                        num_sentences,
                        gen_timer.n,
                        gen_timer.sum,
                        num_sentences / gen_timer.sum,
                        1.0 / gen_timer.avg,
                    ))
        logger.info("| Generate {} with beam={}".format(
            args.gen_subset, args.beam))
    return task, wer
Ejemplo n.º 25
0
def _main(cfg: DictConfig, output_file):
    logging.basicConfig(
        format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        level=os.environ.get("LOGLEVEL", "INFO").upper(),
        stream=output_file,
    )
    logger = logging.getLogger("fairseq_cli.generate")

    utils.import_user_module(cfg.common)

    if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None:
        cfg.dataset.max_tokens = 12000
    logger.info(cfg)

    # Fix seed for stochastic decoding
    if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
        np.random.seed(cfg.common.seed)
        utils.set_torch_seed(cfg.common.seed)

    use_cuda = torch.cuda.is_available() and not cfg.common.cpu

    # Load dataset splits
    task = tasks.setup_task(cfg.task)

    # Set dictionaries
    try:
        src_dict = getattr(task, "source_dictionary", None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    overrides = ast.literal_eval(cfg.common_eval.model_overrides)

    # Load ensemble
    logger.info("loading model(s) from {}".format(cfg.common_eval.path))
    models, saved_cfg = checkpoint_utils.load_model_ensemble(
        utils.split_paths(cfg.common_eval.path),
        arg_overrides=overrides,
        task=task,
        suffix=cfg.checkpoint.checkpoint_suffix,
        strict=(cfg.checkpoint.checkpoint_shard_count == 1),
        num_shards=cfg.checkpoint.checkpoint_shard_count,
    )

    # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
    task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)

    if cfg.generation.lm_path is not None:
        overrides["data"] = cfg.task.data

        try:
            lms, _ = checkpoint_utils.load_model_ensemble(
                [cfg.generation.lm_path], arg_overrides=overrides, task=None)
        except:
            logger.warning(
                f"Failed to load language model! Please make sure that the language model dict is the same "
                f"as target dict and is located in the data dir ({cfg.task.data})"
            )
            raise

        assert len(lms) == 1
    else:
        lms = [None]

    # Optimize ensemble for generation
    for model in chain(models, lms):
        if model is None:
            continue
        if cfg.common.fp16:
            model.half()
        if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
            model.cuda()
        model.prepare_for_inference_(cfg)

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(cfg.generation.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(cfg.dataset.gen_subset),
        max_tokens=cfg.dataset.max_tokens,
        max_sentences=cfg.dataset.batch_size,
        max_positions=utils.resolve_max_positions(
            task.max_positions(), *[m.max_positions() for m in models]),
        ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
        seed=cfg.common.seed,
        num_shards=cfg.distributed_training.distributed_world_size,
        shard_id=cfg.distributed_training.distributed_rank,
        num_workers=cfg.dataset.num_workers,
        data_buffer_size=cfg.dataset.data_buffer_size,
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.progress_bar(
        itr,
        log_format=cfg.common.log_format,
        log_interval=cfg.common.log_interval,
        default_log_format=("tqdm"
                            if not cfg.common.no_progress_bar else "simple"),
    )

    # Initialize generator
    gen_timer = StopwatchMeter()

    extra_gen_cls_kwargs = {
        "lm_model": lms[0],
        "lm_weight": cfg.generation.lm_weight
    }
    generator = task.build_generator(models,
                                     cfg.generation,
                                     extra_gen_cls_kwargs=extra_gen_cls_kwargs)

    # Handle tokenization and BPE
    tokenizer = encoders.build_tokenizer(cfg.tokenizer)
    bpe = encoders.build_bpe(cfg.bpe)

    def decode_fn(x):
        if bpe is not None:
            x = bpe.decode(x)
        if tokenizer is not None:
            x = tokenizer.decode(x)
        return x

    scorer = scoring.build_scorer(cfg.scoring, tgt_dict)

    num_sentences = 0
    has_target = True
    wps_meter = TimeMeter()
    for sample in progress:
        sample = utils.move_to_cuda(sample) if use_cuda else sample
        if "net_input" not in sample:
            continue

        prefix_tokens = None
        if cfg.generation.prefix_size > 0:
            prefix_tokens = sample["target"][:, :cfg.generation.prefix_size]

        constraints = None
        if "constraints" in sample:
            constraints = sample["constraints"]

        gen_timer.start()
        hypos = task.inference_step(
            generator,
            models,
            sample,
            prefix_tokens=prefix_tokens,
            constraints=constraints,
        )
        num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
        gen_timer.stop(num_generated_tokens)

        for i, sample_id in enumerate(sample["id"].tolist()):
            has_target = sample["target"] is not None

            # Remove padding
            if "src_tokens" in sample["net_input"]:
                src_tokens = utils.strip_pad(
                    sample["net_input"]["src_tokens"][i, :], tgt_dict.pad())
            else:
                src_tokens = None

            target_tokens = None
            if has_target:
                target_tokens = (utils.strip_pad(sample["target"][i, :],
                                                 tgt_dict.pad()).int().cpu())

            # Either retrieve the original sentences or regenerate them from tokens.
            if align_dict is not None:
                src_str = task.dataset(
                    cfg.dataset.gen_subset).src.get_original_text(sample_id)
                target_str = task.dataset(
                    cfg.dataset.gen_subset).tgt.get_original_text(sample_id)
            else:
                if src_dict is not None:
                    src_str = src_dict.string(src_tokens,
                                              cfg.common_eval.post_process)
                else:
                    src_str = ""
                if has_target:
                    target_str = tgt_dict.string(
                        target_tokens,
                        cfg.common_eval.post_process,
                        escape_unk=True,
                        extra_symbols_to_ignore=
                        get_symbols_to_strip_from_output(generator),
                    )

            src_str = decode_fn(src_str)
            if has_target:
                target_str = decode_fn(target_str)

            if not cfg.common_eval.quiet:
                if src_dict is not None:
                    print("S-{}\t{}".format(sample_id, src_str),
                          file=output_file)
                if has_target:
                    print("T-{}\t{}".format(sample_id, target_str),
                          file=output_file)

            # Process top predictions
            for j, hypo in enumerate(hypos[i][:cfg.generation.nbest]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo["tokens"].int().cpu(),
                    src_str=src_str,
                    alignment=hypo["alignment"],
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=cfg.common_eval.post_process,
                    extra_symbols_to_ignore=get_symbols_to_strip_from_output(
                        generator),
                )
                detok_hypo_str = decode_fn(hypo_str)
                if not cfg.common_eval.quiet:
                    score = hypo["score"] / math.log(2)  # convert to base 2
                    # original hypothesis (after tokenization and BPE)
                    print(
                        "H-{}\t{}\t{}".format(sample_id, score, hypo_str),
                        file=output_file,
                    )
                    # detokenized hypothesis
                    print(
                        "D-{}\t{}\t{}".format(sample_id, score,
                                              detok_hypo_str),
                        file=output_file,
                    )
                    print(
                        "P-{}\t{}".format(
                            sample_id,
                            " ".join(
                                map(
                                    lambda x: "{:.4f}".format(x),
                                    # convert from base e to base 2
                                    hypo["positional_scores"].div_(math.log(2)
                                                                   ).tolist(),
                                )),
                        ),
                        file=output_file,
                    )

                    if cfg.generation.print_alignment:
                        print(
                            "A-{}\t{}".format(
                                sample_id,
                                " ".join([
                                    "{}-{}".format(src_idx, tgt_idx)
                                    for src_idx, tgt_idx in alignment
                                ]),
                            ),
                            file=output_file,
                        )
                        src_len = len(src_str.split()) + 1
                        tgt_len = len(hypo_str.split()) + 1
                        print(hypo['attention'].size())
                        # print(hypo['attention'].transpose(0,1))
                        print(hypo['attention'][-src_len:, :tgt_len].transpose(
                            0, 1))

                    if cfg.generation.print_step:
                        print(
                            "I-{}\t{}".format(sample_id, hypo["steps"]),
                            file=output_file,
                        )

                    if cfg.generation.retain_iter_history:
                        for step, h in enumerate(hypo["history"]):
                            _, h_str, _ = utils.post_process_prediction(
                                hypo_tokens=h["tokens"].int().cpu(),
                                src_str=src_str,
                                alignment=None,
                                align_dict=None,
                                tgt_dict=tgt_dict,
                                remove_bpe=None,
                            )
                            print(
                                "E-{}_{}\t{}".format(sample_id, step, h_str),
                                file=output_file,
                            )

                # Score only the top hypothesis
                if has_target and j == 0:
                    if align_dict is not None or cfg.common_eval.post_process is not None:
                        # Convert back to tokens for evaluation with unk replacement and/or without BPE
                        target_tokens = tgt_dict.encode_line(
                            target_str, add_if_not_exist=True)
                        hypo_tokens = tgt_dict.encode_line(
                            detok_hypo_str, add_if_not_exist=True)
                    if hasattr(scorer, "add_string"):
                        scorer.add_string(target_str, detok_hypo_str)
                    else:
                        scorer.add(target_tokens, hypo_tokens)

        wps_meter.update(num_generated_tokens)
        progress.log({"wps": round(wps_meter.avg)})
        num_sentences += (sample["nsentences"]
                          if "nsentences" in sample else sample["id"].numel())

    logger.info("NOTE: hypothesis and token scores are output in base 2")
    logger.info(
        "Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)"
        .format(
            num_sentences,
            gen_timer.n,
            gen_timer.sum,
            num_sentences / gen_timer.sum,
            1.0 / gen_timer.avg,
        ))
    if has_target:
        if cfg.bpe and not cfg.generation.sacrebleu:
            if cfg.common_eval.post_process:
                logger.warning(
                    "BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization"
                )
            else:
                logger.warning(
                    "If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words.  Use --sacrebleu for standard 13a BLEU tokenization"
                )
        # use print to be consistent with other main outputs: S-, H-, T-, D- and so on
        print(
            "Generate {} with beam={}: {}".format(cfg.dataset.gen_subset,
                                                  cfg.generation.beam,
                                                  scorer.result_string()),
            file=output_file,
        )

    return scorer
Ejemplo n.º 26
0
def main(args, task=None, model_state=None):
    check_args(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 4000000
    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    if task is None:
        # Load dataset splits
        task = tasks.setup_task(args)
        task.load_dataset(args.gen_subset)
        logger.info("| {} {} {} examples".format(
            args.data, args.gen_subset, len(task.dataset(args.gen_subset))))

    all_trans = []
    if 'audio' in args.task:
        """
            tasks that load tsv data
            trans_path: raw trans (before bpe)
        """
        trans_path = os.path.join(args.data, "{}.word".format(args.gen_subset))
        with open(trans_path, "r") as f:
            for line in f:
                all_trans.append(line)

    # Set dictionary
    tgt_dict = task.target_dictionary

    logger.info("| decoding with criterion {}".format(args.criterion))

    # Load ensemble

    logger.info("| loading model(s) from {}".format(args.path))
    models, criterions, _ = load_models_and_criterions(
        args.path,
        data_path=args.data,
        arg_overrides=eval(args.model_overrides),  # noqa
        task=task,
        model_state=model_state,
    )
    optimize_models(args, use_cuda, models)

    # Load dataset (possibly sharded)
    itr = get_dataset_itr(args, task, models)

    # Initialize generator
    gen_timer = StopwatchMeter()

    generator = CIF_BERT_Decoder(args, task.target_dictionary)

    num_sentences = 0

    if args.results_path is not None and not os.path.exists(args.results_path):
        os.makedirs(args.results_path)

    res_files = prepare_result_files(args)
    errs_t = 0
    lengths_t = 0
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if "net_input" not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample["target"][:, :args.prefix_size]

            gen_timer.start()
            hypos = task.inference_step(generator, models, sample,
                                        prefix_tokens)
            num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample["id"].tolist()):
                speaker = None
                # id = task.dataset(args.gen_subset).ids[int(sample_id)]
                id = sample_id
                toks = sample["target"][
                    i, :] if 'target_label' not in sample else sample[
                        "target_label"][i, :]
                target_tokens = (utils.strip_pad(toks,
                                                 tgt_dict.pad()).int().cpu())
                trans = all_trans[id] if all_trans else task.dataset(
                    args.gen_subset).ids[sample_id][1]['output']['text'].strip(
                    )
                # Process top predictions
                errs, length = process_predictions(args, hypos[i], None,
                                                   tgt_dict, target_tokens,
                                                   res_files, speaker, id,
                                                   trans)
                errs_t += errs
                lengths_t += length

            wps_meter.update(num_generated_tokens)
            t.log({"wps": round(wps_meter.avg)})
            num_sentences += sample[
                "nsentences"] if "nsentences" in sample else sample[
                    "id"].numel()

    wer = None

    if lengths_t > 0:
        wer = errs_t * 100.0 / lengths_t
        logger.info(f"WER: {wer}")

    logger.info("| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}"
                "sentences/s, {:.2f} tokens/s)".format(
                    num_sentences,
                    gen_timer.n,
                    gen_timer.sum,
                    num_sentences / gen_timer.sum,
                    1.0 / gen_timer.avg,
                ))
    logger.info("| Generate {} with beam={}".format(args.gen_subset,
                                                    args.beam))

    return task, wer
Ejemplo n.º 27
0
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample; periodically print out
        randomly sampled predictions if model is in training mode, otherwise
        aggregate word error stats for validation.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        dict = self.scorer.dict
        if model.training:
            if (len(self.args.scheduled_sampling_probs) > 1 or \
                self.args.scheduled_sampling_probs[0] < 1.0) and \
                self.epoch >= self.args.start_scheduled_sampling_epoch:
                # scheduled sampling
                ss_prob = self.args.scheduled_sampling_probs[min(
                    self.epoch - self.args.start_scheduled_sampling_epoch,
                    len(self.args.scheduled_sampling_probs) - 1)]
                assert isinstance(model.decoder, FairseqIncrementalDecoder)
                incremental_states = {}
                encoder_input = {
                    k: v
                    for k, v in sample['net_input'].items()
                    if k != 'prev_output_tokens'
                }
                encoder_out = model.encoder(**encoder_input)
                target = sample['target']
                tokens = sample['net_input']['prev_output_tokens']
                lprobs = []
                for step in range(target.size(1)):
                    if step > 0:
                        sampling_mask = torch.rand(
                            [target.size(0), 1],
                            device=target.device).lt(ss_prob)
                        feed_tokens = torch.where(sampling_mask,
                                                  tokens[:,
                                                         step:step + 1], pred)
                    else:
                        feed_tokens = tokens[:, step:step + 1]
                    log_probs, _ = self._decode(feed_tokens, model,
                                                encoder_out,
                                                incremental_states)
                    pred = log_probs.argmax(-1, keepdim=True)
                    lprobs.append(log_probs)
                lprobs = torch.stack(lprobs, dim=1)
            else:
                # normal training
                net_output = model(**sample['net_input'])
                lprobs = model.get_normalized_probs(net_output, log_probs=True)
                target = model.get_targets(sample, net_output)
        else:
            assert isinstance(model.decoder, FairseqIncrementalDecoder)
            incremental_states = {}
            encoder_input = {
                k: v
                for k, v in sample['net_input'].items()
                if k != 'prev_output_tokens'
            }
            encoder_out = model.encoder(**encoder_input)
            target = sample['target']
            # make the maximum decoding length equal to at least the length of
            # target, and the length of encoder_out if possible
            maxlen = max(encoder_out['encoder_out'][0].size(1), target.size(1))
            tokens = target.new_full([target.size(0), maxlen + 2],
                                     self.padding_idx)
            tokens[:, 0] = dict.eos()
            lprobs = []
            attn = [] if getattr(model.decoder, 'need_attn', False) else None
            dummy_log_probs = encoder_out['encoder_out'][0].new_full(
                [target.size(0), len(dict)], -np.log(len(dict)))
            for step in range(maxlen + 1):  # one extra step for EOS marker
                is_eos = tokens[:, step].eq(dict.eos())
                # if all predictions are finished (i.e., ended with eos),
                # pad lprobs to target length with dummy log probs,
                # truncate tokens up to this step and break
                if step > 0 and is_eos.sum() == is_eos.size(0):
                    for _ in range(step, target.size(1)):
                        lprobs.append(dummy_log_probs)
                    tokens = tokens[:, :step + 1]
                    break
                log_probs, attn_scores = self._decode(tokens[:, :step + 1],
                                                      model, encoder_out,
                                                      incremental_states)
                tokens[:, step + 1] = log_probs.argmax(-1)
                if step > 0:  # deal with finished predictions
                    # make log_probs uniform if the previous output token is EOS
                    # and add consecutive EOS to the end of prediction
                    log_probs[is_eos, :] = -np.log(log_probs.size(1))
                    tokens[is_eos, step + 1] = dict.eos()
                if step < target.size(1):
                    lprobs.append(log_probs)
                if getattr(model.decoder, 'need_attn', False):
                    attn.append(attn_scores)
            # bsz x min(tgtlen, maxlen + 1) x vocab_size
            lprobs = torch.stack(lprobs, dim=1)
            if getattr(model.decoder, 'need_attn', False):
                # bsz x (maxlen + 1) x (length of encoder_out)
                attn = torch.stack(attn, dim=1)
        # word error stats code starts
        if not model.training or (
                self.num_updates // self.args.print_interval >
            (self.num_updates - 1) // self.args.print_interval):
            pred = lprobs.argmax(-1).cpu() if model.training else \
                tokens[:, 1:].data.cpu() # bsz x len

            if not model.training:  # validation step, compute WER stats with scorer
                assert pred.size(0) == target.size(0)
                self.scorer.reset()
                for i in range(target.size(0)):
                    utt_id = sample['utt_id'][i]
                    id = sample['id'].data[i].item()
                    #ref_tokens = dict.string(target.data[i])
                    # if it is a dummy batch (e.g., a "padding" batch in a sharded
                    # dataset), id might exceeds the dataset size; in that case we
                    # just skip it
                    if id < len(self.valid_tgt_dataset):
                        ref_tokens = self.valid_tgt_dataset.get_original_tokens(
                            id)
                        pred_tokens = dict.string(pred.data[i])
                        self.scorer.add_evaluation(
                            utt_id,
                            ref_tokens,
                            pred_tokens,
                            bpe_symbol=self.args.remove_bpe)
            else:  # print a randomly sampled result every print_interval updates
                assert pred.size() == target.size()
                with data_utils.numpy_seed(self.num_updates):
                    i = np.random.randint(0, len(sample['id']))
                id = sample['id'].data[i].item()
                length = utils.strip_pad(target.data[i],
                                         self.padding_idx).size(0)
                #ref_one = dict.tokens_to_sentence(dict.string(target.data[i]))
                ref_one = self.train_tgt_dataset.get_original_text(
                    id, dict, bpe_symbol=self.args.remove_bpe)
                pred_one = dict.tokens_to_sentence(
                    dict.string(pred.data[i][:length]),
                    bpe_symbol=self.args.remove_bpe,
                )
                print('| sample REF: ' + ref_one)
                print('| sample PRD: ' + pred_one)
        # word error stats code ends
        lprobs = lprobs.view(-1, lprobs.size(-1))
        loss = F.nll_loss(
            lprobs,
            target.view(-1),
            ignore_index=self.padding_idx,
            reduction='sum' if reduce else 'none',
        )
        sample_size = sample['target'].size(
            0) if self.args.sentence_avg else sample['ntokens']
        logging_output = {
            'loss': utils.item(loss.data) if reduce else loss.data,
            'nll_loss': utils.item(loss.data) if reduce else loss.data,
            'ntokens': sample['ntokens'],
            'nsentences': sample['target'].size(0),
            'sample_size': sample_size,
        }
        if not model.training:  # do not compute word error in training mode
            logging_output['word_error'] = self.scorer.tot_word_error()
            logging_output['word_count'] = self.scorer.tot_word_count()
            logging_output['char_error'] = self.scorer.tot_char_error()
            logging_output['char_count'] = self.scorer.tot_char_count()
        return loss, sample_size, logging_output
Ejemplo n.º 28
0
    def gen_candidates(self, config, sample):
        use_cuda = torch.cuda.is_available() and not self.args.cpu

        sample = utils.move_to_cuda(sample) if use_cuda else sample
        if 'net_input' not in sample:
            return

        src_dict = self.task.source_dictionary
        tgt_dict = self.task.target_dictionary

        regex = r" ##"
        args = self.make_fit_args(config)
        generator = self.task.build_generator(args)

        prefix_tokens = None
        if args.prefix_size > 0:
            prefix_tokens = sample['target'][:, :args.prefix_size]

        hypos = self.task.inference_step(generator, self.models, sample,
                                         prefix_tokens)

        num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)

        to_score = []
        for i, sample_id in enumerate(sample['id'].tolist()):
            # Remove padding
            src_tokens = utils.strip_pad(
                sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
            target_tokens = utils.strip_pad(sample['target'][i, :],
                                            tgt_dict.pad()).int().cpu()
            src_str = src_dict.string(src_tokens, args.remove_bpe)
            target_str = tgt_dict.string(target_tokens,
                                         args.remove_bpe,
                                         escape_unk=True)

            # Process top predictions
            for j, hypo in enumerate(hypos[i][:args.nbest]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'],
                    align_dict=self.align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe,
                )

                # Score only the top hypothesis
                if j == 0:
                    #if align_dict is not None or args.remove_bpe is not None:
                    #    # Convert back to tokens for evaluation with unk replacement and/or without BPE
                    #    target_tokens = tgt_dict.encode_line(target_str, add_if_not_exist=True)
                    #if hasattr(scorer, 'add_string'):
                    #    self.scorer.add_string(target_str, hypo_str)
                    #else:
                    #self.scorer.add(target_tokens, hypo_tokens)
                    # print(hypo_str)
                    # print(target_str)
                    hypo_str = re.sub(regex, "", hypo_str, 0, re.MULTILINE)
                    target_str = re.sub(regex, "", target_str, 0, re.MULTILINE)
                    to_score.append((target_str.split(), hypo_str.split()))

        return to_score
Ejemplo n.º 29
0
def _main(args, output_file):
    logging.basicConfig(
        format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        level=logging.INFO,
        stream=output_file,
    )
    logger = logging.getLogger('fairseq_cli.generate')

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    logger.info('loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        utils.split_paths(args.path),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        default_log_format=('tqdm' if not args.no_progress_bar else 'none'),
    )

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(models, args)

    # Handle tokenization and BPE
    tokenizer = encoders.build_tokenizer(args)
    bpe = encoders.build_bpe(args)

    def decode_fn(x):
        if bpe is not None:
            x = bpe.decode(x)
        if tokenizer is not None:
            x = tokenizer.decode(x)
        return x

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    wps_meter = TimeMeter()
    for sample in progress:
        sample = utils.move_to_cuda(sample) if use_cuda else sample
        if 'net_input' not in sample:
            continue

        prefix_tokens = None
        if args.prefix_size > 0:
            prefix_tokens = sample['target'][:, :args.prefix_size]

        gen_timer.start()
        hypos = task.inference_step(generator, models, sample, prefix_tokens)
        num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
        gen_timer.stop(num_generated_tokens)

        for i, sample_id in enumerate(sample['id'].tolist()):
            has_target = sample['target'] is not None

            # Remove padding
            src_tokens = utils.strip_pad(
                sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
            target_tokens = None
            if has_target:
                target_tokens = utils.strip_pad(sample['target'][i, :],
                                                tgt_dict.pad()).int().cpu()

            # Either retrieve the original sentences or regenerate them from tokens.
            if align_dict is not None:
                src_str = task.dataset(
                    args.gen_subset).src.get_original_text(sample_id)
                target_str = task.dataset(
                    args.gen_subset).tgt.get_original_text(sample_id)
            else:
                if src_dict is not None:
                    src_str = src_dict.string(src_tokens, args.remove_bpe)
                else:
                    src_str = ""
                if has_target:
                    target_str = tgt_dict.string(target_tokens,
                                                 args.remove_bpe,
                                                 escape_unk=True,
                                                 extra_symbols_to_ignore={
                                                     generator.eos,
                                                 })

            src_str = decode_fn(src_str)
            if has_target:
                target_str = decode_fn(target_str)

            if not args.quiet:
                if src_dict is not None:
                    print('S-{}\t{}'.format(sample_id, src_str),
                          file=output_file)
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str),
                          file=output_file)

            # Process top predictions
            for j, hypo in enumerate(hypos[i][:args.nbest]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'],
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe,
                    extra_symbols_to_ignore={
                        generator.eos,
                    })
                detok_hypo_str = decode_fn(hypo_str)
                if not args.quiet:
                    score = hypo['score'] / math.log(2)  # convert to base 2
                    # original hypothesis (after tokenization and BPE)
                    print('H-{}\t{}\t{}'.format(sample_id, score, hypo_str),
                          file=output_file)
                    # detokenized hypothesis
                    print('D-{}\t{}\t{}'.format(sample_id, score,
                                                detok_hypo_str),
                          file=output_file)
                    print(
                        'P-{}\t{}'.format(
                            sample_id,
                            ' '.join(
                                map(
                                    lambda x: '{:.4f}'.format(x),
                                    # convert from base e to base 2
                                    hypo['positional_scores'].div_(math.log(2)
                                                                   ).tolist(),
                                ))),
                        file=output_file)

                    if args.print_alignment:
                        print('A-{}\t{}'.format(
                            sample_id, ' '.join([
                                '{}-{}'.format(src_idx, tgt_idx)
                                for src_idx, tgt_idx in alignment
                            ])),
                              file=output_file)

                    if args.print_step:
                        print('I-{}\t{}'.format(sample_id, hypo['steps']),
                              file=output_file)

                    if 'enc_selection' in hypo:
                        print('Menc-{}\t{}'.format(sample_id,
                                                   hypo['enc_selection']),
                              file=output_file)
                    if 'dec_selection' in hypo:
                        print('Mdec-{}\t{}'.format(sample_id,
                                                   hypo['dec_selection']),
                              file=output_file)
                    if args.print_attn_confidence:
                        print('C-{}\t{}'.format(sample_id,
                                                hypo['enc_self_attn_conf']),
                              file=output_file)

                    if getattr(args, 'retain_iter_history', False):
                        for step, h in enumerate(hypo['history']):
                            _, h_str, _ = utils.post_process_prediction(
                                hypo_tokens=h['tokens'].int().cpu(),
                                src_str=src_str,
                                alignment=None,
                                align_dict=None,
                                tgt_dict=tgt_dict,
                                remove_bpe=None,
                            )
                            print('E-{}_{}\t{}'.format(sample_id, step, h_str),
                                  file=output_file)

                # Score only the top hypothesis
                if has_target and j == 0:
                    if align_dict is not None or args.remove_bpe is not None:
                        # Convert back to tokens for evaluation with unk replacement and/or without BPE
                        target_tokens = tgt_dict.encode_line(
                            target_str, add_if_not_exist=True)
                        hypo_tokens = tgt_dict.encode_line(
                            detok_hypo_str, add_if_not_exist=True)
                    if hasattr(scorer, 'add_string'):
                        scorer.add_string(target_str, detok_hypo_str)
                    else:
                        scorer.add(target_tokens, hypo_tokens)

        wps_meter.update(num_generated_tokens)
        progress.log({'wps': round(wps_meter.avg)})
        num_sentences += sample['nsentences']

    logger.info('NOTE: hypothesis and token scores are output in base 2')
    logger.info(
        'Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        if args.bpe and not args.sacrebleu:
            if args.remove_bpe:
                logger.warning(
                    "BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization"
                )
            else:
                logger.warning(
                    "If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words.  Use --sacrebleu for standard 13a BLEU tokenization"
                )
        logger.info('Generate {} with beam={}: {}'.format(
            args.gen_subset, args.beam, scorer.result_string()))

    return scorer
Ejemplo n.º 30
0
def _main(args, output_file):
    logging.basicConfig(
        format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        level=logging.INFO,
        stream=output_file,
    )
    logger = logging.getLogger('fairseq_cli.generate')

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    logger.info('loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(os.pathsep),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    from fairseq.sequence_scorer import SequenceScorer
    generator = task.build_generator(args)

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    avg_ranks = AverageMeter()
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        all_ents = []
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample['target'][:, :args.prefix_size]

            gen_timer.start()
            hypos = task.inference_step(generator, models, sample,
                                        prefix_tokens)
            if 'ents' in sample:
                all_ents.extend(sample['ents'])
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample['id'].tolist()):
                has_target = sample['target'] is not None

                # Remove padding
                src_tokens = utils.strip_pad(
                    sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
                target_tokens = None
                if has_target:
                    target_tokens = utils.strip_pad(
                        sample['target'][i, :], tgt_dict.pad()).int().cpu()

                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(
                        args.gen_subset).src.get_original_text(sample_id)
                    target_str = task.dataset(
                        args.gen_subset).tgt.get_original_text(sample_id)
                else:
                    if src_dict is not None:
                        src_str = src_dict.string(src_tokens, args.remove_bpe)
                    else:
                        src_str = ""
                    if has_target:
                        target_str = tgt_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)

                if not args.quiet:
                    if src_dict is not None:
                        print('S-{}\t{}'.format(sample_id, src_str),
                              file=output_file)
                    if has_target:
                        print('T-{}\t{}'.format(sample_id, target_str),
                              file=output_file)

                # Process top predictions
                for j, hypo in enumerate(hypos[i][:args.nbest]):
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int().cpu(),
                        src_str=src_str,
                        alignment=hypo['alignment'],
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )

                    if not args.quiet:
                        score = hypo['score'] / math.log(
                            2)  # convert to base 2
                        print('H-{}\t{}\t{}'.format(sample_id, score,
                                                    hypo_str),
                              file=output_file)
                        print(
                            'P-{}\t{}'.format(
                                sample_id,
                                ' '.join(
                                    map(
                                        lambda x: '{:.4f}'.format(x),
                                        # convert from base e to base 2
                                        hypo['positional_scores'].div_(
                                            math.log(2)).tolist(),
                                    ))),
                            file=output_file)

                        if args.print_alignment:
                            print('A-{}\t{}'.format(
                                sample_id, ' '.join([
                                    '{}-{}'.format(src_idx, tgt_idx)
                                    for src_idx, tgt_idx in alignment
                                ])),
                                  file=output_file)

                        if args.print_step:
                            print('I-{}\t{}'.format(sample_id, hypo['steps']),
                                  file=output_file)

                        if getattr(args, 'retain_iter_history', False):
                            for step, h in enumerate(hypo['history']):
                                _, h_str, _ = utils.post_process_prediction(
                                    hypo_tokens=h['tokens'].int().cpu(),
                                    src_str=src_str,
                                    alignment=None,
                                    align_dict=None,
                                    tgt_dict=tgt_dict,
                                    remove_bpe=None,
                                )
                                print('E-{}_{}\t{}'.format(
                                    sample_id, step, h_str),
                                      file=output_file)

                        if getattr(args, 'score_reference', False):
                            print('R-{}\t{}'.format(
                                sample_id, '{:.4f}'.format(hypo['avg_ranks'])),
                                  file=output_file)

                    # Score only the top hypothesis
                    if getattr(args, 'score_reference', False):
                        avg_ranks.update(hypo['avg_ranks'])
                    if has_target and j == 0:
                        if align_dict is not None or args.remove_bpe is not None:
                            # Convert back to tokens for evaluation with unk replacement and/or without BPE
                            target_tokens = tgt_dict.encode_line(
                                target_str, add_if_not_exist=True)
                        if hasattr(scorer, 'add_string'):
                            scorer.add_string(target_str, hypo_str)
                        else:
                            scorer.add(target_tokens, hypo_tokens)
            wps_meter.update(num_generated_tokens)
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += sample['nsentences']

    logger.info('NOTE: hypothesis and token scores are output in base 2')
    logger.info(
        'Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        logger.info('Generate {} with beam={}: {}'.format(
            args.gen_subset, args.beam, scorer.result_string()))
    if getattr(args, 'score_reference', False):
        logger.info('Average rank of reference={:.4f}, Entropy={:.4f}'.format(
            avg_ranks.avg,
            torch.cat(all_ents, dim=0).mean()))

    return scorer
Ejemplo n.º 31
0
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(':'),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]
        ),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample['target'][:, :args.prefix_size]

            gen_timer.start()
            hypos = task.inference_step(generator, models, sample, prefix_tokens)
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample['id'].tolist()):
                has_target = sample['target'] is not None

                # Remove padding
                src_tokens = utils.strip_pad(sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
                target_tokens = None
                if has_target:
                    target_tokens = utils.strip_pad(sample['target'][i, :], tgt_dict.pad()).int().cpu()

                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
                    target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id)
                else:
                    if src_dict is not None:
                        src_str = src_dict.string(src_tokens, args.remove_bpe)
                    else:
                        src_str = ""
                    if has_target:
                        target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True)

                if not args.quiet:
                    if src_dict is not None:
                        print('S-{}\t{}'.format(sample_id, src_str))
                    if has_target:
                        print('T-{}\t{}'.format(sample_id, target_str))

                # Process top predictions
                for j, hypo in enumerate(hypos[i][:args.nbest]):
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int().cpu(),
                        src_str=src_str,
                        alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )

                    if not args.quiet:
                        print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str))
                        print('P-{}\t{}'.format(
                            sample_id,
                            ' '.join(map(
                                lambda x: '{:.4f}'.format(x),
                                hypo['positional_scores'].tolist(),
                            ))
                        ))

                        if args.print_alignment:
                            print('A-{}\t{}'.format(
                                sample_id,
                                ' '.join(map(lambda x: str(utils.item(x)), alignment))
                            ))

                    # Score only the top hypothesis
                    if has_target and j == 0:
                        if align_dict is not None or args.remove_bpe is not None:
                            # Convert back to tokens for evaluation with unk replacement and/or without BPE
                            target_tokens = tgt_dict.encode_line(target_str, add_if_not_exist=True)
                        if hasattr(scorer, 'add_string'):
                            scorer.add_string(target_str, hypo_str)
                        else:
                            scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(num_generated_tokens)
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += sample['nsentences']

    print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
        num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
    return scorer
Ejemplo n.º 32
0
def validate_translation(args, trainer, task, epoch_itr, generator):
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary
    models = [trainer.get_model()]
    bleu_dict = {key: None for key in task.eval_lang_pairs}

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer_dict = {
            key: bleu.SacrebleuScorer()
            for key in task.eval_lang_pairs
        }
    else:
        scorer_dict = {
            key: bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
            for key in task.eval_lang_pairs
        }

    itr = task.get_batch_iterator(
        dataset=task.dataset('valid'),
        max_tokens=args.max_tokens_valid,
        max_sentences=args.max_sentences_valid,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            trainer.get_model().max_positions(),
        ),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        seed=args.seed,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
        num_workers=args.num_workers,
        noskip=True,
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.build_progress_bar(args,
                                               itr,
                                               epoch_itr.epoch,
                                               prefix='translate subset',
                                               no_progress_bar='simple')

    num_sentences = 0
    has_target = True
    #with progress_bar.build_progress_bar(args, itr) as t:
    for samples in progress:
        if torch.cuda.is_available() and not args.cpu:
            samples = utils.move_to_cuda(samples)
        #if 'net_input' not in samples:
        #    continue

        prefix_tokens = None
        for key, sample in samples.items():
            hypos = task.inference_step(generator, models, sample,
                                        prefix_tokens)
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)

            for i, sample_id in enumerate(sample['id'].tolist()):
                has_target = sample['target'] is not None

                target_tokens = None
                if has_target:
                    target_tokens = utils.strip_pad(
                        sample['target'][i, :], tgt_dict.pad()).int().cpu()
                # Remove padding
                if args.sde:
                    src_tokens = target_tokens
                else:
                    src_tokens = utils.strip_pad(
                        sample['net_input']['src_tokens'][i, :],
                        tgt_dict.pad())

                # Either retrieve the original sentences or regenerate them from tokens.
                #if src_dict is not None:
                #    src_str = src_dict.string(src_tokens, args.remove_bpe)
                #else:
                #    src_str = ""
                if has_target:
                    target_str = tgt_dict.string(target_tokens,
                                                 args.remove_bpe,
                                                 escape_unk=True)

                #if not args.quiet:
                #    if src_dict is not None:
                #        print('S-{}\t{}'.format(sample_id, src_str))
                #    if has_target:
                #        print('T-{}\t{}'.format(sample_id, target_str))

                # Process top predictions
                for j, hypo in enumerate(hypos[i][:args.nbest]):
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int().cpu(),
                        src_str="",
                        alignment=hypo['alignment'].int().cpu()
                        if hypo['alignment'] is not None else None,
                        align_dict=None,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )

                #if not args.quiet:
                #    print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str))
                #    print('P-{}\t{}'.format(
                #        sample_id,
                #        ' '.join(map(
                #            lambda x: '{:.4f}'.format(x),
                #            hypo['positional_scores'].tolist(),
                #        ))
                #    ))

                #    if args.print_alignment:
                #        print('A-{}\t{}'.format(
                #            sample_id,
                #            ' '.join(map(lambda x: str(utils.item(x)), alignment))
                #        ))

                # Score only the top hypothesis
                if has_target and j == 0:
                    if args.remove_bpe is not None:
                        # Convert back to tokens for evaluation with unk replacement and/or without BPE
                        target_tokens = tgt_dict.encode_line(
                            target_str, add_if_not_exist=True)
                    if hasattr(scorer_dict[key], 'add_string'):
                        scorer_dict[key].add_string(target_str, hypo_str)
                    else:
                        scorer_dict[key].add(target_tokens, hypo_tokens)

            num_sentences += sample['nsentences']
    for key, scorer in scorer_dict.items():
        bleu_dict[key] = scorer.score()
    return bleu_dict