Esempio n. 1
0
class ActorViewer(HasTraits):

    # The scene model.
    scene = Instance(MlabSceneModel, ())

    ######################
    # Using 'scene_class=MayaviScene' adds a Mayavi icon to the toolbar,
    # to pop up a dialog editing the pipeline.
    view = View(Item(name='scene', 
                     editor=SceneEditor(scene_class=MayaviScene),
                     show_label=False,
                     resizable=True,
                     height=500,
                     width=500),
                resizable=True
                )

    def __init__(self, **traits):
        HasTraits.__init__(self, **traits)
        self.generate_data()

    def generate_data(self):
        # Create some data
        X, Y = mgrid[-2:2:100j, -2:2:100j]
        R = 10*sqrt(X**2 + Y**2)
        Z = sin(R)/R

        self.scene.mlab.surf(X, Y, Z, colormap='gist_earth')
Esempio n. 2
0
class ActorView(HasTraits):
    scene = Instance(SceneModel, ())
    
    traits_view = View(Item("scene", style="custom", editor=SceneEditor(),
                                show_label=False),
                        resizable=True, width=700, height=600
                        )
Esempio n. 3
0
class Visualization(HasTraits):
    alpha = Range(0.0, 4.0, 1.0 / 4)
    beta = Range(0.0, 4.0, 1.0 / 4)
    scene = Instance(MlabSceneModel, ())

    def __init__(self):
        # Do not forget to call the parent's __init__
        HasTraits.__init__(self)
        x, y, z, = tens_fld(1, 1, 1, self.beta, self.alpha)
        self.plot = self.scene.mlab.mesh(x,
                                         y,
                                         z,
                                         colormap='copper',
                                         representation='surface')

    @on_trait_change('beta,alpha')
    def update_plot(self):
        x, y, z, = tens_fld(1, 1, 1, self.beta, self.alpha)
        self.plot.mlab_source.set(x=x, y=y, z=z)

    # the layout of the dialog created
    view = View(
        Item('scene',
             editor=SceneEditor(scene_class=MayaviScene),
             height=750,
             width=750,
             show_label=False),
        HGroup(
            '_',
            'beta',
            'alpha',
        ),
    )
class MyDemo(HasTraits):
    scene = Instance(SceneModel, ())

    source = Instance(tvtk.ParametricFunctionSource, ())

    func_name = Enum([c.__name__ for c in source_types])

    func = Property(depends_on="func_name")

    traits_view = View(HSplit(
        VGroup(
            Item("func_name", show_label=False),
            Tabbed(
                Item("func",
                     style="custom",
                     editor=InstanceEditor(),
                     show_label=False),
                Item("source", style="custom", show_label=False))),
        Item("scene", style="custom", show_label=False, editor=SceneEditor())),
                       resizable=True,
                       width=700,
                       height=600)

    def __init__(self, *args, **kwds):
        super(MyDemo, self).__init__(*args, **kwds)
        self._make_pipeline()

    def _get_func(self):
        return sources[self.func_name]

    def _make_pipeline(self):
        self.func.on_trait_change(self.on_change, "anytrait")
        src = self.source
        src.on_trait_change(self.on_change, "anytrait")
        src.parametric_function = self.func
        map = tvtk.PolyDataMapper(input_connection=src.output_port)
        act = tvtk.Actor(mapper=map)
        self.scene.add_actor(act)
        self.src = src

    def _func_changed(self, old_func, this_func):
        if old_func is not None:
            old_func.on_trait_change(self.on_change, "anytrait", remove=True)
        this_func.on_trait_change(self.on_change, "anytrait")
        self.src.parametric_function = this_func
        self.scene.render()

    def on_change(self):
        self.scene.render()
Esempio n. 5
0
class ActorView(HasTraits):
    cube = Instance(tvtk.CubeSource, ())
    scene = Instance(SceneModel)

    traits_view = View(HSplit(
        Item("scene", style="custom", editor=SceneEditor(), show_label=False),
        Item("cube", style="custom", show_label=False)),
                       resizable=True,
                       width=700,
                       height=600)

    def _scene_default(self):
        cube = self.cube
        map = tvtk.PolyDataMapper(input=cube.output)
        act = tvtk.Actor(mapper=map)
        scene = SceneModel()
        scene.add_actor(act)
        return scene
Esempio n. 6
0
class ActorViewer(HasTraits):

    # A simple trait to change the actors/widgets.
    actor_type = Enum('cone', 'sphere', 'plane_widget', 'box_widget')

    # The scene model.
    scene = Instance(SceneModel, ())

    _current_actor = Any

    ######################
    view = View(Item(name='actor_type'),
                Item(name='scene', 
                     editor=SceneEditor(),
                     show_label=False,
                     resizable=True,
                     height=500,
                     width=500)
                )

    def __init__(self, **traits):
        super(ActorViewer, self).__init__(**traits)
        self._actor_type_changed(self.actor_type)

    ####################################
    # Private traits.
    def _actor_type_changed(self, value):
        scene = self.scene
        if self._current_actor is not None:
            scene.remove_actors(self._current_actor)
        if value == 'cone':
            a = actors.cone_actor()
            scene.add_actors(a)
        elif value == 'sphere':
            a = actors.sphere_actor()
            scene.add_actors(a)
        elif value == 'plane_widget':
            a = tvtk.PlaneWidget()
            scene.add_actors(a)
        elif value == 'box_widget':
            a = tvtk.BoxWidget()
            scene.add_actors(a)
        self._current_actor = a
Esempio n. 7
0
class PreviewWindow(HasTraits):
    """ A window with a mayavi engine, to preview pipeline elements.
    """

    # The engine that manages the preview view
    _engine = Instance(Engine)

    _scene = Instance(SceneModel, ())

    view = View(Item('_scene',
                     editor=SceneEditor(scene_class=Scene),
                     show_label=False),
                width=500,
                height=500)

    #-----------------------------------------------------------------------
    # Public API
    #-----------------------------------------------------------------------

    def add_source(self, src):
        self._engine.add_source(src)

    def add_module(self, module):
        self._engine.add_module(module)

    def add_filter(self, filter):
        self._engine.add_module(filter)

    def clear(self):
        self._engine.current_scene.scene.disable_render = True
        self._engine.current_scene.children[:] = []
        self._engine.current_scene.scene.disable_render = False

    #-----------------------------------------------------------------------
    # Private API
    #-----------------------------------------------------------------------

    def __engine_default(self):
        e = Engine()
        e.start()
        e.new_scene(self._scene)
        return e
