예제 #1
0
    def __init__(self, target_log_prob_fn, new_state_fn=None, name=None):
        """Initializes this transition kernel.

    Args:
      target_log_prob_fn: Python callable which takes an argument like
        `current_state` (or `*current_state` if it's a list) and returns its
        (possibly unnormalized) log-density under the target distribution.
      new_state_fn: Python callable which takes a list of state parts and a
        seed; returns a same-type `list` of `Tensor`s, each being a perturbation
        of the input state parts. The perturbation distribution is assumed to be
        a symmetric distribution centered at the input state part.
        Default value: `None` which is mapped to
          `tfp.mcmc.random_walk_normal_fn()`.
      name: Python `str` name prefixed to Ops created by this function.
        Default value: `None` (i.e., 'rwm_kernel').

    Returns:
      next_state: Tensor or Python list of `Tensor`s representing the state(s)
        of the Markov chain(s) at each result step. Has same shape as
        `current_state`.
      kernel_results: `collections.namedtuple` of internal calculations used to
        advance the chain.

    Raises:
      ValueError: if there isn't one `scale` or a list with same length as
        `current_state`.
    """
        if new_state_fn is None:
            new_state_fn = random_walk_normal_fn()

        self._impl = metropolis_hastings.MetropolisHastings(
            inner_kernel=UncalibratedRandomWalk(
                target_log_prob_fn=target_log_prob_fn,
                new_state_fn=new_state_fn,
                name=name))
    def __init__(self,
                 target_log_prob_fn,
                 step_size,
                 num_leapfrog_steps,
                 momentum_distribution=None,
                 state_gradients_are_stopped=False,
                 step_size_update_fn=None,
                 store_parameters_in_results=False,
                 name=None):
        """Initializes this transition kernel.

    Args:
      target_log_prob_fn: Python callable which takes an argument like
        `current_state` (or `*current_state` if it's a list) and returns its
        (possibly unnormalized) log-density under the target distribution.
      step_size: `Tensor` or Python `list` of `Tensor`s representing the step
        size for the leapfrog integrator. Must broadcast with the shape of
        `current_state`. Larger step sizes lead to faster progress, but
        too-large step sizes make rejection exponentially more likely. When
        possible, it's often helpful to match per-variable step sizes to the
        standard deviations of the target distribution in each variable.
      num_leapfrog_steps: Integer number of steps to run the leapfrog integrator
        for. Total progress per HMC step is roughly proportional to
        `step_size * num_leapfrog_steps`.
      momentum_distribution: A `tfp.distributions.Distribution` instance to draw
        momentum from. Defaults to isotropic normal distributions.
      state_gradients_are_stopped: Python `bool` indicating that the proposed
        new state be run through `tf.stop_gradient`. This is particularly useful
        when combining optimization over samples from the HMC chain.
        Default value: `False` (i.e., do not apply `stop_gradient`).
      step_size_update_fn: Python `callable` taking current `step_size`
        (typically a `tf.Variable`) and `kernel_results` (typically
        `collections.namedtuple`) and returns updated step_size (`Tensor`s).
        Default value: `None` (i.e., do not update `step_size` automatically).
      store_parameters_in_results: If `True`, then `step_size` and
        `num_leapfrog_steps` are written to and read from eponymous fields in
        the kernel results objects returned from `one_step` and
        `bootstrap_results`. This allows wrapper kernels to adjust those
        parameters on the fly. This is incompatible with `step_size_update_fn`,
        which must be set to `None`.
      name: Python `str` name prefixed to Ops created by this function.
        Default value: `None` (i.e., 'hmc_kernel').
    """
        if step_size_update_fn and store_parameters_in_results:
            raise ValueError('It is invalid to simultaneously specify '
                             '`step_size_update_fn` and set '
                             '`store_parameters_in_results` to `True`.')
        self._impl = metropolis_hastings.MetropolisHastings(
            inner_kernel=UncalibratedPreconditionedHamiltonianMonteCarlo(
                target_log_prob_fn=target_log_prob_fn,
                step_size=step_size,
                num_leapfrog_steps=num_leapfrog_steps,
                state_gradients_are_stopped=state_gradients_are_stopped,
                momentum_distribution=momentum_distribution,
                name=name or 'hmc_kernel',
                store_parameters_in_results=store_parameters_in_results))
        self._parameters = self._impl.inner_kernel.parameters.copy()
        self._parameters.pop('seed',
                             None)  # TODO(b/159636942): Remove this line.
        self._parameters['step_size_update_fn'] = step_size_update_fn
