예제 #1
0
def nested_lists_model_fn():
  return collections.OrderedDict((
      ('abc', tfd.JointDistributionSequential([
          tfd.MultivariateNormalDiag([0., 0.], [1., 1.]),
          tfd.JointDistributionSequential(
              [tfd.StudentT(3., -2., 5.),
               tfd.Exponential(4.)])])),
      ('de', lambda abc: tfd.JointDistributionSequential([  # pylint: disable=g-long-lambda
          tfd.Normal(abc[0] * abc[1][0], abc[1][1]),
          tfd.Normal(abc[0] + abc[1][0], abc[1][1])]))))
예제 #2
0
  def test_multipart_bijector(self):
    seed_stream = test_util.test_seed_stream()

    prior = tfd.JointDistributionSequential([
        tfd.Gamma(1., 1.),
        lambda scale: tfd.Uniform(0., scale),
        lambda concentration: tfd.CholeskyLKJ(4, concentration),
    ], validate_args=True)
    likelihood = lambda corr: tfd.MultivariateNormalTriL(scale_tril=corr)
    obs = self.evaluate(
        likelihood(
            prior.sample(seed=seed_stream())[-1]).sample(seed=seed_stream()))

    bij = prior.experimental_default_event_space_bijector()

    def target_log_prob(scale, conc, corr):
      return prior.log_prob(scale, conc, corr) + likelihood(corr).log_prob(obs)
    kernel = tfp.mcmc.HamiltonianMonteCarlo(target_log_prob,
                                            num_leapfrog_steps=3, step_size=.5)
    kernel = tfp.mcmc.TransformedTransitionKernel(kernel, bij)

    init = self.evaluate(
        tuple(tf.random.uniform(s, -2., 2., seed=seed_stream())
              for s in bij.inverse_event_shape(prior.event_shape)))
    state = bij.forward(init)
    kr = kernel.bootstrap_results(state)
    next_state, next_kr = kernel.one_step(state, kr, seed=seed_stream())
    self.evaluate((state, kr, next_state, next_kr))
    expected = (target_log_prob(*state) -
                bij.inverse_log_det_jacobian(state, [0, 0, 2]))
    actual = kernel._inner_kernel.target_log_prob_fn(*init)  # pylint: disable=protected-access
    self.assertAllClose(expected, actual)
예제 #3
0
  def testDivergence(self):
    """Neals funnel with large step size."""
    strm = tfp_test_util.test_seed_stream()
    neals_funnel = tfd.JointDistributionSequential(
        [
            tfd.Normal(loc=0., scale=3.),  # b0
            lambda y: tfd.Sample(  # pylint: disable=g-long-lambda
                tfd.Normal(loc=0., scale=tf.math.exp(y / 2)),
                sample_shape=9),
        ],
        validate_args=True
    )

    @tf.function(autograph=False)
    def run_chain_and_get_divergence():
      nchains = 5
      init_states = neals_funnel.sample(nchains, seed=strm())
      _, has_divergence = tfp.mcmc.sample_chain(
          num_results=100,
          kernel=tfp.mcmc.NoUTurnSampler(
              target_log_prob_fn=lambda *args: neals_funnel.log_prob(args),
              step_size=[1., 1.],
              parallel_iterations=1,
              seed=strm()),
          current_state=init_states,
          trace_fn=lambda _, pkr: pkr.has_divergence,
          parallel_iterations=1)
      return tf.reduce_sum(tf.cast(has_divergence, dtype=tf.int32))

    divergence_count = self.evaluate(run_chain_and_get_divergence())

    # Test that we observe a fair among of divergence.
    self.assertAllGreater(divergence_count, 100)
    def testExample(self):
        tf1.random.set_random_seed(tfp_test_util.test_seed())
        target_dist = tfd.JointDistributionSequential([
            tfd.Normal(0., 1.5),
            tfd.Independent(tfd.Normal(tf.zeros([2, 5], dtype=tf.float32), 5.),
                            reinterpreted_batch_ndims=2),
        ])
        num_burnin_steps = 500
        num_results = 500
        num_chains = 64

        kernel = tfp.mcmc.HamiltonianMonteCarlo(
            target_log_prob_fn=lambda *args: target_dist.log_prob(args),
            num_leapfrog_steps=2,
            step_size=target_dist.stddev(),
            seed=_set_seed(tfp_test_util.test_seed()))
        kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(
            inner_kernel=kernel,
            num_adaptation_steps=int(num_burnin_steps * 0.8))

        _, log_accept_ratio = tfp.mcmc.sample_chain(
            num_results=num_results,
            num_burnin_steps=num_burnin_steps,
            current_state=target_dist.sample(num_chains),
            kernel=kernel,
            trace_fn=lambda _, pkr: pkr.inner_results.log_accept_ratio)

        p_accept = tf.reduce_mean(
            input_tensor=tf.exp(tf.minimum(log_accept_ratio, 0.)))

        self.assertAllClose(0.75, self.evaluate(p_accept), atol=0.15)
예제 #5
0
    def testMultipleStateParts(self):
        dist = tfd.JointDistributionSequential([
            tfd.MultivariateNormalDiag(tf.zeros(3), tf.ones(3)),
            tfd.MultivariateNormalDiag(tf.zeros(2), tf.ones(2))
        ])
        target_log_prob_fn = dist.log_prob
        kernel = tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo(
            target_log_prob_fn=target_log_prob_fn,
            num_leapfrog_steps=2,
            step_size=1.)
        initial_running_variance = [
            tfp.experimental.stats.RunningVariance.from_stats(
                num_samples=1., mean=tf.zeros(3), variance=tf.ones(3)),
            tfp.experimental.stats.RunningVariance.from_stats(
                num_samples=1., mean=tf.zeros(2), variance=tf.ones(2))
        ]
        kernel = tfp.experimental.mcmc.DiagonalMassMatrixAdaptation(
            inner_kernel=kernel,
            initial_running_variance=initial_running_variance)

        num_results = 5
        draws = tfp.mcmc.sample_chain(num_results=num_results,
                                      current_state=[tf.zeros(3),
                                                     tf.zeros(2)],
                                      kernel=kernel,
                                      seed=test_util.test_seed(),
                                      trace_fn=None)

        # Make sure the result has the correct shape
        self.assertEqual(len(draws), 2)
        self.assertEqual(draws[0].shape, (num_results, 3))
        self.assertEqual(draws[1].shape, (num_results, 2))
