Beispiel #1
0
class Demo(HasTraits):

    list1 = List(Int)

    list2 = List(Float)

    list3 = List(Str, maxlen=3)

    list4 = List(Enum('red', 'green', 'blue', 2, 3))

    list5 = List(Range(low=0.0, high=10.0))

    # 'low' and 'high' are used to demonstrate lists containing dynamic ranges.
    low = Float(0.0)
    high = Float(1.0)

    list6 = List(Range(low=-1.0, high='high'))

    list7 = List(Range(low='low', high='high'))

    pop1 = Button("Pop from first list")

    sort1 = Button("Sort first list")

    # This will be str(self.list1).
    list1str = Property(Str, depends_on='list1')

    traits_view = \
        View(
            HGroup(
                # This VGroup forms the column of CSVListEditor examples.
                VGroup(
                    Item('list1', label="List(Int)",
                         editor=CSVListEditor(ignore_trailing_sep=False),
                         tooltip='options: ignore_trailing_sep=False'),
                    Item('list1', label="List(Int)", style='readonly',
                         editor=CSVListEditor()),
                    Item('list2', label="List(Float)",
                         editor=CSVListEditor(enter_set=True, auto_set=False),
                         tooltip='options: enter_set=True, auto_set=False'),
                    Item('list3', label="List(Str, maxlen=3)",
                         editor=CSVListEditor()),
                    Item('list4',
                         label="List(Enum('red', 'green', 'blue', 2, 3))",
                         editor=CSVListEditor(sep=None),
                         tooltip='options: sep=None'),
                    Item('list5', label="List(Range(low=0.0, high=10.0))",
                         editor=CSVListEditor()),
                    Item('list6', label="List(Range(low=-1.0, high='high'))",
                         editor=CSVListEditor()),
                    Item('list7', label="List(Range(low='low', high='high'))",
                         editor=CSVListEditor()),
                    springy=True,
                ),
                # This VGroup forms the right column; it will display the
                # Python str representation of the lists.
                VGroup(
                    UItem('list1str', editor=TextEditor(),
                          enabled_when='False', width=240),
                    UItem('list1str', editor=TextEditor(),
                          enabled_when='False', width=240),
                    UItem('list2', editor=TextEditor(),
                          enabled_when='False', width=240),
                    UItem('list3', editor=TextEditor(),
                          enabled_when='False', width=240),
                    UItem('list4', editor=TextEditor(),
                          enabled_when='False', width=240),
                    UItem('list5', editor=TextEditor(),
                          enabled_when='False', width=240),
                    UItem('list6', editor=TextEditor(),
                          enabled_when='False', width=240),
                    UItem('list7', editor=TextEditor(),
                          enabled_when='False', width=240),
                ),
            ),
            '_',
            HGroup('low', 'high', spring, UItem('pop1'), UItem('sort1')),
            Heading("Notes"),
            Label("Hover over a list to see which editor options are set, "
                  "if any."),
            Label("The editor of the first list, List(Int), uses "
                  "ignore_trailing_sep=False, so a trailing comma is "
                  "an error."),
            Label("The second list is a read-only view of the first list."),
            Label("The editor of the List(Float) example has enter_set=True "
                  "and auto_set=False; press Enter to validate."),
            Label("The List(Str) example will accept at most 3 elements."),
            Label("The editor of the List(Enum(...)) example uses sep=None, "
                  "i.e. whitespace acts as a separator."),
            Label("The last two List(Range(...)) examples take one or both "
                  "of their limits from the Low and High fields below."),
            width=720,
            title="CSVListEditor Demonstration",
        )

    def _list1_default(self):
        return [1, 4, 0, 10]

    def _get_list1str(self):
        return str(self.list1)

    def _pop1_fired(self):
        if len(self.list1) > 0:
            x = self.list1.pop()
            print(x)

    def _sort1_fired(self):
        self.list1.sort()
Beispiel #2
0
class QSSQLObject(__QS_Object__):
    """基于关系数据库的对象"""
    DBType = Enum("MySQL", "SQL Server", "Oracle", arg_type="SingleOption", label="数据库类型", order=0)
    DBName = Str("Scorpion", arg_type="String", label="数据库名", order=1)
    IPAddr = Str("127.0.0.1", arg_type="String", label="IP地址", order=2)
    Port = Range(low=0, high=65535, value=3306, arg_type="Integer", label="端口", order=3)
    User = Str("root", arg_type="String", label="用户名", order=4)
    Pwd = Password("shuntai11", arg_type="String", label="密码", order=5)
    TablePrefix = Str("", arg_type="String", label="表名前缀", order=6)
    CharSet = Enum("utf8", "gbk", "gb2312", "gb18030", "cp936", "big5", arg_type="SingleOption", label="字符集", order=7)
    Connector = Enum("default", "cx_Oracle", "pymssql", "mysql.connector", "pyodbc", arg_type="SingleOption", label="连接器", order=8)
    DSN = Str("", arg_type="String", label="数据源", order=9)
    def __init__(self, sys_args={}, config_file=None, **kwargs):
        self._Connection = None
        return super().__init__(sys_args=sys_args, config_file=config_file, **kwargs)
    def __getstate__(self):
        state = self.__dict__.copy()
        state["_Connection"] = (True if self.isAvailable() else False)
        return state
    def __setstate__(self, state):
        super().__setstate__(state)
        if self._Connection: self._connect()
        else: self._Connection = None
    def connect(self):
        if (self.Connector=="cx_Oracle") or ((self.Connector=="default") and (self.DBType=="Oracle")):
            try:
                import cx_Oracle
                self._Connection = cx_Oracle.connect(self.User, self.Pwd, cx_Oracle.makedsn(self.IPAddr, str(self.Port), self.DBName))
            except Exception as e:
                if self.Connector!="default": raise e
        elif (self.Connector=="pymssql") or ((self.Connector=="default") and (self.DBType=="SQL Server")):
            try:
                import pymssql
                self._Connection = pymssql.connect(server=self.IPAddr, port=str(self.Port), user=self.User, password=self.Pwd, database=self.DBName, charset=self.CharSet)
            except Exception as e:
                if self.Connector!="default": raise e
        elif (self.Connector=="mysql.connector") or ((self.Connector=="default") and (self.DBType=="MySQL")):
            try:
                import mysql.connector
                self._Connection = mysql.connector.connect(host=self.IPAddr, port=str(self.Port), user=self.User, password=self.Pwd, database=self.DBName, charset=self.CharSet, autocommit=True)
            except Exception as e:
                if self.Connector!="default": raise e
        else:
            if self.Connector not in ("default", "pyodbc"):
                self._Connection = None
                raise __QS_Error__("不支持该连接器(connector) : "+self.Connector)
            else:
                import pyodbc
                if self.DSN: self._Connection = pyodbc.connect("DSN=%s;PWD=%s" % (self.DSN, self.Pwd))
                else: self._Connection = pyodbc.connect("DRIVER={%s};DATABASE=%s;SERVER=%s;UID=%s;PWD=%s" % (self.DBType, self.DBName, self.IPAddr, self.User, self.Pwd))
                self.Connector = "pyodbc"
        return 0
    def disconnect(self):
        if self._Connection is not None:
            try:
                self._Connection.close()
            except Exception as e:
                raise e
            finally:
                self._Connection = None
        return 0
    def isAvailable(self):
        return (self._Connection is not None)
    def cursor(self, sql_str=None):
        if self._Connection is None: raise __QS_Error__("%s尚未连接!" % self.__doc__)
        Cursor = self._Connection.cursor()
        if sql_str is None: return Cursor
        Cursor.execute(sql_str)
        return Cursor
    def fetchall(self, sql_str):
        Cursor = self.cursor(sql_str=sql_str)
        Data = Cursor.fetchall()
        Cursor.close()
        return Data
    def execute(self, sql_str):
        Cursor = self._Connection.cursor()
        Cursor.execute(sql_str)
        self._Connection.commit()
        Cursor.close()
        return 0
    def addIndex(self, index_name, table_name, fields, index_type="BTREE"):
        SQLStr = "CREATE INDEX "+index_name+" USING "+index_type+" ON "+self.TablePrefix+table_name+"("+", ".join(fields)+")"
        return self.execute(SQLStr)
Beispiel #3
0
class Pattern(HasTraits):

    graph = Instance(Graph, (), transient=True)
    cx = Float(transient=True)
    cy = Float(transient=True)
    target_radius = Range(0.0, 3.0, 1)

    #    beam_radius = Float(1, enter_set=True, auto_set=False)
    show_overlap = Bool(False)
    beam_radius = Range(0.0, 3.0, 1)

    path = Str
    name = Property(depends_on='path')

    #    image_width = 640
    #    image_height = 480

    xbounds = (-3, 3)
    ybounds = (-3, 3)
    #    pxpermm = None

    velocity = Float(1)
    calculated_transit_time = Float

    niterations = Range(1, 200)

    # canceled = Event
    #    def map_pt(self, x, y):
    #
    #        return self.pxpermm * x + self.image_width / 2, self.pxpermm * y + self.image_height / 2
    #     def close(self, isok):
    #         self.canceled = True
    #         return True

    @property
    def kind(self):
        return self.__class__.__name__

    def _get_name(self):
        if not self.path:
            return 'New Pattern'
        return os.path.basename(self.path).split('.')[0]

    def _anytrait_changed(self, name, new):
        if name != 'calculated_transit_time':
            self.replot()
            self.calculate_transit_time()

    def calculate_transit_time(self):
        n = self.niterations

        c = -self._get_path_length()
        b = self.velocity

        acceleration = 1
        a = 0.5 * acceleration
        #         0 = -c+ b * t + 0.5 * a * t ** 2

        t1 = -b + (b**2 - 4 * a * c) / (2.0 * a)
        t2 = -b - (b**2 - 4 * a * c) / (2.0 * a)

        self.calculated_transit_time = (max(t1, t2) + self._get_delay()) * n

#         self.calculated_transit_time = ((self._get_path_length() /
#                                         max(self.velocity, 0.001)) + self._get_delay()) * self.niterations

    def _get_path_length(self):
        return 0

    def _get_delay(self):
        return 0

    def _beam_radius_changed(self):
        oo = self.graph.plots[0].plots['plot0'][0].overlays[1]
        oo.beam_radius = self.beam_radius
        self.replot()

    def _show_overlap_changed(self):
        oo = self.graph.plots[0].plots['plot0'][0].overlays[1]
        oo.visible = self.show_overlap
        oo.request_redraw()

    def _target_radius_changed(self):
        self.graph.plots[0].plots['plot0'][0].overlays[
            0].target_radius = self.target_radius
#    def set_mapping(self, px):
#        self.pxpermm = px / 10.0
#
#
#    def set_image(self, data, graph=None):
#        '''
#            px,py pixels per cm x and y
#        '''
#        if graph is None:
#            graph = self.graph
#
# #        p = graph.plots[0].plots['plot0'][0]
# #        for ui in p.underlays:
# #            if isinstance(ui, ImageUnderlay):
# #                ui.image.load(img)
# #                break
# #        else:
#        if isinstance(data, str):
#            image = Image()
#            image.load(data)
#            data = image.get_array()
#        else:
#            data = data.as_numpy_array()
#            data = data.copy()
#            data = flipud(data)
#
# #            mmx = px / 10.0 * (self.xbounds[1] - self.xbounds[0])
# #            mmy = py / 10.0 * (self.ybounds[1] - self.ybounds[0])
# #
# #            w = 640
# #            h = 480
# #            cb = [w / 2 - mmx, w / 2 + mmx, h / 2 - mmy, h / 2 + mmy]
# #            cb = [h / 2 - mmy, h / 2 + mmy, w / 2 - mmx, w / 2 + mmx ]
#
#
#        graph.plots[0].data.set_data('imagedata', data)
#        graph.plots[0].img_plot('imagedata')
#
# #            io = ImageUnderlay(component=p, image=image, crop_rect=(640 / 2, 480 / 2, mmx, mmy))
# #
# #            p.underlays.append(io)
#
#        graph.redraw()

    def pattern_generator_factory(self, **kw):
        raise NotImplementedError

    def replot(self):
        self.plot()

    def plot(self):
        pgen_out = self.pattern_generator_factory()
        data_out = array([pt for pt in pgen_out])
        xs, ys = transpose(data_out)

        #        if self.pxpermm is not None:
        #            xs, ys = self.map_pt(xs, ys)

        self.graph.set_data(xs)
        self.graph.set_data(ys, axis=1)

        return data_out[-1][0], data_out[-1][1]

    def points_factory(self):
        gen_out = self.pattern_generator_factory()
        return [pt for pt in gen_out]

    def graph_view(self):
        v = View(Item(
            'graph',
            style='custom',
            show_label=False,
        ),
                 handler=self.handler_klass,
                 title=self.name)
        return v

#    def _get_crop_bounds(self):
#        px = self.pxpermm
# #        mmx = px / 10.0 * 1 / (self.xbounds[1] - self.xbounds[0])
# #        mmy = py / 10.0 * 1 / (self.ybounds[1] - self.ybounds[0])
#        windx = (self.xbounds[1] - self.xbounds[0])
#        mmx = windx * px / 2
#
#        windy = (self.ybounds[1] - self.ybounds[0])
#        mmy = windy * px / 2
#
#        w = self.image_width
#        h = self.image_height
#
#        cbx = [w / 2 - mmx, w / 2 + mmx ]
#        cby = [h / 2 - mmy, h / 2 + mmy]
#
#        return cbx, cby

    def clear_graph(self):
        graph = self.graph
        graph.set_data([], series=1, axis=0)
        graph.set_data([], series=1, axis=1)
        graph.set_data([], series=2, axis=0)
        graph.set_data([], series=2, axis=1)

    def reset_graph(self, **kw):
        self.graph = self._graph_factory(**kw)

    def _graph_factory(self, with_image=False):
        g = Graph(window_height=250,
                  window_width=300,
                  container_dict=dict(padding=0))
        g.new_plot(bounds=[250, 250], resizable='', padding=[30, 0, 0, 30])

        cx = self.cx
        cy = self.cy
        cbx = self.xbounds
        cby = self.ybounds
        tr = self.target_radius

        #        if with_image:
        #            px = self.pxpermm  #px is in mm
        #            cbx, cby = self._get_crop_bounds()
        #            #g.set_axis_traits(tick_label_formatter=lambda x: '{:0.2f}'.format((x - w / 2) / px))
        #            #g.set_axis_traits(tick_label_formatter=lambda x: '{:0.2f}'.format((x - h / 2) / px), axis='y')
        #
        #            bx, by = g.plots[0].bounds
        #            g.plots[0].x_axis.mapper = LinearMapper(high_pos=bx,
        #                                                    range=DataRange1D(low_setting=self.xbounds[0],
        #                                                                      high_setting=self.xbounds[1]))
        #            g.plots[0].y_axis.mapper = LinearMapper(high_pos=by,
        #                                                    range=DataRange1D(low_setting=self.ybounds[0],
        #                                                                      high_setting=self.ybounds[1]))
        #            cx += self.image_width / 2
        #            cy += self.image_height / 2
        #            tr *= px

        g.set_x_limits(*cbx)
        g.set_y_limits(*cby)

        lp, _plot = g.new_series()
        t = TargetOverlay(component=lp, cx=cx, cy=cy, target_radius=tr)

        lp.overlays.append(t)
        overlap_overlay = OverlapOverlay(component=lp,
                                         visible=self.show_overlap)
        lp.overlays.append(overlap_overlay)

        g.new_series(type='scatter', marker='circle')
        g.new_series(type='line', color='red')
        return g

    def _graph_default(self):
        return self._graph_factory()


#        p = '/Users/ross/Desktop/foo2.tiff'
#
#        i = Image()#width=640, height=480)
#        i.load(p)
#
#        self.set_image(i, px, px, graph=g)
#        from chaco.image_data import ImageData
#        image = ImageData.fromfile(p)
# #        print image._data
#        crop(i.source_frame, 0, 0, 300, 300)
# self.pattern.graph.plots[0].plots['plot0'][0].overlays.append(ImageUnderlay(image=i))
# self.pattern.graph.plots[0].plots[0].underlays.append(ImageUnderlay(image=i))
#        io = ImageUnderlay(component=lp, image=i, visible=False)
#        lp.overlays.append(io)

    def maker_group(self):
        return Group(self.get_parameter_group(),
                     Item('niterations'),
                     HGroup(
                         Item('velocity'),
                         Item('calculated_transit_time',
                              label='Time (s)',
                              style='readonly',
                              format_str='%0.1f')),
                     Item('target_radius'),
                     Item('show_overlap'),
                     Item('beam_radius', enabled_when='show_overlap'),
                     show_border=True,
                     label='Pattern')

    def maker_view(self):
        v = View(
            HGroup(
                self.maker_group(),
                Item(
                    'graph',
                    #                      resizable=False,
                    show_label=False,
                    style='custom'),
            ),
            #                  buttons=['OK', 'Cancel'],
            resizable=True)
        return v

    def traits_view(self):
        v = View(self.maker_group(),
                 buttons=['OK', 'Cancel'],
                 title=self.name,
                 resizable=True)
        return v

    def get_parameter_group(self):
        raise NotImplementedError
Beispiel #4
0
class Visualization(HasTraits):
    meridional = Range(1, 30, 6)
    transverse = Range(0, 30, 11)
    scene_3D1 = Instance(MlabSceneModel, ())
    scene_3D2 = Instance(MlabSceneModel, ())

    figure = Instance(Figure, ())

    colors_list = Enum(*COLORS)
    polynomial_degree = Int(4)
    residual_threshold = Int(20)

    def __init__(self):
        # Do not forget to call the parent's __init__
        HasTraits.__init__(self)

        self.update_3dplot()
        self.update_plot()

    @on_trait_change('colors_list, polynomial_degree, residual_threshold')
    def update_3dplot(self):

        color = self.colors_list

        with open(os.path.join(base_path1, 'measurements.pkl'), 'rb') as f:
            measurements1 = cPickle.load(f)
        with open(os.path.join(base_path2, 'measurements.pkl'), 'rb') as f:
            measurements2 = cPickle.load(f)

        vc1 = VignettingCalibration.readMeasurements(
            base_path1,
            polynomial_degree=self.polynomial_degree,
            residual_threshold=self.residual_threshold)
        vc2 = VignettingCalibration.readMeasurements(
            base_path2,
            polynomial_degree=self.polynomial_degree,
            residual_threshold=self.residual_threshold)

        x1, y1, z1 = zip(*[a for a in measurements1 if a[0] is not None])
        x2, y2, z2 = zip(*[a for a in measurements2 if a[0] is not None])

        for scene, (x, y, z), vc in zip([self.scene_3D1, self.scene_3D2],
                                        ((x1, y1, z1), (x2, y2, z2)),
                                        (vc1, vc2)):
            mlab.clf(figure=scene.mayavi_scene)

            zs = np.array(z)[..., COLOR_INDICES[color]]
            zs = 255 * zs / zs.max()
            mlab.points3d(x,
                          y,
                          zs,
                          mode='sphere',
                          scale_mode='scalar',
                          scale_factor=20,
                          color=(1, 1, 1),
                          figure=scene.mayavi_scene)

            surface = np.ones((1200, 1600))
            corrected = raw2RGB(255 / vc.applyVignetting(surface))

            mlab.surf(corrected[COLOR_INDICES[color]].T,
                      extent=(0, 1600, 0, 1200, 0, 255),
                      opacity=0.5)
            mlab.outline(color=(0, 0, 0),
                         extent=(0, 1600, 0, 1200, 0, 255),
                         figure=scene.mayavi_scene)

    @on_trait_change('colors_list')
    def update_plot(self):

        color = self.colors_list

        with open(os.path.join(base_path1, 'spec.pkl'), 'rb') as f:
            spec1 = cPickle.load(f)
        with open(os.path.join(base_path2, 'spec.pkl'), 'rb') as f:
            spec2 = cPickle.load(f)

        self.figure.clear()
        axes = self.figure.add_subplot(111)

        axes.plot(spec1[0], spec1[1], label='camera1')
        axes.plot(spec2[0], spec2[1], label='camera1')
        axes.legend()

        canvas = self.figure.canvas
        if canvas is not None:
            canvas.draw()

    # the layout of the dialog created
    view = View(
        VGroup(
            VGroup(
                #HGroup(
                #Item('figure', editor=MPLFigureEditor(),
                #show_label=False),
                #label = 'Spectograms',
                #show_border = True
                #),
                HGroup(
                    HGroup(Item('scene_3D1',
                                editor=SceneEditor(scene_class=MayaviScene),
                                height=200,
                                width=200,
                                show_label=False),
                           label='3D',
                           show_border=True),
                    HGroup(Item('scene_3D2',
                                editor=SceneEditor(scene_class=MayaviScene),
                                height=200,
                                width=200,
                                show_label=False),
                           label='3D',
                           show_border=True),
                ),
                '_',
                HGroup(Item('colors_list', style='simple'),
                       Item('polynomial_degree', style='simple'),
                       Item('residual_threshold', style='simple'), Spring()),
            )))
