Ejemplo n.º 1
0
def test_two_component_invalid_dependencies_fail(lightning_squeezenet1_1_obj):
    comp1 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj)
    comp2 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj)

    with pytest.raises(RuntimeError, match="Cannot create cycle"):
        comp1.inputs["tag"] << comp1.outputs.predicted_tag
    with pytest.raises(RuntimeError, match="Cannot create cycle"):
        comp1.inputs.tag << comp1.outputs["predicted_tag"]

    with pytest.raises(AttributeError):
        comp1.inputs["tag"] >> comp2.inputs["label"]
    with pytest.raises(AttributeError):
        comp1.inputs.tag >> comp2.inputs.label

    with pytest.raises(AttributeError):
        comp1.inputs["tag"] >> comp2.outputs["label"]
    with pytest.raises(AttributeError):
        comp1.inputs.tag >> comp2.outputs.label

    with pytest.raises(TypeError):
        comp2.outputs["predicted_tag"] >> comp1.outputs["predicted_tag"]
    with pytest.raises(TypeError):
        comp2.outputs.predicted_tag >> comp1.outputs.predicted_tag

    class Foo:
        def __init__(self):
            pass

    foo = Foo()
    with pytest.raises(TypeError):
        comp1.inputs["tag"] >> foo
Ejemplo n.º 2
0
def test_component_initialization(lightning_squeezenet1_1_obj):
    with pytest.raises(TypeError):
        ClassificationInferenceComposable(wrongname=lightning_squeezenet1_1_obj)

    comp = ClassificationInferenceComposable(lightning_squeezenet1_1_obj)
    assert comp.uid == "callnum_1"
    assert hasattr(comp.inputs, "img")
    assert hasattr(comp.inputs, "tag")
    assert hasattr(comp.outputs, "predicted_tag")
    assert hasattr(comp.outputs, "cropped_img")
    assert "img" in comp.inputs
    assert "predicted_tag" in comp.outputs
Ejemplo n.º 3
0
def test_composition_recieve_wrong_arg_type(lightning_squeezenet1_1_obj):
    # no endpoints or components
    with pytest.raises(TypeError):
        _ = Composition(hello="world")

    # no endpoints multiple components
    from tests.core.serve.models import ClassificationInferenceComposable

    comp1 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj)
    comp2 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj)

    with pytest.raises(ValueError):
        _ = Composition(c1=comp1, c2=comp2)
Ejemplo n.º 4
0
def test_complex_spec_single_endpoint(tmp_path, lightning_squeezenet1_1_obj):
    from tests.core.serve.models import ClassificationInferenceComposable

    comp1 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj)
    comp2 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj)
    comp3 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj)

    comp1.outputs.predicted_tag >> comp3.inputs.tag  # skipcq: PYL-W0104
    comp2.outputs.cropped_img >> comp3.inputs.img  # skipcq: PYL-W0104
    comp1.outputs.predicted_tag >> comp2.inputs.tag  # skipcq: PYL-W0104

    ep = Endpoint(
        route="/predict",
        inputs={
            "img_1": comp1.inputs.img,
            "img_2": comp2.inputs.img,
            "tag_1": comp1.inputs.tag,
        },
        outputs={"prediction": comp3.outputs.predicted_tag},
    )

    composit = Composition(comp1=comp1,
                           comp2=comp2,
                           comp3=comp3,
                           predict_compositon_ep=ep)
    connections = [str(c) for c in composit.connections]
    assert connections == [
        "callnum_1.outputs.predicted_tag >> callnum_3.inputs.tag",
        "callnum_1.outputs.predicted_tag >> callnum_2.inputs.tag",
        "callnum_2.outputs.cropped_img >> callnum_3.inputs.img",
    ]
    assert composit.component_uid_names == {
        "callnum_1": "comp1",
        "callnum_2": "comp2",
        "callnum_3": "comp3",
    }

    actual_endpoints = {k: asdict(v) for k, v in composit.endpoints.items()}
    assert actual_endpoints == {
        "predict_compositon_ep": {
            "inputs": {
                "img_1": "callnum_1.inputs.img",
                "img_2": "callnum_2.inputs.img",
                "tag_1": "callnum_1.inputs.tag",
            },
            "outputs": {
                "prediction": "callnum_3.outputs.predicted_tag",
            },
            "route": "/predict",
        }
    }
Ejemplo n.º 5
0
def test_model_compute_dependencies(lightning_squeezenet1_1_obj):
    comp1 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj)
    comp2 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj)

    comp1.inputs.tag << comp2.outputs.predicted_tag
    res = [{
        "source_component": "callnum_2",
        "source_key": "predicted_tag",
        "target_component": "callnum_1",
        "target_key": "tag",
    }]
    assert list(map(lambda x: x._asdict(),
                    comp1._flashserve_meta_.connections)) == res
    assert list(comp2._flashserve_meta_.connections) == []
