Ejemplo n.º 1
0
class WindowLauncher(_traits.HasTraits):
    node_name = _traits.Str()
    view_creator = _traits.Callable()
    owner_ref = _traits.WeakRef()
    loop_name = _traits.Str()
    # FIXME: Rename to creator_parms
    func_parms = _traits.Tuple()
Ejemplo n.º 2
0
def vtk_color_trait(default, **metadata):
    Range = traits.Range
    if default[0] == -1.0:
        # Occurs for the vtkTextProperty's color trait.  Need to work
        # around.
        return traits.Trait(default, traits.Tuple(*default),
                            traits.Tuple(Range(0.0, 1.0),
                                         Range(0.0, 1.0),
                                         Range(0.0, 1.0),
                                         editor=RGBColorEditor),
                            **metadata)
    else:
        return traits.Trait(traits.Tuple(Range(0.0, 1.0, default[0]),
                                         Range(0.0, 1.0, default[1]),
                                         Range(0.0, 1.0, default[2])),
                            editor=RGBColorEditor, **metadata)
Ejemplo n.º 3
0
class Model(BMCSTreeNode):
    '''Contains the primary unknowns variables U_k
    '''
    tstep_type = tr.Type(TStep)
    tloop_type = tr.Type(TLoop)
    hist_type = tr.Type(Hist)
    sim_type = tr.Type(SimControler)

    tstep = tr.Property(depends_on='tstep_type')

    @tr.cached_property
    def _get_tstep(self):
        return self.tstep_type(model=self)

    tloop = tr.Property(depends_on='tloop_type')

    @tr.cached_property
    def _get_tloop(self):
        return self.tloop_type(tstep=self.tstep)

    hist = tr.Property(depends_on='hist_type')

    @tr.cached_property
    def _get_hist(self):
        return self.hist_type(model=self)

    sim = tr.Property(depends_on='sim_type')

    @tr.cached_property
    def _get_sim(self):
        return self.sim_type(model=self)

    bc = tr.List(tr.Callable)

    U_shape = tr.Tuple(1,)

    def init_state(self):
        self.U_k = np.zeros(self.U_shape, dtype=np.float)
        self.U_n = np.copy(self.U_n)
        self.hist.init_state()

    def get_plot_sheet(self):
        return

    U_k = tr.Array(np.float_, TRIAL_STATE=True)
    U_n = tr.Array(np.float_, FUND_STATE=True)

    S = tr.Dict(tr.Str, tr.Array(np.float), STATE=True)

    F = tr.Property(depends_on='+TRIAL_STATE,+INPUT')

    @tr.cached_property
    def _get_F(self):
        raise NotImplemented

    d_F_U = tr.Property(depends_on='+TRIAL_STATE,+INPUT')

    @tr.cached_property
    def _get_d_F_U(self):
        raise NotImplemented
Ejemplo n.º 4
0
class BMCSVizSheet(ROutputSection):
    '''Vieualization sheet
    - controls the time displayed
    - contains several vizualization adapters.
    This class could be called BMCSTV - for watching the time
    dependent response. It can have several channels - in 2D and 3D
    '''
    def __init__(self, *args, **kw):
        super(BMCSVizSheet, self).__init__(*args, **kw)
        self.on_trait_change(self.viz2d_list_items_changed, 'viz2d_list_items')

    name = Str

    hist = Instance(IHist)

    min = Float(0.0)
    '''Simulation start is always 0.0
    '''
    max = Float(1.0)
    '''Upper range limit of the current simulator.
    This range is determined by the the time-loop range
    of the model. 
    '''
    vot = Float

    def _vot_default(self):
        return self.min

    def _vot_changed(self):
        if self.hist:
            self.hist.vot = self.vot

    vot_slider = Range(low='min',
                       high='max',
                       step=0.01,
                       enter_set=True,
                       auto_set=False)
    '''Time line controlling the current state of the simulation.
    this value is synchronized with the control time of the
    time loop setting the tline. The vot_max = tline.max.
    The value of vot follows the value of tline.val in monitoring mode.
    By default, the monitoring mode is active with vot = tline.value.
    When sliding to a value vot < tline.value, the browser mode is activated.
    When sliding into the range vot > tline.value the monitoring mode
    is reactivated. 
    '''

    def _vot_slider_default(self):
        return 0.0

    mode = Enum('monitor', 'browse')

    def _mode_changed(self):
        if self.mode == 'browse':
            self.offline = False

    time = Float(0.0)

    def time_range_changed(self, max_):
        self.max = max_

    def time_changed(self, time):
        self.time = time
        if self.mode == 'monitor':
            self.vot = time
            self.vot_slider = time

    def _vot_slider_changed(self):
        if self.mode == 'browse':
            if self.vot_slider >= self.time:
                self.mode = 'monitor'
                self.vot_slider = self.time
                self.vot = self.time
            else:
                self.vot = self.vot_slider
        elif self.mode == 'monitor':
            if self.vot_slider < self.time:
                self.mode = 'browse'
                self.vot = self.vot_slider
            else:
                self.vot_slider = self.time
                self.vot = self.time

    offline = Bool(True)
    '''If the sheet is offline, the plot refresh is inactive.
    The sheet starts in offline mode and is activated once the signal
    run_started has been received. Upon run_finished the 
    the sheet goes directly into the offline mode again.
    
    If the user switches to browser mode, the vizsheet gets online 
    and reploting is activated.
    '''

    running = Bool(False)

    def run_started(self):
        self.running = True
        self.offline = False
        for pp in self.pp_list:
            pp.clear()
        self.mode = 'monitor'
        if self.reference_viz2d:
            ax = self.reference_axes
            ax.clear()
            self.reference_viz2d.reset(ax)

    def run_finished(self):
        self.skipped_steps = self.monitor_chunk_size
        # self.update_pipeline(1.0)
        self.replot()
        self.running = False
        self.offline = True

    monitor_chunk_size = Int(10, label='Monitor each # steps')

    skipped_steps = Int(1)

    @on_trait_change('vot,n_cols')
    def replot(self):
        if self.offline:
            return
        if self.running and self.mode == 'monitor' and \
                self.skipped_steps < (self.monitor_chunk_size - 1):
            self.skipped_steps += 1
            return
        for pp in self.pp_list:
            pp.replot(self.vot)


