def testEventShapeDynamicNdims(self):
        """Check forward/inverse shape methods with dynamic ndims."""

        shape_in = tensor_shape.TensorShape([
            6,
        ])
        shape_in_ph = array_ops.placeholder(dtype=dtypes.int32)

        shape_out = tensor_shape.TensorShape([2, 3])
        shape_out_ph = array_ops.placeholder(dtype=dtypes.int32)

        bijector = Reshape(event_shape_out=shape_out_ph,
                           event_shape_in=shape_in_ph,
                           validate_args=True)

        # using the _tensor methods, we should always get a fully-specified
        # result since these are evaluated at graph runtime.
        with self.test_session() as sess:
            (shape_out_, shape_in_) = sess.run((
                bijector.forward_event_shape_tensor(shape_in),
                bijector.inverse_event_shape_tensor(shape_out),
            ),
                                               feed_dict={
                                                   shape_in_ph: shape_in,
                                                   shape_out_ph: shape_out,
                                               })
            self.assertAllEqual(shape_out, shape_out_)
            self.assertAllEqual(shape_in, shape_in_)
Beispiel #2
0
  def testBijector(self):
    """Do a basic sanity check of forward, inverse, jacobian."""
    expected_x = np.random.randn(4, 3, 2)
    expected_y = np.reshape(expected_x, [4, 6])

    with self.cached_session() as sess:
      shape_in, shape_out, feed_dict = self.build_shapes([3, 2], [6,])
      bijector = Reshape(
          event_shape_out=shape_out,
          event_shape_in=shape_in,
          validate_args=True)
      (x_,
       y_,
       fldj_,
       ildj_) = sess.run((
           bijector.inverse(expected_y),
           bijector.forward(expected_x),
           bijector.forward_log_det_jacobian(expected_x, event_ndims=2),
           bijector.inverse_log_det_jacobian(expected_y, event_ndims=2),
       ), feed_dict=feed_dict)
      self.assertEqual("reshape", bijector.name)
      self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0)
      self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0)
      self.assertAllClose(0., fldj_, rtol=1e-6, atol=0)
      self.assertAllClose(0., ildj_, rtol=1e-6, atol=0)
Beispiel #3
0
  def testScalarReshape(self):
    """Test reshaping to and from a scalar shape ()."""

    expected_x = np.random.randn(4, 3, 1)
    expected_y = np.reshape(expected_x, [4, 3])

    expected_x_scalar = np.random.randn(1,)
    expected_y_scalar = expected_x_scalar[0]

    shape_in, shape_out, feed_dict = self.build_shapes([], [1,])
    with self.cached_session() as sess:
      bijector = Reshape(
          event_shape_out=shape_in,
          event_shape_in=shape_out, validate_args=True)
      (x_,
       y_,
       x_scalar_,
       y_scalar_
      ) = sess.run((
          bijector.inverse(expected_y),
          bijector.forward(expected_x),
          bijector.inverse(expected_y_scalar),
          bijector.forward(expected_x_scalar),
      ), feed_dict=feed_dict)
      self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0)
      self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0)
      self.assertAllClose(expected_y_scalar, y_scalar_, rtol=1e-6, atol=0)
      self.assertAllClose(expected_x_scalar, x_scalar_, rtol=1e-6, atol=0)
Beispiel #4
0
  def testMultipleUnspecifiedDimensionsOpError(self):

    with self.cached_session() as sess:
      shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [4, -1, -1,])
      bijector = Reshape(
          event_shape_out=shape_out,
          event_shape_in=shape_in,
          validate_args=True)

      with self.assertRaisesError(
          "elements must have at most one `-1`."):
        sess.run(bijector.forward_event_shape_tensor(shape_in),
                 feed_dict=feed_dict)