Esempio n. 8
0
class ExampleScene(TracerScene):
    source_y = t_api.Range(0., 5., 2.)
    source_z = t_api.Range(0., 5., 1.)

    def __init__(self):
        # The energy bundle we'll use for now:
        nrm = 1 / (N.sqrt(2))
        direct = N.c_[[0, -nrm, nrm], [0, 0, -1]]
        position = N.tile(N.c_[[0, self.source_y, self.source_z]], (1, 2))
        self.bund = RayBundle(vertices=position,
                              directions=direct,
                              energy=N.r_[1, 1])

        # The assembly for ray tracing:
        rot1 = N.dot(G.rotx(N.pi / 4)[:3, :3], G.roty(N.pi)[:3, :3])
        surf1 = rect_one_sided_mirror(width=10, height=10)
        surf1.set_rotation(rot1)
        surf2 = rect_one_sided_mirror(width=10, height=10)
        self.assembly = Assembly(objects=[surf1, surf2])

        TracerScene.__init__(self, self.assembly, self.bund)

    @t_api.on_trait_change('_scene.activated')
    def initialize_camere(self):
        self._scene.mlab.view(0, -90)
        self._scene.mlab.roll(0)

    @t_api.on_trait_change('source_y, source_z')
    def bundle_move(self):
        position = N.tile(N.c_[[0, self.source_y, self.source_z]], (1, 2))
        self.bund.set_vertices(position)
        self.plot_ray_trace()

    view = tui.View(
        tui.Item('_scene',
                 editor=SceneEditor(scene_class=MayaviScene),
                 height=400,
                 width=300,
                 show_label=False), tui.HGroup('-', 'source_y', 'source_z'))
Esempio n. 9
0
class DemoApp(HasTraits):
    plotbutton = Button("绘图")
    # mayavi场景
    scene = Instance(MlabSceneModel, ())

    view = View(
        VGroup(
            # 设置mayavi的编辑器
            Item(name='scene',
                 editor=SceneEditor(scene_class=MayaviScene),
                 resizable=True,
                 height=250,
                 width=400),
            'plotbutton',
            show_labels=False),
        title="在TraitsUI中嵌入Mayavi")

    def _plotbutton_fired(self):
        self.plot()

    def plot(self):
        mlab.test_mesh()
Esempio n. 10
0
class TVTKViewer(Viewer):
    name = Str('TVTK Viewer')
    primitives = List(Primitive)
    scene = Instance(SceneModel, ())

    view = View(
        Item(name='scene', height=400, show_label=False, editor=SceneEditor()))

    traits_view = View(Item(name='name'),
                       Item(name='refresh_rate'),
                       Item(name='primitives',
                            editor=ListEditor(),
                            style='custom'),
                       title='Viewer')

    def start(self):
        from plotconfig import TVTKconfig
        self.config = TVTKconfig(self.variables)

        self.primitives = self.config.getPrimitives()
        for prim in self.primitives:
            self.scene.add_actors(prim.actor)

    def stop(self):
        pass

    def show(self):
        pass

    def hide(self):
        pass

    def update(self):
        for prim in self.primitives:
            prim.update()
        GUI.invoke_later(self.scene.render)
