コード例 #1
0
def _prepare_args(target_log_prob_fn, state, step_size,
                  target_log_prob=None, maybe_expand=False,
                  description='target_log_prob'):
  """Processes input args to meet list-like assumptions."""
  state_parts = list(state) if mcmc_util.is_list_like(state) else [state]
  state_parts = [
      tf.convert_to_tensor(value=s, name='current_state') for s in state_parts
  ]

  target_log_prob = _maybe_call_fn(
      target_log_prob_fn,
      state_parts,
      target_log_prob,
      description)
  step_sizes = (list(step_size) if mcmc_util.is_list_like(step_size)
                else [step_size])
  step_sizes = [
      tf.convert_to_tensor(
          value=s, name='step_size', dtype=target_log_prob.dtype)
      for s in step_sizes
  ]
  if len(step_sizes) == 1:
    step_sizes *= len(state_parts)
  if len(state_parts) != len(step_sizes):
    raise ValueError('There should be exactly one `step_size` or it should '
                     'have same length as `current_state`.')
  def maybe_flatten(x):
    return x if maybe_expand or mcmc_util.is_list_like(state) else x[0]
  return [
      maybe_flatten(state_parts),
      maybe_flatten(step_sizes),
      target_log_prob
  ]
コード例 #2
0
        def _loop_body(iter_, ais_weights_lower, ais_weights_upper,
                       current_state):
            """Closure which implements `tf.while_loop` body."""
            fcurrent = _make_convex_combined_log_prob_fn(iter_ - 1)
            fnext = _make_convex_combined_log_prob_fn(iter_)

            x = (current_state
                 if mcmc_util.is_list_like(current_state) else [current_state])
            fcurrent_log_prob = fcurrent(*x)
            fnext_log_prob = fnext(*x)
            ais_weights_lower += fnext_log_prob - fcurrent_log_prob

            q_inter = _find_exact_intermediate_density(iter_)
            next_state = q_inter.sample(x[0].shape[0])

            x = (next_state
                 if mcmc_util.is_list_like(next_state) else [next_state])
            fcurrent_log_prob = fcurrent(*x)
            fnext_log_prob = fnext(*x)
            ais_weights_upper += fnext_log_prob - fcurrent_log_prob

            return [
                iter_ + 1,
                ais_weights_lower,
                ais_weights_upper,
                next_state,
            ]
コード例 #3
0
ファイル: gibbs_kernel.py プロジェクト: ThomFNC/covid19uk
    def one_step(self, current_state, previous_results, seed=None):
        """We iterate over the state elements, calling each kernel in turn.

        The `target_log_prob` is forwarded to the next `previous_results`
        such that each kernel has a current `target_log_prob` value.
        Transformations are automatically performed if the kernel is of
        type tfp.mcmc.TransformedTransitionKernel.

        In graph and XLA modes, the for loop should be unrolled.
        """
        if mcmc_util.is_list_like(current_state):
            state_parts = list(current_state)
        else:
            state_parts = [current_state]

        state_parts = [
            tf.convert_to_tensor(s, name="current_state") for s in state_parts
        ]

        next_results = []
        untransformed_target_log_prob = previous_results.target_log_prob
        for i, (state_part_idx, kernel_fn) in enumerate(self.kernel_list):

            def target_log_prob_fn(state_part):
                state_parts[state_part_idx  # pylint: disable=cell-var-from-loop
                            ] = state_part
                return self.target_log_prob_fn(*state_parts)

            kernel = kernel_fn(target_log_prob_fn, state_parts)

            previous_kernel_results = update_target_log_prob(
                previous_results.inner_results[i],
                maybe_transform_value(
                    tlp=untransformed_target_log_prob,
                    state=state_parts[state_part_idx],
                    kernel=kernel,
                    direction="inverse",
                ),
            )

            state_parts[state_part_idx], next_kernel_results = kernel.one_step(
                state_parts[state_part_idx], previous_kernel_results, seed)

            next_results.append(next_kernel_results)
            untransformed_target_log_prob = maybe_transform_value(
                tlp=get_target_log_prob(next_kernel_results),
                state=state_parts[state_part_idx],
                kernel=kernel,
                direction="forward",
            )

        return (
            state_parts
            if mcmc_util.is_list_like(current_state) else state_parts[0],
            GibbsKernelResults(
                target_log_prob=untransformed_target_log_prob,
                inner_results=next_results,
            ),
        )