Beispiel #5
0
  def _testInvalidDimensionsOpError(self, expected_error_message):

    with self.cached_session() as sess:

      shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [1, 2, -2,])
      bijector = Reshape(
          event_shape_out=shape_out,
          event_shape_in=shape_in,
          validate_args=True)

      with self.assertRaisesError(expected_error_message):
        sess.run(bijector.forward_event_shape_tensor(shape_in),
                 feed_dict=feed_dict)
  def _testInvalidDimensionsOpError(self, expected_error_message):

    with self.test_session() as sess:

      shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [1, 2, -2,])
      bijector = Reshape(
          event_shape_out=shape_out,
          event_shape_in=shape_in,
          validate_args=True)

      with self.assertRaisesError(expected_error_message):
        sess.run(bijector.forward_event_shape_tensor(shape_in),
                 feed_dict=feed_dict)
  def testMultipleUnspecifiedDimensionsOpError(self):

    with self.test_session() as sess:
      shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [4, -1, -1,])
      bijector = Reshape(
          event_shape_out=shape_out,
          event_shape_in=shape_in,
          validate_args=True)

      with self.assertRaisesError(
          "elements must have at most one `-1`."):
        sess.run(bijector.forward_event_shape_tensor(shape_in),
                 feed_dict=feed_dict)
  def testValidButNonMatchingInputPartiallySpecifiedOpError(self):
    x = np.random.randn(4, 3, 2)

    with self.test_session() as sess:
      shape_in, shape_out, feed_dict = self.build_shapes([2, -1], [1, 6, 1,])
      bijector = Reshape(
          event_shape_out=shape_out,
          event_shape_in=shape_in,
          validate_args=True)

      with self.assertRaisesError(
          "Input `event_shape` does not match `event_shape_in`."):
        sess.run(bijector.forward(x),
                 feed_dict=feed_dict)
  def testInvalidDimensionsOpError(self):

    with self.test_session() as sess:

      shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [1, 2, -2,])
      bijector = Reshape(
          event_shape_out=shape_out,
          event_shape_in=shape_in,
          validate_args=True)

      with self.assertRaisesError(
          "elements must be either positive integers or `-1`."):
        sess.run(bijector.forward_event_shape_tensor(shape_in),
                 feed_dict=feed_dict)
Beispiel #10
0
  def testValidButNonMatchingInputPartiallySpecifiedOpError(self):
    x = np.random.randn(4, 3, 2)

    with self.cached_session() as sess:
      shape_in, shape_out, feed_dict = self.build_shapes([2, -1], [1, 6, 1,])
      bijector = Reshape(
          event_shape_out=shape_out,
          event_shape_in=shape_in,
          validate_args=True)

      with self.assertRaisesError(
          "Input `event_shape` does not match `event_shape_in`."):
        sess.run(bijector.forward(x),
                 feed_dict=feed_dict)
Beispiel #11
0
  def testEventShape(self):
    shape_in_static = tensor_shape.TensorShape([2, 3])
    shape_out_static = tensor_shape.TensorShape([6,])
    bijector = Reshape(
        event_shape_out=shape_out_static,
        event_shape_in=shape_in_static, validate_args=True)

    # test that forward_ and inverse_event_shape do sensible things
    # when shapes are statically known.
    self.assertEqual(
        bijector.forward_event_shape(shape_in_static),
        shape_out_static)
    self.assertEqual(
        bijector.inverse_event_shape(shape_out_static),
        shape_in_static)
Beispiel #12
0
    def testEventShape(self):
        shape_in_static = tensor_shape.TensorShape([2, 3])
        shape_out_static = tensor_shape.TensorShape([
            6,
        ])
        bijector = Reshape(event_shape_out=shape_out_static,
                           event_shape_in=shape_in_static,
                           validate_args=True)

        # test that forward_ and inverse_event_shape do sensible things
        # when shapes are statically known.
        self.assertEqual(bijector.forward_event_shape(shape_in_static),
                         shape_out_static)
        self.assertEqual(bijector.inverse_event_shape(shape_out_static),
                         shape_in_static)
 def testDefaultVectorShape(self):
   expected_x = np.random.randn(4, 4)
   expected_y = np.reshape(expected_x, [4, 2, 2])
   with self.test_session() as sess:
     _, shape_out, feed_dict = self.build_shapes([-1,], [-1, 2])
     bijector = Reshape(shape_out,
                        validate_args=True)
     (x_,
      y_,
     ) = sess.run((
         bijector.inverse(expected_y),
         bijector.forward(expected_x),
     ), feed_dict=feed_dict)
     self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0)
     self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0)
Beispiel #14
0
 def testDefaultVectorShape(self):
   expected_x = np.random.randn(4, 4)
   expected_y = np.reshape(expected_x, [4, 2, 2])
   with self.cached_session() as sess:
     _, shape_out, feed_dict = self.build_shapes([-1,], [-1, 2])
     bijector = Reshape(shape_out,
                        validate_args=True)
     (x_,
      y_,
     ) = sess.run((
         bijector.inverse(expected_y),
         bijector.forward(expected_x),
     ), feed_dict=feed_dict)
     self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0)
     self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0)
