示例#1
0
    def forward(ctx,
                X,
                Y,
                static_kernel,
                dyadic_order,
                sym=False,
                _naive_solver=False):

        A = X.shape[0]
        B = Y.shape[0]
        M = X.shape[1]
        N = Y.shape[1]
        D = X.shape[2]

        MM = (2**dyadic_order) * (M - 1)
        NN = (2**dyadic_order) * (N - 1)

        # computing dsdt k(X^i_s,Y^j_t)
        G_static = static_kernel.Gram_matrix(X, Y)
        G_static_ = G_static[:, :, 1:,
                             1:] + G_static[:, :, :-1, :
                                            -1] - G_static[:, :, 1:, :
                                                           -1] - G_static[:, :, :
                                                                          -1,
                                                                          1:]
        G_static_ = tile(
            tile(G_static_, 2, 2**dyadic_order) / float(2**dyadic_order), 3, 2
            **dyadic_order) / float(2**dyadic_order)

        # if on GPU
        if X.device.type == 'cuda':

            assert max(
                MM, NN
            ) < 1024, 'n must be lowered or data must be moved to CPU as the current choice of n makes exceed the thread limit'

            # cuda parameters
            threads_per_block = max(MM + 1, NN + 1)
            n_anti_diagonals = 2 * threads_per_block - 1

            # Prepare the tensor of output solutions to the PDE (forward)
            G = torch.zeros((A, B, MM + 2, NN + 2),
                            device=G_static.device,
                            dtype=G_static.dtype)
            G[:, :, 0, :] = 1.
            G[:, :, :, 0] = 1.

            # Run the CUDA kernel.
            blockspergrid = (A, B)
            compute_sig_kernel_Gram_mat_varpar_from_increments_cuda[
                blockspergrid,
                threads_per_block](cuda.as_cuda_array(G_static_.detach()),
                                   MM + 1, NN + 1, n_anti_diagonals,
                                   cuda.as_cuda_array(G), _naive_solver)

            G = G[:, :, :-1, :-1]

        else:
            G = torch.tensor(sig_kernel_Gram_varpar(G_static_.detach().numpy(),
                                                    sym, _naive_solver),
                             dtype=G_static.dtype,
                             device=G_static.device)

        ctx.save_for_backward(X, Y, G, G_static)
        ctx.sym = sym
        ctx.static_kernel = static_kernel
        ctx.dyadic_order = dyadic_order
        ctx._naive_solver = _naive_solver

        return G[:, :, -1, -1]
示例#2
0
    def backward(ctx, grad_output):

        X, Y, G, G_static = ctx.saved_tensors
        sym = ctx.sym
        static_kernel = ctx.static_kernel
        dyadic_order = ctx.dyadic_order
        _naive_solver = ctx._naive_solver

        G_static_ = G_static[:, :, 1:,
                             1:] + G_static[:, :, :-1, :
                                            -1] - G_static[:, :, 1:, :
                                                           -1] - G_static[:, :, :
                                                                          -1,
                                                                          1:]
        G_static_ = tile(
            tile(G_static_, 2, 2**dyadic_order) / float(2**dyadic_order), 3, 2
            **dyadic_order) / float(2**dyadic_order)

        A = X.shape[0]
        B = Y.shape[0]
        M = X.shape[1]
        N = Y.shape[1]
        D = X.shape[2]

        MM = (2**dyadic_order) * (M - 1)
        NN = (2**dyadic_order) * (N - 1)

        # Reverse paths
        X_rev = torch.flip(X, dims=[1])
        Y_rev = torch.flip(Y, dims=[1])

        # computing dsdt k(X_rev^i_s,Y_rev^j_t) for variation of parameters
        G_static_rev = flip(flip(G_static_, dim=2), dim=3)

        # if on GPU
        if X.device.type == 'cuda':

            # Prepare the tensor of output solutions to the PDE (backward)
            G_rev = torch.zeros((A, B, MM + 2, NN + 2),
                                device=G_static.device,
                                dtype=G_static.dtype)
            G_rev[:, :, 0, :] = 1.
            G_rev[:, :, :, 0] = 1.

            # cuda parameters
            threads_per_block = max(MM + 1, NN + 1)
            n_anti_diagonals = 2 * threads_per_block - 1

            # Compute signature kernel for reversed paths
            blockspergrid = (A, B)
            compute_sig_kernel_Gram_mat_varpar_from_increments_cuda[
                blockspergrid,
                threads_per_block](cuda.as_cuda_array(G_static_rev.detach()),
                                   MM + 1, NN + 1, n_anti_diagonals,
                                   cuda.as_cuda_array(G_rev), _naive_solver)

            G_rev = G_rev[:, :, :-1, :-1]

        # if on CPU
        else:
            G_rev = torch.tensor(sig_kernel_Gram_varpar(
                G_static_rev.detach().numpy(), sym, _naive_solver),
                                 dtype=G_static.dtype,
                                 device=G_static.device)

        G_rev = flip(flip(G_rev, dim=2), dim=3)
        GG = G[:, :, :-1, :-1] * G_rev[:, :, 1:, 1:]

        # finite difference step
        h = 1e-9

        Xh = X[:, :, :, None] + h * torch.eye(
            D, dtype=X.dtype, device=X.device)[None, None, :]
        Xh = Xh.permute(0, 1, 3, 2)
        Xh = Xh.reshape(A, M * D, D)

        G_h = static_kernel.Gram_matrix(Xh, Y)
        G_h = G_h.reshape(A, B, M, D, N)
        G_h = G_h.permute(0, 1, 2, 4, 3)

        Diff_1 = G_h[:, :, 1:, 1:, :] - G_h[:, :, 1:, :-1, :] - (
            G_static[:, :, 1:, 1:])[:, :, :, :, None] + (
                G_static[:, :, 1:, :-1])[:, :, :, :, None]
        Diff_1 = tile(
            tile(Diff_1, 2, 2**dyadic_order) / float(2**dyadic_order), 3, 2**
            dyadic_order) / float(2**dyadic_order)
        Diff_2 = G_h[:, :, 1:, 1:, :] - G_h[:, :, 1:, :-1, :] - (
            G_static[:, :, 1:, 1:])[:, :, :, :, None] + (
                G_static[:, :, 1:, :-1])[:, :, :, :, None]
        Diff_2 += -G_h[:, :, :-1, 1:, :] + G_h[:, :, :-1, :-1, :] + (
            G_static[:, :, :-1, 1:])[:, :, :, :, None] - (
                G_static[:, :, :-1, :-1])[:, :, :, :, None]
        Diff_2 = tile(
            tile(Diff_2, 2, 2**dyadic_order) / float(2**dyadic_order), 3, 2**
            dyadic_order) / float(2**dyadic_order)

        grad_1 = (GG[:, :, :, :, None] * Diff_1) / h
        grad_2 = (GG[:, :, :, :, None] * Diff_2) / h

        grad_1 = torch.sum(grad_1, axis=3)
        grad_1 = torch.sum(grad_1.reshape(A, B, M - 1, 2**dyadic_order, D),
                           axis=3)
        grad_2 = torch.sum(grad_2, axis=3)
        grad_2 = torch.sum(grad_2.reshape(A, B, M - 1, 2**dyadic_order, D),
                           axis=3)

        grad_prev = grad_1[:, :, :-1, :] + grad_2[:, :, 1:, :]  # /¯¯
        grad_next = torch.cat([
            torch.zeros((A, B, 1, D), dtype=X.dtype, device=X.device),
            grad_1[:, :, 1:, :]
        ],
                              dim=2)  # /
        grad_incr = grad_prev - grad_1[:, :, 1:, :]
        grad_points = torch.cat(
            [(grad_2[:, :, 0, :] - grad_1[:, :, 0, :])[:, :, None, :],
             grad_incr, grad_1[:, :, -1, :][:, :, None, :]],
            dim=2)

        if sym:
            grad = (grad_output[:, :, None, None] * grad_points +
                    grad_output.t()[:, :, None, None] * grad_points).sum(dim=1)
            return grad, None, None, None, None, None
        else:
            grad = (grad_output[:, :, None, None] * grad_points).sum(dim=1)
            return grad, None, None, None, None, None
