def _inverse(self, y): # We undo the transformation from [0, 1] -> [0, 1]. # The inverse of the transformation will look like a shifted and scaled # logit function. We rewrite this to be more numerically stable, and will # produce a term log(a / b). log_{numerator, denominator} below is log(a) # and log(b) respectively. t = tf.convert_to_tensor(self.temperature) fractional_part = y - tf.math.floor(y) log_f = tf.math.log(fractional_part) # We wrap `0` and `0.5` in a `tf.constant` with explicit dtype to avoid # upcasting in the numpy backed. Declare this once, and use it everywhere. zero = tf.zeros([], self.dtype) one_half = tf.constant(0.5, self.dtype) log_numerator = tfp_math.log_sub_exp(one_half / t + log_f, log_f) log_numerator = tfp_math.log_add_exp(zero, log_numerator) # When the fractional part is zero, the numerator is 1. log_numerator = tf.where(tf.equal(fractional_part, 0.), zero, log_numerator) log_denominator = tfp_math.log_sub_exp(one_half / t, log_f + one_half / t) log_denominator = tfp_math.log_add_exp(log_f, log_denominator) # When the fractional part is zero, the denominator is 0.5 / t. log_denominator = tf.where(tf.equal(fractional_part, 0.), one_half / t, log_denominator) new_fractional_part = (t * (log_numerator - log_denominator) + one_half) # We finally shift this up since the original transformation was from # [0.5, 1.5] to [0, 1]. new_fractional_part = new_fractional_part + one_half return tf.math.floor(y) + new_fractional_part
def _inverse(self, y): # We undo the transformation from [0, 1] -> [0, 1]. # The inverse of the transformation will look like a shifted and scaled # logit function. We rewrite this to be more numerically stable, and will # produce a term log(a / b). log_{numerator, denominator} below is log(a) # and log(b) respectively. t = tf.convert_to_tensor(self.temperature) fractional_part = y - tf.math.floor(y) log_f = tf.math.log(fractional_part) log_numerator = tfp_math.log_sub_exp(0.5 / t + log_f, log_f) log_numerator = tfp_math.log_add_exp(0., log_numerator) # When the fractional part is zero, the numerator is 1. log_numerator = tf.where( tf.equal(fractional_part, 0.), dtype_util.as_numpy_dtype(self.dtype)(0.), log_numerator) log_denominator = tfp_math.log_sub_exp(0.5 / t, log_f + 0.5 / t) log_denominator = tfp_math.log_add_exp(log_f, log_denominator) # When the fractional part is zero, the denominator is 0.5 / t. log_denominator = tf.where( tf.equal(fractional_part, 0.), 0.5 / t, log_denominator) new_fractional_part = t * (log_numerator - log_denominator) + 0.5 # We finally shift this up since the original transformation was from # [0.5, 1.5] to [0, 1]. new_fractional_part = new_fractional_part + 0.5 return tf.math.floor(y) + new_fractional_part
def _quantile(self, p, logits=None): if logits is None: logits = self._logits_parameter_no_checks() logp = tf.math.log(p) # The expression for the quantile function is: # log(1 + (e^s - 1) * p) / s, where s is `logits`. When s is large, # the e^s sub-term becomes increasingly ill-conditioned. However, # since the numerator tends to s, we can reformulate the s > 0 case # as a offset from 1, which is more accurate. Coincidentally, # this eliminates a ratio of infinities problem when `s == +inf`. safe_negative_logits = tf.where(logits < 0., logits, -1.) safe_positive_logits = tf.where(logits > 0., logits, 1.) result = tf.where( logits > 0., 1. + tfp_math.log_add_exp( logp + tfp_math.log1mexp(safe_positive_logits), tf.math.negative(safe_positive_logits)) / safe_positive_logits, tf.math.log1p( tf.math.expm1(safe_negative_logits) * p) / safe_negative_logits) # When logits is zero, we can simplify # log(1 + (e^s - 1) * p) / s ~= log(1 + s * p) / s ~= s * p / s = p # Specifically, when logits is zero, the naive computation produces a NaN. result = tf.where(tf.math.equal(logits, 0.), p, result) # Finally, handle the case where `logits` and `p` are on the boundary, # as the above expressions can result in ratio of `infs` in that case as # well. return tf.where( (tf.math.equal(logits, -np.inf) & tf.math.equal(logp, 0.)) | (tf.math.equal(logits, np.inf) & tf.math.is_inf(logp)), tf.ones_like(logits), result)
def _quantile(self, p, probs=None): if probs is None: probs = self._probs_parameter_no_checks() cut_probs = self._cut_probs(probs) cut_logits = tf.math.log(cut_probs) - tf.math.log1p(-cut_probs) logp = tf.math.log(p) # The expression for the quantile function is: # log(1 + (e^s - 1) * p) / s, where s is `cut_logits`. When s is large, # the e^s sub-term becomes increasingly ill-conditioned. However, # since the numerator tends to s, we can reformulate the s > 0 case # as a offset from 1, which is more accurate. Coincidentally, # this eliminates a ratio of infinities problem when `s == +inf`. result = tf.where( cut_logits > 0., 1. + tfp_math.log_add_exp(logp + tfp_math.log1mexp(cut_logits), -cut_logits) / cut_logits, tf.math.log1p(tf.math.expm1(cut_logits) * p) / cut_logits) # Finally, handle the case where `cut_logits` and `p` are on the boundary, # as the above expressions can result in ratio of `infs` in that case as # well. result = tf.where( (tf.math.equal(cut_probs, 0.) & tf.math.equal(logp, 0.)) | (tf.math.equal(cut_probs, 1.) & tf.math.is_inf(logp)), tf.ones_like(cut_probs), result) return tf.where((probs < self._lims[0]) | (probs > self._lims[1]), result, p)
def _forward(self, x): # This has a well defined derivative with respect to x. # This is because in the range [0.5, 1.5] this is just a rescaled # logit function and hence has a derivative. At the end points, because # the logit function satisfies 1 - sigma(-x) = sigma(x), we have that # the derivative is symmetric around the center of the interval=1., # and hence is continuous at the endpoints. t = tf.convert_to_tensor(self.temperature) fractional_part = x - tf.math.floor(x) # First, because our function is defined on the interval [0.5, 1.5] # repeated, we need to rescale our input to reflect that. x - floor(x) # will map our input to [0, 1]. However, we need to map inputs whose # fractional part is < 0.5 to the right hand portion of the interval. # We'll also need to adjust the integer part to reflect this. integer_part = tf.math.floor(x) # We wrap `0.5` in a `tf.constant` with explicit dtype to avoid upcasting # in the numpy backed. Declare this once, and use it everywhere. one_half = tf.constant(0.5, self.dtype) integer_part = tf.where( fractional_part < one_half, integer_part - tf.ones([], self.dtype), integer_part) fractional_part = tf.where(fractional_part < one_half, fractional_part + one_half, fractional_part - one_half) # Rescale so the left tail is 0., and the right tail is 1. This # will also guarantee us continuity. Differentiability comes from the # fact that the derivative of the sigmoid is symmetric, and hence # the two endpoints will have the same value for derivatives. # The below calculations are just # (sigmoid((f - 0.5) / t) - sigmoid(-0.5 / t)) / # (sigmoid(0.5 / t) - sigmoid(-0.5 / t)) # We use log_sum_exp and log_sub_exp to make this calculation more # numerically stable. log_numerator = tfp_math.log_sub_exp( (one_half + fractional_part) / t, one_half / t) # If fractional_part == 0, then we'll get log(0). log_numerator = tf.where( tf.equal(fractional_part, 0.), tf.constant(-np.inf, self.dtype), log_numerator) log_denominator = tfp_math.log_sub_exp( (one_half + fractional_part) / t, fractional_part / t) # If fractional_part == 0, then we'll get log(0). log_denominator = tf.where( tf.equal(fractional_part, 0.), tf.constant(-np.inf, self.dtype), log_denominator) log_denominator = tfp_math.log_add_exp( log_denominator, tfp_math.log_sub_exp(tf.ones([], self.dtype) / t, one_half / t)) rescaled_part = tf.math.exp(log_numerator - log_denominator) # We add a term sigmoid(0.5 / t). When t->infinity, this will be 0.5, # which will correctly shift the function so that this acts like the # identity. When t->0, this will approach 0, so that the function # correctly approximates a floor function. return integer_part + rescaled_part + tf.math.sigmoid(-0.5 / t)
def _forward(self, x): # This has a well defined derivative with respect to x. # This is because in the range [0.5, 1.5] this is just a rescaled # logit function and hence has a derivative. At the end points, because # the logit function satisfies 1 - sigma(-x) = sigma(x), we have that # the derivative is symmetric around the center of the interval=1., # and hence is continuous at the endpoints. t = tf.convert_to_tensor(self.temperature) fractional_part = x - tf.math.floor(x) # First, because our function is defined on the interval [0.5, 1.5] # repeated, we need to rescale our input to reflect that. x - floor(x) # will map our input to [0, 1]. However, we need to map inputs whose # fractional part is < 0.5 to the right hand portion of the interval. # We'll also need to adjust the integer part to reflect this. integer_part = tf.math.floor(x) integer_part = tf.where(fractional_part < 0.5, integer_part - 1., integer_part) fractional_part = tf.where(fractional_part < 0.5, fractional_part + 0.5, fractional_part - 0.5) # Rescale so the left tail is 0., and the right tail is 1. This # will also guarantee us continuity. Differentiability comes from the # fact that the derivative of the sigmoid is symmetric, and hence # the two endpoints will have the same value for derivatives. # The below calculations are just # (sigmoid((f - 0.5) / t) - sigmoid(-0.5 / t)) / # (sigmoid(0.5 / t) - sigmoid(0.5 / t)) # We use log_sum_exp and log_sub_exp to make this calculation more # numerically stable. log_numerator = tfp_math.log_sub_exp((0.5 + fractional_part) / t, 0.5 / t) # If fractional_part == 0, then we'll get log(0). log_numerator = tf.where( tf.equal(fractional_part, 0.), dtype_util.as_numpy_dtype(self.dtype)(-np.inf), log_numerator) log_denominator = tfp_math.log_sub_exp((0.5 + fractional_part) / t, fractional_part / t) # If fractional_part == 0, then we'll get log(0). log_denominator = tf.where( tf.equal(fractional_part, 0.), dtype_util.as_numpy_dtype(self.dtype)(-np.inf), log_denominator) log_denominator = tfp_math.log_add_exp( log_denominator, tfp_math.log_sub_exp(1. / t, 0.5 / t)) rescaled_part = tf.math.exp(log_numerator - log_denominator) return integer_part + rescaled_part