예제 #6
0
    def testExample(self):
        target_dist = tfd.JointDistributionSequential([
            tfd.Normal(0., 1.5),
            tfd.Independent(tfd.Normal(tf.zeros([2, 5], dtype=tf.float32), 5.),
                            reinterpreted_batch_ndims=2),
        ])
        num_burnin_steps = 500
        num_results = 500
        num_chains = 64

        kernel = tfp.mcmc.HamiltonianMonteCarlo(
            target_log_prob_fn=lambda *args: target_dist.log_prob(args),
            num_leapfrog_steps=2,
            step_size=target_dist.stddev())
        kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(
            inner_kernel=kernel,
            num_adaptation_steps=int(num_burnin_steps * 0.8),
            # Cast to int32.  Not necessary for operation since we cast internally
            # to a float type.  This is done to check that we are able to pass in
            # integer types (since they are the natural type for this).
            step_count_smoothing=tf.cast(10, tf.int32))

        seed_stream = test_util.test_seed_stream()
        _, log_accept_ratio = tfp.mcmc.sample_chain(
            num_results=num_results,
            num_burnin_steps=num_burnin_steps,
            current_state=target_dist.sample(num_chains, seed=seed_stream()),
            kernel=kernel,
            trace_fn=lambda _, pkr: pkr.inner_results.log_accept_ratio,
            seed=seed_stream())

        p_accept = tf.reduce_mean(tf.math.exp(tf.minimum(log_accept_ratio,
                                                         0.)))

        self.assertAllClose(0.75, self.evaluate(p_accept), atol=0.15)
예제 #7
0
  def test_transform_parts_to_vector(self, known_split_sizes):
    batch_shape = [4, 2]
    true_split_sizes = [1, 3, 2]

    # Create a joint distribution with parts of the specified sizes.
    seed = test_util.test_seed_stream()
    component_dists = tf.nest.map_structure(
        lambda size: tfd.MultivariateNormalDiag(  # pylint: disable=g-long-lambda
            loc=tf.random.normal(batch_shape + [size], seed=seed()),
            scale_diag=tf.exp(
                tf.random.normal(batch_shape + [size], seed=seed()))),
        true_split_sizes)
    base_dist = tfd.JointDistributionSequential(component_dists)

    # Transform to a vector-valued distribution by concatenating the parts.
    bijector = tfb.Invert(tfb.Split(known_split_sizes, axis=-1))

    with self.assertRaisesRegexp(ValueError, 'Overriding the batch shape'):
      tfd.TransformedDistribution(base_dist, bijector, batch_shape=[3])

    with self.assertRaisesRegexp(ValueError, 'Overriding the event shape'):
      tfd.TransformedDistribution(base_dist, bijector, event_shape=[3])

    concat_dist = tfd.TransformedDistribution(base_dist, bijector)
    self.assertAllEqual(concat_dist.event_shape, [sum(true_split_sizes)])
    self.assertAllEqual(self.evaluate(concat_dist.event_shape_tensor()),
                        [sum(true_split_sizes)])
    self.assertAllEqual(concat_dist.batch_shape, batch_shape)
    self.assertAllEqual(self.evaluate(concat_dist.batch_shape_tensor()),
                        batch_shape)

    # Since the Split bijector has (constant) unit Jacobian, the transformed
    # entropy and mean/mode should match the base entropy and (split) base
    # mean/mode.
    self.assertAllEqual(*self.evaluate(
        (base_dist.entropy(), concat_dist.entropy())))

    self.assertAllEqual(*self.evaluate(
        (concat_dist.mean(), bijector.forward(base_dist.mean()))))
    self.assertAllEqual(*self.evaluate(
        (concat_dist.mode(), bijector.forward(base_dist.mode()))))

    # Since the Split bijector has zero Jacobian, the transformed `log_prob`
    # and `prob` should match the base distribution.
    sample_shape = [3]
    x = base_dist.sample(sample_shape, seed=seed())
    y = bijector.forward(x)
    for attr in ('log_prob', 'prob'):
      base_attr = getattr(base_dist, attr)(x)
      concat_attr = getattr(concat_dist, attr)(y)
      self.assertAllClose(*self.evaluate((base_attr, concat_attr)))

    # Test that `.sample()` works and returns a result of the expected structure
    # and shape.
    y_sampled = concat_dist.sample(sample_shape, seed=seed())
    self.assertAllEqual(y.shape, y_sampled.shape)
