コード例 #1
0
ファイル: core.py プロジェクト: tma15/fairseq
 def decode_target(self, hypos):
     """Method to decode target string from tokens"""
     hypo_str = self.tgt_dict.string(
         hypos[0][0]["tokens"].int().cpu(),
         self.remove_bpe,
         get_symbols_to_strip_from_output(self.generator),
     )
     if self.bpe is not None:
         hypo_str = self.bpe.decode(hypo_str)
     if self.tokenizer is not None:
         hypo_str = self.tokenizer.decode(hypo_str)
     return hypo_str
コード例 #2
0
def get_text(cfg, generator, model, sample, bpe):
    decoder_output = task.inference_step(generator,
                                         model,
                                         sample,
                                         prefix_tokens=None,
                                         constraints=None)
    decoder_output = decoder_output[0][0]  #top1

    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
        hypo_tokens=decoder_output["tokens"].int().cpu(),
        src_str="",
        alignment=decoder_output["alignment"],
        align_dict=None,
        tgt_dict=model[0].decoder.dictionary,
        remove_bpe=cfg.common_eval.post_process,
        extra_symbols_to_ignore=generate.get_symbols_to_strip_from_output(
            generator),
    )

    detok_hypo_str = bpe.decode(hypo_str)

    return detok_hypo_str
