def test_run_request_with_v2_config():

    class Preprocessor:
        def process_input(self, source, target=None, config=None, **kwargs):
            assert config is None
            source = source.split()
            return source, None, None

    class Postprocessor:
        def process_input(self, source, target=None, config=None, **kwargs):
            assert config is None
            return " ".join(target[0])

    def translate(source_tokens, target_tokens, options=None):
        return [
            [_make_output((target if target is not None else []) + list(reversed(source)))]
            for source, target in zip(source_tokens, target_tokens)]

    config = {
        "source": "en",
        "target": "fr",
        "preprocess": [
           {
                "op": "tokenization",
                "source": {"mode": "space"},
                "target": {"mode": "space"},
            },
        ],
    }

    with pytest.raises(ValueError, match="override is not supported"):
        request = {"src": [{"text": "a b c", "config": {"override": 42}}]}
        serving.run_request(
            request,
            translate,
            Preprocessor(),
            Postprocessor(),
            config=config,
            rebatch_request=False,
            max_batch_size=1)

    request = {"src": [{"text": "a b c"}]}
    result = serving.run_request(
        request,
        translate,
        Preprocessor(),
        Postprocessor(),
        config=config,
        rebatch_request=False,
        max_batch_size=1)

    assert result == {'tgt': [[{'text': 'c b a'}]]}
Example #2
0
def test_run_request():
    with pytest.raises(serving.InvalidRequest):
        serving.run_request(["abc"], None)
    with pytest.raises(serving.InvalidRequest):
        serving.run_request({"input": "abc"}, None)
    with pytest.raises(serving.InvalidRequest):
        serving.run_request({"src": "abc"}, None)

    assert serving.run_request({"src": []}, None) == {"tgt": []}

    class Preprocessor:
        def process_input(self, source, target=None, config=None, **kwargs):
            sep = config["separator"]
            source = source.split(sep)
            if target is not None:
                target = target.split(sep)
            return source, target, None

    class Postprocessor:
        def process_input(self, source, target=None, config=None, **kwargs):
            return config["separator"].join(target[0])

    def translate(source_tokens, target_tokens, options=None):
        assert options is not None
        assert "config" in options  # Request options are fowarded.
        assert "mode" in options
        assert options["max_batch_size"] == 1
        return [[
            _make_output((target if target is not None else []) +
                         list(reversed(source)))
        ] for source, target in zip(source_tokens, target_tokens)]

    config = {"separator": "-"}
    request = {
        "src": [{
            "text": "a b c",
            "target_prefix": "1 2",
            "mode": "alternatives"
        }, {
            "text": "x_y_z",
            "config": {
                "separator": "_"
            }
        }],
        "options": {
            "config": {
                "separator": " "
            }
        }
    }

    result = serving.run_request(request,
                                 translate,
                                 Preprocessor(),
                                 Postprocessor(),
                                 config=config,
                                 rebatch_request=False,
                                 max_batch_size=1)
    assert result == {'tgt': [[{'text': '1 2 c b a'}], [{'text': 'z_y_x'}]]}
Example #3
0
def test_serve(tmpdir):
    framework = DummyFramework(stateless=True)
    _, model_info = framework.serve(config_base, None)

    request = {
        "src": [
            {
                "text": "Hello world!",
                "target_prefix": "Bonjour"
            },
            {
                "text": "How are you?",
                "target_prefix": "Comment"
            },
        ]
    }

    result = run_request(
        request,
        functools.partial(framework.forward_request, model_info),
        preprocessor=framework._get_preprocessor(config_base, train=False),
        postprocessor=framework._get_postprocessor(config_base),
        config=config_base)

    # Dummy translation does "target + reversed(source)".
    assert result["tgt"][0][0]["text"] == "Bonjour! world Hello"
    assert result["tgt"][1][0]["text"] == "Comment? you are How"
Example #4
0
def test_run_request():
    with pytest.raises(ValueError):
        serving.run_request(["abc"], None, None, None)
    with pytest.raises(ValueError):
        serving.run_request({"input": "abc"}, None, None, None)
    with pytest.raises(ValueError):
        serving.run_request({"src": "abc"}, None, None, None)

    assert serving.run_request({"src": []}, None, None, None) == {"tgt": []}

    def preprocess(src, tgt, config):
        sep = config["separator"]
        src = src.split(sep)
        if tgt is not None:
            assert config.get("target_type") is not None
            tgt = tgt.split(sep)
        return src, tgt

    def translate(source_tokens, target_tokens, options=None):
        assert options is not None
        assert "config" in options  # Request options are fowarded.
        assert "mode" in options
        return [[
            _make_output((target if target is not None else []) +
                         list(reversed(source)))
        ] for source, target in zip(source_tokens, target_tokens)]

    def postprocess(src, tgt, config):
        return config["separator"].join(tgt)

    config = {"separator": "-"}
    request = {
        "src": [{
            "text": "a b c",
            "target_prefix": "1 2",
            "mode": "alternatives"
        }, {
            "text": "x_y_z",
            "config": {
                "separator": "_"
            }
        }],
        "options": {
            "config": {
                "separator": " "
            }
        }
    }

    result = serving.run_request(request,
                                 preprocess,
                                 translate,
                                 postprocess,
                                 config=config)
    assert result == {'tgt': [[{'text': '1 2 c b a'}], [{'text': 'z_y_x'}]]}
def test_serve_cloud_translation_framework():
    class _ReverseTranslationFramework(CloudTranslationFramework):
        def translate_batch(self, batch, source_lang, target_lang):
            assert source_lang == "en"
            assert target_lang == "fr"
            return ["".join(reversed(list(text))) for text in batch]

    framework = _ReverseTranslationFramework()
    config = {"source": "en", "target": "fr"}
    _, service_info = framework.serve(config, None)
    request = {"src": [{"text": "Hello"}]}
    result = serving.run_request(
        request, functools.partial(framework.forward_request, service_info))
    assert result["tgt"][0][0]["text"] == "olleH"