def _inverse(self, y): # We undo the transformation from [0, 1] -> [0, 1]. # The inverse of the transformation will look like a shifted and scaled # logit function. We rewrite this to be more numerically stable, and will # produce a term log(a / b). log_{numerator, denominator} below is log(a) # and log(b) respectively. t = tf.convert_to_tensor(self.temperature) fractional_part = y - tf.math.floor(y) log_f = tf.math.log(fractional_part) # We wrap `0` and `0.5` in a `tf.constant` with explicit dtype to avoid # upcasting in the numpy backed. Declare this once, and use it everywhere. zero = tf.zeros([], self.dtype) one_half = tf.constant(0.5, self.dtype) log_numerator = tfp_math.log_sub_exp(one_half / t + log_f, log_f) log_numerator = tfp_math.log_add_exp(zero, log_numerator) # When the fractional part is zero, the numerator is 1. log_numerator = tf.where(tf.equal(fractional_part, 0.), zero, log_numerator) log_denominator = tfp_math.log_sub_exp(one_half / t, log_f + one_half / t) log_denominator = tfp_math.log_add_exp(log_f, log_denominator) # When the fractional part is zero, the denominator is 0.5 / t. log_denominator = tf.where(tf.equal(fractional_part, 0.), one_half / t, log_denominator) new_fractional_part = (t * (log_numerator - log_denominator) + one_half) # We finally shift this up since the original transformation was from # [0.5, 1.5] to [0, 1]. new_fractional_part = new_fractional_part + one_half return tf.math.floor(y) + new_fractional_part
def _inverse(self, y): # We undo the transformation from [0, 1] -> [0, 1]. # The inverse of the transformation will look like a shifted and scaled # logit function. We rewrite this to be more numerically stable, and will # produce a term log(a / b). log_{numerator, denominator} below is log(a) # and log(b) respectively. t = tf.convert_to_tensor(self.temperature) fractional_part = y - tf.math.floor(y) log_f = tf.math.log(fractional_part) log_numerator = tfp_math.log_sub_exp(0.5 / t + log_f, log_f) log_numerator = tfp_math.log_add_exp(0., log_numerator) # When the fractional part is zero, the numerator is 1. log_numerator = tf.where( tf.equal(fractional_part, 0.), dtype_util.as_numpy_dtype(self.dtype)(0.), log_numerator) log_denominator = tfp_math.log_sub_exp(0.5 / t, log_f + 0.5 / t) log_denominator = tfp_math.log_add_exp(log_f, log_denominator) # When the fractional part is zero, the denominator is 0.5 / t. log_denominator = tf.where( tf.equal(fractional_part, 0.), 0.5 / t, log_denominator) new_fractional_part = t * (log_numerator - log_denominator) + 0.5 # We finally shift this up since the original transformation was from # [0.5, 1.5] to [0, 1]. new_fractional_part = new_fractional_part + 0.5 return tf.math.floor(y) + new_fractional_part
def _forward(self, x): # This has a well defined derivative with respect to x. # This is because in the range [0.5, 1.5] this is just a rescaled # logit function and hence has a derivative. At the end points, because # the logit function satisfies 1 - sigma(-x) = sigma(x), we have that # the derivative is symmetric around the center of the interval=1., # and hence is continuous at the endpoints. t = tf.convert_to_tensor(self.temperature) fractional_part = x - tf.math.floor(x) # First, because our function is defined on the interval [0.5, 1.5] # repeated, we need to rescale our input to reflect that. x - floor(x) # will map our input to [0, 1]. However, we need to map inputs whose # fractional part is < 0.5 to the right hand portion of the interval. # We'll also need to adjust the integer part to reflect this. integer_part = tf.math.floor(x) # We wrap `0.5` in a `tf.constant` with explicit dtype to avoid upcasting # in the numpy backed. Declare this once, and use it everywhere. one_half = tf.constant(0.5, self.dtype) integer_part = tf.where( fractional_part < one_half, integer_part - tf.ones([], self.dtype), integer_part) fractional_part = tf.where(fractional_part < one_half, fractional_part + one_half, fractional_part - one_half) # Rescale so the left tail is 0., and the right tail is 1. This # will also guarantee us continuity. Differentiability comes from the # fact that the derivative of the sigmoid is symmetric, and hence # the two endpoints will have the same value for derivatives. # The below calculations are just # (sigmoid((f - 0.5) / t) - sigmoid(-0.5 / t)) / # (sigmoid(0.5 / t) - sigmoid(-0.5 / t)) # We use log_sum_exp and log_sub_exp to make this calculation more # numerically stable. log_numerator = tfp_math.log_sub_exp( (one_half + fractional_part) / t, one_half / t) # If fractional_part == 0, then we'll get log(0). log_numerator = tf.where( tf.equal(fractional_part, 0.), tf.constant(-np.inf, self.dtype), log_numerator) log_denominator = tfp_math.log_sub_exp( (one_half + fractional_part) / t, fractional_part / t) # If fractional_part == 0, then we'll get log(0). log_denominator = tf.where( tf.equal(fractional_part, 0.), tf.constant(-np.inf, self.dtype), log_denominator) log_denominator = tfp_math.log_add_exp( log_denominator, tfp_math.log_sub_exp(tf.ones([], self.dtype) / t, one_half / t)) rescaled_part = tf.math.exp(log_numerator - log_denominator) # We add a term sigmoid(0.5 / t). When t->infinity, this will be 0.5, # which will correctly shift the function so that this acts like the # identity. When t->0, this will approach 0, so that the function # correctly approximates a floor function. return integer_part + rescaled_part + tf.math.sigmoid(-0.5 / t)
def _log_prob(self, x): # TODO(b/149334734): Consider using QuantizedDistribution for the log_prob # computation for better precision. num_categories = self._num_categories() x, augmented_log_survival = _broadcast_cat_event_and_params( event=x, params=tf.math.log_sigmoid(self.loc[..., tf.newaxis] - self._augmented_cutpoints()), base_dtype=dtype_util.base_dtype(self.dtype)) x_flat = tf.reshape(x, [-1, 1]) augmented_log_survival_flat = tf.reshape(augmented_log_survival, [-1, num_categories + 1]) log_survival_flat_xm1 = tf.gather(params=augmented_log_survival_flat, indices=tf.clip_by_value( x_flat, 0, num_categories), batch_dims=1) log_survival_flat_x = tf.gather(params=augmented_log_survival_flat, indices=tf.clip_by_value( x_flat + 1, 0, num_categories), batch_dims=1) log_prob_flat = tfp_math.log_sub_exp(log_survival_flat_xm1, log_survival_flat_x) # Deal with case where both survival probabilities are -inf, which gives # `log_prob_flat = nan` when it should be -inf. minus_inf = tf.constant(-np.inf, dtype=log_prob_flat.dtype) log_prob_flat = tf.where(x_flat > num_categories - 1, minus_inf, log_prob_flat) return tf.reshape(log_prob_flat, shape=ps.shape(x))
def _forward(self, x): # This has a well defined derivative with respect to x. # This is because in the range [0.5, 1.5] this is just a rescaled # logit function and hence has a derivative. At the end points, because # the logit function satisfies 1 - sigma(-x) = sigma(x), we have that # the derivative is symmetric around the center of the interval=1., # and hence is continuous at the endpoints. t = tf.convert_to_tensor(self.temperature) fractional_part = x - tf.math.floor(x) # First, because our function is defined on the interval [0.5, 1.5] # repeated, we need to rescale our input to reflect that. x - floor(x) # will map our input to [0, 1]. However, we need to map inputs whose # fractional part is < 0.5 to the right hand portion of the interval. # We'll also need to adjust the integer part to reflect this. integer_part = tf.math.floor(x) integer_part = tf.where(fractional_part < 0.5, integer_part - 1., integer_part) fractional_part = tf.where(fractional_part < 0.5, fractional_part + 0.5, fractional_part - 0.5) # Rescale so the left tail is 0., and the right tail is 1. This # will also guarantee us continuity. Differentiability comes from the # fact that the derivative of the sigmoid is symmetric, and hence # the two endpoints will have the same value for derivatives. # The below calculations are just # (sigmoid((f - 0.5) / t) - sigmoid(-0.5 / t)) / # (sigmoid(0.5 / t) - sigmoid(0.5 / t)) # We use log_sum_exp and log_sub_exp to make this calculation more # numerically stable. log_numerator = tfp_math.log_sub_exp((0.5 + fractional_part) / t, 0.5 / t) # If fractional_part == 0, then we'll get log(0). log_numerator = tf.where( tf.equal(fractional_part, 0.), dtype_util.as_numpy_dtype(self.dtype)(-np.inf), log_numerator) log_denominator = tfp_math.log_sub_exp((0.5 + fractional_part) / t, fractional_part / t) # If fractional_part == 0, then we'll get log(0). log_denominator = tf.where( tf.equal(fractional_part, 0.), dtype_util.as_numpy_dtype(self.dtype)(-np.inf), log_denominator) log_denominator = tfp_math.log_add_exp( log_denominator, tfp_math.log_sub_exp(1. / t, 0.5 / t)) rescaled_part = tf.math.exp(log_numerator - log_denominator) return integer_part + rescaled_part
def _reduce_log_l2_exp(loga, logb, axis=-1): return tf.math.reduce_logsumexp(2. * tfp_math.log_sub_exp(loga, logb), axis=axis)
def categorical_log_probs(self): """Log probabilities for the `K+1` ordered categories.""" log_survival = tf.math.log_sigmoid(self.loc[..., tf.newaxis] - self._augmented_cutpoints()) return tfp_math.log_sub_exp(log_survival[..., :-1], log_survival[..., 1:])
def _log_prob(self, x): # TODO(b/149334734): Consider using QuantizedDistribution for the log_prob # computation for better precision. log_survival_xm1 = self._log_survival_function(x - 1) log_survival_x = self._log_survival_function(x) return tfp_math.log_sub_exp(log_survival_xm1, log_survival_x)