def backward(ctx, grad_output):
        A, B = ctx.saved_tensors
        conjA = A.clone()
        conjB = B.clone()
        conjA[..., 1] = -A[..., 1]
        conjB[..., 1] = -B[..., 1]
        #conjA[:,:,:,:,1] = -A[:,:,:,:,1]
        #conjB[:,:,1] = -B[:,:,1]
        m, n = conjB.nelement() // 2, conjA.nelement() // conjB.nelement()
        # n is the B*C
        # m is the M*N
        gradA = conjA.new(conjA.size())  # (n,m), col-major
        gradC = grad_output.contiguous()  # (n,m), col-major
        # grad_A = grad_C * conj(B)
        lda = m
        ldc = m
        incx = 1
        handle = torch.cuda.current_blas_handle()
        stream = torch.cuda.current_stream()._as_parameter_
        cublas.cublasSetStream(handle, stream)
        cublas.cublasCdgmm(handle, 'l', m, n, gradC.data_ptr(), lda,
                           conjB.data_ptr(), incx, gradA.data_ptr(), ldc)

        # grad_B = sum_n grad_C * conj(A)
        # view grad_C and conjA as one vector of size n*m
        gradB_ = gradC.new(gradC.size())  # mul(gradC,conjA) # (B,C,M,N,2)
        lda = m * n
        ldc = m * n
        incx = 1
        cublas.cublasSetStream(handle, stream)
        cublas.cublasCdgmm(handle, 'l', m * n, 1, gradC.data_ptr(), lda,
                           conjA.data_ptr(), incx, gradB_.data_ptr(), ldc)
        gradB = torch.sum(torch.sum(gradB_, 0), 0)  # (m)

        return gradA, gradB
    def forward(ctx, A, B):
        # assume A and B has the same size , with last dim = 2
        A, B = A.contiguous(), B.contiguous()
        ctx.save_for_backward(A, B)

        if not iscomplex(A) or not iscomplex(B):
            raise TypeError('The input, filter and output should be complex')

        if A.nelement() != B.nelement():
            raise TypeError('The input and filter should have same size')

        if type(A) is not type(B):
            raise RuntimeError('A and B should be same type!')

        if not A.is_cuda:
            raise RuntimeError('Use the torch backend for cpu tensors!')

        C = A.new(A.size())
        m, n = B.nelement() // 2, A.nelement() // B.nelement()
        lda = m
        ldc = m
        incx = 1
        handle = torch.cuda.current_blas_handle()
        stream = torch.cuda.current_stream()._as_parameter_
        cublas.cublasSetStream(handle, stream)
        cublas.cublasCdgmm(handle, 'l', m, n, A.data_ptr(), lda, B.data_ptr(),
                           incx, C.data_ptr(), ldc)
        return C
