Esempio n. 1
0
    def backward(ctx, grad_output):
        dev = grad_output.device
        dtype = grad_output.dtype
        D, R, gamma, bandwidth = ctx.saved_tensors

        B = D.shape[0]
        N = D.shape[1]
        M = D.shape[2]
        threads_per_block = max(N, M)
        n_passes = 2 * threads_per_block - 1

        D_ = torch.zeros((B, N + 2, M + 2), dtype=dtype, device=dev)
        D_[:, 1:N + 1, 1:M + 1] = D

        R[:, :, -1] = -math.inf
        R[:, -1, :] = -math.inf
        R[:, -1, -1] = R[:, -2, -2]

        E = torch.zeros((B, N + 2, M + 2), dtype=dtype, device=dev)
        E[:, -1, -1] = 1

        # Grid and block sizes are set same as done above for the forward() call
        compute_softdtw_backward_cuda[B, threads_per_block](
            cuda.as_cuda_array(D_),
            cuda.as_cuda_array(R),
            1.0 / gamma.item(),
            bandwidth.item(),
            N,
            M,
            n_passes,
            cuda.as_cuda_array(E),
        )
        E = E[:, 1:N + 1, 1:M + 1]
        return grad_output.view(-1, 1, 1).expand_as(E) * E, None, None
Esempio n. 2
0
    def forward(ctx, X, Y, static_kernel, dyadic_order, _naive_solver=False):

        A = X.shape[0]
        M = X.shape[1]
        N = Y.shape[1]
        D = X.shape[2]

        MM = (2**dyadic_order) * (M - 1)
        NN = (2**dyadic_order) * (N - 1)

        # computing dsdt k(X^i_s,Y^i_t)
        G_static = static_kernel.batch_kernel(X, Y)
        G_static_ = G_static[:, 1:,
                             1:] + G_static[:, :-1, :
                                            -1] - G_static[:, 1:, :
                                                           -1] - G_static[:, :
                                                                          -1,
                                                                          1:]
        G_static_ = tile(
            tile(G_static_, 1, 2**dyadic_order) / float(2**dyadic_order), 2, 2
            **dyadic_order) / float(2**dyadic_order)

        # if on GPU
        if X.device.type == 'cuda':

            assert max(
                MM + 1, NN + 1
            ) < 1024, 'n must be lowered or data must be moved to CPU as the current choice of n makes exceed the thread limit'

            # cuda parameters
            threads_per_block = max(MM + 1, NN + 1)
            n_anti_diagonals = 2 * threads_per_block - 1

            # Prepare the tensor of output solutions to the PDE (forward)
            K = torch.zeros((A, MM + 2, NN + 2),
                            device=G_static.device,
                            dtype=G_static.dtype)
            K[:, 0, :] = 1.
            K[:, :, 0] = 1.

            # Compute the forward signature kernel
            compute_sig_kernel_batch_varpar_from_increments_cuda[
                A, threads_per_block](cuda.as_cuda_array(G_static_.detach()),
                                      MM + 1, NN + 1, n_anti_diagonals,
                                      cuda.as_cuda_array(K), _naive_solver)
            K = K[:, :-1, :-1]

        # if on CPU
        else:
            K = torch.tensor(sig_kernel_batch_varpar(
                G_static_.detach().numpy(), _naive_solver),
                             dtype=G_static.dtype,
                             device=G_static.device)

        ctx.save_for_backward(X, Y, G_static, K)
        ctx.static_kernel = static_kernel
        ctx.dyadic_order = dyadic_order
        ctx._naive_solver = _naive_solver

        return K[:, -1, -1]
Esempio n. 3
0
    def forward(ctx, D, gamma, bandwidth):
        dev = D.device
        dtype = D.dtype
        gamma = torch.cuda.FloatTensor([gamma])
        bandwidth = torch.cuda.FloatTensor([bandwidth])

        B = D.shape[0]
        N = D.shape[1]
        M = D.shape[2]
        threads_per_block = max(N, M)
        n_passes = 2 * threads_per_block - 1

        # Prepare the output array
        R = torch.ones((B, N + 2, M + 2), device=dev, dtype=dtype) * math.inf
        R[:, 0, 0] = 0

        # Run the CUDA kernel.
        # Set CUDA's grid size to be equal to the batch size (every CUDA block processes one sample pair)
        # Set the CUDA block size to be equal to the length of the longer sequence (equal to the size of the largest diagonal)
        compute_softdtw_cuda[B, threads_per_block](
            cuda.as_cuda_array(D.detach()),
            gamma.item(),
            bandwidth.item(),
            N,
            M,
            n_passes,
            cuda.as_cuda_array(R),
        )
        ctx.save_for_backward(D, R, gamma, bandwidth)
        return R[:, -2, -2]
