Esempio n. 1
0
    def image_filter_wrapper(*args, **kwargs):
        have_array_input = False
        have_xarray_input = False

        args_list = list(args)
        for index, arg in enumerate(args):
            if _HAVE_XARRAY and isinstance(arg, xr.DataArray):
                have_xarray_input = True
                image = itk.image_from_xarray(arg)
                args_list[index] = image
            elif is_arraylike(arg):
                have_array_input = True
                array = np.asarray(arg)
                image = itk.image_view_from_array(array)
                args_list[index] = image

        potential_image_input_kwargs = ('input', 'input1', 'input2', 'input3')
        for key, value in kwargs.items():
            if (key.lower() in potential_image_input_kwargs
                    or "image" in key.lower()):
                if _HAVE_XARRAY and isinstance(value, xr.DataArray):
                    have_xarray_input = True
                    image = itk.image_from_xarray(value)
                    kwargs[key] = image
                elif is_arraylike(value):
                    have_array_input = True
                    array = np.asarray(value)
                    image = itk.image_view_from_array(array)
                    kwargs[key] = image

        if have_xarray_input or have_array_input:
            # Convert output itk.Image's to numpy.ndarray's
            output = image_filter(*tuple(args_list), **kwargs)
            if isinstance(output, tuple):
                output_list = list(output)
                for index, value in output_list:
                    if isinstance(value, itk.Image):
                        if have_xarray_input:
                            data_array = itk.xarray_from_image(value)
                            output_list[index] = data_array
                        else:
                            array = itk.array_from_image(value)
                            output_list[index] = array
                return tuple(output_list)
            else:
                if isinstance(output, itk.Image):
                    if have_xarray_input:
                        output = itk.xarray_from_image(output)
                    else:
                        output = itk.array_from_image(output)
                return output
        else:
            return image_filter(*args, **kwargs)
Esempio n. 2
0
def test_create_atlas():
    import hasi.align

    # Read in images
    images = list()
    for image_file in IMAGE_FILES:
        image_name = str(Path(image_file).stem)
        cid = DATA_INDEX[image_name]
        store = IPFS_FS.get_mapper(f'ipfs://{cid}')
        image_ds = xr.open_zarr(store)
        image_da = image_ds[image_name]
        image = itk.image_from_xarray(image_da)
        images.append(image)

    # Paste into standard space
    images = hasi.align.paste_to_common_space(images)

    # Downsample template image
    TEMPLATE_IDX = 0
    SPARSE_DOWNSAMPLE_RATIO = 14
    template_image = \
        hasi.align.downsample_images([images[TEMPLATE_IDX]],SPARSE_DOWNSAMPLE_RATIO)[0]

    # Downsample dense images
    DENSE_DOWNSAMPLE_RATIO = 2
    images = hasi.align.downsample_images(images, DENSE_DOWNSAMPLE_RATIO)

    # Generate initial template mesh
    FEMUR_OBJECT_PIXEL_VALUE = 1
    template_mesh = hasi.align.binary_image_list_to_meshes(
        [template_image], object_pixel_value=FEMUR_OBJECT_PIXEL_VALUE)[0]

    # Generate meshes
    meshes = hasi.align.binary_image_list_to_meshes(
        images, object_pixel_value=FEMUR_OBJECT_PIXEL_VALUE)

    # Write out sample meshes for later shape analysis
    for idx in range(len(IMAGE_FILES)):
        mesh_file = IMAGE_FILES[idx].replace('.nrrd', '.vtk')
        mesh = meshes[idx]
        itk.meshwrite(mesh, OUTPUT_DIRECTORY + MESH_DIRECTORY + mesh_file)

    # Iteratively refine atlas
    NUM_ITERATIONS = 3
    for iteration in range(NUM_ITERATIONS):
        updated_mesh = hasi.align.refine_template_from_population(
            template_mesh=template_mesh,
            target_meshes=meshes,
            registration_iterations=200)
        distance = hasi.align.get_pairwise_hausdorff_distance(
            updated_mesh, template_mesh)
        template_mesh = updated_mesh

    # Verify final alignment distance
    assert 0.1 < distance < 0.5

    itk.meshwrite(updated_mesh,
                  OUTPUT_DIRECTORY + MESH_DIRECTORY + 'atlas.vtk')