예제 #3
0
    def __init__(self,
                 target_log_prob_fn,
                 step_size,
                 volatility_fn=None,
                 parallel_iterations=10,
                 experimental_shard_axis_names=None,
                 name=None):
        """Initializes MALA transition kernel.

    Args:
      target_log_prob_fn: Python callable which takes an argument like
        `current_state` (or `*current_state` if it's a list) and returns its
        (possibly unnormalized) log-density under the target distribution.
      step_size: `Tensor` or Python `list` of `Tensor`s representing the step
        size for the leapfrog integrator. Must broadcast with the shape of
        `current_state`. Larger step sizes lead to faster progress, but
        too-large step sizes make rejection exponentially more likely. When
        possible, it's often helpful to match per-variable step sizes to the
        standard deviations of the target distribution in each variable.
      volatility_fn: Python callable which takes an argument like
        `current_state` (or `*current_state` if it's a list) and returns
        volatility value at `current_state`. Should return a `Tensor` or Python
        `list` of `Tensor`s that must broadcast with the shape of
        `current_state` Defaults to the identity function.
      parallel_iterations: the number of coordinates for which the gradients of
        the volatility matrix `volatility_fn` can be computed in parallel.
        Default value: `None` (i.e., use system default).
      experimental_shard_axis_names: A structure of string names indicating how
        members of the state are sharded.
      name: Python `str` name prefixed to Ops created by this function.
        Default value: `None` (i.e., 'mala_kernel').

    Returns:
      next_state: Tensor or Python list of `Tensor`s representing the state(s)
        of the Markov chain(s) at each result step. Has same shape as
        `current_state`.
      kernel_results: `collections.namedtuple` of internal calculations used to
        advance the chain.

    Raises:
      ValueError: if there isn't one `step_size` or a list with same length as
        `current_state`.
      TypeError: if `volatility_fn` is not callable.
    """
        impl = metropolis_hastings.MetropolisHastings(
            inner_kernel=UncalibratedLangevin(
                target_log_prob_fn=target_log_prob_fn,
                step_size=step_size,
                volatility_fn=volatility_fn,
                parallel_iterations=parallel_iterations,
                name=name)).experimental_with_shard_axes(
                    experimental_shard_axis_names)

        self._impl = impl
        parameters = impl.inner_kernel.parameters.copy()
        # Remove `compute_acceptance` parameter as this is not a MALA kernel
        # `__init__` parameter.
        del parameters['compute_acceptance']
        self._parameters = parameters
예제 #4
0
  def __init__(self,
               target_log_prob_fn,
               step_size,
               num_leapfrog_steps,
               state_gradients_are_stopped=False,
               step_size_update_fn=None,
               seed=None,
               name=None):
    """Initializes this transition kernel.

    Args:
      target_log_prob_fn: Python callable which takes an argument like
        `current_state` (or `*current_state` if it's a list) and returns its
        (possibly unnormalized) log-density under the target distribution.
      step_size: `Tensor` or Python `list` of `Tensor`s representing the step
        size for the leapfrog integrator. Must broadcast with the shape of
        `current_state`. Larger step sizes lead to faster progress, but
        too-large step sizes make rejection exponentially more likely. When
        possible, it's often helpful to match per-variable step sizes to the
        standard deviations of the target distribution in each variable.
      num_leapfrog_steps: Integer number of steps to run the leapfrog integrator
        for. Total progress per HMC step is roughly proportional to
        `step_size * num_leapfrog_steps`.
      state_gradients_are_stopped: Python `bool` indicating that the proposed
        new state be run through `tf.stop_gradient`. This is particularly useful
        when combining optimization over samples from the HMC chain.
        Default value: `False` (i.e., do not apply `stop_gradient`).
      step_size_update_fn: Python `callable` taking current `step_size`
        (typically a `tf.Variable`) and `kernel_results` (typically
        `collections.namedtuple`) and returns updated step_size (`Tensor`s).
        Default value: `None` (i.e., do not update `step_size` automatically).
      seed: Python integer to seed the random number generator.
      name: Python `str` name prefixed to Ops created by this function.
        Default value: `None` (i.e., 'hmc_kernel').

    Returns:
      next_state: Tensor or Python list of `Tensor`s representing the state(s)
        of the Markov chain(s) at each result step. Has same shape as
        `current_state`.
      kernel_results: `collections.namedtuple` of internal calculations used to
        advance the chain.

    Raises:
      ValueError: if there isn't one `step_size` or a list with same length as
        `current_state`.
    """
    impl = metropolis_hastings.MetropolisHastings(
        inner_kernel=UncalibratedHamiltonianMonteCarlo(
            target_log_prob_fn=target_log_prob_fn,
            step_size=step_size,
            num_leapfrog_steps=num_leapfrog_steps,
            state_gradients_are_stopped=state_gradients_are_stopped,
            seed=seed,
            name='hmc_kernel' if name is None else name),
        seed=seed)
    parameters = impl.inner_kernel.parameters.copy()
    parameters['step_size_update_fn'] = step_size_update_fn
    self._impl = impl
    self._parameters = parameters
