예제 #1
0
class CyclesPlot(HasTraits):
    """ Simple plotting class with some attached controls"""
    plot = Instance(GridContainer)
    traits_view = View(Item('plot', editor=ComponentEditor(),
                            show_label=False),
                       width=800,
                       height=600,
                       resizable=True,
                       title="Business Cycles Plot")

    # Private Traits
    _file_path = Str
    _dates = Array
    _series1 = Array
    _series2 = Array
    _selected_s1 = Array
    _selected_s2 = Array

    def __init__(self):
        super(CyclesPlot, self).__init__()

        # Normally you'd pass in the data, but I'll hardwire things for this
        #    one-off plot.

        srecs = read_time_series_from_csv("./biz_cycles2.csv",
                                          date_col=0,
                                          date_format="%Y-%m-%d")

        dt = srecs["Date"]

        # Industrial production compared with trend (plotted on value axis)
        iprod_vs_trend = srecs["Metric 1"]

        # Industrial production change in last 6 Months (plotted on index axis)
        iprod_delta = srecs["Metric 2"]

        self._dates = dt
        self._series1 = self._selected_s1 = iprod_delta
        self._series2 = self._selected_s2 = iprod_vs_trend

        end_x = np.array([self._selected_s1[-1]])
        end_y = np.array([self._selected_s2[-1]])

        plotdata = ArrayPlotData(x=self._series1,
                                 y=self._series2,
                                 dates=self._dates,
                                 selected_x=self._selected_s1,
                                 selected_y=self._selected_s2,
                                 endpoint_x=end_x,
                                 endpoint_y=end_y)

        cycles = Plot(plotdata, padding=20)

        cycles.plot(("x", "y"), type="line", color=(.2, .4, .5, .4))

        cycles.plot(("selected_x", "selected_y"),
                    type="line",
                    marker="circle",
                    line_width=3,
                    color=(.2, .4, .5, .9))

        cycles.plot(("endpoint_x", "endpoint_y"),
                    type="scatter",
                    marker_size=4,
                    marker="circle",
                    color=(.2, .4, .5, .2),
                    outline_color=(.2, .4, .5, .6))

        cycles.index_range = DataRange1D(low_setting=80., high_setting=120.)

        cycles.value_range = DataRange1D(low_setting=80., high_setting=120.)

        # dig down to use actual Plot object
        cyc_plot = cycles.components[0]

        # Add the labels in the quadrants
        cyc_plot.overlays.append(
            PlotLabel("\nSlowdown" + 40 * " " + "Expansion",
                      component=cyc_plot,
                      font="swiss 24",
                      color=(.2, .4, .5, .6),
                      overlay_position="inside top"))

        cyc_plot.overlays.append(
            PlotLabel("Downturn" + 40 * " " + "Recovery\n ",
                      component=cyc_plot,
                      font="swiss 24",
                      color=(.2, .4, .5, .6),
                      overlay_position="inside bottom"))

        timeline = Plot(plotdata, resizable='h', height=50, padding=20)
        timeline.plot(("dates", "x"),
                      type="line",
                      color=(.2, .4, .5, .8),
                      name='x')
        timeline.plot(("dates", "y"),
                      type="line",
                      color=(.5, .4, .2, .8),
                      name='y')

        # Snap on the tools
        zoomer = ZoomTool(timeline,
                          drag_button="right",
                          always_on=True,
                          tool_mode="range",
                          axis="index",
                          max_zoom_out_factor=1.1)

        panner = PanTool(timeline, constrain=True, constrain_direction="x")

        # dig down to get Plot component I want
        x_plt = timeline.plots['x'][0]

        range_selection = RangeSelection(x_plt, left_button_selects=True)
        range_selection.on_trait_change(self.update_interval, 'selection')

        x_plt.tools.append(range_selection)
        x_plt.overlays.append(RangeSelectionOverlay(x_plt))

        # Set the plot's bottom axis to use the Scales ticking system
        scale_sys = CalendarScaleSystem(
            fill_ratio=0.4,
            default_numlabels=5,
            default_numticks=10,
        )
        tick_gen = ScalesTickGenerator(scale=scale_sys)

        bottom_axis = ScalesPlotAxis(timeline,
                                     orientation="bottom",
                                     tick_generator=tick_gen)

        # Hack to remove default axis - FIXME: how do I *replace* an axis?
        del (timeline.underlays[-2])

        timeline.overlays.append(bottom_axis)

        container = GridContainer(padding=20,
                                  fill_padding=True,
                                  bgcolor="lightgray",
                                  use_backbuffer=True,
                                  shape=(2, 1),
                                  spacing=(30, 30))

        # add a central "x" and "y" axis

        x_line = LineInspector(cyc_plot,
                               is_listener=True,
                               color="gray",
                               width=2)
        y_line = LineInspector(cyc_plot,
                               is_listener=True,
                               color="gray",
                               width=2,
                               axis="value")

        cyc_plot.overlays.append(x_line)
        cyc_plot.overlays.append(y_line)

        cyc_plot.index.metadata["selections"] = 100.0
        cyc_plot.value.metadata["selections"] = 100.0

        container.add(cycles)
        container.add(timeline)

        container.title = "Business Cycles"

        self.plot = container

    def update_interval(self, value):

        # Reaching pretty deep here to get selections
        sels = self.plot.plot_components[1].plots['x'][0].index.metadata[
            'selections']

        if not sels is None:
            p = self._dates >= sels[0]
            q = self._dates <= sels[1]
            msk = p & q

            self._selected_s1 = self._series1[msk]
            self._selected_s2 = self._series2[msk]

            # Find the index of the last point in the mask
            last_idx = -(msk[::-1].argmax() + 1)
            endpoint_x = np.array([self._series1[last_idx]])
            endpoint_y = np.array([self._series2[last_idx]])

        else:
            self._selected_s1 = self._series1
            self._selected_s2 = self._series2
            endpoint_x = np.array([self._series1[-1]])
            endpoint_y = np.array([self._series2[-1]])

        self.plot.plot_components[0].data['selected_x'] = self._selected_s1
        self.plot.plot_components[0].data['selected_y'] = self._selected_s2
        self.plot.plot_components[0].data['endpoint_x'] = endpoint_x
        self.plot.plot_components[0].data['endpoint_y'] = endpoint_y
예제 #2
0
class InPaintDemo(HasTraits):
    plot = Instance(Plot)
    painter = Instance(CirclePainter)
    r = Range(2.0, 20.0, 10.0)  # inpaint的半径参数
    method = Enum("INPAINT_NS", "INPAINT_TELEA")  # inpaint的算法
    show_mask = Bool(False)  # 是否显示选区
    clear_mask = Button("清除选区")
    apply = Button("保存结果")

    view = View(VGroup(
        VGroup(
            Item("object.painter.r", label="画笔半径"), Item("r",
                                                         label="inpaint半径"),
            HGroup(
                Item("method", label="inpaint算法"),
                Item("show_mask", label="显示选区"),
                Item("clear_mask", show_label=False),
                Item("apply", show_label=False),
            )),
        Item("plot", editor=ComponentEditor(), show_label=False),
    ),
                title="inpaint Demo控制面板",
                width=500,
                height=450,
                resizable=True)

    def __init__(self, *args, **kwargs):
        super(InPaintDemo, self).__init__(*args, **kwargs)
        self.img = cv.imread("stuff.jpg")  # 原始图像
        self.img2 = self.img.clone()  # inpaint效果预览图像
        self.mask = cv.Mat(self.img.size(), cv.CV_8UC1)  # 储存选区的图像
        self.mask[:] = 0
        self.data = ArrayPlotData(img=self.img[:, :, ::-1])
        self.plot = Plot(self.data,
                         padding=10,
                         aspect_ratio=float(self.img.size().width) /
                         self.img.size().height)
        self.plot.x_axis.visible = False
        self.plot.y_axis.visible = False
        imgplot = self.plot.img_plot("img", origin="top left")[0]
        self.painter = CirclePainter(component=imgplot)
        imgplot.overlays.append(self.painter)

    @on_trait_change("r,method")
    def inpaint(self):
        cv.inpaint(self.img, self.mask, self.img2, self.r,
                   getattr(cv, self.method))
        self.draw()

    @on_trait_change("painter:updated")
    def painter_updated(self):
        for _, _, x, y in self.painter.track:
            # 在储存选区的mask上绘制圆形
            cv.circle(self.mask,
                      cv.Point(int(x), int(y)),
                      int(self.painter.r),
                      cv.Scalar(255, 255, 255, 255),
                      thickness=-1)  # 宽度为负表示填充圆形
        self.inpaint()
        self.painter.track = []
        self.painter.request_redraw()

    def _clear_mask_fired(self):
        self.mask[:] = 0
        self.inpaint()

    def _apply_fired(self):
        """保存inpaint的处理结果,并清除选区"""
        self.img[:] = self.img2[:]
        self._clear_mask_fired()

    @on_trait_change("show_mask")
    def draw(self):
        if self.show_mask:
            data = self.img[:, :, ::-1].copy()
            data[self.mask[:] > 0] = 255
            self.data["img"] = data
        else:
            self.data["img"] = self.img2[:, :, ::-1]
예제 #3
0
    def call_mlab(self,
                  scene=None,
                  show=True,
                  is_3d=False,
                  view=None,
                  roll=None,
                  fgcolor=(0.0, 0.0, 0.0),
                  bgcolor=(1.0, 1.0, 1.0),
                  layout='rowcol',
                  scalar_mode='iso_surface',
                  vector_mode='arrows_norm',
                  rel_scaling=None,
                  clamping=False,
                  ranges=None,
                  is_scalar_bar=False,
                  is_wireframe=False,
                  opacity=None,
                  subdomains_args=None,
                  rel_text_width=None,
                  fig_filename='view.png',
                  resolution=None,
                  filter_names=None,
                  only_names=None,
                  group_names=None,
                  step=None,
                  time=None,
                  anti_aliasing=None,
                  domain_specific=None):
        """
        By default, all data (point, cell, scalars, vectors, tensors)
        are plotted in a grid layout, except data named 'node_groups',
        'mat_id' which are usually not interesting.

        Parameters
        ----------
        show : bool
            Call mlab.show().
        is_3d : bool
            If True, use scalar cut planes instead of surface for certain
            datasets. Also sets 3D view mode.
        view : tuple
            Azimuth, elevation angles, distance and focal point as in
            `mlab.view()`.
        roll : float
            Roll angle tuple as in mlab.roll().
        fgcolor : tuple of floats (R, G, B)
            The foreground color, that is the color of all text
            annotation labels (axes, orientation axes, scalar bar
            labels).
        bgcolor : tuple of floats (R, G, B)
            The background color.
        layout : str
            Grid layout for placing the datasets. Possible values are:
            'row', 'col', 'rowcol', 'colrow'.
        scalar_mode : str
             Mode for plotting scalars and tensor magnitudes, one of
             'cut_plane', 'iso_surface', 'both'.
        vector_mode : str
             Mode for plotting vectors, one of 'arrows', 'norm', 'arrows_norm',
             'warp_norm'.
        rel_scaling : float
            Relative scaling of glyphs for vector datasets.
        clamping : bool
            Clamping for vector datasets.
        ranges : dict
            List of data ranges in the form {name : (min, max), ...}.
        is_scalar_bar : bool
            If True, show a scalar bar for each data.
        is_wireframe : bool
            If True, show a wireframe of mesh surface bar for each data.
        opacity : float
            Global surface and wireframe opacity setting in [0.0, 1.0],
        subdomains_args : tuple
            Tuple of (mat_id_name, threshold_limits, single_color), see
            :func:`add_subdomains_surface`, or None.
        rel_text_width : float
            Relative text width.
        fig_filename : str
            File name for saving the resulting scene figure.
        resolution : tuple
            Scene and figure resolution. If None, it is set
            automatically according to the layout.
        filter_names : list of strings
            Omit the listed datasets. If None, it is initialized to
            ['node_groups', 'mat_id']. Pass [] if you need no filtering.
        only_names : list of strings
            Draw only the listed datasets. If None, it is initialized all names
            besides those in filter_names.
        group_names : list of tuples
            List of data names in the form [(name1, ..., nameN), (...)]. Plots
            of data named in each group are superimposed. Repetitions of names
            are possible.
        step : int, optional
            If not None, the time step to display. The closest higher step is
            used if the desired one is not available. Has precedence over
            `time`.
        time : float, optional
            If not None, the time of the time step to display. The closest
            higher time is used if the desired one is not available.
        anti_aliasing : int
            Value of anti-aliasing.
        domain_specific : dict
            Domain-specific drawing functions and configurations.
        """
        self.fgcolor = fgcolor
        self.bgcolor = bgcolor

        if filter_names is None:
            filter_names = ['node_groups', 'mat_id']

        if rel_text_width is None:
            rel_text_width = 0.02

        if isinstance(scalar_mode, basestr):
            if scalar_mode == 'both':
                scalar_mode = ('cut_plane', 'iso_surface')
            elif scalar_mode in ('cut_plane', 'iso_surface'):
                scalar_mode = (scalar_mode, )
            else:
                raise ValueError('bad value of scalar_mode parameter! (%s)' %
                                 scalar_mode)
        else:
            for sm in scalar_mode:
                if not sm in ('cut_plane', 'iso_surface'):
                    raise ValueError(
                        'bad value of scalar_mode parameter! (%s)' % sm)

        if isinstance(vector_mode, basestr):
            if vector_mode == 'arrows_norm':
                vector_mode = ('arrows', 'norm')
            elif vector_mode == 'warp_norm':
                vector_mode = ('warp', 'norm')
            elif vector_mode in ('arrows', 'norm'):
                vector_mode = (vector_mode, )
            elif vector_mode == 'cut_plane':
                if is_3d:
                    vector_mode = ('cut_plane', )
                else:
                    vector_mode = ('arrows', )
            else:
                raise ValueError('bad value of vector_mode parameter! (%s)' %
                                 vector_mode)
        else:
            for vm in vector_mode:
                if not vm in ('arrows', 'norm', 'warp'):
                    raise ValueError(
                        'bad value of vector_mode parameter! (%s)' % vm)

        mlab.options.offscreen = self.offscreen

        self.size_hint = self.get_size_hint(layout, resolution=resolution)

        is_new_scene = False

        if scene is not None:
            if scene is not self.scene:
                is_new_scene = True
                self.scene = scene
            gui = None

        else:
            if (self.scene is not None) and (not self.scene.running):
                self.scene = None

            if self.scene is None:
                if self.offscreen:
                    gui = None
                    scene = mlab.figure(fgcolor=fgcolor,
                                        bgcolor=bgcolor,
                                        size=self.size_hint)

                else:
                    gui = ViewerGUI(viewer=self,
                                    fgcolor=fgcolor,
                                    bgcolor=bgcolor)
                    scene = gui.scene.mayavi_scene

                if scene is not self.scene:
                    is_new_scene = True
                    self.scene = scene

            else:
                gui = self.gui
                scene = self.scene

        self.engine = mlab.get_engine()
        self.engine.current_scene = self.scene

        self.gui = gui

        self.file_source = create_file_source(self.filename,
                                              watch=self.watch,
                                              offscreen=self.offscreen)
        steps, times = self.file_source.get_ts_info()
        has_several_times = len(times) > 0
        has_several_steps = has_several_times or (len(steps) > 0)

        if gui is not None:
            gui.has_several_steps = has_several_steps

        self.reload_source = reload_source = ReloadSource()
        reload_source._viewer = self
        reload_source._source = self.file_source

        if has_several_steps:
            self.set_step = set_step = SetStep()
            set_step._viewer = self
            set_step._source = self.file_source
            if step is not None:
                step = step if step >= 0 else steps[-1] + step + 1
                assert_(steps[0] <= step <= steps[-1],
                        msg='invalid time step! (%d <= %d <= %d)' %
                        (steps[0], step, steps[-1]))
                set_step.step = step

            elif time is not None:
                assert_(times[0] <= time <= times[-1],
                        msg='invalid time! (%e <= %e <= %e)' %
                        (times[0], time, times[-1]))
                set_step.time = time

            else:
                set_step.step = steps[0]

            if self.watch:
                self.file_source.setup_notification(set_step, 'file_changed')

            if gui is not None:
                gui.set_step = set_step

        else:
            if self.watch:
                self.file_source.setup_notification(reload_source,
                                                    'reload_source')

        self.options.update(get_arguments(omit=['self', 'file_source']))

        if gui is None:
            self.render_scene(scene, self.options)
            self.reset_view()
            if is_scalar_bar:
                self.show_scalar_bars(self.scalar_bars)

        else:
            traits_view = View(
                Item(
                    'scene',
                    editor=SceneEditor(scene_class=MayaviScene),
                    show_label=False,
                    width=self.size_hint[0],
                    height=self.size_hint[1],
                    style='custom',
                ),
                Group(
                    Item('set_step',
                         defined_when='set_step is not None',
                         show_label=False,
                         style='custom'), ),
                HGroup(
                    spring,
                    Item('button_make_snapshots_steps',
                         show_label=False,
                         enabled_when='has_several_steps == True'),
                    Item('button_make_animation_steps',
                         show_label=False,
                         enabled_when='has_several_steps == True'),
                    spring,
                    Item('button_make_snapshots_times',
                         show_label=False,
                         enabled_when='has_several_steps == True'),
                    Item('button_make_animation_times',
                         show_label=False,
                         enabled_when='has_several_steps == True'),
                    spring,
                ),
                HGroup(spring, Item('button_reload', show_label=False),
                       Item('button_view', show_label=False),
                       Item('button_quit', show_label=False)),
                resizable=True,
                buttons=[],
                handler=ClosingHandler(),
            )

            if is_new_scene:
                if show:
                    gui.configure_traits(view=traits_view)

                else:
                    gui.edit_traits(view=traits_view)

        return gui
