Exemplo n.º 1
0
 def __init__(self, dat=None, basis=None, dim=3, **backend):
     """
     Parameters
     ----------
     [dat : tensor, optional]
         Pre-allocated log-affine
     basis : {'translation', 'rotation', 'rigid', 'similitude', 'affine'}
         Name of an Affine basis
     dim : int, default=3
         Number of spatial dimensions
     **backend
     """
     if isinstance(dat, str):
         if basis is not None:
             raise ValueError('`basis` provided but `dat` looks like '
                              'a basis.')
         basis, dat = dat, None
     elif basis is None:
         raise ValueError('A basis should be provided')
     self._basis = None
     self.basis = basis
     self.dim = dim
     if dat is None:
         dat = torch.zeros(spatial.affine_basis_size(basis, dim), **backend)
     self.dat = dat
     self._cache = None
     self._icache = None
Exemplo n.º 2
0
    def __init__(self,
                 dim,
                 group='CSO',
                 mode='lie',
                 encoder='leastsquares',
                 unet=None,
                 pull=None):
        """

        Parameters
        ----------
        dim : {1, 2, 3}
            Dimensionality of the input.
            
        group : {'T', 'SO', 'SE', 'D', 'CSO', 'SL', 'GL+', 'Aff+'}, default='CSO'
            Affine group encoded:
            
            - 'T'   : Translations
            - 'SO'  : Special Orthogonal (rotations)
            - 'SE'  : Special Euclidean (translations + rotations)
            - 'D'   : Dilations (translations + isotropic scalings)
            - 'CSO' : Conformal Special Orthogonal
                      (translations + rotations + isotropic scalings)
            - 'SL'  : Special Linear (rotations + isovolumic zooms + shears)
            - 'GL+' : General Linear [det>0] (rotations + zooms + shears)
            - 'Aff+': Affine [det>0] (translations + rotations + zooms + shears)
                
        mode : {'lie', 'classic'}, default='lie'
            Encoding of the affine parameters.
            The basis 'SL' is only available in mode 'lie'.

        encoder : {'leastsquares', 'cnn'} or dict, default='leastsquares'
            Encoder used to transform the dense displacement field
            into affine parameters. If 'leastsquares', compute the 
            least squares affine matrix and convert it into parameters
            using `AffineLog` or `AffineClassicInverse`. If a'cnn' (or 
            a dictionary of CNN options), use an encoding CNN that takes 
            as input the displacement field and an identity grid and 
            directy outputs the affine parameters.

        unet : dict, optional
            Dictionary of UNet parameters, with fields:
            
            - kernel_size : int, default=3
                Kernel size of the UNet.
            - encoder : list[int], default=[16, 32, 32, 32]
                Number of channels after each encoding layer.
            - decoder : list[int], default=[32, 32, 32, 32, 32, 16, 16]
                Number of channels after each decoding layer.
            - batch_norm : bool, default=True
                Use batch normalization in the UNet.
            - activation : callable, default=LeakyReLU(0.2)
                Activation function.
        
        pull : dict, optional
            Dictionary of GridPull parameters, with fields:
            
            - interpolation : int, default=1
                Interpolation order.
            - bound : bound_type, default='dct2'
                Boundary conditions of the image.
            - extrapolate : bool, default=False
                Extrapolate data outside of the field of view.
        """

        # default parameters for the submodules

        unet = unet or dict()
        unet['encoder'] = unet.get('encoder', None)
        unet['decoder'] = unet.get('decoder', None)
        unet['kernel_size'] = unet.get('kernel_size', 3)
        unet['batch_norm'] = unet.get('batch_norm', False)
        unet['activation'] = unet.get('activation', tnn.LeakyReLU(0.2))

        pull = pull or dict()
        pull['interpolation'] = pull.get('interpolation', 1)
        pull['bound'] = pull.get('bound', 'dct2')
        pull['extrapolate'] = pull.get('extrapolate', False)

        cnn = dict()
        if isinstance(encoder, dict):
            cnn = encoder
            encoder = 'cnn'
        if encoder == 'cnn':
            cnn['encoder'] = unet.get('encoder', None)
            cnn['stack'] = unet.get('stack', None)
            cnn['kernel_size'] = cnn.get('kernel_size', 3)
            cnn['batch_norm'] = cnn.get('batch_norm', False)
            cnn['reduction'] = cnn.get('reduction', 'max')
            cnn['activation'] = cnn.get('activation', tnn.LeakyReLU(0.2))
            cnn['final_activation'] = cnn.get('final_activation', 'same')

        # instantiate submodules

        super().__init__()
        self.unet = UNet(dim,
                         input_channels=2,
                         output_channels=dim,
                         encoder=unet['encoder'],
                         decoder=unet['decoder'],
                         batch_norm=unet['batch_norm'],
                         kernel_size=unet['kernel_size'],
                         activation=unet['activation'])
        if encoder == 'leastsquares':
            self.dense2aff = DenseToAffine(shift=True)
            if mode == 'lie':
                self.log = AffineLog(basis=group)
            else:
                self.log = AffineClassicInverse(basis=group)
            self.dense2prm = self._dense2prm_ls
        else:
            nb_prm = spatial.affine_basis_size(group, dim)
            self.cnn = CNN(dim,
                           input_channels=2 * dim,
                           output_channels=nb_prm,
                           encoder=cnn['encoder'],
                           stack=cnn['stack'],
                           batch_norm=cnn['batch_norm'],
                           kernel_size=cnn['kernel_size'],
                           reduction=cnn['reduction'],
                           activation=cnn['activation'],
                           final_activation=cnn['final_activation'])
            self.dense2prm = self._dense2prm_cnn

        if mode == 'lie':
            self.exp = AffineExp(dim, basis=group)
        else:
            self.exp = AffineClassic(dim, basis=group)
        self.grid = AffineGrid(shift=True)
        self.pull = GridPull(interpolation=pull['interpolation'],
                             bound=pull['bound'],
                             extrapolate=pull['extrapolate'])
        self.dim = dim
        self.group = group
        self.encoder = encoder
        self.mode = mode

        # register losses/metrics
        self.tags = ['image', 'dense', 'affine']