Beispiel #5
0
class BaseXYPlot(AbstractPlotRenderer):
    """ Base class for simple X-vs-Y plots that consist of a single index
    data array and a single value data array.

    Subclasses handle the actual rendering, but this base class takes care of
    most of making sure events are wired up between mappers and data or screen
    space changes, etc.
    """

    #------------------------------------------------------------------------
    # Data-related traits
    #------------------------------------------------------------------------

    # The data source to use for the index coordinate.
    index = Instance(ArrayDataSource)

    # The data source to use as value points.
    value = Instance(AbstractDataSource)

    # Screen mapper for index data.
    index_mapper = Instance(AbstractMapper)
    # Screen mapper for value data
    value_mapper = Instance(AbstractMapper)

    # Convenience properties that correspond to either index_mapper or
    # value_mapper, depending on the orientation of the plot.

    # Corresponds to either **index_mapper** or **value_mapper**, depending on
    # the orientation of the plot.
    x_mapper = Property
    # Corresponds to either **value_mapper** or **index_mapper**, depending on
    # the orientation of the plot.
    y_mapper = Property

    # Convenience property for accessing the index data range.
    index_range = Property
    # Convenience property for accessing the value data range.
    value_range = Property

    # The type of hit-testing that is appropriate for this renderer.
    #
    # * 'line': Computes Euclidean distance to the line between the
    #   nearest adjacent points.
    # * 'point': Checks for adjacency to a marker or point.
    hittest_type = Enum("point", "line")

    #------------------------------------------------------------------------
    # Appearance-related traits
    #------------------------------------------------------------------------

    # The orientation of the index axis.
    orientation = Enum("h", "v")

    # Overall alpha value of the image. Ranges from 0.0 for transparent to 1.0
    alpha = Range(0.0, 1.0, 1.0)

    #------------------------------------------------------------------------
    # Convenience readonly properties for common annotations
    #------------------------------------------------------------------------

    # Read-only property for horizontal grid.
    hgrid = Property
    # Read-only property for vertical grid.
    vgrid = Property
    # Read-only property for x-axis.
    x_axis = Property
    # Read-only property for y-axis.
    y_axis = Property
    # Read-only property for labels.
    labels = Property

    #------------------------------------------------------------------------
    # Other public traits
    #------------------------------------------------------------------------

    # Does the plot use downsampling?
    # This is not used right now.  It needs an implementation of robust, fast
    # downsampling, which does not exist yet.
    use_downsampling = Bool(False)

    # Does the plot use a spatial subdivision structure for fast hit-testing?
    # This makes data updates slower, but makes hit-tests extremely fast.
    use_subdivision = Bool(False)

    # Overrides the default background color trait in PlotComponent.
    bgcolor = "transparent"

    # This just turns on a simple drawing of the X and Y axes... not a long
    # term solution, but good for testing.

    # Defines the origin axis color, for testing.
    origin_axis_color = black_color_trait
    # Defines a the origin axis width, for testing.
    origin_axis_width = Float(1.0)
    # Defines the origin axis visibility, for testing.
    origin_axis_visible = Bool(False)

    #------------------------------------------------------------------------
    # Private traits
    #------------------------------------------------------------------------

    # Are the cache traits valid? If False, new ones need to be compute.
    _cache_valid = Bool(False)

    # Cached array of (x,y) data-space points; regardless of self.orientation,
    # these points are always stored as (index_pt, value_pt).
    _cached_data_pts = Array

    # Cached array of (x,y) screen-space points.
    _cached_screen_pts = Array

    # Does **_cached_screen_pts** contain the screen-space coordinates
    # of the points currently in **_cached_data_pts**?
    _screen_cache_valid = Bool(False)

    # Reference to a spatial subdivision acceleration structure.
    _subdivision = Any

    #------------------------------------------------------------------------
    # Abstract methods that subclasses must implement
    #------------------------------------------------------------------------

    def _render(self, gc, points):
        """ Abstract method for rendering points.

        Parameters
        ----------
        gc : graphics context
            Target for drawing the points
        points : List of Nx2 arrays
            Screen-space points to render
        """
        raise NotImplementedError

    def _gather_points(self):
        """ Abstract method to collect data points that are within the range of
        the plot, and cache them.
        """
        raise NotImplementedError

    def _downsample(self):
        """ Abstract method that gives the renderer a chance to downsample in
        screen space.
        """
        # By default, this just does a mapscreen and returns the result
        raise NotImplementedError

    #------------------------------------------------------------------------
    # Concrete methods below
    #------------------------------------------------------------------------

    def __init__(self, **kwtraits):
        # Handling the setting/initialization of these traits manually because
        # they should be initialized in a certain order.
        kwargs_tmp = {"trait_change_notify": False}
        for trait_name in ("index", "value", "index_mapper", "value_mapper"):
            if trait_name in kwtraits:
                kwargs_tmp[trait_name] = kwtraits.pop(trait_name)
        self.set(**kwargs_tmp)
        AbstractPlotRenderer.__init__(self, **kwtraits)
        if self.index is not None:
            self.index.on_trait_change(self._either_data_changed,
                                       "data_changed")
            self.index.on_trait_change(self._either_metadata_changed,
                                       "metadata_changed")
        if self.index_mapper:
            self.index_mapper.on_trait_change(self._mapper_updated_handler,
                                              "updated")
        if self.value is not None:
            self.value.on_trait_change(self._either_data_changed,
                                       "data_changed")
            self.value.on_trait_change(self._either_metadata_changed,
                                       "metadata_changed")
        if self.value_mapper:
            self.value_mapper.on_trait_change(self._mapper_updated_handler,
                                              "updated")

        # If we are not resizable, we will not get a bounds update upon layout,
        # so we have to manually update our mappers
        if self.resizable == "":
            self._update_mappers()
        return

    def hittest(self, screen_pt, threshold=7.0, return_distance=False):
        """ Performs proximity testing between a given screen point and the
        plot.

        Parameters
        ----------
        screen_pt : (x,y)
            A point to test.
        threshold : integer
            Optional maximum screen space distance (pixels) between
            *screen_pt* and the plot.
        return_distance : Boolean
            If True, also return the distance.

        Returns
        -------
        If self.hittest_type is 'point', then this method returns the screen
        coordinates of the closest point on the plot as a tuple (x,y)

        If self.hittest_type is 'line', then this method returns the screen
        endpoints of the line segment closest to *screen_pt*, as
        ((x1,y1), (x2,y2))

        If *screen_pt* does not fall within *threshold* of the plot, then this
        method returns None.

        If return_distance is True, return the (x, y, d), where d is the
        distance between the distance between the input point and
        the closest point (x, y), in screen coordinates.
        """
        if self.hittest_type == "point":
            tmp = self.get_closest_point(screen_pt, threshold)
        elif self.hittest_type == "line":
            tmp = self.get_closest_line(screen_pt, threshold)
        else:
            raise ValueError("Unknown hittest type '%s'" % self.hittest_type)

        if tmp is not None:
            if return_distance:
                return tmp
            else:
                return tmp[:-1]
        else:
            return None

    def get_closest_point(self, screen_pt, threshold=7.0):
        """ Tests for proximity in screen-space.

        This method checks only data points, not the line segments connecting
        them; to do the latter use get_closest_line() instead.

        Parameters
        ----------
        screen_pt : (x,y)
            A point to test.
        threshold : integer
            Optional maximum screen space distance (pixels) between
            *screen_pt* and the plot.  If 0.0, then no threshold tests
            are performed, and the nearest point is returned.

        Returns
        -------
        (x, y, distance) of a datapoint nearest to *screen_pt*.
        If no data points are within *threshold* of *screen_pt*, returns None.
        """
        ndx = self.map_index(screen_pt, threshold)
        if ndx is not None:
            x = self.x_mapper.map_screen(self.index.get_data()[ndx])
            y = self.y_mapper.map_screen(self.value.get_data()[ndx])
            return (x, y, sqrt((x - screen_pt[0])**2 + (y - screen_pt[1])**2))
        else:
            return None

    def get_closest_line(self, screen_pt, threshold=7.0):
        """ Tests for proximity in screen-space against lines connecting the
        points in this plot's dataset.

        Parameters
        ----------
        screen_pt : (x,y)
            A point to test.
        threshold : integer
            Optional maximum screen space distance (pixels) between
            the line and the plot.  If 0.0, then the method returns the closest
            line regardless of distance from the plot.

        Returns
        -------
        (x1, y1, x2, y2, dist) of the endpoints of the line segment
        closest to *screen_pt*.  The *dist* element is the perpendicular
        distance from *screen_pt* to the line.  If there is only a single point
        in the renderer's data, then the method returns the same point twice.

        If no data points are within *threshold* of *screen_pt*, returns None.
        """
        ndx = self.map_index(screen_pt, threshold=0.0)
        if ndx is None:
            return None

        index_data = self.index.get_data()
        value_data = self.value.get_data()
        x = self.x_mapper.map_screen(index_data[ndx])
        y = self.y_mapper.map_screen(value_data[ndx])

        # We need to find another index so we have two points; in the
        # even that we only have 1 point, just return that point.
        datalen = len(index_data)
        if datalen == 1:
            dist = (x, y, sqrt((x - screen_pt[0])**2 + (y - screen_pt[1])**2))
            if (threshold == 0.0) or (dist <= threshold):
                return (x, y, x, y, dist)
            else:
                return None
        else:
            if (ndx == 0) or (screen_pt[0] >= x):
                ndx2 = ndx + 1
            elif (ndx == datalen - 1) or (screen_pt[0] <= x):
                ndx2 = ndx - 1
            x2 = self.x_mapper.map_screen(index_data[ndx2])
            y2 = self.y_mapper.map_screen(value_data[ndx2])
            dist = point_line_distance(screen_pt, (x, y), (x2, y2))
            if (threshold == 0.0) or (dist <= threshold):
                return (x, y, x2, y2, dist)
            else:
                return None

    #------------------------------------------------------------------------
    # AbstractPlotRenderer interface
    #------------------------------------------------------------------------

    def map_screen(self, data_array):
        """ Maps an array of data points into screen space and returns it as
        an array.

        Implements the AbstractPlotRenderer interface.
        """
        # data_array is Nx2 array
        if len(data_array) == 0:
            return []

        x_ary, y_ary = transpose(data_array)

        sx = self.index_mapper.map_screen(x_ary)
        sy = self.value_mapper.map_screen(y_ary)
        if self.orientation == "h":
            return transpose(array((sx, sy)))
        else:
            return transpose(array((sy, sx)))

    def map_data(self, screen_pt, all_values=False):
        """ Maps a screen space point into the "index" space of the plot.

        Implements the AbstractPlotRenderer interface.

        If *all_values* is True, returns an array of (index, value) tuples;
        otherwise, it returns only the index values.
        """
        x, y = screen_pt
        if self.orientation == 'v':
            x, y = y, x
        if all_values:
            return array(
                (self.index_mapper.map_data(x), self.value_mapper.map_data(y)))
        else:
            return self.index_mapper.map_data(x)

    def map_index(self,
                  screen_pt,
                  threshold=2.0,
                  outside_returns_none=True,
                  index_only=False):
        """ Maps a screen space point to an index into the plot's index array(s).

        Implements the AbstractPlotRenderer interface.

        Parameters
        ----------
        screen_pt :
            Screen space point

        threshold : float
            Maximum distance from screen space point to plot data point.
            A value of 0.0 means no threshold (any distance will do).

        outside_returns_none : bool
            If True, a screen space point outside the data range returns None.
            Otherwise, it returns either 0 (outside the lower range) or
            the last index (outside the upper range)

        index_only : bool
            If True, the threshold is measured on the distance between the
            index values, otherwise as Euclidean distance between the (x,y)
            coordinates.
        """

        data_pt = self.map_data(screen_pt)
        if ((data_pt < self.index_mapper.range.low) or
            (data_pt > self.index_mapper.range.high)) and outside_returns_none:
            return None
        index_data = self.index.get_data()
        value_data = self.value.get_data()

        if len(value_data) == 0 or len(index_data) == 0:
            return None

        try:
            # find the closest point to data_pt in index_data
            ndx = reverse_map_1d(index_data, data_pt, self.index.sort_order)
        except IndexError:
            # if reverse_map raises this exception, it means that data_pt is
            # outside the range of values in index_data.
            if outside_returns_none:
                return None
            else:
                if data_pt < index_data[0]:
                    return 0
                else:
                    return len(index_data) - 1

        if threshold == 0.0:
            # Don't do any threshold testing
            return ndx

        x = index_data[ndx]
        y = value_data[ndx]
        if isnan(x) or isnan(y):
            return None

        # transform x,y in a 1x2 array, which is the preferred format of
        # map_screen. this makes it robust against differences in
        # the map_screen methods of logmapper and linearmapper
        # when passed a scalar
        xy = array([[x, y]])
        sx, sy = self.map_screen(xy).T
        if index_only and (threshold == 0.0 or screen_pt[0] - sx < threshold):
            return ndx
        elif ((screen_pt[0] - sx)**2 +
              (screen_pt[1] - sy)**2 < threshold * threshold):
            return ndx
        else:
            return None

    def get_screen_points(self):
        """Returns the currently visible screen-space points.

        Intended for use with overlays.
        """
        self._gather_points()
        if self.use_downsampling:
            # The BaseXYPlot implementation of _downsample doesn't actually
            # do any downsampling.
            return self._downsample()
        else:
            return self.map_screen(self._cached_data_pts)

    #------------------------------------------------------------------------
    # PlotComponent interface
    #------------------------------------------------------------------------

    def _draw_plot(self, gc, view_bounds=None, mode="normal"):
        """ Draws the 'plot' layer.
        """
        self._draw_component(gc, view_bounds, mode)
        return

    def _draw_component(self, gc, view_bounds=None, mode="normal"):
        # This method should be folded into self._draw_plot(), but is here for
        # backwards compatibilty with non-draw-order stuff.

        pts = self.get_screen_points()
        self._render(gc, pts)
        return

    def _draw_default_axes(self, gc):
        if not self.origin_axis_visible:
            return

        with gc:
            gc.set_stroke_color(self.origin_axis_color_)
            gc.set_line_width(self.origin_axis_width)
            gc.set_line_dash(None)

            for range in (self.index_mapper.range, self.value_mapper.range):
                if (range.low < 0) and (range.high > 0):
                    if range == self.index_mapper.range:
                        dual = self.value_mapper.range
                        data_pts = array([[0.0, dual.low], [0.0, dual.high]])
                    else:
                        dual = self.index_mapper.range
                        data_pts = array([[dual.low, 0.0], [dual.high, 0.0]])
                    start, end = self.map_screen(data_pts)
                    start = around(start)
                    end = around(end)
                    gc.move_to(int(start[0]), int(start[1]))
                    gc.line_to(int(end[0]), int(end[1]))
                    gc.stroke_path()
        return

    def _post_load(self):
        super(BaseXYPlot, self)._post_load()
        self._update_mappers()
        self.invalidate_draw()
        self._cache_valid = False
        self._screen_cache_valid = False
        return

    def _update_subdivision(self):

        return

    #------------------------------------------------------------------------
    # Properties
    #------------------------------------------------------------------------

    def _get_index_range(self):
        return self.index_mapper.range

    def _set_index_range(self, val):
        self.index_mapper.range = val

    def _get_value_range(self):
        return self.value_mapper.range

    def _set_value_range(self, val):
        self.value_mapper.range = val

    def _get_x_mapper(self):
        if self.orientation == "h":
            return self.index_mapper
        else:
            return self.value_mapper

    def _get_y_mapper(self):
        if self.orientation == "h":
            return self.value_mapper
        else:
            return self.index_mapper

    def _get_hgrid(self):
        for obj in self.underlays + self.overlays:
            if isinstance(obj, PlotGrid) and obj.orientation == "horizontal":
                return obj
        else:
            return None

    def _get_vgrid(self):
        for obj in self.underlays + self.overlays:
            if isinstance(obj, PlotGrid) and obj.orientation == "vertical":
                return obj
        else:
            return None

    def _get_x_axis(self):
        for obj in self.underlays + self.overlays:
            if isinstance(obj,
                          PlotAxis) and obj.orientation in ("bottom", "top"):
                return obj
        else:
            return None

    def _get_y_axis(self):
        for obj in self.underlays + self.overlays:
            if isinstance(obj,
                          PlotAxis) and obj.orientation in ("left", "right"):
                return obj
        else:
            return None

    def _get_labels(self):
        labels = []
        for obj in self.underlays + self.overlays:
            if isinstance(obj, PlotLabel):
                labels.append(obj)
        return labels

    #------------------------------------------------------------------------
    # Event handlers
    #------------------------------------------------------------------------

    def _update_mappers(self):
        x_mapper = self.index_mapper
        y_mapper = self.value_mapper

        if self.orientation == "v":
            x_mapper, y_mapper = y_mapper, x_mapper

        x = self.x
        x2 = self.x2
        y = self.y
        y2 = self.y2

        if "left" in self.origin:
            x_mapper.screen_bounds = (x, x2)
        else:
            x_mapper.screen_bounds = (x2, x)

        if "bottom" in self.origin:
            y_mapper.screen_bounds = (y, y2)
        else:
            y_mapper.screen_bounds = (y2, y)

        self.invalidate_draw()
        self._cache_valid = False
        self._screen_cache_valid = False

    def _bounds_changed(self, old, new):
        super(BaseXYPlot, self)._bounds_changed(old, new)
        self._update_mappers()

    def _bounds_items_changed(self, event):
        super(BaseXYPlot, self)._bounds_items_changed(event)
        self._update_mappers()

    def _position_changed(self):
        self._update_mappers()

    def _position_items_changed(self):
        self._update_mappers()

    def _orientation_changed(self):
        self._update_mappers()

    def _index_changed(self, old, new):
        if old is not None:
            old.on_trait_change(self._either_data_changed,
                                "data_changed",
                                remove=True)
            old.on_trait_change(self._either_metadata_changed,
                                "metadata_changed",
                                remove=True)
        if new is not None:
            new.on_trait_change(self._either_data_changed, "data_changed")
            new.on_trait_change(self._either_metadata_changed,
                                "metadata_changed")
        self._either_data_changed()
        return

    def _either_data_changed(self):
        self.invalidate_draw()
        self._cache_valid = False
        self._screen_cache_valid = False
        self.request_redraw()
        return

    def _either_metadata_changed(self):
        # By default, don't respond to metadata change events.
        pass

    def _value_changed(self, old, new):
        if old is not None:
            old.on_trait_change(self._either_data_changed,
                                "data_changed",
                                remove=True)
            old.on_trait_change(self._either_metadata_changed,
                                "metadata_changed",
                                remove=True)
        if new is not None:
            new.on_trait_change(self._either_data_changed, "data_changed")
            new.on_trait_change(self._either_metadata_changed,
                                "metadata_changed")
        self._either_data_changed()
        return

    def _origin_changed(self, old, new):
        # origin switch from left to right or vice versa?
        if old.split()[1] != new.split()[1]:
            xm = self.x_mapper
            xm.low_pos, xm.high_pos = xm.high_pos, xm.low_pos
        # origin switch from top to bottom or vice versa?
        if old.split()[0] != new.split()[0]:
            ym = self.y_mapper
            ym.low_pos, ym.high_pos = ym.high_pos, ym.low_pos

        self.invalidate_draw()
        self._screen_cache_valid = False
        return

    def _index_mapper_changed(self, old, new):
        self._either_mapper_changed(self, "index_mapper", old, new)
        if self.orientation == "h":
            self.trait_property_changed("x_mapper", old, new)
        else:
            self.trait_property_changed("y_mapper", old, new)
        return

    def _value_mapper_changed(self, old, new):
        self._either_mapper_changed(self, "value_mapper", old, new)
        if self.orientation == "h":
            self.trait_property_changed("y_mapper", old, new)
        else:
            self.trait_property_changed("x_mapper", old, new)
        return

    def _either_mapper_changed(self, obj, name, old, new):
        if old is not None:
            old.on_trait_change(self._mapper_updated_handler,
                                "updated",
                                remove=True)
        if new is not None:
            new.on_trait_change(self._mapper_updated_handler, "updated")
        self.invalidate_draw()
        self._screen_cache_valid = False
        return

    def _mapper_updated_handler(self):
        self._cache_valid = False
        self._screen_cache_valid = False
        self.invalidate_draw()
        self.request_redraw()
        return

    def _visible_changed(self, old, new):
        if new:
            self._layout_needed = True

    def _bgcolor_changed(self):
        self.invalidate_draw()

    def _use_subdivision_changed(self, old, new):
        if new:
            self._set_up_subdivision()
        return

    #------------------------------------------------------------------------
    # Persistence
    #------------------------------------------------------------------------

    def __getstate__(self):
        state = super(BaseXYPlot, self).__getstate__()
        for key in [
                '_cache_valid', '_cached_data_pts', '_screen_cache_valid',
                '_cached_screen_pts'
        ]:
            if key in state:
                del state[key]

        return state

    def __setstate__(self, state):
        super(BaseXYPlot, self).__setstate__(state)
        if self.index is not None:
            self.index.on_trait_change(self._either_data_changed,
                                       "data_changed")
        if self.value is not None:
            self.value.on_trait_change(self._either_data_changed,
                                       "data_changed")

        self.invalidate_draw()
        self._cache_valid = False
        self._screen_cache_valid = False
        self._update_mappers()
        return
