Пример #1
0
    def __init__(self, tril, v, diag=None, validate_args=False):
        """Creates an instance of _TriLPlusVDVTLightweightOperatorPD.

    WARNING: This object is not to be used outside of `Affine` where it is
    currently being temporarily used for refactoring purposes.

    Args:
      tril: `Tensor` of shape `[B1,..,Bb, d, d]`.
      v: `Tensor` of shape `[B1,...,Bb, d, k]`.
      diag: `Tensor` of shape `[B1,...,Bb, k, k]` or None
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
    """
        self._m = tril
        self._v = v
        self._validate_args = validate_args
        self._inputs = [tril, v]
        if diag is not None:
            self._inputs += [diag]
            self._d = operator_pd_diag.OperatorPDDiag(diag,
                                                      verify_pd=validate_args)
            self._d_inv = operator_pd_diag.OperatorPDDiag(
                1. / diag, verify_pd=validate_args)
            return
        if v.get_shape().is_fully_defined():
            v_shape = v.get_shape().as_list()
            id_shape = v_shape[:-2] + [v_shape[-1], v_shape[-1]]
        else:
            v_shape = array_ops.shape(v)
            id_shape = array_ops.concat(
                [v_shape[:-2], [v_shape[-1], v_shape[-1]]], 0)
        self._d = operator_pd_identity.OperatorPDIdentity(
            id_shape, v.dtype, verify_pd=self.validate_args)
        self._d_inv = self._d
    def __init__(self,
                 operator,
                 v,
                 diag=None,
                 verify_pd=True,
                 verify_shapes=True,
                 name="OperatorPDSqrtVDVTUpdate"):
        """Initialize an `OperatorPDSqrtVDVTUpdate`.

    Args:
      operator:  Subclass of `OperatorPDBase`.  Represents the (batch) positive
        definite matrix `M` in `R^{k x k}`.
      v: `Tensor` defining batch matrix of same `dtype` and `batch_shape` as
        `operator`, and last two dimensions of shape `(k, r)`.
      diag:  Optional `Tensor` defining batch vector of same `dtype` and
        `batch_shape` as `operator`, and last dimension of size `r`.  If `None`,
        the update becomes `VV^T` rather than `VDV^T`.
      verify_pd:  `Boolean`.  If `True`, add asserts that `diag > 0`, which,
        along with the positive definiteness of `operator`, is sufficient to
        make the resulting operator positive definite.
      verify_shapes:  `Boolean`.  If `True`, check that `operator`, `v`, and
        `diag` have compatible shapes.
      name:  A name to prepend to `Op` names.
    """

        if not isinstance(operator, operator_pd.OperatorPDBase):
            raise TypeError("operator was not instance of OperatorPDBase.")

        with ops.name_scope(name):
            with ops.name_scope("init", values=operator.inputs + [v, diag]):
                self._operator = operator
                self._v = ops.convert_to_tensor(v, name="v")
                self._verify_pd = verify_pd
                self._verify_shapes = verify_shapes
                self._name = name

                # This operator will be PD so long as the diag is PSD, but Woodbury
                # and determinant lemmas require diag to be PD.  So require diag PD
                # whenever we ask to "verify_pd".
                if diag is not None:
                    self._diag = ops.convert_to_tensor(diag, name="diag")
                    self._diag_operator = operator_pd_diag.OperatorPDDiag(
                        diag, verify_pd=self.verify_pd)
                    # No need to verify that the inverse of a PD is PD.
                    self._diag_inv_operator = operator_pd_diag.OperatorPDDiag(
                        1 / self._diag, verify_pd=False)
                else:
                    self._diag = None
                    self._diag_operator = self._get_identity_operator(self._v)
                    self._diag_inv_operator = self._diag_operator

                self._check_types(operator, self._v, self._diag)
                # Always check static.
                checked = self._check_shapes_static(operator, self._v,
                                                    self._diag)
                if not checked and self._verify_shapes:
                    self._v, self._diag = self._check_shapes_dynamic(
                        operator, self._v, self._diag)