#         for viz2d, ax in self.axes.items():
#             ax.clear()
#             viz2d.clear()
#             viz2d.plot(ax, self.vot)
#         if self.selected_pp:
#             self.selected_pp.align_xaxis()
        if self.reference_viz2d:
            ax = self.reference_axes
            ax.clear()
            self.reference_viz2d.clear()
            self.reference_viz2d.plot(ax, self.vot)
        self.data_changed = True
        self.skipped_steps = 0
        if self.mode == 'browse':
            self.update_pipeline(self.vot)
        else:
            up = RunThread(self, self.vot)
            up.start()

    viz2d_list = List(Viz2D)
    '''List of visualization adaptors for 2D.
    '''
    viz2d_dict = Property

    def _get_viz2d_dict(self):
        return {viz2d.name: viz2d for viz2d in self.viz2d_list}

    viz2d_names = Property
    '''Names to be supplied to the selector of the
    reference graph.
    '''

    def _get_viz2d_names(self):
        return list(self.viz2d_dict.keys())

    viz2d_list_editor_clicked = Tuple
    viz2d_list_changed = Event

    def _viz2d_list_editor_clicked_changed(self, *args, **kw):
        _, column = self.viz2d_list_editor_clicked
        self.offline = False
        self.viz2d_list_changed = True
        if self.plot_mode == 'single':
            if column.name == 'visible':
                self.selected_viz2d.visible = True
                self.plot_mode = 'multiple'
            else:
                self.replot()
        elif self.plot_mode == 'multiple':
            if column.name != 'visible':
                self.plot_mode = 'single'
            else:
                self.replot()

    plot_mode = Enum('multiple', 'single')

    def _plot_mode_changed(self):
        if self.plot_mode == 'single':
            self.replot_selected_viz2d()
        elif self.plot_mode == 'multiple':
            self.replot()

    def replot_selected_viz2d(self):
        for viz2d in self.viz2d_list:
            viz2d.visible = False
        self.selected_viz2d.visible = True
        self.n_cols = 1
        self.viz2d_list_changed = True
        self.replot()

    def viz2d_list_items_changed(self):
        self.replot()

    def get_subrecords(self):
        '''What is this good for?
        '''
        return self.viz2d_list

    export_button = Button(label='Export selected diagram')

    def plot_in_window(self):
        fig = plt.figure(figsize=(self.fig_width, self.fig_height))
        ax = fig.add_subplot(111)
        self.selected_viz2d.plot(ax, self.vot)
        fig.show()

    def _export_button_fired(self, vot=0):
        print('in export button fired')
        Thread(target=self.plot_in_window).start()
        print('thread started')

    fig_width = Float(8.0, auto_set=False, enter_set=True)
    fig_height = Float(5.0, auto_set=False, enter_set=True)

    save_button = Button(label='Save selected diagram')

    animate_button = Button(label='Animate selected diagram')

    def _animate_button_fired(self):
        ad = AnimationDialog(sheet=self)
        ad.edit_traits()
        return

    #=========================================================================
    # Reference figure serving for orientation.
    #=========================================================================
    reference_viz2d_name = Enum('', values="viz2d_names")
    '''Current name of the reference graphs.
    '''

    def _reference_viz2d_name_changed(self):
        self.replot()

    reference_viz2d_cumulate = Bool(False, label='cumulate')
    reference_viz2d = Property(Instance(Viz2D),
                               depends_on='reference_viz2d_name')
    '''Visualization of a graph showing the time context of the
    current visualization state. 
    '''

    def _get_reference_viz2d(self):
        if self.reference_viz2d_name == None:
            if len(self.viz2d_dict):
                return self.viz2d_list[0]
            else:
                return None
        return self.viz2d_dict[self.reference_viz2d_name]

    reference_figure = Instance(Figure)

    def _reference_figure_default(self):
        figure = Figure(facecolor='white')
        figure.set_tight_layout(True)
        return figure

    reference_axes = Property(List, depends_on='reference_viz2d_name')
    '''Derived axes objects reflecting the layout of plot pane
    and the individual. 
    '''

    @cached_property
    def _get_reference_axes(self):
        return self.reference_figure.add_subplot(1, 1, 1)

    selected_viz2d = Instance(Viz2D)

    def _selected_viz2d_changed(self):
        if self.plot_mode == 'single':
            self.replot_selected_viz2d()

    n_cols = Range(low=1,
                   high=3,
                   value=2,
                   label='Number of columns',
                   tooltip='Defines a number of columns within the plot pane',
                   enter_set=True,
                   auto_set=False)

    figure = Instance(Figure)

    tight_layout = Bool(True)

    def _figure_default(self):
        figure = Figure(facecolor='white')
        figure.set_tight_layout(self.tight_layout)
        return figure

    visible_viz2d_list = Property(
        List,
        depends_on='viz2d_list,viz2d_list_items,n_cols,viz2d_list_changed')
    '''Derived axes objects reflecting the layout of plot pane
    and the individual. 
    '''

    @cached_property
    def _get_visible_viz2d_list(self):
        viz_list = []
        for viz2d in self.viz2d_list:
            if viz2d.visible:
                viz_list.append(viz2d)
        return viz_list

    pp_list = List(PlotPerspective)

    selected_pp = Instance(PlotPerspective)

    xaxes = Property(List, depends_on='selected_pp')
    '''Derived axes objects reflecting the layout of plot pane
    and the individual. 
    '''

    @cached_property
    def _get_xaxes(self):
        self.figure.clear()
        if self.selected_pp:
            self.selected_pp.figure = self.figure
            ad = self.selected_pp.axes
        else:
            n_fig = len(self.visible_viz2d_list)
            n_cols = self.n_cols
            n_rows = (n_fig + n_cols - 1) / self.n_cols
            ad = {
                viz2d: self.figure.add_subplot(n_rows, self.n_cols, i + 1)
                for i, viz2d in enumerate(self.visible_viz2d_list)
            }
        return ad

    data_changed = Event

    bgcolor = tr.Tuple(1.0, 1.0, 1.0)
    fgcolor = tr.Tuple(0.0, 0.0, 0.0)

    scene = Instance(MlabSceneModel)

    def _scene_default(self):
        return MlabSceneModel()

    mlab = Property(depends_on='input_change')
    '''Get the mlab handle'''

    def _get_mlab(self):
        return self.scene.mlab

    fig = Property()
    '''Figure for 3D visualization.
    '''

    @cached_property
    def _get_fig(self):
        fig = self.mlab.gcf()
        bgcolor = tuple(self.bgcolor)
        fgcolor = tuple(self.fgcolor)
        self.mlab.figure(fig, fgcolor=fgcolor, bgcolor=bgcolor)
        return fig

    def show(self, *args, **kw):
        '''Render the visualization.
        '''
        self.mlab.show(*args, **kw)

    def add_viz3d(self, viz3d, order=1):
        '''Add a new visualization objectk.'''
        viz3d.ftv = self
        vis3d = viz3d.vis3d
        name = viz3d.name
        label = '%s[%s:%s]-%s' % (name, str(
            vis3d.__class__), str(viz3d.__class__), vis3d)
        if label in self.viz3d_dict:
            raise KeyError('viz3d object named %s already registered' % label)
        viz3d.order = order
        self.viz3d_dict[label] = viz3d

    viz3d_dict = tr.Dict(tr.Str, tr.Instance(Viz3D))
    '''Dictionary of visualization objects.
    '''

    viz3d_list = tr.Property

    def _get_viz3d_list(self):
        map_order_viz3d = {}
        for idx, (viz3d) in enumerate(self.viz3d_dict.values()):
            order = viz3d.order
            map_order_viz3d['%5g%5g' % (order, idx)] = viz3d
        return [map_order_viz3d[key] for key in sorted(map_order_viz3d.keys())]

    pipeline_ready = Bool(False)

    def setup_pipeline(self):
        if self.pipeline_ready:
            return
        self.fig
        fig = self.mlab.gcf()
        fig.scene.disable_render = True
        for viz3d in self.viz3d_list:
            viz3d.setup()
        fig.scene.disable_render = False
        self.pipeline_ready = True

    def update_pipeline(self, vot):
        self.setup_pipeline()
        # get the current constrain information
        self.vot = vot
        fig = self.mlab.gcf()
        fig.scene.disable_render = True
        for viz3d in self.viz3d_list:
            viz3d.plot(vot)
        fig.scene.disable_render = False

    selected_viz3d = Instance(Viz3D)

    def _selected_viz3d_changed(self):
        print('selection done')

    # Traits view definition:
    traits_view = View(
        VSplit(
            HSplit(
                Tabbed(
                    UItem(
                        'pp_list',
                        id='notebook',
                        style='custom',
                        resizable=True,
                        editor=ListEditor(
                            use_notebook=True,
                            deletable=False,
                            # selected='selected_pp',
                            export='DockWindowShell',
                            page_name='.name')),
                    UItem('scene',
                          label='3d scene',
                          editor=SceneEditor(scene_class=MayaviScene)),
                    scrollable=True,
                    label='Plot panel'),
                VGroup(
                    Item('n_cols', width=250),
                    Item('plot_mode@', width=250),
                    VSplit(
                        UItem('viz2d_list@',
                              editor=viz2d_list_editor,
                              width=100),
                        UItem('selected_viz2d@', width=200),
                        UItem('pp_list@', editor=pp_list_editor, width=100),
                        #                         UItem('selected_pp@',
                        #                               width=200),
                        UItem('viz3d_list@',
                              editor=viz3d_list_editor,
                              width=100),
                        UItem('selected_viz3d@', width=200),
                        VGroup(
                            #                             UItem('export_button',
                            #                                   springy=False, resizable=True),
                            #                             VGroup(
                            #                                 HGroup(
                            #                                     UItem('fig_width', springy=True,
                            #                                           resizable=False),
                            #                                     UItem('fig_height', springy=True),
                            #                                 ),
                            #                                 label='Figure size'
                            #                             ),
                            UItem('animate_button',
                                  springy=False,
                                  resizable=True), ),
                        VGroup(
                            UItem('reference_viz2d_name', resizable=True),
                            UItem(
                                'reference_figure',
                                editor=MPLFigureEditor(),
                                width=200,
                                # springy=True
                            ),
                            label='Reference graph',
                        )),
                    label='Plot configure',
                    scrollable=True),
            ),
            VGroup(
                HGroup(
                    Item('mode', resizable=False, springy=False),
                    Item('monitor_chunk_size', resizable=False, springy=False),
                ),
                Item('vot_slider', height=40),
            )),
        resizable=True,
        width=0.8,
        height=0.8,
        buttons=['OK', 'Cancel'])