Ejemplo n.º 6
0
def test_create_invalid_endpoint(lightning_squeezenet1_1_obj):
    from flash.core.serve import Endpoint

    comp1 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj)

    with pytest.raises(TypeError, match="route parameter must be type"):
        _ = Endpoint(
            route=b"/INVALID",
            inputs={"inp": comp1.inputs.img},
            outputs={"out": comp1.outputs.cropped_img},
        )

    with pytest.raises(ValueError, match="route must begin with"):
        _ = Endpoint(
            route="hello",
            inputs={"inp": comp1.inputs.img},
            outputs={"out": comp1.outputs.cropped_img},
        )

    with pytest.raises(TypeError, match="inputs k=inp, v=b'INVALID'"):
        _ = Endpoint(
            route="/hello",
            inputs={"inp": b"INVALID"},
            outputs={"out": comp1.outputs.cropped_img},
        )

    with pytest.raises(TypeError, match="k=out, v=b'INVALID'"):
        _ = Endpoint(route="/hello", inputs={"inp": comp1.inputs.img}, outputs={"out": b"INVALID"})
Ejemplo n.º 7
0
def test_cycle_in_connection_fails(session_global_datadir,
                                   lightning_squeezenet1_1_obj):
    from tests.core.serve.models import ClassificationInferenceComposable

    c1 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj)

    with pytest.raises(RuntimeError):
        c1.outputs.cropped_img >> c1.inputs.img
Ejemplo n.º 8
0
def test_component_parameters(lightning_squeezenet1_1_obj):
    comp1 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj)
    comp2 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj)

    with pytest.raises(TypeError):
        # Immutability test
        comp1.inputs["newkey"] = comp2.inputs["tag"]

    first_tag = comp1.outputs["predicted_tag"]
    second_tag = comp2.inputs["tag"]
    assert isinstance(first_tag.datatype, Label)

    assert first_tag.connections == []
    first_tag >> second_tag
    assert str(first_tag.connections[0]) == ("callnum_1.outputs.predicted_tag >> callnum_2.inputs.tag")
    assert second_tag.connections == []
    assert first_tag.connections == comp1._flashserve_meta_.connections
Ejemplo n.º 9
0
def test_connection_invalid_raises(lightning_squeezenet1_1_obj):
    comp1 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj)

    with pytest.raises(RuntimeError, match="Cannot compose a parameters of same components"):
        comp1.outputs["predicted_tag"] >> comp1.outputs["predicted_tag"]

    class FakeParam:
        position = "outputs"

    fake_param = FakeParam()

    with pytest.raises(TypeError, match="Can only Compose another `Parameter`"):
        comp1.outputs.predicted_tag >> fake_param
Ejemplo n.º 10
0
def test_composit_endpoint_data(lightning_squeezenet1_1_obj):
    from tests.core.serve.models import ClassificationInferenceComposable

    comp = ClassificationInferenceComposable(lightning_squeezenet1_1_obj)
    composit = Composition(comp=comp)
    assert composit.component_uid_names == {"callnum_1": "comp"}
    assert composit.connections == []

    actual_endpoints = {k: asdict(v) for k, v in composit.endpoints.items()}
    assert actual_endpoints == {
        "classify_ENDPOINT": {
            "inputs": {
                "img": "callnum_1.inputs.img",
                "tag": "callnum_1.inputs.tag"
            },
            "outputs": {
                "cropped_img": "callnum_1.outputs.cropped_img",
                "predicted_tag": "callnum_1.outputs.predicted_tag",
            },
            "route": "/classify",
        }
    }

    ep = Endpoint(
        route="/predict",
        inputs={
            "label_1": comp.inputs.img,
            "tag_1": comp.inputs.tag,
        },
        outputs={
            "prediction": comp.outputs.predicted_tag,
            "cropped": comp.outputs.cropped_img,
        },
    )
    composit = Composition(comp=comp, predict_ep=ep)
    actual_endpoints = {k: asdict(v) for k, v in composit.endpoints.items()}
    assert actual_endpoints == {
        "predict_ep": {
            "inputs": {
                "label_1": "callnum_1.inputs.img",
                "tag_1": "callnum_1.inputs.tag"
            },
            "outputs": {
                "cropped": "callnum_1.outputs.cropped_img",
                "prediction": "callnum_1.outputs.predicted_tag",
            },
            "route": "/predict",
        }
    }
Ejemplo n.º 11
0
def test_model_compute_call_method(lightning_squeezenet1_1_obj):
    comp1 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj)
    img = torch.arange(195075).reshape((1, 255, 255, 3))
    tag = None
    out_res, out_img = comp1(img, tag)
    assert out_res.item() == 753
