Esempio n. 1
0
 def sharded_model():
   x = yield root(tfd.LogNormal(0., 1., name='x'))
   yield sharded.Sharded(
       tfd.Uniform(0., x), shard_axis_name=self.axis_name, name='y')
   yield sharded.Sharded(
       tfb.Scale(x)(tfd.Normal(0., 1.)),
       shard_axis_name=self.axis_name,
       name='z')
Esempio n. 2
0
 def test_duplicate_axes_in_jax(self):
     if not JAX_MODE:
         self.skipTest('This error is JAX-only.')
     dist = sharded.Sharded(tfd.Normal(0., 1.), shard_axis_name='i')
     with self.assertRaisesRegex(ValueError, 'Found duplicate axis name'):
         sharded.Sharded(dist, shard_axis_name='i')
     with self.assertRaisesRegex(ValueError, 'Found duplicate axis name'):
         sharded.Sharded(tfd.Normal(0., 1.), shard_axis_name=['i', 'i'])
def make_jd_sequential(axis_name):
  return jd.JointDistributionSequential([
      tfd.Normal(0., 1.),
      lambda w: sharded.Sharded(  # pylint: disable=g-long-lambda
          tfd.Sample(tfd.Normal(w, 1.), 1),
          shard_axis_name=axis_name),
      lambda x: sharded.Sharded(  # pylint: disable=g-long-lambda
          tfd.Independent(tfd.Normal(x, 1.), 1),
          shard_axis_name=axis_name),
  ])
def make_jd_named(axis_name):
  return jd.JointDistributionNamed(  # pylint: disable=g-long-lambda
      dict(
          w=tfd.Normal(0., 1.),
          x=lambda w: sharded.Sharded(  # pylint: disable=g-long-lambda
              tfd.Sample(tfd.Normal(w, 1.), 1),
              shard_axis_name=axis_name),
          data=lambda x: sharded.Sharded(  # pylint: disable=g-long-lambda
              tfd.Independent(tfd.Normal(x, 1.), 1),
              shard_axis_name=axis_name),
      ))
Esempio n. 5
0
 def manual_sharded_model():
   # This one has manual pbroadcasts; the goal is to get sharded_model above
   # to do this automatically.
   x = yield root(tfd.LogNormal(0., 1., name='x'))
   x = distribute_lib.pbroadcast(x, axis_name=self.axis_name)
   yield sharded.Sharded(
       tfd.Uniform(0., x), shard_axis_name=self.axis_name, name='y')
   yield sharded.Sharded(
       tfb.Scale(x)(tfd.Normal(0., 1.)),
       shard_axis_name=self.axis_name,
       name='z')
Esempio n. 6
0
    def test_nested_sharded_maintains_correct_axis_ordering(self):
        if not JAX_MODE:
            self.skipTest('Multiple axes only supported in JAX backend.')

        other_axis_name = self.axis_name + '_other'

        dist = sharded.Sharded(
            sharded.Sharded(tfd.Normal(0., 1.), self.axis_name),
            other_axis_name)

        self.assertListEqual(dist.experimental_shard_axis_names,
                             [self.axis_name, other_axis_name])
Esempio n. 7
0
 def test_multiple_axes_in_tensorflow_error(self):
     if JAX_MODE:
         self.skipTest('This error is TensorFlow-only.')
     dist = sharded.Sharded(tfd.Normal(0., 1.), shard_axis_name='i')
     with self.assertRaisesRegex(
             ValueError,
             'TensorFlow backend does not support multiple shard axes'):
         sharded.Sharded(dist, shard_axis_name='j')
     with self.assertRaisesRegex(
             ValueError,
             'TensorFlow backend does not support multiple shard axes'):
         sharded.Sharded(tfd.Normal(0., 1.), shard_axis_name=['i', 'j'])
Esempio n. 8
0
 def test_none_axis_in_jax_error(self):
     if not JAX_MODE:
         self.skipTest('This error is JAX-only.')
     with self.assertRaisesRegex(
             ValueError,
             'Cannot provide a `None` axis name in JAX backend.'):
         sharded.Sharded(tfd.Normal(0., 1.))