Esempio n. 11
0
    def call_mlab(self,
                  scene=None,
                  show=True,
                  is_3d=False,
                  view=None,
                  roll=None,
                  parallel_projection=False,
                  fgcolor=(0.0, 0.0, 0.0),
                  bgcolor=(1.0, 1.0, 1.0),
                  colormap='blue-red',
                  layout='rowcol',
                  scalar_mode='iso_surface',
                  vector_mode='arrows_norm',
                  rel_scaling=None,
                  clamping=False,
                  ranges=None,
                  is_scalar_bar=False,
                  is_wireframe=False,
                  opacity=None,
                  subdomains_args=None,
                  rel_text_width=None,
                  fig_filename='view.png',
                  resolution=None,
                  filter_names=None,
                  only_names=None,
                  group_names=None,
                  step=None,
                  time=None,
                  anti_aliasing=None,
                  domain_specific=None):
        """
        By default, all data (point, cell, scalars, vectors, tensors)
        are plotted in a grid layout, except data named 'node_groups',
        'mat_id' which are usually not interesting.

        Parameters
        ----------
        show : bool
            Call mlab.show().
        is_3d : bool
            If True, use scalar cut planes instead of surface for certain
            datasets. Also sets 3D view mode.
        view : tuple
            Azimuth, elevation angles, distance and focal point as in
            `mlab.view()`.
        roll : float
            Roll angle tuple as in mlab.roll().
        parallel_projection: bool
            If True, use parallel projection.
        fgcolor : tuple of floats (R, G, B)
            The foreground color, that is the color of all text
            annotation labels (axes, orientation axes, scalar bar
            labels).
        bgcolor : tuple of floats (R, G, B)
            The background color.
        colormap : str
            The colormap name.
        layout : str
            Grid layout for placing the datasets. Possible values are:
            'row', 'col', 'rowcol', 'colrow'.
        scalar_mode : str
             Mode for plotting scalars and tensor magnitudes, one of
             'cut_plane', 'iso_surface', 'both'.
        vector_mode : str
             Mode for plotting vectors, one of 'arrows', 'norm', 'arrows_norm',
             'warp_norm'.
        rel_scaling : float
            Relative scaling of glyphs for vector datasets.
        clamping : bool
            Clamping for vector datasets.
        ranges : dict
            List of data ranges in the form {name : (min, max), ...}.
        is_scalar_bar : bool
            If True, show a scalar bar for each data.
        is_wireframe : bool
            If True, show a wireframe of mesh surface bar for each data.
        opacity : float
            Global surface and wireframe opacity setting in [0.0, 1.0],
        subdomains_args : tuple
            Tuple of (mat_id_name, threshold_limits, single_color), see
            :func:`add_subdomains_surface`, or None.
        rel_text_width : float
            Relative text width.
        fig_filename : str
            File name for saving the resulting scene figure.
        resolution : tuple
            Scene and figure resolution. If None, it is set
            automatically according to the layout.
        filter_names : list of strings
            Omit the listed datasets. If None, it is initialized to
            ['node_groups', 'mat_id']. Pass [] if you need no filtering.
        only_names : list of strings
            Draw only the listed datasets. If None, it is initialized all names
            besides those in filter_names.
        group_names : list of tuples
            List of data names in the form [(name1, ..., nameN), (...)]. Plots
            of data named in each group are superimposed. Repetitions of names
            are possible.
        step : int, optional
            If not None, the time step to display. The closest higher step is
            used if the desired one is not available. Has precedence over
            `time`.
        time : float, optional
            If not None, the time of the time step to display. The closest
            higher time is used if the desired one is not available.
        anti_aliasing : int
            Value of anti-aliasing.
        domain_specific : dict
            Domain-specific drawing functions and configurations.
        """
        self.fgcolor = fgcolor
        self.bgcolor = bgcolor
        self.colormap = colormap

        if filter_names is None:
            filter_names = ['node_groups', 'mat_id']

        if rel_text_width is None:
            rel_text_width = 0.02

        if isinstance(scalar_mode, basestr):
            if scalar_mode == 'both':
                scalar_mode = ('cut_plane', 'iso_surface')
            elif scalar_mode in ('cut_plane', 'iso_surface'):
                scalar_mode = (scalar_mode, )
            else:
                raise ValueError('bad value of scalar_mode parameter! (%s)' %
                                 scalar_mode)
        else:
            for sm in scalar_mode:
                if not sm in ('cut_plane', 'iso_surface'):
                    raise ValueError(
                        'bad value of scalar_mode parameter! (%s)' % sm)

        if isinstance(vector_mode, basestr):
            if vector_mode == 'arrows_norm':
                vector_mode = ('arrows', 'norm')
            elif vector_mode == 'warp_norm':
                vector_mode = ('warp', 'norm')
            elif vector_mode in ('arrows', 'norm'):
                vector_mode = (vector_mode, )
            elif vector_mode == 'cut_plane':
                if is_3d:
                    vector_mode = ('cut_plane', )
                else:
                    vector_mode = ('arrows', )
            else:
                raise ValueError('bad value of vector_mode parameter! (%s)' %
                                 vector_mode)
        else:
            for vm in vector_mode:
                if not vm in ('arrows', 'norm', 'warp'):
                    raise ValueError(
                        'bad value of vector_mode parameter! (%s)' % vm)

        mlab.options.offscreen = self.offscreen

        self.size_hint = self.get_size_hint(layout, resolution=resolution)

        is_new_scene = False

        if scene is not None:
            if scene is not self.scene:
                is_new_scene = True
                self.scene = scene
            gui = None

        else:
            if (self.scene is not None) and (not self.scene.running):
                self.scene = None

            if self.scene is None:
                if self.offscreen or not show:
                    gui = None
                    scene = mlab.figure(fgcolor=fgcolor,
                                        bgcolor=bgcolor,
                                        size=self.size_hint)

                else:
                    gui = ViewerGUI(viewer=self,
                                    fgcolor=fgcolor,
                                    bgcolor=bgcolor)
                    scene = gui.scene.mayavi_scene

                if scene is not self.scene:
                    is_new_scene = True
                    self.scene = scene

            else:
                gui = self.gui
                scene = self.scene

        self.engine = mlab.get_engine()
        self.engine.current_scene = self.scene

        self.gui = gui

        self.file_source = create_file_source(self.filename,
                                              watch=self.watch,
                                              offscreen=self.offscreen)
        steps, times = self.file_source.get_ts_info()
        has_several_times = len(times) > 1
        has_several_steps = has_several_times or (len(steps) > 1)

        if gui is not None:
            gui.has_several_steps = has_several_steps

        self.reload_source = reload_source = ReloadSource()
        reload_source._viewer = self
        reload_source._source = self.file_source

        if has_several_steps:
            self.set_step = set_step = SetStep()
            set_step._viewer = self
            set_step._source = self.file_source
            if step is not None:
                step = step if step >= 0 else steps[-1] + step + 1
                assert_(steps[0] <= step <= steps[-1],
                        msg='invalid time step! (%d <= %d <= %d)' %
                        (steps[0], step, steps[-1]))
                set_step.step = step

            elif time is not None:
                assert_(times[0] <= time <= times[-1],
                        msg='invalid time! (%e <= %e <= %e)' %
                        (times[0], time, times[-1]))
                set_step.time = time

            else:
                set_step.step = steps[0]

            if self.watch:
                self.file_source.setup_notification(set_step, 'file_changed')

            if gui is not None:
                gui.set_step = set_step

        else:
            if self.watch:
                self.file_source.setup_notification(reload_source,
                                                    'reload_source')

        self.options.update(get_arguments(omit=['self', 'file_source']))

        if gui is None:
            self.render_scene(scene, self.options)
            self.reset_view()
            if is_scalar_bar:
                self.show_scalar_bars(self.scalar_bars)

        else:
            traits_view = View(
                Item(
                    'scene',
                    editor=SceneEditor(scene_class=MayaviScene),
                    show_label=False,
                    width=self.size_hint[0],
                    height=self.size_hint[1],
                    style='custom',
                ),
                Group(
                    Item('set_step',
                         defined_when='set_step is not None',
                         show_label=False,
                         style='custom'), ),
                HGroup(
                    spring,
                    Item('button_make_snapshots_steps',
                         show_label=False,
                         enabled_when='has_several_steps == True'),
                    Item('button_make_animation_steps',
                         show_label=False,
                         enabled_when='has_several_steps == True'),
                    spring,
                    Item('button_make_snapshots_times',
                         show_label=False,
                         enabled_when='has_several_steps == True'),
                    Item('button_make_animation_times',
                         show_label=False,
                         enabled_when='has_several_steps == True'),
                    spring,
                ),
                HGroup(spring, Item('button_reload', show_label=False),
                       Item('button_view', show_label=False),
                       Item('button_quit', show_label=False)),
                resizable=True,
                buttons=[],
                handler=ClosingHandler(),
            )

            if is_new_scene:
                if show:
                    gui.configure_traits(view=traits_view)

                else:
                    gui.edit_traits(view=traits_view)

        return gui
