def make_batch_of_event_sample_matrices( self, x, expand_batch_dim=True, name="make_batch_of_event_sample_matrices"): """Reshapes/transposes `Distribution` `Tensor` from S+B+E to B_+E_+S_. Where: - `B_ = B if B or not expand_batch_dim else [1]`, - `E_ = E if E else [1]`, - `S_ = [tf.reduce_prod(S)]`. Args: x: `Tensor`. expand_batch_dim: Python `bool`. If `True` the batch dims will be expanded such that `batch_ndims >= 1`. name: Python `str`. The name to give this op. Returns: x: `Tensor`. Input transposed/reshaped to `B_+E_+S_`. sample_shape: `Tensor` (1D, `int32`). """ with self._name_scope(name, values=[x]): x = tf.convert_to_tensor(x, name="x") # x.shape: S+B+E sample_shape, batch_shape, event_shape = self.get_shape(x) event_shape = distribution_util.pick_vector( self._event_ndims_is_0, [1], event_shape) if expand_batch_dim: batch_shape = distribution_util.pick_vector( self._batch_ndims_is_0, [1], batch_shape) new_shape = tf.concat([[-1], batch_shape, event_shape], 0) x = tf.reshape(x, shape=new_shape) # x.shape: [prod(S)]+B_+E_ x = distribution_util.rotate_transpose(x, shift=-1) # x.shape: B_+E_+[prod(S)] return x, sample_shape
def make_batch_of_event_sample_matrices( self, x, expand_batch_dim=True, name="make_batch_of_event_sample_matrices"): """Reshapes/transposes `Distribution` `Tensor` from S+B+E to B_+E_+S_. Where: - `B_ = B if B or not expand_batch_dim else [1]`, - `E_ = E if E else [1]`, - `S_ = [tf.reduce_prod(S)]`. Args: x: `Tensor`. expand_batch_dim: Python `bool`. If `True` the batch dims will be expanded such that `batch_ndims >= 1`. name: Python `str`. The name to give this op. Returns: x: `Tensor`. Input transposed/reshaped to `B_+E_+S_`. sample_shape: `Tensor` (1D, `int32`). """ with self._name_scope(name, values=[x]): x = ops.convert_to_tensor(x, name="x") # x.shape: S+B+E sample_shape, batch_shape, event_shape = self.get_shape(x) event_shape = distribution_util.pick_vector( self._event_ndims_is_0, [1], event_shape) if expand_batch_dim: batch_shape = distribution_util.pick_vector( self._batch_ndims_is_0, [1], batch_shape) new_shape = array_ops.concat([[-1], batch_shape, event_shape], 0) x = array_ops.reshape(x, shape=new_shape) # x.shape: [prod(S)]+B_+E_ x = distribution_util.rotate_transpose(x, shift=-1) # x.shape: B_+E_+[prod(S)] return x, sample_shape
def _sample_n(self, n, seed=None): sample_shape = _concat_vectors( distribution_util.pick_vector(self._needs_rotation, self._empty, [n]), self._override_batch_shape, self._override_event_shape, distribution_util.pick_vector(self._needs_rotation, [n], self._empty)) x = self.distribution.sample(sample_shape=sample_shape, seed=seed) x = self._maybe_rotate_dims(x) return self.bijector.forward(x)
def _sample_n(self, n, seed=None): sample_shape = _concat_vectors( distribution_util.pick_vector(self._needs_rotation, self._empty, [n]), self._override_batch_shape, self._override_event_shape, distribution_util.pick_vector(self._needs_rotation, [n], self._empty)) x = self.distribution.sample(sample_shape=sample_shape, seed=seed) x = self._maybe_rotate_dims(x) # We'll apply the bijector in the `_call_sample_n` function. return x
def _make_columnar(self, x): """Ensures non-scalar input has at least one column. Example: If `x = [1, 2, 3]` then the output is `[[1], [2], [3]]`. If `x = [[1, 2, 3], [4, 5, 6]]` then the output is unchanged. If `x = 1` then the output is unchanged. Args: x: `Tensor`. Returns: columnar_x: `Tensor` with at least two dimensions. """ if x.get_shape().ndims is not None: if x.get_shape().ndims == 1: x = x[array_ops.newaxis, :] return x shape = array_ops.shape(x) maybe_expanded_shape = array_ops.concat([ shape[:-1], distribution_util.pick_vector( math_ops.equal(array_ops.rank(x), 1), [1], np.array([], dtype=np.int32)), shape[-1:], ], 0) return array_ops.reshape(x, maybe_expanded_shape)
def _sample_n(self, n, seed=None): # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. batch_size = (np.prod(self.batch_shape.as_list(), dtype=np.int32) if self.batch_shape.is_fully_defined() else math_ops.reduce_prod(self.batch_shape_tensor())) ids = self._mixture_distribution.sample( sample_shape=concat_vectors([n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), [batch_size])), seed=distribution_util.gen_new_seed( seed, "poisson_lognormal_quadrature_compound")) # Stride `quadrature_polynomial_degree` for `batch_size` number of times. offset = math_ops.range(start=0, limit=batch_size * self._degree, delta=self._degree, dtype=ids.dtype) ids += offset rate = array_ops.gather( array_ops.reshape(self.distribution.rate, shape=[-1]), ids) rate = array_ops.reshape(rate, shape=concat_vectors( [n], self.batch_shape_tensor())) return random_ops.random_poisson(lam=rate, shape=[], dtype=self.dtype, seed=seed)
def _make_columnar(self, x): """Ensures non-scalar input has at least one column. Example: If `x = [1, 2, 3]` then the output is `[[1], [2], [3]]`. If `x = [[1, 2, 3], [4, 5, 6]]` then the output is unchanged. If `x = 1` then the output is unchanged. Args: x: `Tensor`. Returns: columnar_x: `Tensor` with at least two dimensions. """ if x.get_shape().ndims is not None: if x.get_shape().ndims == 1: x = x[array_ops.newaxis, :] return x shape = array_ops.shape(x) maybe_expanded_shape = array_ops.concat([ shape[:-1], distribution_util.pick_vector(math_ops.equal(array_ops.rank(x), 1), [1], np.array([], dtype=np.int32)), shape[-1:], ], 0) return array_ops.reshape(x, maybe_expanded_shape)
def _expand_sample_shape_to_vector(self, x, name): """Helper to `sample` which ensures input is 1D.""" x_static_val = tensor_util.constant_value(x) if x_static_val is None: prod = math_ops.reduce_prod(x) else: prod = np.prod(x_static_val, dtype=x.dtype.as_numpy_dtype()) ndims = x.get_shape().ndims # != sample_ndims if ndims is None: # Maybe expand_dims. ndims = array_ops.rank(x) expanded_shape = util.pick_vector( math_ops.equal(ndims, 0), np.array([1], dtype=np.int32), array_ops.shape(x)) x = array_ops.reshape(x, expanded_shape) elif ndims == 0: # Definitely expand_dims. if x_static_val is not None: x = ops.convert_to_tensor( np.array([x_static_val], dtype=x.dtype.as_numpy_dtype()), name=name) else: x = array_ops.reshape(x, [1]) elif ndims != 1: raise ValueError("Input is neither scalar nor vector.") return x, prod
def _sample_n(self, n, seed=None): # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. batch_size = (np.prod(self.batch_shape.as_list(), dtype=np.int32) if self.batch_shape.is_fully_defined() else math_ops.reduce_prod(self.batch_shape_tensor())) ids = self._mixture_distribution.sample( sample_shape=concat_vectors( [n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), [batch_size])), seed=distribution_util.gen_new_seed( seed, "poisson_lognormal_quadrature_compound")) # Stride `quadrature_degree` for `batch_size` number of times. offset = math_ops.range(start=0, limit=batch_size * len(self.quadrature_probs), delta=len(self.quadrature_probs), dtype=ids.dtype) ids += offset rate = array_ops.gather( array_ops.reshape(self.distribution.rate, shape=[-1]), ids) rate = array_ops.reshape( rate, shape=concat_vectors([n], self.batch_shape_tensor())) return random_ops.random_poisson( lam=rate, shape=[], dtype=self.dtype, seed=seed)
def _expand_sample_shape_to_vector(self, x, name): """Helper to `sample` which ensures input is 1D.""" x_static_val = tensor_util.constant_value(x) if x_static_val is None: prod = math_ops.reduce_prod(x) else: prod = np.prod(x_static_val, dtype=x.dtype.as_numpy_dtype()) ndims = x.get_shape().ndims # != sample_ndims if ndims is None: # Maybe expand_dims. ndims = array_ops.rank(x) expanded_shape = util.pick_vector(math_ops.equal(ndims, 0), np.array([1], dtype=np.int32), array_ops.shape(x)) x = array_ops.reshape(x, expanded_shape) elif ndims == 0: # Definitely expand_dims. if x_static_val is not None: x = ops.convert_to_tensor(np.array( [x_static_val], dtype=x.dtype.as_numpy_dtype()), name=name) else: x = array_ops.reshape(x, [1]) elif ndims != 1: raise ValueError("Input is neither scalar nor vector.") return x, prod
def testCorrectlyPicksVector(self): with self.test_session(): x = np.arange(10, 12) y = np.arange(15, 18) self.assertAllEqual( x, distribution_util.pick_vector(math_ops.less(0, 5), x, y).eval()) self.assertAllEqual( y, distribution_util.pick_vector(math_ops.less(5, 0), x, y).eval()) self.assertAllEqual(x, distribution_util.pick_vector( constant_op.constant(True), x, y)) # No eval. self.assertAllEqual(y, distribution_util.pick_vector( constant_op.constant(False), x, y)) # No eval.
def _pad_mix_dims(self, x): with ops.name_scope("pad_mix_dims", values=[x]): def _get_ndims(d): if d.batch_shape.ndims is not None: return d.batch_shape.ndims return array_ops.shape(d.batch_shape_tensor())[0] dist_batch_ndims = _get_ndims(self) cat_batch_ndims = _get_ndims(self.mixture_distribution) bnd = distribution_util.pick_vector( self.mixture_distribution.is_scalar_batch(), [dist_batch_ndims], [cat_batch_ndims])[0] s = array_ops.shape(x) x = array_ops.reshape(x, shape=array_ops.concat([ s[:-1], array_ops.ones([bnd], dtype=dtypes.int32), s[-1:], array_ops.ones([self._event_ndims], dtype=dtypes.int32), ], axis=0)) return x
def _forward_log_det_jacobian(self, x): # Let Y be a symmetric, positive definite matrix and write: # Y = X X.T # where X is lower-triangular. # # Observe that, # dY[i,j]/dX[a,b] # = d/dX[a,b] { X[i,:] X[j,:] } # = sum_{d=1}^p { I[i=a] I[d=b] X[j,d] + I[j=a] I[d=b] X[i,d] } # # To compute the Jacobian dX/dY we must represent X,Y as vectors. Since Y is # symmetric and X is lower-triangular, we need vectors of dimension: # d = p (p + 1) / 2 # where X, Y are p x p matrices, p > 0. We use a row-major mapping, i.e., # k = { i (i + 1) / 2 + j i>=j # { undef i<j # and assume zero-based indexes. When k is undef, the element is dropped. # Example: # j k # 0 1 2 3 / # 0 [ 0 . . . ] # i 1 [ 1 2 . . ] # 2 [ 3 4 5 . ] # 3 [ 6 7 8 9 ] # Write vec[.] to indicate transforming a matrix to vector via k(i,j). (With # slight abuse: k(i,j)=undef means the element is dropped.) # # We now show d vec[Y] / d vec[X] is lower triangular. Assuming both are # defined, observe that k(i,j) < k(a,b) iff (1) i<a or (2) i=a and j<b. # In both cases dvec[Y]/dvec[X]@[k(i,j),k(a,b)] = 0 since: # (1) j<=i<a thus i,j!=a. # (2) i=a>j thus i,j!=a. # # Since the Jacobian is lower-triangular, we need only compute the product # of diagonal elements: # d vec[Y] / d vec[X] @[k(i,j), k(i,j)] # = X[j,j] + I[i=j] X[i,j] # = 2 X[j,j]. # Since there is a 2 X[j,j] term for every lower-triangular element of X we # conclude: # |Jac(d vec[Y]/d vec[X])| = 2^p prod_{j=0}^{p-1} X[j,j]^{p-j}. diag = array_ops.matrix_diag_part(x) # We now ensure diag is columnar. Eg, if `diag = [1, 2, 3]` then the output # is `[[1], [2], [3]]` and if `diag = [[1, 2, 3], [4, 5, 6]]` then the # output is unchanged. diag = self._make_columnar(diag) if self.validate_args: is_matrix = check_ops.assert_rank_at_least( x, 2, message="Input must be a (batch of) matrix.") shape = array_ops.shape(x) is_square = check_ops.assert_equal( shape[-2], shape[-1], message="Input must be a (batch of) square matrix.") # Assuming lower-triangular means we only need check diag>0. is_positive_definite = check_ops.assert_positive( diag, message="Input must be positive definite.") x = control_flow_ops.with_dependencies( [is_matrix, is_square, is_positive_definite], x) # Create a vector equal to: [p, p-1, ..., 2, 1]. if x.get_shape().ndims is None or x.get_shape().dims[-1].value is None: p_int = array_ops.shape(x)[-1] p_float = math_ops.cast(p_int, dtype=x.dtype) else: p_int = x.get_shape().dims[-1].value p_float = np.array(p_int, dtype=x.dtype.as_numpy_dtype) exponents = math_ops.linspace(p_float, 1., p_int) sum_weighted_log_diag = array_ops.squeeze(math_ops.matmul( math_ops.log(diag), exponents[..., array_ops.newaxis]), axis=-1) fldj = p_float * np.log(2.) + sum_weighted_log_diag # We finally need to undo adding an extra column in non-scalar cases # where there is a single matrix as input. if x.get_shape().ndims is not None: if x.get_shape().ndims == 2: fldj = array_ops.squeeze(fldj, axis=-1) return fldj shape = array_ops.shape(fldj) maybe_squeeze_shape = array_ops.concat([ shape[:-1], distribution_util.pick_vector(math_ops.equal(array_ops.rank(x), 2), np.array([], dtype=np.int32), shape[-1:]) ], 0) return array_ops.reshape(fldj, maybe_squeeze_shape)
def _batch_shape_tensor(self): return distribution_util.pick_vector( self._is_batch_override, self._override_batch_shape, self.distribution.batch_shape_tensor())
def _event_shape_tensor(self): return self.bijector.forward_event_shape_tensor( distribution_util.pick_vector( self._is_event_override, self._override_event_shape, self.distribution.event_shape_tensor()))
def _forward_log_det_jacobian(self, x): # Let Y be a symmetric, positive definite matrix and write: # Y = X X.T # where X is lower-triangular. # # Observe that, # dY[i,j]/dX[a,b] # = d/dX[a,b] { X[i,:] X[j,:] } # = sum_{d=1}^p { I[i=a] I[d=b] X[j,d] + I[j=a] I[d=b] X[i,d] } # # To compute the Jacobian dX/dY we must represent X,Y as vectors. Since Y is # symmetric and X is lower-triangular, we need vectors of dimension: # d = p (p + 1) / 2 # where X, Y are p x p matrices, p > 0. We use a row-major mapping, i.e., # k = { i (i + 1) / 2 + j i>=j # { undef i<j # and assume zero-based indexes. When k is undef, the element is dropped. # Example: # j k # 0 1 2 3 / # 0 [ 0 . . . ] # i 1 [ 1 2 . . ] # 2 [ 3 4 5 . ] # 3 [ 6 7 8 9 ] # Write vec[.] to indicate transforming a matrix to vector via k(i,j). (With # slight abuse: k(i,j)=undef means the element is dropped.) # # We now show d vec[Y] / d vec[X] is lower triangular. Assuming both are # defined, observe that k(i,j) < k(a,b) iff (1) i<a or (2) i=a and j<b. # In both cases dvec[Y]/dvec[X]@[k(i,j),k(a,b)] = 0 since: # (1) j<=i<a thus i,j!=a. # (2) i=a>j thus i,j!=a. # # Since the Jacobian is lower-triangular, we need only compute the product # of diagonal elements: # d vec[Y] / d vec[X] @[k(i,j), k(i,j)] # = X[j,j] + I[i=j] X[i,j] # = 2 X[j,j]. # Since there is a 2 X[j,j] term for every lower-triangular element of X we # conclude: # |Jac(d vec[Y]/d vec[X])| = 2^p prod_{j=0}^{p-1} X[j,j]^{p-j}. diag = tf.matrix_diag_part(x) # We now ensure diag is columnar. Eg, if `diag = [1, 2, 3]` then the output # is `[[1], [2], [3]]` and if `diag = [[1, 2, 3], [4, 5, 6]]` then the # output is unchanged. diag = self._make_columnar(diag) if self.validate_args: is_matrix = tf.assert_rank_at_least( x, 2, message="Input must be a (batch of) matrix.") shape = tf.shape(x) is_square = tf.assert_equal( shape[-2], shape[-1], message="Input must be a (batch of) square matrix.") # Assuming lower-triangular means we only need check diag>0. is_positive_definite = tf.assert_positive( diag, message="Input must be positive definite.") x = control_flow_ops.with_dependencies( [is_matrix, is_square, is_positive_definite], x) # Create a vector equal to: [p, p-1, ..., 2, 1]. if x.get_shape().ndims is None or x.get_shape()[-1].value is None: p_int = tf.shape(x)[-1] p_float = tf.cast(p_int, dtype=x.dtype) else: p_int = x.get_shape()[-1].value p_float = np.array(p_int, dtype=x.dtype.as_numpy_dtype) exponents = tf.linspace(p_float, 1., p_int) sum_weighted_log_diag = tf.squeeze( tf.matmul(tf.log(diag), exponents[..., tf.newaxis]), axis=-1) fldj = p_float * np.log(2.) + sum_weighted_log_diag # We finally need to undo adding an extra column in non-scalar cases # where there is a single matrix as input. if x.get_shape().ndims is not None: if x.get_shape().ndims == 2: fldj = tf.squeeze(fldj, axis=-1) return fldj shape = tf.shape(fldj) maybe_squeeze_shape = tf.concat([ shape[:-1], distribution_util.pick_vector( tf.equal(tf.rank(x), 2), np.array([], dtype=np.int32), shape[-1:])], 0) return tf.reshape(fldj, maybe_squeeze_shape)
def __init__(self, df, loc=None, scale_identity_multiplier=None, scale_diag=None, scale_tril=None, scale_perturb_factor=None, scale_perturb_diag=None, validate_args=False, allow_nan_stats=True, name="VectorStudentT"): """Instantiates the vector Student's t-distributions on `R^k`. The `batch_shape` is the broadcast between `df.batch_shape` and `Affine.batch_shape` where `Affine` is constructed from `loc` and `scale_*` arguments. The `event_shape` is the event shape of `Affine.event_shape`. Args: df: Floating-point `Tensor`. The degrees of freedom of the distribution(s). `df` must contain only positive values. Must be scalar if `loc`, `scale_*` imply non-scalar batch_shape or must have the same `batch_shape` implied by `loc`, `scale_*`. loc: Floating-point `Tensor`. If this is set to `None`, no `loc` is applied. scale_identity_multiplier: floating point rank 0 `Tensor` representing a scaling done to the identity matrix. When `scale_identity_multiplier = scale_diag=scale_tril = None` then `scale += IdentityMatrix`. Otherwise no scaled-identity-matrix is added to `scale`. scale_diag: Floating-point `Tensor` representing the diagonal matrix. `scale_diag` has shape [N1, N2, ..., k], which represents a k x k diagonal matrix. When `None` no diagonal term is added to `scale`. scale_tril: Floating-point `Tensor` representing the diagonal matrix. `scale_diag` has shape [N1, N2, ..., k, k], which represents a k x k lower triangular matrix. When `None` no `scale_tril` term is added to `scale`. The upper triangular elements above the diagonal are ignored. scale_perturb_factor: Floating-point `Tensor` representing factor matrix with last two dimensions of shape `(k, r)`. When `None`, no rank-r update is added to `scale`. scale_perturb_diag: Floating-point `Tensor` representing the diagonal matrix. `scale_perturb_diag` has shape [N1, N2, ..., r], which represents an r x r Diagonal matrix. When `None` low rank updates will take the form `scale_perturb_factor * scale_perturb_factor.T`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = locals() graph_parents = [df, loc, scale_identity_multiplier, scale_diag, scale_tril, scale_perturb_factor, scale_perturb_diag] with ops.name_scope(name): with ops.name_scope("init", values=graph_parents): # The shape of the _VectorStudentT distribution is governed by the # relationship between df.batch_shape and affine.batch_shape. In # pseudocode the basic procedure is: # if df.batch_shape is scalar: # if affine.batch_shape is not scalar: # # broadcast distribution.sample so # # it has affine.batch_shape. # self.batch_shape = affine.batch_shape # else: # if affine.batch_shape is scalar: # # let affine broadcasting do its thing. # self.batch_shape = df.batch_shape # All of the above magic is actually handled by TransformedDistribution. # Here we really only need to collect the affine.batch_shape and decide # what we're going to pass in to TransformedDistribution's # (override) batch_shape arg. affine = bijectors.Affine( shift=loc, scale_identity_multiplier=scale_identity_multiplier, scale_diag=scale_diag, scale_tril=scale_tril, scale_perturb_factor=scale_perturb_factor, scale_perturb_diag=scale_perturb_diag, validate_args=validate_args) distribution = student_t.StudentT( df=df, loc=array_ops.zeros([], dtype=affine.dtype), scale=array_ops.ones([], dtype=affine.dtype)) batch_shape, override_event_shape = _infer_shapes( affine.scale, affine.shift) override_batch_shape = distribution_util.pick_vector( distribution.is_scalar_batch(), batch_shape, constant_op.constant([], dtype=dtypes.int32)) super(_VectorStudentT, self).__init__( distribution=distribution, bijector=affine, batch_shape=override_batch_shape, event_shape=override_event_shape, validate_args=validate_args, name=name) self._parameters = parameters
def __init__(self, mixture_distribution, components_distribution, validate_args=False, allow_nan_stats=True, name="MixtureSameFamily"): """Construct a `MixtureSameFamily` distribution. Args: mixture_distribution: `tf.distributions.Categorical`-like instance. Manages the probability of selecting components. The number of categories must match the rightmost batch dimension of the `components_distribution`. Must have either scalar `batch_shape` or `batch_shape` matching `components_distribution.batch_shape[:-1]`. components_distribution: `tf.distributions.Distribution`-like instance. Right-most batch dimension indexes components. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: `if not mixture_distribution.dtype.is_integer`. ValueError: if mixture_distribution does not have scalar `event_shape`. ValueError: if `mixture_distribution.batch_shape` and `components_distribution.batch_shape[:-1]` are both fully defined and the former is neither scalar nor equal to the latter. ValueError: if `mixture_distribution` categories does not equal `components_distribution` rightmost batch shape. """ parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name) as name: self._mixture_distribution = mixture_distribution self._components_distribution = components_distribution self._runtime_assertions = [] s = components_distribution.event_shape_tensor() self._event_ndims = (s.shape[0].value if s.shape.with_rank_at_least(1)[0].value is not None else array_ops.shape(s)[0]) if not mixture_distribution.dtype.is_integer: raise ValueError( "`mixture_distribution.dtype` ({}) is not over integers".format( mixture_distribution.dtype.name)) if (mixture_distribution.event_shape.ndims is not None and mixture_distribution.event_shape.ndims != 0): raise ValueError("`mixture_distribution` must have scalar `event_dim`s") elif validate_args: self._runtime_assertions += [ control_flow_ops.assert_has_rank( mixture_distribution.event_shape_tensor(), 0, message="`mixture_distribution` must have scalar `event_dim`s"), ] mdbs = mixture_distribution.batch_shape cdbs = components_distribution.batch_shape.with_rank_at_least(1)[:-1] if mdbs.is_fully_defined() and cdbs.is_fully_defined(): if mdbs.ndims != 0 and mdbs != cdbs: raise ValueError( "`mixture_distribution.batch_shape` (`{}`) is not " "compatible with `components_distribution.batch_shape` " "(`{}`)".format(mdbs.as_list(), cdbs.as_list())) elif validate_args: mdbs = mixture_distribution.batch_shape_tensor() cdbs = components_distribution.batch_shape_tensor()[:-1] self._runtime_assertions += [ control_flow_ops.assert_equal( distribution_util.pick_vector( mixture_distribution.is_scalar_batch(), cdbs, mdbs), cdbs, message=( "`mixture_distribution.batch_shape` is not " "compatible with `components_distribution.batch_shape`"))] km = mixture_distribution.logits.shape.with_rank_at_least(1)[-1].value kc = components_distribution.batch_shape.with_rank_at_least(1)[-1].value if km is not None and kc is not None and km != kc: raise ValueError("`mixture_distribution components` ({}) does not " "equal `components_distribution.batch_shape[-1]` " "({})".format(km, kc)) elif validate_args: km = array_ops.shape(mixture_distribution.logits)[-1] kc = components_distribution.batch_shape_tensor()[-1] self._runtime_assertions += [ control_flow_ops.assert_equal( km, kc, message=("`mixture_distribution components` does not equal " "`components_distribution.batch_shape[-1:]`")), ] elif km is None: km = array_ops.shape(mixture_distribution.logits)[-1] self._num_components = km super(MixtureSameFamily, self).__init__( dtype=self._components_distribution.dtype, reparameterization_type=distribution.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=( self._mixture_distribution._graph_parents # pylint: disable=protected-access + self._components_distribution._graph_parents), # pylint: disable=protected-access name=name)
def __init__(self, mixture_distribution, components_distribution, validate_args=False, allow_nan_stats=True, name="MixtureSameFamily"): """Construct a `MixtureSameFamily` distribution. Args: mixture_distribution: `tf.distributions.Categorical`-like instance. Manages the probability of selecting components. The number of categories must match the rightmost batch dimension of the `components_distribution`. Must have either scalar `batch_shape` or `batch_shape` matching `components_distribution.batch_shape[:-1]`. components_distribution: `tf.distributions.Distribution`-like instance. Right-most batch dimension indexes components. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: `if not mixture_distribution.dtype.is_integer`. ValueError: if mixture_distribution does not have scalar `event_shape`. ValueError: if `mixture_distribution.batch_shape` and `components_distribution.batch_shape[:-1]` are both fully defined and the former is neither scalar nor equal to the latter. ValueError: if `mixture_distribution` categories does not equal `components_distribution` rightmost batch shape. """ parameters = dict(locals()) with tf.name_scope(name) as name: self._mixture_distribution = mixture_distribution self._components_distribution = components_distribution self._runtime_assertions = [] s = components_distribution.event_shape_tensor() self._event_ndims = (s.shape[0].value if s.shape.with_rank_at_least(1)[0].value is not None else tf.shape(s)[0]) if not mixture_distribution.dtype.is_integer: raise ValueError( "`mixture_distribution.dtype` ({}) is not over integers". format(mixture_distribution.dtype.name)) if (mixture_distribution.event_shape.ndims is not None and mixture_distribution.event_shape.ndims != 0): raise ValueError( "`mixture_distribution` must have scalar `event_dim`s") elif validate_args: self._runtime_assertions += [ control_flow_ops.assert_has_rank( mixture_distribution.event_shape_tensor(), 0, message= "`mixture_distribution` must have scalar `event_dim`s" ), ] mdbs = mixture_distribution.batch_shape cdbs = components_distribution.batch_shape.with_rank_at_least( 1)[:-1] if mdbs.is_fully_defined() and cdbs.is_fully_defined(): if mdbs.ndims != 0 and mdbs != cdbs: raise ValueError( "`mixture_distribution.batch_shape` (`{}`) is not " "compatible with `components_distribution.batch_shape` " "(`{}`)".format(mdbs.as_list(), cdbs.as_list())) elif validate_args: mdbs = mixture_distribution.batch_shape_tensor() cdbs = components_distribution.batch_shape_tensor()[:-1] self._runtime_assertions += [ control_flow_ops.assert_equal( distribution_util.pick_vector( mixture_distribution.is_scalar_batch(), cdbs, mdbs), cdbs, message= ("`mixture_distribution.batch_shape` is not " "compatible with `components_distribution.batch_shape`" )) ] km = mixture_distribution.logits.shape.with_rank_at_least( 1)[-1].value kc = components_distribution.batch_shape.with_rank_at_least( 1)[-1].value if km is not None and kc is not None and km != kc: raise ValueError( "`mixture_distribution components` ({}) does not " "equal `components_distribution.batch_shape[-1]` " "({})".format(km, kc)) elif validate_args: km = tf.shape(mixture_distribution.logits)[-1] kc = components_distribution.batch_shape_tensor()[-1] self._runtime_assertions += [ control_flow_ops.assert_equal( km, kc, message=( "`mixture_distribution components` does not equal " "`components_distribution.batch_shape[-1:]`")), ] elif km is None: km = tf.shape(mixture_distribution.logits)[-1] self._num_components = km super(MixtureSameFamily, self).__init__( dtype=self._components_distribution.dtype, reparameterization_type=tf.distributions.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=( self._mixture_distribution._graph_parents # pylint: disable=protected-access + self._components_distribution._graph_parents), # pylint: disable=protected-access name=name)