def sharded_model(): x = yield root(tfd.LogNormal(0., 1., name='x')) yield sharded.Sharded( tfd.Uniform(0., x), shard_axis_name=self.axis_name, name='y') yield sharded.Sharded( tfb.Scale(x)(tfd.Normal(0., 1.)), shard_axis_name=self.axis_name, name='z')
def test_duplicate_axes_in_jax(self): if not JAX_MODE: self.skipTest('This error is JAX-only.') dist = sharded.Sharded(tfd.Normal(0., 1.), shard_axis_name='i') with self.assertRaisesRegex(ValueError, 'Found duplicate axis name'): sharded.Sharded(dist, shard_axis_name='i') with self.assertRaisesRegex(ValueError, 'Found duplicate axis name'): sharded.Sharded(tfd.Normal(0., 1.), shard_axis_name=['i', 'i'])
def make_jd_sequential(axis_name): return jd.JointDistributionSequential([ tfd.Normal(0., 1.), lambda w: sharded.Sharded( # pylint: disable=g-long-lambda tfd.Sample(tfd.Normal(w, 1.), 1), shard_axis_name=axis_name), lambda x: sharded.Sharded( # pylint: disable=g-long-lambda tfd.Independent(tfd.Normal(x, 1.), 1), shard_axis_name=axis_name), ])
def make_jd_named(axis_name): return jd.JointDistributionNamed( # pylint: disable=g-long-lambda dict( w=tfd.Normal(0., 1.), x=lambda w: sharded.Sharded( # pylint: disable=g-long-lambda tfd.Sample(tfd.Normal(w, 1.), 1), shard_axis_name=axis_name), data=lambda x: sharded.Sharded( # pylint: disable=g-long-lambda tfd.Independent(tfd.Normal(x, 1.), 1), shard_axis_name=axis_name), ))
def manual_sharded_model(): # This one has manual pbroadcasts; the goal is to get sharded_model above # to do this automatically. x = yield root(tfd.LogNormal(0., 1., name='x')) x = distribute_lib.pbroadcast(x, axis_name=self.axis_name) yield sharded.Sharded( tfd.Uniform(0., x), shard_axis_name=self.axis_name, name='y') yield sharded.Sharded( tfb.Scale(x)(tfd.Normal(0., 1.)), shard_axis_name=self.axis_name, name='z')
def test_nested_sharded_maintains_correct_axis_ordering(self): if not JAX_MODE: self.skipTest('Multiple axes only supported in JAX backend.') other_axis_name = self.axis_name + '_other' dist = sharded.Sharded( sharded.Sharded(tfd.Normal(0., 1.), self.axis_name), other_axis_name) self.assertListEqual(dist.experimental_shard_axis_names, [self.axis_name, other_axis_name])
def test_multiple_axes_in_tensorflow_error(self): if JAX_MODE: self.skipTest('This error is TensorFlow-only.') dist = sharded.Sharded(tfd.Normal(0., 1.), shard_axis_name='i') with self.assertRaisesRegex( ValueError, 'TensorFlow backend does not support multiple shard axes'): sharded.Sharded(dist, shard_axis_name='j') with self.assertRaisesRegex( ValueError, 'TensorFlow backend does not support multiple shard axes'): sharded.Sharded(tfd.Normal(0., 1.), shard_axis_name=['i', 'j'])
def test_none_axis_in_jax_error(self): if not JAX_MODE: self.skipTest('This error is JAX-only.') with self.assertRaisesRegex( ValueError, 'Cannot provide a `None` axis name in JAX backend.'): sharded.Sharded(tfd.Normal(0., 1.))
def model_fn(): root = tfp.experimental.distribute.JointDistributionCoroutine.Root _ = yield root( sharded.Sharded(tfd.Independent( tfd.Normal(0, tf.ones([7])), reinterpreted_batch_ndims=1, experimental_use_kahan_sum=True), shard_axis_name=self.axis_name))
def update_momentum_distribution(momentum_distribution, running_variance_parts): """Updates a momentum distribution with new running variance. Args: momentum_distribution: Distribution arranged like a result of `make_momentum_distribution`. running_variance_parts: List of `Tensor` outputs of `tfp.experimental.stats.RunningVariance.variance()`. Returns: `tfd.Distribution` where `.sample` has the same structure as `state_parts`, and `.log_prob` of the sample will have the rank of `batch_ndims` """ model = [] if len(running_variance_parts) != len(momentum_distribution.model): raise ValueError( 'State size mismatch: ' f'{len(running_variance_parts)} vs {len(momentum_distribution.model)}' ) for var, bb in zip(running_variance_parts, momentum_distribution.model): # TODO(b/182603117): Check public BatchBroadcast when Sharded is # guaranteed to be CompositeTensor. if not isinstance(bb, batch_broadcast._BatchBroadcast): # pylint: disable=protected-access raise ValueError(f'Part dist is not a BatchBroadcast: {bb}') td = bb.distribution is_sharded, shard_axes = False, None if isinstance(td, sharded.Sharded): is_sharded, shard_axes = True, td.experimental_shard_axis_names td = td.distribution if not isinstance(td, transformed_distribution._TransformedDistribution): # pylint:disable=protected-access raise ValueError( f'Inner dist is not a TransformedDistribution: {td}') mvnpfl = td.distribution if not isinstance( mvnpfl, mvn_pfl.MultivariateNormalPrecisionFactorLinearOperator): raise ValueError( 'Inner dist is not a ' f'MultivariateNormalPrecisionFactorLinearOperator: {mvnpfl}') var_flat = td.bijector.inverse(var) var_flat_bc = tf.broadcast_to(var_flat, ps.shape(mvnpfl.precision.diag_part())) mvnpfl = mvnpfl.copy( precision_factor=tf.linalg.LinearOperatorDiag( tf.math.sqrt(var_flat_bc)), precision=tf.linalg.LinearOperatorDiag(var_flat_bc)) td = td.copy(distribution=mvnpfl) if is_sharded: td = sharded.Sharded(td, shard_axis_name=shard_axes) model_dist = bb.copy(distribution=td) model.append(model_dist) return momentum_distribution.copy(model=model)
def run(key, _): return sharded.Sharded( tfd.Sample(tfd.Normal(0., 1.), 1), shard_axis_name=[self.axis_name, other_axis_name]).sample(seed=key)
def model(): yield Root(tfd.Normal(1., 1.)) yield Root(tfd.Normal(1., 1.)) yield sharded.Sharded(tfd.Normal(1., 1.), self.axis_name)
def test_none_axis_in_tensorflow(self): if JAX_MODE: self.skipTest('This feature is TensorFlow-only.') dist = sharded.Sharded(tfd.Normal(0., 1.)) self.assertEqual([True], dist.experimental_shard_axis_names)
def sharded_model(): w = yield root(tfd.Normal(prior_mean, 1.)) yield root( sharded.Sharded(increment_log_prob.IncrementLogProb( custom_ll(w, x)), shard_axis_name=self.axis_name))
def surrogate(): x = yield Root(tfd.Normal(x_loc, 1.)) y = yield tfd.Normal(x, 1.) yield sharded.Sharded(tfd.Normal(x + y, 1.), self.axis_name)
def model_coroutine(): w = yield tfd.JointDistributionCoroutine.Root(tfd.Normal(0., 1.)) x = yield sharded.Sharded( tfd.Sample(tfd.Normal(w, 1.), 1), shard_axis_name=axis_name) yield sharded.Sharded( tfd.Independent(tfd.Normal(x, 1.), 1), shard_axis_name=axis_name)
def model(): yield Root(tfd.Normal(0., 1., name='x')) yield tfd.Normal(0., 1., name='y') yield sharded.Sharded(tfd.Normal(1., 1.), self.axis_name, name='z')
def make_momentum_distribution(state_parts, batch_shape, running_variance_parts=None, shard_axis_names=None): """Construct a momentum distribution from the running variance. This uses a running variance to construct a momentum distribution with the correct batch_shape and event_shape. Args: state_parts: List of `Tensor`. batch_shape: Batch shape. running_variance_parts: Optional, list of `Tensor` outputs of `tfp.experimental.stats.RunningVariance.variance()`. Defaults to ones with the same shape as state_parts. shard_axis_names: A structure of string names indicating how members of the state are sharded. Returns: `tfd.Distribution` where `.sample` has the same structure as `state_parts`, and `.log_prob` of the sample will have the rank of `batch_ndims` """ if running_variance_parts is None: running_variance_parts = tf.nest.map_structure(tf.ones_like, state_parts) distributions = [] batch_ndims = ps.rank_from_shape(batch_shape) use_sharded_jd = True if shard_axis_names is None: use_sharded_jd = False shard_axis_names = [None] * len(state_parts) for variance_part, state_part, shard_axes in zip(running_variance_parts, state_parts, shard_axis_names): event_shape = state_part.shape[batch_ndims:] if not tensorshape_util.is_fully_defined(event_shape): event_shape = ps.shape(state_part, name='state_part_shp')[batch_ndims:] variance_tiled = tf.broadcast_to( variance_part, ps.concat([batch_shape, event_shape], axis=0)) nevt = ps.cast(ps.reduce_prod(event_shape), tf.int32) variance_flattened = tf.reshape( variance_tiled, ps.concat([batch_shape, [nevt]], axis=0)) distribution = _CompositeTransformedDistribution( bijector=reshape.Reshape(event_shape_out=event_shape, name='reshape_mvnpfl'), distribution=( _CompositeMultivariateNormalPrecisionFactorLinearOperator( precision_factor=tf.linalg.LinearOperatorDiag( tf.math.sqrt(variance_flattened)), precision=tf.linalg.LinearOperatorDiag(variance_flattened), name='momentum'))) if shard_axes: distribution = sharded.Sharded(distribution, shard_axis_name=shard_axes) distributions.append(distribution) if use_sharded_jd: jd = _CompositeShardedJointDistributionSequential(distributions) else: jd = _CompositeJointDistributionSequential(distributions) return maybe_make_list_and_batch_broadcast(jd, batch_shape)
def run(key): return sharded.Sharded( tfd.Sample(tfd.Normal(0., 1.), 1), shard_axis_name=self.axis_name).sample(seed=key)