def test_run_request_with_v2_config(): class Preprocessor: def process_input(self, source, target=None, config=None, **kwargs): assert config is None source = source.split() return source, None, None class Postprocessor: def process_input(self, source, target=None, config=None, **kwargs): assert config is None return " ".join(target[0]) def translate(source_tokens, target_tokens, options=None): return [ [_make_output((target if target is not None else []) + list(reversed(source)))] for source, target in zip(source_tokens, target_tokens)] config = { "source": "en", "target": "fr", "preprocess": [ { "op": "tokenization", "source": {"mode": "space"}, "target": {"mode": "space"}, }, ], } with pytest.raises(ValueError, match="override is not supported"): request = {"src": [{"text": "a b c", "config": {"override": 42}}]} serving.run_request( request, translate, Preprocessor(), Postprocessor(), config=config, rebatch_request=False, max_batch_size=1) request = {"src": [{"text": "a b c"}]} result = serving.run_request( request, translate, Preprocessor(), Postprocessor(), config=config, rebatch_request=False, max_batch_size=1) assert result == {'tgt': [[{'text': 'c b a'}]]}
def test_run_request(): with pytest.raises(serving.InvalidRequest): serving.run_request(["abc"], None) with pytest.raises(serving.InvalidRequest): serving.run_request({"input": "abc"}, None) with pytest.raises(serving.InvalidRequest): serving.run_request({"src": "abc"}, None) assert serving.run_request({"src": []}, None) == {"tgt": []} class Preprocessor: def process_input(self, source, target=None, config=None, **kwargs): sep = config["separator"] source = source.split(sep) if target is not None: target = target.split(sep) return source, target, None class Postprocessor: def process_input(self, source, target=None, config=None, **kwargs): return config["separator"].join(target[0]) def translate(source_tokens, target_tokens, options=None): assert options is not None assert "config" in options # Request options are fowarded. assert "mode" in options assert options["max_batch_size"] == 1 return [[ _make_output((target if target is not None else []) + list(reversed(source))) ] for source, target in zip(source_tokens, target_tokens)] config = {"separator": "-"} request = { "src": [{ "text": "a b c", "target_prefix": "1 2", "mode": "alternatives" }, { "text": "x_y_z", "config": { "separator": "_" } }], "options": { "config": { "separator": " " } } } result = serving.run_request(request, translate, Preprocessor(), Postprocessor(), config=config, rebatch_request=False, max_batch_size=1) assert result == {'tgt': [[{'text': '1 2 c b a'}], [{'text': 'z_y_x'}]]}
def test_serve(tmpdir): framework = DummyFramework(stateless=True) _, model_info = framework.serve(config_base, None) request = { "src": [ { "text": "Hello world!", "target_prefix": "Bonjour" }, { "text": "How are you?", "target_prefix": "Comment" }, ] } result = run_request( request, functools.partial(framework.forward_request, model_info), preprocessor=framework._get_preprocessor(config_base, train=False), postprocessor=framework._get_postprocessor(config_base), config=config_base) # Dummy translation does "target + reversed(source)". assert result["tgt"][0][0]["text"] == "Bonjour! world Hello" assert result["tgt"][1][0]["text"] == "Comment? you are How"
def test_run_request(): with pytest.raises(ValueError): serving.run_request(["abc"], None, None, None) with pytest.raises(ValueError): serving.run_request({"input": "abc"}, None, None, None) with pytest.raises(ValueError): serving.run_request({"src": "abc"}, None, None, None) assert serving.run_request({"src": []}, None, None, None) == {"tgt": []} def preprocess(src, tgt, config): sep = config["separator"] src = src.split(sep) if tgt is not None: assert config.get("target_type") is not None tgt = tgt.split(sep) return src, tgt def translate(source_tokens, target_tokens, options=None): assert options is not None assert "config" in options # Request options are fowarded. assert "mode" in options return [[ _make_output((target if target is not None else []) + list(reversed(source))) ] for source, target in zip(source_tokens, target_tokens)] def postprocess(src, tgt, config): return config["separator"].join(tgt) config = {"separator": "-"} request = { "src": [{ "text": "a b c", "target_prefix": "1 2", "mode": "alternatives" }, { "text": "x_y_z", "config": { "separator": "_" } }], "options": { "config": { "separator": " " } } } result = serving.run_request(request, preprocess, translate, postprocess, config=config) assert result == {'tgt': [[{'text': '1 2 c b a'}], [{'text': 'z_y_x'}]]}
def test_serve_cloud_translation_framework(): class _ReverseTranslationFramework(CloudTranslationFramework): def translate_batch(self, batch, source_lang, target_lang): assert source_lang == "en" assert target_lang == "fr" return ["".join(reversed(list(text))) for text in batch] framework = _ReverseTranslationFramework() config = {"source": "en", "target": "fr"} _, service_info = framework.serve(config, None) request = {"src": [{"text": "Hello"}]} result = serving.run_request( request, functools.partial(framework.forward_request, service_info)) assert result["tgt"][0][0]["text"] == "olleH"