示例#1
0
 def call(self, x, y):
     xs = x.unbind(0)
     ys = y.unbind(0)
     loss = 0
     nb_channels = max(len(xs), len(ys))
     # if len(xs) == 1:
     #     xs = [xs[0]] * nb_channels
     # if len(ys) == 1:
     #     ys = [ys[0]] * nb_channels
     for x, y in zip(xs, ys):
         x = x[None, None]
         y = y[None, None]
         loss += nn.MutualInfoLoss(patch_size=self.patch)(x,
                                                          y) / nb_channels
     # I take the average of MI across channels to be consistent
     # with how MSE works.
     if self.factor != 1:
         loss = loss * self.factor
     return loss
示例#2
0
 def call(self, x, y):
     xs = x.unbind(0)
     ys = y.unbind(0)
     loss = 0
     nb_channels = max(len(xs), len(ys))
     # if len(xs) == 1:
     #     xs = [xs[0]] * nb_channels
     # if len(ys) == 1:
     #     ys = [ys[0]] * nb_channels
     for x, y in zip(xs, ys):
         x = x[None, None]
         y = y[None, None]
         mi = nn.MutualInfoLoss(nb_bins=self.bins,
                                patch_size=self.patch,
                                mask=[None, self.threshold],
                                order=self.order,
                                fwhm=self.fwhm)
         loss += mi(x, y) / nb_channels
     # I take the average of MI across channels to be consistent
     # with how MSE works.
     if self.factor != 1:
         loss = loss * self.factor
     return loss
