def testEventShapeStatic(self): """Check shape methods when shape is statically known.""" shape_in = tensor_shape.TensorShape([ 6, ]) shape_out = tensor_shape.TensorShape([2, 3]) bijector_static = Reshape(event_shape_out=shape_out, event_shape_in=shape_in, validate_args=True) # test that forward_ and inverse_event_shape do sensible things # when shapes are statically known. self.assertEqual(bijector_static.forward_event_shape(shape_in), shape_out) self.assertEqual(bijector_static.inverse_event_shape(shape_out), shape_in) with self.test_session() as sess: ( shape_out_static_, shape_in_static_, ) = sess.run(( bijector_static.forward_event_shape_tensor(shape_in), bijector_static.inverse_event_shape_tensor(shape_out), )) self.assertAllEqual(shape_out, shape_out_static_) self.assertAllEqual(shape_in, shape_in_static_)
def testEventShape(self): shape_in_static = tensor_shape.TensorShape([2, 3]) shape_out_static = tensor_shape.TensorShape([ 6, ]) bijector = Reshape(event_shape_out=shape_out_static, event_shape_in=shape_in_static, validate_args=True) # test that forward_ and inverse_event_shape do sensible things # when shapes are statically known. self.assertEqual(bijector.forward_event_shape(shape_in_static), shape_out_static) self.assertEqual(bijector.inverse_event_shape(shape_out_static), shape_in_static)
def testEventShape(self): shape_in_static = tensor_shape.TensorShape([2, 3]) shape_out_static = tensor_shape.TensorShape([6,]) bijector = Reshape( event_shape_out=shape_out_static, event_shape_in=shape_in_static, validate_args=True) # test that forward_ and inverse_event_shape do sensible things # when shapes are statically known. self.assertEqual( bijector.forward_event_shape(shape_in_static), shape_out_static) self.assertEqual( bijector.inverse_event_shape(shape_out_static), shape_in_static)
def testEventShapeDynamic(self): """Check shape methods with static ndims but dynamic shape.""" shape_in = tensor_shape.TensorShape([ 6, ]) shape_in_partial = tensor_shape.TensorShape([ None, ]) shape_in_ph = array_ops.placeholder(shape=[ 1, ], dtype=dtypes.int32) shape_out = tensor_shape.TensorShape([2, 3]) shape_out_partial = tensor_shape.TensorShape([None, None]) shape_out_ph = array_ops.placeholder(shape=[ 2, ], dtype=dtypes.int32) bijector = Reshape(event_shape_out=shape_out_ph, event_shape_in=shape_in_ph, validate_args=True) # if event shapes are not statically available, should # return partially-specified TensorShapes. self.assertAllEqual( bijector.forward_event_shape(shape_in).as_list(), shape_out_partial.as_list()) self.assertAllEqual( bijector.inverse_event_shape(shape_out).as_list(), shape_in_partial.as_list()) # using the _tensor methods, we should always get a fully-specified # result since these are evaluated at graph runtime. with self.test_session() as sess: (shape_out_, shape_in_) = sess.run(( bijector.forward_event_shape_tensor(shape_in), bijector.inverse_event_shape_tensor(shape_out), ), feed_dict={ shape_in_ph: shape_in, shape_out_ph: shape_out, }) self.assertAllEqual(shape_out, shape_out_) self.assertAllEqual(shape_in, shape_in_)