Ejemplo n.º 1
0
def deformable_align(
    fix,
    mov,
    fix_spacing,
    mov_spacing,
    control_point_spacing,
    control_point_levels,
    initial_transform=None,
    alignment_spacing=None,
    fix_mask=None,
    mov_mask=None,
    fix_origin=None,
    mov_origin=None,
    jaccard_filter_threshold=None,
    default=None,
    **kwargs,
):
    """
    Register moving to fixed image with a bspline parameterized deformation field

    Parameters
    ----------
    fix : ndarray
        the fixed image

    mov : ndarray
        the moving image; `fix.ndim` must equal `mov.ndim`

    fix_spacing : 1d array
        The spacing in physical units (e.g. mm or um) between voxels
        of the fixed image.
        Length must equal `fix.ndim`

    mov_spacing : 1d array
        The spacing in physical units (e.g. mm or um) between voxels
        of the moving image.

    control_point_spacing : float
        The spacing in physical units (e.g. mm or um) between control
        points that parameterize the deformation. Smaller means
        more precise alignment, but also longer compute time. Larger
        means shorter compute time and smoother transform, but less
        precise.

    control_point_levels : list of type int
        The optimization scales for control point spacing. E.g. if
        `control_point_spacing` is 100.0 and `control_point_levels`
        is [1, 2, 4] then method will optimize at 400.0 units control
        points spacing, then optimize again at 200.0 units, then again
        at the requested 100.0 units control point spacing.
    
    initial_transform : 4x4 array (default: None)
        An initial rigid or affine matrix from which to initialize
        the optimization

    alignment_spacing : float (default: None)
        Fixed and moving images are skip sampled to a voxel spacing
        as close as possible to this value. Intended for very fast
        simple alignments (e.g. low amplitude motion correction)

    fix_mask : binary ndarray (default: None)
        A mask limiting metric evaluation region of the fixed image

    mov_mask : binary ndarray (default: None)
        A mask limiting metric evaluation region of the moving image

    fix_origin : 1d array (default: None)
        Origin of the fixed image.
        Length must equal `fix.ndim`

    mov_origin : 1d array (default: None)
        Origin of the moving image.
        Length must equal `mov.ndim`

    jaccard_filter_threshold : float in range [0, 1] (default: None)
        If `jaccard_filter_threshold`, `fix_mask`, and `mov_mask` are all
        defined (i.e. not None), then the Jaccard index between the masks
        is computed. If the index is less than this threshold the alignment
        is skipped and the default is returned. Useful for distributed piecewise
        workflows over heterogenous data.

    default : any object (default: None)
        If optimization fails to improve image matching metric,
        print an error but also return this object. If None
        the parameters and displacement field for an identity
        transform are returned.

    **kwargs : any additional arguments
        Passed to `configure_irm`
        This is where you would set things like:
        metric, iterations, shrink_factors, and smooth_sigmas

    Returns
    -------
    params : 1d array
        The complete set of control point parameters concatenated
        as a 1d array.

    field : ndarray
        The displacement field parameterized by the bspline control
        points
    """

    # check jaccard index
    a = jaccard_filter_threshold is not None
    b = fix_mask is not None
    c = mov_mask is not None
    failed_jaccard = False
    if a and b and c:
        if not jaccard_filter(fix_mask, mov_mask, jaccard_filter_threshold):
            print("Masks failed jaccard_filter")
            print("Returning default")
            failed_jaccard = True

    # store initial fixed image shape
    initial_fix_shape = fix.shape

    # skip sample to alignment spacing
    if alignment_spacing is not None:
        fix, fix_spacing_ss = ut.skip_sample(fix, fix_spacing,
                                             alignment_spacing)
        mov, mov_spacing_ss = ut.skip_sample(mov, mov_spacing,
                                             alignment_spacing)
        if fix_mask is not None:
            fix_mask, _ = ut.skip_sample(fix_mask, fix_spacing,
                                         alignment_spacing)
        if mov_mask is not None:
            mov_mask, _ = ut.skip_sample(mov_mask, mov_spacing,
                                         alignment_spacing)
        fix_spacing = fix_spacing_ss
        mov_spacing = mov_spacing_ss

    # convert to sitk images, float32 type
    fix = ut.numpy_to_sitk(fix, fix_spacing, origin=fix_origin)
    mov = ut.numpy_to_sitk(mov, mov_spacing, origin=mov_origin)
    fix = sitk.Cast(fix, sitk.sitkFloat32)
    mov = sitk.Cast(mov, sitk.sitkFloat32)

    # set up registration object
    irm = configure_irm(**kwargs)

    # set initial moving transform
    if initial_transform is not None:
        if len(initial_transform.shape) == 2:
            it = ut.matrix_to_affine_transform(initial_transform)
        irm.SetMovingInitialTransform(it)

    # get control point grid shape
    fix_size_physical = [
        sz * sp for sz, sp in zip(fix.GetSize(), fix.GetSpacing())
    ]
    x, y = control_point_spacing, control_point_levels[-1]
    control_point_grid = [
        max(1, int(sz / (x * y))) for sz in fix_size_physical
    ]

    # set initial transform
    transform = sitk.BSplineTransformInitializer(
        image1=fix,
        transformDomainMeshSize=control_point_grid,
        order=3,
    )
    irm.SetInitialTransformAsBSpline(
        transform,
        inPlace=True,
        scaleFactors=control_point_levels,
    )

    # store initial transform coordinates as default
    if default is None:
        fp = transform.GetFixedParameters()
        pp = transform.GetParameters()
        default_params = np.array(list(fp) + list(pp))
        default_field = ut.bspline_to_displacement_field(
            fix,
            transform,
            shape=initial_fix_shape,
        )
        default = (default_params, default_field)

    # now that default is defined, determine jaccard result
    if failed_jaccard: return default

    # set masks
    if fix_mask is not None:
        fix_mask = ut.numpy_to_sitk(fix_mask, fix_spacing, origin=fix_origin)
        irm.SetMetricFixedMask(fix_mask)
    if mov_mask is not None:
        mov_mask = ut.numpy_to_sitk(mov_mask, mov_spacing, origin=mov_origin)
        irm.SetMetricMovingMask(mov_mask)

    # execute alignment, for any exceptions return default
    try:
        initial_metric_value = irm.MetricEvaluate(fix, mov)
        irm.Execute(fix, mov)
        final_metric_value = irm.MetricEvaluate(fix, mov)
    except Exception as e:
        print("Registration failed due to ITK exception:\n", e)
        print("\nReturning default")
        sys.stdout.flush()
        return default

    # if registration improved metric return result
    # otherwise return default
    if final_metric_value < initial_metric_value:
        sys.stdout.flush()
        fp = transform.GetFixedParameters()
        pp = transform.GetParameters()
        params = np.array(list(fp) + list(pp))
        field = ut.bspline_to_displacement_field(
            fix,
            transform,
            shape=initial_fix_shape,
        )
        return params, field
    else:
        print("Optimization failed to improve metric")
        print("initial value: {}".format(initial_metric_value))
        print("final value: {}".format(final_metric_value))
        print("Returning default")
        sys.stdout.flush()
        return default