예제 #8
0
    def testTwoStateParts(self):
        dtype = np.float32
        num_results, tolerance = _get_mode_dependent_settings()

        true_loc1 = 1.
        true_scale1 = 1.
        true_loc2 = -1.
        true_scale2 = 2.
        target = tfd.JointDistributionSequential([
            tfd.Normal(true_loc1, true_scale1),
            tfd.Normal(true_loc2, true_scale2),
        ])

        num_chains = 16

        init_state = [
            np.ones([num_chains], dtype=dtype),
            np.ones([num_chains], dtype=dtype)
        ]

        target_fn = tf.function(lambda *states: target.log_prob(states),
                                autograph=False)
        [states1, states2] = tfp.mcmc.sample_chain(
            num_results=num_results,
            current_state=init_state,
            kernel=tfp.mcmc.SliceSampler(target_log_prob_fn=target_fn,
                                         step_size=1.0,
                                         max_doublings=5),
            num_burnin_steps=100,
            trace_fn=None,
            seed=test_util.test_seed_stream())

        states1 = tf.reshape(states1, [-1])
        states2 = tf.reshape(states2, [-1])
        sample_mean1 = tf.reduce_mean(states1, axis=0)
        sample_stddev1 = tfp.stats.stddev(states1)
        sample_mean2 = tf.reduce_mean(states2, axis=0)
        sample_stddev2 = tfp.stats.stddev(states2)

        self.assertAllClose(true_loc1,
                            sample_mean1,
                            atol=tolerance,
                            rtol=tolerance)
        self.assertAllClose(true_scale1,
                            sample_stddev1,
                            atol=tolerance,
                            rtol=tolerance)
        self.assertAllClose(true_loc2,
                            sample_mean2,
                            atol=tolerance,
                            rtol=tolerance)
        self.assertAllClose(true_scale2,
                            sample_stddev2,
                            atol=tolerance,
                            rtol=tolerance)
예제 #9
0
def get_base_distribution(flat_event_size, dtype=DEFAULT_FLOAT_DTYPE_TF):
    base_standard_dist = tfd.JointDistributionSequential([
        tfd.Sample(
            tfd.Normal(
                loc=tf.constant(0.0, dtype=dtype),
                scale=tf.constant(1.0, dtype=dtype),
            ),
            s,
        ) for s in flat_event_size
    ])
    return base_standard_dist
  def test_transform_joint_to_joint(self, split_sizes):
    dist_batch_shape = tf.nest.pack_sequence_as(
        split_sizes,
        [tensorshape_util.constant_value_as_shape(s)
         for s in [[2, 3], [2, 1], [1, 3]]])
    bijector_batch_shape = [1, 3]

    # Build a joint distribution with parts of the specified sizes.
    seed = test_util.test_seed_stream()
    component_dists = tf.nest.map_structure(
        lambda size, batch_shape: tfd.MultivariateNormalDiag(  # pylint: disable=g-long-lambda
            loc=tf.random.normal(batch_shape + [size], seed=seed()),
            scale_diag=tf.random.uniform(
                minval=1., maxval=2.,
                shape=batch_shape + [size], seed=seed())),
        split_sizes, dist_batch_shape)
    if isinstance(split_sizes, dict):
      base_dist = tfd.JointDistributionNamed(component_dists)
    else:
      base_dist = tfd.JointDistributionSequential(component_dists)

    # Transform the distribution by applying a separate bijector to each part.
    bijectors = [tfb.Exp(),
                 tfb.Scale(
                     tf.random.uniform(
                         minval=1., maxval=2.,
                         shape=bijector_batch_shape, seed=seed())),
                 tfb.Reshape([2, 1])]
    bijector = tfb.JointMap(tf.nest.pack_sequence_as(split_sizes, bijectors),
                            validate_args=True)

    # Transform a joint distribution that has different batch shape components
    transformed_dist = tfd.TransformedDistribution(base_dist, bijector)

    self.assertRegex(
        str(transformed_dist),
        '{}.*batch_shape.*event_shape.*dtype'.format(transformed_dist.name))

    self.assertAllEqualNested(
        transformed_dist.event_shape,
        bijector.forward_event_shape(base_dist.event_shape))
    self.assertAllEqualNested(*self.evaluate((
        transformed_dist.event_shape_tensor(),
        bijector.forward_event_shape_tensor(base_dist.event_shape_tensor()))))

    # Test that the batch shape components of the input are the same as those of
    # the output.
    self.assertAllEqualNested(transformed_dist.batch_shape, dist_batch_shape)
    self.assertAllEqualNested(
        self.evaluate(transformed_dist.batch_shape_tensor()), dist_batch_shape)
    self.assertAllEqualNested(dist_batch_shape, base_dist.batch_shape)
예제 #11
0
  def testPartsWithUnusedInternalStructure(self):
    dist = tfd.JointDistributionSequential([
        tfd.JointDistributionNamed({'a': tfd.Normal(0., 1.)}),
        tfd.JointDistributionNamed({'b': tfd.Normal(1000., 1.)}),
    ])
    x = dist.sample(  # Shape: [{'a': []}, {'b': []}]
        seed=test_util.test_seed(sampler_type='stateless'))

    # Test that we can swap the outer list entries, even though they contain
    # internal structure (i.e., are themselves dicts).
    swap_elements = tfb.Restructure(input_structure=[1, 0],
                                    output_structure=[0, 1])
    self.assertAllEqualNested(swap_elements(x), [x[1], x[0]], check_types=True)

    swapped_dist = swap_elements(dist)
    self.assertAllEqualNested(swapped_dist.event_shape,
                              [dist.event_shape[1], dist.event_shape[0]],
                              check_types=True)
    self.assertEqual(swapped_dist.dtype,
                     [dist.dtype[1], dist.dtype[0]])
