class ActorViewer(HasTraits):

    # The scene model.
    scene = Instance(MlabSceneModel, ())

    ######################
    # Using 'scene_class=MayaviScene' adds a Mayavi icon to the toolbar,
    # to pop up a dialog editing the pipeline.
    view = View(Item(name='scene',
                     editor=SceneEditor(scene_class=MayaviScene),
                     show_label=False,
                     resizable=True,
                     height=500,
                     width=500),
                resizable=True)

    def __init__(self, **traits):
        HasTraits.__init__(self, **traits)
        self.generate_data()

    def generate_data(self):
        # Create some data
        X, Y = mgrid[-2:2:100j, -2:2:100j]
        R = 10 * sqrt(X**2 + Y**2)
        Z = sin(R) / R

        self.scene.mlab.surf(X, Y, Z, colormap='gist_earth')
示例#2
0
 def default_traits_view(self):
     view = View(
         Item('scene',
              editor=SceneEditor(scene_class=MayaviScene),
              height=600,
              width=600,
              show_label=False),
         HGroup(
             Item("current_time", label="Date"), Item(" "),
             Item("num_of_shown_days", label="Show"),
             Item("_home_button", show_label=False),
             Item("_selected_source_name", show_label=False),
             Item("_selected_event_name",
                  editor=CheckListEditor(name='_selected_events_list'),
                  show_label=False), Item("_back1", show_label=False),
             Item(
                 "Relative_Start_Day",
                 show_label=False,
                 editor=RangeEditor(mode="slider",
                                    low_name="_low_start_day_number",
                                    high_name="_high_start_day_number"),
                 tooltip=
                 "Shows total number of days in data set and the currently selected day",
                 springy=True,
                 full_size=True), Item("_forward1", show_label=False),
             Item("move_step", show_label=False),
             Item("play_button", label='Play')),
         title="Visualization of Events",
         resizable=True)
     view.resizable = True
     return view
示例#3
0
文件: demo2.py 项目: purple910/python
class MyModel(HasTraits):
    n_meridional = Range(0, 36, 6)
    n_longitudinal = Range(0, 30, 11)
    scene = Instance(MlabSceneModel, ())

    plot = Instance(PipelineBase)

    @on_trait_change('n_meridional,n_longitudinal,scene.activated')
    def update_plot(self):
        x, y, z, t = curve(n_mer=self.n_meridional, n_long=self.n_longitudinal)
        if self.plot is None:
            self.plot = self.scene.mlab.plot3d(x,
                                               y,
                                               z,
                                               t,
                                               tube_radius=0.025,
                                               colormap='Spectral')
        else:
            self.plot.mlab_source.set(x=x, y=y, z=z, scalars=t)

    view = View(Item('scene',
                     editor=SceneEditor(scene_class=MayaviScene),
                     height=250,
                     width=300,
                     show_label=False),
                Group('_', 'n_meridional', 'n_longitudinal'),
                resizable=True)
示例#4
0
class Kit2FiffFrame(HasTraits):
    """GUI for interpolating between two KIT marker files"""
    model = Instance(Kit2FiffModel, ())
    scene = Instance(MlabSceneModel, ())
    headview = Instance(HeadViewController)
    marker_panel = Instance(CombineMarkersPanel)
    kit2fiff_panel = Instance(Kit2FiffPanel)

    view = View(HGroup(
        VGroup(Item('marker_panel', style='custom'), show_labels=False),
        VGroup(
            Item('scene',
                 editor=SceneEditor(scene_class=MayaviScene),
                 dock='vertical',
                 show_label=False),
            VGroup(headview_item, show_labels=False),
        ),
        VGroup(Item('kit2fiff_panel', style='custom'), show_labels=False),
        show_labels=False,
    ),
                handler=Kit2FiffFrameHandler(),
                height=700,
                resizable=True,
                buttons=NoButtons)

    def _headview_default(self):
        return HeadViewController(scene=self.scene, scale=160, system='RAS')

    def _kit2fiff_panel_default(self):
        return Kit2FiffPanel(scene=self.scene, model=self.model)

    def _marker_panel_default(self):
        return CombineMarkersPanel(scene=self.scene,
                                   model=self.model.markers,
                                   trans=als_ras_trans)
示例#5
0
class Visualization(HasTraits):
    meridonal = Range(1, 30, 6)
    transverse = Range(0, 30, 11)
    scene = Instance(MlabSceneModel, ())

    def __init__(self):
        HasTraits.__init__(self)
        x, y, z, t = curve(self.meridonal, self.transverse)
        self.plot = self.scene.mlab.plot3d(x, y, x, t, colormap='Spectral')

    @on_trait_change('meridonal,transverse')
    def update_plot(self):
        x, y, z, t = curve(self.meridonal, self.transverse)
        self.plot.mlab_source.set(x=x, y=y, z=z, scalars=t)

    #layout of the dialog
    view = View(
        Item('scene',
             editor=SceneEditor(scene_class=MayaviScene),
             height=250,
             width=300,
             show_label=False),
        HGroup(
            '_',
            'meridonal',
            'transverse',
        ),
    )
示例#6
0
class ActorViewer(HasTraits):
    # 建立场景实例
    scene = Instance(MlabSceneModel, ())

    # 提供mayavi试图窗口
    view = View(Item(name='scene',
                     editor=SceneEditor(scene_class=MayaviScene),
                     show_label=False,
                     resizable=True,
                     height=500,
                     width=500),
                resizable=True)

    # 重载初始化函数
    def __init__(self, **tratis):
        HasTraits.__init__(self, **tratis)
        self.generate_data()

    #
    def generate_data(self):
        # 建立数据
        X, Y = mgrid[-2:2:100j, -2:2:100j]
        R = 10 * sqrt(X**2 + Y**2)
        Z = sin(R) / R
        # 绘制数据
        self.scene.mlab.surf(X, Y, Z, colormap='cool')
示例#7
0
class CombineMarkersFrame(HasTraits):
    """GUI for interpolating between two KIT marker files

    Parameters
    ----------
    mrk1, mrk2 : str
        Path to pre- and post measurement marker files (*.sqd) or empty string.
    """
    model = Instance(CombineMarkersModel, ())
    scene = Instance(MlabSceneModel, ())
    headview = Instance(HeadViewController)
    panel = Instance(CombineMarkersPanel)

    def _headview_default(self):
        return HeadViewController(scene=self.scene, system='ALS')

    def _panel_default(self):
        return CombineMarkersPanel(model=self.model, scene=self.scene)

    view = View(HGroup(
        Item('scene',
             editor=SceneEditor(scene_class=MayaviScene),
             dock='vertical'),
        VGroup(headview_borders,
               Item('panel', style="custom"),
               show_labels=False),
        show_labels=False,
    ),
                width=1100,
                resizable=True,
                buttons=NoButtons)
示例#8
0
class MlabGui(HasTraits):    
    x = Bool
    scene = Instance(MlabSceneModel,  ())
    funcs = []

    def __init__(self,funcs):
        HasTraits.__init__(self)
        self.funcs = funcs        
    
    def func1(self,info):
        self.funcs[0]()
    def func2(self,info):
        self.funcs[1]()
    def func3(self,info):
        self.funcs[2]()
        
    
    view = View(
        Item("scene",editor = SceneEditor(scene_class=MayaviScene),
                 height=250, width=300, show_label=False),
        Item("x"),
        resizable = True,
        key_bindings = bindings,

    )
示例#9
0
class DishScene(TracerScene):
    """
    Extends TracerScene with the variables required for this example and adds
    handling of simulation-specific details, like colouring the dish elements
    and setting proper resolution.
    """
    refl = t_api.Float(1., label='Edge reflections')
    concent = t_api.Float(450, label='Concentration')
    disp_num_rays = t_api.Int(10)

    def __init__(self):
        dish, source = self.create_dish_source()
        TracerScene.__init__(self, dish, source)
        self.set_background((0., 0.5, 1.))

    def create_dish_source(self):
        """
        Creates the two basic elements of this simulation: the parabolic dish,
        and the pillbox-sunshape ray bundle. Uses the variables set by 
        TraitsUI.
        """
        dish, f, W, H = standard_minidish(1., self.concent, self.refl, 1., 1.)
        # Add GUI annotations to the dish assembly:
        for surf in dish.get_homogenizer().get_surfaces():
            surf.colour = (1., 0., 0.)
        dish.get_main_reflector().colour = (0., 0., 1.)

        source = solar_disk_bundle(self.disp_num_rays,
                                   N.c_[[0., 0., f + H + 0.5]],
                                   N.r_[0., 0., -1.], 0.5, 0.00465)
        source.set_energy(
            N.ones(self.disp_num_rays) * 1000. / self.disp_num_rays)

        return dish, source

    @t_api.on_trait_change('refl, concent')
    def recreate_dish(self):
        """
        Makes sure that the scene is redrawn upon dish design changes.
        """
        dish, source = self.create_dish_source()
        self.set_assembly(dish)
        self.set_source(source)

    # Parameters of the form that is shown to the user:
    view = tui.View(
        tui.Item('_scene',
                 editor=SceneEditor(scene_class=MayaviScene),
                 height=500,
                 width=500,
                 show_label=False),
        tui.HGroup(
            '-',
            tui.Item('concent',
                     editor=tui.TextEditor(evaluate=float, auto_set=False)),
            tui.Item('refl',
                     editor=tui.TextEditor(evaluate=float, auto_set=False))))
示例#10
0
class TubeDemoApp(HasTraits):
    max_radius = Float(1.0)
    ri1 = Range(0.0, 1.0, 0.8)
    ro1 = Range("ri1", "max_radius", 1.0)
    ri2 = Range(0.0, 1.0, 0.4)
    ro2 = Range("ri2", "max_radius", 0.6)
    update = Button("Update")
    scene = Instance(SceneModel, ())
    view = View(
        VGroup(
            Item(name="scene", editor=SceneEditor(scene_class=Scene)),
            HGroup("ri1", "ro1"),
            HGroup("ri2", "ro2"),
            "update",
            show_labels=False
        ),
        resizable=True,
        height=500,
        width=500,
    )

    def __init__(self, **kw):
        super(TubeDemoApp, self).__init__(**kw)
        self.plot()

    def plot(self):
        t1, a1, o1, i1 = make_tube(5.0, [self.ro1, self.ri1], 32)
        t2, a2, o2, i2 = make_tube(5.0, [self.ro2, self.ri2], 32, rx=90)
        th1, ah1 = difference(t1, i2)
        th2, ah2 = difference(t2, i1)
        ah1.property.opacity = 0.6
        ah2.property.opacity = 0.6
        _, aline = intersection(t1, t2)

        # bind events
        self.co1 = get_source(o1, tvtk.CylinderSource)
        self.ci1 = get_source(i1, tvtk.CylinderSource)
        self.co2 = get_source(o2, tvtk.CylinderSource)
        self.ci2 = get_source(i2, tvtk.CylinderSource)
        self.scene.add_actors([ah1, ah2, aline])

    def _update_fired(self):
        self.co1.radius = self.ro1
        self.ci1.radius = self.ri1
        self.co2.radius = self.ro2
        self.ci2.radius = self.ri2
        self.scene.render_window.render()

    def depth_peeling(self):
        rw = self.scene.render_window
        renderer = self.scene.renderer
        rw.alpha_bit_planes = 1
        rw.multi_samples = 0
        renderer.use_depth_peeling = 1
        renderer.maximum_number_of_peels = 100
        renderer.occlusion_ratio = 0.1
示例#11
0
 def scene_view_item(height=400, width=300):
     """
     Generates an item placable on TraitsUI views, including all necessary
     imports, so that not every non-trivial usage requires tons of imports.
     
     Arguments:
     height, width - of the tui.Item, passed directly to the constructor.
     """
     return tui.Item('_scene', editor=SceneEditor(scene_class=MayaviScene),
         height=height, width=width, show_label=False)
class Plot3dPane(TraitsTaskPane):

    #### 'ITaskPane' interface ################################################

    id = "example.attractors.plot_3d_pane"
    name = "Plot 3D Pane"

    #### 'Plot3dPane' interface ###############################################

    active_model = Instance(IModel3d)
    models = List(IModel3d)

    scene = Instance(MlabSceneModel, ())

    view = View(
        HGroup(
            Label("Model: "),
            Item("active_model", editor=EnumEditor(name="_enum_map")),
            show_labels=False,
        ),
        Item(
            "scene",
            editor=SceneEditor(scene_class=MayaviScene),
            show_label=False,
        ),
        resizable=True,
    )

    #### Private traits #######################################################

    _enum_map = Dict(IModel3d, Str)

    ###########################################################################
    # Protected interface.
    ###########################################################################

    #### Trait change handlers ################################################

    @on_trait_change("active_model.points")
    def _update_scene(self):
        self.scene.mlab.clf()
        if self.active_model:
            x, y, z = self.active_model.points.swapaxes(0, 1)
            self.scene.mlab.plot3d(x, y, z, line_width=1.0, tube_radius=None)

    @on_trait_change("models[]")
    def _update_models(self):
        # Make sure that the active model is valid with the new model list.
        if self.active_model not in self.models:
            self.active_model = self.models[0] if self.models else None

        # Refresh the EnumEditor map.
        self._enum_map = dict((model, model.name) for model in self.models)
示例#13
0
    def default_traits_view(self):  # pylint: disable=no-self-use
        """
        Create the default traits View object for the model

        Returns
        -------
        default_traits_view : :py:class:`traitsui.view.View`
            The default traits View object for the model
        """
        return View(
            Item('scene',
                 show_label=False,
                 editor=SceneEditor(scene_class=MayaviScene)))