예제 #3
0
def cdgmm(A, B, inplace=False):
    """
        Complex pointwise multiplication between (batched) tensor A and tensor B.

        Parameters
        ----------
        A : tensor
            A is a complex tensor of size (B, C, M, N, 2)
        B : tensor
            B is a complex tensor of size (M, N, 2) or real tensor of (M, N, 1)
        inplace : boolean, optional
            if set to True, all the operations are performed inplace

        Returns
        -------
        C : tensor
            output tensor of size (B, C, M, N, 2) such that:
            C[b, c, m, n, :] = A[b, c, m, n, :] * B[m, n, :]
    """
    if not iscomplex(A):
        raise TypeError('The input must be complex, indicated by a last '
                        'dimension of size 2')

    if B.ndimension() != 3:
        raise RuntimeError('The filter must be a 3-tensor, with a last '
                           'dimension of size 1 or 2 to indicate it is real '
                           'or complex, respectively')

    if not iscomplex(B) and not isreal(B):
        raise TypeError('The filter must be complex or real, indicated by a '
                        'last dimension of size 2 or 1, respectively')

    if A.size()[-3:-1] != B.size()[-3:-1]:
        raise RuntimeError('The filters are not compatible for multiplication!')

    if A.dtype is not B.dtype:
        raise RuntimeError('A and B must be of the same dtype')

    if A.device != B.device:
        raise RuntimeError('A and B must be on the same device')

    if isreal(B):
        if inplace:
            return A.mul_(B)
        else:
            return A * B
    else:
        A, B = A.contiguous(), B.contiguous()
        C = A.new(A.size()) if not inplace else A
        m, n = B.nelement() // 2, A.nelement() // B.nelement()
        lda = m
        ldc = m
        incx = 1
        handle = torch.cuda.current_blas_handle()
        stream = torch.cuda.current_stream()._as_parameter_
        cublas.cublasSetStream(handle, stream)
        cublas.cublasCdgmm(handle, 'l', m, n, A.data_ptr(), lda, B.data_ptr(), incx, C.data_ptr(), ldc)
        return C
    def forward(ctx, A, B):
        """
        Complex pointwise multiplication between (batched) tensor A and tensor B.

        Parameters
        ----------
        A : tensor
            input tensor with size (B, C, M, N, 2)
        B : tensor
            B is a complex tensor of size (M, N, 2)
        inplace : boolean, optional
            if set to True, all the operations are performed inplace

        Returns
        -------
        C : tensor
            output tensor of size (B, C, M, N, 2) such that:
            C[b, c, m, n, :] = A[b, c, m, n, :] * B[m, n, :]
        """
        A, B = A.contiguous(), B.contiguous()
        ctx.save_for_backward(A, B)

        if A.size()[-3:] != B.size():
            raise RuntimeError(
                'The filters are not compatible for multiplication!')

        if not iscomplex(A) or not iscomplex(B):
            raise TypeError('The input, filter and output should be complex')

        if B.ndimension() != 3:
            raise RuntimeError('The filters must be simply a complex array!')

        if type(A) is not type(B):
            raise RuntimeError('A and B should be same type!')

        if not A.is_cuda:
            raise RuntimeError('Use the torch backend for cpu tensors!')

        C = A.new(A.size())
        m, n = B.nelement() // 2, A.nelement() // B.nelement()
        lda = m
        ldc = m
        incx = 1
        handle = torch.cuda.current_blas_handle()
        stream = torch.cuda.current_stream()._as_parameter_
        cublas.cublasSetStream(handle, stream)
        cublas.cublasCdgmm(handle, 'l', m, n, A.data_ptr(), lda, B.data_ptr(),
                           incx, C.data_ptr(), ldc)
        return C
예제 #5
0
def cdgmm(A, B, jit=True, inplace=False):
    """This function uses the C-wrapper to use cuBLAS.
        """
    A, B = A.contiguous(), B.contiguous()

    if A.size()[-3:] != B.size():
        raise RuntimeError(
            'The filters are not compatible for multiplication!')

    if not iscomplex(A) or not iscomplex(B):
        raise TypeError('The input, filter and output should be complex')

    if B.ndimension() != 3:
        raise RuntimeError('The filters must be simply a complex array!')

    if type(A) is not type(B):
        raise RuntimeError('A and B should be same type!')

    if not jit or isinstance(A, (torch.FloatTensor, torch.DoubleTensor)):
        C = A.new(A.size())

        A_r = A[..., 0].contiguous().view(-1, A.size(-2) * A.size(-3))
        A_i = A[..., 1].contiguous().view(-1, A.size(-2) * A.size(-3))

        B_r = B[..., 0].contiguous().view(
            B.size(-2) * B.size(-3)).unsqueeze(0).expand_as(A_i)
        B_i = B[..., 1].contiguous().view(
            B.size(-2) * B.size(-3)).unsqueeze(0).expand_as(A_r)

        C[..., 0].copy_(A_r * B_r - A_i * B_i)
        C[..., 1].copy_(A_r * B_i + A_i * B_r)

        # faster if B is actually real
        #B[...,1] = B[...,0]
        #C = A * B.unsqueeze(0).expand_as(A)
        return C if not inplace else A.copy_(C)
    else:
        C = A.new(A.size()) if not inplace else A
        m, n = B.nelement() // 2, A.nelement() // B.nelement()
        lda = m
        ldc = m
        incx = 1
        handle = torch.cuda.current_blas_handle()
        stream = torch.cuda.current_stream()._as_parameter_
        cublas.cublasSetStream(handle, stream)
        cublas.cublasCdgmm(handle, 'l', m, n, A.data_ptr(), lda, B.data_ptr(),
                           incx, C.data_ptr(), ldc)
        return C