Esempio n. 12
0
class DLS_Data(HasTraits):
    scene = Instance(MlabSceneModel, ())

    data = Array
    cr = Array
    renormalized = Bool(True)
    corr_fig = Instance(Figure2D, ())
    correlation = Array

    possible_usable_data = List(Str)
    usable_data = List(Str)

    view = View(Item('scene',
                     editor=SceneEditor(scene_class=MayaviScene),
                     height=250,
                     width=300,
                     show_label=False),
                'corr_fig',
                handler=DLS_DataHandler(),
                resizable=True)

    def plot_data(self):
        figure = self.scene.mlab.gcf()
        self.scene.mlab.clf()
        figure.scene.disable_render = True
        self.lines = []
        x = np.log10(self.data[:, 0])
        x = x / x.max()
        index = -1
        x1 = np.array([x[index]] * (self.data.shape[1] - 1))
        y1 = []
        z1 = []
        self.scalars = list(range(self.data.shape[1] - 1))
        hi_color = 2 * max(self.scalars)
        self.hi_color = hi_color

        for i in range(self.data.shape[1] - 1):
            if self.renormalized == True:
                y = self.data[:, 1 + i] - self.data[:, 1 + i].min()
                y = y / y.max()
            else:
                y = self.data[:, 1 + i]
            y1.append(y[index])
            z = np.ones(self.data.shape[0]) * i * 1.0 / 10
            z1.append(z[index])

            self.lines.append(
                self.scene.mlab.plot3d(x,
                                       y,
                                       z, [self.scalars[i]] *
                                       (self.data.shape[0]),
                                       vmin=0,
                                       vmax=hi_color))
            self.scene.mlab.text3d(x[-1] + 0.1,
                                   y[-1],
                                   z[-1] + 0.05,
                                   str(i),
                                   scale=0.1)
        y1 = np.array(y1)
        z1 = np.array(z1)

        red_glyphs = self.scene.mlab.points3d(x1,
                                              y1,
                                              z1,
                                              self.scalars,
                                              scale_mode='none',
                                              scale_factor=0.1,
                                              vmin=0,
                                              vmax=hi_color,
                                              resolution=20)

        self.red_glyphs = red_glyphs
        outline = self.scene.mlab.outline(line_width=3)
        outline.outline_mode = 'cornered'
        outline.bounds = (x1[0] - 0.01, x1[0] + 0.01, y1[0] - 0.01,
                          y1[0] + 0.01, z1[0] - 0.01, z1[0] + 0.01)

        figure.scene.disable_render = False
        glyph_points = red_glyphs.glyph.glyph_source.glyph_source.output.points.to_array(
        )

        def picker_callback(picker):
            """ Picker callback: this get called when on pick events. 
            """
            if picker.actor in red_glyphs.actor.actors:
                # Find which data point corresponds to the point picked:
                # we have to account for the fact that each data point is
                # represented by a glyph with several points
                point_id = picker.point_id / glyph_points.shape[0]
                # If the no points have been selected, we have '-1'
                if point_id != -1:
                    # Retrieve the coordinnates coorresponding to that data
                    # point
                    x, y, z = x1[point_id], y1[point_id], z1[point_id]
                    # Move the outline to the data point.
                    outline.bounds = (x - 0.01, x + 0.01, y - 0.01, y + 0.01,
                                      z - 0.01, z + 0.01)
                    scalars = red_glyphs.mlab_source.scalars
                    if scalars[point_id] != hi_color:
                        scalars[point_id] = hi_color
                        self.lines[point_id].mlab_source.y -= 1.
                    else:
                        scalars[point_id] = self.scalars[point_id]
                        self.lines[point_id].mlab_source.y += 1.

                    red_glyphs.mlab_source.scalars = scalars
                    self.calculate()

        picker = figure.on_mouse_pick(picker_callback)
        picker.tolerance = 0.01

    def calculate(self):
        """Calculates average correlation data
        """
        corr = 0.
        n = 0
        cr_sum = 0.

        for i in range((self.data.shape[1] - 1)):
            if self.red_glyphs.mlab_source.scalars[i] != self.hi_color:
                print(i)
                cr_mean = self.cr[:, 2 * i + 1].mean()
                n += 1
                cr_sum += cr_mean
                corr += (self.data[:, 1 + i] + 1.) * cr_mean**2.
        corr = corr * n / cr_sum**2. - 1.
        self.correlation = np.empty(shape=(self.data.shape[0], 2),
                                    dtype='float')
        self.correlation[:, 1] = corr
        self.correlation[:, 0] = self.data[:, 0]
        self.corr_fig.ax.cla()
        self.corr_fig.ax.semilogx(self.data[:, 0], corr)
        self.corr_fig.update = True
        return corr

    def save(self, fname):
        np.save(fname, self.correlation)
class FieldExplorer(HasTraits):
    scene = Instance(SceneModel, ())
    wire = Instance(WireLoop)

    interact = Bool(False)
    ipl = Instance(tvtk.PlaneWidget, (), {
        'resolution': 50,
        'normal': [1., 0., 0.]
    })
    #plane_src = Instance(tvtk.PlaneSource, ())
    calc_B = Instance(tvtk.ProgrammableFilter, ())

    glyph = Instance(tvtk.Glyph3D, (), {'scale_factor': 0.02})
    scale_factor = DelegatesTo("glyph")

    lm = Instance(LUTManager, ())

    traits_view = View(HSplit(
        Item("scene", style="custom", editor=SceneEditor(), show_label=False),
        VGroup(Item("wire", style="custom", show_label=False),
               Item("interact"), Item("scale_factor"), Item("lm")),
    ),
                       resizable=True,
                       width=700,
                       height=600)

    def _interact_changed(self, i):
        self.ipl.interactor = self.scene.interactor
        self.ipl.place_widget()
        if i:
            self.ipl.on()
        else:
            self.ipl.off()

    def make_probe(self):
        src = self.ipl.poly_data_algorithm

        map = tvtk.PolyDataMapper(lookup_table=self.lm.lut)
        act = tvtk.Actor(mapper=map)

        calc_B = self.calc_B
        calc_B.input = src.output

        def execute():
            print "calc fields!"
            output = calc_B.poly_data_output
            points = output.points.to_array().astype('d')
            nodes = self.wire.nodes.astype('d')
            vectors = calc_wire_B_field(nodes, points, self.wire.radius)
            output.point_data.vectors = vectors
            mag = np.sqrt((vectors**2).sum(axis=1))
            map.scalar_range = (mag.min(), mag.max())

        calc_B.set_execute_method(execute)

        cone = tvtk.ConeSource(height=0.05, radius=0.01, resolution=15)
        cone.update()

        glyph = self.glyph
        glyph.input_connection = calc_B.output_port
        glyph.source = cone.output
        glyph.scale_mode = 'scale_by_vector'
        glyph.color_mode = 'color_by_vector'

        map.input_connection = glyph.output_port
        self.scene.add_actor(act)

    def on_update(self):
        self.calc_B.modified()
        self.scene.render()

    def _wire_changed(self, anew):
        anew.on_trait_change(self.on_update, "update")
        self.scene.add_actor(anew.actor)
