Example #1
    def test_op_forward_pass(self, on_gpu, dtype, tol):
        data_width = 7
        data_height = 9
        data_channels = 5
        warp_width = 4
        warp_height = 8
        batch_size = 10

        warp = _make_warp(batch_size, warp_height, warp_width,
        data_shape = (batch_size, data_height, data_width, data_channels)
        data = np.random.rand(*data_shape).astype(dtype.as_numpy_dtype)

        with self.test_session(use_gpu=on_gpu, force_gpu=False) as sess:
            data_ph = tf.placeholder(dtype, shape=(None, ) + data.shape[1:])
            warp_ph = tf.placeholder(dtype, shape=(None, ) + warp.shape[1:])
            outputs = snt.resampler(data=data_ph, warp=warp_ph)
                             [None, warp_height, warp_width, data_channels])
            out = sess.run(outputs, feed_dict={data_ph: data, warp_ph: warp})

        # Generate reference output via bilinear interpolation in numpy
        reference_output = np.zeros_like(out)
        for batch in xrange(batch_size):
            for c in xrange(data_channels):
                reference_output[batch, :, :, c] = _bilinearly_interpolate(
                    data[batch, :, :, c], warp[batch, :, :, 0],
                    warp[batch, :, :, 1])

        self.assertAllClose(out, reference_output, rtol=tol, atol=tol)
Example #2
    def test_op_backward_pass(self, on_gpu, dtype, tol):
        data_width = 5
        data_height = 4
        data_channels = 3
        warp_width = 2
        warp_height = 6
        batch_size = 10

        warp = _make_warp(batch_size, warp_height, warp_width,
        data_shape = (batch_size, data_height, data_width, data_channels)
        data = np.random.rand(*data_shape).astype(dtype.as_numpy_dtype)

        with self.test_session(use_gpu=on_gpu, force_gpu=False):
            data_tensor = tf.constant(data)
            warp_tensor = tf.constant(warp)
            output_tensor = snt.resampler(data=data_tensor, warp=warp_tensor)

            grads = tf.test.compute_gradient(
                [data_tensor, warp_tensor], [
                ], output_tensor,
                output_tensor.get_shape().as_list(), [data, warp])

            if not on_gpu:
                # On CPU we perform numerical differentiation at the best available
                # precision, and compare against that. This is necessary for test to
                # pass for float16.
                data_tensor_64 = tf.constant(data, dtype=tf.float64)
                warp_tensor_64 = tf.constant(warp, dtype=tf.float64)
                output_tensor_64 = snt.resampler(data=data_tensor_64,
                grads_64 = tf.test.compute_gradient(
                    [data_tensor_64, warp_tensor_64], [
                    ], output_tensor_64,
                    output_tensor.get_shape().as_list(), [data, warp])

                for g, g_64 in zip(grads, grads_64):
                    self.assertLess(np.fabs(g[0] - g_64[1]).max(), tol)

                for g in grads:
                    self.assertLess(np.fabs(g[0] - g[1]).max(), tol)
def spatial_transformer(img_tensor, transform_params, crop_size):
    :param img_tensor: tf.Tensor of size (batch_size, Height, Width, channels)
    :param transform_params: tf.Tensor of size (batch_size, 4), where params are  (scale_y, shift_y, scale_x, shift_x)
    :param crop_size): tuple of 2 ints, size of the resulting crop
    constraints = snt.AffineWarpConstraints.no_shear_2d()
    img_size = img_tensor.shape.as_list()[1:]
    warper = snt.AffineGridWarper(img_size, crop_size, constraints)
    grid_coords = warper(transform_params)
    glimpse = snt.resampler(img_tensor[..., tf.newaxis], grid_coords)
    return glimpse
Example #4
    def _build(self, img, transform_params):
        if len(img.get_shape()) == 3:
            img = img[..., tf.newaxis]

        grid_coords = self._warper(transform_params)
        return snt.resampler(img, grid_coords)
Example #5
    def test_op_errors(self):
        data_width = 7
        data_height = 9
        data_depth = 3
        data_channels = 5
        warp_width = 4
        warp_height = 8
        batch_size = 10

        # Input data shape is not defined over a 2D grid, i.e. its shape is not like
        # (batch_size, data_height, data_width, data_channels).
        with self.test_session() as sess:
            data_shape = (batch_size, data_height, data_width, data_depth,
            data = np.zeros(data_shape)
            warp_shape = (batch_size, warp_height, warp_width, 2)
            warp = np.zeros(warp_shape)
            outputs = snt.resampler(tf.constant(data), tf.constant(warp))

            with self.assertRaisesRegexp(
                    "Only bilinear interpolation is currently "

        # Warp tensor must be at least a matrix, with shape [batch_size, 2].
        with self.test_session() as sess:
            data_shape = (batch_size, data_height, data_width, data_channels)
            data = np.zeros(data_shape)
            warp_shape = (batch_size, )
            warp = np.zeros(warp_shape)
            outputs = snt.resampler(tf.constant(data), tf.constant(warp))

            with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
                                         "warp should be at least a matrix"):

        # The batch size of the data and warp tensors must be the same.
        with self.test_session() as sess:
            data_shape = (batch_size, data_height, data_width, data_channels)
            data = np.zeros(data_shape)
            warp_shape = (batch_size + 1, warp_height, warp_width, 2)
            warp = np.zeros(warp_shape)
            outputs = snt.resampler(tf.constant(data), tf.constant(warp))

            with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
                                         "Batch size of data and warp tensor"):

        # The warp tensor must contain 2D coordinates, i.e. its shape last dimension
        # must be 2.
        with self.test_session() as sess:
            data_shape = (batch_size, data_height, data_width, data_channels)
            data = np.zeros(data_shape)
            warp_shape = (batch_size, warp_height, warp_width, 3)
            warp = np.zeros(warp_shape)
            outputs = snt.resampler(tf.constant(data), tf.constant(warp))

            with self.assertRaisesRegexp(
                    "Only bilinear interpolation is supported, "