def backward(ctx, grad_output): A, B = ctx.saved_tensors conjA = A.clone() conjB = B.clone() conjA[..., 1] = -A[..., 1] conjB[..., 1] = -B[..., 1] #conjA[:,:,:,:,1] = -A[:,:,:,:,1] #conjB[:,:,1] = -B[:,:,1] m, n = conjB.nelement() // 2, conjA.nelement() // conjB.nelement() # n is the B*C # m is the M*N gradA = conjA.new(conjA.size()) # (n,m), col-major gradC = grad_output.contiguous() # (n,m), col-major # grad_A = grad_C * conj(B) lda = m ldc = m incx = 1 handle = torch.cuda.current_blas_handle() stream = torch.cuda.current_stream()._as_parameter_ cublas.cublasSetStream(handle, stream) cublas.cublasCdgmm(handle, 'l', m, n, gradC.data_ptr(), lda, conjB.data_ptr(), incx, gradA.data_ptr(), ldc) # grad_B = sum_n grad_C * conj(A) # view grad_C and conjA as one vector of size n*m gradB_ = gradC.new(gradC.size()) # mul(gradC,conjA) # (B,C,M,N,2) lda = m * n ldc = m * n incx = 1 cublas.cublasSetStream(handle, stream) cublas.cublasCdgmm(handle, 'l', m * n, 1, gradC.data_ptr(), lda, conjA.data_ptr(), incx, gradB_.data_ptr(), ldc) gradB = torch.sum(torch.sum(gradB_, 0), 0) # (m) return gradA, gradB
def forward(ctx, A, B): # assume A and B has the same size , with last dim = 2 A, B = A.contiguous(), B.contiguous() ctx.save_for_backward(A, B) if not iscomplex(A) or not iscomplex(B): raise TypeError('The input, filter and output should be complex') if A.nelement() != B.nelement(): raise TypeError('The input and filter should have same size') if type(A) is not type(B): raise RuntimeError('A and B should be same type!') if not A.is_cuda: raise RuntimeError('Use the torch backend for cpu tensors!') C = A.new(A.size()) m, n = B.nelement() // 2, A.nelement() // B.nelement() lda = m ldc = m incx = 1 handle = torch.cuda.current_blas_handle() stream = torch.cuda.current_stream()._as_parameter_ cublas.cublasSetStream(handle, stream) cublas.cublasCdgmm(handle, 'l', m, n, A.data_ptr(), lda, B.data_ptr(), incx, C.data_ptr(), ldc) return C
def cdgmm(A, B, inplace=False): """ Complex pointwise multiplication between (batched) tensor A and tensor B. Parameters ---------- A : tensor A is a complex tensor of size (B, C, M, N, 2) B : tensor B is a complex tensor of size (M, N, 2) or real tensor of (M, N, 1) inplace : boolean, optional if set to True, all the operations are performed inplace Returns ------- C : tensor output tensor of size (B, C, M, N, 2) such that: C[b, c, m, n, :] = A[b, c, m, n, :] * B[m, n, :] """ if not iscomplex(A): raise TypeError('The input must be complex, indicated by a last ' 'dimension of size 2') if B.ndimension() != 3: raise RuntimeError('The filter must be a 3-tensor, with a last ' 'dimension of size 1 or 2 to indicate it is real ' 'or complex, respectively') if not iscomplex(B) and not isreal(B): raise TypeError('The filter must be complex or real, indicated by a ' 'last dimension of size 2 or 1, respectively') if A.size()[-3:-1] != B.size()[-3:-1]: raise RuntimeError('The filters are not compatible for multiplication!') if A.dtype is not B.dtype: raise RuntimeError('A and B must be of the same dtype') if A.device != B.device: raise RuntimeError('A and B must be on the same device') if isreal(B): if inplace: return A.mul_(B) else: return A * B else: A, B = A.contiguous(), B.contiguous() C = A.new(A.size()) if not inplace else A m, n = B.nelement() // 2, A.nelement() // B.nelement() lda = m ldc = m incx = 1 handle = torch.cuda.current_blas_handle() stream = torch.cuda.current_stream()._as_parameter_ cublas.cublasSetStream(handle, stream) cublas.cublasCdgmm(handle, 'l', m, n, A.data_ptr(), lda, B.data_ptr(), incx, C.data_ptr(), ldc) return C
def forward(ctx, A, B): """ Complex pointwise multiplication between (batched) tensor A and tensor B. Parameters ---------- A : tensor input tensor with size (B, C, M, N, 2) B : tensor B is a complex tensor of size (M, N, 2) inplace : boolean, optional if set to True, all the operations are performed inplace Returns ------- C : tensor output tensor of size (B, C, M, N, 2) such that: C[b, c, m, n, :] = A[b, c, m, n, :] * B[m, n, :] """ A, B = A.contiguous(), B.contiguous() ctx.save_for_backward(A, B) if A.size()[-3:] != B.size(): raise RuntimeError( 'The filters are not compatible for multiplication!') if not iscomplex(A) or not iscomplex(B): raise TypeError('The input, filter and output should be complex') if B.ndimension() != 3: raise RuntimeError('The filters must be simply a complex array!') if type(A) is not type(B): raise RuntimeError('A and B should be same type!') if not A.is_cuda: raise RuntimeError('Use the torch backend for cpu tensors!') C = A.new(A.size()) m, n = B.nelement() // 2, A.nelement() // B.nelement() lda = m ldc = m incx = 1 handle = torch.cuda.current_blas_handle() stream = torch.cuda.current_stream()._as_parameter_ cublas.cublasSetStream(handle, stream) cublas.cublasCdgmm(handle, 'l', m, n, A.data_ptr(), lda, B.data_ptr(), incx, C.data_ptr(), ldc) return C
def cdgmm(A, B, jit=True, inplace=False): """This function uses the C-wrapper to use cuBLAS. """ A, B = A.contiguous(), B.contiguous() if A.size()[-3:] != B.size(): raise RuntimeError( 'The filters are not compatible for multiplication!') if not iscomplex(A) or not iscomplex(B): raise TypeError('The input, filter and output should be complex') if B.ndimension() != 3: raise RuntimeError('The filters must be simply a complex array!') if type(A) is not type(B): raise RuntimeError('A and B should be same type!') if not jit or isinstance(A, (torch.FloatTensor, torch.DoubleTensor)): C = A.new(A.size()) A_r = A[..., 0].contiguous().view(-1, A.size(-2) * A.size(-3)) A_i = A[..., 1].contiguous().view(-1, A.size(-2) * A.size(-3)) B_r = B[..., 0].contiguous().view( B.size(-2) * B.size(-3)).unsqueeze(0).expand_as(A_i) B_i = B[..., 1].contiguous().view( B.size(-2) * B.size(-3)).unsqueeze(0).expand_as(A_r) C[..., 0].copy_(A_r * B_r - A_i * B_i) C[..., 1].copy_(A_r * B_i + A_i * B_r) # faster if B is actually real #B[...,1] = B[...,0] #C = A * B.unsqueeze(0).expand_as(A) return C if not inplace else A.copy_(C) else: C = A.new(A.size()) if not inplace else A m, n = B.nelement() // 2, A.nelement() // B.nelement() lda = m ldc = m incx = 1 handle = torch.cuda.current_blas_handle() stream = torch.cuda.current_stream()._as_parameter_ cublas.cublasSetStream(handle, stream) cublas.cublasCdgmm(handle, 'l', m, n, A.data_ptr(), lda, B.data_ptr(), incx, C.data_ptr(), ldc) return C
def cdgmm3d(A, B, inplace=False): """ Pointwise multiplication of complex tensors. ---------- A: complex tensor B: complex tensor of the same size as A Returns ------- output : tensor of the same size as A containing the result of the elementwise complex multiplication of A with B """ if not A.is_contiguous(): warnings.warn("cdgmm3d: tensor A is converted to a contiguous array") A = A.contiguous() if not B.is_contiguous(): warnings.warn("cdgmm3d: tensor B is converted to a contiguous array") B = B.contiguous() if A.size()[-4:] != B.size(): raise RuntimeError('The filters are not compatible for multiplication.') if not iscomplex(A) or not iscomplex(B): raise TypeError('The input, filter and output should be complex.') if B.ndimension() != 4: raise RuntimeError('The filters must be simply a complex array.') if type(A) is not type(B): raise RuntimeError('A and B should be same type.') if not A.is_cuda: raise RuntimeError('Use the torch backend for cpu tensors.') C = A.new(A.size()) if not inplace else A m, n = B.nelement() // 2, A.nelement() // B.nelement() lda = m ldc = m incx = 1 handle = torch.cuda.current_blas_handle() stream = torch.cuda.current_stream()._as_parameter_ cublas.cublasSetStream(handle, stream) cublas.cublasCdgmm(handle, 'l', m, n, A.data_ptr(), lda, B.data_ptr(), incx, C.data_ptr(), ldc) return C
def cdgmm(A, B, jit=True, inplace=False): """This function uses the C-wrapper to use cuBLAS. """ A, B = A.contiguous(), B.contiguous() if A.size()[-3:] != B.size(): raise RuntimeError('The filters are not compatible for multiplication!') if not iscomplex(A) or not iscomplex(B): raise TypeError('The input, filter and output should be complex') if B.ndimension() != 3: raise RuntimeError('The filters must be simply a complex array!') if type(A) is not type(B): raise RuntimeError('A and B should be same type!') if not jit or isinstance(A, (torch.FloatTensor, torch.DoubleTensor)): C = A.new(A.size()) A_r = A[..., 0].contiguous().view(-1, A.size(-2)*A.size(-3)) A_i = A[..., 1].contiguous().view(-1, A.size(-2)*A.size(-3)) B_r = B[...,0].contiguous().view(B.size(-2)*B.size(-3)).unsqueeze(0).expand_as(A_i) B_i = B[..., 1].contiguous().view(B.size(-2)*B.size(-3)).unsqueeze(0).expand_as(A_r) C[..., 0].copy_(A_r * B_r - A_i * B_i) C[..., 1].copy_(A_r * B_i + A_i * B_r) # faster if B is actually real #B[...,1] = B[...,0] #C = A * B.unsqueeze(0).expand_as(A) return C if not inplace else A.copy_(C) else: C = A.new(A.size()) if not inplace else A m, n = B.nelement() // 2, A.nelement() // B.nelement() lda = m ldc = m incx = 1 handle = torch.cuda.current_blas_handle() stream = torch.cuda.current_stream()._as_parameter_ cublas.cublasSetStream(handle, stream) cublas.cublasCdgmm(handle, 'l', m, n, A.data_ptr(), lda, B.data_ptr(), incx, C.data_ptr(), ldc) return C
def cublas_cdgmm(A, x, out=None): if out is not None: assert out.is_contiguous() and out.size() == A.size() else: out = A.new(A.size()) assert x.dim() == 2 and x.size(-1) == 2 and A.size(-1) == 2 assert A.dim() == 3 assert x.size(0) == A.size(1) or x.size(0) == A.size(0) assert A.type() == x.type() == out.type() assert A.is_contiguous() if not isinstance(A, (torch.cuda.FloatTensor, torch.cuda.DoubleTensor)): raise NotImplementedError else: m, n = A.size(1), A.size(0) if x.size(0) == A.size(1): mode = 'l' elif x.size(0) == A.size(0): mode = 'r' lda, ldc = m, m incx = 1 handle = torch.cuda.current_blas_handle() stream = torch.cuda.current_stream()._as_parameter_ from skcuda import cublas cublas.cublasSetStream(handle, stream) args = [ handle, mode, m, n, A.data_ptr(), lda, x.data_ptr(), incx, out.data_ptr(), ldc ] if isinstance(A, torch.cuda.FloatTensor): cublas.cublasCdgmm(*args) elif isinstance(A, torch.cuda.DoubleTensor): cublas.cublasZdgmm(*args) return out
def cdgmm3d(A, B, inplace=False): """Complex pointwise multiplication. Complex pointwise multiplication between (batched) tensor A and tensor B. Parameters ---------- A : torch tensor Complex torch tensor. B : torch tensor Complex of the same size as A. inplace : boolean, optional If set True, all the operations are performed inplace. Raises ------ RuntimeError In the event that the tensors are not compatibile for multiplication (i.e. the final four dimensions of A do not match with the dimensions of B), or in the event that B is not complex, or in the event that the type of A and B are not the same. TypeError In the event that x is not complex i.e. does not have a final dimension of 2, or in the event that both tensors are not on the same device. Returns ------- output : torch tensor Torch tensor of the same size as A containing the result of the elementwise complex multiplication of A with B. """ if not A.is_contiguous(): warnings.warn("cdgmm3d: tensor A is converted to a contiguous array") A = A.contiguous() if not B.is_contiguous(): warnings.warn("cdgmm3d: tensor B is converted to a contiguous array") B = B.contiguous() if A.shape[-4:] != B.shape: raise RuntimeError( 'The filters are not compatible for multiplication.') if not _is_complex(A) or not _is_complex(B): raise TypeError('The input, filter and output should be complex.') if B.ndimension() != 4: raise RuntimeError('The filters must be simply a complex array.') if type(A) is not type(B): raise RuntimeError('A and B should be same type.') if not A.is_cuda: raise RuntimeError('Use the torch backend for CPU tensors.') C = A.new(A.shape) if not inplace else A m, n = B.nelement() // 2, A.nelement() // B.nelement() lda = m ldc = m incx = 1 handle = torch.cuda.current_blas_handle() stream = torch.cuda.current_stream()._as_parameter_ cublas.cublasSetStream(handle, stream) cublas.cublasCdgmm(handle, 'l', m, n, A.data_ptr(), lda, B.data_ptr(), incx, C.data_ptr(), ldc) return C
def cdgmm(A, B, inplace=False): """Complex pointwise multiplication. Complex pointwise multiplication between (batched) tensor A and tensor B. Parameters ---------- A : tensor A is a complex tensor of size (B, C, M, N, 2). B : tensor B is a complex tensor of size (M, N, 2) or real tensor of (M, N, 1). inplace : boolean, optional If set to True, all the operations are performed in place. Raises ------ RuntimeError In the event that the filter B is not a 3-tensor with a last dimension of size 1 or 2, or A and B are not compatible for multiplication, or if A or B are not contiguous. TypeError In the event that A is not complex, or B does not have a final dimension of 1 or 2, or A and B are not of the same dtype, or if A or B are not cuda tensors, or if A and B are not on the same device. Returns ------- C : tensor Output tensor of size (B, C, M, N, 2) such that: C[b, c, m, n, :] = A[b, c, m, n, :] * B[m, n, :]. """ if not _is_complex(A): raise TypeError( 'The input should be complex (i.e. last dimension is 2).') if not _is_complex(B) and not _is_real(B): raise TypeError('The filter should be complex or real, indicated by a ' 'last dimension of size 2 or 1, respectively.') if A.shape[-len(B.shape):-1] != B.shape[:-1]: raise RuntimeError( 'The filters are not compatible for multiplication.') if A.dtype is not B.dtype: raise TypeError('Input and filter must be of the same dtype.') if not A.is_cuda or not B.is_cuda: raise TypeError('Input and filter must be CUDA tensors.') if A.device.index != B.device.index: raise TypeError('Input and filter must be on the same GPU.') if _is_real(B): if inplace: return A.mul_(B) else: return A * B else: if not A.is_contiguous() or not B.is_contiguous(): raise RuntimeError('Tensors must be contiguous.') C = A.new(A.shape) if not inplace else A m, n = B.nelement() // 2, A.nelement() // B.nelement() lda = m ldc = m incx = 1 handle = torch.cuda.current_blas_handle() stream = torch.cuda.current_stream()._as_parameter_ cublas.cublasSetStream(handle, stream) cublas.cublasCdgmm(handle, 'l', m, n, A.data_ptr(), lda, B.data_ptr(), incx, C.data_ptr(), ldc) return C