Esempio n. 14
0
class OverlayMap(HasTraits):
    """
    Use mayavi to plot three image cut planes through an fMRI volume
    and stat-map overlay.
    """
    
    # Main scene
    scene = Instance(MlabSceneModel, ())

    # the image planes and lookup tables
    overlays = List(Instance(ImagePlaneWidget))
    underlays = List(Instance(ImagePlaneWidget))
    over_lut = Instance(HasTraits)
    under_lut = Instance(HasTraits)

    # lower range of the overlay lookup table
    _over_low_min = Float(0.0)
    _over_low_max = Float(0.1)
    over_low = Range(low='_over_low_min', high='_over_low_max',
                     value=0.0, mode='slider')

    # upper range of the overlay lookup table
    _over_hi_min = Float(0.0)
    _over_hi_max = Float(0.1)
    over_hi = Range(low='_over_hi_min', high='_over_hi_max',
                     value=0.01, mode='slider')


    # Whether to see x,y,z planes
    x_visible = Bool(True)
    y_visible = Bool(True)
    z_visible = Bool(True)

    # Which colormap to use 
    colormap = Enum("hot", 
                    "jet",
                    "autumn")
    
    def __init__(self, under_image, over_image):
        """
        Provide the underlay and overlay NiftiImages.  Can also
        provide filename strings.

        Example:

        stat = OverlayMap('anat.nii.gz','stat.nii.gz')
        """
        # we've got traits
        HasTraits.__init__(self)

        # load in the image
        if isinstance(under_image, NiftiImage):
            # use it
            self.__under_image = under_image
        elif isinstance(under_image, str):
            # load from file
            self.__under_image = NiftiImage(under_image)
        else:
            raise ValueError("under_image must be a NiftiImage or a file.")

        # TODO: set the extent and spacing of the under image

        # set the over data
        if isinstance(over_image, str):
            # load from file
            over_image = NiftiImage(over_image)

        if isinstance(over_image, NiftiImage):
            # TODO: make sure it matches the dims of under image
            # TODO: set the extent
            
            # save just the dat
            self.__over_image = over_image.data.T

        elif isinstance(over_image, np.ndarray):
            # just set it
            # assumes it matches the dims and extent of the under image
            self.__over_image = over_image

        else:
            raise ValueError("over_image must be a NiftiImage, ndarray, or file.")

        self.__over_image = np.ma.masked_invalid(self.__over_image)

        self.configure_traits()
        pass

    def _plane_callback1(self, widget, event):
	    self._update_planes(0)

    def _plane_callback2(self, widget, event):
	    self._update_planes(1)

    def _plane_callback3(self, widget, event):
	    self._update_planes(2)

    def _update_planes(self,num):
        # set the underlay positions.
        
        # TODO: it may make more sense to do this in the callback for
        # each individual plane when it is called instead of all at
        # once

        #for i in range(len(self.overlays)):
	#if widget == self.overlays[i]:
            #    print "widget is overlay", i
	    #elif widget == self.underlays[i]:
            #    print "widget is underlay", i
	    #else:
	    #    print "widget", widget
                

            # from what I can tell, all these are necessary
        self.underlays[num].ipw.update_traits()
        self.overlays[num].ipw.origin = self.underlays[num].ipw.origin
        self.overlays[num].ipw.point1 = self.underlays[num].ipw.point1
        self.overlays[num].ipw.point2 = self.underlays[num].ipw.point2
        self.overlays[num].ipw.update_traits()
        self.overlays[num].ipw.update_placement()
        #self.overlays[num].scene.render()
                      
    @on_trait_change('scene.activated')
    def _create_plot(self):
        # shorten things a bit
        mlab = self.scene.mlab

        # generate the scalar_fields
        over = mlab.pipeline.scalar_field(np.ma.masked_invalid(self.__over_image).filled(0))
        #over_thresh = mlab.pipeline.threshold(over,low=self.__over_image.mean())
        under = mlab.pipeline.scalar_field(np.ma.masked_invalid(self.__under_image.data.T).filled(0))

        # create the planes for the x,y,z axes
        self.underlays = []
        self.overlays = []
        for orient in ['x_axes','y_axes','z_axes']:
            # first the underlay
            # TODO: fix the slice_index, which is a hack
            under = mlab.pipeline.image_plane_widget(under,colormap='gray',
                                                     slice_index=92,
                                                     plane_opacity=0,
                                                     plane_orientation=orient)
            # set up the lookup table
            under.ipw.user_controlled_lookup_table = True
            if self.under_lut is None:
                # set it
                self.under_lut = under.module_manager.scalar_lut_manager.lut
            else:
                # use it
                under.module_manager.scalar_lut_manager.lut.table = self.under_lut.table

            # add the interaction event
	    if orient == "x_axes":
                under.ipw.add_observer("InteractionEvent", self._plane_callback1)
	    elif orient == "y_axes":
                under.ipw.add_observer("InteractionEvent", self._plane_callback2)
	    else:
                under.ipw.add_observer("InteractionEvent", self._plane_callback3)
            
            # add it to the list
            self.underlays.append(under)

            # set up the overlay
            # TODO: fix the slice_index, which is a hack
            over = mlab.pipeline.image_plane_widget(over,
                                                    colormap=self.colormap,
                                                    slice_index=92,
                                                    plane_opacity=0,
                                                    plane_orientation=orient)
            # set the lookup table
            over.ipw.user_controlled_lookup_table = True
            if self.over_lut is None:
                # is first one, so set it with alpha at bottom
                lut = over.module_manager.scalar_lut_manager.lut.table.to_array()
                lut[:40, -1] = np.linspace(0,255,40)
                over.module_manager.scalar_lut_manager.lut.table = lut
                self.over_lut = over.module_manager.scalar_lut_manager.lut
            else:
                # use it
                over.module_manager.scalar_lut_manager.lut.table = self.over_lut.table

            # turn off the interaction
            over.ipw.interaction = False
            # append 
            self.overlays.append(over)

        # set the overlay upper bounds range
        over_min = np.ma.masked_invalid(self.__over_image).min()
        over_max = np.ma.masked_invalid(self.__over_image).max()
        over_mean = np.ma.masked_invalid(self.__over_image).mean()
        print "mmm:", over_min, over_max, over_mean
        self._over_hi_min = float(over_min) #self.__over_image.min()
        self._over_hi_max = float(over_max) #self.__over_image.max()
        self.over_hi = over_max #self.__over_image.max()

        # set the overlay lower bounds range
        self._over_low_min = float(over_min) #self.__over_image.min()
        self._over_low_max = float(over_max) #self.__over_image.max()
        self.over_low = over_mean #self.__over_image.mean()

    #@on_trait_change('over_hi')
    def _over_hi_changed(self):
        if self.over_hi < self.over_low:
            # set low to be hi
            self.over_low = self.over_hi
        else:
            # update the range
            self._update_overlay_range()
        
    #@on_trait_change('over_low')
    def _over_low_changed(self):
        if self.over_low > self.over_hi:
            # set hi to be low
            self.over_hi = self.over_low
        else:
            # update the range
            self._update_overlay_range()
        
    def _update_overlay_range(self):
        # XXX: Do I need to copy here?
        new_range = self.overlays[0].module_manager.scalar_lut_manager.data_range.copy()
        new_range[0] = self.over_low
        new_range[1] = self.over_hi
        for i in range(len(self.overlays)):
            self.overlays[i].module_manager.scalar_lut_manager.data_range = new_range

    def _x_visible_changed(self):
        # toggle the proper plane on or off
        self._update_plane_visible(0,self.x_visible)
    def _y_visible_changed(self):
        # toggle the proper plane on or off
        self._update_plane_visible(1,self.y_visible)
    def _z_visible_changed(self):
        # toggle the proper plane on or off
        self._update_plane_visible(2,self.z_visible)

    def _update_plane_visible(self, plane_id, bool_val):
        self.underlays[plane_id].visible = bool_val
        self.overlays[plane_id].visible = bool_val

    def _colormap_changed(self):
        print self.colormap
        for o in self.overlays:
            #TODO: Change colormap, don't know how exactly
            #print o.module_manager.scalar_lut_manager.lut
            pass

    # define the view
    view = View(
        VSplit(
            Group(Item('scene', editor=SceneEditor(scene_class=MayaviScene), 
                       height=500, width=500, show_label=False)),
            Group(
                Group(Item('over_low', label="Lower Thresh"),
                      Item('over_hi', label="Upper Thresh"),
                      label="Overlay Properties",
                      show_border=True),
            ),
            Group(
                HGroup(Item('x_visible'),
                       Item('y_visible'),
                       Item('z_visible'),
                       Item('colormap'),
                       label="Plane visibility + colormap",
                       show_border=True),
            ),
        ),
        resizable=True,
        title='Overlay Viewer')
