Ejemplo n.º 1
0
    def run_serve_sanity_check(self):
        if not self.is_servable:
            raise NotImplementedError(
                "This Task is not servable. Attach a Deserializer to enable serving."
            )

        from fastapi.testclient import TestClient

        from flash.core.serve.flash_components import build_flash_serve_model_component

        print("Running serve sanity check")
        comp = build_flash_serve_model_component(self)
        composition = Composition(predict=comp, TESTING=True, DEBUG=True)
        app = composition.serve(host="0.0.0.0", port=8000)

        with TestClient(app) as tc:
            input_str = self.data_pipeline._deserializer.example_input
            body = {
                "session": "UUID",
                "payload": {
                    "inputs": {
                        "data": input_str
                    }
                }
            }
            resp = tc.post("http://0.0.0.0:8000/predict", json=body)
            print(f"Sanity check response: {resp.json()}")
Ejemplo n.º 2
0
def test_start_server_with_repeated_exposed(session_global_datadir,
                                            lightning_squeezenet1_1_obj):
    from tests.core.serve.models import ClassificationInferenceRepeated

    comp = ClassificationInferenceRepeated(lightning_squeezenet1_1_obj)
    composit = Composition(comp=comp, TESTING=True, DEBUG=True)
    app = composit.serve(host="0.0.0.0", port=8000)
    with TestClient(app) as tc:

        meta = tc.get("http://127.0.0.1:8000/classify/meta")
        assert meta.status_code == 200
        with (session_global_datadir / "fish.jpg").open("rb") as f:
            imgstr = base64.b64encode(f.read()).decode("UTF-8")
        body = {"session": "UUID", "payload": {"img": [{"data": imgstr}]}}
        resp = tc.post("http://127.0.0.1:8000/classify", json=body)
        assert "result" in resp.json()
        expected = {
            "session": "UUID",
            "result": {
                "prediction":
                ["goldfish, Carassius auratus", "goldfish, Carassius auratus"],
                "other":
                21,
            },
        }
        assert resp.json() == expected
Ejemplo n.º 3
0
def test_resnet_18_inference_class(session_global_datadir,
                                   lightning_squeezenet1_1_obj):
    from tests.core.serve.models import ClassificationInference

    comp = ClassificationInference(lightning_squeezenet1_1_obj)
    composit = Composition(comp=comp, TESTING=True, DEBUG=True)
    app = composit.serve(host="0.0.0.0", port=8000)

    with TestClient(app) as tc:
        alive = tc.get("http://127.0.0.1:8000/flashserve/alive")
        assert alive.status_code == 200
        assert alive.json() == {"alive": True}

        meta = tc.get("http://127.0.0.1:8000/classify/dag_json")
        assert isinstance(meta.json(), dict)

        meta = tc.get("http://127.0.0.1:8000/classify/meta")
        assert meta.status_code == 200

        with (session_global_datadir / "fish.jpg").open("rb") as f:
            imgstr = base64.b64encode(f.read()).decode("UTF-8")
        body = {"session": "UUID", "payload": {"img": {"data": imgstr}}}
        resp = tc.post("http://127.0.0.1:8000/classify", json=body)
        assert "result" in resp.json()
        expected = {
            "session": "UUID",
            "result": {
                "prediction": "goldfish, Carassius auratus"
            }
        }
        assert expected == resp.json()
Ejemplo n.º 4
0
def test_composition_recieve_wrong_arg_type(lightning_squeezenet1_1_obj):
    # no endpoints or components
    with pytest.raises(TypeError):
        _ = Composition(hello="world")

    # no endpoints multiple components
    from tests.core.serve.models import ClassificationInferenceComposable

    comp1 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj)
    comp2 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj)

    with pytest.raises(ValueError):
        _ = Composition(c1=comp1, c2=comp2)