Ejemplo n.º 2
0
def affine_align(
        fix,
        mov,
        fix_spacing,
        mov_spacing,
        rigid=False,
        initial_transform=None,
        initialize_with_centering=False,
        alignment_spacing=None,
        fix_mask=None,
        mov_mask=None,
        fix_origin=None,
        mov_origin=None,
        jaccard_filter_threshold=None,
        default=np.eye(4),
        **kwargs,
):
    """
    Affine or rigid alignment of a fixed/moving image pair.
    Lots of flexibility in speed/accuracy trade off.
    Highly configurable and useful in many contexts.

    Parameters
    ----------
    fix : ndarray
        the fixed image

    mov : ndarray
        the moving image; `fix.ndim` must equal `mov.ndim`

    fix_spacing : 1d array
        The spacing in physical units (e.g. mm or um) between voxels
        of the fixed image.
        Length must equal `fix.ndim`

    mov_spacing : 1d array
        The spacing in physical units (e.g. mm or um) between voxels
        of the moving image.
        Length must equal `mov.ndim`

    rigid : bool (default: False)
        Restrict the alignment to rigid motion only

    initial_transform : 4x4 array (default: None)
        An initial rigid or affine matrix from which to initialize
        the optimization

    initialize_with_center : bool (default: False)
        Initialize the optimization center of mass translation
        Cannot be True if `initial_transform` is not None

    alignment_spacing : float (default: None)
        Fixed and moving images are skip sampled to a voxel spacing
        as close as possible to this value. Intended for very fast
        simple alignments (e.g. low amplitude motion correction)

    fix_mask : binary ndarray (default: None)
        A mask limiting metric evaluation region of the fixed image

    mov_mask : binary ndarray (default: None)
        A mask limiting metric evaluation region of the moving image

    fix_origin : 1d array (default: None)
        Origin of the fixed image.
        Length must equal `fix.ndim`

    mov_origin : 1d array (default: None)
        Origin of the moving image.
        Length must equal `mov.ndim`

    jaccard_filter_threshold : float in range [0, 1] (default: None)
        If `jaccard_filter_threshold`, `fix_mask`, and `mov_mask` are all
        defined (i.e. not None), then the Jaccard index between the masks
        is computed. If the index is less than this threshold the alignment
        is skipped and the default is returned. Useful for distributed piecewise
        workflows over heterogenous data.

    default : 4x4 array (default: identity matrix)
        If the optimization fails, print error message but return this value

    **kwargs : any additional arguments
        Passed to `configure_irm`
        This is where you would set things like:
        metric, iterations, shrink_factors, and smooth_sigmas

    Returns
    -------
    transform : 4x4 array
        The affine or rigid transform matrix matching moving to fixed
    """

    # update default if an initial transform is provided
    if initial_transform is not None and np.all(default == np.eye(4)):
        default = initial_transform

    # check jaccard index
    a = jaccard_filter_threshold is not None
    b = fix_mask is not None
    c = mov_mask is not None
    if a and b and c:
        if not jaccard_filter(fix_mask, mov_mask, jaccard_filter_threshold):
            print("Masks failed jaccard_filter")
            print("Returning default")
            return default

    # skip sample to alignment spacing
    if alignment_spacing is not None:
        fix, fix_spacing_ss = ut.skip_sample(fix, fix_spacing,
                                             alignment_spacing)
        mov, mov_spacing_ss = ut.skip_sample(mov, mov_spacing,
                                             alignment_spacing)
        if fix_mask is not None:
            fix_mask, _ = ut.skip_sample(fix_mask, fix_spacing,
                                         alignment_spacing)
        if mov_mask is not None:
            mov_mask, _ = ut.skip_sample(mov_mask, mov_spacing,
                                         alignment_spacing)
        fix_spacing = fix_spacing_ss
        mov_spacing = mov_spacing_ss

    # convert to float32 sitk images
    fix = ut.numpy_to_sitk(fix, fix_spacing, origin=fix_origin)
    mov = ut.numpy_to_sitk(mov, mov_spacing, origin=mov_origin)
    fix = sitk.Cast(fix, sitk.sitkFloat32)
    mov = sitk.Cast(mov, sitk.sitkFloat32)

    # set up registration object
    irm = configure_irm(**kwargs)

    # select initial transform type
    if rigid and initial_transform is None:
        transform = sitk.Euler3DTransform()
    elif rigid and initial_transform is not None:
        transform = ut.matrix_to_euler_transform(initial_transform)
    elif not rigid and initial_transform is None:
        transform = sitk.AffineTransform(3)
    elif not rigid and initial_transform is not None:
        transform = ut.matrix_to_affine_transform(initial_transform)

    # consider initializing with centering
    if initial_transform is None and initialize_with_centering:
        transform = sitk.CenteredTransformInitializer(
            fix,
            mov,
            transform,
        )

    # set initial transform
    irm.SetInitialTransform(transform, inPlace=True)

    # set masks
    if fix_mask is not None:
        fix_mask = ut.numpy_to_sitk(fix_mask, fix_spacing, origin=fix_origin)
        irm.SetMetricFixedMask(fix_mask)
    if mov_mask is not None:
        mov_mask = ut.numpy_to_sitk(mov_mask, mov_spacing, origin=mov_origin)
        irm.SetMetricMovingMask(mov_mask)

    # execute alignment, for any exceptions return default
    try:
        initial_metric_value = irm.MetricEvaluate(fix, mov)
        irm.Execute(fix, mov)
        final_metric_value = irm.MetricEvaluate(fix, mov)
    except Exception as e:
        print("Registration failed due to ITK exception:\n", e)
        print("\nReturning default")
        sys.stdout.flush()
        return default

    # if centered, convert back to Euler3DTransform object
    if rigid and initialize_with_centering:
        transform = sitk.Euler3DTransform(transform)

    # if registration improved metric return result
    # otherwise return default
    if final_metric_value < initial_metric_value:
        sys.stdout.flush()
        return ut.affine_transform_to_matrix(transform)
    else:
        print("Optimization failed to improve metric")
        print("initial value: {}".format(initial_metric_value))
        print("final value: {}".format(final_metric_value))
        print("Returning default")
        sys.stdout.flush()
        return default
