def _sample_n(self, n, seed=None): power = tf.convert_to_tensor(self.power) shape = ps.concat([[n], ps.shape(power)], axis=0) numpy_dtype = dtype_util.as_numpy_dtype(power.dtype) seed = samplers.sanitize_seed(seed, salt='zipf') # Because `_hat_integral` is montonically decreasing, the bounds for u will # switch. # Compute the hat_integral explicitly here since we can calculate the log of # the inputs statically in float64 with numpy. maxval_u = tf.math.exp(-(power - 1.) * numpy_dtype(np.log1p(0.5)) - tf.math.log(power - 1.)) + 1. minval_u = tf.math.exp( -(power - 1.) * numpy_dtype(np.log1p(dtype_util.max(self.dtype) - 0.5)) - tf.math.log(power - 1.)) def loop_body(should_continue, k, seed): """Resample the non-accepted points.""" u_seed, next_seed = samplers.split_seed(seed) # Uniform variates must be sampled from the open-interval `(0, 1)` rather # than `[0, 1)`. To do so, we use # `np.finfo(dtype_util.as_numpy_dtype(self.dtype)).tiny` # because it is the smallest, positive, 'normal' number. A 'normal' number # is such that the mantissa has an implicit leading 1. Normal, positive # numbers x, y have the reasonable property that, `x + y >= max(x, y)`. In # this case, a subnormal number (i.e., np.nextafter) can cause us to # sample 0. u = samplers.uniform( shape, minval=np.finfo(dtype_util.as_numpy_dtype(power.dtype)).tiny, maxval=numpy_dtype(1.), dtype=power.dtype, seed=u_seed) # We use (1 - u) * maxval_u + u * minval_u rather than the other way # around, since we want to draw samples in (minval_u, maxval_u]. u = maxval_u + (minval_u - maxval_u) * u # set_shape needed here because of b/139013403 tensorshape_util.set_shape(u, should_continue.shape) # Sample the point X from the continuous density h(x) \propto x^(-power). x = self._hat_integral_inverse(u, power=power) # Rejection-inversion requires a `hat` function, h(x) such that # \int_{k - .5}^{k + .5} h(x) dx >= pmf(k + 1) for points k in the # support. A natural hat function for us is h(x) = x^(-power). # # After sampling X from h(x), suppose it lies in the interval # (K - .5, K + .5) for integer K. Then the corresponding K is accepted if # if lies to the left of x_K, where x_K is defined by: # \int_{x_k}^{K + .5} h(x) dx = H(x_K) - H(K + .5) = pmf(K + 1), # where H(x) = \int_x^inf h(x) dx. # Solving for x_K, we find that x_K = H_inverse(H(K + .5) + pmf(K + 1)). # Or, the acceptance condition is X <= H_inverse(H(K + .5) + pmf(K + 1)). # Since X = H_inverse(U), this simplifies to U <= H(K + .5) + pmf(K + 1). # Update the non-accepted points. # Since X \in (K - .5, K + .5), the sample K is chosen as floor(X + 0.5). k = tf.where(should_continue, tf.floor(x + 0.5), k) accept = (u <= self._hat_integral(k + .5, power=power) + tf.exp(self._log_prob(k + 1, power=power))) return [should_continue & (~accept), k, next_seed] should_continue, samples, _ = tf.while_loop( cond=lambda should_continue, *ignore: tf.reduce_any(should_continue ), body=loop_body, loop_vars=[ tf.ones(shape, dtype=tf.bool), # should_continue tf.zeros(shape, dtype=power.dtype), # k seed, # seed ], maximum_iterations=self.sample_maximum_iterations, ) samples = samples + 1. if self.validate_args and dtype_util.is_integer(self.dtype): samples = distribution_util.embed_check_integer_casting_closed( samples, target_dtype=self.dtype, assert_positive=True) samples = tf.cast(samples, self.dtype) if self.validate_args: npdt = dtype_util.as_numpy_dtype(self.dtype) v = npdt( dtype_util.min(npdt) if dtype_util.is_integer(npdt) else np.nan ) samples = tf.where(should_continue, v, samples) return samples
def _cast_dtype(dtype): if dtype_util.as_numpy_dtype(dtype) is np.int64: return tf.float64 elif dtype_util.is_integer(dtype): return tf.float32 return dtype
def testIsInteger(self): self.assertFalse(dtype_util.is_integer(np.float64))
def __init__(self, permutation, axis=-1, validate_args=False, name=None): """Creates the `Permute` bijector. Args: permutation: An `int`-like vector-shaped `Tensor` representing the permutation to apply to the `axis` dimension of the transformed `Tensor`. axis: Scalar `int` `Tensor` representing the dimension over which to `tf.gather`. `axis` must be relative to the end (reading left to right) thus must be negative. Default value: `-1` (i.e., right-most). validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str`, name given to ops managed by this object. Raises: TypeError: if `not dtype_util.is_integer(permutation.dtype)`. ValueError: if `permutation` does not contain exactly one of each of `{0, 1, ..., d}`. NotImplementedError: if `axis` is not known prior to graph execution. NotImplementedError: if `axis` is not negative. """ with tf.name_scope(name or "permute") as name: axis = tf.convert_to_tensor(axis, name="axis") if not dtype_util.is_integer(axis.dtype): raise TypeError("axis.dtype ({}) should be `int`-like.".format( dtype_util.name(axis.dtype))) permutation = tf.convert_to_tensor(permutation, name="permutation") if not dtype_util.is_integer(permutation.dtype): raise TypeError( "permutation.dtype ({}) should be `int`-like.".format( dtype_util.name(permutation.dtype))) p = tf.get_static_value(permutation) if p is not None: if set(p) != set(np.arange(p.size)): raise ValueError( "Permutation over `d` must contain exactly one of " "each of `{0, 1, ..., d}`.") elif validate_args: p, _ = tf.math.top_k(-permutation, k=tf.shape(permutation)[-1], sorted=True) permutation = distribution_util.with_dependencies([ assert_util.assert_equal( -p, tf.range(tf.size(p)), message=( "Permutation over `d` must contain exactly one of " "each of `{0, 1, ..., d}`.")), ], permutation) axis_ = tf.get_static_value(axis) if axis_ is None: raise NotImplementedError( "`axis` must be known prior to graph " "execution.") elif axis_ >= 0: raise NotImplementedError( "`axis` must be relative the rightmost " "dimension, i.e., negative.") else: forward_min_event_ndims = int(np.abs(axis_)) self._permutation = permutation self._axis = axis super(Permute, self).__init__( forward_min_event_ndims=forward_min_event_ndims, is_constant_jacobian=True, validate_args=validate_args, name=name)
def percentile(x, q, axis=None, interpolation=None, keepdims=False, validate_args=False, preserve_gradients=True, keep_dims=None, name=None): """Compute the `q`-th percentile(s) of `x`. Given a vector `x`, the `q`-th percentile of `x` is the value `q / 100` of the way from the minimum to the maximum in a sorted copy of `x`. The values and distances of the two nearest neighbors as well as the `interpolation` parameter will determine the percentile if the normalized ranking does not match the location of `q` exactly. This function is the same as the median if `q = 50`, the same as the minimum if `q = 0` and the same as the maximum if `q = 100`. Multiple percentiles can be computed at once by using `1-D` vector `q`. Dimension zero of the returned `Tensor` will index the different percentiles. Compare to `numpy.percentile`. Args: x: Numeric `N-D` `Tensor` with `N > 0`. If `axis` is not `None`, `x` must have statically known number of dimensions. q: Scalar or vector `Tensor` with values in `[0, 100]`. The percentile(s). axis: Optional `0-D` or `1-D` integer `Tensor` with constant values. The axis that index independent samples over which to return the desired percentile. If `None` (the default), treat every dimension as a sample dimension, returning a scalar. interpolation : {'nearest', 'linear', 'lower', 'higher', 'midpoint'}. Default value: 'nearest'. This specifies the interpolation method to use when the desired quantile lies between two data points `i < j`: * linear: i + (j - i) * fraction, where fraction is the fractional part of the index surrounded by i and j. * lower: `i`. * higher: `j`. * nearest: `i` or `j`, whichever is nearest. * midpoint: (i + j) / 2. `linear` and `midpoint` interpolation do not work with integer dtypes. keepdims: Python `bool`. If `True`, the last dimension is kept with size 1 If `False`, the last dimension is removed from the output shape. validate_args: Whether to add runtime checks of argument validity. If False, and arguments are incorrect, correct behavior is not guaranteed. preserve_gradients: Python `bool`. If `True`, ensure that gradient w.r.t the percentile `q` is preserved in the case of linear interpolation. If `False`, the gradient will be (incorrectly) zero when `q` corresponds to a point in `x`. keep_dims: deprecated, use keepdims instead. name: A Python string name to give this `Op`. Default is 'percentile' Returns: A `(rank(q) + N - len(axis))` dimensional `Tensor` of same dtype as `x`, or, if `axis` is `None`, a `rank(q)` `Tensor`. The first `rank(q)` dimensions index quantiles for different values of `q`. Raises: ValueError: If argument 'interpolation' is not an allowed type. ValueError: If interpolation type not compatible with `dtype`. #### Examples ```python # Get 30th percentile with default ('nearest') interpolation. x = [1., 2., 3., 4.] tfp.stats.percentile(x, q=30.) ==> 2.0 # Get 30th percentile with 'linear' interpolation. x = [1., 2., 3., 4.] tfp.stats.percentile(x, q=30., interpolation='linear') ==> 1.9 # Get 30th and 70th percentiles with 'lower' interpolation x = [1., 2., 3., 4.] tfp.stats.percentile(x, q=[30., 70.], interpolation='lower') ==> [1., 3.] # Get 100th percentile (maximum). By default, this is computed over every dim x = [[1., 2.] [3., 4.]] tfp.stats.percentile(x, q=100.) ==> 4. # Treat the leading dim as indexing samples, and find the 100th quantile (max) # over all such samples. x = [[1., 2.] [3., 4.]] tfp.stats.percentile(x, q=100., axis=[0]) ==> [3., 4.] ``` """ keepdims = keepdims if keep_dims is None else keep_dims del keep_dims name = name or 'percentile' allowed_interpolations = {'linear', 'lower', 'higher', 'nearest', 'midpoint'} if interpolation is None: interpolation = 'nearest' else: if interpolation not in allowed_interpolations: raise ValueError( 'Argument `interpolation` must be in {}. Found {}.'.format( allowed_interpolations, interpolation)) with tf.name_scope(name): x = tf.convert_to_tensor(x, name='x') if (interpolation in {'linear', 'midpoint'} and dtype_util.is_integer(x.dtype)): raise TypeError('{} interpolation not allowed with dtype {}'.format( interpolation, x.dtype)) # Double is needed here and below, else we get the wrong index if the array # is huge along axis. q = tf.cast(q, tf.float64) _get_static_ndims(q, expect_ndims_no_more_than=1) if validate_args: q = distribution_util.with_dependencies([ assert_util.assert_rank_in(q, [0, 1]), assert_util.assert_greater_equal(q, tf.cast(0., tf.float64)), assert_util.assert_less_equal(q, tf.cast(100., tf.float64)) ], q) # Move `axis` dims of `x` to the rightmost, call it `y`. if axis is None: y = tf.reshape(x, [-1]) else: x_ndims = _get_static_ndims( x, expect_static=True, expect_ndims_at_least=1) axis = _make_static_axis_non_negative_list(axis, x_ndims) y = _move_dims_to_flat_end(x, axis, x_ndims, right_end=True) frac_at_q_or_above = 1. - q / 100. # Sort everything, not just the top 'k' entries, which allows multiple calls # to sort only once (under the hood) and use CSE. sorted_y = _sort_tensor(y) d = tf.cast(tf.shape(y)[-1], tf.float64) def _get_indices(interp_type): """Get values of y at the indices implied by interp_type.""" # Note `lower` <--> ceiling. Confusing, huh? Due to the fact that # _sort_tensor sorts highest to lowest, tf.ceil corresponds to the higher # index, but the lower value of y! if interp_type == 'lower': indices = tf.math.ceil((d - 1) * frac_at_q_or_above) elif interp_type == 'higher': indices = tf.floor((d - 1) * frac_at_q_or_above) elif interp_type == 'nearest': indices = tf.round((d - 1) * frac_at_q_or_above) # d - 1 will be distinct from d in int32, but not necessarily double. # So clip to avoid out of bounds errors. return tf.clip_by_value( tf.cast(indices, tf.int32), 0, tf.shape(y)[-1] - 1) if interpolation in ['nearest', 'lower', 'higher']: gathered_y = tf.gather(sorted_y, _get_indices(interpolation), axis=-1) elif interpolation == 'midpoint': gathered_y = 0.5 * ( tf.gather(sorted_y, _get_indices('lower'), axis=-1) + tf.gather(sorted_y, _get_indices('higher'), axis=-1)) elif interpolation == 'linear': # Copy-paste of docstring on interpolation: # linear: i + (j - i) * fraction, where fraction is the fractional part # of the index surrounded by i and j. larger_y_idx = _get_indices('lower') exact_idx = (d - 1) * frac_at_q_or_above if preserve_gradients: # If q corresponds to a point in x, we will initially have # larger_y_idx == smaller_y_idx. # This results in the gradient w.r.t. fraction being zero (recall `q` # enters only through `fraction`...and see that things cancel). # The fix is to ensure that smaller_y_idx and larger_y_idx are always # separated by exactly 1. smaller_y_idx = tf.maximum(larger_y_idx - 1, 0) larger_y_idx = tf.minimum(smaller_y_idx + 1, tf.shape(y)[-1] - 1) fraction = tf.cast(larger_y_idx, tf.float64) - exact_idx else: smaller_y_idx = _get_indices('higher') fraction = tf.math.ceil((d - 1) * frac_at_q_or_above) - exact_idx fraction = tf.cast(fraction, y.dtype) gathered_y = ( tf.gather(sorted_y, larger_y_idx, axis=-1) * (1 - fraction) + tf.gather(sorted_y, smaller_y_idx, axis=-1) * fraction) # Propagate NaNs if x.dtype in (tf.bfloat16, tf.float16, tf.float32, tf.float64): # Apparently tf.is_nan doesn't like other dtypes nan_batch_members = tf.reduce_any(tf.math.is_nan(x), axis=axis) right_rank_matched_shape = tf.pad( tf.shape(nan_batch_members), paddings=[[0, tf.rank(q)]], constant_values=1) nan_batch_members = tf.reshape( nan_batch_members, shape=right_rank_matched_shape) nan = np.array(np.nan, dtype_util.as_numpy_dtype(gathered_y.dtype)) gathered_y = tf.where(nan_batch_members, nan, gathered_y) # Expand dimensions if requested if keepdims: if axis is None: ones_vec = tf.ones( shape=[_get_best_effort_ndims(x) + _get_best_effort_ndims(q)], dtype=tf.int32) gathered_y *= tf.ones(ones_vec, dtype=x.dtype) else: gathered_y = _insert_back_keepdims(gathered_y, axis) # If q is a scalar, then result has the right shape. # If q is a vector, then result has trailing dim of shape q.shape, which # needs to be rotated to dim 0. return distribution_util.rotate_transpose(gathered_y, tf.rank(q))
def _parameter_control_dependencies(self, is_init): assertions = [] # Check num_steps is a scalar that's at least 1. if is_init != tensor_util.is_ref(self.num_steps): num_steps = tf.convert_to_tensor(self.num_steps) num_steps_ = tf.get_static_value(num_steps) if num_steps_ is not None: if np.ndim(num_steps_) != 0: raise ValueError( '`num_steps` must be a scalar but it has rank {}'. format(np.ndim(num_steps_))) if num_steps_ < 1: raise ValueError('`num_steps` must be at least 1.') elif self.validate_args: message = '`num_steps` must be a scalar' assertions.append( assert_util.assert_rank_at_most(self.num_steps, 0, message=message)) assertions.append( assert_util.assert_greater_equal( num_steps, 1, message='`num_steps` must be at least 1.')) # Check that the initial distribution has scalar events over the # integers. if is_init and not dtype_util.is_integer( self.initial_distribution.dtype): raise ValueError( '`initial_distribution.dtype` ({}) is not over integers'. format(dtype_util.name(self.initial_distribution.dtype))) if tensorshape_util.rank( self.initial_distribution.event_shape) is not None: if tensorshape_util.rank( self.initial_distribution.event_shape) != 0: raise ValueError( '`initial_distribution` must have scalar `event_dim`s') elif self.validate_args: assertions += [ assert_util.assert_equal( tf.size(self.initial_distribution.event_shape_tensor()), 0, message= '`initial_distribution` must have scalar `event_dim`s'), ] # Check that the transition distribution is over the integers. if (is_init and not dtype_util.is_integer(self.transition_distribution.dtype)): raise ValueError( '`transition_distribution.dtype` ({}) is not over integers'. format(dtype_util.name(self.transition_distribution.dtype))) # Check observations have non-scalar batches. # The graph version of this assertion is incorporated as # a control dependency of the transition/observation # compatibility test. if tensorshape_util.rank( self.observation_distribution.batch_shape) == 0: raise ValueError( "`observation_distribution` can't have scalar batches") # Check transitions have non-scalar batches. # The graph version of this assertion is incorporated as # a control dependency of the transition/observation # compatibility test. if tensorshape_util.rank( self.transition_distribution.batch_shape) == 0: raise ValueError( "`transition_distribution` can't have scalar batches") # Check compatibility of transition distribution and observation # distribution. tdbs = self.transition_distribution.batch_shape odbs = self.observation_distribution.batch_shape if (tensorshape_util.dims(tdbs) is not None and tf.compat.dimension_value(odbs[-1]) is not None): if (tf.compat.dimension_value(tdbs[-1]) != tf.compat.dimension_value(odbs[-1])): raise ValueError( '`transition_distribution` and `observation_distribution` ' 'must agree on last dimension of batch size') elif self.validate_args: tdbs = self.transition_distribution.batch_shape_tensor() odbs = self.observation_distribution.batch_shape_tensor() transition_precondition = assert_util.assert_greater( tf.size(tdbs), 0, message=('`transition_distribution` can\'t have scalar ' 'batches')) observation_precondition = assert_util.assert_greater( tf.size(odbs), 0, message=('`observation_distribution` can\'t have scalar ' 'batches')) with tf.control_dependencies( [transition_precondition, observation_precondition]): assertions += [ assert_util.assert_equal( tdbs[-1], odbs[-1], message=('`transition_distribution` and ' '`observation_distribution` ' 'must agree on last dimension of batch size')) ] return assertions
def _sample_n(self, n, seed=None): shape = tf.concat([[n], self.batch_shape_tensor()], axis=0) has_seed = seed is not None seed = SeedStream(seed, salt="zipf") minval_u = self._hat_integral(0.5) + 1. maxval_u = self._hat_integral(tf.int64.max - 0.5) def loop_body(should_continue, k): """Resample the non-accepted points.""" # The range of U is chosen so that the resulting sample K lies in # [0, tf.int64.max). The final sample, if accepted, is K + 1. u = tf.random.uniform(shape, minval=minval_u, maxval=maxval_u, dtype=self.power.dtype, seed=seed()) # Sample the point X from the continuous density h(x) \propto x^(-power). x = self._hat_integral_inverse(u) # Rejection-inversion requires a `hat` function, h(x) such that # \int_{k - .5}^{k + .5} h(x) dx >= pmf(k + 1) for points k in the # support. A natural hat function for us is h(x) = x^(-power). # # After sampling X from h(x), suppose it lies in the interval # (K - .5, K + .5) for integer K. Then the corresponding K is accepted if # if lies to the left of x_K, where x_K is defined by: # \int_{x_k}^{K + .5} h(x) dx = H(x_K) - H(K + .5) = pmf(K + 1), # where H(x) = \int_x^inf h(x) dx. # Solving for x_K, we find that x_K = H_inverse(H(K + .5) + pmf(K + 1)). # Or, the acceptance condition is X <= H_inverse(H(K + .5) + pmf(K + 1)). # Since X = H_inverse(U), this simplifies to U <= H(K + .5) + pmf(K + 1). # Update the non-accepted points. # Since X \in (K - .5, K + .5), the sample K is chosen as floor(X + 0.5). k = tf.where(should_continue, tf.floor(x + 0.5), k) accept = (u <= self._hat_integral(k + .5) + tf.exp(self._log_prob(k + 1))) return [should_continue & (~accept), k] should_continue, samples = tf.while_loop( cond=lambda should_continue, *ignore: tf.reduce_any(should_continue ), body=loop_body, loop_vars=[ tf.ones(shape, dtype=tf.bool), # should_continue tf.zeros(shape, dtype=self.power.dtype), # k ], parallel_iterations=1 if has_seed else 10, maximum_iterations=self.sample_maximum_iterations, ) samples = samples + 1. if self.validate_args and dtype_util.is_integer(self.dtype): samples = distribution_util.embed_check_integer_casting_closed( samples, target_dtype=self.dtype, assert_positive=True) samples = tf.cast(samples, self.dtype) if self.validate_args: npdt = dtype_util.as_numpy_dtype(self.dtype) v = npdt( dtype_util.min(npdt) if dtype_util.is_integer(npdt) else np.nan ) samples = tf.where(should_continue, v, samples) return samples
def _parameter_control_dependencies(self, is_init): assertions = [] if is_init and not dtype_util.is_integer(self.mixture_distribution.dtype): raise ValueError( '`mixture_distribution.dtype` ({}) is not over integers'.format( dtype_util.name(self.mixture_distribution.dtype))) if tensorshape_util.rank(self.mixture_distribution.event_shape) is not None: if tensorshape_util.rank(self.mixture_distribution.event_shape) != 0: raise ValueError('`mixture_distribution` must have scalar `event_dim`s') elif self.validate_args: assertions += [ assert_util.assert_equal( tf.size(self.mixture_distribution.event_shape_tensor()), 0, message='`mixture_distribution` must have scalar `event_dim`s'), ] # pylint: disable=protected-access mixture_dist_param = (self.mixture_distribution._probs if self.mixture_distribution._logits is None else self.mixture_distribution._logits) km = tf.compat.dimension_value( tensorshape_util.with_rank_at_least(mixture_dist_param.shape, 1)[-1]) kc = tf.compat.dimension_value( tensorshape_util.with_rank_at_least( self.components_distribution.batch_shape, 1)[-1]) component_bst = None if km is not None and kc is not None: if km != kc: raise ValueError('`mixture_distribution` components ({}) does not ' 'equal `components_distribution.batch_shape[-1]` ' '({})'.format(km, kc)) elif self.validate_args: if km is None: mixture_dist_param = tf.convert_to_tensor(mixture_dist_param) km = tf.shape(mixture_dist_param)[-1] if kc is None: component_bst = self.components_distribution.batch_shape_tensor() kc = component_bst[-1] assertions += [ assert_util.assert_equal( km, kc, message=('`mixture_distribution` components does not equal ' '`components_distribution.batch_shape[-1]`')), ] mdbs = self.mixture_distribution.batch_shape cdbs = tensorshape_util.with_rank_at_least( self.components_distribution.batch_shape, 1)[:-1] if (tensorshape_util.is_fully_defined(mdbs) and tensorshape_util.is_fully_defined(cdbs)): if tensorshape_util.rank(mdbs) != 0 and mdbs != cdbs: raise ValueError( '`mixture_distribution.batch_shape` (`{}`) is not ' 'compatible with `components_distribution.batch_shape` ' '(`{}`)'.format(tensorshape_util.as_list(mdbs), tensorshape_util.as_list(cdbs))) elif self.validate_args: if not tensorshape_util.is_fully_defined(mdbs): mixture_dist_param = tf.convert_to_tensor(mixture_dist_param) mdbs = tf.shape(mixture_dist_param)[:-1] if not tensorshape_util.is_fully_defined(cdbs): if component_bst is None: component_bst = self.components_distribution.batch_shape_tensor() cdbs = component_bst[:-1] assertions += [ assert_util.assert_equal( distribution_utils.pick_vector( tf.equal(tf.shape(mdbs)[0], 0), cdbs, mdbs), cdbs, message=( '`mixture_distribution.batch_shape` is not ' 'compatible with `components_distribution.batch_shape`')) ] return assertions
def __init__(self, mixture_distribution, components_distribution, reparameterize=False, validate_args=False, allow_nan_stats=True, name="MixtureSameFamily"): """Construct a `MixtureSameFamily` distribution. Args: mixture_distribution: `tfp.distributions.Categorical`-like instance. Manages the probability of selecting components. The number of categories must match the rightmost batch dimension of the `components_distribution`. Must have either scalar `batch_shape` or `batch_shape` matching `components_distribution.batch_shape[:-1]`. components_distribution: `tfp.distributions.Distribution`-like instance. Right-most batch dimension indexes components. reparameterize: Python `bool`, default `False`. Whether to reparameterize samples of the distribution using implicit reparameterization gradients [(Figurnov et al., 2018)][1]. The gradients for the mixture logits are equivalent to the ones described by [(Graves, 2016)][2]. The gradients for the components parameters are also computed using implicit reparameterization (as opposed to ancestral sampling), meaning that all components are updated every step. Only works when: (1) components_distribution is fully reparameterized; (2) components_distribution is either a scalar distribution or fully factorized (tfd.Independent applied to a scalar distribution); (3) batch shape has a known rank. Experimental, may be slow and produce infs/NaNs. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: `if not dtype_util.is_integer(mixture_distribution.dtype)`. ValueError: if mixture_distribution does not have scalar `event_shape`. ValueError: if `mixture_distribution.batch_shape` and `components_distribution.batch_shape[:-1]` are both fully defined and the former is neither scalar nor equal to the latter. ValueError: if `mixture_distribution` categories does not equal `components_distribution` rightmost batch shape. #### References [1]: Michael Figurnov, Shakir Mohamed and Andriy Mnih. Implicit reparameterization gradients. In _Neural Information Processing Systems_, 2018. https://arxiv.org/abs/1805.08498 [2]: Alex Graves. Stochastic Backpropagation through Mixture Density Distributions. _arXiv_, 2016. https://arxiv.org/abs/1607.05690 """ parameters = dict(locals()) with tf.name_scope(name) as name: self._mixture_distribution = mixture_distribution self._components_distribution = components_distribution self._runtime_assertions = [] s = components_distribution.event_shape_tensor() self._event_ndims = tf.compat.dimension_value(s.shape[0]) if self._event_ndims is None: self._event_ndims = tf.size(input=s) self._event_size = tf.reduce_prod(input_tensor=s) if not dtype_util.is_integer(mixture_distribution.dtype): raise ValueError( "`mixture_distribution.dtype` ({}) is not over integers". format(dtype_util.name(mixture_distribution.dtype))) if (tensorshape_util.rank(mixture_distribution.event_shape) is not None and tensorshape_util.rank( mixture_distribution.event_shape) != 0): raise ValueError( "`mixture_distribution` must have scalar `event_dim`s") elif validate_args: self._runtime_assertions += [ assert_util.assert_equal( tf.size( input=mixture_distribution.event_shape_tensor()), 0, message= "`mixture_distribution` must have scalar `event_dim`s" ), ] mdbs = mixture_distribution.batch_shape cdbs = tensorshape_util.with_rank_at_least( components_distribution.batch_shape, 1)[:-1] if tensorshape_util.is_fully_defined( mdbs) and tensorshape_util.is_fully_defined(cdbs): if tensorshape_util.rank(mdbs) != 0 and mdbs != cdbs: raise ValueError( "`mixture_distribution.batch_shape` (`{}`) is not " "compatible with `components_distribution.batch_shape` " "(`{}`)".format(tensorshape_util.as_list(mdbs), tensorshape_util.as_list(cdbs))) elif validate_args: mdbs = mixture_distribution.batch_shape_tensor() cdbs = components_distribution.batch_shape_tensor()[:-1] self._runtime_assertions += [ assert_util.assert_equal( distribution_utils.pick_vector( mixture_distribution.is_scalar_batch(), cdbs, mdbs), cdbs, message= ("`mixture_distribution.batch_shape` is not " "compatible with `components_distribution.batch_shape`" )) ] km = tf.compat.dimension_value( tensorshape_util.with_rank_at_least( mixture_distribution.logits.shape, 1)[-1]) kc = tf.compat.dimension_value( tensorshape_util.with_rank_at_least( components_distribution.batch_shape, 1)[-1]) if km is not None and kc is not None and km != kc: raise ValueError( "`mixture_distribution components` ({}) does not " "equal `components_distribution.batch_shape[-1]` " "({})".format(km, kc)) elif validate_args: km = tf.shape(input=mixture_distribution.logits)[-1] kc = components_distribution.batch_shape_tensor()[-1] self._runtime_assertions += [ assert_util.assert_equal( km, kc, message=( "`mixture_distribution components` does not equal " "`components_distribution.batch_shape[-1:]`")), ] elif km is None: km = tf.shape(input=mixture_distribution.logits)[-1] self._num_components = km self._reparameterize = reparameterize if reparameterize: # Note: tfd.Independent passes through the reparameterization type hence # we do not need separate logic for Independent. if (self._components_distribution.reparameterization_type != reparameterization.FULLY_REPARAMETERIZED): raise ValueError("Cannot reparameterize a mixture of " "non-reparameterized components.") reparameterization_type = reparameterization.FULLY_REPARAMETERIZED else: reparameterization_type = reparameterization.NOT_REPARAMETERIZED super(MixtureSameFamily, self).__init__( dtype=self._components_distribution.dtype, reparameterization_type=reparameterization_type, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=( self._mixture_distribution._graph_parents # pylint: disable=protected-access + self._components_distribution._graph_parents), # pylint: disable=protected-access name=name)
def _is_equal_or_close(self, a, b): if dtype_util.is_integer(self.outcomes.dtype): return tf.equal(a, b) return tf.abs(a - b) < self._atol + self._rtol * tf.abs(b)
def _float_dtype_like(dtype): if dtype_util.as_numpy_dtype(dtype) == np.int64: return tf.float64 if dtype_util.is_integer(dtype): return tf.float32 return dtype
def __init__(self, perm=None, rightmost_transposed_ndims=None, validate_args=False, name='transpose'): """Instantiates the `Transpose` bijector. Args: perm: Positive `int32` vector-shaped `Tensor` representing permutation of rightmost dims (for forward transformation). Note that the `0`th index represents the first of the rightmost dims and the largest value must be `rightmost_transposed_ndims - 1` and corresponds to `tf.rank(x) - 1`. Only one of `perm` and `rightmost_transposed_ndims` can (and must) be specified. The number of elements in a permutation must have a value that can be determined statically. Default value: `tf.range(start=rightmost_transposed_ndims, limit=-1, delta=-1)`. rightmost_transposed_ndims: Positive `int32` scalar-shaped `Tensor` representing the number of rightmost dimensions to permute. Only one of `perm` and `rightmost_transposed_ndims` can (and must) be specified. If `rightmost_transposed_ndims` is specified, the rightmost dims are reversed. This argument must have a value that can be determined statically. Default value: `tf.size(perm)`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. Raises: ValueError: if both or neither `perm` and `rightmost_transposed_ndims` are specified. NotImplementedError: if `rightmost_transposed_ndims` is not known prior to graph execution. """ parameters = dict(locals()) with tf.name_scope(name) as name: # We need to determine `forward_min_event_ndims` statically, which # requires that we know `rightmost_transposed_ndims` statically. # So the corresponding assertions go here rather than in # `_parameter_control_dependencies` if (rightmost_transposed_ndims is None) == (perm is None): raise ValueError('Must specify exactly one of ' '`rightmost_transposed_ndims` and `perm`.') if rightmost_transposed_ndims is not None: rightmost_transposed_ndims = tensor_util.convert_nonref_to_tensor( rightmost_transposed_ndims, dtype_hint=np.int32) if not dtype_util.is_integer(rightmost_transposed_ndims.dtype): raise TypeError( '`rightmost_transposed_ndims` must be integer type.') rightmost_transposed_ndims_ = tf.get_static_value( rightmost_transposed_ndims) if rightmost_transposed_ndims_ is None: raise NotImplementedError( '`rightmost_transposed_ndims` must be ' 'known prior to graph execution.') msg = '`rightmost_transposed_ndims` must be non-negative.' if rightmost_transposed_ndims_ < 0: raise ValueError( msg[:-1] + ', saw: {}.'.format(rightmost_transposed_ndims_)) perm_start = (distribution_util.prefer_static_value( rightmost_transposed_ndims) - 1) perm = tf.range(start=perm_start, limit=-1, delta=-1, name='perm') else: # perm is not None: perm = tensor_util.convert_nonref_to_tensor( perm, dtype_hint=np.int32, name='perm') rightmost_transposed_ndims = tf.size( perm, name='rightmost_transposed_ndims') rightmost_transposed_ndims_ = tf.get_static_value( rightmost_transposed_ndims) # TODO(b/110828604): If bijector base class ever supports dynamic # `min_event_ndims`, then this class already works dynamically and the # following five lines can be removed. if rightmost_transposed_ndims_ is None: raise NotImplementedError( '`rightmost_transposed_ndims` must be ' 'known prior to graph execution.') else: rightmost_transposed_ndims_ = int(rightmost_transposed_ndims_) self._perm = perm self._rightmost_transposed_ndims = rightmost_transposed_ndims self._initial_rightmost_transposed_ndims = rightmost_transposed_ndims_ super(Transpose, self).__init__( forward_min_event_ndims=rightmost_transposed_ndims_, is_constant_jacobian=True, validate_args=validate_args, parameters=parameters, name=name)
def _float_dtype_like(dtype): if dtype is tf.int64: return tf.float64 if dtype_util.is_integer(dtype): return tf.float32 return dtype
def _potential_scale_reduction_single_state(state, independent_chain_ndims, split_chains, validate_args): """potential_scale_reduction for one single state `Tensor`.""" # casting integers to floats for floating-point division # check to see if the `state` is a numpy object for the numpy test suite if dtype_util.as_numpy_dtype(state.dtype) is np.int64: state = tf.cast(state, tf.float64) elif dtype_util.is_integer(state.dtype): state = tf.cast(state, tf.float32) with tf.name_scope('potential_scale_reduction_single_state'): # We assume exactly one leading dimension indexes e.g. correlated samples # from each Markov chain. state = tf.convert_to_tensor(state, name='state') n_samples_ = tf.compat.dimension_value(state.shape[0]) if n_samples_ is not None: # If available statically. if split_chains and n_samples_ < 4: raise ValueError( 'Must provide at least 4 samples when splitting chains. ' 'Found {}'.format(n_samples_)) if not split_chains and n_samples_ < 2: raise ValueError( 'Must provide at least 2 samples. Found {}'.format( n_samples_)) elif validate_args: if split_chains: assertions = [ assert_util.assert_greater( ps.shape(state)[0], 4, message= 'Must provide at least 4 samples when splitting chains.' ) ] with tf.control_dependencies(assertions): state = tf.identity(state) else: assertions = [ assert_util.assert_greater( ps.shape(state)[0], 2, message='Must provide at least 2 samples.') ] with tf.control_dependencies(assertions): state = tf.identity(state) # Define so it's not a magic number. # Warning! `if split_chains` logic assumes this is 1! sample_ndims = 1 if split_chains: # Split the sample dimension in half, doubling the number of # independent chains. # For odd number of samples, keep all but the last sample. state_shape = ps.shape(state) n_samples = state_shape[0] state = state[:n_samples - n_samples % 2] # Suppose state = [0, 1, 2, 3, 4, 5] # Step 1: reshape into [[0, 1, 2], [3, 4, 5]] # E.g. reshape states of shape [a, b] into [2, a//2, b]. state = tf.reshape( state, ps.concat([[2, n_samples // 2], state_shape[1:]], axis=0)) # Step 2: Put the size `2` dimension in the right place to be treated as a # chain, changing [[0, 1, 2], [3, 4, 5]] into [[0, 3], [1, 4], [2, 5]], # reshaping [2, a//2, b] into [a//2, 2, b]. state = tf.transpose( a=state, perm=ps.concat([[1, 0], tf.range(2, tf.rank(state))], axis=0)) # We're treating the new dim as indexing 2 chains, so increment. independent_chain_ndims += 1 sample_axis = tf.range(0, sample_ndims) chain_axis = tf.range(sample_ndims, sample_ndims + independent_chain_ndims) sample_and_chain_axis = tf.range( 0, sample_ndims + independent_chain_ndims) n = _axis_size(state, sample_axis) m = _axis_size(state, chain_axis) # In the language of Brooks and Gelman (1998), # B / n is the between chain variance, the variance of the chain means. # W is the within sequence variance, the mean of the chain variances. b_div_n = _reduce_variance(tf.reduce_mean(state, axis=sample_axis, keepdims=True), sample_and_chain_axis, biased=False) w = tf.reduce_mean(_reduce_variance(state, sample_axis, keepdims=True, biased=False), axis=sample_and_chain_axis) # sigma^2_+ is an estimate of the true variance, which would be unbiased if # each chain was drawn from the target. c.f. "law of total variance." sigma_2_plus = ((n - 1) / n) * w + b_div_n return ((m + 1.) / m) * sigma_2_plus / w - (n - 1.) / (m * n)