コード例 #1
0
def test_repeated_serialize(session_global_datadir):
    repeated = Repeated(dtype=Label(path=str(session_global_datadir / "imagenet_labels.txt")))
    assert repeated.deserialize(*({"label": "chickadee"}, {"label": "stingray"})) == (
        torch.tensor(19),
        torch.tensor(6),
    )
    assert repeated.serialize((torch.tensor(19), torch.tensor(6))) == ("chickadee", "stingray")
    assert repeated.serialize(torch.tensor([19, 6])) == ("chickadee", "stingray")
コード例 #2
0
def test_repeated_deserialize():
    repeated = Repeated(dtype=Label(classes=["classA", "classB"]))
    res = repeated.deserialize(*({
        "label": "classA"
    }, {
        "label": "classA"
    }, {
        "label": "classB"
    }))
    assert res == (torch.tensor(0), torch.tensor(0), torch.tensor(1))
コード例 #3
0
class ObjectDetection(ModelComponent):
    def __init__(self, model):
        self.model = model

    @expose(
        inputs={"img": Image()},
        outputs={
            "boxes": Repeated(BBox()),
            "labels": Repeated(Label("classes.txt"))
        },
    )
    def detect(self, img):
        img = img.permute(0, 3, 2, 1).float() / 255
        out = self.model(img)[0]
        return out["boxes"], out["labels"]
コード例 #4
0
    class ClassificationInferenceRepeated(ModelComponent):

        def __init__(self, model):
            self.model = model

        @expose(
            inputs={"img": Repeated(Image(extension="JPG"))},
            outputs={
                "prediction": Repeated(Label(path=str(CWD / "imagenet_labels.txt"))),
                "other": Number(),
            },
        )
        def classify(self, img):
            img = img[0].float() / 255
            mean = torch.tensor([[[0.485, 0.456, 0.406]]]).float()
            std = torch.tensor([[[0.229, 0.224, 0.225]]]).float()
            img = (img - mean) / std
            img = img.permute(0, 3, 2, 1)
            out = self.model(img)
            return ([out.argmax(), out.argmax()], torch.Tensor([21]))
コード例 #5
0
def test_not_allow_nested_repeated():
    with pytest.raises(TypeError):
        Repeated(dtype=Repeated())
コード例 #6
0
def test_repeated_non_grid_dtype():
    class NonGridDtype:
        pass

    with pytest.raises(TypeError):
        Repeated(NonGridDtype())
コード例 #7
0
def test_repeated_max_len():
    repeated = Repeated(dtype=Label(classes=["classA", "classB"]), max_len=2)

    with pytest.raises(ValueError):
        repeated.deserialize(*({
            "label": "classA"
        }, {
            "label": "classA"
        }, {
            "label": "classB"
        }))
    assert repeated.deserialize(*({
        "label": "classA"
    }, {
        "label": "classB"
    })) == (
        torch.tensor(0),
        torch.tensor(1),
    )
    with pytest.raises(ValueError):
        repeated.serialize((torch.tensor(0), torch.tensor(0), torch.tensor(1)))
    assert repeated.serialize(
        (torch.tensor(1), torch.tensor(0))) == ("classB", "classA")

    # max_len < 1
    with pytest.raises(ValueError):
        Repeated(dtype=Label(classes=["classA", "classB"]), max_len=0)
    assert Repeated(dtype=Label(classes=["classA", "classB"]),
                    max_len=1) is not None

    # type(max_len) is not int
    with pytest.raises(TypeError):
        Repeated(dtype=Label(classes=["classA", "classB"]), max_len=str)
コード例 #8
0
def test_repeated_non_serve_dtype():
    class NonServeDtype:
        pass

    with pytest.raises(TypeError):
        Repeated(NonServeDtype())