Esempio n. 9
0
 def model_fn():
     root = tfp.experimental.distribute.JointDistributionCoroutine.Root
     _ = yield root(
         sharded.Sharded(tfd.Independent(
             tfd.Normal(0, tf.ones([7])),
             reinterpreted_batch_ndims=1,
             experimental_use_kahan_sum=True),
                         shard_axis_name=self.axis_name))
Esempio n. 10
0
def update_momentum_distribution(momentum_distribution,
                                 running_variance_parts):
    """Updates a momentum distribution with new running variance.

  Args:
    momentum_distribution: Distribution arranged like a result of
      `make_momentum_distribution`.
    running_variance_parts: List of `Tensor` outputs of
      `tfp.experimental.stats.RunningVariance.variance()`.

  Returns:
    `tfd.Distribution` where `.sample` has the same structure as `state_parts`,
    and `.log_prob` of the sample will have the rank of `batch_ndims`
  """
    model = []
    if len(running_variance_parts) != len(momentum_distribution.model):
        raise ValueError(
            'State size mismatch: '
            f'{len(running_variance_parts)} vs {len(momentum_distribution.model)}'
        )
    for var, bb in zip(running_variance_parts, momentum_distribution.model):
        # TODO(b/182603117): Check public BatchBroadcast when Sharded is
        # guaranteed to be CompositeTensor.
        if not isinstance(bb, batch_broadcast._BatchBroadcast):  # pylint: disable=protected-access
            raise ValueError(f'Part dist is not a BatchBroadcast: {bb}')
        td = bb.distribution
        is_sharded, shard_axes = False, None
        if isinstance(td, sharded.Sharded):
            is_sharded, shard_axes = True, td.experimental_shard_axis_names
            td = td.distribution
        if not isinstance(td,
                          transformed_distribution._TransformedDistribution):  # pylint:disable=protected-access
            raise ValueError(
                f'Inner dist is not a TransformedDistribution: {td}')
        mvnpfl = td.distribution
        if not isinstance(
                mvnpfl,
                mvn_pfl.MultivariateNormalPrecisionFactorLinearOperator):
            raise ValueError(
                'Inner dist is not a '
                f'MultivariateNormalPrecisionFactorLinearOperator: {mvnpfl}')
        var_flat = td.bijector.inverse(var)
        var_flat_bc = tf.broadcast_to(var_flat,
                                      ps.shape(mvnpfl.precision.diag_part()))
        mvnpfl = mvnpfl.copy(
            precision_factor=tf.linalg.LinearOperatorDiag(
                tf.math.sqrt(var_flat_bc)),
            precision=tf.linalg.LinearOperatorDiag(var_flat_bc))
        td = td.copy(distribution=mvnpfl)
        if is_sharded:
            td = sharded.Sharded(td, shard_axis_name=shard_axes)
        model_dist = bb.copy(distribution=td)
        model.append(model_dist)
    return momentum_distribution.copy(model=model)
Esempio n. 11
0
 def run(key, _):
     return sharded.Sharded(
         tfd.Sample(tfd.Normal(0., 1.), 1),
         shard_axis_name=[self.axis_name,
                          other_axis_name]).sample(seed=key)
Esempio n. 12
0
 def model():
   yield Root(tfd.Normal(1., 1.))
   yield Root(tfd.Normal(1., 1.))
   yield sharded.Sharded(tfd.Normal(1., 1.), self.axis_name)
Esempio n. 13
0
 def test_none_axis_in_tensorflow(self):
     if JAX_MODE:
         self.skipTest('This feature is TensorFlow-only.')
     dist = sharded.Sharded(tfd.Normal(0., 1.))
     self.assertEqual([True], dist.experimental_shard_axis_names)
