def test_repeated_serialize(session_global_datadir): repeated = Repeated(dtype=Label(path=str(session_global_datadir / "imagenet_labels.txt"))) assert repeated.deserialize(*({"label": "chickadee"}, {"label": "stingray"})) == ( torch.tensor(19), torch.tensor(6), ) assert repeated.serialize((torch.tensor(19), torch.tensor(6))) == ("chickadee", "stingray") assert repeated.serialize(torch.tensor([19, 6])) == ("chickadee", "stingray")
def test_repeated_deserialize(): repeated = Repeated(dtype=Label(classes=["classA", "classB"])) res = repeated.deserialize(*({ "label": "classA" }, { "label": "classA" }, { "label": "classB" })) assert res == (torch.tensor(0), torch.tensor(0), torch.tensor(1))
class ObjectDetection(ModelComponent): def __init__(self, model): self.model = model @expose( inputs={"img": Image()}, outputs={ "boxes": Repeated(BBox()), "labels": Repeated(Label("classes.txt")) }, ) def detect(self, img): img = img.permute(0, 3, 2, 1).float() / 255 out = self.model(img)[0] return out["boxes"], out["labels"]
class ClassificationInferenceRepeated(ModelComponent): def __init__(self, model): self.model = model @expose( inputs={"img": Repeated(Image(extension="JPG"))}, outputs={ "prediction": Repeated(Label(path=str(CWD / "imagenet_labels.txt"))), "other": Number(), }, ) def classify(self, img): img = img[0].float() / 255 mean = torch.tensor([[[0.485, 0.456, 0.406]]]).float() std = torch.tensor([[[0.229, 0.224, 0.225]]]).float() img = (img - mean) / std img = img.permute(0, 3, 2, 1) out = self.model(img) return ([out.argmax(), out.argmax()], torch.Tensor([21]))
def test_not_allow_nested_repeated(): with pytest.raises(TypeError): Repeated(dtype=Repeated())
def test_repeated_non_grid_dtype(): class NonGridDtype: pass with pytest.raises(TypeError): Repeated(NonGridDtype())
def test_repeated_max_len(): repeated = Repeated(dtype=Label(classes=["classA", "classB"]), max_len=2) with pytest.raises(ValueError): repeated.deserialize(*({ "label": "classA" }, { "label": "classA" }, { "label": "classB" })) assert repeated.deserialize(*({ "label": "classA" }, { "label": "classB" })) == ( torch.tensor(0), torch.tensor(1), ) with pytest.raises(ValueError): repeated.serialize((torch.tensor(0), torch.tensor(0), torch.tensor(1))) assert repeated.serialize( (torch.tensor(1), torch.tensor(0))) == ("classB", "classA") # max_len < 1 with pytest.raises(ValueError): Repeated(dtype=Label(classes=["classA", "classB"]), max_len=0) assert Repeated(dtype=Label(classes=["classA", "classB"]), max_len=1) is not None # type(max_len) is not int with pytest.raises(TypeError): Repeated(dtype=Label(classes=["classA", "classB"]), max_len=str)
def test_repeated_non_serve_dtype(): class NonServeDtype: pass with pytest.raises(TypeError): Repeated(NonServeDtype())