class DataFramePlotManagerExporter(HasStrictTraits):
    """ Exporter of a DataFramePlotManager content to various formats.
    """
    df_plotter = Instance("pybleau.app.api.DataFramePlotManager")

    export_format = Enum([IMG_FORMAT, PPT_FORMAT, VEGA_FORMAT])

    #: Whether to, and how to, export the data behind the plots
    export_data = Enum(values="_export_data_options")

    export_each_plot_data = Bool

    #: The ways to export the data depend on the target format
    _export_data_options = Property(List(Str), depends_on="export_format")

    data_filename = Str(EXTERNAL_DATA_FNAME)

    data_format = Enum([".csv", ".xlsx", ".h5"])

    #: Whether to skip plots whose visible flag is off
    skip_hidden = Bool(True)

    #: Target directory
    target_dir = Directory

    target_file = File

    view_klass = Any(View)

    interactive = Bool(True)

    # Image format specific parameters ----------------------------------------

    image_format = Enum(["PNG", "JPG", "BMP", "TIFF", "EPS"])

    image_dpi = Range(low=50, high=100, value=72)

    filename_from_title = Bool(True)

    # Image format specific parameters ----------------------------------------

    overwrite_file_if_exists = Bool(False)

    json_index = Range(low=0, high=4, value=2)

    _many_plots = Bool

    # PPT format specific parameters ------------------------------------------

    presentation_title = Str("New Presentation")

    presentation_subtitle = Str

    def traits_view(self):
        is_ppt = "export_format == '{}'".format(PPT_FORMAT)
        is_vega = "export_format == '{}'".format(VEGA_FORMAT)
        is_img = "export_format=='{}'".format(IMG_FORMAT)

        num_plots = len(self.df_plotter.contained_plots)
        inline_warning = "Warning: exporting data inline can use a lot of " \
                         "file storage because the same data is stored inside"\
                         " each of the {} plots.".format(num_plots)
        inline_visible_when = "_many_plots and export_data=='{}'".format(
            EXPORT_INLINE)
        dat_file_visible_when = "export_data in ['{}', '{}']".format(
            EXPORT_YES, EXPORT_SEPARATE)

        view = self.view_klass(
            VGroup(
                HGroup(Spring(), Item("export_format"), Spring()),
                VGroup(Item('target_dir', visible_when=is_img),
                       HGroup(Item('target_file',
                                   editor=FileEditor(dialog_style='save')),
                              Item("overwrite_file_if_exists",
                                   label="Overwrite report file if exists?"),
                              visible_when=is_ppt + " or " + is_vega),
                       Item("overwrite_file_if_exists",
                            label="Overwrite image files if exist?",
                            visible_when=is_img),
                       Item('skip_hidden', label="Skip hidden plots"),
                       label="General Parameters",
                       show_border=True),
                VGroup(
                    HGroup(
                        Item('export_data', label="Export data?"),
                        Item("export_each_plot_data",
                             label="Export each plot's data?",
                             visible_when="export_data=='{}'".format(
                                 EXPORT_YES))  # noqa
                    ),
                    HGroup(Item("data_filename"),
                           Item("data_format"),
                           visible_when=dat_file_visible_when),
                    HGroup(Label(inline_warning),
                           visible_when=inline_visible_when),
                    label="Data Parameters",
                    show_border=True),
                VGroup(HGroup(
                    Item('image_format'),
                    Item('image_dpi'),
                ),
                       Item("filename_from_title"),
                       visible_when=is_img,
                       label="Image Parameters",
                       show_border=True),
                VGroup(Item("presentation_title"),
                       Item("presentation_subtitle"),
                       label="Powerpoint Parameters",
                       show_border=True,
                       visible_when=is_ppt),
            ),
            buttons=OKCancelButtons,
            title="Export plots",
            resizable=True,
            width=600)
        return view

    # Public interface --------------------------------------------------------

    def export(self):
        """ Launch view to select parameters and export content to file.
        """
        if self.interactive:
            ui = self.edit_traits(kind="livemodal")

        if not self.interactive or ui.result:
            msg = f"Exporting plot content to {self.export_format}."
            logger.log(ACTION_LEVEL, msg)

            to_meth = getattr(self, METHOD_MAP[self.export_format])
            try:
                to_meth()
                if self.interactive:
                    target = self.target_dir if self.export_format != \
                                              PPT_FORMAT else self.target_file
                    open_file(target)

            except Exception as e:
                msg = f"Failed to export the plot list. Error was {e}."
                logger.exception(msg)
                if self.interactive:
                    error(None, msg)

    def to_folder(self, **kwargs):
        """ Export all plots as separate images files PNG, JPEG, ....
        """
        if not isdir(self.target_dir):
            os.makedirs(self.target_dir)

        plot_list = self.df_plotter.contained_plots

        if self.export_data == EXPORT_YES:
            if not self.export_each_plot_data:
                self._export_data_source_to_file(target=self.target_dir)
            else:
                fname = self.data_filename + self.data_format
                data_path = join(self.target_dir, fname)
                self._export_plot_data_to_file(plot_list, data_path=data_path)

        if self.filename_from_title:
            filename_patt = "{i}_{title}.{ext}"
        else:
            filename_patt = "plot_{i}.{ext}"

        for i, desc in enumerate(plot_list):
            if self.skip_hidden and not desc.visible:
                continue

            if self.filename_from_title:
                title = string2filename(desc.plot_title)
            else:
                title = ""

            filename = filename_patt.format(i=i,
                                            ext=self.image_format,
                                            title=title)
            filepath = join(self.target_dir, filename)
            if isfile(filepath) and not self.overwrite_file_if_exists:
                msg = "Target image file path specified already exists" \
                      ": {}. Move the file or select the 'overwrite' checkbox."
                msg = msg.format(filepath)
                logger.exception(msg)
                raise IOError(msg)

            save_plot_to_file(desc.plot,
                              filepath=filepath,
                              dpi=self.image_dpi,
                              **kwargs)

    def to_pptx(self, **kwargs):
        """ Export all plots as a PPTX presentation with a plot per slide.
        """
        # Protect imports so pptx remains an optional import
        from pybleau.reporting.pptx_utils import image_to_slide, Presentation,\
            title_slide

        target_dir = dirname(self.target_file)
        data_fname = self.data_filename + self.data_format
        target_data_file = join(target_dir, data_fname)
        if not isdir(target_dir):
            os.makedirs(target_dir)

        plot_list = self.df_plotter.contained_plots

        if self.export_data == EXPORT_YES:
            if not self.export_each_plot_data:
                self._export_data_source_to_file(target=target_data_file)
            else:
                self._export_plot_data_to_file(plot_list,
                                               data_path=target_data_file)

        if splitext(self.target_file)[1] != ".pptx":
            self.target_file += ".pptx"

        if isfile(self.target_file) and not self.overwrite_file_if_exists:
            msg = "Target description file path specified already exists" \
                  ": {}. Move the file or select the 'overwrite' checkbox."
            msg = msg.format(self.target_file)
            logger.exception(msg)
            raise IOError(msg)

        presentation = Presentation()
        title_slide(presentation,
                    title_text=self.presentation_title,
                    sub_title_text=self.presentation_subtitle)
        img_path = mkstemp()[1] + ".png"

        for i, desc in enumerate(plot_list):
            if self.skip_hidden and not desc.visible:
                continue

            save_plot_to_file(desc.plot,
                              filepath=img_path,
                              dpi=self.image_dpi,
                              **kwargs)
            title = "plot_{}".format(i)
            image_to_slide(presentation, img_path=img_path, slide_title=title)
            os.remove(img_path)

        presentation.save(self.target_file)

    def to_vega(self):
        """ Export plot content to dict (option. file) in Vega-Lite format.

        This is useful to recreate this plotter's content, in a separate
        process, using any plotting library. See the pybleau.reporting for
        tools leveraging this export.

        Returns
        -------
        dict:
            Description of the content of the data plotter.

        TODO: Add filter information for each plot.
        """
        content = {CONTENT_KEY: [], DATASETS_KEY: {}}

        if splitext(self.target_file)[1] != ".json":
            self.target_file += ".json"

        export_dir = abspath(dirname(self.target_file))

        if not isdir(export_dir):
            os.makedirs(export_dir)

        if self.export_data != EXPORT_NO:
            df_data = {}
            if self.export_data == EXPORT_SEPARATE:
                df_data[DATA_FILE_KEY] = self.data_filename + self.data_format
                if self.data_format == ".h5":
                    df_data[DATA_FILE_KEY_KEY] = DATA_KEY
                    comp_info = dict(complib="blosc", complevel=9)
                    df_data[DATA_FILE_COMP_KEY] = comp_info
                    self._export_data_source_to_file(export_dir, **comp_info)
                else:
                    self._export_data_source_to_file(export_dir)

            elif self.export_data == EXPORT_IN_FILE:
                df_data[DATA_KEY] = df_to_vega(self.df_plotter.data_source)
                df_data[IDX_NAME_KEY] = self.df_plotter.data_source.index.name
                content[DATASETS_KEY][DEFAULT_DATASET_NAME] = df_data

        for desc in self.df_plotter.contained_plots:
            if self.export_data == EXPORT_INLINE:
                plot_desc = chaco2vega(desc.plot_config, export_data="inline")
            elif self.export_data == EXPORT_IN_FILE:
                plot_desc = chaco2vega(desc.plot_config,
                                       export_data=DEFAULT_DATASET_NAME)
            else:
                plot_desc = chaco2vega(desc.plot_config, export_data=False)

            content[CONTENT_KEY].append(plot_desc)

        if self.target_file:
            if isfile(self.target_file) and not self.overwrite_file_if_exists:
                msg = f"Target description file path specified already " \
                    f"exists: {self.target_file}. Move the file or select " \
                    f"the 'overwrite' checkbox."
                logger.exception(msg)
                raise IOError(msg)

            json.dump(content,
                      open(self.target_file, "w"),
                      indent=self.json_index)

        return content

    # Private interface -------------------------------------------------------

    def _export_data_source_to_file(self,
                                    target,
                                    key=DEFAULT_DATASET_NAME,
                                    **kwargs):
        """ Export the plotter's data source to file.

        Parameters
        ----------
        target : str or pd.XlsxWtriter or pd.HDFStore
            Path to the file or folder or store object to write the source data
            to.

        key : str
            Tag for the dataset. Used only for Excel files (sheet name) or HDF5
            (dataset node name).
        """
        data_format = self.data_format

        if isinstance(target, string_types) and isfile(target):
            msg = "Target data file path specified already exists: {}. It " \
                  "will be overwritten!".format(target)
            logger.info(msg)
        elif isinstance(target, string_types) and isdir(target):
            target = join(target, self.data_filename + data_format)

        df = self.df_plotter.data_source
        if data_format == ".csv":
            df.to_csv(target)
        elif data_format == ".xlsx":
            df.to_excel(target, sheet_name=key, **kwargs)
        elif data_format == ".h5":
            df.to_hdf(target, key=key, **kwargs)
        else:
            msg = "Format {} not implemented. Please report this issue."
            msg = msg.format(data_format)
            logger.exception(msg)
            raise NotImplementedError(msg)

        return target

    def _export_plot_data_to_file(self, plot_list, data_path, **kwargs):
        """ Export the plots' PlotData to a file.

        Supported formats include zipped .csv, multi-tab .xlsx and multi-key
        HDF5.

        Parameters
        ----------
        plot_list : list
            List of PlotDescriptor instances, containing the plot whose data
            need to be exported.

        data_path : str
            Path to the data file to be generated.
        """
        data_format = self.data_format

        if not splitext(data_path)[1]:
            data_path += data_format

        if isfile(data_path):
            msg = "Target data path specified already exists: {}. It will be" \
                  " overwritten.".format(data_path)
            logger.warning(msg)

        data_dir = dirname(data_path)

        if data_format == ".xlsx":
            writer = pd.ExcelWriter(data_path)
        elif data_format == ".h5":
            writer = pd.HDFStore(data_path)

        try:
            if data_format == ".csv":
                data_path = join(
                    data_dir,
                    string2filename(DEFAULT_DATASET_NAME) + ".csv")
                self._export_data_source_to_file(target=data_path)
            elif data_format == ".xlsx":
                writer = pd.ExcelWriter(data_path)
                self._export_data_source_to_file(target=writer,
                                                 key=DEFAULT_DATASET_NAME)
            else:
                self._export_data_source_to_file(target=writer,
                                                 key=DEFAULT_DATASET_NAME)

            created_csv_files = [data_path]
            for i, desc in enumerate(plot_list):
                df_dict = plot_data2dataframes(desc)
                for name, df in df_dict.items():
                    key = "plot_{}_{}".format(i, name)
                    if data_format == ".csv":
                        target_fpath = join(data_dir,
                                            string2filename(key) + ".csv")
                        df.to_csv(target_fpath)
                        created_csv_files.append(target_fpath)
                    elif data_format == ".xlsx":
                        df.to_excel(writer, sheet_name=key, **kwargs)
                    elif data_format == ".h5":
                        df.to_hdf(data_path, key=key, **kwargs)
                    else:
                        msg = "Data format {} not implemented. Please report" \
                              " this issue.".format(data_format)
                        logger.exception(msg)
                        raise NotImplementedError(msg)
        finally:
            if data_format in [".xlsx", ".h5"]:
                writer.close()

        if data_format == ".csv" and len(created_csv_files) > 1:
            # zip up all csv files:
            data_path = join(data_dir, self.data_filename + ".zip")
            with ZipFile(data_path, "w") as f:
                for f_path in created_csv_files:
                    f.write(f_path, basename(f_path))

            for f_path in created_csv_files:
                os.remove(f_path)

        if self.interactive:
            msg = "Plot data is stored in: {}".format(data_path)
            information(None, msg)

    # Property getters/setters ------------------------------------------------

    @cached_property
    def _get__export_data_options(self):
        if self.export_format in {IMG_FORMAT, PPT_FORMAT}:
            return [EXPORT_NO, EXPORT_YES]
        elif self.export_format == VEGA_FORMAT:
            return [EXPORT_NO, EXPORT_SEPARATE, EXPORT_IN_FILE, EXPORT_INLINE]
        else:
            msg = "List of export data options not set for format {}."
            msg = msg.format(self.export_format)
            logger.exception(msg)
            raise NotImplementedError(msg)

    # Traits initialization methods -------------------------------------------

    def _target_file_default(self):
        return join(self.target_dir, DEFAULT_EXPORT_FILENAME)

    def _target_dir_default(self):
        return expanduser("~")

    def __many_plots_default(self):
        return len(self.df_plotter.contained_plots) >= 3