Ejemplo n.º 5
0
def test_composit_endpoint_data(lightning_squeezenet1_1_obj):
    from tests.core.serve.models import ClassificationInferenceComposable

    comp = ClassificationInferenceComposable(lightning_squeezenet1_1_obj)
    composit = Composition(comp=comp)
    assert composit.component_uid_names == {"callnum_1": "comp"}
    assert composit.connections == []

    actual_endpoints = {k: asdict(v) for k, v in composit.endpoints.items()}
    assert actual_endpoints == {
        "classify_ENDPOINT": {
            "inputs": {
                "img": "callnum_1.inputs.img",
                "tag": "callnum_1.inputs.tag"
            },
            "outputs": {
                "cropped_img": "callnum_1.outputs.cropped_img",
                "predicted_tag": "callnum_1.outputs.predicted_tag",
            },
            "route": "/classify",
        }
    }

    ep = Endpoint(
        route="/predict",
        inputs={
            "label_1": comp.inputs.img,
            "tag_1": comp.inputs.tag,
        },
        outputs={
            "prediction": comp.outputs.predicted_tag,
            "cropped": comp.outputs.cropped_img,
        },
    )
    composit = Composition(comp=comp, predict_ep=ep)
    actual_endpoints = {k: asdict(v) for k, v in composit.endpoints.items()}
    assert actual_endpoints == {
        "predict_ep": {
            "inputs": {
                "label_1": "callnum_1.inputs.img",
                "tag_1": "callnum_1.inputs.tag"
            },
            "outputs": {
                "cropped": "callnum_1.outputs.cropped_img",
                "prediction": "callnum_1.outputs.predicted_tag",
            },
            "route": "/predict",
        }
    }
Ejemplo n.º 6
0
    def serve(self,
              host: str = "127.0.0.1",
              port: int = 8000,
              sanity_check: bool = True) -> "Composition":
        if not self.is_servable:
            raise NotImplementedError(
                "This Task is not servable. Attach a Deserializer to enable serving."
            )

        from flash.core.serve.flash_components import build_flash_serve_model_component

        if sanity_check:
            self.run_serve_sanity_check()

        comp = build_flash_serve_model_component(self)
        composition = Composition(predict=comp, TESTING=flash._IS_TESTING)
        composition.serve(host=host, port=port)
        return composition
Ejemplo n.º 7
0
def test_servable_mapping(tmp_path, lightning_squeezenet1_1_obj, squeezenet_servable):
    from tests.core.serve.models import ClassificationInferenceModelMapping

    squeezenet_gm, _ = squeezenet_servable
    model_map = {"model_one": squeezenet_gm, "model_two": squeezenet_gm}
    comp = ClassificationInferenceModelMapping(model_map)

    composit = Composition(comp=comp)
    assert composit.components["callnum_1"]._flashserve_meta_.models == model_map
    assert composit.components["callnum_1"].model1 == model_map["model_one"]
    assert composit.components["callnum_1"].model2 == model_map["model_two"]
Ejemplo n.º 8
0
def test_servable_sequence(tmp_path, lightning_squeezenet1_1_obj, squeezenet_servable):
    from tests.core.serve.models import ClassificationInferenceModelSequence

    squeezenet_gm, _ = squeezenet_servable
    model_seq = [squeezenet_gm, squeezenet_gm]
    comp = ClassificationInferenceModelSequence(model_seq)

    composit = Composition(comp=comp)
    assert composit.components["callnum_1"]._flashserve_meta_.models == model_seq
    assert composit.components["callnum_1"].model1 == model_seq[0]
    assert composit.components["callnum_1"].model2 == model_seq[1]
