def test_guard_raises_inferred(): sg = ShapeGuard() a = torch.ones([1, 2, 3]) b = torch.ones([3, 2, 5]) sg.guard(a, "A, B, C") with pytest.raises(ShapeError): sg.guard(b, "C, B, A")
def test_guard_infers_assign(): sg = ShapeGuard() a = torch.ones([1, 2, 3]) sg.guard(a, "A, D=B*2, A+C") assert sg.dims == {"A": 1, "B": 1, "C": 2, "D": 2} with pytest.raises(ShapeError): sg.guard(a, "1, E=D/2, 3")
def preprocess(self, data): sg = ShapeGuard(dims={ "B": self.batch_size, "H": self.image_dim[0], "W": self.image_dim[1] }) image = sg.guard(data["image"], "B, h, w, C") mask = sg.guard(data["mask"], "B, L, h, w, 1") # to float image = tf.cast(image, tf.float32) / 255.0 mask = tf.cast(mask, tf.float32) / 255.0 # crop if self.crop_region is not None: height_slice = slice(self.crop_region[0][0], self.crop_region[0][1]) width_slice = slice(self.crop_region[1][0], self.crop_region[1][1]) image = image[:, height_slice, width_slice, :] mask = mask[:, :, height_slice, width_slice, :] flat_mask, unflatten = flatten_all_but_last(mask, n_dims=3) # rescale size = tf.constant(self.image_dim, dtype=tf.int32, shape=[2], verify_shape=True) image = tf.image.resize_images(image, size, method=tf.image.ResizeMethod.BILINEAR) mask = tf.image.resize_images( flat_mask, size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) if self.grayscale: image = tf.reduce_mean(image, axis=-1, keepdims=True) output = { "image": sg.guard(image[:, None], "B, T, H, W, C"), "mask": sg.guard(unflatten(mask)[:, None], "B, T, L, H, W, 1"), "factors": self.preprocess_factors(data, sg), } if "visibility" in data: output["visibility"] = sg.guard(data["visibility"], "B, L") else: output["visibility"] = tf.ones(sg["B, L"], dtype=tf.float32) return output
def test_guard_dynamic_shape(): sg = ShapeGuard() with pytest.raises(ShapeError): sg.guard([None, 2, 3], "C, B, A") sg.guard([None, 2, 3], "?, B, A") sg.guard([1, 2, 3], "C?, B, A") sg.guard([None, 2, 3], "C?, B, A")
def preprocess(self, data): sg = ShapeGuard(dims={ "B": self.batch_size, "H": self.image_dim[0], "W": self.image_dim[1] }) image = sg.guard(data["image"], "B, h, w, C") # to float image = tf.cast(image, tf.float32) / 255.0 # crop if self.crop_region is not None: height_slice = slice(self.crop_region[0][0], self.crop_region[0][1]) width_slice = slice(self.crop_region[1][0], self.crop_region[1][1]) image = image[:, height_slice, width_slice, :] mask = mask[:, :, height_slice, width_slice, :] # rescale size = tf.constant(self.image_dim, dtype=tf.int32, shape=[2], verify_shape=True) image = tf.image.resize_images(image, size, method=tf.image.ResizeMethod.BILINEAR) if self.grayscale: image = tf.reduce_mean(image, axis=-1, keepdims=True) output = { "image": sg.guard(image[:, None], "B, T, H, W, C"), } return output
def test_guard_ellipsis(): sg = ShapeGuard() a = torch.ones([1, 2, 3, 4, 5]) sg.guard(a, "...") sg.guard(a, "..., 5") sg.guard(a, "..., 4, 5") sg.guard(a, "1, ...") sg.guard(a, "1, 2, ...") sg.guard(a, "1, 2, ..., 4, 5") sg.guard(a, "1, 2, 3, ..., 4, 5") with pytest.raises(ShapeError): sg.guard(a, "1, 2, 3, 4, 5, 6,...") with pytest.raises(ShapeError): sg.guard(a, "..., 1, 2, 3, 4, 5, 6")
def test_guard_ignores_underscore(): sg = ShapeGuard() a = torch.ones([1, 2, 3]) sg.guard(a, "_A, _b, 3") assert sg.dims == {}
def test_guard_ignores_wildcard(): sg = ShapeGuard() a = torch.ones([1, 2, 3]) sg.guard(a, "*, *, 3") assert sg.dims == {}
def test_guard_raises_complex(): sg = ShapeGuard() a = torch.ones([1, 2, 3]) with pytest.raises(ShapeError): sg.guard(a, "A, B, B")
def test_guard_infers_dimensions_operator_priority(): sg = ShapeGuard() a = torch.ones([1, 2, 8]) sg.guard(a, "A, B, A+C*2+1") assert sg.dims == {"A": 1, "B": 2, "C": 3}
def test_guard_infers_dimensions_complex(): sg = ShapeGuard() a = torch.ones([1, 2, 3]) sg.guard(a, "A, B*2, A+C") assert sg.dims == {"A": 1, "B": 1, "C": 2}
def test_guard_infers_dimensions(): sg = ShapeGuard() a = torch.ones([1, 2, 3]) sg.guard(a, "A, B, C") assert sg.dims == {"A": 1, "B": 2, "C": 3}
def test_guard_raises(): sg = ShapeGuard() a = torch.ones([1, 2, 3]) with pytest.raises(ShapeError): sg.guard(a, "3, 2, 1")
def test_guard_ellipsis_infer_dims(): sg = ShapeGuard() a = torch.ones([1, 2, 3, 4, 5]) sg.guard(a, "A, B, ..., C") assert sg.dims == {"A": 1, "B": 2, "C": 5}