Beispiel #7
0
class ToolkitEditorFactory(EditorFactory):
    """ Editor factory for range editors.
    """

    # -------------------------------------------------------------------------
    #  Trait definitions:
    # -------------------------------------------------------------------------

    #: Number of columns when displayed as an enumeration
    cols = Range(1, 20)

    #: Is user input set on every keystroke?
    auto_set = Bool(True)

    #: Is user input set when the Enter key is pressed?
    enter_set = Bool(False)

    #: Label for the low end of the range
    low_label = Str()

    #: Label for the high end of the range
    high_label = Str()

    #: FIXME: This is supported only in the wx backend so far.
    #: The width of the low and high labels
    label_width = Int()

    #: The name of an [object.]trait that defines the low value for the range
    low_name = Str()

    #: The name of an [object.]trait that defines the high value for the range
    high_name = Str()

    #: Formatting string used to format value and labels
    format = Str("%s")

    #: Is the range for floating pointer numbers (vs. integers)?
    is_float = Bool(Undefined)

    #: Function to evaluate floats/ints when they are assigned to an object
    #: trait
    evaluate = Any()

    #: The object trait containing the function used to evaluate floats/ints
    evaluate_name = Str()

    #: Low end of range
    low = Property()

    #: High end of range
    high = Property()

    #: Display mode to use
    mode = Enum("auto", "slider", "xslider", "spinner", "enum", "text",
                "logslider")

    # -------------------------------------------------------------------------
    #  Traits view definition:
    # -------------------------------------------------------------------------

    traits_view = View([
        ["low", "high", "|[Range]"],
        ["low_label{Low}", "high_label{High}", "|[Range Labels]"],
        [
            "auto_set{Set automatically}",
            "enter_set{Set on enter key pressed}",
            "is_float{Is floating point range}",
            "-[Options]>",
        ],
        ["cols", "|[Number of columns for integer custom style]<>"],
    ])

    def init(self, handler=None):
        """ Performs any initialization needed after all constructor traits
            have been set.
        """
        if handler is not None:
            if isinstance(handler, CTrait):
                handler = handler.handler

            if self.low_name == "":
                if isinstance(handler._low, CodeType):
                    self.low = eval(handler._low)
                else:
                    self.low = handler._low

            if self.high_name == "":
                if isinstance(handler._low, CodeType):
                    self.high = eval(handler._high)
                else:
                    self.high = handler._high
        else:
            if (self.low is None) and (self.low_name == ""):
                self.low = 0.0

            if (self.high is None) and (self.high_name == ""):
                self.high = 1.0

    def _get_low(self):
        return self._low

    def _set_low(self, low):
        old_low = self._low
        self._low = low = self._cast(low)
        if self.is_float is Undefined:
            self.is_float = isinstance(low, float)

        if (self.low_label == "") or (self.low_label == str(old_low)):
            self.low_label = str(low)

    def _get_high(self):
        return self._high

    def _set_high(self, high):
        old_high = self._high
        self._high = high = self._cast(high)
        if self.is_float is Undefined:
            self.is_float = isinstance(high, float)

        if (self.high_label == "") or (self.high_label == str(old_high)):
            self.high_label = str(high)

    def _cast(self, value):
        if not isinstance(value, str):
            return value

        try:
            return int(value)
        except ValueError:
            return float(value)

    # -- Private Methods ------------------------------------------------------

    def _get_low_high(self, ui):
        """ Returns the low and high values used to determine the initial range.
        """
        low, high = self.low, self.high

        if (low is None) and (self.low_name != ""):
            low = self.named_value(self.low_name, ui)
            if self.is_float is Undefined:
                self.is_float = isinstance(low, float)

        if (high is None) and (self.high_name != ""):
            high = self.named_value(self.high_name, ui)
            if self.is_float is Undefined:
                self.is_float = isinstance(high, float)

        if self.is_float is Undefined:
            self.is_float = True

        return (low, high, self.is_float)

    # -------------------------------------------------------------------------
    #  Property getters.
    # -------------------------------------------------------------------------
    def _get_simple_editor_class(self):
        """ Returns the editor class to use for a simple style.

        The type of editor depends on the type and extent of the range being
        edited:

        * One end of range is unspecified: RangeTextEditor
        * **mode** is specified and not 'auto': editor corresponding to **mode**
        * Floating point range with extent > 100: LargeRangeSliderEditor
        * Integer range or floating point range with extent <= 100:
          SimpleSliderEditor
        * All other cases: SimpleSpinEditor
        """
        low, high, is_float = self._low_value, self._high_value, self.is_float

        if (low is None) or (high is None):
            return toolkit_object("range_editor:RangeTextEditor")

        if (not is_float) and (abs(high - low) > 1000000000):
            return toolkit_object("range_editor:RangeTextEditor")

        if self.mode != "auto":
            return toolkit_object("range_editor:SimpleEditorMap")[self.mode]

        if is_float and (abs(high - low) > 100):
            return toolkit_object("range_editor:LargeRangeSliderEditor")

        if is_float or (abs(high - low) <= 100):
            return toolkit_object("range_editor:SimpleSliderEditor")

        return toolkit_object("range_editor:SimpleSpinEditor")

    def _get_custom_editor_class(self):
        """ Creates a custom style of range editor

        The type of editor depends on the type and extent of the range being
        edited:

        * One end of range is unspecified: RangeTextEditor
        * **mode** is specified and not 'auto': editor corresponding to **mode**
        * Floating point range: Same as "simple" style
        * Integer range with extent > 15: Same as "simple" style
        * Integer range with extent <= 15: CustomEnumEditor

        """
        low, high, is_float = self._low_value, self._high_value, self.is_float
        if (low is None) or (high is None):
            return toolkit_object("range_editor:RangeTextEditor")

        if self.mode != "auto":
            return toolkit_object("range_editor:CustomEditorMap")[self.mode]

        if is_float or (abs(high - low) > 15):
            return self.simple_editor_class

        return toolkit_object("range_editor:CustomEnumEditor")

    def _get_text_editor_class(self):
        """Returns the editor class to use for a text style.
        """
        return toolkit_object("range_editor:RangeTextEditor")

    # -------------------------------------------------------------------------
    #  'Editor' factory methods:
    # -------------------------------------------------------------------------

    def simple_editor(self, ui, object, name, description, parent):
        """ Generates an editor using the "simple" style.
        Overridden to set the values of the _low_value, _high_value and
        is_float traits.

        """
        self._low_value, self._high_value, self.is_float = self._get_low_high(
            ui)
        return super(RangeEditor, self).simple_editor(ui, object, name,
                                                      description, parent)

    def custom_editor(self, ui, object, name, description, parent):
        """ Generates an editor using the "custom" style.
        Overridden to set the values of the _low_value, _high_value and
        is_float traits.

        """
        self._low_value, self._high_value, self.is_float = self._get_low_high(
            ui)
        return super(RangeEditor, self).custom_editor(ui, object, name,
                                                      description, parent)
Beispiel #8
0
class XYScatterOptions(BasePlotterOptions):
    # update_needed = Event
    auto_refresh = Bool

    index_attr = Str('Ar40')
    value_attr = Str('Ar39')

    marker_size = Range(0.0, 10., 2.0)
    marker = MarkerTrait
    marker_color = ColorTrait

    attrs = Dict(DATABASE_ATTRS)
    index_error = Bool
    value_error = Bool

    index_end_caps = Bool
    value_end_caps = Bool

    index_nsigma = Enum(1, 2, 3)
    value_nsigma = Enum(1, 2, 3)

    index_time_units = Enum('h', 'm', 's', 'days')
    index_time_scalar = Property

    value_time_units = Enum('h', 'm', 's', 'days')
    value_time_scalar = Property
    fit = Enum([NULL_STR] + FIT_TYPES)

    datasource = Enum('Database', 'File')

    file_source_path = Str
    datasource_name = Property(depends_on='file_source_path')
    use_file_source = Property(depends_on='datasource')
    open_file_button = Button
    _parser = None

    def get_marker_dict(self):
        kw = dict(marker=self.marker,
                  marker_size=self.marker_size,
                  color=self.marker_color)
        return kw

    def get_parser(self):
        p = self._parser
        if p is None:
            p = CSVColumnParser()
            p.load(self.file_source_path)
            self._parser = p
        return p

    def _load_hook(self):
        if self.use_file_source:
            self._load_file_source()

    def _load_file_source(self):
        p = self.get_parser()
        keys = p.list_attributes()
        self.attrs = {
            ai: '{:02d}:{}'.format(i, ai)
            for i, ai in enumerate(keys)
        }
        self.index_attr = keys[0]
        self.value_attr = keys[1]

    def _datasource_changed(self, new):
        if new == 'Database':
            self.attrs = DATABASE_ATTRS

    def _get_dump_attrs(self):
        return [
            'index_attr', 'index_error', 'index_end_caps', 'index_nsigma',
            'index_time_units', 'value_attr', 'value_error', 'value_end_caps',
            'value_nsigma', 'value_time_units', 'fit', 'marker_color',
            'marker', 'marker_size', 'auto_refresh', 'datasource',
            'file_source_path'
        ]

    def _get_index_time_scalar(self):
        return TIME_SCALARS[self.index_time_units]

    def _get_value_time_scalar(self):
        return TIME_SCALARS[self.value_time_units]

    def _get_use_file_source(self):
        return self.datasource == 'File'

    def _get_datasource_name(self):
        p = ''
        if os.path.isfile(self.file_source_path):
            p = os.path.basename(self.file_source_path)
        return p

    def _open_file_button_fired(self):

        # dlg=FileDialog(action='open', default_directory=paths.data_dir)
        # dlg.open()
        # if dlg.path:
        #     self.file_source_path=dlg.path
        #     self._load_file_source()

        p = '/Users/ross/Sandbox/xy_scatter_test.csv'
        self.file_source_path = p
        self._load_file_source()

    @on_trait_change('index_+, value_+, marker+')
    def _refresh(self):
        if self.auto_refresh:
            # self.update_needed = True
            self.refresh_plot_needed = True

    def traits_view(self):
        index_grp = VGroup(HGroup(
            Item('index_attr',
                 editor=EnumEditor(name='attrs'),
                 label='X Attr.'),
            Item('index_time_units',
                 label='Units',
                 visible_when='index_attr=="timestamp"')),
                           HGroup(
                               Item('index_error', label='Show'),
                               Item('index_end_caps',
                                    label='End Caps',
                                    enabled_when='index_error'),
                               Item('index_nsigma',
                                    label='NSigma',
                                    enabled_when='index_error')),
                           label='X Error',
                           show_border=True)
        value_grp = VGroup(HGroup(
            Item('value_attr',
                 editor=EnumEditor(name='attrs'),
                 label='Y Attr.'),
            Item('value_time_units',
                 label='Units',
                 visible_when='value_attr=="timestamp"')),
                           HGroup(
                               Item('value_error', label='Show'),
                               Item('value_end_caps',
                                    label='End Caps',
                                    enabled_when='value_error'),
                               Item('value_nsigma',
                                    label='NSigma',
                                    enabled_when='value_error')),
                           label='Y Error',
                           show_border=True)

        marker_grp = VGroup(Item('marker'),
                            Item('marker_size', label='Size'),
                            Item('marker_color', label='Color'),
                            show_border=True)

        datasource_grp = HGroup(
            Item('datasource'),
            UItem('datasource_name',
                  style='readonly',
                  visible_when='use_file_source'),
            icon_button_editor('open_file_button',
                               'document-open',
                               visible_when='use_file_source"'))
        v = View(HGroup(Item('auto_refresh'),
                        icon_button_editor('refresh_plot_needed', 'refresh')),
                 datasource_grp,
                 index_grp,
                 value_grp,
                 marker_grp,
                 Item('fit'),
                 resizable=True)
        return v
Beispiel #9
0
class ImageButton(Widget):
    """ An image and text-based control that can be used as a normal, radio or
        toolbar button.
    """

    # Pens used to draw the 'selection' marker:
    _selectedPenDark = wx.Pen(
        wx.SystemSettings_GetColour(wx.SYS_COLOUR_3DSHADOW), 1, wx.SOLID)

    _selectedPenLight = wx.Pen(
        wx.SystemSettings_GetColour(wx.SYS_COLOUR_3DHIGHLIGHT), 1, wx.SOLID)

    #---------------------------------------------------------------------------
    #  Trait definitions:
    #---------------------------------------------------------------------------

    # The image:
    image = Instance(ImageResource, allow_none=True)

    # The (optional) label:
    label = Str

    # Extra padding to add to both the left and right sides:
    width_padding = Range(0, 31, 7)

    # Extra padding to add to both the top and bottom sides:
    height_padding = Range(0, 31, 5)

    # Presentation style:
    style = Enum('button', 'radio', 'toolbar', 'checkbox')

    # Orientation of the text relative to the image:
    orientation = Enum('vertical', 'horizontal')

    # Is the control selected ('radio' or 'checkbox' style)?
    selected = false

    # Fired when a 'button' or 'toolbar' style control is clicked:
    clicked = Event

    #---------------------------------------------------------------------------
    #  Initializes the object:
    #---------------------------------------------------------------------------

    def __init__(self, parent, **traits):
        """ Creates a new image control.
        """
        self._image = None

        super(ImageButton, self).__init__(**traits)

        # Calculate the size of the button:
        idx = idy = tdx = tdy = 0
        if self._image is not None:
            idx = self._image.GetWidth()
            idy = self._image.GetHeight()

        if self.label != '':
            dc = wx.ScreenDC()
            dc.SetFont(wx.NORMAL_FONT)
            tdx, tdy = dc.GetTextExtent(self.label)

        wp2 = self.width_padding + 2
        hp2 = self.height_padding + 2
        if self.orientation == 'horizontal':
            self._ix = wp2
            spacing = (idx > 0) * (tdx > 0) * 4
            self._tx = self._ix + idx + spacing
            dx = idx + tdx + spacing
            dy = max(idy, tdy)
            self._iy = hp2 + ((dy - idy) / 2)
            self._ty = hp2 + ((dy - tdy) / 2)
        else:
            self._iy = hp2
            spacing = (idy > 0) * (tdy > 0) * 2
            self._ty = self._iy + idy + spacing
            dx = max(idx, tdx)
            dy = idy + tdy + spacing
            self._ix = wp2 + ((dx - idx) / 2)
            self._tx = wp2 + ((dx - tdx) / 2)

        # Create the toolkit-specific control:
        self._dx = dx + wp2 + wp2
        self._dy = dy + hp2 + hp2
        self.control = wx.Window(parent, -1, size=wx.Size(self._dx, self._dy))
        self.control._owner = self
        self._mouse_over = self._button_down = False

        # Set up mouse event handlers:
        wx.EVT_ENTER_WINDOW(self.control, self._on_enter_window)
        wx.EVT_LEAVE_WINDOW(self.control, self._on_leave_window)
        wx.EVT_LEFT_DOWN(self.control, self._on_left_down)
        wx.EVT_LEFT_UP(self.control, self._on_left_up)
        wx.EVT_PAINT(self.control, self._on_paint)

    #---------------------------------------------------------------------------
    #  Handles the 'image' trait being changed:
    #---------------------------------------------------------------------------

    def _image_changed(self, image):
        self._image = self._mono_image = None
        if image is not None:
            self._img = image.create_image()
            self._image = self._img.ConvertToBitmap()

        if self.control is not None:
            self.control.Refresh()

    #---------------------------------------------------------------------------
    #  Handles the 'selected' trait being changed:
    #---------------------------------------------------------------------------

    def _selected_changed(self, selected):
        """ Handles the 'selected' trait being changed.
        """
        if selected and (self.style == 'radio'):
            for control in self.control.GetParent().GetChildren():
                owner = getattr(control, '_owner', None)
                if (isinstance(owner, ImageButton) and owner.selected
                        and (owner is not self)):
                    owner.selected = False
                    break

        self.control.Refresh()

#-- wx event handlers ----------------------------------------------------------

    def _on_enter_window(self, event):
        """ Called when the mouse enters the widget. """

        if self.style != 'button':
            self._mouse_over = True
            self.control.Refresh()

    def _on_leave_window(self, event):
        """ Called when the mouse leaves the widget. """

        if self._mouse_over:
            self._mouse_over = False
            self.control.Refresh()

    def _on_left_down(self, event):
        """ Called when the left mouse button goes down on the widget. """
        self._button_down = True
        self.control.CaptureMouse()
        self.control.Refresh()

    def _on_left_up(self, event):
        """ Called when the left mouse button goes up on the widget. """
        control = self.control
        control.ReleaseMouse()
        self._button_down = False
        wdx, wdy = control.GetClientSizeTuple()
        x, y = event.GetX(), event.GetY()
        control.Refresh()
        if (0 <= x < wdx) and (0 <= y < wdy):
            if self.style == 'radio':
                self.selected = True
            elif self.style == 'checkbox':
                self.selected = not self.selected
            else:
                self.clicked = True

    def _on_paint(self, event):
        """ Called when the widget needs repainting.
        """
        wdc = wx.PaintDC(self.control)
        wdx, wdy = self.control.GetClientSizeTuple()
        ox = (wdx - self._dx) / 2
        oy = (wdy - self._dy) / 2

        disabled = (not self.control.IsEnabled())
        if self._image is not None:
            image = self._image
            if disabled:
                if self._mono_image is None:
                    img = self._img
                    data = reshape(fromstring(img.GetData(), dtype('uint8')),
                                   (-1, 3)) * array([[0.297, 0.589, 0.114]])
                    g = data[:, 0] + data[:, 1] + data[:, 2]
                    data[:, 0] = data[:, 1] = data[:, 2] = g
                    img.SetData(ravel(data.astype(dtype('uint8'))).tostring())
                    img.SetMaskColour(0, 0, 0)
                    self._mono_image = img.ConvertToBitmap()
                    self._img = None
                image = self._mono_image
            wdc.DrawBitmap(image, ox + self._ix, oy + self._iy, True)

        if self.label != '':
            if disabled:
                wdc.SetTextForeground(DisabledTextColor)
            wdc.SetFont(wx.NORMAL_FONT)
            wdc.DrawText(self.label, ox + self._tx, oy + self._ty)

        pens = [self._selectedPenLight, self._selectedPenDark]
        bd = self._button_down
        style = self.style
        is_rc = (style in ('radio', 'checkbox'))
        if bd or (style == 'button') or (is_rc and self.selected):
            if is_rc:
                bd = 1 - bd
            wdc.SetBrush(wx.TRANSPARENT_BRUSH)
            wdc.SetPen(pens[bd])
            wdc.DrawLine(1, 1, wdx - 1, 1)
            wdc.DrawLine(1, 1, 1, wdy - 1)
            wdc.DrawLine(2, 2, wdx - 2, 2)
            wdc.DrawLine(2, 2, 2, wdy - 2)
            wdc.SetPen(pens[1 - bd])
            wdc.DrawLine(wdx - 2, 2, wdx - 2, wdy - 1)
            wdc.DrawLine(2, wdy - 2, wdx - 2, wdy - 2)
            wdc.DrawLine(wdx - 3, 3, wdx - 3, wdy - 2)
            wdc.DrawLine(3, wdy - 3, wdx - 3, wdy - 3)

        elif self._mouse_over and (not self.selected):
            wdc.SetBrush(wx.TRANSPARENT_BRUSH)
            wdc.SetPen(pens[bd])
            wdc.DrawLine(0, 0, wdx, 0)
            wdc.DrawLine(0, 1, 0, wdy)
            wdc.SetPen(pens[1 - bd])
            wdc.DrawLine(wdx - 1, 1, wdx - 1, wdy)
            wdc.DrawLine(1, wdy - 1, wdx - 1, wdy - 1)
