def testEventShape(self):
        shape_in_static = tensor_shape.TensorShape([2, 3])
        shape_out_static = tensor_shape.TensorShape([
            6,
        ])
        bijector = Reshape(event_shape_out=shape_out_static,
                           event_shape_in=shape_in_static,
                           validate_args=True)

        # test that forward_ and inverse_event_shape do sensible things
        # when shapes are statically known.
        self.assertEqual(bijector.forward_event_shape(shape_in_static),
                         shape_out_static)
        self.assertEqual(bijector.inverse_event_shape(shape_out_static),
                         shape_in_static)
    def _testInputOutputMismatchOpError(self, expected_error_message):
        x1 = np.random.randn(4, 2, 3)
        x2 = np.random.randn(4, 1, 1, 5)

        with self.cached_session() as sess:
            shape_in, shape_out, fd_mismatched = self.build_shapes([2, 3],
                                                                   [1, 1, 5])
            bijector = Reshape(event_shape_out=shape_out,
                               event_shape_in=shape_in,
                               validate_args=True)

            with self.assertRaisesError(expected_error_message):
                sess.run(bijector.forward(x1), feed_dict=fd_mismatched)
            with self.assertRaisesError(expected_error_message):
                sess.run(bijector.inverse(x2), feed_dict=fd_mismatched)
    def testValidButNonMatchingInputPartiallySpecifiedOpError(self):
        x = np.random.randn(4, 3, 2)

        with self.cached_session() as sess:
            shape_in, shape_out, feed_dict = self.build_shapes([2, -1], [
                1,
                6,
                1,
            ])
            bijector = Reshape(event_shape_out=shape_out,
                               event_shape_in=shape_in,
                               validate_args=True)

            with self.assertRaisesError(
                    "Input `event_shape` does not match `event_shape_in`."):
                sess.run(bijector.forward(x), feed_dict=feed_dict)
    def _testInvalidDimensionsOpError(self, expected_error_message):

        with self.cached_session() as sess:

            shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [
                1,
                2,
                -2,
            ])
            bijector = Reshape(event_shape_out=shape_out,
                               event_shape_in=shape_in,
                               validate_args=True)

            with self.assertRaisesError(expected_error_message):
                sess.run(bijector.forward_event_shape_tensor(shape_in),
                         feed_dict=feed_dict)
    def testMultipleUnspecifiedDimensionsOpError(self):

        with self.cached_session() as sess:
            shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [
                4,
                -1,
                -1,
            ])
            bijector = Reshape(event_shape_out=shape_out,
                               event_shape_in=shape_in,
                               validate_args=True)

            with self.assertRaisesError(
                    "elements must have at most one `-1`."):
                sess.run(bijector.forward_event_shape_tensor(shape_in),
                         feed_dict=feed_dict)
 def testDefaultVectorShape(self):
     expected_x = np.random.randn(4, 4)
     expected_y = np.reshape(expected_x, [4, 2, 2])
     with self.cached_session() as sess:
         _, shape_out, feed_dict = self.build_shapes([
             -1,
         ], [-1, 2])
         bijector = Reshape(shape_out, validate_args=True)
         (
             x_,
             y_,
         ) = sess.run((
             bijector.inverse(expected_y),
             bijector.forward(expected_x),
         ),
                      feed_dict=feed_dict)
         self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0)
         self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0)
 def testBothShapesPartiallySpecified(self):
     expected_x = np.random.randn(4, 2, 3)
     expected_y = np.reshape(expected_x, [4, 3, 2])
     with self.cached_session() as sess:
         shape_in, shape_out, feed_dict = self.build_shapes([-1, 3],
                                                            [-1, 2])
         bijector = Reshape(event_shape_out=shape_out,
                            event_shape_in=shape_in,
                            validate_args=True)
         (
             x_,
             y_,
         ) = sess.run((
             bijector.inverse(expected_y),
             bijector.forward(expected_x),
         ),
                      feed_dict=feed_dict)
         self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0)
         self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0)
    def testValidButNonMatchingInputOpError(self):
        x = np.random.randn(4, 3, 2)

        with self.cached_session() as sess:
            shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [
                1,
                6,
                1,
            ])
            bijector = Reshape(event_shape_out=shape_out,
                               event_shape_in=shape_in,
                               validate_args=True)

            # Here we pass in a tensor (x) whose shape is compatible with
            # the output shape, so tf.reshape will throw no error, but
            # doesn't match the expected input shape.
            with self.assertRaisesError(
                    "Input `event_shape` does not match `event_shape_in`."):
                sess.run(bijector.forward(x), feed_dict=feed_dict)
    def testScalarReshape(self):
        """Test reshaping to and from a scalar shape ()."""

        expected_x = np.random.randn(4, 3, 1)
        expected_y = np.reshape(expected_x, [4, 3])

        expected_x_scalar = np.random.randn(1, )
        expected_y_scalar = expected_x_scalar[0]

        shape_in, shape_out, feed_dict = self.build_shapes([], [
            1,
        ])
        with self.cached_session() as sess:
            bijector = Reshape(event_shape_out=shape_in,
                               event_shape_in=shape_out,
                               validate_args=True)
            (x_, y_, x_scalar_, y_scalar_) = sess.run((
                bijector.inverse(expected_y),
                bijector.forward(expected_x),
                bijector.inverse(expected_y_scalar),
                bijector.forward(expected_x_scalar),
            ),
                                                      feed_dict=feed_dict)
            self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0)
            self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0)
            self.assertAllClose(expected_y_scalar,
                                y_scalar_,
                                rtol=1e-6,
                                atol=0)
            self.assertAllClose(expected_x_scalar,
                                x_scalar_,
                                rtol=1e-6,
                                atol=0)
 def testBijectiveAndFinite(self):
     x = np.random.randn(4, 2, 3)
     y = np.reshape(x, [4, 1, 2, 3])
     with self.cached_session():
         bijector = Reshape(event_shape_in=[2, 3],
                            event_shape_out=[1, 2, 3],
                            validate_args=True)
         assert_bijective_and_finite(bijector,
                                     x,
                                     y,
                                     event_ndims=2,
                                     rtol=1e-6,
                                     atol=0)
    def testEventShapeTensor(self):
        """Test event_shape_tensor methods when even ndims may be dynamic."""

        shape_in_static = [2, 3]
        shape_out_static = [
            6,
        ]
        shape_in, shape_out, feed_dict = self.build_shapes(
            shape_in_static, shape_out_static)
        bijector = Reshape(event_shape_out=shape_out,
                           event_shape_in=shape_in,
                           validate_args=True)

        # using the _tensor methods, we should always get a fully-specified
        # result since these are evaluated at graph runtime.
        with self.cached_session() as sess:
            (shape_out_, shape_in_) = sess.run((
                bijector.forward_event_shape_tensor(shape_in),
                bijector.inverse_event_shape_tensor(shape_out),
            ),
                                               feed_dict=feed_dict)
            self.assertAllEqual(shape_out_static, shape_out_)
            self.assertAllEqual(shape_in_static, shape_in_)
    def testBijector(self):
        """Do a basic sanity check of forward, inverse, jacobian."""
        expected_x = np.random.randn(4, 3, 2)
        expected_y = np.reshape(expected_x, [4, 6])

        with self.cached_session() as sess:
            shape_in, shape_out, feed_dict = self.build_shapes([3, 2], [
                6,
            ])
            bijector = Reshape(event_shape_out=shape_out,
                               event_shape_in=shape_in,
                               validate_args=True)
            (x_, y_, fldj_, ildj_) = sess.run((
                bijector.inverse(expected_y),
                bijector.forward(expected_x),
                bijector.forward_log_det_jacobian(expected_x, event_ndims=2),
                bijector.inverse_log_det_jacobian(expected_y, event_ndims=2),
            ),
                                              feed_dict=feed_dict)
            self.assertEqual("reshape", bijector.name)
            self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0)
            self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0)
            self.assertAllClose(0., fldj_, rtol=1e-6, atol=0)
            self.assertAllClose(0., ildj_, rtol=1e-6, atol=0)