示例#1
0
class BaseImageController(ControllerBase):
    plot = Instance(BasePlotContainer)
    plotdata = Instance(ArrayPlotData)
    
    def __init__(self, parent, treasure_chest=None, data_path='/rawdata', *args, **kw):
        super(BaseImageController, self).__init__(parent, treasure_chest, data_path,
                                              *args, **kw)
        self.plotdata = ArrayPlotData()
        self._can_save = True
        self._can_change_idx = True

    def init_plot(self):
        self.plotdata.set_data('imagedata', self.get_active_image())
        self.plot = self.get_simple_image_plot(array_plot_data = self.plotdata,
                title = self.get_active_name()
                )

    def data_updated(self):
        # reinitialize data
        self.__init__(parent = self.parent, treasure_chest=self.chest,
                      data_path=self.data_path)

    # this is a 2D image for plotting purposes
    def get_active_image(self):
        nodes = self.chest.list_nodes('/rawdata')
        if len(nodes) > 0:
            return nodes[self.selected_index][:]

    def get_active_name(self):
        nodes = self.chest.list_nodes('/rawdata')
        return nodes[self.selected_index].name
    
    @on_trait_change("selected_index")
    def update_image(self):
        if self.chest is None or self.numfiles<1:
            return
        # get the old image for the sake of comparing image sizes
        old_data = self.plotdata.get_data('imagedata')
        active_image = self.get_active_image()
        self.plotdata.set_data("imagedata", active_image)
        self.set_plot_title(self.get_active_name())
        if old_data.shape != active_image.shape:
            grid_data_source = self._base_plot.range2d.sources[0]
            grid_data_source.set_data(np.arange(active_image.shape[1]), 
                                  np.arange(active_image.shape[0]))
            self.plot = self.get_simple_image_plot(array_plot_data = self.plotdata,
                    title = self.get_active_name())
            self.plot.aspect_ratio=(float(active_image.shape[1])/active_image.shape[0])

    def open_save_UI(self, plot_id='plot'):
        save_controller = SaveFileController(plot=self.get_plot(plot_id), parent=self)
        save_dialog = simple_session('save', 'Save dialog', SavePlotDialog, 
                                      controller=save_controller)
        Application.instance().add_factories([save_dialog])
        session_id = Application.instance().start_session('save')
        save_controller._session_id = session_id
示例#2
0
class ChacoReporter(StateDataReporter, HasTraits):
    plots = Instance(VPlotContainer)
    labels = List

    traits_view = View(
        Group(Item('plots', editor=ComponentEditor(), show_label=False)),
        width=800, height=600, resizable=True,
        title='OpenMM')

    def construct_plots(self):
        """Build the Chaco Plots. This will be run on the first report
        """
        self.labels = super(ChacoReporter, self)._headers()

        self.plots = VPlotContainer(resizable="hv", bgcolor="lightgray",
                                    fill_padding=True, padding=10)
        # this looks cryptic, but it is equivalent to
        # ArrayPlotData(a=[], b=[], c=[])
        # if the keys are a,b,c. This just does it for all of the keys.
        self.plotdata = ArrayPlotData(**dict(zip(self.labels,
                                      [[]]*len(self.labels))))

        # figure out which key will be the x axis

        x = None
        x_labels = ['Time (ps)', 'Step']
        for possible_x in x_labels:
            if possible_x in self.labels:
                x = possible_x
                break
        if x is None:
            raise ValueError('The reporter published neither the step nor time'
                'count, so I don\'t know what to plot on the x-axis!')

        colors = itertools.cycle(['blue', 'green', 'silver', 'pink', 'lightblue',
                                  'red', 'darkgray', 'lightgreen'])

        for y in set(self.labels).difference(x_labels):
            self.plots.add(chaco_scatter(self.plotdata, x_name=x, y_name=y,
                                         color=colors.next()))

    def _constructReportValues(self, simulation, state):
        values = super(ChacoReporter, self)._constructReportValues(simulation, state)

        for i, label in enumerate(self.labels):
            current = self.plotdata.get_data(label)
            self.plotdata.set_data(label, np.r_[current, float(values[i])])

        return values

    def report(self, simulation, state):
        if not self._hasInitialized:
            self.construct_plots()
        super(ChacoReporter, self).report(simulation, state)
    def test_data_changed_events(self):
        # Test data.
        grumpy = numpy.ones((3, 4))
        grumpy_too = numpy.zeros(16)

        plot_data = ArrayPlotData()

        with self.monitor_events(plot_data) as events:
            plot_data.set_data('Grumpy', grumpy)
            self.assertEqual(events, [{'added': ['Grumpy']}])

        # While we're here, check that get_data works as advertised.
        grumpy_out = plot_data.get_data('Grumpy')
        self.assertIs(grumpy_out, grumpy)

        with self.monitor_events(plot_data) as events:
            plot_data.set_data('Grumpy', grumpy_too)
            self.assertEqual(events, [{'changed': ['Grumpy']}])

        with self.monitor_events(plot_data) as events:
            plot_data.del_data('Grumpy')
            self.assertEqual(events, [{'removed': ['Grumpy']}])
    def test_data_changed_events(self):
        # Test data.
        grumpy = numpy.ones((3, 4))
        grumpy_too = numpy.zeros(16)

        plot_data = ArrayPlotData()

        with self.monitor_events(plot_data) as events:
            plot_data.set_data("Grumpy", grumpy)
            self.assertEqual(events, [{"added": ["Grumpy"]}])

        # While we're here, check that get_data works as advertised.
        grumpy_out = plot_data.get_data("Grumpy")
        self.assertIs(grumpy_out, grumpy)

        with self.monitor_events(plot_data) as events:
            plot_data.set_data("Grumpy", grumpy_too)
            self.assertEqual(events, [{"changed": ["Grumpy"]}])

        with self.monitor_events(plot_data) as events:
            plot_data.del_data("Grumpy")
            self.assertEqual(events, [{"removed": ["Grumpy"]}])
示例#5
0
class TwoDimensionalPlot(ChacoPlot):
    """
	A 2D plot.
	"""

    auto_color_idx = 0
    auto_color_list = ['green', 'brown', 'blue', 'red', 'black']

    @classmethod
    def auto_color(cls):
        """
		Choose the next color.
		"""

        color = cls.auto_color_list[cls.auto_color_idx]
        cls.auto_color_idx = (cls.auto_color_idx + 1) % len(
            cls.auto_color_list)

        return color

    def __init__(self, parent, color=None, *args, **kwargs):
        self.parent = parent

        if color is None:
            color = self.auto_color()

        self.data = ArrayPlotData()
        self.data.set_data('x', [0])
        self.data.set_data('y', [0])

        ChacoPlot.__init__(self, self.data, *args, **kwargs)

        self.plot(('x', 'y'), color=color)

        self.configure()

    @property
    def control(self):
        """
		A drawable control.
		"""

        return Window(self.parent, component=self).control

    def get_data(self, axis):
        """
		Values for an axis.
		"""

        return self.data.get_data(axis)

    def set_data(self, values, axis):
        self.data.set_data(axis, values)

    x_data = property(partial(get_data, axis='x'), partial(set_data, axis='x'))
    y_data = property(partial(get_data, axis='y'), partial(set_data, axis='y'))

    def x_autoscale(self):
        """
		Enable autoscaling for the x axis.
		"""

        x_range = self.plots.values()[0][0].index_mapper.range
        x_range.low = x_range.high = 'auto'

    def y_autoscale(self):
        """
		Enable autoscaling for the y axis.
		"""

        y_range = self.plots.values()[0][0].value_mapper.range
        y_range.low = y_range.high = 'auto'
示例#6
0
class OpenMMScriptRunner(HasTraits):
    plots = Instance(VPlotContainer)
    plots_created = Bool
    openmm_script_code = String
    status = String

    traits_view = View(
        Group(
            HGroup(spring, Item('status', style='readonly'), spring),
            Item('plots', editor=ComponentEditor(),
                show_label=False)
        ),
        width=800, height=600, resizable=True,
        title='OpenMM Script Runner'
    )


    def __init__(self, **traits):
        super(OpenMMScriptRunner, self).__init__(**traits)

        self._plots_created = False

        q = Queue.Queue()
        # start up two threads. the first, t1, will run the script
        # and place the statedata into the queue
        # the second will remove elements from the queue and update the
        # plots in the UI
        t1 = threading.Thread(target=run_openmm_script,
                              args=(self.openmm_script_code, q))
        t2 = threading.Thread(target=self.queue_consumer, args=(q,))
        t1.start()
        t2.start()

    def queue_consumer(self, q):
        """Main loop for a thread that consumes the messages from the queue
        and plots them"""

        self.status = 'Running...'

        while True:
            try:
                msg = q.get_nowait()
                if msg is None:
                    break
                self.update_plot(msg)
            except Queue.Empty:
                time.sleep(0.1)

        self.status = 'Done'

    def create_plots(self, keys):
        """Create the plots

        Paramters
        ---------
        keys : list of strings
            A list of all of the keys in the msg dict. This should be something
            like ['Step', 'Temperature', 'Potential Energy']. We'll create the
            ArrayPlotData container in which each of these timeseries will
            get put.
        """

        self.plots = VPlotContainer(resizable = "hv", bgcolor="lightgray",
                               fill_padding=True, padding = 10)
        # this looks cryptic, but it is equivalent to
        # ArrayPlotData(a=[], b=[], c=[])
        # if the keys are a,b,c. This just does it for all of the keys.
        self.plotdata = ArrayPlotData(**dict(zip(keys, [[]]*len(keys))))

        # figure out which key will be the x axis
        if 'Step' in keys:
            x = 'Step'
        elif 'Time (ps)' in keys:
            x = 'Time (ps)'
        else:
            raise ValueError('The reporter published neither the step nor time'
                'count, so I don\'t know what to plot on the x-axis!')


        colors = itertools.cycle(['blue', 'green', 'silver', 'pink', 'lightblue',
                                  'red', 'darkgray', 'lightgreen',])
        for y in filter(lambda y: y != x, keys):
            self.plots.add(chaco_scatter(self.plotdata, x_name=x, y_name=y,
                                         color=colors.next()))

    def update_plot(self, msg):
        """Add data points from the message to the plots

        Paramters
        ---------
        msg : dict
            This is the message sent over the Queue from the script
        """
        if not self.plots_created:
            self.create_plots(msg.keys())
            self.plots_created = True

        for k, v in msg.iteritems():
            current = self.plotdata.get_data(k)
            self.plotdata.set_data(k, np.r_[current, v])
示例#7
0
class OneDViewer(BaseViewer):
    """ This class just contains the two data arrays that will be updated
    by the Controller.  The visualization/editor for this class is a
    Chaco plot.
    """
    #mode = Enum(['Rolling','Replace'])
    #positions = Array()
    #max_num_points = Int(1000)
    ndim = Int(1)
    num_ticks = Int(0)
    resolution = Float(1.)
    start_pos = Float(0.)
    csr_pos = Array
    pd = Instance(ArrayPlotData,transient=True)
    plot = Any()
    #data = Array()
    #xbounds = Property(property_depends_on='index, start_pos, resolution')



    traits_view = View(

        Item('plot', editor=ComponentEditor(), show_label=False),
        #HGroup(spring, Item("plot_type", style='custom'), spring),
        resizable=True,

        )
    def __init__(self, *args,**kargs):
        super(OneDViewer, self).__init__( *args,**kargs)
        self.create_plot_element()

    def set_data(self, new_data, idx=None):
        if idx:
            self.data[idx] = new_data
            self.csr_pos = idx
        else:
            self.data = new_data
            self.csr_pos = np.array([0])
        self.refresh()

    def _data_default(self):
        return np.full((self.max_size,), np.nan)

    def _csr_pos_default(self):
        return np.array([0])

    def create_plot_element(self):
        self.pd = ArrayPlotData(x=np.arange(self.data.size),
                                y=self.data,

                                posx=self.csr_pos,
                                posy=np.array([self.data[self.csr_pos]]))

        plot = Plot(self.pd)
        plot.plot(("x", "y"),
                  #type_trait="plot_type",
                     #type='line_scatter_1d',
                              #resizable='',
                              title='',
                              #x_label="Time",
                              y_label="Signal",
                              color=tuple(cbrewer[np.random.randint(0,10)]),
                              bgcolor="grey",
                              border_visible=True,
                              border_width=1,
                              #padding_bg_color="lightgray",
                              width=800,
                              height=380,
                              marker_size=2,
                              show_label=False)

        plot.plot_1d("posx",
                             #type="scatter",
                     type="line_scatter_1d",
                             name="dot",
                            color="red",
                             #color_mapper=self._cmap(image_value_range),
                             marker="circle",
                             marker_size=4)
        self.pd.set_data('x', np.arange(self.data.size))
        self.pd.set_data('y', self.data)
        self.pd.set_data('posx', self.csr_pos)
        self.pd.set_data('posy', np.array([self.data[self.csr_pos]]))

        self.plot = plot

    def refresh(self):

        if self.data.size == self.pd.get_data('y').size:
            self.pd.set_data('x', np.arange(self.data.size))
            self.pd.set_data('y', self.data)
            self.pd.set_data('posx', self.csr_pos)
            self.pd.set_data('posy', np.array([self.data[self.csr_pos]]))
        else:
            self.create_plot_element()




                # def _data_changed(self,old, new):
    #     if old is None:
    #         self.create_plot_element()
    #         return
    #     if new is None:
    #         return
    #
    #     if new.shape==old.shape:
    #         self.pd.set_data('y', self.data)
    #     else:
    #         self.create_plot_element()

    #def _plot_default(self):
        #return self.create_plot_element()


    # @on_trait_change('data, resolution, start_pos')
    # def update_positions(self):
    #     if self.data is not None:
    #         return np.linspace(self.start_pos, self.start_pos+self.data.size * self.resolution, self.data.size)
    #     else:
    #         return (0.,1.)

    @property_depends_on('data[]')
    def _get_positions(self):
        if self.data is not None:
            return np.arange(self.data.size)
        else:
            return np.array([0,1])