예제 #4
0
# -*- coding: utf-8 -*-
"""
演示如何自定义控制器类,和各种监听函数
"""
from enthought.traits.api import HasTraits, Str, Int
from enthought.traits.ui.api import View, Item, Group, Handler
from enthought.traits.ui.menu import ModalButtons

g1 = [Item('department', label=u"部门"), Item('name', label=u"姓名")]
g2 = [Item('salary', label=u"工资"), Item('bonus', label=u"奖金")]


class Employee(HasTraits):
    name = Str
    department = Str
    salary = Int
    bonus = Int

    def _department_changed(self):
        print self, "department changed to ", self.department

    def __str__(self):
        return "<Employee at 0x%x>" % id(self)


view1 = View(Group(*g1, label=u'个人信息', show_border=True),
             Group(*g2, label=u'收入', show_border=True),
             title=u"外部视图",
             kind="modal",
             buttons=ModalButtons)
예제 #5
0
파일: fit.py 프로젝트: pmrup/labtools
from labtools.analysis.plot import Plot

from labtools.utils.data_viewer import StructArrayData

import numpy as np

from labtools.log import create_logger
log = create_logger(__name__)


class DlsError(Exception):
    pass


dls_analyzer_group = Group(
    Item('filenames', show_label=False, style='custom'),
    'constants',
    Item('process_selected_btn', show_label=False),
    Item('process_all_btn', show_label=False),
)


class DlsFitter(DataFitter):
    """In adition to :class:`DataFitter` it defines :meth:`open_dls` to open dls data
    """
    def _plotter_default(self):
        return Plot(xlabel='Lag time [ms]',
                    ylabel='g2-1',
                    xscale='log',
                    title='g2 -1')
예제 #6
0
             label='name',
             view=no_view),
    TreeNode(node_for=[ControlPoint],
             auto_open=False,
             children='',
             label='label',
             view=no_view),
],
                         on_select=on_tree_select)

# The main view
view = View(
           Group(
               Item(
                    name = 'patientlist',
                    id = 'patientlist',
                    editor = tree_editor,
                    resizable = True ),
                orientation = 'vertical',
                show_labels = True,
                show_left = False, ),
            title = 'Patients',
            id = \
             'dicomutils.viewer.tree',
            dock = 'horizontal',
            drop_class = HasTraits,
            handler = TreeHandler(),
            buttons = [ 'Undo', 'OK', 'Cancel' ],
            resizable = True,
            width = .3,
            height = .3 )
예제 #7
0
class Case(HasTraits):
    '''
    A class representing an avl input file
    '''
    name = Str()
    mach_no = Float()
    symmetry = List(minlen=3, maxlen=3)
    ref_area = Float()
    ref_chord = Float()
    ref_span = Float()
    ref_cg = Array(numpy.float, (3, ))
    CD_p = Float
    geometry = Instance(Geometry)
    cwd = Directory
    case_filename = File('')

    traits_view = View(Item('name'), Item('mach_no'), Item('symmetry'),
                       Item('ref_area'), Item('ref_chord'), Item('ref_span'),
                       Item('ref_cg', editor=ArrayEditor()), Item('CD_p'))

    #@cached_property
    #def _get_geometries(self):
    #    return [self.geometry] if self.geometry is not None else []
    controls = DelegatesTo('geometry')

    def write_input_file(self, file):
        '''
        Write all the data in the case in the appropriate format as in input .avl file for the AVL program
        '''
        file.write(self.name + '\n')
        file.write('#Mach no\n%f\n' % self.mach_no)
        file.write('#iYsym    iZsym    Zsym\n%s    %s    %s\n' %
                   tuple(self.symmetry))
        file.write('#Sref    Cref    Bref\n%f    %f    %f\n' %
                   (self.ref_area, self.ref_chord, self.ref_span))
        file.write('#Xref    Yref    Zref\n%f    %f    %f\n' %
                   tuple(self.ref_cg))
        if self.CD_p != 0.0:
            file.write('#CD_p profile drag coefficient\n%f\n' % self.CD_p)
        file.write('\n')
        file.write('#' * 70)
        file.write('\n')
        self.geometry.write_to_file(file)
        file.write('')

    @classmethod
    def case_from_input_file(cls, file, cwd=''):
        '''
        return an instance of Case by reading its data from an input file
        '''
        lines = file.readlines()
        lines = filter_lines(lines)
        lineno = 0
        name = lines[0]
        mach_no = float(lines[1].split()[0])
        symmetry = lines[2].split()
        symmetry = [int(symmetry[0]), int(symmetry[1]), float(symmetry[2])]
        ref_area, ref_chord, ref_span = [
            float(value) for value in lines[3].split()[:3]
        ]
        ref_cg = [float(value) for value in lines[4].split()[:3]]
        lineno = 5
        try:
            CD_p = float(lines[5].split()[0])
            lineno = 6
        except ValueError:
            CD_p = 0.0
        geometry = Geometry.create_from_lines(lines, lineno, cwd=cwd)
        case = Case(name=name,
                    mach_no=mach_no,
                    symmetry=symmetry,
                    ref_area=ref_area,
                    ref_chord=ref_chord,
                    ref_span=ref_span,
                    ref_cg=ref_cg,
                    CD_p=CD_p,
                    geometry=geometry,
                    cwd=cwd)
        return case
class GenerateProjectorCalibration(HasTraits):
    #width = traits.Int
    #height = traits.Int
    display_id = traits.String
    plot = Instance(Component)
    linedraw = Instance(LineSegmentTool)
    viewport_id = traits.String('viewport_0')
    display_mode = traits.Trait('white on black', 'black on white')
    client = traits.Any
    blit_compressed_image_proxy = traits.Any

    set_display_server_mode_proxy = traits.Any

    traits_view = View(
        Group(Item('display_mode'),
              Item('viewport_id'),
              Item('plot', editor=ComponentEditor(), show_label=False),
              orientation="vertical"),
        resizable=True,
    )

    def __init__(self, *args, **kwargs):
        display_coords_filename = kwargs.pop('display_coords_filename')
        super(GenerateProjectorCalibration, self).__init__(*args, **kwargs)

        fd = open(display_coords_filename, mode='r')
        data = pickle.load(fd)
        fd.close()

        self.param_name = 'virtual_display_config_json_string'
        self.fqdn = '/virtual_displays/' + self.display_id + '/' + self.viewport_id
        self.fqpn = self.fqdn + '/' + self.param_name
        self.client = dynamic_reconfigure.client.Client(self.fqdn)

        self._update_image()
        if 1:
            virtual_display_json_str = rospy.get_param(self.fqpn)
            this_virtual_display = json.loads(virtual_display_json_str)

        if 1:
            virtual_display_json_str = rospy.get_param(self.fqpn)
            this_virtual_display = json.loads(virtual_display_json_str)

            all_points_ok = True
            # error check
            for (x, y) in this_virtual_display['viewport']:
                if (x >= self.width) or (y >= self.height):
                    all_points_ok = False
                    break
            if all_points_ok:
                self.linedraw.points = this_virtual_display['viewport']
            # else:
            #     self.linedraw.points = []
            self._update_image()

    def _update_image(self):
        self._image = np.zeros((self.height, self.width, 3), dtype=np.uint8)
        # draw polygon
        if len(self.linedraw.points) >= 3:
            pts = [(posint(y, self.height - 1), posint(x, self.width - 1))
                   for (x, y) in self.linedraw.points]
            mahotas.polygon.fill_polygon(pts, self._image[:, :, 0])
            self._image[:, :, 0] *= 255
            self._image[:, :, 1] = self._image[:, :, 0]
            self._image[:, :, 2] = self._image[:, :, 0]

        # draw red horizontal stripes
        for i in range(0, self.height, 100):
            self._image[i:i + 10, :, 0] = 255

        # draw blue vertical stripes
        for i in range(0, self.width, 100):
            self._image[:, i:i + 10, 2] = 255

        if hasattr(self, '_pd'):
            self._pd.set_data("imagedata", self._image)
        self.send_array()
        if len(self.linedraw.points) >= 3:
            self.update_ROS_params()

    def _plot_default(self):
        self._pd = ArrayPlotData()
        self._pd.set_data("imagedata", self._image)

        plot = Plot(self._pd, default_origin="top left")
        plot.x_axis.orientation = "top"
        img_plot = plot.img_plot("imagedata")[0]

        plot.bgcolor = "white"

        # Tweak some of the plot properties
        plot.title = "Click to add points, press Enter to clear selection"
        plot.padding = 50
        plot.line_width = 1

        # Attach some tools to the plot
        pan = PanTool(plot, drag_button="right", constrain_key="shift")
        plot.tools.append(pan)
        zoom = ZoomTool(component=plot, tool_mode="box", always_on=False)
        plot.overlays.append(zoom)

        return plot

    def _linedraw_default(self):
        linedraw = LineSegmentTool(self.plot, color=(0.5, 0.5, 0.9, 1.0))
        self.plot.overlays.append(linedraw)
        linedraw.on_trait_change(self.points_changed, 'points[]')
        return linedraw

    def points_changed(self):
        self._update_image()

    @traits.on_trait_change('display_mode')
    def send_array(self):
        # create an array
        if self.display_mode.endswith(' on black'):
            bgcolor = (0, 0, 0, 1)
        elif self.display_mode.endswith(' on white'):
            bgcolor = (1, 1, 1, 1)

        if self.display_mode.startswith('black '):
            color = (0, 0, 0, 1)
        elif self.display_mode.startswith('white '):
            color = (1, 1, 1, 1)

        fname = tempfile.mktemp('.png')
        try:
            scipy.misc.imsave(fname, self._image)
            image = freemoovr.msg.FreemooVRCompressedImage()
            image.format = 'png'
            image.data = open(fname).read()
            self.blit_compressed_image_proxy(image)
        finally:
            os.unlink(fname)

    def get_viewport_verts(self):
        # convert to integers
        pts = [(posint(x, self.width - 1), posint(y, self.height - 1))
               for (x, y) in self.linedraw.points]
        # convert to list of lists for maximal json compatibility
        return [list(x) for x in pts]
예제 #9
0
파일: fitgui.py 프로젝트: eteq/pymodelfit
class FitGui(HasTraits):
    """
    This class represents the fitgui application state.
    """

    plot = Instance(Plot)
    colorbar = Instance(ColorBar)
    plotcontainer = Instance(HPlotContainer)
    tmodel = Instance(TraitedModel,allow_none=False)
    nomodel = Property
    newmodel = Button('New Model...')
    fitmodel = Button('Fit Model')
    showerror = Button('Fit Error')
    updatemodelplot = Button('Update Model Plot')
    autoupdate = Bool(True)
    data = Array(dtype=float,shape=(2,None))
    weights = Array
    weighttype = Enum(('custom','equal','lin bins','log bins'))
    weightsvary = Property(Bool)
    weights0rem = Bool(True)
    modelselector = NewModelSelector
    ytype = Enum(('data and model','residuals'))

    zoomtool = Instance(ZoomTool)
    pantool = Instance(PanTool)

    scattertool = Enum(None,'clicktoggle','clicksingle','clickimmediate','lassoadd','lassoremove','lassoinvert')
    selectedi = Property #indecies of the selected objects
    weightchangesel = Button('Set Selection To')
    weightchangeto = Float(1.0)
    delsel = Button('Delete Selected')
    unselectonaction = Bool(True)
    clearsel = Button('Clear Selections')
    lastselaction = Str('None')

    datasymb = Button('Data Symbol...')
    modline = Button('Model Line...')

    savews = Button('Save Weights')
    loadws = Button('Load Weights')
    _savedws = Array

    plotname = Property
    updatestats = Event
    chi2 = Property(Float,depends_on='updatestats')
    chi2r = Property(Float,depends_on='updatestats')


    nmod = Int(1024)
    #modelpanel = View(Label('empty'),kind='subpanel',title='model editor')
    modelpanel = View

    panel_view = View(VGroup(
                       Item('plot', editor=ComponentEditor(),show_label=False),
                       HGroup(Item('tmodel.modelname',show_label=False,style='readonly'),
                              Item('nmod',label='Number of model points'),
                              Item('updatemodelplot',show_label=False,enabled_when='not autoupdate'),
                              Item('autoupdate',label='Auto?'))
                      ),
                    title='Model Data Fitter'
                    )


    selection_view = View(Group(
                           Item('scattertool',label='Selection Mode',
                                 editor=EnumEditor(values={None:'1:No Selection',
                                                           'clicktoggle':'3:Toggle Select',
                                                           'clicksingle':'2:Single Select',
                                                           'clickimmediate':'7:Immediate',
                                                           'lassoadd':'4:Add with Lasso',
                                                           'lassoremove':'5:Remove with Lasso',
                                                           'lassoinvert':'6:Invert with Lasso'})),
                           Item('unselectonaction',label='Clear Selection on Action?'),
                           Item('clearsel',show_label=False),
                           Item('weightchangesel',show_label=False),
                           Item('weightchangeto',label='Weight'),
                           Item('delsel',show_label=False)
                         ),title='Selection Options')

    traits_view = View(VGroup(
                        HGroup(Item('object.plot.index_scale',label='x-scaling',
                                    enabled_when='object.plot.index_mapper.range.low>0 or object.plot.index_scale=="log"'),
                              spring,
                              Item('ytype',label='y-data'),
                              Item('object.plot.value_scale',label='y-scaling',
                                   enabled_when='object.plot.value_mapper.range.low>0 or object.plot.value_scale=="log"')
                              ),
                       Item('plotcontainer', editor=ComponentEditor(),show_label=False),
                       HGroup(VGroup(HGroup(Item('weighttype',label='Weights:'),
                                            Item('savews',show_label=False),
                                            Item('loadws',enabled_when='_savedws',show_label=False)),
                                Item('weights0rem',label='Remove 0-weight points for fit?'),
                                HGroup(Item('newmodel',show_label=False),
                                       Item('fitmodel',show_label=False),
                                       Item('showerror',show_label=False,enabled_when='tmodel.lastfitfailure'),
                                       VGroup(Item('chi2',label='Chi2:',style='readonly',format_str='%6.6g',visible_when='tmodel.model is not None'),
                                             Item('chi2r',label='reduced:',style='readonly',format_str='%6.6g',visible_when='tmodel.model is not None'))
                                       )#Item('selbutton',show_label=False))
                              ,springy=False),spring,
                              VGroup(HGroup(Item('autoupdate',label='Auto?'),
                              Item('updatemodelplot',show_label=False,enabled_when='not autoupdate')),
                              Item('nmod',label='Nmodel'),
                              HGroup(Item('datasymb',show_label=False),Item('modline',show_label=False)),springy=False),springy=True),
                       '_',
                       HGroup(Item('scattertool',label='Selection Mode',
                                 editor=EnumEditor(values={None:'1:No Selection',
                                                           'clicktoggle':'3:Toggle Select',
                                                           'clicksingle':'2:Single Select',
                                                           'clickimmediate':'7:Immediate',
                                                           'lassoadd':'4:Add with Lasso',
                                                           'lassoremove':'5:Remove with Lasso',
                                                           'lassoinvert':'6:Invert with Lasso'})),
                           Item('unselectonaction',label='Clear Selection on Action?'),
                           Item('clearsel',show_label=False),
                           Item('weightchangesel',show_label=False),
                           Item('weightchangeto',label='Weight'),
                           Item('delsel',show_label=False),
                         ),#layout='flow'),
                       Item('tmodel',show_label=False,style='custom',editor=InstanceEditor(kind='subpanel'))
                      ),
                    handler=FGHandler(),
                    resizable=True,
                    title='Data Fitting',
                    buttons=['OK','Cancel'],
                    width=700,
                    height=900
                    )


    def __init__(self,xdata=None,ydata=None,weights=None,model=None,
                 include_models=None,exclude_models=None,fittype=None,**traits):
        """

        :param xdata: the first dimension of the data to be fit
        :type xdata: array-like
        :param ydata: the second dimension of the data to be fit
        :type ydata: array-like
        :param weights:
            The weights to apply to the data. Statistically interpreted as inverse
            errors (*not* inverse variance). May be any of the following forms:

            * None for equal weights
            * an array of points that must match `ydata`
            * a 2-sequence of arrays (xierr,yierr) such that xierr matches the
              `xdata` and yierr matches `ydata`
            * a function called as f(params) that returns an array of weights
              that match one of the above two conditions

        :param model: the initial model to use to fit this data
        :type model:
            None, string, or :class:`pymodelfit.core.FunctionModel1D`
            instance.
        :param include_models:
            With `exclude_models`, specifies which models should be available in
            the "new model" dialog (see `models.list_models` for syntax).
        :param exclude_models:
            With `include_models`, specifies which models should be available in
            the "new model" dialog (see `models.list_models` for syntax).
        :param fittype:
            The fitting technique for the initial fit (see
            :class:`pymodelfit.core.FunctionModel`).
        :type fittype: string

        kwargs are passed in as any additional traits to apply to the
        application.

        """

        self.modelpanel = View(Label('empty'),kind='subpanel',title='model editor')

        self.tmodel = TraitedModel(model)

        if model is not None and fittype is not None:
            self.tmodel.model.fittype = fittype

        if xdata is None or ydata is None:
            if not hasattr(self.tmodel.model,'data') or self.tmodel.model.data is None:
                raise ValueError('data not provided and no data in model')
            if xdata is None:
                xdata = self.tmodel.model.data[0]
            if ydata is None:
                ydata = self.tmodel.model.data[1]
            if weights is None:
                weights = self.tmodel.model.data[2]

        self.on_trait_change(self._paramsChanged,'tmodel.paramchange')

        self.modelselector = NewModelSelector(include_models,exclude_models)

        self.data = [xdata,ydata]


        if weights is None:
            self.weights = np.ones_like(xdata)
            self.weighttype = 'equal'
        else:
            self.weights = np.array(weights,copy=True)
            self.savews = True

        weights1d = self.weights
        while len(weights1d.shape)>1:
            weights1d = np.sum(weights1d**2,axis=0)

        pd = ArrayPlotData(xdata=self.data[0],ydata=self.data[1],weights=weights1d)
        self.plot = plot = Plot(pd,resizable='hv')

        self.scatter = plot.plot(('xdata','ydata','weights'),name='data',
                         color_mapper=_cmapblack if self.weights0rem else _cmap,
                         type='cmap_scatter', marker='circle')[0]

        self.errorplots = None

        if not isinstance(model,FunctionModel1D):
            self.fitmodel = True

        self.updatemodelplot = False #force plot update - generates xmod and ymod
        plot.plot(('xmod','ymod'),name='model',type='line',line_style='dash',color='black',line_width=2)
        del plot.x_mapper.range.sources[-1]  #remove the line plot from the x_mapper source so only the data is tied to the scaling

        self.on_trait_change(self._rangeChanged,'plot.index_mapper.range.updated')

        self.pantool = PanTool(plot,drag_button='left')
        plot.tools.append(self.pantool)
        self.zoomtool = ZoomTool(plot)
        self.zoomtool.prev_state_key = KeySpec('a')
        self.zoomtool.next_state_key = KeySpec('s')
        plot.overlays.append(self.zoomtool)

        self.scattertool = None
        self.scatter.overlays.append(ScatterInspectorOverlay(self.scatter,
                        hover_color = "black",
                        selection_color="black",
                        selection_outline_color="red",
                        selection_line_width=2))


        self.colorbar = colorbar = ColorBar(index_mapper=LinearMapper(range=plot.color_mapper.range),
                                            color_mapper=plot.color_mapper.range,
                                            plot=plot,
                                            orientation='v',
                                            resizable='v',
                                            width = 30,
                                            padding = 5)
        colorbar.padding_top = plot.padding_top
        colorbar.padding_bottom = plot.padding_bottom
        colorbar._axis.title = 'Weights'

        self.plotcontainer = container = HPlotContainer(use_backbuffer=True)
        container.add(plot)
        container.add(colorbar)

        super(FitGui,self).__init__(**traits)

        self.on_trait_change(self._scale_change,'plot.value_scale,plot.index_scale')

        if weights is not None and len(weights)==2:
            self.weightsChanged() #update error bars

    def _weights0rem_changed(self,old,new):
        if new:
            self.plot.color_mapper = _cmapblack(self.plot.color_mapper.range)
        else:
            self.plot.color_mapper = _cmap(self.plot.color_mapper.range)
        self.plot.request_redraw()