Beispiel #10
0
class SpectrumOptions(AgeOptions):
    label = 'Spectrum'
    step_nsigma = Int(2)
    plot_option_klass = SpectrumPlotOptions

    edit_plateau_criteria = Button
    pc_nsteps = Int(3)
    pc_gas_fraction = Float(50)

    include_j_error_in_plateau = Bool(True)
    plateau_age_error_kind = Enum(*ERROR_TYPES)
    # plateau_steps = Property(Str)
    # _plateau_steps = Str

    # calculate_fixed_plateau = Bool(False)
    # calculate_fixed_plateau_start = Str
    # calculate_fixed_plateau_end = Str

    plot_option_name = 'Age'
    display_extract_value = Bool(False)
    display_step = Bool(False)
    display_plateau_info = Bool(True)
    display_integrated_info = Bool(True)
    plateau_sig_figs = Int
    integrated_sig_figs = Int

    plateau_font_size = Enum(6, 7, 8, 10, 11, 12, 14, 15, 18, 24, 28, 32)
    integrated_font_size = Enum(6, 7, 8, 10, 11, 12, 14, 15, 18, 24, 28, 32)
    step_label_font_size = Enum(6, 7, 8, 10, 11, 12, 14, 15, 18, 24, 28, 32)
    envelope_alpha = Range(0, 100, style='simple')
    envelope_color = Color
    user_envelope_color = Bool
    center_line_style = Enum('No Line', 'solid', 'dash', 'dot dash', 'dot',
                             'long dash')
    extend_plateau_end_caps = Bool(True)
    # plateau_line_width = Float
    # plateau_line_color = Color
    # user_plateau_line_color = Bool

    plateau_method = Enum('Fleck 1977', 'Mahon 1996')
    error_calc_method = Property
    use_error_envelope_fill = Bool

    include_plateau_sample = Bool
    include_plateau_identifier = Bool

    # edit_groups_button = Button
    group_editor_klass = SpectrumGroupEditor
    options_klass = SpectrumGroupOptions

    # handlers
    @on_trait_change('display_step,display_extract_value')
    def _handle_labels(self):
        labels_enabled = self.display_extract_value or self.display_step
        self.aux_plots[-1].show_labels = labels_enabled

    #
    # def _edit_groups_button_fired(self):
    #     eg = SpectrumGroupEditor(error_envelopes=self.groups)
    #     info = eg.edit_traits()
    #     if info.result:
    #         self.refresh_plot_needed = True

    def _edit_plateau_criteria_fired(self):
        v = View(
            Item('pc_nsteps',
                 label='Num. Steps',
                 tooltip='Number of contiguous steps'),
            Item('pc_gas_fraction',
                 label='Min. Gas%',
                 tooltip='Plateau must represent at least Min. Gas% release'),
            buttons=['OK', 'Cancel'],
            title='Edit Plateau Criteria',
            kind='livemodal')
        self.edit_traits(v)

    def _get_error_calc_method(self):
        return self.plateau_age_error_kind

    def _set_error_calc_method(self, v):
        self.plateau_age_error_kind = v

    # def _get_info_group(self):
    # g = VGroup(
    # HGroup(Item('show_info', label='Display Info'),
    # Item('show_mean_info', label='Mean', enabled_when='show_info'),
    # Item('show_error_type_info', label='Error Type', enabled_when='show_info')
    # ),
    # HGroup(Item('display_step'), Item('display_extract_value'),
    # Item('display_plateau_info')),
    # show_border=True, label='Info')

    # return g

    # def _get_plateau_steps(self):
    # return self._plateau_steps
    #
    # def _set_plateau_steps(self, v):
    # if v:
    # self._plateau_steps = v
    #
    # def _validate_plateau_steps(self, v):
    # if plat_regex.match(v):
    # s, e = v.split('-')
    # try:
    # assert s < e
    #             return v
    #         except AssertionError:
    #             pass

    def _get_dump_attrs(self):
        attrs = super(SpectrumOptions, self)._get_dump_attrs()
        return attrs + [
            'step_nsigma',
            # 'calculate_fixed_plateau',
            # 'calculate_fixed_plateau_start',
            # 'calculate_fixed_plateau_end',
            'display_extract_value',
            'display_step',
            'display_plateau_info',
            'display_integrated_info',
            'plateau_font_size',
            'integrated_font_size',
            'step_label_font_size',
            # 'envelope_alpha',
            # 'user_envelope_color', 'envelope_color',
            # 'groups',
            # '_plateau_steps',
            'center_line_style',
            'extend_plateau_end_caps',
            # 'plateau_line_width',
            # 'plateau_line_color',
            # 'user_plateau_line_color',
            'include_j_error_in_plateau',
            'plateau_age_error_kind',
            'plateau_sig_figs',
            'integrated_sig_figs',
            'use_error_envelope_fill',
            'plateau_method',
            'pc_nsteps',
            'pc_gas_fraction',
            'legend_location',
            'include_legend',
            'include_sample_in_legend'
        ]

    def _get_groups(self):
        lgrp = VGroup(
            Item('plateau_method', label='Method'), Item('nsigma'),
            Item('plateau_age_error_kind', width=-100, label='Error Type'),
            Item('include_j_error_in_plateau', label='Include J Error'))
        rgrp = VGroup(
            Item('center_line_style', label='Line Stype'),
            Item('extend_plateau_end_caps', label='Extend End Caps'),
            icon_button_editor('edit_plateau_criteria',
                               'cog',
                               tooltip='Edit Plateau Criteria'),
        )
        plat_grp = HGroup(lgrp, rgrp, show_border=True, label='Plateau')

        # plat_grp = Group(
        #     HGroup(Item('plateau_method', label='Method'),
        #            icon_button_editor('edit_plateau_criteria', 'cog',
        #                               tooltip='Edit Plateau Criteria')),
        #     Item('center_line_style'),
        #     Item('extend_plateau_end_caps'),
        #     # Item('plateau_line_width'),
        #     # HGroup(UItem('user_plateau_line_color'),
        #     #        Item('plateau_line_color', enabled_when='user_plateau_line_color')),
        #
        #     Item('nsigma'),
        #     Item('plateau_age_error_kind',
        #          width=-100,
        #          label='Error Type'),
        #     Item('include_j_error_in_plateau', label='Include J Error'),
        #     # HGroup(
        #     #     Item('calculate_fixed_plateau',
        #     #          label='Calc. Plateau',
        #     #          tooltip='Calculate a plateau over provided steps'),
        #     #     Item('calculate_fixed_plateau_start', label='Start'),
        #     #     Item('calculate_fixed_plateau_end', label='End')
        #     # ),
        #     show_border=True,
        #     label='Plateau')

        # error_grp = VGroup(HGroup(Item('step_nsigma',
        #                                editor=EnumEditor(values=[1, 2, 3]),
        #                                tooltip='Set the size of the error envelope in standard deviations',
        #                                label='N. Sigma'),
        #                           Item('use_error_envelope_fill', label='Fill')),
        #                    HGroup(UItem('user_envelope_color'),
        #                           Item('envelope_color',
        #                                label='Color',
        #                                enabled_when='user_envelope_color'),
        #                           Item('envelope_alpha',
        #                                label='Opacity',
        #                                enabled_when='use_error_envelope_fill',
        #                                tooltip='Set the opacity (alpha-value) for the error envelope')),
        #                    show_border=True,
        #                    label='Error Envelope')
        # grp_grp = VGroup(HGroup(UItem('group',
        #                               style='custom',
        #                               editor=InstanceEditor(view='simple_view')),
        #                         icon_button_editor('edit_groups_button', 'cog')),
        #                  show_border=True,
        #                  label='Group Attributes')
        grp_grp = VGroup(UItem('group',
                               style='custom',
                               editor=InstanceEditor(view='simple_view')),
                         show_border=True,
                         label='Group Attributes')

        error_grp = VGroup(
            HGroup(
                Item(
                    'step_nsigma',
                    editor=EnumEditor(values=[1, 2, 3]),
                    tooltip=
                    'Set the size of the error envelope in standard deviations',
                    label='N. Sigma')),
            # HGroup(UItem('error_envelope',
            #              style='custom',
            #              editor=InstanceEditor(view='simple_view')),
            #        icon_button_editor('edit_envelopes_button', 'cog')),
            # HGroup(UItem('user_envelope_color'),
            #        Item('envelope_color',
            #             label='Color',
            #             enabled_when='user_envelope_color'),
            #        Item('envelope_alpha',
            #             label='Opacity',
            #             enabled_when='use_error_envelope_fill',
            #             tooltip='Set the opacity (alpha-value) for the error envelope')),
            show_border=True,
            label='Error Envelope')

        display_grp = Group(
            HGroup(UItem(
                'show_info',
                tooltip='Show general info in the upper right corner'),
                   show_border=True,
                   label='General'),
            VGroup(Item('include_legend', label='Show'),
                   Item('include_sample_in_legend', label='Include Sample'),
                   Item('legend_location', label='Location'),
                   label='Legend',
                   show_border=True),
            HGroup(Item('display_step', label='Step'),
                   Item('display_extract_value', label='Power/Temp'),
                   spring,
                   Item('step_label_font_size', label='Size'),
                   show_border=True,
                   label='Labels'),
            VGroup(
                HGroup(
                    UItem('display_plateau_info',
                          tooltip='Display plateau info'),
                    Item('plateau_font_size',
                         label='Size',
                         enabled_when='display_plateau_info'),
                    Item('plateau_sig_figs', label='SigFigs')),
                HGroup(
                    Item(
                        'include_plateau_sample',
                        tooltip='Add the Sample name to the Plateau indicator',
                        label='Sample'),
                    Item('include_plateau_identifier',
                         tooltip='Add the Identifier to the Plateau indicator',
                         label='Identifier')),
                show_border=True,
                label='Plateau'),
            HGroup(UItem('display_integrated_info',
                         tooltip='Display integrated age info'),
                   Item('integrated_font_size',
                        label='Size',
                        enabled_when='display_integrated_info'),
                   Item('integrated_sig_figs', label='SigFigs'),
                   show_border=True,
                   label='Integrated'),
            show_border=True,
            label='Display')
        g = Group(
            self._get_title_group(),
            grp_grp,
            plat_grp,
            error_grp,
            display_grp,
            # self._get_info_group(),
            label='Options')

        # label_grp = VGroup(self._get_x_axis_group(),
        #                    self._get_y_axis_group(),
        #                    label='Fonts')
        return g,

    def _load_factory_defaults(self, yd):
        super(SpectrumOptions, self)._load_factory_defaults(yd)

        self._set_defaults(
            yd, 'legend',
            ('legend_location', 'include_legend', 'include_sample_in_legend'))

        self._set_defaults(
            yd,
            'plateau',
            (
                'plateau_line_width',
                'plateau_line_color',
                'plateau_font_size',
                'plateau_sig_figs',
                # 'calculate_fixed_plateau',
                # 'calculate_fixed_plateau_start',
                # 'calculate_fixed_plateau_end',
                'pc_nsteps',
                'pc_gas_fraction'))

        self._set_defaults(yd, 'integrated', (
            'integrated_font_size',
            'integrated_sig_figs',
        ))
        self._set_defaults(
            yd, 'labels',
            ('display_step', 'display_extract_value', 'step_label_font_size'))
Beispiel #11
0
class BarPlot(AbstractPlotRenderer):
    """
    A renderer for bar charts.
    """
    # The data source to use for the index coordinate.
    index = Instance(ArrayDataSource)

    # The data source to use as value points.
    value = Instance(ArrayDataSource)

    # The data source to use as "starting" values for the bars.
    # For instance, if the values are [10, 20] and starting_value
    # is [3, 7], BarPlot will plot two bars, one  between 3 and 10, and
    # one between 7 and 20
    starting_value = Instance(ArrayDataSource)

    # Labels for the indices.
    index_mapper = Instance(AbstractMapper)
    # Labels for the values.
    value_mapper = Instance(AbstractMapper)

    # The orientation of the index axis.
    orientation = Enum("h", "v")

    # The direction of the index axis with respect to the graphics context's
    # direction.
    index_direction = Enum("normal", "flipped")

    # The direction of the value axis with respect to the graphics context's
    # direction.
    value_direction = Enum("normal", "flipped")

    # Type of width used for bars:
    #
    # 'data'
    #     The width is in the units along the x-dimension of the data space.
    # 'screen'
    #     The width uses a fixed width of pixels.
    bar_width_type = Enum("data", "screen")

    # Width of the bars, in data or screen space (determined by
    # **bar_width_type**).
    bar_width = Float(10)

    # Round on rectangle dimensions? This is not strictly an "antialias", but
    # it has the same effect through exact pixel drawing.
    antialias = Bool(True)

    # Width of the border of the bars.
    line_width = Float(1.0)
    # Color of the border of the bars.
    line_color = black_color_trait
    # Color to fill the bars.
    fill_color = black_color_trait

    # The RGBA tuple for rendering lines.  It is always a tuple of length 4.
    # It has the same RGB values as line_color_, and its alpha value is the
    # alpha value of self.line_color multiplied by self.alpha.
    effective_line_color = Property(Tuple, depends_on=['line_color', 'alpha'])

    # The RGBA tuple for rendering the fill.  It is always a tuple of length 4.
    # It has the same RGB values as fill_color_, and its alpha value is the
    # alpha value of self.fill_color multiplied by self.alpha.
    effective_fill_color = Property(Tuple, depends_on=['fill_color', 'alpha'])

    # Overall alpha value of the image. Ranges from 0.0 for transparent to 1.0
    alpha = Range(0.0, 1.0, 1.0)

    #use_draw_order = False

    # Convenience properties that correspond to either index_mapper or
    # value_mapper, depending on the orientation of the plot.

    # Corresponds to either **index_mapper** or **value_mapper**, depending on
    # the orientation of the plot.
    x_mapper = Property
    # Corresponds to either **value_mapper** or **index_mapper**, depending on
    # the orientation of the plot.
    y_mapper = Property

    # Corresponds to either **index_direction** or **value_direction**,
    # depending on the orientation of the plot.
    x_direction = Property
    # Corresponds to either **value_direction** or **index_direction**,
    # depending on the orientation of the plot
    y_direction = Property

    # Convenience property for accessing the index data range.
    index_range = Property
    # Convenience property for accessing the value data range.
    value_range = Property

    #------------------------------------------------------------------------
    # Private traits
    #------------------------------------------------------------------------

    # Indicates whether or not the data cache is valid
    _cache_valid = Bool(False)

    # Cached data values from the datasources.  If **bar_width_type** is "data",
    # then this is an Nx4 array of (bar_left, bar_right, start, end) for a
    # bar plot in normal orientation.  If **bar_width_type** is "screen", then
    # this is an Nx3 array of (bar_center, start, end).
    _cached_data_pts = Any

    #------------------------------------------------------------------------
    # AbstractPlotRenderer interface
    #------------------------------------------------------------------------

    def __init__(self, *args, **kw):
        # These Traits depend on others, so we'll defer setting them until
        # after the HasTraits initialization has been completed.
        later_list = ['index_direction', 'value_direction']
        postponed = {}
        for name in later_list:
            if name in kw:
                postponed[name] = kw.pop(name)

        super(BarPlot, self).__init__(*args, **kw)

        # Set any keyword Traits that were postponed.
        self.set(**postponed)

    def map_screen(self, data_array):
        """ Maps an array of data points into screen space and returns it as
        an array.

        Implements the AbstractPlotRenderer interface.
        """
        # data_array is Nx2 array
        if len(data_array) == 0:
            return []
        x_ary, y_ary = transpose(data_array)
        sx = self.index_mapper.map_screen(x_ary)
        sy = self.value_mapper.map_screen(y_ary)

        if self.orientation == "h":
            return transpose(array((sx, sy)))
        else:
            return transpose(array((sy, sx)))

    def map_data(self, screen_pt):
        """ Maps a screen space point into the "index" space of the plot.

        Implements the AbstractPlotRenderer interface.
        """
        if self.orientation == "h":
            screen_coord = screen_pt[0]
        else:
            screen_coord = screen_pt[1]
        return self.index_mapper.map_data(screen_coord)

    def map_index(self,
                  screen_pt,
                  threshold=2.0,
                  outside_returns_none=True,
                  index_only=False):
        """ Maps a screen space point to an index into the plot's index array(s).

        Implements the AbstractPlotRenderer interface.
        """
        data_pt = self.map_data(screen_pt)
        if ((data_pt < self.index_mapper.range.low) or \
            (data_pt > self.index_mapper.range.high)) and outside_returns_none:
            return None
        index_data = self.index.get_data()
        value_data = self.value.get_data()

        if len(value_data) == 0 or len(index_data) == 0:
            return None

        try:
            ndx = reverse_map_1d(index_data, data_pt, self.index.sort_order)
        except IndexError:
            return None

        x = index_data[ndx]
        y = value_data[ndx]

        result = self.map_screen(array([[x, y]]))
        if result is None:
            return None

        sx, sy = result[0]
        if index_only and ((screen_pt[0] - sx) < threshold):
            return ndx
        elif ((screen_pt[0] - sx)**2 +
              (screen_pt[1] - sy)**2 < threshold * threshold):
            return ndx
        else:
            return None

    #------------------------------------------------------------------------
    # PlotComponent interface
    #------------------------------------------------------------------------

    def _gather_points(self):
        """ Collects data points that are within the range of the plot, and
        caches them in **_cached_data_pts**.
        """
        index, index_mask = self.index.get_data_mask()
        value, value_mask = self.value.get_data_mask()

        if not self.index or not self.value:
            return

        if len(index) == 0 or len(value) == 0 or len(index) != len(value):
            logger.warn("Chaco: using empty dataset; index_len=%d, value_len=%d." \
                                % (len(index), len(value)))
            self._cached_data_pts = array([])
            self._cache_valid = True
            return

        # TODO: Until we code up a better handling of value-based culling that
        # takes into account starting_value and dataspace bar widths, just use
        # the index culling for now.
        #        value_range_mask = self.value_mapper.range.mask_data(value)
        #        nan_mask = invert(isnan(index_mask)) & invert(isnan(value_mask))
        #        point_mask = index_mask & value_mask & nan_mask & \
        #                     index_range_mask & value_range_mask

        index_range_mask = self.index_mapper.range.mask_data(index)
        nan_mask = invert(isnan(index_mask))
        point_mask = index_mask & nan_mask & index_range_mask

        if self.starting_value is None:
            starting_values = zeros(len(index))
        else:
            starting_values = self.starting_value.get_data()

        if self.bar_width_type == "data":
            half_width = self.bar_width / 2.0
            points = column_stack((index - half_width, index + half_width,
                                   starting_values, value))
        else:
            points = column_stack((index, starting_values, value))
        self._cached_data_pts = compress(point_mask, points, axis=0)

        self._cache_valid = True
        return

    def _draw_plot(self, gc, view_bounds=None, mode="normal"):
        """ Draws the 'plot' layer.
        """
        if not self._cache_valid:
            self._gather_points()

        data = self._cached_data_pts
        if data.size == 0:
            # Nothing to draw.
            return

        with gc:
            gc.clip_to_rect(self.x, self.y, self.width, self.height)
            gc.set_antialias(self.antialias)
            gc.set_stroke_color(self.effective_line_color)
            gc.set_fill_color(self.effective_fill_color)
            gc.set_line_width(self.line_width)

            if self.bar_width_type == "data":
                # map the bar start and stop locations into screen space
                lower_left_pts = self.map_screen(data[:, (0, 2)])
                upper_right_pts = self.map_screen(data[:, (1, 3)])
            else:
                half_width = self.bar_width / 2.0
                # map the bar centers into screen space and then compute the bar
                # start and end positions
                lower_left_pts = self.map_screen(data[:, (0, 1)])
                upper_right_pts = self.map_screen(data[:, (0, 2)])
                lower_left_pts[:, 0] -= half_width
                upper_right_pts[:, 0] += half_width

            bounds = upper_right_pts - lower_left_pts
            gc.rects(column_stack((lower_left_pts, bounds)))
            gc.draw_path()

    def _draw_default_axes(self, gc):
        if not self.origin_axis_visible:
            return

        with gc:
            gc.set_stroke_color(self.origin_axis_color_)
            gc.set_line_width(self.origin_axis_width)
            gc.set_line_dash(None)

            for range in (self.index_mapper.range, self.value_mapper.range):
                if (range.low < 0) and (range.high > 0):
                    if range == self.index_mapper.range:
                        dual = self.value_mapper.range
                        data_pts = array([[0.0, dual.low], [0.0, dual.high]])
                    else:
                        dual = self.index_mapper.range
                        data_pts = array([[dual.low, 0.0], [dual.high, 0.0]])
                    start, end = self.map_screen(data_pts)
                    gc.move_to(int(start[0]) + 0.5, int(start[1]) + 0.5)
                    gc.line_to(int(end[0]) + 0.5, int(end[1]) + 0.5)
                    gc.stroke_path()

        return

    def _render_icon(self, gc, x, y, width, height):
        with gc:
            gc.set_fill_color(self.effective_fill_color)
            gc.set_stroke_color(self.effective_line_color)
            gc.rect(x + width / 4, y + height / 4, width / 2, height / 2)
            gc.draw_path(FILL_STROKE)

    def _post_load(self):
        super(BarPlot, self)._post_load()
        return

    #------------------------------------------------------------------------
    # Properties
    #------------------------------------------------------------------------

    def _get_index_range(self):
        return self.index_mapper.range

    def _set_index_range(self, val):
        self.index_mapper.range = val

    def _get_value_range(self):
        return self.value_mapper.range

    def _set_value_range(self, val):
        self.value_mapper.range = val

    def _get_x_mapper(self):
        if self.orientation == "h":
            return self.index_mapper
        else:
            return self.value_mapper

    def _get_y_mapper(self):
        if self.orientation == "h":
            return self.value_mapper
        else:
            return self.index_mapper

    def _get_x_direction(self):
        if self.orientation == "h":
            return self.index_direction
        else:
            return self.value_direction

    def _get_y_direction(self):
        if self.orientation == "h":
            return self.value_direction
        else:
            return self.index_direction

    #------------------------------------------------------------------------
    # Event handlers - these are mostly copied from BaseXYPlot
    #------------------------------------------------------------------------

    def _update_mappers(self):
        """ Updates the index and value mappers. Called by trait change handlers
        for various traits.
        """
        x_mapper = self.index_mapper
        y_mapper = self.value_mapper
        x_dir = self.index_direction
        y_dir = self.value_direction

        if self.orientation == "v":
            x_mapper, y_mapper = y_mapper, x_mapper
            x_dir, y_dir = y_dir, x_dir

        x = self.x
        x2 = self.x2
        y = self.y
        y2 = self.y2

        if x_mapper is not None:
            if x_dir == "normal":
                x_mapper.low_pos = x
                x_mapper.high_pos = x2
            else:
                x_mapper.low_pos = x2
                x_mapper.high_pos = x

        if y_mapper is not None:
            if y_dir == "normal":
                y_mapper.low_pos = y
                y_mapper.high_pos = y2
            else:
                y_mapper.low_pos = y2
                y_mapper.high_pos = y

        self.invalidate_draw()
        self._cache_valid = False

    @on_trait_change('line_color, line_width, fill_color, alpha')
    def _attributes_changed(self):
        self.invalidate_draw()
        self.request_redraw()

    def _bounds_changed(self, old, new):
        super(BarPlot, self)._bounds_changed(old, new)
        self._update_mappers()

    def _bounds_items_changed(self, event):
        super(BarPlot, self)._bounds_items_changed(event)
        self._update_mappers()

    def _orientation_changed(self):
        self._update_mappers()

    def _index_changed(self, old, new):
        if old is not None:
            old.on_trait_change(self._either_data_changed,
                                "data_changed",
                                remove=True)
        if new is not None:
            new.on_trait_change(self._either_data_changed, "data_changed")
        self._either_data_changed()

    def _index_direction_changed(self):
        m = self.index_mapper
        m.low_pos, m.high_pos = m.high_pos, m.low_pos
        self.invalidate_draw()

    def _value_direction_changed(self):
        m = self.value_mapper
        m.low_pos, m.high_pos = m.high_pos, m.low_pos
        self.invalidate_draw()

    def _either_data_changed(self):
        self.invalidate_draw()
        self._cache_valid = False
        self.request_redraw()

    def _value_changed(self, old, new):
        if old is not None:
            old.on_trait_change(self._either_data_changed,
                                "data_changed",
                                remove=True)
        if new is not None:
            new.on_trait_change(self._either_data_changed, "data_changed")
        self._either_data_changed()

    def _index_mapper_changed(self, old, new):
        return self._either_mapper_changed(old, new)

    def _value_mapper_changed(self, old, new):
        return self._either_mapper_changed(old, new)

    def _either_mapper_changed(self, old, new):
        if old is not None:
            old.on_trait_change(self._mapper_updated_handler,
                                "updated",
                                remove=True)
        if new is not None:
            new.on_trait_change(self._mapper_updated_handler, "updated")
        self.invalidate_draw()

    def _mapper_updated_handler(self):
        self._cache_valid = False
        self.invalidate_draw()
        self.request_redraw()

    def _bar_width_changed(self):
        self._cache_valid = False
        self.invalidate_draw()
        self.request_redraw()

    def _bar_width_type_changed(self):
        self._cache_valid = False
        self.invalidate_draw()
        self.request_redraw()

    #------------------------------------------------------------------------
    # Property getters
    #------------------------------------------------------------------------

    @cached_property
    def _get_effective_line_color(self):
        if len(self.line_color_) == 4:
            line_alpha = self.line_color_[-1]
        else:
            line_alpha = 1.0
        c = self.line_color_[:3] + (line_alpha * self.alpha, )
        return c

    @cached_property
    def _get_effective_fill_color(self):
        if len(self.fill_color_) == 4:
            fill_alpha = self.fill_color_[-1]
        else:
            fill_alpha = 1.0
        c = self.fill_color_[:3] + (fill_alpha * self.alpha, )
        return c