Ejemplo n.º 3
0
def apply_transform(
    fix,
    mov,
    fix_spacing,
    mov_spacing,
    transform_list,
    transform_spacing=None,
    fix_origin=None,
    mov_origin=None,
    interpolate_with_nn=False,
    extrapolate_with_nn=False,
):
    """
    """

    # set global number of threads
    if "LSB_DJOB_NUMPROC" in os.environ:
        ncores = int(os.environ["LSB_DJOB_NUMPROC"])
    else:
        ncores = psutil.cpu_count(logical=False)
    sitk.ProcessObject.SetGlobalDefaultNumberOfThreads(2 * ncores)

    # convert images to sitk objects
    dtype = fix.dtype
    fix = ut.numpy_to_sitk(fix, fix_spacing, fix_origin)
    mov = ut.numpy_to_sitk(mov, mov_spacing, mov_origin)

    # construct transform
    transform = sitk.CompositeTransform(3)
    for i, t in enumerate(transform_list):

        # affine transforms
        if len(t.shape) == 2:
            t = ut.matrix_to_affine_transform(t)

        # bspline parameters
        elif len(t.shape) == 1:
            t = ut.bspline_parameters_to_transform(t)

        # fields
        elif len(t.shape) == 4:
            # set transform_spacing
            if transform_spacing is None:
                sp = fix_spacing
            elif isinstance(transform_spacing[i], tuple):
                sp = transform_spacing[i]
            else:
                sp = transform_spacing
            # create field
            t = ut.field_to_displacement_field_transform(t, sp)

        # add to composite transform
        transform.AddTransform(t)

    # set up resampler object
    resampler = sitk.ResampleImageFilter()
    resampler.SetNumberOfThreads(2 * ncores)
    resampler.SetReferenceImage(sitk.Cast(fix, sitk.sitkFloat32))
    resampler.SetTransform(transform)

    # check for NN interpolation
    if interpolate_with_nn:
        resampler.SetInterpolator(sitk.sitkNearestNeighbor)

    # check for NN extrapolation
    if extrapolate_with_nn:
        resampler.SetUseNearestNeighborExtrapolator(True)

    # execute, return as numpy array
    resampled = resampler.Execute(sitk.Cast(mov, sitk.sitkFloat32))
    return sitk.GetArrayFromImage(resampled).astype(dtype)