예제 #12
0
  def testSampleEndtoEnd(self):
    """An end-to-end test of sampling using NUTS."""
    strm = tfp_test_util.test_seed_stream()
    predictors = tf.cast([
        201., 244., 47., 287., 203., 58., 210., 202., 198., 158., 165., 201.,
        157., 131., 166., 160., 186., 125., 218., 146.
    ], tf.float32)
    obs = tf.cast([
        592., 401., 583., 402., 495., 173., 479., 504., 510., 416., 393., 442.,
        317., 311., 400., 337., 423., 334., 533., 344.
    ], tf.float32)
    y_sigma = tf.cast([
        61., 25., 38., 15., 21., 15., 27., 14., 30., 16., 14., 25., 52., 16.,
        34., 31., 42., 26., 16., 22.
    ], tf.float32)

    # Robust linear regression model
    robust_lm = tfd.JointDistributionSequential(
        [
            tfd.Normal(loc=0., scale=1.),  # b0
            tfd.Normal(loc=0., scale=1.),  # b1
            tfd.HalfNormal(5.),  # df
            lambda df, b1, b0: tfd.Independent(  # pylint: disable=g-long-lambda
                tfd.StudentT(  # Likelihood
                    df=df[:, None],
                    loc=b0[:, None] + b1[:, None] * predictors[None, :],
                    scale=y_sigma[None, :])),
        ],
        validate_args=True)

    log_prob = lambda b0, b1, df: robust_lm.log_prob([b0, b1, df, obs])
    init_step_size = [1., .2, .5]
    step_size0 = [tf.cast(x, dtype=tf.float32) for x in init_step_size]

    number_of_steps, burnin, nchain = 200, 50, 10

    @tf.function(autograph=False)
    def run_chain_and_get_diagnostic():
      # random initialization of the starting postion of each chain
      b0, b1, df, _ = robust_lm.sample(nchain, seed=strm())

      # bijector to map contrained parameters to real
      unconstraining_bijectors = [
          tfb.Identity(),
          tfb.Identity(),
          tfb.Exp(),
      ]

      def trace_fn(_, pkr):
        return (pkr.inner_results.inner_results.step_size,
                pkr.inner_results.inner_results.log_accept_ratio)

      kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(
          tfp.mcmc.TransformedTransitionKernel(
              inner_kernel=tfp.mcmc.NoUTurnSampler(
                  target_log_prob_fn=log_prob,
                  step_size=step_size0,
                  parallel_iterations=1,
                  seed=strm()),
              bijector=unconstraining_bijectors),
          target_accept_prob=.8,
          num_adaptation_steps=burnin,
          step_size_setter_fn=lambda pkr, new_step_size: pkr._replace(  # pylint: disable=g-long-lambda
              inner_results=pkr.inner_results._replace(step_size=new_step_size)
          ),
          step_size_getter_fn=lambda pkr: pkr.inner_results.step_size,
          log_accept_prob_getter_fn=lambda pkr: pkr.inner_results.
          log_accept_ratio,
      )

      # Sampling from the chain and get diagnostics
      mcmc_trace, (step_size, log_accept_ratio) = tfp.mcmc.sample_chain(
          num_results=number_of_steps,
          num_burnin_steps=burnin,
          current_state=[b0, b1, df],
          kernel=kernel,
          trace_fn=trace_fn,
          parallel_iterations=1)
      rhat = tfp.mcmc.potential_scale_reduction(mcmc_trace)
      return (
          [s[-1] for s in step_size],  # final step size
          tf.math.exp(tfp.math.reduce_logmeanexp(log_accept_ratio)),
          [tf.reduce_mean(rhat_) for rhat_ in rhat],  # average rhat
      )

    # Sample from posterior distribution and get diagnostic
    [
        final_step_size, average_accept_ratio, average_rhat
    ] = self.evaluate(run_chain_and_get_diagnostic())

    # Check that step size adaptation reduced the initial step size
    self.assertAllLess(
        np.asarray(final_step_size) - np.asarray(init_step_size), 0.)
    # Check that average acceptance ratio is close to target
    self.assertAllClose(
        average_accept_ratio,
        .8 * np.ones_like(average_accept_ratio),
        atol=0.1, rtol=0.1)
    # Check that mcmc sample quality is acceptable with tuning
    self.assertAllClose(
        average_rhat, np.ones_like(average_rhat), atol=0.05, rtol=0.05)