Beispiel #12
0
class AnnotationEditor(HasTraits):
    component = Any

    border_visible = Bool(True)
    border_width = Range(0, 10)
    border_color = Color

    font = Font('modern 12')
    text_color = Color
    bgcolor = Color
    text = Str
    bg_visible = Bool(True)

    @on_trait_change('component:text')
    def _component_text_changed(self):
        self.text = self.component.text

    def _component_changed(self):
        if self.component:
            traits = ('border_visible', 'border_width', 'text')

            d = self.component.trait_get(traits)
            self.trait_set(self, **d)
            for c in ('border_color', 'text_color', 'bgcolor'):
                v = getattr(self.component, c)
                if not isinstance(v, str):
                    v = v[0] * 255, v[1] * 255, v[2] * 255

                self.trait_set(**{c: v})

    def _bg_visible_changed(self):
        if self.component:
            if self.bg_visible:
                self.component.bgcolor = self.bgcolor
            else:
                self.component.bgcolor = transparent_color
            self.component.request_redraw()

    @on_trait_change('border_+, text_color, bgcolor, text')
    def _update(self, name, new):
        if self.component:
            self.component.trait_set(**{name: new})
            self.component.request_redraw()

    @on_trait_change('font')
    def _update_font(self):
        if self.component:
            self.component.font = str(self.font)
            self.component.request_redraw()

    def traits_view(self):
        v = View(
            VGroup(Item(
                'font',
                width=75,
            ),
                   Item('text_color', label='Text'),
                   HGroup(
                       UItem('bg_visible',
                             tooltip='Is the background transparent'),
                       Item('bgcolor',
                            label='Background',
                            enabled_when='bg_visible'),
                   ),
                   UItem('text', style='custom'),
                   Group(Item('border_visible'),
                         Item('border_width', enabled_when='border_visible'),
                         Item('border_color', enabled_when='border_visible'),
                         label='Border'),
                   visible_when='component'))
        return v
Beispiel #13
0
class Elementary1DRule(NDimRule):
    """ Rule implementing an elementary 1D cellular automata.

    This uses Wolfram's rule numbering scheme to identify the rules and
    scipy.ndimage to handle the boundary conditions.

    Notes
    -----

    See `Wikipedia
    <https://en.wikipedia.org/wiki/Elementary_cellular_automaton>`_ for further
    information on how these automata work.
    """

    # Elementary1DRule Traits ------------------------------------------------

    #: The number of the rule.
    rule_number = Range(0, 255)

    #: The state value for "empty" cells.
    empty_state = StateValue(0)

    #: The state value for "filled" cells.
    filled_state = StateValue(1)

    #: The bit-mask corresponding to the rule.
    bit_mask = Property(Array(shape=(8,), dtype=bool),
                        depends_on='rule_number')

    #: The boundary mode to use.
    boundary = Enum('empty', 'filled', 'nearest', 'wrap', 'reflect')

    # NDimRule Traits --------------------------------------------------------

    #: These are 1-dimensional only rules.
    ndim = Constant(1)

    # ------------------------------------------------------------------------
    # Elementary1DRule interface
    # ------------------------------------------------------------------------

    def reflect(self):
        """ Reflect the cellular automata left-to-right. """
        self.bit_mask = self.bit_mask[REVERSE_PERMUTATION]

    def complement(self):
        """ Complement the cellular automata replacing 1's with 0's throughout. """
        self.bit_mask = ~self.bit_mask[::-1]

    # ------------------------------------------------------------------------
    # AbstractRule interface
    # ------------------------------------------------------------------------

    def step(self, states):
        """ Apply the specified rule to the states.

        Parameters
        ----------
        states : array
            An array holding the current states of the automata.

        Returns
        -------
        states : array
            The new states of the automata after the rule has been applied.
        """
        states = super(NDimRule, self).step(states)

        wrap_args = {'mode': self.boundary}
        if self.boundary == 'empty':
            wrap_args['mode'] = 'constant'
            wrap_args['cval'] = 0
        elif self.boundary == 'filled':
            wrap_args['mode'] = 'constant'
            wrap_args['cval'] = 1

        filled = (states == self.filled_state)
        filled = ndimage.generic_filter1d(
            filled, self._rule_filter, filter_size=3, **wrap_args)

        states = np.full(filled.shape, self.empty_state, dtype='uint8')
        states[filled] = self.filled_state
        return states

    # ------------------------------------------------------------------------
    # Private interface
    # ------------------------------------------------------------------------

    def _rule_filter(self, iline, oline):
        """ Kernel to compute values in generic filter """
        index = (iline[:-2] * 4 + iline[1:-1] * 2 + iline[2:]).astype('uint8')
        oline[...] = self.bit_mask[index]

    # Trait properties -------------------------------------------------------

    @cached_property
    def _get_bit_mask(self):
        bits = np.unpackbits(np.array([self.rule_number], dtype='uint8'))[::-1]
        return bits.astype(bool)

    def _set_bit_mask(self, bits):
        bits = np.asarray(bits, dtype=bool)
        self.rule_number = int(np.packbits(bits[::-1])[0])
Beispiel #14
0
class AnOddClass(HasTraits):
    oddball = Trait(1, TraitOddInteger())
    very_odd = Trait(-1, TraitOddInteger(), Range(-10, -1))
Beispiel #15
0
class Worker(HasTraits):
    """This class basically allows you to create a data set, view it
    and modify the dataset.  This is a rather crude example but
    demonstrates how things can be done.
    """

    # Set by envisage when this is contributed as a ServiceOffer.
    window = Instance('pyface.workbench.api.WorkbenchWindow')

    create_data = Button('Create data')
    reset_data = Button('Reset data')
    view_data = Button('View data')
    scale = Range(0.0, 1.0)
    source = Instance('mayavi.core.source.Source')

    # Our UI view.
    view = View(Item('create_data', show_label=False),
                Item('view_data', show_label=False),
                Item('reset_data', show_label=False),
                Item('scale'),
                resizable=True)

    def get_mayavi(self):
        from mayavi.plugins.script import Script
        return self.window.get_service(Script)

    def _make_data(self):
        dims = [64, 64, 64]
        np = dims[0] * dims[1] * dims[2]
        x, y, z = numpy.ogrid[-5:5:dims[0] * 1j, -5:5:dims[1] * 1j,
                              -5:5:dims[2] * 1j]
        x = x.astype('f')
        y = y.astype('f')
        z = z.astype('f')
        s = (numpy.sin(x * y * z) / (x * y * z))
        s = s.transpose().copy()  # This makes the data contiguous.
        return s

    def _create_data_fired(self):
        mayavi = self.get_mayavi()
        from mayavi.sources.array_source import ArraySource
        s = self._make_data()
        src = ArraySource(transpose_input_array=False, scalar_data=s)
        self.source = src
        mayavi.add_source(src)

    def _reset_data_fired(self):
        self.source.scalar_data = self._make_data()

    def _view_data_fired(self):
        mayavi = self.get_mayavi()
        from mayavi.modules.outline import Outline
        from mayavi.modules.image_plane_widget import ImagePlaneWidget
        # Visualize the data.
        o = Outline()
        mayavi.add_module(o)
        ipw = ImagePlaneWidget()
        mayavi.add_module(ipw)
        ipw.module_manager.scalar_lut_manager.show_scalar_bar = True

        ipw_y = ImagePlaneWidget()
        mayavi.add_module(ipw_y)
        ipw_y.ipw.plane_orientation = 'y_axes'

    def _scale_changed(self, value):
        src = self.source
        data = src.scalar_data
        data += value * 0.01
        numpy.mod(data, 1.0, data)
        src.update()
Beispiel #16
0
class Reactor(HasTraits):

    core_temperature = Range(-273.0, 100000.0)

    water_level = Float
class SelectOutput(Filter):
    """This filter lets a user select one among several of the output blocks of
    a given input. This is very useful for a multi-block data source.
    """

    # The output index in the input to choose from.
    output_index = Range(value=0,
                         enter_set=True,
                         auto_set=False,
                         low='_min_index',
                         high='_max_index')

    input_info = PipelineInfo(datasets=['any'],
                              attribute_types=['any'],
                              attributes=['any'])

    output_info = PipelineInfo(datasets=['any'],
                               attribute_types=['any'],
                               attributes=['any'])

    # The minimum output index of our input.
    _min_index = Int(0, desc='the minimum output index')
    # The maximum output index of our input.
    _max_index = Int(0, desc='the maximum output index')

    _my_input = Any(transient=True)

    ########################################
    # Traits View.

    view = View(Group(Item('output_index', enabled_when='_max_index > 0')),
                resizable=True)

    ######################################################################
    # `object` interface.
    def __get_pure_state__(self):
        d = super(SelectOutput, self).__get_pure_state__()
        d['output_index'] = self.output_index
        return d

    def __set_pure_state__(self, state):
        super(SelectOutput, self).__set_pure_state__(state)
        # Force an update of the output index -- if not this doesn't
        # change.
        self._output_index_changed(state.output_index)

    ######################################################################
    # `Filter` interface.
    def update_pipeline(self):
        # Do nothing if there is no input.
        inputs = self.inputs
        if len(inputs) == 0:
            return

        # Set the maximum index.
        obj = get_new_output(inputs[0].outputs[0])
        self._my_input = obj
        if hasattr(obj, 'number_of_blocks'):
            self._max_index = obj.number_of_blocks - 1
        else:
            self._max_index = len(inputs[0].outputs) - 1
        self._output_index_changed(self.output_index)

    def update_data(self):
        # Propagate the event.
        self.data_changed = True

    ######################################################################
    # Trait handlers.
    def _setup_output(self, value):
        obj = self._my_input
        tp = tvtk.TrivialProducer()
        if hasattr(obj, 'number_of_blocks'):
            tp.set_output(obj.get_block(value))
        else:
            tp.set_output(self.inputs[0].outputs[value])
        self._set_outputs([tp])

    def _output_index_changed(self, value):
        """Static trait handler."""
        if value > self._max_index:
            self.output_index = self._max_index
        elif value < self._min_index:
            self.output_index = self._min_index
        else:
            self._setup_output(value)
            s = self.scene
            if s is not None:
                s.renderer.reset_camera_clipping_range()
                s.render()
Beispiel #18
0
class VUMeter(Component):

    # Value expressed in dB
    db = Property(Float)

    # Value expressed as a percent.
    percent = Range(low=0.0)

    # The maximum value to be display in the VU Meter, expressed as a percent.
    max_percent = Float(150.0)

    # Angle (in degrees) from a horizontal line through the hinge of the
    # needle to the edge of the meter axis.
    angle = Float(45.0)

    # Values of the percentage-based ticks; these are drawn and labeled along
    # the bottom of the curve axis.
    percent_ticks = List(list(range(0, 101, 20)))

    # Text to write in the middle of the VU Meter.
    text = Str("VU")

    # Font used to draw `text`.
    text_font = KivaFont("modern 48")

    # Font for the db tick labels.
    db_tick_font = KivaFont("modern 16")

    # Font for the percent tick labels.
    percent_tick_font = KivaFont("modern 12")

    # beta is the fraction of the of needle that is "hidden".
    # beta == 0 puts the hinge point of the needle on the bottom
    # edge of the window.  Values that result in a decent looking
    # meter are 0 < beta < .65.
    # XXX needs a better name!
    _beta = Float(0.3)

    # _outer_radial_margin is the radial extent beyond the circular axis
    # to include  in calculations of the space required for the meter.
    # This allows room for the ticks and labels.
    _outer_radial_margin = Float(60.0)

    # The angle (in radians) of the span of the curve axis.
    _phi = Property(Float, observe=["angle"])

    # This is the radius of the circular axis (in screen coordinates).
    _axis_radius = Property(Float, observe=["_phi", "width", "height"])

    # ---------------------------------------------------------------------
    # Trait Property methods
    # ---------------------------------------------------------------------

    def _get_db(self):
        db = percent_to_db(self.percent)
        return db

    def _set_db(self, value):
        self.percent = db_to_percent(value)

    def _get__phi(self):
        phi = math.pi * (180.0 - 2 * self.angle) / 180.0
        return phi

    def _get__axis_radius(self):
        M = self._outer_radial_margin
        beta = self._beta
        w = self.width
        h = self.height
        phi = self._phi

        R1 = w / (2 * math.sin(phi / 2)) - M
        R2 = (h - M) / (1 - beta * math.cos(phi / 2))
        R = min(R1, R2)
        return R

    # ---------------------------------------------------------------------
    # Trait change handlers
    # ---------------------------------------------------------------------

    def _anytrait_changed(self):
        self.request_redraw()

    # ---------------------------------------------------------------------
    # Component API
    # ---------------------------------------------------------------------

    def _draw_mainlayer(self, gc, view_bounds=None, mode="default"):

        beta = self._beta
        phi = self._phi

        w = self.width

        M = self._outer_radial_margin
        R = self._axis_radius

        # (ox, oy) is the position of the "hinge point" of the needle
        # (i.e. the center of rotation).  For beta > ~0, oy is negative,
        # so this point is below the visible region.
        ox = self.x + self.width // 2
        oy = -beta * R * math.cos(phi / 2) + 1

        left_theta = math.radians(180 - self.angle)
        right_theta = math.radians(self.angle)

        # The angle of the 100% position.
        nominal_theta = self._percent_to_theta(100.0)

        # The color of the axis for percent > 100.
        red = (0.8, 0, 0)

        with gc:
            gc.set_antialias(True)

            # Draw everything relative to the center of the circles.
            gc.translate_ctm(ox, oy)

            # Draw the primary ticks and tick labels on the curved axis.
            gc.set_fill_color((0, 0, 0))
            gc.set_font(self.db_tick_font)
            for db in [-20, -10, -7, -5, -3, -2, -1, 0, 1, 2, 3]:
                db_percent = db_to_percent(db)
                theta = self._percent_to_theta(db_percent)
                x1 = R * math.cos(theta)
                y1 = R * math.sin(theta)
                x2 = (R + 0.3 * M) * math.cos(theta)
                y2 = (R + 0.3 * M) * math.sin(theta)
                gc.set_line_width(2.5)
                gc.move_to(x1, y1)
                gc.line_to(x2, y2)
                gc.stroke_path()

                text = str(db)
                if db > 0:
                    text = "+" + text
                self._draw_rotated_label(gc, text, theta, R + 0.4 * M)

            # Draw the secondary ticks on the curve axis.
            for db in [-15, -9, -8, -6, -4, -0.5, 0.5]:
                # db_percent = 100 * math.pow(10.0, db / 20.0)
                db_percent = db_to_percent(db)
                theta = self._percent_to_theta(db_percent)
                x1 = R * math.cos(theta)
                y1 = R * math.sin(theta)
                x2 = (R + 0.2 * M) * math.cos(theta)
                y2 = (R + 0.2 * M) * math.sin(theta)
                gc.set_line_width(1.0)
                gc.move_to(x1, y1)
                gc.line_to(x2, y2)
                gc.stroke_path()

            # Draw the percent ticks and label on the bottom of the
            # curved axis.
            gc.set_font(self.percent_tick_font)
            gc.set_fill_color((0.5, 0.5, 0.5))
            gc.set_stroke_color((0.5, 0.5, 0.5))
            percents = self.percent_ticks
            for tick_percent in percents:
                theta = self._percent_to_theta(tick_percent)
                x1 = (R - 0.15 * M) * math.cos(theta)
                y1 = (R - 0.15 * M) * math.sin(theta)
                x2 = R * math.cos(theta)
                y2 = R * math.sin(theta)
                gc.set_line_width(2.0)
                gc.move_to(x1, y1)
                gc.line_to(x2, y2)
                gc.stroke_path()

                text = str(tick_percent)
                if tick_percent == percents[-1]:
                    text = text + "%"
                self._draw_rotated_label(gc, text, theta, R - 0.3 * M)

            if self.text:
                gc.set_font(self.text_font)
                tx, ty, tw, th = gc.get_text_extent(self.text)
                gc.set_fill_color((0, 0, 0, 0.25))
                gc.set_text_matrix(affine.affine_from_rotation(0))
                gc.set_text_position(-0.5 * tw, (0.75 * beta + 0.25) * R)
                gc.show_text(self.text)

            # Draw the red curved axis.
            gc.set_stroke_color(red)
            w = 10
            gc.set_line_width(w)
            gc.arc(0, 0, R + 0.5 * w - 1, right_theta, nominal_theta)
            gc.stroke_path()

            # Draw the black curved axis.
            w = 4
            gc.set_line_width(w)
            gc.set_stroke_color((0, 0, 0))
            gc.arc(0, 0, R + 0.5 * w - 1, nominal_theta, left_theta)
            gc.stroke_path()

            # Draw the filled arc at the bottom.
            gc.set_line_width(2)
            gc.set_stroke_color((0, 0, 0))
            gc.arc(
                0,
                0,
                beta * R,
                math.radians(self.angle),
                math.radians(180 - self.angle),
            )
            gc.stroke_path()
            gc.set_fill_color((0, 0, 0, 0.25))
            gc.arc(
                0,
                0,
                beta * R,
                math.radians(self.angle),
                math.radians(180 - self.angle),
            )
            gc.fill_path()

            # Draw the needle.
            percent = self.percent
            # If percent exceeds max_percent, the needle is drawn at
            # max_percent.
            if percent > self.max_percent:
                percent = self.max_percent
            needle_theta = self._percent_to_theta(percent)
            gc.rotate_ctm(needle_theta - 0.5 * math.pi)
            self._draw_vertical_needle(gc)

    # ---------------------------------------------------------------------
    # Private methods
    # ---------------------------------------------------------------------

    def _draw_vertical_needle(self, gc):
        """ Draw the needle of the meter, pointing straight up. """
        beta = self._beta
        R = self._axis_radius
        end_y = beta * R
        blob_y = R - 0.6 * self._outer_radial_margin
        tip_y = R + 0.2 * self._outer_radial_margin
        lw = 5

        with gc:
            gc.set_alpha(1)
            gc.set_fill_color((0, 0, 0))

            # Draw the needle from the bottom to the blob.
            gc.set_line_width(lw)
            gc.move_to(0, end_y)
            gc.line_to(0, blob_y)
            gc.stroke_path()

            # Draw the thin part of the needle from the blob to the tip.
            gc.move_to(lw, blob_y)
            control_y = blob_y + 0.25 * (tip_y - blob_y)
            gc.quad_curve_to(0.2 * lw, control_y, 0, tip_y)
            gc.quad_curve_to(-0.2 * lw, control_y, -lw, blob_y)
            gc.line_to(lw, blob_y)
            gc.fill_path()

            # Draw the blob on the needle.
            gc.arc(0, blob_y, 6.0, 0, 2 * math.pi)
            gc.fill_path()

    def _draw_rotated_label(self, gc, text, theta, radius):

        tx, ty, tw, th = gc.get_text_extent(text)

        rr = math.sqrt(radius ** 2 + (0.5 * tw) ** 2)
        dtheta = math.atan2(0.5 * tw, radius)
        text_theta = theta + dtheta
        x = rr * math.cos(text_theta)
        y = rr * math.sin(text_theta)

        rot_theta = theta - 0.5 * math.pi
        with gc:
            gc.set_text_matrix(affine.affine_from_rotation(rot_theta))
            gc.set_text_position(x, y)
            gc.show_text(text)

    def _percent_to_theta(self, percent):
        """ Convert percent to the angle theta, in radians.

        theta is the angle of the needle measured counterclockwise from
        the horizontal (i.e. the traditional angle of polar coordinates).
        """
        angle = (
            self.angle
            + (180.0 - 2 * self.angle)
            * (self.max_percent - percent)
            / self.max_percent
        )
        theta = math.radians(angle)
        return theta

    def _db_to_theta(self, db):
        """ Convert db to the angle theta, in radians. """
        percent = db_to_percent(db)
        theta = self._percent_to_theta(percent)
        return theta