コード例 #3
0
ファイル: interactive.py プロジェクト: veralily/fairseq
def main(cfg: FairseqConfig):
    if isinstance(cfg, Namespace):
        cfg = convert_namespace_to_omegaconf(cfg)

    start_time = time.time()
    total_translate_time = 0

    utils.import_user_module(cfg.common)

    if cfg.interactive.buffer_size < 1:
        cfg.interactive.buffer_size = 1
    if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None:
        cfg.dataset.batch_size = 1

    assert (not cfg.generation.sampling
            or cfg.generation.nbest == cfg.generation.beam
            ), "--sampling requires --nbest to be equal to --beam"
    assert (not cfg.dataset.batch_size
            or cfg.dataset.batch_size <= cfg.interactive.buffer_size
            ), "--batch-size cannot be larger than --buffer-size"

    logger.info(cfg)

    # Fix seed for stochastic decoding
    if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
        np.random.seed(cfg.common.seed)
        utils.set_torch_seed(cfg.common.seed)

    use_cuda = torch.cuda.is_available() and not cfg.common.cpu

    # Setup task, e.g., translation
    task = tasks.setup_task(cfg.task)

    # Load ensemble
    overrides = ast.literal_eval(cfg.common_eval.model_overrides)
    logger.info("loading model(s) from {}".format(cfg.common_eval.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        utils.split_paths(cfg.common_eval.path),
        arg_overrides=overrides,
        task=task,
        suffix=cfg.checkpoint.checkpoint_suffix,
        strict=(cfg.checkpoint.checkpoint_shard_count == 1),
        num_shards=cfg.checkpoint.checkpoint_shard_count,
    )

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # Optimize ensemble for generation
    for model in models:
        if model is None:
            continue
        if cfg.common.fp16:
            model.half()
        if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
            model.cuda()
        model.prepare_for_inference_(cfg)

    # Initialize generator
    generator = task.build_generator(models, cfg.generation)

    # Handle tokenization and BPE
    tokenizer = encoders.build_tokenizer(cfg.tokenizer)
    bpe = encoders.build_bpe(cfg.bpe)

    def encode_fn(x):
        if tokenizer is not None:
            x = tokenizer.encode(x)
        if bpe is not None:
            x = bpe.encode(x)
        return x

    def decode_fn(x):
        if bpe is not None:
            x = bpe.decode(x)
        if tokenizer is not None:
            x = tokenizer.decode(x)
        return x

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(cfg.generation.replace_unk)

    max_positions = utils.resolve_max_positions(
        task.max_positions(), *[model.max_positions() for model in models])

    if cfg.generation.constraints:
        logger.warning(
            "NOTE: Constrained decoding currently assumes a shared subword vocabulary."
        )

    if cfg.interactive.buffer_size > 1:
        logger.info("Sentence buffer size: %s", cfg.interactive.buffer_size)
    logger.info("NOTE: hypothesis and token scores are output in base 2")
    logger.info("Type the input sentence and press return:")
    start_id = 0
    for inputs in buffered_read(cfg.interactive.input,
                                cfg.interactive.buffer_size):
        results = []
        for batch in make_batches(inputs, cfg, task, max_positions, encode_fn):
            bsz = batch.src_tokens.size(0)
            src_tokens = batch.src_tokens
            src_lengths = batch.src_lengths
            constraints = batch.constraints
            if use_cuda:
                src_tokens = src_tokens.cuda()
                src_lengths = src_lengths.cuda()
                if constraints is not None:
                    constraints = constraints.cuda()

            sample = {
                "net_input": {
                    "src_tokens": src_tokens,
                    "src_lengths": src_lengths,
                },
            }
            translate_start_time = time.time()
            translations = task.inference_step(generator,
                                               models,
                                               sample,
                                               constraints=constraints)
            translate_time = time.time() - translate_start_time
            total_translate_time += translate_time
            list_constraints = [[] for _ in range(bsz)]
            if cfg.generation.constraints:
                list_constraints = [unpack_constraints(c) for c in constraints]
            for i, (id,
                    hypos) in enumerate(zip(batch.ids.tolist(), translations)):
                src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad())
                constraints = list_constraints[i]
                results.append((
                    start_id + id,
                    src_tokens_i,
                    hypos,
                    {
                        "constraints": constraints,
                        "time": translate_time / len(translations),
                    },
                ))

        # sort output to match input order
        for id_, src_tokens, hypos, info in sorted(results,
                                                   key=lambda x: x[0]):
            if src_dict is not None:
                src_str = src_dict.string(src_tokens,
                                          cfg.common_eval.post_process)
                print("S-{}\t{}".format(id_, src_str))
                print("W-{}\t{:.3f}\tseconds".format(id_, info["time"]))
                for constraint in info["constraints"]:
                    print("C-{}\t{}".format(
                        id_,
                        tgt_dict.string(constraint,
                                        cfg.common_eval.post_process)))

            # Process top predictions
            for hypo in hypos[:min(len(hypos), cfg.generation.nbest)]:
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo["tokens"].int().cpu(),
                    src_str=src_str,
                    alignment=hypo["alignment"],
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=cfg.common_eval.post_process,
                    extra_symbols_to_ignore=get_symbols_to_strip_from_output(
                        generator),
                )
                detok_hypo_str = decode_fn(hypo_str)
                score = hypo["score"] / math.log(2)  # convert to base 2
                # original hypothesis (after tokenization and BPE)
                print("H-{}\t{}\t{}".format(id_, score, hypo_str))
                # detokenized hypothesis
                print("D-{}\t{}\t{}".format(id_, score, detok_hypo_str))
                print("P-{}\t{}".format(
                    id_,
                    " ".join(
                        map(
                            lambda x: "{:.4f}".format(x),
                            # convert from base e to base 2
                            hypo["positional_scores"].div_(math.log(2)
                                                           ).tolist(),
                        )),
                ))
                if cfg.generation.print_alignment:
                    alignment_str = " ".join(
                        ["{}-{}".format(src, tgt) for src, tgt in alignment])
                    print("A-{}\t{}".format(id_, alignment_str))

        # update running id_ counter
        start_id += len(inputs)

    logger.info("Total time: {:.3f} seconds; translation time: {:.3f}".format(
        time.time() - start_time, total_translate_time))