示例#14
0
    class TestScene(HasTraits):
        scene = Instance(SceneModel, ())

        view = View(Item("scene", show_label=False, editor=SceneEditor()))

        def show(self):
            grid = EmptyGridSource(dimensions=(20, 30, 40))
            src_output_port = grid.output_port
            src_output_port = grid._vtk_obj.GetOutputPort()
            assert src_output_port is not None
            extract = tvtk.ImageDataGeometryFilter(
                input_connection=src_output_port)
            mapper = tvtk.PolyDataMapper(input_connection=extract.output_port)
            act = tvtk.Actor(mapper=mapper)
            scene = self.scene
            scene.add_actor(act)
            self.configure_traits()
示例#15
0
class ActorViewer(HasTraits):

    # A simple trait to change the actors/widgets.
    actor_type = Enum('cone', 'sphere', 'plane_widget', 'box_widget')

    # The scene model.
    scene = Instance(SceneModel, ())

    _current_actor = Any

    ######################
    view = View(
        Item(name='actor_type'),
        Item(name='scene',
             editor=SceneEditor(),
             show_label=False,
             resizable=True,
             height=500,
             width=500))

    def __init__(self, **traits):
        super(ActorViewer, self).__init__(**traits)
        self._actor_type_changed(self.actor_type)

    ####################################
    # Private traits.
    def _actor_type_changed(self, value):
        scene = self.scene
        if self._current_actor is not None:
            scene.remove_actors(self._current_actor)
        if value == 'cone':
            a = actors.cone_actor()
            scene.add_actors(a)
        elif value == 'sphere':
            a = actors.sphere_actor()
            scene.add_actors(a)
        elif value == 'plane_widget':
            a = tvtk.PlaneWidget()
            scene.add_actors(a)
        elif value == 'box_widget':
            a = tvtk.BoxWidget()
            scene.add_actors(a)
        self._current_actor = a
示例#16
0
class PreviewWindow(HasTraits):
    """ A window with a mayavi engine, to preview pipeline elements.
    """

    # The engine that manages the preview view
    _engine = Instance(Engine)

    _scene = Instance(SceneModel, ())

    view = View(Item('_scene',
                     editor=SceneEditor(scene_class=Scene),
                     show_label=False),
                width=500,
                height=500)

    #-----------------------------------------------------------------------
    # Public API
    #-----------------------------------------------------------------------

    def add_source(self, src):
        self._engine.add_source(src)

    def add_module(self, module):
        self._engine.add_module(module)

    def add_filter(self, filter):
        self._engine.add_module(filter)

    def clear(self):
        self._engine.current_scene.scene.disable_render = True
        self._engine.current_scene.children[:] = []
        self._engine.current_scene.scene.disable_render = False

    #-----------------------------------------------------------------------
    # Private API
    #-----------------------------------------------------------------------

    def __engine_default(self):
        e = Engine()
        e.start()
        e.new_scene(self._scene)
        return e
示例#17
0
class TubeDemoApp(HasTraits):
    radius1 = Range(0, 1.0, 0.8)
    radius2 = Range(0, 1.0, 0.4)
    scene = Instance(SceneModel, ())  #❶
    view = View(
        VGroup(
            Item(name="scene", editor=SceneEditor(scene_class=Scene)),  #❷
            HGroup("radius1", "radius2"),
            show_labels=False),
        resizable=True,
        height=500,
        width=500)

    def plot(self):
        r1, r2 = min(self.radius1, self.radius2), max(self.radius1,
                                                      self.radius2)
        self.cs1 = cs1 = tvtk.CylinderSource(height=1,
                                             radius=r2,
                                             resolution=32)
        self.cs2 = cs2 = tvtk.CylinderSource(height=1.1,
                                             radius=r1,
                                             resolution=32)
        triangle1 = tvtk.TriangleFilter(input_connection=cs1.output_port)
        triangle2 = tvtk.TriangleFilter(input_connection=cs2.output_port)
        bf = tvtk.BooleanOperationPolyDataFilter()
        bf.operation = "difference"
        bf.set_input_connection(0, triangle1.output_port)
        bf.set_input_connection(1, triangle2.output_port)
        m = tvtk.PolyDataMapper(input_connection=bf.output_port,
                                scalar_visibility=False)
        a = tvtk.Actor(mapper=m)
        a.property.color = 0.5, 0.5, 0.5
        self.scene.add_actors([a])
        self.scene.background = 1, 1, 1
        self.scene.reset_zoom()

    @on_trait_change("radius1, radius2")  #❹
    def update_radius(self):
        self.cs1.radius = max(self.radius1, self.radius2)
        self.cs2.radius = min(self.radius1, self.radius2)
        self.scene.render_window.render()
示例#18
0
        class MlabApp(HasTraits):

            # The scene model.
            scene = Instance(MlabSceneModel, ())

            view = View(Item(name='scene',
                             editor=SceneEditor(scene_class=MayaviScene),
                             show_label=False,
                             resizable=True,
                             height=500,
                             width=500),
                        resizable=True)

            def __init__(self, **traits):
                self.generate_data()

            def generate_data(self):
                # Create some data
                X, Y = mgrid[-2:2:100j, -2:2:100j]
                R = 10 * sqrt(X**2 + Y**2)
                Z = sin(R) / R
                self.scene.mlab.surf(X, Y, Z, colormap='gist_earth')
示例#19
0
class ActorView(HasTraits):
    scene = Instance(MlabSceneModel, ())

    View = View(Item(name="scene",
                     editor=SceneEditor(scene_class=MayaviScene),
                     show_label=False,
                     resizable=True,
                     height=500,
                     width=500),
                resizable=True)
    """docstring for ActorView"""
    def __init__(self, **traits):
        # super(ActorView, self).__init__()
        HasTraits.__init__(self, **traits)
        self.generate_data()

    def generate_data(self):
        x, y = mgrid[-2:2:100j, -2:2:100j]
        r = 10 * sqrt(x**2 + y**2)
        z = sin(r) / r

        self.scene.mlab.surf(x, y, z, colormap="cool")
示例#20
0
    def traits_view(self):
        #        self.scene.mlab.points3d(x, y, z)
        kw = dict()
        klass = Scene

        use_mayavi_toolbar = True
        use_raw_toolbar = False
        if use_mayavi_toolbar:
            klass = MayaviScene
        elif use_raw_toolbar:
            klass = None

        if klass is not None:
            kw['scene_class'] = klass

        v = View(
            Item('scene',
                 show_label=False,
                 height=400,
                 width=400,
                 resizable=True,
                 editor=SceneEditor(**kw)))
        return v
示例#21
0
class Visualization(HasTraits):
    alpha = Range(0.0, 4.0,  1.0/4)
    beta  = Range(0.0, 4.0,  1.0/4)
    scene      = Instance(MlabSceneModel, ())

    def __init__(self):
        # Do not forget to call the parent's __init__
        HasTraits.__init__(self)
        x, y, z, = tens_fld(1,1,1,self.beta, self.alpha)
        self.plot = self.scene.mlab.mesh(x, y, z, colormap='copper', representation='surface')

    @on_trait_change('beta,alpha')
    def update_plot(self):
        x, y, z, = tens_fld(1,1,1,self.beta, self.alpha)
        self.plot.mlab_source.set(x=x, y=y, z=z)


    # the layout of the dialog created
    view = View(Item('scene', editor=SceneEditor(scene_class=MayaviScene),
                    height=550, width=550, show_label=False),
                HGroup(
                        '_', 'beta', 'alpha',
                    ),
                )
示例#22
0
class MyDialog(HasTraits):
    scene2 = Instance(MlabSceneModel, ())
    scale_factor = 0.10
    points_picked = []
    count = 0
    point_cloud_number = 0
    cont = 0
    cont2 = 0

    # The layout of the dialog created
    view = View(
        HSplit(
            Group(
                Item('scene2',
                     editor=SceneEditor(),
                     height=250,
                     width=300,
                     show_label=False),
                'button2',
                'button1',
                'button3',
                show_labels=False,
            ), ),
        resizable=True,
    )

    def picker_callback(self, picker):
        self.cont2 += 1
        print
        print("Coordenadas: x: %.2f, y: %.2f  , z: %.2f" %
              (picker.pick_position))
        print picker.pick_position
        print

        [
            x,
            y,
            z,
        ] = picker.pick_position

        self.scene2.mlab.points3d(x,
                                  y,
                                  z,
                                  color=(0, 0, 0),
                                  scale_factor=0.05,
                                  mode='sphere',
                                  scale_mode='none',
                                  reset_zoom=False,
                                  figure=self.scene2.mayavi_scene)

        with open('Bbdd_CLoud.doc', "a") as file:
            file.write(" BBdd Frame: %i Objeto num: %i  " %
                       (self.cont, self.cont2))
            file.write(
                "Coordenada x: %.2f, Coordenada y: %.2f, Coordenada z: %.2f\n\n"
                % (picker.pick_position))
            file.close()

    button1 = Button('limpiar pantalla')
    button2 = Button('Siguiente')
    button3 = Button('Anterior')

    def picker_callback_2(self, picker):
        scale_factor = 0.05

        if self.count == 0: print 'Point I'
        elif self.count == 1: print 'Point J'
        elif self.count == 2: print 'Point M'
        elif self.count == 3: print 'Point L'

        self.points_picked.append(picker.pick_position)
        if self.count == 3:
            self.count = -1
            data = {
                'cloud_number:': self.cont + 1,
                'i': self.points_picked[0],
                'j': self.points_picked[1],
                'l': self.points_picked[2],
                'm': self.points_picked[3]
            }
            print data
            print 'Is the point collection correct? (y/n)'
            key = getkey()
            if key == 'y':
                print 'The result has been copied to file.'
                self.file_p = open('bbdd.dat', 'a')
                self.file_p.write(json.dumps(data))
                self.file_p.write('\n')
                self.file_p.close()
            else:
                print 'The point collection is not correct.'

            self.points_picked = []

        self.count += 1

        [x, y, z] = picker.pick_position
        points3d(x,
                 y,
                 z,
                 color=(0, 0, 0),
                 scale_factor=scale_factor,
                 mode='sphere',
                 scale_mode='none',
                 reset_zoom=False)

    @on_trait_change('scene2.activated')
    def picker_active(self):
        picker = self.scene2.mayavi_scene.on_mouse_pick(self.picker_callback_2,
                                                        type='world',
                                                        button='Right')

    @on_trait_change('button1')
    def clear_figure(self):
        p = 0
        while p < 4:
            for child in self.scene2.mayavi_scene.children:
                child.remove()

            p += 1

    @on_trait_change('button2')
    def siguiente(self):

        self.cont += 1

        print("Actualmente se esta representando Cloud_yaw_frame_number_" +
              str(self.cont + 1) + " y Cloud_yaw_frame_number_" +
              str(self.cont))

        cloud_source = pcl.load('cloud_yaw_frame_number_' +
                                str(self.cont + 1) + '.pcd')
        points_array = cloud_source.to_array()

        self.points3d_draw = points3d(points_array[:, 0:1],
                                      points_array[:, 1:2],
                                      points_array[:, 2:3],
                                      points_array[:, 2:3],
                                      mode='sphere',
                                      scale_mode='none',
                                      scale_factor=self.scale_factor,
                                      reset_zoom=False,
                                      figure=self.scene2.mayavi_scene)

        cloud_source = pcl.load('cloud_yaw_frame_number_' + str(self.cont) +
                                '.pcd')
        points_array = cloud_source.to_array()

        self.points3d_draw = points3d(points_array[:, 0],
                                      points_array[:, 1],
                                      points_array[:, 2],
                                      color=(1, 1, 1),
                                      mode='sphere',
                                      scale_mode='none',
                                      scale_factor=self.scale_factor,
                                      reset_zoom=False,
                                      figure=self.scene2.mayavi_scene)

    @on_trait_change('button3')
    def anterior(self):

        self.cont -= 1

        print("Actualmente se esta representando Cloud_yaw_frame_number_" +
              str(self.cont + 1) + " y Cloud_yaw_frame_number_" +
              str(self.cont))

        cloud_source = pcl.load('cloud_yaw_frame_number_' +
                                str(self.cont + 1) + '.pcd')
        points_array = cloud_source.to_array()

        self.points3d_draw = points3d(points_array[:, 0:1],
                                      points_array[:, 1:2],
                                      points_array[:, 2:3],
                                      points_array[:, 2:3],
                                      mode='sphere',
                                      scale_mode='none',
                                      scale_factor=self.scale_factor,
                                      reset_zoom=False,
                                      figure=self.scene2.mayavi_scene)

        cloud_source = pcl.load('cloud_yaw_frame_number_' + str(self.cont) +
                                '.pcd')
        points_array = cloud_source.to_array()

        self.points3d_draw = points3d(points_array[:, 0],
                                      points_array[:, 1],
                                      points_array[:, 2],
                                      color=(1, 1, 1),
                                      mode='sphere',
                                      scale_mode='none',
                                      scale_factor=self.scale_factor,
                                      reset_zoom=False,
                                      figure=self.scene2.mayavi_scene)