# -*- coding: utf-8 -*-
Esempio n. 16
0
class GeometryViewer(HasTraits):
    meridional = Range(1, 30, 6)
    transverse = Range(0, 30, 11)
    scene = Instance(MlabSceneModel, ())
    geometry = Instance(Geometry)
    bodies = Property(List(Instance(BodyViewer)), depends_on='geometry.refresh_geometry_view,geometry.bodies[]')
    surfaces = Property(List(Instance(SurfaceViewer)), depends_on='geometry.refresh_geometry_view,geometry.surfaces[]')
    @cached_property
    def _get_surfaces(self):
        ret = []
        for surface in self.geometry.surfaces:
            ret.append(SurfaceViewer(surface=surface))
        return ret
    
    @cached_property
    def _get_bodies(self):
        ret = []
        for body in self.geometry.bodies:
            ret.append(BodyViewer(body=body))
        return ret
    
    def section_points(self, sections, yduplicate):
        ret = numpy.empty((len(sections * 2), 3))
        ret2 = []
        for sno, section in enumerate(sections):
            pno = sno * 2
            pt = ret[pno:pno + 2, :]
            pt[0, :] = section.leading_edge
            pt[1, :] = section.leading_edge
            pt[1, 0] += section.chord * numpy.cos(section.angle * numpy.pi / 180)
            pt[1, 2] -= section.chord * numpy.sin(section.angle * numpy.pi / 180)
            if yduplicate is not numpy.nan:
                pt2 = numpy.copy(pt)
                pt2[:, 1] = yduplicate - pt2[:, 1]
                ret2.append(pt2[0, :])
                ret2.append(pt2[1, :])
        #print ret.shape, len(ret2)
        if len(ret2) > 0:
            ret = numpy.concatenate((ret, numpy.array(ret2)))
        return ret
    
    def section_points_old(self, sections, yduplicate):
        ret = []
        for section in sections:
            pt = numpy.empty((2, 3))
            pt[0, :] = section.leading_edge
            pt[1, :] = section.leading_edge
            pt[1, 0] += section.chord * numpy.cos(section.angle * numpy.pi / 180)
            pt[1, 2] -= section.chord * numpy.sin(section.angle * numpy.pi / 180)
            ret.append(pt)
            if yduplicate is not numpy.nan:
                pt2 = numpy.copy(pt)
                pt2[:, 1] = yduplicate - pt2[:, 1]
                ret.append(pt2)
        ret = numpy.concatenate(ret)
        return ret
    
    def __init__(self, **args):
        # Do not forget to call the parent's __init__
        HasTraits.__init__(self, **args)
        #self.plot = self.scene.mlab.plot3d(x, y, z, t, colormap='Spectral')
        self.update_plot()

    @on_trait_change('geometry,geometry.refresh_geometry_view')
    def update_plot(self):
        self.plots = []
        #self.plot.mlab_source.set(x=x, y=y, z=z, scalars=t)
        self.scene.mlab.clf()
        # plot the axes
        
        for surface in self.surfaces:
            section_pts = surface.sectiondata
            for i, section_pt in enumerate(section_pts):
                if len(section_pt) == 2:
                    #tube_radius = 0.02 * abs(section_pt[-1,0]-section_pt[0,0])
                    tube_radius = None
                else:
                    tube_radius = None
                self.plots.append(self.scene.mlab.plot3d(section_pt[:, 0], section_pt[:, 1], section_pt[:, 2], tube_radius=tube_radius))
            self.plots.append(self.scene.mlab.mesh(surface.surfacedata[:, :, 0], surface.surfacedata[:, :, 1], surface.surfacedata[:, :, 2]))
        for body in self.bodies:
            width = (body.data_props[2] - body.data_props[1]) / body.num_pts * 0.15
            for data in body.bodydata:
                c = numpy.empty((2, 3))
                c[:] = data[:3]
                c[0, 0] -= width / 2
                c[1, 0] += width / 2
                #print c, width
                self.plots.append(self.scene.mlab.plot3d(c[:, 0], c[:, 1], c[:, 2], tube_radius=data[3], tube_sides=24))
                if numpy.isfinite(body.body.yduplicate):
                    c[:, 1] = body.body.yduplicate - c[:, 1]
                    self.plots.append(self.scene.mlab.plot3d(c[:, 0], c[:, 1], c[:, 2], tube_radius=data[3], tube_sides=24))
        #print 'numplots = ', len(self.plots)
    
    #@on_trait_change('geometry')
    def update_plot_old(self):
        self.plots = []
        #self.plot.mlab_source.set(x=x, y=y, z=z, scalars=t)
        self.scene.mlab.clf()
        for surface in self.geometry.surfaces:
            yduplicate = surface.yduplicate
            section_pts = self.section_points(surface.sections, yduplicate)
            for i in xrange(0, section_pts.shape[0], 2):
                self.plots.append(self.scene.mlab.plot3d(section_pts[i:i + 2, 0], section_pts[i:i + 2, 1], section_pts[i:i + 2, 2], tube_radius=0.1))
        print 'numplots = ', len(self.plots)

    # the layout of the dialog created
    view = View(Item('scene', editor=SceneEditor(scene_class=MayaviScene),
                    height=250, width=300, show_label=False),
                #Item('geometry', style='custom'),
                resizable=True
                )
Esempio n. 17
0
class VTKDataObject(DataObject):
    yt_scene = Instance(YTScene)
    scene = DelegatesTo("yt_scene")
    add_contours = Button
    add_isocontour = Button
    add_x_plane = Button
    add_y_plane = Button
    add_z_plane = Button
    edit_camera = Button
    edit_operators = Button
    edit_pipeline = Button
    center_on_max = Button
    operators = DelegatesTo("yt_scene")
    traits_view = View(
        Item("scene",
             editor=SceneEditor(scene_class=DecoratedScene),
             resizable=True,
             show_label=False),
        HGroup(
            Item("add_contours", show_label=False),
            Item("add_isocontour", show_label=False),
            Item("add_x_plane", show_label=False),
            Item("add_y_plane", show_label=False),
            Item("add_z_plane", show_label=False),
            Item("edit_camera", show_label=False),
            Item("edit_operators", show_label=False),
            Item("edit_pipeline", show_label=False),
            Item("center_on_max", show_label=False),
        ),
    )

    operators_edit = View(Item("operators",
                               style='custom',
                               show_label=False,
                               editor=ListEditor(editor=InstanceEditor(),
                                                 use_notebook=True),
                               name="Edit Operators"),
                          height=500.0,
                          width=500.0,
                          resizable=True)

    def _edit_camera_fired(self):
        self.yt_scene.camera_path.edit_traits()

    def _edit_operators_fired(self):
        self.edit_traits(view='operators_edit')

    def _edit_pipeline_fired(self):
        from enthought.tvtk.pipeline.browser import PipelineBrowser
        pb = PipelineBrowser(self.scene)
        pb.show()

    def _add_contours_fired(self):
        self.yt_scene.add_contour()

    def _add_isocontour_fired(self):
        self.yt_scene.add_isocontour()

    def _add_x_plane_fired(self):
        self.yt_scene.add_x_plane()

    def _add_y_plane_fired(self):
        self.yt_scene.add_y_plane()

    def _add_z_plane_fired(self):
        self.yt_scene.add_z_plane()

    def _center_on_max_fired(self):
        self.yt_scene.do_center_on_max()
