def make_jd_named(axis_name):
  return jd.JointDistributionNamed(  # pylint: disable=g-long-lambda
      dict(
          w=tfd.Normal(0., 1.),
          x=lambda w: sharded.Sharded(  # pylint: disable=g-long-lambda
              tfd.Sample(tfd.Normal(w, 1.), 1),
              shard_axis_name=axis_name),
          data=lambda x: sharded.Sharded(  # pylint: disable=g-long-lambda
              tfd.Independent(tfd.Normal(x, 1.), 1),
              shard_axis_name=axis_name),
      ))
distributions = (
    ('coroutine', lambda: jd.JointDistributionCoroutine(model_coroutine)),
    (
        'sequential',
        lambda: jd.JointDistributionSequential([  # pylint: disable=g-long-lambda
            tfd.Normal(0., 1.),
            lambda w: sharded.ShardedSample(tfd.Normal(w, 1.), NUM_DEVICES),
            lambda x: sharded.ShardedIndependent(tfd.Normal(x, 1.), 1),
        ])),
    (
        'named',
        lambda: jd.JointDistributionNamed(  # pylint: disable=g-long-lambda
            dict(
                w=tfd.Normal(0., 1.),
                x=lambda w: sharded.ShardedSample(tfd.Normal(w, 1.),
                                                  NUM_DEVICES),
                data=lambda x: sharded.ShardedIndependent(
                    tfd.Normal(x, 1.), 1),
            ))),
)


@test_util.test_all_tf_execution_regimes
class JointDistributionTest(test_util.TestCase):
    def setUp(self):
        super(JointDistributionTest, self).setUp()
        self.strategy = tf.distribute.MirroredStrategy(
            devices=tf.config.list_logical_devices())

    def shard_values(self, values):
        def value_fn(ctx):