示例#23
0
class TDViz(HasTraits):
    fitsfile = File(filter=[u"*.fits"])
    plotbutton1 = Button(u"Plot")
    plotbutton2 = Button(u"Plot")
    plotbutton3 = Button(u"Plot")
    clearbutton = Button(u"Clear")
    scene = Instance(MlabSceneModel, ())
    rendering = Enum("Surface-Spectrum", "Surface-Intensity",
                     "Volume-Intensity")
    save_the_scene = Button(u"Save")
    save_in_file = Str("test.wrl")
    add_cut = Button(u"Cutthrough")
    remove_cut = Button(u"Remove the Last Cut")
    movie = Button(u"Movie")
    iteration = Int(0)
    quality = Int(8)
    delay = Int(0)
    angle = Int(360)
    spin = Button(u"Spin")
    zscale = Float(1.0)
    xstart = Float(0.0)
    xend = Float(1.0)
    ystart = Float(0.0)
    yend = Float(1.0)
    zstart = Float(0.0)
    zend = Float(1.0)
    datamin = Float(0.0)
    datamax = Float(1.0)
    opacity = Float(0.3)

    view = View(HSplit(
        VGroup(
            Item("fitsfile", label=u"Select a FITS datacube", show_label=True),
            Item("rendering",
                 tooltip=u"Choose the rendering type you like",
                 show_label=True),
            Item(
                'plotbutton1',
                tooltip=
                u"Plot the 3D scene with surface rendering, colored by spectrum",
                visible_when="rendering=='Surface-Spectrum'"),
            Item(
                'plotbutton2',
                tooltip=
                u"Plot the 3D scene with surface rendering, colored by intensity",
                visible_when="rendering=='Surface-Intensity'"),
            Item(
                'plotbutton3',
                tooltip=
                u"Plot the 3D scene with volume rendering, colored by intensity",
                visible_when="rendering=='Volume-Intensity'"),
            HGroup(
                Item('xstart',
                     tooltip=u"starting pixel in X axis",
                     show_label=True,
                     springy=True),
                Item('xend',
                     tooltip=u"ending pixel in X axis",
                     show_label=True,
                     springy=True)),
            HGroup(
                Item('ystart',
                     tooltip=u"starting pixel in Y axis",
                     show_label=True,
                     springy=True),
                Item('yend',
                     tooltip=u"ending pixel in Y axis",
                     show_label=True,
                     springy=True)),
            HGroup(
                Item('zstart',
                     tooltip=u"starting pixel in Z axis",
                     show_label=True,
                     springy=True),
                Item('zend',
                     tooltip=u"ending pixel in Z axis",
                     show_label=True,
                     springy=True)),
            HGroup(
                Item('datamax',
                     tooltip=u"Maximum datapoint shown",
                     show_label=True,
                     springy=True),
                Item('datamin',
                     tooltip=u"Minimum datapoint shown",
                     show_label=True,
                     springy=True)),
            Item('zscale',
                 tooltip=u"Stretch the datacube in Z axis",
                 show_label=True),
            Item('opacity', tooltip=u"Opacity of the scene", show_label=True),
            Item("add_cut", tooltip="Add a cutthrough view"),
            Item("remove_cut", tooltip="Remove all cutthroughs"),
            Item("spin", tooltip=u"Spin 360 degrees"),
            "clearbutton",
            Item('_'),
            Item("movie", tooltip="Make a GIF movie", show_label=False),
            HGroup(
                Item('iteration',
                     tooltip=u"number of iterations, 0 means inf.",
                     show_label=True),
                Item('quality',
                     tooltip=u"quality of plots, 0 is worst, 8 is good.",
                     show_label=True)),
            HGroup(
                Item('delay',
                     tooltip=u"time delay between frames, in millisecond.",
                     show_label=True),
                Item('angle', tooltip=u"angle the cube spins",
                     show_label=True)),
            Item('_'),
            Item(
                "save_the_scene",
                tooltip=u"Save current scene in a .wrl file",
                visible_when=
                "rendering=='Surface-Spectrum' or rendering=='Surface-Intensity'"
            ),
            Item("save_in_file",
                 tooltip=u"3D model file name",
                 show_label=True),
            show_labels=False),
        VGroup(Item(name='scene',
                    editor=SceneEditor(scene_class=MayaviScene),
                    resizable=True,
                    height=600,
                    width=600),
               show_labels=False)),
                resizable=True,
                title=u"TDViz")

    def _fitsfile_changed(self):
        img = pyfits.open(self.fitsfile)  # Read the fits data
        dat = img[0].data
        self.hdr = img[0].header

        naxis = self.hdr['NAXIS']
        ## The three axes loaded by pyfits are: velo, dec, ra
        ## Swap the axes, RA<->velo
        if naxis == 4:
            self.data = np.swapaxes(dat[0], 0, 2) * 1000.0
        elif naxis == 3:
            self.data = np.swapaxes(dat, 0, 2) * 1000.0
        #onevpix = self.hdr['CDELT3']
        self.data[np.isnan(self.data)] = 0.0
        self.data[np.isinf(self.data)] = 0.0

        self.datamax = np.asscalar(np.max(self.data))
        self.datamin = np.asscalar(np.min(self.data))
        self.xend = self.data.shape[0] - 1
        self.yend = self.data.shape[1] - 1
        self.zend = self.data.shape[2] - 1

        self.data[self.data < self.datamin] = self.datamin

    def loaddata(self):
        channel = self.data
        ## Reset the range if it is beyond the cube:
        if self.xstart < 0:
            print 'Wrong number!'
            self.xstart = 0
        if self.xend > channel.shape[0] - 1:
            print 'Wrong number!'
            self.xend = channel.shape[0] - 1
        if self.ystart < 0:
            print 'Wrong number!'
            self.ystart = 0
        if self.yend > channel.shape[1] - 1:
            print 'Wrong number!'
            self.yend = channel.shape[1] - 1
        if self.zstart < 0:
            print 'Wrong number!'
            self.zstart = 0
        if self.zend > channel.shape[2] - 1:
            print 'Wrong number!'
            self.zend = channel.shape[2] - 1
        ## Select a region, use mJy unit
        region = channel[self.xstart:self.xend, self.ystart:self.yend,
                         self.zstart:self.zend]

        ## Stretch the cube in V axis
        from scipy.interpolate import splrep
        from scipy.interpolate import splev
        vol = region.shape
        stretch = self.zscale
        ## Stretch parameter: how many times longer the V axis will be
        sregion = np.empty((vol[0], vol[1], vol[2] * stretch))
        chanindex = np.linspace(0, vol[2] - 1, vol[2])
        for j in range(0, vol[0] - 1):
            for k in range(0, vol[1] - 1):
                spec = region[j, k, :]
                tck = splrep(chanindex, spec, k=1)
                chanindex2 = np.linspace(0, vol[2] - 1, vol[2] * stretch)
                sregion[j, k, :] = splev(chanindex2, tck)
        self.sregion = sregion
        # Reset the max/min values
        if self.datamin < np.asscalar(np.min(self.sregion)):
            print 'Wrong number!'
            self.datamin = np.asscalar(np.min(self.sregion))
        if self.datamax > np.asscalar(np.max(self.sregion)):
            print 'Wrong number!'
            self.datamax = np.asscalar(np.max(self.sregion))
        self.xrang = abs(self.xstart - self.xend)
        self.yrang = abs(self.ystart - self.yend)
        self.zrang = abs(self.zstart - self.zend) * stretch

        ## Keep a record of the coordinates:
        crval1 = self.hdr['crval1']
        cdelt1 = self.hdr['cdelt1']
        crpix1 = self.hdr['crpix1']
        crval2 = self.hdr['crval2']
        cdelt2 = self.hdr['cdelt2']
        crpix2 = self.hdr['crpix2']
        crval3 = self.hdr['crval3']
        cdelt3 = self.hdr['cdelt3']
        crpix3 = self.hdr['crpix3']

        ra_start = (self.xstart + 1 - crpix1) * cdelt1 + crval1
        ra_end = (self.xend + 1 - crpix1) * cdelt1 + crval1
        #if ra_start < ra_end:
        #	ra_start, ra_end = ra_end, ra_start
        dec_start = (self.ystart + 1 - crpix2) * cdelt2 + crval2
        dec_end = (self.yend + 1 - crpix2) * cdelt2 + crval2
        #if dec_start > dec_end:
        #	dec_start, dec_end = dec_end, dec_start
        vel_start = (self.zstart + 1 - crpix3) * cdelt3 + crval3
        vel_end = (self.zend + 1 - crpix3) * cdelt3 + crval3
        #if vel_start < vel_end:
        #	vel_start, vel_end = vel_end, vel_start
        vel_start /= 1e3
        vel_end /= 1e3

        ## Flip the V axis
        if cdelt3 > 0:
            self.sregion = self.sregion[:, :, ::-1]
            vel_start, vel_end = vel_end, vel_start

        self.extent = [
            ra_start, ra_end, dec_start, dec_end, vel_start, vel_end
        ]

    def labels(self):
        '''
		Add 3d text to show the axes.
		'''
        fontsize = max(self.xrang, self.yrang) / 40.
        tcolor = (1, 1, 1)
        mlab.text3d(self.xrang / 2,
                    -40,
                    self.zrang + 40,
                    'R.A.',
                    scale=fontsize,
                    orient_to_camera=True,
                    color=tcolor)
        mlab.text3d(-40,
                    self.yrang / 2,
                    self.zrang + 40,
                    'Decl.',
                    scale=fontsize,
                    orient_to_camera=True,
                    color=tcolor)
        mlab.text3d(-40,
                    -40,
                    self.zrang / 2 - 10,
                    'V (km/s)',
                    scale=fontsize,
                    orient_to_camera=True,
                    color=tcolor)
        # Label the coordinates of the corners
        # Lower left corner
        ra0 = self.extent[0]
        dec0 = self.extent[2]
        c = coord.ICRS(ra=ra0, dec=dec0, unit=(u.degree, u.degree))
        RA_ll = str(int(c.ra.hms.h)) + 'h' + str(int(c.ra.hms.m)) + 'm' + str(
            round(c.ra.hms.s, 1)) + 's'
        mlab.text3d(0,
                    -20,
                    self.zrang + 20,
                    RA_ll,
                    scale=fontsize,
                    orient_to_camera=True,
                    color=tcolor)
        DEC_ll = str(int(c.dec.dms.d)) + 'd' + str(int(abs(
            c.dec.dms.m))) + 'm' + str(round(abs(c.dec.dms.s), 1)) + 's'
        mlab.text3d(-80,
                    0,
                    self.zrang + 20,
                    DEC_ll,
                    scale=fontsize,
                    orient_to_camera=True,
                    color=tcolor)
        # Upper right corner
        ra0 = self.extent[1]
        dec0 = self.extent[3]
        c = coord.ICRS(ra=ra0, dec=dec0, unit=(u.degree, u.degree))
        RA_ll = str(int(c.ra.hms.h)) + 'h' + str(int(c.ra.hms.m)) + 'm' + str(
            round(c.ra.hms.s, 1)) + 's'
        mlab.text3d(self.xrang,
                    -20,
                    self.zrang + 20,
                    RA_ll,
                    scale=fontsize,
                    orient_to_camera=True,
                    color=tcolor)
        DEC_ll = str(int(c.dec.dms.d)) + 'd' + str(int(abs(
            c.dec.dms.m))) + 'm' + str(round(abs(c.dec.dms.s), 1)) + 's'
        mlab.text3d(-80,
                    self.yrang,
                    self.zrang + 20,
                    DEC_ll,
                    scale=fontsize,
                    orient_to_camera=True,
                    color=tcolor)
        # V axis
        if self.extent[5] > self.extent[4]:
            v0 = self.extent[4]
            v1 = self.extent[5]
        else:
            v0 = self.extent[5]
            v1 = self.extent[4]
        mlab.text3d(-20,
                    -20,
                    self.zrang,
                    str(round(v0, 1)),
                    scale=fontsize,
                    orient_to_camera=True,
                    color=tcolor)
        mlab.text3d(-20,
                    -20,
                    0,
                    str(round(v1, 1)),
                    scale=fontsize,
                    orient_to_camera=True,
                    color=tcolor)

        mlab.axes(self.field,
                  ranges=self.extent,
                  x_axis_visibility=False,
                  y_axis_visibility=False,
                  z_axis_visibility=False)
        mlab.outline()

    def _plotbutton1_fired(self):
        mlab.clf()
        self.loaddata()
        self.sregion[np.where(self.sregion < self.datamin)] = self.datamin
        self.sregion[np.where(self.sregion > self.datamax)] = self.datamax

        # The following codes from: http://docs.enthought.com/mayavi/mayavi/auto/example_atomic_orbital.html#example-atomic-orbital
        field = mlab.pipeline.scalar_field(
            self.sregion)  # Generate a scalar field
        colored = self.sregion
        vol = self.sregion.shape
        for v in range(0, vol[2] - 1):
            colored[:, :,
                    v] = self.extent[4] + v * (-1) * abs(self.hdr['cdelt3'])
        new = field.image_data.point_data.add_array(colored.T.ravel())
        field.image_data.point_data.get_array(new).name = 'color'
        field.image_data.point_data.update()

        field2 = mlab.pipeline.set_active_attribute(field,
                                                    point_scalars='scalar')
        contour = mlab.pipeline.contour(field2)
        contour2 = mlab.pipeline.set_active_attribute(contour,
                                                      point_scalars='color')

        mlab.pipeline.surface(contour2, colormap='jet', opacity=self.opacity)

        ## Insert a continuum plot
        ##im = pyfits.open('g28_SMA1.cont.image.fits')
        ##dat = im[0].data
        ##dat0 = dat[0]
        ##channel = dat0[0]
        ##region = np.swapaxes(channel[self.xstart:self.xend,self.ystart:self.yend]*1000.,0,1)
        ##field = mlab.contour3d(region, colormap='gist_ncar')
        ##field.contour.minimum_contour = 5

        self.field = field
        self.labels()
        mlab.view(azimuth=0, elevation=0, distance='auto')
        mlab.show()

    def _plotbutton2_fired(self):
        mlab.clf()
        self.loaddata()
        #field=mlab.contour3d(self.sregion,colormap='gist_ncar')     # Generate a scalar field
        field = mlab.contour3d(self.sregion)  # Generate a scalar field
        field.contour.maximum_contour = self.datamax
        field.contour.minimum_contour = self.datamin
        field.actor.property.opacity = self.opacity

        self.field = field
        self.labels()
        mlab.view(azimuth=0, elevation=0, distance='auto')
        mlab.show()

    def _plotbutton3_fired(self):
        mlab.clf()
        self.loaddata()
        field = mlab.pipeline.scalar_field(
            self.sregion)  # Generate a scalar field
        mlab.pipeline.volume(field, vmax=self.datamax, vmin=self.datamin)

        self.field = field
        self.labels()
        mlab.view(azimuth=0, elevation=0, distance='auto')
        mlab.show()