Ejemplo n.º 4
0
def random_affine_search(
    fix,
    mov,
    fix_spacing,
    mov_spacing,
    max_translation,
    max_rotation,
    max_scale,
    max_shear,
    random_iterations,
    affine_align_best=0,
    alignment_spacing=None,
    fix_mask=None,
    mov_mask=None,
    fix_origin=None,
    mov_origin=None,
    jaccard_filter_threshold=None,
    use_patch_mutual_information=False,
    print_running_improvements=False,
    **kwargs,
):
    """
    Apply random affine matrices within given bounds to moving image. The best
    scoring affines can be further refined with gradient descent based affine
    alignment. The single best result is returned. This function is intended
    to find good initialization for a full affine alignment obtained by calling
    `affine_align`

    Parameters
    ----------
    fix : ndarray
        the fixed image

    mov : ndarray
        the moving image; `fix.ndim` must equal `mov.ndim`

    fix_spacing : 1d array
        The spacing in physical units (e.g. mm or um) between voxels
        of the fixed image.
        Length must equal `fix.ndim`

    mov_spacing : 1d array
        The spacing in physical units (e.g. mm or um) between voxels
        of the moving image.
        Length must equal `mov.ndim`

    max_translation : float
        The maximum amplitude translation allowed in random sampling.
        Specified in physical units (e.g. um or mm)

    max_rotation : float
        The maximum amplitude rotation allowed in random sampling.
        Specified in radians

    max_scale : float
        The maximum amplitude scaling allowed in random sampling.

    max_shear : float
        The maximum amplitude shearing allowed in random sampling.

    random_iterations : int
        The number of random affine matrices to sample

    affine_align_best : int (default: 0)
        The best `affine_align_best` random affine matrices are refined
        by calling `affine_align` setting the random affine as the
        `initial_transform`. This is parameterized through **kwargs.

    alignment_spacing : float (default: None)
        Fixed and moving images are skip sampled to a voxel spacing
        as close as possible to this value. Intended for very fast
        simple alignments (e.g. low amplitude motion correction)

    fix_mask : binary ndarray (default: None)
        A mask limiting metric evaluation region of the fixed image

    mov_mask : binary ndarray (default: None)
        A mask limiting metric evaluation region of the moving image

    fix_origin : 1d array (default: None)
        Origin of the fixed image.
        Length must equal `fix.ndim`

    mov_origin : 1d array (default: None)
        Origin of the moving image.
        Length must equal `mov.ndim`

    jaccard_filter_threshold : float in range [0, 1] (default: None)
        If `jaccard_filter_threshold`, `fix_mask`, and `mov_mask` are all
        defined (i.e. not None), then the Jaccard index between the masks
        is computed. If the index is less than this threshold the alignment
        is skipped and the default is returned. Useful for distributed piecewise
        workflows over heterogenous data.

    **kwargs : any additional arguments
        Passed to `configure_irm` to score random affines
        Also passed to `affine_align` for gradient descent
        based refinement

    Returns
    -------
    transform : 4x4 array
        The (refined) random affine matrix best initializing a match of
        the moving image to the fixed. Should be further refined by calling
        `affine_align`.
    """

    # check jaccard index
    a = jaccard_filter_threshold is not None
    b = fix_mask is not None
    c = mov_mask is not None
    if a and b and c:
        if not jaccard_filter(fix_mask, mov_mask, jaccard_filter_threshold):
            print("Masks failed jaccard_filter")
            print("Returning default")
            return np.eye(4)

    # define conversion from params to affine transform
    def params_to_affine_matrix(params):

        # translation
        translation = np.eye(4)
        translation[:3, -1] = params[:3]

        # rotation
        rotation = np.eye(4)
        rotation[:3, :3] = Rotation.from_rotvec(params[3:6]).as_matrix()
        center = np.array(fix.shape) / 2 * fix_spacing
        tl, tr = np.eye(4), np.eye(4)
        tl[:3, -1], tr[:3, -1] = center, -center
        rotation = np.matmul(tl, np.matmul(rotation, tr))

        # scale
        scale = np.diag(list(params[6:9]) + [
            1,
        ])

        # shear
        shx, shy, shz = np.eye(4), np.eye(4), np.eye(4)
        shx[1, 0], shx[2, 0] = params[10], params[11]
        shy[0, 1], shy[2, 1] = params[9], params[11]
        shz[0, 2], shz[1, 2] = params[9], params[10]
        shear = np.matmul(shz, np.matmul(shy, shx))

        # compose
        aff = np.matmul(rotation, translation)
        aff = np.matmul(scale, aff)
        aff = np.matmul(shear, aff)
        return aff

    # generate random parameters, first row is always identity
    params = np.zeros((random_iterations + 1, 12))
    params[:, 6:9] = 1  # default for scale params
    F = lambda mx: 2 * mx * np.random.rand(random_iterations, 3) - mx
    if max_translation != 0: params[1:, 0:3] = F(max_translation)
    if max_rotation != 0: params[1:, 3:6] = F(max_rotation)
    if max_scale != 1: params[1:, 6:9] = np.e**F(np.log(max_scale))
    if max_shear != 0: params[1:, 9:] = F(max_shear)

    # skip sample to alignment spacing
    if alignment_spacing is not None:
        fix, fix_spacing_ss = ut.skip_sample(fix, fix_spacing,
                                             alignment_spacing)
        mov, mov_spacing_ss = ut.skip_sample(mov, mov_spacing,
                                             alignment_spacing)
        if fix_mask is not None:
            fix_mask, _ = ut.skip_sample(fix_mask, fix_spacing,
                                         alignment_spacing)
        if mov_mask is not None:
            mov_mask, _ = ut.skip_sample(mov_mask, mov_spacing,
                                         alignment_spacing)
        fix_spacing = fix_spacing_ss
        mov_spacing = mov_spacing_ss

    # keep track of poor alignments later
    fail_count = 0

    # define metric evaluation
    if use_patch_mutual_information:

        # wrap patch_mi metric
        def score_affine(affine):

            # apply transform
            aligned = apply_transform(
                fix,
                mov,
                fix_spacing,
                mov_spacing,
                transform_list=[
                    affine,
                ],
                fix_origin=fix_origin,
                mov_origin=mov_origin,
            )

            # mov mask
            mov_mask_aligned = None
            if mov_mask is not None:
                mov_mask_aligned = apply_transform(
                    fix,
                    mov_mask,
                    fix_spacing,
                    mov_spacing,
                    transform_list=[
                        affine,
                    ],
                    fix_origin=fix_origin,
                    mov_origin=mov_origin,
                    interpolate_with_nn=True,
                )

            return patch_mutual_information(
                fix,
                aligned,
                fix_spacing,
                fix_mask=fix_mask,
                mov_mask=mov_mask_aligned,
                return_metric_image=False,
                **kwargs,
            )

    # otherwise score entire image domain
    else:

        # convert to float32 sitk images
        fix_sitk = ut.numpy_to_sitk(fix, fix_spacing, origin=fix_origin)
        mov_sitk = ut.numpy_to_sitk(mov, mov_spacing, origin=mov_origin)
        fix_sitk = sitk.Cast(fix_sitk, sitk.sitkFloat32)
        mov_sitk = sitk.Cast(mov_sitk, sitk.sitkFloat32)

        # create irm to use metric object
        irm = configure_irm(**kwargs)

        # set masks
        if fix_mask is not None:
            fix_mask_sitk = ut.numpy_to_sitk(fix_mask,
                                             fix_spacing,
                                             origin=fix_origin)
            irm.SetMetricFixedMask(fix_mask_sitk)
        if mov_mask is not None:
            mov_mask_sitk = ut.numpy_to_sitk(mov_mask,
                                             mov_spacing,
                                             origin=mov_origin)
            irm.SetMetricMovingMask(mov_mask_sitk)

        # wrap full image mi
        def score_affine(affine):

            # reformat affine as tranfsorm and give to irm
            affine = ut.matrix_to_affine_transform(affine)
            irm.SetMovingInitialTransform(affine)

            # get the metric value
            try:
                return irm.MetricEvaluate(fix_sitk, mov_sitk)
            except Exception as e:
                fail_count += 1
                return np.finfo(scores.dtype).max

    # score all random affines
    current_best_score = 0
    scores = np.empty(random_iterations + 1)
    for iii, ppp in enumerate(params):
        aff = params_to_affine_matrix(ppp)
        scores[iii] = score_affine(aff)

        # print running improvements
        if print_running_improvements:
            if scores[iii] < current_best_score:
                current_best_score = scores[iii]
                print(iii,
                      ': ',
                      current_best_score,
                      '\n',
                      aff,
                      '\n',
                      flush=True)

        # check for excessive failure
        if fail_count >= 10 or fail_count >= random_iterations + 1:
            print("Random search failed due to ITK exceptions")
            print("Returning default")
            return np.eye(4)

    # sort
    params = params[np.argsort(scores)]

    # gradient descent based refinements
    if affine_align_best == 0:
        return params_to_affine_matrix(params[0])

    else:
        # container to hold the scores
        scores = np.empty(affine_align_best)
        fail_count = 0  # keep track of failures
        for iii in range(affine_align_best):

            # gradient descent affine alignment
            aff = params_to_affine_matrix(params[iii])
            aff = affine_align(
                fix,
                mov,
                fix_spacing,
                mov_spacing,
                initial_transform=aff,
                fix_mask=fix_mask,
                mov_mask=mov_mask,
                fix_origin=fix_origin,
                mov_origin=mov_origin,
                alignment_spacing=None,  # already done in this function
                **kwargs,
            )

            # score the result
            scores[iii] = score_affine(aff)
            if fail_count >= affine_align_best:
                print("Random search failed due to ITK exceptions")
                print("Returning default")
                return np.eye(4)

        # return the best one
        return params_to_affine_matrix(params[np.argmin(scores)])
