コード例 #1
0
ファイル: mvn.py プロジェクト: sigmasharp/w266
    def __init__(self,
                 mu,
                 diag_large,
                 v,
                 diag_small=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="MultivariateNormalDiagPlusVDVT"):
        """Multivariate Normal distributions on `R^k`.

    For every batch member, this distribution represents `k` random variables
    `(X_1,...,X_k)`, with mean `E[X_i] = mu[i]`, and covariance matrix
    `C_{ij} := E[(X_i - mu[i])(X_j - mu[j])]`

    The user initializes this class by providing the mean `mu`, and a
    lightweight definition of `C`:

    ```
    C = SS^T = SS = (M + V D V^T) (M + V D V^T)
    M is diagonal (k x k)
    V = is shape (k x r), typically r << k
    D = is diagonal (r x r), optional (defaults to identity).
    ```

    Args:
      mu:  Rank `n + 1` floating point tensor with shape `[N1,...,Nn, k]`,
        `n >= 0`.  The means.
      diag_large:  Optional rank `n + 1` floating point tensor, shape
        `[N1,...,Nn, k]` `n >= 0`.  Defines the diagonal matrix `M`.
      v:  Rank `n + 1` floating point tensor, shape `[N1,...,Nn, k, r]`
        `n >= 0`.  Defines the matrix `V`.
      diag_small:  Rank `n + 1` floating point tensor, shape
        `[N1,...,Nn, k]` `n >= 0`.  Defines the diagonal matrix `D`.  Default
        is `None`, which means `D` will be the identity matrix.
      validate_args: `Boolean`, default `False`.  Whether to validate input
        with asserts.  If `validate_args` is `False`,
        and the inputs are invalid, correct behavior is not guaranteed.
      allow_nan_stats: `Boolean`, default `True`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to give Ops created by the initializer.
    """
        parameters = locals()
        parameters.pop("self")
        with ops.name_scope(name, values=[diag_large, v, diag_small]) as ns:
            cov = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
                operator_pd_diag.OperatorPDDiag(diag_large,
                                                verify_pd=validate_args),
                v,
                diag=diag_small,
                verify_pd=validate_args,
                verify_shapes=validate_args)
        super(MultivariateNormalDiagPlusVDVT,
              self).__init__(mu,
                             cov,
                             allow_nan_stats=allow_nan_stats,
                             validate_args=validate_args,
                             name=ns)
        self._parameters = parameters
コード例 #2
0
    def test_to_dense_placeholder(self):
        # Test simple functionality when the inputs are placeholders.
        mat_shape = [3, 3]
        v_matrix_rank = 2
        with self.test_session():
            # Make an OperatorPDFull with a matrix placeholder.
            mat_ph = tf.placeholder(tf.float64, name='mat_ph')
            mat = self._random_pd_matrix(mat_shape)
            o_made_with_mat = operator_pd_full.OperatorPDFull(mat_ph)

            # Make the placeholders and arrays for the updated operator.
            v_ph = tf.placeholder(tf.float64, name='v_ph')
            v, diag = self._random_v_and_diag(mat_shape, v_matrix_rank)
            if self._diag_is_none:
                diag_ph = None
                feed_dict = {v_ph: v, mat_ph: mat}
            else:
                diag_ph = tf.placeholder(tf.float64, name='diag_ph')
                feed_dict = {v_ph: v, diag_ph: diag, mat_ph: mat}

            # Make the OperatorPDSqrtVDVTUpdate with v and diag placeholders.
            operator = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
                o_made_with_mat, v_ph, diag=diag_ph)

            # Should not fail
            operator.to_dense().eval(feed_dict=feed_dict)
            operator.log_det().eval(feed_dict=feed_dict)
コード例 #3
0
    def test_operator_not_subclass_of_operator_pd_raises(self):
        # We enforce that `operator` is an `OperatorPDBase`.
        with self.test_session():
            v, diag = self._random_v_and_diag((3, 3), 2)
            operator_m = 'I am not a subclass of OperatorPDBase'

            with self.assertRaisesRegexp(TypeError, 'not instance'):
                operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
                    operator_m, v, diag)
