コード例 #1
0
    def _cdf(self, k):
        # TODO(b/135263541): Improve numerical precision of categorical.cdf.
        probs = self.probs_parameter()
        num_categories = self._num_categories(probs)

        k, probs = _broadcast_cat_event_and_params(
            k, probs, base_dtype=dtype_util.base_dtype(self.dtype))

        # Since the lowest number in the support is 0, any k < 0 should be zero in
        # the output.
        should_be_zero = k < 0

        # Will use k as an index in the gather below, so clip it to {0,...,K-1}.
        k = tf.clip_by_value(tf.cast(k, tf.int32), 0, num_categories - 1)

        batch_shape = tf.shape(k)

        # tf.gather(..., batch_dims=batch_dims) requires static batch_dims kwarg, so
        # to handle the case where the batch shape is dynamic, flatten the batch
        # dims (so we know batch_dims=1).
        k_flat_batch = tf.reshape(k, [-1])
        probs_flat_batch = tf.reshape(
            probs, tf.concat(([-1], [num_categories]), axis=0))

        cdf_flat = tf.gather(tf.cumsum(probs_flat_batch, axis=-1),
                             k_flat_batch[..., tf.newaxis],
                             batch_dims=1)

        cdf = tf.reshape(cdf_flat, shape=batch_shape)

        zero = np.array(0, dtype=dtype_util.as_numpy_dtype(cdf.dtype))
        return tf.where(should_be_zero, zero, cdf)
コード例 #2
0
def clip_by_value_preserve_gradient(t,
                                    clip_value_min,
                                    clip_value_max,
                                    name=None):
    """Clips values to a specified min and max while leaving gradient unaltered.

  Like `tf.clip_by_value`, this function returns a tensor of the same type and
  shape as input `t` but with values clamped to be no smaller than to
  `clip_value_min` and no larger than `clip_value_max`. Unlike
  `tf.clip_by_value`, the gradient is unaffected by this op, i.e.,

  ```python
  tf.gradients(tfp.math.clip_by_value_preserve_gradient(x), x)[0]
  # ==> ones_like(x)
  ```

  Note: `clip_value_min` needs to be smaller or equal to `clip_value_max` for
  correct results.

  Args:
    t: A `Tensor`.
    clip_value_min: A scalar `Tensor`, or a `Tensor` with the same shape
      as `t`. The minimum value to clip by.
    clip_value_max: A scalar `Tensor`, or a `Tensor` with the same shape
      as `t`. The maximum value to clip by.
    name: A name for the operation (optional).
      Default value: `'clip_by_value_preserve_gradient'`.

  Returns:
    clipped_t: A clipped `Tensor`.
  """
    with tf.name_scope(name or 'clip_by_value_preserve_gradient'):
        t = tf.convert_to_tensor(t, name='t')
        clip_t = tf.clip_by_value(t, clip_value_min, clip_value_max)
        return t + tf.stop_gradient(clip_t - t)
コード例 #3
0
            def grad(dy):
                """Computes a derivative for the min and max parameters.

        This function implements the derivative wrt the truncation bounds, which
        get blocked by the sampler. We use a custom expression for numerical
        stability instead of automatic differentiation on CDF for implicit
        gradients.

        Args:
          dy: output gradients

        Returns:
           The standard normal samples and the gradients wrt the upper
           bound and lower bound.
        """
                # std_samples has an extra dimension (the sample dimension), expand
                # lower and upper so they broadcast along this dimension.
                # See note above regarding parameterized_truncated_normal, the sample
                # dimension is the final dimension.
                lower_broadcast = lower[..., tf.newaxis]
                upper_broadcast = upper[..., tf.newaxis]

                cdf_samples = ((special_math.ndtr(std_samples) -
                                special_math.ndtr(lower_broadcast)) /
                               (special_math.ndtr(upper_broadcast) -
                                special_math.ndtr(lower_broadcast)))

                # tiny, eps are tolerance parameters to ensure we stay away from giving
                # a zero arg to the log CDF expression.

                tiny = np.finfo(dtype_util.as_numpy_dtype(self.dtype)).tiny
                eps = np.finfo(dtype_util.as_numpy_dtype(self.dtype)).eps
                cdf_samples = tf.clip_by_value(cdf_samples, tiny, 1 - eps)

                du = tf.exp(0.5 * (std_samples**2 - upper_broadcast**2) +
                            tf.math.log(cdf_samples))
                dl = tf.exp(0.5 * (std_samples**2 - lower_broadcast**2) +
                            tf.math.log1p(-cdf_samples))

                # Reduce the gradient across the samples
                grad_u = tf.reduce_sum(dy * du, axis=-1)
                grad_l = tf.reduce_sum(dy * dl, axis=-1)
                return [grad_l, grad_u]