Esempio n. 3
0
    assert np.array_equal(data_array.values, itk.array_from_image(image))
    assert len(data_array.coords["x"]) == 256
    assert len(data_array.coords["y"]) == 256
    assert len(data_array.coords["c"]) == 3
    assert data_array.coords["x"][0] == 30.0
    assert data_array.coords["x"][1] == 30.1
    assert data_array.coords["y"][0] == 44.0
    assert data_array.coords["y"][1] == 44.2
    assert data_array.coords["c"][0] == 0
    assert data_array.coords["c"][1] == 1
    assert data_array.attrs["direction"][0, 0] == cosine
    assert data_array.attrs["direction"][0, 1] == sine
    assert data_array.attrs["direction"][1, 0] == -sine
    assert data_array.attrs["direction"][1, 1] == cosine

    round_trip = itk.image_from_xarray(data_array)
    assert np.array_equal(itk.array_from_image(round_trip),
                          itk.array_from_image(image))
    spacing = round_trip.GetSpacing()
    assert np.isclose(spacing[0], 0.1)
    assert np.isclose(spacing[1], 0.2)
    origin = round_trip.GetOrigin()
    assert np.isclose(origin[0], 30.0)
    assert np.isclose(origin[1], 44.0)
    direction = round_trip.GetDirection()
    assert np.isclose(direction(0, 0), cosine)
    assert np.isclose(direction(0, 1), -sine)
    assert np.isclose(direction(1, 0), sine)
    assert np.isclose(direction(1, 1), cosine)

    wrong_order = data_array.swap_dims({"y": "z"})
Esempio n. 4
0
    assert np.array_equal(data_array.values, itk.array_from_image(image))
    assert len(data_array.coords['x']) == 256
    assert len(data_array.coords['y']) == 256
    assert len(data_array.coords['c']) == 3
    assert data_array.coords['x'][0] == 30.0
    assert data_array.coords['x'][1] == 30.1
    assert data_array.coords['y'][0] == 44.0
    assert data_array.coords['y'][1] == 44.2
    assert data_array.coords['c'][0] == 0
    assert data_array.coords['c'][1] == 1
    assert data_array.attrs['direction'][0, 0] == cosine
    assert data_array.attrs['direction'][0, 1] == sine
    assert data_array.attrs['direction'][1, 0] == -sine
    assert data_array.attrs['direction'][1, 1] == cosine

    round_trip = itk.image_from_xarray(data_array)
    assert np.array_equal(itk.array_from_image(round_trip),
                          itk.array_from_image(image))
    spacing = round_trip.GetSpacing()
    assert np.isclose(spacing[0], 0.1)
    assert np.isclose(spacing[1], 0.2)
    origin = round_trip.GetOrigin()
    assert np.isclose(origin[0], 30.0)
    assert np.isclose(origin[1], 44.0)
    direction = round_trip.GetDirection()
    assert np.isclose(direction(0, 0), cosine)
    assert np.isclose(direction(0, 1), -sine)
    assert np.isclose(direction(1, 0), sine)
    assert np.isclose(direction(1, 1), cosine)

    wrong_order = data_array.swap_dims({'y': 'z'})