コード例 #4
0
    def bootstrap_results(self, init_state):
        """Returns an object with the same type as returned by `one_step`.

    Args:
      init_state: `Tensor` or Python `list` of `Tensor`s representing the
        initial state(s) of the Markov chain(s).

    Returns:
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """
        with tf1.name_scope(name=mcmc_util.make_name(self.name, 'remc',
                                                     'bootstrap_results'),
                            values=[init_state]):
            replica_results = [
                self.replica_kernels[i].bootstrap_results(init_state)
                for i in range(self.num_replica)
            ]

            init_state_parts = (list(init_state)
                                if mcmc_util.is_list_like(init_state) else
                                [init_state])

            # Convert all states parts to tensor...
            replica_states = [[
                tf.convert_to_tensor(value=s) for s in init_state_parts
            ] for i in range(self.num_replica)]

            if not mcmc_util.is_list_like(init_state):
                replica_states = [s[0] for s in replica_states]

            batch_plus_replica_shape = tf.concat([
                tf.shape(_get_field(replica_results[0], 'target_log_prob')),
                [self.num_replica - 1],
            ],
                                                 axis=0)

            if self._exchange_between_adjacent_only:
                is_exchange_proposed = tf.cast(
                    tf.zeros([self.num_replica - 1]), tf.bool)
                is_exchange_accepted = tf.cast(
                    tf.zeros(batch_plus_replica_shape), tf.bool)
            else:
                is_exchange_proposed = tf.convert_to_tensor(np.nan)
                is_exchange_accepted = tf.convert_to_tensor(np.nan)

            return ReplicaExchangeMCKernelResults(
                replica_states=replica_states,
                replica_results=replica_results,
                sampled_replica_states=replica_states,
                sampled_replica_results=replica_results,
                is_exchange_proposed=is_exchange_proposed,
                is_exchange_accepted=is_exchange_accepted,
            )
コード例 #5
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        """Runs one iteration of the Transformed Kernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s
        representing the current state(s) of the Markov chain(s),
        _after_ application of `bijector.forward`. The first `r`
        dimensions index independent chains,
        `r = tf.rank(target_log_prob_fn(*current_state))`. The
        `inner_kernel.one_step` does not actually use `current_state`,
        rather it takes as input
        `previous_kernel_results.transformed_state` (because
        `TransformedTransitionKernel` creates a copy of the input
        inner_kernel with a modified `target_log_prob_fn` which
        internally applies the `bijector.forward`).
      previous_kernel_results: `collections.namedtuple` containing `Tensor`s
        representing values from previous calls to this function (or from the
        `bootstrap_results` function.)
      seed: Optional, a seed for reproducible sampling.

    Returns:
      next_state: Tensor or Python list of `Tensor`s representing the state(s)
        of the Markov chain(s) after taking exactly one step. Has same type and
        shape as `current_state`.
      kernel_results: `collections.namedtuple` of internal calculations used to
        advance the chain.
    """
        with tf.name_scope(
                mcmc_util.make_name(self.name, 'transformed_kernel',
                                    'one_step')):
            inner_kwargs = {} if seed is None else dict(seed=seed)
            transformed_prev_state = previous_kernel_results.transformed_state
            transformed_next_state, kernel_results = self._inner_kernel.one_step(
                transformed_prev_state, previous_kernel_results.inner_results,
                **inner_kwargs)
            transformed_next_state_parts = (
                transformed_next_state
                if mcmc_util.is_list_like(transformed_next_state) else
                [transformed_next_state])
            next_state_parts = self._transform_unconstrained_to_target_support(
                transformed_next_state_parts)
            next_state = (next_state_parts
                          if mcmc_util.is_list_like(transformed_next_state)
                          else next_state_parts[0])
            if mcmc_util.is_list_like(transformed_prev_state):
                transformed_next_state = tf.nest.pack_sequence_as(
                    transformed_prev_state, transformed_next_state)
            kernel_results = TransformedTransitionKernelResults(
                transformed_state=transformed_next_state,
                inner_results=kernel_results)
            return next_state, kernel_results