#	def _datamax_changed(self):
#		if hasattr(self, "field"):
#			self.field.contour.maximum_contour = self.datamax

    def _add_cut_fired(self):
        self.cut = mlab.pipeline.scalar_cut_plane(self.field,
                                                  plane_orientation="x_axes")
        self.cut.enable_contours = True
        self.cut.contour.number_of_contours = 5

    def _remove_cut_fired(self):
        self.cut.stop()

    def _save_the_scene_fired(self):
        mlab.savefig(self.save_in_file)

    def _movie_fired(self):
        if os.path.exists("./tenpfigz"):
            print "The chance of you using this name is really small..."
        else:
            os.system("mkdir tenpfigz")

        if filter(os.path.isfile, glob.glob("./tenpfigz/*.png")) != []:
            os.system("rm -rf ./tenpfigz/*.png")

        i = 0
        ## Quality of the movie: 0 is the worst, 8 is ok.
        self.field.scene.anti_aliasing_frames = self.quality
        self.field.scene.disable_render = True
        mlab.savefig('./tenpfigz/screenshot0' + str(i) + '.png')
        while i < (self.angle / 5):
            self.field.scene.camera.azimuth(5)
            self.field.scene.render()
            i += 1
            if i < 10:
                mlab.savefig('./tenpfigz/screenshot0' + str(i) + '.png')
            elif 9 < i < 100:
                mlab.savefig('./tenpfigz/screenshot' + str(i) + '.png')
        self.field.scene.disable_render = False

        os.system("convert -delay " + str(self.delay) + " -loop " +
                  str(self.iteration) +
                  " ./tenpfigz/*.png ./tenpfigz/animation.gif")

    def _spin_fired(self):
        i = 0
        self.field.scene.disable_render = True

        @mlab.animate
        def anim():
            while i < 72:
                self.field.scene.camera.azimuth(5)
                self.field.scene.render()
                yield

        a = anim()
        #while i<72:
        #	self.field.scene.camera.azimuth(5)
        #	self.field.scene.render()
        #	i += 1
        #	#mlab.savefig('./'+str(i)+'.png')
        self.field.scene.disable_render = False

    def _clearbutton_fired(self):
        mlab.clf()
示例#24
0
class FieldViewer(HasTraits):

    # 三个轴的取值范围
    x0, x1 = Float(-5), Float(5)
    y0, y1 = Float(-5), Float(5)
    z0, z1 = Float(-5), Float(5)
    points = Int(50)  # 分割点数
    autocontour = Bool(True)  # 是否自动计算等值面
    v0, v1 = Float(0.0), Float(1.0)  # 等值面的取值范围
    contour = Range("v0", "v1", 0.5)  # 等值面的值
    function = Str("x*x*0.5 + y*y + z*z*2.0")  # 标量场函数
    function_list = [
        "x*x*0.5 + y*y + z*z*2.0", "x*y*0.5 + np.sin(2*x)*y +y*z*2.0", "x*y*z",
        "np.sin((x*x+y*y)/z)"
    ]
    plotbutton = Button("描画")
    scene = Instance(MlabSceneModel, ())  #❶

    view = View(
        HSplit(
            VGroup(
                "x0",
                "x1",
                "y0",
                "y1",
                "z0",
                "z1",
                Item('points', label="点数"),
                Item('autocontour', label="自动等值"),
                Item('plotbutton', show_label=False),
            ),
            VGroup(
                Item(
                    'scene',
                    editor=SceneEditor(scene_class=MayaviScene),  #❷
                    resizable=True,
                    height=300,
                    width=350),
                Item('function',
                     editor=EnumEditor(name='function_list',
                                       evaluate=lambda x: x)),
                Item('contour',
                     editor=RangeEditor(format="%1.2f",
                                        low_name="v0",
                                        high_name="v1")),
                show_labels=False)),
        width=500,
        resizable=True,
        title="三维标量场观察器")

    def _plotbutton_fired(self):
        self.plot()

    def plot(self):
        # 产生三维网格
        x, y, z = np.mgrid[  #❸
            self.x0:self.x1:1j * self.points, self.y0:self.y1:1j * self.points,
            self.z0:self.z1:1j * self.points]

        # 根据函数计算标量场的值
        scalars = eval(self.function)  #❹
        self.scene.mlab.clf()  # 清空当前场景

        # 绘制等值平面
        g = self.scene.mlab.contour3d(x,
                                      y,
                                      z,
                                      scalars,
                                      contours=8,
                                      transparent=True)  #❺
        g.contour.auto_contours = self.autocontour
        self.scene.mlab.axes(figure=self.scene.mayavi_scene)  # 添加坐标轴

        # 添加一个X-Y的切面
        s = self.scene.mlab.pipeline.scalar_cut_plane(g)
        cutpoint = (self.x0 + self.x1) / 2, (self.y0 + self.y1) / 2, (
            self.z0 + self.z1) / 2
        s.implicit_plane.normal = (0, 0, 1)  # x cut
        s.implicit_plane.origin = cutpoint

        self.g = g  #❻
        self.scalars = scalars
        # 计算标量场的值的范围
        self.v0 = np.min(scalars)
        self.v1 = np.max(scalars)

    def _contour_changed(self):  #❼
        if hasattr(self, "g"):
            if not self.g.contour.auto_contours:
                self.g.contour.contours = [self.contour]

    def _autocontour_changed(self):  #❽
        if hasattr(self, "g"):
            self.g.contour.auto_contours = self.autocontour
            if not self.autocontour:
                self._contour_changed()
示例#25
0
class Visualization(HasTraits):
    orientation = Range(0, 32, 0)
    scene = Instance(MlabSceneModel, ())
    vector_traversabilty_files = []
    vector_traversabilty_color_files = []

    def __init__(self, heightmap_png, dir_path, resolution,
                 height_scale_factor):
        HasTraits.__init__(self)
        self.heightmap_png = heightmap_png
        self.dir_path = dir_path
        self.resolution = resolution
        self.height_scale_factor = height_scale_factor

        self.hm = self.read_image_hm(self.heightmap_png)

        self.retrieve_vector_traversabilty_from_files(self.dir_path)
        self.max_n_orientations = np.shape(self.vector_traversabilty_files)[0]

        self.y, self.x = np.meshgrid(
            np.arange(self.hm.shape[0]) * resolution,
            np.arange(self.hm.shape[1]) * resolution)
        self.surface_hm = self.scene.mlab.surf(
            self.x,
            self.y,
            self.hm,
            color=(0.8, 0.8, 0.8),
            figure=self.scene.mayavi_scene)  #, warp_scale="auto")

        mask_t = self.retrieve_colormask_traversability(0, True)
        mask_f = self.retrieve_colormask_traversability(0, False)

        self.surface_hm_t = self.scene.mlab.surf(
            self.x,
            self.y,
            self.hm + 0.05,
            mask=mask_t,
            color=(0.6, 0.8, 0.6),
            figure=self.scene.mayavi_scene)  #, warp_scale="auto")

        self.surface_hm_f = self.scene.mlab.surf(
            self.x,
            self.y,
            self.hm + 0.05,
            mask=mask_f,
            color=(0.8, 0.6, 0.6),
            figure=self.scene.mayavi_scene)  #, warp_scale="auto")

    def read_image_hm(self, heightmap_png):
        # reads an image takint into account the scalling and the bitdepth
        hm = skimage.io.imread(heightmap_png)
        #print ("hm ndim: ",hm.ndim, "dtype: ", hm.dtype)
        if hm.ndim > 2:  #multiple channels
            hm = skimage.color.rgb2gray(
                hm)  #rgb2gray does the averaging and channel reduction
        elif hm.ndim == 2:  #already in one channel
            #this is mostly for the images treated in matlab beforehand (one channel + grayscale + 16bit)
            if hm.dtype == 'uint8':
                divided = 255
            if hm.dtype == 'uint16':
                divided = 65535
            hm = hm / divided
        hm = hm * self.height_scale_factor  #scaled to proper factor (mostly for testing, for training is 1.0)
        return hm

    def read_image_with_traversability(self, heightmap_png):
        # reads an image generated by our traversability estimation
        hm = skimage.io.imread(heightmap_png) / 255.0
        return hm

    def retrieve_vector_traversabilty_from_files(self, dir_path):
        #remember to read only pngs of the same size as the hm
        self.vector_traversabilty_files = []
        self.vector_traversabilty_color_files = []
        list_files = os.listdir(dir_path)
        list_files.sort()
        for filename in list_files:
            t_e_hm = self.read_image_with_traversability(dir_path + filename)
            mask_t = t_e_hm[:, :, 0] > 0.5
            mask_f = t_e_hm[:, :, 1] > 0.5
            self.vector_traversabilty_files.append([mask_t, mask_f])
            self.vector_traversabilty_color_files.append(t_e_hm)

    def traversability_color_mappping(self, idx_orientation):
        # builds a color vector wit traversability info (this is another way
        # to color the surface without using masks (time consumming to render))
        mask_t = self.retrieve_colormask_traversability(idx_orientation, True)
        mask_f = self.retrieve_colormask_traversability(idx_orientation, False)
        colors = np.zeros((mask_t.shape[0] * mask_t.shape[1]))
        for i in range(mask_t.shape[0]):
            for j in range(mask_t.shape[1]):
                if (mask_t[i, j] == True and (mask_t[i, j] == mask_f[i, j])):
                    colors[i * mask_t.shape[0] + j] = 0.1
                elif (mask_t[i, j] == True):
                    colors[i * mask_t.shape[0] + j] = 0.8
                else:
                    colors[i * mask_t.shape[0] + j] = 0.5
        return colors

    def retrieve_colormask_traversability(self, idx_orientation,
                                          istraversable):
        t_e_hm = self.vector_traversabilty_files[idx_orientation]
        mask = []
        if istraversable == True:
            mask = t_e_hm[0]
        else:
            mask = t_e_hm[1]
        return mask

    def compute_compose_min_traversability(self):
        # computes the minimal traversability map from oriented traversability maps
        # that were generated by function test_fitted_model_full_map_stride_cnn
        compose_t = None
        for shm in self.vector_traversabilty_color_files:
            if compose_t == None:  #first time, create image
                compose_t = np.ones((np.shape(shm)[0], np.shape(shm)[1]),
                                    dtype='float64')
                compose_t.fill(-1)
                for i in range(0, shm.shape[0]):
                    for j in range(0, shm.shape[1]):
                        if (shm[i, j, 0] == 1.0 and shm[i, j, 1] == 1.0
                                and shm[i, j, 2]
                                == 1.0):  #margins are  white with 0 alpha
                            compose_t[i, j] = -1
                        else:  # traversability is stored in alpha of each estimated image for x orientation
                            compose_t[i, j] = shm[i, j, 3]
            else:
                for i in range(0, shm.shape[0]):
                    for j in range(0, shm.shape[1]):
                        if (shm[i, j, 0] == 1.0 and shm[i, j, 1] == 1.0
                                and shm[i, j, 2] == 1.0):  #margins
                            compose_t[i, j] = compose_t[i, j]
                        else:  # traversability is stored in alpha of each estimated image for x orientation
                            compose_t[i, j] = min(compose_t[i, j], shm[i, j,
                                                                       3])
        return compose_t

    @on_trait_change('orientation')
    def update_plot(self):
        if self.orientation < self.max_n_orientations:
            #print ('current orientation idx:',self.orientation)
            mask_t = self.retrieve_colormask_traversability(
                self.orientation, True)
            mask_f = self.retrieve_colormask_traversability(
                self.orientation, False)

            self.surface_hm_t.mlab_source.set(mask=mask_t,
                                              scalars=self.hm + 0.05)
            self.surface_hm_f.mlab_source.set(mask=mask_f,
                                              scalars=self.hm + 0.05)
        else:
            self.orientation = self.max_n_orientations - 1

    # the layout of the dialog created
    view = View(
        Item(
            'scene',
            editor=SceneEditor(scene_class=MayaviScene),
            #height=250, width=300,
            show_label=False),
        HGroup(
            '_',
            'orientation',
        ),
        resizable=True)
