def testValidateArgs(self):
     with self.test_session():
         with self.assertRaisesOpError("diagonal part must be non-zero"):
             scale = distribution_util.make_tril_scale(scale_tril=[[0., 1],
                                                                   [1.,
                                                                    1.]],
                                                       validate_args=True)
             self.evaluate(scale.to_dense())
 def testAssertPositive(self):
     with self.test_session():
         with self.assertRaisesOpError("diagonal part must be positive"):
             scale = distribution_util.make_tril_scale(scale_tril=[[-1., 1],
                                                                   [1.,
                                                                    1.]],
                                                       validate_args=True,
                                                       assert_positive=True)
             self.evaluate(scale.to_dense())
    def _testLegalInputs(self, loc=None, shape_hint=None, scale_params=None):
        for args in _powerset(scale_params.items()):
            with self.test_session():
                args = dict(args)

                scale_args = dict({
                    "loc": loc,
                    "shape_hint": shape_hint
                }, **args)
                expected_scale = _make_tril_scale(**scale_args)
                if expected_scale is None:
                    # Not enough shape information was specified.
                    with self.assertRaisesRegexp(ValueError,
                                                 ("is specified.")):
                        scale = distribution_util.make_tril_scale(**scale_args)
                        self.evaluate(scale.to_dense())
                else:
                    scale = distribution_util.make_tril_scale(**scale_args)
                    self.assertAllClose(expected_scale,
                                        self.evaluate(scale.to_dense()))
예제 #4
0
    def _create_scale_operator(self, identity_multiplier, diag, tril,
                               perturb_diag, perturb_factor, shift,
                               validate_args):
        """Construct `scale` from various components.

    Args:
      identity_multiplier: floating point rank 0 `Tensor` representing a scaling
        done to the identity matrix.
      diag: Floating-point `Tensor` representing the diagonal matrix.
        `scale_diag` has shape [N1, N2, ...  k], which represents a k x k
        diagonal matrix.
      tril: Floating-point `Tensor` representing the diagonal matrix.
        `scale_tril` has shape [N1, N2, ...  k], which represents a k x k lower
        triangular matrix.
      perturb_diag: Floating-point `Tensor` representing the diagonal matrix of
        the low rank update.
      perturb_factor: Floating-point `Tensor` representing factor matrix.
      shift: Floating-point `Tensor` representing `shift in `scale @ X + shift`.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.

    Returns:
      scale. In the case of scaling by a constant, scale is a
      floating point `Tensor`. Otherwise, scale is a `LinearOperator`.

    Raises:
      ValueError: if all of `tril`, `diag` and `identity_multiplier` are `None`.
    """
        identity_multiplier = _as_tensor(identity_multiplier,
                                         "identity_multiplier")
        diag = _as_tensor(diag, "diag")
        tril = _as_tensor(tril, "tril")
        perturb_diag = _as_tensor(perturb_diag, "perturb_diag")
        perturb_factor = _as_tensor(perturb_factor, "perturb_factor")

        # If possible, use the low rank update to infer the shape of
        # the identity matrix, when scale represents a scaled identity matrix
        # with a low rank update.
        shape_hint = None
        if perturb_factor is not None:
            shape_hint = distribution_util.dimension_size(perturb_factor,
                                                          axis=-2)

        if self._is_only_identity_multiplier:
            if validate_args:
                return control_flow_ops.with_dependencies([
                    tf.assert_none_equal(
                        identity_multiplier,
                        tf.zeros([], identity_multiplier.dtype),
                        ["identity_multiplier should be non-zero."])
                ], identity_multiplier)
            return identity_multiplier

        scale = distribution_util.make_tril_scale(
            loc=shift,
            scale_tril=tril,
            scale_diag=diag,
            scale_identity_multiplier=identity_multiplier,
            validate_args=validate_args,
            assert_positive=False,
            shape_hint=shape_hint)

        if perturb_factor is not None:
            return tf.linalg.LinearOperatorLowRankUpdate(
                scale,
                u=perturb_factor,
                diag_update=perturb_diag,
                is_diag_update_positive=perturb_diag is None,
                is_non_singular=True,  # Implied by is_positive_definite=True.
                is_self_adjoint=True,
                is_positive_definite=True,
                is_square=True)

        return scale
 def testZeroTriU(self):
     with self.test_session():
         scale = distribution_util.make_tril_scale(
             scale_tril=[[1., 1], [1., 1.]])
         self.assertAllClose([[1., 0], [1., 1.]],
                             self.evaluate(scale.to_dense()))
예제 #6
0
  def _create_scale_operator(self, identity_multiplier, diag, tril,
                             perturb_diag, perturb_factor, shift,
                             validate_args):
    """Construct `scale` from various components.

    Args:
      identity_multiplier: floating point rank 0 `Tensor` representing a scaling
        done to the identity matrix.
      diag: Floating-point `Tensor` representing the diagonal matrix.
        `scale_diag` has shape [N1, N2, ...  k], which represents a k x k
        diagonal matrix.
      tril: Floating-point `Tensor` representing the diagonal matrix.
        `scale_tril` has shape [N1, N2, ...  k], which represents a k x k lower
        triangular matrix.
      perturb_diag: Floating-point `Tensor` representing the diagonal matrix of
        the low rank update.
      perturb_factor: Floating-point `Tensor` representing factor matrix.
      shift: Floating-point `Tensor` representing `shift in `scale @ X + shift`.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.

    Returns:
      scale. In the case of scaling by a constant, scale is a
      floating point `Tensor`. Otherwise, scale is a `LinearOperator`.

    Raises:
      ValueError: if all of `tril`, `diag` and `identity_multiplier` are `None`.
    """
    identity_multiplier = _as_tensor(identity_multiplier, "identity_multiplier")
    diag = _as_tensor(diag, "diag")
    tril = _as_tensor(tril, "tril")
    perturb_diag = _as_tensor(perturb_diag, "perturb_diag")
    perturb_factor = _as_tensor(perturb_factor, "perturb_factor")

    # If possible, use the low rank update to infer the shape of
    # the identity matrix, when scale represents a scaled identity matrix
    # with a low rank update.
    shape_hint = None
    if perturb_factor is not None:
      shape_hint = distribution_util.dimension_size(perturb_factor, axis=-2)

    if self._is_only_identity_multiplier:
      if validate_args:
        return control_flow_ops.with_dependencies([
            tf.assert_none_equal(identity_multiplier,
                                 tf.zeros([], identity_multiplier.dtype),
                                 ["identity_multiplier should be non-zero."])
        ], identity_multiplier)
      return identity_multiplier

    scale = distribution_util.make_tril_scale(
        loc=shift,
        scale_tril=tril,
        scale_diag=diag,
        scale_identity_multiplier=identity_multiplier,
        validate_args=validate_args,
        assert_positive=False,
        shape_hint=shape_hint)

    if perturb_factor is not None:
      return tf.linalg.LinearOperatorLowRankUpdate(
          scale,
          u=perturb_factor,
          diag_update=perturb_diag,
          is_diag_update_positive=perturb_diag is None,
          is_non_singular=True,  # Implied by is_positive_definite=True.
          is_self_adjoint=True,
          is_positive_definite=True,
          is_square=True)

    return scale