Esempio n. 18
0
class Mayavi(HasTraits):

    # The scene model.
    scene = Instance(MlabSceneModel, ())

    # The mayavi engine view.
    engine_view = Instance(EngineView)

    # The current selection in the engine tree view.
    current_selection = Property

    ######################
    view = View(HSplit(
        VSplit(
            Item(name='engine_view',
                 style='custom',
                 resizable=True,
                 show_label=False),
            Item(name='current_selection',
                 editor=InstanceEditor(),
                 enabled_when='current_selection is not None',
                 style='custom',
                 springy=True,
                 show_label=False),
        ),
        Item(name='scene',
             editor=SceneEditor(),
             show_label=False,
             resizable=True,
             height=500,
             width=500),
    ),
                resizable=True,
                scrollable=True)

    def __init__(self, **traits):
        HasTraits.__init__(self, **traits)
        self.engine_view = EngineView(engine=self.scene.engine)

        # Hook up the current_selection to change when the one in the engine
        # changes.  This is probably unnecessary in Traits3 since you can show
        # the UI of a sub-object in T3.
        self.scene.engine.on_trait_change(self._selection_change,
                                          'current_selection')

        self.generate_data_mayavi()

    def generate_data_mayavi(self):
        """Shows how you can generate data using mayavi instead of mlab."""
        from enthought.mayavi.sources.api import ParametricSurface
        from enthought.mayavi.modules.api import Outline, Surface
        e = self.scene.engine
        s = ParametricSurface()
        e.add_source(s)
        e.add_module(Outline())
        e.add_module(Surface())

    def _selection_change(self, old, new):
        self.trait_property_changed('current_selection', old, new)

    def _get_current_selection(self):
        return self.scene.engine.current_selection
Esempio n. 19
0
class ImageViewer(HasTraits):
    lm = Instance(LUTManager, ())
    source = Instance(tvtk.ProgrammableSource, ())
    mapwin = Instance(tvtk.ImageMapToWindowLevelColors, ())
    scene = Instance(SceneModel)

    traits_view = View(Item('scene', style="custom", editor=SceneEditor()),
                       Item('lm'),
                       Item('object.mapwin.level'),
                       Item('object.mapwin.window'),
                       resizable=True,
                       width=700,
                       height=600)

    def _scene_default(self):
        scene = SceneModel()

        src = self.source

        def execute():
            x, y = np.ogrid[-10:10:0.1, -10:10:0.1]
            r = 3 * np.sqrt(x**2 + y**2) + 0.001
            z = np.sin(r) / r
            print "shape", z.shape
            output = src.structured_points_output
            #output.origin = [0,0,0]
            #output.spacing = [1,1,1]
            #output.dimensions = [z.shape[0], z.shape[1], 1]
            output.whole_extent = [0, z.shape[0] - 1, 0, z.shape[1] - 1, 0, 0]
            output.point_data.scalars = z.reshape(-1, 1)
            #print "output", output

        src.set_execute_method(execute)

        #        output = tvtk.ImageData()
        #        x,y = np.ogrid[-10:10:0.1,-10:10:0.1]
        #        r = 3*np.sqrt(x**2 + y**2) + 0.001
        #        z = np.sin(r)/r
        #        print "shape", z.shape
        #        output.origin = [0,0,0]
        #        output.spacing = [1,1,1]
        #        #output.dimensions = [z.shape[0], z.shape[1], 1]
        #        output.point_data.scalars = z.reshape(-1,1)
        #        output.extent = [0,z.shape[0]-1,0,z.shape[1]-1,0,0]

        ss = tvtk.ImageShiftScale(input=src.structured_points_output,
                                  shift=127.0,
                                  scale=127.0)
        ss.set_output_scalar_type_to_unsigned_char()
        print "set source output"

        mapwin = self.mapwin
        mapwin.input = src.structured_points_output
        mapwin.lookup_table = self.lm.lut

        #ss.update()
        #print ss.output

        act = tvtk.ImageActor(input=mapwin.output)

        scene.add_actor(act)

        scene.add_actor(self.lm.scalar_bar)
        return scene