#        if old and self.filloverlay in self.plot.overlays:
#            self.plot.overlays.remove(self.filloverlay)
#        if new:
#            self.plot.overlays.append(self.filloverlay)
#        self.plot.request_redraw()

    def _paramsChanged(self):
        self.updatemodelplot = True

    def _nmod_changed(self):
        self.updatemodelplot = True

    def _rangeChanged(self):
        self.updatemodelplot = True

    #@on_trait_change('object.plot.value_scale,object.plot.index_scale',post_init=True)
    def _scale_change(self):
        self.plot.request_redraw()

    def _updatemodelplot_fired(self,new):
        #If the plot has not been generated yet, just skip the update
        if self.plot is None:
            return

        #if False (e.g. button click), update regardless, otherwise check for autoupdate
        if new and not self.autoupdate:
            return

        mod = self.tmodel.model
        if self.ytype == 'data and model':
            if mod:
                #xd = self.data[0]
                #xmod = np.linspace(np.min(xd),np.max(xd),self.nmod)
                xl = self.plot.index_range.low
                xh = self.plot.index_range.high
                if self.plot.index_scale=="log":
                    xmod = np.logspace(np.log10(xl),np.log10(xh),self.nmod)
                else:
                    xmod = np.linspace(xl,xh,self.nmod)
                ymod = self.tmodel.model(xmod)

                self.plot.data.set_data('xmod',xmod)
                self.plot.data.set_data('ymod',ymod)

            else:
                self.plot.data.set_data('xmod',[])
                self.plot.data.set_data('ymod',[])
        elif self.ytype == 'residuals':
            if mod:
                self.plot.data.set_data('xmod',[])
                self.plot.data.set_data('ymod',[])
                #residuals set the ydata instead of setting the model
                res = mod.residuals(*self.data)
                self.plot.data.set_data('ydata',res)
            else:
                self.ytype = 'data and model'
        else:
            assert True,'invalid Enum'


    def _fitmodel_fired(self):
        from warnings import warn

        preaup = self.autoupdate
        try:
            self.autoupdate = False
            xd,yd = self.data
            kwd = {'x':xd,'y':yd}
            if self.weights is not None:
                w = self.weights
                if self.weights0rem:
                    if xd.shape == w.shape:
                        m = w!=0
                        w = w[m]
                        kwd['x'] = kwd['x'][m]
                        kwd['y'] = kwd['y'][m]
                    elif np.any(w==0):
                        warn("can't remove 0-weighted points if weights don't match data")
                kwd['weights'] = w
            self.tmodel.fitdata = kwd
        finally:
            self.autoupdate = preaup

        self.updatemodelplot = True
        self.updatestats = True


#    def _tmodel_changed(self,old,new):
#        #old is only None before it is initialized
#        if new is not None and old is not None and new.model is not None:
#            self.fitmodel = True

    def _newmodel_fired(self,newval):
        from inspect import isclass

        if isinstance(newval,basestring) or isinstance(newval,FunctionModel1D) \
           or (isclass(newval) and issubclass(newval,FunctionModel1D)):
            self.tmodel = TraitedModel(newval)
        else:
            if self.modelselector.edit_traits(kind='modal').result:
                cls = self.modelselector.selectedmodelclass
                if cls is None:
                    self.tmodel = TraitedModel(None)
                elif self.modelselector.isvarargmodel:
                    self.tmodel = TraitedModel(cls(self.modelselector.modelargnum))
                    self.fitmodel = True
                else:
                    self.tmodel = TraitedModel(cls())
                    self.fitmodel = True
            else: #cancelled
                return

    def _showerror_fired(self,evt):
        if self.tmodel.lastfitfailure:
            ex = self.tmodel.lastfitfailure
            dialog = HasTraits(s=ex.__class__.__name__+': '+str(ex))
            view = View(Item('s',style='custom',show_label=False),
                        resizable=True,buttons=['OK'],title='Fitting error message')
            dialog.edit_traits(view=view)

    @cached_property
    def _get_chi2(self):
        try:
            return self.tmodel.model.chi2Data()[0]
        except:
            return 0

    @cached_property
    def _get_chi2r(self):
        try:
            return self.tmodel.model.chi2Data()[1]
        except:
            return 0

    def _get_nomodel(self):
        return self.tmodel.model is None

    def _get_weightsvary(self):
        w = self.weights
        return np.any(w!=w[0])if len(w)>0 else False

    def _get_plotname(self):
        xlabel = self.plot.x_axis.title
        ylabel = self.plot.y_axis.title
        if xlabel == '' and ylabel == '':
            return ''
        else:
            return xlabel+' vs '+ylabel
    def _set_plotname(self,val):
        if isinstance(val,basestring):
            val = val.split('vs')
            if len(val) ==1:
                val = val.split('-')
            val = [v.strip() for v in val]
        self.x_axis.title = val[0]
        self.y_axis.title = val[1]


    #selection-related
    def _scattertool_changed(self,old,new):
        if new == 'No Selection':
            self.plot.tools[0].drag_button='left'
        else:
            self.plot.tools[0].drag_button='right'
        if old is not None and 'lasso' in old:
            if new is not None and 'lasso' in new:
                #connect correct callbacks
                self.lassomode = new.replace('lasso','')
                return
            else:
                #TODO:test
                self.scatter.tools[-1].on_trait_change(self._lasso_handler,
                                            'selection_changed',remove=True)
                del self.scatter.overlays[-1]
                del self.lassomode
        elif old == 'clickimmediate':
            self.scatter.index.on_trait_change(self._immediate_handler,
                                            'metadata_changed',remove=True)

        self.scatter.tools = []
        if new is None:
            pass
        elif 'click' in new:
            smodemap = {'clickimmediate':'single','clicksingle':'single',
                        'clicktoggle':'toggle'}
            self.scatter.tools.append(ScatterInspector(self.scatter,
                                      selection_mode=smodemap[new]))
            if new == 'clickimmediate':
                self.clearsel = True
                self.scatter.index.on_trait_change(self._immediate_handler,
                                                    'metadata_changed')
        elif 'lasso' in new:
            lasso_selection = LassoSelection(component=self.scatter,
                                    selection_datasource=self.scatter.index)
            self.scatter.tools.append(lasso_selection)
            lasso_overlay = LassoOverlay(lasso_selection=lasso_selection,
                                         component=self.scatter)
            self.scatter.overlays.append(lasso_overlay)
            self.lassomode = new.replace('lasso','')
            lasso_selection.on_trait_change(self._lasso_handler,
                                            'selection_changed')
            lasso_selection.on_trait_change(self._lasso_handler,
                                            'selection_completed')
            lasso_selection.on_trait_change(self._lasso_handler,
                                            'updated')
        else:
            raise TraitsError('invalid scattertool value')

    def _weightchangesel_fired(self):
        self.weights[self.selectedi] = self.weightchangeto
        if self.unselectonaction:
            self.clearsel = True

        self._sel_alter_weights()
        self.lastselaction = 'weightchangesel'

    def _delsel_fired(self):
        self.weights[self.selectedi] = 0
        if self.unselectonaction:
            self.clearsel = True

        self._sel_alter_weights()
        self.lastselaction = 'delsel'

    def _sel_alter_weights(self):
        if self.weighttype != 'custom':
            self._customweights = self.weights
            self.weighttype = 'custom'
        self.weightsChanged()

    def _clearsel_fired(self,event):
        if isinstance(event,list):
            self.scatter.index.metadata['selections'] = event
        else:
            self.scatter.index.metadata['selections'] = list()

    def _lasso_handler(self,name,new):
        if name == 'selection_changed':
            lassomask = self.scatter.index.metadata['selection'].astype(int)
            clickmask = np.zeros_like(lassomask)
            clickmask[self.scatter.index.metadata['selections']] = 1

            if self.lassomode == 'add':
                mask = clickmask | lassomask
            elif self.lassomode == 'remove':
                mask = clickmask & ~lassomask
            elif self.lassomode == 'invert':
                mask = np.logical_xor(clickmask,lassomask)
            else:
                raise TraitsError('lassomode is in invalid state')

            self.scatter.index.metadata['selections'] = list(np.where(mask)[0])
        elif name == 'selection_completed':
            self.scatter.overlays[-1].visible = False
        elif name == 'updated':
            self.scatter.overlays[-1].visible = True
        else:
            raise ValueError('traits event name %s invalid'%name)

    def _immediate_handler(self):
        sel = self.selectedi
        if len(sel) > 1:
            self.clearsel = True
            raise TraitsError('selection error in immediate mode - more than 1 selection')
        elif len(sel)==1:
            if self.lastselaction != 'None':
                setattr(self,self.lastselaction,True)
            del sel[0]

    def _savews_fired(self):
        self._savedws = self.weights.copy()

    def _loadws_fired(self):
        self.weights = self._savedws
        self._savews_fired()

    def _get_selectedi(self):
        return self.scatter.index.metadata['selections']


    @on_trait_change('data,ytype',post_init=True)
    def dataChanged(self):
        """
        Updates the application state if the fit data are altered - the GUI will
        know if you give it a new data array, but not if the data is changed
        in-place.
        """
        pd = self.plot.data
        #TODO:make set_data apply to both simultaneously?
        pd.set_data('xdata',self.data[0])
        pd.set_data('ydata',self.data[1])

        self.updatemodelplot = False

    @on_trait_change('weights',post_init=True)
    def weightsChanged(self):
        """
        Updates the application state if the weights/error bars for this model
        are changed - the GUI will automatically do this if you give it a new
        set of weights array, but not if they are changed in-place.
        """
        weights = self.weights
        if 'errorplots' in self.trait_names():
            #TODO:switch this to updating error bar data/visibility changing
            if self.errorplots is not None:
                self.plot.remove(self.errorplots[0])
                self.plot.remove(self.errorplots[1])
                self.errorbarplots = None

            if len(weights.shape)==2 and weights.shape[0]==2:
                xerr,yerr = 1/weights

                high = ArrayDataSource(self.scatter.index.get_data()+xerr)
                low = ArrayDataSource(self.scatter.index.get_data()-xerr)
                ebpx = ErrorBarPlot(orientation='v',
                                   value_high = high,
                                   value_low = low,
                                   index = self.scatter.value,
                                   value = self.scatter.index,
                                   index_mapper = self.scatter.value_mapper,
                                   value_mapper = self.scatter.index_mapper
                                )
                self.plot.add(ebpx)

                high = ArrayDataSource(self.scatter.value.get_data()+yerr)
                low = ArrayDataSource(self.scatter.value.get_data()-yerr)
                ebpy = ErrorBarPlot(value_high = high,
                                   value_low = low,
                                   index = self.scatter.index,
                                   value = self.scatter.value,
                                   index_mapper = self.scatter.index_mapper,
                                   value_mapper = self.scatter.value_mapper
                                )
                self.plot.add(ebpy)

                self.errorplots = (ebpx,ebpy)

        while len(weights.shape)>1:
            weights = np.sum(weights**2,axis=0)
        self.plot.data.set_data('weights',weights)
        self.plot.plots['data'][0].color_mapper.range.refresh()

        if self.weightsvary:
            if self.colorbar not in self.plotcontainer.components:
                self.plotcontainer.add(self.colorbar)
                self.plotcontainer.request_redraw()
        elif self.colorbar in self.plotcontainer.components:
                self.plotcontainer.remove(self.colorbar)
                self.plotcontainer.request_redraw()


    def _weighttype_changed(self, name, old, new):
        if old == 'custom':
            self._customweights = self.weights

        if new == 'custom':
            self.weights = self._customweights #if hasattr(self,'_customweights') else np.ones_like(self.data[0])
        elif new == 'equal':
            self.weights = np.ones_like(self.data[0])
        elif new == 'lin bins':
            self.weights = binned_weights(self.data[0],10,False)
        elif new == 'log bins':
            self.weights = binned_weights(self.data[0],10,True)
        else:
            raise TraitError('Invalid Enum value on weighttype')

    def getModelInitStr(self):
        """
        Generates a python code string that can be used to generate a model with
        parameters matching the model in this :class:`FitGui`.

        :returns: initializer string

        """
        mod = self.tmodel.model
        if mod is None:
            return 'None'
        else:
            parstrs = []
            for p,v in mod.pardict.iteritems():
                parstrs.append(p+'='+str(v))
            if mod.__class__._pars is None: #varargs need to have the first argument give the right number
                varcount = len(mod.params)-len(mod.__class__._statargs)
                parstrs.insert(0,str(varcount))
            return '%s(%s)'%(mod.__class__.__name__,','.join(parstrs))

    def getModelObject(self):
        """
        Gets the underlying object representing the model for this fit.

        :returns: The :class:`pymodelfit.core.FunctionModel1D` object.
        """
        return self.tmodel.model
