def _forward_log_det_jacobian(self, x, **kwargs): x = tf.convert_to_tensor(x, name="x") fldj = tf.cast(0., dtype=dtype_util.base_dtype(x.dtype)) if not self.bijectors: return fldj event_ndims = self._maybe_get_static_event_ndims( self.forward_min_event_ndims) if _use_static_shape(x, event_ndims): event_shape = x.shape[tensorshape_util.rank(x.shape) - event_ndims:] else: event_shape = tf.shape(x)[tf.rank(x) - event_ndims:] # TODO(b/129973548): Document and simplify. for b in reversed(self.bijectors): fldj = fldj + b.forward_log_det_jacobian( x, event_ndims=event_ndims, **kwargs.get(b.name, {})) if _use_static_shape(x, event_ndims): event_shape = b.forward_event_shape(event_shape) event_ndims = self._maybe_get_static_event_ndims( tensorshape_util.rank(event_shape)) else: event_shape = b.forward_event_shape_tensor(event_shape) event_shape_ = distribution_util.maybe_get_static_value( event_shape) event_ndims = tf.size(event_shape) event_ndims_ = self._maybe_get_static_event_ndims(event_ndims) if event_ndims_ is not None and event_shape_ is not None: event_ndims = event_ndims_ event_shape = event_shape_ x = b.forward(x, **kwargs.get(b.name, {})) return fldj
def _bessel_ive(v, z, cache=None): """Computes I_v(z)*exp(-abs(z)) using a recurrence relation, where z > 0.""" # TODO(b/67497980): Switch to a more numerically faithful implementation. z = tf.convert_to_tensor(z) wrap = lambda result: tf.debugging.check_numerics(result, 'besseli{}'.format(v )) if float(v) >= 2: raise ValueError( 'Evaluating bessel_i by recurrence becomes imprecise for large v') cache = cache or {} safe_z = tf.where(z > 0, z, tf.ones_like(z)) if v in cache: return wrap(cache[v]) if v == 0: cache[v] = tf.math.bessel_i0e(z) elif v == 1: cache[v] = tf.math.bessel_i1e(z) elif v == 0.5: # sinh(x)*exp(-abs(x)), sinh(x) = (e^x - e^{-x}) / 2 sinhe = lambda x: (tf.exp(x - tf.abs(x)) - tf.exp(-x - tf.abs(x))) / 2 cache[v] = ( np.sqrt(2 / np.pi) * sinhe(z) * tf.where(z > 0, tf.math.rsqrt(safe_z), tf.ones_like(safe_z))) elif v == -0.5: # cosh(x)*exp(-abs(x)), cosh(x) = (e^x + e^{-x}) / 2 coshe = lambda x: (tf.exp(x - tf.abs(x)) + tf.exp(-x - tf.abs(x))) / 2 cache[v] = ( np.sqrt(2 / np.pi) * coshe(z) * tf.where(z > 0, tf.math.rsqrt(safe_z), tf.ones_like(safe_z))) if v <= 1: return wrap(cache[v]) # Recurrence relation: cache[v] = (_bessel_ive(v - 2, z, cache) - (2 * (v - 1)) * _bessel_ive(v - 1, z, cache) / z) return wrap(cache[v])
def _inverse_log_det_jacobian(self, y, **kwargs): y = tf.convert_to_tensor(y, name="y") ildj = tf.cast(0., dtype=dtype_util.base_dtype(y.dtype)) if not self.bijectors: return ildj event_ndims = self._maybe_get_static_event_ndims( self.inverse_min_event_ndims) if _use_static_shape(y, event_ndims): event_shape = y.shape[tensorshape_util.rank(y.shape) - event_ndims:] else: event_shape = tf.shape(y)[tf.rank(y) - event_ndims:] # TODO(b/129973548): Document and simplify. for b in self.bijectors: ildj = ildj + b.inverse_log_det_jacobian( y, event_ndims=event_ndims, **kwargs.get(b.name, {})) if _use_static_shape(y, event_ndims): event_shape = b.inverse_event_shape(event_shape) event_ndims = self._maybe_get_static_event_ndims( tensorshape_util.rank(event_shape)) else: event_shape = b.inverse_event_shape_tensor(event_shape) event_shape_ = distribution_util.maybe_get_static_value( event_shape) event_ndims = tf.size(event_shape) event_ndims_ = self._maybe_get_static_event_ndims(event_ndims) if event_ndims_ is not None and event_shape_ is not None: event_ndims = event_ndims_ event_shape = event_shape_ y = b.inverse(y, **kwargs.get(b.name, {})) return ildj
def _sparse_block_diag(sp_a): """Returns a block diagonal rank 2 SparseTensor from a batch of SparseTensors. Args: sp_a: A rank 3 `SparseTensor` representing a batch of matrices. Returns: sp_block_diag_a: matrix-shaped, `float` `SparseTensor` with the same dtype as `sparse_or_matrix`, of shape [B * M, B * N] where `sp_a` has shape [B, M, N]. Each [M, N] batch of `sp_a` is lined up along the diagonal. """ # Construct the matrix [[M, N], [1, 0], [0, 1]] which would map the index # (b, i, j) to (Mb + i, Nb + j). This effectively creates a block-diagonal # matrix of dense shape [B * M, B * N]. # Note that this transformation doesn't increase the number of non-zero # entries in the SparseTensor. sp_a_shape = tf.convert_to_tensor(_get_shape(sp_a, tf.int64)) ind_mat = tf.concat([[sp_a_shape[-2:]], tf.eye(2, dtype=tf.int64)], axis=0) indices = tf.matmul(sp_a.indices, ind_mat) dense_shape = sp_a_shape[0] * sp_a_shape[1:] return tf.SparseTensor(indices=indices, values=sp_a.values, dense_shape=dense_shape)
def _entropy(self): samples = tf.convert_to_tensor(self.samples) num_samples = self._compute_num_samples(samples) entropy_shape = self._batch_shape_tensor(samples) # Flatten samples for each batch. if self._event_ndims == 0: samples = tf.reshape(samples, [-1, num_samples]) else: event_size = tf.reduce_prod(self.event_shape_tensor()) samples = tf.reshape(samples, [-1, num_samples, event_size]) # Use map_fn to compute entropy for each batch separately. def _get_entropy(samples): # TODO(b/123985779): Switch to tf.unique_with_counts_v2 when exposed count = gen_array_ops.unique_with_counts_v2(samples, axis=[0]).count prob = tf.cast(count / num_samples, dtype=self.dtype) entropy = tf.reduce_sum(-prob * tf.math.log(prob)) return entropy entropy = tf.map_fn(_get_entropy, samples, dtype=self.dtype) return tf.reshape(entropy, entropy_shape)
def _sample_n(self, n, seed=None): temperature = tf.convert_to_tensor(self.temperature) logits = self._logits_parameter_no_checks() # Uniform variates must be sampled from the open-interval `(0, 1)` rather # than `[0, 1)`. To do so, we use # `np.finfo(dtype_util.as_numpy_dtype(self.dtype)).tiny` because it is the # smallest, positive, 'normal' number. A 'normal' number is such that the # mantissa has an implicit leading 1. Normal, positive numbers x, y have the # reasonable property that, `x + y >= max(x, y)`. In this case, a subnormal # number (i.e., np.nextafter) can cause us to sample 0. uniform_shape = tf.concat( [[n], self._batch_shape_tensor(temperature=temperature, logits=logits), self._event_shape_tensor(logits=logits)], 0) uniform = tf.random.uniform( shape=uniform_shape, minval=np.finfo(dtype_util.as_numpy_dtype(self.dtype)).tiny, maxval=1., dtype=self.dtype, seed=seed) gumbel = -tf.math.log(-tf.math.log(uniform)) noisy_logits = (gumbel + logits) / temperature[..., tf.newaxis] return tf.math.log_softmax(noisy_logits)
def _slice_params_to_dict(dist, params_event_ndims, slices): """Computes the override dictionary of sliced parameters. Args: dist: The tfd.Distribution being batch-sliced. params_event_ndims: Per-event parameter ranks, a `str->int` `dict`. slices: Slices as received by __getitem__. Returns: overrides: `str->Tensor` `dict` of batch-sliced parameter overrides. """ override_dict = {} for param_name, param_event_ndims in six.iteritems(params_event_ndims): # Verify that either None or a legit value is in the parameters dict. if param_name not in dist.parameters: raise ValueError('Distribution {} is missing advertised ' 'parameter {}'.format(dist, param_name)) param = dist.parameters[param_name] if param is None: # some distributions have multiple possible parameterizations; this # param was not provided continue dtype = None if hasattr(dist, param_name): attr = getattr(dist, param_name) dtype = getattr(attr, 'dtype', None) if dtype is None: dtype = dist.dtype warnings.warn('Unable to find property getter for parameter Tensor {} ' 'on {}, falling back to Distribution.dtype {}'.format( param_name, dist, dtype)) param = tf.convert_to_tensor(value=param, dtype=dtype) override_dict[param_name] = _slice_single_param(param, param_event_ndims, slices, dist.batch_shape_tensor()) return override_dict
def _validate_dimension(self, x): x = tf.convert_to_tensor(x, name='x') if tensorshape_util.is_fully_defined(x.shape[-2:]): if (tensorshape_util.dims(x.shape)[-2] == tensorshape_util.dims( x.shape)[-1] == self.dimension): pass else: raise ValueError( 'Input dimension mismatch: expected [..., {}, {}], got {}'. format(self.dimension, self.dimension, tensorshape_util.dims(x.shape))) elif self.validate_args: msg = 'Input dimension mismatch: expected [..., {}, {}], got {}'.format( self.dimension, self.dimension, tf.shape(x)) with tf.control_dependencies([ assert_util.assert_equal(tf.shape(x)[-2], self.dimension, message=msg), assert_util.assert_equal(tf.shape(x)[-1], self.dimension, message=msg) ]): x = tf.identity(x) return x
def quadrature_scheme_softmaxnormal_gauss_hermite(normal_loc, normal_scale, quadrature_size, validate_args=False, name=None): """Use Gauss-Hermite quadrature to form quadrature on `K - 1` simplex. A `SoftmaxNormal` random variable `Y` may be generated via ``` Y = SoftmaxCentered(X), X = Normal(normal_loc, normal_scale) ``` Note: for a given `quadrature_size`, this method is generally less accurate than `quadrature_scheme_softmaxnormal_quantiles`. Args: normal_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0. The location parameter of the Normal used to construct the SoftmaxNormal. normal_scale: `float`-like `Tensor`. Broadcastable with `normal_loc`. The scale parameter of the Normal used to construct the SoftmaxNormal. quadrature_size: Python `int` scalar representing the number of quadrature points. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. name: Python `str` name prefixed to Ops created by this class. Returns: grid: Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the convex combination of affine parameters for `K` components. `grid[..., :, n]` is the `n`-th grid point, living in the `K - 1` simplex. probs: Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the associated with each grid point. """ with tf.name_scope(name or "quadrature_scheme_softmaxnormal_gauss_hermite"): normal_loc = tf.convert_to_tensor(normal_loc, name="normal_loc") npdt = dtype_util.as_numpy_dtype(normal_loc.dtype) normal_scale = tf.convert_to_tensor(normal_scale, dtype=npdt, name="normal_scale") normal_scale = maybe_check_quadrature_param(normal_scale, "normal_scale", validate_args) grid, probs = np.polynomial.hermite.hermgauss(deg=quadrature_size) grid = grid.astype(npdt) probs = probs.astype(npdt) probs /= np.linalg.norm(probs, ord=1, keepdims=True) probs = tf.convert_to_tensor(probs, name="probs", dtype=npdt) grid = softmax(-distribution_util.pad( (normal_loc[..., tf.newaxis] + np.sqrt(2.) * normal_scale[..., tf.newaxis] * grid), axis=-2, front=True), axis=-2) # shape: [B, components, deg] return grid, probs
def __init__(self, mix_loc, temperature, distribution, loc=None, scale=None, quadrature_size=8, quadrature_fn=quadrature_scheme_softmaxnormal_quantiles, validate_args=False, allow_nan_stats=True, name="VectorDiffeomixture"): """Constructs the VectorDiffeomixture on `R^d`. The vector diffeomixture (VDM) approximates the compound distribution ```none p(x) = int p(x | z) p(z) dz, where z is in the K-simplex, and p(x | z) := p(x | loc=sum_k z[k] loc[k], scale=sum_k z[k] scale[k]) ``` Args: mix_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`. In terms of samples, larger `mix_loc[..., k]` ==> `Z` is more likely to put more weight on its `kth` component. temperature: `float`-like `Tensor`. Broadcastable with `mix_loc`. In terms of samples, smaller `temperature` means one component is more likely to dominate. I.e., smaller `temperature` makes the VDM look more like a standard mixture of `K` components. distribution: `tfp.distributions.Distribution`-like instance. Distribution from which `d` iid samples are used as input to the selected affine transformation. Must be a scalar-batch, scalar-event distribution. Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is a function of non-trainable parameters. WARNING: If you backprop through a VectorDiffeomixture sample and the `distribution` is not `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then the gradient will be incorrect! loc: Length-`K` list of `float`-type `Tensor`s. The `k`-th element represents the `shift` used for the `k`-th affine transformation. If the `k`-th item is `None`, `loc` is implicitly `0`. When specified, must have shape `[B1, ..., Bb, d]` where `b >= 0` and `d` is the event size. scale: Length-`K` list of `LinearOperator`s. Each should be positive-definite and operate on a `d`-dimensional vector space. The `k`-th element represents the `scale` used for the `k`-th affine transformation. `LinearOperator`s must have shape `[B1, ..., Bb, d, d]`, `b >= 0`, i.e., characterizes `b`-batches of `d x d` matrices quadrature_size: Python `int` scalar representing number of quadrature points. Larger `quadrature_size` means `q_N(x)` better approximates `p(x)`. quadrature_fn: Python callable taking `normal_loc`, `normal_scale`, `quadrature_size`, `validate_args` and returning `tuple(grid, probs)` representing the SoftmaxNormal grid and corresponding normalized weight. normalized) weight. Default value: `quadrature_scheme_softmaxnormal_quantiles`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: if `not scale or len(scale) < 2`. ValueError: if `len(loc) != len(scale)` ValueError: if `quadrature_grid_and_probs is not None` and `len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])` ValueError: if `validate_args` and any not scale.is_positive_definite. TypeError: if any scale.dtype != scale[0].dtype. TypeError: if any loc.dtype != scale[0].dtype. NotImplementedError: if `len(scale) != 2`. ValueError: if `not distribution.is_scalar_batch`. ValueError: if `not distribution.is_scalar_event`. """ parameters = dict(locals()) with tf.name_scope(name) as name: if not scale or len(scale) < 2: raise ValueError( "Must specify list (or list-like object) of scale " "LinearOperators, one for each component with " "num_component >= 2.") if loc is None: loc = [None] * len(scale) if len(loc) != len(scale): raise ValueError("loc/scale must be same-length lists " "(or same-length list-like objects).") dtype = dtype_util.base_dtype(scale[0].dtype) loc = [ tf.convert_to_tensor(loc_, dtype=dtype, name="loc{}".format(k)) if loc_ is not None else None for k, loc_ in enumerate(loc) ] for k, scale_ in enumerate(scale): if validate_args and not scale_.is_positive_definite: raise ValueError( "scale[{}].is_positive_definite = {} != True".format( k, scale_.is_positive_definite)) if dtype_util.base_dtype(scale_.dtype) != dtype: raise TypeError( "dtype mismatch; scale[{}].base_dtype=\"{}\" != \"{}\"" .format(k, dtype_util.name(scale_.dtype), dtype_util.name(dtype))) self._endpoint_affine = [ affine_linear_operator_bijector.AffineLinearOperator( # pylint: disable=g-complex-comprehension shift=loc_, scale=scale_, validate_args=validate_args, name="endpoint_affine_{}".format(k)) for k, (loc_, scale_) in enumerate(zip(loc, scale)) ] # TODO(jvdillon): Remove once we support k-mixtures. # We make this assertion here because otherwise `grid` would need to be a # vector not a scalar. if len(scale) != 2: raise NotImplementedError( "Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(scale))) mix_loc = tf.convert_to_tensor(mix_loc, dtype=dtype, name="mix_loc") temperature = tf.convert_to_tensor(temperature, dtype=dtype, name="temperature") self._grid, probs = tuple( quadrature_fn(mix_loc / temperature, 1. / temperature, quadrature_size, validate_args)) # Note: by creating the logits as `log(prob)` we ensure that # `self.mixture_distribution.logits` is equivalent to # `math_ops.log(self.mixture_distribution.probs)`. self._mixture_distribution = categorical.Categorical( logits=tf.math.log(probs), validate_args=validate_args, allow_nan_stats=allow_nan_stats) asserts = distribution_util.maybe_check_scalar_distribution( distribution, dtype, validate_args) if asserts: self._grid = distribution_util.with_dependencies( asserts, self._grid) self._distribution = distribution self._interpolated_affine = [ affine_linear_operator_bijector.AffineLinearOperator( # pylint: disable=g-complex-comprehension shift=loc_, scale=scale_, validate_args=validate_args, name="interpolated_affine_{}".format(k)) for k, (loc_, scale_) in enumerate( zip(interpolate_loc(self._grid, loc), interpolate_scale(self._grid, scale))) ] [ self._batch_shape_, self._batch_shape_tensor_, self._event_shape_, self._event_shape_tensor_, ] = determine_batch_event_shapes(self._grid, self._endpoint_affine) super(VectorDiffeomixture, self).__init__( dtype=dtype, # We hard-code `FULLY_REPARAMETERIZED` because when # `validate_args=True` we verify that indeed # `distribution.reparameterization_type == FULLY_REPARAMETERIZED`. A # distribution which is a function of only non-trainable parameters # also implies we can use `FULLY_REPARAMETERIZED`. However, we cannot # easily test for that possibility thus we use `validate_args=False` # as a "back-door" to allow users a way to use non # `FULLY_REPARAMETERIZED` distribution. In such cases IT IS THE USERS # RESPONSIBILITY to verify that the base distribution is a function of # non-trainable parameters. reparameterization_type=reparameterization. FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, name=name)
def quadrature_scheme_softmaxnormal_quantiles(normal_loc, normal_scale, quadrature_size, validate_args=False, name=None): """Use SoftmaxNormal quantiles to form quadrature on `K - 1` simplex. A `SoftmaxNormal` random variable `Y` may be generated via ``` Y = SoftmaxCentered(X), X = Normal(normal_loc, normal_scale) ``` Args: normal_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0. The location parameter of the Normal used to construct the SoftmaxNormal. normal_scale: `float`-like `Tensor`. Broadcastable with `normal_loc`. The scale parameter of the Normal used to construct the SoftmaxNormal. quadrature_size: Python `int` scalar representing the number of quadrature points. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. name: Python `str` name prefixed to Ops created by this class. Returns: grid: Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the convex combination of affine parameters for `K` components. `grid[..., :, n]` is the `n`-th grid point, living in the `K - 1` simplex. probs: Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the associated with each grid point. """ with tf.name_scope(name or "softmax_normal_grid_and_probs"): normal_loc = tf.convert_to_tensor(normal_loc, name="normal_loc") dt = dtype_util.base_dtype(normal_loc.dtype) normal_scale = tf.convert_to_tensor(normal_scale, dtype=dt, name="normal_scale") normal_scale = maybe_check_quadrature_param(normal_scale, "normal_scale", validate_args) dist = normal.Normal(loc=normal_loc, scale=normal_scale) def _get_batch_ndims(): """Helper to get rank(dist.batch_shape), statically if possible.""" ndims = tensorshape_util.rank(dist.batch_shape) if ndims is None: ndims = tf.shape(dist.batch_shape_tensor())[0] return ndims batch_ndims = _get_batch_ndims() def _get_final_shape(qs): """Helper to build `TensorShape`.""" bs = tensorshape_util.with_rank_at_least(dist.batch_shape, 1) num_components = tf.compat.dimension_value(bs[-1]) if num_components is not None: num_components += 1 tail = tf.TensorShape([num_components, qs]) return bs[:-1].concatenate(tail) def _compute_quantiles(): """Helper to build quantiles.""" # Omit {0, 1} since they might lead to Inf/NaN. zero = tf.zeros([], dtype=dist.dtype) edges = tf.linspace(zero, 1., quadrature_size + 3)[1:-1] # Expand edges so its broadcast across batch dims. edges = tf.reshape( edges, shape=tf.concat( [[-1], tf.ones([batch_ndims], dtype=tf.int32)], axis=0)) quantiles = dist.quantile(edges) quantiles = softmax_centered_bijector.SoftmaxCentered().forward( quantiles) # Cyclically permute left by one. perm = tf.concat([tf.range(1, 1 + batch_ndims), [0]], axis=0) quantiles = tf.transpose(a=quantiles, perm=perm) tensorshape_util.set_shape(quantiles, _get_final_shape(quadrature_size + 1)) return quantiles quantiles = _compute_quantiles() # Compute grid as quantile midpoints. grid = (quantiles[..., :-1] + quantiles[..., 1:]) / 2. # Set shape hints. tensorshape_util.set_shape(grid, _get_final_shape(quadrature_size)) # By construction probs is constant, i.e., `1 / quadrature_size`. This is # important, because non-constant probs leads to non-reparameterizable # samples. probs = tf.fill(dims=[quadrature_size], value=1. / tf.cast(quadrature_size, dist.dtype)) return grid, probs
def _log_survival_function(self, value): rate = tf.convert_to_tensor(self._rate) return self._log_prob(value, rate=rate) - tf.math.log(rate)
def _logits_parameter_no_checks(self): if self._logits is None: probs = tf.convert_to_tensor(self._probs) return tf.math.log(probs) - tf.math.log1p(-probs) return tf.identity(self._logits)
def _sample_n(self, num_samples, seed=None, name=None): """Returns a Tensor of samples from an LKJ distribution. Args: num_samples: Python `int`. The number of samples to draw. seed: Python integer seed for RNG name: Python `str` name prefixed to Ops created by this function. Returns: samples: A Tensor of correlation matrices with shape `[n, B, D, D]`, where `B` is the shape of the `concentration` parameter, and `D` is the `dimension`. Raises: ValueError: If `dimension` is negative. """ if self.dimension < 0: raise ValueError( 'Cannot sample negative-dimension correlation matrices.') # Notation below: B is the batch shape, i.e., tf.shape(concentration) seed = SeedStream(seed, 'sample_lkj') with tf.name_scope('sample_lkj' or name): concentration = tf.convert_to_tensor(self.concentration) if not dtype_util.is_floating(concentration.dtype): raise TypeError( 'The concentration argument should have floating type, not ' '{}'.format(dtype_util.name(concentration.dtype))) concentration = _replicate(num_samples, concentration) concentration_shape = tf.shape(concentration) if self.dimension <= 1: # For any dimension <= 1, there is only one possible correlation matrix. shape = tf.concat( [concentration_shape, [self.dimension, self.dimension]], axis=0) return tf.ones(shape=shape, dtype=concentration.dtype) beta_conc = concentration + (self.dimension - 2.) / 2. beta_dist = beta.Beta(concentration1=beta_conc, concentration0=beta_conc) # Note that the sampler below deviates from [1], by doing the sampling in # cholesky space. This does not change the fundamental logic of the # sampler, but does speed up the sampling. # This is the correlation coefficient between the first two dimensions. # This is also `r` in reference [1]. corr12 = 2. * beta_dist.sample(seed=seed()) - 1. # Below we construct the Cholesky of the initial 2x2 correlation matrix, # which is of the form: # [[1, 0], [r, sqrt(1 - r**2)]], where r is the correlation between the # first two dimensions. # This is the top-left corner of the cholesky of the final sample. first_row = tf.concat([ tf.ones_like(corr12)[..., tf.newaxis], tf.zeros_like(corr12)[..., tf.newaxis] ], axis=-1) second_row = tf.concat([ corr12[..., tf.newaxis], tf.sqrt(1 - corr12**2)[..., tf.newaxis] ], axis=-1) chol_result = tf.concat([ first_row[..., tf.newaxis, :], second_row[..., tf.newaxis, :] ], axis=-2) for n in range(2, self.dimension): # Loop invariant: on entry, result has shape B + [n, n] beta_conc = beta_conc - 0.5 # norm is y in reference [1]. norm = beta.Beta(concentration1=n / 2., concentration0=beta_conc).sample(seed=seed()) # distance shape: B + [1] for broadcast distance = tf.sqrt(norm)[..., tf.newaxis] # direction is u in reference [1]. # direction shape: B + [n] direction = _uniform_unit_norm(n, concentration_shape, concentration.dtype, seed) # raw_correlation is w in reference [1]. raw_correlation = distance * direction # shape: B + [n] # This is the next row in the cholesky of the result, # which differs from the construction in reference [1]. # In the reference, the new row `z` = chol_result @ raw_correlation^T # = C @ raw_correlation^T (where as short hand we use C = chol_result). # We prove that the below equation is the right row to add to the # cholesky, by showing equality with reference [1]. # Let S be the sample constructed so far, and let `z` be as in # reference [1]. Then at this iteration, the new sample S' will be # [[S z^T] # [z 1]] # In our case we have the cholesky decomposition factor C, so # we want our new row x (same size as z) to satisfy: # [[S z^T] [[C 0] [[C^T x^T] [[CC^T Cx^T] # [z 1]] = [x k]] [0 k]] = [xC^t xx^T + k**2]] # Since C @ raw_correlation^T = z = C @ x^T, and C is invertible, # we have that x = raw_correlation. Also 1 = xx^T + k**2, so k # = sqrt(1 - xx^T) = sqrt(1 - |raw_correlation|**2) = sqrt(1 - # distance**2). new_row = tf.concat( [raw_correlation, tf.sqrt(1. - norm[..., tf.newaxis])], axis=-1) # Finally add this new row, by growing the cholesky of the result. chol_result = tf.concat([ chol_result, tf.zeros_like(chol_result[..., 0][..., tf.newaxis]) ], axis=-1) chol_result = tf.concat( [chol_result, new_row[..., tf.newaxis, :]], axis=-2) if self.input_output_cholesky: return chol_result result = tf.matmul(chol_result, chol_result, transpose_b=True) # The diagonal for a correlation matrix should always be ones. Due to # numerical instability the matmul might not achieve that, so manually set # these to ones. result = tf.linalg.set_diag( result, tf.ones(shape=tf.shape(result)[:-1], dtype=result.dtype)) # This sampling algorithm can produce near-PSD matrices on which standard # algorithms such as `tf.cholesky` or `tf.linalg.self_adjoint_eigvals` # fail. Specifically, as documented in b/116828694, around 2% of trials # of 900,000 5x5 matrices (distributed according to 9 different # concentration parameter values) contained at least one matrix on which # the Cholesky decomposition failed. return result
def _param_shapes(sample_shape): return dict( zip(('concentration', 'rate'), ([tf.convert_to_tensor(sample_shape, dtype=tf.int32)] * 2)))
def _stddev(self): samples = tf.convert_to_tensor(self._samples) axis = self._samples_axis r = samples - tf.expand_dims(self._mean(samples), axis=axis) var = tf.reduce_mean(tf.square(r), axis=axis) return tf.sqrt(var)
def _mean(self, samples=None): if samples is None: samples = tf.convert_to_tensor(self._samples) return tf.reduce_mean(samples, axis=self._samples_axis)
def _event_shape_tensor(self, samples=None): if samples is None: samples = tf.convert_to_tensor(self.samples) return tf.shape(samples)[self._samples_axis + 1:]
def _batch_shape_tensor(self, samples=None): if samples is None: samples = tf.convert_to_tensor(self.samples) return tf.shape(samples)[:self._samples_axis]
def _compute_num_samples(self, samples): samples_shape = distribution_util.prefer_static_shape(samples) return tf.convert_to_tensor(samples_shape[self._samples_axis], dtype_hint=tf.int32, name='num_samples')
def _param_shapes(sample_shape): return dict( zip(('df', 'loc', 'scale'), ([tf.convert_to_tensor(sample_shape, dtype=tf.int32)] * 3)))
def _param_shapes(sample_shape): return dict( zip(('low', 'high'), ([tf.convert_to_tensor(sample_shape, dtype=tf.int32)] * 2)))
def _mode(self): loc = tf.convert_to_tensor(self.loc) return tf.broadcast_to(loc, self._batch_shape_tensor(loc=loc))
def _convert_to_tensor(x, name, dtype=None): return None if x is None else tf.convert_to_tensor( x, name=name, dtype=dtype)
def _entropy(self): concentration = tf.convert_to_tensor(self.concentration) return (concentration - tf.math.log(self.rate) + tf.math.lgamma(concentration) + ((1. - concentration) * tf.math.digamma(concentration)))
def _variance(self): p = self._probs_parameter_no_checks() k = tf.convert_to_tensor(self.total_count) return k[..., tf.newaxis] * p * (1. - p)
def __init__(self, distribution, low=None, high=None, validate_args=False, name="QuantizedDistribution"): """Construct a Quantized Distribution representing `Y = ceiling(X)`. Some properties are inherited from the distribution defining `X`. Example: `allow_nan_stats` is determined for this `QuantizedDistribution` by reading the `distribution`. Args: distribution: The base distribution class to transform. Typically an instance of `Distribution`. low: `Tensor` with same `dtype` as this distribution and shape able to be added to samples. Should be a whole number. Default `None`. If provided, base distribution's `prob` should be defined at `low`. high: `Tensor` with same `dtype` as this distribution and shape able to be added to samples. Should be a whole number. Default `None`. If provided, base distribution's `prob` should be defined at `high - 1`. `high` must be strictly greater than `low`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. name: Python `str` name prefixed to Ops created by this class. Raises: TypeError: If `dist_cls` is not a subclass of `Distribution` or continuous. NotImplementedError: If the base distribution does not implement `cdf`. """ parameters = dict(locals()) with tf.name_scope(name) as name: self._dist = distribution if low is not None: low = tf.convert_to_tensor(low, name="low", dtype=distribution.dtype) if high is not None: high = tf.convert_to_tensor(high, name="high", dtype=distribution.dtype) dtype_util.assert_same_float_dtype( tensors=[self.distribution, low, high]) checks = [] if validate_args and low is not None and high is not None: message = "low must be strictly less than high." checks.append( assert_util.assert_less(low, high, message=message)) self._validate_args = validate_args # self._check_integer uses this. with tf.control_dependencies(checks if validate_args else []): if low is not None: self._low = self._check_integer(low) else: self._low = None if high is not None: self._high = self._check_integer(high) else: self._high = None super(QuantizedDistribution, self).__init__( dtype=self._dist.dtype, reparameterization_type=reparameterization.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=self._dist.allow_nan_stats, parameters=parameters, name=name)
def _prob(self, x): loc = tf.convert_to_tensor(self.loc) # Enforces dtype of probability to be float, when self.dtype is not. prob_dtype = self.dtype if self.dtype.is_floating else tf.float32 return tf.cast(tf.abs(x - loc) <= self._slack(loc), dtype=prob_dtype)
def _parameter_control_dependencies(self, is_init): assertions = [] logits = self._logits probs = self._probs param, name = (probs, 'probs') if logits is None else (logits, 'logits') # In init, we can always build shape and dtype checks because # we assume shape doesn't change for Variable backed args. if is_init: if not dtype_util.is_floating(param.dtype): raise TypeError( 'Argument `{}` must having floating type.'.format(name)) msg = 'Argument `{}` must have rank at least 1.'.format(name) shape_static = tensorshape_util.dims(param.shape) if shape_static is not None: if len(shape_static) < 1: raise ValueError(msg) elif self.validate_args: param = tf.convert_to_tensor(param) assertions.append( assert_util.assert_rank_at_least(param, 1, message=msg)) with tf.control_dependencies(assertions): param = tf.identity(param) msg1 = 'Argument `{}` must have final dimension >= 1.'.format(name) msg2 = 'Argument `{}` must have final dimension <= {}.'.format( name, tf.int32.max) event_size = shape_static[-1] if shape_static is not None else None if event_size is not None: if event_size < 1: raise ValueError(msg1) if event_size > tf.int32.max: raise ValueError(msg2) elif self.validate_args: param = tf.convert_to_tensor(param) assertions.append( assert_util.assert_greater_equal(tf.shape(param)[-1], 1, message=msg1)) # NOTE: For now, we leave out a runtime assertion that # `tf.shape(param)[-1] <= tf.int32.max`. An earlier `tf.shape` call # will fail before we get to this point. if not self.validate_args: assert not assertions # Should never happen. return [] if probs is not None: probs = param # reuse tensor conversion from above if is_init != tensor_util.is_ref(probs): probs = tf.convert_to_tensor(probs) one = tf.ones([], dtype=probs.dtype) assertions.extend([ assert_util.assert_non_negative(probs), assert_util.assert_less_equal(probs, one), assert_util.assert_near( tf.reduce_sum(probs, axis=-1), one, message='Argument `probs` must sum to 1.'), ]) return assertions
def _param_shapes(sample_shape): return {"rate": tf.convert_to_tensor(sample_shape, dtype=tf.int32)}