def __init__(self, tril, v, diag=None, validate_args=False):
        """Creates an instance of _TriLPlusVDVTLightweightOperatorPD.

    WARNING: This object is not to be used outside of `Affine` where it is
    currently being temporarily used for refactoring purposes.

    Args:
      tril: `Tensor` of shape `[B1,..,Bb, d, d]`.
      v: `Tensor` of shape `[B1,...,Bb, d, k]`.
      diag: `Tensor` of shape `[B1,...,Bb, k, k]` or None
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
    """
        self._m = tril
        self._v = v
        self._validate_args = validate_args
        self._inputs = [tril, v]
        if diag is not None:
            self._inputs += [diag]
            self._d = operator_pd_diag.OperatorPDDiag(diag,
                                                      verify_pd=validate_args)
            self._d_inv = operator_pd_diag.OperatorPDDiag(
                1. / diag, verify_pd=validate_args)
            return
        if v.get_shape().is_fully_defined():
            v_shape = v.get_shape().as_list()
            id_shape = v_shape[:-2] + [v_shape[-1], v_shape[-1]]
        else:
            v_shape = array_ops.shape(v)
            id_shape = array_ops.concat(
                [v_shape[:-2], [v_shape[-1], v_shape[-1]]], 0)
        self._d = operator_pd_identity.OperatorPDIdentity(
            id_shape, v.dtype, verify_pd=self.validate_args)
        self._d_inv = self._d
Beispiel #2
0
 def _build_operator_and_mat(self, batch_shape, k, dtype=np.float64):
     # Build an identity matrix with right shape and dtype.
     # Build an operator that should act the same way.
     batch_shape = list(batch_shape)
     diag_shape = batch_shape + [k]
     matrix_shape = batch_shape + [k, k]
     diag = tf.ones(diag_shape, dtype=dtype)
     identity_matrix = tf.batch_matrix_diag(diag)
     operator = operator_pd_identity.OperatorPDIdentity(matrix_shape, dtype)
     return operator, identity_matrix.eval()
 def _build_operator_and_mat(self, batch_shape, k, dtype=np.float64):
     # Build an identity matrix with right shape and dtype.
     # Build an operator that should act the same way.
     batch_shape = list(batch_shape)
     diag_shape = batch_shape + [k]
     matrix_shape = batch_shape + [k, k]
     diag = array_ops.ones(diag_shape, dtype=dtype)
     scale = constant_op.constant(2.0, dtype=dtype)
     scaled_identity_matrix = scale * array_ops.matrix_diag(diag)
     operator = operator_pd_identity.OperatorPDIdentity(matrix_shape,
                                                        dtype,
                                                        scale=scale)
     return operator, scaled_identity_matrix.eval()
 def _get_identity_operator(self, v):
   """Get an `OperatorPDIdentity` to play the role of `D` in `VDV^T`."""
   with ops.op_scope([v], 'get_identity_operator'):
     if v.get_shape().is_fully_defined():
       v_shape = v.get_shape().as_list()
       v_batch_shape = v_shape[:-2]
       r = v_shape[-1]
       id_shape = v_batch_shape + [r, r]
     else:
       v_shape = array_ops.shape(v)
       v_rank = array_ops.rank(v)
       v_batch_shape = array_ops.slice(v_shape, [0], [v_rank - 2])
       r = array_ops.gather(v_shape, v_rank - 1)  # Last dim of v
       id_shape = array_ops.concat(0, (v_batch_shape, [r, r]))
     return operator_pd_identity.OperatorPDIdentity(
         id_shape, v.dtype, verify_pd=self._verify_pd)
    def _create_scale_operator(self, identity_multiplier, diag, tril,
                               perturb_diag, perturb_factor, event_ndims,
                               validate_args):
        """Construct `scale` from various components.

    Args:
      identity_multiplier: floating point rank 0 `Tensor` representing a scaling
        done to the identity matrix.
      diag: Floating-point `Tensor` representing the diagonal matrix.
        `scale_diag` has shape [N1, N2, ...  k], which represents a k x k
        diagonal matrix.
      tril: Floating-point `Tensor` representing the diagonal matrix.
        `scale_tril` has shape [N1, N2, ...  k], which represents a k x k lower
        triangular matrix.
      perturb_diag: Floating-point `Tensor` representing the diagonal matrix of
        the low rank update.
      perturb_factor: Floating-point `Tensor` representing factor matrix.
      event_ndims: Scalar `int32` `Tensor` indicating the number of dimensions
        associated with a particular draw from the distribution. Must be 0 or 1
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.

    Returns:
      scale. In the case of scaling by a constant, scale is a
      floating point `Tensor`. Otherwise, scale is an `OperatorPD`.

    Raises:
      ValueError: if all of `tril`, `diag` and `identity_multiplier` are `None`.
    """
        identity_multiplier = _as_tensor(identity_multiplier,
                                         "identity_multiplier")
        diag = _as_tensor(diag, "diag")
        tril = _as_tensor(tril, "tril")
        perturb_diag = _as_tensor(perturb_diag, "perturb_diag")
        perturb_factor = _as_tensor(perturb_factor, "perturb_factor")

        identity_multiplier = self._maybe_validate_identity_multiplier(
            identity_multiplier, validate_args)

        if perturb_factor is not None:
            perturb_factor = self._process_matrix(perturb_factor,
                                                  min_rank=2,
                                                  event_ndims=event_ndims)

        if perturb_diag is not None:
            perturb_diag = self._process_matrix(perturb_diag,
                                                min_rank=1,
                                                event_ndims=event_ndims)

        # The following if-statments are ordered by increasingly stronger
        # assumptions in the base matrix, i.e., we process in the order:
        # TriL, Diag, Identity.

        if tril is not None:
            tril = self._preprocess_tril(identity_multiplier, diag, tril,
                                         event_ndims)
            if perturb_factor is None:
                return operator_pd_cholesky.OperatorPDCholesky(
                    tril, verify_pd=validate_args)
            return _TriLPlusVDVTLightweightOperatorPD(
                tril=tril,
                v=perturb_factor,
                diag=perturb_diag,
                validate_args=validate_args)

        if diag is not None:
            diag = self._preprocess_diag(identity_multiplier, diag,
                                         event_ndims)
            if perturb_factor is None:
                return operator_pd_diag.OperatorPDSqrtDiag(
                    diag, verify_pd=validate_args)
            return operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
                operator=operator_pd_diag.OperatorPDDiag(
                    diag, verify_pd=validate_args),
                v=perturb_factor,
                diag=perturb_diag,
                verify_pd=validate_args)

        if identity_multiplier is not None:
            if perturb_factor is None:
                return identity_multiplier
            # Infer the shape from the V and D.
            v_shape = array_ops.shape(perturb_factor)
            identity_shape = array_ops.concat([v_shape[:-1], [v_shape[-2]]], 0)
            scaled_identity = operator_pd_identity.OperatorPDIdentity(
                identity_shape,
                perturb_factor.dtype.base_dtype,
                scale=identity_multiplier,
                verify_pd=validate_args)
            return operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
                operator=scaled_identity,
                v=perturb_factor,
                diag=perturb_diag,
                verify_pd=validate_args)

        raise ValueError(
            "One of tril, diag and/or identity_multiplier must be "
            "specified.")