Пример #3
0
    def __init__(self,
                 mu,
                 diag_large,
                 v,
                 diag_small=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="MultivariateNormalDiagPlusVDVT"):
        """Multivariate Normal distributions on `R^k`.

    For every batch member, this distribution represents `k` random variables
    `(X_1,...,X_k)`, with mean `E[X_i] = mu[i]`, and covariance matrix
    `C_{ij} := E[(X_i - mu[i])(X_j - mu[j])]`

    The user initializes this class by providing the mean `mu`, and a
    lightweight definition of `C`:

    ```
    C = SS^T = SS = (M + V D V^T) (M + V D V^T)
    M is diagonal (k x k)
    V = is shape (k x r), typically r << k
    D = is diagonal (r x r), optional (defaults to identity).
    ```

    Args:
      mu:  Rank `n + 1` floating point tensor with shape `[N1,...,Nn, k]`,
        `n >= 0`.  The means.
      diag_large:  Optional rank `n + 1` floating point tensor, shape
        `[N1,...,Nn, k]` `n >= 0`.  Defines the diagonal matrix `M`.
      v:  Rank `n + 1` floating point tensor, shape `[N1,...,Nn, k, r]`
        `n >= 0`.  Defines the matrix `V`.
      diag_small:  Rank `n + 1` floating point tensor, shape
        `[N1,...,Nn, k]` `n >= 0`.  Defines the diagonal matrix `D`.  Default
        is `None`, which means `D` will be the identity matrix.
      validate_args: `Boolean`, default `False`.  Whether to validate input
        with asserts.  If `validate_args` is `False`,
        and the inputs are invalid, correct behavior is not guaranteed.
      allow_nan_stats: `Boolean`, default `True`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to give Ops created by the initializer.
    """
        parameters = locals()
        parameters.pop("self")
        with ops.name_scope(name, values=[diag_large, v, diag_small]) as ns:
            cov = operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
                operator_pd_diag.OperatorPDDiag(diag_large,
                                                verify_pd=validate_args),
                v,
                diag=diag_small,
                verify_pd=validate_args,
                verify_shapes=validate_args)
        super(MultivariateNormalDiagPlusVDVT,
              self).__init__(mu,
                             cov,
                             allow_nan_stats=allow_nan_stats,
                             validate_args=validate_args,
                             name=ns)
        self._parameters = parameters
Пример #4
0
    def _create_scale_operator(self, identity_multiplier, diag, tril,
                               perturb_diag, perturb_factor, event_ndims,
                               validate_args):
        """Construct `scale` from various components.

    Args:
      identity_multiplier: floating point rank 0 `Tensor` representing a scaling
        done to the identity matrix.
      diag: Floating-point `Tensor` representing the diagonal matrix.
        `scale_diag` has shape [N1, N2, ...  k], which represents a k x k
        diagonal matrix.
      tril: Floating-point `Tensor` representing the diagonal matrix.
        `scale_tril` has shape [N1, N2, ...  k], which represents a k x k lower
        triangular matrix.
      perturb_diag: Floating-point `Tensor` representing the diagonal matrix of
        the low rank update.
      perturb_factor: Floating-point `Tensor` representing factor matrix.
      event_ndims: Scalar `int32` `Tensor` indicating the number of dimensions
        associated with a particular draw from the distribution. Must be 0 or 1
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.

    Returns:
      scale. In the case of scaling by a constant, scale is a
      floating point `Tensor`. Otherwise, scale is an `OperatorPD`.

    Raises:
      ValueError: if all of `tril`, `diag` and `identity_multiplier` are `None`.
    """
        identity_multiplier = _as_tensor(identity_multiplier,
                                         "identity_multiplier")
        diag = _as_tensor(diag, "diag")
        tril = _as_tensor(tril, "tril")
        perturb_diag = _as_tensor(perturb_diag, "perturb_diag")
        perturb_factor = _as_tensor(perturb_factor, "perturb_factor")

        identity_multiplier = self._maybe_validate_identity_multiplier(
            identity_multiplier, validate_args)

        if perturb_factor is not None:
            perturb_factor = self._process_matrix(perturb_factor,
                                                  min_rank=2,
                                                  event_ndims=event_ndims)

        if perturb_diag is not None:
            perturb_diag = self._process_matrix(perturb_diag,
                                                min_rank=1,
                                                event_ndims=event_ndims)

        # The following if-statments are ordered by increasingly stronger
        # assumptions in the base matrix, i.e., we process in the order:
        # TriL, Diag, Identity.

        if tril is not None:
            tril = self._preprocess_tril(identity_multiplier, diag, tril,
                                         event_ndims)
            if perturb_factor is None:
                return operator_pd_cholesky.OperatorPDCholesky(
                    tril, verify_pd=validate_args)
            return _TriLPlusVDVTLightweightOperatorPD(
                tril=tril,
                v=perturb_factor,
                diag=perturb_diag,
                validate_args=validate_args)

        if diag is not None:
            diag = self._preprocess_diag(identity_multiplier, diag,
                                         event_ndims)
            if perturb_factor is None:
                return operator_pd_diag.OperatorPDSqrtDiag(
                    diag, verify_pd=validate_args)
            return operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
                operator=operator_pd_diag.OperatorPDDiag(
                    diag, verify_pd=validate_args),
                v=perturb_factor,
                diag=perturb_diag,
                verify_pd=validate_args)

        if identity_multiplier is not None:
            if perturb_factor is None:
                return identity_multiplier
            # Infer the shape from the V and D.
            v_shape = array_ops.shape(perturb_factor)
            identity_shape = array_ops.concat([v_shape[:-1], [v_shape[-2]]], 0)
            scaled_identity = operator_pd_identity.OperatorPDIdentity(
                identity_shape,
                perturb_factor.dtype.base_dtype,
                scale=identity_multiplier,
                verify_pd=validate_args)
            return operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
                operator=scaled_identity,
                v=perturb_factor,
                diag=perturb_diag,
                verify_pd=validate_args)

        raise ValueError(
            "One of tril, diag and/or identity_multiplier must be "
            "specified.")