Ejemplo n.º 5
0
class AxesManager(t.HasTraits):

    """Contains and manages the data axes.

    It supports indexing, slicing, subscriptins and iteration. As an iterator,
    iterate over the navigation coordinates returning the current indices.
    It can only be indexed and sliced to access the DataAxis objects that it
    contains. Standard indexing and slicing follows the "natural order" as in
    Signal, i.e. [nX, nY, ...,sX, sY,...] where `n` indicates a navigation axis
    and `s` a signal axis. In addition AxesManager support indexing using
    complex numbers a + bj, where b can be one of 0, 1, 2 and 3 and a a valid
    index. If b is 3 AxesManager is indexed using the order of the axes in the
    array. If b is 1(2), indexes only the navigation(signal) axes in the
    natural order. In addition AxesManager supports subscription using
    axis name.

    Attributes
    ----------

    coordinates : tuple
        Get and set the current coordinates if the navigation dimension
        is not 0. If the navigation dimension is 0 it raises
        AttributeError when attempting to set its value.


    indices : tuple
        Get and set the current indices if the navigation dimension
        is not 0. If the navigation dimension is 0 it raises
        AttributeError when attempting to set its value.

    signal_axes, navigation_axes : list
        Contain the corresponding DataAxis objects

    Examples
    --------

    >>> import numpy as np

    Create a spectrum with random data

    >>> s = signals.Spectrum(np.random.random((2,3,4,5)))
    >>> s.axes_manager
    <Axes manager, axes: (<axis2 axis, size: 4, index: 0>, <axis1 axis, size: 3, index: 0>, <axis0 axis, size: 2, index: 0>, <axis3 axis, size: 5>)>
    >>> s.axes_manager[0]
    <axis2 axis, size: 4, index: 0>
    >>> s.axes_manager[3j]
    <axis0 axis, size: 2, index: 0>
    >>> s.axes_manager[1j]
    <axis2 axis, size: 4, index: 0>
    >>> s.axes_manager[2j]
    <axis3 axis, size: 5>
    >>> s.axes_manager[1].name="y"
    >>> s.axes_manager['y']
    <y axis, size: 3 index: 0>
    >>> for i in s.axes_manager:
    >>>     print i, s.axes_manager.indices
    (0, 0, 0) (0, 0, 0)
    (1, 0, 0) (1, 0, 0)
    (2, 0, 0) (2, 0, 0)
    (3, 0, 0) (3, 0, 0)
    (0, 1, 0) (0, 1, 0)
    (1, 1, 0) (1, 1, 0)
    (2, 1, 0) (2, 1, 0)
    (3, 1, 0) (3, 1, 0)
    (0, 2, 0) (0, 2, 0)
    (1, 2, 0) (1, 2, 0)
    (2, 2, 0) (2, 2, 0)
    (3, 2, 0) (3, 2, 0)
    (0, 0, 1) (0, 0, 1)
    (1, 0, 1) (1, 0, 1)
    (2, 0, 1) (2, 0, 1)
    (3, 0, 1) (3, 0, 1)
    (0, 1, 1) (0, 1, 1)
    (1, 1, 1) (1, 1, 1)
    (2, 1, 1) (2, 1, 1)
    (3, 1, 1) (3, 1, 1)
    (0, 2, 1) (0, 2, 1)
    (1, 2, 1) (1, 2, 1)
    (2, 2, 1) (2, 2, 1)
    (3, 2, 1) (3, 2, 1)

    """

    _axes = t.List(DataAxis)
    signal_axes = t.Tuple()
    navigation_axes = t.Tuple()
    _step = t.Int(1)

    def __init__(self, axes_list):
        super(AxesManager, self).__init__()
        self.create_axes(axes_list)
        # set_signal_dimension is called only if there is no current
        # view. It defaults to spectrum
        navigates = [i.navigate for i in self._axes]
        if t.Undefined in navigates:
            # Default to Spectrum view if the view is not fully defined
            self.set_signal_dimension(1)

        self._update_attributes()
        self.on_trait_change(self._update_attributes, '_axes.slice')
        self.on_trait_change(self._update_attributes, '_axes.index')
        self.on_trait_change(self._update_attributes, '_axes.size')
        self._index = None  # index for the iterator

    def _get_positive_index(self, axis):
        if axis < 0:
            axis = len(self._axes) + axis
            if axis < 0:
                raise IndexError("index out of bounds")
        return axis

    def _array_indices_generator(self):
        shape = (self.navigation_shape[::-1] if self.navigation_size > 0 else
                 [1, ])
        return np.ndindex(*shape)

    def _am_indices_generator(self):
        shape = (self.navigation_shape if self.navigation_size > 0 else
                 [1, ])[::-1]
        return ndindex_nat(*shape)

    def __getitem__(self, y):
        """x.__getitem__(y) <==> x[y]

        """
        if isinstance(y, basestring):
            axes = list(self._get_axes_in_natural_order())
            while axes:
                axis = axes.pop()
                if y == axis.name:
                    return axis
            raise ValueError("There is no DataAxis named %s" % y)
        elif (isfloat(y.real) and not y.real.is_integer() or
                isfloat(y.imag) and not y.imag.is_integer()):
            raise TypeError("axesmanager indices must be integers, "
                            "complex intergers or strings")
        if y.imag == 0:  # Natural order
            return self._get_axes_in_natural_order()[y]
        elif y.imag == 3:  # Array order
            # Array order
            return self._axes[int(y.real)]
        elif y.imag == 1:  # Navigation natural order
            #
            return self.navigation_axes[int(y.real)]
        elif y.imag == 2:  # Signal natural order
            return self.signal_axes[int(y.real)]
        else:
            raise IndexError("axesmanager imaginary part of complex indices "
                             "must be 0, 1 or 2")

    def __getslice__(self, i=None, j=None):
        """x.__getslice__(i, j) <==> x[i:j]

        """
        return self._get_axes_in_natural_order()[i:j]

    def _get_axes_in_natural_order(self):
        return self.navigation_axes + self.signal_axes

    @property
    def _navigation_shape_in_array(self):
        return self.navigation_shape[::-1]

    @property
    def _signal_shape_in_array(self):
        return self.signal_shape[::-1]

    @property
    def shape(self):
        nav_shape = (self.navigation_shape
                     if self.navigation_shape != (0,)
                     else tuple())
        sig_shape = (self.signal_shape
                     if self.signal_shape != (0,)
                     else tuple())
        return nav_shape + sig_shape

    def remove(self, axis):
        """Remove the given Axis.

        Raises
        ------
        ValueError if the Axis is not present.

        """
        axis = self[axis]
        axis.axes_manager = None
        self._axes.remove(axis)

    def __delitem__(self, i):
        self.remove(self[i])

    def _get_data_slice(self, fill=None):
        """Return a tuple of slice objects to slice the data.

        Parameters
        ----------
        fill: None or iterable of (int, slice)
            If not None, fill the tuple of index int with the given
            slice.

        """
        cslice = [slice(None), ] * len(self._axes)
        if fill is not None:
            for index, slice_ in fill:
                cslice[index] = slice_
        return tuple(cslice)

    def create_axes(self, axes_list):
        """Given a list of dictionaries defining the axes properties
        create the DataAxis instances and add them to the AxesManager.

        The index of the axis in the array and in the `_axes` lists
        can be defined by the index_in_array keyword if given
        for all axes. Otherwise it is defined by their index in the
        list.

        See also
        --------
        append_axis

        """
        # Reorder axes_list using index_in_array if it is defined
        # for all axes and the indices are not repeated.
        indices = set([axis['index_in_array'] for axis in axes_list if
                       hasattr(axis, 'index_in_array')])
        if len(indices) == len(axes_list):
            axes_list.sort(key=lambda x: x['index_in_array'])
        for axis_dict in axes_list:
            self.append_axis(**axis_dict)

    def _update_max_index(self):
        self._max_index = 1
        for i in self.navigation_shape:
            self._max_index *= i
        if self._max_index != 0:
            self._max_index -= 1

    def next(self):
        """
        Standard iterator method, updates the index and returns the
        current coordiantes

        Returns
        -------
        val : tuple of ints
            Returns a tuple containing the coordiantes of the current
            iteration.

        """
        if self._index is None:
            self._index = 0
            val = (0,) * self.navigation_dimension
            self.indices = val
        elif (self._index >= self._max_index):
            raise StopIteration
        else:
            self._index += 1
            val = np.unravel_index(
                self._index,
                tuple(self._navigation_shape_in_array)
            )[::-1]
            self.indices = val
        return val

    def __iter__(self):
        # Reset the _index that can have a value != None due to
        # a previous iteration that did not hit a StopIteration
        self._index = None
        return self

    def append_axis(self, *args, **kwargs):
        axis = DataAxis(*args, **kwargs)
        axis.axes_manager = self
        self._axes.append(axis)

    def _update_attributes(self):
        getitem_tuple = ()
        values = []
        self.signal_axes = ()
        self.navigation_axes = ()
        for axis in self._axes:
            # Until we find a better place, take property of the axes
            # here to avoid difficult to debug bugs.
            axis.axes_manager = self
            if axis.slice is None:
                getitem_tuple += axis.index,
                values.append(axis.value)
                self.navigation_axes += axis,
            else:
                getitem_tuple += axis.slice,
                self.signal_axes += axis,

        self.signal_axes = self.signal_axes[::-1]
        self.navigation_axes = self.navigation_axes[::-1]
        self._getitem_tuple = getitem_tuple
        self.signal_dimension = len(self.signal_axes)
        self.navigation_dimension = len(self.navigation_axes)
        if self.navigation_dimension != 0:
            self.navigation_shape = tuple([
                axis.size for axis in self.navigation_axes])
        else:
            self.navigation_shape = ()

        if self.signal_dimension != 0:
            self.signal_shape = tuple([
                axis.size for axis in self.signal_axes])
        else:
            self.signal_shape = ()
        self.navigation_size = (np.cumprod(self.navigation_shape)[-1]
                                if self.navigation_shape else 0)
        self.signal_size = (np.cumprod(self.signal_shape)[-1]
                            if self.signal_shape else 0)
        self._update_max_index()

    def set_signal_dimension(self, value):
        """Set the dimension of the signal.

        Attributes
        ----------
        value : int

        Raises
        ------
        ValueError if value if greater than the number of axes or
        is negative

        """
        if len(self._axes) == 0:
            return
        elif value > len(self._axes):
            raise ValueError(
                "The signal dimension cannot be greater"
                " than the number of axes which is %i" % len(self._axes))
        elif value < 0:
            raise ValueError(
                "The signal dimension must be a positive integer")

        tl = [True] * len(self._axes)
        if value != 0:
            tl[-value:] = (False,) * value

        for axis in self._axes:
            axis.navigate = tl.pop(0)

    def connect(self, f):
        for axis in self._axes:
            if axis.slice is None:
                axis.on_trait_change(f, 'index')

    def disconnect(self, f):
        for axis in self._axes:
            if axis.slice is None:
                axis.on_trait_change(f, 'index', remove=True)

    def key_navigator(self, event):
        if len(self.navigation_axes) not in (1, 2):
            return
        x = self.navigation_axes[0]
        try:
            if event.key == "right" or event.key == "6":
                x.index += self._step
            elif event.key == "left" or event.key == "4":
                x.index -= self._step
            elif event.key == "pageup":
                self._step += 1
            elif event.key == "pagedown":
                if self._step > 1:
                    self._step -= 1
            if len(self.navigation_axes) == 2:
                y = self.navigation_axes[1]
                if event.key == "up" or event.key == "8":
                    y.index -= self._step
                elif event.key == "down" or event.key == "2":
                    y.index += self._step
        except TraitError:
            pass

    def gui(self):
        from hyperspy.gui.axes import data_axis_view
        for axis in self._axes:
            axis.edit_traits(view=data_axis_view)

    def copy(self):
        return(copy.copy(self))

    def deepcopy(self):
        return(copy.deepcopy(self))

    def __deepcopy__(self, *args):
        return AxesManager(self._get_axes_dicts())

    def _get_axes_dicts(self):
        axes_dicts = []
        for axis in self._axes:
            axes_dicts.append(axis.get_axis_dictionary())
        return axes_dicts

    def as_dictionary(self):
        am_dict = {}
        for i, axis in enumerate(self._axes):
            am_dict['axis-%i' % i] = axis.get_axis_dictionary()
        return am_dict

    def _get_signal_axes_dicts(self):
        return [axis.get_axis_dictionary() for axis in
                self.signal_axes[::-1]]

    def _get_navigation_axes_dicts(self):
        return [axis.get_axis_dictionary() for axis in
                self.navigation_axes[::-1]]

    def show(self):
        from hyperspy.gui.axes import get_axis_group
        import traitsui.api as tui
        context = {}
        ag = []
        for n, axis in enumerate(self._get_axes_in_natural_order()):
            ag.append(get_axis_group(n, str(axis)))
            context['axis%i' % n] = axis
        ag = tuple(ag)
        self.edit_traits(view=tui.View(*ag), context=context)

    def _get_axes_str(self):
        string = "("
        for axis in self.navigation_axes:
            string += axis.__repr__() + ", "
        string = string.rstrip(", ")
        string += "|"
        for axis in self.signal_axes:
            string += axis.__repr__() + ", "
        string = string.rstrip(", ")
        string += ")"
        return string

    def _get_dimension_str(self):
        string = "("
        for axis in self.navigation_axes:
            string += str(axis.size) + ", "
        string = string.rstrip(", ")
        string += "|"
        for axis in self.signal_axes:
            string += str(axis.size) + ", "
        string = string.rstrip(", ")
        string += ")"
        return string

    def __repr__(self):
        text = ('<Axes manager, axes: %s>' %
                self._get_axes_str())
        return text

    @property
    def coordinates(self):
        """Get the coordinates of the navigation axes.

        Returns
        -------
        list

        """
        return tuple([axis.value for axis in self.navigation_axes])

    @coordinates.setter
    def coordinates(self, coordinates):
        """Set the coordinates of the navigation axes.

        Parameters
        ----------
        coordinates : tuple
            The len of the the tuple must coincide with the navigation
            dimension

        """

        if len(coordinates) != self.navigation_dimension:
            raise AttributeError(
                "The number of coordinates must be equal to the "
                "navigation dimension that is %i" %
                self.navigation_dimension)
        for value, axis in zip(coordinates, self.navigation_axes):
            axis.value = value

    @property
    def indices(self):
        """Get the index of the navigation axes.

        Returns
        -------
        list

        """
        return tuple([axis.index for axis in self.navigation_axes])

    @indices.setter
    def indices(self, indices):
        """Set the index of the navigation axes.

        Parameters
        ----------
        indices : tuple
            The len of the the tuple must coincide with the navigation
            dimension

        """

        if len(indices) != self.navigation_dimension:
            raise AttributeError(
                "The number of indices must be equal to the "
                "navigation dimension that is %i" %
                self.navigation_dimension)
        for index, axis in zip(indices, self.navigation_axes):
            axis.index = index

    def _get_axis_attribute_values(self, attr):
        return [getattr(axis, attr) for axis in self._axes]

    def _set_axis_attribute_values(self, attr, values):
        """Set the given attribute of all the axes to the given
        value(s)

        Parameters
        ----------
        attr : string
            The DataAxis attribute to set.
        values: any
            If iterable, it must have the same number of items
            as axes are in this AxesManager instance. If not iterable,
            the attribute of all the axes are set to the given value.

        """
        if not isiterable(values):
            values = [values, ] * len(self._axes)
        elif len(values) != len(self._axes):
            raise ValueError("Values must have the same number"
                             "of items are axes are in this AxesManager")
        for axis, value in zip(self._axes, values):
            setattr(axis, attr, value)
