示例#1
0
def test_matches_wildcards():
    sg = ShapeGuard()
    z = torch.ones([1, 2, 4, 8])
    assert sg.matches(z, "1, 2, 4, *")
    assert sg.matches(z, "*, *, *, 8")
    assert not sg.matches(z, "*")
    assert not sg.matches(z, "*, *, *")
示例#2
0
def test_guard_dynamic_shape():
    sg = ShapeGuard()
    with pytest.raises(ShapeError):
        sg.guard([None, 2, 3], "C, B, A")

    sg.guard([None, 2, 3], "?, B, A")
    sg.guard([1, 2, 3], "C?, B, A")
    sg.guard([None, 2, 3], "C?, B, A")
示例#3
0
def test_matches_basic_numerical():
    sg = ShapeGuard()
    a = torch.ones([1, 2, 3])
    assert sg.matches(a, "1, 2, 3")
    assert sg.matches(a, "1, 2.0, 3.0")
    with pytest.raises(ShapeError):
        assert sg.matches(a, "1, 2.0, 3.1")

    assert not sg.matches(a, "1, 2, 4")
    assert not sg.matches(a, "1, 2, 3, 4")
    assert not sg.matches(a, "1, 2")
示例#4
0
def test_guard_raises_inferred():
    sg = ShapeGuard()
    a = torch.ones([1, 2, 3])
    b = torch.ones([3, 2, 5])
    sg.guard(a, "A, B, C")
    with pytest.raises(ShapeError):
        sg.guard(b, "C, B, A")
示例#5
0
def test_guard_infers_assign():
    sg = ShapeGuard()
    a = torch.ones([1, 2, 3])
    sg.guard(a, "A, D=B*2, A+C")
    assert sg.dims == {"A": 1, "B": 1, "C": 2, "D": 2}
    with pytest.raises(ShapeError):
        sg.guard(a, "1, E=D/2, 3")
示例#6
0
    def preprocess(self, data):
        sg = ShapeGuard(dims={
            "B": self.batch_size,
            "H": self.image_dim[0],
            "W": self.image_dim[1]
        })
        image = sg.guard(data["image"], "B, h, w, C")

        # to float
        image = tf.cast(image, tf.float32) / 255.0

        # crop
        if self.crop_region is not None:
            height_slice = slice(self.crop_region[0][0],
                                 self.crop_region[0][1])
            width_slice = slice(self.crop_region[1][0], self.crop_region[1][1])
            image = image[:, height_slice, width_slice, :]

            mask = mask[:, :, height_slice, width_slice, :]

        # rescale
        size = tf.constant(self.image_dim,
                           dtype=tf.int32,
                           shape=[2],
                           verify_shape=True)
        image = tf.image.resize_images(image,
                                       size,
                                       method=tf.image.ResizeMethod.BILINEAR)

        if self.grayscale:
            image = tf.reduce_mean(image, axis=-1, keepdims=True)

        output = {
            "image": sg.guard(image[:, None], "B, T, H, W, C"),
        }

        return output
