class LineProfiler(Viewer):
    """LineProfiler widget class."""
    _view_name = Unicode('LineProfilerView').tag(sync=True)
    _model_name = Unicode('LineProfilerModel').tag(sync=True)
    _view_module = Unicode('itk-jupyter-widgets').tag(sync=True)
    _model_module = Unicode('itk-jupyter-widgets').tag(sync=True)
    _view_module_version = Unicode('^0.17.0').tag(sync=True)
    _model_module_version = Unicode('^0.17.0').tag(sync=True)
    point1 = NDArray(dtype=np.float64, default_value=np.zeros((3,), dtype=np.float64),
                help="First point in physical space that defines the line profile")\
            .tag(sync=True, **array_serialization)\
            .valid(shape_constraints(3,))
    point2 = NDArray(dtype=np.float64, default_value=np.ones((3,), dtype=np.float64),
                help="First point in physical space that defines the line profile")\
            .tag(sync=True, **array_serialization)\
            .valid(shape_constraints(3,))
    _select_initial_points = CBool(
        default_value=False,
        help="We will select the initial points for the line profile.").tag(
            sync=True)

    def __init__(self, **kwargs):
        if 'point1' not in kwargs or 'point2' not in kwargs:
            self._select_initial_points = True
            # Default to z-plane mode instead of the 3D volume if we need to
            # select points
            if 'mode' not in kwargs:
                kwargs['mode'] = 'z'
        if 'ui_collapsed' not in kwargs:
            kwargs['ui_collapsed'] = True
        super(LineProfiler, self).__init__(**kwargs)
Ejemplo n.º 2
0
class Viewer3d(DOMWidget):
    _model_name = Unicode('Viewer3dModel').tag(sync=True)
    _model_module = Unicode(MODULE_NAME).tag(sync=True)
    _model_module_version = Unicode(EXTENSION_SPEC_VERSION).tag(sync=True)

    _view_name = Unicode('Viewer3dView').tag(sync=True)
    _view_module = Unicode(MODULE_NAME).tag(sync=True)
    _view_module_version = Unicode(EXTENSION_SPEC_VERSION).tag(sync=True)

    image = NDArray(dtype=np.float64,
                    default_value=np.zeros(0, dtype=np.float64)).tag(
                        sync=True, **array_serialization)

    # server = Unicode("http://atlantis.sci.utah.edu").tag(sync=True)
    # dataset = Unicode('').tag(sync=True)
    # tile_size = Int(512).tag(sync=True)

    def __init__(self, *args, **kwargs):
        super().__init__(**kwargs)