Esempio n. 4
0
    def backward(ctx, grad_output):
        dev = grad_output.device
        dtype = grad_output.dtype
        X, D, R, gamma, warp, bandwidth = ctx.saved_tensors

        B = D.shape[0]
        N = D.shape[1]
        M = D.shape[2]
        H = X.shape[2]
        threads_per_block = max(N, M)
        n_passes = 2 * threads_per_block - 1

        D_ = torch.zeros((B, N + 2, M + 2), dtype=dtype, device=dev)
        D_[:, 1:N + 1, 1:M + 1] = D

        R[:, :, -1] = -math.inf
        R[:, -1, :] = -math.inf
        R[:, -1, -1] = R[:, -2, -2]

        E = torch.zeros((B, N + 2, M + 2), dtype=dtype, device=dev)
        E[:, -1, -1] = 1

        # Grid and block sizes are set same as done above for the forward() call
        compute_softdtw_backward_cuda[B, threads_per_block](
            cuda.as_cuda_array(D_), cuda.as_cuda_array(R), 1.0 / gamma.item(),
            warp.item(), bandwidth.item(), N, M, n_passes,
            cuda.as_cuda_array(E))
        E = E[:, 1:N + 1, 1:M + 1]  # dR_D

        # Jacobian product for the gradient w.r.t. X
        # See https://github.com/lyprince/sdtw_pytorch/blob/e509ef56374c83817bcf303bff102ca9636a1efe/sdtw.py#L222
        dR_X = E.matmul(torch.ones(B, M, H, dtype=dtype,
                                   device=dev)) * torch.sign(X)

        return dR_X, None, None, None, None
    def backward(ctx, grad_output):
        dev = grad_output.device
        dtype = grad_output.dtype
        X, raw_D, D, R, gamma, warp, bandwidth = ctx.saved_tensors

        B = D.shape[0]
        N = D.shape[1]
        M = D.shape[2]
        H = X.shape[2]
        threads_per_block = max(N, M)
        n_passes = 2 * threads_per_block - 1

        D_ = torch.zeros((B, N + 2, M + 2), dtype=dtype, device=dev)
        D_[:, 1:N + 1, 1:M + 1] = D

        E = torch.zeros((B, N + 2, M + 2), dtype=dtype, device=dev)
        E[:, -1, -1] = 1

        G = torch.zeros((B, N + 2, M + 2), dtype=dtype, device=dev)
        G[:, -1, -1] = 1

        compute_softdtw_backward_cuda[B, threads_per_block](
            cuda.as_cuda_array(D_), cuda.as_cuda_array(R), 1.0 / gamma.item(),
            warp.item(), bandwidth.item(), N, M, n_passes,
            cuda.as_cuda_array(E), cuda.as_cuda_array(G))
        G = G[:, 1:N + 1, 1:M + 1]  # dR_D

        tmp_G = G.unsqueeze(-1).expand(-1, -1, -1, H)
        tmp_G = tmp_G * torch.sign(raw_D)
        dR_X = tmp_G.sum(dim=2)

        return grad_output.view(
            -1, 1, 1).expand_as(dR_X) * dR_X, None, None, None, None, None
    def forward(ctx, X, raw_D, D, gamma, warp, bandwidth):
        dev = D.device
        dtype = D.dtype
        gamma = torch.cuda.FloatTensor([gamma])
        warp = torch.cuda.FloatTensor([warp])
        bandwidth = torch.cuda.FloatTensor([bandwidth])

        B = D.shape[0]
        N = D.shape[1]
        M = D.shape[2]
        threads_per_block = max(N + 1, M + 1)
        n_passes = 2 * threads_per_block - 1

        D_ = torch.zeros((B, N + 2, M + 2), dtype=dtype, device=dev)
        D_[:, 1:N + 1, 1:M + 1] = D

        # Prepare the output array
        R = torch.zeros((B, N + 2, M + 2), device=dev, dtype=dtype)
        R[:, :, 0] = torch.ones((B, 1), device=dev, dtype=dtype) * math.inf
        R[:, 0, :] = torch.ones((B, 1), device=dev, dtype=dtype) * math.inf
        R[:, 0, 0] = 0

        compute_softdtw_cuda[B, threads_per_block](
            cuda.as_cuda_array(D_.detach()), gamma.item(), warp.item(),
            bandwidth.item(), N + 1, M + 1, n_passes, cuda.as_cuda_array(R))
        ctx.save_for_backward(X, raw_D, D, R, gamma, warp, bandwidth)
        return R[:, -1, -1]