예제 #5
0
    def __init__(self,
                 target_log_prob_fn,
                 step_size,
                 num_leapfrog_steps,
                 seed=None,
                 name=None):
        """Initializes this transition kernel.

    Args:
      target_log_prob_fn: Python callable which takes an argument like
        `current_state` (or `*current_state` if it's a list) and returns its
        (possibly unnormalized) log-density under the target distribution.
      step_size: `Tensor` or Python `list` of `Tensor`s representing the step
        size for the leapfrog integrator. Must broadcast with the shape of
        `current_state`. Larger step sizes lead to faster progress, but
        too-large step sizes make rejection exponentially more likely. When
        possible, it's often helpful to match per-variable step sizes to the
        standard deviations of the target distribution in each variable.
      num_leapfrog_steps: Integer number of steps to run the leapfrog integrator
        for. Total progress per HMC step is roughly proportional to `step_size *
        num_leapfrog_steps`.
      seed: Python integer to seed the random number generator.
      name: Python `str` name prefixed to Ops created by this function.
        Default value: `None` (i.e., 'hmc_kernel').

    Returns:
      next_state: Tensor or Python list of `Tensor`s representing the state(s)
        of the Markov chain(s) at each result step. Has same shape as
        `current_state`.
      kernel_results: `collections.namedtuple` of internal calculations used to
        advance the chain.

    Raises:
      ValueError: if there isn't one `step_size` or a list with same length as
        `current_state`.
    """
        self._target_log_prob_fn = target_log_prob_fn
        self._step_size = step_size
        self._num_leapfrog_steps = num_leapfrog_steps
        self._seed = seed
        self._name = name
        self._hmc_impl = metropolis_hastings.MetropolisHastings(
            inner_kernel=UncalibratedHamiltonianMonteCarlo(
                target_log_prob_fn=target_log_prob_fn,
                step_size=step_size,
                num_leapfrog_steps=num_leapfrog_steps,
                seed=seed,
                name=name),
            seed=seed)
예제 #6
0
    def __init__(self,
                 target_log_prob_fn,
                 step_size,
                 volatility_fn=None,
                 seed=None,
                 name=None):
        """Initializes MALA transition kernel.

    Args:
      target_log_prob_fn: Python callable which takes an argument like
        `current_state` (or `*current_state` if it's a list) and returns its
        (possibly unnormalized) log-density under the target distribution.
      step_size: `Tensor` or Python `list` of `Tensor`s representing the step
        size for the leapfrog integrator. Must broadcast with the shape of
        `current_state`. Larger step sizes lead to faster progress, but
        too-large step sizes make rejection exponentially more likely. When
        possible, it's often helpful to match per-variable step sizes to the
        standard deviations of the target distribution in each variable.
      volatility_fn: Python callable which takes an argument like
        `current_state` (or `*current_state` if it's a list) and returns
        volatility value at `current_state`. Should return a `Tensor` or Python
        `list` of `Tensor`s that must broadcast with the shape of
        `current_state` Defaults to the identity function.
      seed: Python integer to seed the random number generator.
        Default value: `None` (i.e., no seed).
      name: Python `str` name prefixed to Ops created by this function.
        Default value: `None` (i.e., 'mala_kernel').

    Returns:
      next_state: Tensor or Python list of `Tensor`s representing the state(s)
        of the Markov chain(s) at each result step. Has same shape as
        `current_state`.
      kernel_results: `collections.namedtuple` of internal calculations used to
        advance the chain.

    Raises:
      ValueError: if there isn't one `step_size` or a list with same length as
        `current_state`.
      TypeError: if `volatility_fn` is not callable.
    """
        self._impl = metropolis_hastings.MetropolisHastings(
            inner_kernel=UncalibratedLangevin(
                target_log_prob_fn=target_log_prob_fn,
                step_size=step_size,
                volatility_fn=volatility_fn,
                seed=seed,
                name=name),
            seed=seed)