Beispiel #15
0
    def _testInputOutputMismatchOpError(self, expected_error_message):
        x1 = np.random.randn(4, 2, 3)
        x2 = np.random.randn(4, 1, 1, 5)

        with self.test_session() as sess:
            shape_in, shape_out, fd_mismatched = self.build_shapes([2, 3],
                                                                   [1, 1, 5])
            bijector = Reshape(event_shape_out=shape_out,
                               event_shape_in=shape_in,
                               validate_args=True)

            with self.assertRaisesError(expected_error_message):
                sess.run(bijector.forward(x1), feed_dict=fd_mismatched)
            with self.assertRaisesError(expected_error_message):
                sess.run(bijector.inverse(x2), feed_dict=fd_mismatched)
Beispiel #16
0
  def _testInputOutputMismatchOpError(self, expected_error_message):
    x1 = np.random.randn(4, 2, 3)
    x2 = np.random.randn(4, 1, 1, 5)

    with self.cached_session() as sess:
      shape_in, shape_out, fd_mismatched = self.build_shapes([2, 3],
                                                             [1, 1, 5])
      bijector = Reshape(
          event_shape_out=shape_out,
          event_shape_in=shape_in,
          validate_args=True)

      with self.assertRaisesError(expected_error_message):
        sess.run(bijector.forward(x1), feed_dict=fd_mismatched)
      with self.assertRaisesError(expected_error_message):
        sess.run(bijector.inverse(x2), feed_dict=fd_mismatched)
Beispiel #17
0
    def testInputOutputMismatchOpError(self):
        x1 = np.random.randn(4, 2, 3)
        x2 = np.random.randn(4, 1, 1, 5)

        with self.test_session() as sess:
            shape_in, shape_out, fd_mismatched = self.build_shapes([2, 3],
                                                                   [1, 1, 5])
            bijector = Reshape(event_shape_out=shape_out,
                               event_shape_in=shape_in,
                               validate_args=True)

            # test that *all* methods check basic assertions
            with self.assertRaisesError("Input to reshape is a tensor with"):
                sess.run(bijector.forward(x1), feed_dict=fd_mismatched)
            with self.assertRaisesError("Input to reshape is a tensor with"):
                sess.run(bijector.inverse(x2), feed_dict=fd_mismatched)
Beispiel #18
0
 def testBothShapesPartiallySpecified(self):
   expected_x = np.random.randn(4, 2, 3)
   expected_y = np.reshape(expected_x, [4, 3, 2])
   with self.cached_session() as sess:
     shape_in, shape_out, feed_dict = self.build_shapes([-1, 3], [-1, 2])
     bijector = Reshape(
         event_shape_out=shape_out,
         event_shape_in=shape_in,
         validate_args=True)
     (x_,
      y_,
     ) = sess.run((
         bijector.inverse(expected_y),
         bijector.forward(expected_x),
     ), feed_dict=feed_dict)
     self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0)
     self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0)
Beispiel #19
0
  def testValidButNonMatchingInputOpError(self):
    x = np.random.randn(4, 3, 2)

    with self.cached_session() as sess:
      shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [1, 6, 1,])
      bijector = Reshape(
          event_shape_out=shape_out,
          event_shape_in=shape_in,
          validate_args=True)

      # Here we pass in a tensor (x) whose shape is compatible with
      # the output shape, so tf.reshape will throw no error, but
      # doesn't match the expected input shape.
      with self.assertRaisesError(
          "Input `event_shape` does not match `event_shape_in`."):
        sess.run(bijector.forward(x),
                 feed_dict=feed_dict)
  def testValidButNonMatchingInputOpError(self):
    x = np.random.randn(4, 3, 2)

    with self.test_session() as sess:
      shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [1, 6, 1,])
      bijector = Reshape(
          event_shape_out=shape_out,
          event_shape_in=shape_in,
          validate_args=True)

      # Here we pass in a tensor (x) whose shape is compatible with
      # the output shape, so tf.reshape will throw no error, but
      # doesn't match the expected input shape.
      with self.assertRaisesError(
          "Input `event_shape` does not match `event_shape_in`."):
        sess.run(bijector.forward(x),
                 feed_dict=feed_dict)
 def testBothShapesPartiallySpecified(self):
   expected_x = np.random.randn(4, 2, 3)
   expected_y = np.reshape(expected_x, [4, 3, 2])
   with self.test_session() as sess:
     shape_in, shape_out, feed_dict = self.build_shapes([-1, 3], [-1, 2])
     bijector = Reshape(
         event_shape_out=shape_out,
         event_shape_in=shape_in,
         validate_args=True)
     (x_,
      y_,
     ) = sess.run((
         bijector.inverse(expected_y),
         bijector.forward(expected_x),
     ), feed_dict=feed_dict)
     self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0)
     self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0)