コード例 #4
0
ファイル: api.py プロジェクト: HieuNgoUIT/fscustomize
    def infer(self, lines_of_text):
        context = self.context

        bpe = context['bpe']
        tokenizer = context['tokenizer']
        cfg = context['cfg']
        task = context['task']
        max_positions = context['max_positions']
        use_cuda = context['use_cuda']
        generator = context['generator']
        models = context['models']
        src_dict = context['src_dict']
        tgt_dict = context['tgt_dict']
        align_dict = context['align_dict']

        def encode_fn(x):
            if tokenizer is not None:
                x = tokenizer.encode(x)
            if bpe is not None:
                x = bpe.encode(x)
            return x

        def decode_fn(x):
            if bpe is not None:
                x = bpe.decode(x)
            if tokenizer is not None:
                x = tokenizer.decode(x)
            return x

        start_id = 0
        for inputs in [lines_of_text]:
            results = []
            for batch in make_batches(inputs, cfg, task, max_positions,
                                      encode_fn):
                src_tokens = batch.src_tokens
                src_lengths = batch.src_lengths
                if use_cuda:
                    src_tokens = src_tokens.cuda()
                    src_lengths = src_lengths.cuda()
                sample = {
                    "net_input": {
                        "src_tokens": src_tokens,
                        "src_lengths": src_lengths,
                    },
                }
                translations = task.inference_step(generator, models, sample)
                for i, (id, hypos) in enumerate(
                        zip(batch.ids.tolist(), translations)):
                    src_tokens_i = utils.strip_pad(src_tokens[i],
                                                   tgt_dict.pad())
                    results.append((
                        start_id + id,
                        src_tokens_i,
                        hypos,
                    ))

            # sort output to match input order
            for id_, src_tokens, hypos in sorted(results, key=lambda x: x[0]):
                if src_dict is not None:
                    src_str = src_dict.string(src_tokens,
                                              cfg.common_eval.post_process)

                # Process top predictions
                for hypo in hypos[:min(len(hypos), cfg.generation.nbest)]:
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo["tokens"].int().cpu(),
                        src_str=src_str,
                        alignment=hypo["alignment"],
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=cfg.common_eval.post_process,
                        extra_symbols_to_ignore=
                        get_symbols_to_strip_from_output(generator),
                    )
                    detok_hypo_str = decode_fn(hypo_str)
                    score = hypo["score"] / math.log(2)  # convert to base 2
                    #print(detok_hypo_str, hypo_str, hypo_tokens)
                    yield (detok_hypo_str, hypo_str, hypo_tokens)

            # update running id_ counter
            start_id += len(inputs)
コード例 #5
0
    def translate(self, inputs, constraints=None):
        if self.constrained_decoding and constraints is None:
            raise ValueError(
                "Constraints cant be None in constrained decoding mode")
        if not self.constrained_decoding and constraints is not None:
            raise ValueError(
                "Cannot pass constraints during normal translation")
        if constraints:
            constrained_decoding = True
            modified_inputs = []
            for _input, constraint in zip(inputs, constraints):
                modified_inputs.append(_input + f"\t{constraint}")
            inputs = modified_inputs
        else:
            constrained_decoding = False

        start_id = 0
        results = []
        final_translations = []
        for batch in make_batches(
                inputs,
                self.cfg,
                self.task,
                self.max_positions,
                self.encode_fn,
                constrained_decoding,
        ):
            bsz = batch.src_tokens.size(0)
            src_tokens = batch.src_tokens
            src_lengths = batch.src_lengths
            constraints = batch.constraints
            if self.use_cuda:
                src_tokens = src_tokens.cuda()
                src_lengths = src_lengths.cuda()
                if constraints is not None:
                    constraints = constraints.cuda()

            sample = {
                "net_input": {
                    "src_tokens": src_tokens,
                    "src_lengths": src_lengths,
                },
            }

            translations = self.task.inference_step(self.generator,
                                                    self.models,
                                                    sample,
                                                    constraints=constraints)

            list_constraints = [[] for _ in range(bsz)]
            if constrained_decoding:
                list_constraints = [unpack_constraints(c) for c in constraints]
            for i, (id,
                    hypos) in enumerate(zip(batch.ids.tolist(), translations)):
                src_tokens_i = utils.strip_pad(src_tokens[i],
                                               self.tgt_dict.pad())
                constraints = list_constraints[i]
                results.append((
                    start_id + id,
                    src_tokens_i,
                    hypos,
                    {
                        "constraints": constraints,
                    },
                ))

        # sort output to match input order
        for id_, src_tokens, hypos, _ in sorted(results, key=lambda x: x[0]):
            src_str = ""
            if self.src_dict is not None:
                src_str = self.src_dict.string(
                    src_tokens, self.cfg.common_eval.post_process)

            # Process top predictions
            for hypo in hypos[:min(len(hypos), self.cfg.generation.nbest)]:
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo["tokens"].int().cpu(),
                    src_str=src_str,
                    alignment=hypo["alignment"],
                    align_dict=self.align_dict,
                    tgt_dict=self.tgt_dict,
                    remove_bpe="subword_nmt",
                    extra_symbols_to_ignore=get_symbols_to_strip_from_output(
                        self.generator),
                )
                detok_hypo_str = self.decode_fn(hypo_str)
                final_translations.append(detok_hypo_str)
        return final_translations
