コード例 #1
0
ファイル: mvn.py プロジェクト: xordlxordl/tensorflow
    def __init__(self,
                 mu,
                 sigma,
                 validate_args=True,
                 allow_nan_stats=False,
                 name="MultivariateNormalFull"):
        """Multivariate Normal distributions on `R^k`.

    User must provide means `mu` and `sigma`, the mean and covariance.

    Args:
      mu: `(N+1)-D` floating point tensor with shape `[N1,...,Nb, k]`,
        `b >= 0`.
      sigma: `(N+2)-D` `Tensor` with same `dtype` as `mu` and shape
        `[N1,...,Nb, k, k]`.  Each batch member must be positive definite.
      validate_args: Whether to validate input with asserts.  If `validate_args`
        is `False`, and the inputs are invalid, correct behavior is not
        guaranteed.
      allow_nan_stats:  `Boolean`, default `False`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to give Ops created by the initializer.

    Raises:
      TypeError: If `mu` and `sigma` are different dtypes.
    """
        cov = operator_pd_full.OperatorPDFull(sigma, verify_pd=validate_args)
        super(MultivariateNormalFull,
              self).__init__(mu,
                             cov,
                             allow_nan_stats=allow_nan_stats,
                             validate_args=validate_args,
                             name=name)
コード例 #2
0
  def __init__(
      self,
      mu,
      sigma,
      strict=True,
      strict_statistics=True,
      name="MultivariateNormalFull"):
    """Multivariate Normal distributions on `R^k`.

    User must provide means `mu` and `sigma`, the mean and covariance.

    Args:
      mu: `(N+1)-D`  `float` or `double` tensor with shape `[N1,...,Nb, k]`,
        `b >= 0`.
      sigma: `(N+2)-D` `Tensor` with same `dtype` as `mu` and shape
        `[N1,...,Nb, k, k]`.
      strict: Whether to validate input with asserts.  If `strict` is `False`,
        and the inputs are invalid, correct behavior is not guaranteed.
      strict_statistics:  Boolean, default True.  If True, raise an exception if
        a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
        If False, batch members with valid parameters leading to undefined
        statistics will return NaN for this statistic.
      name: The name to give Ops created by the initializer.

    Raises:
      TypeError: If `mu` and `sigma` are different dtypes.
    """
    cov = operator_pd_full.OperatorPDFull(sigma, verify_pd=strict)
    super(MultivariateNormalFull, self).__init__(
        mu, cov, strict_statistics=strict_statistics, strict=strict, name=name)
コード例 #3
0
 def testNonSymmetricMatrixRaises(self):
     with self.test_session():
         matrix = self._random_positive_def_array(3, 2, 2)
         matrix[0, 0, 1] += 0.001
         operator = operator_pd_full.OperatorPDFull(matrix, verify_pd=True)
         with self.assertRaisesOpError("x == y"):
             operator.to_dense().eval()
コード例 #4
0
    def test_to_dense_placeholder(self):
        # Test simple functionality when the inputs are placeholders.
        mat_shape = [3, 3]
        v_matrix_rank = 2
        with self.test_session():
            # Make an OperatorPDFull with a matrix placeholder.
            mat_ph = tf.placeholder(tf.float64, name='mat_ph')
            mat = self._random_pd_matrix(mat_shape)
            o_made_with_mat = operator_pd_full.OperatorPDFull(mat_ph)

            # Make the placeholders and arrays for the updated operator.
            v_ph = tf.placeholder(tf.float64, name='v_ph')
            v, diag = self._random_v_and_diag(mat_shape, v_matrix_rank)
            if self._diag_is_none:
                diag_ph = None
                feed_dict = {v_ph: v, mat_ph: mat}
            else:
                diag_ph = tf.placeholder(tf.float64, name='diag_ph')
                feed_dict = {v_ph: v, diag_ph: diag, mat_ph: mat}

            # Make the OperatorPDSqrtVDVTUpdate with v and diag placeholders.
            operator = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
                o_made_with_mat, v_ph, diag=diag_ph)

            # Should not fail
            operator.to_dense().eval(feed_dict=feed_dict)
            operator.log_det().eval(feed_dict=feed_dict)
コード例 #5
0
 def testNegativeDefiniteMatrixRaises(self):
     with self.test_session():
         matrix = -1 * self._random_positive_def_array(3, 2, 2)
         operator = operator_pd_full.OperatorPDFull(matrix, verify_pd=True)
         # Could fail inside Cholesky decomposition, or later when we test the
         # diag.
         with self.assertRaisesOpError("x > 0|LLT"):
             operator.to_dense().eval()
コード例 #6
0
    def test_tensor_rank_shape_mismatch_v_and_diag_raises_static(self):
        v = self._rng.rand(1, 2, 2, 2)
        diag = self._rng.rand(5, 1)  # Should have rank 1 less than v.
        with self.test_session():

            mat = self._random_pd_matrix((1, 2, 2, 2))  # mat and v match
            operator_m = operator_pd_full.OperatorPDFull(mat)
            with self.assertRaisesRegexp(ValueError, 'diag.*rank'):
                operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
                    operator_m, v, diag)