Exemplo n.º 3
0
def diffeo(source, target, group='SE', origin='center',
           image_loss=None, vel_loss=None, pull=None, optim_affine=True,
           max_iter=1000, lr=0.1, min_lr=1e-7, init=None, device=None):
    """Diffeomorphic registration

    Note
    ----
    .. Tensors must have shape (batch, channel, *spatial)
    .. Composite losses (e.g., computed on both intensity and categorical
       images) can be obtained by stacking all types of inputs across
       the channel dimension. The loss function is then responsible
       for unstacking the tensor and computing the appropriate
       losses. The drawback of this approach is that all inputs
       must share the same lattice and orientation matrix, as
       well as the same interpolation order. The advantage is that
       it simplifies the signature of this function.

    Parameters
    ----------
    source :  tensor or (tensor, affine)
        The source (moving) image, with shape (batch, channel, *spatial).
    target : tensor or (tensor, affine)
        The target (fixed) image, with shape (batch, channel, *spatial).
    group : {'tr', 'rot', 'rigid', 'sim', 'lin', 'aff'}, default='rigid'
        Affine sub-group to optimize.
    origin : {'native', 'center'}, default='center'
        Whether to rotate about the origin of the world-space ('native')
        or the center of the target field-of-view ('center').
        When the origin of the world-space is far off (say you are
        registering smaller blocks cropped from a larger MRI), it can
        be beneficiary to rotate about the center of the FOV.
    image_loss : callable(mov, fix) -> loss, default=MutualInfoLoss()
        A loss function that takestwo  inputs of shape
        (batch, channel, *spatial).
    vel_loss : float or callable(mov, fix) -> loss, default=BendingLoss()
        Either a factor to muultiply the bending loss with or a loss 
        function that takes two inputs of shape (batch, channel, *spatial).
    pull : dict
        interpolation : int, default=1
            Interpolation order
        bound : bound_like, default='dct2'
            Boundary condition
        extrapolate : bool, default=False
            Extrapolate out-of-bound data using the boundary conditions.
    max_iter : int, default=1000
        Maximum number of iterations
    lr : float, default=0.1
        Initial learning rate.
    min_lr : float, default=1e-7
        Minimum learning rate. The optimization is stopped once this
        learning rate is reached.
    device : {'cpu', 'cuda', 'cuda:<id>'}, optional
        Backend to use
    init : ([batch], nb_prm) tensor_like, default=0
        Initial guess for the affine parameters.

    Returns
    -------
    q : (batch, nb_prm) tensor
        Parameters
    aff : (batch, D+1, D+1) tensor
        Affine transformation matrix.
        The source affine matrix can be "corrected" by left-multiplying
        it with `aff`.
    vel : (batch, *shape, D) tensor
        Initial velocity
    moved : tensor
        Source image moved to target space.
    """
    group = affine_group_converter(group)
    pull = pull or dict()
    pull['interpolation'] = pull.get('interpolation', 'linear')
    pull['bound'] = pull.get('bound', 'dct2')
    pull['extrapolate'] = pull.get('extrapolate', False)
    pull_opt = pull

    # prepare all data tensors
    ((source, source_aff), (target, target_aff)) = prepare([source, target],
                                                           device)
    backend = get_backend(source)
    batch = source.shape[0]
    dim = source.dim() - 2

    # Shift origin
    if origin == 'center':
        shift = torch.as_tensor(target.shape, **backend) / 2
        shift = -spatial.affine_matvec(target_aff, shift)
        target_aff = target_aff.clone()
        source_aff = source_aff.clone()
        target_aff[..., :-1, -1] += shift
        source_aff[..., :-1, -1] += shift

    # Prepare affine utils + Initialize parameters
    basis = spatial.affine_basis(group, dim, **backend)
    nb_prm = spatial.affine_basis_size(group, dim)
    if init is not None:
        parameters = torch.as_tensor(init, **backend).clone().detach()
        parameters = parameters.reshape([batch, nb_prm])
    else:
        parameters = torch.zeros([batch, nb_prm], **backend)
    parameters = nn.Parameter(parameters, requires_grad=optim_affine)
    velocity = torch.zeros([batch, *target.shape[2:], dim], **backend)
    velocity = nn.Parameter(velocity, requires_grad=True)

    def pull(q, vel):
        grid = spatial.exp(vel)
        aff = core.linalg.expm(q, basis)
        aff = spatial.affine_matmul(aff, target_aff)
        aff = spatial.affine_lmdiv(source_aff, aff)
        grid = spatial.affine_matvec(aff, grid)
        moved = spatial.grid_pull(source, grid, **pull_opt)
        return moved

    # Prepare loss and optimizer
    if not callable(image_loss):
        image_loss_fn = nni.MutualInfoLoss()
        factor = 1. if image_loss is None else image_loss
        image_loss = lambda x, y: factor * image_loss_fn(x, y)

    if not callable(vel_loss):
        vel_loss_fn = nni.BendingLoss(bound='dft')
        factor = 1. if vel_loss is None else vel_loss
        vel_loss = lambda x: factor * vel_loss_fn(core.utils.last2channel(x))

    lr = core.utils.make_list(lr, 2)
    min_lr = core.utils.make_list(min_lr, 2)
    opt_prm = [{'params': parameters}, {'params': velocity, 'lr': lr[1]}] \
              if optim_affine else [velocity]
    optim = torch.optim.Adam(opt_prm, lr=lr[0])
    scheduler = ReduceLROnPlateau(optim)

    def forward():
        moved = pull(parameters, velocity)
        loss_val = image_loss(moved, target) + vel_loss(velocity)
        return loss_val

    # Optim loop
    loss_avg = 0
    for n_iter in range(1, max_iter + 1):

        optim.zero_grad(set_to_none=True)
        loss_val = forward()
        loss_val.backward()
        optim.step(forward)

        with torch.no_grad():
            loss_avg += loss_val
            if n_iter % 10 == 0:
                loss_avg /= 10
                scheduler.step(loss_avg)

                print('{:4d} {:12.6f} | lr={:g} '
                      .format(n_iter, loss_avg.item(),
                              optim.param_groups[0]['lr']),
                      end='\r')
                loss_avg = 0

        if (optim.param_groups[0]['lr'] < min_lr[0] and
                (len(optim.param_groups) == 1 or
                 optim.param_groups[1]['lr'] < min_lr[1])):
            print('\nConverged.')
            break

    print('')
    with torch.no_grad():
        moved = pull(parameters, velocity)
        aff = core.linalg.expm(parameters, basis)
        if origin == 'center':
            aff[..., :-1, -1] -= shift
            shift = core.linalg.matvec(aff[..., :-1, :-1], shift)
            aff[..., :-1, -1] += shift
        aff = aff.inverse()
    return (parameters.detach(),
            aff.detach(),
            velocity.detach(),
            moved.detach())