コード例 #6
0
def _maybe_call_volatility_fn_and_grads(volatility_fn,
                                        state,
                                        volatility_fn_results=None,
                                        grads_volatility_fn=None,
                                        sample_shape=None,
                                        parallel_iterations=10):
  """Helper which computes `volatility_fn` results and grads, if needed."""
  state_parts = list(state) if mcmc_util.is_list_like(state) else [state]
  needs_volatility_fn_gradients = grads_volatility_fn is None

  # Convert `volatility_fn_results` to a list
  if volatility_fn_results is None:
    volatility_fn_results = volatility_fn(*state_parts)

  volatility_fn_results = (list(volatility_fn_results)
                           if mcmc_util.is_list_like(volatility_fn_results)
                           else [volatility_fn_results])
  if len(volatility_fn_results) == 1:
    volatility_fn_results *= len(state_parts)
  if len(state_parts) != len(volatility_fn_results):
    raise ValueError('`volatility_fn` should return a tensor or a list '
                     'of the same length as `current_state`.')

  # The shape of 'volatility_parts' needs to have the number of chains as a
  # leading dimension. For determinism we broadcast 'volatility_parts' to the
  # shape of `state_parts` since each dimension of `state_parts` could have a
  # different volatility value.

  volatility_fn_results = _maybe_broadcast_volatility(volatility_fn_results,
                                                      state_parts)
  if grads_volatility_fn is None:
    [
        _,
        grads_volatility_fn,
    ] = diag_jacobian(
        xs=state_parts,
        ys=volatility_fn_results,
        sample_shape=sample_shape,
        parallel_iterations=parallel_iterations,
        fn=volatility_fn)

  # Compute gradient of `volatility_parts**2`
  if needs_volatility_fn_gradients:
    grads_volatility_fn = [
        2. * g * volatility if g is not None else tf.zeros_like(
            fn_arg, dtype=fn_arg.dtype.base_dtype)
        for g, volatility, fn_arg in zip(
            grads_volatility_fn, volatility_fn_results, state_parts)
    ]

  return volatility_fn_results, grads_volatility_fn
コード例 #7
0
    def _fn(state_parts, seed):
        """Adds a uniform perturbation to the input state.

    Args:
      state_parts: A list of `Tensor`s of any shape and real dtype representing
        the state parts of the `current_state` of the Markov chain.
      seed: `int` or None. The random seed for this `Op`. If `None`, no seed is
        applied.
        Default value: `None`.

    Returns:
      perturbed_state_parts: A Python `list` of The `Tensor`s. Has the same
        shape and type as the `state_parts`.

    Raises:
      ValueError: if `scale` does not broadcast with `state_parts`.
    """
        with tf.compat.v1.name_scope(name,
                                     'random_walk_uniform_fn',
                                     values=[state_parts, scale, seed]):
            scales = scale if mcmc_util.is_list_like(scale) else [scale]
            if len(scales) == 1:
                scales *= len(state_parts)
            if len(state_parts) != len(scales):
                raise ValueError('`scale` must broadcast with `state_parts`.')
            seed_stream = SeedStream(seed, salt='RandomWalkUniformFn')
            next_state_parts = [
                tf.random.uniform(minval=state_part - scale_part,
                                  maxval=state_part + scale_part,
                                  shape=tf.shape(input=state_part),
                                  dtype=state_part.dtype.base_dtype,
                                  seed=seed_stream())
                for scale_part, state_part in zip(scales, state_parts)
            ]
            return next_state_parts