예제 #13
0
  def test_transform_joint_to_joint(self, split_sizes):
    dist_batch_shape = tf.nest.pack_sequence_as(
        split_sizes,
        [tensorshape_util.constant_value_as_shape(s)
         for s in [[2, 3], [2, 1], [1, 3]]])
    bijector_batch_shape = [1, 3]

    # Build a joint distribution with parts of the specified sizes.
    seed = test_util.test_seed_stream()
    component_dists = tf.nest.map_structure(
        lambda size, batch_shape: tfd.MultivariateNormalDiag(  # pylint: disable=g-long-lambda
            loc=tf.random.normal(batch_shape + [size], seed=seed()),
            scale_diag=tf.exp(
                tf.random.normal(batch_shape + [size], seed=seed()))),
        split_sizes, dist_batch_shape)
    if isinstance(split_sizes, dict):
      base_dist = tfd.JointDistributionNamed(component_dists)
    else:
      base_dist = tfd.JointDistributionSequential(component_dists)

    # Transform the distribution by applying a separate bijector to each part.
    bijectors = [tfb.Exp(),
                 tfb.Scale(tf.random.normal(bijector_batch_shape, seed=seed())),
                 tfb.Reshape([2, 1])]
    bijector = ToyZipMap(tf.nest.pack_sequence_as(split_sizes, bijectors))

    with self.assertRaisesRegexp(ValueError, 'Overriding the batch shape'):
      tfd.TransformedDistribution(base_dist, bijector, batch_shape=[3])

    with self.assertRaisesRegexp(ValueError, 'Overriding the event shape'):
      tfd.TransformedDistribution(base_dist, bijector, event_shape=[3])

    # Transform a joint distribution that has different batch shape components
    transformed_dist = tfd.TransformedDistribution(base_dist, bijector)

    self.assertAllEqualNested(
        transformed_dist.event_shape,
        bijector.forward_event_shape(base_dist.event_shape))
    self.assertAllEqualNested(*self.evaluate((
        transformed_dist.event_shape_tensor(),
        bijector.forward_event_shape_tensor(base_dist.event_shape_tensor()))))

    # Test that the batch shape components of the input are the same as those of
    # the output.
    self.assertAllEqualNested(transformed_dist.batch_shape, dist_batch_shape)
    self.assertAllEqualNested(
        self.evaluate(transformed_dist.batch_shape_tensor()), dist_batch_shape)
    self.assertAllEqualNested(dist_batch_shape, base_dist.batch_shape)

    # Check transformed `log_prob` against the base distribution.
    sample_shape = [3]
    sample = base_dist.sample(sample_shape, seed=seed())
    x = tf.nest.map_structure(tf.zeros_like, sample)
    y = bijector.forward(x)
    base_logprob = base_dist.log_prob(x)
    event_ndims = tf.nest.map_structure(lambda s: s.ndims,
                                        transformed_dist.event_shape)
    ildj = bijector.inverse_log_det_jacobian(y, event_ndims=event_ndims)

    (transformed_logprob,
     base_logprob_plus_ildj,
     log_transformed_prob
    ) = self.evaluate([
        transformed_dist.log_prob(y),
        base_logprob + ildj,
        tf.math.log(transformed_dist.prob(y))
    ])
    self.assertAllClose(base_logprob_plus_ildj, transformed_logprob)
    self.assertAllClose(transformed_logprob, log_transformed_prob)

    # Test that `.sample()` works and returns a result of the expected structure
    # and shape.
    y_sampled = transformed_dist.sample(sample_shape, seed=seed())
    self.assertAllEqual(tf.nest.map_structure(lambda y: y.shape, y),
                        tf.nest.map_structure(lambda y: y.shape, y_sampled))
    def test_transform_joint_to_joint(self, split_sizes):
        dist_batch_shape = tf.nest.pack_sequence_as(split_sizes, [
            tensorshape_util.constant_value_as_shape(s)
            for s in [[2, 3], [2, 1], [1, 3]]
        ])
        bijector_batch_shape = [1, 3]

        # Build a joint distribution with parts of the specified sizes.
        seed = test_util.test_seed_stream()
        component_dists = tf.nest.map_structure(
            lambda size, batch_shape: tfd.MultivariateNormalDiag(  # pylint: disable=g-long-lambda
                loc=tf.random.normal(batch_shape + [size], seed=seed()),
                scale_diag=tf.random.uniform(minval=1.,
                                             maxval=2.,
                                             shape=batch_shape + [size],
                                             seed=seed())),
            split_sizes,
            dist_batch_shape)
        if isinstance(split_sizes, dict):
            base_dist = tfd.JointDistributionNamed(component_dists)
        else:
            base_dist = tfd.JointDistributionSequential(component_dists)

        # Transform the distribution by applying a separate bijector to each part.
        bijectors = [
            tfb.Exp(),
            tfb.Scale(
                tf.random.uniform(minval=1.,
                                  maxval=2.,
                                  shape=bijector_batch_shape,
                                  seed=seed())),
            tfb.Reshape([2, 1])
        ]
        bijector = tfb.JointMap(tf.nest.pack_sequence_as(
            split_sizes, bijectors),
                                validate_args=True)

        # Transform a joint distribution that has different batch shape components
        transformed_dist = tfd.TransformedDistribution(base_dist, bijector)

        self.assertRegex(
            str(transformed_dist),
            '{}.*batch_shape.*event_shape.*dtype'.format(
                transformed_dist.name))

        self.assertAllEqualNested(
            transformed_dist.event_shape,
            bijector.forward_event_shape(base_dist.event_shape))
        self.assertAllEqualNested(
            *self.evaluate((transformed_dist.event_shape_tensor(),
                            bijector.forward_event_shape_tensor(
                                base_dist.event_shape_tensor()))))

        # Test that the batch shape components of the input are the same as those of
        # the output.
        self.assertAllEqualNested(transformed_dist.batch_shape,
                                  dist_batch_shape)
        self.assertAllEqualNested(
            self.evaluate(transformed_dist.batch_shape_tensor()),
            dist_batch_shape)
        self.assertAllEqualNested(dist_batch_shape, base_dist.batch_shape)

        # Check transformed `log_prob` against the base distribution.
        sample_shape = [3]
        sample = base_dist.sample(sample_shape, seed=seed())
        x = tf.nest.map_structure(tf.zeros_like, sample)
        y = bijector.forward(x)
        base_logprob = base_dist.log_prob(x)
        event_ndims = tf.nest.map_structure(lambda s: s.ndims,
                                            transformed_dist.event_shape)
        ildj = bijector.inverse_log_det_jacobian(y, event_ndims=event_ndims)

        (transformed_logprob, base_logprob_plus_ildj,
         log_transformed_prob) = self.evaluate([
             transformed_dist.log_prob(y), base_logprob + ildj,
             tf.math.log(transformed_dist.prob(y))
         ])
        self.assertAllClose(base_logprob_plus_ildj, transformed_logprob)
        self.assertAllClose(transformed_logprob, log_transformed_prob)

        # Test that `.sample()` works and returns a result of the expected structure
        # and shape.
        y_sampled = transformed_dist.sample(sample_shape, seed=seed())
        self.assertAllEqual(
            tf.nest.map_structure(lambda y: y.shape, y),
            tf.nest.map_structure(lambda y: y.shape, y_sampled))

        # Test that a `Restructure` bijector applied to a `JointDistribution` works
        # as expected.
        num_components = len(split_sizes)
        input_keys = (split_sizes.keys() if isinstance(split_sizes, dict) else
                      range(num_components))
        output_keys = [str(i) for i in range(num_components)]
        output_structure = {k: v for k, v in zip(output_keys, input_keys)}
        restructure = tfb.Restructure(output_structure)
        restructured_dist = tfd.TransformedDistribution(base_dist,
                                                        bijector=restructure,
                                                        validate_args=True)

        # Check that attributes of the restructured distribution have the same
        # nested structure as the `output_structure` of the bijector. Pass a no-op
        # as the `assert_fn` since the contents of the structures are not
        # required to be the same.
        noop_assert_fn = lambda *_: None
        self.assertAllAssertsNested(noop_assert_fn,
                                    restructured_dist.event_shape,
                                    output_structure)
        self.assertAllAssertsNested(noop_assert_fn,
                                    restructured_dist.batch_shape,
                                    output_structure)
        self.assertAllAssertsNested(
            noop_assert_fn,
            self.evaluate(restructured_dist.event_shape_tensor()),
            output_structure)
        self.assertAllAssertsNested(
            noop_assert_fn,
            self.evaluate(restructured_dist.batch_shape_tensor()),
            output_structure)
        self.assertAllAssertsNested(
            noop_assert_fn,
            self.evaluate(
                restructured_dist.sample(seed=test_util.test_seed())))