Ejemplo n.º 9
0
def test_complex_spec_single_endpoint(tmp_path, lightning_squeezenet1_1_obj):
    from tests.core.serve.models import ClassificationInferenceComposable

    comp1 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj)
    comp2 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj)
    comp3 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj)

    comp1.outputs.predicted_tag >> comp3.inputs.tag  # skipcq: PYL-W0104
    comp2.outputs.cropped_img >> comp3.inputs.img  # skipcq: PYL-W0104
    comp1.outputs.predicted_tag >> comp2.inputs.tag  # skipcq: PYL-W0104

    ep = Endpoint(
        route="/predict",
        inputs={
            "img_1": comp1.inputs.img,
            "img_2": comp2.inputs.img,
            "tag_1": comp1.inputs.tag,
        },
        outputs={"prediction": comp3.outputs.predicted_tag},
    )

    composit = Composition(comp1=comp1,
                           comp2=comp2,
                           comp3=comp3,
                           predict_compositon_ep=ep)
    connections = [str(c) for c in composit.connections]
    assert connections == [
        "callnum_1.outputs.predicted_tag >> callnum_3.inputs.tag",
        "callnum_1.outputs.predicted_tag >> callnum_2.inputs.tag",
        "callnum_2.outputs.cropped_img >> callnum_3.inputs.img",
    ]
    assert composit.component_uid_names == {
        "callnum_1": "comp1",
        "callnum_2": "comp2",
        "callnum_3": "comp3",
    }

    actual_endpoints = {k: asdict(v) for k, v in composit.endpoints.items()}
    assert actual_endpoints == {
        "predict_compositon_ep": {
            "inputs": {
                "img_1": "callnum_1.inputs.img",
                "img_2": "callnum_2.inputs.img",
                "tag_1": "callnum_1.inputs.tag",
            },
            "outputs": {
                "prediction": "callnum_3.outputs.predicted_tag",
            },
            "route": "/predict",
        }
    }
Ejemplo n.º 10
0
    def serve(self, host: str = "127.0.0.1", port: int = 8000) -> 'Composition':
        from flash.core.serve.flash_components import FlashInputs, FlashOutputs

        class FlashServeModelComponent(ModelComponent):

            def __init__(self, model):
                self.model = model
                self.model.eval()
                self.data_pipeline = self.model.build_data_pipeline()
                self.worker_preprocessor = self.data_pipeline.worker_preprocessor(
                    RunningStage.PREDICTING, is_serving=True
                )
                self.device_preprocessor = self.data_pipeline.device_preprocessor(RunningStage.PREDICTING)
                self.postprocessor = self.data_pipeline.postprocessor(RunningStage.PREDICTING, is_serving=True)
                # todo (tchaton) Remove this hack
                self.extra_arguments = len(inspect.signature(self.model.transfer_batch_to_device).parameters) == 3
                self.device = self.model.device

            @expose(
                inputs={"inputs": FlashInputs(self.data_pipeline.deserialize_processor())},
                outputs={"outputs": FlashOutputs(self.data_pipeline.serialize_processor())},
            )
            def predict(self, inputs):
                with torch.no_grad():
                    inputs = self.worker_preprocessor(inputs)
                    if self.extra_arguments:
                        inputs = self.model.transfer_batch_to_device(inputs, self.device, 0)
                    else:
                        inputs = self.model.transfer_batch_to_device(inputs, self.device)
                    inputs = self.device_preprocessor(inputs)
                    preds = self.model.predict_step(inputs, 0)
                    preds = self.postprocessor(preds)
                    return preds

        comp = FlashServeModelComponent(self)
        composition = Composition(predict=comp)
        composition.serve(host=host, port=port)
        return composition
