def testEventShapeStatic(self):
        """Check shape methods when shape is statically known."""

        shape_in = tensor_shape.TensorShape([
            6,
        ])
        shape_out = tensor_shape.TensorShape([2, 3])

        bijector_static = Reshape(event_shape_out=shape_out,
                                  event_shape_in=shape_in,
                                  validate_args=True)

        # test that forward_ and inverse_event_shape do sensible things
        # when shapes are statically known.
        self.assertEqual(bijector_static.forward_event_shape(shape_in),
                         shape_out)
        self.assertEqual(bijector_static.inverse_event_shape(shape_out),
                         shape_in)

        with self.test_session() as sess:
            (
                shape_out_static_,
                shape_in_static_,
            ) = sess.run((
                bijector_static.forward_event_shape_tensor(shape_in),
                bijector_static.inverse_event_shape_tensor(shape_out),
            ))
            self.assertAllEqual(shape_out, shape_out_static_)
            self.assertAllEqual(shape_in, shape_in_static_)
Exemple #2
0
    def testEventShape(self):
        shape_in_static = tensor_shape.TensorShape([2, 3])
        shape_out_static = tensor_shape.TensorShape([
            6,
        ])
        bijector = Reshape(event_shape_out=shape_out_static,
                           event_shape_in=shape_in_static,
                           validate_args=True)

        # test that forward_ and inverse_event_shape do sensible things
        # when shapes are statically known.
        self.assertEqual(bijector.forward_event_shape(shape_in_static),
                         shape_out_static)
        self.assertEqual(bijector.inverse_event_shape(shape_out_static),
                         shape_in_static)
Exemple #3
0
  def testEventShape(self):
    shape_in_static = tensor_shape.TensorShape([2, 3])
    shape_out_static = tensor_shape.TensorShape([6,])
    bijector = Reshape(
        event_shape_out=shape_out_static,
        event_shape_in=shape_in_static, validate_args=True)

    # test that forward_ and inverse_event_shape do sensible things
    # when shapes are statically known.
    self.assertEqual(
        bijector.forward_event_shape(shape_in_static),
        shape_out_static)
    self.assertEqual(
        bijector.inverse_event_shape(shape_out_static),
        shape_in_static)
    def testEventShapeDynamic(self):
        """Check shape methods with static ndims but dynamic shape."""

        shape_in = tensor_shape.TensorShape([
            6,
        ])
        shape_in_partial = tensor_shape.TensorShape([
            None,
        ])
        shape_in_ph = array_ops.placeholder(shape=[
            1,
        ], dtype=dtypes.int32)

        shape_out = tensor_shape.TensorShape([2, 3])
        shape_out_partial = tensor_shape.TensorShape([None, None])
        shape_out_ph = array_ops.placeholder(shape=[
            2,
        ], dtype=dtypes.int32)

        bijector = Reshape(event_shape_out=shape_out_ph,
                           event_shape_in=shape_in_ph,
                           validate_args=True)

        # if event shapes are not statically available, should
        # return partially-specified TensorShapes.
        self.assertAllEqual(
            bijector.forward_event_shape(shape_in).as_list(),
            shape_out_partial.as_list())
        self.assertAllEqual(
            bijector.inverse_event_shape(shape_out).as_list(),
            shape_in_partial.as_list())

        # using the _tensor methods, we should always get a fully-specified
        # result since these are evaluated at graph runtime.
        with self.test_session() as sess:
            (shape_out_, shape_in_) = sess.run((
                bijector.forward_event_shape_tensor(shape_in),
                bijector.inverse_event_shape_tensor(shape_out),
            ),
                                               feed_dict={
                                                   shape_in_ph: shape_in,
                                                   shape_out_ph: shape_out,
                                               })
            self.assertAllEqual(shape_out, shape_out_)
            self.assertAllEqual(shape_in, shape_in_)