예제 #15
0
def build_factored_surrogate_posterior(
        event_shape=None,
        constraining_bijectors=None,
        initial_unconstrained_loc=_sample_uniform_initial_loc,
        initial_unconstrained_scale=1e-2,
        trainable_distribution_fn=_build_trainable_normal_dist,
        seed=None,
        validate_args=False,
        name=None):
    """Builds a joint variational posterior that factors over model variables.

  By default, this method creates an independent trainable Normal distribution
  for each variable, transformed using a bijector (if provided) to
  match the support of that variable. This makes extremely strong
  assumptions about the posterior: that it is approximately normal (or
  transformed normal), and that all model variables are independent.

  Args:
    event_shape: `Tensor` shape, or nested structure of `Tensor` shapes,
      specifying the event shape(s) of the posterior variables.
    constraining_bijectors: Optional `tfb.Bijector` instance, or nested
      structure of such instances, defining support(s) of the posterior
      variables. The structure must match that of `event_shape` and may
      contain `None` values. A posterior variable will
      be modeled as `tfd.TransformedDistribution(underlying_dist,
      constraining_bijector)` if a corresponding constraining bijector is
      specified, otherwise it is modeled as supported on the
      unconstrained real line.
    initial_unconstrained_loc: Optional Python `callable` with signature
      `tensor = initial_unconstrained_loc(shape, seed)` used to sample
      real-valued initializations for the unconstrained representation of each
      variable. May alternately be a nested structure of
      `Tensor`s, giving specific initial locations for each variable; these
      must have structure matching `event_shape` and shapes determined by the
      inverse image of `event_shape` under `constraining_bijectors`, which
      may optionally be prefixed with a common batch shape.
      Default value: `functools.partial(tf.random.uniform,
        minval=-2., maxval=2., dtype=tf.float32)`.
    initial_unconstrained_scale: Optional scalar float `Tensor` initial
      scale for the unconstrained distributions, or a nested structure of
      `Tensor` initial scales for each variable.
      Default value: `1e-2`.
    trainable_distribution_fn: Optional Python `callable` with signature
      `trainable_dist = trainable_distribution_fn(initial_loc, initial_scale,
      event_ndims, validate_args)`. This is called for each model variable to
      build the corresponding factor in the surrogate posterior. It is expected
      that the distribution returned is supported on unconstrained real values.
      Default value: `functools.partial(
        tfp.vi.experimental.build_trainable_location_scale_distribution,
        distribution_fn=tfd.Normal)`, i.e., a trainable Normal distribution.
    seed: Python integer to seed the random number generator. This is used
      only when `initial_loc` is not specified.
    validate_args: Python `bool`. Whether to validate input with asserts. This
      imposes a runtime cost. If `validate_args` is `False`, and the inputs are
      invalid, correct behavior is not guaranteed.
      Default value: `False`.
    name: Python `str` name prefixed to ops created by this function.
      Default value: `None` (i.e., 'build_factored_surrogate_posterior').

  Returns:
    surrogate_posterior: A `tfd.Distribution` instance whose samples have
      shape and structure matching that of `event_shape` or `initial_loc`.

  ### Examples

  Consider a Gamma model with unknown parameters, expressed as a joint
  Distribution:

  ```python
  Root = tfd.JointDistributionCoroutine.Root
  def model_fn():
    concentration = yield Root(tfd.Exponential(1.))
    rate = yield Root(tfd.Exponential(1.))
    y = yield tfd.Sample(tfd.Gamma(concentration=concentration, rate=rate),
                         sample_shape=4)
  model = tfd.JointDistributionCoroutine(model_fn)
  ```

  Let's use variational inference to approximate the posterior over the
  data-generating parameters for some observed `y`. We'll build a
  surrogate posterior distribution by specifying the shapes of the latent
  `rate` and `concentration` parameters, and that both are constrained to
  be positive.

  ```python
  surrogate_posterior = tfp.vi.experimental.build_factored_surrogate_posterior(
    event_shape=model.event_shape_tensor()[:-1],  # Omit the observed `y`.
    constraining_bijectors=[tfb.Softplus(),   # Rate is positive.
                            tfb.Softplus()])  # Concentration is positive.
  ```

  This creates a trainable joint distribution, defined by variables in
  `surrogate_posterior.trainable_variables`. We use `fit_surrogate_posterior`
  to fit this distribution by minimizing a divergence to the true posterior.

  ```python
  y = [0.2, 0.5, 0.3, 0.7]
  losses = tfp.vi.fit_surrogate_posterior(
    lambda rate, concentration: model.log_prob([rate, concentration, y]),
    surrogate_posterior=surrogate_posterior,
    num_steps=100,
    optimizer=tf.optimizers.Adam(0.1),
    sample_size=10)

  # After optimization, samples from the surrogate will approximate
  # samples from the true posterior.
  samples = surrogate_posterior.sample(100)
  posterior_mean = [tf.reduce_mean(x) for x in samples]     # mean ~= [1.1, 2.1]
  posterior_std = [tf.math.reduce_std(x) for x in samples]  # std  ~= [0.3, 0.8]
  ```

  If we wanted to initialize the optimization at a specific location, we can
  specify one when we build the surrogate posterior. This function requires the
  initial location to be specified in *unconstrained* space; we do this by
  inverting the constraining bijectors (note this section also demonstrates the
  creation of a dict-structured model).

  ```python
  initial_loc = {'concentration': 0.4, 'rate': 0.2}
  constraining_bijectors={'concentration': tfb.Softplus(),   # Rate is positive.
                          'rate': tfb.Softplus()}   # Concentration is positive.
  initial_unconstrained_loc = tf.nest.map_fn(
    lambda b, x: b.inverse(x) if b is not None else x,
    constraining_bijectors, initial_loc)
  surrogate_posterior = tfp.vi.experimental.build_factored_surrogate_posterior(
    event_shape=tf.nest.map_fn(tf.shape, initial_loc),
    constraining_bijectors=constraining_bijectors,
    initial_unconstrained_loc=initial_unconstrained_state,
    initial_unconstrained_scale=1e-4)
  ```

  """

    with tf.name_scope(name or 'build_factored_surrogate_posterior'):
        seed = tfp_util.SeedStream(seed,
                                   salt='build_factored_surrogate_posterior')

        # Convert event shapes to Tensors.
        shallow_structure = _get_event_shape_shallow_structure(event_shape)
        event_shape = nest.map_structure_up_to(
            shallow_structure,
            lambda s: tf.convert_to_tensor(s, dtype=tf.int32), event_shape)
        flat_event_shapes = tf.nest.flatten(event_shape)

        # For simplicity, we'll work with flattened lists of state parts and
        # repack the structure at the end.
        if constraining_bijectors is not None:
            flat_bijectors = tf.nest.flatten(constraining_bijectors)
        else:
            flat_bijectors = [None for _ in flat_event_shapes]
        flat_unconstrained_event_shapes = [
            b.inverse_event_shape_tensor(s) if b is not None else s
            for s, b in zip(flat_event_shapes, flat_bijectors)
        ]

        # Construct initial locations for the internal unconstrained dists.
        if callable(
                initial_unconstrained_loc):  # Sample random initialization.
            flat_unconstrained_locs = [
                initial_unconstrained_loc(shape=s, seed=seed())
                for s in flat_unconstrained_event_shapes
            ]
        else:  # Use provided initialization.
            flat_unconstrained_locs = nest.flatten_up_to(
                shallow_structure,
                initial_unconstrained_loc,
                check_types=False)

        if nest.is_nested(initial_unconstrained_scale):
            flat_unconstrained_scales = nest.flatten_up_to(
                shallow_structure,
                initial_unconstrained_scale,
                check_types=False)
        else:
            flat_unconstrained_scales = [
                initial_unconstrained_scale for _ in flat_unconstrained_locs
            ]

        # Extract the rank of each event, so that we build distributions with the
        # correct event shapes.
        flat_unconstrained_event_ndims = [
            prefer_static.rank_from_shape(s)
            for s in flat_unconstrained_event_shapes
        ]

        # Build the component surrogate posteriors.
        flat_component_dists = []
        for initial_loc, initial_scale, event_ndims, bijector in zip(
                flat_unconstrained_locs, flat_unconstrained_scales,
                flat_unconstrained_event_ndims, flat_bijectors):
            unconstrained_dist = trainable_distribution_fn(
                initial_loc=initial_loc,
                initial_scale=initial_scale,
                event_ndims=event_ndims,
                validate_args=validate_args)
            flat_component_dists.append(
                bijector(unconstrained_dist
                         ) if bijector is not None else unconstrained_dist)
        component_distributions = tf.nest.pack_sequence_as(
            event_shape, flat_component_dists)

        # Return a `Distribution` object whose events have the specified structure.
        if hasattr(component_distributions,
                   'sample'):  # Tensor-valued posterior.
            return component_distributions
        elif hasattr(component_distributions,
                     'keys'):  # Dict-valued posterior.
            return tfd.JointDistributionNamed(component_distributions,
                                              validate_args=validate_args,
                                              name=name)
        else:
            return tfd.JointDistributionSequential(component_distributions,
                                                   validate_args=validate_args,
                                                   name=name)
