def make_jd_sequential(axis_name): return jd.JointDistributionSequential([ tfd.Normal(0., 1.), lambda w: sharded.ShardedSample( # pylint: disable=g-long-lambda tfd.Normal(w, 1.), test_lib.NUM_DEVICES, shard_axis_name=axis_name), lambda x: sharded.ShardedIndependent( # pylint: disable=g-long-lambda tfd.Normal(x, 1.), 1, shard_axis_name=axis_name), ], shard_axis_name=axis_name)
def model_coroutine(): w = yield tfd.JointDistributionCoroutine.Root(tfd.Normal(0., 1.)) x = yield sharded.ShardedSample(tfd.Normal(w, 1.), test_lib.NUM_DEVICES, shard_axis_name=axis_name) yield sharded.ShardedIndependent(tfd.Normal(x, 1.), 1, shard_axis_name=axis_name)
def make_jd_named(axis_name): return jd.JointDistributionNamed( # pylint: disable=g-long-lambda dict( w=tfd.Normal(0., 1.), x=lambda w: sharded.ShardedSample( # pylint: disable=g-long-lambda tfd.Normal(w, 1.), test_lib.NUM_DEVICES, shard_axis_name=axis_name), data=lambda x: sharded.ShardedIndependent( # pylint: disable=g-long-lambda tfd.Normal(x, 1.), 1, shard_axis_name=axis_name), ), shard_axis_name=axis_name)
def model_coroutine(): w = yield tfd.JointDistributionCoroutine.Root(tfd.Normal(0., 1.)) x = yield sharded.ShardedSample(tfd.Normal(w, 1.), NUM_DEVICES) yield sharded.ShardedIndependent(tfd.Normal(x, 1.), 1)
def model_coroutine(): w = yield tfd.JointDistributionCoroutine.Root(tfd.Normal(0., 1.)) x = yield sharded.ShardedSample(tfd.Normal(w, 1.), NUM_DEVICES) yield sharded.ShardedIndependent(tfd.Normal(x, 1.), 1) distributions = ( ('coroutine', lambda: jd.JointDistributionCoroutine(model_coroutine)), ( 'sequential', lambda: jd.JointDistributionSequential([ # pylint: disable=g-long-lambda tfd.Normal(0., 1.), lambda w: sharded.ShardedSample(tfd.Normal(w, 1.), NUM_DEVICES), lambda x: sharded.ShardedIndependent(tfd.Normal(x, 1.), 1), ])), ( 'named', lambda: jd.JointDistributionNamed( # pylint: disable=g-long-lambda dict( w=tfd.Normal(0., 1.), x=lambda w: sharded.ShardedSample(tfd.Normal(w, 1.), NUM_DEVICES), data=lambda x: sharded.ShardedIndependent( tfd.Normal(x, 1.), 1), ))), ) @test_util.test_all_tf_execution_regimes
def run(key): return sharded.ShardedIndependent( tfd.Normal(tf.zeros(1), tf.ones(1)), 1).sample(seed=key)
def run(key): return sharded.ShardedIndependent( tfd.Normal(tf.zeros(1), tf.ones(1)), 1, shard_axis_name=self.axis_name).sample(seed=key)