예제 #10
0
class ContourGridPlane(Module):

    # The version of this class.  Used for persistence.
    __version__ = 0

    # The grid plane component.
    grid_plane = Instance(GridPlane, allow_none=False, record=True)

    # Specifies if contouring is to be done or not.
    enable_contours = Bool(True, desc='if contours are generated')

    # The contour component that contours the data.
    contour = Instance(Contour, allow_none=False, record=True)

    # The actor component that represents the visualization.
    actor = Instance(Actor, allow_none=False, record=True)

    input_info = PipelineInfo(
        datasets=['image_data', 'structured_grid', 'rectilinear_grid'],
        attribute_types=['any'],
        attributes=['any'])

    view = View([
        Group(Item(name='grid_plane', style='custom'), show_labels=False),
        Group(Item(name='enable_contours')),
        Group(Item(name='contour',
                   style='custom',
                   enabled_when='object.enable_contours'),
              Item(name='actor', style='custom'),
              show_labels=False)
    ])

    ######################################################################
    # `Module` interface
    ######################################################################
    def setup_pipeline(self):
        """Override this method so that it *creates* the tvtk
        pipeline.

        This method is invoked when the object is initialized via
        `__init__`.  Note that at the time this method is called, the
        tvtk data pipeline will *not* yet be setup.  So upstream data
        will not be available.  The idea is that you simply create the
        basic objects and setup those parts of the pipeline not
        dependent on upstream sources and filters.  You should also
        set the `actors` attribute up at this point.
        """
        # Create the components
        self.grid_plane = GridPlane()
        self.contour = Contour(auto_contours=True, number_of_contours=10)
        self.actor = Actor()

    def update_pipeline(self):
        """Override this method so that it *updates* the tvtk pipeline
        when data upstream is known to have changed.

        This method is invoked (automatically) when any of the inputs
        sends a `pipeline_changed` event.
        """
        mm = self.module_manager
        if mm is None:
            return

        # Data is available, so set the input for the grid plane.
        self.grid_plane.inputs = [mm.source]

        # This makes sure that any changes made to enable_contours
        # when the module is not running are updated when it is
        # started.
        self._enable_contours_changed(self.enable_contours)
        # Set the LUT for the mapper.
        self.actor.set_lut(mm.scalar_lut_manager.lut)

        self.pipeline_changed = True

    def update_data(self):
        """Override this method so that it flushes the vtk pipeline if
        that is necessary.

        This method is invoked (automatically) when any of the inputs
        sends a `data_changed` event.
        """
        # Just set data_changed, the components should do the rest if
        # they are connected.
        self.data_changed = True

    ######################################################################
    # Non-public methods.
    ######################################################################
    def _filled_contours_changed(self, value):
        """When filled contours are enabled, the mapper should use the
        the cell data, otherwise it should use the default scalar
        mode.
        """
        if value:
            self.actor.mapper.scalar_mode = 'use_cell_data'
        else:
            self.actor.mapper.scalar_mode = 'default'
        self.render()

    def _enable_contours_changed(self, value):
        """Turns on and off the contours."""
        if self.module_manager is None:
            return
        if value:
            self.actor.inputs = [self.contour]
            if self.contour.filled_contours:
                self.actor.mapper.scalar_mode = 'use_cell_data'
        else:
            self.actor.inputs = [self.grid_plane]
            self.actor.mapper.scalar_mode = 'default'
        self.render()

    def _grid_plane_changed(self, old, new):
        cont = self.contour
        if cont is not None:
            cont.inputs = [new]
        self._change_components(old, new)

    def _contour_changed(self, old, new):
        if old is not None:
            old.on_trait_change(self._filled_contours_changed,
                                'filled_contours',
                                remove=True)
        new.on_trait_change(self._filled_contours_changed, 'filled_contours')
        # Setup the contours input.
        gp = self.grid_plane
        if gp is not None:
            new.inputs = [gp]

        # Setup the actor.
        actor = self.actor
        if actor is not None:
            actor.inputs = [new]
        self._change_components(old, new)

    def _actor_changed(self, old, new):
        if old is None:
            # First time this is set.
            new.property.set(line_width=2.0)

        # Set the actors scene and input.
        new.scene = self.scene
        cont = self.contour
        if cont is not None:
            new.inputs = [cont]
        self._change_components(old, new)
예제 #11
0
class SceneModel(TVTKScene):

    ########################################
    # TVTKScene traits.

    light_manager = Property

    picker = Property

    ########################################
    # SceneModel traits.

    # A convenient dictionary based interface to add/remove actors and widgets.
    # This is similar to the interface provided for the ActorEditor.
    actor_map = Dict()

    # This is used primarily to implement the add_actor/remove_actor methods.
    actor_list = List()

    # The actual scene being edited.
    scene_editor = Instance(TVTKScene)

    do_render = Event()

    # Fired when this is activated.
    activated = Event()

    # Fired when this widget is closed.
    closing = Event()

    # This exists just to mirror the TVTKWindow api.
    scene = Property

    ###################################
    # View related traits.

    # Render_window's view.
    _stereo_view = Group(
        Item(name='stereo_render'),
        Item(name='stereo_type'),
        show_border=True,
        label='Stereo rendering',
    )

    # The default view of this object.
    default_view = View(
        Group(Group(
            Item(name='background'),
            Item(name='foreground'),
            Item(name='parallel_projection'),
            Item(name='disable_render'),
            Item(name='off_screen_rendering'),
            Item(name='jpeg_quality'),
            Item(name='jpeg_progressive'),
            Item(name='magnification'),
            Item(name='anti_aliasing_frames'),
        ),
              Group(
                  Item(name='render_window',
                       style='custom',
                       visible_when='object.stereo',
                       editor=InstanceEditor(view=View(_stereo_view)),
                       show_label=False), ),
              label='Scene'),
        Group(Item(name='light_manager',
                   style='custom',
                   editor=InstanceEditor(),
                   show_label=False),
              label='Lights'))

    ###################################
    # Private traits.

    # Used by the editor to determine if the widget was enabled or not.
    enabled_info = Dict()

    def __init__(self, parent=None, **traits):
        """ Initializes the object. """
        # Base class constructor.  We call TVTKScene's super here on purpose.
        # Calling TVTKScene's init will create a new window which we do not
        # want.
        super(TVTKScene, self).__init__(**traits)
        self.control = None

    ######################################################################
    # TVTKScene API.
    ######################################################################
    def render(self):
        """ Force the scene to be rendered. Nothing is done if the
        `disable_render` trait is set to True."""

        self.do_render = True

    def add_actors(self, actors):
        """ Adds a single actor or a tuple or list of actors to the
        renderer."""
        if hasattr(actors, '__iter__'):
            self.actor_list.extend(actors)
        else:
            self.actor_list.append(actors)

    def remove_actors(self, actors):
        """ Removes a single actor or a tuple or list of actors from
        the renderer."""
        my_actors = self.actor_list
        if hasattr(actors, '__iter__'):
            for actor in actors:
                my_actors.remove(actor)
        else:
            my_actors.remove(actors)

    # Conevenience methods.
    add_actor = add_actors
    remove_actor = remove_actors

    def add_widgets(self, widgets, enabled=True):
        """Adds widgets to the renderer.
        """
        if not hasattr(widgets, '__iter__'):
            widgets = [widgets]
        for widget in widgets:
            self.enabled_info[widget] = enabled
        self.add_actors(widgets)

    def remove_widgets(self, widgets):
        """Removes widgets from the renderer."""
        if not hasattr(widgets, '__iter__'):
            widgets = [widgets]
        self.remove_actors(widgets)
        for widget in widgets:
            del self.enabled_info[widget]

    def reset_zoom(self):
        """Reset the camera so everything in the scene fits."""
        if self.scene_editor is not None:
            self.scene_editor.reset_zoom()

    def save(self, file_name, size=None, **kw_args):
        """Saves rendered scene to one of several image formats
        depending on the specified extension of the filename.

        If an additional size (2-tuple) argument is passed the window
        is resized to the specified size in order to produce a
        suitably sized output image.  Please note that when the window
        is resized, the window may be obscured by other widgets and
        the camera zoom is not reset which is likely to produce an
        image that does not reflect what is seen on screen.

        Any extra keyword arguments are passed along to the respective
        image format's save method.
        """
        self._check_scene_editor()
        self.scene_editor.save(file_name, size, **kw_args)

    def save_ps(self, file_name):
        """Saves the rendered scene to a rasterized PostScript image.
        For vector graphics use the save_gl2ps method."""
        self._check_scene_editor()
        self.scene_editor.save_ps(file_name)

    def save_bmp(self, file_name):
        """Save to a BMP image file."""
        self._check_scene_editor()
        self.scene_editor.save_bmp(file_name)

    def save_tiff(self, file_name):
        """Save to a TIFF image file."""
        self._check_scene_editor()
        self.scene_editor.save_tiff(file_name)

    def save_png(self, file_name):
        """Save to a PNG image file."""
        self._check_scene_editor()
        self.scene_editor.save_png(file_name)

    def save_jpg(self, file_name, quality=None, progressive=None):
        """Arguments: file_name if passed will be used, quality is the
        quality of the JPEG(10-100) are valid, the progressive
        arguments toggles progressive jpegs."""
        self._check_scene_editor()
        self.scene_editor.save_jpg(file_name, quality, progressive)

    def save_iv(self, file_name):
        """Save to an OpenInventor file."""
        self._check_scene_editor()
        self.scene_editor.save_iv(file_name)

    def save_vrml(self, file_name):
        """Save to a VRML file."""
        self._check_scene_editor()
        self.scene_editor.save_vrml(file_name)

    def save_oogl(self, file_name):
        """Saves the scene to a Geomview OOGL file. Requires VTK 4 to
        work."""
        self._check_scene_editor()
        self.scene_editor.save_oogl(file_name)

    def save_rib(self, file_name, bg=0, resolution=None, resfactor=1.0):
        """Save scene to a RenderMan RIB file.

        Keyword Arguments:

        file_name -- File name to save to.

        bg -- Optional background option.  If 0 then no background is
        saved.  If non-None then a background is saved.  If left alone
        (defaults to None) it will result in a pop-up window asking
        for yes/no.

        resolution -- Specify the resolution of the generated image in
        the form of a tuple (nx, ny).

        resfactor -- The resolution factor which scales the resolution.
        """
        self._check_scene_editor()
        self.scene_editor.save_rib(file_name, bg, resolution, resfactor)

    def save_wavefront(self, file_name):
        """Save scene to a Wavefront OBJ file.  Two files are
        generated.  One with a .obj extension and another with a .mtl
        extension which contains the material proerties.

        Keyword Arguments:

        file_name -- File name to save to
        """
        self._check_scene_editor()
        self.scene_editor.save_wavefront(file_name)

    def save_gl2ps(self, file_name, exp=None):
        """Save scene to a vector PostScript/EPS/PDF/TeX file using
        GL2PS.  If you choose to use a TeX file then note that only
        the text output is saved to the file.  You will need to save
        the graphics separately.

        Keyword Arguments:

        file_name -- File name to save to.

        exp -- Optionally configured vtkGL2PSExporter object.
        Defaults to None and this will use the default settings with
        the output file type chosen based on the extention of the file
        name.
        """
        self._check_scene_editor()
        self.scene_editor.save_gl2ps(file_name, exp)

    def get_size(self):
        """Return size of the render window."""
        self._check_scene_editor()
        return self.scene_editor.get_size()

    def set_size(self, size):
        """Set the size of the window."""
        self._check_scene_editor()
        self.scene_editor.set_size(size)

    def _update_view(self, x, y, z, vx, vy, vz):
        """Used internally to set the view."""
        if self.scene_editor is not None:
            self.scene_editor._update_view(x, y, z, vx, vy, vz)

    def _check_scene_editor(self):
        if self.scene_editor is None:
            msg = """
            This method requires that there be an active scene editor.
            To do this, you will typically need to invoke::
              object.edit_traits()
            where object is the object that contains the SceneModel.
            """
            raise SceneModelError(msg)

    def _scene_editor_changed(self, old, new):
        if new is None:
            self._renderer = None
            self._renwin = None
            self._interactor = None
        else:
            self._renderer = new._renderer
            self._renwin = new._renwin
            self._interactor = new._interactor

    def _get_picker(self):
        """Getter for the picker."""
        se = self.scene_editor
        if se is not None and hasattr(se, 'picker'):
            return se.picker
        return None

    def _get_light_manager(self):
        """Getter for the light manager."""
        se = self.scene_editor
        if se is not None:
            return se.light_manager
        return None

    ######################################################################
    # SceneModel API.
    ######################################################################
    def _get_scene(self):
        """Getter for the scene property."""
        return self
예제 #12
0
"""
Traits View definition file.

The view trait of the parent class is extracted from the model definition 
file.  This file can either be exec()ed or imported.  See 
core/base.py:Base.trait_view() for what is currently used.  Using exec() 
allows view changes without needing to restart Mayavi, but is slower than 
importing.
"""
# Authors: Prabhu Ramachandran <*****@*****.**>
#          Vibha Srinivasan <*****@*****.**>
#          Judah De Paula <*****@*****.**>
# Copyright (c) 2005-2008, Enthought, Inc.
# License: BSD Style.
from enthought.traits.ui.api import Item, Group, View

view = View(Group(Item(name='function'),
                  Item(name='parametric_function',
                       style='custom',
                       resizable=True),
                  label='Function',
                  show_labels=False),
            Group(Item(name='source', style='custom', resizable=True),
                  label='Source',
                  show_labels=False),
            resizable=True)