Exemplo n.º 4
0
def affine(source,
           target,
           group='SE',
           loss=None,
           pull=None,
           preproc=True,
           max_iter=1000,
           device=None,
           origin='center',
           init=None,
           lr=0.1,
           scheduler=ReduceLROnPlateau):
    """Affine registration
    
    Note
    ----
    .. Tensors must have shape (batch, channel, *spatial)
    .. Composite losses (e.g., computed on both intensity and categorical
       images) can be obtained by stacking all types of inputs across 
       the channel dimension. The loss function is then responsible 
       for unstacking the tensor and computing the appropriate 
       losses. The drawback of this approach is that all inputs
       must share the same lattice and orientation matrix, as 
       well as the same interpolation order. The advantage is that 
       it simplifies the signature of this function.

    Parameters
    ----------
    source : tensor or (tensor, affine)
    target : tensor or (tensor, affine)
    group : {'T', 'SO', 'SE', 'CSO', 'GL+', 'Aff+'}, default='SE'
    loss : Loss, default=MutualInfoLoss()
    pull : dict
        interpolation : int, default=1
        bound : bound_like, default='dct2'
        extrapolate : bool, default=False
    preproc : bool, default=True
    max_iter : int, default=1000
    device : device, optional
    origin : {'native', 'center'}, default='center'
    init : tensor_like, default=0
    lr : float, default=0.1
    scheduler : Scheduler, default=ReduceLROnPlateau

    Returns
    -------
    q : tensor
        Parameters
    aff : (D+1, D+1) tensor
        Affine transformation matrix.
        The source affine matrix can be "corrected" by left-multiplying
        it with `aff`.
    moved : tensor
        Source image moved to target space.


    """
    pull = pull or dict()
    pull['interpolation'] = pull.get('interpolation', 'linear')
    pull['bound'] = pull.get('bound', 'dct2')
    pull['extrapolate'] = pull.get('extrapolate', False)
    pull_opt = pull

    # prepare all data tensors
    ((source, source_aff), (target, target_aff)) = prepare([source, target],
                                                           device)
    backend = get_backend(source)
    batch = source.shape[0]
    src_channels = source.shape[1]
    trg_channels = target.shape[1]
    dim = source.dim() - 2

    # Rescale to [0, 1]
    if preproc:
        source = rescale(source)
        target = rescale(target)

    # Shift origin
    if origin == 'center':
        shift = torch.as_tensor(target.shape, **backend) / 2
        shift = -spatial.affine_matvec(target_aff, shift)
        target_aff[..., :-1, -1] += shift
        source_aff[..., :-1, -1] += shift

    # Prepare affine utils + Initialize parameters
    basis = spatial.affine_basis(group, dim, **backend)
    nb_prm = spatial.affine_basis_size(group, dim)
    if init is not None:
        parameters = torch.as_tensor(init, **backend).clone().detach()
        parameters = parameters.reshape([batch, nb_prm])
    else:
        parameters = torch.zeros([batch, nb_prm], **backend)
    parameters = nn.Parameter(parameters, requires_grad=True)
    identity = spatial.identity_grid(target.shape[2:], **backend)

    def pull(q):
        aff = core.linalg.expm(q, basis)
        aff = spatial.affine_matmul(aff, target_aff)
        aff = spatial.affine_lmdiv(source_aff, aff)
        expd = (slice(None), ) + (None, ) * dim + (slice(None), slice(None))
        grid = spatial.affine_matvec(aff[expd], identity)
        moved = spatial.grid_pull(source, grid, **pull_opt)
        return moved

    # Prepare loss and optimizer
    if loss is None:
        loss_fn = nni.MutualInfoLoss()
        loss = lambda x, y: loss_fn(x, y)

    optim = torch.optim.Adam([parameters], lr=lr)
    if scheduler is not None:
        scheduler = scheduler(optim)

    # Optim loop
    loss_val = core.constants.inf
    for n_iter in range(1, max_iter + 1):

        loss_val0 = loss_val
        optim.zero_grad(set_to_none=True)
        moved = pull(parameters)
        loss_val = loss(moved, target)
        loss_val.backward()
        optim.step()
        if scheduler is not None and n_iter % 10 == 0:
            if isinstance(scheduler, ReduceLROnPlateau):
                scheduler.step(loss_val)
            else:
                scheduler.step()

        with torch.no_grad():
            if n_iter % 10 == 0:
                print('{:4d} {:12.6f} | lr={:g}'.format(
                    n_iter, loss_val.item(), optim.param_groups[0]['lr']),
                      end='\r')

    print('')
    with torch.no_grad():
        moved = pull(parameters)
        aff = core.linalg.expm(parameters, basis)
        if origin == 'center':
            aff[..., :-1, -1] -= shift
            shift = core.linalg.matvec(aff[..., :-1, :-1], shift)
            aff[..., :-1, -1] += shift
        aff = aff.inverse()
        aff.requires_grad_(False)
    return parameters, aff, moved