示例#8
0
class ColormappedPlot(ChacoPlot):
	"""
	A colormapped plot.
	"""

	def __init__(self, parent, x_bounds, y_bounds, *args, **kwargs):
		self.parent = parent

		self.data = ArrayPlotData()
		self.data.set_data('color', [[0]])

		ChacoPlot.__init__(self, self.data, *args, **kwargs)

		self.img_plot('color', colormap=jet, xbounds=x_bounds, ybounds=y_bounds)

		self.configure()

	@property
	def plot_obj(self):
		"""
		The actual plot object.
		"""

		return self.plots.values()[0][0]

	@property
	def control(self):
		"""
		A drawable control with a color bar.
		"""

		color_map = self.plot_obj.color_mapper
		linear_mapper = LinearMapper(range=color_map.range)
		color_bar = ColorBar(index_mapper=linear_mapper, color_mapper=color_map, plot=self.plot_obj,
				orientation='v', resizable='v', width=30)
		color_bar._axis.tick_label_formatter = self.sci_formatter
		color_bar.padding_top = self.padding_top
		color_bar.padding_bottom = self.padding_bottom
		color_bar.padding_left = 50 # Room for labels.
		color_bar.padding_right = 10

		range_selection = RangeSelection(component=color_bar)
		range_selection.listeners.append(self.plot_obj)
		color_bar.tools.append(range_selection)

		range_selection_overlay = RangeSelectionOverlay(component=color_bar)
		color_bar.overlays.append(range_selection_overlay)

		container = HPlotContainer(use_backbuffer=True)
		container.add(self)
		container.add(color_bar)

		return Window(self.parent, component=container).control

	@property
	def color_data(self):
		"""
		Plotted values.
		"""

		return self.data.get_data('color')

	@color_data.setter
	def color_data(self, values):
		self.data.set_data('color', values)

	@property
	def low_setting(self):
		"""
		Lowest color value.
		"""

		return self.plot_obj.color_mapper.range.low

	@low_setting.setter
	def low_setting(self, value):
		self.plot_obj.color_mapper.range.low_setting = value

	@property
	def high_setting(self):
		"""
		Highest color value.
		"""

		return self.plot_obj.color_mapper.range.high

	@high_setting.setter
	def high_setting(self, value):
		self.plot_obj.color_mapper.range.high_setting = value
示例#9
0
class Plot2D(DataView):

    #------------------------------------------------------------------------
    # Data-related traits
    #------------------------------------------------------------------------

    # The PlotData instance that drives this plot.
    data = Instance(AbstractPlotData)

    # Mapping of data names from self.data to their respective datasources.
    datasources = Dict(Str, Instance(AbstractDataSource))

    #------------------------------------------------------------------------
    # General plotting traits
    #------------------------------------------------------------------------

    # Mapping of plot names to *lists* of plot renderers.
    plots = Dict(Str, List)

    index2d = Instance(GridDataSource)

    # Optional mapper for the color axis.  Not instantiated until first use;
    # destroyed if no color plots are on the plot.
    color_mapper = Instance(AbstractColormap)

    # Mapping of renderer type string to renderer class
    # This can be overriden to customize what renderer type the Plot
    # will instantiate for its various plotting methods.
    renderer_map = Dict(dict(img_plot = ImagePlot,
                             cmap_img_plot = CMapImagePlot,
                             contour_line_plot = ContourLinePlot,
                             contour_poly_plot = ContourPolyPlot,
                             ))

    #------------------------------------------------------------------------
    # Annotations and decorations
    #------------------------------------------------------------------------

    # The legend on the plot.
    legend = Instance(Legend)

    # Convenience attribute for legend.align; can be "ur", "ul", "ll", "lr".
    legend_alignment = Property

    #------------------------------------------------------------------------
    # Public methods
    #------------------------------------------------------------------------

    def __init__(self, data=None, **kwtraits):
        if 'origin' in kwtraits:
            self.default_origin = kwtraits.pop('origin')
        if 'bgcolor' not in kwtraits:
            kwtraits['bgcolor'] = 'black'
            
        super(Plot2D, self).__init__(**kwtraits)
        
        if data is not None:
            if isinstance(data, AbstractPlotData):
                self.data = data
            elif type(data) in (ndarray, tuple, list):
                self.data = ArrayPlotData(data)
            else:
                raise ValueError, "Don't know how to create PlotData for data" \
                                  "of type " + str(type(data))

        if not self.legend:
            self.legend = Legend(visible=False, align="ur", error_icon="blank",
                                 padding=10, component=self)

        # ensure that we only get displayed once by new_window()
        self._plot_ui_info = None

        return

    def img_plot(self, data, name=None, colormap=None,
                 xbounds=None, ybounds=None, origin=None, hide_grids=True, **styles):
        """ Adds image plots to this Plot object.

        If *data* has shape (N, M, 3) or (N, M, 4), then it is treated as RGB or
        RGBA (respectively) and *colormap* is ignored.

        If *data* is an array of floating-point data, then a colormap can
        be provided via the *colormap* argument, or the default of 'Spectral'
        will be used.

        *Data* should be in row-major order, so that xbounds corresponds to
        *data*'s second axis, and ybounds corresponds to the first axis.

        Parameters
        ==========
        data : string
            The name of the data array in self.plot_data
        name : string
            The name of the plot; if omitted, then a name is generated.
        xbounds, ybounds : string, tuple, or ndarray
            Bounds where this image resides. Bound may be: a) names of
            data in the plot data; b) tuples of (low, high) in data space,
            c) 1D arrays of values representing the pixel boundaries (must
            be 1 element larger than underlying data), or
            d) 2D arrays as obtained from a meshgrid operation
        origin : string
            Which corner the origin of this plot should occupy:
                "bottom left", "top left", "bottom right", "top right"
        hide_grids : bool, default True
            Whether or not to automatically hide the grid lines on the plot
        styles : series of keyword arguments
            Attributes and values that apply to one or more of the
            plot types requested, e.g.,'line_color' or 'line_width'.
        """
        if name is None:
            name = self._make_new_plot_name()
        if origin is None:
            origin = self.default_origin

        value = self._get_or_create_datasource(data)
        array_data = value.get_data()
        if len(array_data.shape) == 3:
            if array_data.shape[2] not in (3,4):
                raise ValueError("Image plots require color depth of 3 or 4.")
            cls = self.renderer_map["img_plot"]
            kwargs = dict(**styles)
        else:
            if colormap is None:
                if self.color_mapper is None:
                    colormap = Spectral(DataRange1D(value))
                else:
                    colormap = self.color_mapper
            elif isinstance(colormap, AbstractColormap):
                if colormap.range is None:
                    colormap.range = DataRange1D(value)
            else:
                colormap = colormap(DataRange1D(value))
            self.color_mapper = colormap
            cls = self.renderer_map["cmap_img_plot"]
            kwargs = dict(value_mapper=colormap, **styles)
        return self._create_2d_plot(cls, name, origin, xbounds, ybounds, value,
                                    hide_grids, **kwargs)


    def contour_plot(self, data, type="line", name=None, poly_cmap=None,
                     xbounds=None, ybounds=None, origin=None, hide_grids=True, **styles):
        """ Adds contour plots to this Plot object.

        Parameters
        ==========
        data : string
            The name of the data array in self.plot_data, which must be
            floating point data.
        type : comma-delimited string of "line", "poly"
            The type of contour plot to add. If the value is "poly"
            and no colormap is provided via the *poly_cmap* argument, then
            a default colormap of 'Spectral' is used.
        name : string
            The name of the plot; if omitted, then a name is generated.
        poly_cmap : string
            The name of the color-map function to call (in
            chaco.default_colormaps) or an AbstractColormap instance
            to use for contour poly plots (ignored for contour line plots)
        xbounds, ybounds : string, tuple, or ndarray
            Bounds where this image resides. Bound may be: a) names of
            data in the plot data; b) tuples of (low, high) in data space,
            c) 1D arrays of values representing the pixel boundaries (must
            be 1 element larger than underlying data), or
            d) 2D arrays as obtained from a meshgrid operation
        origin : string
            Which corner the origin of this plot should occupy:
                "bottom left", "top left", "bottom right", "top right"
        hide_grids : bool, default True
            Whether or not to automatically hide the grid lines on the plot
        styles : series of keyword arguments
            Attributes and values that apply to one or more of the
            plot types requested, e.g.,'line_color' or 'line_width'.
        """
        if name is None:
            name = self._make_new_plot_name()
        if origin is None:
            origin = self.default_origin

        value = self._get_or_create_datasource(data)
        if value.value_depth != 1:
            raise ValueError("Contour plots require 2D scalar field")
        if type == "line":
            cls = self.renderer_map["contour_line_plot"]
            kwargs = dict(**styles)
            # if colors is given as a factory func, use it to make a
            # concrete colormapper. Better way to do this?
            if "colors" in kwargs:
                cmap = kwargs["colors"]
                if isinstance(cmap, FunctionType):
                    kwargs["colors"] = cmap(DataRange1D(value))
                elif getattr(cmap, 'range', 'dummy') is None:
                    cmap.range = DataRange1D(value)
        elif type == "poly":
            if poly_cmap is None:
                poly_cmap = Spectral(DataRange1D(value))
            elif isinstance(poly_cmap, FunctionType):
                poly_cmap = poly_cmap(DataRange1D(value))
            elif getattr(poly_cmap, 'range', 'dummy') is None:
                poly_cmap.range = DataRange1D(value)
            cls = self.renderer_map["contour_poly_plot"]
            kwargs = dict(color_mapper=poly_cmap, **styles)
        else:
            raise ValueError("Unhandled contour plot type: " + type)

        return self._create_2d_plot(cls, name, origin, xbounds, ybounds, value,
                                    hide_grids, **kwargs)


    def _process_2d_bounds(self, bounds, array_data, axis):
        """Transform an arbitrary bounds definition into a linspace.

        Process all the ways the user could have defined the x- or y-bounds
        of a 2d plot and return a linspace between the lower and upper
        range of the bounds.

        Parameters
        ----------
        bounds : any
            User bounds definition

        array_data : 2D array
            The 2D plot data

        axis : int
            The axis along which the bounds are tyo be set
        """

        num_ticks = array_data.shape[axis] + 1

        if bounds is None:
            return arange(num_ticks)

        if type(bounds) is tuple:
            # create a linspace with the bounds limits
            return linspace(bounds[0], bounds[1], num_ticks)

        if type(bounds) is ndarray and len(bounds.shape) == 1:
            # bounds is 1D, but of the wrong size

            if len(bounds) != num_ticks:
                msg = ("1D bounds of an image plot needs to have 1 more "
                       "element than its corresponding data shape, because "
                       "they represent the locations of pixel boundaries.")
                raise ValueError(msg)
            else:
                return linspace(bounds[0], bounds[-1], num_ticks)

        if type(bounds) is ndarray and len(bounds.shape) == 2:
            # bounds is 2D, assumed to be a meshgrid
            # This is triggered when doing something like
            # >>> xbounds, ybounds = meshgrid(...)
            # >>> z = f(xbounds, ybounds)

            if bounds.shape != array_data.shape:
                msg = ("2D bounds of an image plot needs to have the same "
                       "shape as the underlying data, because "
                       "they are assumed to be generated from meshgrids.")
                raise ValueError(msg)
            else:
                if axis == 0: bounds = bounds[:,0]
                else: bounds = bounds[0,:]
                interval = bounds[1] - bounds[0]
                return linspace(bounds[0], bounds[-1]+interval, num_ticks)

        raise ValueError("bounds must be None, a tuple, an array, "
                         "or a PlotData name")


    def _create_2d_plot(self, cls, name, origin, xbounds, ybounds, value_ds,
                        hide_grids, **kwargs):
        if name is None:
            name = self._make_new_plot_name()
        if origin is None:
            origin = self.default_origin

        array_data = value_ds.get_data()

        #~ # process bounds to get linspaces
        if isinstance(xbounds, basestring):
            xbounds = self._get_or_create_datasource(xbounds).get_data()

        xs = self._process_2d_bounds(xbounds, array_data, 1)

        if isinstance(ybounds, basestring):
            ybounds = self._get_or_create_datasource(ybounds).get_data()

        ys = self._process_2d_bounds(ybounds, array_data, 0)

        # Create the index and add its datasources to the appropriate ranges
        self.index = GridDataSource(xs, ys, sort_order=('ascending', 'ascending'))
        self.range2d.add(self.index)
        mapper = GridMapper(range=self.range2d,
                            stretch_data_x=self.x_mapper.stretch_data,
                            stretch_data_y=self.y_mapper.stretch_data)

        plot = cls(index=self.index,
                   value=value_ds,
                   index_mapper=mapper,
                   orientation=self.orientation,
                   origin=origin,
                   **kwargs)

        if hide_grids:
            self.x_grid.visible = False
            self.y_grid.visible = False

        self.add(plot)
        self.plots[name] = [plot]
        return self.plots[name]

    def delplot(self, *names):
        """ Removes the named sub-plots. """

        # This process involves removing the plots, then checking the index range
        # and value range for leftover datasources, and removing those if necessary.

        # Remove all the renderers from us (container) and create a set of the
        # datasources that we might have to remove from the ranges
        deleted_sources = set()
        for renderer in itertools.chain(*[self.plots.pop(name) for name in names]):
            self.remove(renderer)
            deleted_sources.add(renderer.index)
            deleted_sources.add(renderer.value)

        # Cull the candidate list of sources to remove by checking the other plots
        sources_in_use = set()
        for p in itertools.chain(*self.plots.values()):
                sources_in_use.add(p.index)
                sources_in_use.add(p.value)

        unused_sources = deleted_sources - sources_in_use - set([None])

        # Remove the unused sources from all ranges
        for source in unused_sources:
            if source.index_dimension == "scalar":
                # Try both index and range, it doesn't hurt
                self.index_range.remove(source)
                self.value_range.remove(source)
            elif source.index_dimension == "image":
                self.range2d.remove(source)
            else:
                warnings.warn("Couldn't remove datasource from datarange.")

        return

    def hideplot(self, *names):
        """ Convenience function to sets the named plots to be invisible.  Their
        renderers are not removed, and they are still in the list of plots.
        """
        for renderer in itertools.chain(*[self.plots[name] for name in names]):
            renderer.visible = False
        return

    def showplot(self, *names):
        """ Convenience function to sets the named plots to be visible.
        """
        for renderer in itertools.chain(*[self.plots[name] for name in names]):
            renderer.visible = True
        return

    #------------------------------------------------------------------------
    # Private methods
    #------------------------------------------------------------------------



    def _make_new_plot_name(self):
        """ Returns a string that is not already used as a plot title.
        """
        n = len(self.plots)
        plot_template = "plot%d"
        while 1:
            name = plot_template % n
            if name not in self.plots:
                break
            else:
                n += 1
        return name

    def _get_or_create_datasource(self, name, sort_order = 'none'):
        """ Returns the data source associated with the given name, or creates
        it if it doesn't exist.
        """

        if name not in self.datasources:
            data = self.data.get_data(name)

            if type(data) in (list, tuple):
                data = array(data)

            if isinstance(data, ndarray):
                if len(data.shape) == 1:
                    ds = ArrayDataSource(data, sort_order=sort_order)
                elif len(data.shape) == 2:
                    ds = ImageData(data=data, value_depth=1)
                elif len(data.shape) == 3:
                    if data.shape[2] in (3,4):
                        ds = ImageData(data=data, value_depth=int(data.shape[2]))
                    else:
                        raise ValueError("Unhandled array shape in creating new plot: " \
                                         + str(data.shape))

            elif isinstance(data, AbstractDataSource):
                ds = data

            else:
                raise ValueError("Couldn't create datasource for data of type " + \
                                 str(type(data)))

            self.datasources[name] = ds

        return self.datasources[name]

    #------------------------------------------------------------------------
    # Event handlers
    #------------------------------------------------------------------------

    def _color_mapper_changed(self):
        for plist in self.plots.values():
            for plot in plist:
                plot.color_mapper = self.color_mapper
        self.invalidate_draw()

    def _data_changed(self, old, new):
        if old:
            old.on_trait_change(self._data_update_handler, "data_changed",
                                remove=True)
        if new:
            new.on_trait_change(self._data_update_handler, "data_changed")

    def _data_update_handler(self, name, event):
        # event should be a dict with keys "added", "removed", and "changed",
        # per the comments in AbstractPlotData.
        if event.has_key("added"):
            pass

        if event.has_key("removed"):
            pass

        if event.has_key("changed"):
            for name in event["changed"]:
                if self.datasources.has_key(name):
                    source = self.datasources[name]
                    source.set_data(self.data.get_data(name))

    def _plots_items_changed(self, event):
        if self.legend:
            self.legend.plots = self.plots

    def _legend_changed(self, old, new):
        self._overlay_change_helper(old, new)
        if new:
            new.plots = self.plots

    def _handle_range_changed(self, name, old, new):
        """ Overrides the DataView default behavior.

        Primarily changes how the list of renderers is looked up.
        """
        mapper = getattr(self, name+"_mapper")
        if mapper.range == old:
            mapper.range = new
        if old is not None:
            for datasource in old.sources[:]:
                old.remove(datasource)
                if new is not None:
                    new.add(datasource)
        range_name = name + "_range"
        for renderer in itertools.chain(*self.plots.values()):
            if hasattr(renderer, range_name):
                setattr(renderer, range_name, new)

    #------------------------------------------------------------------------
    # Property getters and setters
    #------------------------------------------------------------------------

    def _set_legend_alignment(self, align):
        if self.legend:
            self.legend.align = align

    def _get_legend_alignment(self):
        if self.legend:
            return self.legend.align
        else:
            return None