Ejemplo n.º 3
0
class Viewer(ViewerParent):
    """Viewer widget class."""
    _view_name = Unicode('ViewerView').tag(sync=True)
    _model_name = Unicode('ViewerModel').tag(sync=True)
    _view_module = Unicode('itkwidgets').tag(sync=True)
    _model_module = Unicode('itkwidgets').tag(sync=True)
    _view_module_version = Unicode('^0.26.1').tag(sync=True)
    _model_module_version = Unicode('^0.26.1').tag(sync=True)
    image = ITKImage(
        default_value=None,
        allow_none=True,
        help="Image to visualize.").tag(
        sync=False,
        **itkimage_serialization)
    rendered_image = ITKImage(
        default_value=None,
        allow_none=True).tag(
        sync=True,
        **itkimage_serialization)
    _rendering_image = CBool(
        default_value=False,
        help="We are currently volume rendering the image.").tag(sync=True)
    interpolation = CBool(
        default_value=True,
        help="Use linear interpolation in slicing planes.").tag(sync=True)
    cmap = Colormap('Viridis (matplotlib)').tag(sync=True)
    _custom_cmap = NDArray(dtype=np.float32, default_value=None, allow_none=True,
                           help="RGB triples from 0.0 to 1.0 that define a custom linear, sequential colormap")\
        .tag(sync=True, **array_serialization)\
        .valid(shape_constraints(None, 3))
    shadow = CBool(
        default_value=True,
        help="Use shadowing in the volume rendering.").tag(sync=True)
    slicing_planes = CBool(
        default_value=False,
        help="Display the slicing planes in volume rendering view mode.").tag(
        sync=True)
    x_slice = CFloat(
        default_value=None,
        allow_none=True,
        help="World-space position of the X slicing plane.").tag(sync=True)
    y_slice = CFloat(
        default_value=None,
        allow_none=True,
        help="World-space position of the Y slicing plane.").tag(sync=True)
    z_slice = CFloat(
        default_value=None,
        allow_none=True,
        help="World-space position of the Z slicing plane.").tag(sync=True)
    gradient_opacity = CFloat(
        default_value=0.2,
        help="Volume rendering gradient opacity, from (0.0, 1.0]").tag(sync=True)
    blend = CaselessStrEnum(
        ('composite',
         'max',
         'min',
         'average'),
        default_value='composite',
        help="Volume rendering blend mode").tag(sync=True)
    roi = NDArray(dtype=np.float64, default_value=np.zeros((2, 3), dtype=np.float64),
                  help="Region of interest: [[lower_x, lower_y, lower_z), (upper_x, upper_y, upper_z]]")\
        .tag(sync=True, **array_serialization)\
        .valid(shape_constraints(2, 3))
    vmin = CFloat(
        default_value=None,
        allow_none=True,
        help="Value that maps to the minimum of image colormap.").tag(
        sync=True)
    vmax = CFloat(
        default_value=None,
        allow_none=True,
        help="Value that maps to the maximum of image colormap.").tag(
        sync=True)
    _largest_roi = NDArray(dtype=np.float64, default_value=np.zeros((2, 3), dtype=np.float64),
                           help="Largest possible region of interest: "
                           "[[lower_x, lower_y, lower_z), (upper_x, upper_y, upper_z]]")\
        .tag(sync=True, **array_serialization)\
        .valid(shape_constraints(2, 3))
    select_roi = CBool(
        default_value=False,
        help="Enable an interactive region of interest widget for the image.").tag(
        sync=True)
    size_limit_2d = NDArray(dtype=np.int64, default_value=np.array([1024, 1024], dtype=np.int64),
                            help="Size limit for 2D image visualization.").tag(sync=False)
    size_limit_3d = NDArray(dtype=np.int64, default_value=np.array([192, 192, 192], dtype=np.int64),
                            help="Size limit for 3D image visualization.").tag(sync=False)
    _scale_factors = NDArray(dtype=np.uint8, default_value=np.array([1, 1, 1], dtype=np.uint8),
                             help="Image downscaling factors.").tag(sync=True, **array_serialization)
    _downsampling = CBool(default_value=False,
                          help="We are downsampling the image to meet the size limits.").tag(sync=True)
    _reset_crop_requested = CBool(default_value=False,
                                  help="The user requested a reset of the roi.").tag(sync=True)
    units = Unicode(
        '',
        help="Units to display in the scale bar.").tag(
        sync=True)
    point_set_representations = List(
        trait=Unicode(),
        default_value=[],
        help="Point set representation").tag(
        sync=True)
    point_sets = PointSetList(
        default_value=None,
        allow_none=True,
        help="Point sets to visualize").tag(
        sync=True,
        **polydata_list_serialization)
    point_set_colors = NDArray(dtype=np.float32, default_value=np.zeros((0, 3), dtype=np.float32),
                               help="RGB colors for the points sets")\
        .tag(sync=True, **array_serialization)\
        .valid(shape_constraints(None, 3))
    point_set_opacities = NDArray(dtype=np.float32, default_value=np.zeros((0,), dtype=np.float32),
                                  help="Opacities for the points sets")\
        .tag(sync=True, **array_serialization)\
        .valid(shape_constraints(None,))
    point_set_representations = List(
        trait=Unicode(),
        default_value=[],
        help="Point set representation").tag(
        sync=True)
    geometries = PolyDataList(
        default_value=None,
        allow_none=True,
        help="Geometries to visualize").tag(
        sync=True,
        **polydata_list_serialization)
    geometry_colors = NDArray(dtype=np.float32, default_value=np.zeros((0, 3), dtype=np.float32),
                              help="RGB colors for the geometries")\
        .tag(sync=True, **array_serialization)\
        .valid(shape_constraints(None, 3))
    geometry_opacities = NDArray(dtype=np.float32, default_value=np.zeros((0,), dtype=np.float32),
                                 help="Opacities for the geometries")\
        .tag(sync=True, **array_serialization)\
        .valid(shape_constraints(None,))
    ui_collapsed = CBool(
        default_value=False,
        help="Collapse the built in user interface.").tag(
        sync=True)
    rotate = CBool(
        default_value=False,
        help="Rotate the camera around the scene.").tag(
        sync=True)
    annotations = CBool(
        default_value=True,
        help="Show annotations.").tag(
        sync=True)
    mode = CaselessStrEnum(
        ('x',
         'y',
         'z',
         'v'),
        default_value='v',
        help="View mode: x: x plane, y: y plane, z: z plane, v: volume rendering").tag(
        sync=True)
    camera = NDArray(dtype=np.float32, default_value=np.zeros((3, 3), dtype=np.float32),
                     help="Camera parameters: [[position_x, position_y, position_z], "
                     "[focal_point_x, focal_point_y, focal_point_z], "
                     "[view_up_x, view_up_y, view_up_z]]")\
        .tag(sync=True, **array_serialization)\
        .valid(shape_constraints(3, 3))

    def __init__(self, **kwargs):  # noqa: C901
        if 'point_set_colors' in kwargs:
            proposal = {'value': kwargs['point_set_colors']}
            color_array = self._validate_point_set_colors(proposal)
            kwargs['point_set_colors'] = color_array
        if 'point_set_opacities' in kwargs:
            proposal = {'value': kwargs['point_set_opacities']}
            opacities_array = self._validate_point_set_opacities(proposal)
            kwargs['point_set_opacities'] = opacities_array
        if 'point_set_representations' in kwargs:
            proposal = {'value': kwargs['point_set_representations']}
            representations_list = self._validate_point_set_representations(
                proposal)
            kwargs['point_set_representations'] = representations_list
        self.observe(self._on_point_sets_changed, ['point_sets'])
        if 'geometry_colors' in kwargs:
            proposal = {'value': kwargs['geometry_colors']}
            color_array = self._validate_geometry_colors(proposal)
            kwargs['geometry_colors'] = color_array
        if 'geometry_opacities' in kwargs:
            proposal = {'value': kwargs['geometry_opacities']}
            opacities_array = self._validate_geometry_opacities(proposal)
            kwargs['geometry_opacities'] = opacities_array
        self.observe(self._on_geometries_changed, ['geometries'])

        super(Viewer, self).__init__(**kwargs)

        if not self.image:
            return
        dimension = self.image.GetImageDimension()
        largest_region = self.image.GetLargestPossibleRegion()
        size = largest_region.GetSize()

        # Cache this so we do not need to recompute on it when resetting the
        # roi
        self._largest_roi_rendered_image = None
        self._largest_roi = np.zeros((2, 3), dtype=np.float64)
        if not np.any(self.roi):
            largest_index = largest_region.GetIndex()
            self.roi[0][:dimension] = np.array(
                self.image.TransformIndexToPhysicalPoint(largest_index))
            largest_index_upper = largest_index + size
            self.roi[1][:dimension] = np.array(
                self.image.TransformIndexToPhysicalPoint(largest_index_upper))
            self._largest_roi = self.roi.copy()

        if dimension == 2:
            for dim in range(dimension):
                if size[dim] > self.size_limit_2d[dim]:
                    self._downsampling = True
        else:
            for dim in range(dimension):
                if size[dim] > self.size_limit_3d[dim]:
                    self._downsampling = True
        if self._downsampling:
            self.extractor = itk.ExtractImageFilter.New(self.image)
            self.shrinker = itk.BinShrinkImageFilter.New(self.extractor)
        self._update_rendered_image()
        if self._downsampling:
            self.observe(self._on_roi_changed, ['roi'])

        self.observe(self._on_reset_crop_requested, ['_reset_crop_requested'])
        self.observe(self.update_rendered_image, ['image'])

    def _on_roi_changed(self, change=None):
        if self._downsampling:
            self._update_rendered_image()

    def _on_reset_crop_requested(self, change=None):
        if change.new is True and self._downsampling:
            dimension = self.image.GetImageDimension()
            largest_region = self.image.GetLargestPossibleRegion()
            size = largest_region.GetSize()
            largest_index = largest_region.GetIndex()
            new_roi = self.roi.copy()
            new_roi[0][:dimension] = np.array(
                self.image.TransformIndexToPhysicalPoint(largest_index))
            largest_index_upper = largest_index + size
            new_roi[1][:dimension] = np.array(
                self.image.TransformIndexToPhysicalPoint(largest_index_upper))
            self._largest_roi = new_roi.copy()
            self.roi = new_roi
        if change.new is True:
            self._reset_crop_requested = False

    @debounced(delay_seconds=0.2, method=True)
    def update_rendered_image(self, change=None):
        self._largest_roi_rendered_image = None
        self._largest_roi = np.zeros((2, 3), dtype=np.float64)
        self._update_rendered_image()

    @staticmethod
    def _find_scale_factors(limit, dimension, size):
        scale_factors = [1, ] * 3
        for dim in range(dimension):
            while(int(np.floor(float(size[dim]) / scale_factors[dim])) > limit[dim]):
                scale_factors[dim] += 1
        return scale_factors

    def _update_rendered_image(self):
        if self.image is None:
            return
        if self._rendering_image:
            @yield_for_change(self, '_rendering_image')
            def f():
                x = yield
                assert(x is False)
            f()
        self._rendering_image = True

        if self._downsampling:
            dimension = self.image.GetImageDimension()
            index = self.image.TransformPhysicalPointToIndex(
                self.roi[0][:dimension])
            upper_index = self.image.TransformPhysicalPointToIndex(
                self.roi[1][:dimension])
            size = upper_index - index

            if dimension == 2:
                scale_factors = self._find_scale_factors(
                    self.size_limit_2d, dimension, size)
            else:
                scale_factors = self._find_scale_factors(
                    self.size_limit_3d, dimension, size)
            self._scale_factors = np.array(scale_factors, dtype=np.uint8)
            self.shrinker.SetShrinkFactors(scale_factors[:dimension])

            region = itk.ImageRegion[dimension]()
            region.SetIndex(index)
            region.SetSize(tuple(size))
            # Account for rounding
            # truncation issues
            region.PadByRadius(1)
            region.Crop(self.image.GetLargestPossibleRegion())

            self.extractor.SetInput(self.image)
            self.extractor.SetExtractionRegion(region)

            size = region.GetSize()

            is_largest = False
            if np.any(self._largest_roi) and np.all(
                    self._largest_roi == self.roi):
                is_largest = True
                if self._largest_roi_rendered_image is not None:
                    self.rendered_image = self._largest_roi_rendered_image
                    return

            self.shrinker.UpdateLargestPossibleRegion()
            if is_largest:
                self._largest_roi_rendered_image = self.shrinker.GetOutput()
                self._largest_roi_rendered_image.DisconnectPipeline()
                self._largest_roi_rendered_image.SetOrigin(
                    self.roi[0][:dimension])
                self.rendered_image = self._largest_roi_rendered_image
                return
            shrunk = self.shrinker.GetOutput()
            shrunk.DisconnectPipeline()
            shrunk.SetOrigin(self.roi[0][:dimension])
            self.rendered_image = shrunk
        else:
            self.rendered_image = self.image

    @validate('gradient_opacity')
    def _validate_gradient_opacity(self, proposal):
        """Enforce 0 < value <= 1.0."""
        value = proposal['value']
        if value <= 0.0:
            return 0.01
        if value > 1.0:
            return 1.0
        return value

    @validate('point_set_colors')
    def _validate_point_set_colors(self, proposal):
        value = proposal['value']
        n_colors = len(value)
        if self.point_sets:
            n_colors = len(self.point_sets)
        result = np.zeros((n_colors, 3), dtype=np.float32)
        for index, color in enumerate(value):
            result[index, :] = matplotlib.colors.to_rgb(color)
        if len(value) < n_colors:
            for index in range(len(value), n_colors):
                color = colorcet.glasbey[index % len(colorcet.glasbey)]
                result[index, :] = matplotlib.colors.to_rgb(color)
        return result

    @validate('point_set_opacities')
    def _validate_point_set_opacities(self, proposal):
        value = proposal['value']
        n_values = 0
        if isinstance(value, float):
            n_values = 1
        else:
            n_values = len(value)
        n_opacities = n_values
        if self.point_sets:
            n_opacities = len(self.point_sets)
        result = np.ones((n_opacities,), dtype=np.float32)
        result[:n_values] = value
        return result

    @validate('point_set_representations')
    def _validate_point_set_representations(self, proposal):
        value = proposal['value']
        n_values = 0
        if isinstance(value, str):
            n_values = 1
        else:
            n_values = len(value)
        n_representations = n_values
        if self.point_sets:
            n_representations = len(self.point_sets)
        result = ['points'] * n_representations
        result[:n_values] = value
        return result

    def _on_point_sets_changed(self, change=None):
        # Make sure we have a sufficient number of colors
        old_colors = self.point_set_colors
        self.point_set_colors = old_colors[:len(self.point_sets)]
        # Make sure we have a sufficient number of opacities
        old_opacities = self.point_set_opacities
        self.point_set_opacities = old_opacities[:len(self.point_sets)]
        # Make sure we have a sufficient number of representations
        old_representations = self.point_set_representations
        self.point_set_representations = old_representations[:len(
            self.point_sets)]

    @validate('geometry_colors')
    def _validate_geometry_colors(self, proposal):
        value = proposal['value']
        n_colors = len(value)
        if self.geometries:
            n_colors = len(self.geometries)
        result = np.zeros((n_colors, 3), dtype=np.float32)
        for index, color in enumerate(value):
            result[index, :] = matplotlib.colors.to_rgb(color)
        if len(value) < n_colors:
            for index in range(len(value), n_colors):
                color = colorcet.glasbey[index % len(colorcet.glasbey)]
                result[index, :] = matplotlib.colors.to_rgb(color)
        return result

    @validate('geometry_opacities')
    def _validate_geometry_opacities(self, proposal):
        value = proposal['value']
        n_values = 0
        if isinstance(value, float):
            n_values = 1
        else:
            n_values = len(value)
        n_opacities = n_values
        if self.geometries:
            n_opacities = len(self.geometries)
        result = np.ones((n_opacities,), dtype=np.float32)
        result[:n_values] = value
        return result

    def _on_geometries_changed(self, change=None):
        # Make sure we have a sufficient number of colors
        old_colors = self.geometry_colors
        self.geometry_colors = old_colors[:len(self.geometries)]
        # Make sure we have a sufficient number of opacities
        old_opacities = self.geometry_opacities
        self.geometry_opacities = old_opacities[:len(self.geometries)]

    def roi_region(self):
        """Return the itk.ImageRegion corresponding to the roi."""
        dimension = self.image.GetImageDimension()
        index = self.image.TransformPhysicalPointToIndex(
            tuple(self.roi[0][:dimension]))
        upper_index = self.image.TransformPhysicalPointToIndex(
            tuple(self.roi[1][:dimension]))
        size = upper_index - index
        for dim in range(dimension):
            size[dim] += 1
        region = itk.ImageRegion[dimension]()
        region.SetIndex(index)
        region.SetSize(tuple(size))
        region.Crop(self.image.GetLargestPossibleRegion())
        return region

    def roi_slice(self):
        """Return the numpy array slice corresponding to the roi."""
        dimension = self.image.GetImageDimension()
        region = self.roi_region()
        index = region.GetIndex()
        upper_index = np.array(index) + np.array(region.GetSize())
        slices = []
        for dim in range(dimension):
            slices.insert(0, slice(index[dim], upper_index[dim] + 1))
        return tuple(slices)