Esempio n. 7
0
    def backward(ctx, grad_output):

        X, Y, G, G_static = ctx.saved_tensors
        sym = ctx.sym
        static_kernel = ctx.static_kernel
        dyadic_order = ctx.dyadic_order
        _naive_solver = ctx._naive_solver

        G_static_ = G_static[:, :, 1:,
                             1:] + G_static[:, :, :-1, :
                                            -1] - G_static[:, :, 1:, :
                                                           -1] - G_static[:, :, :
                                                                          -1,
                                                                          1:]
        G_static_ = tile(
            tile(G_static_, 2, 2**dyadic_order) / float(2**dyadic_order), 3, 2
            **dyadic_order) / float(2**dyadic_order)

        A = X.shape[0]
        B = Y.shape[0]
        M = X.shape[1]
        N = Y.shape[1]
        D = X.shape[2]

        MM = (2**dyadic_order) * (M - 1)
        NN = (2**dyadic_order) * (N - 1)

        # Reverse paths
        X_rev = torch.flip(X, dims=[1])
        Y_rev = torch.flip(Y, dims=[1])

        # computing dsdt k(X_rev^i_s,Y_rev^j_t) for variation of parameters
        G_static_rev = flip(flip(G_static_, dim=2), dim=3)

        # if on GPU
        if X.device.type == 'cuda':

            # Prepare the tensor of output solutions to the PDE (backward)
            G_rev = torch.zeros((A, B, MM + 2, NN + 2),
                                device=G_static.device,
                                dtype=G_static.dtype)
            G_rev[:, :, 0, :] = 1.
            G_rev[:, :, :, 0] = 1.

            # cuda parameters
            threads_per_block = max(MM + 1, NN + 1)
            n_anti_diagonals = 2 * threads_per_block - 1

            # Compute signature kernel for reversed paths
            blockspergrid = (A, B)
            compute_sig_kernel_Gram_mat_varpar_from_increments_cuda[
                blockspergrid,
                threads_per_block](cuda.as_cuda_array(G_static_rev.detach()),
                                   MM + 1, NN + 1, n_anti_diagonals,
                                   cuda.as_cuda_array(G_rev), _naive_solver)

            G_rev = G_rev[:, :, :-1, :-1]

        # if on CPU
        else:
            G_rev = torch.tensor(sig_kernel_Gram_varpar(
                G_static_rev.detach().numpy(), sym, _naive_solver),
                                 dtype=G_static.dtype,
                                 device=G_static.device)

        G_rev = flip(flip(G_rev, dim=2), dim=3)
        GG = G[:, :, :-1, :-1] * G_rev[:, :, 1:, 1:]

        # finite difference step
        h = 1e-9

        Xh = X[:, :, :, None] + h * torch.eye(
            D, dtype=X.dtype, device=X.device)[None, None, :]
        Xh = Xh.permute(0, 1, 3, 2)
        Xh = Xh.reshape(A, M * D, D)

        G_h = static_kernel.Gram_matrix(Xh, Y)
        G_h = G_h.reshape(A, B, M, D, N)
        G_h = G_h.permute(0, 1, 2, 4, 3)

        Diff_1 = G_h[:, :, 1:, 1:, :] - G_h[:, :, 1:, :-1, :] - (
            G_static[:, :, 1:, 1:])[:, :, :, :, None] + (
                G_static[:, :, 1:, :-1])[:, :, :, :, None]
        Diff_1 = tile(
            tile(Diff_1, 2, 2**dyadic_order) / float(2**dyadic_order), 3, 2**
            dyadic_order) / float(2**dyadic_order)
        Diff_2 = G_h[:, :, 1:, 1:, :] - G_h[:, :, 1:, :-1, :] - (
            G_static[:, :, 1:, 1:])[:, :, :, :, None] + (
                G_static[:, :, 1:, :-1])[:, :, :, :, None]
        Diff_2 += -G_h[:, :, :-1, 1:, :] + G_h[:, :, :-1, :-1, :] + (
            G_static[:, :, :-1, 1:])[:, :, :, :, None] - (
                G_static[:, :, :-1, :-1])[:, :, :, :, None]
        Diff_2 = tile(
            tile(Diff_2, 2, 2**dyadic_order) / float(2**dyadic_order), 3, 2**
            dyadic_order) / float(2**dyadic_order)

        grad_1 = (GG[:, :, :, :, None] * Diff_1) / h
        grad_2 = (GG[:, :, :, :, None] * Diff_2) / h

        grad_1 = torch.sum(grad_1, axis=3)
        grad_1 = torch.sum(grad_1.reshape(A, B, M - 1, 2**dyadic_order, D),
                           axis=3)
        grad_2 = torch.sum(grad_2, axis=3)
        grad_2 = torch.sum(grad_2.reshape(A, B, M - 1, 2**dyadic_order, D),
                           axis=3)

        grad_prev = grad_1[:, :, :-1, :] + grad_2[:, :, 1:, :]  # /¯¯
        grad_next = torch.cat([
            torch.zeros((A, B, 1, D), dtype=X.dtype, device=X.device),
            grad_1[:, :, 1:, :]
        ],
                              dim=2)  # /
        grad_incr = grad_prev - grad_1[:, :, 1:, :]
        grad_points = torch.cat(
            [(grad_2[:, :, 0, :] - grad_1[:, :, 0, :])[:, :, None, :],
             grad_incr, grad_1[:, :, -1, :][:, :, None, :]],
            dim=2)

        if sym:
            grad = (grad_output[:, :, None, None] * grad_points +
                    grad_output.t()[:, :, None, None] * grad_points).sum(dim=1)
            return grad, None, None, None, None, None
        else:
            grad = (grad_output[:, :, None, None] * grad_points).sum(dim=1)
            return grad, None, None, None, None, None