Beispiel #22
0
 def testBijectiveAndFinite(self):
     x = np.random.randn(4, 2, 3)
     y = np.reshape(x, [4, 1, 2, 3])
     with self.test_session():
         bijector = Reshape(event_shape_in=[2, 3],
                            event_shape_out=[1, 2, 3],
                            validate_args=True)
         assert_bijective_and_finite(bijector, x, y, rtol=1e-6, atol=0)
Beispiel #23
0
    def testInvalidDimensionsOpError(self):

        with self.test_session() as sess:

            shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [
                1,
                2,
                -2,
            ])
            bijector = Reshape(event_shape_out=shape_out,
                               event_shape_in=shape_in,
                               validate_args=True)

            with self.assertRaisesError(
                    "elements must be either positive integers or `-1`."):
                sess.run(bijector.forward_event_shape_tensor(shape_in),
                         feed_dict=feed_dict)
  def testInputOutputMismatchOpError(self):
    x1 = np.random.randn(4, 2, 3)
    x2 = np.random.randn(4, 1, 1, 5)

    with self.test_session() as sess:
      shape_in, shape_out, fd_mismatched = self.build_shapes([2, 3],
                                                             [1, 1, 5])
      bijector = Reshape(
          event_shape_out=shape_out,
          event_shape_in=shape_in,
          validate_args=True)

      # test that *all* methods check basic assertions
      with self.assertRaisesError(
          "Input to reshape is a tensor with"):
        sess.run(bijector.forward(x1), feed_dict=fd_mismatched)
      with self.assertRaisesError(
          "Input to reshape is a tensor with"):
        sess.run(bijector.inverse(x2), feed_dict=fd_mismatched)
  def testOneShapePartiallySpecified(self):
    expected_x = np.random.randn(4, 6)
    expected_y = np.reshape(expected_x, [4, 2, 3])

    with self.test_session() as sess:
      # one of input/output shapes is partially specified
      shape_in, shape_out, feed_dict = self.build_shapes([-1,], [2, 3])
      bijector = Reshape(
          event_shape_out=shape_out,
          event_shape_in=shape_in,
          validate_args=True)
      (x_,
       y_,
      ) = sess.run((
          bijector.inverse(expected_y),
          bijector.forward(expected_x),
      ), feed_dict=feed_dict)
      self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0)
      self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0)
Beispiel #26
0
  def testEventShapeTensor(self):
    """Test event_shape_tensor methods when even ndims may be dynamic."""

    shape_in_static = [2, 3]
    shape_out_static = [6,]
    shape_in, shape_out, feed_dict = self.build_shapes(shape_in_static,
                                                       shape_out_static)
    bijector = Reshape(
        event_shape_out=shape_out,
        event_shape_in=shape_in, validate_args=True)

    # using the _tensor methods, we should always get a fully-specified
    # result since these are evaluated at graph runtime.
    with self.cached_session() as sess:
      (shape_out_,
       shape_in_) = sess.run((
           bijector.forward_event_shape_tensor(shape_in),
           bijector.inverse_event_shape_tensor(shape_out),
       ), feed_dict=feed_dict)
      self.assertAllEqual(shape_out_static, shape_out_)
      self.assertAllEqual(shape_in_static, shape_in_)
  def testEventShapeTensor(self):
    """Test event_shape_tensor methods when even ndims may be dynamic."""

    shape_in_static = [2, 3]
    shape_out_static = [6,]
    shape_in, shape_out, feed_dict = self.build_shapes(shape_in_static,
                                                       shape_out_static)
    bijector = Reshape(
        event_shape_out=shape_out,
        event_shape_in=shape_in, validate_args=True)

    # using the _tensor methods, we should always get a fully-specified
    # result since these are evaluated at graph runtime.
    with self.test_session() as sess:
      (shape_out_,
       shape_in_) = sess.run((
           bijector.forward_event_shape_tensor(shape_in),
           bijector.inverse_event_shape_tensor(shape_out),
       ), feed_dict=feed_dict)
      self.assertAllEqual(shape_out_static, shape_out_)
      self.assertAllEqual(shape_in_static, shape_in_)
