def _value(self, dtype=None, name=None, as_ref=False): y = self.transform_fn(self.pretransformed_input) # pylint: disable=not-callable if dtype_util.base_dtype(y.dtype) != self.dtype: raise TypeError( 'Actual dtype ({}) does not match deferred dtype ({}).'.format( dtype_util.name(dtype_util.base_dtype(y.dtype)), dtype_util.name(self.dtype))) if not tensorshape_util.is_compatible_with(y.shape, self.shape): raise TypeError( 'Actual shape ({}) is incompatible with deferred shape ({}).'. format(y.shape, self.shape)) return tf.convert_to_tensor(y, dtype=dtype, name=name)
def __init__(self, shift=None, scale=None, adjoint=False, validate_args=False, name="affine_linear_operator"): """Instantiates the `AffineLinearOperator` bijector. Args: shift: Floating-point `Tensor`. scale: Subclass of `LinearOperator`. Represents the (batch) positive definite matrix `M` in `R^{k x k}`. adjoint: Python `bool` indicating whether to use the `scale` matrix as specified or its adjoint. Default value: `False`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. Raises: TypeError: if `scale` is not a `LinearOperator`. TypeError: if `shift.dtype` does not match `scale.dtype`. ValueError: if not `scale.is_non_singular`. """ with tf.name_scope(name) as name: # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`. dtype = tf.float32 if shift is not None: shift = tf.convert_to_tensor(value=shift, name="shift") dtype = dtype_util.base_dtype(shift.dtype) self._shift = shift if scale is not None: if (shift is not None and not dtype_util.base_equal(shift.dtype, scale.dtype)): raise TypeError( "shift.dtype({}) is incompatible with scale.dtype({})." .format(shift.dtype, scale.dtype)) if not isinstance(scale, tf.linalg.LinearOperator): raise TypeError( "scale is not an instance of tf.LinearOperator") if validate_args and not scale.is_non_singular: raise ValueError("Scale matrix must be non-singular.") if scale.dtype is not None: dtype = dtype_util.base_dtype(scale.dtype) self._scale = scale self._adjoint = adjoint super(AffineLinearOperator, self).__init__(forward_min_event_ndims=1, is_constant_jacobian=True, dtype=dtype, validate_args=validate_args, name=name)
def _fn(*args): p = tf.identity(proposal_log_prob_fn(*args), name='proposal_log_prob') t = tf.identity(target_log_prob_fn(*args), name='target_log_prob') dtype = dtype_util.base_dtype(p.dtype) beta = tf.cast(iter_ + 1, dtype) / tf.cast(num_steps, dtype) return tf.identity(beta * t + (1. - beta) * p, name='convex_combined_log_prob')
def _prepare_args_with_initial_vertex(objective_function, initial_vertex, step_sizes, objective_at_initial_vertex, batch_evaluate_objective): """Constructs a standard axes aligned simplex.""" dim = tf.size(initial_vertex) num_vertices = dim + 1 unit_vectors_along_axes = tf.reshape( tf.eye(dim, dim, dtype=dtype_util.base_dtype(initial_vertex.dtype)), tf.concat([[dim], tf.shape(initial_vertex)], axis=0)) # If step_sizes does not broadcast to initial_vertex, the multiplication # in the second term will fail. simplex_face = initial_vertex + step_sizes * unit_vectors_along_axes simplex = tf.concat([tf.expand_dims(initial_vertex, axis=0), simplex_face], axis=0) # Evaluate the objective function at the simplex vertices. if objective_at_initial_vertex is None: objective_at_simplex, num_evaluations = _evaluate_objective_multiple( objective_function, simplex, batch_evaluate_objective) else: objective_at_simplex_face, num_evaluations = _evaluate_objective_multiple( objective_function, simplex_face, batch_evaluate_objective) objective_at_simplex = tf.concat( [ tf.expand_dims(objective_at_initial_vertex, axis=0), objective_at_simplex_face ], axis=0) return (dim, num_vertices, simplex, objective_at_simplex, num_evaluations)
def _prepare_args_with_initial_vertex(objective_function, initial_vertex, step_sizes, objective_at_initial_vertex, batch_evaluate_objective): """Constructs a standard axes aligned simplex.""" dim = ps.size(initial_vertex) # tf.eye complains about np.array(.., np.int32) num_rows, only welcomes numpy # scalars. TODO(b/162529062): Remove the following line. dim = dim if tf.is_tensor(dim) else int(dim) num_vertices = dim + 1 unit_vectors_along_axes = tf.reshape( tf.eye(dim, dim, dtype=dtype_util.base_dtype(initial_vertex.dtype)), ps.concat([[dim], ps.shape(initial_vertex)], axis=0)) # If step_sizes does not broadcast to initial_vertex, the multiplication # in the second term will fail. simplex_face = initial_vertex + step_sizes * unit_vectors_along_axes simplex = tf.concat([tf.expand_dims(initial_vertex, axis=0), simplex_face], axis=0) # Evaluate the objective function at the simplex vertices. if objective_at_initial_vertex is None: objective_at_simplex, num_evaluations = _evaluate_objective_multiple( objective_function, simplex, batch_evaluate_objective) else: objective_at_simplex_face, num_evaluations = _evaluate_objective_multiple( objective_function, simplex_face, batch_evaluate_objective) objective_at_simplex = tf.concat([ tf.expand_dims(objective_at_initial_vertex, axis=0), objective_at_simplex_face ], axis=0) return (dim, num_vertices, simplex, objective_at_simplex, num_evaluations)
def _fn(state_parts, seed): """Adds a uniform perturbation to the input state. Args: state_parts: A list of `Tensor`s of any shape and real dtype representing the state parts of the `current_state` of the Markov chain. seed: `int` or None. The random seed for this `Op`. If `None`, no seed is applied. Default value: `None`. Returns: perturbed_state_parts: A Python `list` of The `Tensor`s. Has the same shape and type as the `state_parts`. Raises: ValueError: if `scale` does not broadcast with `state_parts`. """ with tf.name_scope(name or 'random_walk_uniform_fn'): scales = scale if mcmc_util.is_list_like(scale) else [scale] if len(scales) == 1: scales *= len(state_parts) if len(state_parts) != len(scales): raise ValueError('`scale` must broadcast with `state_parts`.') part_seeds = samplers.split_seed(seed, n=len(state_parts)) next_state_parts = [ samplers.uniform( # pylint: disable=g-complex-comprehension minval=state_part - scale_part, maxval=state_part + scale_part, shape=tf.shape(state_part), dtype=dtype_util.base_dtype(state_part.dtype), seed=seed_part) for scale_part, state_part, seed_part in zip( scales, state_parts, part_seeds) ] return next_state_parts
def _make_empty_queue_for(k, element): """Creates a `tf.Tensor` suitable to hold `k` element-shaped tensors. For example: ```python element = tf.constant([[0., 1., 2., 3., 4.], [5., 6., 7., 8., 9.]]) # A queue capable of holding 3 elements. _make_empty_queue_for(3, element) # => [[[ 0., 0., 0., 0., 0.], # [ 0., 0., 0., 0., 0.]], # # [[ 0., 0., 0., 0., 0.], # [ 0., 0., 0., 0., 0.]], # # [[ 0., 0., 0., 0., 0.], # [ 0., 0., 0., 0., 0.]]] ``` Args: k: A positive scalar integer, number of elements that each queue will hold. element: A `tf.Tensor`, only its shape and dtype information are relevant. Returns: A zero-filed `tf.Tensor` of shape `(k,) + tf.shape(element)` and same dtype as `element`. """ queue_shape = tf.concat( [[k], distribution_util.prefer_static_shape(element)], axis=0) return tf.zeros(queue_shape, dtype=dtype_util.base_dtype(element.dtype))
def _log_prob(self, x): # TODO(b/149334734): Consider using QuantizedDistribution for the log_prob # computation for better precision. num_categories = self._num_categories() x, augmented_log_survival = _broadcast_cat_event_and_params( event=x, params=tf.math.log_sigmoid(self.loc[..., tf.newaxis] - self._augmented_cutpoints()), base_dtype=dtype_util.base_dtype(self.dtype)) x_flat = tf.reshape(x, [-1, 1]) augmented_log_survival_flat = tf.reshape(augmented_log_survival, [-1, num_categories + 1]) log_survival_flat_xm1 = tf.gather(params=augmented_log_survival_flat, indices=tf.clip_by_value( x_flat, 0, num_categories), batch_dims=1) log_survival_flat_x = tf.gather(params=augmented_log_survival_flat, indices=tf.clip_by_value( x_flat + 1, 0, num_categories), batch_dims=1) log_prob_flat = tfp_math.log_sub_exp(log_survival_flat_xm1, log_survival_flat_x) # Deal with case where both survival probabilities are -inf, which gives # `log_prob_flat = nan` when it should be -inf. minus_inf = tf.constant(-np.inf, dtype=log_prob_flat.dtype) log_prob_flat = tf.where(x_flat > num_categories - 1, minus_inf, log_prob_flat) return tf.reshape(log_prob_flat, shape=ps.shape(x))
def _fn(state_parts, seed): """Adds a normal perturbation to the input state. Args: state_parts: A list of `Tensor`s of any shape and real dtype representing the state parts of the `current_state` of the Markov chain. seed: `int` or None. The random seed for this `Op`. If `None`, no seed is applied. Default value: `None`. Returns: perturbed_state_parts: A Python `list` of The `Tensor`s. Has the same shape and type as the `state_parts`. Raises: ValueError: if `scale` does not broadcast with `state_parts`. """ with tf.name_scope(name or 'random_walk_normal_fn'): scales = scale if mcmc_util.is_list_like(scale) else [scale] if len(scales) == 1: scales *= len(state_parts) if len(state_parts) != len(scales): raise ValueError('`scale` must broadcast with `state_parts`.') seed_stream = SeedStream(seed, salt='RandomWalkNormalFn') next_state_parts = [ tf.random.normal( # pylint: disable=g-complex-comprehension mean=state_part, stddev=scale_part, shape=tf.shape(state_part), dtype=dtype_util.base_dtype(state_part.dtype), seed=seed_stream()) for scale_part, state_part in zip(scales, state_parts) ] return next_state_parts
def __init__(self, transform_fn, pretransformed_input, dtype=None, shape=NONE_SPECIFIED, name=None): """Creates the `DeferredTensor` object. Args: transform_fn: Python `callable` taking `pretransformed_input` and returning a `Tensor` (representing by this object). pretransformed_input: object with `shape`, `dtype` properties (typically a `tf.Variable`) passed into `transform_fn` when this object is acted upon in a `Tensor` context, eg, `tf.convert_to_tensor`, `+`, `tf.math.exp`, etc. dtype: Equivalent to what would otherwise be `transform_fn(pretransformed_input).dtype`. Default value: `None` (i.e., `pretransformed_input.dtype`). shape: Equivalent to what would otherwise be `transform_fn(pretransformed_input).shape`. Default value: `'None'` (i.e., `pretransformed_input.shape`). name: Python `str` representing this object's `name`; used only in graph mode. Default value: `None` (i.e., `transform_fn.__name__ + '_' + pretransformed_input.name`). Raises: TypeError: if `transform_fn` is not `callable`. TypeError: if `pretransformed_input` lacks `dtype` and/or `shape` properties (and `dtype` and/or `shape` arguments are unspecified). """ if not callable(transform_fn): raise TypeError('Argument `transform_fn` must be a Python `callable`.') if ((dtype is None and not hasattr(pretransformed_input, 'dtype')) or (shape is None and not hasattr(pretransformed_input, 'shape'))): raise TypeError('Argument `pretransformed_input` must have `dtype` and ' '`shape` properties (unless `dtype`, `shape` arguments ' 'are explicitly provided.') has_name = bool(name) if not has_name: name = '_'.join([ transform_fn.__name__, getattr(pretransformed_input, 'name', '')]) name = name_util.strip_invalid_chars(name) name = name_util.camel_to_lower_snake(name) name = name_util.get_name_scope_name(name) name = name_util.strip_invalid_chars(name) super(DeferredTensor, self).__init__(name=name) self._name = name self._transform_fn = transform_fn self._pretransformed_input = pretransformed_input self._dtype = dtype_util.base_dtype(dtype or pretransformed_input.dtype) self._shape = tf.TensorShape( pretransformed_input.shape if shape == 'None' else shape) # Secret handshake with tf.is_tensor to return True for DT. # # Works around an exception in LinearOperator (which in 2.0.0 checks only # `tf.is_tensor`, not also `linear_operator_util.is_ref`: # ValueError: Graph parent item 0 is not a Tensor; # <DeferredTensor: dtype=float32, shape=[2], fn=exp>. # TODO(b/140157055): Remove this shim after LinOp is patched in 2.0. self.is_tensor_like = True
def _inverse_log_det_jacobian(self, y): # If event_ndims = 2, # F^{-1}(y) = (-y, y), so DF^{-1}(y) = (-1, 1), # so Log|DF^{-1}(y)| = Log[1, 1] = [0, 0]. with tf.control_dependencies(self._assertions(y)): zero = tf.zeros([], dtype=dtype_util.base_dtype(y.dtype)) return zero, zero
def _cdf(self, k): # TODO(b/135263541): Improve numerical precision of categorical.cdf. probs = self.probs_parameter() num_categories = self._num_categories(probs) k, probs = _broadcast_cat_event_and_params( k, probs, base_dtype=dtype_util.base_dtype(self.dtype)) # Since the lowest number in the support is 0, any k < 0 should be zero in # the output. should_be_zero = k < 0 # Will use k as an index in the gather below, so clip it to {0,...,K-1}. k = tf.clip_by_value(tf.cast(k, tf.int32), 0, num_categories - 1) batch_shape = tf.shape(k) # tf.gather(..., batch_dims=batch_dims) requires static batch_dims kwarg, so # to handle the case where the batch shape is dynamic, flatten the batch # dims (so we know batch_dims=1). k_flat_batch = tf.reshape(k, [-1]) probs_flat_batch = tf.reshape( probs, tf.concat(([-1], [num_categories]), axis=0)) cdf_flat = tf.gather(tf.cumsum(probs_flat_batch, axis=-1), k_flat_batch[..., tf.newaxis], batch_dims=1) cdf = tf.reshape(cdf_flat, shape=batch_shape) zero = np.array(0, dtype=dtype_util.as_numpy_dtype(cdf.dtype)) return tf.where(should_be_zero, zero, cdf)
def dense_to_sparse(x, ignore_value=None, name=None): """Converts dense `Tensor` to `SparseTensor`, dropping `ignore_value` cells. Args: x: A `Tensor`. ignore_value: Entries in `x` equal to this value will be absent from the return `SparseTensor`. If `None`, default value of `x` dtype will be used (e.g. '' for `str`, 0 for `int`). name: Python `str` prefix for ops created by this function. Returns: sparse_x: A `tf.SparseTensor` with the same shape as `x`. Raises: ValueError: when `x`'s rank is `None`. """ # Copied (with modifications) from: # tensorflow/contrib/layers/python/ops/sparse_ops.py. with tf.name_scope(name or 'dense_to_sparse'): x = tf.convert_to_tensor(x, name='x') if ignore_value is None: if dtype_util.base_dtype(x.dtype) == tf.string: # Exception due to TF strings are converted to numpy objects by default. ignore_value = '' else: ignore_value = dtype_util.as_numpy_dtype(x.dtype)(0) ignore_value = tf.cast(ignore_value, x.dtype, name='ignore_value') indices = tf.where(tf.not_equal(x, ignore_value), name='indices') return tf.SparseTensor(indices=indices, values=tf.gather_nd(x, indices, name='values'), dense_shape=tf.shape(x, out_type=tf.int64, name='dense_shape'))
def _sample_n(self, n, seed): df = tf.convert_to_tensor(self.df) batch_shape = self._batch_shape_tensor(df) event_shape = self._event_shape_tensor() batch_ndims = tf.shape(batch_shape)[0] ndims = batch_ndims + 3 # sample_ndims=1, event_ndims=2 shape = tf.concat([[n], batch_shape, event_shape], 0) normal_seed, gamma_seed = samplers.split_seed(seed, salt='Wishart') # Complexity: O(nbk**2) x = samplers.normal(shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=normal_seed) # Complexity: O(nbk) # This parameterization is equivalent to Chi2, i.e., # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2) expanded_df = df * tf.ones(self._scale.batch_shape_tensor(), dtype=dtype_util.base_dtype(df.dtype)) g = gamma_lib.random_gamma(shape=[n], concentration=self._multi_gamma_sequence( 0.5 * expanded_df, self._dimension()), rate=0.5, seed=gamma_seed) # Complexity: O(nbk**2) x = tf.linalg.band_part(x, -1, 0) # Tri-lower. # Complexity: O(nbk) x = tf.linalg.set_diag(x, tf.sqrt(g)) # Make batch-op ready. # Complexity: O(nbk**2) perm = tf.concat([tf.range(1, ndims), [0]], 0) x = tf.transpose(a=x, perm=perm) shape = tf.concat( [batch_shape, [event_shape[0]], [event_shape[1] * n]], 0) x = tf.reshape(x, shape) # Complexity: O(nbM) where M is the complexity of the operator solving a # vector system. For LinearOperatorLowerTriangular, each matmul is O(k^3) so # this step has complexity O(nbk^3). x = self._scale.matmul(x) # Undo make batch-op ready. # Complexity: O(nbk**2) shape = tf.concat([batch_shape, event_shape, [n]], 0) x = tf.reshape(x, shape) perm = tf.concat([[ndims - 1], tf.range(0, ndims - 1)], 0) x = tf.transpose(a=x, perm=perm) if not self.input_output_cholesky: # Complexity: O(nbk**3) x = tf.matmul(x, x, adjoint_b=True) return x
def _cdf(self, k): k = tf.convert_to_tensor(value=k, name="k") k, probs = _broadcast_cat_event_and_params( k, self.probs, base_dtype=dtype_util.base_dtype(self.dtype)) # Since the lowest number in the support is 0, any k < 0 should be zero in # the output. should_be_zero = k < 0 # Will use k as an index in the gather below, so clip it to {0,...,K-1}. k = tf.clip_by_value(tf.cast(k, tf.int32), 0, self.num_categories - 1) batch_shape = tf.shape(input=k) # tf.gather(..., batch_dims=batch_dims) requires static batch_dims kwarg, so # to handle the case where the batch shape is dynamic, flatten the batch # dims (so we know batch_dims=1). k_flat_batch = tf.reshape(k, [-1]) probs_flat_batch = tf.reshape( probs, tf.concat(([-1], [self.num_categories]), axis=0)) cdf_flat = tf.gather(tf.cumsum(probs_flat_batch, axis=-1), k_flat_batch[..., tf.newaxis], batch_dims=1) cdf = tf.reshape(cdf_flat, shape=batch_shape) return tf.where(should_be_zero, tf.zeros_like(cdf), cdf)
def bootstrap_results(self, init_state): """Creates initial `state`.""" with tf.name_scope( mcmc_util.make_name(self.name, "AdaptiveRandomWalkMetropolisHastings", "bootstrap_results")): if mcmc_util.is_list_like(init_state): initial_state_parts = list(init_state) else: initial_state_parts = [init_state] initial_state_parts = [ tf.convert_to_tensor(s, name="init_state") for s in initial_state_parts ] shape = tf.stack(initial_state_parts).shape dtype = dtype_util.base_dtype(tf.stack(initial_state_parts).dtype) init_covariance_scaling = tf.cast( tf.repeat([self.initial_covariance_scaling], repeats=[shape[0]], axis=0), dtype=dtype, ) inner_results = self._impl.bootstrap_results(init_state) return self.extra_setter_fn( inner_results, 0, init_covariance_scaling / shape[-1], self.initial_covariance, self._accum_covar, self.initial_u, )
def __init__(self, transform_fn, pretransformed_input, dtype=None, shape=NONE_SPECIFIED, name=None): """Creates the `DeferredTensor` object. Args: transform_fn: Python `callable` taking `pretransformed_input` and returning a `Tensor` (representing by this object). pretransformed_input: object with `shape`, `dtype` properties (typically a `tf.Variable`) passed into `transform_fn` when this object is acted upon in a `Tensor` context, eg, `tf.convert_to_tensor`, `+`, `tf.math.exp`, etc. dtype: Equivalent to what would otherwise be `transform_fn(pretransformed_input).dtype`. Default value: `None` (i.e., `pretransformed_input.dtype`). shape: Equivalent to what would otherwise be `transform_fn(pretransformed_input).shape`. Default value: `'None'` (i.e., `pretransformed_input.shape`). name: Python `str` representing this object's `name`; used only in graph mode. Default value: `None` (i.e., `transform_fn.__name__ + '_' + pretransformed_input.name`). Raises: TypeError: if `transform_fn` is not `callable`. TypeError: if `pretransformed_input` lacks `dtype` and/or `shape` properties (and `dtype` and/or `shape` arguments are unspecified). """ if not callable(transform_fn): raise TypeError( 'Argument `transform_fn` must be a Python `callable`.') if ((dtype is None and not hasattr(pretransformed_input, 'dtype')) or (shape is None and not hasattr(pretransformed_input, 'shape'))): raise TypeError( 'Argument `pretransformed_input` must have `dtype` and ' '`shape` properties (unless `dtype`, `shape` arguments ' 'are explicitly provided.') has_name = bool(name) if not has_name: name = '_'.join([ transform_fn.__name__, getattr(pretransformed_input, 'name', '') ]) name = name_util.strip_invalid_chars(name) name = name_util.camel_to_lower_snake(name) name = name_util.get_name_scope_name(name) name = name_util.strip_invalid_chars(name) super(DeferredTensor, self).__init__(name=name) self._name = name self._transform_fn = transform_fn self._pretransformed_input = pretransformed_input self._dtype = dtype_util.base_dtype(dtype or pretransformed_input.dtype) self._shape = tf.TensorShape(pretransformed_input.shape if shape == 'None' else shape)
def _forward_log_det_jacobian(self, x): # is_constant_jacobian = True for this bijector, hence the # `log_det_jacobian` need only be specified for a single input, as this will # be tiled to match `event_ndims`. if self.scale is None: return tf.constant(0., dtype=dtype_util.base_dtype(x.dtype)) return tf.math.log(tf.abs(self.scale))
def _forward_log_det_jacobian(self, x): # For a discussion of this (non-obvious) result, see Note 7.2.2 (and the # sections leading up to it, for context) in # http://neutrino.aquaphoenix.com/ReactionDiffusion/SERC5chap7.pdf with tf.control_dependencies(self._assertions(x)): matrix_dim = tf.cast(tf.shape(x)[-1], dtype_util.base_dtype(x.dtype)) return -(matrix_dim + 1) * tf.reduce_sum( tf.math.log(tf.abs(tf.linalg.diag_part(x))), axis=-1)
def _log_prob(self, k): logits = self.logits_parameter() if self.validate_args: k = distribution_util.embed_check_integer_casting_closed( k, target_dtype=tf.int32) k, logits = _broadcast_cat_event_and_params( k, logits, base_dtype=dtype_util.base_dtype(self.dtype)) return -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=k, logits=logits)
def histogram(self, x, value_range=None, nbins=None, name=None): """Return histogram of values. Given the tensor `values`, this operation returns a rank 1 histogram counting the number of entries in `values` that fell into every bin. The bins are equal width and determined by the arguments `value_range` and `nbins`. Args: x: 1D numeric `Tensor` of items to count. value_range: Shape [2] `Tensor`. `new_values <= value_range[0]` will be mapped to `hist[0]`, `values >= value_range[1]` will be mapped to `hist[-1]`. Must be same dtype as `x`. nbins: Scalar `int32 Tensor`. Number of histogram bins. name: Python `str` name prefixed to Ops created by this class. Returns: counts: 1D `Tensor` of counts, i.e., `counts[i] = sum{ edges[i-1] <= values[j] < edges[i] : j }`. edges: 1D `Tensor` characterizing intervals used for counting. """ with tf.compat.v2.name_scope(name or 'histogram'): x = tf.convert_to_tensor(value=x, name='x') if value_range is None: value_range = [ tf.reduce_min(input_tensor=x), 1 + tf.reduce_max(input_tensor=x) ] value_range = tf.convert_to_tensor(value=value_range, name='value_range') lo = value_range[0] hi = value_range[1] if nbins is None: nbins = tf.cast(hi - lo, dtype=tf.int32) delta = (hi - lo) / tf.cast( nbins, dtype=dtype_util.base_dtype(value_range.dtype)) edges = tf.range(start=lo, limit=hi, delta=delta, dtype=dtype_util.base_dtype(x.dtype)) counts = tf.histogram_fixed_width(x, value_range=value_range, nbins=nbins) return counts, edges
def _forward(self, x): with tf.control_dependencies(self._assertions(x)): x_shape = tf.shape(input=x) identity_matrix = tf.eye(x_shape[-1], batch_shape=x_shape[:-2], dtype=dtype_util.base_dtype(x.dtype)) # Note `matrix_triangular_solve` implicitly zeros upper triangular of `x`. y = tf.linalg.triangular_solve(x, identity_matrix) y = tf.matmul(y, y, adjoint_a=True) return tf.linalg.cholesky(y)
def _log_prob(self, k): with tf.name_scope("Cat2log_prob"): logits = self.logits_parameter() if self.validate_args: k = distribution_util.embed_check_integer_casting_closed( k, target_dtype=self.dtype) k, logits = _broadcast_cat_event_and_params( k, logits, base_dtype=dtype_util.base_dtype(self.dtype)) logits_normalised = tf.math.log(tf.math.softmax(logits)) return tf.gather(logits_normalised, k, batch_dims=1)
def _forward_log_det_jacobian(self, x): # is_constant_jacobian = True for this bijector, hence the # `log_det_jacobian` need only be specified for a single input, as this will # be tiled to match `event_ndims`. if self.scale is None: return tf.constant(0., dtype=dtype_util.base_dtype(x.dtype)) with tf.control_dependencies(self._maybe_collect_assertions() if self. validate_args else []): return self.scale.log_abs_determinant()
def _inverse(self, y): # As specified in the Stan reference manual, the procedure is as follows: # N = y.shape[-1] # z_k = y_k / (1 - sum_{i=1 to k-1} y_i) # x_k = logit(z_k) - log(1 / (N - k)) offset = tf.math.log( tf.cast(tf.range(ps.shape(y)[-1] - 1, 0, delta=-1), dtype=dtype_util.base_dtype(y.dtype))) z = y / (1. - tf.math.cumsum(y, axis=-1, exclusive=True)) return tf.math.log(z[..., :-1]) - tf.math.log1p(-z[..., :-1]) + offset
def _entropy(self, **kwargs): if not self.bijector.is_constant_jacobian: raise NotImplementedError('`entropy` is not implemented.') if not self.bijector._is_injective: # pylint: disable=protected-access raise NotImplementedError('`entropy` is not implemented when ' '`bijector` is not injective.') distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs) override_event_shape = tf.convert_to_tensor(self._override_event_shape) override_batch_shape = tf.convert_to_tensor(self._override_batch_shape) base_batch_shape_tensor = self.distribution.batch_shape_tensor() base_event_shape_tensor = self.distribution.event_shape_tensor() # Suppose Y = g(X) where g is a diffeomorphism and X is a continuous rv. It # can be shown that: # H[Y] = H[X] + E_X[(log o abs o det o J o g)(X)]. # If is_constant_jacobian then: # E_X[(log o abs o det o J o g)(X)] = (log o abs o det o J o g)(c) # where c can by anything. entropy = self.distribution.entropy(**distribution_kwargs) if self._is_maybe_event_override: # H[X] = sum_i H[X_i] if X_i are mutually independent. # This means that a reduce_sum is a simple rescaling. entropy = entropy * tf.cast(tf.reduce_prod(override_event_shape), dtype=dtype_util.base_dtype( entropy.dtype)) if self._is_maybe_batch_override: new_shape = tf.concat([ prefer_static.ones_like(override_batch_shape), base_batch_shape_tensor ], 0) entropy = tf.reshape(entropy, new_shape) multiples = tf.concat([ override_batch_shape, prefer_static.ones_like(base_batch_shape_tensor) ], 0) entropy = tf.tile(entropy, multiples) dummy = prefer_static.zeros(shape=tf.concat([ self._batch_shape_tensor(override_batch_shape, base_batch_shape_tensor), self._event_shape_tensor(override_event_shape, base_event_shape_tensor) ], 0), dtype=self.dtype) event_ndims = ( tensorshape_util.rank(self.event_shape) # pylint: disable=g-long-ternary if tensorshape_util.rank(self.event_shape) is not None else tf.size( self._event_shape_tensor(override_event_shape, base_event_shape_tensor))) ildj = self.bijector.inverse_log_det_jacobian(dummy, event_ndims=event_ndims, **bijector_kwargs) entropy = entropy - tf.cast(ildj, entropy.dtype) tensorshape_util.set_shape(entropy, self.batch_shape) return entropy
def _maybe_validate_distributions(distributions, dtype_override, validate_args): """Checks that `distributions` satisfies all assumptions.""" assertions = [] if not _is_iterable(distributions) or not distributions: raise ValueError('`distributions` must be a list of one or more ' 'distributions.') if dtype_override is None: dts = [ dtype_util.base_dtype(d.dtype) for d in distributions if d.dtype is not None ] if dts[1:] != dts[:-1]: raise TypeError( 'Distributions must have same dtype; found: {}.'.format( set(dtype_util.name(dt) for dt in dts))) # Validate event_ndims. for d in distributions: if tensorshape_util.rank(d.event_shape) is not None: if tensorshape_util.rank(d.event_shape) != 1: raise ValueError('`Distribution` must be vector variate, ' 'found event nimds: {}.'.format( tensorshape_util.rank(d.event_shape))) elif validate_args: assertions.append( assert_util.assert_equal( 1, tf.size(d.event_shape_tensor()), message='`Distribution` must be vector variate.')) batch_shapes = [d.batch_shape for d in distributions] if all(tensorshape_util.is_fully_defined(b) for b in batch_shapes): if batch_shapes[1:] != batch_shapes[:-1]: raise ValueError('Distributions must have the same `batch_shape`; ' 'found: {}.'.format(batch_shapes)) elif validate_args: batch_shapes = [ tensorshape_util.as_list(d.batch_shape) # pylint: disable=g-complex-comprehension if tensorshape_util.is_fully_defined(d.batch_shape) else d.batch_shape_tensor() for d in distributions ] assertions.extend( assert_util.assert_equal( # pylint: disable=g-complex-comprehension b1, b2, message='Distribution `batch_shape`s must be identical.') for b1, b2 in zip(batch_shapes[1:], batch_shapes[:-1])) return assertions
def _log_survival_function(self, x): num_categories = self._num_categories() x, augmented_log_survival = _broadcast_cat_event_and_params( event=x, params=tf.math.log_sigmoid(self.loc[..., tf.newaxis] - self._augmented_cutpoints()), base_dtype=dtype_util.base_dtype(self.dtype)) x_flat = tf.reshape(x, [-1, 1]) augmented_log_survival_flat = tf.reshape(augmented_log_survival, [-1, num_categories + 1]) log_survival_flat = tf.gather(params=augmented_log_survival_flat, indices=tf.clip_by_value( x_flat + 1, 0, num_categories), batch_dims=1) return tf.reshape(log_survival_flat, shape=ps.shape(x))
def _forward(self, x): # As specified in the Stan reference manual, the procedure is as follows: # N = x.shape[-1] + 1 # z_k = sigmoid(x + log(1 / (N - k))) # y_1 = z_1 # y_k = (1 - sum_{i=1 to k-1} y_i) * z_k # y_N = 1 - sum_{i=1 to N-1} y_i # TODO(b/128857065): The numerics can possibly be improved here with a # log-space computation. offset = -tf.math.log( tf.cast(tf.range(ps.shape(x)[-1], 0, delta=-1), dtype=dtype_util.base_dtype(x.dtype))) z = tf.math.sigmoid(x + offset) y = z * tf.math.cumprod(1 - z, axis=-1, exclusive=True) return tf.concat([y, 1. - tf.reduce_sum(y, axis=-1, keepdims=True)], axis=-1)
def _parameter_control_dependencies(self, is_init): assertions = [] sample_shape = None # Memoize concretization. # Check valid shape. ndims_ = tensorshape_util.rank(self.sample_shape.shape) if is_init != (ndims_ is None): msg = 'Argument `sample_shape` must be either a scalar or a vector.' if ndims_ is not None: if ndims_ > 1: raise ValueError(msg) elif self.validate_args: if sample_shape is None: sample_shape = tf.convert_to_tensor(self.sample_shape) assertions.append( assert_util.assert_less(tf.rank(sample_shape), 2, message=msg)) # Check valid dtype. if is_init: # No xor check because `dtype` cannot change. dtype_ = self.sample_shape.dtype if dtype_ is None: if sample_shape is None: sample_shape = tf.convert_to_tensor(self.sample_shape) dtype_ = sample_shape.dtype if dtype_util.base_dtype(dtype_) not in {tf.int32, tf.int64}: raise TypeError( 'Argument `sample_shape` must be integer type; ' 'saw {}.'.format(dtype_util.name(dtype_))) # Check valid "value". if is_init != tensor_util.is_ref(self.sample_shape): sample_shape_ = tf.get_static_value(self.sample_shape) msg = 'Argument `sample_shape` must have non-negative values.' if sample_shape_ is not None: if np.any(np.array(sample_shape_) < 0): raise ValueError('{} Saw: {}'.format(msg, sample_shape_)) elif self.validate_args: if sample_shape is None: sample_shape = tf.convert_to_tensor(self.sample_shape) assertions.append( assert_util.assert_greater(sample_shape, -1, message=msg)) return assertions