Exemplo n.º 5
0
def diffeo(source,
           target,
           group='SE',
           image_loss=None,
           vel_loss=None,
           pull=None,
           preproc=False,
           max_iter=1000,
           device=None,
           origin='center',
           init=None,
           lr=1e-4,
           optim_affine=True,
           scheduler=ReduceLROnPlateau):
    """

    Parameters
    ----------
    source : path or tensor or (tensor, affine)
    target : path or tensor or (tensor, affine)
    group : {'T', 'SO', 'SE', 'CSO', 'GL+', 'Aff+'}, default='SE'
    image_loss : Loss, default=MutualInfoLoss()
    pull : dict
        interpolation : int, default=1
        bound : bound_like, default='dct2'
        extrapolate : bool, default=False
    preproc : bool, default=True
    max_iter : int, default=1000
    device : device, optional
    origin : {'native', 'center'}, default='center'
    init : tensor_like, default=0
    lr: float, default=1e-4
    optim_affine : bool, default=True

    Returns
    -------
    q : tensor
        Parameters
    aff : (D+1, D+1) tensor
        Affine transformation matrix.
        The source affine matrix can be "corrected" by left-multiplying
        it with `aff`.
    vel : (D+1, D+1) tensor
        Initial velocity of the diffeomorphic transform.
        The full warp is `(aff @ aff_src).inv() @ aff_trg @ exp(vel)`
    moved : tensor
        Source image moved to target space.


    """
    pull = pull or dict()
    pull['interpolation'] = pull.get('interpolation', 'linear')
    pull['bound'] = pull.get('bound', 'dct2')
    pull['extrapolate'] = pull.get('extrapolate', False)
    pull_opt = pull

    # prepare all data tensors
    ((source, source_aff), (target, target_aff)) = prepare([source, target],
                                                           device)
    backend = get_backend(source)
    batch = source.shape[0]
    src_channels = source.shape[1]
    trg_channels = target.shape[1]
    dim = source.dim() - 2

    # Rescale to [0, 1]
    source = rescale(source)
    targe = rescale(target)

    # Shift origin
    if origin == 'center':
        shift = torch.as_tensor(target.shape, **backend) / 2
        shift = -spatial.affine_matvec(target_aff, shift)
        target_aff = target_aff.clone()
        source_aff = source_aff.clone()
        target_aff[..., :-1, -1] += shift
        source_aff[..., :-1, -1] += shift

    # Prepare affine utils + Initialize parameters
    basis = spatial.affine_basis(group, dim, **backend)
    nb_prm = spatial.affine_basis_size(group, dim)
    if init is not None:
        parameters = torch.as_tensor(init, **backend).clone().detach()
        parameters = parameters.reshape([batch, nb_prm])
    else:
        parameters = torch.zeros([batch, nb_prm], **backend)
    parameters = nn.Parameter(parameters, requires_grad=optim_affine)
    velocity = torch.zeros([batch, *target.shape[2:], dim], **backend)
    velocity = nn.Parameter(velocity, requires_grad=True)

    def pull(q, vel):
        grid = spatial.exp(vel)
        aff = core.linalg.expm(q, basis)
        aff = spatial.affine_matmul(aff, target_aff)
        aff = spatial.affine_lmdiv(source_aff, aff)
        grid = spatial.affine_matvec(aff, grid)
        moved = spatial.grid_pull(source, grid, **pull_opt)
        return moved

    # Prepare loss and optimizer
    if not callable(image_loss):
        image_loss_fn = nni.MutualInfoLoss()
        factor = 1. if image_loss is None else image_loss
        image_loss = lambda x, y: factor * image_loss_fn(x, y)

    if not callable(vel_loss):
        vel_loss_fn = nni.BendingLoss(bound='dft')
        factor = 1. if vel_loss is None else vel_loss
        vel_loss = lambda x: factor * vel_loss_fn(core.utils.last2channel(x))

    lr = core.utils.make_list(lr, 2)
    opt_prm = [{'params': parameters}, {'params': velocity, 'lr': lr[1]}] \
              if optim_affine else [velocity]
    optim = torch.optim.Adam(opt_prm, lr=lr[0])
    if scheduler is not None:
        scheduler = scheduler(optim, cooldown=5)

    # Optim loop
    loss_val = core.constants.inf
    loss_avg = 0
    for n_iter in range(1, max_iter + 1):

        loss_val0 = loss_val
        optim.zero_grad(set_to_none=True)
        moved = pull(parameters, velocity)
        loss_val = image_loss(moved, target) + vel_loss(velocity)
        loss_val.backward()
        optim.step()
        with torch.no_grad():
            loss_avg += loss_val

        if n_iter % 10 == 0:
            loss_avg /= 10
            if scheduler is not None:
                if isinstance(scheduler, ReduceLROnPlateau):
                    scheduler.step(loss_avg)
                else:
                    scheduler.step()

            with torch.no_grad():
                if n_iter % 10 == 0:
                    print('{:4d} {:12.6f} | lr={:g}'.format(
                        n_iter, loss_avg.item(), optim.param_groups[0]['lr']),
                          end='\r')

            loss_avg = 0

    print('')
    with torch.no_grad():
        moved = pull(parameters, velocity)
        aff = core.linalg.expm(parameters, basis)
        if origin == 'center':
            aff[..., :-1, -1] -= shift
            shift = core.linalg.matvec(aff[..., :-1, :-1], shift)
            aff[..., :-1, -1] += shift
        aff = aff.inverse()
        aff.requires_grad_(False)
    return parameters, aff, velocity, moved