Ejemplo n.º 11
0
def test_endpoint_overwrite_connection_dag(session_global_datadir,
                                           lightning_squeezenet1_1_obj):
    from tests.core.serve.models import ClassificationInference, SeatClassifier

    resnet_comp = ClassificationInference(lightning_squeezenet1_1_obj)
    seat_comp = SeatClassifier(lightning_squeezenet1_1_obj,
                               config={"sport": "football"})

    resnet_comp.outputs.prediction >> seat_comp.inputs.stadium

    ep = Endpoint(
        route="/predict_seat",
        inputs={
            "image": resnet_comp.inputs.img,
            "isle": seat_comp.inputs.isle,
            "section": seat_comp.inputs.section,
            "row": seat_comp.inputs.row,
        },
        outputs={
            "seat_number": seat_comp.outputs.seat_number,
            "team": seat_comp.outputs.team
        },
    )
    ep2 = Endpoint(
        route="/predict_seat_img",
        inputs={
            "image": resnet_comp.inputs.img,
            "isle": seat_comp.inputs.isle,
            "section": seat_comp.inputs.section,
            "row": seat_comp.inputs.row,
        },
        outputs={
            "seat_number": seat_comp.outputs.seat_number,
            "team": seat_comp.outputs.team,
            "img_out": resnet_comp.outputs.prediction,
        },
    )
    ep3 = Endpoint(
        route="/predict_seat_img_two",
        inputs={
            "stadium": seat_comp.inputs.stadium,
            "isle": seat_comp.inputs.isle,
            "section": seat_comp.inputs.section,
            "row": seat_comp.inputs.row,
        },
        outputs={
            "seat_number": seat_comp.outputs.seat_number,
            "team": seat_comp.outputs.team
        },
    )

    composit = Composition(
        resnet_comp=resnet_comp,
        seat_comp=seat_comp,
        seat_prediction_ep=ep,
        seat_image_prediction_ep=ep2,
        seat_image_prediction_two_ep=ep3,
        TESTING=True,
        DEBUG=True,
    )
    app = composit.serve(host="0.0.0.0", port=8000)

    with TestClient(app) as tc:
        resp = tc.get("http://127.0.0.1:8000/flashserve/component_dags")
        assert resp.headers["content-type"] == "text/html; charset=utf-8"
        assert resp.template.name == "dag.html"
        resp = tc.get("http://127.0.0.1:8000/predict_seat/dag")
        assert resp.headers["content-type"] == "text/html; charset=utf-8"
        assert resp.template.name == "dag.html"
        resp = tc.get("http://127.0.0.1:8000/predict_seat_img/dag")
        assert resp.headers["content-type"] == "text/html; charset=utf-8"
        assert resp.template.name == "dag.html"
        resp = tc.get("http://127.0.0.1:8000/predict_seat_img_two/dag")
        assert resp.headers["content-type"] == "text/html; charset=utf-8"
        assert resp.template.name == "dag.html"

        with (session_global_datadir / "cat.jpg").open("rb") as f:
            imgstr = base64.b64encode(f.read()).decode("UTF-8")
        body = {
            "session": "UUID",
            "payload": {
                "image": {
                    "data": imgstr
                },
                "section": {
                    "num": 10
                },
                "isle": {
                    "num": 4
                },
                "row": {
                    "num": 53
                },
            },
        }
        success = tc.post("http://127.0.0.1:8000/predict_seat", json=body)
        assert success.json() == {
            "result": {
                "seat_number": 4799680,
                "team": "buffalo bills, the ralph"
            },
            "session": "UUID",
        }

        success = tc.post("http://127.0.0.1:8000/predict_seat_img", json=body)
        assert success.json() == {
            "result": {
                "seat_number": 4799680,
                "team": "buffalo bills, the ralph",
                "img_out": "Persian cat",
            },
            "session": "UUID",
        }

        body = {
            "session": "UUID",
            "payload": {
                "stadium": {
                    "label": "buffalo bills, the ralph"
                },
                "section": {
                    "num": 10
                },
                "isle": {
                    "num": 4
                },
                "row": {
                    "num": 53
                },
            },
        }
        success = tc.post("http://127.0.0.1:8000/predict_seat_img_two",
                          json=body)
        assert success.json() == {
            "result": {
                "seat_number": 16960000,
                "team": "buffalo bills, the ralph"
            },
            "session": "UUID",
        }
Ejemplo n.º 12
0
    "NOX",
    "RM",
    "AGE",
    "DIS",
    "RAD",
    "TAX",
    "PTRATIO",
    "B",
    "LSTAT",
]


class PricePrediction(ModelComponent):
    def __init__(self, model):  # skipcq: PYL-W0621
        self.model = model

    @expose(inputs={"table": Table(column_names=feature_names)},
            outputs={"pred": Number()})
    def predict(self, table):
        return self.model(table)