示例#3
0
def diffeo(source, target, group='SE', origin='center',
           image_loss=None, vel_loss=None, pull=None, optim_affine=True,
           max_iter=1000, lr=0.1, min_lr=1e-7, init=None, device=None):
    """Diffeomorphic registration

    Note
    ----
    .. Tensors must have shape (batch, channel, *spatial)
    .. Composite losses (e.g., computed on both intensity and categorical
       images) can be obtained by stacking all types of inputs across
       the channel dimension. The loss function is then responsible
       for unstacking the tensor and computing the appropriate
       losses. The drawback of this approach is that all inputs
       must share the same lattice and orientation matrix, as
       well as the same interpolation order. The advantage is that
       it simplifies the signature of this function.

    Parameters
    ----------
    source :  tensor or (tensor, affine)
        The source (moving) image, with shape (batch, channel, *spatial).
    target : tensor or (tensor, affine)
        The target (fixed) image, with shape (batch, channel, *spatial).
    group : {'tr', 'rot', 'rigid', 'sim', 'lin', 'aff'}, default='rigid'
        Affine sub-group to optimize.
    origin : {'native', 'center'}, default='center'
        Whether to rotate about the origin of the world-space ('native')
        or the center of the target field-of-view ('center').
        When the origin of the world-space is far off (say you are
        registering smaller blocks cropped from a larger MRI), it can
        be beneficiary to rotate about the center of the FOV.
    image_loss : callable(mov, fix) -> loss, default=MutualInfoLoss()
        A loss function that takestwo  inputs of shape
        (batch, channel, *spatial).
    vel_loss : float or callable(mov, fix) -> loss, default=BendingLoss()
        Either a factor to muultiply the bending loss with or a loss 
        function that takes two inputs of shape (batch, channel, *spatial).
    pull : dict
        interpolation : int, default=1
            Interpolation order
        bound : bound_like, default='dct2'
            Boundary condition
        extrapolate : bool, default=False
            Extrapolate out-of-bound data using the boundary conditions.
    max_iter : int, default=1000
        Maximum number of iterations
    lr : float, default=0.1
        Initial learning rate.
    min_lr : float, default=1e-7
        Minimum learning rate. The optimization is stopped once this
        learning rate is reached.
    device : {'cpu', 'cuda', 'cuda:<id>'}, optional
        Backend to use
    init : ([batch], nb_prm) tensor_like, default=0
        Initial guess for the affine parameters.

    Returns
    -------
    q : (batch, nb_prm) tensor
        Parameters
    aff : (batch, D+1, D+1) tensor
        Affine transformation matrix.
        The source affine matrix can be "corrected" by left-multiplying
        it with `aff`.
    vel : (batch, *shape, D) tensor
        Initial velocity
    moved : tensor
        Source image moved to target space.
    """
    group = affine_group_converter(group)
    pull = pull or dict()
    pull['interpolation'] = pull.get('interpolation', 'linear')
    pull['bound'] = pull.get('bound', 'dct2')
    pull['extrapolate'] = pull.get('extrapolate', False)
    pull_opt = pull

    # prepare all data tensors
    ((source, source_aff), (target, target_aff)) = prepare([source, target],
                                                           device)
    backend = get_backend(source)
    batch = source.shape[0]
    dim = source.dim() - 2

    # Shift origin
    if origin == 'center':
        shift = torch.as_tensor(target.shape, **backend) / 2
        shift = -spatial.affine_matvec(target_aff, shift)
        target_aff = target_aff.clone()
        source_aff = source_aff.clone()
        target_aff[..., :-1, -1] += shift
        source_aff[..., :-1, -1] += shift

    # Prepare affine utils + Initialize parameters
    basis = spatial.affine_basis(group, dim, **backend)
    nb_prm = spatial.affine_basis_size(group, dim)
    if init is not None:
        parameters = torch.as_tensor(init, **backend).clone().detach()
        parameters = parameters.reshape([batch, nb_prm])
    else:
        parameters = torch.zeros([batch, nb_prm], **backend)
    parameters = nn.Parameter(parameters, requires_grad=optim_affine)
    velocity = torch.zeros([batch, *target.shape[2:], dim], **backend)
    velocity = nn.Parameter(velocity, requires_grad=True)

    def pull(q, vel):
        grid = spatial.exp(vel)
        aff = core.linalg.expm(q, basis)
        aff = spatial.affine_matmul(aff, target_aff)
        aff = spatial.affine_lmdiv(source_aff, aff)
        grid = spatial.affine_matvec(aff, grid)
        moved = spatial.grid_pull(source, grid, **pull_opt)
        return moved

    # Prepare loss and optimizer
    if not callable(image_loss):
        image_loss_fn = nni.MutualInfoLoss()
        factor = 1. if image_loss is None else image_loss
        image_loss = lambda x, y: factor * image_loss_fn(x, y)

    if not callable(vel_loss):
        vel_loss_fn = nni.BendingLoss(bound='dft')
        factor = 1. if vel_loss is None else vel_loss
        vel_loss = lambda x: factor * vel_loss_fn(core.utils.last2channel(x))

    lr = core.utils.make_list(lr, 2)
    min_lr = core.utils.make_list(min_lr, 2)
    opt_prm = [{'params': parameters}, {'params': velocity, 'lr': lr[1]}] \
              if optim_affine else [velocity]
    optim = torch.optim.Adam(opt_prm, lr=lr[0])
    scheduler = ReduceLROnPlateau(optim)

    def forward():
        moved = pull(parameters, velocity)
        loss_val = image_loss(moved, target) + vel_loss(velocity)
        return loss_val

    # Optim loop
    loss_avg = 0
    for n_iter in range(1, max_iter + 1):

        optim.zero_grad(set_to_none=True)
        loss_val = forward()
        loss_val.backward()
        optim.step(forward)

        with torch.no_grad():
            loss_avg += loss_val
            if n_iter % 10 == 0:
                loss_avg /= 10
                scheduler.step(loss_avg)

                print('{:4d} {:12.6f} | lr={:g} '
                      .format(n_iter, loss_avg.item(),
                              optim.param_groups[0]['lr']),
                      end='\r')
                loss_avg = 0

        if (optim.param_groups[0]['lr'] < min_lr[0] and
                (len(optim.param_groups) == 1 or
                 optim.param_groups[1]['lr'] < min_lr[1])):
            print('\nConverged.')
            break

    print('')
    with torch.no_grad():
        moved = pull(parameters, velocity)
        aff = core.linalg.expm(parameters, basis)
        if origin == 'center':
            aff[..., :-1, -1] -= shift
            shift = core.linalg.matvec(aff[..., :-1, :-1], shift)
            aff[..., :-1, -1] += shift
        aff = aff.inverse()
    return (parameters.detach(),
            aff.detach(),
            velocity.detach(),
            moved.detach())
