def testMatchWithAffineTransform(self): direct_bj = tfb.Tanh() indirect_bj = tfb.Chain([ tfb.Shift(tf.cast(-1.0, dtype=tf.float64)), tfb.Scale(tf.cast(2.0, dtype=tf.float64)), tfb.Sigmoid(), tfb.Scale(tf.cast(2.0, dtype=tf.float64)) ]) x = np.linspace(-3.0, 3.0, 100) y = np.tanh(x) self.assertAllClose(self.evaluate(direct_bj.forward(x)), self.evaluate(indirect_bj.forward(x))) self.assertAllClose(self.evaluate(direct_bj.inverse(y)), self.evaluate(indirect_bj.inverse(y))) self.assertAllClose( self.evaluate(direct_bj.inverse_log_det_jacobian(y, event_ndims=0)), self.evaluate( indirect_bj.inverse_log_det_jacobian(y, event_ndims=0))) self.assertAllClose( self.evaluate(direct_bj.forward_log_det_jacobian(x, event_ndims=0)), self.evaluate( indirect_bj.forward_log_det_jacobian(x, event_ndims=0)))
def testTransformedKLDifferentBijectorFails(self): d1 = self._cls()(tfd.Exponential(rate=0.25), bijector=tfb.Scale(scale=2.), validate_args=True) d2 = self._cls()(tfd.Exponential(rate=0.25), bijector=tfb.Scale(scale=3.), validate_args=True) with self.assertRaisesRegex(NotImplementedError, r'their bijectors are not equal'): tfd.kl_divergence(d1, d2)
def testMixedDtypeLogDetJacobian(self): bij = tfb.JointMap({ 'a': tfb.Scale(tf.constant(1, dtype=tf.float16)), 'b': tfb.Scale(tf.constant(2, dtype=tf.float32)), 'c': tfb.Scale(tf.constant(3, dtype=tf.float64)) }) fldj = bij.forward_log_det_jacobian( x={'a': 4, 'b': 5, 'c': 6}, event_ndims=dict.fromkeys('abc', 0)) self.assertDTypeEqual(fldj, np.float64) self.assertAllClose(np.log(1) + np.log(2) + np.log(3), self.evaluate(fldj))
def testCdfDescendingChained(self): bij1 = tfb.Shift(shift=1.)(tfb.Scale(scale=[1., -2.])) bij2 = tfb.Shift(shift=1.)(tfb.Scale(scale=[[3.], [-5.]])) bij3 = tfb.Shift(shift=1.)(tfb.Scale(scale=[[[7.]], [[-11.]]])) for chain in bij2(bij1), bij3(bij2(bij1)): td = self._cls()( distribution=tfd.Normal(loc=0., scale=tf.ones([2, 2, 2])), bijector=chain, validate_args=True) nd = tfd.Normal(loc=1., scale=3., validate_args=True) self.assertAllEqual(tf.ones(td.batch_shape, dtype=tf.bool), td.cdf(nd.quantile(.4)) < td.cdf(nd.quantile(.6)), msg=chain.name)
def testTinyScale(self, dtype): log_scale = tf.cast(-2000., dtype) x = tf.cast(1., dtype) scale = tf.math.exp(log_scale) fldj_linear = tfb.Scale(scale=scale).forward_log_det_jacobian( x, event_ndims=0) fldj_log = tfb.Scale(log_scale=log_scale).forward_log_det_jacobian( x, event_ndims=0) fldj_linear_, fldj_log_ = self.evaluate([fldj_linear, fldj_log]) # Using the linear scale will saturate to 0, and produce bad log-det # Jacobians. self.assertNotEqual(fldj_linear_, fldj_log_) self.assertAllClose(-2000., fldj_log_)
def test_end_to_end_works_correctly(self): true_mean = self.dtype([0, 0]) true_cov = self.dtype([[1, 0.5], [0.5, 1]]) num_results = 500 def target_log_prob(x, y): # Corresponds to unnormalized MVN. # z = matmul(inv(chol(true_cov)), [x, y] - true_mean) z = tf.stack([x, y], axis=-1) - true_mean z = tf.squeeze(tf.linalg.triangular_solve( np.linalg.cholesky(true_cov), z[..., tf.newaxis]), axis=-1) return -0.5 * tf.reduce_sum(z**2., axis=-1) transformed_hmc = tfp.mcmc.TransformedTransitionKernel( inner_kernel=tfp.mcmc.HamiltonianMonteCarlo( target_log_prob_fn=tf.function(target_log_prob, autograph=False), # Affine scaling means we have to change the step_size # in order to get 60% acceptance, as was done in mcmc/hmc_test.py. step_size=[1.23 / 0.75, 1.23 / 0.5], num_leapfrog_steps=2), bijector=[ tfb.Scale(scale=0.75), tfb.Scale(scale=0.5), ]) # Recall, tfp.mcmc.sample_chain calls # transformed_hmc.bootstrap_results too. states, kernel_results = tfp.mcmc.sample_chain( num_results=num_results, # The initial state is used by inner_kernel.bootstrap_results. # Note the input is *after* `bijector.forward`. current_state=[self.dtype(-2), self.dtype(2)], kernel=transformed_hmc, num_burnin_steps=200, num_steps_between_results=1, seed=test_util.test_seed()) states = tf.stack(states, axis=-1) self.assertEqual(num_results, tf.compat.dimension_value(states.shape[0])) sample_mean = tf.reduce_mean(states, axis=0) x = states - sample_mean sample_cov = tf.matmul(x, x, transpose_a=True) / self.dtype(num_results) [sample_mean_, sample_cov_, is_accepted_] = self.evaluate([ sample_mean, sample_cov, kernel_results.inner_results.is_accepted ]) self.assertAllClose(0.6, is_accepted_.mean(), atol=0.15, rtol=0.) self.assertAllClose(sample_mean_, true_mean, atol=0.2, rtol=0.) self.assertAllClose(sample_cov_, true_cov, atol=0., rtol=0.4)
def test_dist_fn_takes_kwargs(self): dist = tfd.JointDistributionNamed( {'positive': tfd.Exponential(rate=1.), 'negative': tfb.Scale(-1.)(tfd.Exponential(rate=1.)), 'b': lambda **kwargs: tfd.Normal(loc=kwargs['negative'], # pylint: disable=g-long-lambda scale=kwargs['positive'], validate_args=True), 'a': lambda **kwargs: tfb.Scale(kwargs['b'])( # pylint: disable=g-long-lambda tfd.Gamma(concentration=-kwargs['negative'], rate=kwargs['positive'], validate_args=True)) }, validate_args=True) lp = dist.log_prob(dist.sample(5, seed=test_util.test_seed())) self.assertAllEqual(lp.shape, [5])
def test_nested_transform(self): target_dist = tfd.Normal(loc=0., scale=1.) b1 = tfb.Scale(0.5) b2 = tfb.Exp() chain = tfb.Chain([b2, b1 ]) # applies bijectors right to left (b1 then b2). inner_kernel = tfp.mcmc.TransformedTransitionKernel( inner_kernel=tfp.mcmc.HamiltonianMonteCarlo( target_log_prob_fn=target_dist.log_prob, num_leapfrog_steps=27, step_size=10), bijector=b1) outer_kernel = tfp.mcmc.TransformedTransitionKernel( inner_kernel=inner_kernel, bijector=b2) chain_kernel = tfp.mcmc.TransformedTransitionKernel( inner_kernel=tfp.mcmc.HamiltonianMonteCarlo( target_log_prob_fn=target_dist.log_prob, num_leapfrog_steps=27, step_size=10), bijector=chain) outer_pkr_one, outer_pkr_two = self.evaluate([ outer_kernel.bootstrap_results(2.), outer_kernel.bootstrap_results(9.), ]) # the outermost kernel only applies the outermost bijector self.assertNear(np.log(2.), outer_pkr_one.transformed_state, err=1e-6) self.assertNear(np.log(9.), outer_pkr_two.transformed_state, err=1e-6) chain_pkr_one, chain_pkr_two = self.evaluate([ chain_kernel.bootstrap_results(2.), chain_kernel.bootstrap_results(9.), ]) # all bijectors are applied to the inner kernel, from innermost to outermost # this behavior is completely analogous to a bijector Chain self.assertNear(chain_pkr_one.transformed_state, outer_pkr_one.inner_results.transformed_state, err=1e-6) self.assertEqual( chain_pkr_one.inner_results.accepted_results, outer_pkr_one.inner_results.inner_results.accepted_results) self.assertNear(chain_pkr_two.transformed_state, outer_pkr_two.inner_results.transformed_state, err=1e-6) self.assertEqual( chain_pkr_two.inner_results.accepted_results, outer_pkr_two.inner_results.inner_results.accepted_results) seed = test_util.test_seed(sampler_type='stateless') outer_results_one, outer_results_two = self.evaluate([ outer_kernel.one_step(2., outer_pkr_one, seed=seed), outer_kernel.one_step(9., outer_pkr_two, seed=seed) ]) chain_results_one, chain_results_two = self.evaluate([ chain_kernel.one_step(2., chain_pkr_one, seed=seed), chain_kernel.one_step(9., chain_pkr_two, seed=seed) ]) self.assertNear(chain_results_one[0], outer_results_one[0], err=1e-6) self.assertNear(chain_results_two[0], outer_results_two[0], err=1e-6)
def testCompositeTensor(self): exp = tfb.Exp() sp = tfb.Softplus() aff = tfb.Scale(scale=2.) bij = tfb.JointMap(bijectors=[exp, sp, aff]) self.assertIsInstance(bij, tf.__internal__.CompositeTensor) # Bijector may be flattened into `Tensor` components and rebuilt. flat = tf.nest.flatten(bij, expand_composites=True) unflat = tf.nest.pack_sequence_as(bij, flat, expand_composites=True) self.assertIsInstance(unflat, tfb.JointMap) # Bijector may be input to a `tf.function`-decorated callable. @tf.function def call_forward(bij, x): return bij.forward(x) x = [1., 2., 3.] self.assertAllClose(call_forward(unflat, x), bij.forward(x)) # Type spec can be encoded/decoded. struct_coder = tf.__internal__.saved_model.StructureCoder() enc = struct_coder.encode_structure(bij._type_spec) dec = struct_coder.decode_proto(enc) self.assertEqual(bij._type_spec, dec)
def test_single_part_str_repr_match_expected(self): bij = tfb.Exp() self.assertContainsInOrder( ['tfp.bijectors.Exp("exp", batch_shape=[], min_event_ndims=0)'], str(bij)) self.assertContainsInOrder( ["<tfp.bijectors.Exp 'exp' batch_shape=[] forward_min_event_ndims=0 " "inverse_min_event_ndims=0 dtype_x=? dtype_y=?>"], repr(bij)) bij = tfb.Scale([1., 1.], name='myscale') self.assertContainsInOrder( ['tfp.bijectors.Scale("myscale", batch_shape=[2], min_event_ndims=0, ' 'dtype=float32)'], str(bij)) self.assertContainsInOrder( ["<tfp.bijectors.Scale 'myscale' batch_shape=[2] " "forward_min_event_ndims=0 inverse_min_event_ndims=0 dtype_x=float32 " "dtype_y=float32>"], repr(bij)) bij = tfb.Split([3, 4, 2], name='s_p_l_i_t') self.assertContainsInOrder( ['tfp.bijectors.Split("s_p_l_i_t", batch_shape=[], ' 'forward_min_event_ndims=1, inverse_min_event_ndims=[1, 1, 1])'], str(bij)) self.assertContainsInOrder( ["<tfp.bijectors.Split 's_p_l_i_t' batch_shape=[] " "forward_min_event_ndims=1 inverse_min_event_ndims=[1, 1, 1] " "dtype_x=? dtype_y=[?, ?, ?]>" ], repr(bij))
def testNoBatchScale(self, is_static, dtype): bijector = tfb.Scale(scale=dtype(2.)) x = self.maybe_static(np.array([1., 2, 3], dtype)) self.assertAllClose([2., 4, 6], bijector.forward(x)) self.assertAllClose([.5, 1, 1.5], bijector.inverse(x)) self.assertAllClose( -np.log(2.), bijector.inverse_log_det_jacobian(x, event_ndims=0))
def testModifiedVariableScaleAssertion(self): v = tf.Variable(1.) self.evaluate(v.initializer) b = tfb.Scale(scale=v, validate_args=True) with self.assertRaisesOpError('Argument `scale` must be non-zero'): with tf.control_dependencies([v.assign(0.)]): _ = self.evaluate(b.forward(1.))
def test_unknown_event_rank(self): if tf.executing_eagerly(): self.skipTest('Eager execution.') unknown_rank_dist = tfd.Independent( tfd.Normal(loc=tf.ones([2, 1, 3]), scale=2.), reinterpreted_batch_ndims=tf1.placeholder_with_default(1, shape=[])) td = tfd.TransformedDistribution( distribution=unknown_rank_dist, bijector=tfb.Scale(1.), validate_args=True) self.assertEqual(td.batch_shape, tf.TensorShape(None)) self.assertEqual(td.event_shape, tf.TensorShape(None)) self.assertAllEqual(td.batch_shape_tensor(), [2, 1]) self.assertAllEqual(td.event_shape_tensor(), [3]) joint_td = tfd.TransformedDistribution( distribution=tfd.JointDistributionSequentialAutoBatched( [unknown_rank_dist, unknown_rank_dist]), bijector=tfb.Invert(tfb.Split(2)), validate_args=True) # Note that the current behavior is conservative; we could also correctly # return a batch shape of `[]` in this case. self.assertEqual(joint_td.batch_shape, tf.TensorShape(None)) self.assertEqual(joint_td.event_shape, tf.TensorShape(None)) self.assertAllEqual(joint_td.batch_shape_tensor(), []) self.assertAllEqual(joint_td.event_shape_tensor(), [2, 1, 6])
def testScalarCongruency(self, dtype): bijector = tfb.Scale(scale=dtype(0.42)) bijector_test_util.assert_scalar_congruency( bijector, lower_x=dtype(-2.), upper_x=dtype(2.), eval_func=self.evaluate)
def testStddev(self): base_stddev = 2. shift = np.array([[-1, 0, 1], [-1, -2, -3]], dtype=np.float32) scale = np.array([[1, -2, 3], [2, -3, 2]], dtype=np.float32) expected_stddev = tf.abs(base_stddev * scale) normal = self._cls()( distribution=tfd.Normal(loc=tf.zeros_like(shift), scale=base_stddev * tf.ones_like(scale), validate_args=True), bijector=tfb.Chain( [tfb.Shift(shift=shift), tfb.Scale(scale=scale)], validate_args=True), validate_args=True) self.assertAllClose(expected_stddev, normal.stddev()) self.assertAllClose(expected_stddev**2, normal.variance()) split_normal = self._cls()(distribution=tfd.Independent( normal, reinterpreted_batch_ndims=1), bijector=tfb.Split(3), validate_args=True) self.assertAllCloseNested( tf.split(expected_stddev, num_or_size_splits=3, axis=-1), split_normal.stddev()) scaled_normal = self._cls()(distribution=tfd.Independent( normal, reinterpreted_batch_ndims=1), bijector=tfb.ScaleMatvecTriL([[1., 0.], [-1., 2.]]), validate_args=True) with self.assertRaisesRegex(NotImplementedError, 'is a multivariate transformation'): scaled_normal.stddev()
def test_bijector_constant_underlying_ildj(self): d = tfb.Scale([2., 3.])(tfd.Normal([0., 0.], 1.)) bij = tfd.Sample(d, [3]).experimental_default_event_space_bijector() ildj = bij.inverse_log_det_jacobian(tf.zeros([2, 3]), event_ndims=1) self.assertAllClose(-np.log([2., 3.]) * 3, ildj) ildj = bij.inverse_log_det_jacobian(tf.zeros([2, 3]), event_ndims=2) self.assertAllClose(-np.log([2., 3.]).sum() * 3, ildj)
def testExcessiveConcretizationOfParamsBatchShapeOverride(self): # Test methods that are not implemented if event_shape is overriden. loc = tfp_hps.defer_and_count_usage( tf.Variable(0., name='loc', dtype=tf.float32, shape=self.shape)) scale = tfp_hps.defer_and_count_usage( tf.Variable(2., name='scale', dtype=tf.float32, shape=self.shape)) bij_scale = tfp_hps.defer_and_count_usage( tf.Variable(2., name='bij_scale', dtype=tf.float32, shape=self.shape)) batch_shape = tfp_hps.defer_and_count_usage( tf.Variable([4, 3, 5], name='input_batch_shape', dtype=tf.int32, shape=self.shape)) dist = tfd.TransformedDistribution( distribution=tfd.Normal(loc=loc, scale=scale, validate_args=True), bijector=tfb.Scale(scale=bij_scale, validate_args=True), batch_shape=batch_shape, validate_args=True) for method in ('log_cdf', 'cdf', 'survival_function', 'log_survival_function'): with tfp_hps.assert_no_excessive_var_usage( method, max_permissible=self.max_permissible[method]): getattr(dist, method)(np.ones((4, 3, 2)) / 3.) with tfp_hps.assert_no_excessive_var_usage( 'quantile', max_permissible=self.max_permissible['quantile']): dist.quantile(.1)
def testBijector(self): low = np.array([[-3.], [0.], [5.]]).astype(np.float32) high = 12. bijector = tfb.Sigmoid(low=low, high=high, validate_args=True) equivalent_bijector = tfb.Chain( [tfb.Shift(shift=low), tfb.Scale(scale=high - low), tfb.Sigmoid()]) x = [[[1., 2., -5., -0.3]]] y = self.evaluate(equivalent_bijector.forward(x)) self.assertAllClose(y, self.evaluate(bijector.forward(x))) self.assertAllClose(x, self.evaluate(bijector.inverse(y)[..., :1, :]), rtol=1e-5) self.assertAllClose( self.evaluate( equivalent_bijector.inverse_log_det_jacobian(y, event_ndims=1)), self.evaluate(bijector.inverse_log_det_jacobian(y, event_ndims=1)), rtol=1e-5) self.assertAllClose( self.evaluate( equivalent_bijector.forward_log_det_jacobian(x, event_ndims=1)), self.evaluate(bijector.forward_log_det_jacobian(x, event_ndims=1)))
def testQuantileDescending(self): td = self._cls()(distribution=tfd.Normal(loc=0., scale=[1., 1.]), bijector=tfb.Shift(shift=1.)( tfb.Scale(scale=[2., -2.])), validate_args=True) self.assertAllEqual(tf.ones(td.batch_shape, dtype=tf.bool), td.quantile(.8) < td.quantile(.9))
def testCompositeTensor(self): exp = tfb.Exp() sp = tfb.Softplus() aff = tfb.Scale(scale=2.) blockwise = tfb.Blockwise(bijectors=[exp, sp, aff]) self.assertIsInstance(blockwise, tf.__internal__.CompositeTensor) # Bijector may be flattened into `Tensor` components and rebuilt. flat = tf.nest.flatten(blockwise, expand_composites=True) unflat = tf.nest.pack_sequence_as(blockwise, flat, expand_composites=True) self.assertIsInstance(unflat, tfb.Blockwise) # Bijector may be input to a `tf.function`-decorated callable. @tf.function def call_forward(bij, x): return bij.forward(x) x = tf.ones([2, 3], dtype=tf.float32) self.assertAllClose(call_forward(unflat, x), blockwise.forward(x)) # Type spec can be encoded/decoded. enc = tf.__internal__.saved_model.encode_structure( blockwise._type_spec) dec = tf.__internal__.saved_model.decode_proto(enc) self.assertEqual(blockwise._type_spec, dec)
def testExcessiveConcretizationOfParams(self): loc = tfp_hps.defer_and_count_usage( tf.Variable(0., name='loc', dtype=tf.float32, shape=self.shape)) scale = tfp_hps.defer_and_count_usage( tf.Variable(2., name='scale', dtype=tf.float32, shape=self.shape)) bij_scale = tfp_hps.defer_and_count_usage( tf.Variable(2., name='bij_scale', dtype=tf.float32, shape=self.shape)) event_shape = tfp_hps.defer_and_count_usage( tf.Variable([2, 2], name='input_event_shape', dtype=tf.int32, shape=self.shape)) batch_shape = tfp_hps.defer_and_count_usage( tf.Variable([4, 3, 5], name='input_batch_shape', dtype=tf.int32, shape=self.shape)) dist = tfd.TransformedDistribution( distribution=tfd.Normal(loc=loc, scale=scale, validate_args=True), bijector=tfb.Scale(scale=bij_scale, validate_args=True), event_shape=event_shape, batch_shape=batch_shape, validate_args=True) for method in ('mean', 'entropy', 'event_shape_tensor', 'batch_shape_tensor'): with tfp_hps.assert_no_excessive_var_usage( method, max_permissible=self.max_permissible[method]): getattr(dist, method)() with tfp_hps.assert_no_excessive_var_usage( 'sample', max_permissible=self.max_permissible['sample']): dist.sample(seed=test_util.test_seed()) for method in ('log_prob', 'prob'): with tfp_hps.assert_no_excessive_var_usage( method, max_permissible=self.max_permissible[method]): getattr(dist, method)(np.ones((4, 3, 5, 2, 2)) / 3.)
def testBijectorWithDeepStructure(self): bij = tfb.JointMap({ 'a': tfb.Exp(), 'bc': tfb.JointMap([tfb.Scale(2.), tfb.Shift(3.)]) }) a = np.asarray([[[1, 2], [2, 3]]], dtype=np.float32) # shape=[1, 2, 2] b = np.asarray([[0, 4]], dtype=np.float32) # shape=[1, 2] c = np.asarray([[5, 6]], dtype=np.float32) # shape=[1, 2] inputs = { 'a': a, 'bc': [b, c] } # Could be inputs to forward or inverse. event_ndims = {'a': 1, 'bc': [0, 0]} self.assertStartsWith(bij.name, 'jointmap_of_exp_and_jointmap_of_') self.assertAllCloseNested({ 'a': np.exp(a), 'bc': [b * 2., c + 3] }, self.evaluate(bij.forward(inputs))) self.assertAllCloseNested({ 'a': np.log(a), 'bc': [b / 2., c - 3] }, self.evaluate(bij.inverse(inputs))) fldj = self.evaluate(bij.forward_log_det_jacobian(inputs, event_ndims)) self.assertEqual((1, 2), fldj.shape) self.assertAllClose(np.sum(a, axis=-1) + np.log(2), fldj) ildj = self.evaluate(bij.inverse_log_det_jacobian(inputs, event_ndims)) self.assertEqual((1, 2), ildj.shape) self.assertAllClose(-np.log(a).sum(axis=-1) - np.log(2), ildj)
def testBatchShapeBroadcasts(self): bij = tfb.JointMap({ 'a': tfb.Exp(), 'b': tfb.Scale(10.) }, validate_args=True) self.assertStartsWith(bij.name, 'jointmap_of_exp_and_scale') a = np.asarray([[[1, 2]], [[2, 3]]], dtype=np.float32) # shape=[2, 1, 2] b = np.asarray([[0, 1, 2]], dtype=np.float32) # shape=[1, 3] inputs = {'a': a, 'b': b} # Could be inputs to forward or inverse. self.assertAllClose( a.sum(axis=-1) + np.log(10.), self.evaluate( bij.forward_log_det_jacobian(inputs, { 'a': 1, 'b': 0 }))) self.assertAllClose( a.sum(axis=-1) + 3 * np.log(10.), self.evaluate( bij.forward_log_det_jacobian(inputs, { 'a': 1, 'b': 1 })))
def default_bijector(cls, dtype: Any = None, **kwargs) -> tfb.Bijector: """ Linear bijection between $[0, 1]^{2} <--> [0, 4]^{2}$ """ if dtype is None: dtype = default_float() return tfb.Scale(tf.cast(4.0, dtype=dtype))
def testCdfDescending(self): td = tfd.TransformedDistribution( distribution=tfd.Normal(loc=0., scale=[1., 1.]), bijector=tfb.Shift(shift=1.)(tfb.Scale(scale=[2., -2.])), validate_args=True) nd = tfd.Normal(loc=1., scale=2., validate_args=True) self.assertAllEqual(tf.ones(td.batch_shape, dtype=tf.bool), td.cdf(nd.quantile(.8)) < td.cdf(nd.quantile(.9)))
def testLDJRatio(self): q = tfb.JointMap({ 'a': tfb.Exp(), 'b': tfb.Scale(2.), 'c': tfb.Shift(3.) }) p = tfb.JointMap({ 'a': tfb.Exp(), 'b': tfb.Scale(3.), 'c': tfb.Shift(4.) }) a = np.asarray([[[1, 2], [2, 3]]], dtype=np.float32) # shape=[1, 2, 2] b = np.asarray([[0, 4]], dtype=np.float32) # shape=[1, 2] c = np.asarray([[5, 6]], dtype=np.float32) # shape=[1, 2] x = {'a': a, 'b': b, 'c': c} y = {'a': a + 1, 'b': b + 1, 'c': c + 1} event_ndims = {'a': 1, 'b': 0, 'c': 0} fldj_ratio_true = p.forward_log_det_jacobian( x, event_ndims) - q.forward_log_det_jacobian(y, event_ndims) fldj_ratio = ldj_ratio.forward_log_det_jacobian_ratio( p, x, q, y, event_ndims) self.assertAllClose(fldj_ratio_true, fldj_ratio) ildj_ratio_true = p.inverse_log_det_jacobian( x, event_ndims) - q.inverse_log_det_jacobian(y, event_ndims) ildj_ratio = ldj_ratio.inverse_log_det_jacobian_ratio( p, x, q, y, event_ndims) self.assertAllClose(ildj_ratio_true, ildj_ratio) event_ndims = {'a': 1, 'b': 2, 'c': 0} fldj_ratio_true = p.forward_log_det_jacobian( x, event_ndims) - q.forward_log_det_jacobian(y, event_ndims) fldj_ratio = ldj_ratio.forward_log_det_jacobian_ratio( p, x, q, y, event_ndims) self.assertAllClose(fldj_ratio_true, fldj_ratio) ildj_ratio_true = p.inverse_log_det_jacobian( x, event_ndims) - q.inverse_log_det_jacobian(y, event_ndims) ildj_ratio = ldj_ratio.inverse_log_det_jacobian_ratio( p, x, q, y, event_ndims) self.assertAllClose(ildj_ratio_true, ildj_ratio)
def testBatchScale(self, is_static, dtype): # Batched scale bijector = tfb.Scale(scale=dtype([2., 3.])) x = self.maybe_static(np.array([1.], dtype=dtype)) self.assertAllClose([2., 3.], bijector.forward(x)) self.assertAllClose([0.5, 1. / 3.], bijector.inverse(x)) self.assertAllClose([-np.log(2.), -np.log(3.)], bijector.inverse_log_det_jacobian(x, event_ndims=0))
def testNestedDtype(self): chain = tfb.Chain([ tfb.Identity(), tfb.Scale(tf.constant(2., tf.float64)), tfb.Identity() ]) self.assertAllClose(tf.constant([2, 4, 6], tf.float64), self.evaluate(chain.forward([1, 2, 3])))
def default_bijector(cls, dtype: Any = None, **kwargs) -> tfb.Bijector: """ Affine bijection between $[[0, 1], [0, 1]] <--> [[-2.5, 2.5], [-1.0, 2.0]]$ """ if dtype is None: dtype = default_float() scale = tfb.Scale(tf.convert_to_tensor([5.0, 3.0], dtype=dtype)) shift = tfb.Shift(tf.convert_to_tensor([-0.5, -1 / 3], dtype=dtype)) return tfb.Chain([scale, shift])
def testScalarBatchScalarEventIdentityScale(self): exp2 = self._cls()( tfd.Exponential(rate=0.25), bijector=tfb.Scale(scale=2.), validate_args=True) log_prob = exp2.log_prob(1.) log_prob_ = self.evaluate(log_prob) base_log_prob = -0.5 * 0.25 + np.log(0.25) ildj = np.log(2.) self.assertAllClose(base_log_prob - ildj, log_prob_, rtol=1e-6, atol=0.)