示例#26
0
class FiducialsFrame(HasTraits):
    """GUI for interpolating between two KIT marker files.

    Parameters
    ----------
    subject : None | str
        Set the subject which is initially selected.
    subjects_dir : None | str
        Override the SUBJECTS_DIR environment variable.
    """

    model = Instance(MRIHeadWithFiducialsModel, ())

    scene = Instance(MlabSceneModel, ())
    headview = Instance(HeadViewController)

    spanel = Instance(SubjectSelectorPanel)
    panel = Instance(FiducialsPanel)

    mri_obj = Instance(SurfaceObject)
    point_scale = float(defaults['mri_fid_scale'])
    lpa_obj = Instance(PointObject)
    nasion_obj = Instance(PointObject)
    rpa_obj = Instance(PointObject)

    def _headview_default(self):
        return HeadViewController(scene=self.scene, system='RAS')

    def _panel_default(self):
        panel = FiducialsPanel(model=self.model, headview=self.headview)
        panel.trait_view('view', view2)
        return panel

    def _spanel_default(self):
        return SubjectSelectorPanel(model=self.model.subject_source)

    view = View(HGroup(Item('scene',
                            editor=SceneEditor(scene_class=MayaviScene),
                            dock='vertical'),
                       VGroup(headview_borders,
                              VGroup(Item('spanel', style='custom'),
                                     label="Subject",
                                     show_border=True,
                                     show_labels=False),
                              VGroup(Item('panel', style="custom"),
                                     label="Fiducials",
                                     show_border=True,
                                     show_labels=False),
                              show_labels=False),
                       show_labels=False),
                resizable=True,
                buttons=NoButtons)

    def __init__(self,
                 subject=None,
                 subjects_dir=None,
                 **kwargs):  # noqa: D102
        super(FiducialsFrame, self).__init__(**kwargs)

        subjects_dir = get_subjects_dir(subjects_dir)
        if subjects_dir is not None:
            self.spanel.subjects_dir = subjects_dir

        if subject is not None:
            if subject in self.spanel.subjects:
                self.spanel.subject = subject

    @on_trait_change('scene.activated')
    def _init_plot(self):
        _toggle_mlab_render(self, False)

        # bem
        color = defaults['mri_color']
        self.mri_obj = SurfaceObject(points=self.model.points,
                                     color=color,
                                     tri=self.model.tris,
                                     scene=self.scene)
        self.model.on_trait_change(self._on_mri_src_change, 'tris')
        self.panel.hsp_obj = self.mri_obj

        # fiducials
        for key in ('lpa', 'nasion', 'rpa'):
            attr = f'{key}_obj'
            setattr(
                self, attr,
                PointObject(scene=self.scene,
                            color=defaults[f'{key}_color'],
                            has_norm=True,
                            point_scale=self.point_scale))
            obj = getattr(self, attr)
            self.panel.sync_trait(key, obj, 'points', mutual=False)
            self.sync_trait('point_scale', obj, mutual=False)

        self.headview.left = True
        _toggle_mlab_render(self, True)

        # picker
        self.scene.mayavi_scene.on_mouse_pick(self.panel._on_pick, type='cell')

    def _on_mri_src_change(self):
        if (not np.any(self.model.points)) or (not np.any(self.model.tris)):
            self.mri_obj.clear()
            return

        self.mri_obj.points = self.model.points
        self.mri_obj.tri = self.model.tris
        self.mri_obj.plot()