Esempio n. 8
0
    def forward(ctx,
                X,
                Y,
                static_kernel,
                dyadic_order,
                sym=False,
                _naive_solver=False):

        A = X.shape[0]
        B = Y.shape[0]
        M = X.shape[1]
        N = Y.shape[1]
        D = X.shape[2]

        MM = (2**dyadic_order) * (M - 1)
        NN = (2**dyadic_order) * (N - 1)

        # computing dsdt k(X^i_s,Y^j_t)
        G_static = static_kernel.Gram_matrix(X, Y)
        G_static_ = G_static[:, :, 1:,
                             1:] + G_static[:, :, :-1, :
                                            -1] - G_static[:, :, 1:, :
                                                           -1] - G_static[:, :, :
                                                                          -1,
                                                                          1:]
        G_static_ = tile(
            tile(G_static_, 2, 2**dyadic_order) / float(2**dyadic_order), 3, 2
            **dyadic_order) / float(2**dyadic_order)

        # if on GPU
        if X.device.type == 'cuda':

            assert max(
                MM, NN
            ) < 1024, 'n must be lowered or data must be moved to CPU as the current choice of n makes exceed the thread limit'

            # cuda parameters
            threads_per_block = max(MM + 1, NN + 1)
            n_anti_diagonals = 2 * threads_per_block - 1

            # Prepare the tensor of output solutions to the PDE (forward)
            G = torch.zeros((A, B, MM + 2, NN + 2),
                            device=G_static.device,
                            dtype=G_static.dtype)
            G[:, :, 0, :] = 1.
            G[:, :, :, 0] = 1.

            # Run the CUDA kernel.
            blockspergrid = (A, B)
            compute_sig_kernel_Gram_mat_varpar_from_increments_cuda[
                blockspergrid,
                threads_per_block](cuda.as_cuda_array(G_static_.detach()),
                                   MM + 1, NN + 1, n_anti_diagonals,
                                   cuda.as_cuda_array(G), _naive_solver)

            G = G[:, :, :-1, :-1]

        else:
            G = torch.tensor(sig_kernel_Gram_varpar(G_static_.detach().numpy(),
                                                    sym, _naive_solver),
                             dtype=G_static.dtype,
                             device=G_static.device)

        ctx.save_for_backward(X, Y, G, G_static)
        ctx.sym = sym
        ctx.static_kernel = static_kernel
        ctx.dyadic_order = dyadic_order
        ctx._naive_solver = _naive_solver

        return G[:, :, -1, -1]
Esempio n. 9
0
    def forward(ctx,
                K_XX,
                K_XY,
                K_YY,
                static_kernel,
                dyadic_order,
                lambda_,
                sym=False,
                _naive_solver=False,
                inspect=False,
                centered=False):

        A = K_XX.shape[0]
        B = K_YY.shape[0]
        M = K_XX.shape[2]
        N = K_YY.shape[2]

        MM = (2**dyadic_order) * (M - 1)
        NN = (2**dyadic_order) * (N - 1)

        # computing dsdt k(X^1[i]_s,Y^1[j]_t)
        G_base = innerprodCKME(
            K_XX,
            K_XY,
            K_YY,
            lambda_,
            static_kernel,
            sym=sym,
            centered=centered
        )  # <--------------------- this is the only change compared to rank 0

        G_base_ = G_base[:, :, 1:,
                         1:] + G_base[:, :, :-1, :
                                      -1] - G_base[:, :,
                                                   1:, :-1] - G_base[:, :, :-1,
                                                                     1:]

        G_base_ = tile(
            tile(G_base_, 2, 2**dyadic_order) / float(2**dyadic_order), 3, 2**
            dyadic_order) / float(2**dyadic_order)

        # if on GPU
        if K_XX.device.type == 'cuda':

            assert max(
                MM, NN
            ) < 1024, 'n must be lowered or data must be moved to CPU as the current choice of n makes exceed the thread limit'

            # cuda parameters
            threads_per_block = max(MM + 1, NN + 1)
            n_anti_diagonals = 2 * threads_per_block - 1

            # Prepare the tensor of output solutions to the PDE (forward)
            G = torch.zeros((A, B, MM + 2, NN + 2),
                            device=G_base.device,
                            dtype=G_base.dtype)
            G[:, :, 0, :] = 1.
            G[:, :, :, 0] = 1.

            # Run the CUDA kernel.
            blockspergrid = (A, B)
            compute_sig_kernel_Gram_mat_varpar_from_increments_cuda[
                blockspergrid,
                threads_per_block](cuda.as_cuda_array(G_base_.detach()),
                                   MM + 1, NN + 1, n_anti_diagonals,
                                   cuda.as_cuda_array(G), _naive_solver)
            G = G[:, :, :-1, :-1]

        else:
            G = torch.tensor(sig_kernel_Gram_varpar(G_base_.detach().numpy(),
                                                    sym, _naive_solver),
                             dtype=G_base.dtype,
                             device=G_base.device)
        if inspect:
            return G[:, :, -1, -1], G_base
        return G[:, :, -1, -1]
