def wrapper(actual: Tensor, expected: Tensor, **kwargs: Any) -> Optional[_TestingErrorMeta]: if not actual.is_sparse: return check_tensors(actual, expected, **kwargs) if actual._nnz() != expected._nnz(): return _TestingErrorMeta( AssertionError, f"The number of specified values does not match: {actual._nnz()} != {expected._nnz()}" ) kwargs_equal = dict(kwargs, rtol=0, atol=0) error_meta = check_tensors(actual._indices(), expected._indices(), **kwargs_equal) if error_meta: return error_meta.amend_msg( postfix="\n\nThe failure occurred for the indices.") error_meta = check_tensors(actual._values(), expected._values(), **kwargs) if error_meta: return error_meta.amend_msg( postfix="\n\nThe failure occurred for the values.") return None
def kronecker_torch(t1: Tensor, t2: Tensor) -> Tensor: r""" Compute the kronecker product of :math:`\mathbf{T}_1` and :math:`\mathbf{T}_2`. This function is implemented in torch API and is not efficient for sparse {0, 1} matrix. :param t1: input tensor 1 :param t2: input tensor 2 :return: kronecker product of :math:`\mathbf{T}_1` and :math:`\mathbf{T}_2` """ batch_num = t1.shape[0] t1dim1, t1dim2 = t1.shape[1], t1.shape[2] t2dim1, t2dim2 = t2.shape[1], t2.shape[2] if t1.is_sparse and t2.is_sparse: tt_idx = torch.stack(t1._indices()[0, :] * t2dim1, t1._indices()[1, :] * t2dim2) tt_idx = torch.repeat_interleave( tt_idx, t2._nnz(), dim=1) + t2._indices().repeat(1, t1._nnz()) tt_val = torch.repeat_interleave(t1._values(), t2._nnz(), dim=1) * t2._values().repeat( 1, t1._nnz()) tt = torch.sparse.FloatTensor( tt_idx, tt_val, torch.Size(t1dim1 * t2dim1, t1dim2 * t2dim2)) else: t1 = t1.reshape(batch_num, -1, 1) t2 = t2.reshape(batch_num, 1, -1) tt = torch.bmm(t1, t2) tt = tt.reshape(batch_num, t1dim1, t1dim2, t2dim1, t2dim2) tt = tt.permute([0, 1, 3, 2, 4]) tt = tt.reshape(batch_num, t1dim1 * t2dim1, t1dim2 * t2dim2) return tt
def wrapper( actual: Tensor, expected: Tensor, msg: Optional[Union[str, Callable[[Tensor, Tensor, Diagnostics], str]]] = None, **kwargs: Any, ) -> Optional[_TestingErrorMeta]: if not actual.is_sparse: return check_tensors(actual, expected, msg=msg, **kwargs) if actual._nnz() != expected._nnz(): return _TestingErrorMeta( AssertionError, (f"The number of specified values in sparse COO tensors does not match: " f"{actual._nnz()} != {expected._nnz()}"), ) kwargs_equal = dict(kwargs, rtol=0, atol=0) error_meta = check_tensors( actual._indices(), expected._indices(), msg=msg or functools.partial(_make_mismatch_msg, identifier="Sparse COO indices"), **kwargs_equal, ) if error_meta: return error_meta error_meta = check_tensors( actual._values(), expected._values(), msg=msg or functools.partial(_make_mismatch_msg, identifier="Sparse COO values"), **kwargs, ) if error_meta: return error_meta return None