예제 #7
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        with tf.name_scope(
                mcmc_util.make_name(self.name,
                                    "AdaptiveRandomWalkMetropolisHastings",
                                    "one_step")):
            with tf.name_scope("initialize"):
                if mcmc_util.is_list_like(current_state):
                    current_state_parts = list(current_state)
                else:
                    current_state_parts = [current_state]
                current_state_parts = [
                    tf.convert_to_tensor(s, name="current_state")
                    for s in current_state_parts
                ]

            # Note 'covariance_scaling' and 'accum_covar' are updated every step but
            # 'covariance' is not updated until 'num_steps' >= 'covariance_burnin'.
            num_steps = self.extra_getter_fn(previous_kernel_results).num_steps
            # for parallel processing efficiency use gather() rather than cond()?
            previous_is_adaptive = self.extra_getter_fn(
                previous_kernel_results).is_adaptive
            current_covariance_scaling = tf.gather(
                tf.stack(
                    [
                        self.extra_getter_fn(
                            previous_kernel_results).covariance_scaling,
                        self.update_covariance_scaling(previous_kernel_results,
                                                       num_steps),
                    ],
                    axis=-1,
                ),
                previous_is_adaptive,
                batch_dims=1,
                axis=1,
            )
            previous_accum_covar = self.extra_getter_fn(
                previous_kernel_results).running_covariance
            current_accum_covar = self.running_covar.update(
                state=previous_accum_covar, new_sample=current_state_parts)

            previous_covariance = self.extra_getter_fn(
                previous_kernel_results).covariance
            current_covariance = tf.gather(
                [
                    previous_covariance,
                    self.running_covar.finalize(current_accum_covar, ddof=1),
                ],
                tf.cast(
                    num_steps >= self.covariance_burnin,
                    dtype=tf.dtypes.int32,
                ),
            )

            current_scaled_covariance = tf.squeeze(
                tf.expand_dims(current_covariance_scaling, axis=1) *
                tf.stack([current_covariance]),
                axis=0,
            )

            current_scaled_covariance = tf.unstack(current_scaled_covariance)

            if mcmc_util.is_list_like(current_scaled_covariance):
                current_scaled_covariance_parts = list(
                    current_scaled_covariance)
            else:
                current_scaled_covariance_parts = [current_scaled_covariance]
            current_scaled_covariance_parts = [
                tf.convert_to_tensor(s, name="current_scaled_covariance")
                for s in current_scaled_covariance_parts
            ]

            current_is_adaptive = self.u.sample(seed=self.seed)
            self._impl = metropolis_hastings.MetropolisHastings(
                inner_kernel=random_walk_metropolis.UncalibratedRandomWalk(
                    target_log_prob_fn=self.target_log_prob_fn,
                    new_state_fn=random_walk_mvnorm_fn(
                        covariance=current_scaled_covariance_parts,
                        pu=self.pu,
                        fixed_variance=self.fixed_variance,
                        is_adaptive=current_is_adaptive,
                        name=self.name,
                    ),
                    name=self.name,
                ),
                name=self.name,
            )
            new_state, new_inner_results = self._impl.one_step(
                current_state, previous_kernel_results)
            new_inner_results = self.extra_setter_fn(
                new_inner_results,
                num_steps + 1,
                tf.squeeze(current_covariance_scaling, axis=1),
                current_covariance,
                current_accum_covar,
                current_is_adaptive,
            )
            return [new_state, new_inner_results]