data = sklearn.datasets.load_boston()
model = sklearn.linear_model.LinearRegression()
model.fit(data.data, data.target)

model = hummingbird.ml.convert(model, "torch", test_input=data.data[0:1]).model
comp = PricePrediction(model)
composit = Composition(comp=comp)
composit.serve()
Ejemplo n.º 13
0
def test_endpoint_errors_on_wrong_key_name(lightning_squeezenet1_1_obj):
    from tests.core.serve.models import ClassificationInferenceComposable

    comp1 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj)

    # input key does not exist
    with pytest.raises(AttributeError):
        _ = Endpoint(
            route="/predict",
            inputs={
                "label_1": comp1.inputs.img,
                "tag_1": comp1.inputs.DOESNOTEXIST,
            },
            outputs={
                "prediction": comp1.outputs.predicted_tag,
                "cropped": comp1.outputs.cropped_img,
            },
        )

    # output key does not exist
    with pytest.raises(AttributeError):
        _ = Endpoint(
            route="/predict",
            inputs={
                "label_1": comp1.inputs.img,
                "tag_1": comp1.inputs.tag,
            },
            outputs={
                "prediction": comp1.outputs.predicted_tag,
                "cropped": comp1.outputs.DOESNOTEXIST,
            },
        )

    # output key does not exist
    ep = Endpoint(
        route="/predict",
        inputs={
            "label_1": comp1.inputs.img,
            "tag_1": comp1.inputs.tag,
        },
        outputs={
            "prediction": comp1.outputs.predicted_tag,
            "cropped": "callnum_1.outputs.DOESNOTEXIST",
        },
    )
    with pytest.raises(AttributeError):
        _ = Composition(comp1=comp1, predict_ep=ep)

    # input function does not exist
    ep = Endpoint(
        route="/predict",
        inputs={
            "label_1": comp1.inputs.img,
            "tag_1": "DOESNOTEXIST.inputs.tag",
        },
        outputs={
            "prediction": comp1.outputs.predicted_tag,
            "cropped": comp1.outputs.cropped_img,
        },
    )
    with pytest.raises(AttributeError):
        _ = Composition(comp1=comp1, predict_ep=ep)

    # output function does not exist
    ep = Endpoint(
        route="/predict",
        inputs={
            "label_1": comp1.inputs.img,
            "tag_1": comp1.inputs.tag,
        },
        outputs={
            "prediction": comp1.outputs.predicted_tag,
            "cropped": "DOESNOTEXIST.outputs.cropped_img",
        },
    )
    with pytest.raises(AttributeError):
        _ = Composition(comp1=comp1, predict_ep=ep)
Ejemplo n.º 14
0
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torchvision

from flash.core.serve import Composition, expose, ModelComponent
from flash.core.serve.types import BBox, Image, Label, Repeated


class ObjectDetection(ModelComponent):
    def __init__(self, model):
        self.model = model

    @expose(
        inputs={"img": Image()},
        outputs={
            "boxes": Repeated(BBox()),
            "labels": Repeated(Label("classes.txt"))
        },
    )
    def detect(self, img):
        img = img.permute(0, 3, 2, 1).float() / 255
        out = self.model(img)[0]
        return out["boxes"], out["labels"]


fasterrcnn = torchvision.models.detection.fasterrcnn_resnet50_fpn(
    pretrained=True).eval()