コード例 #4
0
    def testOperatorNotSubclassOfOperatorPdRaises(self):
        # We enforce that `operator` is an `OperatorPDBase`.
        with self.test_session():
            v, diag = self._random_v_and_diag((3, 3), 2)
            operator_m = "I am not a subclass of OperatorPDBase"

            with self.assertRaisesRegexp(TypeError, "not instance"):
                operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
                    operator_m, v, diag)
コード例 #5
0
    def test_tensor_rank_shape_mismatch_v_and_diag_raises_static(self):
        v = self._rng.rand(1, 2, 2, 2)
        diag = self._rng.rand(5, 1)  # Should have rank 1 less than v.
        with self.test_session():

            mat = self._random_pd_matrix((1, 2, 2, 2))  # mat and v match
            operator_m = operator_pd_full.OperatorPDFull(mat)
            with self.assertRaisesRegexp(ValueError, 'diag.*rank'):
                operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
                    operator_m, v, diag)
コード例 #6
0
    def test_batch_shape_mismatch_v_and_diag_raises_static(self):
        v = self._rng.rand(4, 3, 2)
        diag = self._rng.rand(5, 1)  # Should be shape (4, 2,) to match v.
        with self.test_session():

            mat = self._random_pd_matrix((4, 3, 3))  # mat and v match
            operator_m = operator_pd_full.OperatorPDFull(mat)
            with self.assertRaisesRegexp(ValueError, 'diag.*batch shape'):
                operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
                    operator_m, v, diag)
コード例 #7
0
    def testTensorRankShapeMismatchVAndDiagRaisesStatic(self):
        v = self._rng.rand(1, 2, 2, 2)
        diag = self._rng.rand(5, 1)  # Should have rank 1 less than v.
        with self.test_session():

            mat = self._random_pd_matrix((1, 2, 2, 2))  # mat and v match
            operator_m = operator_pd_full.OperatorPDFull(mat)
            with self.assertRaisesRegexp(ValueError, "diag.*rank"):
                operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
                    operator_m, v, diag)
コード例 #8
0
    def testEventShapeMismatchVAndDiagRaisesStatic(self):
        v = self._rng.rand(4, 3, 2)
        diag = self._rng.rand(4, 1)  # Should be shape (4, 2,) to match v.
        with self.test_session():

            mat = self._random_pd_matrix((4, 3, 3))  # mat and v match
            operator_m = operator_pd_full.OperatorPDFull(mat)
            with self.assertRaisesRegexp(ValueError,
                                         "diag.*v.*last dimension"):
                operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
                    operator_m, v, diag)
コード例 #9
0
    def test_non_pos_def_diag_doesnt_raise_if_verify_pd_false(self):
        # We enforce that the diag is positive definite.
        if self._diag_is_none:
            return
        with self.test_session():
            matrix_shape = (3, 3)
            v_rank = 2
            v, diag = self._random_v_and_diag(matrix_shape, v_rank)
            mat = self._random_pd_matrix(matrix_shape)
            diag[0] = 0.0

            operator_m = operator_pd_full.OperatorPDFull(mat)
            operator = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
                operator_m, v, diag, verify_pd=False)

            operator.to_dense().eval()  # Should not raise.
コード例 #10
0
    def test_non_pos_def_diag_raises(self):
        if self._diag_is_none:
            return
        # We enforce that the diag is positive definite.
        with self.test_session():
            matrix_shape = (3, 3)
            v_rank = 2
            v, diag = self._random_v_and_diag(matrix_shape, v_rank)
            mat = self._random_pd_matrix(matrix_shape)
            diag[0] = 0.0

            operator_m = operator_pd_full.OperatorPDFull(mat)
            operator = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
                operator_m, v, diag)

            with self.assertRaisesOpError('positive'):
                operator.to_dense().eval()
