def pull1d(img, grid, grad=False, **kwargs): """Pull an image by a transform along the last dimension Parameters ---------- img : (K, *spatial) tensor, Image grid : (*spatial) tensor, Sampling grid grad : bool, Sample gradients Returns ------- warped_img : (K, *spatial) tensor warped_grad : (K, *spatial) tensor, if `grad` """ if grid is None: if grad: bound = kwargs.get('bound', 'dft') return img, diff1d(img, dim=-1, bound=bound, side='c') else: return img, None kwargs.setdefault('extrapolate', True) kwargs.setdefault('bound', 'dft') img, grid = img.unsqueeze(-2), grid.unsqueeze(-1) warped = grid_pull(img, grid, **kwargs).squeeze(-2) if not grad: return warped grad = grid_grad(img, grid, **kwargs) grad = grad.squeeze(-1).squeeze(-2) return warped, grad
def smart_grad(tensor, grid, **opt): """Pull gradients iff grid is defined (+ add/remove batch dim). Parameters ---------- tensor : (channels, *input_shape) tensor Input volume grid : (*output_shape, D) tensor or None Sampling grid Returns ------- pulled : (channels, *output_shape) tensor Sampled volume """ # if grid is None: # opt.pop('extrapolate', None) # opt.pop('interpolation', None) # return spatial.diff(tensor, dim=3, **opt) if grid is None: grid = spatial.identity_grid(tensor.shape[-3:], dtype=tensor.dtype, device=tensor.device) out = spatial.grid_grad(tensor, grid, **opt) return out
def _deform_1d(img, disp, grad=False): img = utils.movedim(img, 0, -2) disp = disp.unsqueeze(-1) disp = spatial.add_identity_grid(disp) wrp = spatial.grid_pull(img, disp, bound=BND, extrapolate=True) wrp = utils.movedim(wrp, -2, 0) if not grad: return wrp, None grd = spatial.grid_grad(img, disp, bound=BND, extrapolate=True) grd = utils.movedim(grd.squeeze(-1), -2, 0) return wrp, grd
def grad_position(self, t): """Gradient of the evaluated position wrt time""" # convert (0, 1) to (0, n) shape = t.shape t = t.flatten() t = t.clamp(0, 1) * (len(self.waypoints) - 1) # interpolate y = self.coeff.T # [D, K] t = t.unsqueeze(-1) # [N, 1] g = grid_grad(y, t, interpolation=self.order, bound=self.bound) g = g.squeeze(-1).T # [N, D] g = g.reshape([*shape, g.shape[-1]]) g *= (len(self.waypoints) - 1) return g
def _jhistc_backward(g, x, w=None, order=0, bound='replicate', extrapolate=True, gradx=True, gradw=False): """Compute derivative of the joint histogram. The input must already be a soft mapping to bins indices. Parameters ---------- g : (b, bins, bins) tensor x : (b, n, 2) tensor w : ([b], n) tensor, optional order : int, default=0 bound : {'zero', 'nearest'}, default='nearest' extrapolate : bool, default=True gradx : bool, default=True gradw : bool, default=False Returns ------- gx : (b, n, 2) tensor, if gradx gw : ([b], n) tensor, if gradw """ extrapolate = 1 if extrapolate else 2 opt = dict(interpolation=order, bound=bound, extrapolate=extrapolate) x = x.unsqueeze(-3) # make 2d spatial g = g.unsqueeze(-3) # add channel dimension out = [] if gradx: gx = grid_grad(g, x, **opt) gx = gx.squeeze(-3).squeeze(-3) if w is not None: gx *= w.unsqueeze(-1) out.append(gx) if gradw and w is not None: gw = grid_pull(g, x, **opt) gw = gw.squeeze(-2).squeeze(-2) # drop spatial + channel out.append(gw) elif gradw: out.append(None) return out[0] if len(out) == 1 else tuple(out)
def pull1d(img, grid, dim, grad=False, **kwargs): if grid is None: if grad: bound = kwargs.get('bound', 'dft') return img, spatial.diff1d(img, dim=dim, bound=bound, side='c') else: return img, None kwargs.setdefault('extrapolate', True) kwargs.setdefault('bound', 'dft') img = core.utils.movedim(img, dim, -1).unsqueeze(-2) grid = core.utils.movedim(grid, dim, -1).unsqueeze(-1) warped = spatial.grid_pull(img, grid, **kwargs) warped = core.utils.movedim(warped.squeeze(-2), -1, dim) if not grad: return warped, None grad = spatial.grid_grad(img, grid, **kwargs) grad = core.utils.movedim(grad.squeeze(-1).squeeze(-2), -1, dim) return warped, grad
def eval_grad_position(self, t): """Evaluate position and its gradient wrt time""" # convert (0, 1) to (0, n) shape = t.shape t = t.flatten() t = t.clamp(0, 1) * (len(self.waypoints) - 1) # interpolate y = self.coeff.T # [D, K] t = t.unsqueeze(-1) # [N, 1] x = grid_pull(y, t, interpolation=self.order, bound=self.bound) x = x.T # [N, D] g = grid_grad(y, t, interpolation=self.order, bound=self.bound) g = g.squeeze(-1).T # [N, D] x = x.reshape([*shape, x.shape[-1]]) g = g.reshape([*shape, g.shape[-1]]) g *= (len(self.waypoints) - 1) return x, g
def pull_grad(self, grid, rotate=False): """Sample the image gradients at dense coordinates. Parameters ---------- grid : (*spatial, dim) tensor or None Dense transformation field. rotate : bool, default=False Rotate the gradients using the Jacobian of the transformation. Returns ------- grad : ([C], *spatial, dim) tensor """ if grid is None: return self.grad() grad = spatial.grid_grad(self.dat, grid, bound=self.bound, extrapolate=self.extrapolate) if rotate: jac = spatial.grid_jacobian(grid) jac = jac.transpose(-1, -2) grad = linalg.matvec(jac, grad) return grad
def __call__(self, logaff, grad=False, hess=False, gradmov=False, hessmov=False, in_line_search=False): """ logaff : (..., nb) tensor, Lie parameters grad : Whether to compute and return the gradient wrt `logaff` hess : Whether to compute and return the Hessian wrt `logaff` gradmov : Whether to compute and return the gradient wrt `moving` hessmov : Whether to compute and return the Hessian wrt `moving` Returns ------- ll : () tensor, loss value (objective to minimize) g : (..., logaff) tensor, optional, Gradient wrt Lie parameters h : (..., logaff) tensor, optional, Hessian wrt Lie parameters gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving """ # This loop performs the forward pass, and computes # derivatives along the way. pullopt = dict(bound=self.bound, extrapolate=self.extrapolate) logplot = max(self.max_iter // 20, 1) do_plot = (not in_line_search) and self.plot \ and (self.n_iter - 1) % logplot == 0 # jitter # if not hasattr(self, '_fixed'): # idj = spatial.identity_grid(self.fixed.shape[-self.dim:], # jitter=True, # **utils.backend(self.fixed)) # self._fixed = spatial.grid_pull(self.fixed, idj, **pullopt) # del idj # fixed = self._fixed fixed = self.fixed # forward if not torch.is_tensor(self.basis): self.basis = spatial.affine_basis(self.basis, self.dim, **utils.backend(logaff)) aff = linalg.expm(logaff, self.basis) with torch.no_grad(): _, gaff = linalg._expm(logaff, self.basis, grad_X=True, hess_X=False) aff = spatial.affine_matmul(aff, self.affine_fixed) aff = spatial.affine_lmdiv(self.affine_moving, aff) # /!\ derivatives are not "homogeneous" (they do not have a one # on the bottom right): we should *not* use affine_matmul and # such (I only lost a day...) gaff = torch.matmul(gaff, self.affine_fixed) gaff = linalg.lmdiv(self.affine_moving, gaff) # haff = torch.matmul(haff, self.affine_fixed) # haff = linalg.lmdiv(self.affine_moving, haff) if self.id is None: shape = self.fixed.shape[-self.dim:] self.id = spatial.identity_grid(shape, **utils.backend(logaff), jitter=False) grid = spatial.affine_matvec(aff, self.id) warped = spatial.grid_pull(self.moving, grid, **pullopt) if do_plot: iscat = isinstance(self.loss, losses.Cat) plt.mov2fix(self.fixed, self.moving, warped, cat=iscat, dim=self.dim) # gradient/Hessian of the log-likelihood in observed space if not grad and not hess: llx = self.loss.loss(warped, fixed) elif not hess: llx, grad = self.loss.loss_grad(warped, fixed) if gradmov: gradmov = spatial.grid_push(grad, grid, **pullopt) else: llx, grad, hess = self.loss.loss_grad_hess(warped, fixed) if gradmov: gradmov = spatial.grid_push(grad, grid, **pullopt) if hessmov: hessmov = spatial.grid_push(hess, grid, **pullopt) del warped # compose with spatial gradients + dot product with grid if grad is not False or hess is not False: mugrad = spatial.grid_grad(self.moving, grid, **pullopt) grad = jg(mugrad, grad) if hess is not False: hess = jhj(mugrad, hess) grad, hess = regutils.affine_grid_backward(grad, hess, grid=self.id) else: grad = regutils.affine_grid_backward(grad) # , grid=self.id) dim2 = self.dim * (self.dim + 1) grad = grad.reshape([*grad.shape[:-2], dim2]) gaff = gaff[..., :-1, :] gaff = gaff.reshape([*gaff.shape[:-2], dim2]) grad = linalg.matvec(gaff, grad) if hess is not False: hess = hess.reshape([*hess.shape[:-4], dim2, dim2]) # haff = haff[..., :-1, :, :-1, :] # haff = haff.reshape([*gaff.shape[:-4], dim2, dim2]) hess = gaff.matmul(hess).matmul(gaff.transpose(-1, -2)) hess = hess.abs().sum(-1).diag_embed() del mugrad # print objective llx = llx.item() ll = llx if self.verbose and not in_line_search: self.n_iter += 1 if self.ll_prev is None: print( f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g}', end='\n') else: gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) print( f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g} | {gain:12.6g}', end='\n') self.ll_prev = ll self.ll_max = max(self.ll_max, ll) out = [ll] if grad is not False: out.append(grad) if hess is not False: out.append(hess) if gradmov is not False: out.append(gradmov) if hessmov is not False: out.append(hessmov) return tuple(out) if len(out) > 1 else out[0]
def backward(self, x, g, min=None, max=None, hess=False, mask=None): """ Parameters ---------- x : (..., N, 2) tensor Input multidimensional vector g : (..., B, B) tensor Gradient with respect to the histogram min : (..., 2) tensor, optional max : (..., 2) tensor, optional Returns ------- g : (..., N, 2) tensor Gradient with respect to x """ if self.fwhm: g = spatial.smooth(g, fwhm=self.fwhm, bound=self.bound, dim=2) shape = x.shape x, min, max = self._prepare(x, min, max) nvox = x.shape[-2] min = min.unsqueeze(-2) max = max.unsqueeze(-2) g = g.reshape([-1, *g.shape[-2:]]) extrapolate = self.extrapolate or 2 if not hess: g = spatial.grid_grad(g[:, None], x[:, None], self.order, self.bound, extrapolate) g = g[:, 0].reshape(shape) else: # 1) Absolute value of adjoint of gradient # we want shapes # o : [batch=1, channel=1, spatial=[1, vox], dim=2] # g : [batch=1, channel=1, spatial=[B(mov), B(fix)]] # x : [batch=1, spatial=[1, vox], dim=2] # -> [batch=1, channel=1, spatial=[B(mov), B(fix)]] order = _spatial.inter_to_nitorch([self.order], True) bound = _spatial.bound_to_nitorch([self.bound], True) o = torch.ones_like(x) g.requires_grad_() # triggers push o, = _spatial.grid_grad_backward(o[:, None, None], g[:, None], x[:, None], bound, order, extrapolate) g.requires_grad_(False) g *= o[:, 0] # 2) Absolute value of gradient # g : [batch=1, channel=1, spatial=[B(mov), B(fix)]] # x : [batch=1, spatial=[1, vox], dim=2] # -> [batch=1, channel=1, spatial=[1, vox], 2] g = _spatial.grid_grad(g[:, None], x[:, None], bound, order, extrapolate) g = g.reshape(shape) # adjoint of affine function nn = torch.as_tensor(self.n, dtype=x.dtype, device=x.device) factor = nn / (max - min) if hess: factor = factor.square_() g = g.mul_(factor) if mask is not None: g *= mask[..., None] return g
def fit_affine_tpm(dat, tpm, affine=None, affine_tpm=None, weights=None, basis='affine', fwhm=None, joint=False, prm=None, max_iter_gn=100, max_iter_em=32, max_line_search=6, progressive=False, verbose=1): """ Parameters ---------- dat : (B, J|1, *spatial) tensor tpm : (B|1, K, *spatial) tensor affine : (4, 4) tensor affine_tpm : (4, 4) tensor weights : (B, 1, *spatial) tensor basis : {'translation', 'rotation', 'rigid', 'similitude', 'affine'} fwhm : float, default=J/32 joint : bool, default=False max_iter_gn : int, default=100 max_iter_em : int, default=32 max_line_search : int, default=12 progressive : bool, default=False Returns ------- mi : (B,) tensor aff : (B, 4, 4) tensor prm : (B, F) tensor """ dim = dat.dim() - 2 # ------------------------------------------------------------------ # RECURSIVE PROGRESSIVE FIT # ------------------------------------------------------------------ if progressive: nb_se = dim * (dim + 1) // 2 nb_aff = dim * (dim + 1) basis_recursion = {'Aff+': 'CSO', 'CSO': 'SE', 'SE': 'T'} basis_nb_feat = {'Aff+': nb_aff, 'CSO': nb_se + 1, 'SE': nb_se} basis = convert_basis(basis) next_basis = basis_recursion.get(basis, None) if next_basis: *_, prm = fit_affine_tpm(dat, tpm, affine, affine_tpm, weights, basis=next_basis, fwhm=fwhm, joint=joint, prm=prm, max_iter_gn=max_iter_gn, max_iter_em=max_iter_em, max_line_search=max_line_search) B = len(dat) F = basis_nb_feat[basis] prm0 = prm prm = prm0.new_zeros([1 if joint else B, F]) if basis == 'SE': prm[:, :dim] = prm0[:, :dim] else: nb_se = dim * (dim + 1) // 2 prm[:, :nb_se] = prm0[:, :nb_se] if basis == 'Aff+': prm[:, nb_se:nb_se + dim] = prm0[:, nb_se] * (dim**(-0.5)) basis_name = basis # ------------------------------------------------------------------ # PREPARE # ------------------------------------------------------------------ B = len(dat) if affine is None: affine = spatial.affine_default(dat.shape[-dim:]) if affine_tpm is None: affine_tpm = spatial.affine_default(tpm.shape[-dim:]) affine = affine.to(**utils.backend(tpm)) affine_tpm = affine_tpm.to(**utils.backend(tpm)) shape = dat.shape[-dim:] tpm = tpm.to(dat.device) basis = make_basis(basis, dim, **utils.backend(tpm)) F = len(basis) if prm is None: prm = tpm.new_zeros([1 if joint else B, F]) aff, gaff = linalg._expm(prm, basis, grad_X=True) em_opt = dict(fwhm=fwhm, max_iter=max_iter_em, weights=weights, verbose=verbose - 2) drv_opt = dict(weights=weights) pull_opt = dict(bound='replicate', extrapolate=True) # ------------------------------------------------------------------ # OPTIMIZE # ------------------------------------------------------------------ prior = None mi = torch.as_tensor(-float('inf')) delta = torch.zeros_like(prm) for n_iter in range(max_iter_gn): # -------------------------------------------------------------- # LINE SEARCH # -------------------------------------------------------------- prior0, prm0, mi0 = prior, prm, mi armijo = 1 success = False for n_ls in range(max_line_search): # --- take a step ------------------------------------------ prm = prm0 - armijo * delta # --- build transformation field --------------------------- aff, gaff = linalg._expm(prm, basis, grad_X=True) phi = lmdiv(affine_tpm, mm(aff, affine)) phi = spatial.affine_grid(phi, shape) # --- warp TPM --------------------------------------------- mov = spatial.grid_pull(tpm, phi, **pull_opt) # --- mutual info ------------------------------------------ mi, Nm, prior = em_prior(mov, dat, prior0, **em_opt) mi = mi / Nm success = mi.sum() > mi0.sum() if verbose >= 2: end = '\n' if verbose >= 3 else '\r' happy = ':D' if success else ':(' print(f'(search) | {n_ls:02d} | {mi.mean():12.6g} | {happy}', end=end) if success: break armijo *= 0.5 # if verbose == 2: # print('') # -------------------------------------------------------------- # DID IT WORK? # -------------------------------------------------------------- if not success: prior, prm, mi = prior0, prm0, mi0 break # DEBUG # plot_registration(dat, mov, f'{basis_name} | {n_iter}') space = ' ' * max(0, 6 - len(basis_name)) if verbose >= 1: end = '\n' if verbose >= 2 else '\r' print( f'({basis_name[:6]}){space} | {n_iter:02d} | {mi.mean():12.6g}', end=end) if mi.mean() - mi0.mean() < 1e-5: break # -------------------------------------------------------------- # GAUSS-NEWTON # -------------------------------------------------------------- # --- derivatives ---------------------------------------------- g, h = derivatives_intensity(mov, dat, prior, **drv_opt) # --- chain rule ----------------------------------------------- gmov = spatial.grid_grad(tpm, phi, **pull_opt) if joint and len(mov) == 1: g = g.sum(0, keepdim=True) h = h.sum(0, keepdim=True) else: gmov = gmov.expand([B, *gmov.shape[1:]]) gaff = lmdiv(affine_tpm, mm(gaff, affine)) g, h = chain_rule(g, h, gmov, gaff, maj=False) del gmov if joint and len(g) > 1: g = g.sum(0, keepdim=True) h = h.sum(0, keepdim=True) # --- Gauss-Newton --------------------------------------------- delta = lmdiv(h, g.unsqueeze(-1)).squeeze(-1) if verbose == 1: print('') return mi, aff, prm
def __call__(self, vel, grad=False, hess=False, gradmov=False, hessmov=False, in_line_search=False): """ vel : (..., *spatial, dim) tensor, Displacement grad : Whether to compute and return the gradient wrt `vel` hess : Whether to compute and return the Hessian wrt `vel` gradmov : Whether to compute and return the gradient wrt `moving` hessmov : Whether to compute and return the Hessian wrt `moving` Returns ------- ll : () tensor, loss value (objective to minimize) g : (..., *spatial, dim) tensor, optional, Gradient wrt velocity h : (..., *spatial, ?) tensor, optional, Hessian wrt velocity gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving """ # This loop performs the forward pass, and computes # derivatives along the way. dim = vel.shape[-1] pullopt = dict(bound=self.bound, extrapolate=self.extrapolate) # in_line_search = not grad and not hess logplot = max(self.max_iter // 20, 1) do_plot = (not in_line_search) and self.plot \ and (self.n_iter - 1) % logplot == 0 # forward if self.id is None: self.id = spatial.identity_grid(vel.shape[-dim - 1:-1], **utils.backend(vel)) grid = self.id + vel warped = spatial.grid_pull(self.moving, grid, **pullopt) if do_plot: iscat = isinstance(self.loss, losses.Cat) plt.mov2fix(self.fixed, self.moving, warped, vel, cat=iscat, dim=dim) # gradient/Hessian of the log-likelihood in observed space if not grad and not hess and not hessmov: llx = self.loss.loss(warped, self.fixed) elif not hess and not hessmov: llx, grad = self.loss.loss_grad(warped, self.fixed) if gradmov: gradmov = spatial.grid_push(grad, grid, **pullopt) else: llx, grad, hess = self.loss.loss_grad_hess(warped, self.fixed) if gradmov: gradmov = spatial.grid_push(grad, grid, **pullopt) if hessmov: hessmov = spatial.grid_push(hess, grid, **pullopt) del warped # compose with spatial gradients if grad is not False or hess is not False: mugrad = spatial.grid_grad(self.moving, grid, **pullopt) if grad is not False: grad = jg(mugrad, grad) if hess is not False: hess = jhj(mugrad, hess) # add regularization term vgrad = spatial.regulariser_grid(vel, **self.prm, kernel=False) llv = 0.5 * (vel * vgrad).sum() if grad is not False: grad += vgrad del vgrad # print objective llx = llx.item() llv = llv.item() ll = llx + llv if self.verbose and not in_line_search: self.n_iter += 1 if self.ll_prev is None: print( f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} = {ll:12.6g}', end='\r') else: gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8) print( f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} = {ll:12.6g} | {gain:12.6g}', end='\r') self.ll_prev = ll self.ll_max = max(self.ll_max, ll) out = [ll] if grad is not False: out.append(grad) if hess is not False: out.append(hess) if gradmov is not False: out.append(gradmov) if hessmov is not False: out.append(hessmov) return tuple(out) if len(out) > 1 else out[0]
def _rigid_match(dat_x, dat_y, po, tau, rigid, sett, CtC=None, diff=False, verbose=0): """ Computes the rigid matching term, and its gradient and Hessian (if requested). Args: dat_x (torch.tensor): Observed data (X0, Y0, Z0). dat_y (torch.tensor): Reconstructed data (X1, Y1, Z1). po (ProjOp): Projection operator. tau (torch.tensor): Noice precision. CtC (torch.tensor, optional): CtC(ones), used for super-res gradient calculation. Defaults to None. rigid (torch.tensor): Rigid transformation matrix (4, 4). diff (bool, optional): Compute derivatives, defaults to False. verbose (bool, optional): Show registration results, defaults to 0. 0: No verbose 1: Print convergence info to console 2: Plot registration results using matplotlib Returns: ll (torch.tensor): Log-likelihood. gr (torch.tensor): Gradient (dim_x, 3). Hes (torch.tensor): Hessian (dim_x, 6). """ # Projection info mat_x = po.mat_x mat_y = po.mat_y mat_yx = po.mat_yx dim_x = po.dim_x dim_yx = po.dim_yx ratio = po.ratio smo_ker = po.smo_ker dim_thick = po.dim_thick scl = po.scl # Init output ll = None gr = None Hes = None if sett.method == 'super-resolution': extrapolate = False dim = dim_yx mat = mat_yx elif sett.method == 'denoising': extrapolate = False dim = dim_x mat = mat_x # Get grid mat = rigid.mm(mat).solve(mat_y)[0] # mat_y\rigid*mat grid = affine_grid(mat.type(torch.float32), dim, jitter=False) # Warp y and compute spatial derivatives dat_yx = grid_pull(dat_y, grid[None, ...], bound=sett.bound, extrapolate=extrapolate, interpolation=sett.interpolation)[0, 0, ...] if sett.method == 'super-resolution': dat_yx = F.conv3d(dat_yx[None, None, ...], smo_ker, stride=ratio)[0, 0, ...] if scl != 0: dat_yx = _apply_scaling(dat_yx, scl, dim_thick) if diff: gr = grid_grad(dat_y, grid[None, ...], bound=sett.bound, extrapolate=extrapolate, interpolation=sett.interpolation)[0, 0, ...] if verbose >= 2: # Show images show_slices(torch.stack((dat_x, dat_yx, (dat_x - dat_yx)**2), 3), fig_num=666, colorbar=False, flip=False) # Double and mask msk = (dat_x != 0) # Compute matching term ll = 0.5 * tau * torch.sum( (dat_x[msk] - dat_yx[msk])**2, dtype=torch.float64) if diff: # Difference diff = dat_yx - dat_x msk = msk & (dat_yx != 0) diff[~msk] = 0 # Hessian Hes = torch.zeros(dim + (6, ), device=dat_x.device, dtype=torch.float32) Hes[:, :, :, 0] = gr[:, :, :, 0] * gr[:, :, :, 0] Hes[:, :, :, 1] = gr[:, :, :, 1] * gr[:, :, :, 1] Hes[:, :, :, 2] = gr[:, :, :, 2] * gr[:, :, :, 2] Hes[:, :, :, 3] = gr[:, :, :, 0] * gr[:, :, :, 1] Hes[:, :, :, 4] = gr[:, :, :, 0] * gr[:, :, :, 2] Hes[:, :, :, 5] = gr[:, :, :, 1] * gr[:, :, :, 2] if sett.method == 'super-resolution': Hes *= CtC[..., None] diff = F.conv_transpose3d(diff[None, None, ...], smo_ker, stride=ratio)[0, 0, ...] # Gradient gr *= diff[..., None] return ll, gr, Hes