Esempio n. 10
0
    def backward(ctx, grad_output):

        X, Y, K = ctx.saved_tensors
        n = ctx.n
        solver = ctx.solver
        rbf = ctx.rbf
        sigma = ctx.sigma

        A = X.shape[0]
        M = X.shape[1]
        N = Y.shape[1]
        D = X.shape[2]

        MM = (2**n) * (M - 1)
        NN = (2**n) * (N - 1)

        # Reverse paths
        X_rev = torch.flip(X, dims=[1])
        Y_rev = torch.flip(Y, dims=[1])

        # if on GPU
        if X.device.type == 'cuda':

            M_inc_rev = increment_matrix(X_rev,
                                         Y_rev,
                                         rbf=False,
                                         sigma=None,
                                         n=n)

            # Prepare the tensor of output solutions to the PDE (backward)
            K_rev = torch.zeros((A, MM + 2, NN + 2),
                                device=M_inc_rev.device,
                                dtype=M_inc_rev.dtype)
            K_rev[:, 0, :] = 1.
            K_rev[:, :, 0] = 1.

            # cuda parameters
            threads_per_block = max(MM, NN)
            n_anti_diagonals = 2 * threads_per_block - 1

            # Compute signature kernel for reversed paths
            compute_sig_kernel_batch_varpar_from_increments_cuda[
                A, threads_per_block](cuda.as_cuda_array(M_inc_rev.detach()),
                                      MM + 1, NN + 1, n_anti_diagonals,
                                      cuda.as_cuda_array(K_rev), solver)

            K_rev = K_rev[:, :-1, :-1]

        # if on CPU
        else:

            K_rev = sig_kernel_batch_varpar(X_rev.detach().numpy(),
                                            Y_rev.detach().numpy(),
                                            n=n,
                                            solver=solver,
                                            rbf=rbf,
                                            sigma=sigma)
            K_rev = torch.tensor(K_rev, dtype=X.dtype)

        inc_X = tile(X[:, 1:, :] - X[:, :-1, :], 1, 2**n) / float(
            2**n)  # (A,(2**n)*(M-1),D)  increments on the finer grid
        inc_Y = tile(Y[:, 1:, :] - Y[:, :-1, :], 1, 2**n) / float(
            2**n)  # (A,(2**n)*(N-1),D)  increments on the finer grid

        K_rev = flip(flip(K_rev, dim=1), dim=2)

        KK = K[:, :-1, :-1] * K_rev[:, 1:, 1:]  # (A,(2**n)*(M-1),(2**n)*(N-1))

        if rbf:
            # create linear interpolations of X and Y on finer discretization
            X_interp = torch.cat(
                [torch.zeros(A, 1, D),
                 torch.cumsum(inc_X, dim=1)], axis=1) + X[:, 0, :][:, None, :]
            Y_interp = torch.cat(
                [torch.zeros(A, 1, D),
                 torch.cumsum(inc_Y, dim=1)], axis=1) + Y[:, 0, :][:, None, :]

            # compute tensor k_rbf(x_s,y_t)
            Xs = torch.sum(X_interp**2, dim=2)
            Ys = torch.sum(Y_interp**2, dim=2)
            dist = -2. * torch.bmm(X_interp, Y_interp.permute(0, 2, 1))
            dist += Xs[:, :, None] + Ys[:, None, :]
            M_rbf = torch.exp(-dist / sigma)

            # form term required in variation of parameters formula (for rbf)
            term_1 = Y_interp[:, None,
                              1:, :] * M_rbf[:, 1:, 1:,
                                             None] - Y_interp[:, None, :
                                                              -1, :] * M_rbf[:,
                                                                             1:, :
                                                                             -1,
                                                                             None]
            term_2 = X_interp[:, 1:,
                              None, :] * M_rbf[:, 1:, 1:,
                                               None] - X_interp[:, 1:,
                                                                None, :] * M_rbf[:,
                                                                                 1:, :
                                                                                 -1,
                                                                                 None]

            grad_incr = 2. * KK[:, :, :, None] * (term_1 - term_2) / 4.

        else:
            grad_incr = KK[:, :, :,
                           None] * inc_Y[:,
                                         None, :, :]  # (A,(2**n)*(M-1),(2**n)*(N-1),D)

        grad_incr = torch.sum(grad_incr, axis=2) / float(
            2**n)  # (A,(2**n)*(M-1),D)

        grad_incr = torch.sum(grad_incr.reshape(A, M - 1, 2**n, D),
                              axis=2)  # (A,M-1,D)

        if Y.requires_grad:
            grad_incr *= 2

        grad_points = -torch.cat(
            [
                grad_incr,
                torch.zeros((A, 1, D), dtype=X.dtype, device=X.device)
            ],
            dim=1) + torch.cat([
                torch.zeros(
                    (A, 1, D), dtype=X.dtype, device=X.device), grad_incr
            ],
                               dim=1)

        return grad_output[:, None,
                           None] * grad_points, None, None, None, None, None