Ejemplo n.º 6
0
class AxesManager(t.HasTraits):
    """Contains and manages the data axes.

    It supports indexing, slicing, subscriptins and iteration. As an iterator,
    iterate over the navigation coordinates returning the current indices.
    It can only be indexed and sliced to access the DataAxis objects that it
    contains. Standard indexing and slicing follows the "natural order" as in
    Signal, i.e. [nX, nY, ...,sX, sY,...] where `n` indicates a navigation axis
    and `s` a signal axis. In addition AxesManager support indexing using
    complex numbers a + bj, where b can be one of 0, 1, 2 and 3 and a a valid
    index. If b is 3 AxesManager is indexed using the order of the axes in the
    array. If b is 1(2), indexes only the navigation(signal) axes in the
    natural order. In addition AxesManager supports subscription using
    axis name.

    Attributes
    ----------

    coordinates : tuple
        Get and set the current coordinates if the navigation dimension
        is not 0. If the navigation dimension is 0 it raises
        AttributeError when attempting to set its value.


    indices : tuple
        Get and set the current indices if the navigation dimension
        is not 0. If the navigation dimension is 0 it raises
        AttributeError when attempting to set its value.

    signal_axes, navigation_axes : list
        Contain the corresponding DataAxis objects

    Examples
    --------

    >>> %hyperspy
    HyperSpy imported!
    The following commands were just executed:
    ---------------
    import numpy as np
    import hyperspy.api as hs
    %matplotlib qt
    import matplotlib.pyplot as plt

    >>> # Create a spectrum with random data

    >>> s = hs.signals.Signal1D(np.random.random((2,3,4,5)))
    >>> s.axes_manager
    <Axes manager, axes: (<axis2 axis, size: 4, index: 0>, <axis1 axis, size: 3, index: 0>, <axis0 axis, size: 2, index: 0>, <axis3 axis, size: 5>)>
    >>> s.axes_manager[0]
    <axis2 axis, size: 4, index: 0>
    >>> s.axes_manager[3j]
    <axis0 axis, size: 2, index: 0>
    >>> s.axes_manager[1j]
    <axis2 axis, size: 4, index: 0>
    >>> s.axes_manager[2j]
    <axis3 axis, size: 5>
    >>> s.axes_manager[1].name="y"
    >>> s.axes_manager['y']
    <y axis, size: 3 index: 0>
    >>> for i in s.axes_manager:
    >>>     print(i, s.axes_manager.indices)
    (0, 0, 0) (0, 0, 0)
    (1, 0, 0) (1, 0, 0)
    (2, 0, 0) (2, 0, 0)
    (3, 0, 0) (3, 0, 0)
    (0, 1, 0) (0, 1, 0)
    (1, 1, 0) (1, 1, 0)
    (2, 1, 0) (2, 1, 0)
    (3, 1, 0) (3, 1, 0)
    (0, 2, 0) (0, 2, 0)
    (1, 2, 0) (1, 2, 0)
    (2, 2, 0) (2, 2, 0)
    (3, 2, 0) (3, 2, 0)
    (0, 0, 1) (0, 0, 1)
    (1, 0, 1) (1, 0, 1)
    (2, 0, 1) (2, 0, 1)
    (3, 0, 1) (3, 0, 1)
    (0, 1, 1) (0, 1, 1)
    (1, 1, 1) (1, 1, 1)
    (2, 1, 1) (2, 1, 1)
    (3, 1, 1) (3, 1, 1)
    (0, 2, 1) (0, 2, 1)
    (1, 2, 1) (1, 2, 1)
    (2, 2, 1) (2, 2, 1)
    (3, 2, 1) (3, 2, 1)

    """

    _axes = t.List(DataAxis)
    signal_axes = t.Tuple()
    navigation_axes = t.Tuple()
    _step = t.Int(1)

    def __init__(self, axes_list):
        super(AxesManager, self).__init__()
        self.events = Events()
        self.events.indices_changed = Event("""
            Event that triggers when the indices of the `AxesManager` changes

            Triggers after the internal state of the `AxesManager` has been
            updated.

            Arguments:
            ----------
            obj : The AxesManager that the event belongs to.
            """,
                                            arguments=['obj'])
        self.events.any_axis_changed = Event("""
            Event that trigger when the space defined by the axes transforms.

            Specifically, it triggers when one or more of the folloing
            attributes changes on one or more of the axes:
                `offset`, `size`, `scale`

            Arguments:
            ----------
            obj : The AxesManager that the event belongs to.
            """,
                                             arguments=['obj'])
        self.create_axes(axes_list)
        # set_signal_dimension is called only if there is no current
        # view. It defaults to spectrum
        navigates = [i.navigate for i in self._axes]
        if t.Undefined in navigates:
            # Default to Signal1D view if the view is not fully defined
            self.set_signal_dimension(len(axes_list))

        self._update_attributes()
        self._update_trait_handlers()
        self._index = None  # index for the iterator

    def _update_trait_handlers(self, remove=False):
        things = {
            self._on_index_changed: '_axes.index',
            self._on_slice_changed: '_axes.slice',
            self._on_size_changed: '_axes.size',
            self._on_scale_changed: '_axes.scale',
            self._on_offset_changed: '_axes.offset'
        }

        for k, v in things.items():
            self.on_trait_change(k, name=v, remove=remove)

    def _get_positive_index(self, axis):
        if axis < 0:
            axis += len(self._axes)
            if axis < 0:
                raise IndexError("index out of bounds")
        return axis

    def _array_indices_generator(self):
        shape = (self.navigation_shape[::-1] if self.navigation_size > 0 else [
            1,
        ])
        return np.ndindex(*shape)

    def _am_indices_generator(self):
        shape = (self.navigation_shape if self.navigation_size > 0 else [
            1,
        ])[::-1]
        return ndindex_nat(*shape)

    def __getitem__(self, y):
        """x.__getitem__(y) <==> x[y]

        """
        if isinstance(y, str) or not np.iterable(y):
            return self[(y, )][0]
        axes = [self._axes_getter(ax) for ax in y]
        _, indices = np.unique([_id for _id in map(id, axes)],
                               return_index=True)
        ans = tuple(axes[i] for i in sorted(indices))
        return ans

    def _axes_getter(self, y):
        if y in self._axes:
            return y
        if isinstance(y, str):
            axes = list(self._get_axes_in_natural_order())
            while axes:
                axis = axes.pop()
                if y == axis.name:
                    return axis
            raise ValueError("There is no DataAxis named %s" % y)
        elif (isfloat(y.real) and not y.real.is_integer()
              or isfloat(y.imag) and not y.imag.is_integer()):
            raise TypeError("axesmanager indices must be integers, "
                            "complex intergers or strings")
        if y.imag == 0:  # Natural order
            return self._get_axes_in_natural_order()[y]
        elif y.imag == 3:  # Array order
            # Array order
            return self._axes[int(y.real)]
        elif y.imag == 1:  # Navigation natural order
            #
            return self.navigation_axes[int(y.real)]
        elif y.imag == 2:  # Signal natural order
            return self.signal_axes[int(y.real)]
        else:
            raise IndexError("axesmanager imaginary part of complex indices "
                             "must be 0, 1, 2 or 3")

    def __getslice__(self, i=None, j=None):
        """x.__getslice__(i, j) <==> x[i:j]

        """
        return self._get_axes_in_natural_order()[i:j]

    def _get_axes_in_natural_order(self):
        return self.navigation_axes + self.signal_axes

    @property
    def _navigation_shape_in_array(self):
        return self.navigation_shape[::-1]

    @property
    def _signal_shape_in_array(self):
        return self.signal_shape[::-1]

    @property
    def shape(self):
        nav_shape = (self.navigation_shape if self.navigation_shape !=
                     (0, ) else tuple())
        sig_shape = (self.signal_shape if self.signal_shape !=
                     (0, ) else tuple())
        return nav_shape + sig_shape

    def remove(self, axes):
        """Remove one or more axes
        """
        axes = self[axes]
        if not np.iterable(axes):
            axes = (axes, )
        for ax in axes:
            self._remove_one_axis(ax)

    def _remove_one_axis(self, axis):
        """Remove the given Axis.

        Raises
        ------
        ValueError if the Axis is not present.

        """
        axis = self._axes_getter(axis)
        axis.axes_manager = None
        self._axes.remove(axis)

    def __delitem__(self, i):
        self.remove(self[i])

    def _get_data_slice(self, fill=None):
        """Return a tuple of slice objects to slice the data.

        Parameters
        ----------
        fill: None or iterable of (int, slice)
            If not None, fill the tuple of index int with the given
            slice.

        """
        cslice = [
            slice(None),
        ] * len(self._axes)
        if fill is not None:
            for index, slice_ in fill:
                cslice[index] = slice_
        return tuple(cslice)

    def create_axes(self, axes_list):
        """Given a list of dictionaries defining the axes properties
        create the DataAxis instances and add them to the AxesManager.

        The index of the axis in the array and in the `_axes` lists
        can be defined by the index_in_array keyword if given
        for all axes. Otherwise it is defined by their index in the
        list.

        See also
        --------
        _append_axis

        """
        # Reorder axes_list using index_in_array if it is defined
        # for all axes and the indices are not repeated.
        indices = set([
            axis['index_in_array'] for axis in axes_list
            if hasattr(axis, 'index_in_array')
        ])
        if len(indices) == len(axes_list):
            axes_list.sort(key=lambda x: x['index_in_array'])
        for axis_dict in axes_list:
            self._append_axis(**axis_dict)

    def _update_max_index(self):
        self._max_index = 1
        for i in self.navigation_shape:
            self._max_index *= i
        if self._max_index != 0:
            self._max_index -= 1

    def __next__(self):
        """
        Standard iterator method, updates the index and returns the
        current coordiantes

        Returns
        -------
        val : tuple of ints
            Returns a tuple containing the coordiantes of the current
            iteration.

        """
        if self._index is None:
            self._index = 0
            val = (0, ) * self.navigation_dimension
            self.indices = val
        elif self._index >= self._max_index:
            raise StopIteration
        else:
            self._index += 1
            val = np.unravel_index(self._index,
                                   tuple(
                                       self._navigation_shape_in_array))[::-1]
            self.indices = val
        return val

    def __iter__(self):
        # Reset the _index that can have a value != None due to
        # a previous iteration that did not hit a StopIteration
        self._index = None
        return self

    def _append_axis(self, *args, **kwargs):
        axis = DataAxis(*args, **kwargs)
        axis.axes_manager = self
        self._axes.append(axis)

    def _on_index_changed(self):
        self._update_attributes()
        self.events.indices_changed.trigger(obj=self)

    def _on_slice_changed(self):
        self._update_attributes()

    def _on_size_changed(self):
        self._update_attributes()
        self.events.any_axis_changed.trigger(obj=self)

    def _on_scale_changed(self):
        self.events.any_axis_changed.trigger(obj=self)

    def _on_offset_changed(self):
        self.events.any_axis_changed.trigger(obj=self)

    def update_axes_attributes_from(self,
                                    axes,
                                    attributes=["scale", "offset", "units"]):
        """Update the axes attributes to match those given.

        The axes are matched by their index in the array. The purpose of this
        method is to update multiple axes triggering `any_axis_changed` only
        once.

        Parameters
        ----------
        axes: iterable of `DataAxis` instances.
            The axes to copy the attributes from.
        attributes: iterable of strings.
            The attributes to copy.

        """

        # To only trigger once even with several changes, we suppress here
        # and trigger manually below if there were any changes.
        changes = False
        with self.events.any_axis_changed.suppress():
            for axis in axes:
                changed = self._axes[axis.index_in_array].update_from(
                    axis=axis, attributes=attributes)
                changes = changes or changed
        if changes:
            self.events.any_axis_changed.trigger(obj=self)

    def _update_attributes(self):
        getitem_tuple = []
        values = []
        self.signal_axes = ()
        self.navigation_axes = ()
        for axis in self._axes:
            # Until we find a better place, take property of the axes
            # here to avoid difficult to debug bugs.
            axis.axes_manager = self
            if axis.slice is None:
                getitem_tuple += axis.index,
                values.append(axis.value)
                self.navigation_axes += axis,
            else:
                getitem_tuple += axis.slice,
                self.signal_axes += axis,
        if not self.signal_axes and self.navigation_axes:
            getitem_tuple[-1] = slice(axis.index, axis.index + 1)

        self.signal_axes = self.signal_axes[::-1]
        self.navigation_axes = self.navigation_axes[::-1]
        self._getitem_tuple = tuple(getitem_tuple)
        self.signal_dimension = len(self.signal_axes)
        self.navigation_dimension = len(self.navigation_axes)
        if self.navigation_dimension != 0:
            self.navigation_shape = tuple(
                [axis.size for axis in self.navigation_axes])
        else:
            self.navigation_shape = ()

        if self.signal_dimension != 0:
            self.signal_shape = tuple([axis.size for axis in self.signal_axes])
        else:
            self.signal_shape = ()
        self.navigation_size = (np.cumprod(self.navigation_shape)[-1]
                                if self.navigation_shape else 0)
        self.signal_size = (np.cumprod(self.signal_shape)[-1]
                            if self.signal_shape else 0)
        self._update_max_index()

    def set_signal_dimension(self, value):
        """Set the dimension of the signal.

        Attributes
        ----------
        value : int

        Raises
        ------
        ValueError if value if greater than the number of axes or
        is negative

        """
        if len(self._axes) == 0:
            return
        elif value > len(self._axes):
            raise ValueError("The signal dimension cannot be greater"
                             " than the number of axes which is %i" %
                             len(self._axes))
        elif value < 0:
            raise ValueError("The signal dimension must be a positive integer")

        tl = [True] * len(self._axes)
        if value != 0:
            tl[-value:] = (False, ) * value

        for axis in self._axes:
            axis.navigate = tl.pop(0)

    def key_navigator(self, event):
        if len(self.navigation_axes) not in (1, 2):
            return
        x = self.navigation_axes[0]
        try:
            if event.key == "right" or event.key == "6":
                x.index += self._step
            elif event.key == "left" or event.key == "4":
                x.index -= self._step
            elif event.key == "pageup":
                self._step += 1
            elif event.key == "pagedown":
                if self._step > 1:
                    self._step -= 1
            if len(self.navigation_axes) == 2:
                y = self.navigation_axes[1]
                if event.key == "up" or event.key == "8":
                    y.index -= self._step
                elif event.key == "down" or event.key == "2":
                    y.index += self._step
        except TraitError:
            pass

    def gui(self):
        from hyperspy.gui.axes import data_axis_view
        for axis in self._axes:
            axis.edit_traits(view=data_axis_view)

    def copy(self):
        return copy.copy(self)

    def deepcopy(self):
        return copy.deepcopy(self)

    def __deepcopy__(self, *args):
        return AxesManager(self._get_axes_dicts())

    def _get_axes_dicts(self):
        axes_dicts = []
        for axis in self._axes:
            axes_dicts.append(axis.get_axis_dictionary())
        return axes_dicts

    def as_dictionary(self):
        am_dict = {}
        for i, axis in enumerate(self._axes):
            am_dict['axis-%i' % i] = axis.get_axis_dictionary()
        return am_dict

    def _get_signal_axes_dicts(self):
        return [axis.get_axis_dictionary() for axis in self.signal_axes[::-1]]

    def _get_navigation_axes_dicts(self):
        return [
            axis.get_axis_dictionary() for axis in self.navigation_axes[::-1]
        ]

    def show(self):
        from hyperspy.gui.axes import get_axis_group
        import traitsui.api as tui
        context = {}
        ag = []
        for n, axis in enumerate(self._get_axes_in_natural_order()):
            ag.append(get_axis_group(n, str(axis)))
            context['axis%i' % n] = axis
        ag = tuple(ag)
        self.edit_traits(view=tui.View(*ag), context=context)

    def _get_dimension_str(self):
        string = "("
        for axis in self.navigation_axes:
            string += str(axis.size) + ", "
        string = string.rstrip(", ")
        string += "|"
        for axis in self.signal_axes:
            string += str(axis.size) + ", "
        string = string.rstrip(", ")
        string += ")"
        return string

    def __repr__(self):
        text = ('<Axes manager, axes: %s>\n' % self._get_dimension_str())
        ax_signature = "% 16s | %6g | %6s | %7.2g | %7.2g | %6s "
        signature = "% 16s | %6s | %6s | %7s | %7s | %6s "
        text += signature % ('Name', 'size', 'index', 'offset', 'scale',
                             'units')
        text += '\n'
        text += signature % ('=' * 16, '=' * 6, '=' * 6, '=' * 7, '=' * 7,
                             '=' * 6)
        for ax in self.navigation_axes:
            text += '\n'
            text += ax_signature % (str(ax.name)[:16], ax.size, str(
                ax.index), ax.offset, ax.scale, ax.units)
        text += '\n'
        text += signature % ('-' * 16, '-' * 6, '-' * 6, '-' * 7, '-' * 7,
                             '-' * 6)
        for ax in self.signal_axes:
            text += '\n'
            text += ax_signature % (str(
                ax.name)[:16], ax.size, ' ', ax.offset, ax.scale, ax.units)

        return text

    def _repr_html_(self):
        text = ("<style>\n"
                "table, th, td {\n\t"
                "border: 1px solid black;\n\t"
                "border-collapse: collapse;\n}"
                "\nth, td {\n\t"
                "padding: 5px;\n}"
                "\n</style>")
        text += ('\n<p><b>< Axes manager, axes: %s ></b></p>\n' %
                 self._get_dimension_str())

        def format_row(*args, tag='td', bold=False):
            if bold:
                signature = "\n<tr class='bolder_row'> "
            else:
                signature = "\n<tr> "
            signature += " ".join(("{}" for _ in args)) + " </tr>"
            return signature.format(*map(
                lambda x: '\n<' + tag + '>{}</'.format(x) + tag + '>', args))

        if self.navigation_axes:
            text += "<table style='width:100%'>\n"
            text += format_row('Navigation axis name',
                               'size',
                               'index',
                               'offset',
                               'scale',
                               'units',
                               tag='th')
            for ax in self.navigation_axes:
                text += format_row(ax.name, ax.size, ax.index, ax.offset,
                                   ax.scale, ax.units)
            text += "</table>\n"
        if self.signal_axes:
            text += "<table style='width:100%'>\n"
            text += format_row('Signal axis name',
                               'size',
                               'offset',
                               'scale',
                               'units',
                               tag='th')
            for ax in self.signal_axes:
                text += format_row(ax.name, ax.size, ax.offset, ax.scale,
                                   ax.units)
            text += "</table>\n"
        return text

    @property
    def coordinates(self):
        """Get the coordinates of the navigation axes.

        Returns
        -------
        list

        """
        return tuple([axis.value for axis in self.navigation_axes])

    @coordinates.setter
    def coordinates(self, coordinates):
        """Set the coordinates of the navigation axes.

        Parameters
        ----------
        coordinates : tuple
            The len of the the tuple must coincide with the navigation
            dimension

        """

        if len(coordinates) != self.navigation_dimension:
            raise AttributeError(
                "The number of coordinates must be equal to the "
                "navigation dimension that is %i" % self.navigation_dimension)
        for value, axis in zip(coordinates, self.navigation_axes):
            axis.value = value

    @property
    def indices(self):
        """Get the index of the navigation axes.

        Returns
        -------
        list

        """
        return tuple([axis.index for axis in self.navigation_axes])

    @indices.setter
    def indices(self, indices):
        """Set the index of the navigation axes.

        Parameters
        ----------
        indices : tuple
            The len of the the tuple must coincide with the navigation
            dimension

        """

        if len(indices) != self.navigation_dimension:
            raise AttributeError("The number of indices must be equal to the "
                                 "navigation dimension that is %i" %
                                 self.navigation_dimension)
        for index, axis in zip(indices, self.navigation_axes):
            axis.index = index

    def _get_axis_attribute_values(self, attr):
        return [getattr(axis, attr) for axis in self._axes]

    def _set_axis_attribute_values(self, attr, values):
        """Set the given attribute of all the axes to the given
        value(s)

        Parameters
        ----------
        attr : string
            The DataAxis attribute to set.
        values: any
            If iterable, it must have the same number of items
            as axes are in this AxesManager instance. If not iterable,
            the attribute of all the axes are set to the given value.

        """
        if not isiterable(values):
            values = [
                values,
            ] * len(self._axes)
        elif len(values) != len(self._axes):
            raise ValueError("Values must have the same number"
                             "of items are axes are in this AxesManager")
        for axis, value in zip(self._axes, values):
            setattr(axis, attr, value)

    @property
    def navigation_indices_in_array(self):
        return tuple([axis.index_in_array for axis in self.navigation_axes])

    @property
    def signal_indices_in_array(self):
        return tuple([axis.index_in_array for axis in self.signal_axes])

    @property
    def axes_are_aligned_with_data(self):
        """Verify if the data axes are aligned with the signal axes.

        When the data are aligned with the axes the axes order in `self._axes`
        is [nav_n, nav_n-1, ..., nav_0, sig_m, sig_m-1 ..., sig_0].

        Returns
        -------
        aligned : bool

        """
        nav_iia_r = self.navigation_indices_in_array[::-1]
        sig_iia_r = self.signal_indices_in_array[::-1]
        iia_r = nav_iia_r + sig_iia_r
        aligned = iia_r == tuple(range(len(iia_r)))
        return aligned

    def _sort_axes(self):
        """Sort _axes to align them.

        When the data are aligned with the axes the axes order in `self._axes`
        is [nav_n, nav_n-1, ..., nav_0, sig_m, sig_m-1 ..., sig_0]. This method
        sort the axes in this way. Warning: this doesn't sort the `data` axes.

        """
        am = self
        new_axes = am.navigation_axes[::-1] + am.signal_axes[::-1]
        self._axes = list(new_axes)