コード例 #7
0
    def test_batch_shape_mismatch_v_and_diag_raises_static(self):
        v = self._rng.rand(4, 3, 2)
        diag = self._rng.rand(5, 1)  # Should be shape (4, 2,) to match v.
        with self.test_session():

            mat = self._random_pd_matrix((4, 3, 3))  # mat and v match
            operator_m = operator_pd_full.OperatorPDFull(mat)
            with self.assertRaisesRegexp(ValueError, 'diag.*batch shape'):
                operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
                    operator_m, v, diag)
コード例 #8
0
    def testTensorRankShapeMismatchVAndDiagRaisesStatic(self):
        v = self._rng.rand(1, 2, 2, 2)
        diag = self._rng.rand(5, 1)  # Should have rank 1 less than v.
        with self.test_session():

            mat = self._random_pd_matrix((1, 2, 2, 2))  # mat and v match
            operator_m = operator_pd_full.OperatorPDFull(mat)
            with self.assertRaisesRegexp(ValueError, "diag.*rank"):
                operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
                    operator_m, v, diag)
コード例 #9
0
    def testEventShapeMismatchVAndDiagRaisesStatic(self):
        v = self._rng.rand(4, 3, 2)
        diag = self._rng.rand(4, 1)  # Should be shape (4, 2,) to match v.
        with self.test_session():

            mat = self._random_pd_matrix((4, 3, 3))  # mat and v match
            operator_m = operator_pd_full.OperatorPDFull(mat)
            with self.assertRaisesRegexp(ValueError,
                                         "diag.*v.*last dimension"):
                operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
                    operator_m, v, diag)
コード例 #10
0
    def test_non_pos_def_diag_doesnt_raise_if_verify_pd_false(self):
        # We enforce that the diag is positive definite.
        if self._diag_is_none:
            return
        with self.test_session():
            matrix_shape = (3, 3)
            v_rank = 2
            v, diag = self._random_v_and_diag(matrix_shape, v_rank)
            mat = self._random_pd_matrix(matrix_shape)
            diag[0] = 0.0

            operator_m = operator_pd_full.OperatorPDFull(mat)
            operator = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
                operator_m, v, diag, verify_pd=False)

            operator.to_dense().eval()  # Should not raise.
コード例 #11
0
    def test_non_pos_def_diag_raises(self):
        if self._diag_is_none:
            return
        # We enforce that the diag is positive definite.
        with self.test_session():
            matrix_shape = (3, 3)
            v_rank = 2
            v, diag = self._random_v_and_diag(matrix_shape, v_rank)
            mat = self._random_pd_matrix(matrix_shape)
            diag[0] = 0.0

            operator_m = operator_pd_full.OperatorPDFull(mat)
            operator = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
                operator_m, v, diag)

            with self.assertRaisesOpError('positive'):
                operator.to_dense().eval()
コード例 #12
0
    def test_batch_shape_mismatch_v_and_diag_raises_dynamic(self):
        with self.test_session():
            v = self._rng.rand(4, 3, 2)
            diag = self._rng.rand(5, 1)  # Should be shape (4, 2,) to match v.
            mat = self._random_pd_matrix((4, 3, 3))  # mat and v match

            v_ph = tf.placeholder(tf.float32, name='v_ph')
            diag_ph = tf.placeholder(tf.float32, name='diag_ph')
            mat_ph = tf.placeholder(tf.float32, name='mat_ph')

            operator_m = operator_pd_full.OperatorPDFull(mat_ph)
            updated = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
                operator_m, v_ph, diag_ph)
            with self.assertRaisesOpError('x == y'):
                updated.to_dense().eval(feed_dict={
                    v_ph: v,
                    diag_ph: diag,
                    mat_ph: mat
                })
コード例 #13
0
    def testBatchShapeMismatchVAndDiagRaisesDynamic(self):
        with self.test_session():
            v = self._rng.rand(4, 3, 2)
            diag = self._rng.rand(5, 1)  # Should be shape (4, 2,) to match v.
            mat = self._random_pd_matrix((4, 3, 3))  # mat and v match

            v_ph = tf.placeholder(tf.float32, name="v_ph")
            diag_ph = tf.placeholder(tf.float32, name="diag_ph")
            mat_ph = tf.placeholder(tf.float32, name="mat_ph")

            operator_m = operator_pd_full.OperatorPDFull(mat_ph)
            updated = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
                operator_m, v_ph, diag_ph)
            with self.assertRaisesOpError("x == y"):
                updated.to_dense().eval(feed_dict={
                    v_ph: v,
                    diag_ph: diag,
                    mat_ph: mat
                })