示例#4
0
def affine(source,
           target,
           group='SE',
           loss=None,
           pull=None,
           preproc=True,
           max_iter=1000,
           device=None,
           origin='center',
           init=None,
           lr=0.1,
           scheduler=ReduceLROnPlateau):
    """Affine registration
    
    Note
    ----
    .. Tensors must have shape (batch, channel, *spatial)
    .. Composite losses (e.g., computed on both intensity and categorical
       images) can be obtained by stacking all types of inputs across 
       the channel dimension. The loss function is then responsible 
       for unstacking the tensor and computing the appropriate 
       losses. The drawback of this approach is that all inputs
       must share the same lattice and orientation matrix, as 
       well as the same interpolation order. The advantage is that 
       it simplifies the signature of this function.

    Parameters
    ----------
    source : tensor or (tensor, affine)
    target : tensor or (tensor, affine)
    group : {'T', 'SO', 'SE', 'CSO', 'GL+', 'Aff+'}, default='SE'
    loss : Loss, default=MutualInfoLoss()
    pull : dict
        interpolation : int, default=1
        bound : bound_like, default='dct2'
        extrapolate : bool, default=False
    preproc : bool, default=True
    max_iter : int, default=1000
    device : device, optional
    origin : {'native', 'center'}, default='center'
    init : tensor_like, default=0
    lr : float, default=0.1
    scheduler : Scheduler, default=ReduceLROnPlateau

    Returns
    -------
    q : tensor
        Parameters
    aff : (D+1, D+1) tensor
        Affine transformation matrix.
        The source affine matrix can be "corrected" by left-multiplying
        it with `aff`.
    moved : tensor
        Source image moved to target space.


    """
    pull = pull or dict()
    pull['interpolation'] = pull.get('interpolation', 'linear')
    pull['bound'] = pull.get('bound', 'dct2')
    pull['extrapolate'] = pull.get('extrapolate', False)
    pull_opt = pull

    # prepare all data tensors
    ((source, source_aff), (target, target_aff)) = prepare([source, target],
                                                           device)
    backend = get_backend(source)
    batch = source.shape[0]
    src_channels = source.shape[1]
    trg_channels = target.shape[1]
    dim = source.dim() - 2

    # Rescale to [0, 1]
    if preproc:
        source = rescale(source)
        target = rescale(target)

    # Shift origin
    if origin == 'center':
        shift = torch.as_tensor(target.shape, **backend) / 2
        shift = -spatial.affine_matvec(target_aff, shift)
        target_aff[..., :-1, -1] += shift
        source_aff[..., :-1, -1] += shift

    # Prepare affine utils + Initialize parameters
    basis = spatial.affine_basis(group, dim, **backend)
    nb_prm = spatial.affine_basis_size(group, dim)
    if init is not None:
        parameters = torch.as_tensor(init, **backend).clone().detach()
        parameters = parameters.reshape([batch, nb_prm])
    else:
        parameters = torch.zeros([batch, nb_prm], **backend)
    parameters = nn.Parameter(parameters, requires_grad=True)
    identity = spatial.identity_grid(target.shape[2:], **backend)

    def pull(q):
        aff = core.linalg.expm(q, basis)
        aff = spatial.affine_matmul(aff, target_aff)
        aff = spatial.affine_lmdiv(source_aff, aff)
        expd = (slice(None), ) + (None, ) * dim + (slice(None), slice(None))
        grid = spatial.affine_matvec(aff[expd], identity)
        moved = spatial.grid_pull(source, grid, **pull_opt)
        return moved

    # Prepare loss and optimizer
    if loss is None:
        loss_fn = nni.MutualInfoLoss()
        loss = lambda x, y: loss_fn(x, y)

    optim = torch.optim.Adam([parameters], lr=lr)
    if scheduler is not None:
        scheduler = scheduler(optim)

    # Optim loop
    loss_val = core.constants.inf
    for n_iter in range(1, max_iter + 1):

        loss_val0 = loss_val
        optim.zero_grad(set_to_none=True)
        moved = pull(parameters)
        loss_val = loss(moved, target)
        loss_val.backward()
        optim.step()
        if scheduler is not None and n_iter % 10 == 0:
            if isinstance(scheduler, ReduceLROnPlateau):
                scheduler.step(loss_val)
            else:
                scheduler.step()

        with torch.no_grad():
            if n_iter % 10 == 0:
                print('{:4d} {:12.6f} | lr={:g}'.format(
                    n_iter, loss_val.item(), optim.param_groups[0]['lr']),
                      end='\r')

    print('')
    with torch.no_grad():
        moved = pull(parameters)
        aff = core.linalg.expm(parameters, basis)
        if origin == 'center':
            aff[..., :-1, -1] -= shift
            shift = core.linalg.matvec(aff[..., :-1, :-1], shift)
            aff[..., :-1, -1] += shift
        aff = aff.inverse()
        aff.requires_grad_(False)
    return parameters, aff, moved
