def maybe_check_quadrature_param(param, name, validate_args): """Helper which checks validity of `loc` and `scale` init args.""" with tf.name_scope("check_" + name): assertions = [] if tensorshape_util.rank(param.shape) is not None: if tensorshape_util.rank(param.shape) == 0: raise ValueError("Mixing params must be a (batch of) vector; " "{}.rank={} is not at least one.".format( name, tensorshape_util.rank(param.shape))) elif validate_args: assertions.append( assert_util.assert_rank_at_least( param, 1, message=("Mixing params must be a (batch of) vector; " "{}.rank is not at least one.".format(name)))) # TODO(jvdillon): Remove once we support k-mixtures. if tensorshape_util.with_rank_at_least(param.shape, 1)[-1] is not None: if tf.compat.dimension_value(param.shape[-1]) != 1: raise NotImplementedError( "Currently only bimixtures are supported; " "{}.shape[-1]={} is not 1.".format( name, tf.compat.dimension_value(param.shape[-1]))) elif validate_args: assertions.append( assert_util.assert_equal( tf.shape(param)[-1], 1, message=("Currently only bimixtures are supported; " "{}.shape[-1] is not 1.".format(name)))) if assertions: return distribution_util.with_dependencies(assertions, param) return param
def calculate_reshape(original_shape, new_shape, validate=False, name=None): """Calculates the reshaped dimensions (replacing up to one -1 in reshape).""" batch_shape_static = tensorshape_util.constant_value_as_shape(new_shape) if tensorshape_util.is_fully_defined(batch_shape_static): return np.int32(batch_shape_static), batch_shape_static, [] with tf.name_scope(name or 'calculate_reshape'): original_size = tf.reduce_prod(original_shape) implicit_dim = tf.equal(new_shape, -1) size_implicit_dim = (original_size // tf.maximum(1, -tf.reduce_prod(new_shape))) expanded_new_shape = tf.where( # Assumes exactly one `-1`. implicit_dim, size_implicit_dim, new_shape) validations = [] if not validate else [ # pylint: disable=g-long-ternary assert_util.assert_rank( original_shape, 1, message='Original shape must be a vector.'), assert_util.assert_rank( new_shape, 1, message='New shape must be a vector.'), assert_util.assert_less_equal( tf.math.count_nonzero(implicit_dim, dtype=tf.int32), 1, message='At most one dimension can be unknown.'), assert_util.assert_positive( expanded_new_shape, message='Shape elements must be >=-1.'), assert_util.assert_equal(tf.reduce_prod(expanded_new_shape), original_size, message='Shape sizes do not match.'), ] return expanded_new_shape, batch_shape_static, validations
def _maybe_check_valid_map_values(map_values, validate_args): """Validate `map_values` if `validate_args`==True.""" assertions = [] message = 'Rank of map_values must be 1.' if tensorshape_util.rank(map_values.shape) is not None: if tensorshape_util.rank(map_values.shape) != 1: raise ValueError(message) elif validate_args: assertions.append( assert_util.assert_rank(map_values, 1, message=message)) message = 'Size of map_values must be greater than 0.' if tensorshape_util.num_elements(map_values.shape) is not None: if tensorshape_util.num_elements(map_values.shape) == 0: raise ValueError(message) elif validate_args: assertions.append( assert_util.assert_greater(tf.size(map_values), 0, message=message)) if validate_args: assertions.append( assert_util.assert_equal( tf.math.is_strictly_increasing(map_values), True, message='map_values is not strictly increasing.')) return assertions
def lu_reconstruct_assertions(lower_upper, perm, validate_args): """Returns list of assertions related to `lu_reconstruct` assumptions.""" assertions = [] message = 'Input `lower_upper` must have at least 2 dimensions.' if tensorshape_util.rank(lower_upper.shape) is not None: if tensorshape_util.rank(lower_upper.shape) < 2: raise ValueError(message) elif validate_args: assertions.append( assert_util.assert_rank_at_least(lower_upper, rank=2, message=message)) message = '`rank(lower_upper)` must equal `rank(perm) + 1`' if (tensorshape_util.rank(lower_upper.shape) is not None and tensorshape_util.rank(perm.shape) is not None): if (tensorshape_util.rank(lower_upper.shape) != tensorshape_util.rank(perm.shape) + 1): raise ValueError(message) elif validate_args: assertions.append( assert_util.assert_rank(lower_upper, rank=tf.rank(perm) + 1, message=message)) message = '`lower_upper` must be square.' if tensorshape_util.is_fully_defined(lower_upper.shape[:-2]): if lower_upper.shape[-2] != lower_upper.shape[-1]: raise ValueError(message) elif validate_args: m, n = tf.split(tf.shape(lower_upper)[-2:], num_or_size_splits=2) assertions.append(assert_util.assert_equal(m, n, message=message)) return assertions
def _maybe_validate_perm(perm, validate_args, name=None): """Checks that `perm` is valid.""" with tf.name_scope(name or 'maybe_validate_perm'): assertions = [] if not dtype_util.is_integer(perm.dtype): raise TypeError('`perm` must be integer type') msg = '`perm` must be a vector.' if tensorshape_util.rank(perm.shape) is not None: if tensorshape_util.rank(perm.shape) != 1: raise ValueError(msg[:-1] + ', saw rank: {}.'.format( tensorshape_util.rank(perm.shape))) elif validate_args: assertions += [assert_util.assert_rank(perm, 1, message=msg)] perm_ = tf.get_static_value(perm) msg = '`perm` must be a valid permutation vector.' if perm_ is not None: if not np.all(np.arange(np.size(perm_)) == np.sort(perm_)): raise ValueError(msg[:-1] + ', saw: {}.'.format(perm_)) elif validate_args: assertions += [ assert_util.assert_equal(tf.sort(perm), tf.range(tf.size(perm)), message=msg) ] return assertions
def _maybe_validate_shape_override(self, override_shape, base_is_scalar, validate_args, name): """Helper to __init__ which ensures override batch/event_shape are valid.""" if override_shape is None: override_shape = [] override_shape = tf.convert_to_tensor(override_shape, dtype=tf.int32, name=name) if not dtype_util.is_integer(override_shape.dtype): raise TypeError("shape override must be an integer") override_is_scalar = _is_scalar_from_shape_tensor(override_shape) if tf.get_static_value(override_is_scalar): return self._empty dynamic_assertions = [] if tensorshape_util.rank(override_shape.shape) is not None: if tensorshape_util.rank(override_shape.shape) != 1: raise ValueError("shape override must be a vector") elif validate_args: dynamic_assertions += [ assert_util.assert_rank( override_shape, 1, message="shape override must be a vector") ] if tf.get_static_value(override_shape) is not None: if any(s < 0 for s in tf.get_static_value(override_shape)): raise ValueError( "shape override must have non-negative elements") elif validate_args: dynamic_assertions += [ assert_util.assert_non_negative( override_shape, message="shape override must have non-negative elements") ] is_both_nonscalar = prefer_static.logical_and( prefer_static.logical_not(base_is_scalar), prefer_static.logical_not(override_is_scalar)) if tf.get_static_value(is_both_nonscalar) is not None: if tf.get_static_value(is_both_nonscalar): raise ValueError("base distribution not scalar") elif validate_args: dynamic_assertions += [ assert_util.assert_equal( is_both_nonscalar, False, message="base distribution not scalar") ] if not dynamic_assertions: return override_shape return distribution_util.with_dependencies(dynamic_assertions, override_shape)
def _assertions(self, x): if not self.validate_args: return [] x_shape = tf.shape(x) is_matrix = assert_util.assert_rank_at_least( x, 2, message="Input must have rank at least 2.") is_square = assert_util.assert_equal( x_shape[-2], x_shape[-1], message="Input must be a square matrix.") diag_part_x = tf.linalg.diag_part(x) is_lower_triangular = assert_util.assert_equal( tf.linalg.band_part(x, 0, -1), # Preserves triu, zeros rest. tf.linalg.diag(diag_part_x), message="Input must be lower triangular.") is_positive_diag = assert_util.assert_positive( diag_part_x, message="Input must have all positive diagonal entries.") return [is_matrix, is_square, is_lower_triangular, is_positive_diag]
def _assertions(self, t): if self.validate_args: return [] is_matrix = assert_util.assert_rank_at_least(t, 2) is_square = assert_util.assert_equal(tf.shape(t)[-2], tf.shape(t)[-1]) is_positive_definite = assert_util.assert_positive( tf.linalg.diag_part(t), message="Input must be positive definite.") return [is_matrix, is_square, is_positive_definite]
def _observation_mask_shape_preconditions(self, observation_tensor_shape, mask_tensor_shape): shape_condition = [ assert_util.assert_equal( observation_tensor_shape[-1 - self._underlying_event_rank], self._num_steps, message="The tensor `observations` must consist of sequences" "of observations from `HiddenMarkovModel` of length" "`num_steps`.") ] if mask_tensor_shape is not None: shape_condition.append( assert_util.assert_equal( mask_tensor_shape[-1], self._num_steps, message="The tensor `mask` must consist of sequences" "of length `num_steps`.")) return tf.control_dependencies(shape_condition)
def validate_equal_last_dim(tensor_a, tensor_b, message): event_size_a = tf.compat.dimension_value(tensor_a.shape[-1]) event_size_b = tf.compat.dimension_value(tensor_b.shape[-1]) if event_size_a is not None and event_size_b is not None: if event_size_a != event_size_b: raise ValueError(message) elif validate_args: return assert_util.assert_equal(tf.shape(tensor_a)[-1], tf.shape(tensor_b)[-1], message=message)
def _parameter_control_dependencies(self, is_init): assertions = [] message = 'Distributions must have the same `batch_shape`' if is_init: batch_shapes = tf.nest.flatten(self._cached_batch_shape) if all(tensorshape_util.is_fully_defined(b) for b in batch_shapes): if batch_shapes[1:] != batch_shapes[:-1]: raise ValueError('{}; found: {}.'.format( message, batch_shapes)) if not self.validate_args: assert not assertions # Should never happen. return [] if self.validate_args: batch_shapes = self._cached_batch_shape if not all( tensorshape_util.is_fully_defined(s) for s in tf.nest.flatten(batch_shapes)): batch_shapes = tf.nest.map_structure( lambda static_shape, shape_tensor: # pylint: disable=g-long-lambda (static_shape if tensorshape_util.is_fully_defined( static_shape) else shape_tensor), batch_shapes, self._cached_batch_shape_tensor) batch_shapes = tf.nest.flatten(batch_shapes) assertions.extend( assert_util.assert_equal( # pylint: disable=g-complex-comprehension b1, b2, message='{}.'.format(message)) for b1, b2 in zip(batch_shapes[1:], batch_shapes[:-1])) assertions.extend( assert_util.assert_equal( # pylint: disable=g-complex-comprehension tf.size(b1), tf.size(b2), message='{}.'.format(message)) for b1, b2 in zip(batch_shapes[1:], batch_shapes[:-1])) return assertions
def _maybe_assert_valid_sample(self, counts): """Check counts for proper shape, values, then return tensor version.""" if not self.validate_args: return counts counts = distribution_util.embed_check_nonnegative_integer_form(counts) return distribution_util.with_dependencies([ assert_util.assert_equal( self.total_count, tf.reduce_sum(counts, axis=-1), message='counts last-dimension must sum to `self.total_count`'), ], counts)
def _maybe_assert_valid_sample(self, counts): """Check counts for proper shape, values, then return tensor version.""" if not self.validate_args: return [] assertions = distribution_util.assert_nonnegative_integer_form(counts) assertions.append( assert_util.assert_equal( self.total_count, tf.reduce_sum(counts, axis=-1), message='counts must sum to `self.total_count`')) return assertions
def _assertions(self, x): if not self.validate_args: return [] shape = tf.shape(x) is_matrix = assert_util.assert_rank_at_least( x, 2, message="Input must have rank at least 2.") is_square = assert_util.assert_equal( shape[-2], shape[-1], message="Input must be a square matrix.") above_diagonal = tf.linalg.band_part( tf.linalg.set_diag(x, tf.zeros(shape[:-1], dtype=tf.float32)), 0, -1) is_lower_triangular = assert_util.assert_equal( above_diagonal, tf.zeros_like(above_diagonal), message="Input must be lower triangular.") # A lower triangular matrix is nonsingular iff all its diagonal entries are # nonzero. diag_part = tf.linalg.diag_part(x) is_nonsingular = assert_util.assert_none_equal( diag_part, tf.zeros_like(diag_part), message="Input must have all diagonal entries nonzero.") return [is_matrix, is_square, is_lower_triangular, is_nonsingular]
def _validate_block_sizes(block_sizes, bijectors, validate_args): """Helper to validate block sizes.""" block_sizes_shape = block_sizes.shape if tensorshape_util.is_fully_defined(block_sizes_shape): if (tensorshape_util.rank(block_sizes_shape) != 1 or (tensorshape_util.num_elements(block_sizes_shape) != len(bijectors))): raise ValueError( '`block_sizes` must be `None`, or a vector of the same length as ' '`bijectors`. Got a `Tensor` with shape {} and `bijectors` of ' 'length {}'.format(block_sizes_shape, len(bijectors))) return block_sizes elif validate_args: message = ('`block_sizes` must be `None`, or a vector of the same length ' 'as `bijectors`.') with tf.control_dependencies([ assert_util.assert_equal( tf.size(block_sizes), len(bijectors), message=message), assert_util.assert_equal(tf.rank(block_sizes), 1) ]): return tf.identity(block_sizes) else: return block_sizes
def _z(self, x, scale, concentration): loc = tf.convert_to_tensor(self.loc) if self.validate_args: valid = (x >= loc) & ((concentration >= 0) | (x <= loc - scale / concentration)) with tf.control_dependencies([ assert_util.assert_equal( valid, True, message='`x` outside distribution\'s support.') ]): x = tf.identity(x) return (x - loc) / scale
def _forward(self, x): map_values = tf.convert_to_tensor(self.map_values) if self.validate_args: with tf.control_dependencies([ assert_util.assert_equal( (0 <= x) & (x < tf.size(map_values)), True, message='indices out of bound') ]): x = tf.identity(x) # If we want batch dims in self.map_values, we can (after broadcasting), # use: # tf.gather(self.map_values, x, batch_dims=-1, axis=-1) return tf.gather(map_values, indices=x)
def _assert_compatible_shape(self, index, sample_shape, samples): requested_shape, _ = self._expand_sample_shape_to_vector( tf.convert_to_tensor(sample_shape, dtype=tf.int32), name='requested_shape') actual_shape = prefer_static.shape(samples) actual_rank = prefer_static.rank_from_shape(actual_shape) requested_rank = prefer_static.rank_from_shape(requested_shape) # We test for two properties we expect of yielded distributions: # (1) The rank of the tensor of generated samples must be at least # as large as the rank requested. # (2) The requested shape must be a prefix of the shape of the # generated tensor of samples. # We attempt to perform test (1) statically first. # We don't need to do this explicitly for test (2) because # `assert_equal` evaluates statically if it can. static_actual_rank = tf.get_static_value(actual_rank) static_requested_rank = tf.get_static_value(requested_rank) assertion_message = ('Samples yielded by distribution #{} are not ' 'consistent with `sample_shape` passed to ' '`JointDistributionCoroutine` ' 'distribution.'.format(index)) # TODO Remove this static check (b/138738650) if (static_actual_rank is not None and static_requested_rank is not None): # We're able to statically check the rank if static_actual_rank < static_requested_rank: raise ValueError(assertion_message) else: control_dependencies = [] else: # We're not able to statically check the rank control_dependencies = [ assert_util.assert_greater_equal( actual_rank, requested_rank, message=assertion_message) ] with tf.control_dependencies(control_dependencies): trimmed_actual_shape = actual_shape[:requested_rank] control_dependencies = [ assert_util.assert_equal( requested_shape, trimmed_actual_shape, message=assertion_message) ] return control_dependencies
def _validate_dimension(self, x): x = tf.convert_to_tensor(x, name='x') if tensorshape_util.is_fully_defined(x.shape[-2:]): if (tensorshape_util.dims(x.shape)[-2] == tensorshape_util.dims( x.shape)[-1] == self.dimension): pass else: raise ValueError( 'Input dimension mismatch: expected [..., {}, {}], got {}'. format(self.dimension, self.dimension, tensorshape_util.dims(x.shape))) elif self.validate_args: msg = 'Input dimension mismatch: expected [..., {}, {}], got {}'.format( self.dimension, self.dimension, tf.shape(x)) with tf.control_dependencies([ assert_util.assert_equal(tf.shape(x)[-2], self.dimension, message=msg), assert_util.assert_equal(tf.shape(x)[-1], self.dimension, message=msg) ]): x = tf.identity(x) return x
def _maybe_assert_valid_sample(self, samples): """Check counts for proper shape, values, then return tensor version.""" if not self.validate_args: return samples with tf.control_dependencies([ assert_util.assert_near(1., tf.linalg.norm(samples, axis=-1), message='samples must be unit length'), assert_util.assert_equal( tf.shape(samples)[-1:], self.event_shape_tensor(), message= ('samples must have innermost dimension matching that of ' '`self.mean_direction`')), ]): return tf.identity(samples)
def _prob(self, x): if self.validate_args: is_vector_check = assert_util.assert_rank_at_least(x, 1) right_vec_space_check = assert_util.assert_equal( self.event_shape_tensor(), tf.gather(tf.shape(x), tf.rank(x) - 1), message= "Argument 'x' not defined in the same space R^k as this distribution" ) with tf.control_dependencies([is_vector_check]): with tf.control_dependencies([right_vec_space_check]): x = tf.identity(x) loc = tf.convert_to_tensor(self.loc) return tf.cast(tf.reduce_all(tf.abs(x - loc) <= self._slack(loc), axis=-1), dtype=self.dtype)
def maybe_check_wont_broadcast(flat_xs, validate_args): """Verifies that `parts` don't broadcast.""" flat_xs = tuple(flat_xs) # So we can receive generators. if not validate_args: # Note: we don't try static validation because it is theoretically # possible that a user wants to take advantage of broadcasting. # Only when `validate_args` is `True` do we enforce the validation. return flat_xs msg = 'Broadcasting probably indicates an error in model specification.' s = tuple(prefer_static.shape(x) for x in flat_xs) if all(prefer_static.is_numpy(s_) for s_ in s): if not all(np.all(a == b) for a, b in zip(s[1:], s[:-1])): raise ValueError(msg) return flat_xs assertions = [ assert_util.assert_equal(a, b, message=msg) for a, b in zip(s[1:], s[:-1]) ] with tf.control_dependencies(assertions): return tuple(tf.identity(x) for x in flat_xs)
def _kl_sample(a, b, name='kl_sample'): """Batched KL divergence `KL(a || b)` for Sample distributions. We can leverage the fact that: ``` KL(Sample(a) || Sample(b)) = sum(KL(a || b)) ``` where the sum is over the `sample_shape` dims. Args: a: Instance of `Sample` distribution. b: Instance of `Sample` distribution. name: (optional) name to use for created ops. Default value: `"kl_sample"`'. Returns: kldiv: Batchwise `KL(a || b)`. Raises: ValueError: If the `sample_shape` of `a` and `b` don't match. """ assertions = [] a_ss = tf.get_static_value(a.sample_shape) b_ss = tf.get_static_value(b.sample_shape) msg = '`a.sample_shape` must be identical to `b.sample_shape`.' if a_ss is not None and b_ss is not None: if not np.array_equal(a_ss, b_ss): raise ValueError(msg) elif a.validate_args or b.validate_args: assertions.append( assert_util.assert_equal(a.sample_shape, b.sample_shape, message=msg)) with tf.control_dependencies(assertions): kl = kullback_leibler.kl_divergence(a.distribution, b.distribution, name=name) n = prefer_static.reduce_prod(a.sample_shape) return tf.cast(x=n, dtype=kl.dtype) * kl
def _lu_solve_assertions(lower_upper, perm, rhs, validate_args): """Returns list of assertions related to `lu_solve` assumptions.""" assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args) message = 'Input `rhs` must have at least 2 dimensions.' if rhs.shape.ndims is not None: if rhs.shape.ndims < 2: raise ValueError(message) elif validate_args: assertions.append( assert_util.assert_rank_at_least(rhs, rank=2, message=message)) message = '`lower_upper.shape[-1]` must equal `rhs.shape[-1]`.' if (tf.compat.dimension_value(lower_upper.shape[-1]) is not None and tf.compat.dimension_value(rhs.shape[-2]) is not None): if lower_upper.shape[-1] != rhs.shape[-2]: raise ValueError(message) elif validate_args: assertions.append( assert_util.assert_equal(tf.shape(lower_upper)[-1], tf.shape(rhs)[-2], message=message)) return assertions
def __init__(self, df, scale=None, scale_tril=None, input_output_cholesky=False, validate_args=False, allow_nan_stats=True, name="Wishart"): """Construct Wishart distributions. Args: df: `float` or `double` `Tensor`. Degrees of freedom, must be greater than or equal to dimension of the scale matrix. scale: `float` or `double` `Tensor`. The symmetric positive definite scale matrix of the distribution. Exactly one of `scale` and 'scale_tril` must be passed. scale_tril: `float` or `double` `Tensor`. The Cholesky factorization of the symmetric positive definite scale matrix of the distribution. Exactly one of `scale` and 'scale_tril` must be passed. input_output_cholesky: Python `bool`. If `True`, functions whose input or output have the semantics of samples assume inputs are in Cholesky form and return outputs in Cholesky form. In particular, if this flag is `True`, input to `log_prob` is presumed of Cholesky form and output from `sample`, `mean`, and `mode` are of Cholesky form. Setting this argument to `True` is purely a computational optimization and does not change the underlying distribution; for instance, `mean` returns the Cholesky of the mean, not the mean of Cholesky factors. The `variance` and `stddev` methods are unaffected by this flag. Default value: `False` (i.e., input/output does not have Cholesky semantics). validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: if zero or both of 'scale' and 'scale_tril' are passed in. """ parameters = dict(locals()) with tf.name_scope(name) as name: with tf.name_scope("init"): if (scale is None) == (scale_tril is None): raise ValueError( "Must pass scale or scale_tril, but not both.") dtype = dtype_util.common_dtype([df, scale, scale_tril], tf.float32) df = tf.convert_to_tensor(df, name="df", dtype=dtype) if scale is not None: scale = tf.convert_to_tensor(scale, name="scale", dtype=dtype) if validate_args: scale = distribution_util.assert_symmetric(scale) scale_tril = tf.linalg.cholesky(scale) else: # scale_tril is not None scale_tril = tf.convert_to_tensor(scale_tril, name="scale_tril", dtype=dtype) if validate_args: scale_tril = distribution_util.with_dependencies([ assert_util.assert_positive( tf.linalg.diag_part(scale_tril), message="scale_tril must be positive definite" ), assert_util.assert_equal( tf.shape(scale_tril)[-1], tf.shape(scale_tril)[-2], message="scale_tril must be square") ], scale_tril) super(Wishart, self).__init__( df=df, scale_operator=tf.linalg.LinearOperatorLowerTriangular( tril=scale_tril, is_non_singular=True, is_positive_definite=True, is_square=True), input_output_cholesky=input_output_cholesky, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters
def __init__(self, cat, components, validate_args=False, allow_nan_stats=True, use_static_graph=False, name="Mixture"): """Initialize a Mixture distribution. A `Mixture` is defined by a `Categorical` (`cat`, representing the mixture probabilities) and a list of `Distribution` objects all having matching dtype, batch shape, event shape, and continuity properties (the components). The `num_classes` of `cat` must be possible to infer at graph construction time and match `len(components)`. Args: cat: A `Categorical` distribution instance, representing the probabilities of `distributions`. components: A list or tuple of `Distribution` instances. Each instance must have the same type, be defined on the same domain, and have matching `event_shape` and `batch_shape`. validate_args: Python `bool`, default `False`. If `True`, raise a runtime error if batch or event ranks are inconsistent between cat and any of the distributions. This is only checked if the ranks cannot be determined statically at graph construction time. allow_nan_stats: Boolean, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. use_static_graph: Calls to `sample` will not rely on dynamic tensor indexing, allowing for some static graph compilation optimizations, but at the expense of sampling all underlying distributions in the mixture. (Possibly useful when running on TPUs). Default value: `False` (i.e., use dynamic indexing). name: A name for this distribution (optional). Raises: TypeError: If cat is not a `Categorical`, or `components` is not a list or tuple, or the elements of `components` are not instances of `Distribution`, or do not have matching `dtype`. ValueError: If `components` is an empty list or tuple, or its elements do not have a statically known event rank. If `cat.num_classes` cannot be inferred at graph creation time, or the constant value of `cat.num_classes` is not equal to `len(components)`, or all `components` and `cat` do not have matching static batch shapes, or all components do not have matching static event shapes. """ parameters = dict(locals()) if not isinstance(cat, categorical.Categorical): raise TypeError( "cat must be a Categorical distribution, but saw: %s" % cat) if not components: raise ValueError("components must be a non-empty list or tuple") if not isinstance(components, (list, tuple)): raise TypeError("components must be a list or tuple, but saw: %s" % components) if not all( isinstance(c, distribution.Distribution) for c in components): raise TypeError( "all entries in components must be Distribution instances" " but saw: %s" % components) dtype = components[0].dtype if not all(d.dtype == dtype for d in components): raise TypeError("All components must have the same dtype, but saw " "dtypes: %s" % [(d.name, d.dtype) for d in components]) static_event_shape = components[0].event_shape static_batch_shape = cat.batch_shape for di, d in enumerate(components): if not tensorshape_util.is_compatible_with(static_batch_shape, d.batch_shape): raise ValueError( "components[{}] batch shape must be compatible with cat " "shape and other component batch shapes".format(di)) static_event_shape = tensorshape_util.merge_with( static_event_shape, d.event_shape) static_batch_shape = tensorshape_util.merge_with( static_batch_shape, d.batch_shape) if tensorshape_util.rank(static_event_shape) is None: raise ValueError( "Expected to know rank(event_shape) from components, but " "none of the components provide a static number of ndims") # Ensure that all batch and event ndims are consistent. with tf.name_scope(name) as name: num_components = cat._num_categories() static_num_components = tf.get_static_value(num_components) if static_num_components is None: raise ValueError( "Could not infer number of classes from cat and unable " "to compare this value to the number of components passed in." ) # Possibly convert from numpy 0-D array. static_num_components = int(static_num_components) if static_num_components != len(components): raise ValueError( "cat.num_classes != len(components): %d vs. %d" % (static_num_components, len(components))) cat_batch_shape = cat.batch_shape_tensor() cat_batch_rank = tf.size(cat_batch_shape) if validate_args: batch_shapes = [d.batch_shape_tensor() for d in components] batch_ranks = [tf.size(bs) for bs in batch_shapes] check_message = ("components[%d] batch shape must match cat " "batch shape") self._assertions = [ assert_util.assert_equal(cat_batch_rank, batch_ranks[di], message=check_message % di) for di in range(len(components)) ] self._assertions += [ assert_util.assert_equal(cat_batch_shape, batch_shapes[di], message=check_message % di) for di in range(len(components)) ] else: self._assertions = [] self._cat = cat self._components = list(components) self._num_components = static_num_components self._static_event_shape = static_event_shape self._static_batch_shape = static_batch_shape self._use_static_graph = use_static_graph if use_static_graph and static_num_components is None: raise ValueError( "Number of categories must be known statically when " "`static_sample=True`.") super(Mixture, self).__init__( dtype=dtype, reparameterization_type=reparameterization.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, name=name)
def _parameter_control_dependencies(self, is_init): assertions = [] # For `logits` and `probs`, we only want to have an assertion on what the # user actually passed. For now, we access the underlying categorical's # _logits and _probs directly. After the 2019-10-01 deprecation, it would # also work to use .logits() and .probs(). logits = self._categorical._logits probs = self._categorical._probs outcomes = self._outcomes validate_args = self._validate_args # Build all shape and dtype checks during the `is_init` call. if is_init: def validate_equal_last_dim(tensor_a, tensor_b, message): event_size_a = tf.compat.dimension_value(tensor_a.shape[-1]) event_size_b = tf.compat.dimension_value(tensor_b.shape[-1]) if event_size_a is not None and event_size_b is not None: if event_size_a != event_size_b: raise ValueError(message) elif validate_args: return assert_util.assert_equal(tf.shape(tensor_a)[-1], tf.shape(tensor_b)[-1], message=message) message = 'Size of outcomes must be greater than 0.' if tensorshape_util.num_elements(outcomes.shape) is not None: if tensorshape_util.num_elements(outcomes.shape) == 0: raise ValueError(message) elif validate_args: assertions.append( tf.assert_greater(tf.size(outcomes), 0, message=message)) if logits is not None: maybe_assert = validate_equal_last_dim( outcomes, # pylint: disable=protected-access self._categorical._logits, # pylint: enable=protected-access message= 'Last dimension of outcomes and logits must be equal size.' ) if maybe_assert: assertions.append(maybe_assert) if probs is not None: maybe_assert = validate_equal_last_dim( outcomes, probs, message= 'Last dimension of outcomes and probs must be equal size.') if maybe_assert: assertions.append(maybe_assert) message = 'Rank of outcomes must be 1.' ndims = tensorshape_util.rank(outcomes.shape) if ndims is not None: if ndims != 1: raise ValueError(message) elif validate_args: assertions.append( assert_util.assert_rank(outcomes, 1, message=message)) if not validate_args: assert not assertions # Should never happen. return [] if is_init != tensor_util.is_ref(outcomes): assertions.append( assert_util.assert_equal( tf.math.is_strictly_increasing(outcomes), True, message='outcomes is not strictly increasing.')) return assertions
def _kl_independent(a, b, name="kl_independent"): """Batched KL divergence `KL(a || b)` for Independent distributions. We can leverage the fact that ``` KL(Independent(a) || Independent(b)) = sum(KL(a || b)) ``` where the sum is over the `reinterpreted_batch_ndims`. Args: a: Instance of `Independent`. b: Instance of `Independent`. name: (optional) name to use for created ops. Default "kl_independent". Returns: Batchwise `KL(a || b)`. Raises: ValueError: If the event space for `a` and `b`, or their underlying distributions don't match. """ p = a.distribution q = b.distribution # The KL between any two (non)-batched distributions is a scalar. # Given that the KL between two factored distributions is the sum, i.e. # KL(p1(x)p2(y) || q1(x)q2(y)) = KL(p1 || q1) + KL(q1 || q2), we compute # KL(p || q) and do a `reduce_sum` on the reinterpreted batch dimensions. if (tensorshape_util.is_fully_defined(a.event_shape) and tensorshape_util.is_fully_defined(b.event_shape)): if a.event_shape == b.event_shape: if p.event_shape == q.event_shape: num_reduce_dims = (tensorshape_util.rank(a.event_shape) - tensorshape_util.rank(p.event_shape)) reduce_dims = [-i - 1 for i in range(0, num_reduce_dims)] return tf.reduce_sum(kullback_leibler.kl_divergence(p, q, name=name), axis=reduce_dims) else: raise NotImplementedError( "KL between Independents with different " "event shapes not supported.") else: raise ValueError("Event shapes do not match.") else: with tf.control_dependencies([ assert_util.assert_equal(a.event_shape_tensor(), b.event_shape_tensor()), assert_util.assert_equal(p.event_shape_tensor(), q.event_shape_tensor()) ]): num_reduce_dims = (prefer_static.rank_from_shape( a.event_shape_tensor, a.event_shape) - prefer_static.rank_from_shape( p.event_shape_tensor, a.event_shape)) reduce_dims = prefer_static.range(-num_reduce_dims - 1, -1, 1) return tf.reduce_sum(kullback_leibler.kl_divergence(p, q, name=name), axis=reduce_dims)
def _replace_event_shape_in_shape_tensor(input_shape, event_shape_in, event_shape_out, validate_args): """Replaces the rightmost dims in a `Tensor` representing a shape. Args: input_shape: a rank-1 `Tensor` of integers event_shape_in: the event shape expected to be present in rightmost dims of `shape_in`. event_shape_out: the event shape with which to replace `event_shape_in` in the rightmost dims of `input_shape`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. Returns: output_shape: A rank-1 integer `Tensor` with the same contents as `input_shape` except for the event dims, which are replaced with `event_shape_out`. """ output_tensorshape, is_validated = _replace_event_shape_in_tensorshape( tensorshape_util.constant_value_as_shape(input_shape), event_shape_in, event_shape_out) # TODO(b/124240153): Remove map(tf.identity, deps) once tf.function # correctly supports control_dependencies. validation_dependencies = (map(tf.identity, (event_shape_in, event_shape_out)) if validate_args else ()) if (tensorshape_util.is_fully_defined(output_tensorshape) and (is_validated or not validate_args)): with tf.control_dependencies(validation_dependencies): output_shape = tf.convert_to_tensor( tensorshape_util.as_list(output_tensorshape), name='output_shape', dtype_hint=tf.int32) return output_shape, output_tensorshape with tf.control_dependencies(validation_dependencies): event_shape_in_ndims = ( tf.size(event_shape_in) if tensorshape_util.num_elements(event_shape_in.shape) is None else tensorshape_util.num_elements(event_shape_in.shape)) input_non_event_shape, input_event_shape = tf.split( input_shape, num_or_size_splits=[-1, event_shape_in_ndims]) additional_assertions = [] if is_validated: pass elif validate_args: # Check that `input_event_shape` and `event_shape_in` are compatible in the # sense that they have equal entries in any position that isn't a `-1` in # `event_shape_in`. Note that our validations at construction time ensure # there is at most one such entry in `event_shape_in`. mask = event_shape_in >= 0 explicit_input_event_shape = tf.boolean_mask(input_event_shape, mask=mask) explicit_event_shape_in = tf.boolean_mask(event_shape_in, mask=mask) additional_assertions.append( assert_util.assert_equal( explicit_input_event_shape, explicit_event_shape_in, message='Input `event_shape` does not match `event_shape_in`.') ) # We don't explicitly additionally verify # `tf.size(input_shape) > tf.size(event_shape_in)` since `tf.split` # already makes this assertion. with tf.control_dependencies(additional_assertions): output_shape = tf.concat([input_non_event_shape, event_shape_out], axis=0, name='output_shape') return output_shape, output_tensorshape
def __init__(self, permutation, axis=-1, validate_args=False, name=None): """Creates the `Permute` bijector. Args: permutation: An `int`-like vector-shaped `Tensor` representing the permutation to apply to the `axis` dimension of the transformed `Tensor`. axis: Scalar `int` `Tensor` representing the dimension over which to `tf.gather`. `axis` must be relative to the end (reading left to right) thus must be negative. Default value: `-1` (i.e., right-most). validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str`, name given to ops managed by this object. Raises: TypeError: if `not dtype_util.is_integer(permutation.dtype)`. ValueError: if `permutation` does not contain exactly one of each of `{0, 1, ..., d}`. NotImplementedError: if `axis` is not known prior to graph execution. NotImplementedError: if `axis` is not negative. """ with tf.name_scope(name or "permute") as name: axis = tf.convert_to_tensor(axis, name="axis") if not dtype_util.is_integer(axis.dtype): raise TypeError("axis.dtype ({}) should be `int`-like.".format( dtype_util.name(axis.dtype))) permutation = tf.convert_to_tensor(permutation, name="permutation") if not dtype_util.is_integer(permutation.dtype): raise TypeError( "permutation.dtype ({}) should be `int`-like.".format( dtype_util.name(permutation.dtype))) p = tf.get_static_value(permutation) if p is not None: if set(p) != set(np.arange(p.size)): raise ValueError( "Permutation over `d` must contain exactly one of " "each of `{0, 1, ..., d}`.") elif validate_args: p, _ = tf.math.top_k(-permutation, k=tf.shape(permutation)[-1], sorted=True) permutation = distribution_util.with_dependencies([ assert_util.assert_equal( -p, tf.range(tf.size(p)), message=( "Permutation over `d` must contain exactly one of " "each of `{0, 1, ..., d}`.")), ], permutation) axis_ = tf.get_static_value(axis) if axis_ is None: raise NotImplementedError( "`axis` must be known prior to graph " "execution.") elif axis_ >= 0: raise NotImplementedError( "`axis` must be relative the rightmost " "dimension, i.e., negative.") else: forward_min_event_ndims = int(np.abs(axis_)) self._permutation = permutation self._axis = axis super(Permute, self).__init__( forward_min_event_ndims=forward_min_event_ndims, is_constant_jacobian=True, validate_args=validate_args, name=name)