示例#27
0
class MainWindow(HasTraits):
    scene = Instance(MlabSceneModel, ())
    TrA_Raw_file = File("TrA data")
    Chirp_file = File("Chirp data")
    Load_files = Button("Load data")
    Shiftzero = Button("Shift time zero")
    Ohioloader = Button("Ohio data loader")
    DeleteTraces = Button("Delete traces")
    Delete_spectra = Button('Delete spectra')
    fft_filter = Button('FFT filter')
    PlotChirp = Button("2D plot of chirp")
    Timelim = Array(np.float, (1, 2))
    Fix_Chirp = Button("Fix for chirp")
    Fit_Trace = Button("Fit Trace")
    Fit_Spec = Button("Fit Spectra")
    mcmc = Button("MCMC fitting")
    Fit_Chirp = Button("Fit chirp")
    SVD = Enum(1, 2, 3, 4, 5)
    SVD = Button("SVD on plot")
    EFA = Button("Evolving factor analysis")
    Traces_num = 0
    Multiple_Trace = Button("Select multiple traces")
    Global = Button("Global fit")
    title = Str("Welcome to PyTrA")
    z_height = Range(1, 100)
    Plot_3D = Button("3D plot")
    Plot_2D = Button("2D plot")
    Plot_log = Button("2D log plot")
    Plot_Traces = Button("Plot traces")
    multiple_plots = Button("Multiple traces/spectra on plot")
    Normalise = Button("Normalise")
    Kinetic_Trace = Button("Kinetic trace")
    Spectra = Button("Spectra")
    Trace_Igor = Button("Send traces to Igor")
    Global = Button("Global fit")
    Save_Glo = Button("Save as Glotaran file")
    Save_csv = Button("Save csv with title as file name")
    Save_log = Button("Save log file")
    Help = Button("Help")
    log = Str(
        "PyTrA:Python based fitting of Ultra-fast Transient Absorption Data")

    #Setting up views

    buttons_group = Group(
        Item('title', show_label=False),
        Item('TrA_Raw_file', style='simple', show_label=False),
        Item('Chirp_file', style='simple', show_label=False),
        Item('Load_files', show_label=False),
        Item('Ohioloader', show_label=False),
        Item('DeleteTraces', show_label=False),
        Item('Delete_spectra', show_label=False),
        Item('fft_filter', show_label=False),
        Item('Shiftzero', show_label=False), Label('Chirp Correction'),
        Item('PlotChirp', show_label=False),
        Label('Time range for chirp corr short/long'),
        Item('Timelim', show_label=False), Item('Fix_Chirp', show_label=False),
        Label('Data Analysis'), Item('Fit_Trace', show_label=False),
        Item('Fit_Spec', show_label=False), Item('mcmc', show_label=False),
        Item('Plot_2D', show_label=False), Item('SVD', show_label=False),
        Item('EFA', show_label=False), Label('Global fitting'),
        Item('Multiple_Trace', show_label=False),
        Item('Trace_Igor', show_label=False),
        Item('Plot_Traces', show_label=False), Item('Global',
                                                    show_label=False),
        Label('Visualisation'), Item('Plot_3D', show_label=False),
        Item('z_height', show_label=False), Item('Plot_2D', show_label=False),
        Item('Plot_log', show_label=False), Item('Spectra', show_label=False),
        Item('Kinetic_Trace', show_label=False),
        Item('multiple_plots', show_label=False),
        Item('Normalise', show_label=False), Label('Export Data'),
        Item('Save_csv', show_label=False), Item('Save_log', show_label=False),
        Item('Save_Glo', show_label=False), Item('Help', show_label=False))

    threed_group = Group(Item('scene',
                              editor=SceneEditor(scene_class=MayaviScene),
                              height=600,
                              width=800,
                              show_label=False),
                         label='3d graph')

    log_group = Group(Item('log', style='custom', show_label=False),
                      label='log file')

    view = View(
        HSplit(
            buttons_group,
            Tabbed(
                threed_group,
                log_group,
            ),
        ),
        title='PyTrA',
        resizable=True,
    )

    def _Load_files_fired(self):
        # Load TrA file into array depends on extension
        Data.filename = self.TrA_Raw_file
        TrA_Raw_file_name, TrA_Raw_file_extension = os.path.splitext(
            self.TrA_Raw_file)
        TrA_Raw_file_dir, TrA_Raw_file_name = os.path.split(self.TrA_Raw_file)
        TrA_Raw_name, TrA_Raw_ex = os.path.splitext(TrA_Raw_file_name)
        self.title = TrA_Raw_name

        if TrA_Raw_file_extension == '.csv':
            TrA_Raw_T = genfromtxt(self.TrA_Raw_file,
                                   delimiter=',',
                                   filling_values='0')
        elif TrA_Raw_file_extension == '.txt':
            TrA_Raw_T = genfromtxt(self.TrA_Raw_file,
                                   delimiter=' ',
                                   filling_values='0')

        # Take transponse of matrix

        TrA_Raw = TrA_Raw_T.transpose()

        # Extracts out Data and column values

        TrA_Raw_m, TrA_Raw_n = TrA_Raw.shape

        Data.time = TrA_Raw[1:TrA_Raw_m, 0]
        Data.wavelength = TrA_Raw[0, 1:TrA_Raw_n]
        Data.TrA_Data = TrA_Raw[1:TrA_Raw_m, 1:TrA_Raw_n]

        # deleting last time if equal to zero this occurs if data is saved in excel

        if Data.TrA_Data[-1, 0] == 0:
            Data.TrA_Data = Data.TrA_Data[0:-1, :]
            Data.time = Data.time[0:-1]

        #Sort data into correct order
        inds = Data.wavelength.argsort()
        Data.TrA_Data = Data.TrA_Data[:, inds]
        Data.wavelength = Data.wavelength[inds]

        indst = Data.time.argsort()
        Data.TrA_Data = Data.TrA_Data[indst, :]
        Data.time = Data.time[indst]

        # Importing Chirp data

        try:
            Chirp_file_name, Chirp_file_extension = os.path.splitext(
                self.Chirp_file)

            if Chirp_file_extension == '.csv':
                Chirp_Raw_T = genfromtxt(self.Chirp_file,
                                         delimiter=',',
                                         filling_values='0')
            if Chirp_file_extension == '.txt':
                Chirp_Raw_T = genfromtxt(self.Chirp_file,
                                         delimiter=' ',
                                         filling_values='0')

            Chirp_Raw = Chirp_Raw_T.transpose()

            Chirp_Raw_m, Chirp_Raw_n = Chirp_Raw.shape

            Data.time_C = Chirp_Raw[1:TrA_Raw_m, 0]
            Data.wavelength_C = Chirp_Raw[0, 1:Chirp_Raw_n]
            Data.Chirp = Chirp_Raw[1:Chirp_Raw_m, 1:Chirp_Raw_n]

        except:
            self.log = ("%s\n---\nNo Chirp found" % (self.log))

        self.log = (
            '%s\nData file imported of size t=%s and wavelength=%s name=%s' %
            (self.log, Data.TrA_Data.shape[0], Data.TrA_Data.shape[1],
             TrA_Raw_name))

    def _Ohioloader_fired(self):
        ohio = OhioLoader().edit_traits()
        self.log = ('%s\n---\nData file imported of size %s by %s' %
                    (self.log, Data.TrA_Data.shape[0], Data.TrA_Data.shape[1]))

    def _fft_filter_fired(self):
        fft_live = FFTfilter().edit_traits()

    def _Shiftzero_fired(self):
        plt.figure()
        plt.contourf(Data.wavelength, Data.time[1:20], Data.TrA_Data[1:20, :],
                     100)
        plt.title('Pick time zero')
        plt.xlabel('Wavelength')
        plt.ylabel('Time')
        fittingto = np.array(ginput(1))
        plt.show()
        plt.close()

        Data.time = Data.time - fittingto[0][1]

        self.log = "%s\n---\nDeleted traces between %s and %s" % (
            self.log, fittingto[0, 0], fittingto[1, 0])

    def _DeleteTraces_fired(self):
        plt.figure()
        plt.contourf(Data.wavelength, Data.time, Data.TrA_Data, 100)
        plt.title('Pick between wavelength to delete (left to right)')
        plt.xlabel('Wavelength')
        plt.ylabel('Time')
        fittingto = np.array(ginput(2))
        plt.show()
        plt.close()

        index_wavelength_left = (np.abs(Data.wavelength -
                                        fittingto[0, 0])).argmin()
        index_wavelength_right = (np.abs(Data.wavelength -
                                         fittingto[1, 0])).argmin() + 1

        if index_wavelength_right <= index_wavelength_left:
            hold = index_wavelength_left
            index_wavelength_left = index_wavelength_right
            index_wavelength_right = hold

        if index_wavelength_left == 0:
            Data.TrA_Data = Data.TrA_Data[:, index_wavelength_right:]
            Data.wavelength = Data.wavelength[index_wavelength_right:]

        if index_wavelength_right == Data.wavelength.shape:
            Data.TrA_Data = Data.TrA_Data[:, :index_wavelength_left]
            Data.wavelength = Data.wavelength[:index_wavelength_left]

        if index_wavelength_left != 0 & index_wavelength_right != Data.wavelength.shape:
            Data.TrA_Data = np.hstack(
                (Data.TrA_Data[:, :index_wavelength_left],
                 Data.TrA_Data[:, index_wavelength_right:]))
            Data.wavelength = np.hstack(
                (Data.wavelength[:index_wavelength_left],
                 Data.wavelength[index_wavelength_right:]))

        self.log = "%s\n---\nDeleted traces between %s and %s" % (
            self.log, fittingto[0, 0], fittingto[1, 0])

    def _Delete_spectra_fired(self):
        plt.figure()
        plt.contourf(Data.wavelength, Data.time, Data.TrA_Data, 100)
        plt.title('Pick between times to delete (top to bottom)')
        plt.xlabel('Wavelength')
        plt.ylabel('Time')
        fittingto = np.array(ginput(2))
        plt.show()
        plt.close()

        index_time_top = (np.abs(Data.time - fittingto[1, 1])).argmin()
        index_time_bottom = (np.abs(Data.time - fittingto[0, 1])).argmin() + 1

        if index_time_bottom <= index_time_top:
            hold = index_time_top
            index_time_top = index_time_bottom
            index_time_bottom = hold

        if index_time_top == 0:
            Data.TrA_Data = Data.TrA_Data[index_time_bottom:, :]
            Data.time = Data.time[index_time_bottom:]

        if index_time_bottom == Data.time.shape:
            Data.TrA_Data = Data.TrA_Data[:index_time_top, :]
            Data.time = Data.time[:index_time_top]

        if index_time_top != 0 & index_time_bottom != Data.time.shape:
            Data.TrA_Data = np.vstack((Data.TrA_Data[:index_time_top, :],
                                       Data.TrA_Data[index_time_bottom:, :]))
            Data.time = np.hstack(
                (Data.time[:index_time_top], Data.time[index_time_bottom:]))

        self.log = "%s\n---\nDeleted spectra between %s and %s" % (
            self.log, fittingto[0, 1], fittingto[1, 1])

    def _PlotChirp_fired(self):
        plt.figure()
        plt.contourf(Data.wavelength_C, Data.time_C, Data.Chirp, 100)
        plt.title('%s Chirp' % (self.title))
        plt.xlabel('Wavelength')
        plt.ylabel('Time')
        plt.show()

    def _Timelim_changed(self):
        Data.Range = self.Timelim

    def _Fix_Chirp_fired(self):
        #plot file and pick points for graphing
        plt.figure(figsize=(20, 12))
        plt.title('Pick 8 points')
        plt.xlabel('Wavelength')
        plt.ylabel('Time')
        plt.contourf(Data.wavelength_C, Data.time_C, Data.Chirp, 20)
        plt.ylim((Data.Range[0][0], Data.Range[0][1]))
        polypts = np.array(ginput(8))
        plt.show()
        plt.close()

        #Fit a polynomial of the form p(x) = p[2] + p[1] + p[0]
        fitcoeff, residuals, rank, singular_values, rcond = np.polyfit(
            polypts[:, 0], polypts[:, 1], 2, full=True)

        stdev = np.sum(residuals**2) / 8

        #finding where zero time is
        idx = (np.abs(Data.time - 0)).argmin()
        plt.figure(figsize=(20, 12))
        plt.title("Pick point on wave front")
        plt.xlabel('Wavelength')
        plt.ylabel('Time')
        plt.contourf(Data.wavelength, Data.time[idx - 1:idx + 10],
                     Data.TrA_Data[idx - 1:idx + 10, :], 100)
        fittingto = np.array(ginput(1)[0])
        plt.show()
        plt.close()

        #Moves the chirp inorder to correct coefficient
        fitcoeff[2] = (fitcoeff[0] * fittingto[0]**2 +
                       fitcoeff[1] * fittingto[0] + fittingto[1]) * -1

        #Iterate over the wavelengths and interpolate for the corrected values

        for i in range(0, len(Data.wavelength), 1):

            correcttimeval = np.polyval(fitcoeff, Data.wavelength[i])
            f = interpolate.interp1d((Data.time - correcttimeval),
                                     (Data.TrA_Data[:, i]),
                                     bounds_error=False,
                                     fill_value=0)
            fixed_wave = f(Data.time)
            Data.TrA_Data[:, i] = fixed_wave

        self.log = "%s\n---\nPolynomial fit with form %s*x^2 + %s*x + %s stdev %s" % (
            self.log, fitcoeff[0], fitcoeff[1], fitcoeff[2], stdev)

    def _Fit_Trace_fired(self):
        plt.figure()
        plt.contourf(Data.wavelength, Data.time, Data.TrA_Data, 100)
        plt.title('Pick wavelength to fit')
        plt.xlabel('Wavelength')
        plt.ylabel('Time')
        fittingto = np.array(ginput(1))
        plt.show()
        plt.close()

        index_wavelength = (np.abs(Data.wavelength - fittingto[:, 0])).argmin()
        Data.tracefitmodel = fitgui.fit_data(
            Data.time,
            Data.TrA_Data[:, index_wavelength],
            autoupdate=False,
            model=Convoluted_exp1,
            include_models=
            'Convoluted_exp1,Convoluted_exp2,Convoluted_exp3,Convoluted_exp4,doublegaussian,doubleopposedgaussian,gaussian'
        )

        #If you want to have the fitting gui in another window while PyTrA remains responsive change the fit model to a model instance and use the line bellow to call it
        #Data.tracefitmodel.edit_traits()

        results_error = Data.tracefitmodel.getCov().diagonal()
        results_par = Data.tracefitmodel.params
        results = Data.tracefitmodel.parvals

        self.log = (
            '%s\n---\nFitted parameters at wavelength %s \nFitting parameters'
            % (self.log, fittingto[:, 0]))

        for i in range(len(results)):
            self.log = (
                '%s\n%s = %s +- %s' %
                (self.log, results_par[i], results[i], results_error[i]))

    def _Fit_Spec_fired(self):
        plt.figure()
        plt.contourf(Data.wavelength, Data.time, Data.TrA_Data, 100)
        plt.title('Pick time to fit')
        plt.xlabel('Wavelength')
        plt.ylabel('Time')
        fittingto = np.array(ginput(1))
        plt.show()
        plt.close()

        index_time = (np.abs(Data.time - fittingto[:, 1])).argmin()
        Data.tracefitmodel = fitgui.fit_data(Data.wavelength,
                                             Data.TrA_Data[index_time, :],
                                             autoupdate=False)

        #If you want to have the fitting gui in another window while PyTrA remains responsive change the fit model to a model instance and use the line bellow to call it
        #Data.tracefitmodel.edit_traits()

        results_error = Data.tracefitmodel.getCov().diagonal()
        results_par = Data.tracefitmodel.params
        results = Data.tracefitmodel.parvals

        self.log = (
            '%s\n---\nFitted parameters at time %s \nFitting parameters' %
            (self.log, fittingto[:, 1]))

        for i in range(len(results)):
            self.log = (
                '%s\n---\n%s = %s +- %s' %
                (self.log, results_par[i], results[i], results_error[i]))

    def _mcmc_fired(self):
        mcmc_app = mcmc.MCMC_1(parameters=[
            mcmc.Params(name=i) for i in Data.tracefitmodel.params
        ])
        mcmc_app.edit_traits(kind='livemodal')
        mcmc_app = mcmc.MCMC_1(parameters=[])
        self.log = (
            '%s\n---\n---MCMC sampler summary (pymc)---\nBayesian Information Criterion = %s'
            % (self.log, Data.mcmc['MAP']))
        for i in Data.tracefitmodel.params:
            self.log = ('%s\n%s,mean %s,stdev %s' %
                        (self.log, i, Data.mcmc['MCMC'].stats()[i]['mean'],
                         Data.mcmc['MCMC'].stats()[i]['standard deviation']))

    def _Global_fired(self):
        global_app = glo.Global(
            parameters=[glo.Params(name=i) for i in Data.tracefitmodel.params])
        global_app.edit_traits(kind='livemodal')
        global_app = glo.Global(parameters=[])

    def _SVD_fired(self):

        xmin, xmax = plt.xlim()
        ymin, ymax = plt.ylim()

        index_wavelength_left = (np.abs(Data.wavelength - xmin)).argmin()
        index_wavelength_right = (np.abs(Data.wavelength - xmax)).argmin()

        index_time_left = (np.abs(Data.time - ymin)).argmin()
        index_time_right = (np.abs(Data.time - ymax)).argmin()

        U, s, V_T = linalg.svd(
            Data.TrA_Data[index_time_left:index_time_right,
                          index_wavelength_left:index_wavelength_right])

        f = plt.figure()
        f.text(0.5,
               0.975, ("SVD %s" % (self.title)),
               horizontalalignment='center',
               verticalalignment='top')
        plt.subplot(341)
        plt.plot(Data.time[index_time_left:index_time_right], U[:, 0])
        plt.title("1")
        plt.xlabel("time (ps)")
        plt.ylabel("abs.")
        plt.subplot(342)
        plt.plot(Data.time[index_time_left:index_time_right], U[:, 1])
        plt.title("2")
        plt.xlabel("time (ps)")
        plt.ylabel("abs.")
        plt.subplot(343)
        plt.plot(Data.time[index_time_left:index_time_right], U[:, 2])
        plt.title("3")
        plt.xlabel("time (ps)")
        plt.ylabel("abs.")
        plt.subplot(344)
        plt.plot(Data.time[index_time_left:index_time_right], U[:, 3])
        plt.title("4")
        plt.xlabel("time (ps)")
        plt.ylabel("abs.")
        plt.subplot(345)
        plt.plot(Data.wavelength[index_wavelength_left:index_wavelength_right],
                 V_T[0, :])
        plt.title("%s" % (s[0]))
        plt.xlabel("wavelength (nm)")
        plt.ylabel("abs.")
        plt.subplot(346)
        plt.plot(Data.wavelength[index_wavelength_left:index_wavelength_right],
                 V_T[1, :])
        plt.title("%s" % (s[1]))
        plt.xlabel("wavelength (nm)")
        plt.ylabel("abs.")
        plt.subplot(347)
        plt.plot(Data.wavelength[index_wavelength_left:index_wavelength_right],
                 V_T[2, :])
        plt.title("%s" % (s[2]))
        plt.xlabel("wavelength (nm)")
        plt.ylabel("abs.")
        plt.subplot(348)
        plt.plot(Data.wavelength[index_wavelength_left:index_wavelength_right],
                 V_T[3, :])
        plt.title("%s" % (s[3]))
        plt.xlabel("wavelength (nm)")
        plt.ylabel("abs.")
        plt.subplot(349)
        [SVD_1_x, SVD_1_y] = np.meshgrid(V_T[0, :], U[:, 0])
        SVD_1 = np.multiply(SVD_1_x, SVD_1_y) * s[0]
        plt.contourf(
            Data.wavelength[index_wavelength_left:index_wavelength_right],
            Data.time[index_time_left:index_time_right], SVD_1, 50)
        plt.subplot(3, 4, 10)
        [SVD_2_x, SVD_2_y] = np.meshgrid(V_T[1, :], U[:, 1])
        SVD_2 = np.multiply(SVD_2_x, SVD_2_y) * s[1]
        plt.contourf(
            Data.wavelength[index_wavelength_left:index_wavelength_right],
            Data.time[index_time_left:index_time_right], SVD_2, 50)
        plt.subplot(3, 4, 11)
        [SVD_3_x, SVD_3_y] = np.meshgrid(V_T[2, :], U[:, 2])
        SVD_3 = np.multiply(SVD_3_x, SVD_3_y) * s[2]
        plt.contourf(
            Data.wavelength[index_wavelength_left:index_wavelength_right],
            Data.time[index_time_left:index_time_right], SVD_3, 50)
        plt.subplot(3, 4, 12)
        [SVD_4_x, SVD_4_y] = np.meshgrid(V_T[3, :], U[:, 3])
        SVD_4 = np.multiply(SVD_4_x, SVD_4_y) * s[3]
        plt.contourf(
            Data.wavelength[index_wavelength_left:index_wavelength_right],
            Data.time[index_time_left:index_time_right], SVD_4, 50)
        plt.subplots_adjust(left=0.03,
                            bottom=0.05,
                            right=0.99,
                            top=0.94,
                            wspace=0.2,
                            hspace=0.2)
        plt.show()

        plt.figure()
        plt.semilogy(s[0:9], '*')
        plt.title("First 10 singular values")
        plt.show()

        self.log = "%s\nFirst 5 singular values %s in range wavelength %s to %s, time %s to %s" % (
            self.log, s[0:5], xmin, xmax, ymin, ymax)

    def _EFA_fired(self):

        #number of singular values to track
        singvals = 3

        #Time
        rows = Data.TrA_Data.shape[0]
        forward_r = np.zeros((rows, singvals))
        backward_r = np.zeros((rows, singvals))

        stepl_r = rows - singvals
        #Forward

        #Must start with number of tracked singular values in order to intially generate 10 SV
        for i in range(singvals, rows):
            partsvd = linalg.svdvals(Data.TrA_Data[:i, :]).T
            forward_r[i, :] = partsvd[:singvals]

        #Backwards

        for i in range(0, stepl_r):
            j = (rows - singvals) - i
            partsvd = linalg.svdvals(Data.TrA_Data[j:, :]).T
            backward_r[j, :] = partsvd[:singvals]

        plt.figure()
        plt.semilogy(Data.time[singvals:], forward_r[singvals:, :], 'b',
                     Data.time[:(rows - singvals)],
                     backward_r[:(rows - singvals), :], 'r')
        plt.title("%s EFA time" % (self.title))
        plt.xlabel("Time (ps)")
        plt.ylabel("Log(EV)")
        plt.show()

        #Wavelength

        cols = Data.TrA_Data.shape[1]
        forward_c = np.zeros((cols, singvals))
        backward_c = np.zeros((cols, singvals))

        stepl_c = cols - singvals
        #Forward

        #Must start with number of tracked singular values in order to intially generate 10 SV
        for i in range(singvals, cols):
            partsvd = linalg.svdvals(Data.TrA_Data[:, :i])
            forward_c[i, :] = partsvd[:singvals]

        #Backwards

        for i in range(0, stepl_c):
            j = (cols - singvals) - i
            partsvd = linalg.svdvals(Data.TrA_Data[:, j:])
            backward_c[j, :] = partsvd[:singvals]

        plt.figure()
        plt.semilogy(Data.wavelength[singvals:], forward_c[singvals:, :], 'b',
                     Data.wavelength[:cols - singvals],
                     backward_c[:cols - singvals, :], 'r')
        plt.title("%s EFA wavelength" % (self.title))
        plt.xlabel("Wavelength (nm)")
        plt.ylabel("Log(EV)")
        plt.show()

    def _Multiple_Trace_fired(self):
        self.Traces_num = 0
        Data.Traces = 0

        plt.figure(figsize=(15, 10))
        plt.contourf(Data.wavelength, Data.time, Data.TrA_Data, 100)
        plt.title('Pick between wavelength to fit (left to right)')
        plt.xlabel('Wavelength')
        plt.ylabel('Time')
        fittingto = np.array(ginput(2))
        plt.show()
        plt.close()

        index_wavelength_left = (np.abs(Data.wavelength -
                                        fittingto[0, 0])).argmin()
        index_wavelength_right = (np.abs(Data.wavelength -
                                         fittingto[1, 0])).argmin()

        Data.Traces = Data.TrA_Data[:, index_wavelength_left:
                                    index_wavelength_right].transpose()

        self.log = '%s\n\n%s Traces saved from %s to %s' % (
            self.log, Data.Traces.shape[0], fittingto[0, 0], fittingto[1, 0])

    def _Plot_3D_fired(self):

        xmin, xmax = plt.xlim()
        ymin, ymax = plt.ylim()

        index_wavelength_left = (np.abs(Data.wavelength - xmin)).argmin()
        index_wavelength_right = (np.abs(Data.wavelength - xmax)).argmin()

        index_time_left = (np.abs(Data.time - ymin)).argmin()
        index_time_right = (np.abs(Data.time - ymax)).argmin()

        Data.Three_d = Data.TrA_Data[
            index_time_left:index_time_right,
            index_wavelength_left:index_wavelength_right]
        Data.Three_d_wavelength = Data.wavelength[
            index_wavelength_left:index_wavelength_right]
        Data.Three_d_time = Data.time[index_time_left:index_time_right]

        self.scene.mlab.clf()

        #Gets smallest spacing to use to construct mesh
        y_step = Data.time[(np.abs(Data.time - 0)).argmin() +
                           1] - Data.time[(np.abs(Data.time - 0)).argmin()]

        x = np.linspace(Data.Three_d_wavelength[0],
                        Data.Three_d_wavelength[-1],
                        len(Data.Three_d_wavelength))
        y = np.arange(Data.Three_d_time[0], Data.Three_d_time[-1], y_step)
        print y.shape
        [xi, yi] = np.meshgrid(x, y)

        for i in range(len(Data.Three_d_wavelength)):
            repeating_wavelength = np.array(
                np.ones((len(Data.Three_d_time))) * Data.Three_d_wavelength[i])
            vectors = np.array(
                [Data.Three_d_time, repeating_wavelength, Data.Three_d[:, i]])
            if i == 0:
                Data.TrA_Data_gridded = vectors
            else:
                Data.TrA_Data_gridded = np.hstack(
                    (Data.TrA_Data_gridded, vectors))

        zi = interpolate.griddata(
            (Data.TrA_Data_gridded[1, :], Data.TrA_Data_gridded[0, :]),
            Data.TrA_Data_gridded[2, :], (xi, yi),
            method='linear',
            fill_value=0)

        #Sends 3D plot to mayavi in gui

        #uncomment for plotting actual data matrix
        #self.scene.mlab.surf(Data.time,Data.wavelength,Data.TrA_Data,warp_scale=-self.z_height*100)
        #gridded plot which gives correct view
        self.scene.mlab.surf(yi, xi, zi, warp_scale=-self.z_height * 100)
        self.scene.mlab.imshow(yi, xi, zi)
        self.scene.mlab.colorbar(orientation="vertical")
        self.scene.mlab.axes(nb_labels=5, )
        self.scene.mlab.ylabel("wavelength (nm)")
        self.scene.mlab.xlabel("time (ps)")

    def _z_height_changed(self):
        # Need to work out how to just modify the the warp scalar without redrawing
        self._Plot_3D_fired()

    def _Plot_2D_fired(self):
        plt.figure()
        plt.contourf(Data.wavelength, Data.time, Data.TrA_Data, 200)
        plt.xlabel('Wavelength (nm)')
        plt.ylabel('Times (ps)')
        plt.title(self.title)
        plt.colorbar()
        plt.show()

    def _Plot_log_fired(self):
        plt.figure()
        plt.contourf(Data.wavelength, Data.time, Data.TrA_Data, 100)
        plt.xlabel('Wavelength (nm)')
        plt.ylabel('Times (ps)')
        plt.title(self.title)
        plt.colorbar()
        plt.yscale('symlog',
                   basey=10,
                   linthreshy=(-100, -0.1),
                   subsy=[0, 1, 2, 3, 4])
        plt.show()

    def _Plot_Traces_fired(self):
        plt.figure(figsize=(15, 10))
        plt.plot(Data.time, Data.Traces.transpose())
        plt.title("%s Traces" % (self.title))
        plt.xlabel('Time')
        plt.ylabel('Abs')
        plt.show()

    def _Kinetic_Trace_fired(self):

        plt.figure()
        plt.contourf(Data.wavelength, Data.time, Data.TrA_Data, 100)
        plt.title('Pick wavelength')
        plt.xlabel('Wavelength')
        plt.ylabel('Time')
        fittingto = np.array(ginput(1))
        plt.show()
        plt.close()

        index_wavelength = (np.abs(Data.wavelength - fittingto[:, 0])).argmin()

        plt.figure(figsize=(20, 12))
        plt.plot(Data.time, Data.TrA_Data[:, index_wavelength])
        plt.title("%s %s" % (self.title, Data.wavelength[index_wavelength]))
        plt.xlabel('Time')
        plt.ylabel('Abs')
        plt.show()

    def _Spectra_fired(self):

        plt.figure()
        plt.contourf(Data.wavelength, Data.time, Data.TrA_Data, 100)
        plt.title('Pick time')
        plt.xlabel('Wavelength')
        plt.ylabel('Time')
        fittingto = np.array(ginput(1))
        plt.show()
        plt.close()

        index_time = (np.abs(Data.time - fittingto[:, 1])).argmin()

        plt.figure()
        plt.plot(Data.wavelength, Data.TrA_Data[index_time, :])
        plt.title("%s %s" % (self.title, Data.time[index_time]))
        plt.xlabel('Wavelength')
        plt.ylabel('Abs')
        plt.show()

    def _multiple_plots_fired(self):

        xmin, xmax = plt.xlim()
        ymin, ymax = plt.ylim()

        index_wavelength_left = (np.abs(Data.wavelength - xmin)).argmin()
        index_wavelength_right = (np.abs(Data.wavelength - xmax)).argmin()

        index_time_left = (np.abs(Data.time - ymin)).argmin()
        index_time_right = (np.abs(Data.time - ymax)).argmin()

        indexwave = int((index_wavelength_right - index_wavelength_left) / 10)

        # spectrum from every 10th spectra

        timevec = np.ones(
            [Data.time[index_time_left:index_time_right].shape[0], 10])
        time = np.ones(
            [Data.time[index_time_left:index_time_right].shape[0], 10])
        wavelengthvals = np.ones(10)

        for i in range(10):
            timevec[:, i] = np.average(
                Data.TrA_Data[index_time_left:index_time_right,
                              index_wavelength_left +
                              ((i) * indexwave):index_wavelength_left +
                              ((i) * indexwave) + indexwave],
                axis=1)
            time[:, i] = Data.time[index_time_left:index_time_right]
            wavelengthvals[i] = round(
                np.average(
                    Data.wavelength[index_wavelength_left +
                                    ((i) * indexwave):index_wavelength_left +
                                    ((i) * indexwave) + indexwave]), 1)

        plt.figure()
        colormap = plt.cm.jet
        plt.gca().set_color_cycle(
            [colormap(i) for i in np.linspace(0, 0.9, 10)])
        plt.plot(time, timevec)
        plt.legend(wavelengthvals)
        plt.xlabel('Time (ps)')
        plt.ylabel('Abs.')
        plt.title("Averaged %s %s" % (self.title, 'Wavelengths (nm)'))
        plt.show()

        indextime = int((index_time_right - index_time_left) / 10)

        wavevec = np.ones([
            Data.wavelength[index_wavelength_left:index_wavelength_right].
            shape[0], 10
        ])
        wave = np.ones([
            Data.wavelength[index_wavelength_left:index_wavelength_right].
            shape[0], 10
        ])
        timevals = np.ones(10)

        for i in range(10):
            wavevec[:, i] = np.average(
                Data.TrA_Data[index_time_left +
                              ((i) * indextime):index_time_left +
                              ((i) * indextime) + indextime,
                              index_wavelength_left:index_wavelength_right],
                axis=0)
            wave[:, i] = Data.wavelength[
                index_wavelength_left:index_wavelength_right]
            timevals[i] = round(
                np.average(Data.time[index_time_left +
                                     ((i) * indextime):index_time_left +
                                     ((i) * indextime) + indextime]), 1)

        plt.figure()
        colormap = plt.cm.jet
        plt.gca().set_color_cycle(
            [colormap(i) for i in np.linspace(0, 0.9, 10)])
        plt.plot(wave, wavevec)
        plt.legend(timevals)
        plt.title("Averaged %s %s" % (self.title, 'Times (ps)'))
        plt.xlabel('Wavelength (nm)')
        plt.ylabel('Abs.')
        plt.show()

    def _Normalise_fired(self):

        xmin, xmax = plt.xlim()
        ymin, ymax = plt.ylim()

        index_wavelength_left = (np.abs(Data.wavelength - xmin)).argmin()
        index_wavelength_right = (np.abs(Data.wavelength - xmax)).argmin()

        index_time_left = (np.abs(Data.time - ymin)).argmin()
        index_time_right = (np.abs(Data.time - ymax)).argmin()

        indextime = int((index_time_right - index_time_left) / 10)

        wavevec = np.ones([
            Data.wavelength[index_wavelength_left:index_wavelength_right].
            shape[0], 10
        ])
        wave = np.ones([
            Data.wavelength[index_wavelength_left:index_wavelength_right].
            shape[0], 10
        ])
        timevals = np.ones(10)

        for i in range(10):
            wavevec[:, i] = Data.TrA_Data[
                (index_time_left + ((i) * indextime)),
                index_wavelength_left:index_wavelength_right]
            max_i = np.max(wavevec[:, i])
            min_i = np.min(wavevec[:, i])
            wavevec[:, i] = (wavevec[:, i] - min_i) / (max_i - min_i)
            wave[:, i] = Data.wavelength[
                index_wavelength_left:index_wavelength_right]
            timevals[i] = Data.time[index_time_left + ((i) * indextime)]

        plt.figure()
        colormap = plt.cm.jet
        plt.gca().set_color_cycle(
            [colormap(i) for i in np.linspace(0, 0.9, 10)])
        plt.plot(wave, wavevec)
        plt.jet()
        plt.legend(timevals)
        plt.title("Normalised %s %s" % (self.title, 'Times (ps)'))
        plt.xlabel('Wavelength (nm)')
        plt.ylabel('Abs.')
        plt.show()

        indexwave = int((index_wavelength_right - index_wavelength_left) / 10)

        # spectrum from every 10th spectra

        timevec = np.ones(
            [Data.time[index_time_left:index_time_right].shape[0], 10])
        time = np.ones(
            [Data.time[index_time_left:index_time_right].shape[0], 10])
        wavelengthvals = np.ones(10)

        for i in range(10):
            timevec[:, i] = Data.TrA_Data[index_time_left:index_time_right,
                                          (index_wavelength_left +
                                           ((i) * indexwave))]
            max2_i = np.max(timevec[:, i])
            min2_i = np.min(timevec[:, i])
            timevec[:, i] = (timevec[:, i] - min2_i) / (max2_i - min2_i)
            time[:, i] = Data.time[index_time_left:index_time_right]
            wavelengthvals[i] = Data.wavelength[index_wavelength_left +
                                                ((i) * indexwave)]

        plt.figure()
        colormap = plt.cm.jet
        plt.gca().set_color_cycle(
            [colormap(i) for i in np.linspace(0, 0.9, 10)])
        plt.plot(time, timevec)
        plt.legend(wavelengthvals)
        plt.xlabel('Time (ps)')
        plt.ylabel('Abs.')
        plt.title("Normalised %s %s" % (self.title, 'Wavelengths (nm)'))
        plt.show()

    def _Trace_Igor_fired(self):

        try:
            import win32com.client  # Communicates with Igor needs pywin32 library
            f = open(("%s\Traces.txt" % (os.path.dirname(self.TrA_Raw_file))),
                     'w')
            for i in range(len(Data.time)):
                f.write("%s" % (Data.time[i]))
                for j in range(len(Data.Traces)):
                    f.write(",%s" % (Data.Traces[j, i]))
                f.write("\n")
            f.close()

            # Sends traces to Igor and opens up Global fitting gui in Igor
            igor = win32com.client.Dispatch("IgorPro.Application")

            #Load into igor using LoadWave(/A=Traces/J/P=pathname) /J specifies it as a txt delimited file
            igor.Execute('NewPath pathName, "%s"' %
                         (os.path.dirname(self.TrA_Raw_file)))
            igor.Execute('Loadwave/J/P=pathName "Traces.txt"')
            igor.Execute('Rename wave0,timeval')

            # Run global fitting gui in Igor
            igor.Execute('WM_NewGlobalFit1#InitNewGlobalFitPanel()')
            igor.clear()

        except:
            self.log = '%s\n\nsetuptools not installed or Igor not open. Saved traces into directory' % (
                self.log)
            try:
                f = open(
                    ("%s\Traces.txt" % (os.path.dirname(self.TrA_Raw_file))),
                    'w')
                for i in range(len(Data.time)):
                    f.write("%s" % (Data.time[i]))
                    for j in range(len(Data.Traces)):
                        f.write(",%s" % (Data.Traces[j, i]))
                    f.write("\n")
                f.close()
            except:
                self.log = '%s\n\nPlease select multiple traces' % (self.log)

    def _Save_Glo_fired(self):
        # Generates ouput file in Glotaran Time explicit format
        pathname = "%s\Glotaran.txt" % (os.path.dirname(self.TrA_Raw_file))
        f = open(pathname, 'w')
        f.write("#-#-#-#-#-# Made with PyTrA #-#-#-#-#-#\n")
        f.write("\n")
        f.write("Time explicit\n")
        f.write("intervalnr %d\n" % (len(Data.time)))
        for i in range(len(Data.time)):
            f.write(" %s" % (Data.time[i]))
        f.write("\n")
        for i in range(len(Data.wavelength)):
            f.write("%s" % (Data.wavelength[i]))
            for j in range(len(Data.time)):
                f.write(" %s" % (Data.TrA_Data[j, i]))
            f.write("\n")

        self.log = '%s \nSaved Glotaran file to TrA data file directory' % (
            self.log)

    def _Save_csv_fired(self):
        now = date.today()
        pathname = "%s\Saved%s%s.csv" % (os.path.dirname(
            self.TrA_Raw_file), now.strftime("%m-%d-%y"), self.title)
        f = open(pathname, 'w')
        f.write("0")
        for i in range(len(Data.time)):
            f.write(",%s" % (Data.time[i]))
        f.write("\n")
        for i in range(len(Data.wavelength)):
            f.write("%s" % (Data.wavelength[i]))
            for j in range(len(Data.time)):
                f.write(",%s" % (Data.TrA_Data[j, i]))
            f.write("\n")

        self.log = '%s\n\nSaved to TrA data file directory' % (self.log)

    def _Save_log_fired(self):
        now = date.today()
        pathname = "%s\log%s_%s.log" % (os.path.dirname(
            self.TrA_Raw_file), now.strftime("%m-%d-%y"), self.title)
        f = open(pathname, 'w')
        f.write("%s" % (self.log))

        self.log = '%s\n\nSaved log file to %s' % (
            self.log, os.path.dirname(self.TrA_Raw_file))

    def _Help_fired(self):
        help = Help().edit_traits()