示例#5
0
def diffeo(source,
           target,
           group='SE',
           image_loss=None,
           vel_loss=None,
           pull=None,
           preproc=False,
           max_iter=1000,
           device=None,
           origin='center',
           init=None,
           lr=1e-4,
           optim_affine=True,
           scheduler=ReduceLROnPlateau):
    """

    Parameters
    ----------
    source : path or tensor or (tensor, affine)
    target : path or tensor or (tensor, affine)
    group : {'T', 'SO', 'SE', 'CSO', 'GL+', 'Aff+'}, default='SE'
    image_loss : Loss, default=MutualInfoLoss()
    pull : dict
        interpolation : int, default=1
        bound : bound_like, default='dct2'
        extrapolate : bool, default=False
    preproc : bool, default=True
    max_iter : int, default=1000
    device : device, optional
    origin : {'native', 'center'}, default='center'
    init : tensor_like, default=0
    lr: float, default=1e-4
    optim_affine : bool, default=True

    Returns
    -------
    q : tensor
        Parameters
    aff : (D+1, D+1) tensor
        Affine transformation matrix.
        The source affine matrix can be "corrected" by left-multiplying
        it with `aff`.
    vel : (D+1, D+1) tensor
        Initial velocity of the diffeomorphic transform.
        The full warp is `(aff @ aff_src).inv() @ aff_trg @ exp(vel)`
    moved : tensor
        Source image moved to target space.


    """
    pull = pull or dict()
    pull['interpolation'] = pull.get('interpolation', 'linear')
    pull['bound'] = pull.get('bound', 'dct2')
    pull['extrapolate'] = pull.get('extrapolate', False)
    pull_opt = pull

    # prepare all data tensors
    ((source, source_aff), (target, target_aff)) = prepare([source, target],
                                                           device)
    backend = get_backend(source)
    batch = source.shape[0]
    src_channels = source.shape[1]
    trg_channels = target.shape[1]
    dim = source.dim() - 2

    # Rescale to [0, 1]
    source = rescale(source)
    targe = rescale(target)

    # Shift origin
    if origin == 'center':
        shift = torch.as_tensor(target.shape, **backend) / 2
        shift = -spatial.affine_matvec(target_aff, shift)
        target_aff = target_aff.clone()
        source_aff = source_aff.clone()
        target_aff[..., :-1, -1] += shift
        source_aff[..., :-1, -1] += shift

    # Prepare affine utils + Initialize parameters
    basis = spatial.affine_basis(group, dim, **backend)
    nb_prm = spatial.affine_basis_size(group, dim)
    if init is not None:
        parameters = torch.as_tensor(init, **backend).clone().detach()
        parameters = parameters.reshape([batch, nb_prm])
    else:
        parameters = torch.zeros([batch, nb_prm], **backend)
    parameters = nn.Parameter(parameters, requires_grad=optim_affine)
    velocity = torch.zeros([batch, *target.shape[2:], dim], **backend)
    velocity = nn.Parameter(velocity, requires_grad=True)

    def pull(q, vel):
        grid = spatial.exp(vel)
        aff = core.linalg.expm(q, basis)
        aff = spatial.affine_matmul(aff, target_aff)
        aff = spatial.affine_lmdiv(source_aff, aff)
        grid = spatial.affine_matvec(aff, grid)
        moved = spatial.grid_pull(source, grid, **pull_opt)
        return moved

    # Prepare loss and optimizer
    if not callable(image_loss):
        image_loss_fn = nni.MutualInfoLoss()
        factor = 1. if image_loss is None else image_loss
        image_loss = lambda x, y: factor * image_loss_fn(x, y)

    if not callable(vel_loss):
        vel_loss_fn = nni.BendingLoss(bound='dft')
        factor = 1. if vel_loss is None else vel_loss
        vel_loss = lambda x: factor * vel_loss_fn(core.utils.last2channel(x))

    lr = core.utils.make_list(lr, 2)
    opt_prm = [{'params': parameters}, {'params': velocity, 'lr': lr[1]}] \
              if optim_affine else [velocity]
    optim = torch.optim.Adam(opt_prm, lr=lr[0])
    if scheduler is not None:
        scheduler = scheduler(optim, cooldown=5)

    # Optim loop
    loss_val = core.constants.inf
    loss_avg = 0
    for n_iter in range(1, max_iter + 1):

        loss_val0 = loss_val
        optim.zero_grad(set_to_none=True)
        moved = pull(parameters, velocity)
        loss_val = image_loss(moved, target) + vel_loss(velocity)
        loss_val.backward()
        optim.step()
        with torch.no_grad():
            loss_avg += loss_val

        if n_iter % 10 == 0:
            loss_avg /= 10
            if scheduler is not None:
                if isinstance(scheduler, ReduceLROnPlateau):
                    scheduler.step(loss_avg)
                else:
                    scheduler.step()

            with torch.no_grad():
                if n_iter % 10 == 0:
                    print('{:4d} {:12.6f} | lr={:g}'.format(
                        n_iter, loss_avg.item(), optim.param_groups[0]['lr']),
                          end='\r')

            loss_avg = 0

    print('')
    with torch.no_grad():
        moved = pull(parameters, velocity)
        aff = core.linalg.expm(parameters, basis)
        if origin == 'center':
            aff[..., :-1, -1] -= shift
            shift = core.linalg.matvec(aff[..., :-1, :-1], shift)
            aff[..., :-1, -1] += shift
        aff = aff.inverse()
        aff.requires_grad_(False)
    return parameters, aff, velocity, moved