Exemplo n.º 6
0
def ffd(source,
        target,
        grid_shape=10,
        group='SE',
        image_loss=None,
        def_loss=None,
        pull=None,
        preproc=True,
        max_iter=1000,
        device=None,
        origin='center',
        init=None,
        lr=1e-4,
        optim_affine=True,
        scheduler=ReduceLROnPlateau):
    """FFD (= cubic spline) registration
    
    Note
    ----
    .. Tensors must have shape (batch, channel, *spatial)
    .. Composite losses (e.g., computed on both intensity and categorical
       images) can be obtained by stacking all types of inputs across 
       the channel dimension. The loss function is then responsible 
       for unstacking the tensor and computing the appropriate 
       losses. The drawback of this approach is that all inputs
       must share the same lattice and orientation matrix, as 
       well as the same interpolation order. The advantage is that 
       it simplifies the signature of this function.

    Parameters
    ----------
    source : tensor or (tensor, affine)
    target : tensor or (tensor, affine)
    group : {'T', 'SO', 'SE', 'CSO', 'GL+', 'Aff+'}, default='SE'
    loss : Loss, default=MutualInfoLoss()
    pull : dict
        interpolation : int, default=1
        bound : bound_like, default='dct2'
        extrapolate : bool, default=False
    preproc : bool, default=True
    max_iter : int, default=1000
    device : device, optional
    origin : {'native', 'center'}, default='center'
    init : tensor_like, default=0
    lr : float, default=0.1
    scheduler : Scheduler, default=ReduceLROnPlateau

    Returns
    -------
    q : tensor
        Parameters
    aff : (D+1, D+1) tensor
        Affine transformation matrix.
        The source affine matrix can be "corrected" by left-multiplying
        it with `aff`.
    moved : tensor
        Source image moved to target space.


    """
    pull = pull or dict()
    pull['interpolation'] = pull.get('interpolation', 'linear')
    pull['bound'] = pull.get('bound', 'dft')
    pull['extrapolate'] = pull.get('extrapolate', False)
    pull_opt = pull

    # prepare all data tensors
    ((source, source_aff), (target, target_aff)) = prepare([source, target],
                                                           device)
    backend = get_backend(source)
    batch = source.shape[0]
    src_channels = source.shape[1]
    trg_channels = target.shape[1]
    dim = source.dim() - 2

    # Rescale to [0, 1]
    if preproc:
        source = rescale(source)
        target = rescale(target)

    # Shift origin
    if origin == 'center':
        shift = torch.as_tensor(target.shape, **backend) / 2
        shift = -spatial.affine_matvec(target_aff, shift)
        target_aff[..., :-1, -1] += shift
        source_aff[..., :-1, -1] += shift

    # Prepare affine utils + Initialize parameters
    basis = spatial.affine_basis(group, dim, **backend)
    nb_prm = spatial.affine_basis_size(group, dim)
    if init is not None:
        affine_parameters = torch.as_tensor(init, **backend).clone().detach()
        affine_parameters = affine_parameters.reshape([batch, nb_prm])
    else:
        affine_parameters = torch.zeros([batch, nb_prm], **backend)
    affine_parameters = nn.Parameter(affine_parameters,
                                     requires_grad=optim_affine)
    grid_shape = core.pyutils.make_list(grid_shape, dim)
    grid_parameters = torch.zeros([batch, *grid_shape, dim], **backend)
    grid_parameters = nn.Parameter(grid_parameters, requires_grad=True)

    def pull(q, grid):
        aff = core.linalg.expm(q, basis)
        aff = spatial.affine_matmul(aff, target_aff)
        aff = spatial.affine_lmdiv(source_aff, aff)
        expd = (slice(None), ) + (None, ) * dim + (slice(None), slice(None))
        grid = spatial.affine_matvec(aff[expd], grid)
        moved = spatial.grid_pull(source, grid, **pull_opt)
        return moved

    def exp(prm):
        disp = spatial.resize_grid(prm,
                                   type='displacement',
                                   shape=target.shape[2:],
                                   interpolation=3,
                                   bound='dft')
        grid = disp + spatial.identity_grid(target.shape[2:], **backend)
        return disp, grid

    # Prepare loss and optimizer
    if not callable(image_loss):
        image_loss_fn = nni.MutualInfoLoss()
        factor = 1. if image_loss is None else image_loss
        image_loss = lambda x, y: factor * image_loss_fn(x, y)

    if not callable(def_loss):
        def_loss_fn = nni.BendingLoss(bound='dft')
        factor = 1. if def_loss is None else def_loss
        def_loss = lambda x: factor * def_loss_fn(core.utils.last2channel(x))

    lr = core.utils.make_list(lr, 2)
    opt_prm = [{
        'params': affine_parameters,
        'lr': lr[1]
    }, {
        'params': grid_parameters,
        'lr': lr[0]
    }] if optim_affine else [grid_parameters]
    optim = torch.optim.Adam(opt_prm, lr=lr[0])
    if scheduler is not None:
        scheduler = scheduler(optim, cooldown=5)