예제 #13
0
파일: gui.py 프로젝트: danginsburg/cmp
class CMPGUI(PipelineConfiguration):
    """ The Graphical User Interface for the CMP
    """
    def __init__(self, **kwargs):
        # NOTE: In python 2.6, object.__init__ no longer accepts input
        # arguments.  HasTraits does not define an __init__ and
        # therefore these args were being ignored.
        super(CMPGUI, self).__init__(**kwargs)

    about = Button
    run = Button
    save = Button
    load = Button
    help = Button

    inspect_registration = Button
    inspect_segmentation = Button
    inspect_whitemattermask = Button
    inspect_parcellation = Button
    inspect_reconstruction = Button
    inspect_tractography = Button
    inspect_tractography_filtered = Button
    inspect_fiberfilter = Button
    inspect_connectionmatrix = Button

    main_group = Group(
        VGroup(Item('project_dir',
                    label='Project Directory:',
                    tooltip='Please select the root folder of your project'),
               Item(
                   'generator',
                   label='Generator',
               ),
               Item('diffusion_imaging_model', label='Imaging Modality'),
               label="Project Settings"),
        HGroup(
            VGroup(
                Item('active_dicomconverter',
                     label='DICOM Converter',
                     tooltip="converts DICOM to the Nifti format"),
                Item('active_registration', label='Registration'),
                Item('active_segmentation', label='Segmentation'),
                Item('active_parcellation', label='Parcellation'),
                Item('active_applyregistration', label='Apply registration'),
                Item('active_reconstruction', label='Reconstruction'),
                Item('active_tractography',
                     label='Tractography',
                     tooltip='performs tractography'),
                Item('active_fiberfilter',
                     label='Fiber Filtering',
                     tooltip='applies filtering operation to the fibers'),
                Item('active_connectome',
                     label='Connectome Creation',
                     tooltip='creates the connectivity matrices'),
                # Item('active_statistics', label = 'Statistics'),
                Item('active_cffconverter',
                     label='CFF Converter',
                     tooltip='converts processed files to a connectome file'),
                Item('skip_completed_stages',
                     label='Skip Previously Completed Stages:'),
                label="Stages"),
            VGroup(
                #Item('inspect_rawT1', label = 'Inspect Raw T1', show_label = False),
                #Item('inspect_rawdiff', label = 'Inspect Raw Diffusion', show_label = False),
                Item('inspect_registration',
                     label='Registration',
                     show_label=False),
                Item('inspect_segmentation',
                     label='Segmentation',
                     show_label=False),
                #Item('inspect_whitemattermask', label = 'White Matter Mask', show_label = False),
                Item('inspect_parcellation',
                     label='Parcellation',
                     show_label=False),
                #Item('inspect_reconstruction', label = 'Reconstruction', show_label = False), # DTB_viewer
                Item('inspect_tractography',
                     label='Tractography Original',
                     show_label=False),
                Item('inspect_tractography_filtered',
                     label='Tractography Filtered',
                     show_label=False),
                Item('inspect_connectionmatrix',
                     label='Connection Matrix',
                     show_label=False),
                label="Inspector")
            #VGroup(
            #label="Status",
            #)
        ),
        label="Main",
        show_border=False)

    metadata_group = Group(VGroup(
        Item('creator', label="Creator"),
        Item('email', label="E-Mail"),
        Item('publisher', label="Publisher"),
        Item('created', label="Creation Date"),
        Item('modified', label="Modification Date"),
        Item('license', label="License"),
        Item('rights', label="Rights"),
        Item('reference', label="References"),
        Item('relation', label="Relations"),
        Item('species', label="Species"),
        Item('description', label="Project Description"),
    ),
                           label="Metadata",
                           show_border=False)

    subject_group = Group(VGroup(Item('subject_name', label="Name"),
                                 Item('subject_timepoint', label="Timepoint"),
                                 Item('subject_workingdir',
                                      label="Working Directory"),
                                 Item('subject_metadata',
                                      label='Metadata',
                                      editor=table_editor),
                                 show_border=True),
                          label="Subject")

    dicomconverter_group = Group(VGroup(
        Item('do_convert_diffusion', label="Convert Diffusion data?"),
        Item('subject_raw_glob_diffusion',
             label="Diffusion File Pattern",
             enabled_when='do_convert_diffusion'),
        Item('do_convert_T1', label="Convert T1 data?"),
        Item('subject_raw_glob_T1',
             label="T1 File Pattern",
             enabled_when='do_convert_T1'),
        Item('do_convert_T2', label="Convert T2 data?"),
        Item('subject_raw_glob_T2',
             label="T2 File Pattern",
             enabled_when='do_convert_T2'),
        Item('extract_diffusion_metadata',
             label="Try extracting Diffusion metadata"),
        show_border=True),
                                 visible_when="active_dicomconverter",
                                 label="DICOM Converter")

    registration_group = Group(
        VGroup(Item('registration_mode', label="Registration"),
               VGroup(Item('lin_reg_param', label='FLIRT Parameters'),
                      enabled_when='registration_mode == "Linear"',
                      label="Linear Registration"),
               VGroup(Item('nlin_reg_bet_T2_param', label="BET T2 Parameters"),
                      Item('nlin_reg_bet_b0_param', label="BET b0 Parameters"),
                      Item('nlin_reg_fnirt_param', label="FNIRT Parameters"),
                      enabled_when='registration_mode == "Nonlinear"',
                      label="Nonlinear Registration"),
               show_border=True,
               enabled_when="active_registration"),
        visible_when="active_registration",
        label="Registration",
    )

    parcellation_group = Group(
        VGroup(
            Item('parcellation_scheme', label="Parcellation Scheme"),
            #               VGroup(
            #                      Item('custompar_nrroi', label="Number of ROI"),
            #                      Item('custompar_nodeinfo', label="Node Information (GraphML)"),
            #                      Item('custompar_volumeparcell', label="Volumetric parcellation"),
            #                      enabled_when = 'parcellation_scheme == "custom"',
            #                      label = "Custom Parcellation"
            #               ),
            #               show_border = True,
            #               enabled_when = "active_registration"
        ),
        visible_when="active_parcellation",
        label="Parcellation",
    )

    reconstruction_group = Group(
        VGroup(Item('nr_of_gradient_directions',
                    label="Number of Gradient Directions"),
               Item('nr_of_sampling_directions',
                    label="Number of Sampling Directions"),
               Item('nr_of_b0', label="Number of b0 volumes"),
               Item('odf_recon_param', label="odf_recon Parameters"),
               Item('dtb_dtk2dir_param', label="DTB_dtk2dir Parameters"),
               show_border=True,
               visible_when="diffusion_imaging_model == 'DSI'"),
        VGroup(Item('gradient_table', label="Gradient Table"),
               Item('gradient_table_file', label="Gradient Table File"),
               Item('nr_of_b0', label="Number of b0 volumes"),
               Item('max_b0_val', label="Maximum b value"),
               Item('dti_recon_param', label="dti_recon Parameters"),
               Item('dtb_dtk2dir_param', label="DTB_dtk2dir Parameters"),
               show_border=True,
               visible_when="diffusion_imaging_model == 'DTI'"),
        VGroup(
            Item('gradient_table', label="Gradient Table"),
            Item('gradient_table_file', label="Gradient Table File"),
            Item('nr_of_gradient_directions',
                 label="Number of Gradient Directions"),
            Item('nr_of_sampling_directions',
                 label="Number of Sampling Directions"),
            Item('nr_of_b0', label="Number of b0 volumes"),
            #Item('max_b0_val', label="Maximumb b value"),
            Item('hardi_recon_param', label="odf_recon Parameters"),
            Item('dtb_dtk2dir_param', label="DTB_dtk2dir Parameters"),
            show_border=True,
            visible_when="diffusion_imaging_model == 'QBALL'"),
        visible_when="active_reconstruction",
        label="Reconstruction",
    )

    segementation_group = Group(
        VGroup(
            Item('recon_all_param', label="recon_all Parameters"),
            show_border=True,
        ),
        enabled_when="active_segmentation",
        visible_when="active_segmentation",
        label="Segmentation",
    )

    tractography_group = Group(
        VGroup(
            Item('streamline_param', label="DTB_streamline Parameters"),
            show_border=True,
            visible_when=
            "diffusion_imaging_model == 'DSI' or diffusion_imaging_model == 'QBALL'",
        ),
        VGroup(Item('streamline_param_dti', label="dti_tracker Parameters"),
               show_border=True,
               visible_when="diffusion_imaging_model == 'DTI'"),
        enabled_when="active_tractography",
        visible_when="active_tractography",
        label="Tractography",
    )

    fiberfilter_group = Group(
        VGroup(Item('apply_splinefilter', label="Apply spline filter"),
               Item('apply_fiberlength', label="Apply cutoff filter"),
               Item('fiber_cutoff_lower',
                    label='Lower cutoff length (mm)',
                    enabled_when='apply_fiberlength'),
               Item('fiber_cutoff_upper',
                    label='Upper cutoff length (mm)',
                    enabled_when='apply_fiberlength'),
               show_border=True,
               enabled_when="active_fiberfilter"),
        visible_when="active_fiberfilter",
        label="Fiber Filtering",
    )

    connectioncreation_group = Group(
        VGroup(Item('compute_curvature', label="Compute curvature"),
               show_border=True,
               enabled_when="active_connectome"),
        visible_when="active_connectome",
        label="Connectome Creation",
    )

    cffconverter_group = Group(
        VGroup(
            Item('cff_fullnetworkpickle', label="All connectomes"),
            # Item('cff_cmatpickle', label='cmat.pickle'),
            Item('cff_originalfibers', label="Original Tractography"),
            Item('cff_filteredfibers', label="Filtered Tractography"),
            Item('cff_fiberarr', label="Filtered fiber arrays"),
            Item('cff_finalfiberlabels',
                 label="Final Tractography and Labels"),
            Item('cff_scalars', label="Scalar maps"),
            Item('cff_rawdiffusion', label="Raw Diffusion Data"),
            Item('cff_rawT1', label="Raw T1 data"),
            Item('cff_rawT2', label="Raw T2 data"),
            Item('cff_roisegmentation', label="Parcellation Volumes"),
            Item('cff_surfaces',
                 label="Surfaces",
                 tooltip='stores individually generated surfaces'),
            #Item('cff_surfacelabels', label="Surface labels", tooltip = 'stores the labels on the surfaces'),
            show_border=True,
        ),
        visible_when="active_cffconverter",
        label="CFF Converter",
    )

    configuration_group = Group(
        VGroup(
            Item('emailnotify', label='E-Mail Notification'),
            #Item('wm_handling', label='White Matter Mask Handling', tooltip = """1: run through the freesurfer step without stopping
            #2: prepare whitematter mask for correction (store it in subject dir/NIFTI
            #3: rerun freesurfer part with corrected white matter mask"""),
            Item('freesurfer_home', label="Freesurfer Home"),
            Item('fsl_home', label="FSL Home"),
            Item('dtk_home', label="DTK Home"),
            show_border=True,
        ),
        label="Configuration",
    )

    view = View(
        Group(
            HGroup(main_group,
                   metadata_group,
                   subject_group,
                   dicomconverter_group,
                   registration_group,
                   segementation_group,
                   parcellation_group,
                   reconstruction_group,
                   tractography_group,
                   fiberfilter_group,
                   connectioncreation_group,
                   cffconverter_group,
                   configuration_group,
                   orientation='horizontal',
                   layout='tabbed',
                   springy=True),
            spring,
            HGroup(
                Item('about', label='About', show_label=False),
                Item('help', label='Help', show_label=False),
                Item('save', label='Save State', show_label=False),
                Item('load', label='Load State', show_label=False),
                spring,
                Item('run', label='Map Connectome!', show_label=False),
            ),
        ),
        resizable=True,
        width=0.3,
        handler=CMPGUIHandler,
        title='Connectome Mapper',
    )

    def _about_fired(self):
        a = HelpDialog()
        a.configure_traits(kind='livemodal')

    def _help_fired(self):
        a = HelpDialog()
        a.configure_traits(kind='livemodal')

    def load_state(self, cmpconfigfile):
        """ Load CMP Configuration state directly.
        Useful if you do not want to invoke the GUI"""
        import enthought.sweet_pickle as sp
        output = open(cmpconfigfile, 'rb')
        data = sp.load(output)
        self.__setstate__(data.__getstate__())
        # make sure that dtk_matrices is set
        self.dtk_matrices = os.path.join(self.dtk_home, 'matrices')
        # update the subject directory
        if os.path.exists(self.project_dir):
            self.subject_workingdir = os.path.join(self.project_dir,
                                                   self.subject_name,
                                                   self.subject_timepoint)
        output.close()

    def save_state(self, cmpconfigfile):
        """ Save CMP Configuration state directly.
        Useful if you do not want to invoke the GUI
        
        Parameters
        ----------
        cmpconfigfile : string
            Absolute path and filename to store the CMP configuration
            pickled object
        
        """
        # check if path available
        if not os.path.exists(os.path.dirname(cmpconfigfile)):
            os.makedirs(os.path.abspath(os.path.dirname(cmpconfigfile)))

        import enthought.sweet_pickle as sp
        output = open(cmpconfigfile, 'wb')
        # Pickle the list using the highest protocol available.
        # copy object first
        tmpconf = CMPGUI()
        tmpconf.copy_traits(self)
        sp.dump(tmpconf, output, -1)
        output.close()

    def show(self):
        """ Show the GUI """
        #self.configure_traits()
        self.edit_traits(kind='livemodal')

#    def _gradient_table_file_default(self):
#    	return self.get_gradient_table_file()

# XXX this is not automatically invoked!

    def _get_gradient_table_file(self):

        if self.gradient_table == 'custom':
            gradfile = self.get_custom_gradient_table()
        else:
            gradfile = self.get_cmp_gradient_table(self.gradient_table)

        if not os.path.exists(gradfile):
            msg = 'Selected gradient table %s does not exist!' % gradfile
            raise Exception(msg)

        return gradfile

    def _project_dir_changed(self, value):
        self.subject_workingdir = value

    def _subject_name_changed(self, value):
        self.subject_workingdir = os.path.join(self.project_dir, value,
                                               self.subject_timepoint)

    def _subject_timepoint_changed(self, value):
        self.subject_workingdir = os.path.join(self.project_dir,
                                               self.subject_name, value)

    def _gradient_table_changed(self, value):
        if value == 'custom':
            self.gradient_table_file = self.get_custom_gradient_table()
        else:
            self.gradient_table_file = self.get_cmp_gradient_table(value)

        if not os.path.exists(self.gradient_table_file):
            msg = 'Selected gradient table %s does not exist!' % self.gradient_table_file
            raise Exception(msg)

    def _parcellation_scheme_changed(self, value):
        if value == "Lausanne2008":
            self.parcellation = self._get_lausanne_parcellation(
                parcel="Lausanne2008")
        else:
            self.parcellation = self._get_lausanne_parcellation(
                parcel="NativeFreesurfer")

    def _inspect_registration_fired(self):
        cmp.registration.inspect(self)

    def _inspect_tractography_fired(self):
        cmp.tractography.inspect(self)

    def _inspect_tractography_filtered_fired(self):
        cmp.fiberfilter.inspect(self)

    def _inspect_segmentation_fired(self):
        cmp.freesurfer.inspect(self)

    def _inspect_parcellation_fired(self):
        cmp.maskcreation.inspect(self)

    def _inspect_connectionmatrix_fired(self):
        cmp.connectionmatrix.inspect(self)

    def _run_fired(self):
        pass
        # execute the pipeline thread

        # first do a consistency check
        #self.consistency_check()

        # otherwise store the pickle
        #self.save_state(os.path.join(self.get_log(), self.get_logname(suffix = '.pkl')) )

        # hide the gui
        # run the pipeline
        #print "mapit"
        #cmp.connectome.mapit(self)
        # show the gui

        #cmpthread = CMPThread(self)
        #cmpthread.start()

    def _load_fired(self):
        import enthought.sweet_pickle as sp
        from enthought.pyface.api import FileDialog, OK

        wildcard = "CMP Configuration State (*.pkl)|*.pkl|" \
                        "All files (*.*)|*.*"
        dlg = FileDialog(wildcard=wildcard,title="Select a configuration state to load",\
                         resizeable=False, \
                         default_directory=self.project_dir,)

        if dlg.open() == OK:
            if not os.path.isfile(dlg.path):
                return
            else:
                self.load_state(dlg.path)

    def _save_fired(self):
        import pickle
        import enthought.sweet_pickle as sp
        import os.path
        from enthought.pyface.api import FileDialog, OK

        wildcard = "CMP Configuration State (*.pkl)|*.pkl|" \
                        "All files (*.*)|*.*"
        dlg = FileDialog(wildcard=wildcard,title="Filename to store configuration state",\
                         resizeable=False, action = 'save as', \
                         default_directory=self.subject_workingdir,)

        if dlg.open() == OK:
            if not dlg.path.endswith('.pkl'):
                dlg.path = dlg.path + '.pkl'
            self.save_state(dlg.path)
예제 #14
0
class FitProcessor(HasTraits):
    """ A traits based class for simplifying multidimensional nonlinear 
        least squares function fitting.                    
    """
    fit_data = Instance(FitData)  #holds the data
    fit_model = Instance(
        FitModel
    )  #holds the data selection, fit parameters, and functional model

    error_func = Function  #automatically generated error function
    optimizer = Instance(
        NLSOptimizer
    )  #optimizer for error function, obtains best fit parameters
    iter_num = Int(0)  #keep track of the optimizer iterations
    fit_log = Str("")  #stores the info from fitting in a YAML format

    view = View(Item('optimizer', label='Optimizer', style='custom'),
                resizable=True,
                height=0.75,
                width=0.25)

    #--------------------------------------------------------------------------
    #@on_trait_change('fit_model')
    def update_error_func(self):
        #freeze out copies of data and parameters
        fp_names = self.fit_model.get_free_param_names()
        pdict = self.fit_model.get_params_dict()
        X, Y, W = self.fit_data.get_selection()
        #create the function closure on the free (varied) parameter set
        func = self.fit_model.evaluate

        def varied_func(p):
            pdict.update(
                dict([(fp_name, val) for fp_name, val in zip(fp_names, p)]))
            return func(X=X, pdict=pdict)

        #create the error function closure:
        def error_func(p):
            F = varied_func(
                p)  #evaulate the varied function on the parameter set, p
            errs = [
                (y - f) * w for y, f, w in zip(Y, F, W)
            ]  #pair each data row with its function evalutation element and weighting
            errs = hstack(
                errs)  #create a lumped row vector of all the deviations
            return errs

        self.error_func = error_func

    #--------------------------------------------------------------------------
    def fit(self):
        "run the optimizer on the error function to obtain best fit parameters" ""
        error_func = self.error_func
        P0 = self.fit_model.get_free_param_values()
        self._clear_log()  #empty the fitting log
        self._print_log("## Starting Fit ##")
        if len(P0) > 0:  #free parameter set cannot be empty for fitting
            self.iter_num = 0  #reset iteration counter
            self.optimizer = NLSOptimizer(cost_map=error_func, P0=P0)
            self.optimizer.optimize()
            #determine if fitting was successful
            success = self.optimizer.success
            msg = self.optimizer.message
            if success:
                #update the parameters
                fp_values = self.optimizer.P
                fp_names = self.fit_model.get_free_param_names()
                for name, value in zip(fp_names, fp_values):
                    self.fit_model.update_param(name, value=value)
                #compute the error on the parameters, if possible
                err = self.optimizer.cost
                ndf = self.optimizer.ndf
                reduced_chisqr = (err * err).sum() / ndf
                covar = self.optimizer.covar
                if not covar is Undefined:
                    covar *= reduced_chisqr  #rescale the covariance matrix
                    p_var = covar.diagonal()
                    p_err = sqrt(p_var)
                    for name, error in zip(fp_names, p_err):
                        self.fit_model.update_param(name, error=error)
                self._print_log("## Fitting Completed ##")
                self._print_log("---")
                self._print_log("parameters:")
                self._print_log(self.fit_model.params, level=2, indent='  ')
                self._print_log("ndf: %d" % ndf)
                self._print_log("reduced_chisqr: %g" % reduced_chisqr)
            else:
                self._print_log("## Fitting Failed! ##")
                self._print_log("---")
            self._print_log("ierr: %s" % self.optimizer.ier)
            self._print_log("message: %s" % msg)
            self._print_log("...")
        else:  #empty free parameter set, do not fit
            pass

    def _clear_log(self):
        self.fit_log = ""

    def _print_log(self, text, indent="\t", level=0, newline='\n'):
        text = str(text)
        if level >= 1:  #reformat the text to indent it
            text_lines = text.split(newline)
            space = indent * level
            text_lines = ["%s%s" % (space, line) for line in text_lines]
            text = newline.join(text_lines)
        self.fit_log += text + newline