composit = Composition(component=ObjectDetection(fasterrcnn))
composit.serve()
Ejemplo n.º 15
0
def test_composition_from_url_torchscript_servable(tmp_path):
    from flash.core.serve import expose, ModelComponent, Servable
    from flash.core.serve.types import Number
    """
    # Tensor x Tensor
    class MyModule(torch.nn.Module):
        def __init__(self):
            super(MyModule, self).__init__()

        def forward(self, a, b):
            result_0 = a / b
            result_1 = torch.div(a, b)
            result_2 = a.div(b)

            return result_0, result_1, result_2

    TorchScript (.pt) can be downloaded at TORCHSCRIPT_DOWNLOAD_URL
    """
    TORCHSCRIPT_DOWNLOAD_URL = "https://github.com/pytorch/pytorch/raw/95489b590f00801bdee7f41783f30874883cf6bb/test/jit/fixtures/test_versioned_div_tensor_inplace_v3.pt"  # noqa E501

    class ComponentTwoModels(ModelComponent):
        def __init__(self, model):
            self.encoder = model["encoder"]
            self.decoder = model["decoder"]

        @expose(inputs={"inp": Number()}, outputs={"output": Number()})
        def do_my_predict(self, inp):
            """My predict docstring."""
            return self.decoder(self.encoder(inp, inp), inp)

    gm = Servable(TORCHSCRIPT_DOWNLOAD_URL,
                  download_path=tmp_path / "tmp_download.pt")

    c_1 = ComponentTwoModels({"encoder": gm, "decoder": gm})
    c_2 = ComponentTwoModels({"encoder": gm, "decoder": gm})

    c_1.outputs.output >> c_2.inputs.inp

    ep = Endpoint(
        route="/predictr",
        inputs={"ep_in": c_1.inputs.inp},
        outputs={"ep_out": c_1.outputs.output},
    )

    composit = Composition(c_1=c_1,
                           c_2=c_2,
                           endpoints=ep,
                           TESTING=True,
                           DEBUG=True)
    app = composit.serve(host="0.0.0.0", port=8000)
    with TestClient(app) as tc:
        body = {
            "session": "UUID",
            "payload": {
                "ep_in": {
                    "num": 10
                },
            },
        }
        success = tc.post("http://127.0.0.1:8000/predictr", json=body)
        assert success.json() == {
            "result": {
                "ep_out": 1.0
            },
            "session": "UUID",
        }
Ejemplo n.º 16
0
def test_serving_single_component_and_endpoint_no_composition(
        session_global_datadir, lightning_squeezenet1_1_obj):
    from tests.core.serve.models import ClassificationInference

    comp = ClassificationInference(lightning_squeezenet1_1_obj)
    assert hasattr(comp.inputs, "img")
    assert hasattr(comp.outputs, "prediction")
    assert list(comp._flashserve_meta_.connections) == []

    ep = Endpoint(
        route="/different_route",
        inputs={"ep_in_image": comp.inputs.img},
        outputs={"ep_out_prediction": comp.outputs.prediction},
    )

    assert ep.route == "/different_route"

    composit = Composition(comp=comp, ep=ep, TESTING=True, DEBUG=True)
    app = composit.serve(host="0.0.0.0", port=8000)

    with TestClient(app) as tc:
        meta = tc.get("http://127.0.0.1:8000/different_route/meta")
        assert meta.json() == {
            "definitions": {
                "Ep_Ep_In_Image": {
                    "properties": {
                        "data": {
                            "title": "Data",
                            "type": "string"
                        }
                    },
                    "required": ["data"],
                    "title": "Ep_Ep_In_Image",
                    "type": "object",
                },
                "Ep_Payload": {
                    "properties": {
                        "ep_in_image": {
                            "$ref": "#/definitions/Ep_Ep_In_Image"
                        }
                    },
                    "required": ["ep_in_image"],
                    "title": "Ep_Payload",
                    "type": "object",
                },
            },
            "properties": {
                "payload": {
                    "$ref": "#/definitions/Ep_Payload"
                },
                "session": {
                    "title": "Session",
                    "type": "string"
                },
            },
            "required": ["payload"],
            "title": "Ep_RequestModel",
            "type": "object",
        }

        with (session_global_datadir / "fish.jpg").open("rb") as f:
            imgstr = base64.b64encode(f.read()).decode("UTF-8")
        body = {
            "session": "UUID",
            "payload": {
                "ep_in_image": {
                    "data": imgstr
                }
            }
        }
        success = tc.post("http://127.0.0.1:8000/different_route", json=body)
        assert tc.post("http://127.0.0.1:8000/classify",
                       json=body).status_code == 404
        assert tc.post("http://127.0.0.1:8000/my_test_component",
                       json=body).status_code == 404

        assert "result" in success.json()
        expected = {
            "session": "UUID",
            "result": {
                "ep_out_prediction": "goldfish, Carassius auratus"
            },
        }
        assert expected == success.json()

        res = tc.get("http://127.0.0.1:8000/flashserve/dag_json")
        assert res.status_code == 200
        assert res.json() == {
            "component_dependencies": {
                "callnum_1": {
                    "callnum_1.funcout": ["callnum_1.inputs.img"],
                    "callnum_1.inputs.img": [],
                    "callnum_1.outputs.prediction": ["callnum_1.funcout"],
                    "callnum_1.outputs.prediction.serial":
                    ["callnum_1.outputs.prediction"],
                }
            },
            "component_dependents": {
                "callnum_1": {
                    "callnum_1.funcout": ["callnum_1.outputs.prediction"],
                    "callnum_1.inputs.img": ["callnum_1.funcout"],
                    "callnum_1.outputs.prediction":
                    ["callnum_1.outputs.prediction.serial"],
                    "callnum_1.outputs.prediction.serial": [],
                }
            },
            "component_funcnames": {
                "callnum_1": {
                    "callnum_1.funcout": ["Compose"],
                    "callnum_1.inputs.img": ["packed_deserialize"],
                    "callnum_1.outputs.prediction": ["get"],
                    "callnum_1.outputs.prediction.serial": ["serialize"],
                }
            },
            "connections": [],
        }
