def make_jd_coroutine(axis_name):
    def model_coroutine():
        w = yield tfd.JointDistributionCoroutine.Root(tfd.Normal(0., 1.))
        x = yield sharded.Sharded(tfd.Sample(tfd.Normal(w, 1.), 1),
                                  shard_axis_name=axis_name)
        yield sharded.Sharded(tfd.Independent(tfd.Normal(x, 1.), 1),
                              shard_axis_name=axis_name)

    return jd.JointDistributionCoroutine(model_coroutine)
Beispiel #2
0
def make_jd_coroutine(axis_name):

  def model_coroutine():
    w = yield tfd.JointDistributionCoroutine.Root(tfd.Normal(0., 1.))
    x = yield sharded.ShardedSample(
        tfd.Normal(w, 1.), test_lib.NUM_DEVICES, shard_axis_name=axis_name)
    yield sharded.ShardedIndependent(
        tfd.Normal(x, 1.), 1, shard_axis_name=axis_name)

  return jd.JointDistributionCoroutine(
      model_coroutine, shard_axis_name=axis_name)
Beispiel #3
0
  def __init__(self, *args, name='JointDensityCoroutine', **kwargs):
    """Construct the `JointDensityCoroutine` density.

    See the documentation for JointDistributionCoroutine

    Args:
      *args: Positional arguments forwarded to JointDistributionCoroutine.
      name: The name for ops managed by the density.
        Default value: `JointDensityCoroutine`.
      **kwargs: Named arguments forwarded to JointDistributionCoroutine.
    """
    with tf.name_scope(name) as name:
      self._joint_distribution_coroutine = jdc_lib.JointDistributionCoroutine(
          name=name, *args, **kwargs)
NUM_DEVICES = 4


def per_replica_to_tensor(value):
    return tf.nest.map_structure(
        lambda per_replica: tf.stack(per_replica.values, axis=0), value)


def model_coroutine():
    w = yield tfd.JointDistributionCoroutine.Root(tfd.Normal(0., 1.))
    x = yield sharded.ShardedSample(tfd.Normal(w, 1.), NUM_DEVICES)
    yield sharded.ShardedIndependent(tfd.Normal(x, 1.), 1)


distributions = (
    ('coroutine', lambda: jd.JointDistributionCoroutine(model_coroutine)),
    (
        'sequential',
        lambda: jd.JointDistributionSequential([  # pylint: disable=g-long-lambda
            tfd.Normal(0., 1.),
            lambda w: sharded.ShardedSample(tfd.Normal(w, 1.), NUM_DEVICES),
            lambda x: sharded.ShardedIndependent(tfd.Normal(x, 1.), 1),
        ])),
    (
        'named',
        lambda: jd.JointDistributionNamed(  # pylint: disable=g-long-lambda
            dict(
                w=tfd.Normal(0., 1.),
                x=lambda w: sharded.ShardedSample(tfd.Normal(w, 1.),
                                                  NUM_DEVICES),
                data=lambda x: sharded.ShardedIndependent(