Example #1
0
    def forward(ctx, X, Y, static_kernel, dyadic_order, _naive_solver=False):

        A = X.shape[0]
        M = X.shape[1]
        N = Y.shape[1]
        D = X.shape[2]

        MM = (2**dyadic_order) * (M - 1)
        NN = (2**dyadic_order) * (N - 1)

        # computing dsdt k(X^i_s,Y^i_t)
        G_static = static_kernel.batch_kernel(X, Y)
        G_static_ = G_static[:, 1:,
                             1:] + G_static[:, :-1, :
                                            -1] - G_static[:, 1:, :
                                                           -1] - G_static[:, :
                                                                          -1,
                                                                          1:]
        G_static_ = tile(
            tile(G_static_, 1, 2**dyadic_order) / float(2**dyadic_order), 2, 2
            **dyadic_order) / float(2**dyadic_order)

        # if on GPU
        if X.device.type == 'cuda':

            assert max(
                MM + 1, NN + 1
            ) < 1024, 'n must be lowered or data must be moved to CPU as the current choice of n makes exceed the thread limit'

            # cuda parameters
            threads_per_block = max(MM + 1, NN + 1)
            n_anti_diagonals = 2 * threads_per_block - 1

            # Prepare the tensor of output solutions to the PDE (forward)
            K = torch.zeros((A, MM + 2, NN + 2),
                            device=G_static.device,
                            dtype=G_static.dtype)
            K[:, 0, :] = 1.
            K[:, :, 0] = 1.

            # Compute the forward signature kernel
            compute_sig_kernel_batch_varpar_from_increments_cuda[
                A, threads_per_block](cuda.as_cuda_array(G_static_.detach()),
                                      MM + 1, NN + 1, n_anti_diagonals,
                                      cuda.as_cuda_array(K), _naive_solver)
            K = K[:, :-1, :-1]

        # if on CPU
        else:
            K = torch.tensor(sig_kernel_batch_varpar(
                G_static_.detach().numpy(), _naive_solver),
                             dtype=G_static.dtype,
                             device=G_static.device)

        ctx.save_for_backward(X, Y, G_static, K)
        ctx.static_kernel = static_kernel
        ctx.dyadic_order = dyadic_order
        ctx._naive_solver = _naive_solver

        return K[:, -1, -1]
Example #2
0
    def backward(ctx, grad_output):

        X, Y, G_static, K = ctx.saved_tensors
        static_kernel = ctx.static_kernel
        dyadic_order = ctx.dyadic_order
        _naive_solver = ctx._naive_solver

        G_static_ = G_static[:, 1:,
                             1:] + G_static[:, :-1, :
                                            -1] - G_static[:, 1:, :
                                                           -1] - G_static[:, :
                                                                          -1,
                                                                          1:]
        G_static_ = tile(
            tile(G_static_, 1, 2**dyadic_order) / float(2**dyadic_order), 2, 2
            **dyadic_order) / float(2**dyadic_order)

        A = X.shape[0]
        M = X.shape[1]
        N = Y.shape[1]
        D = X.shape[2]

        MM = (2**dyadic_order) * (M - 1)
        NN = (2**dyadic_order) * (N - 1)

        # Reverse paths
        X_rev = torch.flip(X, dims=[1])
        Y_rev = torch.flip(Y, dims=[1])

        # computing dsdt k(X_rev^i_s,Y_rev^i_t) for variation of parameters
        G_static_rev = flip(flip(G_static_, dim=1), dim=2)

        # if on GPU
        if X.device.type == 'cuda':

            # Prepare the tensor of output solutions to the PDE (backward)
            K_rev = torch.zeros((A, MM + 2, NN + 2),
                                device=G_static_rev.device,
                                dtype=G_static_rev.dtype)
            K_rev[:, 0, :] = 1.
            K_rev[:, :, 0] = 1.

            # cuda parameters
            threads_per_block = max(MM, NN)
            n_anti_diagonals = 2 * threads_per_block - 1

            # Compute signature kernel for reversed paths
            compute_sig_kernel_batch_varpar_from_increments_cuda[
                A,
                threads_per_block](cuda.as_cuda_array(G_static_rev.detach()),
                                   MM + 1, NN + 1, n_anti_diagonals,
                                   cuda.as_cuda_array(K_rev), _naive_solver)

            K_rev = K_rev[:, :-1, :-1]

        # if on CPU
        else:
            K_rev = torch.tensor(sig_kernel_batch_varpar(
                G_static_rev.detach().numpy(), _naive_solver),
                                 dtype=G_static.dtype,
                                 device=G_static.device)

        K_rev = flip(flip(K_rev, dim=1), dim=2)
        KK = K[:, :-1, :-1] * K_rev[:, 1:, 1:]

        # finite difference step
        h = 1e-9

        Xh = X[:, :, :, None] + h * torch.eye(
            D, dtype=X.dtype, device=X.device)[None, None, :]
        Xh = Xh.permute(0, 1, 3, 2)
        Xh = Xh.reshape(A, M * D, D)

        G_h = static_kernel.batch_kernel(Xh, Y)
        G_h = G_h.reshape(A, M, D, N)
        G_h = G_h.permute(0, 1, 3, 2)

        Diff_1 = G_h[:, 1:, 1:, :] - G_h[:, 1:, :-1, :] - (
            G_static[:, 1:, 1:])[:, :, :,
                                 None] + (G_static[:, 1:, :-1])[:, :, :, None]
        Diff_1 = tile(
            tile(Diff_1, 1, 2**dyadic_order) / float(2**dyadic_order), 2, 2**
            dyadic_order) / float(2**dyadic_order)
        Diff_2 = G_h[:, 1:, 1:, :] - G_h[:, 1:, :-1, :] - (
            G_static[:, 1:, 1:])[:, :, :,
                                 None] + (G_static[:, 1:, :-1])[:, :, :, None]
        Diff_2 += -G_h[:, :-1, 1:, :] + G_h[:, :-1, :-1, :] + (
            G_static[:, :-1, 1:])[:, :, :, None] - (
                G_static[:, :-1, :-1])[:, :, :, None]
        Diff_2 = tile(
            tile(Diff_2, 1, 2**dyadic_order) / float(2**dyadic_order), 2, 2**
            dyadic_order) / float(2**dyadic_order)

        grad_1 = (KK[:, :, :, None] * Diff_1) / h
        grad_2 = (KK[:, :, :, None] * Diff_2) / h

        grad_1 = torch.sum(grad_1, axis=2)
        grad_1 = torch.sum(grad_1.reshape(A, M - 1, 2**dyadic_order, D),
                           axis=2)
        grad_2 = torch.sum(grad_2, axis=2)
        grad_2 = torch.sum(grad_2.reshape(A, M - 1, 2**dyadic_order, D),
                           axis=2)

        grad_prev = grad_1[:, :-1, :] + grad_2[:, 1:, :]  # /¯¯
        grad_next = torch.cat([
            torch.zeros(
                (A, 1, D), dtype=X.dtype, device=X.device), grad_1[:, 1:, :]
        ],
                              dim=1)  # /
        grad_incr = grad_prev - grad_1[:, 1:, :]
        grad_points = torch.cat(
            [(grad_2[:, 0, :] - grad_1[:, 0, :])[:, None, :], grad_incr,
             grad_1[:, -1, :][:, None, :]],
            dim=1)

        if Y.requires_grad:
            grad_points *= 2

        return grad_output[:, None, None] * grad_points, None, None, None, None