def _inverse_log_det_jacobian(self, y): # If event_ndims = 2, # F^{-1}(y) = (-y, y), so DF^{-1}(y) = (-1, 1), # so Log|DF^{-1}(y)| = Log[1, 1] = [0, 0]. with tf.control_dependencies(self._assertions(y)): zero = tf.zeros([], dtype=dtype_util.base_dtype(y.dtype)) return zero, zero
def _cdf(self, k): # TODO(b/135263541): Improve numerical precision of categorical.cdf. probs = self.probs_parameter() num_categories = self._num_categories(probs) k, probs = _broadcast_cat_event_and_params( k, probs, base_dtype=dtype_util.base_dtype(self.dtype)) # Since the lowest number in the support is 0, any k < 0 should be zero in # the output. should_be_zero = k < 0 # Will use k as an index in the gather below, so clip it to {0,...,K-1}. k = tf.clip_by_value(tf.cast(k, tf.int32), 0, num_categories - 1) batch_shape = tf.shape(k) # tf.gather(..., batch_dims=batch_dims) requires static batch_dims kwarg, so # to handle the case where the batch shape is dynamic, flatten the batch # dims (so we know batch_dims=1). k_flat_batch = tf.reshape(k, [-1]) probs_flat_batch = tf.reshape( probs, tf.concat(([-1], [num_categories]), axis=0)) cdf_flat = tf.gather(tf.cumsum(probs_flat_batch, axis=-1), k_flat_batch[..., tf.newaxis], batch_dims=1) cdf = tf.reshape(cdf_flat, shape=batch_shape) zero = np.array(0, dtype=dtype_util.as_numpy_dtype(cdf.dtype)) return tf.where(should_be_zero, zero, cdf)
def _log_prob(self, k): logits = self.logits_parameter() if self.validate_args: k = distribution_util.embed_check_integer_casting_closed( k, target_dtype=tf.int32) k, logits = _broadcast_cat_event_and_params( k, logits, base_dtype=dtype_util.base_dtype(self.dtype)) return -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=k, logits=logits)
def _forward_log_det_jacobian(self, x): # For a discussion of this (non-obvious) result, see Note 7.2.2 (and the # sections leading up to it, for context) in # http://neutrino.aquaphoenix.com/ReactionDiffusion/SERC5chap7.pdf with tf.control_dependencies(self._assertions(x)): matrix_dim = tf.cast( tf.shape(x)[-1], dtype_util.base_dtype(x.dtype)) return -(matrix_dim + 1) * tf.reduce_sum( tf.math.log(tf.abs(tf.linalg.diag_part(x))), axis=-1)
def _inverse(self, y): # As specified in the Stan reference manual, the procedure is as follows: # N = y.shape[-1] # z_k = y_k / (1 - sum_{i=1 to k-1} y_i) # x_k = logit(z_k) - log(1 / (N - k)) offset = tf.math.log( tf.cast(tf.range(tf.shape(y)[-1] - 1, 0, delta=-1), dtype=dtype_util.base_dtype(y.dtype))) z = y / (1. - tf.math.cumsum(y, axis=-1, exclusive=True)) return tf.math.log(z[..., :-1]) - tf.math.log1p(-z[..., :-1]) + offset
def _forward_log_det_jacobian(self, x): # is_constant_jacobian = True for this bijector, hence the # `log_det_jacobian` need only be specified for a single input, as this will # be tiled to match `event_ndims`. if self.scale is None: return tf.constant(0., dtype=dtype_util.base_dtype(x.dtype)) with tf.control_dependencies(self._maybe_collect_assertions() if self. validate_args else []): return self.scale.log_abs_determinant()
def _forward(self, x): with tf.control_dependencies(self._assertions(x)): x_shape = tf.shape(x) identity_matrix = tf.eye( x_shape[-1], batch_shape=x_shape[:-2], dtype=dtype_util.base_dtype(x.dtype)) # Note `matrix_triangular_solve` implicitly zeros upper triangular of `x`. y = tf.linalg.triangular_solve(x, identity_matrix) y = tf.matmul(y, y, adjoint_a=True) return tf.linalg.cholesky(y)
def _forward(self, x): # As specified in the Stan reference manual, the procedure is as follows: # N = x.shape[-1] + 1 # z_k = sigmoid(x + log(1 / (N - k))) # y_1 = z_1 # y_k = (1 - sum_{i=1 to k-1} y_i) * z_k # y_N = 1 - sum_{i=1 to N-1} y_i # TODO(b/128857065): The numerics can possibly be improved here with a # log-space computation. offset = -tf.math.log( tf.cast(tf.range(tf.shape(x)[-1], 0, delta=-1), dtype=dtype_util.base_dtype(x.dtype))) z = tf.math.sigmoid(x + offset) y = z * tf.math.cumprod(1 - z, axis=-1, exclusive=True) return tf.concat([y, 1. - tf.reduce_sum(y, axis=-1, keepdims=True)], axis=-1)
def _entropy(self, **kwargs): if not self.bijector.is_constant_jacobian: raise NotImplementedError("entropy is not implemented") if not self.bijector._is_injective: # pylint: disable=protected-access raise NotImplementedError("entropy is not implemented when " "bijector is not injective.") distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs) # Suppose Y = g(X) where g is a diffeomorphism and X is a continuous rv. It # can be shown that: # H[Y] = H[X] + E_X[(log o abs o det o J o g)(X)]. # If is_constant_jacobian then: # E_X[(log o abs o det o J o g)(X)] = (log o abs o det o J o g)(c) # where c can by anything. entropy = self.distribution.entropy(**distribution_kwargs) if self._is_maybe_event_override: # H[X] = sum_i H[X_i] if X_i are mutually independent. # This means that a reduce_sum is a simple rescaling. entropy = entropy * tf.cast( tf.reduce_prod(self._override_event_shape), dtype=dtype_util.base_dtype(entropy.dtype)) if self._is_maybe_batch_override: new_shape = tf.concat([ prefer_static.ones_like(self._override_batch_shape), self.distribution.batch_shape_tensor() ], 0) entropy = tf.reshape(entropy, new_shape) multiples = tf.concat([ self._override_batch_shape, prefer_static.ones_like(self.distribution.batch_shape_tensor()) ], 0) entropy = tf.tile(entropy, multiples) dummy = prefer_static.zeros(shape=tf.concat( [self.batch_shape_tensor(), self.event_shape_tensor()], 0), dtype=self.dtype) event_ndims = (tensorshape_util.rank(self.event_shape) if tensorshape_util.rank(self.event_shape) is not None else tf.size(self.event_shape_tensor())) ildj = self.bijector.inverse_log_det_jacobian(dummy, event_ndims=event_ndims, **bijector_kwargs) entropy = entropy - tf.cast(ildj, entropy.dtype) tensorshape_util.set_shape(entropy, self.batch_shape) return entropy
def _forward_log_det_jacobian(self, x, **kwargs): x = tf.convert_to_tensor(x, name="x") fldj = tf.cast(0., dtype=dtype_util.base_dtype(x.dtype)) if not self.bijectors: return fldj event_ndims = self._maybe_get_static_event_ndims( self.forward_min_event_ndims) if _use_static_shape(x, event_ndims): event_shape = x.shape[tensorshape_util.rank(x.shape) - event_ndims:] else: event_shape = tf.shape(x)[tf.rank(x) - event_ndims:] # TODO(b/129973548): Document and simplify. for b in reversed(self.bijectors): fldj = fldj + b.forward_log_det_jacobian( x, event_ndims=event_ndims, **kwargs.get(b.name, {})) if _use_static_shape(x, event_ndims): event_shape = b.forward_event_shape(event_shape) event_ndims = self._maybe_get_static_event_ndims( tensorshape_util.rank(event_shape)) else: event_shape = b.forward_event_shape_tensor(event_shape) event_shape_ = distribution_util.maybe_get_static_value( event_shape) event_ndims = tf.size(event_shape) event_ndims_ = self._maybe_get_static_event_ndims(event_ndims) if event_ndims_ is not None and event_shape_ is not None: event_ndims = event_ndims_ event_shape = event_shape_ x = b.forward(x, **kwargs.get(b.name, {})) return fldj
def _inverse_log_det_jacobian(self, y, **kwargs): y = tf.convert_to_tensor(y, name="y") ildj = tf.cast(0., dtype=dtype_util.base_dtype(y.dtype)) if not self.bijectors: return ildj event_ndims = self._maybe_get_static_event_ndims( self.inverse_min_event_ndims) if _use_static_shape(y, event_ndims): event_shape = y.shape[tensorshape_util.rank(y.shape) - event_ndims:] else: event_shape = tf.shape(y)[tf.rank(y) - event_ndims:] # TODO(b/129973548): Document and simplify. for b in self.bijectors: ildj = ildj + b.inverse_log_det_jacobian( y, event_ndims=event_ndims, **kwargs.get(b.name, {})) if _use_static_shape(y, event_ndims): event_shape = b.inverse_event_shape(event_shape) event_ndims = self._maybe_get_static_event_ndims( tensorshape_util.rank(event_shape)) else: event_shape = b.inverse_event_shape_tensor(event_shape) event_shape_ = distribution_util.maybe_get_static_value( event_shape) event_ndims = tf.size(event_shape) event_ndims_ = self._maybe_get_static_event_ndims(event_ndims) if event_ndims_ is not None and event_shape_ is not None: event_ndims = event_ndims_ event_shape = event_shape_ y = b.inverse(y, **kwargs.get(b.name, {})) return ildj
def __init__(self, mix_loc, temperature, distribution, loc=None, scale=None, quadrature_size=8, quadrature_fn=quadrature_scheme_softmaxnormal_quantiles, validate_args=False, allow_nan_stats=True, name="VectorDiffeomixture"): """Constructs the VectorDiffeomixture on `R^d`. The vector diffeomixture (VDM) approximates the compound distribution ```none p(x) = int p(x | z) p(z) dz, where z is in the K-simplex, and p(x | z) := p(x | loc=sum_k z[k] loc[k], scale=sum_k z[k] scale[k]) ``` Args: mix_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`. In terms of samples, larger `mix_loc[..., k]` ==> `Z` is more likely to put more weight on its `kth` component. temperature: `float`-like `Tensor`. Broadcastable with `mix_loc`. In terms of samples, smaller `temperature` means one component is more likely to dominate. I.e., smaller `temperature` makes the VDM look more like a standard mixture of `K` components. distribution: `tfp.distributions.Distribution`-like instance. Distribution from which `d` iid samples are used as input to the selected affine transformation. Must be a scalar-batch, scalar-event distribution. Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is a function of non-trainable parameters. WARNING: If you backprop through a VectorDiffeomixture sample and the `distribution` is not `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then the gradient will be incorrect! loc: Length-`K` list of `float`-type `Tensor`s. The `k`-th element represents the `shift` used for the `k`-th affine transformation. If the `k`-th item is `None`, `loc` is implicitly `0`. When specified, must have shape `[B1, ..., Bb, d]` where `b >= 0` and `d` is the event size. scale: Length-`K` list of `LinearOperator`s. Each should be positive-definite and operate on a `d`-dimensional vector space. The `k`-th element represents the `scale` used for the `k`-th affine transformation. `LinearOperator`s must have shape `[B1, ..., Bb, d, d]`, `b >= 0`, i.e., characterizes `b`-batches of `d x d` matrices quadrature_size: Python `int` scalar representing number of quadrature points. Larger `quadrature_size` means `q_N(x)` better approximates `p(x)`. quadrature_fn: Python callable taking `normal_loc`, `normal_scale`, `quadrature_size`, `validate_args` and returning `tuple(grid, probs)` representing the SoftmaxNormal grid and corresponding normalized weight. normalized) weight. Default value: `quadrature_scheme_softmaxnormal_quantiles`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: if `not scale or len(scale) < 2`. ValueError: if `len(loc) != len(scale)` ValueError: if `quadrature_grid_and_probs is not None` and `len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])` ValueError: if `validate_args` and any not scale.is_positive_definite. TypeError: if any scale.dtype != scale[0].dtype. TypeError: if any loc.dtype != scale[0].dtype. NotImplementedError: if `len(scale) != 2`. ValueError: if `not distribution.is_scalar_batch`. ValueError: if `not distribution.is_scalar_event`. """ parameters = dict(locals()) with tf.name_scope(name) as name: if not scale or len(scale) < 2: raise ValueError( "Must specify list (or list-like object) of scale " "LinearOperators, one for each component with " "num_component >= 2.") if loc is None: loc = [None] * len(scale) if len(loc) != len(scale): raise ValueError("loc/scale must be same-length lists " "(or same-length list-like objects).") dtype = dtype_util.base_dtype(scale[0].dtype) loc = [ tf.convert_to_tensor(loc_, dtype=dtype, name="loc{}".format(k)) if loc_ is not None else None for k, loc_ in enumerate(loc) ] for k, scale_ in enumerate(scale): if validate_args and not scale_.is_positive_definite: raise ValueError( "scale[{}].is_positive_definite = {} != True".format( k, scale_.is_positive_definite)) if dtype_util.base_dtype(scale_.dtype) != dtype: raise TypeError( "dtype mismatch; scale[{}].base_dtype=\"{}\" != \"{}\"" .format(k, dtype_util.name(scale_.dtype), dtype_util.name(dtype))) self._endpoint_affine = [ affine_linear_operator_bijector.AffineLinearOperator( # pylint: disable=g-complex-comprehension shift=loc_, scale=scale_, validate_args=validate_args, name="endpoint_affine_{}".format(k)) for k, (loc_, scale_) in enumerate(zip(loc, scale)) ] # TODO(jvdillon): Remove once we support k-mixtures. # We make this assertion here because otherwise `grid` would need to be a # vector not a scalar. if len(scale) != 2: raise NotImplementedError( "Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(scale))) mix_loc = tf.convert_to_tensor(mix_loc, dtype=dtype, name="mix_loc") temperature = tf.convert_to_tensor(temperature, dtype=dtype, name="temperature") self._grid, probs = tuple( quadrature_fn(mix_loc / temperature, 1. / temperature, quadrature_size, validate_args)) # Note: by creating the logits as `log(prob)` we ensure that # `self.mixture_distribution.logits` is equivalent to # `math_ops.log(self.mixture_distribution.probs)`. self._mixture_distribution = categorical.Categorical( logits=tf.math.log(probs), validate_args=validate_args, allow_nan_stats=allow_nan_stats) asserts = distribution_util.maybe_check_scalar_distribution( distribution, dtype, validate_args) if asserts: self._grid = distribution_util.with_dependencies( asserts, self._grid) self._distribution = distribution self._interpolated_affine = [ affine_linear_operator_bijector.AffineLinearOperator( # pylint: disable=g-complex-comprehension shift=loc_, scale=scale_, validate_args=validate_args, name="interpolated_affine_{}".format(k)) for k, (loc_, scale_) in enumerate( zip(interpolate_loc(self._grid, loc), interpolate_scale(self._grid, scale))) ] [ self._batch_shape_, self._batch_shape_tensor_, self._event_shape_, self._event_shape_tensor_, ] = determine_batch_event_shapes(self._grid, self._endpoint_affine) super(VectorDiffeomixture, self).__init__( dtype=dtype, # We hard-code `FULLY_REPARAMETERIZED` because when # `validate_args=True` we verify that indeed # `distribution.reparameterization_type == FULLY_REPARAMETERIZED`. A # distribution which is a function of only non-trainable parameters # also implies we can use `FULLY_REPARAMETERIZED`. However, we cannot # easily test for that possibility thus we use `validate_args=False` # as a "back-door" to allow users a way to use non # `FULLY_REPARAMETERIZED` distribution. In such cases IT IS THE USERS # RESPONSIBILITY to verify that the base distribution is a function of # non-trainable parameters. reparameterization_type=reparameterization. FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, name=name)
def quadrature_scheme_softmaxnormal_quantiles(normal_loc, normal_scale, quadrature_size, validate_args=False, name=None): """Use SoftmaxNormal quantiles to form quadrature on `K - 1` simplex. A `SoftmaxNormal` random variable `Y` may be generated via ``` Y = SoftmaxCentered(X), X = Normal(normal_loc, normal_scale) ``` Args: normal_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0. The location parameter of the Normal used to construct the SoftmaxNormal. normal_scale: `float`-like `Tensor`. Broadcastable with `normal_loc`. The scale parameter of the Normal used to construct the SoftmaxNormal. quadrature_size: Python `int` scalar representing the number of quadrature points. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. name: Python `str` name prefixed to Ops created by this class. Returns: grid: Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the convex combination of affine parameters for `K` components. `grid[..., :, n]` is the `n`-th grid point, living in the `K - 1` simplex. probs: Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the associated with each grid point. """ with tf.name_scope(name or "softmax_normal_grid_and_probs"): normal_loc = tf.convert_to_tensor(normal_loc, name="normal_loc") dt = dtype_util.base_dtype(normal_loc.dtype) normal_scale = tf.convert_to_tensor(normal_scale, dtype=dt, name="normal_scale") normal_scale = maybe_check_quadrature_param(normal_scale, "normal_scale", validate_args) dist = normal.Normal(loc=normal_loc, scale=normal_scale) def _get_batch_ndims(): """Helper to get rank(dist.batch_shape), statically if possible.""" ndims = tensorshape_util.rank(dist.batch_shape) if ndims is None: ndims = tf.shape(dist.batch_shape_tensor())[0] return ndims batch_ndims = _get_batch_ndims() def _get_final_shape(qs): """Helper to build `TensorShape`.""" bs = tensorshape_util.with_rank_at_least(dist.batch_shape, 1) num_components = tf.compat.dimension_value(bs[-1]) if num_components is not None: num_components += 1 tail = tf.TensorShape([num_components, qs]) return bs[:-1].concatenate(tail) def _compute_quantiles(): """Helper to build quantiles.""" # Omit {0, 1} since they might lead to Inf/NaN. zero = tf.zeros([], dtype=dist.dtype) edges = tf.linspace(zero, 1., quadrature_size + 3)[1:-1] # Expand edges so its broadcast across batch dims. edges = tf.reshape( edges, shape=tf.concat( [[-1], tf.ones([batch_ndims], dtype=tf.int32)], axis=0)) quantiles = dist.quantile(edges) quantiles = softmax_centered_bijector.SoftmaxCentered().forward( quantiles) # Cyclically permute left by one. perm = tf.concat([tf.range(1, 1 + batch_ndims), [0]], axis=0) quantiles = tf.transpose(a=quantiles, perm=perm) tensorshape_util.set_shape(quantiles, _get_final_shape(quadrature_size + 1)) return quantiles quantiles = _compute_quantiles() # Compute grid as quantile midpoints. grid = (quantiles[..., :-1] + quantiles[..., 1:]) / 2. # Set shape hints. tensorshape_util.set_shape(grid, _get_final_shape(quadrature_size)) # By construction probs is constant, i.e., `1 / quadrature_size`. This is # important, because non-constant probs leads to non-reparameterizable # samples. probs = tf.fill(dims=[quadrature_size], value=1. / tf.cast(quadrature_size, dist.dtype)) return grid, probs
def _sample_n(self, n, seed): batch_shape = self.batch_shape_tensor() event_shape = self.event_shape_tensor() batch_ndims = tf.shape(batch_shape)[0] ndims = batch_ndims + 3 # sample_ndims=1, event_ndims=2 shape = tf.concat([[n], batch_shape, event_shape], 0) stream = SeedStream(seed, salt="Wishart") # Complexity: O(nbk**2) x = tf.random.normal(shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=stream()) # Complexity: O(nbk) # This parametrization is equivalent to Chi2, i.e., # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2) expanded_df = self.df * tf.ones( self.scale_operator.batch_shape_tensor(), dtype=dtype_util.base_dtype(self.df.dtype)) g = tf.random.gamma(shape=[n], alpha=self._multi_gamma_sequence( 0.5 * expanded_df, self.dimension), beta=0.5, dtype=self.dtype, seed=stream()) # Complexity: O(nbk**2) x = tf.linalg.band_part(x, -1, 0) # Tri-lower. # Complexity: O(nbk) x = tf.linalg.set_diag(x, tf.sqrt(g)) # Make batch-op ready. # Complexity: O(nbk**2) perm = tf.concat([tf.range(1, ndims), [0]], 0) x = tf.transpose(a=x, perm=perm) shape = tf.concat( [batch_shape, [event_shape[0]], [event_shape[1] * n]], 0) x = tf.reshape(x, shape) # Complexity: O(nbM) where M is the complexity of the operator solving a # vector system. For LinearOperatorLowerTriangular, each matmul is O(k^3) so # this step has complexity O(nbk^3). x = self.scale_operator.matmul(x) # Undo make batch-op ready. # Complexity: O(nbk**2) shape = tf.concat([batch_shape, event_shape, [n]], 0) x = tf.reshape(x, shape) perm = tf.concat([[ndims - 1], tf.range(0, ndims - 1)], 0) x = tf.transpose(a=x, perm=perm) if not self.input_output_cholesky: # Complexity: O(nbk**3) x = tf.matmul(x, x, adjoint_b=True) return x
def _inverse_log_det_jacobian(self, y): return tf.constant(0., dtype=dtype_util.base_dtype(y.dtype))
def _forward_log_det_jacobian(self, x): return tf.constant(0., dtype=dtype_util.base_dtype(x.dtype))
def _inverse_log_det_jacobian(self, y): # is_constant_jacobian = True for this bijector, hence the # `log_det_jacobian` need only be specified for a single input, as this will # be tiled to match `event_ndims`. return tf.constant(0., dtype=dtype_util.base_dtype(y.dtype))
def __init__(self, distributions, dtype_override=None, validate_args=False, allow_nan_stats=False, name='Blockwise'): """Construct the `Blockwise` distribution. Args: distributions: Python `list` of `tfp.distributions.Distribution` instances. All distribution instances must have the same `batch_shape` and all must have `event_ndims==1`, i.e., be vector-variate distributions. dtype_override: samples of `distributions` will be cast to this `dtype`. If unspecified, all `distributions` must have the same `dtype`. Default value: `None` (i.e., do not cast). validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = dict(locals()) with tf.name_scope(name) as name: self._distributions = distributions if dtype_override is not None: distributions = tf.nest.map_structure( lambda d: _Cast(d, dtype_override), distributions) if _is_iterable(distributions): self._distribution = ( joint_distribution_sequential.JointDistributionSequential( list(distributions))) else: self._distribution = distributions # Need to cache these for JointDistributions as the batch shape of that # distribution can change after `_sample` calls. self._cached_batch_shape_tensor = self._distribution.batch_shape_tensor( ) self._cached_batch_shape = self._distribution.batch_shape if dtype_override is not None: dtype = dtype_override else: dtype = set( dtype_util.base_dtype(dtype) for dtype in tf.nest.flatten(self._distribution.dtype) if dtype is not None) if len(dtype) == 0: # pylint: disable=g-explicit-length-test dtype = tf.float32 elif len(dtype) == 1: dtype = dtype.pop() else: raise TypeError( 'Distributions must have same dtype; found: {}.'. format(self._distribution.dtype)) reparameterization_type = set( tf.nest.flatten(self._distribution.reparameterization_type)) reparameterization_type = (reparameterization_type.pop() if len(reparameterization_type) == 1 else reparameterization.NOT_REPARAMETERIZED) super(Blockwise, self).__init__( dtype=dtype, validate_args=validate_args, allow_nan_stats=allow_nan_stats, reparameterization_type=reparameterization_type, parameters=parameters, name=name)
def convert_nonref_to_tensor(value, dtype=None, dtype_hint=None, name=None): """Converts the given `value` to a `Tensor` if input is nonreference type. This function converts Python objects of various types to `Tensor` objects only if the input has nonreference semantics. Reference semantics are characterized by `tensor_util.is_ref` and is any object which is a `tf.Variable` or instance of `tf.Module`. This function accepts any input which `tf.convert_to_tensor` would also. Note: This function diverges from default Numpy behavior for `float` and `string` types when `None` is present in a Python list or scalar. Rather than silently converting `None` values, an error will be thrown. Args: value: An object whose type has a registered `Tensor` conversion function. dtype: Optional element type for the returned tensor. If missing, the type is inferred from the type of `value`. dtype_hint: Optional element type for the returned tensor, used when dtype is None. In some cases, a caller may not have a dtype in mind when converting to a tensor, so dtype_hint can be used as a soft preference. If the conversion to `dtype_hint` is not possible, this argument has no effect. name: Optional name to use if a new `Tensor` is created. Returns: tensor: A `Tensor` based on `value`. Raises: TypeError: If no conversion function is registered for `value` to `dtype`. RuntimeError: If a registered conversion function returns an invalid value. ValueError: If the `value` is a tensor not of given `dtype` in graph mode. #### Examples: ```python from tensorflow_probability.python.experimental.substrates.jax.internal import tensor_util x = tf.Variable(0.) y = tensor_util.convert_nonref_to_tensor(x) x is y # ==> True x = tf.constant(0.) y = tensor_util.convert_nonref_to_tensor(x) x is y # ==> True x = np.array(0.) y = tensor_util.convert_nonref_to_tensor(x) x is y # ==> False tf.is_tensor(y) # ==> True x = tfp.util.DeferredTensor(lambda x: x, 13.37) y = tensor_util.convert_nonref_to_tensor(x) x is y # ==> True tf.is_tensor(y) # ==> True tf.equal(y, 13.37) # ==> True ``` """ # We explicitly do not use a tf.name_scope to avoid graph clutter. if value is None: return None if is_ref(value): if dtype is None: return value dtype_base = dtype_util.base_dtype(dtype) value_dtype_base = dtype_util.base_dtype(value.dtype) if dtype_base != value_dtype_base: raise TypeError('Mutable type must be of dtype "{}" but is "{}".'.format( dtype_util.name(dtype_base), dtype_util.name(value_dtype_base))) return value return tf.convert_to_tensor( value, dtype=dtype, dtype_hint=dtype_hint, name=name)