示例#10
0
class CellCropController(BaseImageController):
    zero=Int(0)
    template_plot = Instance(BasePlotContainer)
    template_data = Instance(ArrayPlotData)
    template_size = Range(low=2, high=512, value=64, cols=4)
    template_top = Range(low='zero',high='max_pos_y', value=20, cols=4)
    template_left = Range(low='zero',high='max_pos_x', value=20, cols=4)
    peaks = Dict({})
    ShowCC = Bool(False)
    max_pos_x = Int(256)
    max_pos_y = Int(256)
    is_square = Bool(True)
    peak_width = Range(low=2, high=200, value=10)
    numpeaks_total = Int(0,cols=5)
    numpeaks_img = Int(0,cols=5)
    _session_id = String('')

    def __init__(self, parent, treasure_chest=None, data_path='/rawdata', 
                 *args, **kw):
        super(CellCropController, self).__init__(parent, treasure_chest, 
                                                 data_path, *args, **kw)
        
        if self.chest is not None:
            self.numfiles = len(self.nodes)
            if self.numfiles > 0:
                self.init_plot()
                print "initialized plot for data in %s" % data_path
    
    def data_updated(self):
        # reinitialize data
        self.__init__(parent = self.parent, treasure_chest=self.chest,
                      data_path=self.data_path)
        
    
    def init_plot(self):
        self.plotdata.set_data('imagedata', self.get_active_image())
        self.plot = self.get_scatter_overlay_plot(array_plot_data=self.plotdata,
                title=self.get_active_name(),
                tools=['csr','colorbar','zoom','pan']
                    )
        # pick an initial template with default parameters
        self.template_data = ArrayPlotData()
        self.template_plot = Plot(self.template_data, default_origin="top left")
        self.template_data.set_data('imagedata',
                    self.get_active_image()[
                        self.template_top:self.template_top + self.template_size,
                        self.template_left:self.template_left + self.template_size
                    ]
                    )
        self.template_plot.img_plot('imagedata', title = "Template")
        self.template_plot.aspect_ratio=1 #square templates
        self.template_filename = self.get_active_name()
        self._get_max_positions()

    @on_trait_change("selected_index, ShowCC")
    def update_image(self):
        if self.ShowCC:
            CC = cv_funcs.xcorr(self.template_data.get_data('imagedata'),
                                     self.get_active_image())
            self.plotdata.set_data("imagedata",CC)
            self.plot = self.get_scatter_overlay_plot(array_plot_data=self.plotdata,
                        title=self.get_active_name(),
                        tools=['csr','zoom','pan', 'colorbar'],
                        )
            self.plot.aspect_ratio = (float(CC.shape[1])/ 
                                                  CC.shape[0])                                    
            self.set_plot_title("Cross correlation of " + self.get_active_name())
            grid_data_source = self._base_plot.range2d.sources[0]
            grid_data_source.set_data(np.arange(CC.shape[1]), 
                                      np.arange(CC.shape[0]))
        else:                       
            self.plotdata.set_data("imagedata", self.get_active_image())
            self.plot = self.get_scatter_overlay_plot(array_plot_data=self.plotdata,
                        title=self.get_active_name(),
                        tools=['csr','zoom','pan', 'colorbar'],
                        )
            self.plot.aspect_ratio = (float(self.get_active_image().shape[1])/ 
                                      self.get_active_image().shape[0])             
            self.set_plot_title(self.get_active_name())
            grid_data_source = self._base_plot.range2d.sources[0]
            grid_data_source.set_data(np.arange(self.get_active_image().shape[1]), 
                                      np.arange(self.get_active_image().shape[0]))

    def update_CC(self):
        if self.ShowCC:
            CC = cv_funcs.xcorr(self.template_data.get_data('imagedata'),
                                     self.get_active_image())
            self.plotdata.set_data("imagedata",CC)

    @on_trait_change('template_left, template_top, template_size')
    def update_template_data(self):
        self.template_data.set_data('imagedata',
                    self.get_active_image()[
                        self.template_top:self.template_top + self.template_size,
                        self.template_left:self.template_left + self.template_size
                    ]
                    )
        self.template_filename = self.get_active_name()
        if self.numpeaks_total>0:
            print "clearing peaks"
            self.peaks={}
        # when template data changes, we should check whether to update the 
        #    cross correlation plot, which depends on the template
        self.update_CC()
    
    @on_trait_change('selected_index, template_size')
    def _get_max_positions(self):
        max_pos_x=self.get_active_image().shape[-1]-self.template_size-1
        if max_pos_x>0:
            self.max_pos_x = int(max_pos_x)
        max_pos_y=self.get_active_image().shape[-2]-self.template_size-1
        if max_pos_y>0:
            self.max_pos_y = int(max_pos_y)

    @on_trait_change('template_left, template_top')
    def update_csr_position(self):
        if self.template_left>0:        
            self._csr.current_position=self.template_left,self.template_top
        pass

    @on_trait_change('_csr:current_position')
    def update_top_left(self):
        if self._csr.current_position[0]>0 or self._csr.current_position[1]>0:
            if self._csr.current_position[0]>self.max_pos_x:
                if self._csr.current_position[1]<self.max_pos_y:
                    self.template_top=self._csr.current_position[1]
                else:
                    self._csr.current_position=self.max_pos_x, self.max_pos_y
            elif self._csr.current_position[1]>self.max_pos_y:
                self.template_left,self.template_top=self._csr.current_position[0],self.max_pos_y
            else:
                self.template_left,self.template_top=self._csr.current_position

    @on_trait_change('_colorbar_selection:selection')
    def update_thresh(self):
        try:
            thresh=self._colorbar_selection.selection
            self.thresh=thresh
            scatter_renderer=self._scatter_plot.plots['scatter_plot'][0]
            scatter_renderer.color_data.metadata['selections']=thresh
            self.thresh_lower=thresh[0]
            self.thresh_upper=thresh[1]
            scatter_renderer.color_data.metadata_changed={'selections':thresh}
            self.plot.request_redraw()
        except:
            pass

    @on_trait_change('thresh_upper,thresh_lower')
    def manual_thresh_update(self):
        self.thresh=[self.thresh_lower,self.thresh_upper]
        scatter_renderer=self._scatter_plot.plots['scatter_plot'][0]
        scatter_renderer.color_data.metadata['selections']=self.thresh
        scatter_renderer.color_data.metadata_changed={'selections':self.thresh}
        self.plot.request_redraw()

    @on_trait_change('peaks, _colorbar_selection:selection, selected_index')
    def calc_numpeaks(self):
        try:
            thresh=self._colorbar_selection.selection
            self.thresh=thresh
        except:
            thresh=[]
        if thresh==[] or thresh==() or thresh==None:
            thresh=(-1,1)
        self.numpeaks_total=int(np.sum([np.sum(np.ma.masked_inside(
            self.peaks[image_id][:,2], thresh[0], thresh[1]).mask) 
                                        for image_id in self.peaks.keys()
                                        ]
                                       )
                                )
        try:
            self.numpeaks_img=int(np.sum(np.ma.masked_inside(
                self.peaks[self.get_active_name()][:,2],
                thresh[0],thresh[1]).mask))
        except:
            self.numpeaks_img=0

    @on_trait_change('peaks, selected_index')
    def update_scatter_plot(self):
        data = self.plotdata.get_data('imagedata')
        aspect_ratio = (float(data.shape[1])/ 
                                  data.shape[0])        
        if self.get_active_name() in self.peaks:
            self.plotdata.set_data("index",self.peaks[self.get_active_name()][:,0])
            self.plotdata.set_data("value",self.peaks[self.get_active_name()][:,1])
            self.plotdata.set_data("color",self.peaks[self.get_active_name()][:,2])
            self.plot = self.get_scatter_overlay_plot(array_plot_data=self.plotdata,
                                                      tools=['zoom','pan','colorbar'])
            scatter_renderer = self._scatter_plot.plots['scatter_plot'][0]
            scatter_renderer.color_data.metadata['selections']=self.thresh
            scatter_renderer.color_data.metadata_changed={'selections':self.thresh}
        else:
            if 'index' in self.plotdata.arrays:
                self.plotdata.del_data('index')
                # value will implicitly exist if value exists.
                self.plotdata.del_data('value')
            if 'color' in self.plotdata.arrays:
                self.plotdata.del_data('color')
            self.plot = self.get_scatter_overlay_plot(array_plot_data=self.plotdata,
                                                      )
            
    def locate_peaks(self):
        peaks={}
        for idx in xrange(self.numfiles):
            self.set_active_index(idx)
            CC = cv_funcs.xcorr(self.template_data.get_data("imagedata"),
                                    self.get_active_image())
            # 
            pks=pc.two_dim_findpeaks((CC-CC.min())*255, medfilt_radius=None, alpha=1,
                                     coords_list=[],
                                     )
            pks=pc.flatten_peak_list(pks)
            pks[:,2]=pks[:,2]/255+CC.min()
            peaks[self.get_active_name()]=pks
        self.peaks=peaks
        
    def mask_peaks(self,image_id):
        mpeaks=np.ma.asarray(self.peaks[image_id])
        mpeaks[:,2]=np.ma.masked_outside(mpeaks[:,2],self.thresh[0],self.thresh[1])
        return mpeaks

    def crop_cells(self):
        rows = self.chest.root.cell_description.nrows
        if rows > 0:
            # remove the table
            self.chest.removeNode('/cell_description')
            try:
                # remove the table of peak characteristics - they are not valid.
                self.chest.removeNode('/cell_peaks')
            except:
                pass
            # recreate it
            self.chest.createTable('/', 'cell_description', CellsTable)
            # remove all existing entries in the data group
            for node in self.chest.listNodes('/cells'):
                self.chest.removeNode('/cells/' + node.name)
        # store the template
        template_data = self.template_data.get_data('imagedata')
        self.parent.add_cell_data(template_data, name="template")
        # TODO: set attribute that tells where the template came from
        row = self.chest.root.cell_description.row
        files=[]
        for idx in xrange(self.numfiles):
            # filter the peaks that are outside the selected threshold
            self.set_active_index(idx)
            active_image = self.get_active_image()
            peaks=np.ma.compress_rows(self.mask_peaks(self.get_active_name()))
            files.append(self.get_active_name())
            tmp_sz=self.template_size
            data=np.zeros((peaks.shape[0],tmp_sz,tmp_sz), 
                          dtype=active_image.dtype)
            if data.shape[0] >0:
                for i in xrange(peaks.shape[0]):
                    # store the peak in the table
                    row['file_idx'] = i
                    row['input_data'] = self.data_path
                    row['filename'] = self.get_active_name()
                    row['x_coordinate'] = peaks[i, 1]
                    row['y_coordinate'] = peaks[i, 0]
                    row.append()
                    # crop the cells from the given locations
                    data[i,:,:]=active_image[peaks[i, 1]:peaks[i, 1] + tmp_sz,
                                      peaks[i, 0]:peaks[i, 0] + tmp_sz]
                self.chest.root.cell_description.flush()
                self.parent.add_cell_data(data, name=self.get_active_name())
                # insert the data (one 3d array per file)
                self.chest.setNodeAttr('/cell_description', 'threshold', (self.thresh_lower, self.thresh_upper))
                self.chest.setNodeAttr('/cell_description', 'template_position', (self.template_left, self.template_top))
                self.chest.setNodeAttr('/cell_description', 'template_filename', self.template_filename)
                self.chest.setNodeAttr('/cell_description', 'template_size', (self.template_size))
                                       
                self.chest.root.cell_description.flush()
                self.chest.flush()
        average_data = np.average(data,axis=0).squeeze()
        self.parent.add_cell_data(average_data, name="average")
        self.parent.update_cell_data()
        self.log_action(action="crop cells", files=files, thresh=self.thresh, 
                        template_position=(self.template_left, self.template_top), 
                        template_size=self.template_size, 
                        template_filename=self.template_filename)
        Application.instance().end_session(self._session_id)