コード例 #8
0
def _prepare_args(target_log_prob_fn,
                  state,
                  step_size,
                  momentum_distribution,
                  target_log_prob=None,
                  grads_target_log_prob=None,
                  maybe_expand=False,
                  state_gradients_are_stopped=False):
    """Helper which processes input args to meet list-like assumptions."""
    state_parts, _ = mcmc_util.prepare_state_parts(state, name='current_state')
    if state_gradients_are_stopped:
        state_parts = [tf.stop_gradient(x) for x in state_parts]
    target_log_prob, grads_target_log_prob = mcmc_util.maybe_call_fn_and_grads(
        target_log_prob_fn, state_parts, target_log_prob,
        grads_target_log_prob)
    step_sizes, _ = mcmc_util.prepare_state_parts(step_size,
                                                  dtype=target_log_prob.dtype,
                                                  name='step_size')

    # Default momentum distribution is None, but if `store_parameters_in_results`
    # is true, then `momentum_distribution` defaults to an empty list
    if momentum_distribution is None or isinstance(momentum_distribution,
                                                   list):
        batch_rank = ps.rank(target_log_prob)

        def _batched_isotropic_normal_like(state_part):
            event_ndims = ps.rank(state_part) - batch_rank
            return independent.Independent(
                normal.Normal(ps.zeros_like(state_part, tf.float32), 1.),
                reinterpreted_batch_ndims=event_ndims)

        momentum_distribution = jds.JointDistributionSequential([
            _batched_isotropic_normal_like(state_part)
            for state_part in state_parts
        ])

    # The momentum will get "maybe listified" to zip with the state parts,
    # and this step makes sure that the momentum distribution will have the
    # same "maybe listified" underlying shape.
    if not mcmc_util.is_list_like(momentum_distribution.dtype):
        momentum_distribution = jds.JointDistributionSequential(
            [momentum_distribution])

    if len(step_sizes) == 1:
        step_sizes *= len(state_parts)
    if len(state_parts) != len(step_sizes):
        raise ValueError(
            'There should be exactly one `step_size` or it should '
            'have same length as `current_state`.')

    def maybe_flatten(x):
        return x if maybe_expand or mcmc_util.is_list_like(state) else x[0]

    return [
        maybe_flatten(state_parts),
        maybe_flatten(step_sizes),
        momentum_distribution,
        target_log_prob,
        grads_target_log_prob,
    ]
コード例 #9
0
ファイル: tf_support.py プロジェクト: yongheshinian/pymc4
    def bootstrap_results(self, init_state):
        """
        Returns an object with the same type as returned by `one_step(...)[1]`
        Compound bootrstrap step
        """
        with tf.name_scope(mcmc_util.make_name(self.name, "compound", "bootstrap_results")):
            if not mcmc_util.is_list_like(init_state):
                init_state = [init_state]
            init_state = [tf.convert_to_tensor(x) for x in init_state]

            init_results = []
            for sampleri, setli, curri in zip(
                self._compound_samplers,
                self._compound_set_lengths,
                self._cumulative_lengths,
            ):
                kernel = self.kernel_create_object(
                    sampleri, curri, setli, init_state, self._target_log_prob_fn
                )
                # bootstrap results in listj
                init_results.append(
                    kernel.bootstrap_results(init_state[slice(curri, curri + setli)])
                )

        return CompoundGibbsStepResults(compound_results=init_results)
コード例 #10
0
def make_transform_fn(bijector, direction):
  """Makes a function which applies a list of Bijectors' `forward`s."""
  if not mcmc_util.is_list_like(bijector):
    bijector = [bijector]
  def fn(state_parts):
    return [getattr(b, direction)(sp) for b, sp in zip(bijector, state_parts)]
  return fn
コード例 #11
0
ファイル: gibbs_kernel.py プロジェクト: ThomFNC/covid19uk
    def bootstrap_results(self, current_state):

        if mcmc_util.is_list_like(current_state):
            current_state = list(current_state)
        else:
            current_state = [tf.convert_to_tensor(current_state)]
        current_state = [
            tf.convert_to_tensor(s, name="current_state")
            for s in current_state
        ]

        inner_results = []
        untransformed_target_log_prob = 0.0
        for state_part_idx, kernel_fn in self.kernel_list:

            def target_log_prob_fn(_):
                return self.target_log_prob_fn(*current_state)

            kernel = kernel_fn(target_log_prob_fn, current_state)
            kernel_results = kernel.bootstrap_results(
                current_state[state_part_idx])
            inner_results.append(kernel_results)
            untransformed_target_log_prob = maybe_transform_value(
                tlp=get_target_log_prob(kernel_results),
                state=current_state[state_part_idx],
                kernel=kernel,
                direction="forward",
            )

        return GibbsKernelResults(
            target_log_prob=untransformed_target_log_prob,
            inner_results=inner_results,
        )