예제 #15
0
class CheckListTest(Handler):

    #---------------------------------------------------------------------------
    #  Trait definitions:
    #---------------------------------------------------------------------------

    case = Enum('Colors', 'Numbers')
    value = List(editor=CheckListEditor(values=colors, cols=5))

    #---------------------------------------------------------------------------
    #  Event handlers:
    #---------------------------------------------------------------------------

    def object_case_changed(self, info):
        if self.case == 'Colors':
            info.value.factory.values = colors
        else:
            info.value.factory.values = numbers


#-------------------------------------------------------------------------------
#  Run the tests:
#-------------------------------------------------------------------------------

if __name__ == '__main__':
    clt = CheckListTest()
    clt.configure_traits(view=View('case', '_', Item('value', id='value')))
    print 'value:', clt.value
    clt.configure_traits(view=View('case', '_', Item('value@', id='value')))
    print 'value:', clt.value
예제 #16
0
class Figure(HasTraits):
    figure = Instance(MPL_Figure, transient=True)
    process_selection = Function
    view = View(Item('figure',
                     editor=MPLFigureEditor(),
                     width=400,
                     show_label=False,
                     height=300),
                resizable=True)

    def _process_selection_default(self):
        def f(point0, point1):
            pass

        return f

    def _figure_default(self):
        self.figure = MPL_Figure()
        image = zeros(shape=(300, 400), dtype='uint8')
        self.update_image(image)
        return figure

    def append_selector(self):
        def line_select_callback(event1, event2):
            'event1 and event2 are the press and release events'
            pos1 = event1.xdata, event1.ydata
            pos2 = event2.xdata, event2.ydata
            self.process_selection(pos1, pos2)

        ax = self.figure.add_subplot(111)
        RectangleSelector(ax,
                          line_select_callback,
                          drawtype='box',
                          useblit=True,
                          minspanx=0,
                          minspany=0,
                          spancoords='pixels')

    def update_image(self, data):
        ax = self.figure.add_subplot(111)
        ax.set_autoscale_on(True)
        ax.images = []
        try:
            ax.imshow(data, interpolation='nearest')
            self.append_selector()
        except:
            pass
        finally:
            ax.set_autoscale_on(False)
            self.figure.canvas.draw()

    def plot_data(self, x, y, name='data 0', color='black'):
        ax = self.figure.add_subplot(111)
        ax.plot(x, y, color)
        ax.text(x[0], y[0], name, color=color)
        self.figure.canvas.draw()

    def del_plot(self, name):
        if name == 'all':
            ax = self.figure.add_subplot(111)
            ax.lines = []
            ax.texts = []
예제 #17
0
class CheckListEditorDemo(HasTraits):
    """ Define the main CheckListEditor demo class. """

    # Define a trait for each of three formations:
    checklist_4col = List(
        editor=CheckListEditor(values=['one', 'two', 'three', 'four'], cols=4))

    checklist_2col = List(
        editor=CheckListEditor(values=['one', 'two', 'three', 'four'], cols=2))

    checklist_1col = List(
        editor=CheckListEditor(values=['one', 'two', 'three', 'four'], cols=1))

    # CheckListEditor display with four columns:
    cl_4_group = Group(Item('checklist_4col', style='simple', label='Simple'),
                       Item('_'),
                       Item('checklist_4col', style='custom', label='Custom'),
                       Item('_'),
                       Item('checklist_4col', style='text', label='Text'),
                       Item('_'),
                       Item('checklist_4col',
                            style='readonly',
                            label='ReadOnly'),
                       label='4-column')

    # CheckListEditor display with two columns:
    cl_2_group = Group(Item('checklist_2col', style='simple', label='Simple'),
                       Item('_'),
                       Item('checklist_2col', style='custom', label='Custom'),
                       Item('_'),
                       Item('checklist_2col', style='text', label='Text'),
                       Item('_'),
                       Item('checklist_2col',
                            style='readonly',
                            label='ReadOnly'),
                       label='2-column')

    # CheckListEditor display with one column:
    cl_1_group = Group(Item('checklist_1col', style='simple', label='Simple'),
                       Item('_'),
                       Item('checklist_1col', style='custom', label='Custom'),
                       Item('_'),
                       Item('checklist_1col', style='text', label='Text'),
                       Item('_'),
                       Item('checklist_1col',
                            style='readonly',
                            label='ReadOnly'),
                       label='1-column')

    # The view includes one group per column formation.  These will be displayed
    # on separate tabbed panels.
    view1 = View(cl_4_group,
                 cl_2_group,
                 cl_1_group,
                 title='CheckListEditor',
                 buttons=['OK'],
                 resizable=True)
예제 #18
0
class BasePlot(HasTraits):
    """ An interface defining an object which can render a plot on a 
        figure object
    """
    implements(IPlot)
    n_x = Int(1)
    n_y = Int(1)
    figure = Instance(Figure, ())
    view = View(
        Item(
            'figure',
            #height = 600,
            #width  = 800,
            style='custom',
            show_label=False,
            editor=MPLFigureEditor(
            ),  #this editor will automatically find and connect the _handle_onpick method for handling matplotlib's object picking events )
        ), )

    def clear(self):
        self.figure.clear()

    def render(self, Xs, Ys, fmts=None, labels=None, pickable=[], **kwargs):
        ''' Plots data from 'Xs', 'Ys' on 'figure'  and returns the 
            figure object'''
        data = self._convert_data(Xs, Ys)
        Xs = data['Xs']
        Ys = data['Ys']
        if fmts is None:
            fmts = []
        if labels is None:
            labels = []
        axes = self.figure.add_subplot(111)
        #kwargs['axes']   = axes
        #kwargs['figure'] = self.figure
        for X, Y, fmt, label in map(None, Xs, Ys, fmts, labels):
            if not (X is None or Y is None):
                kwargs['label'] = label
                self._plot(X, Y, fmt, axes=axes, **kwargs)
        if labels:
            axes.legend()
        #set up the plot point object selection
        for ind in pickable:
            line = axes.lines[ind]
            line.set_picker(5.0)

    def redraw(self):
        if not self.figure.canvas is None:
            self.figure.canvas.draw()

    def register_onpick_handler(self, handler):
        self._handle_onpick = handler

    def _plot(self, x, y, fmt=None, axes=None, **kwargs):
        if axes is None:
            raise TypeError, "an 'axes' object must be supplied"
        if fmt is None:
            axes.plot(x, y, **kwargs)
        else:
            axes.plot(x, y, fmt, **kwargs)

    def _convert_data(self, Xs=None, Ys=None):
        #convert the data for the independent variables
        data_args = {'Xs': (Xs, self.n_x), 'Ys': (Ys, self.n_y)}
        data = {}
        for name, args in data_args.items():
            D, n = args  #data array, expected number of variables
            if not D is None:
                for d in D:
                    print d.shape
                D = array(D)  #convert to a numpy array
                dim = len(D.shape)
                if dim == 1:
                    if n == 1:
                        #upconvert 1D array to 2D
                        D = D.reshape((1, -1))
                    else:
                        raise TypeError, "'%s' dimension must be 2 or 3 for n > 1, detected incommensurate data of dimension %d" % (
                            name, dim)
                elif dim == 2:
                    d1, d2 = D.shape
                    if n == 1:
                        pass  #no conversion needed
                    elif not (d1 == n):
                        raise TypeError, "'%s' shape (%d,%d) must match (n=%d,:)" % (
                            name, d1, d2, n)
                    else:
                        #up convert 2D array to 3D
                        D = D.reshape((1, d1, d2))
                elif dim == 3:
                    d1, d2, d3 = D.shape
                    if n == 1 and d2 == 1:
                        #down convert 3D array to 2D
                        D = D.reshape((d1, d3))
                    elif not (d2 == n):
                        raise TypeError, "'%s' shape (%d,%d,%d) must match (:,n=%d,:)" % (
                            name, d1, d2, d3, n)
                else:
                    raise TypeError, "'%s' dimension must be 1, 2 or 3, detected incommensurate data of dimension %d" % (
                        name, dim)
                data[name] = D
            else:
                #default to an empty array
                data[name] = array([])
        return data
예제 #19
0
from enthought.traits.api import HasTraits, Str, Int, Bool
from enthought.traits.ui.api import View, Group, Item

#--[Code]-----------------------------------------------------------------------


# Sample class
class House(HasTraits):
    address = Str
    bedrooms = Int
    pool = Bool
    price = Int


# View object designed to display two objects of class 'House'
comp_view = View(Group(Group(Item('h1.address', resizable=True),
                             Item('h1.bedrooms'),
                             Item('h1.pool'),
                             Item('h1.price'),
                             show_border=True),
                       Group(Item('h2.address', resizable=True),
                             Item('h2.bedrooms'),
                             Item('h2.pool'),
                             Item('h2.price'),
                             show_border=True),
                       orientation='horizontal'),
                 title='House Comparison')