#     with torch.no_grad():
#         disp, grid = exp(grid_parameters)
#         moved = pull(affine_parameters, grid)
#         plt.imshow(torch.cat([target, moved, source], dim=1).detach().cpu())
#         plt.show()

# Optim loop
    loss_val = core.constants.inf
    loss_avg = 0
    for n_iter in range(max_iter):

        loss_val0 = loss_val
        zero_grad_([affine_parameters, grid_parameters])
        disp, grid = exp(grid_parameters)
        moved = pull(affine_parameters, grid)
        loss_val = image_loss(moved, target) + def_loss(disp[0])
        loss_val.backward()
        optim.step()

        with torch.no_grad():
            loss_avg += loss_val

        if n_iter % 10 == 0:
            #             print(affine_parameters)
            #             plt.imshow(torch.cat([target, moved, source], dim=1).detach().cpu())
            #             plt.show()

            loss_avg /= 10
            if scheduler is not None:
                if isinstance(scheduler, ReduceLROnPlateau):
                    scheduler.step(loss_avg)
                else:
                    scheduler.step()

            with torch.no_grad():
                if n_iter % 10 == 0:
                    print('{:4d} {:12.6f} | lr={:g}'.format(
                        n_iter, loss_avg.item(), optim.param_groups[0]['lr']),
                          end='\r')

            loss_avg = 0

    print('')
    with torch.no_grad():
        moved = pull(affine_parameters, grid)
        aff = core.linalg.expm(affine_parameters, basis)
        if origin == 'center':
            aff[..., :-1, -1] -= shift
            shift = core.linalg.matvec(aff[..., :-1, :-1], shift)
            aff[..., :-1, -1] += shift
        aff = aff.inverse()
        aff.requires_grad_(False)
    return affine_parameters, aff, grid_parameters, moved