コード例 #12
0
    def _fn(state_parts, seed):
        """Adds a uniform perturbation to the input state.

    Args:
      state_parts: A list of `Tensor`s of any shape and real dtype representing
        the state parts of the `current_state` of the Markov chain.
      seed: `int` or None. The random seed for this `Op`. If `None`, no seed is
        applied.
        Default value: `None`.

    Returns:
      perturbed_state_parts: A Python `list` of The `Tensor`s. Has the same
        shape and type as the `state_parts`.

    Raises:
      ValueError: if `scale` does not broadcast with `state_parts`.
    """
        with tf.name_scope(name or 'random_walk_uniform_fn'):
            scales = scale if mcmc_util.is_list_like(scale) else [scale]
            if len(scales) == 1:
                scales *= len(state_parts)
            if len(state_parts) != len(scales):
                raise ValueError('`scale` must broadcast with `state_parts`.')
            part_seeds = samplers.split_seed(seed, n=len(state_parts))
            next_state_parts = [
                samplers.uniform(  # pylint: disable=g-complex-comprehension
                    minval=state_part - scale_part,
                    maxval=state_part + scale_part,
                    shape=tf.shape(state_part),
                    dtype=dtype_util.base_dtype(state_part.dtype),
                    seed=seed_part)
                for scale_part, state_part, seed_part in zip(
                    scales, state_parts, part_seeds)
            ]
            return next_state_parts
コード例 #13
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        """
    Wrapper over the normal HMC steps
    """
        next_state_parts, new_kernel_results = super().one_step(
            current_state, previous_kernel_results, seed)
        # We need to integrate the score over a path between input and output points
        # Direction of integration
        if mcmc_util.is_list_like(current_state):
            v = next_state_parts[0] - current_state[0]
            cs = current_state[0]
        else:
            v = next_state_parts - current_state
            cs = current_state

        @jax.vmap
        def integrand(t):
            return jnp.sum(self._parameters['target_score_fn'](t * v + cs) * v,
                           axis=-1)

        delta_logp = simps(integrand, 0., 1.,
                           self._parameters['num_delta_logp_steps'])
        new_kernel_results2 = new_kernel_results._replace(
            log_acceptance_correction=new_kernel_results.
            log_acceptance_correction + delta_logp)
        return next_state_parts, new_kernel_results2
コード例 #14
0
    def one_step(self, current_state, previous_kernel_results):
        with tf.name_scope(mcmc_util.make_name(self.name, 'rwm', 'one_step')):
            with tf.name_scope('initialize'):
                if mcmc_util.is_list_like(current_state):
                    current_state_parts = list(current_state)
                else:
                    current_state_parts = [current_state]
                current_state_parts = [
                    tf.convert_to_tensor(s, name='current_state')
                    for s in current_state_parts
                ]

            next_state_parts = self.new_state_fn(
                current_state_parts,  # pylint: disable=not-callable
                self._seed_stream())
            # Compute `target_log_prob` so its available to MetropolisHastings.
            next_target_log_prob = self.target_log_prob_fn(*next_state_parts)  # pylint: disable=not-callable

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(current_state) else x[0]

            return [
                maybe_flatten(next_state_parts),
                UncalibratedRandomWalkResults(
                    log_acceptance_correction=tf.zeros_like(
                        next_target_log_prob),
                    target_log_prob=next_target_log_prob,
                ),
            ]
コード例 #15
0
    def _fn(state_parts, seed):
        """Adds a normal perturbation to the input state.

    Args:
      state_parts: A list of `Tensor`s of any shape and real dtype representing
        the state parts of the `current_state` of the Markov chain.
      seed: `int` or None. The random seed for this `Op`. If `None`, no seed is
        applied.
        Default value: `None`.

    Returns:
      perturbed_state_parts: A Python `list` of The `Tensor`s. Has the same
        shape and type as the `state_parts`.

    Raises:
      ValueError: if `scale` does not broadcast with `state_parts`.
    """
        with tf.name_scope(name or 'random_walk_normal_fn'):
            scales = scale if mcmc_util.is_list_like(scale) else [scale]
            if len(scales) == 1:
                scales *= len(state_parts)
            if len(state_parts) != len(scales):
                raise ValueError('`scale` must broadcast with `state_parts`.')
            seed_stream = SeedStream(seed, salt='RandomWalkNormalFn')
            next_state_parts = [
                tf.random.normal(  # pylint: disable=g-complex-comprehension
                    mean=state_part,
                    stddev=scale_part,
                    shape=tf.shape(state_part),
                    dtype=dtype_util.base_dtype(state_part.dtype),
                    seed=seed_stream())
                for scale_part, state_part in zip(scales, state_parts)
            ]

            return next_state_parts