Ejemplo n.º 17
0
def test_start_server_from_composition(tmp_path, squeezenet_servable, session_global_datadir):
    from tests.core.serve.models import ClassificationInferenceComposable

    squeezenet_gm, _ = squeezenet_servable
    comp1 = ClassificationInferenceComposable(squeezenet_gm)
    comp2 = ClassificationInferenceComposable(squeezenet_gm)
    comp3 = ClassificationInferenceComposable(squeezenet_gm)

    comp1.outputs.predicted_tag >> comp3.inputs.tag  # skipcq: PYL-W0104
    comp2.outputs.cropped_img >> comp3.inputs.img  # skipcq: PYL-W0104
    comp1.outputs.predicted_tag >> comp2.inputs.tag  # skipcq: PYL-W0104

    ep1 = Endpoint(
        route="/predict",
        inputs={
            "img_1": comp1.inputs.img,
            "img_2": comp2.inputs.img,
            "tag_1": comp1.inputs.tag,
        },
        outputs={"prediction": comp3.outputs.predicted_tag},
    )

    ep2 = Endpoint(
        route="/other_predict",
        inputs={
            "img_1": comp1.inputs.img,
            "img_2": comp2.inputs.img,
            "tag_1": comp1.inputs.tag,
        },
        outputs={
            "prediction_3": comp3.outputs.predicted_tag,
            "prediction_2": comp2.outputs.cropped_img,
        },
    )

    composit = Composition(
        comp1=comp1,
        comp2=comp2,
        comp3=comp3,
        predict_compositon_ep=ep1,
        other_predict_ep=ep2,
        TESTING=True,
        DEBUG=True,
    )

    with (session_global_datadir / "cat.jpg").open("rb") as f:
        cat_imgstr = base64.b64encode(f.read()).decode("UTF-8")
    with (session_global_datadir / "fish.jpg").open("rb") as f:
        fish_imgstr = base64.b64encode(f.read()).decode("UTF-8")
    data = {
        "session": "session_uuid",
        "payload": {
            "img_1": {
                "data": cat_imgstr
            },
            "img_2": {
                "data": fish_imgstr
            },
            "tag_1": {
                "label": "stingray"
            },
        },
    }
    expected_response = {
        "result": {
            "prediction": "goldfish, Carassius auratus"
        },
        "session": "session_uuid",
    }

    app = composit.serve(host="0.0.0.0", port=8000)
    with TestClient(app) as tc:
        res = tc.post("http://127.0.0.1:8000/predict", json=data)
        assert res.status_code == 200
        assert res.json() == expected_response