Ejemplo n.º 12
0
def test_endpoint_errors_on_wrong_key_name(lightning_squeezenet1_1_obj):
    from tests.core.serve.models import ClassificationInferenceComposable

    comp1 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj)

    # input key does not exist
    with pytest.raises(AttributeError):
        _ = Endpoint(
            route="/predict",
            inputs={
                "label_1": comp1.inputs.img,
                "tag_1": comp1.inputs.DOESNOTEXIST,
            },
            outputs={
                "prediction": comp1.outputs.predicted_tag,
                "cropped": comp1.outputs.cropped_img,
            },
        )

    # output key does not exist
    with pytest.raises(AttributeError):
        _ = Endpoint(
            route="/predict",
            inputs={
                "label_1": comp1.inputs.img,
                "tag_1": comp1.inputs.tag,
            },
            outputs={
                "prediction": comp1.outputs.predicted_tag,
                "cropped": comp1.outputs.DOESNOTEXIST,
            },
        )

    # output key does not exist
    ep = Endpoint(
        route="/predict",
        inputs={
            "label_1": comp1.inputs.img,
            "tag_1": comp1.inputs.tag,
        },
        outputs={
            "prediction": comp1.outputs.predicted_tag,
            "cropped": "callnum_1.outputs.DOESNOTEXIST",
        },
    )
    with pytest.raises(AttributeError):
        _ = Composition(comp1=comp1, predict_ep=ep)

    # input function does not exist
    ep = Endpoint(
        route="/predict",
        inputs={
            "label_1": comp1.inputs.img,
            "tag_1": "DOESNOTEXIST.inputs.tag",
        },
        outputs={
            "prediction": comp1.outputs.predicted_tag,
            "cropped": comp1.outputs.cropped_img,
        },
    )
    with pytest.raises(AttributeError):
        _ = Composition(comp1=comp1, predict_ep=ep)

    # output function does not exist
    ep = Endpoint(
        route="/predict",
        inputs={
            "label_1": comp1.inputs.img,
            "tag_1": comp1.inputs.tag,
        },
        outputs={
            "prediction": comp1.outputs.predicted_tag,
            "cropped": "DOESNOTEXIST.outputs.cropped_img",
        },
    )
    with pytest.raises(AttributeError):
        _ = Composition(comp1=comp1, predict_ep=ep)
Ejemplo n.º 13
0
def test_start_server_from_composition(tmp_path, squeezenet_servable, session_global_datadir):
    from tests.core.serve.models import ClassificationInferenceComposable

    squeezenet_gm, _ = squeezenet_servable
    comp1 = ClassificationInferenceComposable(squeezenet_gm)
    comp2 = ClassificationInferenceComposable(squeezenet_gm)
    comp3 = ClassificationInferenceComposable(squeezenet_gm)

    comp1.outputs.predicted_tag >> comp3.inputs.tag  # skipcq: PYL-W0104
    comp2.outputs.cropped_img >> comp3.inputs.img  # skipcq: PYL-W0104
    comp1.outputs.predicted_tag >> comp2.inputs.tag  # skipcq: PYL-W0104

    ep1 = Endpoint(
        route="/predict",
        inputs={
            "img_1": comp1.inputs.img,
            "img_2": comp2.inputs.img,
            "tag_1": comp1.inputs.tag,
        },
        outputs={"prediction": comp3.outputs.predicted_tag},
    )

    ep2 = Endpoint(
        route="/other_predict",
        inputs={
            "img_1": comp1.inputs.img,
            "img_2": comp2.inputs.img,
            "tag_1": comp1.inputs.tag,
        },
        outputs={
            "prediction_3": comp3.outputs.predicted_tag,
            "prediction_2": comp2.outputs.cropped_img,
        },
    )

    composit = Composition(
        comp1=comp1,
        comp2=comp2,
        comp3=comp3,
        predict_compositon_ep=ep1,
        other_predict_ep=ep2,
        TESTING=True,
        DEBUG=True,
    )

    with (session_global_datadir / "cat.jpg").open("rb") as f:
        cat_imgstr = base64.b64encode(f.read()).decode("UTF-8")
    with (session_global_datadir / "fish.jpg").open("rb") as f:
        fish_imgstr = base64.b64encode(f.read()).decode("UTF-8")
    data = {
        "session": "session_uuid",
        "payload": {
            "img_1": {
                "data": cat_imgstr
            },
            "img_2": {
                "data": fish_imgstr
            },
            "tag_1": {
                "label": "stingray"
            },
        },
    }
    expected_response = {
        "result": {
            "prediction": "goldfish, Carassius auratus"
        },
        "session": "session_uuid",
    }

    app = composit.serve(host="0.0.0.0", port=8000)
    with TestClient(app) as tc:
        res = tc.post("http://127.0.0.1:8000/predict", json=data)
        assert res.status_code == 200
        assert res.json() == expected_response