コード例 #16
0
    def bootstrap_results(self, init_state):
        """Creates initial `state`."""
        with tf.name_scope(
                mcmc_util.make_name(self.name,
                                    "AdaptiveRandomWalkMetropolisHastings",
                                    "bootstrap_results")):
            if mcmc_util.is_list_like(init_state):
                initial_state_parts = list(init_state)
            else:
                initial_state_parts = [init_state]
            initial_state_parts = [
                tf.convert_to_tensor(s, name="init_state")
                for s in initial_state_parts
            ]

            shape = tf.stack(initial_state_parts).shape
            dtype = dtype_util.base_dtype(tf.stack(initial_state_parts).dtype)

            init_covariance_scaling = tf.cast(
                tf.repeat([self.initial_covariance_scaling],
                          repeats=[shape[0]],
                          axis=0),
                dtype=dtype,
            )

            inner_results = self._impl.bootstrap_results(init_state)
            return self.extra_setter_fn(
                inner_results,
                0,
                init_covariance_scaling / shape[-1],
                self.initial_covariance,
                self._accum_covar,
                self.initial_u,
            )
コード例 #17
0
 def one_step(current_state, previous_kernel_results):
     # Make next_state.
     if is_list_like(current_state):
         next_state = []
         for i, s in enumerate(current_state):
             next_state.append(
                 tf.identity(s * dtype(i + 2), name='next_state'))
     else:
         next_state = tf.identity(2. * current_state, name='next_state')
     # Make kernel_results.
     kernel_results = {}
     for fn in sorted(previous_kernel_results._fields):
         if fn == 'grads_target_log_prob':
             kernel_results['grads_target_log_prob'] = [
                 tf.identity(0.5 * g, name='grad_target_log_prob')
                 for g in previous_kernel_results.grads_target_log_prob
             ]
         elif fn == 'extraneous':
             kernel_results[fn] = getattr(previous_kernel_results, fn, None)
         else:
             kernel_results[fn] = tf.identity(
                 0.5 * getattr(previous_kernel_results, fn, None), name=fn)
     kernel_results = type(previous_kernel_results)(**kernel_results)
     # Done.
     return next_state, kernel_results
コード例 #18
0
        def _loop_body(iter_, seed, ais_weights, current_state,
                       kernel_results):
            """Closure which implements `tf.while_loop` body."""
            iter_seed, next_seed = samplers.split_seed(
                seed,
                salt='ais_chain.seeded_one_step') if is_seeded else (seed,
                                                                     seed)

            x = (current_state
                 if mcmc_util.is_list_like(current_state) else [current_state])
            proposal_log_prob = proposal_log_prob_fn(*x)
            target_log_prob = target_log_prob_fn(*x)
            ais_weights += ((target_log_prob - proposal_log_prob) /
                            tf.cast(num_steps, ais_weights.dtype))
            kernel = make_kernel_fn(_make_convex_combined_log_prob_fn(iter_))
            # TODO(b/147676843): Should we warn if the kernel is not calibrated?
            one_step_kwargs = dict(seed=iter_seed) if is_seeded else {}
            next_state, inner_results = kernel.one_step(
                current_state, kernel_results.inner_results, **one_step_kwargs)
            kernel_results = AISResults(
                proposal_log_prob=proposal_log_prob,
                target_log_prob=target_log_prob,
                inner_results=inner_results,
            )
            return [
                iter_ + 1, next_seed, ais_weights, next_state, kernel_results
            ]
コード例 #19
0
ファイル: hmc.py プロジェクト: awajinokami/probability
 def build_assign_op():
     if mcmc_util.is_list_like(step_size_var):
         return [
             ss.assign_add(ss * tf.cast(adjustment, ss.dtype))
             for ss in step_size_var
         ]
     return step_size_var.assign_add(
         step_size_var * tf.cast(adjustment, step_size_var.dtype))
