def __init__(self, mu, diag_large, v, diag_small=None, validate_args=False, allow_nan_stats=True, name="MultivariateNormalDiagPlusVDVT"): """Multivariate Normal distributions on `R^k`. For every batch member, this distribution represents `k` random variables `(X_1,...,X_k)`, with mean `E[X_i] = mu[i]`, and covariance matrix `C_{ij} := E[(X_i - mu[i])(X_j - mu[j])]` The user initializes this class by providing the mean `mu`, and a lightweight definition of `C`: ``` C = SS^T = SS = (M + V D V^T) (M + V D V^T) M is diagonal (k x k) V = is shape (k x r), typically r << k D = is diagonal (r x r), optional (defaults to identity). ``` Args: mu: Rank `n + 1` floating point tensor with shape `[N1,...,Nn, k]`, `n >= 0`. The means. diag_large: Optional rank `n + 1` floating point tensor, shape `[N1,...,Nn, k]` `n >= 0`. Defines the diagonal matrix `M`. v: Rank `n + 1` floating point tensor, shape `[N1,...,Nn, k, r]` `n >= 0`. Defines the matrix `V`. diag_small: Rank `n + 1` floating point tensor, shape `[N1,...,Nn, k]` `n >= 0`. Defines the diagonal matrix `D`. Default is `None`, which means `D` will be the identity matrix. validate_args: `Boolean`, default `False`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. allow_nan_stats: `Boolean`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. """ parameters = locals() parameters.pop("self") with ops.name_scope(name, values=[diag_large, v, diag_small]) as ns: cov = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate( operator_pd_diag.OperatorPDDiag(diag_large, verify_pd=validate_args), v, diag=diag_small, verify_pd=validate_args, verify_shapes=validate_args) super(MultivariateNormalDiagPlusVDVT, self).__init__(mu, cov, allow_nan_stats=allow_nan_stats, validate_args=validate_args, name=ns) self._parameters = parameters
def test_to_dense_placeholder(self): # Test simple functionality when the inputs are placeholders. mat_shape = [3, 3] v_matrix_rank = 2 with self.test_session(): # Make an OperatorPDFull with a matrix placeholder. mat_ph = tf.placeholder(tf.float64, name='mat_ph') mat = self._random_pd_matrix(mat_shape) o_made_with_mat = operator_pd_full.OperatorPDFull(mat_ph) # Make the placeholders and arrays for the updated operator. v_ph = tf.placeholder(tf.float64, name='v_ph') v, diag = self._random_v_and_diag(mat_shape, v_matrix_rank) if self._diag_is_none: diag_ph = None feed_dict = {v_ph: v, mat_ph: mat} else: diag_ph = tf.placeholder(tf.float64, name='diag_ph') feed_dict = {v_ph: v, diag_ph: diag, mat_ph: mat} # Make the OperatorPDSqrtVDVTUpdate with v and diag placeholders. operator = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate( o_made_with_mat, v_ph, diag=diag_ph) # Should not fail operator.to_dense().eval(feed_dict=feed_dict) operator.log_det().eval(feed_dict=feed_dict)
def test_operator_not_subclass_of_operator_pd_raises(self): # We enforce that `operator` is an `OperatorPDBase`. with self.test_session(): v, diag = self._random_v_and_diag((3, 3), 2) operator_m = 'I am not a subclass of OperatorPDBase' with self.assertRaisesRegexp(TypeError, 'not instance'): operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate( operator_m, v, diag)
def testOperatorNotSubclassOfOperatorPdRaises(self): # We enforce that `operator` is an `OperatorPDBase`. with self.test_session(): v, diag = self._random_v_and_diag((3, 3), 2) operator_m = "I am not a subclass of OperatorPDBase" with self.assertRaisesRegexp(TypeError, "not instance"): operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate( operator_m, v, diag)
def test_tensor_rank_shape_mismatch_v_and_diag_raises_static(self): v = self._rng.rand(1, 2, 2, 2) diag = self._rng.rand(5, 1) # Should have rank 1 less than v. with self.test_session(): mat = self._random_pd_matrix((1, 2, 2, 2)) # mat and v match operator_m = operator_pd_full.OperatorPDFull(mat) with self.assertRaisesRegexp(ValueError, 'diag.*rank'): operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate( operator_m, v, diag)
def test_batch_shape_mismatch_v_and_diag_raises_static(self): v = self._rng.rand(4, 3, 2) diag = self._rng.rand(5, 1) # Should be shape (4, 2,) to match v. with self.test_session(): mat = self._random_pd_matrix((4, 3, 3)) # mat and v match operator_m = operator_pd_full.OperatorPDFull(mat) with self.assertRaisesRegexp(ValueError, 'diag.*batch shape'): operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate( operator_m, v, diag)
def testTensorRankShapeMismatchVAndDiagRaisesStatic(self): v = self._rng.rand(1, 2, 2, 2) diag = self._rng.rand(5, 1) # Should have rank 1 less than v. with self.test_session(): mat = self._random_pd_matrix((1, 2, 2, 2)) # mat and v match operator_m = operator_pd_full.OperatorPDFull(mat) with self.assertRaisesRegexp(ValueError, "diag.*rank"): operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate( operator_m, v, diag)
def testEventShapeMismatchVAndDiagRaisesStatic(self): v = self._rng.rand(4, 3, 2) diag = self._rng.rand(4, 1) # Should be shape (4, 2,) to match v. with self.test_session(): mat = self._random_pd_matrix((4, 3, 3)) # mat and v match operator_m = operator_pd_full.OperatorPDFull(mat) with self.assertRaisesRegexp(ValueError, "diag.*v.*last dimension"): operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate( operator_m, v, diag)
def test_non_pos_def_diag_doesnt_raise_if_verify_pd_false(self): # We enforce that the diag is positive definite. if self._diag_is_none: return with self.test_session(): matrix_shape = (3, 3) v_rank = 2 v, diag = self._random_v_and_diag(matrix_shape, v_rank) mat = self._random_pd_matrix(matrix_shape) diag[0] = 0.0 operator_m = operator_pd_full.OperatorPDFull(mat) operator = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate( operator_m, v, diag, verify_pd=False) operator.to_dense().eval() # Should not raise.
def test_non_pos_def_diag_raises(self): if self._diag_is_none: return # We enforce that the diag is positive definite. with self.test_session(): matrix_shape = (3, 3) v_rank = 2 v, diag = self._random_v_and_diag(matrix_shape, v_rank) mat = self._random_pd_matrix(matrix_shape) diag[0] = 0.0 operator_m = operator_pd_full.OperatorPDFull(mat) operator = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate( operator_m, v, diag) with self.assertRaisesOpError('positive'): operator.to_dense().eval()
def test_batch_shape_mismatch_v_and_diag_raises_dynamic(self): with self.test_session(): v = self._rng.rand(4, 3, 2) diag = self._rng.rand(5, 1) # Should be shape (4, 2,) to match v. mat = self._random_pd_matrix((4, 3, 3)) # mat and v match v_ph = tf.placeholder(tf.float32, name='v_ph') diag_ph = tf.placeholder(tf.float32, name='diag_ph') mat_ph = tf.placeholder(tf.float32, name='mat_ph') operator_m = operator_pd_full.OperatorPDFull(mat_ph) updated = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate( operator_m, v_ph, diag_ph) with self.assertRaisesOpError('x == y'): updated.to_dense().eval(feed_dict={ v_ph: v, diag_ph: diag, mat_ph: mat })
def testBatchShapeMismatchVAndDiagRaisesDynamic(self): with self.test_session(): v = self._rng.rand(4, 3, 2) diag = self._rng.rand(5, 1) # Should be shape (4, 2,) to match v. mat = self._random_pd_matrix((4, 3, 3)) # mat and v match v_ph = tf.placeholder(tf.float32, name="v_ph") diag_ph = tf.placeholder(tf.float32, name="diag_ph") mat_ph = tf.placeholder(tf.float32, name="mat_ph") operator_m = operator_pd_full.OperatorPDFull(mat_ph) updated = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate( operator_m, v_ph, diag_ph) with self.assertRaisesOpError("x == y"): updated.to_dense().eval(feed_dict={ v_ph: v, diag_ph: diag, mat_ph: mat })
def test_tensor_rank_shape_mismatch_v_and_diag_raises_dynamic(self): with self.test_session(): v = self._rng.rand(2, 2, 2, 2) diag = self._rng.rand(2, 2) # Should have rank 1 less than v. mat = self._random_pd_matrix((2, 2, 2, 2)) # mat and v match v_ph = tf.placeholder(tf.float32, name="v_ph") diag_ph = tf.placeholder(tf.float32, name="diag_ph") mat_ph = tf.placeholder(tf.float32, name="mat_ph") operator_m = operator_pd_full.OperatorPDFull(mat_ph) updated = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate( operator_m, v_ph, diag_ph) with self.assertRaisesOpError("rank"): updated.to_dense().eval(feed_dict={ v_ph: v, diag_ph: diag, mat_ph: mat })
def _build_operator_and_mat(self, batch_shape, k, dtype=np.float64): """This method is called by base class, enabling many standard tests.""" # Create a matrix then explicitly update it with v and diag. # Create an OperatorPDSqrtVDVTUpdate from the matrix and v and diag # The operator should have the same behavior. # # The low-rank matrix V will have rank 1/2 of k, unless k is 1, in which # case it will be 1 as well. if k == 1: v_matrix_rank = k else: v_matrix_rank = k // 2 mat_shape = list(batch_shape) + [k, k] mat = self._random_pd_matrix(mat_shape) v, diag = self._random_v_and_diag(mat_shape, v_matrix_rank) # Set dtypes mat = mat.astype(dtype) v = v.astype(dtype) if diag is not None: diag = diag.astype(dtype) # The matrix: (mat + v*diag*v^T) * (mat + v*diag*v^T)^T # Our final updated operator should behave like this. updated_mat = self._updated_mat(mat, v, diag) # Represents the matrix: `mat`, before updating. # This is the Operator that we will update. o_made_with_mat = operator_pd_full.OperatorPDFull(mat) # Represents the matrix: (mat + v*diag*v^T) * (mat + v*diag*v^T)^T, # achieved by updating the operator "o_made_with_mat". # This is the operator we're testing. operator = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate( o_made_with_mat, v, diag) return operator, updated_mat
def _create_scale_operator(self, identity_multiplier, diag, tril, perturb_diag, perturb_factor, event_ndims, validate_args): """Construct `scale` from various components. Args: identity_multiplier: floating point rank 0 `Tensor` representing a scaling done to the identity matrix. diag: Floating-point `Tensor` representing the diagonal matrix. `scale_diag` has shape [N1, N2, ... k], which represents a k x k diagonal matrix. tril: Floating-point `Tensor` representing the diagonal matrix. `scale_tril` has shape [N1, N2, ... k], which represents a k x k lower triangular matrix. perturb_diag: Floating-point `Tensor` representing the diagonal matrix of the low rank update. perturb_factor: Floating-point `Tensor` representing factor matrix. event_ndims: Scalar `int32` `Tensor` indicating the number of dimensions associated with a particular draw from the distribution. Must be 0 or 1 validate_args: Python `bool` indicating whether arguments should be checked for correctness. Returns: scale. In the case of scaling by a constant, scale is a floating point `Tensor`. Otherwise, scale is an `OperatorPD`. Raises: ValueError: if all of `tril`, `diag` and `identity_multiplier` are `None`. """ identity_multiplier = _as_tensor(identity_multiplier, "identity_multiplier") diag = _as_tensor(diag, "diag") tril = _as_tensor(tril, "tril") perturb_diag = _as_tensor(perturb_diag, "perturb_diag") perturb_factor = _as_tensor(perturb_factor, "perturb_factor") identity_multiplier = self._maybe_validate_identity_multiplier( identity_multiplier, validate_args) if perturb_factor is not None: perturb_factor = self._process_matrix(perturb_factor, min_rank=2, event_ndims=event_ndims) if perturb_diag is not None: perturb_diag = self._process_matrix(perturb_diag, min_rank=1, event_ndims=event_ndims) # The following if-statments are ordered by increasingly stronger # assumptions in the base matrix, i.e., we process in the order: # TriL, Diag, Identity. if tril is not None: tril = self._preprocess_tril(identity_multiplier, diag, tril, event_ndims) if perturb_factor is None: return operator_pd_cholesky.OperatorPDCholesky( tril, verify_pd=validate_args) return _TriLPlusVDVTLightweightOperatorPD( tril=tril, v=perturb_factor, diag=perturb_diag, validate_args=validate_args) if diag is not None: diag = self._preprocess_diag(identity_multiplier, diag, event_ndims) if perturb_factor is None: return operator_pd_diag.OperatorPDSqrtDiag( diag, verify_pd=validate_args) return operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate( operator=operator_pd_diag.OperatorPDDiag( diag, verify_pd=validate_args), v=perturb_factor, diag=perturb_diag, verify_pd=validate_args) if identity_multiplier is not None: if perturb_factor is None: return identity_multiplier # Infer the shape from the V and D. v_shape = array_ops.shape(perturb_factor) identity_shape = array_ops.concat([v_shape[:-1], [v_shape[-2]]], 0) scaled_identity = operator_pd_identity.OperatorPDIdentity( identity_shape, perturb_factor.dtype.base_dtype, scale=identity_multiplier, verify_pd=validate_args) return operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate( operator=scaled_identity, v=perturb_factor, diag=perturb_diag, verify_pd=validate_args) raise ValueError( "One of tril, diag and/or identity_multiplier must be " "specified.")