예제 #8
0
    def __init__(
        self,
        target_log_prob_fn,
        initial_state,
        initial_covariance=None,
        initial_covariance_scaling=2.38**2,
        covariance_scaling_reducer=0.7,
        covariance_scaling_limiter=0.01,
        covariance_burnin=100,
        target_accept_ratio=0.234,
        pu=0.95,
        fixed_variance=0.01,
        extra_getter_fn=rwm_extra_getter_fn,
        extra_setter_fn=rwm_extra_setter_fn,
        log_accept_prob_getter_fn=rwm_log_accept_prob_getter_fn,
        seed=None,
        name=None,
    ):
        """Initializes this transition kernel.

    Args:
      target_log_prob_fn: Python callable which takes an argument like
        `current_state` and returns its (possibly unnormalized) log-density
        under the target distribution.
      initial_state: Python `list` of `Tensor`s representing the initial
        state of each parameter.
      initial_covariance: Python `list` of `Tensor`s representing the
        initial covariance of the proposal. The `initial_covariance` and 
        `initial_state` should have identical `dtype`s and 
        batch dimensions.  If `initial_covariance` is `None` then it 
        initialized to a Python `list` of `Tensor`s where each tensor is 
        the identity matrix multiplied by 0.001; the `list` structure will
        be identical to `initial_state`. The covariance matrix is tuned
        during the evolution of the MCMC chain.
        Default value: `None`.
      initial_covariance_scaling: Python floating point number representing a 
        the initial value of the `covariance_scaling`. The value of 
        `covariance_scaling` is tuned during the evolution of the MCMC chain.
        Let d represent the number of parameters e.g. as given by the 
        `initial_state`. The ratio given by the `covariance_scaling` divided
        by d is used to multiply the running covariance. The covariance
        scaling factor multiplied by the covariance matrix is used in the
        proposal at each step.
        Default value: 2.38**2.
      covariance_scaling_reducer: Python floating point number, bounded over the 
        range (0.5,1.0], representing the constant factor used during the
        adaptation of the `covariance_scaling`. 
        Default value: 0.7.
      covariance_scaling_limiter: Python floating point number, bounded between 
        0.0 and 1.0, which places a limit on the maximum amount the
        `covariance_scaling` value can be purturbed at each interaction of the 
        MCMC chain.
        Default value: 0.01.
      covariance_burnin: Python integer number of steps to take before starting to 
        compute the running covariance.
        Default value: 100.
      target_accept_ratio: Python floating point number, bounded between 0.0 and 1.0,
        representing the target acceptance probability of the 
        Metropolis–Hastings algorithm.
        Default value: 0.234.
      pu: Python floating point number, bounded between 0.0 and 1.0, representing the 
        bounded convergence parameter.  See `random_walk_mvnorm_fn()` for further
        details.
        Default value: 0.95.
      fixed_variance: Python floating point number representing the variance of
        the fixed proposal distribution. See `random_walk_mvnorm_fn` for 
        further details.
        Default value: 0.01.
      extra_getter_fn: A callable with the signature
        `(kernel_results) -> extra` where `kernel_results` are the results
        of the `inner_kernel`, and `extra` is a nested collection of 
        `Tensor`s.
      extra_setter_fn: A callable with the signature
        `(kernel_results, args) -> new_kernel_results` where
        `kernel_results` are the results of the `inner_kernel`, `args`
        are a nested collection of `Tensor`s with the same
        structure as returned by the `extra_getter_fn`, and
        `new_kernel_results` are a copy of `kernel_results` with `args`
        in the `extra` field set.
      log_accept_prob_getter_fn: A callable with the signature
        `(kernel_results) -> log_accept_prob` where `kernel_results` are the
        results of the `inner_kernel`, and `log_accept_prob` is either a 
        a scalar, or has shape [num_chains].
      seed: Python integer to seed the random number generator.
        Default value: `None`.
      name: Python `str` name prefixed to Ops created by this function.
        Default value: `None`.

    Returns:
      next_state: Tensor or list of `Tensor`s representing the state(s)
        of the Markov chain(s) at each result step. Has same shape as
        `current_state`.
      kernel_results: `collections.namedtuple` of internal calculations used to
        advance the chain.

    Raises:
      ValueError: if `initial_covariance_scaling` is less than or equal
        to 0.0.
      ValueError: if `covariance_scaling_reducer` is less than or equal
        to 0.5 or greater than 1.0.
      ValueError: if `covariance_scaling_limiter` is less than 0.0 or
        greater than 1.0.
      ValueError: if `covariance_burnin` is less than 0.
      ValueError: if `target_accept_ratio` is less than 0.0 or
        greater than 1.0.
      ValueError: if `pu` is less than 0.0 or greater than 1.0.
      ValueError: if `fixed_variance` is less than 0.0.
    """
        with tf.name_scope(
                mcmc_util.make_name(name,
                                    "AdaptiveRandomWalkMetropolisHastings",
                                    "__init__")) as name:
            if initial_covariance_scaling <= 0.0:
                raise ValueError(
                    "`{}` must be a `float` greater than 0.0".format(
                        "initial_covariance_scaling"))
            if covariance_scaling_reducer <= 0.5 or covariance_scaling_reducer > 1.0:
                raise ValueError(
                    "`{}` must be a `float` greater than 0.5 and less than or equal to 1.0."
                    .format("covariance_scaling_reducer"))
            if covariance_scaling_limiter < 0.0 or covariance_scaling_limiter > 1.0:
                raise ValueError(
                    "`{}` must be a `float` between 0.0 and 1.0.".format(
                        "covariance_scaling_limiter"))
            if covariance_burnin < 0:
                raise ValueError(
                    "`{}` must be a `integer` greater or equal to 0.".format(
                        "covariance_burnin"))
            if target_accept_ratio <= 0.0 or target_accept_ratio > 1.0:
                raise ValueError(
                    "`{}` must be a `float` between 0.0 and 1.0.".format(
                        "target_accept_ratio"))
            if pu < 0.0 or pu > 1.0:
                raise ValueError(
                    "`{}` must be a `float` between 0.0 and 1.0.".format("pu"))
            if fixed_variance < 0.0:
                raise ValueError(
                    "`{}` must be a `float` greater than 0.0.".format(
                        "fixed_variance"))

        if mcmc_util.is_list_like(initial_state):
            initial_state_parts = list(initial_state)
        else:
            initial_state_parts = [initial_state]
        initial_state_parts = [
            tf.convert_to_tensor(s, name="initial_state")
            for s in initial_state_parts
        ]

        shape = tf.stack(initial_state_parts).shape
        dtype = dtype_util.base_dtype(tf.stack(initial_state_parts).dtype)

        if initial_covariance is None:
            initial_covariance = 0.001 * tf.eye(
                num_rows=shape[-1], dtype=dtype, batch_shape=[shape[0]])
        else:
            initial_covariance = tf.stack(initial_covariance)

        if mcmc_util.is_list_like(initial_covariance):
            initial_covariance_parts = list(initial_covariance)
        else:
            initial_covariance_parts = [initial_covariance]
        initial_covariance_parts = [
            tf.convert_to_tensor(s, name="initial_covariance")
            for s in initial_covariance_parts
        ]

        self._running_covar = stats.RunningCovariance(shape=(1, shape[-1]),
                                                      dtype=dtype,
                                                      event_ndims=1)
        self._accum_covar = self._running_covar.initialize()

        probs = tf.expand_dims(tf.ones([shape[0]], dtype=dtype) * pu, axis=1)
        self._u = Bernoulli(probs=probs, dtype=tf.dtypes.int32)
        self._initial_u = tf.zeros_like(self._u.sample(seed=seed),
                                        dtype=tf.dtypes.int32)

        name = mcmc_util.make_name(name,
                                   "AdaptiveRandomWalkMetropolisHastings", "")
        seed_stream = SeedStream(seed,
                                 salt="AdaptiveRandomWalkMetropolisHastings")

        self._parameters = dict(
            target_log_prob_fn=target_log_prob_fn,
            initial_state=initial_state,
            initial_covariance=initial_covariance,
            initial_covariance_scaling=initial_covariance_scaling,
            covariance_scaling_reducer=covariance_scaling_reducer,
            covariance_scaling_limiter=covariance_scaling_limiter,
            covariance_burnin=covariance_burnin,
            target_accept_ratio=target_accept_ratio,
            pu=pu,
            fixed_variance=fixed_variance,
            extra_getter_fn=extra_getter_fn,
            extra_setter_fn=extra_setter_fn,
            log_accept_prob_getter_fn=log_accept_prob_getter_fn,
            seed=seed,
            name=name,
        )
        self._impl = metropolis_hastings.MetropolisHastings(
            inner_kernel=random_walk_metropolis.UncalibratedRandomWalk(
                target_log_prob_fn=target_log_prob_fn,
                new_state_fn=random_walk_mvnorm_fn(
                    covariance=initial_covariance_parts,
                    pu=pu,
                    fixed_variance=fixed_variance,
                    is_adaptive=self._initial_u,
                    name=name,
                ),
                name=name,
            ),
            name=name,
        )