示例#3
0
    def forward(ctx,
                K_XX,
                K_XY,
                K_YY,
                static_kernel,
                dyadic_order,
                lambda_,
                sym=False,
                _naive_solver=False,
                inspect=False,
                centered=False):

        A = K_XX.shape[0]
        B = K_YY.shape[0]
        M = K_XX.shape[2]
        N = K_YY.shape[2]

        MM = (2**dyadic_order) * (M - 1)
        NN = (2**dyadic_order) * (N - 1)

        # computing dsdt k(X^1[i]_s,Y^1[j]_t)
        G_base = innerprodCKME(
            K_XX,
            K_XY,
            K_YY,
            lambda_,
            static_kernel,
            sym=sym,
            centered=centered
        )  # <--------------------- this is the only change compared to rank 0

        G_base_ = G_base[:, :, 1:,
                         1:] + G_base[:, :, :-1, :
                                      -1] - G_base[:, :,
                                                   1:, :-1] - G_base[:, :, :-1,
                                                                     1:]

        G_base_ = tile(
            tile(G_base_, 2, 2**dyadic_order) / float(2**dyadic_order), 3, 2**
            dyadic_order) / float(2**dyadic_order)

        # if on GPU
        if K_XX.device.type == 'cuda':

            assert max(
                MM, NN
            ) < 1024, 'n must be lowered or data must be moved to CPU as the current choice of n makes exceed the thread limit'

            # cuda parameters
            threads_per_block = max(MM + 1, NN + 1)
            n_anti_diagonals = 2 * threads_per_block - 1

            # Prepare the tensor of output solutions to the PDE (forward)
            G = torch.zeros((A, B, MM + 2, NN + 2),
                            device=G_base.device,
                            dtype=G_base.dtype)
            G[:, :, 0, :] = 1.
            G[:, :, :, 0] = 1.

            # Run the CUDA kernel.
            blockspergrid = (A, B)
            compute_sig_kernel_Gram_mat_varpar_from_increments_cuda[
                blockspergrid,
                threads_per_block](cuda.as_cuda_array(G_base_.detach()),
                                   MM + 1, NN + 1, n_anti_diagonals,
                                   cuda.as_cuda_array(G), _naive_solver)
            G = G[:, :, :-1, :-1]

        else:
            G = torch.tensor(sig_kernel_Gram_varpar(G_base_.detach().numpy(),
                                                    sym, _naive_solver),
                             dtype=G_base.dtype,
                             device=G_base.device)
        if inspect:
            return G[:, :, -1, -1], G_base
        return G[:, :, -1, -1]