def testCorrectlyAssertsSmallestPossibleInteger(self): with self.assertRaisesOpError('Elements cannot be smaller than 0.'): x = tf1.placeholder_with_default(np.array([1, -1], dtype=np.int32), shape=None) x_checked = distribution_util.embed_check_integer_casting_closed( x, target_dtype=tf.uint16, assert_nonnegative=False) self.evaluate(x_checked)
def testCorrectlyAssersIntegerForm(self): with self.assertRaisesOpError('Elements must be int16-equivalent.'): x = tf1.placeholder_with_default( np.array([1, 1.5], dtype=np.float16), shape=None) x_checked = distribution_util.embed_check_integer_casting_closed( x, target_dtype=tf.int16) self.evaluate(x_checked)
def testCorrectlyAssertsLargestPossibleInteger(self): with self.assertRaisesOpError('Elements cannot exceed 32767.'): x = tf1.placeholder_with_default( np.array([1, 2**15], dtype=np.int32), shape=None) x_checked = distribution_util.embed_check_integer_casting_closed( x, target_dtype=tf.int16) self.evaluate(x_checked)
def testCorrectlyAssertsPositive(self): with self.assertRaisesOpError('Elements must be positive'): x = tf1.placeholder_with_default( np.array([1, 0], dtype=np.float16), shape=None) x_checked = distribution_util.embed_check_integer_casting_closed( x, target_dtype=tf.int16, assert_positive=True) self.evaluate(x_checked)
def _log_prob(self, event): if self.validate_args: event = distribution_util.embed_check_integer_casting_closed( event, target_dtype=tf.bool) log_probs0, log_probs1 = self._outcome_log_probs() event = tf.cast(event, log_probs0.dtype) return event * (log_probs1 - log_probs0) + log_probs0
def _mean(self): probs = self.probs outcomes = self.outcomes if dtype_util.is_integer(outcomes.dtype): if self._validate_args: outcomes = dist_util.embed_check_integer_casting_closed( outcomes, target_dtype=probs.dtype) outcomes = tf.cast(outcomes, dtype=probs.dtype) return tf.tensordot(outcomes, probs, axes=[[0], [-1]])
def _log_prob(self, k): logits = self.logits_parameter() if self.validate_args: k = distribution_util.embed_check_integer_casting_closed( k, target_dtype=tf.int32) k, logits = _broadcast_cat_event_and_params( k, logits, base_dtype=dtype_util.base_dtype(self.dtype)) return -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=k, logits=logits)
def _log_prob(self, k): with tf.name_scope("Cat2log_prob"): logits = self.logits_parameter() if self.validate_args: k = distribution_util.embed_check_integer_casting_closed( k, target_dtype=self.dtype) k, logits = _broadcast_cat_event_and_params( k, logits, base_dtype=dtype_util.base_dtype(self.dtype)) logits_normalised = tf.math.log(tf.math.softmax(logits)) return tf.gather(logits_normalised, k, batch_dims=1)
def _log_prob(self, k): k = tf.convert_to_tensor(value=k, name="k") if self.validate_args: k = util.embed_check_integer_casting_closed(k, target_dtype=tf.int32) k, logits = _broadcast_cat_event_and_params( k, self.logits, base_dtype=self.dtype.base_dtype) return -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=k, logits=logits)
def _log_prob(self, k): k = tf.convert_to_tensor(k, name="k") if self.validate_args: k = util.embed_check_integer_casting_closed( k, target_dtype=tf.int32) k, logits = _broadcast_cat_event_and_params( k, self.logits, base_dtype=self.dtype.base_dtype) return -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=k, logits=logits)
def _variance(self): probs = self._categorical.probs outcomes = tf.broadcast_to( self.outcomes, shape=dist_util.prefer_static_shape(probs)) if dtype_util.is_integer(outcomes.dtype): if self._validate_args: outcomes = dist_util.embed_check_integer_casting_closed( outcomes, target_dtype=probs.dtype) outcomes = tf.cast(outcomes, dtype=probs.dtype) square_d = tf.math.squared_difference(outcomes, tf.expand_dims(self.mean(), axis=-1)) return tf.reduce_sum(input_tensor=probs * square_d, axis=-1)
def _variance(self): probs = self._categorical.probs_parameter() outcomes = tf.broadcast_to(self.outcomes, shape=ps.shape(probs)) if dtype_util.is_integer(outcomes.dtype): if self._validate_args: outcomes = dist_util.embed_check_integer_casting_closed( outcomes, target_dtype=probs.dtype) outcomes = tf.cast(outcomes, dtype=probs.dtype) square_d = tf.math.squared_difference( outcomes, self._mean(probs)[..., tf.newaxis]) return tf.reduce_sum(probs * square_d, axis=-1)
def _log_prob(self, x): # The log probability at positive integer points x is log(x^(-power) / Z) # where Z is the normalization constant. For x < 1 and non-integer points, # the log-probability is -inf. # # However, if interpolate_nondiscrete is True, we return the natural # continuous relaxation for x >= 1 which agrees with the log probability at # positive integer points. # # If interpolate_nondiscrete is False and validate_args is True, we check # that the sample point x is in the support. That is, x is equivalent to a # positive integer. x = tf.cast(x, self.power.dtype) if self.validate_args and not self.interpolate_nondiscrete: x = distribution_util.embed_check_integer_casting_closed( x, target_dtype=self.dtype, assert_positive=True) return self._log_unnormalized_prob(x) - self._log_normalization()
def _log_prob(self, x): # The log probability at positive integer points x is log(x^(-power) / Z) # where Z is the normalization constant. For x < 1 and non-integer points, # the log-probability is -inf. # # However, if interpolate_nondiscrete is True, we return the natural # continuous relaxation for x >= 1 which agrees with the log probability at # positive integer points. # # If interpolate_nondiscrete is False and validate_args is True, we check # that the sample point x is in the support. That is, x is equivalent to a # positive integer. x = tf.cast(x, self.power.dtype) if self.validate_args and not self.interpolate_nondiscrete: x = distribution_util.embed_check_integer_casting_closed( x, target_dtype=self.dtype, assert_positive=True) return self._log_unnormalized_prob(x) - self._log_normalization()
def _cdf(self, k): k = tf.convert_to_tensor(k, name="k") if self.validate_args: k = util.embed_check_integer_casting_closed(k, target_dtype=tf.int32) k, probs = _broadcast_cat_event_and_params( k, self.probs, base_dtype=self.dtype.base_dtype) # batch-flatten everything in order to use `sequence_mask()`. batch_flattened_probs = tf.reshape(probs, (-1, self._event_size)) batch_flattened_k = tf.reshape(k, [-1]) to_sum_over = tf.where( tf.sequence_mask(batch_flattened_k, self._event_size), batch_flattened_probs, tf.zeros_like(batch_flattened_probs)) batch_flattened_cdf = tf.reduce_sum(to_sum_over, axis=-1) # Reshape back to the shape of the argument. return tf.reshape(batch_flattened_cdf, tf.shape(k))
def _cdf(self, k): k = tf.convert_to_tensor(k, name="k") if self.validate_args: k = util.embed_check_integer_casting_closed( k, target_dtype=tf.int32) k, probs = _broadcast_cat_event_and_params( k, self.probs, base_dtype=self.dtype.base_dtype) # batch-flatten everything in order to use `sequence_mask()`. batch_flattened_probs = tf.reshape(probs, (-1, self._event_size)) batch_flattened_k = tf.reshape(k, [-1]) to_sum_over = tf.where( tf.sequence_mask(batch_flattened_k, self._event_size), batch_flattened_probs, tf.zeros_like(batch_flattened_probs)) batch_flattened_cdf = tf.reduce_sum(to_sum_over, axis=-1) # Reshape back to the shape of the argument. return tf.reshape(batch_flattened_cdf, tf.shape(k))
def _log_prob(self, event): if self.validate_args: event = util.embed_check_integer_casting_closed( event, target_dtype=tf.bool) # TODO(jaana): The current sigmoid_cross_entropy_with_logits has # inconsistent behavior for logits = inf/-inf. event = tf.cast(event, self.logits.dtype) logits = self.logits # sigmoid_cross_entropy_with_logits doesn't broadcast shape, # so we do this here. def _broadcast(logits, event): return (tf.ones_like(event) * logits, tf.ones_like(logits) * event) if not (event.shape.is_fully_defined() and logits.shape.is_fully_defined() and event.shape == logits.shape): logits, event = _broadcast(logits, event) return -tf.nn.sigmoid_cross_entropy_with_logits(labels=event, logits=logits)
def _sample_n(self, n, seed=None): power = tf.convert_to_tensor(self.power) shape = tf.concat([[n], tf.shape(power)], axis=0) has_seed = seed is not None seed = SeedStream(seed, salt='zipf') minval_u = self._hat_integral(0.5, power=power) + 1. maxval_u = self._hat_integral(tf.int64.max - 0.5, power=power) def loop_body(should_continue, k): """Resample the non-accepted points.""" # The range of U is chosen so that the resulting sample K lies in # [0, tf.int64.max). The final sample, if accepted, is K + 1. u = tf.random.uniform( shape, minval=minval_u, maxval=maxval_u, dtype=power.dtype, seed=seed()) # Sample the point X from the continuous density h(x) \propto x^(-power). x = self._hat_integral_inverse(u, power=power) # Rejection-inversion requires a `hat` function, h(x) such that # \int_{k - .5}^{k + .5} h(x) dx >= pmf(k + 1) for points k in the # support. A natural hat function for us is h(x) = x^(-power). # # After sampling X from h(x), suppose it lies in the interval # (K - .5, K + .5) for integer K. Then the corresponding K is accepted if # if lies to the left of x_K, where x_K is defined by: # \int_{x_k}^{K + .5} h(x) dx = H(x_K) - H(K + .5) = pmf(K + 1), # where H(x) = \int_x^inf h(x) dx. # Solving for x_K, we find that x_K = H_inverse(H(K + .5) + pmf(K + 1)). # Or, the acceptance condition is X <= H_inverse(H(K + .5) + pmf(K + 1)). # Since X = H_inverse(U), this simplifies to U <= H(K + .5) + pmf(K + 1). # Update the non-accepted points. # Since X \in (K - .5, K + .5), the sample K is chosen as floor(X + 0.5). k = tf.where(should_continue, tf.floor(x + 0.5), k) accept = (u <= self._hat_integral(k + .5, power=power) + tf.exp( self._log_prob(k + 1, power=power))) return [should_continue & (~accept), k] should_continue, samples = tf.while_loop( cond=lambda should_continue, *ignore: tf.reduce_any(should_continue), body=loop_body, loop_vars=[ tf.ones(shape, dtype=tf.bool), # should_continue tf.zeros(shape, dtype=power.dtype), # k ], parallel_iterations=1 if has_seed else 10, maximum_iterations=self.sample_maximum_iterations, ) samples = samples + 1. if self.validate_args and dtype_util.is_integer(self.dtype): samples = distribution_util.embed_check_integer_casting_closed( samples, target_dtype=self.dtype, assert_positive=True) samples = tf.cast(samples, self.dtype) if self.validate_args: npdt = dtype_util.as_numpy_dtype(self.dtype) v = npdt(dtype_util.min(npdt) if dtype_util.is_integer(npdt) else np.nan) samples = tf.where(should_continue, v, samples) return samples
def _sample_n(self, n, seed=None): power = tf.convert_to_tensor(self.power) shape = ps.concat([[n], ps.shape(power)], axis=0) numpy_dtype = dtype_util.as_numpy_dtype(power.dtype) seed = samplers.sanitize_seed(seed, salt='zipf') # Because `_hat_integral` is montonically decreasing, the bounds for u will # switch. # Compute the hat_integral explicitly here since we can calculate the log of # the inputs statically in float64 with numpy. maxval_u = tf.math.exp(-(power - 1.) * numpy_dtype(np.log1p(0.5)) - tf.math.log(power - 1.)) + 1. minval_u = tf.math.exp( -(power - 1.) * numpy_dtype(np.log1p(dtype_util.max(self.dtype) - 0.5)) - tf.math.log(power - 1.)) def loop_body(should_continue, k, seed): """Resample the non-accepted points.""" u_seed, next_seed = samplers.split_seed(seed) # Uniform variates must be sampled from the open-interval `(0, 1)` rather # than `[0, 1)`. To do so, we use # `np.finfo(dtype_util.as_numpy_dtype(self.dtype)).tiny` # because it is the smallest, positive, 'normal' number. A 'normal' number # is such that the mantissa has an implicit leading 1. Normal, positive # numbers x, y have the reasonable property that, `x + y >= max(x, y)`. In # this case, a subnormal number (i.e., np.nextafter) can cause us to # sample 0. u = samplers.uniform( shape, minval=np.finfo(dtype_util.as_numpy_dtype(power.dtype)).tiny, maxval=numpy_dtype(1.), dtype=power.dtype, seed=u_seed) # We use (1 - u) * maxval_u + u * minval_u rather than the other way # around, since we want to draw samples in (minval_u, maxval_u]. u = maxval_u + (minval_u - maxval_u) * u # set_shape needed here because of b/139013403 tensorshape_util.set_shape(u, should_continue.shape) # Sample the point X from the continuous density h(x) \propto x^(-power). x = self._hat_integral_inverse(u, power=power) # Rejection-inversion requires a `hat` function, h(x) such that # \int_{k - .5}^{k + .5} h(x) dx >= pmf(k + 1) for points k in the # support. A natural hat function for us is h(x) = x^(-power). # # After sampling X from h(x), suppose it lies in the interval # (K - .5, K + .5) for integer K. Then the corresponding K is accepted if # if lies to the left of x_K, where x_K is defined by: # \int_{x_k}^{K + .5} h(x) dx = H(x_K) - H(K + .5) = pmf(K + 1), # where H(x) = \int_x^inf h(x) dx. # Solving for x_K, we find that x_K = H_inverse(H(K + .5) + pmf(K + 1)). # Or, the acceptance condition is X <= H_inverse(H(K + .5) + pmf(K + 1)). # Since X = H_inverse(U), this simplifies to U <= H(K + .5) + pmf(K + 1). # Update the non-accepted points. # Since X \in (K - .5, K + .5), the sample K is chosen as floor(X + 0.5). k = tf.where(should_continue, tf.floor(x + 0.5), k) accept = (u <= self._hat_integral(k + .5, power=power) + tf.exp(self._log_prob(k + 1, power=power))) return [should_continue & (~accept), k, next_seed] should_continue, samples, _ = tf.while_loop( cond=lambda should_continue, *ignore: tf.reduce_any(should_continue ), body=loop_body, loop_vars=[ tf.ones(shape, dtype=tf.bool), # should_continue tf.zeros(shape, dtype=power.dtype), # k seed, # seed ], maximum_iterations=self.sample_maximum_iterations, ) samples = samples + 1. if self.validate_args and dtype_util.is_integer(self.dtype): samples = distribution_util.embed_check_integer_casting_closed( samples, target_dtype=self.dtype, assert_positive=True) samples = tf.cast(samples, self.dtype) if self.validate_args: npdt = dtype_util.as_numpy_dtype(self.dtype) v = npdt( dtype_util.min(npdt) if dtype_util.is_integer(npdt) else np.nan ) samples = tf.where(should_continue, v, samples) return samples
def _sample_n(self, n, seed=None): shape = tf.concat([[n], self.batch_shape_tensor()], axis=0) has_seed = seed is not None seed = SeedStream(seed, salt="zipf") minval_u = self._hat_integral(0.5) + 1. maxval_u = self._hat_integral(tf.int64.max - 0.5) def loop_body(should_continue, k): """Resample the non-accepted points.""" # The range of U is chosen so that the resulting sample K lies in # [0, tf.int64.max). The final sample, if accepted, is K + 1. u = tf.random_uniform( shape, minval=minval_u, maxval=maxval_u, dtype=self.power.dtype, seed=seed()) # Sample the point X from the continuous density h(x) \propto x^(-power). x = self._hat_integral_inverse(u) # Rejection-inversion requires a `hat` function, h(x) such that # \int_{k - .5}^{k + .5} h(x) dx >= pmf(k + 1) for points k in the # support. A natural hat function for us is h(x) = x^(-power). # # After sampling X from h(x), suppose it lies in the interval # (K - .5, K + .5) for integer K. Then the corresponding K is accepted if # if lies to the left of x_K, where x_K is defined by: # \int_{x_k}^{K + .5} h(x) dx = H(x_K) - H(K + .5) = pmf(K + 1), # where H(x) = \int_x^inf h(x) dx. # Solving for x_K, we find that x_K = H_inverse(H(K + .5) + pmf(K + 1)). # Or, the acceptance condition is X <= H_inverse(H(K + .5) + pmf(K + 1)). # Since X = H_inverse(U), this simplifies to U <= H(K + .5) + pmf(K + 1). # Update the non-accepted points. # Since X \in (K - .5, K + .5), the sample K is chosen as floor(X + 0.5). k = tf.where(should_continue, tf.floor(x + 0.5), k) accept = (u <= self._hat_integral(k + .5) + tf.exp(self._log_prob(k + 1))) return [should_continue & (~accept), k] should_continue, samples = tf.while_loop( cond=lambda should_continue, *ignore: tf.reduce_any(should_continue), body=loop_body, loop_vars=[ tf.ones(shape, dtype=tf.bool), # should_continue tf.zeros(shape, dtype=self.power.dtype), # k ], parallel_iterations=1 if has_seed else 10, maximum_iterations=self.sample_maximum_iterations, ) samples = samples + 1. if self.validate_args and self.dtype.is_integer: samples = distribution_util.embed_check_integer_casting_closed( samples, target_dtype=self.dtype, assert_positive=True) samples = tf.cast(samples, self.dtype) if self.validate_args: dt = self.dtype.as_numpy_dtype if self.dtype.is_integer: mask = tf.fill(shape, value=np.array(np.iinfo(dt).min, dtype=dt)) samples = tf.where(should_continue, mask, samples) else: mask = tf.fill(shape, value=np.array(np.nan, dtype=dt)) samples = tf.where(should_continue, mask, samples) return samples