示例#28
0
class ApplicationMain(HasTraits):

    scene = Instance(MlabSceneModel, ())

    firstcalc = Instance(FirstCalc)
    secondcalc = Instance(SecondCalc)
    display = Instance(Figure)
    markercolor = ColorTrait
    markerstyle = Enum(['+', ',', '*', 's', 'p', 'd', 'o'])
    markersize = Range(0, 10, 2)

    left_panel = Tabbed(Group(VGroup(
        Item('display',
             editor=MPLFigureEditor(),
             show_label=False,
             resizable=True)),
                              HGroup(
                                  Item(name='markercolor',
                                       label="Color",
                                       style="custom",
                                       springy=True),
                                  Item(name='markerstyle',
                                       label="Marker",
                                       springy=True),
                                  Item(name='markersize',
                                       label="Size",
                                       springy=True)),
                              label='Display'),
                        Item(name='scene',
                             label='Mayavi',
                             editor=SceneEditor(scene_class=MayaviScene)),
                        show_labels=False)

    right_panel = Tabbed(
        Item('firstcalc', style='custom', label='First Tab', show_label=False),
        Item('secondcalc',
             style='custom',
             label='Second Tab',
             show_label=False))

    view = View(HSplit(left_panel, right_panel),
                width=1280,
                height=750,
                resizable=True,
                title="My First Python GUI Interface")

    def _display_default(self):
        """Initialises the display."""
        figure = Figure()
        ax = figure.add_subplot(111)
        ax = figure.axes[0]
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)

        # Set matplotlib canvas colour to be white
        rect = figure.patch
        rect.set_facecolor('w')
        return figure

    def _firstcalc_default(self):
        # Initialize halos the way we want to.
        # Pass a reference of main (e.g. self) downwards
        return FirstCalc(self)

    def _secondcalc_default(self):
        # Initialize halos the way we want to.
        # Pass a reference of main (e.g. self) downwards
        return SecondCalc(self)

    def _markercolor_changed(self):
        ax = self.display.axes[0]
        if hasattr(self, 'display_points'):
            self.display_points.set_color(self.markercolor)
            self.display_points.set_markeredgecolor(self.markercolor)
            wx.CallAfter(self.display.canvas.draw)

    def _markerstyle_changed(self):
        ax = self.display.axes[0]
        if hasattr(self, 'display_points'):
            self.display_points.set_marker(self.markerstyle)
            wx.CallAfter(self.display.canvas.draw)

    def _markersize_changed(self):
        ax = self.display.axes[0]
        if hasattr(self, 'display_points'):
            self.display_points.set_markersize(self.markersize)
            wx.CallAfter(self.display.canvas.draw)

    def __init__(self, **kwargs):
        self.markercolor = 'blue'
        self.markersize = 2
        self.markerstyle = 'o'
