def _make_reshaped_bijector(b, s): return tfb.Chain([ tfb.Reshape(event_shape_in=s, event_shape_out=[ps.reduce_prod(s)]), b, tfb.Reshape(event_shape_out=b.inverse_event_shape(s)), ])
def testRaisesBadBijectors(self): with self.assertRaisesRegexp(NotImplementedError, 'Only scalar and vector event-shape'): tfb.Blockwise(bijectors=[tfb.Reshape(event_shape_out=[1, 1])]) with self.assertRaisesRegexp(NotImplementedError, 'Only scalar and vector event-shape'): tfb.Blockwise(bijectors=[ tfb.Reshape(event_shape_out=[1], event_shape_in=[]) ])
def testUnknownShapeRank(self): unknown_shape = tf.placeholder_with_default([2, 2], shape=None) known_shape = [2, 2] with self.assertRaisesRegexp(NotImplementedError, "must be statically known."): tfb.Reshape(event_shape_out=unknown_shape) with self.assertRaisesRegexp(NotImplementedError, "must be statically known."): tfb.Reshape(event_shape_out=known_shape, event_shape_in=unknown_shape)
def testUnknownShapeRank(self): if tf.executing_eagerly(): return unknown_shape = tf1.placeholder_with_default([2, 2], shape=None) known_shape = [2, 2] with self.assertRaisesRegexp(NotImplementedError, 'must be statically known.'): tfb.Reshape(event_shape_out=unknown_shape) with self.assertRaisesRegexp(NotImplementedError, 'must be statically known.'): tfb.Reshape(event_shape_out=known_shape, event_shape_in=unknown_shape)
def _testInputOutputMismatchOpError(self, expected_error_message): x1 = np.random.randn(4, 2, 3) x2 = np.random.randn(4, 1, 1, 5) shape_in, shape_out = self.build_shapes([2, 3], [1, 1, 5]) with self.assertRaisesError(expected_error_message): bijector = tfb.Reshape(event_shape_out=shape_out, event_shape_in=shape_in, validate_args=True) self.evaluate(bijector.forward(x1)) with self.assertRaisesError(expected_error_message): bijector = tfb.Reshape(event_shape_out=shape_out, event_shape_in=shape_in, validate_args=True) self.evaluate(bijector.inverse(x2))
def testLogProb(self, event_shape, event_dims, training, layer_cls): training = tf.compat.v1.placeholder_with_default( training, (), "training") layer = layer_cls(axis=event_dims, epsilon=0.) batch_norm = tfb.BatchNormalization(batchnorm_layer=layer, training=training) base_dist = distributions.MultivariateNormalDiag( loc=np.zeros(np.prod(event_shape), dtype=np.float32)) # Reshape the events. if isinstance(event_shape, int): event_shape = [event_shape] base_dist = distributions.TransformedDistribution( distribution=base_dist, bijector=tfb.Reshape(event_shape_out=event_shape)) dist = distributions.TransformedDistribution(distribution=base_dist, bijector=batch_norm, validate_args=True) samples = dist.sample(int(1e5)) # No volume distortion since training=False, bijector is initialized # to the identity transformation. base_log_prob = base_dist.log_prob(samples) dist_log_prob = dist.log_prob(samples) self.evaluate(tf.compat.v1.global_variables_initializer()) base_log_prob_, dist_log_prob_ = self.evaluate( [base_log_prob, dist_log_prob]) self.assertAllClose(base_log_prob_, dist_log_prob_)
def testScalarReshape(self): """Test reshaping to and from a scalar shape ().""" expected_x = np.random.randn(4, 3, 1) expected_y = np.reshape(expected_x, [4, 3]) expected_x_scalar = np.random.randn(1, ) expected_y_scalar = expected_x_scalar[0] shape_in, shape_out = self.build_shapes([], [ 1, ]) bijector = tfb.Reshape(event_shape_out=shape_in, event_shape_in=shape_out, validate_args=True) (x_, y_, x_scalar_, y_scalar_) = self.evaluate(( bijector.inverse(expected_y), bijector.forward(expected_x), bijector.inverse(expected_y_scalar), bijector.forward(expected_x_scalar), )) self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) self.assertAllClose(expected_y_scalar, y_scalar_, rtol=1e-6, atol=0) self.assertAllClose(expected_x_scalar, x_scalar_, rtol=1e-6, atol=0)
def testCovarianceNotImplemented(self): mvn = tfd.MultivariateNormalDiag(loc=[0., 0.], scale_diag=[1., 2.]) # Non-affine bijector. with self.assertRaisesRegex( NotImplementedError, '`covariance` is not implemented'): tfd.TransformedDistribution( distribution=mvn, bijector=tfb.Exp()).covariance() # Non-injective bijector. with self.assertRaisesRegex( NotImplementedError, '`covariance` is not implemented'): tfd.TransformedDistribution( distribution=mvn, bijector=tfb.AbsoluteValue()).covariance() # Non-vector event shape. with self.assertRaisesRegex( NotImplementedError, '`covariance` is only implemented'): tfd.TransformedDistribution( distribution=mvn, bijector=tfb.Reshape(event_shape_out=[2, 1], event_shape_in=[2])).covariance() # Multipart bijector. with self.assertRaisesRegex( NotImplementedError, '`covariance` is only implemented'): tfd.TransformedDistribution( distribution=mvn, bijector=tfb.Split(2)).covariance()
def testInvalidDimensionsOpError(self): shape_in, shape_out = self.build_shapes([2, 3], [1, 2, -2,]) with self.assertRaises(ValueError): tfb.Reshape(event_shape_out=shape_out, event_shape_in=shape_in, validate_args=True)
def testMultipleUnspecifiedDimensionsOpError(self): shape_in, shape_out = self.build_shapes([2, 3], [4, -1, -1,]) with self.assertRaises(ValueError): tfb.Reshape(event_shape_out=shape_out, event_shape_in=shape_in, validate_args=True)
def testEventShape(self): shape_in_static = tf.TensorShape([2, 3]) shape_out_static = tf.TensorShape([6]) bijector = tfb.Reshape(event_shape_out=shape_out_static, event_shape_in=shape_in_static, validate_args=True) # Test that forward_ and inverse_event_shape are correct when # event_shape_in/_out are statically known, even when the input shapes # are only partially specified. self.assertEqual( bijector.forward_event_shape(tf.TensorShape([4, 2, 3])).as_list(), [4, 6]) self.assertEqual( bijector.inverse_event_shape(tf.TensorShape([4, 6])).as_list(), [4, 2, 3]) # Shape is always known for reshaping in eager mode, so we skip these tests. if tf.executing_eagerly(): return self.assertEqual( bijector.forward_event_shape(tf.TensorShape([None, 2, 3])).as_list(), [None, 6]) self.assertEqual( bijector.inverse_event_shape(tf.TensorShape([None, 6])).as_list(), [None, 2, 3]) # If the input shape is totally unknown, there's nothing we can do! self.assertIsNone( bijector.forward_event_shape(tf.TensorShape(None)).ndims)
def testCheckingVariableShape(self): shape_out = tf.Variable([-2, 10]) self.evaluate(shape_out.initializer) with self.assertRaisesOpError( 'elements must be either positive integers or `-1`'): self.evaluate( tfb.Reshape(shape_out, validate_args=True).forward([0]))
def testEventShape(self): # Shape is always known for reshaping in eager mode, so we skip these tests. if tf.executing_eagerly(): return event_shape_in, event_shape_out = self.build_shapes([2, 3], [6]) bijector = tfb.Reshape(event_shape_out=event_shape_out, event_shape_in=event_shape_in, validate_args=True) self.assertEqual( bijector.forward_event_shape(tf.TensorShape([4, 2, 3])).as_list(), [4, None]) self.assertEqual( bijector.forward_event_shape(tf.TensorShape([None, 2, 3])).as_list(), [None, None]) self.assertEqual( bijector.inverse_event_shape(tf.TensorShape([4, 6])).as_list(), [4, None, None]) self.assertEqual( bijector.inverse_event_shape(tf.TensorShape([None, 6])).as_list(), [None, None, None]) # If the input shape is totally unknown, there's nothing we can do! self.assertIsNone( bijector.forward_event_shape(tf.TensorShape(None)).ndims)
def testBijector(self): """Do a basic sanity check of forward, inverse, jacobian.""" expected_x = np.random.randn(4, 3, 2) expected_y = np.reshape(expected_x, [4, 6]) shape_in, shape_out = self.build_shapes([3, 2], [ 6, ]) bijector = tfb.Reshape(event_shape_out=shape_out, event_shape_in=shape_in, validate_args=True) [ x_, y_, fldj_, ildj_, fest_, iest_, ] = self.evaluate([ bijector.inverse(expected_y), bijector.forward(expected_x), bijector.forward_log_det_jacobian(expected_x, event_ndims=2), bijector.inverse_log_det_jacobian(expected_y, event_ndims=2), bijector.forward_event_shape_tensor(expected_x.shape), bijector.inverse_event_shape_tensor(expected_y.shape), ]) self.assertStartsWith(bijector.name, 'reshape') self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) self.assertAllClose(0., fldj_, rtol=1e-6, atol=0) self.assertAllClose(0., ildj_, rtol=1e-6, atol=0) # Test that event_shape_tensors match fwd/inv result shapes. self.assertAllEqual(y_.shape, fest_) self.assertAllEqual(x_.shape, iest_)
def testMultipleUnspecifiedDimensionsOpError(self): with self.assertRaisesError('must have at most one `-1`.'): shape_in, shape_out = self.build_shapes([2, 3], [4, -1, -1,]) bijector = tfb.Reshape( event_shape_out=shape_out, event_shape_in=shape_in, validate_args=True) self.evaluate(bijector.forward_event_shape_tensor(shape_in))
def testScalarInVectorOut(self): bijector = tfb.Reshape(event_shape_in=[], event_shape_out=[-1]) self.assertAllEqual( np.zeros([3, 4, 5, 1]), self.evaluate(bijector.forward(np.zeros([3, 4, 5])))) self.assertAllEqual( np.zeros([3, 4, 5]), self.evaluate(bijector.inverse(np.zeros([3, 4, 5, 1]))))
def testBijectiveAndFinite(self): x = np.random.randn(4, 2, 3) y = np.reshape(x, [4, 1, 2, 3]) bijector = tfb.Reshape( event_shape_in=[2, 3], event_shape_out=[1, 2, 3], validate_args=True) bijector_test_util.assert_bijective_and_finite( bijector, x, y, eval_func=self.evaluate, event_ndims=2, rtol=1e-6, atol=0)
def testCheckingMutatedVariableShape(self): shape_out = tf.Variable([1, 1]) self.evaluate(shape_out.initializer) reshape = tfb.Reshape(shape_out, validate_args=True) self.evaluate(reshape.forward([0])) with self.assertRaisesOpError( 'elements must be either positive integers or `-1`'): with tf.control_dependencies([shape_out.assign([-2, 10])]): self.evaluate(reshape.forward([0]))
def testInvalidDimensionsOpError(self): shape_in, shape_out = self.build_shapes([2, 3], [1, 2, -2,]) with self.assertRaisesError( "elements must be either positive integers or `-1`."): bijector = tfb.Reshape( event_shape_out=shape_out, event_shape_in=shape_in, validate_args=True) self.evaluate(bijector.forward_event_shape_tensor(shape_in))
def _default_event_space_bijector(self): """The bijector maps a zero-dimensional null Tensor input to `self.loc`.""" # The shape of the pulled back null tensor will be `self.loc.shape + (0,)`. # First we pad to a tensor of zeros with shape `self.loc.shape + (1,)`. pad_zero = tfb.Pad([(1, 0)]) # Next, we squeeze to a tensor of zeros with shape matching `self.loc`. zeros_squeezed = tfb.Reshape([], event_shape_in=[1])(pad_zero) # Finally, we shift the zeros by `self.loc`. return tfb.Shift(self.loc)(zeros_squeezed)
def testValidButNonMatchingInputPartiallySpecifiedOpError(self): x = np.random.randn(4, 3, 2) shape_in, shape_out = self.build_shapes([2, -1], [1, 6, 1,]) bijector = tfb.Reshape( event_shape_out=shape_out, event_shape_in=shape_in, validate_args=True) with self.assertRaisesError('Input `event_shape` does not match'): self.evaluate(bijector.forward(x))
def testDefaultVectorShape(self): expected_x = np.random.randn(4, 4) expected_y = np.reshape(expected_x, [4, 2, 2]) _, shape_out = self.build_shapes([-1,], [-1, 2]) bijector = tfb.Reshape(shape_out, validate_args=True) x_, y_, = self.evaluate([ bijector.inverse(expected_y), bijector.forward(expected_x), ]) self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0)
def testConcretizationLimits(self): shape_out = tfp_hps.defer_and_count_usage(tf.Variable([1])) reshape = tfb.Reshape(shape_out, validate_args=True) x = [1] # Pun: valid input or output, and valid input or output shape for method in ['forward', 'inverse', 'forward_event_shape', 'inverse_event_shape', 'forward_event_shape_tensor', 'inverse_event_shape_tensor']: with tfp_hps.assert_no_excessive_var_usage(method, max_permissible=7): getattr(reshape, method)(x) for method in ['forward_log_det_jacobian', 'inverse_log_det_jacobian']: with tfp_hps.assert_no_excessive_var_usage(method, max_permissible=4): getattr(reshape, method)(x, event_ndims=1)
def testBothShapesPartiallySpecified(self): expected_x = np.random.randn(4, 2, 3) expected_y = np.reshape(expected_x, [4, 3, 2]) shape_in, shape_out = self.build_shapes([-1, 3], [-1, 2]) bijector = tfb.Reshape(event_shape_out=shape_out, event_shape_in=shape_in, validate_args=True) x_, y_, = self.evaluate([ bijector.inverse(expected_y), bijector.forward(expected_x), ]) self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0)
def testBijectiveAndFinite(self): x = np.random.randn(4, 2, 3) y = np.reshape(x, [4, 1, 2, 3]) with self.test_session(): bijector = tfb.Reshape(event_shape_in=[2, 3], event_shape_out=[1, 2, 3], validate_args=True) assert_bijective_and_finite(bijector, x, y, event_ndims=2, rtol=1e-6, atol=0)
def test_transform_joint_to_joint(self, split_sizes): dist_batch_shape = tf.nest.pack_sequence_as( split_sizes, [tensorshape_util.constant_value_as_shape(s) for s in [[2, 3], [2, 1], [1, 3]]]) bijector_batch_shape = [1, 3] # Build a joint distribution with parts of the specified sizes. seed = test_util.test_seed_stream() component_dists = tf.nest.map_structure( lambda size, batch_shape: tfd.MultivariateNormalDiag( # pylint: disable=g-long-lambda loc=tf.random.normal(batch_shape + [size], seed=seed()), scale_diag=tf.random.uniform( minval=1., maxval=2., shape=batch_shape + [size], seed=seed())), split_sizes, dist_batch_shape) if isinstance(split_sizes, dict): base_dist = tfd.JointDistributionNamed(component_dists) else: base_dist = tfd.JointDistributionSequential(component_dists) # Transform the distribution by applying a separate bijector to each part. bijectors = [tfb.Exp(), tfb.Scale( tf.random.uniform( minval=1., maxval=2., shape=bijector_batch_shape, seed=seed())), tfb.Reshape([2, 1])] bijector = tfb.JointMap(tf.nest.pack_sequence_as(split_sizes, bijectors), validate_args=True) # Transform a joint distribution that has different batch shape components transformed_dist = tfd.TransformedDistribution(base_dist, bijector) self.assertRegex( str(transformed_dist), '{}.*batch_shape.*event_shape.*dtype'.format(transformed_dist.name)) self.assertAllEqualNested( transformed_dist.event_shape, bijector.forward_event_shape(base_dist.event_shape)) self.assertAllEqualNested(*self.evaluate(( transformed_dist.event_shape_tensor(), bijector.forward_event_shape_tensor(base_dist.event_shape_tensor())))) # Test that the batch shape components of the input are the same as those of # the output. self.assertAllEqualNested(transformed_dist.batch_shape, dist_batch_shape) self.assertAllEqualNested( self.evaluate(transformed_dist.batch_shape_tensor()), dist_batch_shape) self.assertAllEqualNested(dist_batch_shape, base_dist.batch_shape)
def testMultipleUnspecifiedDimensionsOpError(self): shape_in, shape_out = self.build_shapes([2, 3], [ 4, -1, -1, ]) bijector = tfb.Reshape(event_shape_out=shape_out, event_shape_in=shape_in, validate_args=True) with self.cached_session() as sess: with self.assertRaisesError( "elements must have at most one `-1`."): sess.run(bijector.forward_event_shape_tensor(shape_in))
def testValidButNonMatchingInputOpError(self): x = np.random.randn(4, 3, 2) shape_in, shape_out = self.build_shapes([2, 3], [1, 6, 1,]) bijector = tfb.Reshape( event_shape_out=shape_out, event_shape_in=shape_in, validate_args=True) # Here we pass in a tensor (x) whose shape is compatible with # the output shape, so tf.reshape will throw no error, but # doesn't match the expected input shape. with self.assertRaisesError('Input `event_shape` does not match'): self.evaluate(bijector.forward(x))
def _testInputOutputMismatchOpError(self, expected_error_message): x1 = np.random.randn(4, 2, 3) x2 = np.random.randn(4, 1, 1, 5) with self.test_session() as sess: shape_in, shape_out, fd_mismatched = self.build_shapes([2, 3], [1, 1, 5]) bijector = tfb.Reshape(event_shape_out=shape_out, event_shape_in=shape_in, validate_args=True) with self.assertRaisesError(expected_error_message): sess.run(bijector.forward(x1), feed_dict=fd_mismatched) with self.assertRaisesError(expected_error_message): sess.run(bijector.inverse(x2), feed_dict=fd_mismatched)
def testEventShape(self): shape_in_static = tf.TensorShape([2, 3]) shape_out_static = tf.TensorShape([ 6, ]) bijector = tfb.Reshape(event_shape_out=shape_out_static, event_shape_in=shape_in_static, validate_args=True) # test that forward_ and inverse_event_shape do sensible things # when shapes are statically known. self.assertEqual(bijector.forward_event_shape(shape_in_static), shape_out_static) self.assertEqual(bijector.inverse_event_shape(shape_out_static), shape_in_static)