Пример #1
0
import os
import random
from tempfile import TemporaryDirectory
from typing import Optional, List, Tuple

import mxnet as mx
import numpy as np
import pytest

from sockeye import constants as C
from sockeye import data_io
from sockeye import vocab
from sockeye.utils import SockeyeError, get_tokens, seed_rngs
from sockeye.test_utils import tmp_digits_dataset

seed_rngs(12)

define_bucket_tests = [(50, 10, [10, 20, 30, 40, 50]), (50, 20, [20, 40, 50]),
                       (50, 50, [50]), (5, 10, [5]), (11, 5, [5, 10, 11]),
                       (19, 10, [10, 19])]


@pytest.mark.parametrize("max_seq_len, step, expected_buckets",
                         define_bucket_tests)
def test_define_buckets(max_seq_len, step, expected_buckets):
    buckets = data_io.define_buckets(max_seq_len, step=step)
    assert buckets == expected_buckets


define_parallel_bucket_tests = [(50, 50, 10, True, 1.0, [(10, 10), (20, 20),
                                                         (30, 30), (40, 40),
Пример #2
0
def run_translate(args: argparse.Namespace):
    # Seed randomly unless a seed has been passed
    seed_rngs(args.seed if args.seed is not None else int(time.time()))

    if args.output is not None:
        setup_main_logger(console=not args.quiet,
                          file_logging=not args.no_logfile,
                          path="%s.%s" % (args.output, C.LOG_NAME),
                          level=args.loglevel)
    else:
        setup_main_logger(file_logging=False, level=args.loglevel)

    log_basic_info(args)

    if args.nbest_size > 1:
        if args.output_type != C.OUTPUT_HANDLER_JSON:
            logger.warning(
                "For nbest translation, you must specify `--output-type '%s'; overriding your setting of '%s'.",
                C.OUTPUT_HANDLER_JSON, args.output_type)
            args.output_type = C.OUTPUT_HANDLER_JSON
    output_handler = get_output_handler(args.output_type, args.output)

    use_cpu = args.use_cpu
    if not pt.cuda.is_available():
        logger.info("CUDA not available, using cpu")
        use_cpu = True
    device = pt.device('cpu') if use_cpu else pt.device('cuda', args.device_id)
    logger.info(f"Translate Device: {device}")
    models, source_vocabs, target_vocabs = load_models(
        device=device,
        model_folders=args.models,
        checkpoints=args.checkpoints,
        dtype=args.dtype,
        inference_only=True)

    restrict_lexicon = None  # type: Optional[Union[RestrictLexicon, Dict[str, RestrictLexicon]]]
    if args.restrict_lexicon is not None:
        logger.info(str(args.restrict_lexicon))
        if len(args.restrict_lexicon) == 1:
            # Single lexicon used for all inputs.
            # Handle a single arg of key:path or path (parsed as path:path)
            restrict_lexicon = load_restrict_lexicon(
                args.restrict_lexicon[0][1],
                source_vocabs[0],
                target_vocabs[0],
                k=args.restrict_lexicon_topk)
            logger.info(
                f"Loaded a single lexicon ({args.restrict_lexicon[0][0]}) that will be applied to all inputs."
            )
        else:
            check_condition(
                args.json_input,
                "JSON input is required when using multiple lexicons for vocabulary restriction"
            )
            # Multiple lexicons with specified names
            restrict_lexicon = dict()
            for key, path in args.restrict_lexicon:
                lexicon = load_restrict_lexicon(path,
                                                source_vocabs[0],
                                                target_vocabs[0],
                                                k=args.restrict_lexicon_topk)
                restrict_lexicon[key] = lexicon

    brevity_penalty_weight = args.brevity_penalty_weight
    if args.brevity_penalty_type == C.BREVITY_PENALTY_CONSTANT:
        if args.brevity_penalty_constant_length_ratio > 0.0:
            constant_length_ratio = args.brevity_penalty_constant_length_ratio
        else:
            constant_length_ratio = sum(model.length_ratio_mean
                                        for model in models) / len(models)
            logger.info(
                "Using average of constant length ratios saved in the model configs: %f",
                constant_length_ratio)
    elif args.brevity_penalty_type == C.BREVITY_PENALTY_LEARNED:
        constant_length_ratio = -1.0
    elif args.brevity_penalty_type == C.BREVITY_PENALTY_NONE:
        brevity_penalty_weight = 0.0
        constant_length_ratio = -1.0
    else:
        raise ValueError("Unknown brevity penalty type %s" %
                         args.brevity_penalty_type)

    for model in models:
        model.eval()

    scorer = inference.CandidateScorer(
        length_penalty_alpha=args.length_penalty_alpha,
        length_penalty_beta=args.length_penalty_beta,
        brevity_penalty_weight=brevity_penalty_weight)
    scorer.to(models[0].dtype)

    translator = inference.Translator(
        device=device,
        ensemble_mode=args.ensemble_mode,
        scorer=scorer,
        batch_size=args.batch_size,
        beam_size=args.beam_size,
        beam_search_stop=args.beam_search_stop,
        nbest_size=args.nbest_size,
        models=models,
        source_vocabs=source_vocabs,
        target_vocabs=target_vocabs,
        restrict_lexicon=restrict_lexicon,
        strip_unknown_words=args.strip_unknown_words,
        sample=args.sample,
        output_scores=output_handler.reports_score(),
        constant_length_ratio=constant_length_ratio,
        max_output_length_num_stds=args.max_output_length_num_stds,
        max_input_length=args.max_input_length,
        max_output_length=args.max_output_length,
        prevent_unk=args.prevent_unk,
        greedy=args.greedy)

    read_and_translate(translator=translator,
                       output_handler=output_handler,
                       chunk_size=args.chunk_size,
                       input_file=args.input,
                       input_factors=args.input_factors,
                       input_is_json=args.json_input)
Пример #3
0
import os
import random
from tempfile import TemporaryDirectory
from typing import Optional, List, Tuple

import mxnet as mx
import numpy as np
import pytest

from sockeye import constants as C
from sockeye import data_io
from sockeye import vocab
from sockeye.utils import SockeyeError, get_tokens, seed_rngs
from test.common import tmp_digits_dataset

seed_rngs(12)

define_bucket_tests = [(50, 10, [10, 20, 30, 40, 50]),
                       (50, 20, [20, 40, 50]),
                       (50, 50, [50]),
                       (5, 10, [5]),
                       (11, 5, [5, 10, 11]),
                       (19, 10, [10, 19])]


@pytest.mark.parametrize("max_seq_len, step, expected_buckets", define_bucket_tests)
def test_define_buckets(max_seq_len, step, expected_buckets):
    buckets = data_io.define_buckets(max_seq_len, step=step)
    assert buckets == expected_buckets