Esempio n. 11
0
    def backward(ctx, grad_output):

        X, Y, G = ctx.saved_tensors
        n = ctx.n
        solver = ctx.solver
        sym = ctx.sym

        A = X.shape[0]
        B = Y.shape[0]
        M = X.shape[1]
        N = Y.shape[1]
        D = X.shape[2]

        MM = (2**n) * (M - 1)
        NN = (2**n) * (N - 1)

        # Reverse paths
        X_rev = torch.flip(X, dims=[1])
        Y_rev = torch.flip(Y, dims=[1])

        # if on GPU
        if X.device.type == 'cuda':

            M_inc_rev = increment_matrix_mmd(X_rev,
                                             Y_rev,
                                             rbf=False,
                                             sigma=None,
                                             n=n)

            # Prepare the tensor of output solutions to the PDE (backward)
            G_rev = torch.zeros((A, B, MM + 2, NN + 2),
                                device=M_inc_rev.device,
                                dtype=M_inc_rev.dtype)
            G_rev[:, :, 0, :] = 1.
            G_rev[:, :, :, 0] = 1.

            # cuda parameters
            threads_per_block = max(MM + 1, NN + 1)
            n_anti_diagonals = 2 * threads_per_block - 1

            # Compute signature kernel for reversed paths
            blockspergrid = (A, B)
            compute_sig_kernel_Gram_mat_varpar_from_increments_cuda[
                blockspergrid,
                threads_per_block](cuda.as_cuda_array(M_inc_rev.detach()),
                                   MM + 1, NN + 1, n_anti_diagonals,
                                   cuda.as_cuda_array(G_rev), solver)

            G_rev = G_rev[:, :, :-1, :-1]

        # if on CPU
        else:
            G_rev = sig_kernel_Gram_matrix(X_rev.detach().numpy(),
                                           Y_rev.detach().numpy(),
                                           n=n,
                                           solver=solver,
                                           sym=sym,
                                           full=True,
                                           rbf=False,
                                           sigma=None)
            G_rev = torch.tensor(G_rev, dtype=X.dtype)

        inc_Y = tile(Y[:, 1:, :] - Y[:, :-1, :], 1, 2**n) / float(
            2**n)  # (B,(2**n)*(N-1),D)  increments on the finer grid

        G_rev = flip(flip(G_rev, dim=2), dim=3)

        GG = G[:, :, :-1, :-1] * G_rev[:, :, 1:,
                                       1:]  # (A,B,(2**n)*(M-1),(2**n)*(N-1))

        grad_incr = GG[:, :, :, :, None] * inc_Y[
            None, :, None, :, :]  # (A,B,(2**n)*(M-1),(2**n)*(N-1),D)

        grad_incr = (1. / (2**n)) * torch.sum(grad_incr,
                                              axis=3)  # (A,B,(2**n)*(M-1),D)

        grad_incr = torch.sum(grad_incr.reshape(A, B, M - 1, 2**n, D),
                              axis=3)  # (A,B,M-1,D)

        grad_points = -torch.cat(
            [
                grad_incr,
                torch.zeros((A, B, 1, D), dtype=X.dtype, device=X.device)
            ],
            dim=2) + torch.cat([
                torch.zeros(
                    (A, B, 1, D), dtype=X.dtype, device=X.device), grad_incr
            ],
                               dim=2)

        if sym:
            grad = (grad_output[:, :, None, None] * grad_points +
                    grad_output.t()[:, :, None, None] * grad_points).sum(dim=1)
            return grad, None, None, None, None, None, None
        else:
            grad = (grad_output[:, :, None, None] * grad_points).sum(dim=1)
            return grad, None, None, None, None, None, None