Esempio n. 14
0
 def sharded_model():
     w = yield root(tfd.Normal(prior_mean, 1.))
     yield root(
         sharded.Sharded(increment_log_prob.IncrementLogProb(
             custom_ll(w, x)),
                         shard_axis_name=self.axis_name))
Esempio n. 15
0
 def surrogate():
   x = yield Root(tfd.Normal(x_loc, 1.))
   y = yield tfd.Normal(x, 1.)
   yield sharded.Sharded(tfd.Normal(x + y, 1.), self.axis_name)
 def model_coroutine():
   w = yield tfd.JointDistributionCoroutine.Root(tfd.Normal(0., 1.))
   x = yield sharded.Sharded(
       tfd.Sample(tfd.Normal(w, 1.), 1), shard_axis_name=axis_name)
   yield sharded.Sharded(
       tfd.Independent(tfd.Normal(x, 1.), 1), shard_axis_name=axis_name)
Esempio n. 17
0
 def model():
   yield Root(tfd.Normal(0., 1., name='x'))
   yield tfd.Normal(0., 1., name='y')
   yield sharded.Sharded(tfd.Normal(1., 1.), self.axis_name, name='z')
Esempio n. 18
0
def make_momentum_distribution(state_parts,
                               batch_shape,
                               running_variance_parts=None,
                               shard_axis_names=None):
    """Construct a momentum distribution from the running variance.

  This uses a running variance to construct a momentum distribution with the
  correct batch_shape and event_shape.

  Args:
    state_parts: List of `Tensor`.
    batch_shape: Batch shape.
    running_variance_parts: Optional, list of `Tensor`
       outputs of `tfp.experimental.stats.RunningVariance.variance()`. Defaults
       to ones with the same shape as state_parts.
    shard_axis_names: A structure of string names indicating how members of the
      state are sharded.

  Returns:
    `tfd.Distribution` where `.sample` has the same structure as `state_parts`,
    and `.log_prob` of the sample will have the rank of `batch_ndims`
  """
    if running_variance_parts is None:
        running_variance_parts = tf.nest.map_structure(tf.ones_like,
                                                       state_parts)
    distributions = []
    batch_ndims = ps.rank_from_shape(batch_shape)
    use_sharded_jd = True
    if shard_axis_names is None:
        use_sharded_jd = False
        shard_axis_names = [None] * len(state_parts)
    for variance_part, state_part, shard_axes in zip(running_variance_parts,
                                                     state_parts,
                                                     shard_axis_names):
        event_shape = state_part.shape[batch_ndims:]
        if not tensorshape_util.is_fully_defined(event_shape):
            event_shape = ps.shape(state_part,
                                   name='state_part_shp')[batch_ndims:]
        variance_tiled = tf.broadcast_to(
            variance_part, ps.concat([batch_shape, event_shape], axis=0))
        nevt = ps.cast(ps.reduce_prod(event_shape), tf.int32)
        variance_flattened = tf.reshape(
            variance_tiled, ps.concat([batch_shape, [nevt]], axis=0))

        distribution = _CompositeTransformedDistribution(
            bijector=reshape.Reshape(event_shape_out=event_shape,
                                     name='reshape_mvnpfl'),
            distribution=(
                _CompositeMultivariateNormalPrecisionFactorLinearOperator(
                    precision_factor=tf.linalg.LinearOperatorDiag(
                        tf.math.sqrt(variance_flattened)),
                    precision=tf.linalg.LinearOperatorDiag(variance_flattened),
                    name='momentum')))
        if shard_axes:
            distribution = sharded.Sharded(distribution,
                                           shard_axis_name=shard_axes)
        distributions.append(distribution)
    if use_sharded_jd:
        jd = _CompositeShardedJointDistributionSequential(distributions)
    else:
        jd = _CompositeJointDistributionSequential(distributions)
    return maybe_make_list_and_batch_broadcast(jd, batch_shape)
Esempio n. 19
0
 def run(key):
     return sharded.Sharded(
         tfd.Sample(tfd.Normal(0., 1.), 1),
         shard_axis_name=self.axis_name).sample(seed=key)