示例#6
0
def ffd(source,
        target,
        grid_shape=10,
        group='SE',
        image_loss=None,
        def_loss=None,
        pull=None,
        preproc=True,
        max_iter=1000,
        device=None,
        origin='center',
        init=None,
        lr=1e-4,
        optim_affine=True,
        scheduler=ReduceLROnPlateau):
    """FFD (= cubic spline) registration
    
    Note
    ----
    .. Tensors must have shape (batch, channel, *spatial)
    .. Composite losses (e.g., computed on both intensity and categorical
       images) can be obtained by stacking all types of inputs across 
       the channel dimension. The loss function is then responsible 
       for unstacking the tensor and computing the appropriate 
       losses. The drawback of this approach is that all inputs
       must share the same lattice and orientation matrix, as 
       well as the same interpolation order. The advantage is that 
       it simplifies the signature of this function.

    Parameters
    ----------
    source : tensor or (tensor, affine)
    target : tensor or (tensor, affine)
    group : {'T', 'SO', 'SE', 'CSO', 'GL+', 'Aff+'}, default='SE'
    loss : Loss, default=MutualInfoLoss()
    pull : dict
        interpolation : int, default=1
        bound : bound_like, default='dct2'
        extrapolate : bool, default=False
    preproc : bool, default=True
    max_iter : int, default=1000
    device : device, optional
    origin : {'native', 'center'}, default='center'
    init : tensor_like, default=0
    lr : float, default=0.1
    scheduler : Scheduler, default=ReduceLROnPlateau

    Returns
    -------
    q : tensor
        Parameters
    aff : (D+1, D+1) tensor
        Affine transformation matrix.
        The source affine matrix can be "corrected" by left-multiplying
        it with `aff`.
    moved : tensor
        Source image moved to target space.


    """
    pull = pull or dict()
    pull['interpolation'] = pull.get('interpolation', 'linear')
    pull['bound'] = pull.get('bound', 'dft')
    pull['extrapolate'] = pull.get('extrapolate', False)
    pull_opt = pull

    # prepare all data tensors
    ((source, source_aff), (target, target_aff)) = prepare([source, target],
                                                           device)
    backend = get_backend(source)
    batch = source.shape[0]
    src_channels = source.shape[1]
    trg_channels = target.shape[1]
    dim = source.dim() - 2

    # Rescale to [0, 1]
    if preproc:
        source = rescale(source)
        target = rescale(target)

    # Shift origin
    if origin == 'center':
        shift = torch.as_tensor(target.shape, **backend) / 2
        shift = -spatial.affine_matvec(target_aff, shift)
        target_aff[..., :-1, -1] += shift
        source_aff[..., :-1, -1] += shift

    # Prepare affine utils + Initialize parameters
    basis = spatial.affine_basis(group, dim, **backend)
    nb_prm = spatial.affine_basis_size(group, dim)
    if init is not None:
        affine_parameters = torch.as_tensor(init, **backend).clone().detach()
        affine_parameters = affine_parameters.reshape([batch, nb_prm])
    else:
        affine_parameters = torch.zeros([batch, nb_prm], **backend)
    affine_parameters = nn.Parameter(affine_parameters,
                                     requires_grad=optim_affine)
    grid_shape = core.pyutils.make_list(grid_shape, dim)
    grid_parameters = torch.zeros([batch, *grid_shape, dim], **backend)
    grid_parameters = nn.Parameter(grid_parameters, requires_grad=True)

    def pull(q, grid):
        aff = core.linalg.expm(q, basis)
        aff = spatial.affine_matmul(aff, target_aff)
        aff = spatial.affine_lmdiv(source_aff, aff)
        expd = (slice(None), ) + (None, ) * dim + (slice(None), slice(None))
        grid = spatial.affine_matvec(aff[expd], grid)
        moved = spatial.grid_pull(source, grid, **pull_opt)
        return moved

    def exp(prm):
        disp = spatial.resize_grid(prm,
                                   type='displacement',
                                   shape=target.shape[2:],
                                   interpolation=3,
                                   bound='dft')
        grid = disp + spatial.identity_grid(target.shape[2:], **backend)
        return disp, grid

    # Prepare loss and optimizer
    if not callable(image_loss):
        image_loss_fn = nni.MutualInfoLoss()
        factor = 1. if image_loss is None else image_loss
        image_loss = lambda x, y: factor * image_loss_fn(x, y)

    if not callable(def_loss):
        def_loss_fn = nni.BendingLoss(bound='dft')
        factor = 1. if def_loss is None else def_loss
        def_loss = lambda x: factor * def_loss_fn(core.utils.last2channel(x))

    lr = core.utils.make_list(lr, 2)
    opt_prm = [{
        'params': affine_parameters,
        'lr': lr[1]
    }, {
        'params': grid_parameters,
        'lr': lr[0]
    }] if optim_affine else [grid_parameters]
    optim = torch.optim.Adam(opt_prm, lr=lr[0])
    if scheduler is not None:
        scheduler = scheduler(optim, cooldown=5)