示例#7
0
    def preprocess(self, data):
        sg = ShapeGuard(dims={
            "B": self.batch_size,
            "H": self.image_dim[0],
            "W": self.image_dim[1]
        })
        image = sg.guard(data["image"], "B, h, w, C")
        mask = sg.guard(data["mask"], "B, L, h, w, 1")

        # to float
        image = tf.cast(image, tf.float32) / 255.0
        mask = tf.cast(mask, tf.float32) / 255.0

        # crop
        if self.crop_region is not None:
            height_slice = slice(self.crop_region[0][0],
                                 self.crop_region[0][1])
            width_slice = slice(self.crop_region[1][0], self.crop_region[1][1])
            image = image[:, height_slice, width_slice, :]

            mask = mask[:, :, height_slice, width_slice, :]

        flat_mask, unflatten = flatten_all_but_last(mask, n_dims=3)

        # rescale
        size = tf.constant(self.image_dim,
                           dtype=tf.int32,
                           shape=[2],
                           verify_shape=True)
        image = tf.image.resize_images(image,
                                       size,
                                       method=tf.image.ResizeMethod.BILINEAR)
        mask = tf.image.resize_images(
            flat_mask, size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

        if self.grayscale:
            image = tf.reduce_mean(image, axis=-1, keepdims=True)

        output = {
            "image": sg.guard(image[:, None], "B, T, H, W, C"),
            "mask": sg.guard(unflatten(mask)[:, None], "B, T, L, H, W, 1"),
            "factors": self.preprocess_factors(data, sg),
        }

        if "visibility" in data:
            output["visibility"] = sg.guard(data["visibility"], "B, L")
        else:
            output["visibility"] = tf.ones(sg["B, L"], dtype=tf.float32)

        return output
示例#8
0
 def get_placeholders(self, batch_size=None):
     batch_size = batch_size or self.batch_size
     sg = ShapeGuard(
         dims={
             "B": batch_size,
             "H": self.image_dim[0],
             "W": self.image_dim[1],
             "L": self.num_true_objects,
             "C": 3,
             "T": 1,
         })
     return {
         "image": tf.placeholder(dtype=tf.float32,
                                 shape=sg["B, T, H, W, C"]),
         "mask": tf.placeholder(dtype=tf.float32,
                                shape=sg["B, T, L, H, W, 1"]),
         "visibility": tf.placeholder(dtype=tf.float32, shape=sg["B, L"]),
         "factors": {
             name: tf.placeholder(dtype=dtype,
                                  shape=sg["B, L, {}".format(size)])
             for name, (dtype, size) in self.factors
         },
     }
示例#9
0
def test_guard_infers_dimensions_operator_priority():
    sg = ShapeGuard()
    a = torch.ones([1, 2, 8])
    sg.guard(a, "A, B, A+C*2+1")
    assert sg.dims == {"A": 1, "B": 2, "C": 3}
示例#10
0
def test_matches_named_dims():
    sg = ShapeGuard(dims={"N": 24, "Z": 16})
    z = torch.ones([24, 16])
    assert sg.matches(z, "N, Z")
    assert sg.matches(z, "24, Z")
    assert not sg.matches(z, "N, N")
示例#11
0
def test_guard_ellipsis_infer_dims():
    sg = ShapeGuard()
    a = torch.ones([1, 2, 3, 4, 5])
    sg.guard(a, "A, B, ..., C")
    assert sg.dims == {"A": 1, "B": 2, "C": 5}
示例#12
0
def test_guard_ellipsis():
    sg = ShapeGuard()
    a = torch.ones([1, 2, 3, 4, 5])
    sg.guard(a, "...")
    sg.guard(a, "..., 5")
    sg.guard(a, "..., 4, 5")
    sg.guard(a, "1, ...")
    sg.guard(a, "1, 2, ...")
    sg.guard(a, "1, 2, ..., 4, 5")
    sg.guard(a, "1, 2, 3, ..., 4, 5")

    with pytest.raises(ShapeError):
        sg.guard(a, "1, 2, 3, 4, 5, 6,...")

    with pytest.raises(ShapeError):
        sg.guard(a, "..., 1, 2, 3, 4, 5, 6")
示例#13
0
def test_guard_ignores_wildcard():
    sg = ShapeGuard()
    a = torch.ones([1, 2, 3])
    sg.guard(a, "*, *, 3")
    assert sg.dims == {}
示例#14
0
def test_guard_ignores_underscore():
    sg = ShapeGuard()
    a = torch.ones([1, 2, 3])
    sg.guard(a, "_A, _b, 3")
    assert sg.dims == {}
示例#15
0
def test_guard_raises():
    sg = ShapeGuard()
    a = torch.ones([1, 2, 3])
    with pytest.raises(ShapeError):
        sg.guard(a, "3, 2, 1")
示例#16
0
def test_guard_raises_complex():
    sg = ShapeGuard()
    a = torch.ones([1, 2, 3])
    with pytest.raises(ShapeError):
        sg.guard(a, "A, B, B")
示例#17
0
def test_matches_ignores_spaces():
    sg = ShapeGuard()
    a = torch.ones([1, 2, 3])
    assert sg.matches(a, "1,2,3")
    assert sg.matches(a, "1 ,  2, 3   ")
    assert sg.matches(a, "1,  2,3 ")
示例#18
0
def test_guard_infers_dimensions():
    sg = ShapeGuard()
    a = torch.ones([1, 2, 3])
    sg.guard(a, "A, B, C")
    assert sg.dims == {"A": 1, "B": 2, "C": 3}
示例#19
0
def test_guard_infers_dimensions_complex():
    sg = ShapeGuard()
    a = torch.ones([1, 2, 3])
    sg.guard(a, "A, B*2, A+C")
    assert sg.dims == {"A": 1, "B": 1, "C": 2}