Esempio n. 5
0
    def image_filter_wrapper(*args, **kwargs):
        have_array_input = False
        have_xarray_input = False
        have_torch_input = False

        args_list = list(args)
        for index, arg in enumerate(args):
            if _HAVE_XARRAY and isinstance(arg, xr.DataArray):
                have_xarray_input = True
                image = itk.image_from_xarray(arg)
                args_list[index] = image
            elif _HAVE_TORCH and isinstance(arg, torch.Tensor):
                have_torch_input = True
                image = itk.image_view_from_array(np.asarray(arg))
                args_list[index] = image
            elif not isinstance(arg, itk.Object) and is_arraylike(arg):
                have_array_input = True
                array = np.asarray(arg)
                image = itk.image_view_from_array(array)
                args_list[index] = image

        potential_image_input_kwargs = ("input", "input1", "input2", "input3")
        for key, value in kwargs.items():
            if key.lower(
            ) in potential_image_input_kwargs or "image" in key.lower():
                if _HAVE_XARRAY and isinstance(value, xr.DataArray):
                    have_xarray_input = True
                    image = itk.image_from_xarray(value)
                    kwargs[key] = image
                elif _HAVE_TORCH and isinstance(value, torch.Tensor):
                    have_torch_input = True
                    image = itk.image_view_from_array(np.asarray(value))
                    kwargs[key] = image
                elif not isinstance(value, itk.Object) and is_arraylike(value):
                    have_array_input = True
                    array = np.asarray(value)
                    image = itk.image_view_from_array(array)
                    kwargs[key] = image

        if have_xarray_input or have_torch_input or have_array_input:
            # Convert output itk.Image's to numpy.ndarray's
            output = image_filter(*tuple(args_list), **kwargs)
            if isinstance(output, tuple):
                output_list = list(output)
                for index, value in enumerate(output_list):
                    if isinstance(value, itk.Image):
                        if have_xarray_input:
                            data_array = itk.xarray_from_image(value)
                            output_list[index] = data_array
                        elif have_torch_input:
                            data_array = itk.array_view_from_image(value)
                            torch_tensor = torch.from_numpy(data_array)
                            output_list[index] = torch_tensor
                        else:
                            array = itk.array_view_from_image(value)
                            output_list[index] = array
                return tuple(output_list)
            else:
                if isinstance(output, itk.Image):
                    if have_xarray_input:
                        output = itk.xarray_from_image(output)
                    elif have_torch_input:
                        output = itk.array_view_from_image(output)
                        output = torch.from_numpy(output)
                    else:
                        output = itk.array_view_from_image(output)
                return output
        else:
            return image_filter(*args, **kwargs)
    def __init__(self,
                 tree=None,
                 swc_morphologies=[],
                 markers=[],
                 marker_sizes=[],
                 marker_opacities=[],
                 marker_colors=[],
                 selected_allen_ids=None,
                 selected_acronyms=None,
                 rotate=False,
                 **kwargs):
        """Create a 3D CCF visualization ipywidget.

        Parameters
        ----------
        tree: None or 'ipytree', optional, default: None
            Structure tree visualization to include.

        swc_morphologies: List, optional, default: []
            List of Allen SWC morphologies to render.

        markers: List of Nx3 arrays, optional, default: []
            Points locations to visualize in the CCF. Each element in the list
            corresponds to a different set of markers. Each marker set has N points,
            with point locations in CCF coordinates:
                [anterior_posterior, dorsal_ventral, left_right]

        marker_sizes: List of integers in [1, 10], optional, default: []
            Size of the markers.

        marker_opacities: array of floats, default: [1.0,]*n
            Opacity for the markers, in the range (0.0, 1.0].

        marker_colors: list of (r, g, b) colors
            Colors for the N markers. See help(matplotlib.colors) for
            specification. Defaults to the Glasbey series of categorical colors.

        selected_allen_ids: List of Allen ids to highlight, optional, default: None
            List of integer Allen Structure Graph ids to highlight. Specify
            selected_allen_ids or selected_acronyms.

        selected_acronyms: List of Allen acronyms to highlight, optional, default: None
            List of string Allen Structure Graph acronyms to highlight. Specify
            selected_allen_ids or selected_acronyms.

        rotate: bool, optional, default: False
            Make the CCF continuously rotate.
        """
        self._image = itk.image_from_xarray(_image_da)
        self._label_image = itk.image_from_xarray(_label_image_da)
        self.swc_point_sets = []
        self.swc_geometries = []
        opacity_gaussians = [[{
            'position': 0.28094135802469133,
            'height': 0.3909090909090909,
            'width': 0.44048611111111113,
            'xBias': 0.21240499194846996,
            'yBias': 0.5416908212560397
        }, {
            'position': 0.2787808641975309,
            'height': 1,
            'width': 0.1,
            'xBias': 0,
            'yBias': 0
        }]]
        opacity_gaussians = [[{
            'position': 0.32816358024691356,
            'height': 0.5,
            'width': 0.29048611111111106,
            'xBias': 0.20684943639291442,
            'yBias': 1.1235090030742216
        }]]
        camera = np.array([[1.3441567e+03, -2.1723846e+04, 1.7496496e+04],
                           [6.5500000e+03, 3.9750000e+03, 5.6750000e+03],
                           [3.6606243e-01, -4.4908229e-01, -8.1506038e-01]],
                          dtype=np.float32)
        size_limit_3d = [256, 256, 256]
        self.itk_viewer = view(image=self._image,
                               label_image=self._label_image,
                               opacity_gaussians=opacity_gaussians,
                               label_image_blend=0.65,
                               point_sets=markers.copy(),
                               camera=camera,
                               ui_collapsed=True,
                               shadow=False,
                               size_limit_3d=size_limit_3d,
                               background=(0.85, ) * 3,
                               units="μm",
                               gradient_opacity=0.1)
        # Todo: initialization should work
        self.itk_viewer.opacity_gaussians = opacity_gaussians
        self.itk_viewer.rotate = rotate
        self.itk_viewer.label_image_blend = 0.65

        mode_buttons = RadioButtons(options=['x', 'y', 'z', 'v'],
                                    value='v',
                                    description='View mode:')
        link((mode_buttons, 'value'), (self.itk_viewer, 'mode'))

        rotate_checkbox = Checkbox(value=rotate, description='Rotate')
        link((rotate_checkbox, 'value'), (self.itk_viewer, 'rotate'))
        viewer_controls = HBox([mode_buttons, rotate_checkbox])

        viewer = VBox([self.itk_viewer, viewer_controls])

        children = [viewer]

        self._validating_allen_ids = False
        self._validating_acronyms = False
        self._validating_tree = False

        self.tree_widget = None
        if tree is not None:
            if tree == 'ipytree':
                from .ipytree_widget import IPyTreeWidget
                self.tree_widget = IPyTreeWidget(structure_graph)
                children.append(self.tree_widget)
                self.tree_widget.observe(self._ipytree_on_selected_change,
                                         names=['selected_nodes'])

                def open_parent(node):
                    if hasattr(node, 'parent_structure_id') and \
                            node.parent_structure_id in self.tree_widget.allen_id_to_node:
                        parent = self.tree_widget.allen_id_to_node[
                            node.parent_structure_id]
                        open_parent(parent)
                    else:
                        node.opened = True

                self.last_selected_tree_nodes = []

                def ipytree_on_allen_ids_changed(change):
                    self._validating_tree = True
                    with self.tree_widget.hold_sync():
                        for node in self.last_selected_tree_nodes:
                            node.selected = False
                        tree_nodes = [
                            self.tree_widget.allen_id_to_node[allen_id]
                            for allen_id in change.new
                        ]
                        for node in tree_nodes:
                            open_parent(node)
                            node.selected = True
                            node.opened = True
                    self.last_selected_tree_nodes = tree_nodes
                    self._validating_tree = False

                # self.observe(ipytree_on_allen_ids_changed, names='selected_allen_ids')
            else:
                raise RuntimeError('Invalid tree type')

        self.labels = np.unique(self.itk_viewer.rendered_label_image)

        super(CCFWidget, self).__init__(children, **kwargs)

        for morphology in swc_morphologies:
            soma_point_set, geometry = swc_morphology_geometry(morphology)
            self.swc_point_sets.append(soma_point_set)
            self.swc_geometries.append(geometry)
        if swc_morphologies:
            self.itk_viewer.point_sets = self.swc_point_sets
            self.itk_viewer.point_set_sizes = [
                3,
            ] * len(self.swc_point_sets)
            self.itk_viewer.point_set_opacities = [
                1.0,
            ] * len(self.swc_point_sets)
            self.itk_viewer.point_set_colors = [
                (1.0, 0.0, 0.0),
            ] * len(self.swc_point_sets)
            self.itk_viewer.geometries = self.swc_geometries

        if selected_acronyms:
            self.selected_acronyms = selected_acronyms
        if selected_allen_ids:
            self.selected_allen_ids = selected_allen_ids

        self.markers = markers
        if marker_sizes:
            self.marker_sizes = marker_sizes
        if marker_opacities:
            self.marker_opacities = marker_opacities
        if marker_colors:
            self.marker_colors = marker_colors