Ejemplo n.º 5
0
def patch_mutual_information(
    fix,
    mov,
    spacing,
    radius,
    stride,
    percentile_cutoff=0,
    fix_mask=None,
    mov_mask=None,
    return_metric_image=False,
    **kwargs,
):
    """
    """

    # create sitk versions of data
    fix_sitk = ut.numpy_to_sitk(fix.transpose(2, 1, 0), spacing[::-1])
    fix_sitk = sitk.Cast(fix_sitk, sitk.sitkFloat32)
    mov_sitk = ut.numpy_to_sitk(mov.transpose(2, 1, 0), spacing[::-1])
    mov_sitk = sitk.Cast(mov_sitk, sitk.sitkFloat32)

    # convert to voxel units
    radius = np.round(radius / spacing).astype(np.uint16)
    stride = np.round(stride / spacing).astype(np.uint16)

    # determine patch sample centers
    samples = np.zeros_like(fix)
    sample_points = tuple(slice(r, -r, s) for r, s in zip(radius, stride))
    samples[sample_points] = 1

    # mask sample points
    if fix_mask is not None: samples = samples * fix_mask
    if mov_mask is not None: samples = samples * mov_mask

    # convert to list of coordinates
    samples = np.column_stack(np.nonzero(samples))

    # create irm for evaluating metric
    irm = configure_irm(**kwargs)

    # create container for metric image
    if return_metric_image:
        metric_image = np.zeros(fix.shape, dtype=np.float32)

    # loop over patches and evaluate
    scores = []
    for sample in samples:

        # get patches
        patch = tuple(slice(s - r, s + r + 1) for s, r in zip(sample, radius))
        f = fix_sitk[patch]
        m = mov_sitk[patch]

        # evaluate metric
        try:
            scores.append(irm.MetricEvaluate(f, m))
        except Exception as e:
            scores.append(0)

        # update metric image
        if return_metric_image:
            metric_image[patch] = scores[-1]

    # threshold scores
    scores = np.array(scores)
    if percentile_cutoff > 0:
        cutoff = np.percentile(-scores, percentile_cutoff)
        scores = scores[-scores > cutoff]

    # return results
    if return_metric_image:
        return np.mean(scores), metric_image
    else:
        return np.mean(scores)