示例#11
0
class Plot1D(DataView):

#------------------------------------------------------------------------
    # Data-related traits
    #------------------------------------------------------------------------

    # The PlotData instance that drives this plot.
    data = Instance(AbstractPlotData)

    # Mapping of data names from self.data to their respective datasources.
    datasources = Dict(Str, Instance(AbstractDataSource))

    #------------------------------------------------------------------------
    # General plotting traits
    #------------------------------------------------------------------------

    # Mapping of plot names to *lists* of plot renderers.
    plots = Dict(Str, List)

    # The default index to use when adding new subplots.
    default_index = Instance(AbstractDataSource)

    # List of colors to cycle through when auto-coloring is requested. Picked
    # and ordered to be red-green color-blind friendly, though should not
    # be an issue for blue-yellow.
    auto_colors = List(["white", "red" , "blue","green", "lightblue",
                        "pink", "silver"])

    # index into auto_colors list
    _auto_color_idx = Int(-1)
    _auto_edge_color_idx = Int(-1)
    _auto_face_color_idx = Int(-1)

    # Mapping of renderer type string to renderer class
    # This can be overriden to customize what renderer type the Plot
    # will instantiate for its various plotting methods.
    renderer_map = Dict(dict(line = LinePlot,
                             scatter = ScatterPlot))

    #------------------------------------------------------------------------
    # Annotations and decorations
    #------------------------------------------------------------------------
    # The legend on the plot.
    legend = Instance(Legend)

    # Convenience attribute for legend.align; can be "ur", "ul", "ll", "lr".
    legend_alignment = Property

    def __init__(self, data=None, grid_color='yellow',  **kwtraits):
        if 'origin' in kwtraits:
            self.default_origin = kwtraits.pop('origin')
        if 'bgcolor' not in kwtraits:
            kwtraits['bgcolor'] = 'black'
        super(Plot1D, self).__init__(**kwtraits)

        self.x_grid.line_color = grid_color
        self.y_grid.line_color = grid_color
        self.padding = (65,10,10,50)

        if data is not None:
            if isinstance(data, AbstractPlotData):
                self.data = data
            elif type(data) in (ndarray, tuple, list):
                self.data = ArrayPlotData(data)
            else:
                raise ValueError, "Don't know how to create PlotData for data" \
                                  "of type " + str(type(data))

        if not self.legend:
            self.legend = Legend(visible=False, align="ur", error_icon="blank",
                                 padding=10, component=self)

        # ensure that we only get displayed once by new_window()
        self._plot_ui_info = None

        return

    def plot(self, data, type="line", name=None, index_scale="linear",
             value_scale="linear", origin=None, **styles):
        """ Adds a new sub-plot using the given data and plot style.

        Parameters
        ==========
        data : string, tuple(string), list(string)
            The data to be plotted. The type of plot and the number of
            arguments determines how the arguments are interpreted:

            one item: (line/scatter)
                The data is treated as the value and self.default_index is
                used as the index.  If **default_index** does not exist, one is
                created from arange(len(*data*))
            two or more items: (line/scatter)
                Interpreted as (index, value1, value2, ...).  Each index,value
                pair forms a new plot of the type specified.

        type : comma-delimited string of "line", "scatter", "cmap_scatter"
            The types of plots to add.
        name : string
            The name of the plot.  If None, then a default one is created
            (usually "plotNNN").
        index_scale : string
            The type of scale to use for the index axis. If not "linear", then
            a log scale is used.
        value_scale : string
            The type of scale to use for the value axis. If not "linear", then
            a log scale is used.
        origin : string
            Which corner the origin of this plot should occupy:
                "bottom left", "top left", "bottom right", "top right"
        styles : series of keyword arguments
            attributes and values that apply to one or more of the
            plot types requested, e.g.,'line_color' or 'line_width'.

        Examples
        ========
        ::

            plot("my_data", type="line", name="myplot", color=lightblue)

            plot(("x-data", "y-data"), type="scatter")

            plot(("x", "y1", "y2", "y3"))

        Returns
        =======
        [renderers] -> list of renderers created in response to this call to plot()
        """

        if len(data) == 0:
            return

        if isinstance(data, basestring):
            data = (data,)

        self.index_scale = index_scale
        self.value_scale = value_scale

        # TODO: support lists of plot types
        plot_type = type
        if name is None:
            name = self._make_new_plot_name()
        if origin is None:
            origin = self.default_origin
        if plot_type in ("line", "scatter"):
            if len(data) == 1:
                if self.default_index is None:
                    # Create the default index based on the length of the first
                    # data series
                    value = self._get_or_create_datasource(data[0],sort_order="ascending")
                    self.default_index = ArrayDataSource(arange(len(value.get_data())),
                                                         sort_order="ascending")
                    self.index_range.add(self.default_index)
                index = self.default_index
            else:
                index = self._get_or_create_datasource(data[0])
                if self.default_index is None:
                    self.default_index = index
                self.index_range.add(index)
                data = data[1:]

            new_plots = []
            simple_plot_types = ("line", "scatter")
            for value_name in data:
                value = self._get_or_create_datasource(value_name)
                self.value_range.add(value)
                if plot_type in simple_plot_types:
                    cls = self.renderer_map[plot_type]
                    # handle auto-coloring request
                    if styles.get("color") == "auto":
                        self._auto_color_idx = \
                            (self._auto_color_idx + 1) % len(self.auto_colors)
                        styles["color"] = self.auto_colors[self._auto_color_idx]
                else:
                    raise ValueError("Unhandled plot type: " + plot_type)

                if self.index_scale == "linear":
                    imap = LinearMapper(range=self.index_range,
                                stretch_data=self.index_mapper.stretch_data)
                else:
                    imap = LogMapper(range=self.index_range,
                                stretch_data=self.index_mapper.stretch_data)
                if self.value_scale == "linear":
                    vmap = LinearMapper(range=self.value_range,
                                stretch_data=self.value_mapper.stretch_data)
                else:
                    vmap = LogMapper(range=self.value_range,
                                stretch_data=self.value_mapper.stretch_data)

                plot = cls(index=index,
                           value=value,
                           index_mapper=imap,
                           value_mapper=vmap,
                           orientation=self.orientation,
                           origin = origin,
                           **styles)
                self.add(plot)
                new_plots.append(plot)

            self.plots[name] = new_plots

        else:
            raise ValueError("Unknown plot type: " + plot_type)

        return self.plots[name]

    def delplot(self, *names):
        """ Removes the named sub-plots. """

        # This process involves removing the plots, then checking the index range
        # and value range for leftover datasources, and removing those if necessary.

        # Remove all the renderers from us (container) and create a set of the
        # datasources that we might have to remove from the ranges
        deleted_sources = set()
        for renderer in itertools.chain(*[self.plots.pop(name) for name in names]):
            self.remove(renderer)
            deleted_sources.add(renderer.index)
            deleted_sources.add(renderer.value)

        #~ #Go back in the auto-coloring index
        for name in names:
            self._auto_color_idx = \
                            (self._auto_color_idx - 1) % len(self.auto_colors)

        # Cull the candidate list of sources to remove by checking the other plots
        sources_in_use = set()
        for p in itertools.chain(*self.plots.values()):
                sources_in_use.add(p.index)
                sources_in_use.add(p.value)

        unused_sources = deleted_sources - sources_in_use - set([None])

        # Remove the unused sources from all ranges and delete them
        for source in unused_sources:
            if source.index_dimension == "scalar":
                # Try both index and range, it doesn't hurt
                self.index_range.remove(source)
                self.value_range.remove(source)
            elif source.index_dimension == "image":
                self.range2d.remove(source)
            else:
                warnings.warn("Couldn't remove datasource from datarange.")

        #Remove the unused sources from the data sources
        for name in names:
            if self.datasources[name] in unused_sources:
                del self.datasources[name]

        return

    #------------------------------------------------------------------------
    # Private methods
    #------------------------------------------------------------------------



    def _make_new_plot_name(self):
        """ Returns a string that is not already used as a plot title.
        """
        n = len(self.plots)
        plot_template = "plot%d"
        while 1:
            name = plot_template % n
            if name not in self.plots:
                break
            else:
                n += 1
        return name

    def _get_or_create_datasource(self, name, sort_order = 'none'):
        """ Returns the data source associated with the given name, or creates
        it if it doesn't exist.
        """

        if name not in self.datasources:
            data = self.data.get_data(name)

            if type(data) in (list, tuple):
                data = array(data)

            if isinstance(data, ndarray):
                if len(data.shape) == 1:
                    ds = ArrayDataSource(data, sort_order=sort_order)
                else:
                    raise ValueError("Unhandled array shape in creating new plot: " \
                                     + str(data.shape))

            elif isinstance(data, AbstractDataSource):
                ds = data

            else:
                raise ValueError("Couldn't create datasource for data of type " + \
                                 str(type(data)))

            self.datasources[name] = ds

        return self.datasources[name]

    #------------------------------------------------------------------------
    # Event handlers
    #------------------------------------------------------------------------

    def _data_changed(self, old, new):
        if old:
            old.on_trait_change(self._data_update_handler, "data_changed",
                                remove=True, dispatch = 'ui')
        if new:
            new.on_trait_change(self._data_update_handler, "data_changed",
                                dispatch = 'ui')

    def _data_update_handler(self, name, event):
        # event should be a dict with keys "added", "removed", and "changed",
        # per the comments in AbstractPlotData.
        if event.has_key("added"):
            pass

        if event.has_key("removed"):
            pass

        if event.has_key("changed"):
            for name in event["changed"]:

                if self.datasources.has_key(name):
                    if self.datasources[name] in self.index_range.sources:
                        index = self.index_range
                        if (index.low_setting == 'auto' and\
                                        index.high_setting != 'auto'):
                            index.set_high('auto')
                        elif(index.low_setting != 'auto' and\
                                        index.high_setting == 'auto'):
                            index.set_low('auto')

                    if self.datasources[name] in self.value_range.sources:
                        value = self.value_range
                        if (value.low_setting == 'auto' and\
                                        value.high_setting != 'auto'):
                            value.set_high('auto')
                        elif(value.low_setting != 'auto' and\
                                        value.high_setting == 'auto'):
                            value.set_low('auto')

                    source = self.datasources[name]
                    source.set_data(self.data.get_data(name))

    def _plots_items_changed(self, event):
        if self.legend:
            self.legend.plots = self.plots

    def _index_scale_changed(self, old, new):
        if old is None: return
        if new == old: return
        if not self.range2d: return
        if self.index_scale == "linear":
            imap = LinearMapper(range=self.index_range,
                                screen_bounds=self.index_mapper.screen_bounds,
                                stretch_data=self.index_mapper.stretch_data)
        else:
            imap = LogMapper(range=self.index_range,
                             screen_bounds=self.index_mapper.screen_bounds,
                             stretch_data=self.index_mapper.stretch_data)
        self.index_mapper = imap
        for key in self.plots:
            for plot in self.plots[key]:
                if not isinstance(plot, BaseXYPlot):
                    raise ValueError("log scale only supported on XY plots")
                if self.index_scale == "linear":
                    imap = LinearMapper(range=plot.index_range,
                                screen_bounds=plot.index_mapper.screen_bounds,
                                stretch_data=self.index_mapper.stretch_data)
                else:
                    imap = LogMapper(range=plot.index_range,
                                screen_bounds=plot.index_mapper.screen_bounds,
                                stretch_data=self.index_mapper.stretch_data)
                plot.index_mapper = imap

    def _value_scale_changed(self, old, new):
        if old is None: return
        if new == old: return
        if not self.range2d: return
        if self.value_scale == "linear":
            vmap = LinearMapper(range=self.value_range,
                                screen_bounds=self.value_mapper.screen_bounds,
                                stretch_data=self.value_mapper.stretch_data)
        else:
            vmap = LogMapper(range=self.value_range,
                             screen_bounds=self.value_mapper.screen_bounds,
                                stretch_data=self.value_mapper.stretch_data)
        self.value_mapper = vmap
        for key in self.plots:
            for plot in self.plots[key]:
                if not isinstance(plot, BaseXYPlot):
                    raise ValueError("log scale only supported on XY plots")
                if self.value_scale == "linear":
                    vmap = LinearMapper(range=plot.value_range,
                                screen_bounds=plot.value_mapper.screen_bounds,
                                stretch_data=self.value_mapper.stretch_data)
                else:
                    vmap = LogMapper(range=plot.value_range,
                                screen_bounds=plot.value_mapper.screen_bounds,
                                stretch_data=self.value_mapper.stretch_data)
                plot.value_mapper = vmap


    def _legend_changed(self, old, new):
        self._overlay_change_helper(old, new)
        if new:
            new.plots = self.plots

    def _handle_range_changed(self, name, old, new):
        """ Overrides the DataView default behavior.

        Primarily changes how the list of renderers is looked up.
        """
        mapper = getattr(self, name+"_mapper")
        if mapper.range == old:
            mapper.range = new
        if old is not None:
            for datasource in old.sources[:]:
                old.remove(datasource)
                if new is not None:
                    new.add(datasource)
        range_name = name + "_range"
        for renderer in itertools.chain(*self.plots.values()):
            if hasattr(renderer, range_name):
                setattr(renderer, range_name, new)

    #------------------------------------------------------------------------
    # Property getters and setters
    #------------------------------------------------------------------------

    def _set_legend_alignment(self, align):
        if self.legend:
            self.legend.align = align

    def _get_legend_alignment(self):
        if self.legend:
            return self.legend.align
        else:
            return None