Beispiel #28
0
    def testScalarReshape(self):
        """Test reshaping to and from a scalar shape ()."""

        expected_x = np.random.randn(4, 3, 1)
        expected_y = np.reshape(expected_x, [4, 3])

        expected_x_scalar = np.random.randn(1, )
        expected_y_scalar = expected_x_scalar[0]

        shape_in, shape_out, feed_dict = self.build_shapes([], [
            1,
        ])
        with self.test_session() as sess:
            bijector = Reshape(event_shape_out=shape_in,
                               event_shape_in=shape_out,
                               validate_args=True)
            (x_, y_, x_scalar_, y_scalar_) = sess.run((
                bijector.inverse(expected_y),
                bijector.forward(expected_x),
                bijector.inverse(expected_y_scalar),
                bijector.forward(expected_x_scalar),
            ),
                                                      feed_dict=feed_dict)
            self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0)
            self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0)
            self.assertAllClose(expected_y_scalar,
                                y_scalar_,
                                rtol=1e-6,
                                atol=0)
            self.assertAllClose(expected_x_scalar,
                                x_scalar_,
                                rtol=1e-6,
                                atol=0)
  def testBijector(self):
    """Do a basic sanity check of forward, inverse, jacobian."""
    expected_x = np.random.randn(4, 3, 2)
    expected_y = np.reshape(expected_x, [4, 6])

    with self.test_session() as sess:
      shape_in, shape_out, feed_dict = self.build_shapes([3, 2], [6,])
      bijector = Reshape(
          event_shape_out=shape_out,
          event_shape_in=shape_in,
          validate_args=True)
      (x_,
       y_,
       fldj_,
       ildj_) = sess.run((
           bijector.inverse(expected_y),
           bijector.forward(expected_x),
           bijector.forward_log_det_jacobian(expected_x),
           bijector.inverse_log_det_jacobian(expected_y),
       ), feed_dict=feed_dict)
      self.assertEqual("reshape", bijector.name)
      self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0)
      self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0)
      self.assertAllClose(0., fldj_, rtol=1e-6, atol=0)
      self.assertAllClose(0., ildj_, rtol=1e-6, atol=0)
Beispiel #30
0
    def testEventShapeStatic(self):
        """Check shape methods when shape is statically known."""

        shape_in = tensor_shape.TensorShape([
            6,
        ])
        shape_out = tensor_shape.TensorShape([2, 3])

        bijector_static = Reshape(event_shape_out=shape_out,
                                  event_shape_in=shape_in,
                                  validate_args=True)

        # test that forward_ and inverse_event_shape do sensible things
        # when shapes are statically known.
        self.assertEqual(bijector_static.forward_event_shape(shape_in),
                         shape_out)
        self.assertEqual(bijector_static.inverse_event_shape(shape_out),
                         shape_in)

        with self.test_session() as sess:
            (
                shape_out_static_,
                shape_in_static_,
            ) = sess.run((
                bijector_static.forward_event_shape_tensor(shape_in),
                bijector_static.inverse_event_shape_tensor(shape_out),
            ))
            self.assertAllEqual(shape_out, shape_out_static_)
            self.assertAllEqual(shape_in, shape_in_static_)