コード例 #20
0
def inverse_transform_fn(bijector):
  """Makes a function which applies a list of Bijectors' `inverse`s."""
  if not mcmc_util.is_list_like(bijector):
    bijector = [bijector]
  def fn(state_parts):
    return [b.inverse(sp)
            for b, sp in zip(bijector, state_parts)]
  return fn
コード例 #21
0
def make_transform_fn(bijector, direction):
  """Makes a function which applies a list of Bijectors' `forward`s."""
  if not mcmc_util.is_list_like(bijector):
    dtype = getattr(bijector, '{}_dtype'.format(direction))()
    if mcmc_util.is_list_like(dtype):
      return getattr(bijector, direction)
    elif tf.nest.is_nested(dtype):
      raise ValueError(
          'Only list-like multi-part bijectors are currently supported, but '
          'got {}.'.format(tf.nest.map_structure(lambda _: '.', dtype)))
    bijector = [bijector]
  def fn(state_parts):
    if len(bijector) != len(state_parts):
      raise ValueError('State has {} parts, but bijector has {}.'.format(
          len(state_parts), len(bijector)))
    return [getattr(b, direction)(sp) for b, sp in zip(bijector, state_parts)]
  return fn
コード例 #22
0
def forward_transform_fn(bijector):
  """Makes a function which applies a list of Bijectors' `forward`s."""
  if not mcmc_util.is_list_like(bijector):
    bijector = [bijector]

  def fn(transformed_state_parts):
    return [b.forward(sp) for b, sp in zip(bijector, transformed_state_parts)]

  return fn
コード例 #23
0
 def bootstrap_results(self, init_state):
     with tf1.name_scope(self.name, 'rwm_bootstrap_results', [init_state]):
         if not mcmc_util.is_list_like(init_state):
             init_state = [init_state]
         init_state = [tf.convert_to_tensor(value=x) for x in init_state]
         init_target_log_prob = self.target_log_prob_fn(*init_state)  # pylint:disable=not-callable
         return UncalibratedRandomWalkResults(
             log_acceptance_correction=tf.zeros_like(init_target_log_prob),
             target_log_prob=init_target_log_prob)
コード例 #24
0
ファイル: ais_mcmc.py プロジェクト: thangbui/annealed_is
        def _loop_body(
            iter_,
            seed,
            ais_weights_lower,
            ais_weights_upper,
            current_state,
            kernel_results,
        ):
            """Closure which implements `tf.while_loop` body."""
            iter_seed, next_seed = (samplers.split_seed(
                seed, salt="ais_chain.seeded_one_step") if is_seeded else
                                    (seed, seed))
            fcurrent = _make_convex_combined_log_prob_fn(iter_ - 1)
            fnext = _make_convex_combined_log_prob_fn(iter_)

            x = (current_state
                 if mcmc_util.is_list_like(current_state) else [current_state])
            fcurrent_log_prob = fcurrent(*x)
            fnext_log_prob = fnext(*x)
            ais_weights_lower += fnext_log_prob - fcurrent_log_prob

            kernel = make_kernel_fn(fnext)
            one_step_kwargs = dict(seed=iter_seed) if is_seeded else {}
            next_state, inner_results = kernel.one_step(
                current_state, kernel_results.inner_results, **one_step_kwargs)
            kernel_results = AISResults(
                proposal_log_prob=fcurrent_log_prob,
                target_log_prob=fnext_log_prob,
                inner_results=inner_results,
            )
            x = (next_state
                 if mcmc_util.is_list_like(next_state) else [next_state])
            fcurrent_log_prob = fcurrent(*x)
            fnext_log_prob = fnext(*x)
            ais_weights_upper += fnext_log_prob - fcurrent_log_prob

            return [
                iter_ + 1,
                next_seed,
                ais_weights_lower,
                ais_weights_upper,
                next_state,
                kernel_results,
            ]