示例#12
0
class CellCropController(BaseImageController):
    zero = Int(0)
    template_plot = Instance(BasePlotContainer)
    template_data = Instance(ArrayPlotData)
    template_size = Range(low=2, high=512, value=64, cols=4)
    template_top = Range(low='zero', high='max_pos_y', value=20, cols=4)
    template_left = Range(low='zero', high='max_pos_x', value=20, cols=4)
    peaks = Dict({})
    ShowCC = Bool(False)
    max_pos_x = Int(256)
    max_pos_y = Int(256)
    is_square = Bool(True)
    peak_width = Range(low=2, high=200, value=10)
    numpeaks_total = Int(0, cols=5)
    numpeaks_img = Int(0, cols=5)
    _session_id = String('')

    def __init__(self,
                 parent,
                 treasure_chest=None,
                 data_path='/rawdata',
                 *args,
                 **kw):
        super(CellCropController, self).__init__(parent, treasure_chest,
                                                 data_path, *args, **kw)

        if self.chest is not None:
            self.numfiles = len(self.nodes)
            if self.numfiles > 0:
                self.init_plot()

    def data_updated(self):
        # reinitialize data
        self.__init__(parent=self.parent,
                      treasure_chest=self.chest,
                      data_path=self.data_path)

    def init_plot(self):
        self.plotdata.set_data('imagedata', self.get_active_image())
        self.plot = self.get_scatter_overlay_plot(
            array_plot_data=self.plotdata,
            title=self.get_active_name(),
            tools=['csr', 'colorbar', 'zoom', 'pan'])
        # pick an initial template with default parameters
        self.template_data = ArrayPlotData()
        self.template_plot = Plot(self.template_data,
                                  default_origin="top left")
        self.template_data.set_data(
            'imagedata',
            self.get_active_image()[self.template_top:self.template_top +
                                    self.template_size,
                                    self.template_left:self.template_left +
                                    self.template_size])
        self.template_plot.img_plot('imagedata', title="Template")
        self.template_plot.aspect_ratio = 1  #square templates
        self.template_filename = self.get_active_name()
        self._get_max_positions()

    @on_trait_change("selected_index, ShowCC")
    def update_image(self):
        if self.ShowCC:
            CC = cv_funcs.xcorr(self.template_data.get_data('imagedata'),
                                self.get_active_image())
            self.plotdata.set_data("imagedata", CC)
            self.plot = self.get_scatter_overlay_plot(
                array_plot_data=self.plotdata,
                title=self.get_active_name(),
                tools=['csr', 'zoom', 'pan', 'colorbar'],
            )
            self.plot.aspect_ratio = (float(CC.shape[1]) / CC.shape[0])
            self.set_plot_title("Cross correlation of " +
                                self.get_active_name())
            grid_data_source = self._base_plot.range2d.sources[0]
            grid_data_source.set_data(np.arange(CC.shape[1]),
                                      np.arange(CC.shape[0]))
        else:
            self.plotdata.set_data("imagedata", self.get_active_image())
            self.plot = self.get_scatter_overlay_plot(
                array_plot_data=self.plotdata,
                title=self.get_active_name(),
                tools=['csr', 'zoom', 'pan', 'colorbar'],
            )
            self.plot.aspect_ratio = (float(self.get_active_image().shape[1]) /
                                      self.get_active_image().shape[0])
            self.set_plot_title(self.get_active_name())
            grid_data_source = self._base_plot.range2d.sources[0]
            grid_data_source.set_data(
                np.arange(self.get_active_image().shape[1]),
                np.arange(self.get_active_image().shape[0]))

    def update_CC(self):
        if self.ShowCC:
            CC = cv_funcs.xcorr(self.template_data.get_data('imagedata'),
                                self.get_active_image())
            self.plotdata.set_data("imagedata", CC)

    @on_trait_change('template_left, template_top, template_size')
    def update_template_data(self):
        self.template_data.set_data(
            'imagedata',
            self.get_active_image()[self.template_top:self.template_top +
                                    self.template_size,
                                    self.template_left:self.template_left +
                                    self.template_size])
        self.template_filename = self.get_active_name()
        if self.numpeaks_total > 0:
            print "clearing peaks"
            self.peaks = {}
        # when template data changes, we should check whether to update the
        #    cross correlation plot, which depends on the template
        self.update_CC()

    @on_trait_change('selected_index, template_size')
    def _get_max_positions(self):
        max_pos_x = self.get_active_image().shape[-1] - self.template_size - 1
        if max_pos_x > 0:
            self.max_pos_x = int(max_pos_x)
        max_pos_y = self.get_active_image().shape[-2] - self.template_size - 1
        if max_pos_y > 0:
            self.max_pos_y = int(max_pos_y)

    @on_trait_change('template_left, template_top')
    def update_csr_position(self):
        if self.template_left > 0:
            self._csr.current_position = self.template_left, self.template_top
        pass

    @on_trait_change('_csr:current_position')
    def update_top_left(self):
        if self._csr.current_position[0] > 0 or self._csr.current_position[
                1] > 0:
            if self._csr.current_position[0] > self.max_pos_x:
                if self._csr.current_position[1] < self.max_pos_y:
                    self.template_top = self._csr.current_position[1]
                else:
                    self._csr.current_position = self.max_pos_x, self.max_pos_y
            elif self._csr.current_position[1] > self.max_pos_y:
                self.template_left, self.template_top = self._csr.current_position[
                    0], self.max_pos_y
            else:
                self.template_left, self.template_top = self._csr.current_position

    @on_trait_change('_colorbar_selection:selection')
    def update_thresh(self):
        try:
            thresh = self._colorbar_selection.selection
            self.thresh = thresh
            scatter_renderer = self._scatter_plot.plots['scatter_plot'][0]
            scatter_renderer.color_data.metadata['selections'] = thresh
            self.thresh_lower = thresh[0]
            self.thresh_upper = thresh[1]
            scatter_renderer.color_data.metadata_changed = {
                'selections': thresh
            }
            self.plot.request_redraw()
        except:
            pass

    @on_trait_change('thresh_upper,thresh_lower')
    def manual_thresh_update(self):
        self.thresh = [self.thresh_lower, self.thresh_upper]
        scatter_renderer = self._scatter_plot.plots['scatter_plot'][0]
        scatter_renderer.color_data.metadata['selections'] = self.thresh
        scatter_renderer.color_data.metadata_changed = {
            'selections': self.thresh
        }
        self.plot.request_redraw()

    @on_trait_change('peaks, _colorbar_selection:selection, selected_index')
    def calc_numpeaks(self):
        try:
            thresh = self._colorbar_selection.selection
            self.thresh = thresh
        except:
            thresh = []
        if thresh == [] or thresh == () or thresh == None:
            thresh = (-1, 1)
        self.numpeaks_total = int(
            np.sum([
                np.sum(
                    np.ma.masked_inside(self.peaks[image_id][:, 2], thresh[0],
                                        thresh[1]).mask)
                for image_id in self.peaks.keys()
            ]))
        try:
            self.numpeaks_img = int(
                np.sum(
                    np.ma.masked_inside(
                        self.peaks[self.get_active_name()][:, 2], thresh[0],
                        thresh[1]).mask))
        except:
            self.numpeaks_img = 0

    @on_trait_change('peaks, selected_index')
    def update_scatter_plot(self):
        data = self.plotdata.get_data('imagedata')
        aspect_ratio = (float(data.shape[1]) / data.shape[0])
        if self.get_active_name() in self.peaks:
            self.plotdata.set_data("index",
                                   self.peaks[self.get_active_name()][:, 1])
            self.plotdata.set_data("value",
                                   self.peaks[self.get_active_name()][:, 0])
            self.plotdata.set_data("color",
                                   self.peaks[self.get_active_name()][:, 2])
            self.plot = self.get_scatter_overlay_plot(
                array_plot_data=self.plotdata,
                tools=['zoom', 'pan', 'colorbar'])
            scatter_renderer = self._scatter_plot.plots['scatter_plot'][0]
            scatter_renderer.color_data.metadata['selections'] = self.thresh
            scatter_renderer.color_data.metadata_changed = {
                'selections': self.thresh
            }
        else:
            if 'index' in self.plotdata.arrays:
                self.plotdata.del_data('index')
                # value will implicitly exist if value exists.
                self.plotdata.del_data('value')
            if 'color' in self.plotdata.arrays:
                self.plotdata.del_data('color')
            self.plot = self.get_scatter_overlay_plot(
                array_plot_data=self.plotdata, )

    def locate_peaks(self):
        peaks = {}
        for idx in xrange(self.numfiles):
            self.set_active_index(idx)
            CC = cv_funcs.xcorr(self.template_data.get_data("imagedata"),
                                self.get_active_image())
            #
            pks = pc.two_dim_findpeaks((CC - CC.min()) * 255, xc_filter=False)
            pks[:, 2] = pks[:, 2] / 255 + CC.min()
            peaks[self.get_active_name()] = pks
        self.peaks = peaks

    def mask_peaks(self, image_id):
        mpeaks = np.ma.asarray(self.peaks[image_id])
        mpeaks[:, 2] = np.ma.masked_outside(mpeaks[:, 2], self.thresh[0],
                                            self.thresh[1])
        return mpeaks

    def crop_cells(self):
        rows = self.chest.root.cell_description.nrows
        if rows > 0:
            # remove the table
            self.chest.remove_node('/cell_description')
            try:
                # remove the table of peak characteristics - they are not valid.
                self.chest.remove_node('/cell_peaks')
            except:
                pass
            # recreate it
            self.chest.create_table('/', 'cell_description', CellsTable)
            # remove all existing entries in the data group
            for node in self.chest.list_nodes('/cells'):
                self.chest.remove_node('/cells/' + node.name)
        # store the template
        template_data = self.template_data.get_data('imagedata')
        self.parent.add_cell_data(template_data, name="template")
        # TODO: set attribute that tells where the template came from
        row = self.chest.root.cell_description.row
        files = []
        for idx in xrange(self.numfiles):
            # filter the peaks that are outside the selected threshold
            self.set_active_index(idx)
            active_image = self.get_active_image()
            peaks = np.ma.compress_rows(self.mask_peaks(
                self.get_active_name()))
            files.append(self.get_active_name())
            tmp_sz = self.template_size
            data = np.zeros((peaks.shape[0], tmp_sz, tmp_sz),
                            dtype=active_image.dtype)
            if data.shape[0] > 0:
                for i in xrange(peaks.shape[0]):
                    # store the peak in the table
                    row['file_idx'] = i
                    row['input_data'] = self.data_path
                    row['filename'] = self.get_active_name()
                    row['x_coordinate'] = peaks[i, 0]
                    row['y_coordinate'] = peaks[i, 1]
                    row.append()
                    # crop the cells from the given locations
                    data[i, :, :] = active_image[
                        int(peaks[i, 0]):int(peaks[i, 0] + tmp_sz),
                        int(peaks[i, 1]):int(peaks[i, 1] + tmp_sz)]
                self.chest.root.cell_description.flush()
                self.parent.add_cell_data(data, name=self.get_active_name())
                # insert the data (one 3d array per file)
                self.chest.set_node_attr(
                    '/cell_description', 'threshold',
                    (self.thresh_lower, self.thresh_upper))
                self.chest.set_node_attr(
                    '/cell_description', 'template_position',
                    (self.template_left, self.template_top))
                self.chest.set_node_attr('/cell_description',
                                         'template_filename',
                                         self.template_filename)
                self.chest.set_node_attr('/cell_description', 'template_size',
                                         (self.template_size))

                self.chest.root.cell_description.flush()
                self.chest.flush()
        average_data = np.average(data, axis=0).squeeze()
        self.parent.add_cell_data(average_data, name="average")
        row = self.chest.root.cell_description.row
        row['file_idx'] = 0
        row['input_data'] = self.data_path
        row['filename'] = "average"
        row['x_coordinate'] = 0
        row['y_coordinate'] = 0
        row.append()
        self.chest.root.cell_description.flush()
        self.parent.update_cell_data()
        self.parent.add_image_data(average_data, "average")
        self.log_action(action="crop cells",
                        files=files,
                        thresh=self.thresh,
                        template_position=(self.template_left,
                                           self.template_top),
                        template_size=self.template_size,
                        template_filename=self.template_filename)
        Application.instance().end_session(self._session_id)