Ejemplo n.º 4
0
class LineProfiler(Viewer):
    """LineProfiler widget class."""
    _view_name = Unicode('LineProfilerView').tag(sync=True)
    _model_name = Unicode('LineProfilerModel').tag(sync=True)
    _view_module = Unicode('itkwidgets').tag(sync=True)
    _model_module = Unicode('itkwidgets').tag(sync=True)
    _view_module_version = Unicode('^0.32.0').tag(sync=True)
    _model_module_version = Unicode('^0.32.0').tag(sync=True)
    point1 = NDArray(dtype=np.float64, default_value=np.zeros((3,), dtype=np.float64),
                     help="First point in physical space that defines the line profile")\
        .tag(sync=True, **array_serialization)\
        .valid(shape_constraints(3,))
    point2 = NDArray(dtype=np.float64, default_value=np.ones((3,), dtype=np.float64),
                     help="First point in physical space that defines the line profile")\
        .tag(sync=True, **array_serialization)\
        .valid(shape_constraints(3,))
    _select_initial_points = CBool(
        default_value=False,
        help="We will select the initial points for the line profile.").tag(
            sync=True)

    def __init__(self, image, order, **kwargs):
        self.image = image
        self.order = order
        if 'point1' not in kwargs or 'point2' not in kwargs:
            self._select_initial_points = True
            # Default to z-plane mode instead of the 3D volume if we need to
            # select points
            if 'mode' not in kwargs:
                kwargs['mode'] = 'z'
        if 'ui_collapsed' not in kwargs:
            kwargs['ui_collapsed'] = True
        super(LineProfiler, self).__init__(**kwargs)

    def get_profile(self,
                    image_or_array=None,
                    point1=None,
                    point2=None,
                    order=None):
        """Calculate the line profile.

        Calculate the pixel intensity values along the line that connects
        the given two points.

        The image can be 2D or 3D. If any/all of the parameters are None, default
        vales are assigned.

        Parameters
        ----------
        image_or_array : array_like, itk.Image, or vtk.vtkImageData
            The 2D or 3D image to visualize.

        point1 : list of float
            List elements represent the 2D/3D coordinate of the point1.

        point2 : list of float
            List elements represent the 2D/3D coordinate of the point2.

        order : int, optional
            Spline order for line profile interpolation. The order has to be in the
            range 0-5.

        """

        if not have_scipy:
            raise RuntimeError(
                'The scipy package in necessary for the line_profiler widget.')
        if not have_itk:
            raise RuntimeError(
                'The itk package in necessary for the line_profiler widget.')

        if image_or_array is None:
            image_or_array = self.image
        if point1 is None:
            point1 = self.point1
        if point2 is None:
            point2 = self.point2
        if order is None:
            order = self.order
        image = to_itk_image(image_or_array)
        image_array = itk.array_view_from_image(image)
        dimension = image.GetImageDimension()
        distance = np.sqrt(
            sum([(point1[ii] - point2[ii])**2 for ii in range(dimension)]))
        index1 = tuple(
            image.TransformPhysicalPointToIndex(tuple(point1[:dimension])))
        index2 = tuple(
            image.TransformPhysicalPointToIndex(tuple(point2[:dimension])))
        num_points = int(
            np.round(
                np.sqrt(
                    sum([(index1[ii] - index2[ii])**2
                         for ii in range(dimension)])) * 2.1))
        coords = [
            np.linspace(index1[ii], index2[ii], num_points)
            for ii in range(dimension)
        ]
        mapped = scipy.ndimage.map_coordinates(image_array,
                                               np.vstack(coords[::-1]),
                                               order=order,
                                               mode='nearest')
        return np.linspace(0.0, distance, num_points), mapped