コード例 #6
0
    def translate():
        if request.method == 'POST':
            inputs_ = request.get_data().decode('utf-8')
        print('Original  text : ', inputs_)
        inputs = json.loads(inputs_)['str']
        prefix_src = json.loads(inputs_)['prefix']
        prefix_src, _ = preprocess(prefix_src,
                                   'en') if prefix_src != "" else (None, None)
        print(prefix_src)
        inputs, line_sep = preprocess(inputs)
        print('Processed text', inputs)
        start_id = 0
        results = []
        for batch in make_batches(inputs, args, task, max_positions,
                                  encode_fn):
            bsz = batch.src_tokens.size(0)
            src_tokens = batch.src_tokens
            src_lengths = batch.src_lengths
            constraints = batch.constraints
            if use_cuda:
                src_tokens = src_tokens.cuda()
                src_lengths = src_lengths.cuda()
                if constraints is not None:
                    constraints = constraints.cuda()

            sample = {
                'net_input': {
                    'src_tokens': src_tokens,
                    'src_lengths': src_lengths,
                },
            }
            translate_start_time = time.time()
            prefix_tokens = tgt_dict.encode_line(prefix_src[0]).cuda(
            ).unsqueeze(0).long()[:, :-1] if prefix_src else None
            translations = task.inference_step(generator,
                                               models,
                                               sample,
                                               constraints=constraints,
                                               prefix_tokens=prefix_tokens)
            translate_time = time.time() - translate_start_time
            list_constraints = [[] for _ in range(bsz)]
            if args.constraints:
                list_constraints = [unpack_constraints(c) for c in constraints]
            for i, (id,
                    hypos) in enumerate(zip(batch.ids.tolist(), translations)):
                src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad())
                constraints = list_constraints[i]
                results.append((start_id + id, src_tokens_i, hypos, {
                    "constraints": constraints,
                    "time": translate_time / len(translations)
                }))

        # sort output to match input order
        outputs = []
        for id_, src_tokens, hypos, info in sorted(results,
                                                   key=lambda x: x[0]):
            if src_dict is not None:
                src_str = src_dict.string(src_tokens, args.remove_bpe)
                print('S-{}\t{}'.format(id_, src_str))
                print("W-{}\t{:.3f}\tseconds".format(id_, info["time"]))
                for constraint in info["constraints"]:
                    print("C-{}\t{}".format(
                        id_, tgt_dict.string(constraint, args.remove_bpe)))

            # Process top predictions
            for hypo in hypos[:min(len(hypos), args.nbest)]:
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'],
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe,
                    extra_symbols_to_ignore=get_symbols_to_strip_from_output(
                        generator),
                )
                outputs.append(hypo_str)
                detok_hypo_str = decode_fn(hypo_str)
                score = hypo['score'] / math.log(2)  # convert to base 2
                # original hypothesis (after tokenization and BPE)
                print('H-{}\t{}\t{}'.format(id_, score, hypo_str))
                # detokenized hypothesis
                print('D-{}\t{}\t{}'.format(id_, score, detok_hypo_str))
                print('P-{}\t{}'.format(
                    id_,
                    ' '.join(
                        map(
                            lambda x: '{:.4f}'.format(x),
                            # convert from base e to base 2
                            hypo['positional_scores'].div_(math.log(2)
                                                           ).tolist(),
                        ))))
                if args.print_alignment:
                    alignment_str = " ".join(
                        ["{}-{}".format(src, tgt) for src, tgt in alignment])
                    print('A-{}\t{}'.format(id_, alignment_str))

        # update running id_ counter
        start_id += len(inputs)
        print(outputs)
        output = postprocess(outputs, line_sep)
        return output