示例#13
0
class Plotter2D(HasPreferenceTraits):

    plot = Instance(Plot2D)
    colorbar = Instance(ColorBar)
    container = Instance(HPlotContainer)

    zoom_bar_plot = Instance(ZoomBar)
    zoom_bar_colorbar = Instance(ZoomBar)
    pan_bar = Instance(PanBar)
    range_bar = Instance(RangeBar)

    data = Instance(ArrayPlotData,())
    x_min = Float(0.0)
    x_max = Float(1.0)
    y_min = Float(0.0)
    y_max = Float(1.0)

    add_contour = Bool(False)
    x_axis_label = Str
    y_axis_label = Str
    c_axis_label = Str

    x_axis_formatter = Instance(AxisFormatter)
    y_axis_formatter = Instance(AxisFormatter)
    c_axis_formatter = Instance(AxisFormatter)

    colormap = Enum(color_map_name_dict.keys(), preference = 'async')
    _cmap = Trait(Greys, Callable)

    update_index = Event

    traits_view = View(
                    Group(
                        Group(
                            UItem('container', editor=ComponentEditor()),
                            VGroup(
                                UItem('zoom_bar_colorbar',style = 'custom'),
                                ),
                            orientation = 'horizontal',
                            ),
                        Group(
                            Group(
                                UItem('zoom_bar_plot', style = 'custom'),
                                UItem('pan_bar', style = 'custom'),
                                UItem('range_bar', style = 'custom'),
                                Group(
                                    UItem('colormap'),
                                    label = 'Color map',
                                    ),
                                orientation = 'horizontal',
                                ),
                            orientation = 'vertical',
                            ),
                        orientation = 'vertical',
                        ),
                    resizable=True
                    )

    preference_view = View(
                        HGroup(
                            VGroup(
                                Item('x_axis_formatter', style = 'custom',
                                     editor = InstanceEditor(
                                                 view = 'preference_view'),
                                     label = 'X axis',
                                     ),
                                Item('y_axis_formatter', style = 'custom',
                                     editor = InstanceEditor(
                                                 view = 'preference_view'),
                                     label = 'Y axis',
                                     ),
                                ),
                            Item('c_axis_formatter', style = 'custom',
                                 editor = InstanceEditor(
                                             view = 'preference_view'),
                                 label = 'C axis',
                                 ),
                            show_border = True,
                            label = 'Axis format',
                            ),
                        )

    def __init__(self, **kwargs):

        super(Plotter2D, self).__init__(**kwargs)

        self.x_axis_formatter = AxisFormatter(pref_name = 'X axis format',
                                              pref_parent = self)
        self.y_axis_formatter = AxisFormatter(pref_name = 'Y axis format',
                                              pref_parent = self)
        self.c_axis_formatter = AxisFormatter(pref_name = 'C axis format',
                                              pref_parent = self)

        self.data = ArrayPlotData()
        self.plot = Plot2D(self.data)
        self.plot.padding = (80,50,10,40)
        self.plot.x_axis.tick_label_formatter =\
                        self.x_axis_formatter.float_format
        self.plot.y_axis.tick_label_formatter =\
                        self.y_axis_formatter.float_format
        self.pan_bar = PanBar(self.plot)
        self.zoom_bar_plot = zoom_bar(self.plot,x = True,\
                                        y = True, reset = True
                                        )

        #Dummy plot so that the color bar can be correctly initialized
        xs = linspace(-2, 2, 600)
        ys = linspace(-1.2, 1.2, 300)
        self.x_min = xs[0]
        self.x_max = xs[-1]
        self.y_min = ys[0]
        self.y_max = ys[-1]
        x, y = meshgrid(xs,ys)
        z = tanh(x*y/6)*cosh(exp(-y**2)*x/3)
        z = x*y
        self.data.set_data('c',z)
        self.plot.img_plot(('c'),\
                                name = 'c',
                                colormap = self._cmap,
                                xbounds = (self.x_min,self.x_max),
                                ybounds = (self.y_min,self.y_max),
                                )

        # Create the colorbar, the appropriate range and colormap are handled
        # at the plot creation

        self.colorbar = ColorBar(
                            index_mapper = LinearMapper(range = \
                                            self.plot.color_mapper.range),
                            color_mapper=self.plot.color_mapper,
                            plot = self.plot,
                            orientation='v',
                            resizable='v',
                            width=20,
                            padding=10)

        self.colorbar.padding_top = self.plot.padding_top
        self.colorbar.padding_bottom = self.plot.padding_bottom

        self.colorbar._axis.tick_label_formatter =\
                                self.c_axis_formatter.float_format

        self.container = HPlotContainer(self.plot,
                                    self.colorbar,
                                    use_backbuffer=True,
                                    bgcolor="lightgray")

        # Add pan and zoom tools to the colorbar
        self.colorbar.tools.append(PanTool(self.colorbar,\
                                        constrain_direction="y",\
                                        constrain=True)
                                )
        self.zoom_bar_colorbar = zoom_bar(self.colorbar,
                                          box = False,
                                          reset=True,
                                          orientation = 'vertical'
                                        )

        # Add the range bar now that we are sure that we have a color_mapper
        self.range_bar = RangeBar(self.plot)
        self.x_axis_label = 'X'
        self.y_axis_label = 'Y'
        self.c_axis_label = 'C'
        self.sync_trait('x_axis_label',self.range_bar,alias = 'x_name')
        self.sync_trait('y_axis_label',self.range_bar,alias = 'y_name')
        self.sync_trait('c_axis_label',self.range_bar,alias = 'c_name')

        #Dynamically bing the update methods for trait likely to be updated
        #from other thread
        self.on_trait_change(self.new_x_label, 'x_axis_label',
                             dispatch = 'ui')
        self.on_trait_change(self.new_y_label, 'y_axis_label',
                             dispatch = 'ui')
        self.on_trait_change(self.new_c_label, 'c_axis_label',
                             dispatch = 'ui')
        self.on_trait_change(self.new_x_axis_format, 'x_axis_formatter.+',
                             dispatch = 'ui')
        self.on_trait_change(self.new_y_axis_format, 'y_axis_formatter.+',
                             dispatch = 'ui')
        self.on_trait_change(self.new_c_axis_format, 'c_axis_formatter.+',
                             dispatch = 'ui')
        self.on_trait_change(self._update_plots_index, 'update_index',
                             dispatch = 'ui')

        #set the default colormap in the editor
        self.colormap = 'Blues'

        self.preference_init()

    #@on_trait_change('x_axis_label', dispatch = 'ui')
    def new_x_label(self,new):
        self.plot.x_axis.title = new

    #@on_trait_change('y_axis_label', dispatch = 'ui')
    def new_y_label(self,new):
        self.plot.y_axis.title = new

    #@on_trait_change('c_axis_label', dispatch = 'ui')
    def new_c_label(self,new):
        self.colorbar._axis.title = new

    @on_trait_change('colormap')
    def new_colormap(self, new):
        self._cmap = color_map_name_dict[new]
        for plots in self.plot.plots.itervalues():
            for plot in plots:
                if isinstance(plot,ImagePlot) or\
                    isinstance(plot,CMapImagePlot) or\
                    isinstance(plot,ContourPolyPlot):
                    value_range = plot.color_mapper.range
                    plot.color_mapper = self._cmap(value_range)
                    self.plot.color_mapper = self._cmap(value_range)

        self.container.request_redraw()

    #@on_trait_change('x_axis_formatter', dispatch = 'ui')
    def new_x_axis_format(self):
        self.plot.x_axis._invalidate()
        self.plot.invalidate_and_redraw()

    #@on_trait_change('y_axis_formatter', dispatch = 'ui')
    def new_y_axis_format(self):
        self.plot.y_axis._invalidate()
        self.plot.invalidate_and_redraw()

    #@on_trait_change('y_axis_formatter', dispatch = 'ui')
    def new_c_axis_format(self):
        self.colorbar._axis._invalidate()
        self.plot.invalidate_and_redraw()

    def request_update_plots_index(self):
        self.update_index = True

    #@on_trait_change('update_index', dispatch = 'ui')
    def _update_plots_index(self):
        if 'c' in self.data.list_data():
            array = self.data.get_data('c')
            xs = linspace(self.x_min, self.x_max, array.shape[1] + 1)
            ys = linspace(self.y_min, self.y_max, array.shape[0] + 1)
            self.plot.range2d.remove(self.plot.index)
            self.plot.index = GridDataSource(xs, ys,
                                        sort_order=('ascending', 'ascending'))
            self.plot.range2d.add(self.plot.index)
            for plots in self.plot.plots.itervalues():
                for plot in plots:
                    plot.index = GridDataSource(xs, ys,
                                        sort_order=('ascending', 'ascending'))
示例#14
0
class Demo(HasTraits):
    pd = Instance(ArrayPlotData, ())
    plot = Instance(HPlotContainer)
    
    _load_file = File(
        find_resource('imageAlignment', '../images/GIRLS-IN-SPACE.jpg',
        '../images/GIRLS-IN-SPACE.jpg', return_path=True))
    _save_file = File
    
    load_file_view = View(
        Item('_load_file'),
        buttons=OKCancelButtons,
        kind='livemodal',
        width=400,
        resizable=True,
    )
    
    save_file_view = View(
        Item('_save_file'),
        buttons=OKCancelButtons,
        kind='livemodal',
        width=400,
        resizable=True,
    )
    
    def __init__(self, *args, **kwargs):
        super(Demo, self).__init__(*args, **kwargs)
        
        from imread import imread
        imarray = imread(find_resource('imageAlignment', '../images/GIRLS-IN-SPACE.jpg',
            '../images/GIRLS-IN-SPACE.jpg', return_path=True))
        
        self.pd = ArrayPlotData(imagedata=imarray)
        #self.pd.x_axis.orientation = "top"
        self.plot = HPlotContainer()
        
        titles = ["I KEEP DANCE", "ING ON MY OWN"]
        
        
        self._load()
        
        i = 0
        for plc in [Plot, Plot]:
            xs = linspace(0, 334*pi, 333)
            ys = linspace(0, 334*pi, 333)
            x, y = meshgrid(xs,ys)
            z = tanh(x*y/6)*cosh(exp(-y**2)*x/3)
            z = x*y
            
            _pd = ArrayPlotData()
            _pd.set_data("drawdata", z)
            _pd.set_data("imagedata", self.pd.get_data('imagedata'))
            
            plc = Plot(_pd,
                title="render_style = hold",
                padding=50, border_visible=True, overlay_border=True)
            
            self.plot.add(plc)
            
            plc.img_plot("imagedata",
                alpha=0.95)
            
            # Create a contour polygon plot of the data
            plc.contour_plot("drawdata",
                              type="poly",
                              poly_cmap=jet,
                              xbounds=(0, 499),
                              ybounds=(0, 582),
                              alpha=0.35)
            
            # Create a contour line plot for the data, too
            plc.contour_plot("drawdata",
                              type="line",
                              xbounds=(0, 499),
                              ybounds=(0, 582),
                              alpha=0.35)
            
            # Create a plot data obect and give it this data
            plc.legend.visible = True
            plc.title = titles[i]
            i += 1
            
            #plc.plot(("index", "y0"), name="j_0", color="red", render_style="hold")
            
            #plc.padding = 50
            #plc.padding_top = 75
            plc.tools.append(PanTool(plc))
            zoom = ZoomTool(component=plc, tool_mode="box", always_on=False)
            plc.overlays.append(zoom)
            
            # Tweak some of the plot properties
            plc.padding = 50
            #zoom = ZoomTool(component=plot1, tool_mode="box", always_on=False)
            #plot1.overlays.append(zoom)
            
            # Attach some tools to the plot
            #attach_tools(plc)
            plc.bg_color = None
            plc.fill_padding = True
            

    def default_traits_view(self):
        traits_view = View(
            Group(
                Item('plot',
                    editor=ComponentEditor(size=size),
                    show_label=False),
                orientation="vertical"),
            menubar=MenuBar(
                Menu(Action(name="Save Plot", action="save"),
                     Action(name="Load Plot", action="load"),
                     Separator(), CloseAction,
                     name="File")),
            resizable=True,
            title=title,
            handler=ImageFileController)
        return traits_view
    
    '''
    def _plot_default(self):
        
        # Create some x-y data series to plot
        x = linspace(-2.0, 10.0, 400)
        self.pd = pd = ArrayPlotData(index=x, y0=jn(0,x), default_origin="top left")

        # Create some line plots of some of the data
        plot1 = Plot(self.pd,
            title="render_style = hold",
            padding=50, border_visible=True, overlay_border=True)

        plot1.legend.visible = True
        plot1.plot(("index", "y0"), name="j_0", color="red", render_style="hold")
        
        plot1.padding = 50
        plot1.padding_top = 75
        plot1.tools.append(PanTool(plot1))
        #zoom = ZoomTool(component=plot1, tool_mode="box", always_on=False)
        #plot1.overlays.append(zoom)

        # Attach some tools to the plot
        attach_tools(plot1)

        # Create a second scatter plot of one of the datasets, linking its
        # range to the first plot
        plot2 = Plot(self.pd, range2d=plot1.range2d,
            title="render_style = connectedhold",
            padding=50, border_visible=True, overlay_border=True)

        plot2.plot(('index', 'y0'), color="blue", render_style="connectedhold")
        
        plot2.padding = 50
        plot2.padding_top = 75
        plot2.tools.append(PanTool(plot2))
        #zoom = ZoomTool(component=plot2, tool_mode="box", always_on=False)
        #plot2.overlays.append(zoom)
        
        attach_tools(plot2)

        # Create a container and add our plots
        container = HPlotContainer()
        container.add(plot1)
        container.add(plot2)
        return container
    '''
    
    def _save(self):
        win_size = self.plot.outer_bounds
        plot_gc = PlotGraphicsContext(win_size)
        plot_gc.render_component(self.plot)
        plot_gc.save(self._save_file)

    def _load(self):
        try:
            image = ImageData.fromfile(self._load_file)
            self.pd.set_data('imagedata', image._data)
            self.plot.title = "YO DOGG: %s" % os.path.basename(self._load_file)
            self.plot.request_redraw()
        except Exception, exc:
            print "YO DOGG: %s" % exc