class Viewer(ViewerParent):
    """Viewer widget class."""
    _view_name = Unicode('ViewerView').tag(sync=True)
    _model_name = Unicode('ViewerModel').tag(sync=True)
    _view_module = Unicode('itk-jupyter-widgets').tag(sync=True)
    _model_module = Unicode('itk-jupyter-widgets').tag(sync=True)
    _view_module_version = Unicode('^0.15.2').tag(sync=True)
    _model_module_version = Unicode('^0.15.2').tag(sync=True)
    image = ITKImage(default_value=None,
                     allow_none=True,
                     help="Image to visualize.").tag(sync=False,
                                                     **itkimage_serialization)
    rendered_image = ITKImage(default_value=None,
                              allow_none=True).tag(sync=True,
                                                   **itkimage_serialization)
    _rendering_image = CBool(
        default_value=False,
        help="We are currently volume rendering the image.").tag(sync=True)
    ui_collapsed = CBool(
        default_value=False,
        help="Collapse the built in user interface.").tag(sync=True)
    annotations = CBool(default_value=True,
                        help="Show annotations.").tag(sync=True)
    mode = CaselessStrEnum(
        ('x', 'y', 'z', 'v'),
        default_value='v',
        help="View mode: x: x plane, y: y plane, z: z plane, v: volume rendering"
    ).tag(sync=True)
    interpolation = CBool(
        default_value=True,
        help="Use linear interpolation in slicing planes.").tag(sync=True)
    cmap = Unicode('Viridis (matplotlib)').tag(sync=True)
    shadow = CBool(
        default_value=True,
        help="Use shadowing in the volume rendering.").tag(sync=True)
    slicing_planes = CBool(
        default_value=False,
        help="Display the slicing planes in volume rendering view mode.").tag(
            sync=True)
    gradient_opacity = CFloat(
        default_value=0.2,
        help="Volume rendering gradient opacity, from (0.0, 1.0]").tag(
            sync=True)
    roi = NDArray(dtype=np.float64, default_value=np.zeros((2, 3), dtype=np.float64),
                help="Region of interest: ((lower_x, lower_y, lower_z), (upper_x, upper_y, upper_z))")\
            .tag(sync=True, **array_serialization)\
            .valid(shape_constraints(2, 3))
    _largest_roi = NDArray(dtype=np.float64, default_value=np.zeros((2, 3), dtype=np.float64),
                help="Largest possible region of interest: ((lower_x, lower_y, lower_z), (upper_x, upper_y, upper_z))")\
            .tag(sync=True, **array_serialization)\
            .valid(shape_constraints(2, 3))
    select_roi = CBool(
        default_value=False,
        help="Enable an interactive region of interest widget for the image."
    ).tag(sync=True)
    size_limit_2d = NDArray(
        dtype=np.int64,
        default_value=np.array([1024, 1024], dtype=np.int64),
        help="Size limit for 2D image visualization.").tag(sync=False)
    size_limit_3d = NDArray(
        dtype=np.int64,
        default_value=np.array([192, 192, 192], dtype=np.int64),
        help="Size limit for 3D image visualization.").tag(sync=False)
    _downsampling = CBool(
        default_value=False,
        help="We are downsampling the image to meet the size limits.").tag(
            sync=True)
    _reset_crop_requested = CBool(
        default_value=False,
        help="The user requested a reset of the roi.").tag(sync=True)

    def __init__(self, **kwargs):
        super(Viewer, self).__init__(**kwargs)
        dimension = self.image.GetImageDimension()
        largest_region = self.image.GetLargestPossibleRegion()
        size = largest_region.GetSize()

        # Cache this so we do not need to recompute on it when resetting the roi
        self._largest_roi_rendered_image = None
        self._largest_roi = np.zeros((2, 3), dtype=np.float64)
        if not np.any(self.roi):
            largest_index = largest_region.GetIndex()
            self.roi[0][:dimension] = np.array(
                self.image.TransformIndexToPhysicalPoint(largest_index))
            largest_index_upper = largest_index + size
            self.roi[1][:dimension] = np.array(
                self.image.TransformIndexToPhysicalPoint(largest_index_upper))
            self._largest_roi = self.roi.copy()

        if dimension == 2:
            for dim in range(dimension):
                if size[dim] > self.size_limit_2d[dim]:
                    self._downsampling = True
        else:
            for dim in range(dimension):
                if size[dim] > self.size_limit_3d[dim]:
                    self._downsampling = True
        if self._downsampling:
            self.extractor = itk.ExtractImageFilter.New(self.image)
            self.extractor.InPlaceOn()
            self.shrinker = itk.BinShrinkImageFilter.New(self.extractor)
        self._update_rendered_image()
        if self._downsampling:
            self.observe(self._on_roi_changed, ['roi'])

        self.observe(self._on_reset_crop_requested, ['_reset_crop_requested'])
        self.observe(self.update_rendered_image, ['image'])

    @debounced(delay_seconds=1.5, method=True)
    def _on_roi_changed(self, change=None):
        if self._downsampling:
            self._update_rendered_image()

    def _on_reset_crop_requested(self, change=None):
        if change.new == True and self._downsampling:
            dimension = self.image.GetImageDimension()
            largest_region = self.image.GetLargestPossibleRegion()
            size = largest_region.GetSize()
            largest_index = largest_region.GetIndex()
            new_roi = self.roi.copy()
            new_roi[0][:dimension] = np.array(
                self.image.TransformIndexToPhysicalPoint(largest_index))
            largest_index_upper = largest_index + size
            new_roi[1][:dimension] = np.array(
                self.image.TransformIndexToPhysicalPoint(largest_index_upper))
            self._largest_roi = new_roi.copy()
            self.roi = new_roi
        if change.new == True:
            self._reset_crop_requested = False

    @debounced(delay_seconds=0.2, method=True)
    def update_rendered_image(self, change=None):
        self._largest_roi_rendered_image = None
        self._largest_roi = np.zeros((2, 3), dtype=np.float64)
        self._update_rendered_image()

    @staticmethod
    def _find_shrink_factors(limit, dimension, size):
        shrink_factors = [
            1,
        ] * dimension
        for dim in range(dimension):
            while (int(np.floor(float(size[dim]) / shrink_factors[dim])) >
                   limit[dim]):
                shrink_factors[dim] += 1
        return shrink_factors

    def _update_rendered_image(self):
        if self.image is None:
            return
        if self._rendering_image:

            @yield_for_change(self, '_rendering_image')
            def f():
                x = yield
                assert (x == False)

            f()
        self._rendering_image = True

        if self._downsampling:
            dimension = self.image.GetImageDimension()
            index = self.image.TransformPhysicalPointToIndex(
                self.roi[0][:dimension])
            upper_index = self.image.TransformPhysicalPointToIndex(
                self.roi[1][:dimension])
            size = upper_index - index

            if dimension == 2:
                shrink_factors = self._find_shrink_factors(
                    self.size_limit_2d, dimension, size)
            else:
                shrink_factors = self._find_shrink_factors(
                    self.size_limit_3d, dimension, size)
            self.shrinker.SetShrinkFactors(shrink_factors)

            region = itk.ImageRegion[dimension]()
            region.SetIndex(index)
            region.SetSize(tuple(size))
            # Account for rounding
            # truncation issues
            region.PadByRadius(1)
            region.Crop(self.image.GetLargestPossibleRegion())

            self.extractor.SetExtractionRegion(region)

            size = region.GetSize()

            is_largest = False
            if np.any(self._largest_roi) and np.all(
                    self._largest_roi == self.roi):
                is_largest = True
                if self._largest_roi_rendered_image is not None:
                    self.rendered_image = self._largest_roi_rendered_image
                    return

            self.shrinker.UpdateLargestPossibleRegion()
            if is_largest:
                self._largest_roi_rendered_image = self.shrinker.GetOutput()
                self._largest_roi_rendered_image.DisconnectPipeline()
                self._largest_roi_rendered_image.SetOrigin(
                    self.roi[0][:dimension])
                self.rendered_image = self._largest_roi_rendered_image
                return
            shrunk = self.shrinker.GetOutput()
            shrunk.DisconnectPipeline()
            shrunk.SetOrigin(self.roi[0][:dimension])
            self.rendered_image = shrunk
        else:
            self.rendered_image = self.image

    @validate('gradient_opacity')
    def _validate_gradient_opacity(self, proposal):
        """Enforce 0 < value <= 1.0."""
        value = proposal['value']
        if value <= 0.0:
            return 0.01
        if value > 1.0:
            return 1.0
        return value

    @validate('cmap')
    def _validate_cmap(self, proposal):
        value = proposal['value']
        if not value in COLORMAPS:
            raise TraitError('Invalid colormap')
        return value

    def roi_region(self):
        """Return the itk.ImageRegion corresponding to the roi."""
        dimension = self.image.GetImageDimension()
        index = self.image.TransformPhysicalPointToIndex(
            tuple(self.roi[0][:dimension]))
        upper_index = self.image.TransformPhysicalPointToIndex(
            tuple(self.roi[1][:dimension]))
        size = upper_index - index
        for dim in range(dimension):
            size[dim] += 1
        region = itk.ImageRegion[dimension]()
        region.SetIndex(index)
        region.SetSize(tuple(size))
        region.Crop(self.image.GetLargestPossibleRegion())
        return region

    def roi_slice(self):
        """Return the numpy array slice corresponding to the roi."""
        dimension = self.image.GetImageDimension()
        region = self.roi_region()
        index = region.GetIndex()
        upper_index = np.array(index) + np.array(region.GetSize())
        slices = []
        for dim in range(dimension):
            slices.insert(0, slice(index[dim], upper_index[dim] + 1))
        return tuple(slices)