Esempio n. 12
0
    def forward(ctx, X, Y, n=0, solver=0, sym=False, rbf=False, sigma=1.):

        A = X.shape[0]
        B = Y.shape[0]
        M = X.shape[1]
        N = Y.shape[1]
        D = X.shape[2]

        MM = (2**n) * (M - 1)
        NN = (2**n) * (N - 1)

        if X.requires_grad:
            assert not rbf, 'Current backpropagation method only for linear signature kernel. For rbf signature kernel use naive implementation'

        # if on GPU
        if X.device.type == 'cuda':

            assert max(
                MM, NN
            ) < 1024, 'n must be lowered or data must be moved to CPU as the current choice of n makes exceed the thread limit'

            M_inc = increment_matrix_mmd(X, Y, rbf, sigma, n)

            # cuda parameters
            threads_per_block = max(MM + 1, NN + 1)
            n_anti_diagonals = 2 * threads_per_block - 1

            # Prepare the tensor of output solutions to the PDE (forward)
            G = torch.zeros((A, B, MM + 2, NN + 2),
                            device=M_inc.device,
                            dtype=M_inc.dtype)
            G[:, :, 0, :] = 1.
            G[:, :, :, 0] = 1.

            # Run the CUDA kernel.
            blockspergrid = (A, B)
            compute_sig_kernel_Gram_mat_varpar_from_increments_cuda[
                blockspergrid,
                threads_per_block](cuda.as_cuda_array(M_inc.detach()), MM + 1,
                                   NN + 1, n_anti_diagonals,
                                   cuda.as_cuda_array(G), solver)

            G = G[:, :, :-1, :-1]

        else:
            G = sig_kernel_Gram_matrix(X.detach().numpy(),
                                       Y.detach().numpy(),
                                       n=n,
                                       solver=solver,
                                       sym=sym,
                                       full=True,
                                       rbf=rbf,
                                       sigma=sigma)
            G = torch.tensor(G, dtype=X.dtype)

        ctx.save_for_backward(X, Y, G)
        ctx.n = n
        ctx.solver = solver
        ctx.sym = sym

        return G[:, :, -1, -1]
Esempio n. 13
0
    def forward(ctx, X, Y, n=0, solver=0, rbf=False, sigma=1.):
        """
            Compute Signature Kernel and its gradients via variation of parameters. Supports both CPU and GPU.
         
            - X : 3-tensor of shape (batch, len, dim)
            - Y : 3-tensor of shape (batch, len, dim)       
        """

        A = X.shape[0]
        M = X.shape[1]
        N = Y.shape[1]
        D = X.shape[2]

        MM = (2**n) * (M - 1)
        NN = (2**n) * (N - 1)

        if X.requires_grad:
            assert not rbf, 'Current backpropagation method only for linear signature kernel. For rbf signature kernel use naive implementation'

        # if on GPU
        if X.device.type == 'cuda':

            assert max(
                MM + 1, NN + 1
            ) < 1024, 'n must be lowered or data must be moved to CPU as the current choice of n makes exceed the thread limit'

            M_inc = increment_matrix(X, Y, rbf, sigma, n)

            # cuda parameters
            threads_per_block = max(MM + 1, NN + 1)
            n_anti_diagonals = 2 * threads_per_block - 1

            # Prepare the tensor of output solutions to the PDE (forward)
            K = torch.zeros((A, MM + 2, NN + 2),
                            device=M_inc.device,
                            dtype=M_inc.dtype)
            K[:, 0, :] = 1.
            K[:, :, 0] = 1.

            # Run the CUDA kernel to compute the forward signature kernel.
            # Set CUDA's grid size to be equal to the batch size (every CUDA block processes one sample pair)
            # Set the CUDA block size to be equal to the length of the longer sequence (equal to the size of the largest diagonal)
            compute_sig_kernel_batch_varpar_from_increments_cuda[
                A, threads_per_block](cuda.as_cuda_array(M_inc.detach()),
                                      MM + 1, NN + 1, n_anti_diagonals,
                                      cuda.as_cuda_array(K), solver)

            K = K[:, :-1, :-1]

        # if on CPU
        else:

            K = sig_kernel_batch_varpar(X.detach().numpy(),
                                        Y.detach().numpy(),
                                        n=n,
                                        solver=solver,
                                        rbf=rbf,
                                        sigma=sigma)
            K = torch.tensor(K, dtype=X.dtype)

        ctx.save_for_backward(X, Y, K)
        ctx.n = n
        ctx.solver = solver
        ctx.sigma = sigma
        ctx.rbf = rbf

        return K[:, -1, -1]