class ElectrodeHysteresis(HasTraits):

    coercitive_field1 = Float(0.01)
    easy_axis_angle1 = Float(0)
    easy_axis_angle_input1 = Str('0')

    coercitive_field2 = Float(0.02)
    easy_axis_angle2 = Float(0)
    easy_axis_angle_input2 = Str('Pi/6')

    field_min = Float(-0.015)
    field_max = Float(0.015)
    points = Int(1000)

    cavity_freq = Float(6.7)
    cavity_quality = Float(3500)
    dqd_ready = Bool(False)
    dqd_t = Float(6.2)
    dqd_coupling = Float(10)
    dqd_delta = Float(3)
    dqd_gamma12 = Float(0.1)
    dqd_gamma13 = Float(0.1)
    lande = Float(28)

    data1 = Instance(ArrayPlotData)
    plot1 = Instance(Plot)
    data2 = Instance(ArrayPlotData)
    plot2 = Instance(Plot)
    data_diff = Instance(ArrayPlotData)
    plot_diff = Instance(Plot)
    data_g = Instance(ArrayPlotData)
    plot_g = Instance(Plot)


    compute = Button('Compute')

    traits_view = View(
                    VGroup(
                        HGroup(
                            VGroup(
                                Item('coercitive_field1'),
                                Item('easy_axis_angle_input1',
                                    editor = TextEditor(auto_set = False,
                                        enter_set = True),
                                    ),
                                Item('easy_axis_angle1', style = 'readonly'),
                                show_border = True,
                                label = 'Electrode 1',
                                ),
                            VGroup(
                                Item('coercitive_field2'),
                                Item('easy_axis_angle_input2',
                                    editor = TextEditor(auto_set = False,
                                        enter_set = True),
                                    ),
                                Item('easy_axis_angle2', style = 'readonly'),
                                show_border = True,
                                label = 'Electrode 2',
                                ),
                            ),
                        HGroup(
                            Item('field_min'),
                            Item('field_max'),
                            Item('points'),
                            ),
                        HGroup(
                            Spring(),
                            Heading('Increasing = blue, decreasing = red'),
                            Spring(),
                            UItem('compute'),
                            ),
                        HGroup(
                            UItem('plot1', editor = ComponentEditor()),
                            UItem('plot2', editor = ComponentEditor()),
                            UItem('plot_diff', editor = ComponentEditor()),
                            ),
                        HGroup(
                            VGroup(
                                HGroup(Item('cavity_freq'),Label('GHz')),
                                Item('cavity_quality'),
                                HGroup(Item('dqd_t'),Label('GHz')),
                                HGroup(Item('dqd_coupling'),Label('MHz')),
                                HGroup(Item('dqd_delta'),Label('GHz')),
                                HGroup(Item('dqd_gamma12'),Label('GHz')),
                                HGroup(Item('dqd_gamma13'),Label('GHz')),
                                HGroup(Item('lande', style = 'readonly'),
                                       Label('GHz/T')),
                                ),
                            UItem('plot_g', editor = ComponentEditor()),
                            ),
                    ),
                resizable = True,
                )

    def __init__(self):

        super(ElectrodeHysteresis,self).__init__()
        self.new_value1(self.easy_axis_angle_input1)
        self.new_value2(self.easy_axis_angle_input2)

        self.data1 = ArrayPlotData()
        self.plot1 = Plot(self.data1,title = 'Electrode 1')
        self.plot1.padding = (80,10,20,40)
        self.plot1.index_axis.title = 'Field (mT)'
        self.plot1.value_axis.title = 'Orientation (rad)'

        self.data2 = ArrayPlotData()
        self.plot2 = Plot(self.data2,title = 'Electrode 2')
        self.plot2.padding = (80,10,20,40)
        self.plot2.index_axis.title = 'Field (mT)'
        self.plot2.value_axis.title = 'Orientation (rad)'

        self.data_diff = ArrayPlotData()
        self.plot_diff = Plot(self.data_diff,title = 'Current')
        self.plot_diff.padding = (80,10,20,40)
        self.plot_diff.index_axis.title = 'Field (mT)'
        self.plot_diff.value_axis.title = 'Current (a.u.)'

        self.data_g = ArrayPlotData()
        self.plot_g = Plot(self.data_g,title = 'DQD response')
        self.plot_g.padding = (80,10,20,40)
        self.plot_g.index_axis.title = 'Field (mT)'
        self.plot_g.value_axis.title = 'Phase contrast (rad)'

        dummy_x = linspace(0,10,1000)
        dummy_y1 = 0*dummy_x
        dummy_y2 = 0*dummy_x

        self.data1.set_data('x',dummy_x)
        self.data1.set_data('y1',dummy_y1)
        self.data1.set_data('y2',dummy_y2)
        self.plot1.plot(('x','y1'),color = 'blue')
        self.plot1.plot(('x','y2'),color = 'red')

        self.data2.set_data('x',dummy_x)
        self.data2.set_data('y1',dummy_y1)
        self.data2.set_data('y2',dummy_y2)
        self.plot2.plot(('x','y1'),color = 'blue')
        self.plot2.plot(('x','y2'),color = 'red')

        self.data_diff.set_data('x',dummy_x)
        self.data_diff.set_data('y1',dummy_y1)
        self.data_diff.set_data('y2',dummy_y2)
        self.plot_diff.plot(('x','y1'),color = 'blue')
        self.plot_diff.plot(('x','y2'),color = 'red')

        self.data_g.set_data('x',dummy_x)
        self.data_g.set_data('y1',dummy_y1)
        self.data_g.set_data('y2',dummy_y2)
        self.plot_g.plot(('x','y1'),color = 'blue')
        self.plot_g.plot(('x','y2'),color = 'red')

    def _compute_fired(self):

        self.data1.set_data('x',
                            linspace(self.field_min,self.field_max,self.points))
        self.data2.set_data('x',
                            linspace(self.field_min,self.field_max,self.points))
        self.data_diff.set_data('x',
                            linspace(self.field_min,self.field_max,self.points))
        self.data_g.set_data('x',
                            linspace(self.field_min,self.field_max,self.points))

        inv_field_coercion1 = 1/self.coercitive_field1
        phi1 = 0
        phi1_incr = []
        phi1_decr = []

        inv_field_coercion2 = 1/self.coercitive_field2
        phi2 = 0
        phi2_incr = []
        phi2_decr = []

        for field in linspace(0, self.field_max, self.points/2):

            def min_stoner1(phi):
                return stoner_free_energy(phi, self.easy_axis_angle1,field,inv_field_coercion1)

            def min_stoner2(phi):
                return stoner_free_energy(phi, self.easy_axis_angle2,field,inv_field_coercion2)

            phi1 = optimize.fmin(min_stoner1,[phi1],disp = False)
            phi2 = optimize.fmin(min_stoner2,[phi2],disp = False)
            del min_stoner1
            del min_stoner2

        for field in linspace(self.field_max, self.field_min, self.points):

            def min_stoner1(phi):
                return stoner_free_energy(phi, self.easy_axis_angle1,field,inv_field_coercion1)

            def min_stoner2(phi):
                return stoner_free_energy(phi, self.easy_axis_angle2,field,inv_field_coercion2)

            phi1 = optimize.fmin(min_stoner1,[phi1],disp = False)
            phi1_decr.append(float(phi1)%(2*Pi))
            phi2 = optimize.fmin(min_stoner2,[phi2],disp = False)
            phi2_decr.append(float(phi2)%(2*Pi))
            del min_stoner1
            del min_stoner2

        for field in linspace(self.field_min, self.field_max, self.points):

            def min_stoner1(phi):
                return stoner_free_energy(phi, self.easy_axis_angle1,field,inv_field_coercion1)

            def min_stoner2(phi):
                return stoner_free_energy(phi, self.easy_axis_angle2,field,inv_field_coercion2)

            phi1 = optimize.fmin(min_stoner1,[phi1],disp = False)
            phi1_incr.append(float(phi1)%(2*Pi))
            phi2 = optimize.fmin(min_stoner2,[phi2],disp = False)
            phi2_incr.append(float(phi2)%(2*Pi))
            del min_stoner1
            del min_stoner2

        phi1_decr.reverse()
        phi2_decr.reverse()

        phi_diff_incr = array(phi1_incr)-array(phi2_incr)
        phi_diff_decr = array(phi1_decr)-array(phi2_decr)

        self.data1.set_data('y1',phi1_incr)
        self.data1.set_data('y2',phi1_decr)
        self.plot1.value_range.low_setting = 'auto'
        self.plot1.value_range.high_setting = 'auto'

        self.data2.set_data('y1',phi2_incr)
        self.data2.set_data('y2',phi2_decr)
        self.plot2.value_range.low_setting = 'auto'
        self.plot2.value_range.high_setting = 'auto'

        self.data_diff.set_data('y1',cos(phi_diff_incr))
        self.data_diff.set_data('y2',cos(phi_diff_decr))
        self.plot_diff.value_range.low_setting = 'auto'
        self.plot_diff.value_range.high_setting = 'auto'

        self.data_g.set_data('aux1', mod(phi_diff_incr, Pi))
        self.data_g.set_data('aux2', mod(phi_diff_decr, Pi))
        self.dqd_ready = True
        self.compute_dqd_answer()

    @on_trait_change('dqd_t, dqd_delta, dqd_gamma12, dqd_gamma13,dqd_coupling,\
    dqd_ready')
    def compute_dqd_answer(self):
        if self.dqd_ready:
            mag_field = self.lande*self.data_g.get_data('x')
            delta = self.dqd_delta + mag_field
            theta1 = self.data_g.get_data('aux1')
            theta2 = self.data_g.get_data('aux2')

            ener1_1 = -energy_dqd(self.dqd_t, delta, theta1, +1)
            ener1_2 = -energy_dqd(self.dqd_t, delta, theta2, +1)
            ener2_1 = -energy_dqd(self.dqd_t, delta, theta1, -1)
            ener2_2 = -energy_dqd(self.dqd_t, delta, theta2, -1)
            ener3_1 = energy_dqd(self.dqd_t, delta, theta1, -1)
            ener3_2 = energy_dqd(self.dqd_t, delta, theta2, -1)

            aux1 = delta*sin(theta1/2)
            aux2 = delta*sin(theta2/2)

            kmm_1 = kappa(self.dqd_t, delta, theta1,-1,-1)
            kmm_2 = kappa(self.dqd_t, delta, theta2,-1,-1)
            kpm_1 = kappa(self.dqd_t, delta, theta1,+1,-1)
            kpm_2 = kappa(self.dqd_t, delta, theta2,+1,-1)
            kmp_1 = kappa(self.dqd_t, delta, theta1,-1,+1)
            kmp_2 = kappa(self.dqd_t, delta, theta2,-1,+1)

            g12_1 = (aux1*kmm_1 + aux1*kpm_1)/\
                sqrt((aux1**2+kmm_1**2)*(aux1**2+kpm_1**2))
            g12_2 = (aux2*kmm_2 + aux2*kpm_2)/\
                sqrt((aux2**2+kmm_2**2)*(aux2**2+kpm_2**2))
            g13_1 = (aux1*kpm_1 - aux1*kmp_1)/\
                sqrt((aux1**2+kmp_1**2)*(aux1**2+kpm_1**2))
            g13_2 = (aux2*kpm_2 - aux2*kmp_2)/\
                sqrt((aux2**2+kmp_2**2)*(aux2**2+kpm_2**2))

            detun12_1 = ener2_1-ener1_1-self.cavity_freq
            detun12_2 = ener2_2-ener1_2-self.cavity_freq
            detun13_1 = ener3_1-ener1_1-self.cavity_freq
            detun13_2 = ener3_2-ener1_2-self.cavity_freq

            dqd_response_1 = detun12_1*g12_1**2/\
                                (detun12_1**2 + self.dqd_gamma12**2) +\
                            detun13_1*g13_1**2/\
                                (detun13_1**2 + self.dqd_gamma13**2)
            dqd_response_2 = detun12_2*g12_2**2/\
                                (detun12_2**2 + self.dqd_gamma12**2) +\
                            detun13_2*g13_2**2/\
                                    (detun13_2**2 +self.dqd_gamma13**2)

            dqd_response_1 *= 2*self.cavity_quality/self.cavity_freq*\
                                                    self.dqd_coupling**2/1000000
            dqd_response_2 *= 2*self.cavity_quality/self.cavity_freq*\
                                                    self.dqd_coupling**2/1000000

            self.data_g.set_data('y1', dqd_response_1)
            self.data_g.set_data('y2', dqd_response_2)


    @on_trait_change('easy_axis_angle_input1')
    def new_value1(self,new):
        self.easy_axis_angle1 = eval(new)

    @on_trait_change('easy_axis_angle_input2')
    def new_value2(self,new):
        self.easy_axis_angle2 = eval(new)
示例#16
0
class TwoDimensionalPlot(ChacoPlot):
	"""
	A 2D plot.
	"""

	auto_color_idx = 0
	auto_color_list = ['green', 'brown', 'blue', 'red', 'black']

	@classmethod
	def auto_color(cls):
		"""
		Choose the next color.
		"""

		color = cls.auto_color_list[cls.auto_color_idx]
		cls.auto_color_idx = (cls.auto_color_idx + 1) % len(cls.auto_color_list)

		return color

	def __init__(self, parent, color=None, *args, **kwargs):
		self.parent = parent

		if color is None:
			color = self.auto_color()

		self.data = ArrayPlotData()
		self.data.set_data('x', [0])
		self.data.set_data('y', [0])

		ChacoPlot.__init__(self, self.data, *args, **kwargs)

		self.plot(('x', 'y'), color=color)

		self.configure()

	@property
	def control(self):
		"""
		A drawable control.
		"""

		return Window(self.parent, component=self).control

	def get_data(self, axis):
		"""
		Values for an axis.
		"""

		return self.data.get_data(axis)

	def set_data(self, values, axis):
		self.data.set_data(axis, values)

	x_data = property(partial(get_data, axis='x'), partial(set_data, axis='x'))
	y_data = property(partial(get_data, axis='y'), partial(set_data, axis='y'))

	def x_autoscale(self):
		"""
		Enable autoscaling for the x axis.
		"""

		x_range = self.plots.values()[0][0].index_mapper.range
		x_range.low = x_range.high = 'auto'

	def y_autoscale(self):
		"""
		Enable autoscaling for the y axis.
		"""

		y_range = self.plots.values()[0][0].value_mapper.range
		y_range.low = y_range.high = 'auto'
class ColormappedPlot(ChacoPlot):
    """
	A colormapped plot.
	"""
    def __init__(self, parent, x_bounds, y_bounds, *args, **kwargs):
        self.parent = parent

        self.data = ArrayPlotData()
        self.data.set_data('color', [[0]])

        ChacoPlot.__init__(self, self.data, *args, **kwargs)

        self.img_plot('color',
                      colormap=jet,
                      xbounds=x_bounds,
                      ybounds=y_bounds)

        self.configure()

    @property
    def plot_obj(self):
        """
		The actual plot object.
		"""

        return self.plots.values()[0][0]

    @property
    def control(self):
        """
		A drawable control with a color bar.
		"""

        color_map = self.plot_obj.color_mapper
        linear_mapper = LinearMapper(range=color_map.range)
        color_bar = ColorBar(index_mapper=linear_mapper,
                             color_mapper=color_map,
                             plot=self.plot_obj,
                             orientation='v',
                             resizable='v',
                             width=30)
        color_bar._axis.tick_label_formatter = self.sci_formatter
        color_bar.padding_top = self.padding_top
        color_bar.padding_bottom = self.padding_bottom
        color_bar.padding_left = 50  # Room for labels.
        color_bar.padding_right = 10

        range_selection = RangeSelection(component=color_bar)
        range_selection.listeners.append(self.plot_obj)
        color_bar.tools.append(range_selection)

        range_selection_overlay = RangeSelectionOverlay(component=color_bar)
        color_bar.overlays.append(range_selection_overlay)

        container = HPlotContainer(use_backbuffer=True)
        container.add(self)
        container.add(color_bar)

        return Window(self.parent, component=container).control

    @property
    def color_data(self):
        """
		Plotted values.
		"""

        return self.data.get_data('color')

    @color_data.setter
    def color_data(self, values):
        self.data.set_data('color', values)

    @property
    def low_setting(self):
        """
		Lowest color value.
		"""

        return self.plot_obj.color_mapper.range.low

    @low_setting.setter
    def low_setting(self, value):
        self.plot_obj.color_mapper.range.low_setting = value

    @property
    def high_setting(self):
        """
		Highest color value.
		"""

        return self.plot_obj.color_mapper.range.high

    @high_setting.setter
    def high_setting(self, value):
        self.plot_obj.color_mapper.range.high_setting = value