# A pair of houses to demonstrate the View
house1 = House(address='4743 Dudley Lane',
               bedrooms=3,
예제 #20
0
class ScatterPlotNM(MutableTemplate):

    #-- Template Traits --------------------------------------------------------

    # The title of the plot:
    title = TStr('NxM Scatter Plots')

    # The type of marker to use.  This is a mapped trait using strings as the
    # keys:
    marker = marker_trait(template='copy', event='update')

    # The pixel size of the marker (doesn't include the thickness of the
    # outline):
    marker_size = TRange(1, 5, 1, event='update')

    # The thickness, in pixels, of the outline to draw around the marker.  If
    # this is 0, no outline will be drawn.
    line_width = TRange(0.0, 5.0, 1.0)

    # The fill color of the marker:
    color = TColor('red', event='update')

    # The color of the outline to draw around the marker
    outline_color = TColor('black', event='update')

    # The number of rows of plots:
    rows = TRange(1, 3, 1, event='grid')

    # The number of columns of plots:
    columns = TRange(1, 5, 1, event='grid')

    # The contained scatter plots:
    scatter_plots = TList(ScatterPlot)

    #-- Derived Traits ---------------------------------------------------------

    plot = TDerived

    #-- Traits UI Views --------------------------------------------------------

    # The scatter plot view:
    template_view = View(VGroup(
        Item('title',
             show_label=False,
             style='readonly',
             editor=ThemedTextEditor(theme=Theme('@GBB', alignment='center'))),
        Item('plot',
             show_label=False,
             resizable=True,
             editor=EnableEditor(),
             item_theme=Theme('@GF5', margins=0))),
                         resizable=True)

    # The scatter plot options view:
    options_view = View(
        VGroup(
            VGroup(Label('Scatter Plot Options',
                         item_theme=Theme('@GBB', alignment='center')),
                   show_labels=False),
            VGroup(Item('title', editor=TextEditor()),
                   Item('marker'),
                   Item('marker_size', editor=ThemedSliderEditor()),
                   Item('line_width',
                        label='Line Width',
                        editor=ThemedSliderEditor()),
                   Item('color', label='Fill Color'),
                   Item('outline_color', label='Outline Color'),
                   Item('rows', editor=ThemedSliderEditor()),
                   Item('columns', editor=ThemedSliderEditor()),
                   group_theme=Theme('@GF5', margins=(-5, -1)),
                   item_theme=Theme('@G0B', margins=0))))

    #-- ITemplate Interface Implementation -------------------------------------

    def activate_template(self):
        """ Converts all contained 'TDerived' objects to real objects using the
            template traits of the object. This method must be overridden in
            subclasses.
            
            Returns
            -------
            None
        """
        plots = []
        i = 0
        for r in range(self.rows):
            row = []
            for c in range(self.columns):
                plot = self.scatter_plots[i].plot
                if plot is None:
                    plot = PlotComponent()
                row.append(plot)
                i += 1
            plots.append(row)

        self.plot = GridPlotContainer(shape=(self.rows, self.columns))
        self.plot.component_grid = plots

    #-- Default Values ---------------------------------------------------------

    def _scatter_plots_default(self):
        """ Returns the default value for the scatter plots list.
        """
        plots = []
        for i in range(self.rows * self.columns):
            plots.append(ScatterPlot())

        self._update_plots(plots)

        return plots

    #-- Trait Event Handlers ---------------------------------------------------

    def _update_changed(self, name, old, new):
        """ Handles a plot option being changed. 
        """
        for sp in self.scatter_plots:
            setattr(sp, name, new)

        self.plot = Undefined

    def _grid_changed(self):
        """ Handles the grid size being changed.
        """
        n = self.rows * self.columns
        plots = self.scatter_plots
        if n < len(plots):
            self.scatter_plots = plots[:n]
        else:
            for j in range(len(plots), n):
                plots.append(ScatterPlot())

        self._update_plots(plots)

        self.template_mutated = True

    #-- Private Methods --------------------------------------------------------

    def _update_plots(self, plots):
        """ Update the data sources for all of the current plots.
        """
        index = None
        i = 0
        for r in range(self.rows):
            for c in range(self.columns):
                sp = plots[i]
                i += 1
                desc = sp.value.description
                col = desc.rfind('[')
                if col >= 0:
                    desc = desc[:col]
                sp.value.description = '%s[%d,%d]' % (desc, r, c)
                sp.value.optional = True

                if index is None:
                    index = sp.index
                    index.description = 'Shared Plot Index'
                    index.optional = True
                else:
                    sp.index = index
예제 #21
0
파일: r50.py 프로젝트: liq07lzucn/radpy
    def get_item(self):

        return Item(name=self.name,
                    style=self.style,
                    visible_when=self.visible_when,
                    format_func=lambda v: '%.2f cm' % v)
예제 #22
0
class Graph(HasTraits):
    """
    绘图组件,包括左边的数据选择控件和右边的绘图控件
    """
    name = Str # 绘图名,显示在标签页标题和绘图标题中
    data_source = Instance(DataSource) # 保存数据的数据源
    figure = Instance(Figure) # 控制绘图控件的Figure对象
    selected_xaxis = Str # X轴所用的数据名
    selected_items = List # Y轴所用的数据列表

    clear_button = Button(u"清除") # 快速清除Y轴的所有选择的数据

    view = View(
        HSplit( # HSplit分为左右两个区域,中间有可调节宽度比例的调节手柄
            # 左边为一个组
            VGroup(
                Item("name"),   # 绘图名编辑框
                Item("clear_button"), # 清除按钮
                Heading(u"X轴数据"),  # 静态文本
                # X轴选择器,用EnumEditor编辑器,即ComboBox控件,控件中的候选数据从
                # data_source的names属性得到
                Item("selected_xaxis", editor=
                    EnumEditor(name="object.data_source.names", format_str=u"%s")),
                Heading(u"Y轴数据"), # 静态文本
                # Y轴选择器,由于Y轴可以多选,因此用CheckBox列表编辑,按两列显示
                Item("selected_items", style="custom", 
                     editor=CheckListEditor(name="object.data_source.names", 
                            cols=2, format_str=u"%s")),
                show_border = True, # 显示组的边框
                scrollable = True,  # 组中的控件过多时,采用滚动条
                show_labels = False # 组中的所有控件都不显示标签
            ),
            # 右边绘图控件
            Item("figure", editor=MPLFigureEditor(), show_label=False, width=600)
        )        
    )

    def _name_changed(self):
        """
        当绘图名发生变化时,更新绘图的标题
        """
        axe = self.figure.axes[0]
        axe.set_title(self.name)
        self.figure.canvas.draw()

    def _clear_button_fired(self):
        """
        清除按钮的事件处理
        """
        self.selected_items = []
        self.update()

    def _figure_default(self):
        """
        figure属性的缺省值,直接创建一个Figure对象
        """
        figure = Figure()
        figure.add_axes([0.1, 0.1, 0.85, 0.80]) #添加绘图区域,四周留有边距
        return figure

    def _selected_items_changed(self):
        """
        Y轴数据选择更新
        """
        self.update()

    def _selected_xaxis_changed(self):
        """
        X轴数据选择更新
        """    
        self.update()

    def update(self):
        """
        重新绘制所有的曲线
        """    
        axe = self.figure.axes[0]
        axe.clear()
        try:
            xdata = self.data_source.data[self.selected_xaxis]
        except:
            return 
        for field in self.selected_items:
            axe.plot(xdata, self.data_source.data[field], label=field)
        axe.set_xlabel(self.selected_xaxis)
        axe.set_title(self.name)
        axe.legend()
        self.figure.canvas.draw()
예제 #23
0
        except:
            return 1.

    return cmp(getNum(x), getNum(y))


filenames_view = View(
    Group(

        #Item('directory', style = 'simple'),
        #Item('pattern', style = 'simple'),
        Item('filenames',
             style='custom',
             editor=ListStrEditor(
                 selected='selected',
                 operations=['insert', 'edit', 'move', 'delete', 'append'],
                 auto_add=True,
                 drag_move=True),
             height=-100,
             width=-300),
        Item('is_reversed', style='simple'),
    ),
    Item('from_directory_bttn', show_label=False),
    #                    statusbar = [ StatusItem( name = 'error')],
    resizable=True,
)


def filenames_from_list(filenames):
    """A helper function. Returns a :class:`Filenames` object from a given filenames list
    
class FieldExplorer(HasTraits):
    scene = Instance(SceneModel, ())
    wire = Instance(WireLoop)

    interact = Bool(False)
    ipl = Instance(tvtk.PlaneWidget, (), {
        'resolution': 50,
        'normal': [1., 0., 0.]
    })
    #plane_src = Instance(tvtk.PlaneSource, ())
    calc_B = Instance(tvtk.ProgrammableFilter, ())

    glyph = Instance(tvtk.Glyph3D, (), {'scale_factor': 0.02})
    scale_factor = DelegatesTo("glyph")

    lm = Instance(LUTManager, ())

    traits_view = View(HSplit(
        Item("scene", style="custom", editor=SceneEditor(), show_label=False),
        VGroup(Item("wire", style="custom", show_label=False),
               Item("interact"), Item("scale_factor"), Item("lm")),
    ),
                       resizable=True,
                       width=700,
                       height=600)

    def _interact_changed(self, i):
        self.ipl.interactor = self.scene.interactor
        self.ipl.place_widget()
        if i:
            self.ipl.on()
        else:
            self.ipl.off()

    def make_probe(self):
        src = self.ipl.poly_data_algorithm

        map = tvtk.PolyDataMapper(lookup_table=self.lm.lut)
        act = tvtk.Actor(mapper=map)

        calc_B = self.calc_B
        calc_B.input = src.output

        def execute():
            print "calc fields!"
            output = calc_B.poly_data_output
            points = output.points.to_array().astype('d')
            nodes = self.wire.nodes.astype('d')
            vectors = calc_wire_B_field(nodes, points, self.wire.radius)
            output.point_data.vectors = vectors
            mag = np.sqrt((vectors**2).sum(axis=1))
            map.scalar_range = (mag.min(), mag.max())

        calc_B.set_execute_method(execute)

        cone = tvtk.ConeSource(height=0.05, radius=0.01, resolution=15)
        cone.update()

        glyph = self.glyph
        glyph.input_connection = calc_B.output_port
        glyph.source = cone.output
        glyph.scale_mode = 'scale_by_vector'
        glyph.color_mode = 'color_by_vector'

        map.input_connection = glyph.output_port
        self.scene.add_actor(act)

    def on_update(self):
        self.calc_B.modified()
        self.scene.render()

    def _wire_changed(self, anew):
        anew.on_trait_change(self.on_update, "update")
        self.scene.add_actor(anew.actor)
예제 #25
0
파일: fit.py 프로젝트: pmrup/labtools
class DlsAnalyzer(BaseFileAnalyzer):
    """
    DlsAnalyzer is used to analyze multiple dls files. First you must define a function 
    that returns x value for the data analyzed. A default function :attr:'get_x_value' returns just index value.
    This function must have two erguments as an input: index value and filename.
    It is up to you how the return value uses these inputs. For instance:
    
    >>> def get_x(fnames, index):
    ...    return 100 + 0.1 * index
    
    Then create :class:`Filenames` instance (optional)
    
    >>> filenames = Filenames(directory = '../testdata', pattern = *.ASC)
    
    Now you cen create analyzer and do some analysis
    
    >>> fitter = create_dls_fitter('single_stretch_exp')
    >>> analyzer = DlsAnalyzer(filenames = filenames, 
    ...                        fitter = fitter, 
    ...                        get_x_value = get_x)
    
    >>> analyzer.log_name = 'analysis.rst' #specify logname to log results in reStructuredText format
    >>> analyzer.constants = (('s','n'),()) #set constant parameters in fitting process, 
    >>> analyzer.x_name = 'position' #specify x data name
    
    When everything is set you can call process to fit all data.
    
    >>> analyzer.process()
    >>> analyzer.save_result('..testdata/output.npy')
    
    """
    #: Filenames instance
    filenames = Instance(Filenames, ())
    #: selected filename
    selected = DelegatesTo('filenames')
    #: data fitter object for data fitting
    fitter = Instance(DlsFitter)
    #: defines a list of constants tuple that are set in each fit run. See :meth:`process`
    constants = List(List(Str))
    #: defines whethere fit plots are saved
    saves_fits = Bool(False)
    #: if defined it will generate a valif reStructuredText file
    log_name = Str
    #: actual log is written here
    log = Str
    #: fit results are storred here
    results = Instance(StructArrayData, ())
    #: This function is used to get x value from index integer and filename string
    get_x_value = Function
    #: this specifies name of the x data of results
    x_name = Str('index')
    #: if this list is not empty it will be used to obtain x_values
    x_values = List(Float)

    view = View(Group(dls_analyzer_group, 'saves_fits', 'results'),
                Item('fitter', style='custom'),
                resizable=True)

    @on_trait_change('selected')
    def _open_dls(self, name):
        self.fitter.open_dls(name)
        self.fitter._plot()

    def _constants_default(self):
        return [['f', 's'], ['']]

    def _get_x_value_default(self):
        def get(fnames, index):
            return index

        return get

    def _selected_changed(self):
        self.process_selected()

    def process_selected(self):
        """Opens fname and fits data according to self.constants
        
        :param str fname: 
            filename of asc data to be opened and fitted
        """
        fname = self.selected
        self.fitter.open_dls(fname)
        print(self.constants)
        for constants in self.constants:
            try:
                self.fitter.fit(constants=constants)
            except:
                self.fitter.configure_traits()
        if self.saves_fits:
            path, fname = os.path.split(fname)
            path = os.path.join(path, 'fits')
            try:
                os.mkdir(path)
            except:
                pass
            fname = os.path.join(path, fname)
            imagename = fname + '.png'
            log.info('Plotting %s' % imagename)
            self.fitter.plotter.title = imagename
            self.fitter.plotter.savefig(imagename)
        result = self.fitter.function.get_parameters()
        self._process_result(result, self.selected, self.index)
        return result

    def _process_result(self, result, fname, index):
        result = (i for sub in result
                  for i in sub)  #flatten results list first
        try:
            self.results.data[index] = (self.x_values[index], ) + tuple(result)
        except:
            self.results.data[index] = (self.get_x_value(
                self.filenames.filenames, index), ) + tuple(result)
        self.results.data_updated = True

    @on_trait_change('filenames.filenames')
    def _init(self):
        array_names = [self.x_name]
        for name in self.fitter.function.pnames:
            array_names.append(name)
            array_names.append(name + '_err')

        dtype = np.dtype(list(zip(array_names, ['float'] * len(array_names))))
        self.results = StructArrayData(
            data=np.zeros(len(self.filenames), dtype=dtype))
        #self.results_err = StructArrayData(data = np.zeros(len(self.filenames), dtype = dtype))
        self.results.data_updated = True
        #self.results_err.data_updated = True
        #self.log = '===========\nFit results\n===========\n\n'
        return True

    def save_results(self, fname):
        """Saves results to disk
        
        :param str fname:
            output filename
        """
        np.save(fname, self.results.data)
        if self.log_name:
            self.log = '===========\nFit results\n===========\n\n'
            for fname in self.filenames.filenames:
                imagename = fname + '.png'
                self.log += '.. image:: %s\n' % os.path.basename(imagename)
            with open(self.log_name, 'w') as f:
                f.write(self.log)
예제 #26
0
class RangeEditorDemo(HasTraits):
    """ This class specifies the details of the RangeEditor demo.
    """

    # Define a trait for each of four variants
    small_int_range = Range(1, 16)
    medium_int_range = Range(1, 25)
    large_int_range = Range(1, 150)
    float_range = Range(0.0, 150.0)

    # RangeEditor display for narrow integer Range traits (< 17 wide):
    int_range_group1 = Group(Item('small_int_range',
                                  style='simple',
                                  label='Simple'),
                             Item('_'),
                             Item('small_int_range',
                                  style='custom',
                                  label='Custom'),
                             Item('_'),
                             Item('small_int_range',
                                  style='text',
                                  label='Text'),
                             Item('_'),
                             Item('small_int_range',
                                  style='readonly',
                                  label='ReadOnly'),
                             label="Small Int")

    # RangeEditor display for medium-width integer Range traits (17 to 100):
    int_range_group2 = Group(Item('medium_int_range',
                                  style='simple',
                                  label='Simple'),
                             Item('_'),
                             Item('medium_int_range',
                                  style='custom',
                                  label='Custom'),
                             Item('_'),
                             Item('medium_int_range',
                                  style='text',
                                  label='Text'),
                             Item('_'),
                             Item('medium_int_range',
                                  style='readonly',
                                  label='ReadOnly'),
                             label="Medium Int")

    # RangeEditor display for wide integer Range traits (> 100):
    int_range_group3 = Group(Item('large_int_range',
                                  style='simple',
                                  label='Simple'),
                             Item('_'),
                             Item('large_int_range',
                                  style='custom',
                                  label='Custom'),
                             Item('_'),
                             Item('large_int_range',
                                  style='text',
                                  label='Text'),
                             Item('_'),
                             Item('large_int_range',
                                  style='readonly',
                                  label='ReadOnly'),
                             label="Large Int")

    # RangeEditor display for float Range traits:
    float_range_group = Group(Item('float_range',
                                   style='simple',
                                   label='Simple'),
                              Item('_'),
                              Item('float_range',
                                   style='custom',
                                   label='Custom'),
                              Item('_'),
                              Item('float_range', style='text', label='Text'),
                              Item('_'),
                              Item('float_range',
                                   style='readonly',
                                   label='ReadOnly'),
                              label="Float")

    # The view includes one group per data type.  These will be displayed
    # on separate tabbed panels.
    view1 = View(int_range_group1,
                 int_range_group2,
                 int_range_group3,
                 float_range_group,
                 title='RangeEditor',
                 buttons=['OK'])
예제 #27
0
class SetStep(HasTraits):

    _viewer = Instance(Viewer)
    _source = Instance(FileSource)

    seq_start = Int(0)
    seq_stop = Int(-1)
    seq_step = Int(1)

    seq_t0 = Float
    seq_t1 = Float
    seq_dt = Float
    seq_n_step = Int

    _step_editor = RangeEditor(low_name='step_low',
                               high_name='step_high',
                               label_width=28,
                               auto_set=True,
                               mode='slider')
    step = None
    step_low = Int
    step_high = Int

    _time_editor = RangeEditor(low_name='time_low',
                               high_name='time_high',
                               label_width=28,
                               auto_set=True,
                               mode='slider')
    time = None
    time_low = Float
    time_high = Float

    file_changed = Bool(False)

    is_adjust = False

    traits_view = View(
        Item('step', defined_when='step is not None', editor=_step_editor),
        Item('time', defined_when='time is not None', editor=_time_editor),
        HGroup(Heading('steps:'), Item('seq_start', label='start'),
               Item('seq_stop', label='stop'), Item('seq_step', label='step'),
               Heading('times:'), Item('seq_t0', label='t0'),
               Item('seq_t1', label='t1'), Item('seq_dt', label='dt'),
               Item('seq_n_step', label='n_step')),
    )

    def __source_changed(self, old, new):
        steps = self._source.steps
        if len(steps):
            self.add_trait('step', Int(0))
            self.step_low, self.step_high = steps[0], steps[-1]

        times = self._source.times
        if len(times):
            self.add_trait('time', Float(0.0))
            self.time_low, self.time_high = times[0], times[-1]

    def _step_changed(self, old, new):
        if new == old: return
        if not self.is_adjust:
            step, time = self._source.get_step_time(step=new)
            self.is_adjust = True
            self.step = step
            self.time = time
            self.is_adjust = False

            self._viewer.set_source_filename(self._source.filename)

    def _time_changed(self, old, new):
        if new == old: return
        if not self.is_adjust:
            step, time = self._source.get_step_time(time=new)
            self.is_adjust = True
            self.step = step
            self.time = time
            self.is_adjust = False

            self._viewer.set_source_filename(self._source.filename)

    def _file_changed_changed(self, old, new):
        if new == True:
            steps = self._source.steps
            if len(steps):
                self.step_low, self.step_high = steps[0], steps[-1]

            times = self._source.times
            if len(times):
                self.time_low, self.time_high = times[0], times[-1]

        self.file_changed = False

    @on_trait_change('step_high, time_high')
    def init_seq_selection(self, name, new):
        self.seq_t0 = self.time_low
        self.seq_t1 = self.time_high
        self.seq_n_step = self.step_high - self.step_low + 1
        self.seq_dt = (self.seq_t1 - self.seq_t0) / self.seq_n_step

        self.seq_start = self.step_low
        self.seq_stop = self.step_high + 1

        if name == 'time_high':
            self.on_trait_change(self.init_seq_selection,
                                 'time_high',
                                 remove=True)

    def _seq_n_step_changed(self, old, new):
        if new == old: return
        self.seq_dt = (self.seq_t1 - self.seq_t0) / self.seq_n_step

    def _seq_dt_changed(self, old, new):
        if new == old: return
        if self.seq_dt == 0.0: return
        n_step = int(round((self.seq_t1 - self.seq_t0) / self.seq_dt))
        self.seq_n_step = max(1, n_step)
예제 #28
0
class BuiltinImage(Source):
    # The version of this class.  Used for persistence.
    __version__ = 0

    # Flag to set the image data type.
    source = Enum('ellipsoid','gaussian','grid','mandelbrot','noise',
                  'sinusoid','rt_analytic',
                  desc='which image data source to be used')

    # Define the trait 'data_source' whose value must be an instance of
    # type ImageAlgorithm
    data_source = Instance(tvtk.ImageAlgorithm, allow_none=False, 
                                     record=True)


    # Information about what this object can produce.
    output_info = PipelineInfo(datasets=['image_data'], 
                               attribute_types=['any'],
                               attributes=['any'])

    # Create the UI for the traits.
    view = View(Group(Item(name='source'),
                  Item(name='data_source',
                       style='custom',
                       resizable=True),
                   label='Image Source',
                    show_labels=False),          
             resizable=True)
    
    ########################################
    # Private traits.
    
    # A dictionary that maps the source names to instances of the
    # image data objects. 
    _source_dict = Dict(Str,
                          Instance(tvtk.ImageAlgorithm,
                                   allow_none=False))

    ######################################################################
    # `object` interface
    ######################################################################
    def __init__(self, **traits):
        # Call parent class' init.
        super(BuiltinImage, self).__init__(**traits)

        # Initialize the source to the default mode's instance from
        # the dictionary if needed.
        if 'source' not in traits:
            self._source_changed(self.source)
        
    def __set_pure_state__(self, state):
        self.source = state.source
        super(BuiltinImage, self).__set_pure_state__(state)

    ######################################################################
    # Non-public methods.
    ######################################################################
    def _source_changed(self, value):
        """This method is invoked (automatically) when the `function`
        trait is changed.
        """
        self.data_source = self._source_dict[self.source]
     
    
    def _data_source_changed(self, old, new):
        """This method is invoked (automatically) when the
        image data source is changed ."""

        self.outputs = [self.data_source.output]
        
        if old is not None:
            old.on_trait_change(self.render, remove=True)
        new.on_trait_change(self.render)        
       
    def __source_dict_default(self):
        """The default _source_dict trait."""
        sd = {
              'ellipsoid':tvtk.ImageEllipsoidSource(),
              'gaussian':tvtk.ImageGaussianSource(),
              'grid':tvtk.ImageGridSource(),
              'mandelbrot':tvtk.ImageMandelbrotSource(),
              'noise':tvtk.ImageNoiseSource(),
              'sinusoid':tvtk.ImageSinusoidSource(),
             }
        if hasattr(tvtk, 'RTAnalyticSource'):
            sd['rt_analytic'] = tvtk.RTAnalyticSource()
        else:
            sd['rt_analytic'] = tvtk.ImageNoiseSource()
        return sd
예제 #29
0
class DataSourceWizardView(DataSourceWizard):

    #----------------------------------------------------------------------
    # Private traits
    #----------------------------------------------------------------------

    _top_label = Str('Describe your data')

    _info_text = Str('Array size do not match')

    _array_label = Str('Available arrays')

    _data_type_text = Str("What does your data represents?" )

    _lines_text = Str("Connect the points with lines" )

    _scalar_data_text = Str("Array giving the value of the scalars")

    _optional_scalar_data_text = Str("Associate scalars with the data points")

    _connectivity_text = Str("Array giving the triangles")

    _vector_data_text = Str("Associate vector components")

    _position_text = Property(depends_on="position_type_")

    _position_text_dict = {'explicit':
                'Coordinnates of the data points:',
                           'orthogonal grid':
                'Position of the layers along each axis:',
            }

    def _get__position_text(self):
        return self._position_text_dict.get(self.position_type_, "")

    _shown_help_text = Str

    _data_sources_wrappers = Property(depends_on='data_sources')

    def _get__data_sources_wrappers(self):
         return [
            ArrayColumnWrapper(name=name, 
                shape=repr(self.data_sources[name].shape))
                    for name in self._data_sources_names
                ]
            

    # A traits pointing to the object, to play well with traitsUI
    _self = Instance(DataSourceWizard)

    _suitable_traits_view = Property(depends_on="data_type_")

    def _get__suitable_traits_view(self):
        return "_%s_data_view" % self.data_type_

    ui = Any(False)

    _preview_button = Button(label='Preview structure')

    def __preview_button_fired(self):
        if self.ui:
            self.build_data_source()
            self.preview()

    _ok_button = Button(label='OK')

    def __ok_button_fired(self):
        if self.ui:
            self.ui.dispose()
            self.build_data_source()


    _cancel_button = Button(label='Cancel')

    def __cancel_button_fired(self):
        if self.ui:
            self.ui.dispose()

    _is_ok = Bool

    _is_not_ok = Bool

    def _anytrait_changed(self):
        """ Validates if the OK button is enabled.
        """
        if self.ui:
            self._is_ok =  self.check_arrays()
            self._is_not_ok = not self._is_ok
    
    _preview_window = Instance(PreviewWindow, ())

    _info_image = Instance(ImageResource, 
                    ImageLibrary.image_resource('@std:alert16',))

    #----------------------------------------------------------------------
    # TraitsUI views
    #----------------------------------------------------------------------

    _coordinates_group = \
                        HGroup(
                           Item('position_x', label='x',
                               editor=EnumEditor(name='_data_sources_names',
                                        invalid='_is_not_ok')), 
                           Item('position_y', label='y',
                               editor=EnumEditor(name='_data_sources_names', 
                                        invalid='_is_not_ok')), 
                           Item('position_z', label='z',
                               editor=EnumEditor(name='_data_sources_names', 
                                        invalid='_is_not_ok')), 
                       )


    _position_group = \
                    Group(
                       Item('position_type'),
                       Group(
                           Item('_position_text', style='readonly',
                                    resizable=False,
                                    show_label=False),
                           _coordinates_group,
                           visible_when='not position_type_=="image data"',
                       ),
                       Group(
                           Item('grid_shape_source_',
                            label='Grid shape',
                            editor=EnumEditor(
                                name='_grid_shape_source_labels',
                                        invalid='_is_not_ok')), 
                           HGroup(
                            spring,
                            Item('grid_shape', style='custom', 
                                    editor=ArrayEditor(width=-60),
                                    show_label=False),
                           enabled_when='grid_shape_source==""',
                            ),
                           visible_when='position_type_=="image data"',
                       ),
                       label='Position of the data points',
                       show_border=True,
                       show_labels=False,
                   ),


    _connectivity_group = \
                   Group(
                       HGroup(
                         Item('_connectivity_text', style='readonly',
                                resizable=False),
                         spring,
                         Item('connectivity_triangles',
                                editor=EnumEditor(name='_data_sources_names'),
                                show_label=False,
                                ),
                         show_labels=False,
                       ),
                       label='Connectivity information',
                       show_border=True,
                       show_labels=False,
                       enabled_when='position_type_=="explicit"',
                   ),


    _scalar_data_group = \
                   Group(
                       Item('_scalar_data_text', style='readonly', 
                           resizable=False,
                           show_label=False),
                       HGroup(
                           spring,
                           Item('scalar_data', 
                               editor=EnumEditor(name='_data_sources_names',
                                        invalid='_is_not_ok')), 
                           show_labels=False,
                           ),
                       label='Scalar value',
                       show_border=True,
                       show_labels=False,
                   )


    _optional_scalar_data_group = \
                   Group(
                       HGroup(
                       'has_scalar_data',
                       Item('_optional_scalar_data_text',
                            resizable=False,
                            style='readonly'),
                       show_labels=False,
                       ),
                       Item('_scalar_data_text', style='readonly', 
                            resizable=False,
                            enabled_when='has_scalar_data',
                           show_label=False),
                       HGroup(
                           spring, 
                           Item('scalar_data', 
                               editor=EnumEditor(name='_data_sources_names',
                                        invalid='_is_not_ok'), 
                               enabled_when='has_scalar_data'),
                           show_labels=False,
                           ),
                       label='Scalar data',
                       show_border=True,
                       show_labels=False,
                   ),


    _vector_data_group = \
                   VGroup(
                       HGroup(
                           Item('vector_u', label='u',
                               editor=EnumEditor(name='_data_sources_names', 
                                        invalid='_is_not_ok')), 
                           Item('vector_v', label='v',
                               editor=EnumEditor(name='_data_sources_names', 
                                        invalid='_is_not_ok')), 
                           Item('vector_w', label='w',
                               editor=EnumEditor(name='_data_sources_names', 
                                        invalid='_is_not_ok')), 
                       ),
                       label='Vector data',
                       show_border=True,
                   ),


    _optional_vector_data_group = \
                   VGroup(
                        HGroup(
                            Item('has_vector_data', show_label=False),
                            Item('_vector_data_text', style='readonly', 
                                resizable=False,
                                show_label=False),
                        ),
                       HGroup(
                           Item('vector_u', label='u',
                               editor=EnumEditor(name='_data_sources_names', 
                                        invalid='_is_not_ok')), 
                           Item('vector_v', label='v',
                               editor=EnumEditor(name='_data_sources_names', 
                                        invalid='_is_not_ok')), 
                           Item('vector_w', label='w',
                               editor=EnumEditor(name='_data_sources_names', 
                                        invalid='_is_not_ok')), 
                           enabled_when='has_vector_data',
                       ),
                       label='Vector data',
                       show_border=True,
                   ),


    _array_view = \
                View(
                    Item('_array_label', editor=TitleEditor(),
                        show_label=False),
                    Group(    
                    Item('_data_sources_wrappers', 
                      editor=TabularEditor(
                          adapter = ArrayColumnAdapter(),
                      ), 
                    ),
                    show_border=True,
                    show_labels=False
                ))

    _questions_view = View(
                Item('_top_label', editor=TitleEditor(),
                        show_label=False),
                HGroup(
                    Item('_data_type_text', style='readonly',
                                resizable=False),
                    spring,
                    'data_type',
                    spring,
                    show_border=True,
                    show_labels=False,
                  ),
                HGroup(
                    Item('_self', style='custom', 
                        editor=InstanceEditor(
                                    view_name='_suitable_traits_view'),
                        ),
                    Group(
                        # FIXME: Giving up on context sensitive help
                        # because of lack of time.
                        #Group(
                        #    Item('_shown_help_text', editor=HTMLEditor(), 
                        #        width=300,
                        #        label='Help',
                        #        ),
                        #    show_labels=False,
                        #    label='Help',
                        #),
                        #Group(
                            Item('_preview_button', 
                                    enabled_when='_is_ok'),
                            Item('_preview_window', style='custom',
                                    label='Preview structure'),
                            show_labels=False,
                            #label='Preview structure',
                        #),
                        #layout='tabbed',
                        #dock='tab',
                    ),
                    show_labels=False,
                    show_border=True,
                ),
            )

    _point_data_view = \
                View(Group(
                   Group(_coordinates_group,
                        label='Position of the data points',
                        show_border=True,
                   ),
                   HGroup(
                       'lines',
                       Item('_lines_text', style='readonly',
                                        resizable=False), 
                       label='Lines',
                       show_labels=False,
                       show_border=True,
                   ),
                   _optional_scalar_data_group,
                   _optional_vector_data_group,
                   # XXX: hack to have more vertical space
                   Label('\n'),
                   Label('\n'),
                   Label('\n'),
                ))


    _surface_data_view = \
                View(Group(
                   _position_group,
                   _connectivity_group,
                   _optional_scalar_data_group,
                   _optional_vector_data_group,
                ))


    _vector_data_view = \
                View(Group(
                   _vector_data_group,
                   _position_group,
                   _optional_scalar_data_group,
                ))


    _volumetric_data_view = \
                View(Group(
                   _scalar_data_group,
                   _position_group,
                   _optional_vector_data_group,
                ))


    _wizard_view = View(
          Group(
            HGroup(
                Item('_self', style='custom', show_label=False,
                     editor=InstanceEditor(view='_array_view'),
                     width=0.17,
                     ),
                '_',
                Item('_self', style='custom', show_label=False,
                     editor=InstanceEditor(view='_questions_view'),
                     ),
                ),
            HGroup(
                Item('_info_image', editor=ImageEditor(),
                    visible_when="_is_not_ok"),
                Item('_info_text', style='readonly', resizable=False,
                    visible_when="_is_not_ok"),
                spring, 
                '_cancel_button', 
                Item('_ok_button', enabled_when='_is_ok'),
                show_labels=False,
            ),
          ),
        title='Import arrays',
        resizable=True,
        )


    #----------------------------------------------------------------------
    # Public interface
    #----------------------------------------------------------------------

    def __init__(self, **traits):
        DataSourceFactory.__init__(self, **traits)
        self._self = self


    def view_wizard(self):
        """ Pops up the view of the wizard, and keeps the reference it to
            be able to close it.
        """
        # FIXME: Workaround for traits bug in enabled_when
        self.position_type_
        self.data_type_
        self._suitable_traits_view
        self.grid_shape_source
        self._is_ok
        self.ui = self.edit_traits(view='_wizard_view')


    def preview(self):
        """ Display a preview of the data structure in the preview
            window.
        """
        self._preview_window.clear()
        self._preview_window.add_source(self.data_source)
        data = lambda name: self.data_sources[name]
        g = Glyph()
        g.glyph.glyph_source.glyph_source = \
                    g.glyph.glyph_source.glyph_list[0]
        g.glyph.scale_mode = 'data_scaling_off'
        if not (self.has_vector_data or self.data_type_ == 'vector'):
            g.glyph.glyph_source.glyph_source.glyph_type = 'cross'
            g.actor.property.representation = 'points'
            g.actor.property.point_size = 3.
        self._preview_window.add_module(g)
        if not self.data_type_ in ('point', 'vector') or self.lines:
            s = Surface()
            s.actor.property.opacity = 0.3
            self._preview_window.add_module(s)
        if not self.data_type_ == 'point':
            self._preview_window.add_filter(ExtractEdges())
            s = Surface()
            s.actor.property.opacity = 0.2
            self._preview_window.add_module(s)
예제 #30
0
class GlyphSource(Component):

    # The version of this class.  Used for persistence.
    __version__ = 1

    # Glyph position.  This can be one of ['head', 'tail', 'center'],
    # and indicates the position of the glyph with respect to the
    # input point data.  Please note that this will work correctly
    # only if you do not mess with the source glyph's basic size.  For
    # example if you use a ConeSource and set its height != 1, then the
    # 'head' and 'tail' options will not work correctly.
    glyph_position = Trait('center',
                           TraitPrefixList(['head', 'tail', 'center']),
                           desc='position of glyph w.r.t. data point')

    # The Source to use for the glyph.  This is chosen from
    # `self._glyph_list` or `self.glyph_dict`.
    glyph_source = Instance(tvtk.Object, allow_none=False, record=True)

    # A dict of glyphs to use.
    glyph_dict = Dict(desc='the glyph sources to select from', record=False)

    # A list of predefined glyph sources that can be used.
    glyph_list = Property(List(tvtk.Object), record=False)

    ########################################
    # Private traits.

    # The transformation to use to place glyph appropriately.
    _trfm = Instance(tvtk.TransformFilter, args=())

    # Used for optimization.
    _updating = Bool(False)

    ########################################
    # View related traits.

    view = View(Group(
        Group(Item(name='glyph_position')),
        Group(Item(
            name='glyph_source',
            style='custom',
            resizable=True,
            editor=InstanceEditor(name='glyph_list'),
        ),
              label='Glyph Source',
              show_labels=False)),
                resizable=True)

    ######################################################################
    # `Base` interface
    ######################################################################
    def __get_pure_state__(self):
        d = super(GlyphSource, self).__get_pure_state__()
        for attr in ('_updating', 'glyph_list'):
            d.pop(attr, None)
        return d

    def __set_pure_state__(self, state):
        if 'glyph_dict' in state:
            # Set their state.
            set_state(self, state, first=['glyph_dict'], ignore=['*'])
            ignore = ['glyph_dict']
        else:
            # Set the dict state using the persisted list.
            gd = self.glyph_dict
            gl = self.glyph_list
            handle_children_state(gl, state.glyph_list)
            for g, gs in zip(gl, state.glyph_list):
                name = camel2enthought(g.__class__.__name__)
                if name not in gd:
                    gd[name] = g
                # Set the glyph source's state.
                set_state(g, gs)
            ignore = ['glyph_list']
        g_name = state.glyph_source.__metadata__['class_name']
        name = camel2enthought(g_name)
        # Set the correct glyph_source.
        self.glyph_source = self.glyph_dict[name]
        set_state(self, state, ignore=ignore)

    ######################################################################
    # `Component` interface
    ######################################################################
    def setup_pipeline(self):
        """Override this method so that it *creates* the tvtk
        pipeline.

        This method is invoked when the object is initialized via
        `__init__`.  Note that at the time this method is called, the
        tvtk data pipeline will *not* yet be setup.  So upstream data
        will not be available.  The idea is that you simply create the
        basic objects and setup those parts of the pipeline not
        dependent on upstream sources and filters.  You should also
        set the `actors` attribute up at this point.
        """

        self._trfm.transform = tvtk.Transform()
        # Setup the glyphs.
        self.glyph_source = self.glyph_dict['glyph_source2d']

    def update_pipeline(self):
        """Override this method so that it *updates* the tvtk pipeline
        when data upstream is known to have changed.

        This method is invoked (automatically) when any of the inputs
        sends a `pipeline_changed` event.
        """
        self._glyph_position_changed(self.glyph_position)
        self.pipeline_changed = True

    def update_data(self):
        """Override this method so that it flushes the vtk pipeline if
        that is necessary.

        This method is invoked (automatically) when any of the inputs
        sends a `data_changed` event.
        """
        self.data_changed = True

    def render(self):
        if not self._updating:
            super(GlyphSource, self).render()

    ######################################################################
    # Non-public methods.
    ######################################################################
    def _glyph_source_changed(self, value):
        if self._updating == True:
            return

        gd = self.glyph_dict
        value_cls = camel2enthought(value.__class__.__name__)
        if value not in gd.values():
            gd[value_cls] = value

        # Now change the glyph's source trait.
        self._updating = True
        recorder = self.recorder
        if recorder is not None:
            name = recorder.get_script_id(self)
            lhs = '%s.glyph_source' % name
            rhs = '%s.glyph_dict[%r]' % (name, value_cls)
            recorder.record('%s = %s' % (lhs, rhs))

        name = value.__class__.__name__
        if name == 'GlyphSource2D':
            self.outputs = [value.output]
        else:
            self._trfm.input = value.output
            self.outputs = [self._trfm.output]
        value.on_trait_change(self.render)
        self._updating = False

        # Now update the glyph position since the transformation might
        # be different.
        self._glyph_position_changed(self.glyph_position)

    def _glyph_position_changed(self, value):
        if self._updating == True:
            return

        self._updating = True
        tr = self._trfm.transform
        tr.identity()

        g = self.glyph_source
        name = g.__class__.__name__
        # Compute transformation factor
        if name == 'CubeSource':
            tr_factor = g.x_length / 2.0
        elif name == 'CylinderSource':
            tr_factor = -g.height / 2.0
        elif name == 'ConeSource':
            tr_factor = g.height / 2.0
        elif name == 'SphereSource':
            tr_factor = g.radius
        else:
            tr_factor = 1.
        # Translate the glyph
        if value == 'tail':
            if name == 'GlyphSource2D':
                g.center = 0.5, 0.0, 0.0
            elif name == 'ArrowSource':
                pass
            elif name == 'CylinderSource':
                g.center = 0, tr_factor, 0.0
            elif hasattr(g, 'center'):
                g.center = tr_factor, 0.0, 0.0
        elif value == 'head':
            if name == 'GlyphSource2D':
                g.center = -0.5, 0.0, 0.0
            elif name == 'ArrowSource':
                tr.translate(-1, 0, 0)
            elif name == 'CylinderSource':
                g.center = 0, -tr_factor, 0.0
            else:
                g.center = -tr_factor, 0.0, 0.0
        else:
            if name == 'ArrowSource':
                tr.translate(-0.5, 0, 0)
            elif name != 'Axes':
                g.center = 0.0, 0.0, 0.0

        if name == 'CylinderSource':
            tr.rotate_z(90)

        self._updating = False
        self.render()

    def _get_glyph_list(self):
        # Return the glyph list as per the original order in earlier
        # implementation.
        order = [
            'glyph_source2d', 'arrow_source', 'cone_source', 'cylinder_source',
            'sphere_source', 'cube_source', 'axes'
        ]
        gd = self.glyph_dict
        for key in gd:
            if key not in order:
                order.append(key)
        return [gd[key] for key in order]

    def _glyph_dict_default(self):
        g = {
            'glyph_source2d':
            tvtk.GlyphSource2D(glyph_type='arrow', filled=False),
            'arrow_source':
            tvtk.ArrowSource(),
            'cone_source':
            tvtk.ConeSource(height=1.0, radius=0.2, resolution=15),
            'cylinder_source':
            tvtk.CylinderSource(height=1.0, radius=0.15, resolution=10),
            'sphere_source':
            tvtk.SphereSource(),
            'cube_source':
            tvtk.CubeSource(),
            'axes':
            tvtk.Axes(symmetric=1)
        }
        return g