Esempio n. 7
0
    def image_filter_wrapper(*args, **kwargs):
        have_array_input = False
        have_xarray_input = False
        have_torch_input = False

        args_list = list(args)
        for index, arg in enumerate(args):
            if _HAVE_XARRAY and isinstance(arg, xr.DataArray):
                have_xarray_input = True
                image = itk.image_from_xarray(arg)
                args_list[index] = image
            elif _HAVE_TORCH and isinstance(arg, torch.Tensor):
                have_torch_input = True
                channels = arg.shape[0]  # assume first dimension is channels
                arr = np.asarray(arg)
                if channels > 1:  # change from contiguous to interleaved channel order
                    arr = move_last_dimension_to_first(arr)
                image = itk.image_view_from_array(arr, is_vector=channels > 1)
                args_list[index] = image
            elif not isinstance(arg, itk.Object) and is_arraylike(arg):
                have_array_input = True
                array = np.asarray(arg)
                image = itk.image_view_from_array(array)
                args_list[index] = image

        potential_image_input_kwargs = ("input", "input1", "input2", "input3")
        for key, value in kwargs.items():
            if key.lower(
            ) in potential_image_input_kwargs or "image" in key.lower():
                if _HAVE_XARRAY and isinstance(value, xr.DataArray):
                    have_xarray_input = True
                    image = itk.image_from_xarray(value)
                    kwargs[key] = image
                elif _HAVE_TORCH and isinstance(value, torch.Tensor):
                    have_torch_input = True
                    channels = value.shape[
                        0]  # assume first dimension is channels
                    arr = np.asarray(value)
                    if (
                            channels > 1
                    ):  # change from contiguous to interleaved channel order
                        arr = move_last_dimension_to_first(arr)
                    image = itk.image_view_from_array(arr,
                                                      is_vector=channels > 1)
                    kwargs[key] = image
                elif not isinstance(value, itk.Object) and is_arraylike(value):
                    have_array_input = True
                    array = np.asarray(value)
                    image = itk.image_view_from_array(array)
                    kwargs[key] = image

        if have_xarray_input or have_torch_input or have_array_input:
            # Convert output itk.Image's to numpy.ndarray's
            output = image_filter(*tuple(args_list), **kwargs)
            if isinstance(output, tuple):
                output_list = list(output)
                for index, value in enumerate(output_list):
                    if isinstance(value, itk.Image):
                        if have_xarray_input:
                            data_array = itk.xarray_from_image(value)
                            output_list[index] = data_array
                        elif have_torch_input:
                            channels = value.GetNumberOfComponentsPerPixel()
                            data_array = itk.array_view_from_image(value)
                            if (
                                    channels > 1
                            ):  # change from interleaved to contiguous channel order
                                data_array = move_first_dimension_to_last(
                                    data_array)
                            torch_tensor = torch.from_numpy(data_array)
                            output_list[index] = torch_tensor
                        else:
                            array = itk.array_view_from_image(value)
                            output_list[index] = array
                return tuple(output_list)
            else:
                if isinstance(output, itk.Image):
                    if have_xarray_input:
                        output = itk.xarray_from_image(output)
                    elif have_torch_input:
                        channels = output.GetNumberOfComponentsPerPixel()
                        output = itk.array_view_from_image(output)
                        if (
                                channels > 1
                        ):  # change from interleaved to contiguous channel order
                            output = move_first_dimension_to_last(output)
                        output = torch.from_numpy(output)
                    else:
                        output = itk.array_view_from_image(output)
                return output
        else:
            return image_filter(*args, **kwargs)