コード例 #11
0
    def test_batch_shape_mismatch_v_and_diag_raises_dynamic(self):
        with self.test_session():
            v = self._rng.rand(4, 3, 2)
            diag = self._rng.rand(5, 1)  # Should be shape (4, 2,) to match v.
            mat = self._random_pd_matrix((4, 3, 3))  # mat and v match

            v_ph = tf.placeholder(tf.float32, name='v_ph')
            diag_ph = tf.placeholder(tf.float32, name='diag_ph')
            mat_ph = tf.placeholder(tf.float32, name='mat_ph')

            operator_m = operator_pd_full.OperatorPDFull(mat_ph)
            updated = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
                operator_m, v_ph, diag_ph)
            with self.assertRaisesOpError('x == y'):
                updated.to_dense().eval(feed_dict={
                    v_ph: v,
                    diag_ph: diag,
                    mat_ph: mat
                })
コード例 #12
0
    def testBatchShapeMismatchVAndDiagRaisesDynamic(self):
        with self.test_session():
            v = self._rng.rand(4, 3, 2)
            diag = self._rng.rand(5, 1)  # Should be shape (4, 2,) to match v.
            mat = self._random_pd_matrix((4, 3, 3))  # mat and v match

            v_ph = tf.placeholder(tf.float32, name="v_ph")
            diag_ph = tf.placeholder(tf.float32, name="diag_ph")
            mat_ph = tf.placeholder(tf.float32, name="mat_ph")

            operator_m = operator_pd_full.OperatorPDFull(mat_ph)
            updated = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
                operator_m, v_ph, diag_ph)
            with self.assertRaisesOpError("x == y"):
                updated.to_dense().eval(feed_dict={
                    v_ph: v,
                    diag_ph: diag,
                    mat_ph: mat
                })
コード例 #13
0
    def test_tensor_rank_shape_mismatch_v_and_diag_raises_dynamic(self):
        with self.test_session():

            v = self._rng.rand(2, 2, 2, 2)
            diag = self._rng.rand(2, 2)  # Should have rank 1 less than v.
            mat = self._random_pd_matrix((2, 2, 2, 2))  # mat and v match

            v_ph = tf.placeholder(tf.float32, name="v_ph")
            diag_ph = tf.placeholder(tf.float32, name="diag_ph")
            mat_ph = tf.placeholder(tf.float32, name="mat_ph")

            operator_m = operator_pd_full.OperatorPDFull(mat_ph)
            updated = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
                operator_m, v_ph, diag_ph)
            with self.assertRaisesOpError("rank"):
                updated.to_dense().eval(feed_dict={
                    v_ph: v,
                    diag_ph: diag,
                    mat_ph: mat
                })
コード例 #14
0
    def _build_operator_and_mat(self, batch_shape, k, dtype=np.float64):
        """This method is called by base class, enabling many standard tests."""
        # Create a matrix then explicitly update it with v and diag.
        # Create an OperatorPDSqrtVDVTUpdate from the matrix and v and diag
        # The operator should have the same behavior.
        #
        # The low-rank matrix V will have rank 1/2 of k, unless k is 1, in which
        # case it will be 1 as well.
        if k == 1:
            v_matrix_rank = k
        else:
            v_matrix_rank = k // 2
        mat_shape = list(batch_shape) + [k, k]
        mat = self._random_pd_matrix(mat_shape)
        v, diag = self._random_v_and_diag(mat_shape, v_matrix_rank)

        # Set dtypes
        mat = mat.astype(dtype)
        v = v.astype(dtype)
        if diag is not None:
            diag = diag.astype(dtype)

        # The matrix: (mat + v*diag*v^T) * (mat + v*diag*v^T)^T
        # Our final updated operator should behave like this.
        updated_mat = self._updated_mat(mat, v, diag)

        # Represents the matrix: `mat`, before updating.
        # This is the Operator that we will update.
        o_made_with_mat = operator_pd_full.OperatorPDFull(mat)

        # Represents the matrix: (mat + v*diag*v^T) * (mat + v*diag*v^T)^T,
        # achieved by updating the operator "o_made_with_mat".
        # This is the operator we're testing.
        operator = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
            o_made_with_mat, v, diag)

        return operator, updated_mat
