def _sample_n(self, n, seed=None): sample_shape = _concat_vectors( distribution_util.pick_vector(self._needs_rotation, self._empty, [n]), self._override_batch_shape, self._override_event_shape, distribution_util.pick_vector(self._needs_rotation, [n], self._empty)) x = self.distribution.sample(sample_shape=sample_shape, seed=seed) x = self._maybe_rotate_dims(x) # We'll apply the bijector in the `_call_sample_n` function. return x
def _sample_n(self, n, seed=None): sample_shape = _concat_vectors( distribution_util.pick_vector(self._needs_rotation, self._empty, [n]), self._override_batch_shape, self._override_event_shape, distribution_util.pick_vector(self._needs_rotation, [n], self._empty)) x = self.distribution.sample(sample_shape=sample_shape, seed=seed) x = self._maybe_rotate_dims(x) # We'll apply the bijector in the `_call_sample_n` function. return x
def _sample_n(self, n, seed=None, **distribution_kwargs): sample_shape = prefer_static.concat([ distribution_util.pick_vector(self._needs_rotation, self._empty, [n]), self._override_batch_shape, self._override_event_shape, distribution_util.pick_vector(self._needs_rotation, [n], self._empty), ], axis=0) x = self.distribution.sample(sample_shape=sample_shape, seed=seed, **distribution_kwargs) x = self._maybe_rotate_dims(x) # We'll apply the bijector in the `_call_sample_n` function. return x
def testCorrectlyPicksVector(self): x = np.arange(10, 12) y = np.arange(15, 18) self.assertAllEqual( x, self.evaluate(distribution_util.pick_vector(tf.less(0, 5), x, y))) self.assertAllEqual( y, self.evaluate(distribution_util.pick_vector(tf.less(5, 0), x, y))) self.assertAllEqual(x, distribution_util.pick_vector( tf.constant(True), x, y)) # No eval. self.assertAllEqual(y, distribution_util.pick_vector( tf.constant(False), x, y)) # No eval.
def _make_columnar(self, x): """Ensures non-scalar input has at least one column. Example: If `x = [1, 2, 3]` then the output is `[[1], [2], [3]]`. If `x = [[1, 2, 3], [4, 5, 6]]` then the output is unchanged. If `x = 1` then the output is unchanged. Args: x: `Tensor`. Returns: columnar_x: `Tensor` with at least two dimensions. """ if x.shape.ndims is not None: if x.shape.ndims == 1: x = x[tf.newaxis, :] return x shape = tf.shape(x) maybe_expanded_shape = tf.concat([ shape[:-1], distribution_util.pick_vector( tf.equal(tf.rank(x), 1), [1], np.array([], dtype=np.int32)), shape[-1:], ], 0) return tf.reshape(x, maybe_expanded_shape)
def _expand_sample_shape_to_vector(self, x, name): """Helper to `sample` which ensures input is 1D.""" x_static_val = tensor_util.constant_value(x) if x_static_val is None: prod = tf.reduce_prod(x) else: prod = np.prod(x_static_val, dtype=x.dtype.as_numpy_dtype()) ndims = x.get_shape().ndims # != sample_ndims if ndims is None: # Maybe expand_dims. ndims = tf.rank(x) expanded_shape = util.pick_vector(tf.equal(ndims, 0), np.array([1], dtype=np.int32), tf.shape(x)) x = tf.reshape(x, expanded_shape) elif ndims == 0: # Definitely expand_dims. if x_static_val is not None: x = tf.convert_to_tensor(np.array( [x_static_val], dtype=x.dtype.as_numpy_dtype()), name=name) else: x = tf.reshape(x, [1]) elif ndims != 1: raise ValueError("Input is neither scalar nor vector.") return x, prod
def _batch_shape_tensor(self, override_batch_shape=None, base_batch_shape_tensor=None): override_batch_shape = (tf.convert_to_tensor( self._override_batch_shape) if override_batch_shape is None else override_batch_shape) base_batch_shape_tensor = (self.distribution.batch_shape_tensor() if base_batch_shape_tensor is None else base_batch_shape_tensor) # The `batch_shape_tensor` of the transformed distribution is the same as # that of the base distribution in all cases except when the following are # both true: # - the base distribution is joint with structured `batch_shape_tensor` # - the transformed distribution is not joint. # In this case, the components of the base distribution's # `batch_shape_tensor` are broadcast to obtain the `batch_shape_tensor` of # the transformed distribution. Non-broadcasting components are not # supported. (Note that joint distributions may either have a single # `batch_shape_tensor` for all components, or a component-wise # `batch_shape_tensor` with the same nested structure as the distribution's # dtype.) if tf.nest.is_nested(base_batch_shape_tensor): if self._is_joint: return base_batch_shape_tensor base_batch_shape_tensor = functools.reduce( prefer_static.broadcast_shape, tf.nest.flatten(base_batch_shape_tensor)) # If the batch shape has been overridden, return the override batch shape # instead. return distribution_util.pick_vector( self._has_nonzero_rank(override_batch_shape), override_batch_shape, base_batch_shape_tensor)
def _make_columnar(self, x): """Ensures non-scalar input has at least one column. Example: If `x = [1, 2, 3]` then the output is `[[1], [2], [3]]`. If `x = [[1, 2, 3], [4, 5, 6]]` then the output is unchanged. If `x = 1` then the output is unchanged. Args: x: `Tensor`. Returns: columnar_x: `Tensor` with at least two dimensions. """ if tensorshape_util.rank(x.shape) is not None: if tensorshape_util.rank(x.shape) == 1: x = x[tf.newaxis, :] return x shape = tf.shape(x) maybe_expanded_shape = tf.concat([ shape[:-1], distribution_util.pick_vector( tf.equal(tf.rank(x), 1), [1], np.array([], dtype=np.int32)), shape[-1:], ], 0) return tf.reshape(x, maybe_expanded_shape)
def _expand_sample_shape_to_vector(self, x, name): """Helper to `sample` which ensures input is 1D.""" x_static_val = tf.contrib.util.constant_value(x) if x_static_val is None: prod = tf.reduce_prod(x) else: prod = np.prod(x_static_val, dtype=x.dtype.as_numpy_dtype()) ndims = x.shape.ndims # != sample_ndims if ndims is None: # Maybe expand_dims. ndims = tf.rank(x) expanded_shape = util.pick_vector( tf.equal(ndims, 0), np.array([1], dtype=np.int32), tf.shape(x)) x = tf.reshape(x, expanded_shape) elif ndims == 0: # Definitely expand_dims. if x_static_val is not None: x = tf.convert_to_tensor( np.array([x_static_val], dtype=x.dtype.as_numpy_dtype()), name=name) else: x = tf.reshape(x, [1]) elif ndims != 1: raise ValueError("Input is neither scalar nor vector.") return x, prod
def _sample_n(self, n, seed=None): # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. distributions = self.poisson_and_mixture_distributions() dist, mixture_dist = distributions batch_size = tensorshape_util.num_elements(self.batch_shape) if batch_size is None: batch_size = tf.reduce_prod( self._batch_shape_tensor(distributions=distributions)) # We need to 'sample extra' from the mixture distribution if it doesn't # already specify a probs vector for each batch coordinate. # We only support this kind of reduced broadcasting, i.e., there is exactly # one probs vector for all batch dims or one for each. mixture_seed, poisson_seed = samplers.split_seed( seed, salt='PoissonLogNormalQuadratureCompound') ids = mixture_dist.sample( sample_shape=concat_vectors( [n], distribution_util.pick_vector( mixture_dist.is_scalar_batch(), [batch_size], np.int32([]))), seed=mixture_seed) # We need to flatten batch dims in case mixture_dist has its own # batch dims. ids = tf.reshape( ids, shape=concat_vectors([n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), np.int32([-1])))) # Stride `quadrature_size` for `batch_size` number of times. offset = tf.range( start=0, limit=batch_size * self._quadrature_size, delta=self._quadrature_size, dtype=ids.dtype) ids = ids + offset rate = tf.gather(tf.reshape(dist.rate_parameter(), shape=[-1]), ids) rate = tf.reshape( rate, shape=concat_vectors([n], self._batch_shape_tensor( distributions=distributions))) return samplers.poisson( shape=[], lam=rate, dtype=self.dtype, seed=poisson_seed)
def _batch_shape_tensor(self, override_batch_shape=None, base_batch_shape_tensor=None): override_batch_shape = (tf.convert_to_tensor( self._override_batch_shape) if override_batch_shape is None else override_batch_shape) base_batch_shape_tensor = (self.distribution.batch_shape_tensor() if base_batch_shape_tensor is None else base_batch_shape_tensor) return distribution_util.pick_vector( self._has_nonzero_rank(override_batch_shape), override_batch_shape, base_batch_shape_tensor)
def _sample_n(self, n, seed=None): # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. batch_size = self.batch_shape.num_elements() if batch_size is None: batch_size = tf.reduce_prod(self.batch_shape_tensor()) # We need to "sample extra" from the mixture distribution if it doesn't # already specify a probs vector for each batch coordinate. # We only support this kind of reduced broadcasting, i.e., there is exactly # one probs vector for all batch dims or one for each. stream = seed_stream.SeedStream( seed, salt="PoissonLogNormalQuadratureCompound") ids = self._mixture_distribution.sample( sample_shape=concat_vectors( [n], distribution_util.pick_vector( self.mixture_distribution.is_scalar_batch(), [batch_size], np.int32([]))), seed=stream()) # We need to flatten batch dims in case mixture_distribution has its own # batch dims. ids = tf.reshape( ids, shape=concat_vectors([n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), np.int32([-1])))) # Stride `quadrature_size` for `batch_size` number of times. offset = tf.range( start=0, limit=batch_size * self._quadrature_size, delta=self._quadrature_size, dtype=ids.dtype) ids += offset rate = tf.gather(tf.reshape(self.distribution.rate, shape=[-1]), ids) rate = tf.reshape( rate, shape=concat_vectors([n], self.batch_shape_tensor())) return tf.random_poisson(lam=rate, shape=[], dtype=self.dtype, seed=seed)
def _sample_n(self, n, seed=None): # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. batch_size = self.batch_shape.num_elements() if batch_size is None: batch_size = tf.reduce_prod(input_tensor=self.batch_shape_tensor()) # We need to "sample extra" from the mixture distribution if it doesn't # already specify a probs vector for each batch coordinate. # We only support this kind of reduced broadcasting, i.e., there is exactly # one probs vector for all batch dims or one for each. stream = seed_stream.SeedStream( seed, salt="PoissonLogNormalQuadratureCompound") ids = self._mixture_distribution.sample(sample_shape=concat_vectors( [n], distribution_util.pick_vector( self.mixture_distribution.is_scalar_batch(), [batch_size], np.int32([]))), seed=stream()) # We need to flatten batch dims in case mixture_distribution has its own # batch dims. ids = tf.reshape(ids, shape=concat_vectors([n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), np.int32([-1])))) # Stride `quadrature_size` for `batch_size` number of times. offset = tf.range(start=0, limit=batch_size * self._quadrature_size, delta=self._quadrature_size, dtype=ids.dtype) ids += offset rate = tf.gather(tf.reshape(self.distribution.rate, shape=[-1]), ids) rate = tf.reshape(rate, shape=concat_vectors([n], self.batch_shape_tensor())) return tf.random.poisson(lam=rate, shape=[], dtype=self.dtype, seed=seed)
def _event_shape_tensor(self, override_event_shape=None, base_event_shape_tensor=None): override_event_shape = (tf.convert_to_tensor( self._override_event_shape) if override_event_shape is None else override_event_shape) base_event_shape_tensor = (self.distribution.event_shape_tensor() if base_event_shape_tensor is None else base_event_shape_tensor) return self.bijector.forward_event_shape_tensor( distribution_util.pick_vector( self._has_nonzero_rank(override_event_shape), override_event_shape, base_event_shape_tensor))
def _sample_n(self, n, seed=None, **distribution_kwargs): override_event_shape = tf.convert_to_tensor(self._override_event_shape) override_batch_shape = tf.convert_to_tensor(self._override_batch_shape) base_is_scalar_batch = self.distribution.is_scalar_batch() needs_rotation = self._needs_rotation(override_event_shape, override_batch_shape, base_is_scalar_batch) sample_shape = prefer_static.concat([ distribution_util.pick_vector(needs_rotation, self._empty, [n]), override_batch_shape, override_event_shape, distribution_util.pick_vector(needs_rotation, [n], self._empty), ], axis=0) x = self.distribution.sample(sample_shape=sample_shape, seed=seed, **distribution_kwargs) x = self._maybe_rotate_dims(x, override_event_shape, override_batch_shape, base_is_scalar_batch) # We'll apply the bijector in the `_call_sample_n` function. return x
def make_batch_of_event_sample_matrices( self, x, expand_batch_dim=True, name="make_batch_of_event_sample_matrices"): """Reshapes/transposes `Distribution` `Tensor` from S+B+E to B_+E_+S_. Where: - `B_ = B if B or not expand_batch_dim else [1]`, - `E_ = E if E else [1]`, - `S_ = [tf.reduce_prod(S)]`. Args: x: `Tensor`. expand_batch_dim: Python `bool`. If `True` the batch dims will be expanded such that `batch_ndims >= 1`. name: Python `str`. The name to give this op. Returns: x: `Tensor`. Input transposed/reshaped to `B_+E_+S_`. sample_shape: `Tensor` (1D, `int32`). """ with self._name_scope(name, values=[x]): x = tf.convert_to_tensor(x, name="x") # x.shape: S+B+E sample_shape, batch_shape, event_shape = self.get_shape(x) event_shape = distribution_util.pick_vector( self._event_ndims_is_0, [1], event_shape) if expand_batch_dim: batch_shape = distribution_util.pick_vector( self._batch_ndims_is_0, [1], batch_shape) new_shape = tf.concat([[-1], batch_shape, event_shape], 0) x = tf.reshape(x, shape=new_shape) # x.shape: [prod(S)]+B_+E_ x = distribution_util.rotate_transpose(x, shift=-1) # x.shape: B_+E_+[prod(S)] return x, sample_shape
def _event_shape_tensor(self, override_event_shape=None, base_event_shape_tensor=None): override_event_shape = (tf.convert_to_tensor( self._override_event_shape) if override_event_shape is None else override_event_shape) base_event_shape_tensor = (self.distribution.event_shape_tensor() if base_event_shape_tensor is None else base_event_shape_tensor) # If the base distribution is not joint, use the base event shape override, # if any. if not self._base_is_joint: base_event_shape_tensor = distribution_util.pick_vector( self._has_nonzero_rank(override_event_shape), override_event_shape, base_event_shape_tensor) return self.bijector.forward_event_shape_tensor( base_event_shape_tensor)
def __init__(self, df, loc=None, scale_identity_multiplier=None, scale_diag=None, scale_tril=None, scale_perturb_factor=None, scale_perturb_diag=None, validate_args=False, allow_nan_stats=True, name="VectorStudentT"): """Instantiates the vector Student's t-distributions on `R^k`. The `batch_shape` is the broadcast between `df.batch_shape` and `Affine.batch_shape` where `Affine` is constructed from `loc` and `scale_*` arguments. The `event_shape` is the event shape of `Affine.event_shape`. Args: df: Floating-point `Tensor`. The degrees of freedom of the distribution(s). `df` must contain only positive values. Must be scalar if `loc`, `scale_*` imply non-scalar batch_shape or must have the same `batch_shape` implied by `loc`, `scale_*`. loc: Floating-point `Tensor`. If this is set to `None`, no `loc` is applied. scale_identity_multiplier: floating point rank 0 `Tensor` representing a scaling done to the identity matrix. When `scale_identity_multiplier = scale_diag=scale_tril = None` then `scale += IdentityMatrix`. Otherwise no scaled-identity-matrix is added to `scale`. scale_diag: Floating-point `Tensor` representing the diagonal matrix. `scale_diag` has shape [N1, N2, ..., k], which represents a k x k diagonal matrix. When `None` no diagonal term is added to `scale`. scale_tril: Floating-point `Tensor` representing the diagonal matrix. `scale_diag` has shape [N1, N2, ..., k, k], which represents a k x k lower triangular matrix. When `None` no `scale_tril` term is added to `scale`. The upper triangular elements above the diagonal are ignored. scale_perturb_factor: Floating-point `Tensor` representing factor matrix with last two dimensions of shape `(k, r)`. When `None`, no rank-r update is added to `scale`. scale_perturb_diag: Floating-point `Tensor` representing the diagonal matrix. `scale_perturb_diag` has shape [N1, N2, ..., r], which represents an r x r Diagonal matrix. When `None` low rank updates will take the form `scale_perturb_factor * scale_perturb_factor.T`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = dict(locals()) graph_parents = [df, loc, scale_identity_multiplier, scale_diag, scale_tril, scale_perturb_factor, scale_perturb_diag] with tf.name_scope(name) as name: with tf.name_scope("init", values=graph_parents): # The shape of the _VectorStudentT distribution is governed by the # relationship between df.batch_shape and affine.batch_shape. In # pseudocode the basic procedure is: # if df.batch_shape is scalar: # if affine.batch_shape is not scalar: # # broadcast distribution.sample so # # it has affine.batch_shape. # self.batch_shape = affine.batch_shape # else: # if affine.batch_shape is scalar: # # let affine broadcasting do its thing. # self.batch_shape = df.batch_shape # All of the above magic is actually handled by TransformedDistribution. # Here we really only need to collect the affine.batch_shape and decide # what we're going to pass in to TransformedDistribution's # (override) batch_shape arg. affine = bijectors.Affine( shift=loc, scale_identity_multiplier=scale_identity_multiplier, scale_diag=scale_diag, scale_tril=scale_tril, scale_perturb_factor=scale_perturb_factor, scale_perturb_diag=scale_perturb_diag, validate_args=validate_args) distribution = student_t.StudentT( df=df, loc=tf.zeros([], dtype=affine.dtype), scale=tf.ones([], dtype=affine.dtype)) batch_shape, override_event_shape = ( distribution_util.shapes_from_loc_and_scale( affine.shift, affine.scale)) override_batch_shape = distribution_util.pick_vector( distribution.is_scalar_batch(), batch_shape, tf.constant([], dtype=tf.int32)) super(_VectorStudentT, self).__init__( distribution=distribution, bijector=affine, batch_shape=override_batch_shape, event_shape=override_event_shape, validate_args=validate_args, name=name) self._parameters = parameters
def __init__(self, mixture_distribution, components_distribution, validate_args=False, allow_nan_stats=True, name="MixtureSameFamily"): """Construct a `MixtureSameFamily` distribution. Args: mixture_distribution: `tfp.distributions.Categorical`-like instance. Manages the probability of selecting components. The number of categories must match the rightmost batch dimension of the `components_distribution`. Must have either scalar `batch_shape` or `batch_shape` matching `components_distribution.batch_shape[:-1]`. components_distribution: `tfp.distributions.Distribution`-like instance. Right-most batch dimension indexes components. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: `if not mixture_distribution.dtype.is_integer`. ValueError: if mixture_distribution does not have scalar `event_shape`. ValueError: if `mixture_distribution.batch_shape` and `components_distribution.batch_shape[:-1]` are both fully defined and the former is neither scalar nor equal to the latter. ValueError: if `mixture_distribution` categories does not equal `components_distribution` rightmost batch shape. """ parameters = dict(locals()) with tf.name_scope(name) as name: self._mixture_distribution = mixture_distribution self._components_distribution = components_distribution self._runtime_assertions = [] s = components_distribution.event_shape_tensor() self._event_ndims = ( s.shape[0].value if s.shape.with_rank_at_least(1)[0].value is not None else tf.shape(s)[0]) if not mixture_distribution.dtype.is_integer: raise ValueError( "`mixture_distribution.dtype` ({}) is not over integers".format( mixture_distribution.dtype.name)) if (mixture_distribution.event_shape.ndims is not None and mixture_distribution.event_shape.ndims != 0): raise ValueError("`mixture_distribution` must have scalar `event_dim`s") elif validate_args: self._runtime_assertions += [ tf.assert_rank( mixture_distribution.event_shape_tensor(), 0, message="`mixture_distribution` must have scalar `event_dim`s"), ] mdbs = mixture_distribution.batch_shape cdbs = components_distribution.batch_shape.with_rank_at_least(1)[:-1] if mdbs.is_fully_defined() and cdbs.is_fully_defined(): if mdbs.ndims != 0 and mdbs != cdbs: raise ValueError( "`mixture_distribution.batch_shape` (`{}`) is not " "compatible with `components_distribution.batch_shape` " "(`{}`)".format(mdbs.as_list(), cdbs.as_list())) elif validate_args: mdbs = mixture_distribution.batch_shape_tensor() cdbs = components_distribution.batch_shape_tensor()[:-1] self._runtime_assertions += [ tf.assert_equal( distribution_utils.pick_vector( mixture_distribution.is_scalar_batch(), cdbs, mdbs), cdbs, message=( "`mixture_distribution.batch_shape` is not " "compatible with `components_distribution.batch_shape`"))] km = mixture_distribution.logits.shape.with_rank_at_least(1)[-1].value kc = components_distribution.batch_shape.with_rank_at_least(1)[-1].value if km is not None and kc is not None and km != kc: raise ValueError("`mixture_distribution components` ({}) does not " "equal `components_distribution.batch_shape[-1]` " "({})".format(km, kc)) elif validate_args: km = tf.shape(mixture_distribution.logits)[-1] kc = components_distribution.batch_shape_tensor()[-1] self._runtime_assertions += [ tf.assert_equal( km, kc, message=("`mixture_distribution components` does not equal " "`components_distribution.batch_shape[-1:]`")), ] elif km is None: km = tf.shape(mixture_distribution.logits)[-1] self._num_components = km super(MixtureSameFamily, self).__init__( dtype=self._components_distribution.dtype, reparameterization_type=reparameterization.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=( self._mixture_distribution._graph_parents # pylint: disable=protected-access + self._components_distribution._graph_parents), # pylint: disable=protected-access name=name)
def _sample_n(self, n, seed=None): stream = seed_stream.SeedStream(seed, salt="VectorDiffeomixture") x = self.distribution.sample( sample_shape=concat_vectors( [n], self.batch_shape_tensor(), self.event_shape_tensor()), seed=stream()) # shape: [n, B, e] x = [aff.forward(x) for aff in self.endpoint_affine] # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. batch_size = self.batch_shape.num_elements() if batch_size is None: batch_size = tf.reduce_prod(self.batch_shape_tensor()) mix_batch_size = self.mixture_distribution.batch_shape.num_elements() if mix_batch_size is None: mix_batch_size = tf.reduce_prod( self.mixture_distribution.batch_shape_tensor()) ids = self.mixture_distribution.sample( sample_shape=concat_vectors( [n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), [batch_size // mix_batch_size])), seed=stream()) # We need to flatten batch dims in case mixture_distribution has its own # batch dims. ids = tf.reshape( ids, shape=concat_vectors([n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), np.int32([-1])))) # Stride `components * quadrature_size` for `batch_size` number of times. stride = self.grid.shape.with_rank_at_least( 2)[-2:].num_elements() if stride is None: stride = tf.reduce_prod(tf.shape(self.grid)[-2:]) offset = tf.range( start=0, limit=batch_size * stride, delta=stride, dtype=ids.dtype) weight = tf.gather(tf.reshape(self.grid, shape=[-1]), ids + offset) # At this point, weight flattened all batch dims into one. # We also need to append a singleton to broadcast with event dims. if self.batch_shape.is_fully_defined(): new_shape = [-1] + self.batch_shape.as_list() + [1] else: new_shape = tf.concat(([-1], self.batch_shape_tensor(), [1]), axis=0) weight = tf.reshape(weight, shape=new_shape) if len(x) != 2: # We actually should have already triggered this exception. However as a # policy we're putting this exception wherever we exploit the bimixture # assumption. raise NotImplementedError("Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(x))) # Alternatively: # x = weight * x[0] + (1. - weight) * x[1] x = weight * (x[0] - x[1]) + x[1] return x
def _sample_n(self, n, seed=None): stream = seed_stream.SeedStream(seed, salt="VectorDiffeomixture") x = self.distribution.sample(sample_shape=concat_vectors( [n], self.batch_shape_tensor(), self.event_shape_tensor()), seed=stream()) # shape: [n, B, e] x = [aff.forward(x) for aff in self.endpoint_affine] # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. batch_size = tensorshape_util.num_elements(self.batch_shape) if batch_size is None: batch_size = tf.reduce_prod(input_tensor=self.batch_shape_tensor()) mix_batch_size = tensorshape_util.num_elements( self.mixture_distribution.batch_shape) if mix_batch_size is None: mix_batch_size = tf.reduce_prod( input_tensor=self.mixture_distribution.batch_shape_tensor()) ids = self.mixture_distribution.sample(sample_shape=concat_vectors( [n], distribution_util.pick_vector(self.is_scalar_batch(), np.int32([]), [batch_size // mix_batch_size])), seed=stream()) # We need to flatten batch dims in case mixture_distribution has its own # batch dims. ids = tf.reshape(ids, shape=concat_vectors([n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), np.int32([-1])))) # Stride `components * quadrature_size` for `batch_size` number of times. stride = tensorshape_util.num_elements( tensorshape_util.with_rank_at_least(self.grid.shape, 2)[-2:]) if stride is None: stride = tf.reduce_prod(input_tensor=tf.shape( input=self.grid)[-2:]) offset = tf.range(start=0, limit=batch_size * stride, delta=stride, dtype=ids.dtype) weight = tf.gather(tf.reshape(self.grid, shape=[-1]), ids + offset) # At this point, weight flattened all batch dims into one. # We also need to append a singleton to broadcast with event dims. if tensorshape_util.is_fully_defined(self.batch_shape): new_shape = [-1] + tensorshape_util.as_list(self.batch_shape) + [1] else: new_shape = tf.concat(([-1], self.batch_shape_tensor(), [1]), axis=0) weight = tf.reshape(weight, shape=new_shape) if len(x) != 2: # We actually should have already triggered this exception. However as a # policy we're putting this exception wherever we exploit the bimixture # assumption. raise NotImplementedError( "Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(x))) # Alternatively: # x = weight * x[0] + (1. - weight) * x[1] x = weight * (x[0] - x[1]) + x[1] return x
def __init__(self, mixture_distribution, components_distribution, reparameterize=False, validate_args=False, allow_nan_stats=True, name="MixtureSameFamily"): """Construct a `MixtureSameFamily` distribution. Args: mixture_distribution: `tfp.distributions.Categorical`-like instance. Manages the probability of selecting components. The number of categories must match the rightmost batch dimension of the `components_distribution`. Must have either scalar `batch_shape` or `batch_shape` matching `components_distribution.batch_shape[:-1]`. components_distribution: `tfp.distributions.Distribution`-like instance. Right-most batch dimension indexes components. reparameterize: Python `bool`, default `False`. Whether to reparameterize samples of the distribution using implicit reparameterization gradients [(Figurnov et al., 2018)][1]. The gradients for the mixture logits are equivalent to the ones described by [(Graves, 2016)][2]. The gradients for the components parameters are also computed using implicit reparameterization (as opposed to ancestral sampling), meaning that all components are updated every step. Only works when: (1) components_distribution is fully reparameterized; (2) components_distribution is either a scalar distribution or fully factorized (tfd.Independent applied to a scalar distribution); (3) batch shape has a known rank. Experimental, may be slow and produce infs/NaNs. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: `if not mixture_distribution.dtype.is_integer`. ValueError: if mixture_distribution does not have scalar `event_shape`. ValueError: if `mixture_distribution.batch_shape` and `components_distribution.batch_shape[:-1]` are both fully defined and the former is neither scalar nor equal to the latter. ValueError: if `mixture_distribution` categories does not equal `components_distribution` rightmost batch shape. #### References [1]: Michael Figurnov, Shakir Mohamed and Andriy Mnih. Implicit reparameterization gradients. In _Neural Information Processing Systems_, 2018. https://arxiv.org/abs/1805.08498 [2]: Alex Graves. Stochastic Backpropagation through Mixture Density Distributions. _arXiv_, 2016. https://arxiv.org/abs/1607.05690 """ parameters = dict(locals()) with tf.name_scope(name) as name: self._mixture_distribution = mixture_distribution self._components_distribution = components_distribution self._runtime_assertions = [] s = components_distribution.event_shape_tensor() self._event_ndims = tf.compat.dimension_value(s.shape[0]) if self._event_ndims is None: self._event_ndims = tf.size(input=s) self._event_size = tf.reduce_prod(input_tensor=s) if not mixture_distribution.dtype.is_integer: raise ValueError( "`mixture_distribution.dtype` ({}) is not over integers". format(mixture_distribution.dtype.name)) if (mixture_distribution.event_shape.ndims is not None and mixture_distribution.event_shape.ndims != 0): raise ValueError( "`mixture_distribution` must have scalar `event_dim`s") elif validate_args: self._runtime_assertions += [ assert_util.assert_equal( tf.size( input=mixture_distribution.event_shape_tensor()), 0, message= "`mixture_distribution` must have scalar `event_dim`s" ), ] mdbs = mixture_distribution.batch_shape cdbs = components_distribution.batch_shape.with_rank_at_least( 1)[:-1] if mdbs.is_fully_defined() and cdbs.is_fully_defined(): if mdbs.ndims != 0 and mdbs != cdbs: raise ValueError( "`mixture_distribution.batch_shape` (`{}`) is not " "compatible with `components_distribution.batch_shape` " "(`{}`)".format(mdbs.as_list(), cdbs.as_list())) elif validate_args: mdbs = mixture_distribution.batch_shape_tensor() cdbs = components_distribution.batch_shape_tensor()[:-1] self._runtime_assertions += [ assert_util.assert_equal( distribution_utils.pick_vector( mixture_distribution.is_scalar_batch(), cdbs, mdbs), cdbs, message= ("`mixture_distribution.batch_shape` is not " "compatible with `components_distribution.batch_shape`" )) ] km = tf.compat.dimension_value( mixture_distribution.logits.shape.with_rank_at_least(1)[-1]) kc = tf.compat.dimension_value( components_distribution.batch_shape.with_rank_at_least(1)[-1]) if km is not None and kc is not None and km != kc: raise ValueError( "`mixture_distribution components` ({}) does not " "equal `components_distribution.batch_shape[-1]` " "({})".format(km, kc)) elif validate_args: km = tf.shape(input=mixture_distribution.logits)[-1] kc = components_distribution.batch_shape_tensor()[-1] self._runtime_assertions += [ assert_util.assert_equal( km, kc, message=( "`mixture_distribution components` does not equal " "`components_distribution.batch_shape[-1:]`")), ] elif km is None: km = tf.shape(input=mixture_distribution.logits)[-1] self._num_components = km self._reparameterize = reparameterize if reparameterize: # Note: tfd.Independent passes through the reparameterization type hence # we do not need separate logic for Independent. if (self._components_distribution.reparameterization_type != reparameterization.FULLY_REPARAMETERIZED): raise ValueError("Cannot reparameterize a mixture of " "non-reparameterized components.") reparameterization_type = reparameterization.FULLY_REPARAMETERIZED else: reparameterization_type = reparameterization.NOT_REPARAMETERIZED super(MixtureSameFamily, self).__init__( dtype=self._components_distribution.dtype, reparameterization_type=reparameterization_type, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=( self._mixture_distribution._graph_parents # pylint: disable=protected-access + self._components_distribution._graph_parents), # pylint: disable=protected-access name=name)
def _event_shape_tensor(self): return self.bijector.forward_event_shape_tensor( distribution_util.pick_vector( self._is_event_override, self._override_event_shape, self.distribution.event_shape_tensor()))
def _forward_log_det_jacobian(self, x): # Let Y be a symmetric, positive definite matrix and write: # Y = X X.T # where X is lower-triangular. # # Observe that, # dY[i,j]/dX[a,b] # = d/dX[a,b] { X[i,:] X[j,:] } # = sum_{d=1}^p { I[i=a] I[d=b] X[j,d] + I[j=a] I[d=b] X[i,d] } # # To compute the Jacobian dX/dY we must represent X,Y as vectors. Since Y is # symmetric and X is lower-triangular, we need vectors of dimension: # d = p (p + 1) / 2 # where X, Y are p x p matrices, p > 0. We use a row-major mapping, i.e., # k = { i (i + 1) / 2 + j i>=j # { undef i<j # and assume zero-based indexes. When k is undef, the element is dropped. # Example: # j k # 0 1 2 3 / # 0 [ 0 . . . ] # i 1 [ 1 2 . . ] # 2 [ 3 4 5 . ] # 3 [ 6 7 8 9 ] # Write vec[.] to indicate transforming a matrix to vector via k(i,j). (With # slight abuse: k(i,j)=undef means the element is dropped.) # # We now show d vec[Y] / d vec[X] is lower triangular. Assuming both are # defined, observe that k(i,j) < k(a,b) iff (1) i<a or (2) i=a and j<b. # In both cases dvec[Y]/dvec[X]@[k(i,j),k(a,b)] = 0 since: # (1) j<=i<a thus i,j!=a. # (2) i=a>j thus i,j!=a. # # Since the Jacobian is lower-triangular, we need only compute the product # of diagonal elements: # d vec[Y] / d vec[X] @[k(i,j), k(i,j)] # = X[j,j] + I[i=j] X[i,j] # = 2 X[j,j]. # Since there is a 2 X[j,j] term for every lower-triangular element of X we # conclude: # |Jac(d vec[Y]/d vec[X])| = 2^p prod_{j=0}^{p-1} X[j,j]^{p-j}. diag = tf.linalg.diag_part(x) # We now ensure diag is columnar. Eg, if `diag = [1, 2, 3]` then the output # is `[[1], [2], [3]]` and if `diag = [[1, 2, 3], [4, 5, 6]]` then the # output is unchanged. diag = self._make_columnar(diag) with tf.control_dependencies(self._assertions(x)): # Create a vector equal to: [p, p-1, ..., 2, 1]. if tf.compat.dimension_value(x.shape[-1]) is None: p_int = tf.shape(x)[-1] p_float = tf.cast(p_int, dtype=x.dtype) else: p_int = tf.compat.dimension_value(x.shape[-1]) p_float = dtype_util.as_numpy_dtype(x.dtype)(p_int) exponents = tf.linspace(p_float, 1., p_int) sum_weighted_log_diag = tf.squeeze( tf.matmul(tf.math.log(diag), exponents[..., tf.newaxis]), axis=-1) fldj = p_float * np.log(2.) + sum_weighted_log_diag # We finally need to undo adding an extra column in non-scalar cases # where there is a single matrix as input. if tensorshape_util.rank(x.shape) is not None: if tensorshape_util.rank(x.shape) == 2: fldj = tf.squeeze(fldj, axis=-1) return fldj shape = tf.shape(fldj) maybe_squeeze_shape = tf.concat([ shape[:-1], distribution_util.pick_vector( tf.equal(tf.rank(x), 2), np.array([], dtype=np.int32), shape[-1:])], 0) return tf.reshape(fldj, maybe_squeeze_shape)
def _parameter_control_dependencies(self, is_init): assertions = [] if is_init and not dtype_util.is_integer( self.mixture_distribution.dtype): raise ValueError( '`mixture_distribution.dtype` ({}) is not over integers'. format(dtype_util.name(self.mixture_distribution.dtype))) if tensorshape_util.rank( self.mixture_distribution.event_shape) is not None: if tensorshape_util.rank( self.mixture_distribution.event_shape) != 0: raise ValueError( '`mixture_distribution` must have scalar `event_dim`s') elif self.validate_args: assertions += [ assert_util.assert_equal( tf.size(self.mixture_distribution.event_shape_tensor()), 0, message= '`mixture_distribution` must have scalar `event_dim`s'), ] # pylint: disable=protected-access mixture_dist_param = (self.mixture_distribution._probs if self.mixture_distribution._logits is None else self.mixture_distribution._logits) km = tf.compat.dimension_value( tensorshape_util.with_rank_at_least(mixture_dist_param.shape, 1)[-1]) kc = tf.compat.dimension_value( tensorshape_util.with_rank_at_least( self.components_distribution.batch_shape, 1)[-1]) component_bst = None if km is not None and kc is not None: if km != kc: raise ValueError( '`mixture_distribution` components ({}) does not ' 'equal `components_distribution.batch_shape[-1]` ' '({})'.format(km, kc)) elif self.validate_args: if km is None: mixture_dist_param = tf.convert_to_tensor(mixture_dist_param) km = tf.shape(mixture_dist_param)[-1] if kc is None: component_bst = self.components_distribution.batch_shape_tensor( ) kc = component_bst[-1] assertions += [ assert_util.assert_equal( km, kc, message=( '`mixture_distribution` components does not equal ' '`components_distribution.batch_shape[-1]`')), ] mdbs = self.mixture_distribution.batch_shape cdbs = tensorshape_util.with_rank_at_least( self.components_distribution.batch_shape, 1)[:-1] if (tensorshape_util.is_fully_defined(mdbs) and tensorshape_util.is_fully_defined(cdbs)): if tensorshape_util.rank(mdbs) != 0 and mdbs != cdbs: raise ValueError( '`mixture_distribution.batch_shape` (`{}`) is not ' 'compatible with `components_distribution.batch_shape` ' '(`{}`)'.format(tensorshape_util.as_list(mdbs), tensorshape_util.as_list(cdbs))) elif self.validate_args: if not tensorshape_util.is_fully_defined(mdbs): mixture_dist_param = tf.convert_to_tensor(mixture_dist_param) mdbs = tf.shape(mixture_dist_param)[:-1] if not tensorshape_util.is_fully_defined(cdbs): if component_bst is None: component_bst = self.components_distribution.batch_shape_tensor( ) cdbs = component_bst[:-1] assertions += [ assert_util.assert_equal( distribution_utils.pick_vector( tf.equal(tf.shape(mdbs)[0], 0), cdbs, mdbs), cdbs, message=( '`mixture_distribution.batch_shape` is not ' 'compatible with `components_distribution.batch_shape`' )) ] return assertions
def __init__(self, mixture_distribution, components_distribution, validate_args=False, allow_nan_stats=True, name="MixtureSameFamily"): """Construct a `MixtureSameFamily` distribution. Args: mixture_distribution: `tfp.distributions.Categorical`-like instance. Manages the probability of selecting components. The number of categories must match the rightmost batch dimension of the `components_distribution`. Must have either scalar `batch_shape` or `batch_shape` matching `components_distribution.batch_shape[:-1]`. components_distribution: `tfp.distributions.Distribution`-like instance. Right-most batch dimension indexes components. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: `if not mixture_distribution.dtype.is_integer`. ValueError: if mixture_distribution does not have scalar `event_shape`. ValueError: if `mixture_distribution.batch_shape` and `components_distribution.batch_shape[:-1]` are both fully defined and the former is neither scalar nor equal to the latter. ValueError: if `mixture_distribution` categories does not equal `components_distribution` rightmost batch shape. """ parameters = dict(locals()) with tf.name_scope(name) as name: self._mixture_distribution = mixture_distribution self._components_distribution = components_distribution self._runtime_assertions = [] s = components_distribution.event_shape_tensor() self._event_ndims = (s.shape[0].value if s.shape.with_rank_at_least(1)[0].value is not None else tf.shape(s)[0]) if not mixture_distribution.dtype.is_integer: raise ValueError( "`mixture_distribution.dtype` ({}) is not over integers". format(mixture_distribution.dtype.name)) if (mixture_distribution.event_shape.ndims is not None and mixture_distribution.event_shape.ndims != 0): raise ValueError( "`mixture_distribution` must have scalar `event_dim`s") elif validate_args: self._runtime_assertions += [ control_flow_ops.assert_has_rank( mixture_distribution.event_shape_tensor(), 0, message= "`mixture_distribution` must have scalar `event_dim`s" ), ] mdbs = mixture_distribution.batch_shape cdbs = components_distribution.batch_shape.with_rank_at_least( 1)[:-1] if mdbs.is_fully_defined() and cdbs.is_fully_defined(): if mdbs.ndims != 0 and mdbs != cdbs: raise ValueError( "`mixture_distribution.batch_shape` (`{}`) is not " "compatible with `components_distribution.batch_shape` " "(`{}`)".format(mdbs.as_list(), cdbs.as_list())) elif validate_args: mdbs = mixture_distribution.batch_shape_tensor() cdbs = components_distribution.batch_shape_tensor()[:-1] self._runtime_assertions += [ control_flow_ops.assert_equal( distribution_utils.pick_vector( mixture_distribution.is_scalar_batch(), cdbs, mdbs), cdbs, message= ("`mixture_distribution.batch_shape` is not " "compatible with `components_distribution.batch_shape`" )) ] km = mixture_distribution.logits.shape.with_rank_at_least( 1)[-1].value kc = components_distribution.batch_shape.with_rank_at_least( 1)[-1].value if km is not None and kc is not None and km != kc: raise ValueError( "`mixture_distribution components` ({}) does not " "equal `components_distribution.batch_shape[-1]` " "({})".format(km, kc)) elif validate_args: km = tf.shape(mixture_distribution.logits)[-1] kc = components_distribution.batch_shape_tensor()[-1] self._runtime_assertions += [ control_flow_ops.assert_equal( km, kc, message=( "`mixture_distribution components` does not equal " "`components_distribution.batch_shape[-1:]`")), ] elif km is None: km = tf.shape(mixture_distribution.logits)[-1] self._num_components = km super(MixtureSameFamily, self).__init__( dtype=self._components_distribution.dtype, reparameterization_type=reparameterization.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=( self._mixture_distribution._graph_parents # pylint: disable=protected-access + self._components_distribution._graph_parents), # pylint: disable=protected-access name=name)
def _batch_shape_tensor(self): return distribution_util.pick_vector( self._is_batch_override, self._override_batch_shape, self.distribution.batch_shape_tensor())
def _forward_log_det_jacobian(self, x): # Let Y be a symmetric, positive definite matrix and write: # Y = X X.T # where X is lower-triangular. # # Observe that, # dY[i,j]/dX[a,b] # = d/dX[a,b] { X[i,:] X[j,:] } # = sum_{d=1}^p { I[i=a] I[d=b] X[j,d] + I[j=a] I[d=b] X[i,d] } # # To compute the Jacobian dX/dY we must represent X,Y as vectors. Since Y is # symmetric and X is lower-triangular, we need vectors of dimension: # d = p (p + 1) / 2 # where X, Y are p x p matrices, p > 0. We use a row-major mapping, i.e., # k = { i (i + 1) / 2 + j i>=j # { undef i<j # and assume zero-based indexes. When k is undef, the element is dropped. # Example: # j k # 0 1 2 3 / # 0 [ 0 . . . ] # i 1 [ 1 2 . . ] # 2 [ 3 4 5 . ] # 3 [ 6 7 8 9 ] # Write vec[.] to indicate transforming a matrix to vector via k(i,j). (With # slight abuse: k(i,j)=undef means the element is dropped.) # # We now show d vec[Y] / d vec[X] is lower triangular. Assuming both are # defined, observe that k(i,j) < k(a,b) iff (1) i<a or (2) i=a and j<b. # In both cases dvec[Y]/dvec[X]@[k(i,j),k(a,b)] = 0 since: # (1) j<=i<a thus i,j!=a. # (2) i=a>j thus i,j!=a. # # Since the Jacobian is lower-triangular, we need only compute the product # of diagonal elements: # d vec[Y] / d vec[X] @[k(i,j), k(i,j)] # = X[j,j] + I[i=j] X[i,j] # = 2 X[j,j]. # Since there is a 2 X[j,j] term for every lower-triangular element of X we # conclude: # |Jac(d vec[Y]/d vec[X])| = 2^p prod_{j=0}^{p-1} X[j,j]^{p-j}. diag = tf.matrix_diag_part(x) # We now ensure diag is columnar. Eg, if `diag = [1, 2, 3]` then the output # is `[[1], [2], [3]]` and if `diag = [[1, 2, 3], [4, 5, 6]]` then the # output is unchanged. diag = self._make_columnar(diag) if self.validate_args: is_matrix = tf.assert_rank_at_least( x, 2, message="Input must be a (batch of) matrix.") shape = tf.shape(x) is_square = tf.assert_equal( shape[-2], shape[-1], message="Input must be a (batch of) square matrix.") # Assuming lower-triangular means we only need check diag>0. is_positive_definite = tf.assert_positive( diag, message="Input must be positive definite.") x = control_flow_ops.with_dependencies( [is_matrix, is_square, is_positive_definite], x) # Create a vector equal to: [p, p-1, ..., 2, 1]. if x.shape.ndims is None or x.shape[-1].value is None: p_int = tf.shape(x)[-1] p_float = tf.cast(p_int, dtype=x.dtype) else: p_int = x.shape[-1].value p_float = np.array(p_int, dtype=x.dtype.as_numpy_dtype) exponents = tf.linspace(p_float, 1., p_int) sum_weighted_log_diag = tf.squeeze( tf.matmul(tf.log(diag), exponents[..., tf.newaxis]), axis=-1) fldj = p_float * np.log(2.) + sum_weighted_log_diag # We finally need to undo adding an extra column in non-scalar cases # where there is a single matrix as input. if x.shape.ndims is not None: if x.shape.ndims == 2: fldj = tf.squeeze(fldj, axis=-1) return fldj shape = tf.shape(fldj) maybe_squeeze_shape = tf.concat([ shape[:-1], distribution_util.pick_vector( tf.equal(tf.rank(x), 2), np.array([], dtype=np.int32), shape[-1:])], 0) return tf.reshape(fldj, maybe_squeeze_shape)
def __init__(self, df, loc=None, scale_identity_multiplier=None, scale_diag=None, scale_tril=None, scale_perturb_factor=None, scale_perturb_diag=None, validate_args=False, allow_nan_stats=True, name="VectorStudentT"): """Instantiates the vector Student's t-distributions on `R^k`. The `batch_shape` is the broadcast between `df.batch_shape` and `Affine.batch_shape` where `Affine` is constructed from `loc` and `scale_*` arguments. The `event_shape` is the event shape of `Affine.event_shape`. Args: df: Floating-point `Tensor`. The degrees of freedom of the distribution(s). `df` must contain only positive values. Must be scalar if `loc`, `scale_*` imply non-scalar batch_shape or must have the same `batch_shape` implied by `loc`, `scale_*`. loc: Floating-point `Tensor`. If this is set to `None`, no `loc` is applied. scale_identity_multiplier: floating point rank 0 `Tensor` representing a scaling done to the identity matrix. When `scale_identity_multiplier = scale_diag=scale_tril = None` then `scale += IdentityMatrix`. Otherwise no scaled-identity-matrix is added to `scale`. scale_diag: Floating-point `Tensor` representing the diagonal matrix. `scale_diag` has shape [N1, N2, ..., k], which represents a k x k diagonal matrix. When `None` no diagonal term is added to `scale`. scale_tril: Floating-point `Tensor` representing the diagonal matrix. `scale_diag` has shape [N1, N2, ..., k, k], which represents a k x k lower triangular matrix. When `None` no `scale_tril` term is added to `scale`. The upper triangular elements above the diagonal are ignored. scale_perturb_factor: Floating-point `Tensor` representing factor matrix with last two dimensions of shape `(k, r)`. When `None`, no rank-r update is added to `scale`. scale_perturb_diag: Floating-point `Tensor` representing the diagonal matrix. `scale_perturb_diag` has shape [N1, N2, ..., r], which represents an r x r Diagonal matrix. When `None` low rank updates will take the form `scale_perturb_factor * scale_perturb_factor.T`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = dict(locals()) graph_parents = [ df, loc, scale_identity_multiplier, scale_diag, scale_tril, scale_perturb_factor, scale_perturb_diag ] with tf.name_scope(name) as name: with tf.name_scope("init"): dtype = dtype_util.common_dtype(graph_parents, tf.float32) df = tf.convert_to_tensor(value=df, name="df", dtype=dtype) # The shape of the _VectorStudentT distribution is governed by the # relationship between df.batch_shape and affine.batch_shape. In # pseudocode the basic procedure is: # if df.batch_shape is scalar: # if affine.batch_shape is not scalar: # # broadcast distribution.sample so # # it has affine.batch_shape. # self.batch_shape = affine.batch_shape # else: # if affine.batch_shape is scalar: # # let affine broadcasting do its thing. # self.batch_shape = df.batch_shape # All of the above magic is actually handled by TransformedDistribution. # Here we really only need to collect the affine.batch_shape and decide # what we're going to pass in to TransformedDistribution's # (override) batch_shape arg. affine = affine_bijector.Affine( shift=loc, scale_identity_multiplier=scale_identity_multiplier, scale_diag=scale_diag, scale_tril=scale_tril, scale_perturb_factor=scale_perturb_factor, scale_perturb_diag=scale_perturb_diag, validate_args=validate_args, dtype=dtype) distribution = student_t.StudentT( df=df, loc=tf.zeros([], dtype=affine.dtype), scale=tf.ones([], dtype=affine.dtype)) batch_shape, override_event_shape = ( distribution_util.shapes_from_loc_and_scale( affine.shift, affine.scale)) override_batch_shape = distribution_util.pick_vector( distribution.is_scalar_batch(), batch_shape, tf.constant([], dtype=tf.int32)) super(_VectorStudentT, self).__init__(distribution=distribution, bijector=affine, batch_shape=override_batch_shape, event_shape=override_event_shape, validate_args=validate_args, name=name) self._parameters = parameters
def _event_shape_tensor(self): return self.bijector.forward_event_shape_tensor( distribution_util.pick_vector( self._is_event_override, self._override_event_shape, self.distribution.event_shape_tensor()))
def _batch_shape_tensor(self): return distribution_util.pick_vector( self._is_batch_override, self._override_batch_shape, self.distribution.batch_shape_tensor())