def _make_columnar(self, x): """Ensures non-scalar input has at least one column. Example: If `x = [1, 2, 3]` then the output is `[[1], [2], [3]]`. If `x = [[1, 2, 3], [4, 5, 6]]` then the output is unchanged. If `x = 1` then the output is unchanged. Args: x: `Tensor`. Returns: columnar_x: `Tensor` with at least two dimensions. """ if tensorshape_util.rank(x.shape) is not None: if tensorshape_util.rank(x.shape) == 1: x = x[tf.newaxis, :] return x shape = tf.shape(x) maybe_expanded_shape = tf.concat([ shape[:-1], distribution_util.pick_vector( tf.equal(tf.rank(x), 1), [1], np.array([], dtype=np.int32)), shape[-1:], ], 0) return tf.reshape(x, maybe_expanded_shape)
def _sample_n(self, n, seed=None, **distribution_kwargs): sample_shape = prefer_static.concat([ distribution_util.pick_vector(self._needs_rotation, self._empty, [n]), self._override_batch_shape, self._override_event_shape, distribution_util.pick_vector(self._needs_rotation, [n], self._empty), ], axis=0) x = self.distribution.sample(sample_shape=sample_shape, seed=seed, **distribution_kwargs) x = self._maybe_rotate_dims(x) # We'll apply the bijector in the `_call_sample_n` function. return x
def _sample_n(self, n, seed=None): # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. distributions = self.poisson_and_mixture_distributions() dist, mixture_dist = distributions batch_size = tensorshape_util.num_elements(self.batch_shape) if batch_size is None: batch_size = tf.reduce_prod( self._batch_shape_tensor(distributions=distributions)) # We need to 'sample extra' from the mixture distribution if it doesn't # already specify a probs vector for each batch coordinate. # We only support this kind of reduced broadcasting, i.e., there is exactly # one probs vector for all batch dims or one for each. stream = SeedStream(seed, salt='PoissonLogNormalQuadratureCompound') ids = mixture_dist.sample(sample_shape=concat_vectors( [n], distribution_util.pick_vector(mixture_dist.is_scalar_batch(), [batch_size], np.int32([]))), seed=stream()) # We need to flatten batch dims in case mixture_dist has its own # batch dims. ids = tf.reshape(ids, shape=concat_vectors([n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), np.int32([-1])))) # Stride `quadrature_size` for `batch_size` number of times. offset = tf.range(start=0, limit=batch_size * self._quadrature_size, delta=self._quadrature_size, dtype=ids.dtype) ids = ids + offset rate = tf.gather(tf.reshape(dist.rate, shape=[-1]), ids) rate = tf.reshape( rate, shape=concat_vectors( [n], self._batch_shape_tensor(distributions=distributions))) return tf.random.poisson(lam=rate, shape=[], dtype=self.dtype, seed=seed)
def _batch_shape_tensor(self): return distribution_util.pick_vector( self._is_batch_override, self._override_batch_shape, self.distribution.batch_shape_tensor())
def _event_shape_tensor(self): return self.bijector.forward_event_shape_tensor( distribution_util.pick_vector( self._is_event_override, self._override_event_shape, self.distribution.event_shape_tensor()))
def _sample_n(self, n, seed=None): stream = SeedStream(seed, salt="VectorDiffeomixture") x = self.distribution.sample(sample_shape=concat_vectors( [n], self.batch_shape_tensor(), self.event_shape_tensor()), seed=stream()) # shape: [n, B, e] x = [aff.forward(x) for aff in self.endpoint_affine] # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. batch_size = tensorshape_util.num_elements(self.batch_shape) if batch_size is None: batch_size = tf.reduce_prod(self.batch_shape_tensor()) mix_batch_size = tensorshape_util.num_elements( self.mixture_distribution.batch_shape) if mix_batch_size is None: mix_batch_size = tf.reduce_prod( self.mixture_distribution.batch_shape_tensor()) ids = self.mixture_distribution.sample(sample_shape=concat_vectors( [n], distribution_util.pick_vector(self.is_scalar_batch(), np.int32([]), [batch_size // mix_batch_size])), seed=stream()) # We need to flatten batch dims in case mixture_distribution has its own # batch dims. ids = tf.reshape(ids, shape=concat_vectors([n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), np.int32([-1])))) # Stride `components * quadrature_size` for `batch_size` number of times. stride = tensorshape_util.num_elements( tensorshape_util.with_rank_at_least(self.grid.shape, 2)[-2:]) if stride is None: stride = tf.reduce_prod(tf.shape(self.grid)[-2:]) offset = tf.range(start=0, limit=batch_size * stride, delta=stride, dtype=ids.dtype) weight = tf.gather(tf.reshape(self.grid, shape=[-1]), ids + offset) # At this point, weight flattened all batch dims into one. # We also need to append a singleton to broadcast with event dims. if tensorshape_util.is_fully_defined(self.batch_shape): new_shape = [-1] + tensorshape_util.as_list(self.batch_shape) + [1] else: new_shape = tf.concat(([-1], self.batch_shape_tensor(), [1]), axis=0) weight = tf.reshape(weight, shape=new_shape) if len(x) != 2: # We actually should have already triggered this exception. However as a # policy we're putting this exception wherever we exploit the bimixture # assumption. raise NotImplementedError( "Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(x))) # Alternatively: # x = weight * x[0] + (1. - weight) * x[1] x = weight * (x[0] - x[1]) + x[1] return x
def _forward_log_det_jacobian(self, x): # Let Y be a symmetric, positive definite matrix and write: # Y = X X.T # where X is lower-triangular. # # Observe that, # dY[i,j]/dX[a,b] # = d/dX[a,b] { X[i,:] X[j,:] } # = sum_{d=1}^p { I[i=a] I[d=b] X[j,d] + I[j=a] I[d=b] X[i,d] } # # To compute the Jacobian dX/dY we must represent X,Y as vectors. Since Y is # symmetric and X is lower-triangular, we need vectors of dimension: # d = p (p + 1) / 2 # where X, Y are p x p matrices, p > 0. We use a row-major mapping, i.e., # k = { i (i + 1) / 2 + j i>=j # { undef i<j # and assume zero-based indexes. When k is undef, the element is dropped. # Example: # j k # 0 1 2 3 / # 0 [ 0 . . . ] # i 1 [ 1 2 . . ] # 2 [ 3 4 5 . ] # 3 [ 6 7 8 9 ] # Write vec[.] to indicate transforming a matrix to vector via k(i,j). (With # slight abuse: k(i,j)=undef means the element is dropped.) # # We now show d vec[Y] / d vec[X] is lower triangular. Assuming both are # defined, observe that k(i,j) < k(a,b) iff (1) i<a or (2) i=a and j<b. # In both cases dvec[Y]/dvec[X]@[k(i,j),k(a,b)] = 0 since: # (1) j<=i<a thus i,j!=a. # (2) i=a>j thus i,j!=a. # # Since the Jacobian is lower-triangular, we need only compute the product # of diagonal elements: # d vec[Y] / d vec[X] @[k(i,j), k(i,j)] # = X[j,j] + I[i=j] X[i,j] # = 2 X[j,j]. # Since there is a 2 X[j,j] term for every lower-triangular element of X we # conclude: # |Jac(d vec[Y]/d vec[X])| = 2^p prod_{j=0}^{p-1} X[j,j]^{p-j}. diag = tf.linalg.diag_part(x) # We now ensure diag is columnar. Eg, if `diag = [1, 2, 3]` then the output # is `[[1], [2], [3]]` and if `diag = [[1, 2, 3], [4, 5, 6]]` then the # output is unchanged. diag = self._make_columnar(diag) with tf.control_dependencies(self._assertions(x)): # Create a vector equal to: [p, p-1, ..., 2, 1]. if tf.compat.dimension_value(x.shape[-1]) is None: p_int = tf.shape(x)[-1] p_float = tf.cast(p_int, dtype=x.dtype) else: p_int = tf.compat.dimension_value(x.shape[-1]) p_float = dtype_util.as_numpy_dtype(x.dtype)(p_int) exponents = tf.linspace(p_float, 1., p_int) sum_weighted_log_diag = tf.squeeze( tf.matmul(tf.math.log(diag), exponents[..., tf.newaxis]), axis=-1) fldj = p_float * np.log(2.) + sum_weighted_log_diag # We finally need to undo adding an extra column in non-scalar cases # where there is a single matrix as input. if tensorshape_util.rank(x.shape) is not None: if tensorshape_util.rank(x.shape) == 2: fldj = tf.squeeze(fldj, axis=-1) return fldj shape = tf.shape(fldj) maybe_squeeze_shape = tf.concat([ shape[:-1], distribution_util.pick_vector( tf.equal(tf.rank(x), 2), np.array([], dtype=np.int32), shape[-1:])], 0) return tf.reshape(fldj, maybe_squeeze_shape)
def __init__(self, mixture_distribution, components_distribution, reparameterize=False, validate_args=False, allow_nan_stats=True, name="MixtureSameFamily"): """Construct a `MixtureSameFamily` distribution. Args: mixture_distribution: `tfp.distributions.Categorical`-like instance. Manages the probability of selecting components. The number of categories must match the rightmost batch dimension of the `components_distribution`. Must have either scalar `batch_shape` or `batch_shape` matching `components_distribution.batch_shape[:-1]`. components_distribution: `tfp.distributions.Distribution`-like instance. Right-most batch dimension indexes components. reparameterize: Python `bool`, default `False`. Whether to reparameterize samples of the distribution using implicit reparameterization gradients [(Figurnov et al., 2018)][1]. The gradients for the mixture logits are equivalent to the ones described by [(Graves, 2016)][2]. The gradients for the components parameters are also computed using implicit reparameterization (as opposed to ancestral sampling), meaning that all components are updated every step. Only works when: (1) components_distribution is fully reparameterized; (2) components_distribution is either a scalar distribution or fully factorized (tfd.Independent applied to a scalar distribution); (3) batch shape has a known rank. Experimental, may be slow and produce infs/NaNs. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: `if not dtype_util.is_integer(mixture_distribution.dtype)`. ValueError: if mixture_distribution does not have scalar `event_shape`. ValueError: if `mixture_distribution.batch_shape` and `components_distribution.batch_shape[:-1]` are both fully defined and the former is neither scalar nor equal to the latter. ValueError: if `mixture_distribution` categories does not equal `components_distribution` rightmost batch shape. #### References [1]: Michael Figurnov, Shakir Mohamed and Andriy Mnih. Implicit reparameterization gradients. In _Neural Information Processing Systems_, 2018. https://arxiv.org/abs/1805.08498 [2]: Alex Graves. Stochastic Backpropagation through Mixture Density Distributions. _arXiv_, 2016. https://arxiv.org/abs/1607.05690 """ parameters = dict(locals()) with tf.name_scope(name) as name: self._mixture_distribution = mixture_distribution self._components_distribution = components_distribution self._runtime_assertions = [] s = components_distribution.event_shape_tensor() self._event_ndims = tf.compat.dimension_value(s.shape[0]) if self._event_ndims is None: self._event_ndims = tf.size(s) self._event_size = tf.reduce_prod(s) if not dtype_util.is_integer(mixture_distribution.dtype): raise ValueError( "`mixture_distribution.dtype` ({}) is not over integers". format(dtype_util.name(mixture_distribution.dtype))) if (tensorshape_util.rank(mixture_distribution.event_shape) is not None and tensorshape_util.rank( mixture_distribution.event_shape) != 0): raise ValueError( "`mixture_distribution` must have scalar `event_dim`s") elif validate_args: self._runtime_assertions += [ assert_util.assert_equal( tf.size(mixture_distribution.event_shape_tensor()), 0, message= "`mixture_distribution` must have scalar `event_dim`s" ), ] mdbs = mixture_distribution.batch_shape cdbs = tensorshape_util.with_rank_at_least( components_distribution.batch_shape, 1)[:-1] if tensorshape_util.is_fully_defined( mdbs) and tensorshape_util.is_fully_defined(cdbs): if tensorshape_util.rank(mdbs) != 0 and mdbs != cdbs: raise ValueError( "`mixture_distribution.batch_shape` (`{}`) is not " "compatible with `components_distribution.batch_shape` " "(`{}`)".format(tensorshape_util.as_list(mdbs), tensorshape_util.as_list(cdbs))) elif validate_args: mdbs = mixture_distribution.batch_shape_tensor() cdbs = components_distribution.batch_shape_tensor()[:-1] self._runtime_assertions += [ assert_util.assert_equal( distribution_utils.pick_vector( mixture_distribution.is_scalar_batch(), cdbs, mdbs), cdbs, message= ("`mixture_distribution.batch_shape` is not " "compatible with `components_distribution.batch_shape`" )) ] mixture_dist_param = (mixture_distribution.probs if mixture_distribution.logits is None else mixture_distribution.logits) km = tf.compat.dimension_value( tensorshape_util.with_rank_at_least(mixture_dist_param.shape, 1)[-1]) kc = tf.compat.dimension_value( tensorshape_util.with_rank_at_least( components_distribution.batch_shape, 1)[-1]) if km is not None and kc is not None and km != kc: raise ValueError( "`mixture_distribution components` ({}) does not " "equal `components_distribution.batch_shape[-1]` " "({})".format(km, kc)) elif validate_args: km = tf.shape(mixture_dist_param)[-1] kc = components_distribution.batch_shape_tensor()[-1] self._runtime_assertions += [ assert_util.assert_equal( km, kc, message=( "`mixture_distribution components` does not equal " "`components_distribution.batch_shape[-1:]`")), ] elif km is None: km = tf.shape(mixture_dist_param)[-1] self._num_components = km self._reparameterize = reparameterize if reparameterize: # Note: tfd.Independent passes through the reparameterization type hence # we do not need separate logic for Independent. if (self._components_distribution.reparameterization_type != reparameterization.FULLY_REPARAMETERIZED): raise ValueError("Cannot reparameterize a mixture of " "non-reparameterized components.") reparameterization_type = reparameterization.FULLY_REPARAMETERIZED else: reparameterization_type = reparameterization.NOT_REPARAMETERIZED super(MixtureSameFamily, self).__init__( dtype=self._components_distribution.dtype, reparameterization_type=reparameterization_type, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, name=name)
def __init__(self, df, loc=None, scale_identity_multiplier=None, scale_diag=None, scale_tril=None, scale_perturb_factor=None, scale_perturb_diag=None, validate_args=False, allow_nan_stats=True, name="VectorStudentT"): """Instantiates the vector Student's t-distributions on `R^k`. The `batch_shape` is the broadcast between `df.batch_shape` and `Affine.batch_shape` where `Affine` is constructed from `loc` and `scale_*` arguments. The `event_shape` is the event shape of `Affine.event_shape`. Args: df: Floating-point `Tensor`. The degrees of freedom of the distribution(s). `df` must contain only positive values. Must be scalar if `loc`, `scale_*` imply non-scalar batch_shape or must have the same `batch_shape` implied by `loc`, `scale_*`. loc: Floating-point `Tensor`. If this is set to `None`, no `loc` is applied. scale_identity_multiplier: floating point rank 0 `Tensor` representing a scaling done to the identity matrix. When `scale_identity_multiplier = scale_diag=scale_tril = None` then `scale += IdentityMatrix`. Otherwise no scaled-identity-matrix is added to `scale`. scale_diag: Floating-point `Tensor` representing the diagonal matrix. `scale_diag` has shape [N1, N2, ..., k], which represents a k x k diagonal matrix. When `None` no diagonal term is added to `scale`. scale_tril: Floating-point `Tensor` representing the diagonal matrix. `scale_diag` has shape [N1, N2, ..., k, k], which represents a k x k lower triangular matrix. When `None` no `scale_tril` term is added to `scale`. The upper triangular elements above the diagonal are ignored. scale_perturb_factor: Floating-point `Tensor` representing factor matrix with last two dimensions of shape `(k, r)`. When `None`, no rank-r update is added to `scale`. scale_perturb_diag: Floating-point `Tensor` representing the diagonal matrix. `scale_perturb_diag` has shape [N1, N2, ..., r], which represents an r x r Diagonal matrix. When `None` low rank updates will take the form `scale_perturb_factor * scale_perturb_factor.T`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = dict(locals()) args = [ df, loc, scale_identity_multiplier, scale_diag, scale_tril, scale_perturb_factor, scale_perturb_diag ] with tf.name_scope(name) as name: with tf.name_scope("init"): dtype = dtype_util.common_dtype(args, tf.float32) df = tf.convert_to_tensor(df, name="df", dtype=dtype) # The shape of the _VectorStudentT distribution is governed by the # relationship between df.batch_shape and affine.batch_shape. In # pseudocode the basic procedure is: # if df.batch_shape is scalar: # if affine.batch_shape is not scalar: # # broadcast distribution.sample so # # it has affine.batch_shape. # self.batch_shape = affine.batch_shape # else: # if affine.batch_shape is scalar: # # let affine broadcasting do its thing. # self.batch_shape = df.batch_shape # All of the above magic is actually handled by TransformedDistribution. # Here we really only need to collect the affine.batch_shape and decide # what we're going to pass in to TransformedDistribution's # (override) batch_shape arg. affine = affine_bijector.Affine( shift=loc, scale_identity_multiplier=scale_identity_multiplier, scale_diag=scale_diag, scale_tril=scale_tril, scale_perturb_factor=scale_perturb_factor, scale_perturb_diag=scale_perturb_diag, validate_args=validate_args, dtype=dtype) distribution = student_t.StudentT( df=df, loc=tf.zeros([], dtype=affine.dtype), scale=tf.ones([], dtype=affine.dtype)) batch_shape, override_event_shape = ( distribution_util.shapes_from_loc_and_scale( affine.shift, affine.scale)) override_batch_shape = distribution_util.pick_vector( distribution.is_scalar_batch(), batch_shape, tf.constant([], dtype=tf.int32)) super(_VectorStudentT, self).__init__(distribution=distribution, bijector=affine, batch_shape=override_batch_shape, event_shape=override_event_shape, validate_args=validate_args, name=name) self._parameters = parameters