Ejemplo n.º 7
0
class DOTSGrid(BMCSLeafNode):
    '''Domain time steppsr on a grid mesh
    '''
    x_0 = tr.Tuple(0., 0., input=True)
    L_x = tr.Float(200, input=True, MESH=True)
    L_y = tr.Float(100, input=True, MESH=True)
    n_x = tr.Int(100, input=True, MESH=True)
    n_y = tr.Int(30, input=True, MESH=True)
    integ_factor = tr.Float(1.0, input=True, MESH=True)
    fets = tr.Instance(IFETSEval, input=True, MESH=True)

    D1_abcd = tr.Array(np.float_, input=True)
    '''Symmetric operator distributing the 
    derivatives of the shape functions into the 
    tensor field
    '''
    def _D1_abcd_default(self):
        delta = np.identity(2)
        # symmetrization operator
        D1_abcd = 0.5 * (np.einsum('ac,bd->abcd', delta, delta) +
                         np.einsum('ad,bc->abcd', delta, delta))
        return D1_abcd

    mesh = tr.Property(tr.Instance(FEGrid), depends_on='+input')

    @tr.cached_property
    def _get_mesh(self):
        return FEGrid(coord_min=self.x_0,
                      coord_max=(self.x_0[0] + self.L_x,
                                 self.x_0[1] + self.L_y),
                      shape=(self.n_x, self.n_y),
                      fets_eval=self.fets)

    cached_grid_values = tr.Property(tr.Tuple, depends_on='+input')

    @tr.cached_property
    def _get_cached_grid_values(self):
        x_Ia = self.mesh.X_Id
        n_I, n_a = x_Ia.shape
        dof_Ia = np.arange(n_I * n_a, dtype=np.int_).reshape(n_I, -1)
        I_Ei = self.mesh.I_Ei
        x_Eia = x_Ia[I_Ei, :]
        dof_Eia = dof_Ia[I_Ei]
        x_Ema = np.einsum('im,Eia->Ema', self.fets.N_im, x_Eia)
        J_Emar = np.einsum('imr,Eia->Emar', self.fets.dN_imr, x_Eia)
        J_Enar = np.einsum('inr,Eia->Enar', self.fets.dN_inr, x_Eia)
        det_J_Em = np.linalg.det(J_Emar)
        inv_J_Emar = np.linalg.inv(J_Emar)
        inv_J_Enar = np.linalg.inv(J_Enar)
        B_Eimabc = np.einsum('abcd,imr,Eidr->Eimabc', self.D1_abcd,
                             self.fets.dN_imr, inv_J_Emar)
        B_Einabc = np.einsum('abcd,inr,Eidr->Einabc', self.D1_abcd,
                             self.fets.dN_inr, inv_J_Enar)
        BB_Emicjdabef = np.einsum('Eimabc,Ejmefd, Em, m->Emicjdabef', B_Eimabc,
                                  B_Eimabc, det_J_Em, self.fets.w_m)
        return (BB_Emicjdabef, B_Eimabc, dof_Eia, x_Eia, dof_Ia, I_Ei,
                B_Einabc, det_J_Em)

    BB_Emicjdabef = tr.Property()
    '''Quadratic form of the kinematic mapping.
    '''

    def _get_BB_Emicjdabef(self):
        return self.cached_grid_values[0]

    B_Eimabc = tr.Property()
    '''Kinematic mapping between displacements and strains in every
    integration point.
    '''

    def _get_B_Eimabc(self):
        return self.cached_grid_values[1]

    B_Einabc = tr.Property()
    '''Kinematic mapping between displacement and strain in every
    visualization point
    '''

    def _get_B_Einabc(self):
        return self.cached_grid_values[6]

    dof_Eia = tr.Property()
    '''Mapping [element, node, direction] -> degree of freedom.
    '''

    def _get_dof_Eia(self):
        return self.cached_grid_values[2]

    x_Eia = tr.Property()
    '''Mapping [element, node, direction] -> value of coordinate.
    '''

    def _get_x_Eia(self):
        return self.cached_grid_values[3]

    dof_Ia = tr.Property()
    '''[global node, direction] -> degree of freedom
    '''

    def _get_dof_Ia(self):
        return self.cached_grid_values[4]

    I_Ei = tr.Property()
    '''[element, node] -> global node
    '''

    def _get_I_Ei(self):
        return self.cached_grid_values[5]

    det_J_Em = tr.Property()
    '''Jacobi determinant in every element and integration point.
    '''

    def _get_det_J_Em(self):
        return self.cached_grid_values[7]

    state_arrays = tr.Property(tr.Dict(tr.Str, tr.Array),
                               depends_on='fets, mats')
    '''Dictionary of state arrays.
    The entry names and shapes are defined by the material
    model.
    '''

    @tr.cached_property
    def _get_state_arrays(self):
        return {
            name: np.zeros((
                self.mesh.n_active_elems,
                self.fets.n_m,
            ) + mats_sa_shape,
                           dtype=np.float_)
            for name, mats_sa_shape in list(
                self.mats.state_array_shapes.items())
        }

    def get_corr_pred(self, U, dU, t_n, t_n1, update_state, algorithmic):
        '''Get the corrector and predictor for the given increment
        of unknown .
        '''
        n_c = self.fets.n_nodal_dofs
        U_Ia = U.reshape(-1, n_c)
        U_Eia = U_Ia[self.I_Ei]
        eps_Emab = np.einsum('Eimabc,Eic->Emab', self.B_Eimabc, U_Eia)
        dU_Ia = dU.reshape(-1, n_c)
        dU_Eia = dU_Ia[self.I_Ei]
        deps_Emab = np.einsum('Eimabc,Eic->Emab', self.B_Eimabc, dU_Eia)
        D_Emabef, sig_Emab = self.mats.get_corr_pred(eps_Emab, deps_Emab, t_n,
                                                     t_n1, update_state,
                                                     algorithmic,
                                                     **self.state_arrays)
        K_Eicjd = self.integ_factor * np.einsum('Emicjdabef,Emabef->Eicjd',
                                                self.BB_Emicjdabef, D_Emabef)
        n_E, n_i, n_c, n_j, n_d = K_Eicjd.shape
        K_E = K_Eicjd.reshape(-1, n_i * n_c, n_j * n_d)
        dof_E = self.dof_Eia.reshape(-1, n_i * n_c)
        K_subdomain = SysMtxArray(mtx_arr=K_E, dof_map_arr=dof_E)
        f_Eic = self.integ_factor * np.einsum(
            'm,Eimabc,Emab,Em->Eic', self.fets.w_m, self.B_Eimabc, sig_Emab,
            self.det_J_Em)
        f_Ei = f_Eic.reshape(-1, n_i * n_c)
        F_dof = np.bincount(dof_E.flatten(), weights=f_Ei.flatten())
        F_int = F_dof
        norm_F_int = np.linalg.norm(F_int)
        return K_subdomain, F_int, norm_F_int
