def _parameter_control_dependencies(self, is_init): assertions = [] logits = self._logits probs = self._probs param, name = (probs, 'probs') if logits is None else (logits, 'logits') # In init, we can always build shape and dtype checks because # we assume shape doesn't change for Variable backed args. if is_init: if not dtype_util.is_floating(param.dtype): raise TypeError('Argument `{}` must having floating type.'.format(name)) msg = 'Argument `{}` must have rank at least 1.'.format(name) shape_static = tensorshape_util.dims(param.shape) if shape_static is not None: if len(shape_static) < 1: raise ValueError(msg) elif self.validate_args: param = tf.convert_to_tensor(param) assertions.append( assert_util.assert_rank_at_least(param, 1, message=msg)) with tf.control_dependencies(assertions): param = tf.identity(param) msg1 = 'Argument `{}` must have final dimension >= 1.'.format(name) msg2 = 'Argument `{}` must have final dimension <= {}.'.format( name, dtype_util.max(tf.int32)) event_size = shape_static[-1] if shape_static is not None else None if event_size is not None: if event_size < 1: raise ValueError(msg1) if event_size > dtype_util.max(tf.int32): raise ValueError(msg2) elif self.validate_args: param = tf.convert_to_tensor(param) assertions.append(assert_util.assert_greater_equal( tf.shape(param)[-1], 1, message=msg1)) # NOTE: For now, we leave out a runtime assertion that # `tf.shape(param)[-1] <= tf.int32.max`. An earlier `tf.shape` call # will fail before we get to this point. if not self.validate_args: assert not assertions # Should never happen. return [] if probs is not None: probs = param # reuse tensor conversion from above if is_init != tensor_util.is_ref(probs): probs = tf.convert_to_tensor(probs) one = tf.ones([], dtype=probs.dtype) assertions.extend([ assert_util.assert_non_negative(probs), assert_util.assert_less_equal(probs, one), assert_util.assert_near( tf.reduce_sum(probs, axis=-1), one, message='Argument `probs` must sum to 1.'), ]) return assertions
def test_assert_all_finite_input_finite(self): minval = tf.constant(dtype_util.min(self.dtype), dtype=self.dtype) maxval = tf.constant(dtype_util.max(self.dtype), dtype=self.dtype) # This tests if the minimum value for the dtype is detected as finite. self.assertAllFinite(minval) # This tests if the maximum value for the dtype is detected as finite. self.assertAllFinite(maxval) # This tests if a rank 3 `Tensor` with entries in the range # [0.4*minval, 0.4*maxval] is detected as finite. # The choice of range helps to avoid overflows or underflows # in tf.linspace calculations. num_elem = 1000 shape = (10, 10, 10) a = tf.reshape(tf.linspace(0.4*minval, 0.4*maxval, num_elem), shape) self.assertAllFinite(a)
def testMax(self, dtype, expected_maxval): self.assertEqual(dtype_util.max(dtype), expected_maxval)
def _sample_n(self, n, seed=None): power = tf.convert_to_tensor(self.power) shape = ps.concat([[n], ps.shape(power)], axis=0) numpy_dtype = dtype_util.as_numpy_dtype(power.dtype) seed = samplers.sanitize_seed(seed, salt='zipf') # Because `_hat_integral` is montonically decreasing, the bounds for u will # switch. # Compute the hat_integral explicitly here since we can calculate the log of # the inputs statically in float64 with numpy. maxval_u = tf.math.exp(-(power - 1.) * numpy_dtype(np.log1p(0.5)) - tf.math.log(power - 1.)) + 1. minval_u = tf.math.exp( -(power - 1.) * numpy_dtype(np.log1p(dtype_util.max(self.dtype) - 0.5)) - tf.math.log(power - 1.)) def loop_body(should_continue, k, seed): """Resample the non-accepted points.""" u_seed, next_seed = samplers.split_seed(seed) # Uniform variates must be sampled from the open-interval `(0, 1)` rather # than `[0, 1)`. To do so, we use # `np.finfo(dtype_util.as_numpy_dtype(self.dtype)).tiny` # because it is the smallest, positive, 'normal' number. A 'normal' number # is such that the mantissa has an implicit leading 1. Normal, positive # numbers x, y have the reasonable property that, `x + y >= max(x, y)`. In # this case, a subnormal number (i.e., np.nextafter) can cause us to # sample 0. u = samplers.uniform( shape, minval=np.finfo(dtype_util.as_numpy_dtype(power.dtype)).tiny, maxval=numpy_dtype(1.), dtype=power.dtype, seed=u_seed) # We use (1 - u) * maxval_u + u * minval_u rather than the other way # around, since we want to draw samples in (minval_u, maxval_u]. u = maxval_u + (minval_u - maxval_u) * u # set_shape needed here because of b/139013403 tensorshape_util.set_shape(u, should_continue.shape) # Sample the point X from the continuous density h(x) \propto x^(-power). x = self._hat_integral_inverse(u, power=power) # Rejection-inversion requires a `hat` function, h(x) such that # \int_{k - .5}^{k + .5} h(x) dx >= pmf(k + 1) for points k in the # support. A natural hat function for us is h(x) = x^(-power). # # After sampling X from h(x), suppose it lies in the interval # (K - .5, K + .5) for integer K. Then the corresponding K is accepted if # if lies to the left of x_K, where x_K is defined by: # \int_{x_k}^{K + .5} h(x) dx = H(x_K) - H(K + .5) = pmf(K + 1), # where H(x) = \int_x^inf h(x) dx. # Solving for x_K, we find that x_K = H_inverse(H(K + .5) + pmf(K + 1)). # Or, the acceptance condition is X <= H_inverse(H(K + .5) + pmf(K + 1)). # Since X = H_inverse(U), this simplifies to U <= H(K + .5) + pmf(K + 1). # Update the non-accepted points. # Since X \in (K - .5, K + .5), the sample K is chosen as floor(X + 0.5). k = tf.where(should_continue, tf.floor(x + 0.5), k) accept = (u <= self._hat_integral(k + .5, power=power) + tf.exp(self._log_prob(k + 1, power=power))) return [should_continue & (~accept), k, next_seed] should_continue, samples, _ = tf.while_loop( cond=lambda should_continue, *ignore: tf.reduce_any(should_continue ), body=loop_body, loop_vars=[ tf.ones(shape, dtype=tf.bool), # should_continue tf.zeros(shape, dtype=power.dtype), # k seed, # seed ], maximum_iterations=self.sample_maximum_iterations, ) samples = samples + 1. if self.validate_args and dtype_util.is_integer(self.dtype): samples = distribution_util.embed_check_integer_casting_closed( samples, target_dtype=self.dtype, assert_positive=True) samples = tf.cast(samples, self.dtype) if self.validate_args: npdt = dtype_util.as_numpy_dtype(self.dtype) v = npdt( dtype_util.min(npdt) if dtype_util.is_integer(npdt) else np.nan ) samples = tf.where(should_continue, v, samples) return samples
def _sample_n(self, n, seed=None): power = tf.convert_to_tensor(self.power) shape = tf.concat([[n], tf.shape(power)], axis=0) has_seed = seed is not None seed = SeedStream(seed, salt='zipf') minval_u = self._hat_integral(0.5, power=power) + 1. maxval_u = self._hat_integral(dtype_util.max(tf.int64) - 0.5, power=power) def loop_body(should_continue, k): """Resample the non-accepted points.""" # The range of U is chosen so that the resulting sample K lies in # [0, tf.int64.max). The final sample, if accepted, is K + 1. u = tf.random.uniform(shape, minval=minval_u, maxval=maxval_u, dtype=power.dtype, seed=seed()) # Sample the point X from the continuous density h(x) \propto x^(-power). x = self._hat_integral_inverse(u, power=power) # Rejection-inversion requires a `hat` function, h(x) such that # \int_{k - .5}^{k + .5} h(x) dx >= pmf(k + 1) for points k in the # support. A natural hat function for us is h(x) = x^(-power). # # After sampling X from h(x), suppose it lies in the interval # (K - .5, K + .5) for integer K. Then the corresponding K is accepted if # if lies to the left of x_K, where x_K is defined by: # \int_{x_k}^{K + .5} h(x) dx = H(x_K) - H(K + .5) = pmf(K + 1), # where H(x) = \int_x^inf h(x) dx. # Solving for x_K, we find that x_K = H_inverse(H(K + .5) + pmf(K + 1)). # Or, the acceptance condition is X <= H_inverse(H(K + .5) + pmf(K + 1)). # Since X = H_inverse(U), this simplifies to U <= H(K + .5) + pmf(K + 1). # Update the non-accepted points. # Since X \in (K - .5, K + .5), the sample K is chosen as floor(X + 0.5). k = tf.where(should_continue, tf.floor(x + 0.5), k) accept = (u <= self._hat_integral(k + .5, power=power) + tf.exp(self._log_prob(k + 1, power=power))) return [should_continue & (~accept), k] should_continue, samples = tf.while_loop( cond=lambda should_continue, *ignore: tf.reduce_any(should_continue ), body=loop_body, loop_vars=[ tf.ones(shape, dtype=tf.bool), # should_continue tf.zeros(shape, dtype=power.dtype), # k ], parallel_iterations=1 if has_seed else 10, maximum_iterations=self.sample_maximum_iterations, ) samples = samples + 1. if self.validate_args and dtype_util.is_integer(self.dtype): samples = distribution_util.embed_check_integer_casting_closed( samples, target_dtype=self.dtype, assert_positive=True) samples = tf.cast(samples, self.dtype) if self.validate_args: npdt = dtype_util.as_numpy_dtype(self.dtype) v = npdt( dtype_util.min(npdt) if dtype_util.is_integer(npdt) else np.nan ) samples = tf.where(should_continue, v, samples) return samples