Beispiel #19
0
class Test(HasTraits):
    var = Range(low=[])  # E: arg-type
    var2 = Range(low="3")
    var3 = Range(low=3)
Beispiel #20
0
class RemapDemo(ImageProcessDemo):
    TITLE = "Remap Demo"
    DEFAULT_IMAGE = "lena.jpg"

    offsetx = Array
    offsety = Array
    gridx = Array
    gridy = Array
    need_redraw = Event
    need_update_map = Bool
    radius = Range(10, 50, 20)
    sigma = Range(10, 50, 30)
    history = List()
    history_len = Property(depends_on="history")

    def __init__(self, **kw):
        super(RemapDemo, self).__init__(**kw)
        self.figure.canvas_events = [
            ("button_press_event", self.on_figure_press),
            ("button_release_event", self.on_figure_release),
            ("motion_notify_event", self.on_figure_motion),
            ('scroll_event', self.on_scroll),
        ]
        self.connect_dirty("need_redraw")
        self.target_pos = self.source_pos = None

    def control_panel(self):
        return VGroup(
            "radius",
            "sigma",
            Item("history_len", style="readonly"),
        )

    def _get_history_len(self):
        return len(self.history)

    def on_scroll(self, event):
        try:
            if event.key == "shift":
                self.sigma = int(self.sigma + event.step)
            else:
                self.radius = int(self.radius + event.step)
        except TraitError:
            pass

    def on_figure_press(self, event):
        if event.inaxes is self.axe and event.button == 1:
            self.target_pos = int(event.xdata), int(event.ydata)
        elif event.button == 3:
            if self.history:
                self.img[:] = self.history.pop()[:]
                self.need_redraw = True

    def on_figure_release(self, event):
        if event.inaxes is self.axe and event.button == 1:
            self.source_pos = int(event.xdata), int(event.ydata)
            self.need_redraw = True
            self.need_update_map = True

    def on_figure_motion(self, event):
        if event.inaxes is self.axe and event.button == 1:
            self.source_pos = int(event.xdata), int(event.ydata)
            self.need_redraw = True

    def _img_changed(self):
        self.offsetx = np.zeros(self.img.shape[:2], dtype=np.float32)
        self.offsety = np.zeros(self.img.shape[:2], dtype=np.float32)
        self.gridy, self.gridx = np.mgrid[:self.img.shape[0], :self.img.
                                          shape[1]]
        self.gridx = self.gridx.astype(np.float32)
        self.gridy = self.gridy.astype(np.float32)

    def draw(self):
        self.offsetx.fill(0)
        self.offsety.fill(0)

        if self.source_pos is not None and self.target_pos is not None:
            sx, sy = self.source_pos
            tx, ty = self.target_pos
            mask = ((self.gridx - sx)**2 +
                    (self.gridy - sy)**2) < self.radius**2
            self.offsetx[mask] = tx - sx
            self.offsety[mask] = ty - sy
            cv2.GaussianBlur(self.offsetx, (0, 0), self.sigma, self.offsetx)
            cv2.GaussianBlur(self.offsety, (0, 0), self.sigma, self.offsety)

        img2 = cv2.remap(self.img, self.offsetx + self.gridx,
                         self.offsety + self.gridy, cv2.INTER_LINEAR)
        self.draw_image(img2)

        if self.need_update_map:
            self.history.append(self.img.copy())
            if len(self.history) > 10:
                del self.history[0]
            self.img[:] = img2[:]
            self.source_pos = self.target_pos = None
            self.need_update_map = False
Beispiel #21
0
class VolumeFactory(PipeFactory):
    """ Applies the Volume mayavi module to the given VTK data
        source (Mayavi source, or VTK dataset).

        **Note**

        The range of the colormap can be changed simply using the
        vmin/vmax parameters (see below). For more complex modifications of
        the colormap, here is some pseudo code to change the ctf (color
        transfer function), or the otf (opacity transfer function)::

            vol = mlab.pipeline.volume(src)

            # Changing the ctf:
            from tvtk.util.ctf import ColorTransferFunction
            ctf = ColorTransferFunction()
            ctf.add_rgb_point(value, r, g, b)  # r, g, and b are float
                                               # between 0 and 1
            ctf.add_hsv_point(value, h, s, v)
            # ...
            vol._volume_property.set_color(ctf)
            vol._ctf = ctf
            vol.update_ctf = True

            # Changing the otf:
            from tvtk.util.ctf import PiecewiseFunction
            otf = PiecewiseFunction()
            otf.add_point(value, opacity)
            vol._otf = otf
            vol._volume_property.set_scalar_opacity(otf)

        Also, it might be useful to change the range of the ctf::

            ctf.range = [0, 1]
    """

    color = Trait(
        None,
        None,
        TraitTuple(Range(0., 1.), Range(0., 1.), Range(0., 1.)),
        help="""the color of the vtk object. Overides the colormap,
                        if any, when specified. This is specified as a
                        triplet of float ranging from 0 to 1, eg (1, 1,
                        1) for white.""",
    )

    vmin = Trait(None,
                 None,
                 CFloat,
                 help="""vmin is used to scale the transparency
                            gradient. If None, the min of the data will be
                            used""")

    vmax = Trait(None,
                 None,
                 CFloat,
                 help="""vmax is used to scale the transparency
                            gradient. If None, the max of the data will be
                            used""")

    _target = Instance(modules.Volume, ())

    __ctf_rescaled = Bool(False)

    ######################################################################
    # Non-public interface.
    ######################################################################
    def _color_changed(self):
        if not self.color:
            return
        range_min, range_max = self._target.current_range
        from tvtk.util.ctf import ColorTransferFunction
        ctf = ColorTransferFunction()
        try:
            ctf.range = (range_min, range_max)
        except Exception:
            # VTK versions < 5.2 don't seem to need this.
            pass

        r, g, b = self.color
        ctf.add_rgb_point(range_min, r, g, b)
        ctf.add_rgb_point(range_max, r, g, b)

        self._target._ctf = ctf
        self._target._volume_property.set_color(ctf)
        self._target.update_ctf = True

    def _vmin_changed(self):
        vmin = self.vmin
        vmax = self.vmax
        range_min, range_max = self._target.current_range
        if vmin is None:
            vmin = range_min
        if vmax is None:
            vmax = range_max

        # Change the opacity function
        from tvtk.util.ctf import PiecewiseFunction, save_ctfs

        otf = PiecewiseFunction()
        if range_min < vmin:
            otf.add_point(range_min, 0.)
        if range_max > vmax:
            otf.add_point(range_max, 0.2)
        otf.add_point(vmin, 0.)
        otf.add_point(vmax, 0.2)
        self._target._otf = otf
        self._target._volume_property.set_scalar_opacity(otf)
        if self.color is None and not self.__ctf_rescaled and \
                        ((self.vmin is not None) or (self.vmax is not None)):
            # FIXME: We don't use 'rescale_ctfs' because it screws up the
            # nodes.

            def _rescale_value(x):
                nx = (x - range_min) / (range_max - range_min)
                return vmin + nx * (vmax - vmin)

            # The range of the existing ctf can vary.
            scale_min, scale_max = self._target._ctf.range

            def _rescale_node(x):
                nx = (x - scale_min) / (scale_max - scale_min)
                return range_min + nx * (range_max - range_min)

            if hasattr(self._target._ctf, 'nodes'):
                rgb = list()
                for value in self._target._ctf.nodes:
                    r, g, b = \
                            self._target._ctf.get_color(value)
                    rgb.append((_rescale_node(value), r, g, b))
            else:
                rgb = save_ctfs(self._target.volume_property)['rgb']

            from tvtk.util.ctf import ColorTransferFunction
            ctf = ColorTransferFunction()
            try:
                ctf.range = (range_min, range_max)
            except Exception:
                # VTK versions < 5.2 don't seem to need this.
                pass
            rgb.sort()
            v = rgb[0]
            ctf.add_rgb_point(range_min, v[1], v[2], v[3])
            for v in rgb:
                ctf.add_rgb_point(_rescale_value(v[0]), v[1], v[2], v[3])
            ctf.add_rgb_point(range_max, v[1], v[2], v[3])

            self._target._ctf = ctf
            self._target._volume_property.set_color(ctf)
            self.__ctf_rescaled = True

        self._target.update_ctf = True
Beispiel #22
0
class A(HasTraits):
    i = Int
    r = Range(2, 9223372036854775807)
Beispiel #23
0
from ..editor_factory import EditorFactory

from ..toolkit import toolkit_object

# Currently, this traits is used only for the wx backend.
from ..helper import DockStyle

# -------------------------------------------------------------------------
#  Trait definitions:
# -------------------------------------------------------------------------

# Trait whose value is a BaseTraitHandler object
handler_trait = Instance(BaseTraitHandler)

# The visible number of rows displayed
rows_trait = Range(1, 50, 5, desc="the number of list rows to display")

# The visible number of columns displayed
columns_trait = Range(1, 10, 1, desc="the number of list columns to display")

editor_trait = Instance(EditorFactory)

# -------------------------------------------------------------------------
#  'ToolkitEditorFactory' class:
# -------------------------------------------------------------------------


class ToolkitEditorFactory(EditorFactory):
    """ Editor factory for list editors.
    """
class MultiLinePlotDemo(HasTraits):
    """Demonstrates the MultiLinePlot.

    This demo assumes that 'model', an instance of DataModel containing the 2D
    data to be plotted, will be given to the constructor, and will not change
    later.
    """

    model = Instance(DataModel)

    plot = Instance(Plot)

    multi_line_plot_renderer = Instance(MultiLinePlot)

    # Drives multi_line_plot_renderer.normalized_amplitude
    amplitude = Range(-1.5, 1.5, value=-0.5)

    # Drives multi_line_plot_renderer.offset
    offset = Range(-1.0, 1.0, value=0)

    traits_view = \
        View(
            VGroup(
                Group(
                    Item('plot', editor=ComponentEditor(), show_label=False),
                ),
                HGroup(
                    Item('amplitude', springy=True),
                    Item('offset', springy=True),
                    springy=True,
                ),
                HGroup(
                    Item('object.multi_line_plot_renderer.color', springy=True),
                    Item('object.multi_line_plot_renderer.line_style', springy=True),
                    springy=True,
                ),
            ),
            width=800,
            height=500,
            resizable=True,
        )


    #-----------------------------------------------------------------------
    # Trait defaults
    #-----------------------------------------------------------------------

    def _multi_line_plot_renderer_default(self):
        """Create the default MultiLinePlot instance."""

        xs = ArrayDataSource(self.model.x_index, sort_order='ascending')
        xrange = DataRange1D()
        xrange.add(xs)

        ys = ArrayDataSource(self.model.y_index, sort_order='ascending')
        yrange = DataRange1D()
        yrange.add(ys)

        # The data source for the MultiLinePlot.
        ds = MultiArrayDataSource(data=self.model.data)

        multi_line_plot_renderer = \
            MultiLinePlot(
                index = xs,
                yindex = ys,
                index_mapper = LinearMapper(range=xrange),
                value_mapper = LinearMapper(range=yrange),
                value=ds,
                global_max = self.model.data.max(),
                global_min = self.model.data.min())

        return multi_line_plot_renderer

    def _plot_default(self):
        """Create the Plot instance."""

        plot = Plot(title="MultiLinePlot Demo")
        plot.add(self.multi_line_plot_renderer)

        x_axis = PlotAxis(component=plot,
                            mapper=self.multi_line_plot_renderer.index_mapper,
                            orientation='bottom',
                            title='t (seconds)')
        y_axis = PlotAxis(component=plot,
                            mapper=self.multi_line_plot_renderer.value_mapper,
                            orientation='left',
                            title='channel')
        plot.overlays.extend([x_axis, y_axis])
        return plot

    #-----------------------------------------------------------------------
    # Trait change handlers
    #-----------------------------------------------------------------------

    def _amplitude_changed(self, amp):
        self.multi_line_plot_renderer.normalized_amplitude = amp

    def _offset_changed(self, off):
        self.multi_line_plot_renderer.offset = off
        # FIXME:  The change does not trigger a redraw.  Force a redraw by
        # faking an amplitude change.
        self.multi_line_plot_renderer._amplitude_changed()
class FieldViewer(HasTraits):

    # 三个轴的取值范围
    x0, x1 = Float(-5), Float(5)
    y0, y1 = Float(-5), Float(5)
    z0, z1 = Float(-5), Float(5)
    points = Int(50)  # 分割点数
    autocontour = Bool(True)  # 是否自动计算等值面
    v0, v1 = Float(0.0), Float(1.0)  # 等值面的取值范围
    contour = Range("v0", "v1", 0.5)  # 等值面的值
    function = Str("x*x*0.5 + y*y + z*z*2.0")  # 标量场函数
    function_list = [
        "x*x*0.5 + y*y + z*z*2.0", "x*y*0.5 + np.sin(2*x)*y +y*z*2.0", "x*y*z",
        "np.sin((x*x+y*y)/z)"
    ]
    plotbutton = Button("描画")
    scene = Instance(MlabSceneModel, ())  #❶

    view = View(
        HSplit(
            VGroup(
                "x0",
                "x1",
                "y0",
                "y1",
                "z0",
                "z1",
                Item('points', label="点数"),
                Item('autocontour', label="自动等值"),
                Item('plotbutton', show_label=False),
            ),
            VGroup(
                Item(
                    'scene',
                    editor=SceneEditor(scene_class=MayaviScene),  #❷
                    resizable=True,
                    height=300,
                    width=350),
                Item('function',
                     editor=EnumEditor(name='function_list',
                                       evaluate=lambda x: x)),
                Item('contour',
                     editor=RangeEditor(format="%1.2f",
                                        low_name="v0",
                                        high_name="v1")),
                show_labels=False)),
        width=500,
        resizable=True,
        title="三维标量场观察器")

    def _plotbutton_fired(self):
        self.plot()

    def plot(self):
        # 产生三维网格
        x, y, z = np.mgrid[  #❸
            self.x0:self.x1:1j * self.points, self.y0:self.y1:1j * self.points,
            self.z0:self.z1:1j * self.points]

        # 根据函数计算标量场的值
        scalars = eval(self.function)  #❹
        self.scene.mlab.clf()  # 清空当前场景

        # 绘制等值平面
        g = self.scene.mlab.contour3d(x,
                                      y,
                                      z,
                                      scalars,
                                      contours=8,
                                      transparent=True)  #❺
        g.contour.auto_contours = self.autocontour
        self.scene.mlab.axes(figure=self.scene.mayavi_scene)  # 添加坐标轴

        # 添加一个X-Y的切面
        s = self.scene.mlab.pipeline.scalar_cut_plane(g)
        cutpoint = (self.x0 + self.x1) / 2, (self.y0 + self.y1) / 2, (
            self.z0 + self.z1) / 2
        s.implicit_plane.normal = (0, 0, 1)  # x cut
        s.implicit_plane.origin = cutpoint

        self.g = g  #❻
        self.scalars = scalars
        # 计算标量场的值的范围
        self.v0 = np.min(scalars)
        self.v1 = np.max(scalars)

    def _contour_changed(self):  #❼
        if hasattr(self, "g"):
            if not self.g.contour.auto_contours:
                self.g.contour.contours = [self.contour]

    def _autocontour_changed(self):  #❽
        if hasattr(self, "g"):
            self.g.contour.auto_contours = self.autocontour
            if not self.autocontour:
                self._contour_changed()
Beispiel #26
0
class Visualization(HasTraits):
    lmax = Range(0, 30, 6)
    scene = Instance(MlabSceneModel, ())

    # the layout of the dialog created
    view = View(
        Item('scene',
             editor=SceneEditor(scene_class=MayaviScene),
             height=250,
             width=300,
             show_label=False),
        HGroup(
            '_',
            'lmax',
        ),
    )

    def __init__(self, x, y, z, theta, phi, THETAnorm, PHInorm, mask):
        # Do not forget to call the parent's __init__
        HasTraits.__init__(self)

        self.theta = theta
        self.phi = phi
        self.THETAnorm = THETAnorm
        self.PHInorm = PHInorm
        self.z = z
        self.mask = mask

        #
        # Visualize the results.
        #
        self.Ximg, self.Yimg = np.mgrid[0:1001, 0:1001]

        Zimg = estimate(self.z, self.theta, self.phi, self.THETAnorm,
                        self.PHInorm, self.lmax)
        self.mesh = self.scene.mlab.mesh(self.Ximg,
                                         self.Yimg,
                                         Zimg,
                                         mask=~mask)

        x_undistort = (theta / (np.pi / 2) * np.sin(phi) + 1) * 1001 / 2
        y_undistort = (theta / (np.pi / 2) * np.cos(phi) + 1) * 1001 / 2
        self.scene.mlab.points3d(x_undistort,
                                 y_undistort,
                                 self.z,
                                 mode='sphere',
                                 scale_mode='none',
                                 scale_factor=5,
                                 color=(0, 0, 1))

        self.scene.mlab.outline(color=(0, 0, 0),
                                extent=(0, 1001, 0, 1001, 0, 255))

    @on_trait_change('lmax')
    def update_plot(self):
        Zimg = estimate(self.z, self.theta, self.phi, self.THETAnorm,
                        self.PHInorm, self.lmax)
        Zimg = np.clip(Zimg, 0, 255)
        self.mesh.mlab_source.set(x=self.Ximg,
                                  y=self.Yimg,
                                  z=Zimg,
                                  mask=~self.mask)
