def _cdf(self, k): # TODO(b/135263541): Improve numerical precision of categorical.cdf. probs = self.probs_parameter() num_categories = self._num_categories(probs) k, probs = _broadcast_cat_event_and_params( k, probs, base_dtype=dtype_util.base_dtype(self.dtype)) # Since the lowest number in the support is 0, any k < 0 should be zero in # the output. should_be_zero = k < 0 # Will use k as an index in the gather below, so clip it to {0,...,K-1}. k = tf.clip_by_value(tf.cast(k, tf.int32), 0, num_categories - 1) batch_shape = tf.shape(k) # tf.gather(..., batch_dims=batch_dims) requires static batch_dims kwarg, so # to handle the case where the batch shape is dynamic, flatten the batch # dims (so we know batch_dims=1). k_flat_batch = tf.reshape(k, [-1]) probs_flat_batch = tf.reshape( probs, tf.concat(([-1], [num_categories]), axis=0)) cdf_flat = tf.gather(tf.cumsum(probs_flat_batch, axis=-1), k_flat_batch[..., tf.newaxis], batch_dims=1) cdf = tf.reshape(cdf_flat, shape=batch_shape) zero = np.array(0, dtype=dtype_util.as_numpy_dtype(cdf.dtype)) return tf.where(should_be_zero, zero, cdf)
def clip_by_value_preserve_gradient(t, clip_value_min, clip_value_max, name=None): """Clips values to a specified min and max while leaving gradient unaltered. Like `tf.clip_by_value`, this function returns a tensor of the same type and shape as input `t` but with values clamped to be no smaller than to `clip_value_min` and no larger than `clip_value_max`. Unlike `tf.clip_by_value`, the gradient is unaffected by this op, i.e., ```python tf.gradients(tfp.math.clip_by_value_preserve_gradient(x), x)[0] # ==> ones_like(x) ``` Note: `clip_value_min` needs to be smaller or equal to `clip_value_max` for correct results. Args: t: A `Tensor`. clip_value_min: A scalar `Tensor`, or a `Tensor` with the same shape as `t`. The minimum value to clip by. clip_value_max: A scalar `Tensor`, or a `Tensor` with the same shape as `t`. The maximum value to clip by. name: A name for the operation (optional). Default value: `'clip_by_value_preserve_gradient'`. Returns: clipped_t: A clipped `Tensor`. """ with tf.name_scope(name or 'clip_by_value_preserve_gradient'): t = tf.convert_to_tensor(t, name='t') clip_t = tf.clip_by_value(t, clip_value_min, clip_value_max) return t + tf.stop_gradient(clip_t - t)
def grad(dy): """Computes a derivative for the min and max parameters. This function implements the derivative wrt the truncation bounds, which get blocked by the sampler. We use a custom expression for numerical stability instead of automatic differentiation on CDF for implicit gradients. Args: dy: output gradients Returns: The standard normal samples and the gradients wrt the upper bound and lower bound. """ # std_samples has an extra dimension (the sample dimension), expand # lower and upper so they broadcast along this dimension. # See note above regarding parameterized_truncated_normal, the sample # dimension is the final dimension. lower_broadcast = lower[..., tf.newaxis] upper_broadcast = upper[..., tf.newaxis] cdf_samples = ((special_math.ndtr(std_samples) - special_math.ndtr(lower_broadcast)) / (special_math.ndtr(upper_broadcast) - special_math.ndtr(lower_broadcast))) # tiny, eps are tolerance parameters to ensure we stay away from giving # a zero arg to the log CDF expression. tiny = np.finfo(dtype_util.as_numpy_dtype(self.dtype)).tiny eps = np.finfo(dtype_util.as_numpy_dtype(self.dtype)).eps cdf_samples = tf.clip_by_value(cdf_samples, tiny, 1 - eps) du = tf.exp(0.5 * (std_samples**2 - upper_broadcast**2) + tf.math.log(cdf_samples)) dl = tf.exp(0.5 * (std_samples**2 - lower_broadcast**2) + tf.math.log1p(-cdf_samples)) # Reduce the gradient across the samples grad_u = tf.reduce_sum(dy * du, axis=-1) grad_l = tf.reduce_sum(dy * dl, axis=-1) return [grad_l, grad_u]
def one_step(self, current_state, previous_kernel_results, seed=None): """Takes one step of the TransitionKernel. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). seed: Optional, a seed for reproducible sampling. Returns: next_state: `Tensor` or Python `list` of `Tensor`s representing the next state(s) of the Markov chain(s). kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. This inculdes replica states. """ with tf.name_scope(mcmc_util.make_name(self.name, 'tmc', 'one_step')): # Force a read in case the `inverse_temperatures` is a `tf.Variable`. inverse_temperatures = tf.convert_to_tensor( previous_kernel_results.post_tempering_inverse_temperatures, name='inverse_temperatures') steps_at_temperature = tf.convert_to_tensor( previous_kernel_results.steps_at_temperature, name='number of steps') target_score_for_inner_kernel = partial(self.target_score_fn, sigma=inverse_temperatures) target_log_prob_for_inner_kernel = partial( self.target_log_prob_fn, sigma=inverse_temperatures) try: inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable target_log_prob_for_inner_kernel, target_score_for_inner_kernel, inverse_temperatures) except TypeError as e: if 'argument' not in str(e): raise warnings.warn( 'The `seed` argument to `ReplicaExchangeMC`s `make_kernel_fn` is ' 'deprecated. `TransitionKernel` instances now receive seeds via ' '`one_step`.') inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable target_log_prob_for_inner_kernel, target_score_for_inner_kernel, inverse_temperatures, self._seed_stream()) if seed is not None: seed = samplers.sanitize_seed(seed) inner_seed, swap_seed, logu_seed = samplers.split_seed( seed, n=3, salt='tmc_one_step') inner_kwargs = dict(seed=inner_seed) else: if self._seed_stream.original_seed is not None: warnings.warn(mcmc_util.SEED_CTOR_ARG_DEPRECATION_MSG) inner_kwargs = {} swap_seed, logu_seed = samplers.split_seed(self._seed_stream()) if mcmc_util.is_list_like(current_state): # We *always* canonicalize the states in the kernel results. states = current_state else: states = [current_state] print(states) [ new_state, pre_tempering_results, ] = inner_kernel.one_step( states, previous_kernel_results.post_tempering_results, **inner_kwargs) # Now that we have run one step, we consider maybe lowering the temperature # Proposed new temperature proposed_inverse_temperatures = tf.clip_by_value( self.gamma * inverse_temperatures, self.min_temp, 1e6) dtype = inverse_temperatures.dtype # We will lower the temperature if this new proposed step is compatible with # a temperature swap v = new_state[0] - states[0] cs = states[0] @jax.vmap def integrand(t): return jnp.sum(self._parameters['target_score_fn']( t * v + cs, inverse_temperatures) * v, axis=-1) delta_logp1 = simps(integrand, 0., 1., self._parameters['num_delta_logp_steps']) # Now we compute the reverse v = -v cs = new_state[0] @jax.vmap def integrand(t): return jnp.sum(self._parameters['target_score_fn']( t * v + cs, proposed_inverse_temperatures) * v, axis=-1) delta_logp2 = simps(integrand, 0., 1., self._parameters['num_delta_logp_steps']) log_accept_ratio = (delta_logp1 + delta_logp2) log_accept_ratio = tf.where(tf.math.is_finite(log_accept_ratio), log_accept_ratio, tf.constant(-np.inf, dtype=dtype)) # Produce Log[Uniform] draws that are identical at swapped indices. log_uniform = tf.math.log( samplers.uniform(shape=log_accept_ratio.shape, dtype=dtype, seed=logu_seed)) is_tempering_accepted_mask = tf.less( log_uniform, log_accept_ratio, name='is_tempering_accepted_mask') is_min_steps_satisfied = tf.greater( steps_at_temperature, self.min_steps_per_temp * tf.ones_like(steps_at_temperature), name='is_min_steps_satisfied') # Only propose tempering if the chain was going to accept this point anyway is_tempering_accepted_mask = tf.math.logical_and( is_tempering_accepted_mask, pre_tempering_results.is_accepted) is_tempering_accepted_mask = tf.math.logical_and( is_tempering_accepted_mask, is_min_steps_satisfied) # Updating accepted inverse temperatures post_tempering_inverse_temperatures = mcmc_util.choose( is_tempering_accepted_mask, proposed_inverse_temperatures, inverse_temperatures) steps_at_temperature = mcmc_util.choose( is_tempering_accepted_mask, tf.zeros_like(steps_at_temperature), steps_at_temperature + 1) # Invalidating and recomputing results [ new_target_log_prob, new_grads_target_log_prob, ] = mcmc_util.maybe_call_fn_and_grads( partial(self.target_log_prob_fn, sigma=post_tempering_inverse_temperatures), new_state) # Updating inner kernel results post_tempering_results = pre_tempering_results._replace( proposed_results=tf.convert_to_tensor(np.nan, dtype=dtype), proposed_state=tf.convert_to_tensor(np.nan, dtype=dtype), ) if isinstance(post_tempering_results.accepted_results, hmc.UncalibratedHamiltonianMonteCarloKernelResults): post_tempering_results = post_tempering_results._replace( accepted_results=post_tempering_results.accepted_results. _replace(target_log_prob=new_target_log_prob, grads_target_log_prob=new_grads_target_log_prob)) elif isinstance( post_tempering_results.accepted_results, random_walk_metropolis.UncalibratedRandomWalkResults): post_tempering_results = post_tempering_results._replace( accepted_results=post_tempering_results.accepted_results. _replace(target_log_prob=new_target_log_prob)) else: # TODO(b/143702650) Handle other kernels. raise NotImplementedError( 'Only HMC and RWMH Kernels are handled at this time. Please file a ' 'request with the TensorFlow Probability team.') new_kernel_results = TemperedMCKernelResults( pre_tempering_results=pre_tempering_results, post_tempering_results=post_tempering_results, pre_tempering_inverse_temperatures=inverse_temperatures, post_tempering_inverse_temperatures= post_tempering_inverse_temperatures, tempering_log_accept_ratio=log_accept_ratio, steps_at_temperature=steps_at_temperature, seed=samplers.zeros_seed() if seed is None else seed, ) return new_state[0], new_kernel_results
def _mode(self): # mode = { loc: for low <= loc <= high # low: for loc < low # high: for loc > high # } return tf.clip_by_value(self.loc, self.low, self.high)
def _cdf(self, x): cdf_in_support = ((special_math.ndtr((x - self.loc) / self.scale) - special_math.ndtr(self._standardized_low)) / self._normalizer) return tf.clip_by_value(cdf_in_support, 0., 1.)