コード例 #25
0
def step_size_simple_update(step_size_var,kernel_results,target_rate=0.75, decrement_multiplier=0.01,increment_multiplier=0.01):
  if kernel_results is None:
    if mcmc_util.is_list_like(step_size_var):
      return [tf.identity(ss) for ss in step_size_var]
    return tf.identity(step_size_var)
  log_n = tf.log(tf.cast(tf.size(kernel_results.inner_results.log_accept_ratio),
                         kernel_results.inner_results.log_accept_ratio.dtype))
  log_mean_accept_ratio = tf.reduce_logsumexp(
      tf.minimum(kernel_results.inner_results.log_accept_ratio, 0.)) - log_n
  adjustment = tf.where(
      log_mean_accept_ratio < tf.log(target_rate),
      -decrement_multiplier / (1. + decrement_multiplier),
      increment_multiplier)
  if not mcmc_util.is_list_like(step_size_var):
    return step_size_var.assign_add(step_size_var * adjustment)
  step_size_assign = []
  for ss in step_size_var:
    step_size_assign.append(ss.assign_add(ss * adjustment))
  return step_size_assign
コード例 #26
0
def make_log_det_jacobian_fn(bijector, direction):
  """Makes a function which applies a list of Bijectors' `log_det_jacobian`s."""
  attr = '{}_log_det_jacobian'.format(direction)
  if not mcmc_util.is_list_like(bijector):
    dtype = getattr(bijector, '{}_dtype'.format(direction))()
    if mcmc_util.is_list_like(dtype):
      def multipart_fn(state_parts, event_ndims):
        return getattr(bijector, attr)(state_parts, event_ndims)
      return multipart_fn
    elif tf.nest.is_nested(dtype):
      raise ValueError(
          'Only list-like multi-part bijectors are currently supported, but '
          'got {}.'.format(tf.nest.map_structure(lambda _: '.', dtype)))
    bijector = [bijector]
  def fn(state_parts, event_ndims):
    return sum([
        getattr(b, attr)(sp, event_ndims=e)
        for b, e, sp in zip(bijector, event_ndims, state_parts)
    ])
  return fn
コード例 #27
0
def make_log_det_jacobian_fn(bijector, direction):
  """Makes a function which applies a list of Bijectors' `log_det_jacobian`s."""
  if not mcmc_util.is_list_like(bijector):
    bijector = [bijector]
  attr = '{}_log_det_jacobian'.format(direction)
  def fn(state_parts, event_ndims):
    return [
        getattr(b, attr)(sp, event_ndims=e)
        for b, e, sp in zip(bijector, event_ndims, state_parts)
    ]
  return fn
コード例 #28
0
 def bootstrap_results(self, init_state):
     with tf.compat.v1.name_scope(name=mcmc_util.make_name(
             self.name, 'elliptical_slice', 'bootstrap_results'),
                                  values=[init_state]):
         if not mcmc_util.is_list_like(init_state):
             init_state = [init_state]
         init_state = [tf.convert_to_tensor(x) for x in init_state]
         init_log_likelihood = self.log_likelihood_fn(*init_state)  # pylint:disable=not-callable
         return EllipticalSliceSamplerKernelResults(
             log_likelihood=init_log_likelihood,
             angle=tf.zeros_like(init_log_likelihood),
             normal_samples=[tf.zeros_like(x) for x in init_state])
コード例 #29
0
def forward_log_det_jacobian_fn(bijector):
    """Makes a function which applies a list of Bijectors' `log_det_jacobian`s."""
    if not mcmc_util.is_list_like(bijector):
        bijector = [bijector]

    def fn(transformed_state_parts, event_ndims):
        return sum([
            b.forward_log_det_jacobian(sp, event_ndims=e)
            for b, e, sp in zip(bijector, event_ndims, transformed_state_parts)
        ])

    return fn
コード例 #30
0
 def bootstrap_results(self, init_state):
     with tf.name_scope(
             mcmc_util.make_name(self.name, 'rwm', 'bootstrap_results')):
         if not mcmc_util.is_list_like(init_state):
             init_state = [init_state]
         init_state = [tf.convert_to_tensor(x) for x in init_state]
         init_target_log_prob = self.target_log_prob_fn(*init_state)  # pylint:disable=not-callable
         return UncalibratedRandomWalkResults(
             log_acceptance_correction=tf.zeros_like(init_target_log_prob),
             target_log_prob=init_target_log_prob,
             # Allow room for one_step's seed.
             seed=samplers.zeros_seed())