예제 #6
0
def cdgmm3d(A, B, inplace=False):
    """
    Pointwise multiplication of complex tensors.

    ----------
    A: complex tensor
    B: complex tensor of the same size as A

    Returns
    -------
    output : tensor of the same size as A containing the result of the
             elementwise complex multiplication of  A with B
    """
    if not A.is_contiguous():
        warnings.warn("cdgmm3d: tensor A is converted to a contiguous array")
        A = A.contiguous()
    if not B.is_contiguous():
        warnings.warn("cdgmm3d: tensor B is converted to a contiguous array")
        B = B.contiguous()

    if A.size()[-4:] != B.size():
        raise RuntimeError('The filters are not compatible for multiplication.')

    if not iscomplex(A) or not iscomplex(B):
        raise TypeError('The input, filter and output should be complex.')

    if B.ndimension() != 4:
        raise RuntimeError('The filters must be simply a complex array.')

    if type(A) is not type(B):
        raise RuntimeError('A and B should be same type.')

    if not A.is_cuda:
        raise RuntimeError('Use the torch backend for cpu tensors.')

    C = A.new(A.size()) if not inplace else A
    m, n = B.nelement() // 2, A.nelement() // B.nelement()
    lda = m
    ldc = m
    incx = 1
    handle = torch.cuda.current_blas_handle()
    stream = torch.cuda.current_stream()._as_parameter_
    cublas.cublasSetStream(handle, stream)
    cublas.cublasCdgmm(handle, 'l', m, n, A.data_ptr(), lda, B.data_ptr(), incx, C.data_ptr(), ldc)
    return C
def cdgmm(A, B, jit=True, inplace=False):
    """This function uses the C-wrapper to use cuBLAS.
        """
    A, B = A.contiguous(), B.contiguous()

    if A.size()[-3:] != B.size():
        raise RuntimeError('The filters are not compatible for multiplication!')

    if not iscomplex(A) or not iscomplex(B):
        raise TypeError('The input, filter and output should be complex')

    if B.ndimension() != 3:
        raise RuntimeError('The filters must be simply a complex array!')

    if type(A) is not type(B):
        raise RuntimeError('A and B should be same type!')

    if not jit or isinstance(A, (torch.FloatTensor, torch.DoubleTensor)):
        C = A.new(A.size())

        A_r = A[..., 0].contiguous().view(-1, A.size(-2)*A.size(-3))
        A_i = A[..., 1].contiguous().view(-1, A.size(-2)*A.size(-3))

        B_r = B[...,0].contiguous().view(B.size(-2)*B.size(-3)).unsqueeze(0).expand_as(A_i)
        B_i = B[..., 1].contiguous().view(B.size(-2)*B.size(-3)).unsqueeze(0).expand_as(A_r)

        C[..., 0].copy_(A_r * B_r - A_i * B_i)
        C[..., 1].copy_(A_r * B_i + A_i * B_r)

        # faster if B is actually real
        #B[...,1] = B[...,0]
        #C = A * B.unsqueeze(0).expand_as(A)
        return C if not inplace else A.copy_(C)
    else:
        C = A.new(A.size()) if not inplace else A
        m, n = B.nelement() // 2, A.nelement() // B.nelement()
        lda = m
        ldc = m
        incx = 1
        handle = torch.cuda.current_blas_handle()
        stream = torch.cuda.current_stream()._as_parameter_
        cublas.cublasSetStream(handle, stream)
        cublas.cublasCdgmm(handle, 'l', m, n, A.data_ptr(), lda, B.data_ptr(), incx, C.data_ptr(), ldc)
        return C
예제 #8
0
def cublas_cdgmm(A, x, out=None):
    if out is not None:
        assert out.is_contiguous() and out.size() == A.size()
    else:
        out = A.new(A.size())
    assert x.dim() == 2 and x.size(-1) == 2 and A.size(-1) == 2
    assert A.dim() == 3
    assert x.size(0) == A.size(1) or x.size(0) == A.size(0)
    assert A.type() == x.type() == out.type()
    assert A.is_contiguous()

    if not isinstance(A, (torch.cuda.FloatTensor, torch.cuda.DoubleTensor)):
        raise NotImplementedError
    else:
        m, n = A.size(1), A.size(0)
        if x.size(0) == A.size(1):
            mode = 'l'
        elif x.size(0) == A.size(0):
            mode = 'r'
        lda, ldc = m, m
        incx = 1
        handle = torch.cuda.current_blas_handle()
        stream = torch.cuda.current_stream()._as_parameter_
        from skcuda import cublas
        cublas.cublasSetStream(handle, stream)
        args = [
            handle, mode, m, n,
            A.data_ptr(), lda,
            x.data_ptr(), incx,
            out.data_ptr(), ldc
        ]
        if isinstance(A, torch.cuda.FloatTensor):
            cublas.cublasCdgmm(*args)
        elif isinstance(A, torch.cuda.DoubleTensor):
            cublas.cublasZdgmm(*args)
        return out