Ejemplo n.º 8
0
class ShmTrackingInterface(T.HasStrictTraits):

    dwi_images = T.DelegatesTo('all_inputs')
    all_inputs = T.Instance(InputData, args=())
    min_signal = T.DelegatesTo('all_inputs')
    seed_roi = nifti_file
    seed_density = T.Array(dtype='int', shape=(3, ), value=[1, 1, 1])

    smoothing_kernel_type = T.Enum(None, all_kernels.keys())
    smoothing_kernel = T.Instance(T.HasTraits)

    @T.on_trait_change('smoothing_kernel_type')
    def set_smoothing_kernel(self):
        if self.smoothing_kernel_type is not None:
            kernel_factory = all_kernels[self.smoothing_kernel_type]
            self.smoothing_kernel = kernel_factory()
        else:
            self.smoothing_kernel = None

    interpolator = T.Enum('NearestNeighbor', all_interpolators.keys())
    model_type = T.Enum('SlowAdcOpdf', all_shmodels.keys())
    sh_order = T.Int(4)
    Lambda = T.Float(0, desc="Smoothing on the odf")
    sphere_coverage = T.Int(5)
    min_peak_spacing = T.Range(0., 1, np.sqrt(.5), desc="as a dot product")
    min_relative_peak = T.Range(0., 1, .25)

    probabilistic = T.Bool(False, label='Probabilistic (Residual Bootstrap)')
    bootstrap_input = T.Bool(False)
    bootstrap_vector = T.Array(dtype='int', value=[])

    # integrator = Enum('Boundry', all_integrators.keys())
    seed_largest_peak = T.Bool(False,
                               desc="Ignore sub-peaks and start follow "
                               "the largest peak at each seed")
    start_direction = T.Array(dtype='float',
                              shape=(3, ),
                              value=[0, 0, 1],
                              desc="Prefered direction from seeds when "
                              "multiple directions are available. "
                              "(Mostly) doesn't matter when 'seed "
                              "largest peak' and 'track two directions' "
                              "are both True",
                              label="Start direction (RAS)")
    track_two_directions = T.Bool(False)
    fa_threshold = T.Float(1.0)
    max_turn_angle = T.Range(0., 90, 0)

    stop_on_target = T.Bool(False)
    targets = T.List(nifti_file, [])

    # will be set later
    voxel_size = T.Array(dtype='float', shape=(3, ))
    affine = T.Array(dtype='float', shape=(4, 4))
    shape = T.Tuple((0, 0, 0))

    # set for io
    save_streamlines_to = T.File('')
    save_counts_to = nifti_file

    # io methods
    def save_streamlines(self, streamlines, save_streamlines_to):
        trk_hdr = empty_header()
        voxel_order = orientation_to_string(nib.io_orientation(self.affine))
        trk_hdr['voxel_order'] = voxel_order
        trk_hdr['voxel_size'] = self.voxel_size
        trk_hdr['vox_to_ras'] = self.affine
        trk_hdr['dim'] = self.shape
        trk_tracks = ((ii, None, None) for ii in streamlines)
        write(save_streamlines_to, trk_tracks, trk_hdr)
        pickle.dump(self, open(save_streamlines_to + '.p', 'wb'))

    def save_counts(self, streamlines, save_counts_to):
        counts = density_map(streamlines, self.shape, self.voxel_size)
        if counts.max() < 2**15:
            counts = counts.astype('int16')
        nib.save(nib.Nifti1Image(counts, self.affine), save_counts_to)

    # tracking methods
    def track_shm(self, debug=False):
        if self.sphere_coverage > 7 or self.sphere_coverage < 1:
            raise ValueError("sphere coverage must be between 1 and 7")
        verts, edges, faces = create_half_unit_sphere(self.sphere_coverage)
        verts, pot = disperse_charges(verts, 10, .3)

        data, voxel_size, affine, fa, bvec, bval = self.all_inputs.read_data()
        self.voxel_size = voxel_size
        self.affine = affine
        self.shape = fa.shape

        model_type = all_shmodels[self.model_type]
        model = model_type(self.sh_order, bval, bvec, self.Lambda)
        model.set_sampling_points(verts, edges)

        data = np.asarray(data, dtype='float', order='C')
        if self.smoothing_kernel is not None:
            kernel = self.smoothing_kernel.get_kernel()
            convolve(data, kernel, out=data)

        normalize_data(data, bval, self.min_signal, out=data)
        dmin = data.min()
        data = data[..., lazy_index(bval > 0)]
        if self.bootstrap_input:
            if self.bootstrap_vector.size == 0:
                n = data.shape[-1]
                self.bootstrap_vector = np.random.randint(n, size=n)
            H = hat(model.B)
            R = lcr_matrix(H)
            data = bootstrap_data_array(data, H, R, self.bootstrap_vector)
            data.clip(dmin, out=data)

        mask = fa > self.fa_threshold
        targets = [read_roi(tgt, shape=self.shape) for tgt in self.targets]
        if self.stop_on_target:
            for target_mask in targets:
                mask = mask & ~target_mask

        seed_mask = read_roi(self.seed_roi, shape=self.shape)
        seeds = seeds_from_mask(seed_mask, self.seed_density, voxel_size)

        if ((self.interpolator == 'NearestNeighbor' and not self.probabilistic
             and not debug)):
            using_optimze = True
            peak_finder = NND_ClosestPeakSelector(model, data, mask,
                                                  voxel_size)
        else:
            using_optimze = False
            interpolator_type = all_interpolators[self.interpolator]
            interpolator = interpolator_type(data, voxel_size, mask)
            peak_finder = ClosestPeakSelector(model, interpolator)

        # Set peak_finder parameters for start steps
        peak_finder.angle_limit = 90
        model.peak_spacing = self.min_peak_spacing
        if self.seed_largest_peak:
            model.min_relative_peak = 1
        else:
            model.min_relative_peak = self.min_relative_peak

        data_ornt = nib.io_orientation(self.affine)
        best_start = reorient_vectors(self.start_direction, 'ras', data_ornt)
        start_steps = closest_start(seeds, peak_finder, best_start)

        if self.probabilistic:
            interpolator = ResidualBootstrapWrapper(interpolator,
                                                    model.B,
                                                    min_signal=dmin)
            peak_finder = ClosestPeakSelector(model, interpolator)
        elif using_optimze and self.seed_largest_peak:
            peak_finder.reset_cache()

        # Reset peak_finder parameters for tracking
        peak_finder.angle_limit = self.max_turn_angle
        model.peak_spacing = self.min_peak_spacing
        model.min_relative_peak = self.min_relative_peak

        integrator = BoundryIntegrator(voxel_size, overstep=.1)
        streamlines = generate_streamlines(peak_finder, integrator, seeds,
                                           start_steps)
        if self.track_two_directions:
            start_steps = -start_steps
            streamlinesB = generate_streamlines(peak_finder, integrator, seeds,
                                                start_steps)
            streamlines = merge_streamlines(streamlines, streamlinesB)

        for target_mask in targets:
            streamlines = target(streamlines, target_mask, voxel_size)

        return streamlines
