示例#1
0
def test_guard_raises_inferred():
    sg = ShapeGuard()
    a = torch.ones([1, 2, 3])
    b = torch.ones([3, 2, 5])
    sg.guard(a, "A, B, C")
    with pytest.raises(ShapeError):
        sg.guard(b, "C, B, A")
示例#2
0
def test_guard_infers_assign():
    sg = ShapeGuard()
    a = torch.ones([1, 2, 3])
    sg.guard(a, "A, D=B*2, A+C")
    assert sg.dims == {"A": 1, "B": 1, "C": 2, "D": 2}
    with pytest.raises(ShapeError):
        sg.guard(a, "1, E=D/2, 3")
示例#3
0
    def preprocess(self, data):
        sg = ShapeGuard(dims={
            "B": self.batch_size,
            "H": self.image_dim[0],
            "W": self.image_dim[1]
        })
        image = sg.guard(data["image"], "B, h, w, C")
        mask = sg.guard(data["mask"], "B, L, h, w, 1")

        # to float
        image = tf.cast(image, tf.float32) / 255.0
        mask = tf.cast(mask, tf.float32) / 255.0

        # crop
        if self.crop_region is not None:
            height_slice = slice(self.crop_region[0][0],
                                 self.crop_region[0][1])
            width_slice = slice(self.crop_region[1][0], self.crop_region[1][1])
            image = image[:, height_slice, width_slice, :]

            mask = mask[:, :, height_slice, width_slice, :]

        flat_mask, unflatten = flatten_all_but_last(mask, n_dims=3)

        # rescale
        size = tf.constant(self.image_dim,
                           dtype=tf.int32,
                           shape=[2],
                           verify_shape=True)
        image = tf.image.resize_images(image,
                                       size,
                                       method=tf.image.ResizeMethod.BILINEAR)
        mask = tf.image.resize_images(
            flat_mask, size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

        if self.grayscale:
            image = tf.reduce_mean(image, axis=-1, keepdims=True)

        output = {
            "image": sg.guard(image[:, None], "B, T, H, W, C"),
            "mask": sg.guard(unflatten(mask)[:, None], "B, T, L, H, W, 1"),
            "factors": self.preprocess_factors(data, sg),
        }

        if "visibility" in data:
            output["visibility"] = sg.guard(data["visibility"], "B, L")
        else:
            output["visibility"] = tf.ones(sg["B, L"], dtype=tf.float32)

        return output
示例#4
0
def test_guard_dynamic_shape():
    sg = ShapeGuard()
    with pytest.raises(ShapeError):
        sg.guard([None, 2, 3], "C, B, A")

    sg.guard([None, 2, 3], "?, B, A")
    sg.guard([1, 2, 3], "C?, B, A")
    sg.guard([None, 2, 3], "C?, B, A")
示例#5
0
    def preprocess(self, data):
        sg = ShapeGuard(dims={
            "B": self.batch_size,
            "H": self.image_dim[0],
            "W": self.image_dim[1]
        })
        image = sg.guard(data["image"], "B, h, w, C")

        # to float
        image = tf.cast(image, tf.float32) / 255.0

        # crop
        if self.crop_region is not None:
            height_slice = slice(self.crop_region[0][0],
                                 self.crop_region[0][1])
            width_slice = slice(self.crop_region[1][0], self.crop_region[1][1])
            image = image[:, height_slice, width_slice, :]

            mask = mask[:, :, height_slice, width_slice, :]

        # rescale
        size = tf.constant(self.image_dim,
                           dtype=tf.int32,
                           shape=[2],
                           verify_shape=True)
        image = tf.image.resize_images(image,
                                       size,
                                       method=tf.image.ResizeMethod.BILINEAR)

        if self.grayscale:
            image = tf.reduce_mean(image, axis=-1, keepdims=True)

        output = {
            "image": sg.guard(image[:, None], "B, T, H, W, C"),
        }

        return output
示例#6
0
def test_guard_ellipsis():
    sg = ShapeGuard()
    a = torch.ones([1, 2, 3, 4, 5])
    sg.guard(a, "...")
    sg.guard(a, "..., 5")
    sg.guard(a, "..., 4, 5")
    sg.guard(a, "1, ...")
    sg.guard(a, "1, 2, ...")
    sg.guard(a, "1, 2, ..., 4, 5")
    sg.guard(a, "1, 2, 3, ..., 4, 5")

    with pytest.raises(ShapeError):
        sg.guard(a, "1, 2, 3, 4, 5, 6,...")

    with pytest.raises(ShapeError):
        sg.guard(a, "..., 1, 2, 3, 4, 5, 6")
示例#7
0
def test_guard_ignores_underscore():
    sg = ShapeGuard()
    a = torch.ones([1, 2, 3])
    sg.guard(a, "_A, _b, 3")
    assert sg.dims == {}
示例#8
0
def test_guard_ignores_wildcard():
    sg = ShapeGuard()
    a = torch.ones([1, 2, 3])
    sg.guard(a, "*, *, 3")
    assert sg.dims == {}
示例#9
0
def test_guard_raises_complex():
    sg = ShapeGuard()
    a = torch.ones([1, 2, 3])
    with pytest.raises(ShapeError):
        sg.guard(a, "A, B, B")
示例#10
0
def test_guard_infers_dimensions_operator_priority():
    sg = ShapeGuard()
    a = torch.ones([1, 2, 8])
    sg.guard(a, "A, B, A+C*2+1")
    assert sg.dims == {"A": 1, "B": 2, "C": 3}
示例#11
0
def test_guard_infers_dimensions_complex():
    sg = ShapeGuard()
    a = torch.ones([1, 2, 3])
    sg.guard(a, "A, B*2, A+C")
    assert sg.dims == {"A": 1, "B": 1, "C": 2}
示例#12
0
def test_guard_infers_dimensions():
    sg = ShapeGuard()
    a = torch.ones([1, 2, 3])
    sg.guard(a, "A, B, C")
    assert sg.dims == {"A": 1, "B": 2, "C": 3}
示例#13
0
def test_guard_raises():
    sg = ShapeGuard()
    a = torch.ones([1, 2, 3])
    with pytest.raises(ShapeError):
        sg.guard(a, "3, 2, 1")
示例#14
0
def test_guard_ellipsis_infer_dims():
    sg = ShapeGuard()
    a = torch.ones([1, 2, 3, 4, 5])
    sg.guard(a, "A, B, ..., C")
    assert sg.dims == {"A": 1, "B": 2, "C": 5}