Esempio n. 14
0
    def backward(ctx, grad_output):

        X, Y, G_static, K = ctx.saved_tensors
        static_kernel = ctx.static_kernel
        dyadic_order = ctx.dyadic_order
        _naive_solver = ctx._naive_solver

        G_static_ = G_static[:, 1:,
                             1:] + G_static[:, :-1, :
                                            -1] - G_static[:, 1:, :
                                                           -1] - G_static[:, :
                                                                          -1,
                                                                          1:]
        G_static_ = tile(
            tile(G_static_, 1, 2**dyadic_order) / float(2**dyadic_order), 2, 2
            **dyadic_order) / float(2**dyadic_order)

        A = X.shape[0]
        M = X.shape[1]
        N = Y.shape[1]
        D = X.shape[2]

        MM = (2**dyadic_order) * (M - 1)
        NN = (2**dyadic_order) * (N - 1)

        # Reverse paths
        # X_rev = torch.flip(X, dims=[1])
        # Y_rev = torch.flip(Y, dims=[1])

        # computing dsdt k(X_rev^i_s,Y_rev^i_t) for variation of parameters
        G_static_rev = flip(flip(G_static_, dim=1), dim=2)

        # if on GPU
        if X.device.type == 'cuda':

            # Prepare the tensor of output solutions to the PDE (backward)
            K_rev = torch.zeros((A, MM + 2, NN + 2),
                                device=G_static_rev.device,
                                dtype=G_static_rev.dtype)
            K_rev[:, 0, :] = 1.
            K_rev[:, :, 0] = 1.

            # cuda parameters
            threads_per_block = max(MM, NN)
            n_anti_diagonals = 2 * threads_per_block - 1

            # Compute signature kernel for reversed paths
            compute_sig_kernel_batch_varpar_from_increments_cuda[
                A,
                threads_per_block](cuda.as_cuda_array(G_static_rev.detach()),
                                   MM + 1, NN + 1, n_anti_diagonals,
                                   cuda.as_cuda_array(K_rev), _naive_solver)

            K_rev = K_rev[:, :-1, :-1]

        # if on CPU
        else:
            K_rev = torch.tensor(sig_kernel_batch_varpar(
                G_static_rev.detach().numpy(), _naive_solver),
                                 dtype=G_static.dtype,
                                 device=G_static.device)

        K_rev = flip(flip(K_rev, dim=1), dim=2)
        KK = K[:, :-1, :-1] * K_rev[:, 1:, 1:]

        # finite difference step
        h = 1e-9

        Xh = X[:, :, :, None] + h * torch.eye(
            D, dtype=X.dtype, device=X.device)[None, None, :]
        Xh = Xh.permute(0, 1, 3, 2)
        Xh = Xh.reshape(A, M * D, D)

        G_h = static_kernel.batch_kernel(Xh, Y)
        G_h = G_h.reshape(A, M, D, N)
        G_h = G_h.permute(0, 1, 3, 2)

        Diff_1 = G_h[:, 1:, 1:, :] - G_h[:, 1:, :-1, :] - (
            G_static[:, 1:, 1:])[:, :, :,
                                 None] + (G_static[:, 1:, :-1])[:, :, :, None]
        Diff_1 = tile(
            tile(Diff_1, 1, 2**dyadic_order) / float(2**dyadic_order), 2, 2**
            dyadic_order) / float(2**dyadic_order)
        Diff_2 = G_h[:, 1:, 1:, :] - G_h[:, 1:, :-1, :] - (
            G_static[:, 1:, 1:])[:, :, :,
                                 None] + (G_static[:, 1:, :-1])[:, :, :, None]
        Diff_2 += -G_h[:, :-1, 1:, :] + G_h[:, :-1, :-1, :] + (
            G_static[:, :-1, 1:])[:, :, :, None] - (
                G_static[:, :-1, :-1])[:, :, :, None]
        Diff_2 = tile(
            tile(Diff_2, 1, 2**dyadic_order) / float(2**dyadic_order), 2, 2**
            dyadic_order) / float(2**dyadic_order)

        grad_1 = (KK[:, :, :, None] * Diff_1) / h
        grad_2 = (KK[:, :, :, None] * Diff_2) / h

        grad_1 = torch.sum(grad_1, axis=2)
        grad_1 = torch.sum(grad_1.reshape(A, M - 1, 2**dyadic_order, D),
                           axis=2)
        grad_2 = torch.sum(grad_2, axis=2)
        grad_2 = torch.sum(grad_2.reshape(A, M - 1, 2**dyadic_order, D),
                           axis=2)

        grad_prev = grad_1[:, :-1, :] + grad_2[:, 1:, :]  # /¯¯
        grad_next = torch.cat([
            torch.zeros(
                (A, 1, D), dtype=X.dtype, device=X.device), grad_1[:, 1:, :]
        ],
                              dim=1)  # /
        grad_incr = grad_prev - grad_1[:, 1:, :]
        grad_points = torch.cat(
            [(grad_2[:, 0, :] - grad_1[:, 0, :])[:, None, :], grad_incr,
             grad_1[:, -1, :][:, None, :]],
            dim=1)

        if Y.requires_grad:
            grad_points *= 2

        return grad_output[:, None, None] * grad_points, None, None, None, None