예제 #9
0
def cdgmm3d(A, B, inplace=False):
    """Complex pointwise multiplication.

        Complex pointwise multiplication between (batched) tensor A and tensor B.

        Parameters
        ----------
        A : torch tensor
            Complex torch tensor.
        B : torch tensor
            Complex of the same size as A.
        inplace : boolean, optional
            If set True, all the operations are performed inplace.

        Raises
        ------
        RuntimeError
            In the event that the tensors are not compatibile for multiplication
            (i.e. the final four dimensions of A do not match with the dimensions
            of B), or in the event that B is not complex, or in the event that the
            type of A and B are not the same.
        TypeError
            In the event that x is not complex i.e. does not have a final dimension
            of 2, or in the event that both tensors are not on the same device.

        Returns
        -------
        output : torch tensor
            Torch tensor of the same size as A containing the result of the
            elementwise complex multiplication of A with B.

    """
    if not A.is_contiguous():
        warnings.warn("cdgmm3d: tensor A is converted to a contiguous array")
        A = A.contiguous()
    if not B.is_contiguous():
        warnings.warn("cdgmm3d: tensor B is converted to a contiguous array")
        B = B.contiguous()

    if A.shape[-4:] != B.shape:
        raise RuntimeError(
            'The filters are not compatible for multiplication.')

    if not _is_complex(A) or not _is_complex(B):
        raise TypeError('The input, filter and output should be complex.')

    if B.ndimension() != 4:
        raise RuntimeError('The filters must be simply a complex array.')

    if type(A) is not type(B):
        raise RuntimeError('A and B should be same type.')

    if not A.is_cuda:
        raise RuntimeError('Use the torch backend for CPU tensors.')

    C = A.new(A.shape) if not inplace else A
    m, n = B.nelement() // 2, A.nelement() // B.nelement()
    lda = m
    ldc = m
    incx = 1
    handle = torch.cuda.current_blas_handle()
    stream = torch.cuda.current_stream()._as_parameter_
    cublas.cublasSetStream(handle, stream)
    cublas.cublasCdgmm(handle, 'l', m, n, A.data_ptr(), lda, B.data_ptr(),
                       incx, C.data_ptr(), ldc)
    return C
예제 #10
0
def cdgmm(A, B, inplace=False):
    """Complex pointwise multiplication.

        Complex pointwise multiplication between (batched) tensor A and tensor
        B.

        Parameters
        ----------
        A : tensor
            A is a complex tensor of size (B, C, M, N, 2).
        B : tensor
            B is a complex tensor of size (M, N, 2) or real tensor of (M, N,
            1).
        inplace : boolean, optional
            If set to True, all the operations are performed in place.

        Raises
        ------
        RuntimeError
            In the event that the filter B is not a 3-tensor with a last
            dimension of size 1 or 2, or A and B are not compatible for
            multiplication, or if A or B are not contiguous.
        TypeError
            In the event that A is not complex, or B does not have a final
            dimension of 1 or 2, or A and B are not of the same dtype, or
            if A or B are not cuda tensors, or if A and B are not on the same
            device.

        Returns
        -------
        C : tensor
            Output tensor of size (B, C, M, N, 2) such that:
            C[b, c, m, n, :] = A[b, c, m, n, :] * B[m, n, :].

    """
    if not _is_complex(A):
        raise TypeError(
            'The input should be complex (i.e. last dimension is 2).')

    if not _is_complex(B) and not _is_real(B):
        raise TypeError('The filter should be complex or real, indicated by a '
                        'last dimension of size 2 or 1, respectively.')

    if A.shape[-len(B.shape):-1] != B.shape[:-1]:
        raise RuntimeError(
            'The filters are not compatible for multiplication.')

    if A.dtype is not B.dtype:
        raise TypeError('Input and filter must be of the same dtype.')

    if not A.is_cuda or not B.is_cuda:
        raise TypeError('Input and filter must be CUDA tensors.')

    if A.device.index != B.device.index:
        raise TypeError('Input and filter must be on the same GPU.')

    if _is_real(B):
        if inplace:
            return A.mul_(B)
        else:
            return A * B
    else:
        if not A.is_contiguous() or not B.is_contiguous():
            raise RuntimeError('Tensors must be contiguous.')

        C = A.new(A.shape) if not inplace else A
        m, n = B.nelement() // 2, A.nelement() // B.nelement()
        lda = m
        ldc = m
        incx = 1
        handle = torch.cuda.current_blas_handle()
        stream = torch.cuda.current_stream()._as_parameter_
        cublas.cublasSetStream(handle, stream)
        cublas.cublasCdgmm(handle, 'l', m, n, A.data_ptr(), lda, B.data_ptr(),
                           incx, C.data_ptr(), ldc)
        return C