def forward(ctx, X, Y, static_kernel, dyadic_order, _naive_solver=False): A = X.shape[0] M = X.shape[1] N = Y.shape[1] D = X.shape[2] MM = (2**dyadic_order) * (M - 1) NN = (2**dyadic_order) * (N - 1) # computing dsdt k(X^i_s,Y^i_t) G_static = static_kernel.batch_kernel(X, Y) G_static_ = G_static[:, 1:, 1:] + G_static[:, :-1, : -1] - G_static[:, 1:, : -1] - G_static[:, : -1, 1:] G_static_ = tile( tile(G_static_, 1, 2**dyadic_order) / float(2**dyadic_order), 2, 2 **dyadic_order) / float(2**dyadic_order) # if on GPU if X.device.type == 'cuda': assert max( MM + 1, NN + 1 ) < 1024, 'n must be lowered or data must be moved to CPU as the current choice of n makes exceed the thread limit' # cuda parameters threads_per_block = max(MM + 1, NN + 1) n_anti_diagonals = 2 * threads_per_block - 1 # Prepare the tensor of output solutions to the PDE (forward) K = torch.zeros((A, MM + 2, NN + 2), device=G_static.device, dtype=G_static.dtype) K[:, 0, :] = 1. K[:, :, 0] = 1. # Compute the forward signature kernel compute_sig_kernel_batch_varpar_from_increments_cuda[ A, threads_per_block](cuda.as_cuda_array(G_static_.detach()), MM + 1, NN + 1, n_anti_diagonals, cuda.as_cuda_array(K), _naive_solver) K = K[:, :-1, :-1] # if on CPU else: K = torch.tensor(sig_kernel_batch_varpar( G_static_.detach().numpy(), _naive_solver), dtype=G_static.dtype, device=G_static.device) ctx.save_for_backward(X, Y, G_static, K) ctx.static_kernel = static_kernel ctx.dyadic_order = dyadic_order ctx._naive_solver = _naive_solver return K[:, -1, -1]
def backward(ctx, grad_output): X, Y, G_static, K = ctx.saved_tensors static_kernel = ctx.static_kernel dyadic_order = ctx.dyadic_order _naive_solver = ctx._naive_solver G_static_ = G_static[:, 1:, 1:] + G_static[:, :-1, : -1] - G_static[:, 1:, : -1] - G_static[:, : -1, 1:] G_static_ = tile( tile(G_static_, 1, 2**dyadic_order) / float(2**dyadic_order), 2, 2 **dyadic_order) / float(2**dyadic_order) A = X.shape[0] M = X.shape[1] N = Y.shape[1] D = X.shape[2] MM = (2**dyadic_order) * (M - 1) NN = (2**dyadic_order) * (N - 1) # Reverse paths X_rev = torch.flip(X, dims=[1]) Y_rev = torch.flip(Y, dims=[1]) # computing dsdt k(X_rev^i_s,Y_rev^i_t) for variation of parameters G_static_rev = flip(flip(G_static_, dim=1), dim=2) # if on GPU if X.device.type == 'cuda': # Prepare the tensor of output solutions to the PDE (backward) K_rev = torch.zeros((A, MM + 2, NN + 2), device=G_static_rev.device, dtype=G_static_rev.dtype) K_rev[:, 0, :] = 1. K_rev[:, :, 0] = 1. # cuda parameters threads_per_block = max(MM, NN) n_anti_diagonals = 2 * threads_per_block - 1 # Compute signature kernel for reversed paths compute_sig_kernel_batch_varpar_from_increments_cuda[ A, threads_per_block](cuda.as_cuda_array(G_static_rev.detach()), MM + 1, NN + 1, n_anti_diagonals, cuda.as_cuda_array(K_rev), _naive_solver) K_rev = K_rev[:, :-1, :-1] # if on CPU else: K_rev = torch.tensor(sig_kernel_batch_varpar( G_static_rev.detach().numpy(), _naive_solver), dtype=G_static.dtype, device=G_static.device) K_rev = flip(flip(K_rev, dim=1), dim=2) KK = K[:, :-1, :-1] * K_rev[:, 1:, 1:] # finite difference step h = 1e-9 Xh = X[:, :, :, None] + h * torch.eye( D, dtype=X.dtype, device=X.device)[None, None, :] Xh = Xh.permute(0, 1, 3, 2) Xh = Xh.reshape(A, M * D, D) G_h = static_kernel.batch_kernel(Xh, Y) G_h = G_h.reshape(A, M, D, N) G_h = G_h.permute(0, 1, 3, 2) Diff_1 = G_h[:, 1:, 1:, :] - G_h[:, 1:, :-1, :] - ( G_static[:, 1:, 1:])[:, :, :, None] + (G_static[:, 1:, :-1])[:, :, :, None] Diff_1 = tile( tile(Diff_1, 1, 2**dyadic_order) / float(2**dyadic_order), 2, 2** dyadic_order) / float(2**dyadic_order) Diff_2 = G_h[:, 1:, 1:, :] - G_h[:, 1:, :-1, :] - ( G_static[:, 1:, 1:])[:, :, :, None] + (G_static[:, 1:, :-1])[:, :, :, None] Diff_2 += -G_h[:, :-1, 1:, :] + G_h[:, :-1, :-1, :] + ( G_static[:, :-1, 1:])[:, :, :, None] - ( G_static[:, :-1, :-1])[:, :, :, None] Diff_2 = tile( tile(Diff_2, 1, 2**dyadic_order) / float(2**dyadic_order), 2, 2** dyadic_order) / float(2**dyadic_order) grad_1 = (KK[:, :, :, None] * Diff_1) / h grad_2 = (KK[:, :, :, None] * Diff_2) / h grad_1 = torch.sum(grad_1, axis=2) grad_1 = torch.sum(grad_1.reshape(A, M - 1, 2**dyadic_order, D), axis=2) grad_2 = torch.sum(grad_2, axis=2) grad_2 = torch.sum(grad_2.reshape(A, M - 1, 2**dyadic_order, D), axis=2) grad_prev = grad_1[:, :-1, :] + grad_2[:, 1:, :] # /¯¯ grad_next = torch.cat([ torch.zeros( (A, 1, D), dtype=X.dtype, device=X.device), grad_1[:, 1:, :] ], dim=1) # / grad_incr = grad_prev - grad_1[:, 1:, :] grad_points = torch.cat( [(grad_2[:, 0, :] - grad_1[:, 0, :])[:, None, :], grad_incr, grad_1[:, -1, :][:, None, :]], dim=1) if Y.requires_grad: grad_points *= 2 return grad_output[:, None, None] * grad_points, None, None, None, None