示例#18
0
class StdXYPlotFactory(BasePlotFactory):
    """ Factory to create a 2D plot with one of more renderers of the same kind
    """
    #: Generated chaco plot containing all requested renderers
    plot = Instance(MultiMapperPlot)

    #: List of plot_data keys to plot in pairs, one pair per renderer
    renderer_desc = List(Dict)

    #: Renderer list, mapped to their name
    renderers = Dict

    #: Optional legend object to be added to the future plot
    legend = Instance(Legend)

    def __init__(self,
                 x_arr=None,
                 y_arr=None,
                 z_arr=None,
                 hover_data=None,
                 **traits):
        super(StdXYPlotFactory, self).__init__(**traits)

        if isinstance(x_arr, pd.Series):
            x_arr = x_arr.values

        if isinstance(y_arr, pd.Series):
            y_arr = y_arr.values

        if isinstance(y_arr, pd.Series):
            z_arr = z_arr.values

        if hover_data is None:
            hover_data = {}

        if self.plot_data is None:
            self.initialize_plot_data(x_arr=x_arr,
                                      y_arr=y_arr,
                                      z_arr=z_arr,
                                      **hover_data)

        self.adjust_plot_style()

    def adjust_plot_style(self):
        """ Translate general plotting style info into xy plot parameters.
        """
        pass

    def initialize_plot_data(self,
                             x_arr=None,
                             y_arr=None,
                             z_arr=None,
                             **adtl_arrays):
        """ Set the plot_data and the list of renderer descriptions.

        If the data arrays are dictionaries rather than straight arrays, they
        describe multiple renderers.
        """
        if x_arr is None or y_arr is None:
            msg = "2D plots require a valid plot_data or an array for both x" \
                  " and y."
            logger.exception(msg)
            raise ValueError(msg)

        if isinstance(x_arr, np.ndarray):
            data_map = self._plot_data_single_renderer(x_arr, y_arr, z_arr,
                                                       **adtl_arrays)
        elif isinstance(x_arr, dict):
            assert set(x_arr.keys()) == set(y_arr.keys())
            data_map = self._plot_data_multi_renderer(x_arr, y_arr, z_arr,
                                                      **adtl_arrays)
        else:
            msg = "x_arr/y_arr should be either an array or a dictionary " \
                  "mapping the z/hue value to the corresponding x array, but" \
                  " {} ({}) was passed."
            msg = msg.format(x_arr, type(x_arr))
            raise ValueError(msg)

        self.plot_data = ArrayPlotData(**data_map)
        return data_map

    def _plot_data_single_renderer(self,
                                   x_arr=None,
                                   y_arr=None,
                                   z_arr=None,
                                   **adtl_arrays):
        """ Build the data_map to build the plot data.
        """
        data_map = {self.x_col_name: x_arr, self.y_col_name: y_arr}
        data_map.update(adtl_arrays)
        renderer_data = {
            "x": self.x_col_name,
            "y": self.y_col_name,
            "name": DEFAULT_RENDERER_NAME
        }
        self.renderer_desc = [renderer_data]
        return data_map

    def _plot_data_multi_renderer(self,
                                  x_arr=None,
                                  y_arr=None,
                                  z_arr=None,
                                  **adtl_arrays):
        """ Built the data_map to build the plot data for multiple renderers.
        """
        data_map = {}
        for i, hue_val in enumerate(sorted(x_arr.keys())):
            hue_name, x_name, y_name = self._add_arrays_for_hue(
                data_map, x_arr, y_arr, hue_val, i, adtl_arrays)
            renderer_data = {"x": x_name, "y": y_name, "name": hue_name}
            self._hue_values.append(hue_name)
            self.renderer_desc.append(renderer_data)

        return data_map

    def _add_arrays_for_hue(self, data_map, x_arr, y_arr, hue_val, hue_val_idx,
                            adtl_arrays):
        """ Build and collect all arrays to add to ArrayPlotData for hue value.
        """
        hue_name = str(hue_val)
        x_name = self._plotdata_array_key(self.x_col_name, hue_name)
        y_name = self._plotdata_array_key(self.y_col_name, hue_name)
        data_map[x_name] = x_arr[hue_val]
        data_map[y_name] = y_arr[hue_val]
        # Collect any additional dataset that needs to be stored (for
        # e.g. to feed plot tools)
        for adtl_col, col_data in adtl_arrays.items():
            key = self._plotdata_array_key(adtl_col, hue_name)
            data_map[key] = col_data[hue_val]
        return hue_name, x_name, y_name

    def _plotdata_array_key(self, col_name, hue_name=""):
        """ Name of the ArrayPlotData containing the array from specified col.

        Parameters
        ----------
        col_name : str
            Name of the column being displayed.

        hue_name : str
            Name of the renderer color the array will be used in. Typically the
            coloring column value, converted to string.
        """
        return col_name + hue_name

    def generate_plot(self):
        """ Generate and return a dict containing a plot and its properties.
        """
        plot = self.plot = MultiMapperPlot(
            **self.plot_style.container_style.to_traits())

        # Emulate chaco.Plot interface:
        plot.data = self.plot_data

        self.add_renderers(plot)

        self.set_axis_labels(plot)

        if len(self.renderer_desc) > 1:
            self.set_legend(plot)

        self.add_tools(plot)

        # Build a description of the plot to build a PlotDescriptor
        desc = dict(plot_type=self.plot_type,
                    plot=plot,
                    visible=True,
                    plot_title=self.plot_title,
                    x_col_name=self.x_col_name,
                    y_col_name=self.y_col_name,
                    x_axis_title=self.x_axis_title,
                    y_axis_title=self.y_axis_title,
                    z_col_name=self.z_col_name,
                    z_axis_title=self.z_axis_title,
                    plot_factory=self)

        if self.plot_style.container_style.include_colorbar:
            self.generate_colorbar(desc)
            self.add_colorbar(desc)

        return desc

    def add_tools(self, plot):
        """ Add all tools specified in plot_tools list to provided plot.
        """
        broadcaster = BroadcasterTool()

        # IMPORTANT: add the broadcast tool to one of the renderers, NOT the
        # container. Otherwise, the box zoom will crop wrong:
        first_plot = plot.components[0]
        first_plot.tools.append(broadcaster)

        for i, plot in enumerate(plot.components):
            if "pan" in self.plot_tools:
                pan = PanTool(plot)
                broadcaster.tools.append(pan)

            if "zoom" in self.plot_tools:
                # FIXME: the zoom tool is added to the broadcaster's tools
                #  attribute because it doesn't have an overlay list. That
                #  means the box plot mode won't display the blue box!
                zoom = ZoomTool(component=plot, zoom_factor=1.15)
                broadcaster.tools.append(zoom)

        if "legend" in self.plot_tools and self.legend:
            legend = self.legend
            legend.tools.append(
                LegendTool(component=self.legend, drag_button="right"))
            legend.tools.append(LegendHighlighter(component=legend))

        if "context_menu" in self.plot_tools:
            self.context_menu_manager.target = self.plot
            menu = self.context_menu_manager.build_menu()
            context_menu = ContextMenuTool(component=self.plot,
                                           menu_manager=menu)

            self.plot.tools.append(context_menu)

    def add_renderers(self, plot):
        """ Add all renderers to provided plot container.
        """
        styles = self.plot_style.renderer_styles
        if len(styles) != len(self.renderer_desc):
            msg = "Something went wrong: received {} styles and {} renderer " \
                  "descriptions.".format(len(styles), len(self.renderer_desc))
            logger.exception(msg)
            raise ValueError(msg)

        for i, (desc, style) in enumerate(zip(self.renderer_desc, styles)):
            first_renderer = i == 0
            self.add_renderer(plot, desc, style, first_renderer=first_renderer)

        self.align_all_renderers(plot)

    def add_renderer(self, plot, desc, style, first_renderer=False):
        """ Create and add to plot renderer described by desc and style.

        If the axis it is displayed along isn't already created, create it too,
        and add it to the plot's list of underlays.
        """
        # Modify the renderer's style's name so it is displayed in the style
        # view:
        style.renderer_name = desc["name"]
        renderer = self._build_renderer(desc, style)
        plot.add(renderer)
        self.renderers[desc["name"]] = renderer

        if first_renderer:
            left_axis, bottom_axis = add_default_axes(renderer)
            # Emulate chaco.Plot interface:
            plot.x_axis = bottom_axis
            plot.y_axis = left_axis
            renderer.underlays = []
            plot.underlays = [bottom_axis, left_axis]
        else:
            if style.orientation == STYLE_R_ORIENT and \
                    plot.second_y_axis is None:
                is_log = self.plot_style.second_y_axis_style.scaling == \
                    LOG_AXIS_STYLE
                if is_log:
                    mapper_klass = LogMapper
                else:
                    mapper_klass = LinearMapper

                # The range needs to be initialized to the axis can be aligned
                # with all secondary y axis renderers:
                mapper = mapper_klass(range=DataRange1D())
                second_y_axis = PlotAxis(component=renderer,
                                         orientation="right",
                                         mapper=mapper)
                plot.second_y_axis = second_y_axis
                plot.underlays.append(second_y_axis)

        return renderer

    def align_all_renderers(self, plot):
        """ Align all renderers in index and value dimensions to plot's axis.

        This method is used to keep renderers aligned with the displayed axes
        once their ranges have been set.
        """
        all_renderers = self.renderers.values()
        if len(all_renderers) <= 1:
            return

        styles = self.plot_style.renderer_styles
        align_renderers(all_renderers, plot.x_axis, dim="index")
        if plot.second_y_axis is not None:
            l_renderers = [
                rend for rend, style in zip(all_renderers, styles)
                if style.orientation == STYLE_L_ORIENT
            ]
            r_renderers = [
                rend for rend, style in zip(all_renderers, styles)
                if style.orientation == STYLE_R_ORIENT
            ]
            align_renderers(l_renderers, plot.y_axis, dim="value")
            align_renderers(r_renderers, plot.second_y_axis, dim="value")
        else:
            align_renderers(all_renderers, plot.y_axis, dim="value")

    def _build_renderer(self, desc, style):
        """ Invoke appropriate renderer factory to build and return renderer.
        """
        renderer_maker = RENDERER_MAKER[style.renderer_type]
        x = self.plot_data.get_data(desc["x"])
        y = self.plot_data.get_data(desc["y"])
        if self.plot_style.x_axis_style.scaling == LOG_AXIS_STYLE:
            x_mapper_class = LogMapper
        else:
            x_mapper_class = LinearMapper

        if style.orientation == STYLE_L_ORIENT:
            y_style = self.plot_style.y_axis_style
        else:
            y_style = self.plot_style.second_y_axis_style

        if y_style.scaling == LOG_AXIS_STYLE:
            y_mapper_class = LogMapper
        else:
            y_mapper_class = LinearMapper

        return renderer_maker(data=(x, y),
                              index_mapper_class=x_mapper_class,
                              value_mapper_class=y_mapper_class,
                              **style.to_plot_kwargs())

    def set_legend(self, plot, align="ur", padding=10):
        """ Add legend and make it relocatable & clickable if tools requested.

        FIXME: Add control over legend labels.
        """
        # Make sure plot list in legend doesn't include error bar renderers:
        # legend_labels = [desc["name"] for desc in self.renderer_desc]
        legend = Legend(component=plot,
                        padding=padding,
                        align=align,
                        title=self.z_axis_title)
        legend.plots = self.renderers
        legend.visible = True
        plot.overlays.append(legend)
        # Emulate chaco.Plot-like behavior:
        self.legend = legend

    # Post creation renderer management methods -------------------------------

    def update_renderers_from_data(self, removed=None):
        """ The plot_data was updated: update/remove existing renderers.
        """
        if removed is None:
            removed = []

        rend_desc_map = {}
        for desc in self.renderer_desc:
            rend_desc_map[desc["name"]] = desc

        rend_name_list = list(self.renderers.keys())
        for name in rend_name_list:
            renderer = self.renderers[name]
            desc = rend_desc_map[name]

            both_removed = desc["x"] in removed and desc["y"] in removed
            one_removed = (desc["x"] in removed and desc["y"] not in removed) \
                or (desc["x"] not in removed and desc["y"] in removed)
            if both_removed:
                self.remove_renderer(desc)
            elif one_removed:
                msg = "Unable to update the renderer {}: the data seems to be"\
                      " incomplete because x was set as removed and not y or" \
                      " vice versa. Removed keys: {}. Please report this " \
                      "issue.".format(desc["name"], removed)
                logger.exception(msg)
                raise ValueError(msg)
            else:
                x = self.plot_data.get_data(desc["x"])
                y = self.plot_data.get_data(desc["y"])
                renderer.index.set_data(x)
                renderer.value.set_data(y)

    def remove_renderer(self, rend_desc):
        """ Remove renderer described by provided descriptor from current plot.
        """
        rend_name = rend_desc["name"]
        renderer = self.renderers.pop(rend_name)

        self.plot.remove(renderer)

        rend_idx = 0
        for desc in self.renderer_desc:
            if desc["name"] == rend_name:
                self.renderer_desc.pop(rend_idx)
                self.plot_style.renderer_styles.pop(rend_idx)
                break

            rend_idx += 1

        if self.legend:
            self.legend.plots.pop(rend_name)

    def append_new_renderers(self, desc_list, styles):
        """ Append new renderers to an existing factory plot.
        """
        num_existing_renderers = len(self.renderer_desc)
        for i, (rend_desc, rend_style) in enumerate(zip(desc_list, styles)):
            rend_idx = num_existing_renderers + i
            renderer = self.add_renderer(self.plot,
                                         rend_desc,
                                         rend_style,
                                         first_renderer=rend_idx == 0)
            self.renderer_desc.append(rend_desc)
            self.plot_style.renderer_styles.append(rend_style)

            if self.legend:
                self.legend.plots[rend_desc["name"]] = renderer

    # Traits initialization methods -------------------------------------------

    def _plot_tools_default(self):
        return {"zoom", "pan", "legend", "context_menu"}