Esempio n. 1
0
        def dot(mat: Optional[np.ndarray],
                batch_ndim: int,
                other1: np.ndarray = None,
                other2: np.ndarray = None) -> Optional[np.ndarray]:
            if mat is None or mat.ndim == 0 or other1 is None and other2 is None:
                return mat

            operands = ()

            if other1 is not None:
                other1_dims = get_other_dims(batch_ndim, True)
                operands += (other1, other1_dims)

            mat_dims = list(range(mat.ndim))
            if self.is_reversed:
                mat_dims = utils.reverse_zipped(mat_dims, batch_ndim)
            operands += (mat, mat_dims)

            if other2 is not None:
                other2_dims = get_other_dims(batch_ndim, False)
                operands += (other2, other2_dims)

            return np.einsum(*operands,
                             get_out_dims(batch_ndim),
                             optimize=True)
Esempio n. 2
0
    def reverse(self) -> 'Kernel':
        """Reverse the order of spatial axes in the covariance matrices.

    Returns:
      A `Kernel` object with spatial axes order flipped in
      all covariance matrices. For example, if `kernel.nngp` has shape
      `(batch_size_1, batch_size_2, H, H, W, W, D, D, ...)`, then
      `reverse(kernels).nngp` has shape
      `(batch_size_1, batch_size_2, ..., D, D, W, W, H, H)`.
    """
        batch_ndim = 1 if self.diagonal_batch else 2
        cov1 = utils.reverse_zipped(self.cov1, batch_ndim)
        cov2 = utils.reverse_zipped(self.cov2, batch_ndim)
        nngp = utils.reverse_zipped(self.nngp, 2)
        ntk = utils.reverse_zipped(self.ntk, 2)

        return self.replace(cov1=cov1,
                            nngp=nngp,
                            cov2=cov2,
                            ntk=ntk,
                            is_reversed=not self.is_reversed)