def dot(mat: Optional[np.ndarray], batch_ndim: int, other1: np.ndarray = None, other2: np.ndarray = None) -> Optional[np.ndarray]: if mat is None or mat.ndim == 0 or other1 is None and other2 is None: return mat operands = () if other1 is not None: other1_dims = get_other_dims(batch_ndim, True) operands += (other1, other1_dims) mat_dims = list(range(mat.ndim)) if self.is_reversed: mat_dims = utils.reverse_zipped(mat_dims, batch_ndim) operands += (mat, mat_dims) if other2 is not None: other2_dims = get_other_dims(batch_ndim, False) operands += (other2, other2_dims) return np.einsum(*operands, get_out_dims(batch_ndim), optimize=True)
def reverse(self) -> 'Kernel': """Reverse the order of spatial axes in the covariance matrices. Returns: A `Kernel` object with spatial axes order flipped in all covariance matrices. For example, if `kernel.nngp` has shape `(batch_size_1, batch_size_2, H, H, W, W, D, D, ...)`, then `reverse(kernels).nngp` has shape `(batch_size_1, batch_size_2, ..., D, D, W, W, H, H)`. """ batch_ndim = 1 if self.diagonal_batch else 2 cov1 = utils.reverse_zipped(self.cov1, batch_ndim) cov2 = utils.reverse_zipped(self.cov2, batch_ndim) nngp = utils.reverse_zipped(self.nngp, 2) ntk = utils.reverse_zipped(self.ntk, 2) return self.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk, is_reversed=not self.is_reversed)