示例#29
0
class Kit2FiffFrame(HasTraits):
    """GUI for interpolating between two KIT marker files."""

    model = Instance(Kit2FiffModel)
    scene = Instance(MlabSceneModel, ())
    headview = Instance(HeadViewController)
    marker_panel = Instance(CombineMarkersPanel)
    kit2fiff_panel = Instance(Kit2FiffPanel)

    view = View(HGroup(
        VGroup(Item('marker_panel', style='custom'), show_labels=False),
        VGroup(
            Item('scene',
                 editor=SceneEditor(scene_class=MayaviScene),
                 dock='vertical',
                 show_label=False),
            VGroup(Item('headview', style='custom'), show_labels=False),
        ),
        VGroup(Item('kit2fiff_panel', style='custom'), show_labels=False),
        show_labels=False,
    ),
                handler=Kit2FiffFrameHandler(),
                height=700,
                resizable=True,
                buttons=NoButtons)

    def __init__(self, *args, **kwargs):  # noqa: D102
        logger.debug("Initializing Kit2fiff-GUI with %s backend",
                     ETSConfig.toolkit)
        HasTraits.__init__(self, *args, **kwargs)

    # can't be static method due to Traits
    def _model_default(self):
        # load configuration values and make sure they're valid
        config = get_config(home_dir=os.environ.get('_MNE_FAKE_HOME_DIR'))
        stim_threshold = 1.
        if 'MNE_KIT2FIFF_STIM_CHANNEL_THRESHOLD' in config:
            try:
                stim_threshold = float(
                    config['MNE_KIT2FIFF_STIM_CHANNEL_THRESHOLD'])
            except ValueError:
                warn("Ignoring invalid configuration value for "
                     "MNE_KIT2FIFF_STIM_CHANNEL_THRESHOLD: %r (expected "
                     "float)" %
                     (config['MNE_KIT2FIFF_STIM_CHANNEL_THRESHOLD'], ))
        stim_slope = config.get('MNE_KIT2FIFF_STIM_CHANNEL_SLOPE', '-')
        if stim_slope not in '+-':
            warn("Ignoring invalid configuration value for "
                 "MNE_KIT2FIFF_STIM_CHANNEL_THRESHOLD: %s (expected + or -)" %
                 stim_slope)
            stim_slope = '-'
        stim_coding = config.get('MNE_KIT2FIFF_STIM_CHANNEL_CODING', '>')
        if stim_coding not in ('<', '>', 'channel'):
            warn("Ignoring invalid configuration value for "
                 "MNE_KIT2FIFF_STIM_CHANNEL_CODING: %s (expected <, > or "
                 "channel)" % stim_coding)
            stim_coding = '>'
        return Kit2FiffModel(stim_chs=config.get('MNE_KIT2FIFF_STIM_CHANNELS',
                                                 ''),
                             stim_coding=stim_coding,
                             stim_slope=stim_slope,
                             stim_threshold=stim_threshold,
                             show_gui=True)

    def _headview_default(self):
        return HeadViewController(scene=self.scene, scale=160, system='RAS')

    def _kit2fiff_panel_default(self):
        return Kit2FiffPanel(scene=self.scene, model=self.model)

    def _marker_panel_default(self):
        return CombineMarkersPanel(scene=self.scene,
                                   model=self.model.markers,
                                   trans=als_ras_trans)

    def save_config(self, home_dir=None):
        """Write configuration values."""
        set_config('MNE_KIT2FIFF_STIM_CHANNELS',
                   self.model.stim_chs,
                   home_dir,
                   set_env=False)
        set_config('MNE_KIT2FIFF_STIM_CHANNEL_CODING',
                   self.model.stim_coding,
                   home_dir,
                   set_env=False)
        set_config('MNE_KIT2FIFF_STIM_CHANNEL_SLOPE',
                   self.model.stim_slope,
                   home_dir,
                   set_env=False)
        set_config('MNE_KIT2FIFF_STIM_CHANNEL_THRESHOLD',
                   str(self.model.stim_threshold),
                   home_dir,
                   set_env=False)
示例#30
0
class Visualization(HasTraits):
    seq_num = Int(0, desc='Sequence number', auto_set=False, enter_set=True)
    seq_name = Str('0')
    label_name = Str(
        'gt    standing    sit    left_h_hold    right_h_hold    drink    walk    bending    clapping    phone_call    pointing'
    )
    gt = Str('')
    est = Str('')
    next_seq = Button('Next seq')
    prev_seq = Button('Prev seq')
    scene = Instance(MlabSceneModel, ())

    def __init__(self, data_set):
        # Do not forget to call the parent's __init__
        HasTraits.__init__(self)
        self.data_set = data_set

        self.net = ActionNet(512, 10)
        self.net.load_state_dict(
            torch.load('./models/action-net-300-140.pkl',
                       map_location=torch.device('cpu')))
        self.net.eval()

        out, lbl = self.estimate(0)
        out = np.array_str(out)
        lbl = np.array_str(lbl)
        x, y, z, s = self.render(0)
        self.plot = self.scene.mlab.points3d(x,
                                             y,
                                             z,
                                             s,
                                             colormap='hot',
                                             scale_factor=1,
                                             scale_mode='none')
        self.trait_set(seq_name=self.data_set.get_name(0), gt=lbl, est=out)
        # self.anim(self))
        #mlab.axes(figure=self.scene.mayavi_scene)

    def key_down(self, vtk, event):
        vtk.GetKeyCode()

    def _next_seq_fired(self):
        seq_num = int(getattr(self, 'seq_num'))
        seq_num += 1

        if seq_num > self.data_set.__len__():
            seq_num = self.data_set.__len__()

        out, lbl = self.estimate(seq_num)
        out = np.array_str(out)
        lbl = np.array_str(lbl)
        x, y, z, s = self.render(seq_num)
        self.scene.mlab.clf()
        self.plot = self.scene.mlab.points3d(x,
                                             y,
                                             z,
                                             s,
                                             colormap='hot',
                                             scale_factor=1,
                                             scale_mode='none')

        self.trait_set(seq_num=seq_num,
                       seq_name=self.data_set.get_name(seq_num),
                       gt=lbl,
                       est=out)
        #self.plot.mlab_source.trait_set(seq_num=seq_num)

    def _prev_seq_fired(self):
        seq_num = int(getattr(self, 'seq_num'))
        seq_num -= 1

        if seq_num < 0:
            seq_num = 0

        out, lbl = self.estimate(self.seq_num)
        out = np.array_str(out)
        lbl = np.array_str(lbl)
        x, y, z, s = self.render(self.seq_num)
        self.scene.mlab.clf()
        self.plot = self.scene.mlab.points3d(x,
                                             y,
                                             z,
                                             s,
                                             colormap='hot',
                                             scale_factor=1,
                                             scale_mode='none')
        self.trait_set(seq_num=seq_num,
                       seq_name=self.data_set.get_name(seq_num),
                       gt=lbl,
                       est=out)

    @on_trait_change('seq_num')
    def update_seq_num(self):
        seq_num = int(getattr(self, 'seq_num'))

        out, lbl = self.estimate(seq_num)
        out = np.array_str(out)
        lbl = np.array_str(lbl)
        x, y, z, s = self.render(seq_num)
        self.scene.mlab.clf()
        self.plot = self.scene.mlab.points3d(x,
                                             y,
                                             z,
                                             s,
                                             colormap='hot',
                                             scale_factor=1,
                                             scale_mode='none')
        self.trait_set(seq_name=self.data_set.get_name(seq_num),
                       gt=lbl,
                       est=out)
        # self.anim(self)

    @animate(delay=100)
    def anim(self):
        for i in range(10):
            frame = self.seq[i]
            x, y, z = np.nonzero(frame)
            s = np.linspace(0, 1, num=x.shape[0])

            self.scene.mlab.clf()
            self.plot = self.scene.mlab.points3d(x,
                                                 y,
                                                 z,
                                                 s,
                                                 colormap='hot',
                                                 scale_factor=1,
                                                 scale_mode='none')
            # self.plot.mlab_source.trait_set(x=x, y=y, z=z, s=s)
            # self.plot.mlab_source.scalars = np.asarray(x * 0.1 * (i + 1), 'd')
            yield

    def render(self, index):
        self.seq, val = self.data_set[index]
        self.seq = self.seq.numpy()
        val = val.numpy()
        first_frame = self.seq
        x, y, z = np.nonzero(first_frame)
        s = np.linspace(0, 1, num=x.shape[0])
        return x, y, z, s

    def _LeftKeyPressed(self, event):
        self._prev_seq_fired(self)

    def _RightKeyPressed(self, event):
        self._next_seq_fired(self)

    def estimate(self, index):
        # Set mini-batch dataset
        data, lbls = self.data_set[index]
        data = data.view(1, 61, 61, 85)
        data = Variable(data)
        lbls = Variable(lbls)
        self.net.zero_grad()
        output = self.net(data)

        output = output.detach().numpy()
        lbls = lbls.detach().numpy()

        output = output.reshape(-1)
        output = (output > 0.99).astype(int)

        return output, lbls

    key_bindings = KeyBindings(
        KeyBinding(binding1='Left',
                   description='prev seq',
                   method_name='_LeftKeyPressed'),
        KeyBinding(binding1='Right',
                   description='next seq',
                   method_name='_RightKeyPressed'))

    # class KeyHandler(Handler):
    #
    #     def save_file(self, info):
    #         info.object.status = "save file"
    #
    #     def run_script(self, info):
    #         info.object.status = "run script"
    #
    #     def edit_bindings(self, info):
    #         info.object.status = "edit bindings"
    #         key_bindings.edit_traits()

    # the layout of the dialog created
    view = View(Item('scene',
                     editor=SceneEditor(scene_class=MayaviScene),
                     height=1000,
                     width=1200,
                     show_label=False),
                HGroup(Item('seq_num'), Item('prev_seq'), Item('next_seq')),
                HGroup(Item('seq_name', style='readonly')),
                HGroup(Item('label_name', style='readonly')),
                HGroup(Item('gt', style='readonly')),
                HGroup(Item('est', style='readonly')),
                key_bindings=key_bindings,
                resizable=True)