class A(HasTraits):
    i = Int
    l = Long
    r = Range(LONG_TYPE(2), LONG_TYPE(9223372036854775807))
Beispiel #28
0
class FileDataSource(Source):

    # The version of this class.  Used for persistence.
    __version__ = 0

    # The list of file names for the timeseries.
    file_list = List(Str, desc='a list of files belonging to a time series')

    # The current time step (starts with 0).  This trait is a dummy
    # and is dynamically changed when the `file_list` trait changes.
    # This is done so the timestep bounds are linked to the number of
    # the files in the file list.
    timestep = Range(value=0,
                     low='_min_timestep',
                     high='_max_timestep',
                     enter_set=True, auto_set=False,
                     desc='the current time step')

    sync_timestep = Bool(False, desc='if all dataset timesteps are synced')

    play = Bool(False, desc='if timesteps are automatically updated')
    play_delay = Float(0.2, desc='the delay between loading files')
    loop = Bool(False, desc='if animation is looped')

    update_files = Button('Rescan files')

    base_file_name=Str('', desc="the base name of the file",
                       enter_set=True, auto_set=False,
                       editor=FileEditor())

    # A timestep view group that may be included by subclasses.
    time_step_group = Group(
                          Item(name='file_path', style='readonly'),
                          Group(
                              Item(name='timestep',
                                   editor=RangeEditor(
                                       low=0, high_name='_max_timestep',
                                       mode='slider'
                                   ),
                              ),
                              Item(name='sync_timestep'),
                              HGroup(
                                  Item(name='play'),
                                  Item(name='play_delay',
                                       label='Delay'),
                                  Item(name='loop'),
                              ),
                              visible_when='len(object.file_list) > 1'
                          ),
                          Item(name='update_files', show_label=False),
                      )

    ##################################################
    # Private traits.
    ##################################################

    # The current file name.  This is not meant to be touched by the
    # user.
    file_path = Instance(FilePath, (), desc='the current file name')

    _min_timestep = Int(0)
    _max_timestep = Int(0)
    _timer = Any
    _in_update_files = Any(False)

    ######################################################################
    # `object` interface
    ######################################################################
    def __get_pure_state__(self):
        d = super(FileDataSource, self).__get_pure_state__()
        # These are obtained dynamically, so don't pickle them.
        for x in ['file_list', 'timestep', 'play']:
            d.pop(x, None)
        return d

    def __set_pure_state__(self, state):
        # Use the saved path to initialize the file_list and timestep.
        fname = state.file_path.abs_pth
        if not isfile(fname):
            msg = 'Could not find file at %s\n'%fname
            msg += 'Please move the file there and try again.'
            raise IOError(msg)

        self.initialize(fname)
        # Now set the remaining state without touching the children.
        set_state(self, state, ignore=['children', 'file_path'])
        # Setup the children.
        handle_children_state(self.children, state.children)
        # Setup the children's state.
        set_state(self, state, first=['children'], ignore=['*'])

    ######################################################################
    # `FileDataSource` interface
    ######################################################################
    def initialize(self, base_file_name):
        """Given a single filename which may or may not be part of a
        time series, this initializes the list of files.  This method
        need not be called to initialize the data.
        """
        self.base_file_name = base_file_name

    ######################################################################
    # Non-public interface
    ######################################################################
    def _file_list_changed(self, value):
        # Change the range of the timestep suitably to reflect new list.
        n_files = len(self.file_list)
        timestep = max(min(self.timestep, n_files-1), 0)
        if self.timestep == timestep:
            self._timestep_changed(timestep)
        else:
            self.timestep = timestep
        self._max_timestep = max(n_files -1, 0)

    def _file_list_items_changed(self, list_event):
        self._file_list_changed(self.file_list)

    def _timestep_changed(self, value):
        file_list = self.file_list
        if len(file_list) > 0:
            self.file_path = FilePath(file_list[value])
        else:
            self.file_path = FilePath('')
        if self.sync_timestep:
            for sibling in self._find_sibling_datasets():
                sibling.timestep = value

    def _base_file_name_changed(self, value):
        self._update_files_fired()
        try:
            self.timestep = self.file_list.index(value)
        except ValueError:
            self.timestep = 0

    def _play_changed(self, value):
        mm = getattr(self.scene, 'movie_maker', None)
        if value:
            if mm is not None:
                mm.animation_start()
            self._timer = self._make_play_timer()
            if not self._timer.IsRunning():
                self._timer.Start()
        else:
            self._timer.Stop()
            self._timer = None
            if mm is not None:
                mm.animation_stop()

    def _loop_changed(self, value):
        if value and self.play:
            self._play_changed(self.play)

    def _play_event(self):
        mm = getattr(self.scene, 'movie_maker', None)
        nf = self._max_timestep
        pc = self.timestep
        pc += 1
        if pc > nf:
            if self.loop:
                pc = 0
            else:
                self._timer.Stop()
                pc = nf
                if mm is not None:
                    mm.animation_stop()
        if pc != self.timestep:
            self.timestep = pc
            if mm is not None:
                mm.animation_step()

    def _play_delay_changed(self):
        if self.play:
            self._timer.Stop()
            self._timer.Start(self.play_delay*1000)

    def _make_play_timer(self):
        scene = self.scene
        if scene is None or scene.off_screen_rendering:
            timer = NoUITimer(self.play_delay*1000, self._play_event)
        else:
            from pyface.timer.api import Timer
            timer = Timer(self.play_delay*1000, self._play_event)
        return timer

    def _find_sibling_datasets(self):
        if self.parent is not None:
            nt = self._max_timestep
            return [x for x in self.parent.children if x._max_timestep == nt]
        else:
            return []

    def _update_files_fired(self):
        # First get all the siblings before we change the current file list.
        if self._in_update_files:
            return
        try:
            self._in_update_files = True
            if self.sync_timestep:
                siblings = self._find_sibling_datasets()
            else:
                siblings = []
            fname = self.base_file_name
            file_list = get_file_list(fname)
            if len(file_list) == 0:
                file_list = [fname]
            self.file_list = file_list
            for sibling in siblings:
                sibling.update_files = True
        finally:
            self._in_update_files = False
Beispiel #29
0
class Catenary(HasTraits):

    N = Int(31)
    dump = Range(0.0, 0.5, 0.1)
    k = Range(1.0, 100.0, 20.0)
    length = Range(1.0, 3.0, 1.0)
    g = Range(0.0, 0.1, 0.01)
    t = Float(0.0)
    x = Property()
    y = Property()

    def __init__(self, *args, **kw):
        super(Catenary, self).__init__(*args, **kw)

        x0 = np.linspace(0, 1, self.N)
        y0 = np.zeros_like(x0)
        vx0 = np.zeros_like(x0)
        vy0 = np.zeros_like(x0)

        self.status0 = np.r_[x0, y0, vx0, vy0]

    def diff_status(self, t, status):
        x, y, vx, vy = status.reshape(4, -1)
        dvx = np.zeros_like(x)
        dvy = np.zeros_like(x)
        dx = vx
        dy = vy

        s = np.s_[1:-1]

        l = self.length / (self.N - 1)
        k = self.k
        g = self.g
        dump = self.dump

        l1 = np.sqrt((x[s] - x[:-2])**2 + (y[s] - y[:-2])**2)
        l2 = np.sqrt((x[s] - x[2:])**2 + (y[s] - y[2:])**2)
        dl1 = (l1 - l) / l1
        dl2 = (l2 - l) / l2
        dvx[s] = -(x[s] - x[:-2]) * k * dl1 - (x[s] - x[2:]) * k * dl2
        dvy[s] = -(y[s] - y[:-2]) * k * dl1 - (y[s] - y[2:]) * k * dl2 + g
        dvx[s] -= vx[s] * dump
        dvy[s] -= vy[s] * dump
        return np.r_[dx, dy, dvx, dvy]

    def ode_init(self):
        self.t = 0
        self.system = integrate.ode(self.diff_status)
        self.system.set_integrator("vode", method="bdf")
        self.system.set_initial_value(self.status0, 0)
        self.status = self.status0

    def ode_step(self, dt):
        self.system.integrate(self.t + dt)
        self.t = self.system.t
        self.status = self.system.y

    def _get_x(self):
        return self.status[:self.N]

    def _get_y(self):
        return self.status[self.N:self.N * 2]
Beispiel #30
0
class ArcticDB(WritableFactorDB):
    """ArcticDB"""
    DBName = Str("arctic", arg_type="String", label="数据库名", order=0)
    IPAddr = Str("127.0.0.1", arg_type="String", label="IP地址", order=1)
    Port = Range(low=0,
                 high=65535,
                 value=27017,
                 arg_type="Integer",
                 label="端口",
                 order=2)
    User = Str("", arg_type="String", label="用户名", order=3)
    Pwd = Password("", arg_type="String", label="密码", order=4)

    def __init__(self, sys_args={}, config_file=None, **kwargs):
        self._Arctic = None  # Arctic 对象
        super().__init__(sys_args=sys_args,
                         config_file=(__QS_ConfigPath__ + os.sep +
                                      "ArcticDBConfig.json"
                                      if config_file is None else config_file),
                         **kwargs)
        self.Name = "ArcticDB"
        return

    def __getstate__(self):
        state = self.__dict__.copy()
        # Remove the unpicklable entries.
        state["_Arctic"] = self.isAvailable()
        return state

    def __setstate__(self, state):
        super().__setstate__(state)
        if self._Arctic: self.connect()
        else: self._Arctic = None

    def connect(self):
        self._Arctic = arctic.Arctic(self.IPAddr)
        return 0

    def disconnect(self):
        self._Arctic = None
        return 1

    def isAvailable(self):
        return (self._Arctic is not None)

    @property
    def TableNames(self):
        return sorted(self._Arctic.list_libraries())

    def getTable(self, table_name, args={}):
        if table_name not in self._Arctic.list_libraries():
            raise __QS_Error__("表 '%s' 不存在!" % table_name)
        return _FactorTable(name=table_name, fdb=self, sys_args=args)

    def renameTable(self, old_table_name, new_table_name):
        self._Arctic.rename_library(old_table_name, new_table_name)
        return 0

    def deleteTable(self, table_name):
        self._Arctic.delete_library(table_name)
        return 0

    def setTableMetaData(self,
                         table_name,
                         key=None,
                         value=None,
                         meta_data=None):
        Lib = self._Arctic[table_name]
        TableInfo = Lib.read_metadata("_FactorInfo")
        if TableInfo is None: TableInfo = {}
        if meta_data is not None: TableInfo.update(dict(meta_data))
        if key is not None: TableInfo[key] = value
        Lib.write_metadata("_FactorInfo", TableInfo)
        return 0

    def renameFactor(self, table_name, old_factor_name, new_factor_name):
        if table_name not in self._Arctic.list_libraries():
            raise __QS_Error__("表: '%s' 不存在!" % table_name)
        Lib = self._Arctic[table_name]
        FactorInfo = Lib.read(symbol="_FactorInfo").set_index(["FactorName"])
        if old_factor_name not in FactorInfo.index:
            raise __QS_Error__("因子: '%s' 不存在!" % old_factor_name)
        if new_factor_name in FactorInfo.index:
            raise __QS_Error__("因子: '%s' 已经存在!" % new_factor_name)
        FactorNames = FactorInfo.index.tolist()
        FactorNames[FactorNames.index(old_factor_name)] = new_factor_name
        FactorInfo.index = FactorNames
        FactorInfo.index.name = "FactorName"
        Lib.write(
            "_FactorInfo",
            FactorInfo.reset_index(),
            chunker=arctic.chunkstore.passthrough_chunker.PassthroughChunker())
        IDs = Lib.list_symbols()
        IDs.remove("_FactorInfo")
        for iID in IDs:
            iMetaData = Lib.read_metadata(iID)
            if old_factor_name in iMetaData["FactorNames"]:
                iMetaData["FactorNames"][iMetaData["FactorNames"].index(
                    old_factor_name)] = new_factor_name
                Lib.write_metadata(iID, iMetaData)
        return 0

    def deleteFactor(self, table_name, factor_names):
        if table_name not in self._Arctic.list_libraries(): return 0
        Lib = self._Arctic[table_name]
        FactorInfo = Lib.read(symbol="_FactorInfo").set_index(["FactorName"])
        FactorInfo = FactorInfo.loc[FactorInfo.index.difference(factor_names)]
        if FactorInfo.shape[0] == 0: return self.deleteTable(table_name)
        IDs = Lib.list_symbols()
        IDs.remove("_FactorInfo")
        for iID in IDs:
            iMetaData = Lib.read_metadata(iID)
            iFactorIndex = pd.Series(iMetaData["Cols"],
                                     index=iMetaData["FactorNames"])
            iFactorIndex = iFactorIndex[iFactorIndex.index.difference(
                factor_names)]
            if iFactorIndex.shape[0] == 0:
                Lib.delete(iID)
                continue
            iFactorNames = iFactorIndex.values.tolist()
            iData = Lib.read(symbol=iID, columns=iFactorNames)
            iCols = [str(i) for i in range(iFactorIndex.shape[0])]
            iData.columns = iCols
            iMetaData["FactorNames"], iMetaData["Cols"] = iFactorNames, iCols
            Lib.write(iID, iData, metadata=iMetaData)
        Lib.write(
            "_FactorInfo",
            FactorInfo.reset_index(),
            chunker=arctic.chunkstore.passthrough_chunker.PassthroughChunker())
        return 0

    def setFactorMetaData(self,
                          table_name,
                          ifactor_name,
                          key=None,
                          value=None,
                          meta_data=None):
        if (key is None) and (meta_data is None): return 0
        Lib = self._Arctic[table_name]
        FactorInfo = Lib.read(symbol="_FactorInfo").set_index(["FactorName"])
        if key is not None: FactorInfo.loc[ifactor_name, key] = value
        if meta_data is not None:
            for iKey in meta_data:
                FactorInfo.loc[ifactor_name, iKey] = meta_data[iKey]
        Lib.write(
            "_FactorInfo",
            FactorInfo.reset_index(),
            chunker=arctic.chunkstore.passthrough_chunker.PassthroughChunker())
        return 0

    def writeData(self,
                  data,
                  table_name,
                  if_exists="update",
                  data_type={},
                  **kwargs):
        if data.shape[0] == 0: return 0
        if table_name not in self._Arctic.list_libraries():
            return self._writeNewData(data, table_name, data_type=data_type)
        Lib = self._Arctic[table_name]
        DataCols = [str(i) for i in range(data.shape[0])]
        #DTRange = pd.date_range(data.major_axis[0], data.major_axis[-1], freq=Freq)
        DTRange = data.major_axis
        OverWrite = (if_exists == "update")
        for i, iID in enumerate(data.minor_axis):
            iData = data.iloc[:, :, i]
            if not Lib.has_symbol(iID):
                iMetaData = {
                    "FactorNames": iData.columns.tolist(),
                    "Cols": DataCols
                }
                iData.index.name, iData.columns = "date", DataCols
                Lib.write(iID, iData, metadata=iMetaData)
                continue
            iMetaData = Lib.read_metadata(symbol=iID)
            iOldFactorNames, iCols = iMetaData["FactorNames"], iMetaData[
                "Cols"]
            iNewFactorNames = iData.columns.difference(
                iOldFactorNames).tolist()
            #iCrossFactorNames = iOldFactorNames.intersection(iData.columns).tolist()
            iOldData = Lib.read(symbol=iID,
                                chunk_range=DTRange,
                                filter_data=True)
            if iOldData.shape[0] > 0:
                iOldData.columns = iOldFactorNames
                iOldData = iOldData.loc[iOldData.index.union(iData.index),
                                        iOldFactorNames + iNewFactorNames]
                iOldData.update(iData, overwrite=OverWrite)
            else:
                iOldData = iData.loc[:, iOldFactorNames + iNewFactorNames]
            if iNewFactorNames:
                iCols += [
                    str(i)
                    for i in range(iOldData.shape[1], iOldData.shape[1] +
                                   len(iNewFactorNames))
                ]
                #iOldData = pd.merge(iOldData, iData.loc[:, iNewFactorNames], how="outer", left_index=True, right_index=True)
            #if iCrossFactorNames:
            #iOldData = iOldData.loc[iOldData.index.union(iData.index), :]
            #iOldData.update(iData, overwrite=OverWrite)
            #if if_exists=="update": iOldData.loc[iData.index, iCrossFactorNames] = iData.loc[:, iCrossFactorNames]
            #else: iOldData.loc[iData.index, iCrossFactorNames] = iOldData.loc[iData.index, iCrossFactorNames].where(pd.notnull(iOldData.loc[iData.index, iCrossFactorNames]), iData.loc[:, iCrossFactorNames])
            iOldData.index.name, iOldData.columns = "date", iCols
            iMetaData["FactorNames"], iMetaData[
                "Cols"] = iOldFactorNames + iNewFactorNames, iCols
            Lib.update(iID, iOldData, metadata=iMetaData, chunk_range=DTRange)
        FactorInfo = Lib.read(symbol="_FactorInfo").set_index("FactorName")
        NewFactorNames = data.items.difference(FactorInfo.index).tolist()
        FactorInfo = FactorInfo.loc[FactorInfo.index.tolist() +
                                    NewFactorNames, :]
        for iFactorName in NewFactorNames:
            if iFactorName in data_type:
                FactorInfo.loc[iFactorName,
                               "DataType"] = data_type[iFactorName]
            elif np.dtype('O') in data.loc[iFactorName].dtypes:
                FactorInfo.loc[iFactorName, "DataType"] = "string"
            else:
                FactorInfo.loc[iFactorName, "DataType"] = "double"
        Lib.write(
            "_FactorInfo",
            FactorInfo.reset_index(),
            chunker=arctic.chunkstore.passthrough_chunker.PassthroughChunker())
        return 0

    def _writeNewData(self, data, table_name, data_type):
        FactorNames = data.items.tolist()
        DataType = pd.Series("double", index=data.items)
        for i, iFactorName in enumerate(DataType.index):
            if iFactorName in data_type:
                DataType.iloc[i] = data_type[iFactorName]
            elif np.dtype('O') in data.iloc[i].dtypes:
                DataType.iloc[i] = "string"
        DataCols = [str(i) for i in range(data.shape[0])]
        data.items = DataCols
        self._Arctic.initialize_library(table_name,
                                        lib_type=arctic.CHUNK_STORE)
        Lib = self._Arctic[table_name]
        for i, iID in enumerate(data.minor_axis):
            iData = data.iloc[:, :, i]
            iMetaData = {"FactorNames": FactorNames, "Cols": DataCols}
            iData.index.name = "date"
            Lib.write(iID, iData, metadata=iMetaData)
        DataType = DataType.reset_index()
        DataType.columns = ["FactorName", "DataType"]
        Lib.write(
            "_FactorInfo",
            DataType,
            chunker=arctic.chunkstore.passthrough_chunker.PassthroughChunker())
        data.items = FactorNames
        return 0