Beispiel #31
0
    def testRaisesOpError(self):
        x1 = np.random.randn(4, 2, 3)
        x2 = np.random.randn(4, 3, 2)
        x3 = np.random.randn(4, 5, 1, 1)

        with self.test_session() as sess:
            shape_in_ph = array_ops.placeholder(shape=[
                2,
            ],
                                                dtype=dtypes.int32)
            shape_out_ph = array_ops.placeholder(shape=[
                3,
            ],
                                                 dtype=dtypes.int32)
            bijector = Reshape(event_shape_out=shape_out_ph,
                               event_shape_in=shape_in_ph,
                               validate_args=True)

            with self.assertRaisesOpError(
                    "Input `event_shape` does not match `event_shape_in`."):
                sess.run(bijector.forward(x2),
                         feed_dict={
                             shape_out_ph: [1, 6, 1],
                             shape_in_ph: [2, 3]
                         })

            with self.assertRaisesOpError(
                    "event_shape_out entries must be positive."):
                sess.run(bijector.forward(x1),
                         feed_dict={
                             shape_out_ph: [-1, -1, 6],
                             shape_in_ph: [2, 3]
                         })

            # test that *all* methods check basic assertions
            fd_mismatched = {shape_out_ph: [1, 1, 5], shape_in_ph: [2, 3]}
            with self.assertRaisesOpError(
                    "Input/output `event_size`s do not match."):
                sess.run(bijector.forward(x1), feed_dict=fd_mismatched)
            with self.assertRaisesOpError(
                    "Input/output `event_size`s do not match."):
                sess.run(bijector.inverse(x3), feed_dict=fd_mismatched)
            with self.assertRaisesOpError(
                    "Input/output `event_size`s do not match."):
                sess.run(bijector.inverse_log_det_jacobian(x3),
                         feed_dict=fd_mismatched)
            with self.assertRaisesOpError(
                    "Input/output `event_size`s do not match."):
                sess.run(bijector.forward_log_det_jacobian(x1),
                         feed_dict=fd_mismatched)
    def testOneShapePartiallySpecified(self):
        expected_x = np.random.randn(4, 6)
        expected_y = np.reshape(expected_x, [4, 2, 3])

        with self.cached_session() as sess:
            # one of input/output shapes is partially specified
            shape_in, shape_out, feed_dict = self.build_shapes([
                -1,
            ], [2, 3])
            bijector = Reshape(event_shape_out=shape_out,
                               event_shape_in=shape_in,
                               validate_args=True)
            (
                x_,
                y_,
            ) = sess.run((
                bijector.inverse(expected_y),
                bijector.forward(expected_x),
            ),
                         feed_dict=feed_dict)
            self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0)
            self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0)
    def _testInvalidDimensionsStatic(self, expected_error_message):
        """Version of _testInvalidDimensionsOpError for errors detected statically.

    Statically means at graph construction time.

    Args:
        expected_error_message: String that should be present in the error
          message that `Reshape` raises for invalid shapes.
    """
        shape_in, shape_out, _ = self.build_shapes([2, 3], [
            1,
            2,
            -2,
        ])
        with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
                                                 expected_error_message):
            _ = Reshape(event_shape_out=shape_out,
                        event_shape_in=shape_in,
                        validate_args=True)
Beispiel #34
0
    def testEventShapeDynamic(self):
        """Check shape methods with static ndims but dynamic shape."""

        shape_in = tensor_shape.TensorShape([
            6,
        ])
        shape_in_partial = tensor_shape.TensorShape([
            None,
        ])
        shape_in_ph = array_ops.placeholder(shape=[
            1,
        ], dtype=dtypes.int32)

        shape_out = tensor_shape.TensorShape([2, 3])
        shape_out_partial = tensor_shape.TensorShape([None, None])
        shape_out_ph = array_ops.placeholder(shape=[
            2,
        ], dtype=dtypes.int32)

        bijector = Reshape(event_shape_out=shape_out_ph,
                           event_shape_in=shape_in_ph,
                           validate_args=True)

        # if event shapes are not statically available, should
        # return partially-specified TensorShapes.
        self.assertAllEqual(
            bijector.forward_event_shape(shape_in).as_list(),
            shape_out_partial.as_list())
        self.assertAllEqual(
            bijector.inverse_event_shape(shape_out).as_list(),
            shape_in_partial.as_list())

        # using the _tensor methods, we should always get a fully-specified
        # result since these are evaluated at graph runtime.
        with self.test_session() as sess:
            (shape_out_, shape_in_) = sess.run((
                bijector.forward_event_shape_tensor(shape_in),
                bijector.inverse_event_shape_tensor(shape_out),
            ),
                                               feed_dict={
                                                   shape_in_ph: shape_in,
                                                   shape_out_ph: shape_out,
                                               })
            self.assertAllEqual(shape_out, shape_out_)
            self.assertAllEqual(shape_in, shape_in_)