#     with torch.no_grad():
#         disp, grid = exp(grid_parameters)
#         moved = pull(affine_parameters, grid)
#         plt.imshow(torch.cat([target, moved, source], dim=1).detach().cpu())
#         plt.show()

# Optim loop
    loss_val = core.constants.inf
    loss_avg = 0
    for n_iter in range(max_iter):

        loss_val0 = loss_val
        zero_grad_([affine_parameters, grid_parameters])
        disp, grid = exp(grid_parameters)
        moved = pull(affine_parameters, grid)
        loss_val = image_loss(moved, target) + def_loss(disp[0])
        loss_val.backward()
        optim.step()

        with torch.no_grad():
            loss_avg += loss_val

        if n_iter % 10 == 0:
            #             print(affine_parameters)
            #             plt.imshow(torch.cat([target, moved, source], dim=1).detach().cpu())
            #             plt.show()

            loss_avg /= 10
            if scheduler is not None:
                if isinstance(scheduler, ReduceLROnPlateau):
                    scheduler.step(loss_avg)
                else:
                    scheduler.step()

            with torch.no_grad():
                if n_iter % 10 == 0:
                    print('{:4d} {:12.6f} | lr={:g}'.format(
                        n_iter, loss_avg.item(), optim.param_groups[0]['lr']),
                          end='\r')

            loss_avg = 0

    print('')
    with torch.no_grad():
        moved = pull(affine_parameters, grid)
        aff = core.linalg.expm(affine_parameters, basis)
        if origin == 'center':
            aff[..., :-1, -1] -= shift
            shift = core.linalg.matvec(aff[..., :-1, :-1], shift)
            aff[..., :-1, -1] += shift
        aff = aff.inverse()
        aff.requires_grad_(False)
    return affine_parameters, aff, grid_parameters, moved