Ejemplo n.º 9
0
class AppWindow(tr.HasTraits):
    '''Container class synchronizing the interactionjup elements with plotting area.
    It is equivalent to the traitsui.View class
    '''
    model = tr.Instance(IModel)

    figsize = tr.Tuple(8, 3)

    def __init__(self, model, **kw):
        super(AppWindow, self).__init__(**kw)
        self.model = model
        self.output = ipw.Output()

    plot_backend_table = tr.Dict

    def _plot_backend_table_default(self):
        return {'mpl': MPLBackend(), 'k3d': K3DBackend()}

    # Shared layouts -
    left_pane_layout = tr.Instance(ipw.Layout)

    def _left_pane_layout_default(self):
        return ipw.Layout(
            #border='solid 1px black',
            margin='0px 0px 0px 0px',
            padding='0px 0px 0px 0px',
            width="300px",
            flex_grow="1",
        )

    right_pane_layout = tr.Instance(ipw.Layout)

    def _right_pane_layout_default(self):
        return ipw.Layout(
            border='solid 1px black',
            margin='0px 0px 0px 5px',
            padding='1px 1px 1px 1px',
            width="100%",
            flex_grow="1",
        )

    def interact(self):
        left_pane = ipw.VBox([self.tree_pane, self.model_editor_pane],
                             layout=self.left_pane_layout)
        name = self.model.name
        self.menubar = ipw.Label(value=name,
                                 layout=ipw.Layout(width="100%",
                                                   height="150px"))
        self.empty_pane = ipw.Box(
            layout=ipw.Layout(width="100%", height="100%"))
        self.plot_pane = ipw.VBox(
            [self.menubar, self.empty_pane],
            layout=ipw.Layout(
                align_items="stretch",
                #border='solid 1px black',
                width="100%"))
        right_pane = ipw.VBox([self.plot_pane, self.time_editor_pane],
                              layout=self.right_pane_layout)
        app = ipw.HBox([left_pane, right_pane],
                       layout=ipw.Layout(align_items="stretch", width="100%"))
        app_print = ipw.VBox([app, print_output],
                             layout=ipw.Layout(align_items="stretch",
                                               width="100%"))
        self.model_tree.selected = True
        display(app_print)

    model_tree = tr.Property()

    @tr.cached_property
    def _get_model_tree(self):
        tree = self.model.get_sub_node(self.model.name)
        return self.get_tree_entries(tree)

    def get_tree_entries(self, node):
        name, model, sub_nodes = node
        bmcs_sub_nodes = [
            self.get_tree_entries(sub_node) for sub_node in sub_nodes
        ]
        node_ = BMCSNode(name,
                         nodes=tuple(bmcs_sub_nodes),
                         controller=model.get_controller(self))
        node_.observe(self.select_node, 'selected')

        def update_node(event):
            '''upon tree change - rebuild the subnodes'''
            new_node = model.get_sub_node(model.name)
            new_node_ = self.get_tree_entries(new_node)
            node_.nodes = new_node_.nodes
            # are the original nodes deleted? memory leak?
            # are the original observers deleted?

        model.observe(update_node, 'tree_changed')
        return node_

    tree_pane = tr.Property  # might depend on the model

    @tr.cached_property
    def _get_tree_pane(self):
        # provide a method scanning the tree of the model
        # components
        tree_layout = ipw.Layout(display='flex',
                                 overflow='scroll hidden',
                                 flex_flow='column',
                                 border='solid 1px black',
                                 margin='0px 5px 5px 0px',
                                 padding='1px 1px 15px 1px',
                                 align_items='stretch',
                                 flex_grow="2",
                                 height="30%",
                                 width='100%')

        tree_pane = ipt.Tree(layout=tree_layout)
        root_node = self.model_tree
        tree_pane.nodes = (root_node, )
        return tree_pane

    model_editor_pane = tr.Property  # should depend on the model

    @tr.cached_property
    def _get_model_editor_pane(self):
        editor_pane_layout = ipw.Layout(display='flex',
                                        border='solid 1px black',
                                        overflow='scroll hidden',
                                        justify_content='space-between',
                                        flex_flow='column',
                                        padding='10px 5px 10px 5px',
                                        margin='0px 5px 0px 0px',
                                        align_items='flex-start',
                                        height="70%",
                                        width='100%')
        return ipw.VBox(layout=editor_pane_layout)

    time_editor_pane_layout = tr.Instance(ipw.Layout)

    def _time_editor_pane_layout_default(self):
        return ipw.Layout(
            height="35px",
            width="100%",
            margin='0px 0px 0px 0px',
            padding='0px 0px 0px 0px',
        )

    time_editor_pane = tr.Property  # should depend on the model

    @tr.cached_property
    def _get_time_editor_pane(self):
        return ipw.VBox(layout=self.time_editor_pane_layout)

    def select_node(self, event):
        if event['old']:
            return
        node = event['owner']
        controller = node.controller
        self.controller = controller
        time_editor = controller.time_editor
        self.time_editor_pane.children = time_editor
        model_editor = controller.model_editor
        self.model_editor_pane.children = model_editor.children
        # with print_output:
        #     print('select node: controller', controller)
        #     print('time_editor: time_editor', time_editor)
        backend = controller.model.plot_backend
        self.set_plot_backend(backend)
        self.setup_plot(controller.model)
        self.update_plot(controller.model)

    current_plot_backend = tr.Str

    pb = tr.Property()

    def _get_pb(self):
        '''Get the current plot backend'''
        return self.plot_backend_table[self.current_plot_backend]

    def set_plot_backend(self, backend):
        if self.current_plot_backend == backend:
            return
        self.current_plot_backend = backend
        pb = self.plot_backend_table[backend]
        self.plot_pane.children = [pb.plot_widget]

    def setup_plot(self, model):
        pb = self.plot_backend_table[self.current_plot_backend]
        pb.clear_fig()
        pb.setup_plot(model)

    def update_plot(self, model):
        pb = self.plot_backend_table[self.current_plot_backend]
        pb.update_plot(model)
        pb.show_fig()
Ejemplo n.º 10
0
class MomentCurvature(tr.HasStrictTraits):
    r'''Class returning the moment curvature relationship.
    '''

    b_z = tr.Any
    get_b_z = tr.Property

    @tr.cached_property
    def _get_get_b_z(self):
        return sp.lambdify(z, self.b_z, 'numpy')

    h = tr.Float

    model_params = tr.Dict({
        E_ct: 24000, E_cc: 25000,
        eps_cr: 0.001,
        eps_cy: -0.003,
        eps_cu: -0.01,
        mu: 0.33,
        eps_tu: 0.003
    })

    # Number of material points along the height of the cross section
    n_m = tr.Int(100)

    # Reinforcement
    z_j = tr.Array(np.float_, value=[10])
    A_j = tr.Array(np.float_, value=[[np.pi * (16 / 2.)**2]])
    E_j = tr.Array(np.float_, value=[[210000]])
    eps_sy_j = tr.Array(np.float_, value=[[500. / 210000.]])

    z_m = tr.Property(depends_on='n_m, h')

    @tr.cached_property
    def _get_z_m(self):
        return np.linspace(0, self.h, self.n_m)

    kappa_range = tr.Tuple(-0.001, 0.001, 101)

    kappa_t = tr.Property(tr.Array(np.float_), depends_on='kappa_range')

    @tr.cached_property
    def _get_kappa_t(self):
        return np.linspace(*self.kappa_range)

    get_eps_z = tr.Property(depends_on='model_params_items')

    @tr.cached_property
    def _get_get_eps_z(self):
        return sp.lambdify(
            (kappa, eps_bot, z), eps_z.subs(subs_eps), 'numpy'
        )

    get_sig_c_z = tr.Property(depends_on='model_params_items')

    @tr.cached_property
    def _get_get_sig_c_z(self):
        return sp.lambdify(
            (kappa, eps_bot, z), sig_c_z_lin.subs(self.model_params), 'numpy'
        )

    get_sig_s_eps = tr.Property(depends_on='model_params_items')

    @tr.cached_property
    def _get_get_sig_s_eps(self):
        return sp.lambdify((eps, E_s, eps_sy), sig_s_eps, 'numpy')

    # Normal force

    def get_N_s_tj(self, kappa_t, eps_bot_t):
        eps_z_tj = self.get_eps_z(
            kappa_t[:, np.newaxis], eps_bot_t[:, np.newaxis],
            self.z_j[np.newaxis, :]
        )
        sig_s_tj = self.get_sig_s_eps(eps_z_tj, self.E_j, self.eps_sy_j)
        return np.einsum('j,tj->tj', self.A_j, sig_s_tj)

    def get_N_c_t(self, kappa_t, eps_bot_t):
        z_tm = self.z_m[np.newaxis, :]
        b_z_m = self.get_b_z(z_tm)  # self.get_b_z(self.z_m) also OK
        N_z_tm = b_z_m * self.get_sig_c_z(
            kappa_t[:, np.newaxis], eps_bot_t[:, np.newaxis], z_tm
        )
        return np.trapz(N_z_tm, x=z_tm, axis=-1)

    def get_N_t(self, kappa_t, eps_bot_t):
        N_s_t = np.sum(self.get_N_s_tj(kappa_t, eps_bot_t), axis=-1)
        return self.get_N_c_t(kappa_t, eps_bot_t) + N_s_t

    # SOLVER: Get eps_bot to render zero force

    eps_bot_t = tr.Property()
    r'''Resolve the tensile strain to get zero normal force 
    for the prescribed curvature
    '''

    def _get_eps_bot_t(self):
        res = root(lambda eps_bot_t: self.get_N_t(self.kappa_t, eps_bot_t),
                   0.0000001 + np.zeros_like(self.kappa_t), tol=1e-6)
        return res.x

    # POSTPROCESSING

    eps_cr = tr.Property()

    def _get_eps_cr(self):
        return np.array([self.model_params[eps_cr]], dtype=np.float_)

    kappa_cr = tr.Property()

    def _get_kappa_cr(self):
        res = root(lambda kappa: self.get_N_t(kappa, self.eps_cr),
                   0.0000001 + np.zeros_like(self.eps_cr), tol=1e-10)
        return res.x

    # Bending moment

    M_s_t = tr.Property()

    def _get_M_s_t(self):
        eps_z_tj = self.get_eps_z(
            self.kappa_t[:, np.newaxis], self.eps_bot_t[:, np.newaxis],
            self.z_j[np.newaxis, :]
        )
        sig_z_tj = self.get_sig_s_eps(
            eps_z_tj, self.E_j, self.eps_sy_j)
        return -np.einsum('j,tj,j->t', self.A_j, sig_z_tj, self.z_j)

    M_c_t = tr.Property()

    def _get_M_c_t(self):
        z_tm = self.z_m[np.newaxis, :]
        b_z_m = self.get_b_z(z_tm)
        N_z_tm = b_z_m * self.get_sig_c_z(
            self.kappa_t[:, np.newaxis], self.eps_bot_t[:, np.newaxis], z_tm
        )
        return -np.trapz(N_z_tm * z_tm, x=z_tm, axis=-1)

    M_t = tr.Property()

    def _get_M_t(self):
        return self.M_c_t + self.M_s_t

    N_s_tj = tr.Property()

    def _get_N_s_tj(self):
        return self.get_N_s_tj(self.kappa_t, self.eps_bot_t)

    eps_tm = tr.Property()

    def _get_eps_tm(self):
        return self.get_eps_z(
            self.kappa_t[:, np.newaxis], self.eps_bot_t[:, np.newaxis],
            self.z_m[np.newaxis, :],
        )

    sig_tm = tr.Property()

    def _get_sig_tm(self):
        return self.get_sig_c_z(
            self.kappa_t[:, np.newaxis], self.eps_bot_t[:, np.newaxis],
            self.z_m[np.newaxis, :],
        )

    idx = tr.Int(0)

    M_norm = tr.Property()

    def _get_M_norm(self):
        # Section modulus @TODO optimize W for var b
        W = (self.b * self.h**2) / 6
        sig_cr = self.model_params[E_ct] * self.model_params[eps_cr]
        return W * sig_cr

    kappa_norm = tr.Property()

    def _get_kappa_norm(self):
        return self.kappa_cr

    def plot_norm(self, ax1, ax2):
        idx = self.idx
        ax1.plot(self.kappa_t / self.kappa_norm, self.M_t / self.M_norm)
        ax1.plot(self.kappa_t[idx] / self.kappa_norm,
                 self.M_t[idx] / self.M_norm, marker='o')
        ax2.barh(self.z_j, self.N_s_tj[idx, :],
                 height=2, color='red', align='center')
        #ax2.fill_between(eps_z_arr[idx,:], z_arr, 0, alpha=0.1);
        ax3 = ax2.twiny()