예제 #16
0
def init_near_unconstrained_zero(
    model=None, constraining_bijector=None, event_shapes=None,
    event_shape_tensors=None, batch_shapes=None, batch_shape_tensors=None,
    dtypes=None):
  """Returns an initialization Distribution for starting a Markov chain.

  This initialization scheme follows Stan: we sample every latent
  independently, uniformly from -2 to 2 in its unconstrained space,
  and then transform into constrained space to construct an initial
  state that can be passed to `sample_chain` or other MCMC drivers.

  The argument signature is arranged to let the user pass either a
  `JointDistribution` describing their model, if it's in that form, or
  the essential information necessary for the sampling, namely a
  bijector (from unconstrained to constrained space) and the desired
  shape and dtype of each sample (specified in constrained space).

  Note: As currently implemented, this function has the limitation
  that the batch shape of the supplied model is ignored, but that
  could probably be generalized if needed.

  Args:
    model: A `Distribution` (typically a `JointDistribution`) giving the
      model to be initialized.  If supplied, it is queried for
      its default event space bijector, its event shape, and its dtype.
      If not supplied, those three elements must be supplied instead.
    constraining_bijector: A (typically multipart) `Bijector` giving
      the mapping from unconstrained to constrained space.  If
      supplied together with a `model`, acts as an override.  A nested
      structure of `Bijector`s is accepted, and interpreted as
      applying in parallel to a corresponding structure of state parts
      (see `JointMap` for details).
    event_shapes: A structure of shapes giving the (unconstrained)
      event space shape of the desired samples.  Must be an acceptable
      input to `constraining_bijector.inverse_event_shape`.  If
      supplied together with `model`, acts as an override.
    event_shape_tensors: A structure of tensors giving the (unconstrained)
      event space shape of the desired samples.  Must be an acceptable
      input to `constraining_bijector.inverse_event_shape_tensor`.  If
      supplied together with `model`, acts as an override. Required if any of
      `event_shapes` are not fully-defined.
    batch_shapes: A structure of shapes giving the batch shape of the desired
      samples.  If supplied together with `model`, acts as an override.  If
      unspecified, we assume scalar batch `[]`.
    batch_shape_tensors: A structure of tensors giving the batch shape of the
      desired samples.  If supplied together with `model`, acts as an override.
      Required if any of `batch_shapes` are not fully-defined.
    dtypes: A structure of dtypes giving the (unconstrained) dtypes of
      the desired samples.  Must be an acceptable input to
      `constraining_bijector.inverse_dtype`.  If supplied together
      with `model`, acts as an override.

  Returns:
    init_dist: A `Distribution` representing the initialization
      distribution, in constrained space.  Samples from this
      `Distribution` are valid initial states for a Markov chain
      targeting the model.

  #### Example

  Initialize 100 chains from the unconstrained -2, 2 distribution
  for a model expressed as a `JointDistributionCoroutine`:

  ```python
  @tfp.distributions.JointDistributionCoroutine
  def model():
    ...

  init_dist = tfp.experimental.mcmc.init_near_unconstrained_zero(model)
  states = tfp.mcmc.sample_chain(
    current_state=init_dist.sample(100, seed=[4, 8]),
    ...)
  ```

  """
  # Canonicalize arguments into the parts we need, namely
  # the constraining_bijector, the event_shapes, and the dtypes.
  if model is not None:
    # Got a Distribution model; treat other arguments as overrides if
    # present.
    if constraining_bijector is None:
      # pylint: disable=protected-access
      constraining_bijector = model.experimental_default_event_space_bijector()
    if event_shapes is None:
      event_shapes = model.event_shape
    if event_shape_tensors is None:
      event_shape_tensors = model.event_shape_tensor()
    if dtypes is None:
      dtypes = model.dtype
    if batch_shapes is None:
      batch_shapes = nest_util.broadcast_structure(dtypes, model.batch_shape)
    if batch_shape_tensors is None:
      batch_shape_tensors = nest_util.broadcast_structure(
          dtypes, model.batch_shape_tensor())

  else:
    if constraining_bijector is None or event_shapes is None or dtypes is None:
      msg = ('Must pass either a Distribution (typically a JointDistribution), '
             'or a bijector, a structure of event shapes, and a '
             'structure of dtypes')
      raise ValueError(msg)
    event_shapes_fully_defined = all(tensorshape_util.is_fully_defined(s)
                                     for s in tf.nest.flatten(event_shapes))
    if not event_shapes_fully_defined and event_shape_tensors is None:
      raise ValueError('Must specify `event_shape_tensors` when `event_shapes` '
                       f'are not fully-defined: {event_shapes}')
    if batch_shapes is None:
      batch_shapes = tf.TensorShape([])
    batch_shapes = nest_util.broadcast_structure(dtypes, batch_shapes)
    batch_shapes_fully_defined = all(tensorshape_util.is_fully_defined(s)
                                     for s in tf.nest.flatten(batch_shapes))
    if batch_shape_tensors is None:
      if not batch_shapes_fully_defined:
        raise ValueError(
            'Must specify `batch_shape_tensors` when `batch_shapes` are not '
            f'fully-defined: {batch_shapes}')
      batch_shape_tensors = tf.nest.map_structure(
          tf.convert_to_tensor, batch_shapes)

  # Interpret a structure of Bijectors as the joint multipart bijector.
  if not isinstance(constraining_bijector, tfb.Bijector):
    constraining_bijector = tfb.JointMap(constraining_bijector)

  # Actually initialize
  def one_term(event_shape, event_shape_tensor, batch_shape, batch_shape_tensor,
               dtype):
    if not tensorshape_util.is_fully_defined(event_shape):
      event_shape = event_shape_tensor
    result = tfd.Sample(
        tfd.Uniform(low=tf.constant(-2., dtype=dtype),
                    high=tf.constant(2., dtype=dtype)),
        sample_shape=event_shape)
    if not tensorshape_util.is_fully_defined(batch_shape):
      batch_shape = batch_shape_tensor
      needs_bcast = True
    else:  # Only batch broadcast when batch ndims > 0.
      needs_bcast = bool(tensorshape_util.as_list(batch_shape))
    if needs_bcast:
      result = tfd.BatchBroadcast(result, batch_shape)
    return result

  inv_shapes = constraining_bijector.inverse_event_shape(event_shapes)
  if event_shape_tensors is not None:
    inv_shape_tensors = constraining_bijector.inverse_event_shape_tensor(
        event_shape_tensors)
  else:
    inv_shape_tensors = tf.nest.map_structure(lambda _: None, inv_shapes)
  inv_dtypes = constraining_bijector.inverse_dtype(dtypes)
  terms = tf.nest.map_structure(
      one_term, inv_shapes, inv_shape_tensors, batch_shapes,
      batch_shape_tensors, inv_dtypes)
  unconstrained = tfb.pack_sequence_as(inv_shapes)(
      tfd.JointDistributionSequential(tf.nest.flatten(terms)))
  return tfd.TransformedDistribution(
      unconstrained, bijector=constraining_bijector)