コード例 #14
0
    def test_tensor_rank_shape_mismatch_v_and_diag_raises_dynamic(self):
        with self.test_session():

            v = self._rng.rand(2, 2, 2, 2)
            diag = self._rng.rand(2, 2)  # Should have rank 1 less than v.
            mat = self._random_pd_matrix((2, 2, 2, 2))  # mat and v match

            v_ph = tf.placeholder(tf.float32, name="v_ph")
            diag_ph = tf.placeholder(tf.float32, name="diag_ph")
            mat_ph = tf.placeholder(tf.float32, name="mat_ph")

            operator_m = operator_pd_full.OperatorPDFull(mat_ph)
            updated = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
                operator_m, v_ph, diag_ph)
            with self.assertRaisesOpError("rank"):
                updated.to_dense().eval(feed_dict={
                    v_ph: v,
                    diag_ph: diag,
                    mat_ph: mat
                })
コード例 #15
0
    def __init__(self,
                 df,
                 scale,
                 cholesky_input_output_matrices=False,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="WishartFull"):
        """Construct Wishart distributions.

    Args:
      df: `float` or `double` `Tensor`. Degrees of freedom, must be greater than
        or equal to dimension of the scale matrix.
      scale: `float` or `double` `Tensor`. The symmetric positive definite
        scale matrix of the distribution.
      cholesky_input_output_matrices: Python `bool`. Any function which whose
        input or output is a matrix assumes the input is Cholesky and returns a
        Cholesky factored matrix. Example `log_prob` input takes a Cholesky and
        `sample_n` returns a Cholesky when
        `cholesky_input_output_matrices=True`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
        parameters = locals()
        with ops.name_scope(name, values=[scale]) as ns:
            super(WishartFull, self).__init__(
                df=df,
                scale_operator_pd=operator_pd_full.OperatorPDFull(
                    scale, verify_pd=validate_args),
                cholesky_input_output_matrices=cholesky_input_output_matrices,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                name=ns)
        self._parameters = parameters
コード例 #16
0
ファイル: wishart.py プロジェクト: sanketg10/tensorflow
  def __init__(self,
               df,
               scale,
               cholesky_input_output_matrices=False,
               validate_args=False,
               allow_nan_stats=True,
               name="WishartFull"):
    """Construct Wishart distributions.

    Args:
      df: `float` or `double` `Tensor`. Degrees of freedom, must be greater than
        or equal to dimension of the scale matrix.
      scale: `float` or `double` `Tensor`. The symmetric positive definite
        scale matrix of the distribution.
      cholesky_input_output_matrices: `Boolean`. Any function which whose input
        or output is a matrix assumes the input is Cholesky and returns a
        Cholesky factored matrix. Example`log_pdf` input takes a Cholesky and
        `sample_n` returns a Cholesky when
        `cholesky_input_output_matrices=True`.
      validate_args: `Boolean`, default `False`.  Whether to validate input with
        asserts. If `validate_args` is `False`, and the inputs are invalid,
        correct behavior is not guaranteed.
      allow_nan_stats: `Boolean`, default `True`. If `False`, raise an
        exception if a statistic (e.g., mean, mode) is undefined for any batch
        member. If True, batch members with valid parameters leading to
        undefined statistics will return `NaN` for this statistic.
      name: The name scope to give class member ops.
    """
    parameters = locals()
    parameters.pop("self")
    with ops.name_scope(name, values=[scale]) as ns:
      super(WishartFull, self).__init__(
          df=df,
          scale_operator_pd=operator_pd_full.OperatorPDFull(
              scale, verify_pd=validate_args),
          cholesky_input_output_matrices=cholesky_input_output_matrices,
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats,
          name=ns)
    self._parameters = parameters
コード例 #17
0
    def _build_operator_and_mat(self, batch_shape, k, dtype=np.float64):
        """This method is called by base class, enabling many standard tests."""
        # Create a matrix then explicitly update it with v and diag.
        # Create an OperatorPDSqrtVDVTUpdate from the matrix and v and diag
        # The operator should have the same behavior.
        #
        # The low-rank matrix V will have rank 1/2 of k, unless k is 1, in which
        # case it will be 1 as well.
        if k == 1:
            v_matrix_rank = k
        else:
            v_matrix_rank = k // 2
        mat_shape = list(batch_shape) + [k, k]
        mat = self._random_pd_matrix(mat_shape)
        v, diag = self._random_v_and_diag(mat_shape, v_matrix_rank)

        # Set dtypes
        mat = mat.astype(dtype)
        v = v.astype(dtype)
        if diag is not None:
            diag = diag.astype(dtype)

        # The matrix: (mat + v*diag*v^T) * (mat + v*diag*v^T)^T
        # Our final updated operator should behave like this.
        updated_mat = self._updated_mat(mat, v, diag)

        # Represents the matrix: `mat`, before updating.
        # This is the Operator that we will update.
        o_made_with_mat = operator_pd_full.OperatorPDFull(mat)

        # Represents the matrix: (mat + v*diag*v^T) * (mat + v*diag*v^T)^T,
        # achieved by updating the operator "o_made_with_mat".
        # This is the operator we're testing.
        operator = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
            o_made_with_mat, v, diag)

        return operator, updated_mat
コード例 #18
0
 def testPositiveDefiniteMatrixDoesntRaise(self):
     with self.test_session():
         matrix = self._random_positive_def_array(2, 3, 3)
         operator = operator_pd_full.OperatorPDFull(matrix, verify_pd=True)
         operator.to_dense().eval()  # Should not raise