コード例 #15
0
    def _create_scale_operator(self, identity_multiplier, diag, tril,
                               perturb_diag, perturb_factor, event_ndims,
                               validate_args):
        """Construct `scale` from various components.

    Args:
      identity_multiplier: floating point rank 0 `Tensor` representing a scaling
        done to the identity matrix.
      diag: Floating-point `Tensor` representing the diagonal matrix.
        `scale_diag` has shape [N1, N2, ...  k], which represents a k x k
        diagonal matrix.
      tril: Floating-point `Tensor` representing the diagonal matrix.
        `scale_tril` has shape [N1, N2, ...  k], which represents a k x k lower
        triangular matrix.
      perturb_diag: Floating-point `Tensor` representing the diagonal matrix of
        the low rank update.
      perturb_factor: Floating-point `Tensor` representing factor matrix.
      event_ndims: Scalar `int32` `Tensor` indicating the number of dimensions
        associated with a particular draw from the distribution. Must be 0 or 1
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.

    Returns:
      scale. In the case of scaling by a constant, scale is a
      floating point `Tensor`. Otherwise, scale is an `OperatorPD`.

    Raises:
      ValueError: if all of `tril`, `diag` and `identity_multiplier` are `None`.
    """
        identity_multiplier = _as_tensor(identity_multiplier,
                                         "identity_multiplier")
        diag = _as_tensor(diag, "diag")
        tril = _as_tensor(tril, "tril")
        perturb_diag = _as_tensor(perturb_diag, "perturb_diag")
        perturb_factor = _as_tensor(perturb_factor, "perturb_factor")

        identity_multiplier = self._maybe_validate_identity_multiplier(
            identity_multiplier, validate_args)

        if perturb_factor is not None:
            perturb_factor = self._process_matrix(perturb_factor,
                                                  min_rank=2,
                                                  event_ndims=event_ndims)

        if perturb_diag is not None:
            perturb_diag = self._process_matrix(perturb_diag,
                                                min_rank=1,
                                                event_ndims=event_ndims)

        # The following if-statments are ordered by increasingly stronger
        # assumptions in the base matrix, i.e., we process in the order:
        # TriL, Diag, Identity.

        if tril is not None:
            tril = self._preprocess_tril(identity_multiplier, diag, tril,
                                         event_ndims)
            if perturb_factor is None:
                return operator_pd_cholesky.OperatorPDCholesky(
                    tril, verify_pd=validate_args)
            return _TriLPlusVDVTLightweightOperatorPD(
                tril=tril,
                v=perturb_factor,
                diag=perturb_diag,
                validate_args=validate_args)

        if diag is not None:
            diag = self._preprocess_diag(identity_multiplier, diag,
                                         event_ndims)
            if perturb_factor is None:
                return operator_pd_diag.OperatorPDSqrtDiag(
                    diag, verify_pd=validate_args)
            return operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
                operator=operator_pd_diag.OperatorPDDiag(
                    diag, verify_pd=validate_args),
                v=perturb_factor,
                diag=perturb_diag,
                verify_pd=validate_args)

        if identity_multiplier is not None:
            if perturb_factor is None:
                return identity_multiplier
            # Infer the shape from the V and D.
            v_shape = array_ops.shape(perturb_factor)
            identity_shape = array_ops.concat([v_shape[:-1], [v_shape[-2]]], 0)
            scaled_identity = operator_pd_identity.OperatorPDIdentity(
                identity_shape,
                perturb_factor.dtype.base_dtype,
                scale=identity_multiplier,
                verify_pd=validate_args)
            return operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
                operator=scaled_identity,
                v=perturb_factor,
                diag=perturb_diag,
                verify_pd=validate_args)

        raise ValueError(
            "One of tril, diag and/or identity_multiplier must be "
            "specified.")