#        ax3.plot(self.eps_tm[idx, :], self.z_m, color='k', linewidth=0.8)
        ax3.plot(self.sig_tm[idx, :], self.z_m)
        ax3.axvline(0, linewidth=0.8, color='k')
        ax3.fill_betweenx(self.z_m, self.sig_tm[idx, :], 0, alpha=0.1)
        self._align_xaxis(ax2, ax3)

    def plot(self, ax1, ax2):
        idx = self.idx
        ax1.plot(self.kappa_t, self.M_t / (1e6))
        ax1.set_ylabel('Moment [kN.m]')
        ax1.set_xlabel('Curvature [$m^{-1}$]')
        ax1.plot(self.kappa_t[idx], self.M_t[idx] / (1e6), marker='o')
        ax2.barh(self.z_j, self.N_s_tj[idx, :],
                 height=6, color='red', align='center')
        #ax2.plot(self.N_s_tj[idx, :], self.z_j, color='red')
        #print('Z', self.z_j)
        #print(self.N_s_tj[idx, :])
        #ax2.fill_between(eps_z_arr[idx,:], z_arr, 0, alpha=0.1);
        ax3 = ax2.twiny()
#        ax3.plot(self.eps_tm[idx, :], self.z_m, color='k', linewidth=0.8)
        ax3.plot(self.sig_tm[idx, :], self.z_m)
        ax3.axvline(0, linewidth=0.8, color='k')
        ax3.fill_betweenx(self.z_m, self.sig_tm[idx, :], 0, alpha=0.1)
        self._align_xaxis(ax2, ax3)

    def _align_xaxis(self, ax1, ax2):
        """Align zeros of the two axes, zooming them out by same ratio"""
        axes = (ax1, ax2)
        extrema = [ax.get_xlim() for ax in axes]
        tops = [extr[1] / (extr[1] - extr[0]) for extr in extrema]
        # Ensure that plots (intervals) are ordered bottom to top:
        if tops[0] > tops[1]:
            axes, extrema, tops = [list(reversed(l))
                                   for l in (axes, extrema, tops)]

        # How much would the plot overflow if we kept current zoom levels?
        tot_span = tops[1] + 1 - tops[0]

        b_new_t = extrema[0][0] + tot_span * (extrema[0][1] - extrema[0][0])
        t_new_b = extrema[1][1] - tot_span * (extrema[1][1] - extrema[1][0])
        axes[0].set_xlim(extrema[0][0], b_new_t)
        axes[1].set_xlim(t_new_b, extrema[1][1])
Ejemplo n.º 11
0
class EnsembleTrainer(t.HasStrictTraits):
    def __init__(self, config={}, **kwargs):
        trainer_template = Trainer(**config)
        super().__init__(trainer_template=trainer_template,
                         config=config,
                         **kwargs)

    config: dict = t.Dict()

    trainer_template: Trainer = t.Instance(Trainer)
    trainers: ty.List[Trainer] = t.List(t.Instance(Trainer))

    n_folds = t.Int(5)

    dl_test: DataLoader = t.DelegatesTo("trainer_template")
    data_spec: dict = t.DelegatesTo("trainer_template")
    cuda: bool = t.DelegatesTo("trainer_template")
    device: str = t.DelegatesTo("trainer_template")
    loss_func: str = t.DelegatesTo("trainer_template")
    batch_size: int = t.DelegatesTo("trainer_template")
    win_len: int = t.DelegatesTo("trainer_template")
    has_null_class: bool = t.DelegatesTo("trainer_template")
    predict_null_class: bool = t.DelegatesTo("trainer_template")
    name: str = t.Str()

    def _name_default(self):
        import time

        modelstr = "Ensemble"
        timestr = time.strftime("%Y%m%d-%H%M%S")
        return f"{modelstr}_{timestr}"

    X_folds = t.Tuple(transient=True)
    ys_folds = t.Tuple(transient=True)

    def _trainers_default(self):
        # Temp trainer for grabbing datasets, etc
        tt = self.trainer_template
        tt.init_data()

        # Combine official train & val sets
        X = torch.cat(
            [tt.dl_train.dataset.tensors[0], tt.dl_val.dataset.tensors[0]])
        ys = [
            torch.cat([yt, yv]) for yt, yv in zip(
                tt.dl_train.dataset.tensors[1:], tt.dl_val.dataset.tensors[1:])
        ]
        # make folds
        fold_len = int(np.ceil(len(X) / self.n_folds))
        self.X_folds = torch.split(X, fold_len)
        self.ys_folds = [torch.split(y, fold_len) for y in ys]

        trainers = []
        for i_val_fold in range(self.n_folds):
            trainer = Trainer(
                validation_fold=i_val_fold,
                name=f"{self.name}/{i_val_fold}",
                **self.config,
            )

            trainer.dl_test = tt.dl_test

            trainers.append(trainer)

        return trainers

    model: models.BaseNet = t.Instance(torch.nn.Module, transient=True)

    def _model_default(self):
        model = models.FilterNetEnsemble()
        model.set_models([trainer.model for trainer in self.trainers])
        return model

    model_path: str = t.Str()

    def _model_path_default(self):
        return f"saved_models/{self.name}/"

    def init_data(self):
        # Initiate loading of datasets, model
        pass
        # for trainer in self.trainers:
        #     trainer.init_data()

    def init_train(self):
        pass
        # for trainer in self.trainers:
        #     trainer.init_train()

    def train(self, max_epochs=50):
        """ A pretty standard training loop, constrained to stop in `max_epochs` but may stop early if our
        custom stopping metric does not improve for `self.patience` epochs. Always checkpoints
        when a new best stopping_metric is achieved. An alternative to using
        ray.tune for training."""

        for trainer in self.trainers:
            # Add data to trainer

            X_train = torch.cat([
                arr for i, arr in enumerate(self.X_folds)
                if i != trainer.validation_fold
            ])
            ys_train = [
                torch.cat([
                    arr for i, arr in enumerate(y)
                    if i != trainer.validation_fold
                ]) for y in self.ys_folds
            ]

            X_val = torch.cat([
                arr for i, arr in enumerate(self.X_folds)
                if i == trainer.validation_fold
            ])
            ys_val = [
                torch.cat([
                    arr for i, arr in enumerate(y)
                    if i == trainer.validation_fold
                ]) for y in self.ys_folds
            ]

            trainer.dl_train = DataLoader(
                TensorDataset(torch.Tensor(X_train), *ys_train),
                batch_size=trainer.batch_size,
                shuffle=True,
            )
            trainer.data_spec = self.trainer_template.data_spec
            trainer.epoch_iters = self.trainer_template.epoch_iters
            trainer.dl_val = DataLoader(
                TensorDataset(torch.Tensor(X_val), *ys_val),
                batch_size=trainer.batch_size,
                shuffle=False,
            )

            # Now clear local vars to save ranm
            X_train = ys_train = X_val = ys_val = None

            trainer.init_data()
            trainer.init_train()
            trainer.train(max_epochs=max_epochs)

            # Clear trainer train and val datasets to save ram
            trainer.dl_train = t.Undefined
            trainer.dl_val = t.Undefined

            print(f"RESTORING TO best model")
            trainer._restore()
            trainer._save()

            trainer.print_train_summary()

            em = EvalModel(trainer=trainer)

            em.run_test_set()
            em.calc_metrics()
            em.calc_ward_metrics()
            print(em.classification_report_df.to_string(float_format="%.3f"))
            em._save()

    def print_train_summary(self):
        for trainer in self.trainers:
            trainer.print_train_summary()

    def _save(self, checkpoint_dir=None, save_model=True, save_trainer=True):
        """ Saves/checkpoints model state and training state to disk. """
        if checkpoint_dir is None:
            checkpoint_dir = self.model_path
        else:
            self.model_path = checkpoint_dir

        os.makedirs(checkpoint_dir, exist_ok=True)

        # save model params
        model_path = os.path.join(checkpoint_dir, "model.pth")
        trainer_path = os.path.join(checkpoint_dir, "trainer.pth")

        if save_model:
            torch.save(self.model.state_dict(), model_path)
        if save_trainer:
            with open(trainer_path, "wb") as f:
                pickle.dump(self, f)

        return checkpoint_dir

    def _restore(self, checkpoint_dir=None):
        """ Restores model state and training state from disk. """

        if checkpoint_dir is None:
            checkpoint_dir = self.model_path

        model_path = os.path.join(checkpoint_dir, "model.pth")
        trainer_path = os.path.join(checkpoint_dir, "trainer.pth")

        # Reconstitute old trainer and copy state to this trainer.
        with open(trainer_path, "rb") as f:
            other_trainer = pickle.load(f)

        self.__setstate__(other_trainer.__getstate__())

        # Load sub-models
        for trainer in self.trainers:
            trainer._restore()

        # Load model (after loading state in case we need to re-initialize model from config)
        self.model.load_state_dict(
            torch.load(model_path, map_location=self.device))