Esempio n. 20
0
class MultiFitGui(HasTraits):
    """
    data should be c x N where c is the number of data columns/axes and N is
    the number of points
    """
    doplot3d = Bool(False)
    show3d = Button('Show 3D Plot')
    replot3d = Button('Replot 3D')
    scalefactor3d = Float(0)
    do3dscale = Bool(False)
    nmodel3d = Int(1024)
    usecolor3d = Bool(False)
    color3d = Color((0,0,0))
    scene3d = Instance(MlabSceneModel,())
    plot3daxes = Tuple(('x','y','z'))
    data = Array(shape=(None,None))
    weights = Array(shape=(None,))
    curveaxes = List(Tuple(Int,Int))
    axisnames = Dict(Int,Str)
    invaxisnames = Property(Dict,depends_on='axisnames')

    fgs = List(Instance(FitGui))


    traits_view = View(VGroup(Item('fgs',editor=ListEditor(use_notebook=True,page_name='.plotname'),style='custom',show_label=False),
                              Item('show3d',show_label=False)),
                              resizable=True,height=900,buttons=['OK','Cancel'],title='Multiple Model Data Fitters')

    plot3d_view = View(VGroup(Item('scene3d',editor=SceneEditor(scene_class=MayaviScene),show_label=False,resizable=True),
                              Item('plot3daxes',editor=TupleEditor(cols=3,labels=['x','y','z']),label='Axes'),
                              HGroup(Item('do3dscale',label='Scale by weight?'),
                              Item('scalefactor3d',label='Point scale'),
                              Item('nmodel3d',label='Nmodel')),
                              HGroup(Item('usecolor3d',label='Use color?'),Item('color3d',label='Relation Color',enabled_when='usecolor3d')),
                              Item('replot3d',show_label=False),springy=True),
                       resizable=True,height=800,width=800,title='Multiple Model3D Plot')

    def __init__(self,data,names=None,models=None,weights=None,dofits=True,**traits):
        """
        :param data: The data arrays
        :type data: sequence of c equal-length arrays (length N)
        :param names: Names
        :type names: sequence of strings, length c
        :param models:
            The models to fit for each pair either as strings or
            :class:`astroypsics.models.ParametricModel` objects.
        :type models: sequence of models, length c-1
        :param weights: the weights for each point or None for no weights
        :type weights: array-like of size N or None
        :param dofits:
            If True, the data will be fit to the models when the object is
            created, otherwise the models will be passed in as-is (or as
            created).
        :type dofits: bool

        extra keyword arguments get passed in as new traits
        (r[finmask],m[finmask],l[finmask]),names='rh,Mh,Lh',weights=w[finmask],models=models,dofits=False)
        """
        super(MultiFitGui,self).__init__(**traits)
        self._lastcurveaxes = None

        data = np.array(data,copy=False)
        if weights is None:
            self.weights = np.ones(data.shape[1])
        else:
            self.weights = np.array(weights)

        self.data = data
        if data.shape[0] < 2:
            raise ValueError('Must have at least 2 columns')

        if isinstance(names,basestring):
            names = names.split(',')
        if names is None:
            if len(data) == 2:
                self.axisnames = {0:'x',1:'y'}
            elif len(data) == 3:
                self.axisnames = {0:'x',1:'y',2:'z'}
            else:
                self.axisnames = dict((i,str(i)) for i in data)
        elif len(names) == len(data):
            self.axisnames = dict([t for t in enumerate(names)])
        else:
            raise ValueError("names don't match data")

        #default to using 0th axis as parametric
        self.curveaxes = [(0,i) for i in range(len(data))[1:]]
        if models is not None:
            if len(models) != len(data)-1:
                raise ValueError("models don't match data")
            for i,m in enumerate(models):
                fg = self.fgs[i]
                newtmodel = TraitedModel(m)
                if dofits:
                    fg.tmodel = newtmodel
                    fg.fitmodel = True #should happen automatically, but this makes sure
                else:
                    oldpard = newtmodel.model.pardict
                    fg.tmodel = newtmodel
                    fg.tmodel .model.pardict = oldpard
                if dofits:
                    fg.fitmodel = True

    def _data_changed(self):
        self.curveaxes = [(0,i) for i in range(len(self.data))[1:]]

    def _axisnames_changed(self):
        for ax,fg in zip(self.curveaxes,self.fgs):
            fg.plot.x_axis.title = self.axisnames[ax[0]] if ax[0] in self.axisnames else ''
            fg.plot.y_axis.title = self.axisnames[ax[1]] if ax[1] in self.axisnames else ''
        self.plot3daxes = (self.axisnames[0],self.axisnames[1],self.axisnames[2] if len(self.axisnames) > 2 else self.axisnames[1])

    @on_trait_change('curveaxes[]')
    def _curveaxes_update(self,names,old,new):
        ax=[]
        for t in self.curveaxes:
            ax.append(t[0])
            ax.append(t[1])
        if set(ax) != set(range(len(self.data))):
            self.curveaxes = self._lastcurveaxes
            return #TOOD:check for recursion

        if self._lastcurveaxes is None:
            self.fgs = [FitGui(self.data[t[0]],self.data[t[1]],weights=self.weights) for t in self.curveaxes]
            for ax,fg in zip(self.curveaxes,self.fgs):
                fg.plot.x_axis.title = self.axisnames[ax[0]] if ax[0] in self.axisnames else ''
                fg.plot.y_axis.title = self.axisnames[ax[1]] if ax[1] in self.axisnames else ''
        else:
            for i,t in enumerate(self.curveaxes):
                if  self._lastcurveaxes[i] != t:
                    self.fgs[i] = fg = FitGui(self.data[t[0]],self.data[t[1]],weights=self.weights)
                    ax = self.curveaxes[i]
                    fg.plot.x_axis.title = self.axisnames[ax[0]] if ax[0] in self.axisnames else ''
                    fg.plot.y_axis.title = self.axisnames[ax[1]] if ax[1] in self.axisnames else ''

        self._lastcurveaxes = self.curveaxes

    def _show3d_fired(self):
        self.edit_traits(view='plot3d_view')
        self.doplot3d = True
        self.replot3d = True

    def _plot3daxes_changed(self):
        self.replot3d = True

    @on_trait_change('weights',post_init=True)
    def weightsChanged(self):
        for fg in self.fgs:
            if fg.weighttype != 'custom':
                fg.weighttype = 'custom'
            fg.weights = self.weights


    @on_trait_change('data','fgs','replot3d','weights')
    def _do_3d(self):
        if self.doplot3d:
            M = self.scene3d.mlab
            try:
                xi = self.invaxisnames[self.plot3daxes[0]]
                yi = self.invaxisnames[self.plot3daxes[1]]
                zi = self.invaxisnames[self.plot3daxes[2]]

                x,y,z = self.data[xi],self.data[yi],self.data[zi]
                w = self.weights

                M.clf()
                if self.scalefactor3d == 0:
                    sf = x.max()-x.min()
                    sf *= y.max()-y.min()
                    sf *= z.max()-z.min()
                    sf = sf/len(x)/5
                    self.scalefactor3d = sf
                else:
                    sf = self.scalefactor3d
                glyph = M.points3d(x,y,z,w,scale_factor=sf)
                glyph.glyph.scale_mode = 0 if self.do3dscale else 1
                M.axes(xlabel=self.plot3daxes[0],ylabel=self.plot3daxes[1],zlabel=self.plot3daxes[2])

                try:
                    xs = np.linspace(np.min(x),np.max(x),self.nmodel3d)

                    #find sequence of models to go from x to y and z
                    ymods,zmods = [],[]
                    for curri,mods in zip((yi,zi),(ymods,zmods)):
                        while curri != xi:
                            for i,(i1,i2) in enumerate(self.curveaxes):
                                if curri==i2:
                                    curri = i1
                                    mods.insert(0,self.fgs[i].tmodel.model)
                                    break
                            else:
                                raise KeyError

                    ys = xs
                    for m in ymods:
                        ys = m(ys)
                    zs = xs
                    for m in zmods:
                        zs = m(zs)

                    if self.usecolor3d:
                        c = (self.color3d[0]/255,self.color3d[1]/255,self.color3d[2]/255)
                        M.plot3d(xs,ys,zs,color=c)
                    else:
                        M.plot3d(xs,ys,zs,np.arange(len(xs)))
                except (KeyError,TypeError):
                    M.text(0.5,0.75,'Underivable relation')
            except KeyError:
                M.clf()
                M.text(0.25,0.25,'Data problem')



    @cached_property
    def _get_invaxisnames(self):
        d={}
        for k,v in self.axisnames.iteritems():
            d[v] = k
        return d
Esempio n. 21
0
# -*- coding: utf-8 -*-