コード例 #4
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        """Takes one step of the TransitionKernel.
    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).
      seed: Optional, a seed for reproducible sampling.
    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """

        with tf.name_scope(mcmc_util.make_name(self.name, 'tmc', 'one_step')):
            # Force a read in case the `inverse_temperatures` is a `tf.Variable`.
            inverse_temperatures = tf.convert_to_tensor(
                previous_kernel_results.post_tempering_inverse_temperatures,
                name='inverse_temperatures')

            steps_at_temperature = tf.convert_to_tensor(
                previous_kernel_results.steps_at_temperature,
                name='number of steps')

            target_score_for_inner_kernel = partial(self.target_score_fn,
                                                    sigma=inverse_temperatures)
            target_log_prob_for_inner_kernel = partial(
                self.target_log_prob_fn, sigma=inverse_temperatures)

            try:
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                    target_log_prob_for_inner_kernel,
                    target_score_for_inner_kernel, inverse_temperatures)
            except TypeError as e:
                if 'argument' not in str(e):
                    raise
                warnings.warn(
                    'The `seed` argument to `ReplicaExchangeMC`s `make_kernel_fn` is '
                    'deprecated. `TransitionKernel` instances now receive seeds via '
                    '`one_step`.')
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                    target_log_prob_for_inner_kernel,
                    target_score_for_inner_kernel, inverse_temperatures,
                    self._seed_stream())

            if seed is not None:
                seed = samplers.sanitize_seed(seed)
                inner_seed, swap_seed, logu_seed = samplers.split_seed(
                    seed, n=3, salt='tmc_one_step')
                inner_kwargs = dict(seed=inner_seed)
            else:
                if self._seed_stream.original_seed is not None:
                    warnings.warn(mcmc_util.SEED_CTOR_ARG_DEPRECATION_MSG)
                inner_kwargs = {}
                swap_seed, logu_seed = samplers.split_seed(self._seed_stream())

            if mcmc_util.is_list_like(current_state):
                # We *always* canonicalize the states in the kernel results.
                states = current_state
            else:
                states = [current_state]
            print(states)
            [
                new_state,
                pre_tempering_results,
            ] = inner_kernel.one_step(
                states, previous_kernel_results.post_tempering_results,
                **inner_kwargs)

            # Now that we have run one step, we consider maybe lowering the temperature
            # Proposed new temperature
            proposed_inverse_temperatures = tf.clip_by_value(
                self.gamma * inverse_temperatures, self.min_temp, 1e6)
            dtype = inverse_temperatures.dtype

            # We will lower the temperature if this new proposed step is compatible with
            # a temperature swap
            v = new_state[0] - states[0]
            cs = states[0]

            @jax.vmap
            def integrand(t):
                return jnp.sum(self._parameters['target_score_fn'](
                    t * v + cs, inverse_temperatures) * v,
                               axis=-1)

            delta_logp1 = simps(integrand, 0., 1.,
                                self._parameters['num_delta_logp_steps'])

            # Now we compute the reverse
            v = -v
            cs = new_state[0]

            @jax.vmap
            def integrand(t):
                return jnp.sum(self._parameters['target_score_fn'](
                    t * v + cs, proposed_inverse_temperatures) * v,
                               axis=-1)

            delta_logp2 = simps(integrand, 0., 1.,
                                self._parameters['num_delta_logp_steps'])

            log_accept_ratio = (delta_logp1 + delta_logp2)

            log_accept_ratio = tf.where(tf.math.is_finite(log_accept_ratio),
                                        log_accept_ratio,
                                        tf.constant(-np.inf, dtype=dtype))

            # Produce Log[Uniform] draws that are identical at swapped indices.
            log_uniform = tf.math.log(
                samplers.uniform(shape=log_accept_ratio.shape,
                                 dtype=dtype,
                                 seed=logu_seed))

            is_tempering_accepted_mask = tf.less(
                log_uniform,
                log_accept_ratio,
                name='is_tempering_accepted_mask')

            is_min_steps_satisfied = tf.greater(
                steps_at_temperature,
                self.min_steps_per_temp * tf.ones_like(steps_at_temperature),
                name='is_min_steps_satisfied')

            # Only propose tempering if the chain was going to accept this point anyway
            is_tempering_accepted_mask = tf.math.logical_and(
                is_tempering_accepted_mask, pre_tempering_results.is_accepted)

            is_tempering_accepted_mask = tf.math.logical_and(
                is_tempering_accepted_mask, is_min_steps_satisfied)

            # Updating accepted inverse temperatures
            post_tempering_inverse_temperatures = mcmc_util.choose(
                is_tempering_accepted_mask, proposed_inverse_temperatures,
                inverse_temperatures)

            steps_at_temperature = mcmc_util.choose(
                is_tempering_accepted_mask,
                tf.zeros_like(steps_at_temperature), steps_at_temperature + 1)

            # Invalidating and recomputing results
            [
                new_target_log_prob,
                new_grads_target_log_prob,
            ] = mcmc_util.maybe_call_fn_and_grads(
                partial(self.target_log_prob_fn,
                        sigma=post_tempering_inverse_temperatures), new_state)

            # Updating inner kernel results
            post_tempering_results = pre_tempering_results._replace(
                proposed_results=tf.convert_to_tensor(np.nan, dtype=dtype),
                proposed_state=tf.convert_to_tensor(np.nan, dtype=dtype),
            )

            if isinstance(post_tempering_results.accepted_results,
                          hmc.UncalibratedHamiltonianMonteCarloKernelResults):
                post_tempering_results = post_tempering_results._replace(
                    accepted_results=post_tempering_results.accepted_results.
                    _replace(target_log_prob=new_target_log_prob,
                             grads_target_log_prob=new_grads_target_log_prob))
            elif isinstance(
                    post_tempering_results.accepted_results,
                    random_walk_metropolis.UncalibratedRandomWalkResults):
                post_tempering_results = post_tempering_results._replace(
                    accepted_results=post_tempering_results.accepted_results.
                    _replace(target_log_prob=new_target_log_prob))
            else:
                # TODO(b/143702650) Handle other kernels.
                raise NotImplementedError(
                    'Only HMC and RWMH Kernels are handled at this time. Please file a '
                    'request with the TensorFlow Probability team.')

            new_kernel_results = TemperedMCKernelResults(
                pre_tempering_results=pre_tempering_results,
                post_tempering_results=post_tempering_results,
                pre_tempering_inverse_temperatures=inverse_temperatures,
                post_tempering_inverse_temperatures=
                post_tempering_inverse_temperatures,
                tempering_log_accept_ratio=log_accept_ratio,
                steps_at_temperature=steps_at_temperature,
                seed=samplers.zeros_seed() if seed is None else seed,
            )

            return new_state[0], new_kernel_results
コード例 #5
0
 def _mode(self):
     # mode = { loc:         for low <= loc <= high
     #          low: for loc < low
     #          high: for loc > high
     #        }
     return tf.clip_by_value(self.loc, self.low, self.high)
コード例 #6
0
 def _cdf(self, x):
     cdf_in_support = ((special_math.ndtr((x - self.loc) / self.scale) -
                        special_math.ndtr(self._standardized_low)) /
                       self._normalizer)
     return tf.clip_by_value(cdf_in_support, 0., 1.)