Exemplo n.º 1
0
def make_jd_sequential(axis_name):
  return jd.JointDistributionSequential([
      tfd.Normal(0., 1.),
      lambda w: sharded.ShardedSample(  # pylint: disable=g-long-lambda
          tfd.Normal(w, 1.), test_lib.NUM_DEVICES, shard_axis_name=axis_name),
      lambda x: sharded.ShardedIndependent(  # pylint: disable=g-long-lambda
          tfd.Normal(x, 1.), 1, shard_axis_name=axis_name),
  ], shard_axis_name=axis_name)
Exemplo n.º 2
0
 def model_coroutine():
     w = yield tfd.JointDistributionCoroutine.Root(tfd.Normal(0., 1.))
     x = yield sharded.ShardedSample(tfd.Normal(w, 1.),
                                     test_lib.NUM_DEVICES,
                                     shard_axis_name=axis_name)
     yield sharded.ShardedIndependent(tfd.Normal(x, 1.),
                                      1,
                                      shard_axis_name=axis_name)
Exemplo n.º 3
0
def make_jd_named(axis_name):
  return jd.JointDistributionNamed(  # pylint: disable=g-long-lambda
      dict(
          w=tfd.Normal(0., 1.),
          x=lambda w: sharded.ShardedSample(  # pylint: disable=g-long-lambda
              tfd.Normal(w, 1.),
              test_lib.NUM_DEVICES,
              shard_axis_name=axis_name),
          data=lambda x: sharded.ShardedIndependent(  # pylint: disable=g-long-lambda
              tfd.Normal(x, 1.),
              1,
              shard_axis_name=axis_name),
      ), shard_axis_name=axis_name)
def model_coroutine():
    w = yield tfd.JointDistributionCoroutine.Root(tfd.Normal(0., 1.))
    x = yield sharded.ShardedSample(tfd.Normal(w, 1.), NUM_DEVICES)
    yield sharded.ShardedIndependent(tfd.Normal(x, 1.), 1)
        lambda per_replica: tf.stack(per_replica.values, axis=0), value)


def model_coroutine():
    w = yield tfd.JointDistributionCoroutine.Root(tfd.Normal(0., 1.))
    x = yield sharded.ShardedSample(tfd.Normal(w, 1.), NUM_DEVICES)
    yield sharded.ShardedIndependent(tfd.Normal(x, 1.), 1)


distributions = (
    ('coroutine', lambda: jd.JointDistributionCoroutine(model_coroutine)),
    (
        'sequential',
        lambda: jd.JointDistributionSequential([  # pylint: disable=g-long-lambda
            tfd.Normal(0., 1.),
            lambda w: sharded.ShardedSample(tfd.Normal(w, 1.), NUM_DEVICES),
            lambda x: sharded.ShardedIndependent(tfd.Normal(x, 1.), 1),
        ])),
    (
        'named',
        lambda: jd.JointDistributionNamed(  # pylint: disable=g-long-lambda
            dict(
                w=tfd.Normal(0., 1.),
                x=lambda w: sharded.ShardedSample(tfd.Normal(w, 1.),
                                                  NUM_DEVICES),
                data=lambda x: sharded.ShardedIndependent(
                    tfd.Normal(x, 1.), 1),
            ))),
)

Exemplo n.º 6
0
 def run(key):
     return sharded.ShardedSample(tfd.Normal(0., 1.),
                                  NUM_DEVICES).sample(seed=key)
Exemplo n.º 7
0
 def run(key):
     return sharded.ShardedSample(
         tfd.Normal(0., 1.),
         test_lib.NUM_DEVICES,
         shard_axis_name=self.axis_name).sample(seed=key)