コード例 #7
0
def sari_validate(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask,
                  epoch_itr, subsets: List[str]) -> List[Optional[float]]:
    from pathlib import Path
    from access.resources.paths import get_data_filepath
    from access.utils.helpers import read_lines
    from access.preprocessors import load_preprocessors, ComposedPreprocessor
    from easse.report import get_all_scores
    from fairseq.data import encoders
    from fairseq_cli.interactive import buffered_read, make_batches
    from fairseq_cli.generate import get_symbols_to_strip_from_output
    from fairseq.token_generation_constraints import pack_constraints, unpack_constraints
    import tempfile

    use_cuda = torch.cuda.is_available() and not cfg.common.cpu

    # Setup task, e.g., translation
    task = tasks.setup_task(cfg.task)

    # TODO: Choose parameters for the preprocessors ?
    # 从pickle文件读取preprocessor
    # preprocessors = load_preprocessors(Path(cfg.task.data).parent)
    # composed_preprocessor = ComposedPreprocessor(preprocessors)
    # 获得turkcorpus.valid.complex的路径
    complex_filepath = get_data_filepath('turkcorpus', 'valid', 'complex')
    # make temp dir
    # encoded_complex_filepath = tempfile.mkstemp()[1]
    # encoded_pred_filepath = tempfile.mkstemp()[1]
    pred_filepath = tempfile.mkstemp()[1]
    # use preprocessors to encode complex file
    # composed_preprocessor.encode_file(complex_filepath, encoded_complex_filepath)
    max_positions = utils.resolve_max_positions(
        task.max_positions(),
        trainer.get_model().max_positions(),
    )
    parser = options.get_generation_parser(interactive=True)
    # TODO: Take args from fairseq_generate
    gen_args = options.parse_args_and_arch(
        parser, input_args=['/dummy_data', '--beam', '2'])
    # Initialize generator
    generator = task.build_generator([trainer.model], gen_args)

    # Handle tokenization and BPE
    tokenizer = encoders.build_tokenizer(cfg.tokenizer)
    bpe = encoders.build_bpe(cfg.bpe)

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    def encode_fn(x):
        if tokenizer is not None:
            x = tokenizer.encode(x)
        if bpe is not None:
            x = bpe.encode(x)
        return x

    def decode_fn(x):
        if bpe is not None:
            x = bpe.decode(x)
        if tokenizer is not None:
            x = tokenizer.decode(x)
        return x

    align_dict = utils.load_align_dict(cfg.generation.replace_unk)

    with open(pred_filepath, 'w') as f:
        start_id = 0
        for inputs in buffered_read(complex_filepath, buffer_size=9999):
            results = []
            for batch in make_batches(inputs, cfg, task, max_positions,
                                      encode_fn):
                bsz = batch.src_tokens.size(0)
                src_tokens = batch.src_tokens
                src_lengths = batch.src_lengths
                constraints = batch.constraints
                if use_cuda:
                    src_tokens = src_tokens.cuda()
                    src_lengths = src_lengths.cuda()
                    if constraints is not None:
                        constraints = constraints.cuda()
                sample = {
                    "net_input": {
                        "src_tokens": src_tokens,
                        "src_lengths": src_lengths,
                    },
                }
                translations = task.inference_step(generator, [trainer.model],
                                                   sample,
                                                   constraints=constraints)
                list_constraints = [[] for _ in range(bsz)]
                if cfg.generation.constraints:
                    list_constraints = [
                        unpack_constraints(c) for c in constraints
                    ]
                for i, (id, hypos) in enumerate(
                        zip(batch.ids.tolist(), translations)):
                    src_tokens_i = utils.strip_pad(src_tokens[i],
                                                   tgt_dict.pad())
                    constraints = list_constraints[i]
                    results.append((
                        start_id + id,
                        src_tokens_i,
                        hypos,
                        {
                            "constraints": constraints,
                        },
                    ))

            # sort output to match input order
            for id_, src_tokens, hypos, info in sorted(results,
                                                       key=lambda x: x[0]):
                if src_dict is not None:
                    src_str = src_dict.string(src_tokens,
                                              cfg.common_eval.post_process)
                    for constraint in info["constraints"]:
                        pass

                # Process top predictions
                for hypo in hypos[:min(len(hypos), cfg.generation.nbest)]:
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo["tokens"].int().cpu(),
                        src_str=src_str,
                        alignment=hypo["alignment"],
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=cfg.common_eval.post_process,
                        extra_symbols_to_ignore=
                        get_symbols_to_strip_from_output(generator),
                    )
                    detok_hypo_str = decode_fn(hypo_str)
                    # detokenized hypothesis
                    f.write(f'{detok_hypo_str}\n')
                    if cfg.generation.print_alignment:
                        alignment_str = " ".join([
                            "{}-{}".format(src, tgt) for src, tgt in alignment
                        ])

            # update running id_ counter
            start_id += len(inputs)

        # composed_preprocessor.decode_file(encoded_pred_filepath, pred_filepath)
        ref_filepaths = [
            get_data_filepath('turkcorpus', 'valid', 'simple.turk', i)
            for i in range(8)
        ]
        scores = get_all_scores(
            read_lines(complex_filepath), read_lines(pred_filepath),
            [read_lines(ref_filepath) for ref_filepath in ref_filepaths])
        print(f'num_updates={trainer.get_num_updates()}')
        print(f'ts_scores={scores}')
        sari = scores['SARI']
        if not hasattr(trainer, 'best_sari'):
            trainer.best_sari = 0
        if not hasattr(trainer, 'n_validations_since_best'):
            trainer.n_validations_since_best = 0
        if sari > trainer.best_sari:
            trainer.best_sari = sari
            trainer.n_validations_since_best = 0
        else:
            trainer.n_validations_since_best += 1
            print(
                f'SARI did not improve for {trainer.n_validations_since_best} validations'
            )
            # Does not work because scheduler will set it to previous value everytime
            # trainer.optimizer.set_lr(0.75 * trainer.optimizer.get_lr())
            if trainer.n_validations_since_best >= cfg.validations_before_sari_early_stopping:
                print(
                    f'Early stopping because SARI did not improve for {trainer.n_validations_since_best} validations'
                )
                trainer.early_stopping = True

            def is_abort(epoch_itr, best_sari):
                if (epoch_itr.epoch >= 2 and best_sari < 19):
                    return True
                if (epoch_itr.epoch >= 5 and best_sari < 22):
                    return True
                if (epoch_itr.epoch >= 10 and best_sari < 25):
                    return True
                return False

            # if is_abort(epoch_itr, best_sari):
            #     print(f'Early stopping because best SARI is too low ({best_sari:.2f}) after {epoch_itr.epoch} epochs.')
            #     # Remove the checkpoint directory as we got nothing interesting
            #     shutil.rmtree(args.save_dir)
            #     # TODO: Abort
    return [-sari]