def testMinEventNdimsWithJointMap(self): jm_0 = tfb.JointMap([ShapeChanging(1, 1), ShapeChanging(3, 1)]) split = ShapeChanging(1, [1, 1]) concat = ShapeChanging([1, 1], 1) jm_1 = tfb.JointMap([ShapeChanging(1, 0), ShapeChanging(1, 1)]) permute = PermuteParts() self._validateChainMinEventNdims(bijectors=[jm_0, split, concat, jm_1], forward_min_event_ndims=[4, 3], inverse_min_event_ndims=[3, 1]) self._validateChainMinEventNdims(bijectors=[jm_0, jm_1], forward_min_event_ndims=[2, 3], inverse_min_event_ndims=[1, 1]) self._validateChainMinEventNdims(bijectors=[jm_1, jm_0], forward_min_event_ndims=[1, 3], inverse_min_event_ndims=[0, 1]) self._validateChainMinEventNdims(bijectors=[jm_1, permute, jm_0], forward_min_event_ndims=[1, 3], inverse_min_event_ndims=[0, 1]) self._validateChainMinEventNdims(bijectors=[jm_0, split], forward_min_event_ndims=3, inverse_min_event_ndims=[3, 1]) self._validateChainMinEventNdims(bijectors=[permute, jm_1, split], forward_min_event_ndims=1, inverse_min_event_ndims=[1, 0])
def testBijectorWithDeepStructure(self): bij = tfb.JointMap({ 'a': tfb.Exp(), 'bc': tfb.JointMap([tfb.Scale(2.), tfb.Shift(3.)]) }) a = np.asarray([[[1, 2], [2, 3]]], dtype=np.float32) # shape=[1, 2, 2] b = np.asarray([[0, 4]], dtype=np.float32) # shape=[1, 2] c = np.asarray([[5, 6]], dtype=np.float32) # shape=[1, 2] inputs = { 'a': a, 'bc': [b, c] } # Could be inputs to forward or inverse. event_ndims = {'a': 1, 'bc': [0, 0]} self.assertStartsWith(bij.name, 'jointmap_of_exp_and_jointmap_of_') self.assertAllCloseNested({ 'a': np.exp(a), 'bc': [b * 2., c + 3] }, self.evaluate(bij.forward(inputs))) self.assertAllCloseNested({ 'a': np.log(a), 'bc': [b / 2., c - 3] }, self.evaluate(bij.inverse(inputs))) fldj = self.evaluate(bij.forward_log_det_jacobian(inputs, event_ndims)) self.assertEqual((1, 2), fldj.shape) self.assertAllClose(np.sum(a, axis=-1) + np.log(2), fldj) ildj = self.evaluate(bij.inverse_log_det_jacobian(inputs, event_ndims)) self.assertEqual((1, 2), ildj.shape) self.assertAllClose(-np.log(a).sum(axis=-1) - np.log(2), ildj)
def testNonCompositeTensor(self): exp = tfb.Exp() scale = test_util.NonCompositeTensorScale(scale=tf.constant(3.)) bij = tfb.JointMap(bijectors=[exp, scale]) self.assertNotIsInstance(bij, tf.__internal__.CompositeTensor) self.assertAllCloseNested(bij.forward( [1., 1.]), [exp.forward(1.), scale.forward(1.)])
def testCompositeTensor(self): exp = tfb.Exp() sp = tfb.Softplus() aff = tfb.Scale(scale=2.) bij = tfb.JointMap(bijectors=[exp, sp, aff]) self.assertIsInstance(bij, tf.__internal__.CompositeTensor) # Bijector may be flattened into `Tensor` components and rebuilt. flat = tf.nest.flatten(bij, expand_composites=True) unflat = tf.nest.pack_sequence_as(bij, flat, expand_composites=True) self.assertIsInstance(unflat, tfb.JointMap) # Bijector may be input to a `tf.function`-decorated callable. @tf.function def call_forward(bij, x): return bij.forward(x) x = [1., 2., 3.] self.assertAllClose(call_forward(unflat, x), bij.forward(x)) # Type spec can be encoded/decoded. struct_coder = tf.__internal__.saved_model.StructureCoder() enc = struct_coder.encode_structure(bij._type_spec) dec = struct_coder.decode_proto(enc) self.assertEqual(bij._type_spec, dec)
def testBatchShapeBroadcasts(self): bij = tfb.JointMap({ 'a': tfb.Exp(), 'b': tfb.Scale(10.) }, validate_args=True) self.assertStartsWith(bij.name, 'jointmap_of_exp_and_scale') a = np.asarray([[[1, 2]], [[2, 3]]], dtype=np.float32) # shape=[2, 1, 2] b = np.asarray([[0, 1, 2]], dtype=np.float32) # shape=[1, 3] inputs = {'a': a, 'b': b} # Could be inputs to forward or inverse. self.assertAllClose( a.sum(axis=-1) + np.log(10.), self.evaluate( bij.forward_log_det_jacobian(inputs, { 'a': 1, 'b': 0 }))) self.assertAllClose( a.sum(axis=-1) + 3 * np.log(10.), self.evaluate( bij.forward_log_det_jacobian(inputs, { 'a': 1, 'b': 1 })))
def test_batch_broadcast_vector_to_parts(self): batch_shape = [4, 2] true_split_sizes = [1, 3, 2] base_event_size = sum(true_split_sizes) # Base dist with no batch shape (will require broadcasting). base_dist = tfd.MultivariateNormalDiag( loc=tf.random.normal([base_event_size], seed=test_util.test_seed()), scale_diag=tf.exp(tf.random.normal([base_event_size], seed=test_util.test_seed()))) # Bijector with batch shape in one part. bijector = tfb.Chain([tfb.JointMap([tfb.Identity(), tfb.Identity(), tfb.Shift( tf.ones(batch_shape + [true_split_sizes[-1]]))]), tfb.Split(true_split_sizes, axis=-1)]) split_dist = tfd.TransformedDistribution(base_dist, bijector) self.assertAllEqual(split_dist.batch_shape, batch_shape) # Because one branch of the split has batch shape, TD should feed batches # of base samples *into* the split, so the batch shape propagates to all # branches. xs = split_dist.sample(seed=test_util.test_seed()) self.assertAllEqualNested( tf.nest.map_structure(lambda x: x.shape, xs), [batch_shape + [d] for d in true_split_sizes])
def testNonCompositeTensor(self): # TODO(b/182603117): Move NonComposite* into test_util. class NonCompositeScale(tfb.Bijector): """Bijector that is not a `CompositeTensor`.""" def __init__(self, scale): parameters = dict(locals()) self.scale = scale super(NonCompositeScale, self).__init__( validate_args=True, forward_min_event_ndims=0., parameters=parameters, name='non_composite_scale') def _forward(self, x): return x * self.scale exp = tfb.Exp() scale = NonCompositeScale(scale=tf.constant(3.)) bij = tfb.JointMap(bijectors=[exp, scale]) self.assertNotIsInstance(bij, tf.__internal__.CompositeTensor) self.assertAllClose( bij.forward([1., 1.]), [exp.forward(1.), scale.forward(1.)])
def testLDJRatio(self): q = tfb.JointMap({ 'a': tfb.Exp(), 'b': tfb.Scale(2.), 'c': tfb.Shift(3.) }) p = tfb.JointMap({ 'a': tfb.Exp(), 'b': tfb.Scale(3.), 'c': tfb.Shift(4.) }) a = np.asarray([[[1, 2], [2, 3]]], dtype=np.float32) # shape=[1, 2, 2] b = np.asarray([[0, 4]], dtype=np.float32) # shape=[1, 2] c = np.asarray([[5, 6]], dtype=np.float32) # shape=[1, 2] x = {'a': a, 'b': b, 'c': c} y = {'a': a + 1, 'b': b + 1, 'c': c + 1} event_ndims = {'a': 1, 'b': 0, 'c': 0} fldj_ratio_true = p.forward_log_det_jacobian( x, event_ndims) - q.forward_log_det_jacobian(y, event_ndims) fldj_ratio = ldj_ratio.forward_log_det_jacobian_ratio( p, x, q, y, event_ndims) self.assertAllClose(fldj_ratio_true, fldj_ratio) ildj_ratio_true = p.inverse_log_det_jacobian( x, event_ndims) - q.inverse_log_det_jacobian(y, event_ndims) ildj_ratio = ldj_ratio.inverse_log_det_jacobian_ratio( p, x, q, y, event_ndims) self.assertAllClose(ildj_ratio_true, ildj_ratio) event_ndims = {'a': 1, 'b': 2, 'c': 0} fldj_ratio_true = p.forward_log_det_jacobian( x, event_ndims) - q.forward_log_det_jacobian(y, event_ndims) fldj_ratio = ldj_ratio.forward_log_det_jacobian_ratio( p, x, q, y, event_ndims) self.assertAllClose(fldj_ratio_true, fldj_ratio) ildj_ratio_true = p.inverse_log_det_jacobian( x, event_ndims) - q.inverse_log_det_jacobian(y, event_ndims) ildj_ratio = ldj_ratio.inverse_log_det_jacobian_ratio( p, x, q, y, event_ndims) self.assertAllClose(ildj_ratio_true, ildj_ratio)
def test_slice_single_param_bijector_composition(self): sliced = slicing._slice_single_param( tfb.JointMap({'a': tfb.Chain([ tfb.Invert(tfb.Scale(tf.ones([4, 3, 1]))) ])}), param_event_ndims={'a': 1}, slices=make_slices[..., tf.newaxis, 2:, tf.newaxis], batch_shape=tf.constant([7, 4, 3])) self.assertAllEqual( list(tf.zeros([1, 4, 3])[..., tf.newaxis, 2:, tf.newaxis].shape), sliced.experimental_batch_shape_tensor(x_event_ndims={'a': 1}))
def testMixedDtypeLogDetJacobian(self): bij = tfb.JointMap({ 'a': tfb.Scale(tf.constant(1, dtype=tf.float16)), 'b': tfb.Scale(tf.constant(2, dtype=tf.float32)), 'c': tfb.Scale(tf.constant(3, dtype=tf.float64)) }) fldj = bij.forward_log_det_jacobian( x={'a': 4, 'b': 5, 'c': 6}, event_ndims=dict.fromkeys('abc', 0)) self.assertDTypeEqual(fldj, np.float64) self.assertAllClose(np.log(1) + np.log(2) + np.log(3), self.evaluate(fldj))
def get_trainable_shift_bijector(flat_event_size, init_loc_unconstrained, dtype=DEFAULT_FLOAT_DTYPE_TF): return tfb.JointMap( tf.nest.map_structure( lambda s, init: tfb.Shift( tf.Variable( tf.random.uniform( (s, ), minval=-2.0, maxval=2.0, dtype=dtype) if init is None else init)), flat_event_size, init_loc_unconstrained, ))
def testMinEventNdimsWithJointMap(self): jm_0 = tfb.JointMap([ShapeChanging(1, 1), ShapeChanging(3, 1)]) split = ShapeChanging(1, [1, 1]) concat = ShapeChanging([1, 1], 1) jm_1 = tfb.JointMap([ShapeChanging(1, 0), ShapeChanging(1, 1)]) self.assertFalse(jm_0.has_static_min_event_ndims) self.assertFalse(jm_1.has_static_min_event_ndims) self.assertTrue(split.has_static_min_event_ndims) self.assertTrue(concat.has_static_min_event_ndims) # Decidable. Inner bijectors have static min_event_ndims. chain = tfb.Chain([jm_0, split, concat, jm_1]) self.assertTrue(chain.has_static_min_event_ndims) self.assertAllEqualNested([4, 3], chain.forward_min_event_ndims) self.assertAllEqualNested([3, 1], chain.inverse_min_event_ndims) # Undecidable. None of the nested bijectors have known event_ndims. chain = tfb.Chain([jm_0, jm_1]) self.assertFalse(chain.has_static_min_event_ndims) self.assertAllEqualNested([None, None], chain.forward_min_event_ndims) self.assertAllEqualNested([None, None], chain.inverse_min_event_ndims)
def test_transform_joint_to_joint(self, split_sizes): dist_batch_shape = tf.nest.pack_sequence_as( split_sizes, [tensorshape_util.constant_value_as_shape(s) for s in [[2, 3], [2, 1], [1, 3]]]) bijector_batch_shape = [1, 3] # Build a joint distribution with parts of the specified sizes. seed = test_util.test_seed_stream() component_dists = tf.nest.map_structure( lambda size, batch_shape: tfd.MultivariateNormalDiag( # pylint: disable=g-long-lambda loc=tf.random.normal(batch_shape + [size], seed=seed()), scale_diag=tf.random.uniform( minval=1., maxval=2., shape=batch_shape + [size], seed=seed())), split_sizes, dist_batch_shape) if isinstance(split_sizes, dict): base_dist = tfd.JointDistributionNamed(component_dists) else: base_dist = tfd.JointDistributionSequential(component_dists) # Transform the distribution by applying a separate bijector to each part. bijectors = [tfb.Exp(), tfb.Scale( tf.random.uniform( minval=1., maxval=2., shape=bijector_batch_shape, seed=seed())), tfb.Reshape([2, 1])] bijector = tfb.JointMap(tf.nest.pack_sequence_as(split_sizes, bijectors), validate_args=True) # Transform a joint distribution that has different batch shape components transformed_dist = tfd.TransformedDistribution(base_dist, bijector) self.assertRegex( str(transformed_dist), '{}.*batch_shape.*event_shape.*dtype'.format(transformed_dist.name)) self.assertAllEqualNested( transformed_dist.event_shape, bijector.forward_event_shape(base_dist.event_shape)) self.assertAllEqualNested(*self.evaluate(( transformed_dist.event_shape_tensor(), bijector.forward_event_shape_tensor(base_dist.event_shape_tensor())))) # Test that the batch shape components of the input are the same as those of # the output. self.assertAllEqualNested(transformed_dist.batch_shape, dist_batch_shape) self.assertAllEqualNested( self.evaluate(transformed_dist.batch_shape_tensor()), dist_batch_shape) self.assertAllEqualNested(dist_batch_shape, base_dist.batch_shape)
def testMinEventNdimsWithPartiallyDependentJointMap(self): dependent = tfb.Chain([tfb.Split(2), tfb.Invert(tfb.Split(2))]) wrap_in_list = tfb.Restructure(input_structure=[0, 1], output_structure=[[0, 1]]) dependent_as_chain = tfb.Chain([ tfb.Invert(wrap_in_list), tfb.JointMap([dependent]), wrap_in_list]) self.assertAllEqualNested(dependent.forward_min_event_ndims, dependent_as_chain.forward_min_event_ndims) self.assertAllEqualNested(dependent.inverse_min_event_ndims, dependent_as_chain.inverse_min_event_ndims) self.assertAllEqualNested(dependent._parts_interact, dependent_as_chain._parts_interact)
def get_mean_field_approximation( model: tfd.JointDistribution, init_loc=None, dtype=DEFAULT_FLOAT_DTYPE_TF, joint_bijector_func: tp.Callable[ [tfd.JointDistribution], tfb.Composition] = get_default_event_space_bijector, event_shape_fn: tp.Callable[[tfd.JointDistribution], object] = event_shape_fn, ) -> tfd.Distribution: event_shape = event_shape_fn(model) flat_event_shape = tf.nest.flatten(event_shape) flat_event_size = tf.nest.map_structure(tf.reduce_prod, flat_event_shape) operator_classes = get_mean_field_operator_classes(flat_event_size) linear_operator_block = build_trainable_linear_operator_block( operator_classes, flat_event_size, dtype=dtype) scale_bijector = tfb.ScaleMatvecLinearOperatorBlock(linear_operator_block) if init_loc is None: init_loc = tf.nest.map_structure(lambda _: None, flat_event_shape) else: init_loc = defaultdict(lambda: None, init_loc) # TODO: Handle nesting event_space_bijector = joint_bijector_func(model) unflatten_bijector = tfb.Restructure( tf.nest.pack_sequence_as(event_shape, range(len(flat_event_shape)))) reshape_bijector = tfb.JointMap( tf.nest.map_structure(tfb.Reshape, flat_event_shape)) init_loc_unconstrained = joint_inverse_with_nones(event_space_bijector, init_loc) init_loc_flat = unflatten_bijector.inverse(init_loc_unconstrained) init_loc_1d = joint_inverse_with_nones(reshape_bijector, init_loc_flat) loc_bijector = get_trainable_shift_bijector(flat_event_size, init_loc_1d, dtype=dtype) base_standard_dist = get_base_distribution(flat_event_size, dtype=dtype) chain_bijector = tfb.Chain([ event_space_bijector, unflatten_bijector, reshape_bijector, loc_bijector, scale_bijector, ]) distribution = tfd.TransformedDistribution(base_standard_dist, chain_bijector) return distribution
def test_composition_str_and_repr_match_expected_dynamic_shape(self): bij = tfb.Chain([ tfb.Exp(), tfb.Shift(self._tensor([1., 2.])), tfb.SoftmaxCentered() ]) self.assertContainsInOrder([ 'tfp.bijectors.Chain(', ('min_event_ndims=1, bijectors=[Exp, Shift, SoftmaxCentered])') ], str(bij)) self.assertContainsInOrder([ '<tfp.bijectors.Chain ', ('batch_shape=? forward_min_event_ndims=1 inverse_min_event_ndims=1 ' 'dtype_x=float32 dtype_y=float32 bijectors=[<tfp.bijectors.Exp'), '>, <tfp.bijectors.Shift', '>, <tfp.bijectors.SoftmaxCentered', '>]>' ], repr(bij)) bij = tfb.Chain([ tfb.JointMap({ 'a': tfb.Exp(), 'b': tfb.ScaleMatvecDiag(self._tensor([2., 2.])) }), tfb.Restructure({ 'a': 0, 'b': 1 }, [0, 1]), tfb.Split(2), tfb.Invert(tfb.SoftmaxCentered()), ]) self.assertContainsInOrder([ 'tfp.bijectors.Chain(', ('forward_min_event_ndims=1, ' 'inverse_min_event_ndims={a: 1, b: 1}, ' 'bijectors=[JointMap({a: Exp, b: ScaleMatvecDiag}), ' 'Restructure, Split, Invert(SoftmaxCentered)])') ], str(bij)) self.assertContainsInOrder([ '<tfp.bijectors.Chain ', ('batch_shape=? forward_min_event_ndims=1 ' "inverse_min_event_ndims={'a': 1, 'b': 1} dtype_x=float32 " "dtype_y={'a': ?, 'b': float32} " "bijectors=[<tfp.bijectors.JointMap "), '>, <tfp.bijectors.Restructure', '>, <tfp.bijectors.Split', '>, <tfp.bijectors.Invert', '>]>' ], repr(bij))
def test_slice_transformed_distribution_with_chain(self): dist = tfd.TransformedDistribution( distribution=tfd.MultivariateNormalDiag( loc=tf.zeros([4]), scale_diag=tf.ones([1, 4])), bijector=tfb.Chain([tfb.JointMap([tfb.Identity(), tfb.Shift(tf.ones([4, 3, 2]))]), tfb.Split(2), tfb.ScaleMatvecDiag(tf.ones([5, 1, 3, 4])), tfb.Exp()])) self.assertAllEqual(dist.batch_shape_tensor(), [5, 4, 3]) self.assertAllEqualNested( tf.nest.map_structure(lambda x: x.shape, dist.sample(seed=test_util.test_seed())), [[5, 4, 3, 2], [5, 4, 3, 2]]) sliced = dist[tf.newaxis, ..., 0, :, :-1] self.assertAllEqual(sliced.batch_shape_tensor(), [1, 4, 2]) self.assertAllEqualNested( tf.nest.map_structure(lambda x: x.shape, sliced.sample(seed=test_util.test_seed())), [[1, 4, 2, 2], [1, 4, 2, 2]])
def build_factored_surrogate_posterior( event_shape=None, bijector=None, constraining_bijectors=None, initial_unconstrained_loc=_sample_uniform_initial_loc, initial_unconstrained_scale=1e-2, trainable_distribution_fn=_build_trainable_normal_dist, seed=None, validate_args=False, name=None): """Builds a joint variational posterior that factors over model variables. By default, this method creates an independent trainable Normal distribution for each variable, transformed using a bijector (if provided) to match the support of that variable. This makes extremely strong assumptions about the posterior: that it is approximately normal (or transformed normal), and that all model variables are independent. Args: event_shape: `Tensor` shape, or nested structure of `Tensor` shapes, specifying the event shape(s) of the posterior variables. bijector: Optional `tfb.Bijector` instance, or nested structure of such instances, defining support(s) of the posterior variables. The structure must match that of `event_shape` and may contain `None` values. A posterior variable will be modeled as `tfd.TransformedDistribution(underlying_dist, bijector)` if a corresponding constraining bijector is specified, otherwise it is modeled as supported on the unconstrained real line. constraining_bijectors: Deprecated alias for `bijector`. initial_unconstrained_loc: Optional Python `callable` with signature `tensor = initial_unconstrained_loc(shape, seed)` used to sample real-valued initializations for the unconstrained representation of each variable. May alternately be a nested structure of `Tensor`s, giving specific initial locations for each variable; these must have structure matching `event_shape` and shapes determined by the inverse image of `event_shape` under `bijector`, which may optionally be prefixed with a common batch shape. Default value: `functools.partial(tf.random.uniform, minval=-2., maxval=2., dtype=tf.float32)`. initial_unconstrained_scale: Optional scalar float `Tensor` initial scale for the unconstrained distributions, or a nested structure of `Tensor` initial scales for each variable. Default value: `1e-2`. trainable_distribution_fn: Optional Python `callable` with signature `trainable_dist = trainable_distribution_fn(initial_loc, initial_scale, event_ndims, validate_args)`. This is called for each model variable to build the corresponding factor in the surrogate posterior. It is expected that the distribution returned is supported on unconstrained real values. Default value: `functools.partial( tfp.experimental.vi.build_trainable_location_scale_distribution, distribution_fn=tfd.Normal)`, i.e., a trainable Normal distribution. seed: Python integer to seed the random number generator. This is used only when `initial_loc` is not specified. validate_args: Python `bool`. Whether to validate input with asserts. This imposes a runtime cost. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. Default value: `False`. name: Python `str` name prefixed to ops created by this function. Default value: `None` (i.e., 'build_factored_surrogate_posterior'). Returns: surrogate_posterior: A `tfd.Distribution` instance whose samples have shape and structure matching that of `event_shape` or `initial_loc`. ### Examples Consider a Gamma model with unknown parameters, expressed as a joint Distribution: ```python Root = tfd.JointDistributionCoroutine.Root def model_fn(): concentration = yield Root(tfd.Exponential(1.)) rate = yield Root(tfd.Exponential(1.)) y = yield tfd.Sample(tfd.Gamma(concentration=concentration, rate=rate), sample_shape=4) model = tfd.JointDistributionCoroutine(model_fn) ``` Let's use variational inference to approximate the posterior over the data-generating parameters for some observed `y`. We'll build a surrogate posterior distribution by specifying the shapes of the latent `rate` and `concentration` parameters, and that both are constrained to be positive. ```python surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior( event_shape=model.event_shape_tensor()[:-1], # Omit the observed `y`. bijector=[tfb.Softplus(), # Rate is positive. tfb.Softplus()]) # Concentration is positive. ``` This creates a trainable joint distribution, defined by variables in `surrogate_posterior.trainable_variables`. We use `fit_surrogate_posterior` to fit this distribution by minimizing a divergence to the true posterior. ```python y = [0.2, 0.5, 0.3, 0.7] losses = tfp.vi.fit_surrogate_posterior( lambda rate, concentration: model.log_prob([rate, concentration, y]), surrogate_posterior=surrogate_posterior, num_steps=100, optimizer=tf.optimizers.Adam(0.1), sample_size=10) # After optimization, samples from the surrogate will approximate # samples from the true posterior. samples = surrogate_posterior.sample(100) posterior_mean = [tf.reduce_mean(x) for x in samples] # mean ~= [1.1, 2.1] posterior_std = [tf.math.reduce_std(x) for x in samples] # std ~= [0.3, 0.8] ``` If we wanted to initialize the optimization at a specific location, we can specify one when we build the surrogate posterior. This function requires the initial location to be specified in *unconstrained* space; we do this by inverting the constraining bijectors (note this section also demonstrates the creation of a dict-structured model). ```python initial_loc = {'concentration': 0.4, 'rate': 0.2} bijector={'concentration': tfb.Softplus(), # Rate is positive. 'rate': tfb.Softplus()} # Concentration is positive. initial_unconstrained_loc = tf.nest.map_fn( lambda b, x: b.inverse(x) if b is not None else x, bijector, initial_loc) surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior( event_shape=tf.nest.map_fn(tf.shape, initial_loc), bijector=bijector, initial_unconstrained_loc=initial_unconstrained_state, initial_unconstrained_scale=1e-4) ``` """ with tf.name_scope(name or 'build_factored_surrogate_posterior'): bijector = deprecation.deprecated_argument_lookup( 'bijector', bijector, 'constraining_bijectors', constraining_bijectors) seed = tfp_util.SeedStream(seed, salt='build_factored_surrogate_posterior') # Convert event shapes to Tensors. shallow_structure = _get_event_shape_shallow_structure(event_shape) event_shape = nest.map_structure_up_to( shallow_structure, lambda s: tf.convert_to_tensor(s, dtype=tf.int32), event_shape) if nest.is_nested(bijector): bijector = nest.map_structure( lambda b: identity_bijector.Identity() if b is None else b, bijector) # Support mismatched nested structures for backwards compatibility (e.g. # non-nested `event_shape` and a single-element list of `bijector`s). bijector = nest.pack_sequence_as(event_shape, nest.flatten(bijector)) event_space_bijector = tfb.JointMap(bijector, validate_args=validate_args) else: event_space_bijector = bijector if event_space_bijector is None: unconstrained_event_shape = event_shape else: unconstrained_event_shape = ( event_space_bijector.inverse_event_shape_tensor(event_shape)) # Construct initial locations for the internal unconstrained dists. if callable( initial_unconstrained_loc): # Sample random initialization. initial_unconstrained_loc = nest.map_structure( lambda s: initial_unconstrained_loc(shape=s, seed=seed()), unconstrained_event_shape) if not nest.is_nested(initial_unconstrained_scale): initial_unconstrained_scale = nest.map_structure( lambda _: initial_unconstrained_scale, unconstrained_event_shape) # Extract the rank of each event, so that we build distributions with the # correct event shapes. unconstrained_event_ndims = nest.map_structure( ps.rank_from_shape, unconstrained_event_shape) # Build the component surrogate posteriors. unconstrained_distributions = nest.map_structure_up_to( unconstrained_event_shape, lambda loc, scale, ndims: trainable_distribution_fn( # pylint: disable=g-long-lambda loc, scale, ndims, validate_args=validate_args), initial_unconstrained_loc, initial_unconstrained_scale, unconstrained_event_ndims) base_distribution = (joint_distribution_util. independent_joint_distribution_from_structure( unconstrained_distributions, validate_args=validate_args)) if event_space_bijector is None: return base_distribution return transformed_distribution.TransformedDistribution( base_distribution, event_space_bijector)
def init_near_unconstrained_zero( model=None, constraining_bijector=None, event_shapes=None, event_shape_tensors=None, batch_shapes=None, batch_shape_tensors=None, dtypes=None): """Returns an initialization Distribution for starting a Markov chain. This initialization scheme follows Stan: we sample every latent independently, uniformly from -2 to 2 in its unconstrained space, and then transform into constrained space to construct an initial state that can be passed to `sample_chain` or other MCMC drivers. The argument signature is arranged to let the user pass either a `JointDistribution` describing their model, if it's in that form, or the essential information necessary for the sampling, namely a bijector (from unconstrained to constrained space) and the desired shape and dtype of each sample (specified in constrained space). Note: As currently implemented, this function has the limitation that the batch shape of the supplied model is ignored, but that could probably be generalized if needed. Args: model: A `Distribution` (typically a `JointDistribution`) giving the model to be initialized. If supplied, it is queried for its default event space bijector, its event shape, and its dtype. If not supplied, those three elements must be supplied instead. constraining_bijector: A (typically multipart) `Bijector` giving the mapping from unconstrained to constrained space. If supplied together with a `model`, acts as an override. A nested structure of `Bijector`s is accepted, and interpreted as applying in parallel to a corresponding structure of state parts (see `JointMap` for details). event_shapes: A structure of shapes giving the (unconstrained) event space shape of the desired samples. Must be an acceptable input to `constraining_bijector.inverse_event_shape`. If supplied together with `model`, acts as an override. event_shape_tensors: A structure of tensors giving the (unconstrained) event space shape of the desired samples. Must be an acceptable input to `constraining_bijector.inverse_event_shape_tensor`. If supplied together with `model`, acts as an override. Required if any of `event_shapes` are not fully-defined. batch_shapes: A structure of shapes giving the batch shape of the desired samples. If supplied together with `model`, acts as an override. If unspecified, we assume scalar batch `[]`. batch_shape_tensors: A structure of tensors giving the batch shape of the desired samples. If supplied together with `model`, acts as an override. Required if any of `batch_shapes` are not fully-defined. dtypes: A structure of dtypes giving the (unconstrained) dtypes of the desired samples. Must be an acceptable input to `constraining_bijector.inverse_dtype`. If supplied together with `model`, acts as an override. Returns: init_dist: A `Distribution` representing the initialization distribution, in constrained space. Samples from this `Distribution` are valid initial states for a Markov chain targeting the model. #### Example Initialize 100 chains from the unconstrained -2, 2 distribution for a model expressed as a `JointDistributionCoroutine`: ```python @tfp.distributions.JointDistributionCoroutine def model(): ... init_dist = tfp.experimental.mcmc.init_near_unconstrained_zero(model) states = tfp.mcmc.sample_chain( current_state=init_dist.sample(100, seed=[4, 8]), ...) ``` """ # Canonicalize arguments into the parts we need, namely # the constraining_bijector, the event_shapes, and the dtypes. if model is not None: # Got a Distribution model; treat other arguments as overrides if # present. if constraining_bijector is None: # pylint: disable=protected-access constraining_bijector = model.experimental_default_event_space_bijector() if event_shapes is None: event_shapes = model.event_shape if event_shape_tensors is None: event_shape_tensors = model.event_shape_tensor() if dtypes is None: dtypes = model.dtype if batch_shapes is None: batch_shapes = nest_util.broadcast_structure(dtypes, model.batch_shape) if batch_shape_tensors is None: batch_shape_tensors = nest_util.broadcast_structure( dtypes, model.batch_shape_tensor()) else: if constraining_bijector is None or event_shapes is None or dtypes is None: msg = ('Must pass either a Distribution (typically a JointDistribution), ' 'or a bijector, a structure of event shapes, and a ' 'structure of dtypes') raise ValueError(msg) event_shapes_fully_defined = all(tensorshape_util.is_fully_defined(s) for s in tf.nest.flatten(event_shapes)) if not event_shapes_fully_defined and event_shape_tensors is None: raise ValueError('Must specify `event_shape_tensors` when `event_shapes` ' f'are not fully-defined: {event_shapes}') if batch_shapes is None: batch_shapes = tf.TensorShape([]) batch_shapes = nest_util.broadcast_structure(dtypes, batch_shapes) batch_shapes_fully_defined = all(tensorshape_util.is_fully_defined(s) for s in tf.nest.flatten(batch_shapes)) if batch_shape_tensors is None: if not batch_shapes_fully_defined: raise ValueError( 'Must specify `batch_shape_tensors` when `batch_shapes` are not ' f'fully-defined: {batch_shapes}') batch_shape_tensors = tf.nest.map_structure( tf.convert_to_tensor, batch_shapes) # Interpret a structure of Bijectors as the joint multipart bijector. if not isinstance(constraining_bijector, tfb.Bijector): constraining_bijector = tfb.JointMap(constraining_bijector) # Actually initialize def one_term(event_shape, event_shape_tensor, batch_shape, batch_shape_tensor, dtype): if not tensorshape_util.is_fully_defined(event_shape): event_shape = event_shape_tensor result = tfd.Sample( tfd.Uniform(low=tf.constant(-2., dtype=dtype), high=tf.constant(2., dtype=dtype)), sample_shape=event_shape) if not tensorshape_util.is_fully_defined(batch_shape): batch_shape = batch_shape_tensor needs_bcast = True else: # Only batch broadcast when batch ndims > 0. needs_bcast = bool(tensorshape_util.as_list(batch_shape)) if needs_bcast: result = tfd.BatchBroadcast(result, batch_shape) return result inv_shapes = constraining_bijector.inverse_event_shape(event_shapes) if event_shape_tensors is not None: inv_shape_tensors = constraining_bijector.inverse_event_shape_tensor( event_shape_tensors) else: inv_shape_tensors = tf.nest.map_structure(lambda _: None, inv_shapes) inv_dtypes = constraining_bijector.inverse_dtype(dtypes) terms = tf.nest.map_structure( one_term, inv_shapes, inv_shape_tensors, batch_shapes, batch_shape_tensors, inv_dtypes) unconstrained = tfb.pack_sequence_as(inv_shapes)( tfd.JointDistributionSequential(tf.nest.flatten(terms))) return tfd.TransformedDistribution( unconstrained, bijector=constraining_bijector)
class BatchShapeInferenceTests(test_util.TestCase): @parameterized.named_parameters( {'testcase_name': '_trivial', 'value_fn': lambda: tfd.Normal(loc=0., scale=1.), 'expected_batch_shape': []}, {'testcase_name': '_simple_tensor_broadcasting', 'value_fn': lambda: tfd.MultivariateNormalDiag( # pylint: disable=g-long-lambda loc=[0., 0.], scale_diag=tf.convert_to_tensor([[1., 1.], [1., 1.]])), 'expected_batch_shape': [2]}, {'testcase_name': '_rank_deficient_tensor_broadcasting', 'value_fn': lambda: tfd.MultivariateNormalDiag( # pylint: disable=g-long-lambda loc=0., scale_diag=tf.convert_to_tensor([[1., 1.], [1., 1.]])), 'expected_batch_shape': [2]}, {'testcase_name': '_mixture_same_family', 'value_fn': lambda: tfd.MixtureSameFamily( # pylint: disable=g-long-lambda mixture_distribution=tfd.Categorical( logits=[[[1., 2., 3.], [4., 5., 6.]]]), components_distribution=tfd.Normal(loc=0., scale=[[[1., 2., 3.], [4., 5., 6.]]])), 'expected_batch_shape': [1, 2]}, {'testcase_name': '_deeply_nested', 'value_fn': lambda: tfd.Independent( # pylint: disable=g-long-lambda tfd.Independent( tfd.Independent( tfd.Independent( tfd.Normal(loc=0., scale=[[[[[[[[1.]]]]]]]]), reinterpreted_batch_ndims=2), reinterpreted_batch_ndims=0), reinterpreted_batch_ndims=1), reinterpreted_batch_ndims=1), 'expected_batch_shape': [1, 1, 1, 1]}) def test_batch_shape_inference_is_correct( self, value_fn, expected_batch_shape): value = value_fn() # Defer construction until we're in the right graph. self.assertAllEqual( expected_batch_shape, value.batch_shape_tensor()) batch_shape = value.batch_shape self.assertIsInstance(batch_shape, tf.TensorShape) self.assertTrue( batch_shape.is_compatible_with(expected_batch_shape)) def assert_all_parameters_have_full_batch_shape( self, dist, expected_batch_shape): self.assertAllEqual(expected_batch_shape, dist.batch_shape_tensor()) param_batch_shapes = batch_shape_lib.batch_shape_parts(dist) for param_batch_shape in param_batch_shapes.values(): self.assertAllEqual(expected_batch_shape, param_batch_shape) @parameterized.named_parameters( {'testcase_name': '_trivial', 'dist_fn': lambda: tfd.Normal(loc=0., scale=1.)}, {'testcase_name': '_simple_tensor_broadcasting', 'dist_fn': lambda: tfd.MultivariateNormalDiag( # pylint: disable=g-long-lambda loc=[0., 0.], scale_diag=[[1., 1.], [1., 1.]])}, {'testcase_name': '_rank_deficient_tensor_broadcasting', 'dist_fn': lambda: tfd.MultivariateNormalDiag( # pylint: disable=g-long-lambda loc=0., scale_diag=[[1., 1.], [1., 1.]])}, {'testcase_name': '_deeply_nested', 'dist_fn': lambda: tfd.Independent( # pylint: disable=g-long-lambda tfd.Independent( tfd.Independent( tfd.Independent( tfd.Normal(loc=0., scale=[[[[[[[[1.]]]]]]]]), reinterpreted_batch_ndims=2), reinterpreted_batch_ndims=0), reinterpreted_batch_ndims=1), reinterpreted_batch_ndims=1)}, {'testcase_name': '_transformed_dist_simple', 'dist_fn': lambda: tfd.TransformedDistribution( # pylint: disable=g-long-lambda tfd.Normal(loc=[[1., 2., 3.], [3., 4., 5.]], scale=[1.]), tfb.Scale(scale=[2., 3., 4.]))}, {'testcase_name': '_transformed_dist_with_chain', 'dist_fn': lambda: tfd.TransformedDistribution( # pylint: disable=g-long-lambda tfd.Normal(loc=[[1., 2., 3.], [3., 4., 5.]], scale=[1.]), tfb.Shift(-4.)(tfb.Scale(scale=[2., 3., 4.])))}, {'testcase_name': '_transformed_dist_multipart_nested', 'dist_fn': lambda: tfd.TransformedDistribution( # pylint: disable=g-long-lambda tfd.TransformedDistribution( tfd.TransformedDistribution( tfd.MultivariateNormalDiag(tf.zeros([4, 6]), tf.ones([6])), tfb.Split([3, 3])), tfb.JointMap([tfb.Identity(), tfb.Reshape([3, 1])])), tfb.JointMap([tfb.Scale(scale=[2., 3., 4.]), tfb.Shift(1.)]))} ) def test_batch_broadcasting(self, dist_fn): dist = dist_fn() broadcast_dist = dist._broadcast_parameters_with_batch_shape( dist.batch_shape) self.assert_all_parameters_have_full_batch_shape( broadcast_dist, expected_batch_shape=broadcast_dist.batch_shape_tensor()) expanded_batch_shape = ps.concat([[7, 4], dist.batch_shape], axis=0) broadcast_params = batch_shape_lib.broadcast_parameters_with_batch_shape( dist, expanded_batch_shape) broadcast_dist = dist.copy(**broadcast_params) self.assert_all_parameters_have_full_batch_shape( broadcast_dist, expected_batch_shape=expanded_batch_shape)
class BijectorBatchShapesTest(test_util.TestCase): @parameterized.named_parameters( ('exp', tfb.Exp, None), ('scale', lambda: tfb.Scale(tf.ones([4, 2])), None), ('sigmoid', lambda: tfb.Sigmoid(low=tf.zeros([3]), high=tf.ones([4, 1])), None), ('scale_matvec', lambda: tfb.ScaleMatvecDiag([[0.], [3.]]), None), ('invert', lambda: tfb.Invert(tfb.ScaleMatvecDiag(tf.ones([2, 1]))), None), ('reshape', lambda: tfb.Reshape([1], event_shape_in=[1, 1]), None), ('chain', lambda: tfb.Chain([tfb.Power(power=[[2.], [3.]]), # pylint: disable=g-long-lambda tfb.Invert(tfb.Split(2))]), None), ('jointmap_01', lambda: tfb.JointMap([tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [0, 1]), ('jointmap_11', lambda: tfb.JointMap([tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [1, 1]), ('jointmap_20', lambda: tfb.JointMap([tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [2, 0]), ('jointmap_22', lambda: tfb.JointMap([tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [2, 2]), ('restructure_with_ragged_event_ndims', lambda: tfb.Restructure(input_structure=[0, 1], # pylint: disable=g-long-lambda output_structure={'a': 0, 'b': 1}), [0, 1])) def test_batch_shape_matches_output_shapes(self, bijector_fn, override_x_event_ndims=None): bijector = bijector_fn() if override_x_event_ndims is None: x_event_ndims = bijector.forward_min_event_ndims y_event_ndims = bijector.inverse_min_event_ndims else: x_event_ndims = override_x_event_ndims y_event_ndims = bijector.forward_event_ndims(x_event_ndims) # All ways of calculating the batch shape should yield the same result. batch_shape_x = bijector.experimental_batch_shape( x_event_ndims=x_event_ndims) batch_shape_y = bijector.experimental_batch_shape( y_event_ndims=y_event_ndims) self.assertEqual(batch_shape_x, batch_shape_y) batch_shape_tensor_x = bijector.experimental_batch_shape_tensor( x_event_ndims=x_event_ndims) batch_shape_tensor_y = bijector.experimental_batch_shape_tensor( y_event_ndims=y_event_ndims) self.assertAllEqual(batch_shape_tensor_x, batch_shape_tensor_y) self.assertAllEqual(batch_shape_tensor_x, batch_shape_x) # Check that we're robust to integer type. batch_shape_tensor_x64 = bijector.experimental_batch_shape_tensor( x_event_ndims=tf.nest.map_structure(np.int64, x_event_ndims)) batch_shape_tensor_y64 = bijector.experimental_batch_shape_tensor( y_event_ndims=tf.nest.map_structure(np.int64, y_event_ndims)) self.assertAllEqual(batch_shape_tensor_x64, batch_shape_tensor_y64) self.assertAllEqual(batch_shape_tensor_x64, batch_shape_x) # Pushing a value through the bijector should return a Tensor(s) with # the expected batch shape... xs = tf.nest.map_structure(lambda nd: tf.ones([1] * nd), x_event_ndims) ys = bijector.forward(xs) for y_part, nd in zip(tf.nest.flatten(ys), tf.nest.flatten(y_event_ndims)): part_batch_shape = ps.shape(y_part)[:ps.rank(y_part) - nd] self.assertAllEqual(batch_shape_y, ps.broadcast_shape(batch_shape_y, part_batch_shape)) # ... which should also be the shape of the fldj. fldj = bijector.forward_log_det_jacobian(xs, event_ndims=x_event_ndims) self.assertAllEqual(batch_shape_y, ps.shape(fldj)) # Also check the inverse case. xs = bijector.inverse(tf.nest.map_structure(tf.identity, ys)) for x_part, nd in zip(tf.nest.flatten(xs), tf.nest.flatten(x_event_ndims)): part_batch_shape = ps.shape(x_part)[:ps.rank(x_part) - nd] self.assertAllEqual(batch_shape_x, ps.broadcast_shape(batch_shape_x, part_batch_shape)) ildj = bijector.inverse_log_det_jacobian(ys, event_ndims=y_event_ndims) self.assertAllEqual(batch_shape_x, ps.shape(ildj)) @parameterized.named_parameters( ('scale', lambda: tfb.Scale([3.14159])), ('chain', lambda: tfb.Exp()(tfb.Scale([3.14159])))) def test_ndims_specification(self, bijector_fn): bijector = bijector_fn() # If no `event_ndims` is passed, should assume min_event_ndims. self.assertAllEqual(bijector.experimental_batch_shape(), [1]) self.assertAllEqual(bijector.experimental_batch_shape_tensor(), [1]) with self.assertRaisesRegex( ValueError, 'Only one of `x_event_ndims` and `y_event_ndims`'): bijector.experimental_batch_shape(x_event_ndims=0, y_event_ndims=0) with self.assertRaisesRegex( ValueError, 'Only one of `x_event_ndims` and `y_event_ndims`'): bijector.experimental_batch_shape_tensor(x_event_ndims=0, y_event_ndims=0) @parameterized.named_parameters( ('scale', lambda: tfb.Scale(tf.ones([4, 2])), None), ('sigmoid', lambda: tfb.Sigmoid(low=tf.zeros([3]), high=tf.ones([4, 1])), None), ('invert', lambda: tfb.Invert(tfb.ScaleMatvecDiag(tf.ones([2, 1]))), None), ('chain', lambda: tfb.Chain([tfb.Power(power=[[2.], [3.]]), # pylint: disable=g-long-lambda tfb.Invert(tfb.Split(2))]), None), ('jointmap_01', lambda: tfb.JointMap( # pylint: disable=g-long-lambda [tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [0, 1]), ('jointmap_11', lambda: tfb.JointMap( # pylint: disable=g-long-lambda [tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [1, 1]), ('jointmap_20', lambda: tfb.JointMap( # pylint: disable=g-long-lambda [tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [2, 0]), ('jointmap_22', lambda: tfb.JointMap( # pylint: disable=g-long-lambda [tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [2, 2]), ('nested_jointmap', lambda: tfb.JointMap([tfb.JointMap({'a': tfb.Scale([1.]), # pylint: disable=g-long-lambda 'b': tfb.Exp()}), tfb.Scale([1, 4])(tfb.Invert(tfb.Split(2)))]), [{'a': 0, 'b': 0}, [2, 2]])) def test_with_broadcast_batch_shape(self, bijector_fn, x_event_ndims=None): bijector = bijector_fn() if x_event_ndims is None: x_event_ndims = bijector.forward_min_event_ndims batch_shape = bijector.experimental_batch_shape(x_event_ndims=x_event_ndims) param_batch_shapes = batch_shape_lib.batch_shape_parts( bijector, bijector_x_event_ndims=x_event_ndims) new_batch_shape = [4, 2, 1, 1, 1] broadcast_bijector = bijector._broadcast_parameters_with_batch_shape( new_batch_shape, x_event_ndims) broadcast_batch_shape = broadcast_bijector.experimental_batch_shape_tensor( x_event_ndims=x_event_ndims) self.assertAllEqual(broadcast_batch_shape, ps.broadcast_shape(batch_shape, new_batch_shape)) # Check that all params have the expected batch shape. broadcast_param_batch_shapes = batch_shape_lib.batch_shape_parts( broadcast_bijector, bijector_x_event_ndims=x_event_ndims) def _maybe_broadcast_param_batch_shape(p, s): if isinstance(p, tfb.Invert) and not p.bijector._params_event_ndims(): return s # Can't broadcast a bijector that doesn't itself have params. return ps.broadcast_shape(s, new_batch_shape) expected_broadcast_param_batch_shapes = tf.nest.map_structure( _maybe_broadcast_param_batch_shape, {param: getattr(bijector, param) for param in param_batch_shapes}, param_batch_shapes) self.assertAllEqualNested(broadcast_param_batch_shapes, expected_broadcast_param_batch_shapes)
def test_inverse_has_event_ndims(self): bij_reshape = tfb.Invert(tfb.JointMap([tfb.Reshape([])])) bij_reshape.inverse_event_ndims([10]) # expect [9] self.assertEqual(bij_reshape.inverse_event_ndims([10]), [9])
def test_transform_joint_to_joint(self, split_sizes): dist_batch_shape = tf.nest.pack_sequence_as(split_sizes, [ tensorshape_util.constant_value_as_shape(s) for s in [[2, 3], [2, 1], [1, 3]] ]) bijector_batch_shape = [1, 3] # Build a joint distribution with parts of the specified sizes. seed = test_util.test_seed_stream() component_dists = tf.nest.map_structure( lambda size, batch_shape: tfd.MultivariateNormalDiag( # pylint: disable=g-long-lambda loc=tf.random.normal(batch_shape + [size], seed=seed()), scale_diag=tf.random.uniform(minval=1., maxval=2., shape=batch_shape + [size], seed=seed())), split_sizes, dist_batch_shape) if isinstance(split_sizes, dict): base_dist = tfd.JointDistributionNamed(component_dists) else: base_dist = tfd.JointDistributionSequential(component_dists) # Transform the distribution by applying a separate bijector to each part. bijectors = [ tfb.Exp(), tfb.Scale( tf.random.uniform(minval=1., maxval=2., shape=bijector_batch_shape, seed=seed())), tfb.Reshape([2, 1]) ] bijector = tfb.JointMap(tf.nest.pack_sequence_as( split_sizes, bijectors), validate_args=True) # Transform a joint distribution that has different batch shape components transformed_dist = tfd.TransformedDistribution(base_dist, bijector) self.assertRegex( str(transformed_dist), '{}.*batch_shape.*event_shape.*dtype'.format( transformed_dist.name)) self.assertAllEqualNested( transformed_dist.event_shape, bijector.forward_event_shape(base_dist.event_shape)) self.assertAllEqualNested( *self.evaluate((transformed_dist.event_shape_tensor(), bijector.forward_event_shape_tensor( base_dist.event_shape_tensor())))) # Test that the batch shape components of the input are the same as those of # the output. self.assertAllEqualNested(transformed_dist.batch_shape, dist_batch_shape) self.assertAllEqualNested( self.evaluate(transformed_dist.batch_shape_tensor()), dist_batch_shape) self.assertAllEqualNested(dist_batch_shape, base_dist.batch_shape) # Check transformed `log_prob` against the base distribution. sample_shape = [3] sample = base_dist.sample(sample_shape, seed=seed()) x = tf.nest.map_structure(tf.zeros_like, sample) y = bijector.forward(x) base_logprob = base_dist.log_prob(x) event_ndims = tf.nest.map_structure(lambda s: s.ndims, transformed_dist.event_shape) ildj = bijector.inverse_log_det_jacobian(y, event_ndims=event_ndims) (transformed_logprob, base_logprob_plus_ildj, log_transformed_prob) = self.evaluate([ transformed_dist.log_prob(y), base_logprob + ildj, tf.math.log(transformed_dist.prob(y)) ]) self.assertAllClose(base_logprob_plus_ildj, transformed_logprob) self.assertAllClose(transformed_logprob, log_transformed_prob) # Test that `.sample()` works and returns a result of the expected structure # and shape. y_sampled = transformed_dist.sample(sample_shape, seed=seed()) self.assertAllEqual( tf.nest.map_structure(lambda y: y.shape, y), tf.nest.map_structure(lambda y: y.shape, y_sampled)) # Test that a `Restructure` bijector applied to a `JointDistribution` works # as expected. num_components = len(split_sizes) input_keys = (split_sizes.keys() if isinstance(split_sizes, dict) else range(num_components)) output_keys = [str(i) for i in range(num_components)] output_structure = {k: v for k, v in zip(output_keys, input_keys)} restructure = tfb.Restructure(output_structure) restructured_dist = tfd.TransformedDistribution(base_dist, bijector=restructure, validate_args=True) # Check that attributes of the restructured distribution have the same # nested structure as the `output_structure` of the bijector. Pass a no-op # as the `assert_fn` since the contents of the structures are not # required to be the same. noop_assert_fn = lambda *_: None self.assertAllAssertsNested(noop_assert_fn, restructured_dist.event_shape, output_structure) self.assertAllAssertsNested(noop_assert_fn, restructured_dist.batch_shape, output_structure) self.assertAllAssertsNested( noop_assert_fn, self.evaluate(restructured_dist.event_shape_tensor()), output_structure) self.assertAllAssertsNested( noop_assert_fn, self.evaluate(restructured_dist.batch_shape_tensor()), output_structure) self.assertAllAssertsNested( noop_assert_fn, self.evaluate( restructured_dist.sample(seed=test_util.test_seed())))