Example #1
0
class Polarization(ManagedJob, GetSetItemsMixin):
    """
    Record a polarization curve.
    
    written by: [email protected]
    last modified: 2012-08-17
    """

    seconds_per_point = Range(low=1e-4,
                              high=100.,
                              value=1.,
                              desc='integration time for one point',
                              label='seconds per point',
                              mode='text',
                              auto_set=False,
                              enter_set=True)
    angle_step = Range(low=1e-3,
                       high=100.,
                       value=1.,
                       desc='angular step',
                       label='angle step',
                       mode='text',
                       auto_set=False,
                       enter_set=True)

    angle = Array()
    intensity = Array()
    power = Array()

    plot = Instance(Plot)
    plot_data = Instance(ArrayPlotData)

    get_set_items = [
        '__doc__', 'seconds_per_point', 'angle_step', 'angle', 'intensity',
        'power'
    ]

    def __init__(self):
        super(Polarization, self).__init__()
        self._create_plot()
        self.on_trait_change(self._update_index, 'angle', dispatch='ui')
        self.on_trait_change(self._update_value, 'intensity', dispatch='ui')

    def _run(self):
        """Acquire data."""

        try:  # run the acquisition
            self.state = 'run'

            RotationStage().go_home()

            self.angle = np.array(())
            self.intensity = np.array(())
            self.power = np.array(())

            c1 = TimeTagger.Countrate(0)
            c2 = TimeTagger.Countrate(1)

            for phi in np.arange(0., 360., self.angle_step):
                RotationStage().set_angle(phi)
                c1.clear()
                c2.clear()
                self.thread.stop_request.wait(self.seconds_per_point)
                if self.thread.stop_request.isSet():
                    logging.getLogger().debug('Caught stop signal. Exiting.')
                    self.state = 'idle'
                    break
                self.angle = np.append(self.angle, phi)
                self.intensity = np.append(self.intensity,
                                           c1.getData() + c2.getData())
                self.power = np.append(self.power, PowerMeter().getPower())
            else:
                self.state = 'done'

        except:  # if anything fails, recover
            logging.getLogger().exception('Error in polarization.')
            self.state = 'error'
        finally:
            del c1
            del c2

    def _create_plot(self):
        plot_data = ArrayPlotData(
            angle=np.array(()),
            intensity=np.array(()),
        )
        plot = Plot(plot_data, padding=8, padding_left=64, padding_bottom=64)
        plot.plot(('angle', 'intensity'), color='blue')
        plot.index_axis.title = 'angle [deg]'
        plot.value_axis.title = 'intensity [count/s]'
        self.plot_data = plot_data
        self.plot = plot

    def _update_index(self, new):
        self.plot_data.set_data('angle', new)

    def _update_value(self, new):
        self.plot_data.set_data('intensity', new)

    def save_plot(self, filename):
        self.save_figure(self.plot, filename)

    traits_view = View(HGroup(
        Item('submit_button', show_label=False),
        Item('remove_button', show_label=False),
        Item('priority'),
        Item('state', style='readonly'),
    ),
                       HGroup(
                           Item('seconds_per_point'),
                           Item('angle_step'),
                       ),
                       Item('plot', editor=ComponentEditor(),
                            show_label=False),
                       menubar=MenuBar(
                           Menu(Action(action='save',
                                       name='Save (.pyd or .pys)'),
                                Action(action='load', name='Load'),
                                Action(action='save_plot',
                                       name='Save Plot (.png)'),
                                Action(action='_on_close', name='Quit'),
                                name='File')),
                       title='Polarization',
                       width=640,
                       height=640,
                       buttons=[],
                       resizable=True,
                       handler=PolarizationHandler)
Example #2
0
class SurveyLine(HasTraits):
    """ A class representing a single survey line """

    #: the user-visible name for the line
    name = Str

    #: sample locations, an Nx2 array (example: easting/northing?)
    locations = Array(shape=(None, 2))

    #: specifies unit for values in locations array
    locations_unit = Str('feet')

    #: array of associated lat/long available for display
    lat_long = Array(shape=(None, 2))

    #: a dictionary mapping frequencies to intensity arrays
    frequencies = Dict

    #: complete trace_num set. array = combined freq_trace_num arrays
    trace_num = Array

    #: array of trace numbers corresponding to each intensity pixel/column
    #: ! NOTE ! starts at 1, not 0, so need to subtract 1 to use as index
    freq_trace_num = Dict

    #: relevant core samples
    core_samples = List(Supports(ICoreSample))

    #: depth of the lake at each location as generated by various soruces
    lake_depths = Dict(Str, Supports(IDepthLine))

    #: name of final choice for line used as current lake depth for volume calculations
    final_lake_depth = Str

    # and event fired when the lake depths are updated
    lake_depths_updated = Event

    #: The navigation track of the survey line in map coordinates
    navigation_line = Instance(LineString)

    #: pre-impoundment depth at each location as generated by various soruces
    preimpoundment_depths = Dict(Str, Supports(IDepthLine))

    #: name of final choice for pre-impoundment depth to track sedimentation
    final_preimpoundment_depth = Str

    # and event fired when the lake depth is updated
    preimpoundment_depths_updated = Event

    # power values for entire trace set
    power = Array

    # gain values for entire trace set
    gain = Array

    #: Depth corrections:
    #:  depth = (pixel_number_from_top * pixel_resolution) + draft - heave
    #: distance from sensor to water. Constant offset added to depth
    draft = CFloat

    #: array of depth corrections.  Changes vertical offset of each column.
    heave = Array

    #: pixel resolution, depth/pixel
    pixel_resolution = CFloat

    # XXX probably other metadata should be here
    # if some check results in a bad survey line then some text should be
    # put here stating why or where the check was.
    bad_survey_line = Str('')

    def load_data(self, hdf5_file):
        ''' Called by UI to load this survey line when selected to edit
        '''
        # read in sdi dictionary.  Only use 'frequencies' item.
        # sdi_dict_separated = binary.read(self.data_file_path)
        # sdi_dict_raw = binary.read(self.data_file_path, separate=False)
        # freq_dict_list = sdi_dict_separated['frequencies']

        from ..io import survey_io

        # read frequency dict from hdf5 file.
        sdi_dict_raw = survey_io.read_sdi_data_unseparated_from_hdf(hdf5_file,
                                                                    self.name)
        freq_dict_list = survey_io.read_frequency_data_from_hdf(hdf5_file,
                                                                self.name)

        # fill frequncies and freq_trace_num dictionaries with freqs as keys.
        for freq_dict in freq_dict_list:
            key = freq_dict['kHz']
            # transpose array to go into image plot correctly oriented
            intensity = freq_dict['intensity'].T
            self.frequencies[str(key)] = intensity
            self.freq_trace_num[str(key)] = freq_dict['trace_num']

        # for all other traits, use un-freq-sorted values
        self.trace_num = sdi_dict_raw['trace_num']
        self.locations = np.vstack([sdi_dict_raw['interpolated_easting'],
                                   sdi_dict_raw['interpolated_northing']]).T
        self.lat_long = np.vstack([sdi_dict_raw['latitude'],
                                  sdi_dict_raw['longitude']]).T
        self.draft = (np.mean(sdi_dict_raw['draft']))
        self.heave = sdi_dict_raw['heave']
        self.pixel_resolution = (np.mean(sdi_dict_raw['pixel_resolution']))
        self.power = sdi_dict_raw['power']
        self.gain = sdi_dict_raw['gain']
        self.array_sizes_ok()
        filename = os.path.basename(sdi_dict_raw['filepath'])
        sdi_surface = DepthLine(
            name='current_surface_from_bin',
            survey_line_name=self.name,
            line_type='current surface',
            source='sdi_file',
            source_name=filename,
            index_array=self.trace_num - 1,
            depth_array=sdi_dict_raw['depth_r1']            
        )
        survey_io.write_depth_line_to_hdf(hdf5_file, sdi_surface, self.name)
        # depth lines stored separately
        self.lake_depths = survey_io.read_pick_lines_from_hdf(
                                     hdf5_file, self.name, 'current')
        self.preimpoundment_depths = survey_io.read_pick_lines_from_hdf(
                                     hdf5_file, self.name, 'preimpoundment')

    def nearby_core_samples(self, core_samples, dist_tol=100):
        """ Find core samples from a list of CoreSample instances
        that lie within dist_tol units of this survey line.
        """
        def distance(core, line):
            """ Calculate distance between a core sample and a survey line
            """
            from shapely.geometry import Point
            return self.navigation_line.distance(Point(core.location))
        cores = [core for core in core_samples
                 if distance(core, self) < dist_tol]
        return cores
    
    def array_sizes_ok(self):
        ''' this is an check that the arrays for this line make sense
        All the non-separated arrays should be the same size and the
        trace_num array should be range(1,N).  This could be slightly
        more general by assuming and order array instead of contiguous'''
        name = self.name
        logger.info('Checking all array integrity for line {}'.format(name))
        arrays = ['trace_num', 'locations', 'lat_long', 'heave', 'power',
                  'gain']
        # N = self.trace_num.shape[0]
        # check = self.trace_num - np.arange(N) - 1
        # if np.any(check != 0):
        #     bad_traces = np.nonzero(check)[0] + 1
        #     values = self.trace_num[bad_traces - 1]
        #     print check, bad_traces, values
        #     s = '''trace_num not contiguous for array: {}.
        #     values of {} at traces {}
        #     '''.format(name, values, bad_traces)
        #     logger.warn(s)
        #     self.fix_trace_num(N, bad_traces, values)
        # now check rest of arrays
        
        from ..io import survey_io
        bad_indices, bad_vals = survey_io.check_trace_num_array(self.trace_num,
                                                                self.name)
        tn, fn = survey_io.fix_trace_num_arrays(self.trace_num,
                                                bad_indices,
                                                self.freq_trace_num)
        self.trace_num = tn
        self.freq_trace_num = fn
        N = len(tn)
        for a in arrays:
            if getattr(self, a).shape[0] != N:
                s = '{} is not size {}'.format(a, N)
                logger.warn(s)
                self.bad_survey_line = "Array sizes don't match on load"

    def fix_trace_num(self, N, bad_traces, values):
        for freq, trace_array in self.freq_trace_num.items():
            for t, v in zip(bad_traces, values):
                if v in trace_array:
                    i = np.floor((t - 1) / 3.0)
                    print 'freq trace i value is',freq, i, trace_array[i], i+1
                    trace_array[i] = t
            self.freq_trace_num[freq] = trace_array
        for f, v in self.freq_trace_num.items():
            print 'max is ', f, v.max(), v.shape
        self.trace_num = np.arange(1, N + 1)
Example #3
0
class FETS2D4Q8U(FETSEval):
    debug_on = True

    mats_eval = Instance(MATSEval)

    # Dimensional mapping
    dim_slice = slice(0, 2)

    n_e_dofs = Int(2 * 8)
    t = Float(1.0, label='thickness')

    # Integration parameters
    #
    ngp_r = 3
    ngp_s = 3

    dof_r = Array(value=[[-1., -1.], [1., -1.], [1., 1.], [-1., 1.], [0., -1.],
                         [1., 0.], [0., 1.], [-1., 0.]])
    geo_r = Array(value=[[-1, -1], [1, -1], [1, 1], [-1, 1]])
    #
    vtk_r = Array(value=[[-1., -1.], [0., -1.], [1., -1.], [-1., 0.], [1., 0.],
                         [-1., 1.], [0., 1.], [1., 1.]])
    vtk_cells = [[0, 2, 7, 5, 1, 4, 6, 3]]
    vtk_cell_types = 'QuadraticQuad'

    n_nodal_dofs = 2

    # Ordering of the nodes of the parent element used for the geometry
    # approximation
    _node_coord_map_geo = Array(Float, (4, 2),
                                [[-1., -1.], [1., -1.], [1., 1.], [-1., 1.]])

    #---------------------------------------------------------------------
    # Method required to represent the element geometry
    #---------------------------------------------------------------------
    def get_N_geo_mtx(self, r_pnt):
        '''
        Return the value of shape functions for the specified local coordinate r_pnt
        '''
        cx = self._node_coord_map_geo
        N_geo_mtx = np.array([[
            1 / 4. * (1 + r_pnt[0] * cx[i, 0]) * (1 + r_pnt[1] * cx[i, 1])
            for i in range(0, 4)
        ]])
        return N_geo_mtx

    def get_dNr_geo_mtx(self, r_pnt):
        '''
        Return the matrix of shape function derivatives.
        Used for the conrcution of the Jacobi matrix.

        @TODO - the B matrix is used
        just for uniaxial bar here with a trivial differential
        operator.
        '''
        cx = self._node_coord_map_geo
        dNr_geo_mtx = np.array([[
            1 / 4. * cx[i, 0] * (1 + r_pnt[1] * cx[i, 1]) for i in range(0, 4)
        ], [
            1 / 4. * cx[i, 1] * (1 + r_pnt[0] * cx[i, 0]) for i in range(0, 4)
        ]])
        return dNr_geo_mtx

    #-------------------------------------------------------------------------
    # Method delivering the shape functions for the field variables and their derivatives
    #-------------------------------------------------------------------------
    def get_N_mtx(self, r_pnt):
        '''
        Returns the matrix of the shape functions (derived in femple) used for the field 
        approximation containing zero entries. The number of rows corresponds to the number 
        of nodal dofs. The matrix is evaluated for the specified local coordinate r_pnt.
        '''
        N_dof = np.zeros((1, 8))
        N_dof[0, 0] = -((-1 + r_pnt[1]) * (-1 + r_pnt[0]) *
                        (r_pnt[0] + 1 + r_pnt[1])) / 4.0
        N_dof[0, 1] = -((-1 + r_pnt[1]) * (1 + r_pnt[0]) *
                        (r_pnt[0] - 1 - r_pnt[1])) / 4.0
        N_dof[0, 2] = ((1 + r_pnt[1]) * (1 + r_pnt[0]) *
                       (r_pnt[0] - 1 + r_pnt[1])) / 4.0
        N_dof[0, 3] = ((1 + r_pnt[1]) * (-1 + r_pnt[0]) *
                       (r_pnt[0] + 1 - r_pnt[1])) / 4.0
        N_dof[0,
              4] = ((-1 + r_pnt[0]) * (1 + r_pnt[0]) * (-1 + r_pnt[1])) / 2.0
        N_dof[0, 5] = - \
            ((-1 + r_pnt[1]) * (1 + r_pnt[1]) * (1 + r_pnt[0])) / 2.0
        N_dof[0, 6] = - \
            ((-1 + r_pnt[0]) * (1 + r_pnt[0]) * (1 + r_pnt[1])) / 2.0
        N_dof[0,
              7] = ((-1 + r_pnt[1]) * (1 + r_pnt[1]) * (-1 + r_pnt[0])) / 2.0

        I_mtx = np.identity(self.n_nodal_dofs, float)
        N_mtx_list = [I_mtx * N_dof[0, i] for i in range(0, N_dof.shape[1])]
        N_mtx = np.hstack(N_mtx_list)
        return N_mtx

    def get_dNr_mtx(self, r_pnt):
        '''
        Return the derivatives of the shape functions (derived in femple) 
        used for the field approximation
        '''
        dNr_mtx = np.zeros((2, 8), dtype='float_')
        dNr_mtx[0,
                0] = -((-1 + r_pnt[1]) * (r_pnt[0] + 1 + r_pnt[1])) / 4.0 - (
                    (-1 + r_pnt[1]) * (-1 + r_pnt[0])) / 4.0
        dNr_mtx[0,
                1] = -((-1 + r_pnt[1]) * (r_pnt[0] - 1 - r_pnt[1])) / 4.0 - (
                    (-1 + r_pnt[1]) * (1 + r_pnt[0])) / 4.0
        dNr_mtx[0,
                2] = ((1 + r_pnt[1]) *
                      (r_pnt[0] - 1 + r_pnt[1])) / 4.0 + ((1 + r_pnt[1]) *
                                                          (1 + r_pnt[0])) / 4.0
        dNr_mtx[0, 3] = ((1 + r_pnt[1]) * (r_pnt[0] + 1 - r_pnt[1])) / 4.0 + (
            (1 + r_pnt[1]) * (-1 + r_pnt[0])) / 4.0
        dNr_mtx[0, 4] = ((-1 + r_pnt[1]) * (1 + r_pnt[0])) / \
            2.0 + ((-1 + r_pnt[1]) * (-1 + r_pnt[0])) / 2.0
        dNr_mtx[0, 5] = -((-1 + r_pnt[1]) * (1 + r_pnt[1])) / 2.0
        dNr_mtx[0, 6] = -((1 + r_pnt[1]) * (1 + r_pnt[0])) / \
            2.0 - ((1 + r_pnt[1]) * (-1 + r_pnt[0])) / 2.0
        dNr_mtx[0, 7] = ((-1 + r_pnt[1]) * (1 + r_pnt[1])) / 2.0
        dNr_mtx[1,
                0] = -((-1 + r_pnt[0]) * (r_pnt[0] + 1 + r_pnt[1])) / 4.0 - (
                    (-1 + r_pnt[1]) * (-1 + r_pnt[0])) / 4.0
        dNr_mtx[1, 1] = -((1 + r_pnt[0]) * (r_pnt[0] - 1 - r_pnt[1])) / 4.0 + (
            (-1 + r_pnt[1]) * (1 + r_pnt[0])) / 4.0
        dNr_mtx[1,
                2] = ((1 + r_pnt[0]) *
                      (r_pnt[0] - 1 + r_pnt[1])) / 4.0 + ((1 + r_pnt[1]) *
                                                          (1 + r_pnt[0])) / 4.0
        dNr_mtx[1, 3] = ((-1 + r_pnt[0]) * (r_pnt[0] + 1 - r_pnt[1])) / 4.0 - (
            (1 + r_pnt[1]) * (-1 + r_pnt[0])) / 4.0
        dNr_mtx[1, 4] = ((-1 + r_pnt[0]) * (1 + r_pnt[0])) / 2.0
        dNr_mtx[1, 5] = -((1 + r_pnt[1]) * (1 + r_pnt[0])) / \
            2.0 - ((-1 + r_pnt[1]) * (1 + r_pnt[0])) / 2.0
        dNr_mtx[1, 6] = -((-1 + r_pnt[0]) * (1 + r_pnt[0])) / 2.0
        dNr_mtx[1, 7] = ((1 + r_pnt[1]) * (-1 + r_pnt[0])) / \
            2.0 + ((-1 + r_pnt[1]) * (-1 + r_pnt[0])) / 2.0
        return dNr_mtx

    def get_B_mtx(self, r_pnt, X_mtx):
        J_mtx = self.get_J_mtx(r_pnt, X_mtx)
        dNr_mtx = self.get_dNr_mtx(r_pnt)
        dNx_mtx = np.dot(inv(J_mtx), dNr_mtx)
        Bx_mtx = np.zeros((3, 16), dtype='float_')
        for i in range(0, 8):
            Bx_mtx[0, i * 2] = dNx_mtx[0, i]
            Bx_mtx[1, i * 2 + 1] = dNx_mtx[1, i]
            Bx_mtx[2, i * 2] = dNx_mtx[1, i]
            Bx_mtx[2, i * 2 + 1] = dNx_mtx[0, i]
        return Bx_mtx
Example #4
0
class GridCell( SDomain ):
    '''
    A single mgrid cell for geometrical representation of the domain.
    
    Based on the grid_cell_spec attribute, 
    the node distribution is determined.
    
    '''
    # Everything depends on the grid_cell_specification
    #
    grid_cell_spec = Instance( CellSpec )
    def _grid_cell_spec_default( self ):
        return CellSpec()

    # Generated grid cell coordinates as they come from mgrid.
    # The dimensionality of the mgrid comes from the 
    # grid_cell_spec_attribute
    #
    grid_cell_coords = Property( depends_on = 'grid_cell_spec' )
    @cached_property
    def _get_grid_cell_coords( self ):
        grid_cell = mgrid[ self.grid_cell_spec.get_cell_slices() ]
        return c_[ tuple( [ x.flatten() for x in grid_cell ] ) ]

    n_nodes = Property( depends_on = 'grid_cell_spec' )
    @cached_property
    def _get_n_nodes( self ):
        '''Return the number of all nodes within the cell.
        '''
        return self.grid_cell_coords.shape[0]

    # Node map lists the active nodes within the grid cell
    # in the specified order
    #
    node_map = Property( Array( int ), depends_on = 'grid_cell_spec' )
    @cached_property
    def _get_node_map( self ):
        n_map = []
        for node in self.grid_cell_spec._node_array:
            for idx, grid_cell_node in enumerate( self.grid_cell_coords ):
                if allclose( node , grid_cell_node , atol = 1.0e-3 ):
                    n_map.append( idx )
                    continue
        return array( n_map, int )

    # #-----------------------------------------------------------------
    # # Visualization related methods
    # #-----------------------------------------------------------------
    # mvp_mgrid_ngeo_labels = Trait( MVPointLabels )
    # def _mvp_mgrid_ngeo_labels_default( self ):
    #     return MVPointLabels( name = 'Geo node numbers',
    #                               points = self._get_points,
    #                               scalars = self._get_node_distribution )

    refresh_button = Button( 'Draw' )
    @on_trait_change( 'refresh_button' )
    def redraw( self ):
        '''
        '''
        self.mvp_mgrid_ngeo_labels.redraw( 'label_scalars' )

    def _get_points( self ):
        points = self.grid_cell_coords[ ix_( self.node_map ) ]
        shape = points.shape
        if shape[1] < 3:
            _points = zeros( ( shape[0], 3 ), dtype = float )
            _points[:, 0:shape[1]] = points
            return _points
        else:
            return points

    def _get_node_distribution( self ):
        #return arange(len(self.node_map))
        n_points = self.grid_cell_coords.shape[0]
        full_node_map = ones( n_points, dtype = float ) * -1.
        full_node_map[ ix_( self.node_map ) ] = arange( len( self.node_map ) )
        return full_node_map

    def __getitem__( self, idx ):
        # construct the full boolean map of the grid cell'        
        node_bool_map = repeat( False, self.n_nodes ).reshape( self.grid_cell_spec.get_cell_shape() )
        # put true at the sliced positions         
        node_bool_map[idx] = True
        # extract the used nodes using the node map
        node_selection = node_bool_map.flatten()[ self.node_map ]
        return node_selection

    #------------------------------------------------------------------
    # UI - related methods
    #------------------------------------------------------------------
    traits_view = View( Item( 'grid_cell_spec' ),
                       Item( 'refresh_button' ),
                       Item( 'node_map' ),
                       resizable = True,
                       height = 0.5,
                       width = 0.5 )
Example #5
0
class Mayavi3DScene(Editor):  # pylint: disable=too-many-instance-attributes
    """
    A Pyface Tasks Editor for holding a Mayavi scene
    """
    #: The model object to view. If not specified, the editor is used instead.
    model = Instance(HasTraits)

    #: The UI object associated with the Traits view, if it has been
    #: constructed.
    ui = Instance("traitsui.ui.UI")

    #: The editor's user-visible name.
    name = Str('3D View')

    #: Configuration parser.
    configuration = Instance(ConfigParser)

    #: Current participant ID.
    participant_id = Str()

    #: The :py:class:`EMFields` instance containing the field data.
    fields_model = Instance(EMFields)

    #: Normal vector of the cut plane
    normal = Array()

    #: Origin point of the cut plane
    origin = Array()

    #: The :py:class:`mayavi.core.ui.api.MlabSceneModel` instance
    #: containing the 3D plot.
    scene = Instance(MlabSceneModel, ())

    #: The mayavi pipeline object containing the cut plane.
    data_set_clipper = Instance(DataSetClipper)

    #: The list of points describing the line for the line figure.
    points = List(ArrayClass,
                  value=[
                      ArrayClass(value=np.array([0, 0, -1])),
                      ArrayClass(value=np.array([0, 0, 1]))
                  ])

    #: The 3D surface object for the line.
    line = Instance(Surface)

    #: The field data source.
    src = Instance(ArraySource)

    #: The field cut plane.
    cut = Instance(CutPlane)

    #: The 3D surface for the field data.
    surf = Instance(Surface)

    #: The path to the spinal cord model file.
    csf_model = File()

    #: The mayavi file reader object to read the spinal cord model file.
    csf_model_reader = Any()

    #: The 3D surface object for the cut spinal cord model.
    csf_surface = Instance(Surface)

    #: Show the full spinal cord model?
    show_full_model = Bool()

    #: The 3D surface object for the full spinal cord model.
    full_csf_surface = Instance(Surface)

    #: Use a logarithmic scale for the field data?
    log_scale = Bool()

    #: Current participant ID.
    participant_id = Str()

    def default_traits_view(self):  # pylint: disable=no-self-use
        """
        Create the default traits View object for the model

        Returns
        -------
        default_traits_view : :py:class:`traitsui.view.View`
            The default traits View object for the model
        """
        return View(
            Item('scene',
                 show_label=False,
                 editor=SceneEditor(scene_class=MayaviScene)))

    def create(self, parent):
        """
        Create and set the widget(s) for the Editor.

        Parameters
        ----------
        parent : toolkit-specific widget
            The parent widget for the Editor
        """
        self.ui = self.edit_traits(kind='subpanel', parent=parent)  # pylint: disable=invalid-name
        self.control = self.ui.control  # pylint: disable=attribute-defined-outside-init

    def destroy(self):
        """
        Destroy the Editor and clean up after
        """
        self.control = None  # pylint: disable=attribute-defined-outside-init
        if self.ui is not None:
            self.ui.dispose()
        self.ui = None

    @observe('log_scale', post_init=True)
    def toggle_log_scale(self, event):
        """
        Toggle between using a logarithmic scale and a linear scale

        Parameters
        ----------
        event : A :py:class:`traits.observation.events.TraitChangeEvent` instance
            The trait change event for log_scale
        """
        if event.new:
            self.surf.parent.scalar_lut_manager.lut.scale = 'log10'
        else:
            self.surf.parent.scalar_lut_manager.lut.scale = 'linear'
        self.scene.mlab.draw()

    @observe('origin', post_init=True)
    def update_origin(self, event):
        """
        Update objects when the cut plane origin is changed.

        Parameters
        ----------
        event : A :py:class:`traits.observation.events.TraitChangeEvent` instance
            The trait change event for origin
        """
        if hasattr(self.data_set_clipper, 'widget'):
            self.data_set_clipper.widget.widget.origin = event.new
            self.cut.filters[0].widget.origin = event.new
        self.scene.mlab.draw()

    @observe('normal', post_init=True)
    def update_normal(self, event):
        """
        Update objects when the cut plane normal is changed.

        Parameters
        ----------
        event : A :py:class:`traits.observation.events.TraitChangeEvent` instance
            The trait change event for normal
        """
        if hasattr(self.data_set_clipper, 'widget'):
            self.data_set_clipper.widget.widget.normal = event.new
            self.cut.filters[0].widget.normal = event.new
        self.scene.mlab.draw()

    @observe('show_full_model', post_init=True)
    def toggle_full_model(self, event):
        """
        Toggle between showing the full spinal cord model and showing only below the cut plane.

        Parameters
        ----------
        event : A :py:class:`traits.observation.events.TraitChangeEvent` instance
            The trait change event for show_full_model.
        """
        self.csf_surface.visible = not event.new
        self.full_csf_surface.visible = event.new

        self.scene.mlab.draw()

    def reset_participant_defaults(self):
        self.reset_traits(traits=[
            'csf_model', 'show_full_model', 'log_scale', 'normal', 'origin'
        ])

    @observe('csf_model', post_init=True)
    def change_cord_model(self, event):
        """
        Change the spinal cord model file used for the 3D display.

        Parameters
        ----------
        event : A :py:class:`traits.observation.events.TraitChangeEvent` instance
            The trait change event for csf_model
        """
        if self.csf_model_reader is not None:
            self.csf_model_reader.initialize(event.new)

    @observe('scene.activated')
    def initialize_camera(self, event=None):  # pylint: disable=unused-argument
        """
        Set the camera for the Mayavi scene to a pre-determined perspective.

        Parameters
        ----------
        event : A :py:class:`traits.observation.events.TraitChangeEvent` instance
            The trait change event for scene.activated
        """
        if self.csf_surface is not None:
            self.scene.engine.current_object = self.csf_surface
        self.scene.mlab.view(azimuth=-35, elevation=75)

        self.scene.mlab.draw()

    def create_plot(self):
        """
        Create the 3D objects to be shown.
        """
        normal = self.normal

        max_ind = np.unravel_index(
            np.nanargmax(self.fields_model.masked_grid_data),
            self.fields_model.masked_grid_data.shape)

        self.origin = np.array([
            self.fields_model.masked_gr_x[max_ind],
            self.fields_model.masked_gr_y[max_ind],
            self.fields_model.masked_gr_z[max_ind]
        ])

        self.csf_model_reader = self.scene.engine.open(self.csf_model)
        self.csf_surface = Surface()

        self.data_set_clipper = DataSetClipper()

        self.scene.engine.add_filter(self.data_set_clipper,
                                     self.csf_model_reader)

        self.data_set_clipper.widget.widget_mode = 'ImplicitPlane'
        self.data_set_clipper.widget.widget.normal = normal
        self.data_set_clipper.widget.widget.origin = self.origin
        self.data_set_clipper.widget.widget.enabled = False
        self.data_set_clipper.widget.widget.key_press_activation = False
        self.data_set_clipper.filter.inside_out = True
        self.csf_surface.actor.property.opacity = 0.3
        self.csf_surface.actor.property.specular_color = (0.0, 0.0, 1.0)
        self.csf_surface.actor.property.specular = 1.0
        self.csf_surface.actor.actor.use_bounds = False

        self.scene.engine.add_filter(self.csf_surface, self.data_set_clipper)

        self.full_csf_surface = Surface()
        self.full_csf_surface.actor.property.opacity = 0.3
        self.full_csf_surface.actor.property.specular_color = (0.0, 0.0, 1.0)
        self.full_csf_surface.actor.property.specular = 1.0
        self.full_csf_surface.actor.actor.use_bounds = False
        self.full_csf_surface.visible = False

        self.scene.engine.add_filter(self.full_csf_surface,
                                     self.csf_model_reader)

        self.src = self.scene.mlab.pipeline.scalar_field(
            self.fields_model.masked_gr_x, self.fields_model.masked_gr_y,
            self.fields_model.masked_gr_z, self.fields_model.masked_grid_data)
        self.cut = self.scene.mlab.pipeline.cut_plane(self.src)
        self.cut.filters[0].widget.normal = normal
        self.cut.filters[0].widget.origin = self.origin
        self.cut.filters[0].widget.enabled = False
        self.surf = self.scene.mlab.pipeline.surface(self.cut, colormap='jet')
        self.surf.actor.actor.use_bounds = False
        self.surf.parent.scalar_lut_manager.lut.nan_color = np.array(
            [0, 0, 0, 0])

        self.scene.mlab.draw()

    @observe(ob.trait('points').list_items().trait(
        'value', optional=True).list_items(optional=True),
             post_init=True)
    def draw_line(self, event):
        """
        Create or update the line described by the points in :ref:`line-attributes`.

        Parameters
        ----------
        event : A :py:class:`traits.observation.events.TraitChangeEvent` instance
            The trait change event for points.
        """
        if None in event.new and len(event.old) == len(
                event.new) and None not in event.old:
            self.points = event.old
            return
        points = np.array([
            val.value if val is not None else np.array([0, 0, 0])
            for val in self.points
        ])

        x_positions = []
        y_positions = []
        z_positions = []

        for point in points:
            x_positions.append(point[0])
            y_positions.append(point[1])
            z_positions.append(point[2])

        if not hasattr(self.line, 'mlab_source'):
            self.line = self.scene.mlab.plot3d(x_positions,
                                               y_positions,
                                               z_positions,
                                               tube_radius=0.2,
                                               color=(1, 0, 0),
                                               figure=self.scene.mayavi_scene)
        else:
            self.line.mlab_source.reset(x=x_positions,
                                        y=y_positions,
                                        z=z_positions)

        self.scene.mlab.draw()

    def disable_widgets(self):
        """
        Disable widgets to be hidden and set up color properties.
        """
        if self.data_set_clipper.widget.widget.enabled:
            self.cut.filters[0].widget.enabled = False
            self.data_set_clipper.widget.widget.enabled = False
            if self.log_scale:
                self.surf.parent.scalar_lut_manager.lut.scale = 'log10'
            else:
                self.surf.parent.scalar_lut_manager.lut.scale = 'linear'
            self.surf.parent.scalar_lut_manager.show_legend = True
            self.surf.parent.scalar_lut_manager.use_default_name = False
            self.surf.parent.scalar_lut_manager.data_name = 'J (A/m^2)'
            self.surf.parent.scalar_lut_manager.shadow = True
            self.surf.parent.scalar_lut_manager.use_default_range = False
            self.surf.parent.scalar_lut_manager.data_range = np.array([
                np.nanmin(self.fields_model.masked_grid_data),
                np.nanmax(self.fields_model.masked_grid_data)
            ])
            self.surf.parent.scalar_lut_manager.lut.nan_color = np.array(
                [0, 0, 0, 0])

    def _csf_model_default(self):
        try:
            model = self._get_default_value('csf_model')
        except KeyError:
            model = os.path.join(os.getcwd(), 'CSF.vtk')
        return model

    def _show_full_model_default(self):
        full_model = self.configuration.BOOLEAN_STATES[self._get_default_value(
            'full_model').lower()]
        return full_model

    def _log_scale_default(self):
        log_scale = self.configuration.BOOLEAN_STATES[self._get_default_value(
            'log_scale').lower()]
        return log_scale

    def _normal_default(self):
        normal = self._get_default_value('normal')
        return np.fromstring(normal.strip('()'), sep=',')

    def _origin_default(self):
        origin = self._get_default_value('origin')
        return np.fromstring(origin.strip('()'), sep=',')

    def _get_default_value(self, option):
        if self.participant_id is not None:
            if self.participant_id not in self.configuration:
                self.configuration[self.participant_id] = {}
            val = self.configuration[self.participant_id][option]
        else:
            val = self.configuration[self.participant_id][option]
        return val
Example #6
0
class PulsedFit(HasTraits, GetSetItemsMixin):

    """
    Base class for a pulsed fit. Provides calculation of normalized intensity.
    Derive from this to create fits for pulsed measurements.
    """

    measurement = Instance(dr.PulsedDEER, factory=dr.PulsedDEER)
    
    pulse = Array(value=np.array((0., 0.)))
    flank = Float(value=0.0)
    x_tau = Array(value=np.array((0., 1.)))
    spin_state = Array(value=np.array((0., 0.)))
    spin_state_error = Array(value=np.array((0., 0.)))
    
    integration_width = Range(low=10., high=1000., value=200., desc='time window for pulse analysis [ns]', label='integr. width [ns]', mode='text', auto_set=False, enter_set=True)
    position_signal = Range(low= -100., high=1000., value=0., desc='position of signal window relative to edge [ns]', label='pos. signal [ns]', mode='text', auto_set=False, enter_set=True)
    position_normalize = Range(low=0., high=10000., value=2200., desc='position of normalization window relative to edge [ns]', label='pos. norm. [ns]', mode='text', auto_set=False, enter_set=True)
    
    def __init__(self):
        super(PulsedFit, self).__init__()
        self.on_trait_change(self.update_plot_spin_state, 'spin_state', dispatch='ui')
    
    @on_trait_change('measurement.count_data,integration_width,position_signal,position_normalize')
    def update_spin_state(self):
        if self.measurement is None:
            return
        y, profile, flank = spin_state(c=self.measurement.count_data,
                                       dt=self.measurement.bin_width,
                                       T=self.integration_width,
                                       t0=self.position_signal,
                                       t1=self.position_normalize,)
        self.spin_state = y
        self.spin_state_error = y ** 0.5
        self.pulse = profile
        self.flank = self.measurement.time_bins[flank]
        self.x_tau = self.measurement.tau

    # The following is a bit tricky. Data for plots is passed to the PulsedAnalyzer through these two attributes.
    # The first one is a standard chaco ArrayPlotData instance. It is very important to specify a proper initial instance,
    # otherwise the PulsedAnalyzer will not start. The line below specifies an initial instance through the factory and kw argument.
    line_data = Instance(ArrayPlotData,
                         factory=ArrayPlotData,
                         kw={'pulse_number':np.array((0, 1)),
                             'spin_state':np.array((0, 0)),
                             }
                         )
    # The second line is a list that is interpreted to create appropriate plots through the chaco Plot.plot() method.
    # The list elements are dictionaries that are directly passed to the Plot.plot() command as keyword arguments through the **kwagrs expansion 
    plots = [ {'data':('pulse_number', 'spin_state'), 'color':'blue', 'name':'pulsed'} ]
        
    def update_plot_spin_state(self):
        old_mesh = self.x_tau
        #old_mesh = self.line_data.get_data('pulse_number')
        if old_mesh is not None and len(old_mesh) != len(self.spin_state):
            self.line_data.set_data('pulse_number', np.arange(len(self.spin_state)))
        self.line_data.set_data('spin_state', self.spin_state)
        self.line_data.set_data('pulse_number', old_mesh)

    traits_view = View(HGroup(Item('integration_width'),
                              Item('position_signal'),
                              Item('position_normalize'),
                              ),
                       title='Pulsed Fit',
                       )

    get_set_items = ['__doc__', 'integration_width', 'position_signal', 'position_normalize', 'pulse', 'flank', 'spin_state', 'spin_state_error']
Example #7
0
class LUTManager(Base):

    # The version of this class.  Used for persistence.
    __version__ = 0

    # The lookup table.
    lut = Instance(tvtk.LookupTable, (), record=False)
    # The scalar bar.
    scalar_bar = Instance(tvtk.ScalarBarActor, (), record=True)
    # The scalar_bar_widget
    scalar_bar_widget = Instance(tvtk.ScalarBarWidget, ())

    # The representation associated with the scalar_bar_widget.  This
    # only exists in VTK versions about around 5.2.
    scalar_bar_representation = Instance(tvtk.Object,
                                         allow_none=True,
                                         record=True)

    # The title text property of the axes.
    title_text_property = Property(record=True)

    # The label text property of the axes.
    label_text_property = Property(record=True)

    # The current mode of the LUT.
    lut_mode = Enum('blue-red',
                    lut_mode_list(),
                    desc='the type of the lookup table')

    # File name of the LUT file to use.
    file_name = Str('',
                    editor=FileEditor,
                    desc='the filename containing the LUT')

    # Reverse the colors of the LUT.
    reverse_lut = Bool(False, desc='if the lut is to be reversed')

    # Turn on/off the visibility of the scalar bar.
    show_scalar_bar = Bool(False, desc='if scalar bar is shown or not')

    # This is an alias for show_scalar_bar.
    show_legend = Property(Bool, desc='if legend is shown or not')

    # The number of labels to use for the scalar bar.
    number_of_labels = Range(0,
                             64,
                             8,
                             enter_set=True,
                             auto_set=False,
                             desc='the number of labels to display')

    # Number of colors for the LUT.
    number_of_colors = Range(2,
                             2147483647,
                             256,
                             enter_set=True,
                             auto_set=False,
                             desc='the number of colors for the LUT')

    # Enable shadowing of the labels and text.
    shadow = Bool(False, desc='if the labels and text have shadows')

    # Use the default data name or the user specified one.
    use_default_name = Bool(True,
                            desc='if the default data name is to be used')

    # The default data name -- set by the module manager.
    default_data_name = Str('data',
                            enter_set=True,
                            auto_set=False,
                            desc='the default data name')

    # The optionally user specified name of the data.
    data_name = Str('',
                    enter_set=True,
                    auto_set=False,
                    desc='the title of the legend')

    # Use the default range or user specified one.
    use_default_range = Bool(True,
                             desc='if the default data range is to be used')
    # The default data range -- this is computed and set by the
    # module manager.
    default_data_range = Array(shape=(2, ),
                               value=[0.0, 1.0],
                               dtype=float,
                               enter_set=True,
                               auto_set=False,
                               desc='the default range of the data mapped')

    # The optionally user defined range of the data.
    data_range = Array(shape=(2, ),
                       value=[0.0, 1.0],
                       dtype=float,
                       enter_set=True,
                       auto_set=False,
                       desc='the range of the data mapped')

    # Create a new LUT.
    create_lut = Button('Launch LUT editor',
                        desc='if we launch a Lookup table editor in'
                        ' a separate process')

    ########################################
    ## Private traits.
    # The original range of the data.
    _orig_data_range = Array(shape=(2, ), value=[0.0, 1.0], dtype=float)
    _title_text_property = Instance(tvtk.TextProperty)
    _label_text_property = Instance(tvtk.TextProperty)

    ######################################################################
    # `object` interface
    ######################################################################
    def __init__(self, **traits):
        super(LUTManager, self).__init__(**traits)

        # Initialize the scalar bar.
        sc_bar = self.scalar_bar
        sc_bar.set(lookup_table=self.lut,
                   title=self.data_name,
                   number_of_labels=self.number_of_labels,
                   orientation='horizontal',
                   width=0.8,
                   height=0.17)
        pc = sc_bar.position_coordinate
        pc.set(coordinate_system='normalized_viewport', value=(0.1, 0.01, 0.0))
        self._shadow_changed(self.shadow)

        # Initialize the lut.
        self._lut_mode_changed(self.lut_mode)

        # Set the private traits.
        ttp = self._title_text_property = sc_bar.title_text_property
        ltp = self._label_text_property = sc_bar.label_text_property

        # Call render when the text properties are changed.
        ttp.on_trait_change(self.render)
        ltp.on_trait_change(self.render)

        # Initialize the scalar_bar_widget
        self.scalar_bar_widget.set(scalar_bar_actor=self.scalar_bar,
                                   key_press_activation=False)
        self._number_of_colors_changed(self.number_of_colors)

    ######################################################################
    # `Base` interface
    ######################################################################
    def start(self):
        """This is invoked when this object is added to the mayavi
        pipeline.
        """
        # Do nothing if we are already running.
        if self.running:
            return

        # Show the legend if necessary.
        self._show_scalar_bar_changed(self.show_scalar_bar)

        # Call parent method to set the running state.
        super(LUTManager, self).start()

    def stop(self):
        """Invoked when this object is removed from the mayavi
        pipeline.
        """
        if not self.running:
            return

        # Hide the scalar bar.
        sbw = self.scalar_bar_widget
        if sbw.interactor is not None:
            sbw.off()

        # Call parent method to set the running state.
        super(LUTManager, self).stop()

    ######################################################################
    # Non-public interface
    ######################################################################
    def _lut_mode_changed(self, value):

        if value == 'file':
            if self.file_name:
                self.load_lut_from_file(self.file_name)
            #self.lut.force_build()
            return

        reverse = self.reverse_lut
        if value in pylab_luts:
            lut = pylab_luts[value]
            if reverse:
                lut = lut[::-1, :]
            n_total = len(lut)
            n_color = self.number_of_colors
            if not n_color >= n_total:
                lut = lut[::int(round(n_total / float(n_color)))]
            self.load_lut_from_list(lut.tolist())
            #self.lut.force_build()
            return
        elif value == 'blue-red':
            if reverse:
                hue_range = 0.0, 0.6667
                saturation_range = 1.0, 1.0
                value_range = 1.0, 1.0
            else:
                hue_range = 0.6667, 0.0
                saturation_range = 1.0, 1.0
                value_range = 1.0, 1.0
        elif value == 'black-white':
            if reverse:
                hue_range = 0.0, 0.0
                saturation_range = 0.0, 0.0
                value_range = 1.0, 0.0
            else:
                hue_range = 0.0, 0.0
                saturation_range = 0.0, 0.0
                value_range = 0.0, 1.0
        lut = self.lut
        lut.set(hue_range=hue_range,
                saturation_range=saturation_range,
                value_range=value_range,
                number_of_table_values=self.number_of_colors,
                ramp='sqrt')
        lut.modified()
        lut.force_build()

        self.render()

    def _scene_changed(self, value):
        sbw = self.scalar_bar_widget
        if value is None:
            return
        if sbw.interactor is not None:
            sbw.off()
        value.add_widgets(sbw, enabled=False)
        if self.show_scalar_bar:
            sbw.on()
        self._foreground_changed_for_scene(None, value.foreground)

    def _foreground_changed_for_scene(self, old, new):
        # Change the default color for the text.
        self.title_text_property.color = new
        self.label_text_property.color = new
        self.render()

    def _number_of_colors_changed(self, value):
        if self.lut_mode == 'file':
            return
        elif self.lut_mode in pylab_luts:
            # We can't interpolate these LUTs, as they are defined from a
            # table. We hack around this limitation
            reverse = self.reverse_lut
            lut = pylab_luts[self.lut_mode]
            if reverse:
                lut = lut[::-1, :]
            n_total = len(lut)
            if value > n_total:
                return
            lut = lut[::int(round(n_total / float(value)))]
            self.load_lut_from_list(lut.tolist())
        else:
            lut = self.lut
            lut.number_of_table_values = value
            lut.modified()
            lut.build()
            self.render()  # necessary to flush.
        sc_bar = self.scalar_bar
        sc_bar.maximum_number_of_colors = value
        sc_bar.modified()
        self.render()

    def _number_of_labels_changed(self, value):
        sc_bar = self.scalar_bar
        sc_bar.number_of_labels = value
        sc_bar.modified()
        self.render()

    def _file_name_changed(self, value):
        if self.lut_mode == 'file':
            self.load_lut_from_file(value)
        else:
            # This will automagically load the LUT from the file.
            self.lut_mode = 'file'

    def _reverse_lut_changed(self, value):
        # This will do the needful.
        self._lut_mode_changed(self.lut_mode)

    def _show_scalar_bar_changed(self, value):
        if self.scene is not None:
            # Without a title for scalar bar actor, vtkOpenGLTexture logs this:
            # Error: No scalar values found for texture input!
            if self.scalar_bar.title == '':
                self.scalar_bar.title = ' '
            self.scalar_bar_widget.enabled = value
            self.render()

    def _get_show_legend(self):
        return self.show_scalar_bar

    def _set_show_legend(self, value):
        old = self.show_scalar_bar
        if value != old:
            self.show_scalar_bar = value
            self.trait_property_changed('show_legend', old, value)

    def _shadow_changed(self, value):
        sc_bar = self.scalar_bar
        sc_bar.title_text_property.shadow = self.shadow
        sc_bar.label_text_property.shadow = self.shadow
        self.render()

    def _use_default_name_changed(self, value):
        self._default_data_name_changed(self.default_data_name)

    def _data_name_changed(self, value):
        sc_bar = self.scalar_bar
        sc_bar.title = value
        sc_bar.modified()
        self.render()

    def _default_data_name_changed(self, value):
        if self.use_default_name:
            self.data_name = value

    def _use_default_range_changed(self, value):
        self._default_data_range_changed(self.default_data_range)

    def _data_range_changed(self, value):
        try:
            self.lut.set_range(value[0], value[1])
        except TypeError:
            self.lut.set_range((value[0], value[1]))
        except AttributeError:
            self.lut.range = value
        self.scalar_bar.modified()
        self.render()

    def _default_data_range_changed(self, value):
        if self.use_default_range:
            self.data_range = value

    def _visible_changed(self, value):
        state = self.show_scalar_bar and value
        self._show_scalar_bar_changed(state)
        super(LUTManager, self)._visible_changed(value)

    def load_lut_from_file(self, file_name):
        lut_list = []
        if len(file_name) > 0:
            try:
                f = open(file_name, 'r')
            except IOError:
                msg = "Cannot open Lookup Table file: %s\n" % file_name
                error(msg)
            else:
                f.close()
                try:
                    lut_list = parse_lut_file(file_name)
                except IOError as err_msg:
                    msg = "Sorry could not parse LUT file: %s\n" % file_name
                    msg += err_msg
                    error(msg)
                else:
                    if self.reverse_lut:
                        lut_list.reverse()
                    self.lut = set_lut(self.lut, lut_list)
                    self.render()

    def load_lut_from_list(self, list):
        self.lut = set_lut(self.lut, list)
        self.render()

    def _get_title_text_property(self):
        return self._title_text_property

    def _get_label_text_property(self):
        return self._label_text_property

    def _create_lut_fired(self):
        from tvtk import util
        script = os.path.join(os.path.dirname(util.__file__),
                              'wx_gradient_editor.py')
        subprocess.Popen([sys.executable, script])
        auto_close_message('Launching LUT editor in separate process ...')

    def _scalar_bar_representation_default(self):
        w = self.scalar_bar_widget
        if hasattr(w, 'representation'):
            r = w.representation
            r.on_trait_change(self.render)
            return r
        else:
            return None
Example #8
0
class FETS2D4Q(FETSEval):

    debug_on = True

    dots_class = DOTSGridEval

    # Dimensional mapping
    dim_slice = slice(0, 2)

    # Order of node positions for the formulation of shape function
    #
    dof_r = Array(value=[[-1, -1], [1, -1], [1, 1], [-1, 1]])
    geo_r = Array(value=[[-1, -1], [1, -1], [1, 1], [-1, 1]])

    n_e_dofs = Int(8,
                   FE=True,
                   enter_set=False,
                   auto_set=True,
                   label='number of elemnet dofs')
    t = Float(1.0, CS=True, enter_set=False, auto_set=True, label='thickness')

    # Integration parameters
    #
    ngp_r = Int(
        2,
        FE=True,
        enter_set=False,
        auto_set=True,
    )
    ngp_s = Int(
        2,
        FE=True,
        enter_set=False,
        auto_set=True,
    )

    # Corner nodes are used for visualization
    vtk_r = Array(value=[[-1., -1.], [1., -1.], [1., 1.], [-1., 1.]])
    vtk_cells = [[0, 1, 2, 3]]
    vtk_cell_types = 'Quad'

    #vtk_point_ip_map = [0,1,3,2]
    n_nodal_dofs = Int(2)

    A_C = Property(depends_on='A_m,A_f')

    @cached_property
    def _get_A_C(self):
        return np.array((1.0, 1.0, 1.0), dtype=np.float_)

    #---------------------------------------------------------------------
    # Method required to represent the element geometry
    #---------------------------------------------------------------------
    def get_N_geo_mtx(self, r_pnt):
        '''
        Return the value of shape functions for the specified local coordinate r
        '''
        cx = np.array(self.geo_r, dtype='float_')
        Nr = np.array([[
            1 / 4. * (1 + r_pnt[0] * cx[i, 0]) * (1 + r_pnt[1] * cx[i, 1])
            for i in range(0, 4)
        ]])
        return Nr

    def get_dNr_geo_mtx(self, r_pnt):
        '''
        Return the matrix of shape function derivatives.
        Used for the conrcution of the Jacobi matrix.

        @TODO - the B matrix is used
        just for uniaxial bar here with a trivial differential
        operator.
        '''
        #cx = self._node_coord_map
        cx = np.array(self.geo_r, dtype='float_')
        dNr_geo = np.array([[
            1 / 4. * cx[i, 0] * (1 + r_pnt[1] * cx[i, 1]) for i in range(0, 4)
        ], [
            1 / 4. * cx[i, 1] * (1 + r_pnt[0] * cx[i, 0]) for i in range(0, 4)
        ]])
        return dNr_geo

    #---------------------------------------------------------------------
    # Method delivering the shape functions for the field variables and their derivatives
    #---------------------------------------------------------------------
    def get_N_mtx(self, r_pnt):
        '''
        Returns the matrix of the shape functions used for the field approximation
        containing zero entries. The number of rows corresponds to the number of nodal
        dofs. The matrix is evaluated for the specified local coordinate r.
        '''
        Nr_geo = self.get_N_geo_mtx(r_pnt)
        I_mtx = np.identity(self.n_nodal_dofs, float)
        N_mtx_list = [I_mtx * Nr_geo[0, i] for i in range(0, Nr_geo.shape[1])]
        N_mtx = np.hstack(N_mtx_list)
        return N_mtx

    def get_dNr_mtx(self, r_pnt):
        '''
        Return the derivatives of the shape functions
        '''
        return self.get_dNr_geo_mtx(r_pnt)

    def get_B_mtx(self, r_pnt, X_mtx):
        J_mtx = self.get_J_mtx(r_pnt, X_mtx)
        dNr_mtx = self.get_dNr_mtx(r_pnt)
        dNx_mtx = np.dot(inv(J_mtx), dNr_mtx)
        Bx_mtx = np.zeros((3, 8), dtype='float_')
        for i in range(0, 4):
            Bx_mtx[0, i * 2] = dNx_mtx[0, i]
            Bx_mtx[1, i * 2 + 1] = dNx_mtx[1, i]
            Bx_mtx[2, i * 2] = dNx_mtx[1, i]
            Bx_mtx[2, i * 2 + 1] = dNx_mtx[0, i]
        return Bx_mtx
Example #9
0
class LatticeModelSim(Simulator):

    node_name = 'pull out simulation'

    tree_node_list = List([])

    def _tree_node_list_default(self):

        return [
            self.tline,
            self.loading_scenario,
            self.mats_eval,
            self.lattice_tessellation,
        ]

    def _update_node_list(self):
        self.tree_node_list = [
            self.tline,
            self.loading_scenario,
            self.mats_eval,
            self.lattice_tessellation,
        ]

    tree_view = View(
        Group(Item('mats_eval_type', resizable=True, full_size=True),
              Item('control_variable', resizable=True, full_size=True),
              Item('w_max', resizable=True, full_size=True),
              Group(Item('loading_scenario@', show_label=False), )))

    #=========================================================================
    # Test setup parameters
    #=========================================================================
    loading_scenario = Instance(LoadingScenario,
                                report=True,
                                desc='object defining the loading scenario')

    def _loading_scenario_default(self):
        return LoadingScenario()

    lattice_tessellation = Instance(LatticeTessellation,
                                    MESH=True,
                                    report=True,
                                    desc='cross section parameters')

    def _lattice_tessellation_default(self):
        return LatticeTessellation()

    control_variable = Enum('u', 'f', auto_set=False, enter_set=True, BC=True)

    #=========================================================================
    # Algorithimc parameters
    #=========================================================================
    k_max = Int(400,
                unit='-',
                symbol='k_{\max}',
                desc='maximum number of iterations',
                ALG=True)

    tolerance = Float(1e-4,
                      unit='-',
                      symbol='\epsilon',
                      desc='required accuracy',
                      ALG=True)

    mats_eval_type = Trait('mats3d_ifc_elastic', {
        'mats3d_ifc_elastic': MATS3DIfcElastic,
        'mats3d_ifc_cumslip': MATS3DIfcCumSlip,
    },
                           MAT=True,
                           desc='material model type')

    @on_trait_change('mats_eval_type')
    def _set_mats_eval(self):
        self.mats_eval = self.mats_eval_type_()
        self._update_node_list()

    mats_eval = Instance(IMATSEval, report=True)
    '''Material model'''

    def _mats_eval_default(self):
        return self.mats_eval_type_()

    dots_lattice = Property(Instance(XDomainLattice), depends_on=itags_str)
    '''Discretization object.
    '''

    @cached_property
    def _get_dots_lattice(self):
        print('reconstruct DOTS')
        return XDomainLattice(mesh=self.lattice_tessellation)

    domains = Property(depends_on=itags_str)

    @cached_property
    def _get_domains(self):
        print('reconstruct DOMAIN')
        return [(self.dots_lattice, self.mats_eval)]

    #=========================================================================
    # Boundary conditions
    #=========================================================================
    w_max = Float(1,
                  BC=True,
                  symbol='w_{\max}',
                  unit='mm',
                  desc='maximum pullout slip',
                  auto_set=False,
                  enter_set=True)

    fixed_dofs = Array(np.int_, value=[], BC=True)

    fixed_bc_list = Property(depends_on=itags_str)
    r'''Foxed boundary condition'''

    @cached_property
    def _get_fixed_bc_list(self):
        return [
            BCDof(node_name='fixed dof %d' % dof, var='u', dof=dof, value=0.0)
            for dof in self.fixed_dofs
        ]

    control_dofs = Array(np.int_, value=[], BC=True)

    control_bc_list = Property(depends_on=itags_str)
    r'''Control boundary condition - make it accessible directly
    for the visualization adapter as property
    '''

    @cached_property
    def _get_control_bc_list(self):
        return [
            BCDof(node_name='control dof %d' % dof,
                  var=self.control_variable,
                  dof=dof,
                  value=self.w_max,
                  time_function=self.loading_scenario)
            for dof in self.control_dofs
        ]

    bc = Property(depends_on=itags_str)

    @cached_property
    def _get_bc(self):
        return self.control_bc_list + self.fixed_bc_list

    def get_window(self):
        self.record['Pw'] = LatticeRecord()
        self.record['eps'] = Vis3DLattice(var='eps')
        w = BMCSWindow(sim=self)
        fw = Viz2DLatticeFW(name='FW-curve', vis2d=self.hist['Pw'])
        w.viz_sheet.viz2d_list.append(fw)
        viz3d_u_Lb = Viz3DLattice(vis3d=self.hist['eps'])
        w.viz_sheet.add_viz3d(viz3d_u_Lb)
        w.viz_sheet.monitor_chunk_size = 10
        return w
Example #10
0
class CombineMarkersPanel(HasTraits):
    """Has two marker points sources and interpolates to a third one"""
    model = Instance(CombineMarkersModel, ())

    # model references for UI
    mrk1 = Instance(MarkerPointSource)
    mrk2 = Instance(MarkerPointSource)
    mrk3 = Instance(MarkerPointDest)
    distance = Str

    # Visualization
    scene = Instance(MlabSceneModel)
    scale = Float(5e-3)
    mrk1_obj = Instance(PointObject)
    mrk2_obj = Instance(PointObject)
    mrk3_obj = Instance(PointObject)
    trans = Array()

    view = View(VGroup(VGroup(Item('mrk1', style='custom'),
                              Item('mrk1_obj', style='custom'),
                              show_labels=False,
                              label="Source Marker 1", show_border=True),
                       VGroup(Item('mrk2', style='custom'),
                              Item('mrk2_obj', style='custom'),
                              show_labels=False,
                              label="Source Marker 2", show_border=True),
                       VGroup(Item('distance', style='readonly'),
                              label='Stats', show_border=True),
                       VGroup(Item('mrk3', style='custom'),
                              Item('mrk3_obj', style='custom'),
                              show_labels=False,
                              label="New Marker", show_border=True),
                       ))

    def _mrk1_default(self):
        return self.model.mrk1

    def _mrk2_default(self):
        return self.model.mrk2

    def _mrk3_default(self):
        return self.model.mrk3

    def __init__(self, *args, **kwargs):
        super(CombineMarkersPanel, self).__init__(*args, **kwargs)

        m = self.model
        m.sync_trait('distance', self, 'distance', mutual=False)

        self.mrk1_obj = PointObject(scene=self.scene, color=(155, 55, 55),
                                    point_scale=self.scale)
        self.sync_trait('trans', self.mrk1_obj, mutual=False)
        m.mrk1.sync_trait('points', self.mrk1_obj, 'points', mutual=False)
        m.mrk1.sync_trait('enabled', self.mrk1_obj, 'visible',
                          mutual=False)

        self.mrk2_obj = PointObject(scene=self.scene, color=(55, 155, 55),
                                    point_scale=self.scale)
        self.sync_trait('trans', self.mrk2_obj, mutual=False)
        m.mrk2.sync_trait('points', self.mrk2_obj, 'points', mutual=False)
        m.mrk2.sync_trait('enabled', self.mrk2_obj, 'visible',
                          mutual=False)

        self.mrk3_obj = PointObject(scene=self.scene, color=(150, 200, 255),
                                    point_scale=self.scale)
        self.sync_trait('trans', self.mrk3_obj, mutual=False)
        m.mrk3.sync_trait('points', self.mrk3_obj, 'points', mutual=False)
        m.mrk3.sync_trait('enabled', self.mrk3_obj, 'visible', mutual=False)
Example #11
0
class FETS2D4Q12U(FETSEval):
    debug_on = True

    mats_eval = Instance(MATSEval)

    # Dimensional mapping
    dim_slice = slice(0, 2)

    n_e_dofs = Int(2 * 12)
    t = Float(1.0, label='thickness')

    # Integration parameters
    #
    ngp_r = 3
    ngp_s = 3

    # The order of the field approximation is higher then the order of the geometry
    # approximation (subparametric element).
    # The implemented shape functions are derived (in femple) based
    # on the following ordering of the nodes of the parent element.
    #
    dof_r = Array(
        value=[[-1., -1.], [1., -1.], [1., 1.], [-1., 1.], [-1. / 3., -1.],
               [1. / 3., -1.], [1., -1. / 3.], [1., 1. / 3.], [1. / 3., 1.],
               [-1. / 3., 1.], [-1., 1. / 3.], [-1., -1. / 3.]])

    geo_r = Array(value=[[-1., -1.], [1., -1.], [1., 1.], [-1., 1.]])

    vtk_r = Array(value=[[-1., -1.], [0., -1.], [1., -1.], [-1., 0.], [0., 0.],
                         [1., 0.], [-1., 1.], [0., 1.], [1., 1.]])

    vtk_cells = [[0, 2, 8, 6, 1, 5, 7, 3, 4]]
    vtk_cell_types = 'QuadraticQuad'

    n_nodal_dofs = Int(2)

    # Ordering of the nodes of the parent element used for the geometry
    # approximation
    _node_coord_map_geo = Array(Float, (4, 2),
                                [[-1., -1.], [1., -1.], [1., 1.], [-1., 1.]])

    #---------------------------------------------------------------------
    # Method required to represent the element geometry
    #---------------------------------------------------------------------
    def get_N_geo_mtx(self, r_pnt):
        '''
        Return the value of shape functions for the specified local coordinate r
        '''
        cx = self._node_coord_map_geo
        N_geo_mtx = array([[
            1 / 4. * (1 + r_pnt[0] * cx[i, 0]) * (1 + r_pnt[1] * cx[i, 1])
            for i in range(0, 4)
        ]])
        return N_geo_mtx

    def get_dNr_geo_mtx(self, r_pnt):
        '''
        Return the matrix of shape function derivatives.
        Used for the construction of the Jacobi matrix.
        '''
        cx = self._node_coord_map_geo
        dNr_geo_mtx = array([[
            1 / 4. * cx[i, 0] * (1 + r_pnt[1] * cx[i, 1]) for i in range(0, 4)
        ], [
            1 / 4. * cx[i, 1] * (1 + r_pnt[0] * cx[i, 0]) for i in range(0, 4)
        ]])
        return dNr_geo_mtx

    #-------------------------------------------------------------------
    # Shape functions for the field variables and their derivatives
    #-------------------------------------------------------------------
    def get_N_mtx(self, r_pnt):
        '''
        Returns the matrix of the shape functions used for the field approximation
        containing zero entries. The number of rows corresponds to the number of nodal
        dofs. The matrix is evaluated for the specified local coordinate r.
        '''
        r = r_pnt[0]
        s = r_pnt[1]

        N_mtx = zeros((2, 24), dtype='float_')
        N_mtx[0,
              0] = ((-1 + r) * (-1 + s) * (9 * s * s + 9 * r * r - 10)) / 32.
        N_mtx[0, 2] = - \
            ((1 + r) * (-1 + s) * (9 * s * s + 9 * r * r - 10)) / 32.
        N_mtx[0, 4] = ((1 + r) * (1 + s) * (9 * s * s + 9 * r * r - 10)) / 32.
        N_mtx[0, 6] = - \
            ((-1 + r) * (1 + s) * (9 * s * s + 9 * r * r - 10)) / 32.
        N_mtx[0, 8] = -9. / 32. * (-1 + r) * (3 * r - 1) * (1 + r) * (-1 + s)
        N_mtx[0, 10] = 9. / 32. * (-1 + r) * (3 * r + 1) * (1 + r) * (-1 + s)
        N_mtx[0, 12] = 9. / 32. * (-1 + s) * (3 * s - 1) * (1 + s) * (1 + r)
        N_mtx[0, 14] = -9. / 32. * (-1 + s) * (3 * s + 1) * (1 + s) * (1 + r)
        N_mtx[0, 16] = -9. / 32. * (-1 + r) * (3 * r + 1) * (1 + r) * (1 + s)
        N_mtx[0, 18] = 9. / 32. * (-1 + r) * (3 * r - 1) * (1 + r) * (1 + s)
        N_mtx[0, 20] = 9. / 32. * (-1 + s) * (3 * s + 1) * (1 + s) * (-1 + r)
        N_mtx[0, 22] = -9. / 32. * (-1 + s) * (3 * s - 1) * (1 + s) * (-1 + r)
        N_mtx[1,
              1] = ((-1 + r) * (-1 + s) * (9 * s * s + 9 * r * r - 10)) / 32.
        N_mtx[1, 3] = - \
            ((1 + r) * (-1 + s) * (9 * s * s + 9 * r * r - 10)) / 32.
        N_mtx[1, 5] = ((1 + r) * (1 + s) * (9 * s * s + 9 * r * r - 10)) / 32.
        N_mtx[1, 7] = - \
            ((-1 + r) * (1 + s) * (9 * s * s + 9 * r * r - 10)) / 32.
        N_mtx[1, 9] = -9. / 32. * (-1 + r) * (3 * r - 1) * (1 + r) * (-1 + s)
        N_mtx[1, 11] = 9. / 32. * (-1 + r) * (3 * r + 1) * (1 + r) * (-1 + s)
        N_mtx[1, 13] = 9. / 32. * (-1 + s) * (3 * s - 1) * (1 + s) * (1 + r)
        N_mtx[1, 15] = -9. / 32. * (-1 + s) * (3 * s + 1) * (1 + s) * (1 + r)
        N_mtx[1, 17] = -9. / 32. * (-1 + r) * (3 * r + 1) * (1 + r) * (1 + s)
        N_mtx[1, 19] = 9. / 32. * (-1 + r) * (3 * r - 1) * (1 + r) * (1 + s)
        N_mtx[1, 21] = 9. / 32. * (-1 + s) * (3 * s + 1) * (1 + s) * (-1 + r)
        N_mtx[1, 23] = -9. / 32. * (-1 + s) * (3 * s - 1) * (1 + s) * (-1 + r)
        return N_mtx

    def get_dNr_mtx(self, r_pnt):
        '''
        Return the derivatives of the shape functions used for the field approximation
        '''
        r = r_pnt[0]
        s = r_pnt[1]
        dNr = zeros((2, 12), dtype='float_')
        dNr[0, 0] = ((-1 + s) * (9 * r * r + 9 * s * s - 10)) / \
            32. + 9. / 16. * (-1 + s) * (-1 + r) * r
        dNr[0, 1] = - ((-1 + s) * (9 * r * r + 9 * s * s - 10)) / \
            32. - 9. / 16. * (-1 + s) * (1 + r) * r
        dNr[0, 2] = ((1 + s) * (9 * r * r + 9 * s * s - 10)) / \
            32. + 9. / 16. * (1 + s) * (1 + r) * r
        dNr[0, 3] = - ((1 + s) * (9 * r * r + 9 * s * s - 10)) / \
            32. - 9. / 16. * (1 + s) * (-1 + r) * r
        dNr[0, 4] = -9. / 32. * (3 * r - 1) * (1 + r) * (
            -1 + s) - 27. / 32. * (-1 + r) * (1 + r) * (-1 + s) - 9. / 32. * (
                -1 + r) * (3 * r - 1) * (-1 + s)
        dNr[0, 5] = 9. / 32. * (3 * r + 1) * (1 + r) * (-1 + s) + 27. / 32. * (
            -1 + r) * (1 + r) * (-1 + s) + 9. / 32. * (-1 + r) * (3 * r +
                                                                  1) * (-1 + s)
        dNr[0, 6] = 9. / 32. * (-1 + s) * (3 * s - 1) * (1 + s)
        dNr[0, 7] = -9. / 32. * (-1 + s) * (3 * s + 1) * (1 + s)
        dNr[0, 8] = -9. / 32. * (3 * r + 1) * (1 + r) * (1 + s) - 27. / 32. * (
            -1 + r) * (1 + r) * (1 + s) - 9. / 32. * (-1 + r) * (3 * r +
                                                                 1) * (1 + s)
        dNr[0, 9] = 9. / 32. * (3 * r - 1) * (1 + r) * (1 + s) + 27. / 32. * (
            -1 + r) * (1 + r) * (1 + s) + 9. / 32. * (-1 + r) * (3 * r -
                                                                 1) * (1 + s)
        dNr[0, 10] = 9. / 32. * (-1 + s) * (3 * s + 1) * (1 + s)
        dNr[0, 11] = -9. / 32. * (-1 + s) * (3 * s - 1) * (1 + s)
        dNr[1, 0] = ((-1 + r) * (9 * r * r + 9 * s * s - 10)) / \
            32. + 9. / 16. * (-1 + s) * (-1 + r) * s
        dNr[1, 1] = - ((1 + r) * (9 * r * r + 9 * s * s - 10)) / \
            32. - 9. / 16. * (-1 + s) * (1 + r) * s
        dNr[1, 2] = ((1 + r) * (9 * r * r + 9 * s * s - 10)) / \
            32. + 9. / 16. * (1 + s) * (1 + r) * s
        dNr[1, 3] = - ((-1 + r) * (9 * r * r + 9 * s * s - 10)) / \
            32. - 9. / 16. * (1 + s) * (-1 + r) * s
        dNr[1, 4] = -9. / 32. * (-1 + r) * (3 * r - 1) * (1 + r)
        dNr[1, 5] = 9. / 32. * (-1 + r) * (3 * r + 1) * (1 + r)
        dNr[1, 6] = 9. / 32. * (3 * s - 1) * (1 + s) * (1 + r) + 27. / 32. * (
            -1 + s) * (1 + s) * (1 + r) + 9. / 32. * (-1 + s) * (3 * s -
                                                                 1) * (1 + r)
        dNr[1, 7] = -9. / 32. * (3 * s + 1) * (1 + s) * (1 + r) - 27. / 32. * (
            -1 + s) * (1 + s) * (1 + r) - 9. / 32. * (-1 + s) * (3 * s +
                                                                 1) * (1 + r)
        dNr[1, 8] = -9. / 32. * (-1 + r) * (3 * r + 1) * (1 + r)
        dNr[1, 9] = 9. / 32. * (-1 + r) * (3 * r - 1) * (1 + r)
        dNr[1, 10] = 9. / 32. * (3 * s + 1) * (1 + s) * (
            -1 + r) + 27. / 32. * (-1 + s) * (1 + s) * (-1 + r) + 9. / 32. * (
                -1 + s) * (3 * s + 1) * (-1 + r)
        dNr[1, 11] = -9. / 32. * (3 * s - 1) * (1 + s) * (
            -1 + r) - 27. / 32. * (-1 + s) * (1 + s) * (-1 + r) - 9. / 32. * (
                -1 + s) * (3 * s - 1) * (-1 + r)
        return dNr

    def get_B_mtx(self, r_pnt, X_mtx):

        J_mtx = self.get_J_mtx(r_pnt, X_mtx)
        dNr_mtx = self.get_dNr_mtx(r_pnt)
        dNx_mtx = dot(inv(J_mtx), dNr_mtx)
        Bx_mtx = zeros((3, 24), dtype='float_')
        for i in range(0, 12):
            Bx_mtx[0, i * 2] = dNx_mtx[0, i]
            Bx_mtx[1, i * 2 + 1] = dNx_mtx[1, i]
            Bx_mtx[2, i * 2] = dNx_mtx[1, i]
            Bx_mtx[2, i * 2 + 1] = dNx_mtx[0, i]
        return Bx_mtx
Example #12
0
class MarkerPointDest(MarkerPoints):
    """MarkerPoints subclass that serves for derived points"""
    src1 = Instance(MarkerPointSource)
    src2 = Instance(MarkerPointSource)

    name = Property(Str, depends_on='src1.name,src2.name')
    dir = Property(Str, depends_on='src1.dir,src2.dir')

    points = Property(Array(float, (5, 3)),
                      depends_on=['method', 'src1.points', 'src1.use',
                                  'src2.points', 'src2.use'])
    enabled = Property(Bool, depends_on=['points'])

    method = Enum('Transform', 'Average', desc="Transform: estimate a rotation"
                  "/translation from mrk1 to mrk2; Average: use the average "
                  "of the mrk1 and mrk2 coordinates for each point.")

    view = View(VGroup(Item('method', style='custom'),
                       Item('save_as', enabled_when='can_save',
                            show_label=False)))

    @cached_property
    def _get_dir(self):
        return self.src1.dir

    @cached_property
    def _get_name(self):
        n1 = self.src1.name
        n2 = self.src2.name

        if not n1:
            if n2:
                return n2
            else:
                return ''
        elif not n2:
            return n1

        if n1 == n2:
            return n1

        i = 0
        l1 = len(n1) - 1
        l2 = len(n1) - 2
        while n1[i] == n2[i]:
            if i == l1:
                return n1
            elif i == l2:
                return n2

            i += 1

        return n1[:i]

    @cached_property
    def _get_enabled(self):
        return np.any(self.points)

    @cached_property
    def _get_points(self):
        # in case only one or no source is enabled
        if not (self.src1 and self.src1.enabled):
            if (self.src2 and self.src2.enabled):
                return self.src2.points
            else:
                return np.zeros((5, 3))
        elif not (self.src2 and self.src2.enabled):
            return self.src1.points

        # Average method
        if self.method == 'Average':
            if len(np.union1d(self.src1.use, self.src2.use)) < 5:
                error(None, "Need at least one source for each point.",
                      "Marker Average Error")
                return np.zeros((5, 3))

            pts = (self.src1.points + self.src2.points) / 2.
            for i in np.setdiff1d(self.src1.use, self.src2.use):
                pts[i] = self.src1.points[i]
            for i in np.setdiff1d(self.src2.use, self.src1.use):
                pts[i] = self.src2.points[i]

            return pts

        # Transform method
        idx = np.intersect1d(self.src1.use, self.src2.use, assume_unique=True)
        if len(idx) < 3:
            error(None, "Need at least three shared points for trans"
                  "formation.", "Marker Interpolation Error")
            return np.zeros((5, 3))

        src_pts = self.src1.points[idx]
        tgt_pts = self.src2.points[idx]
        est = fit_matched_points(src_pts, tgt_pts, out='params')
        rot = np.array(est[:3]) / 2.
        tra = np.array(est[3:]) / 2.

        if len(self.src1.use) == 5:
            trans = np.dot(translation(*tra), rotation(*rot))
            pts = apply_trans(trans, self.src1.points)
        elif len(self.src2.use) == 5:
            trans = np.dot(translation(* -tra), rotation(* -rot))
            pts = apply_trans(trans, self.src2.points)
        else:
            trans1 = np.dot(translation(*tra), rotation(*rot))
            pts = apply_trans(trans1, self.src1.points)
            trans2 = np.dot(translation(* -tra), rotation(* -rot))
            for i in np.setdiff1d(self.src2.use, self.src1.use):
                pts[i] = apply_trans(trans2, self.src2.points[i])

        return pts
Example #13
0
class DefaultPickHandler(PickHandler):
    """The default handler for the picked data."""
    # Traits.
    ID = Trait(None, None, Int, desc='the picked ID')

    coordinate = Trait(None, None, Array('d', (3,)),
                       desc='the coordinate of the picked point')

    scalar = Trait(None, None, Array, Float, desc='the scalar at picked point')

    vector = Trait(None, None, Array('d', (3,)),
                   desc='the vector at picked point')

    tensor = Trait(None, None, Array('d', (3, 3)),
                   desc='the tensor at picked point')

    # History of picked data.
    history = Str

    default_view = View(Item(name='ID', style='readonly'),
                        Item(name='coordinate', style='readonly'),
                        Item(name='scalar', style='readonly'),
                        Item(name='vector', style='readonly'),
                        Item(name='tensor', style='readonly'),
                        Item(name='history', style='custom'),
                        )

    def __init__(self, **traits):
        super(DefaultPickHandler, self).__init__(**traits)
        # This saves all the data picked earlier.
        self.data = {'ID': [], 'coordinate': [], 'scalar': [], 'vector': [],
                     'tensor': []}

    #################################################################
    # `DefaultPickHandler` interface.
    #################################################################
    def handle_pick(self, data):
        """Called when a pick event happens.
        """
        if data.valid_:
            if data.point_id > -1:
                self.ID = data.point_id
            elif data.cell_id > -1:
                self.ID = data.cell_id
            self.coordinate = list(data.coordinate)

            if data.data:
                array_data = {'scalar': data.data.scalars,
                              'vector': data.data.vectors,
                              'tensor': data.data.tensors}
            else:
                array_data = {'scalar': None,
                              'vector': None,
                              'tensor': None}
            for name in array_data.keys():
                if array_data[name]:
                    setattr(self, name, array_data[name][self.ID])
                else:
                    setattr(self, name, None)
        else:
            for name in ['ID', 'coordinate', 'scalar', 'vector', 'tensor']:
                setattr(self, name, None)

        self._update_data(data.renwin, data.text_actor)

    #################################################################
    # Non-public interface.
    #################################################################

    def _update_data(self, renwin, text_actor):
        for name in ['ID', 'coordinate', 'scalar', 'vector', 'tensor']:
            value = getattr(self, name)
            self.data.get(name).append(getattr(self, name))
            self.history += '%s: %r\n' % (name, value)

        x_coord = np.format_float_scientific(self.coordinate[0], precision=3)
        y_coord = np.format_float_scientific(self.coordinate[1], precision=3)
        z_coord = np.format_float_scientific(self.coordinate[2], precision=3)

        if(self.vector is not None and
           self.scalar is not None and
           self.tensor is not None):
            text_actor.trait_set(
                input=("ID : %s\nx : %s\ny : %s\nz : %s " +
                       "\nscalar : %s\nvector : %s\ntensor : %s ")
                % (self.ID, x_coord, y_coord, z_coord,
                   self.scalar, self.vector, self.tensor)
            )

        elif self.vector is not None and self.scalar is not None:
            scalar = np.format_float_scientific(self.scalar, precision=3)
            vector = np.zeros(3)
            for i in range(3):
                vector[i] = np.format_float_scientific(self.vector[i],
                                                       precision=3)

            text_actor.trait_set(
                input=("ID : %s\nx : %s\ny : %s\nz : %s" +
                       "\nscalar : %s\nvector : %s ")
                % (self.ID, x_coord, y_coord, z_coord, scalar, vector)
            )

        elif self.scalar is not None:
            scalar = np.format_float_scientific(self.scalar, precision=3)
            text_actor.trait_set(
                input="ID : %s\nx : %s\ny : %s\nz : %s\nscalar : %s"
                % (self.ID, x_coord, y_coord, z_coord, scalar)
            )

        else:
            text_actor.trait_set(
                input="ID : %s\nx : %s\ny : %s\nz : %s "
                % (self.ID, x_coord, y_coord, z_coord)
            )
class BasicFigure(MinimalFigure):
    mask_data_bool = Bool(True)
    mask_length = Int(100000)

    normalize_bool = Bool(False)
    normalize_max = Float()
    normalize_maxes = List()

    log_bool = Bool(False)
    draw_legend_bool = Bool(True)

    # image stuff
    origin = Str('lower')
    img_bool = Bool(False)
    img_max = Float(1.)
    img_data = Array()
    img_kwargs = Dict
    zlabel = Str()

    vmin_lv, vmin_hv, vmax_lv, vmax_hv = Float(0.), Float(0.), Float(
        1.), Float(1.)
    vmin = Range('vmin_lv', 'vmin_hv')
    vmax = Range('vmax_lv', 'vmax_hv')
    cmaps = List
    cmap_selector = Enum(values='cmaps')

    image_slider_btn = Button('z-slider')

    errorbar_data = Dict(
    )  # this has is needed because of https://github.com/matplotlib/matplotlib/issues/8458
    _xerr = Dict()
    _yerr = Dict()

    def __init__(self, **kwargs):
        super(BasicFigure, self).__init__(**kwargs)
        self.grid = True

    def _test_plot_kwargs(self, kwargs):
        if 'fmt' in kwargs:
            fmt = kwargs['fmt']
            del kwargs['fmt']
        else:
            fmt = ''
        if 'label' not in kwargs:
            raise Exception(
                "BasicFigure: Please provide a label for datapoints.")
        else:
            label = kwargs['label']
        if 'text' in kwargs:
            text = kwargs['text']
            del kwargs['text']
        else:
            text = ''
        # position of the text
        if 'pos' in kwargs:
            pos = kwargs['pos']
            del kwargs['pos']
        else:
            pos = (0, 0)
        return fmt, label, text, pos

    def _mask_data(self, data):
        # fast when data is not too big (like 70M datapoints, but still works. Never possible with matplotlib)
        if not self.mask_data_bool:
            return data
        else:
            data = np.array(data)
            steps = len(data) / self.mask_length
            masked_data = data[0:-1:int(steps)]
            return masked_data

    def _zlabel_changed(self):
        if self.img_bool:
            self.cb.set_label(self.zlabel)
            self.draw()

    def _cmap_selector_changed(self):
        if self.img_bool:
            self.img.set_cmap(self.cmap_selector)
            self.draw()

    def _cmaps_default(self):
        print(self.__class__.__name__, ": Initiating colormaps")
        cmaps = sorted(m for m in mpl._cm.datad)
        return cmaps

    def _normalize_bool_changed(self, old=None, new=None):
        # Function is a little bit long since it handles normalization completly by itself
        # Maybe there is a better way, but it's working and i do not have time to think about a better one

        if old != new and self.img_bool is False:
            self.set_animation_for_lines(False)

        self.normalize_max = 0.

        if self.img_bool:
            if self.normalize_bool:
                self.img_max = np.nanmax(self.img_data)
                self.img_data = self.img_data / self.img_max
            else:
                self.img_data = self.img_data * self.img_max

            self.update_imshow(self.img_data)
        else:
            if self.normalize_bool:
                self.normalize_maxes = []
                line = None
                for l in self.lines_list:
                    line = self._is_line_in_axes(l)
                    if line is False:
                        continue
                    x, y = line.get_data()
                    max = np.nanmax(y)
                    self.normalize_maxes.append(max)
                    if self.normalize_max < max:
                        self.normalize_max = max

                for l in self.lines_list:
                    line = self._is_line_in_axes(l)
                    if line is False:
                        continue
                    x, y = line.get_data()
                    line.set_data(x, y / self.normalize_max)
                    if not line.get_animated():
                        self.draw()
            else:
                line = None
                if len(self.normalize_maxes) > 0:
                    for i, l in enumerate(self.lines_list):
                        line = self._is_line_in_axes(l, self.axes_selector)
                        if line is False:
                            continue
                        x, y = line.get_data()
                        max = np.nanmax(y)
                        if old != new:
                            line.set_data(x, y / max * self.normalize_maxes[i])
                        else:
                            line.set_data(x, y)

                    if line is not None and line is not False:
                        if not line.get_animated():
                            self.draw()

    def draw(self):
        if self.autoscale and not self.img_bool:
            axes = self.figure.get_axes()
            for ax in axes:
                ax.relim()
                ax.autoscale()
                # ax.autoscale_view(True,True,True)

        self.draw_canvas()

    def _img_bool_changed(self, val):
        self.figure.clear()
        if val:
            self.grid = False
        else:
            self.grid = True
        self.create_axis_if_no_axis()

    def _log_bool_changed(self):
        if self.img_bool:
            self.clear()
            if not self.log_bool:
                self.img_kwargs.pop('norm')
            self.imshow(self.img_data, **self.img_kwargs)
        else:
            self.set_animation_for_lines(
                False)  # has to be done, otherwise no datapoints
            if self.log_bool:  # TODO: Maybe add xscale log, but not needed now.
                # self.axes_selector.set_xscale("log", nonposx='clip')
                self.axes_selector.set_yscale("log", nonposy='clip')
            else:
                self.axes_selector.set_yscale("linear")
            self.draw()

    def _image_slider_btn_fired(self):
        self.autoscale = False
        self.edit_traits(view='image_slider_view')

    def _clear_btn_fired(self):
        if self.img_bool:
            self.img_bool = False  # also triggers
        else:
            ax = self.figure.get_axes()
            for a in ax:
                print("MinimalFigure: Clearing axis ", a)
                a.clear()

            self.xlabel = ''
            self.ylabel = ''
            self.title = ''
            for ax in self.figure.axes:
                ax.grid(self.grid)

        self.errorbar_data = {}
        self.draw_canvas()

    def imshow(self, z, ax=0, **kwargs):
        if self.normalize_bool:
            self._normalize_bool_changed()
            return

        if self.log_bool:
            kwargs['norm'] = LogNorm()

            if np.any(z < 0.):
                print(self.__class__.__name__,
                      ": WARNING - All values below 0. has been set to 0.")
                z[np.where(z < 0.)] = 0.

        self.img_data = np.array(z)

        if 'label' in kwargs:
            self.label = kwargs.pop('label')

        if 'origin' in kwargs:
            self.origin = kwargs['origin']

        if 'aspect' in kwargs:
            aspect = kwargs.pop('aspect')
        else:
            aspect = 'auto'

        if not self.img_bool:
            self.img_bool = True
            self.img = self.axes_selector.imshow(self.img_data,
                                                 aspect=aspect,
                                                 **kwargs)

            if not hasattr(self, "label"):
                self.label = ''

            self.cb = self.figure.colorbar(self.img, label=self.label)
            self.draw()
        else:
            self.update_imshow(self.img_data, ax=ax)
            if 'extent' in kwargs.keys():
                self.img.set_extent(kwargs['extent'])

        assert type(self.img) == mpl.image.AxesImage
        self._set_cb_slider()
        self.img_kwargs = kwargs

    def update_imshow(self, z, ax=0):
        z = np.array(z)
        self.img.set_data(z)
        if self.autoscale:
            self.img.autoscale()

        self.draw()

    @on_trait_change('autoscale')
    def _set_cb_slider(self):
        if self.autoscale and self.img_bool:
            minv, maxv = float(np.nanmin(self.img_data).round(2)), float(
                np.nanmax(self.img_data).round(2))
            self.vmin_lv = minv
            self.vmin_hv = maxv
            self.vmax_lv = minv
            self.vmax_hv = maxv
            self.vmin = self.vmin_lv
            self.vmax = self.vmax_hv

    def _vmin_changed(self):
        vmin = self.vmin
        if self.log_bool:
            if self.vmin < 0.:
                vmin = 0.

        if not self.autoscale:
            self.img.set_clim(vmin=vmin, vmax=self.vmax)
            self.draw()

    def _vmax_changed(self):
        vmin = self.vmin
        if self.log_bool and self.vmin < 0.:
            vmin = 0.

        if not self.autoscale:
            self.img.set_clim(vmin=vmin, vmax=self.vmax)
            self.draw()

    def axvline(self, pos, ax=0, **kwargs):
        self.ax_line(pos, 'axvline', ax=ax, **kwargs)

    def axhline(self, pos, ax=0, **kwargs):
        self.ax_line(pos, 'axhline', ax=ax, **kwargs)

    def ax_line(self, pos, func_str, ax=0, **kwargs):
        # self.img_bool = False
        fmt, label = self._test_plot_kwargs(kwargs)

        axes = self.figure.get_axes()
        line = self._is_line_in_axes(label)

        nodraw = False

        if 'nodraw' in kwargs:
            if kwargs.pop('nodraw'):
                nodraw = True

        if not line:
            print("BasicFigure: Plotting axhline ", label)
            if type(ax) == int:
                line = getattr(axes[ax], func_str)(pos, **kwargs)
            elif hasattr(ax, func_str):
                line = getattr(ax, func_str)(pos, **kwargs)
            else:
                raise TypeError('ax can be an int or the axis itself!')
            self.lines_list.append(label)
        else:
            line.remove()
            if type(ax) == int:
                line = getattr(axes[ax], func_str)(pos, **kwargs)
            elif hasattr(ax, func_str):
                line = getattr(ax, func_str)(pos, **kwargs)
            else:
                raise TypeError('ax can be an int or the axis itself!')

        self.lines_list.append(label)
        self.draw_legend()

        if not nodraw:
            self._normalize_bool_changed()
            self.draw()  # draws with respect to autolim etc.

        if hasattr(line, "append"):
            return line[0]
        else:
            return line

    def _is_errorbar_plotted(self, label):
        if label in self.errorbar_data:
            return self.errorbar_data[label]
        else:
            return False

    def errorbar(self, x, y, ax=0, **kwargs):
        """ Additional (to normal matplotlib plot method) kwargs:
                - (bool) nodraw     If True, will not draw canvas
                - (str) fmt         like in matplotlib errorbar(), but it is stupid to use it only in one function
        """
        self.img_bool = False
        fmt, label = self._test_plot_kwargs(kwargs)
        axes = self.get_axes()
        line = self._is_errorbar_plotted(label)

        if len(x) == 0:
            print(self.__class__.__name__, "Length of x array is 0.")
            return

        if not 'xerr' in kwargs:
            kwargs['xerr'] = np.zeros(x.shape)

        if not 'yerr' in kwargs:
            kwargs['yerr'] = np.zeros(y.shape)

        self._xerr[label] = kwargs['xerr']
        self._yerr[label] = kwargs['yerr']

        if len(x) > self.mask_length:
            x = self._mask_data(x)
            y = self._mask_data(y)
            kwargs['xerr'] = self._mask_data(kwargs.pop('xerr'))
            kwargs['yerr'] = self._mask_data(kwargs.pop('yerr'))

        nodraw = False
        if 'nodraw' in kwargs:
            if kwargs.pop('nodraw'):
                nodraw = True

        if type(line) is bool:
            print("BasicFigure: Plotting ", label)
            if type(ax) == int:
                self.errorbar_data[label] = axes[ax].errorbar(x,
                                                              y,
                                                              fmt=fmt,
                                                              **kwargs)
            elif hasattr(ax, 'plot'):
                self.errorbar_data[label] = ax.plot(x, y, fmt=fmt, **kwargs)
            else:
                raise TypeError('ax can be an int or the axis itself!')

            self.lines_list.append(label)
            self.draw_legend()
        else:
            if line[0].get_animated():
                self.set_animation_for_lines(
                    False)  # doesn't work otherwise, dunno why.
            self._set_errorbar_data(x, y, **kwargs)

        if not nodraw:
            self._normalize_bool_changed()
            self.draw()  # draws with respect to autolim etc.

        if hasattr(line, "append"):
            return line[0]
        else:
            return line

    def _copy_data_btn_fired(self):
        print(self.__class__.__name__, ": Trying to copy data to clipboard")
        if self.line_selector in self.errorbar_data:
            line, caplines, barlinecols = self.errorbar_data[
                self.line_selector]
            x = line.get_xdata()
            y = line.get_ydata()

            xerr = self._xerr[self.line_selector]
            yerr = self._yerr[self.line_selector]
            print("xerr = ", xerr)

            text = 'x \t y \t x_error \t y_error \n'
            for i in xrange(len(x)):
                text += str(x[i]) + "\t" + str(y[i]) + "\t" + str(
                    xerr[i]) + "\t" + str(yerr[i]) + "\n"
        else:
            line = self._is_line_in_axes(self.line_selector)
            x = line.get_xdata()
            y = line.get_ydata()

            text = 'x \t y \n'
            for i in xrange(len(x)):
                text += str(x[i]) + "\t" + str(y[i]) + "\n"

        self.add_to_clip_board(text)

    def _set_errorbar_data(self, *args, **kwargs):
        x, y = args
        label = kwargs['label']
        x = np.array(x)
        y = np.array(y)
        line, caplines, barlinecols = self.errorbar_data[label]

        line.set_data(x, y)
        xerr = kwargs['xerr']
        yerr = kwargs['yerr']

        if not (xerr is None and yerr is None):
            error_positions = (x - xerr, y), (x + xerr,
                                              y), (x, y - yerr), (x, y + yerr)

            # Update the caplines
            if len(caplines) > 0:
                for i, pos in enumerate(error_positions):
                    caplines[i].set_data(pos)

            # Update the error bars
            barlinecols[0].set_segments(zip(zip(x - xerr, y), zip(x + xerr,
                                                                  y)))
            barlinecols[1].set_segments(zip(zip(x, y - yerr), zip(x,
                                                                  y + yerr)))

    def plot(self, x, y, ax=0, **kwargs):
        """ Additional (to normal matplotlib plot method) kwargs:
                - (bool) nodraw     If True, will not draw canvas
                - (str) fmt         like in matplotlib errorbar(), but it is stupid to use it only in one function
        """
        self.img_bool = False
        fmt, label, text, pos = self._test_plot_kwargs(kwargs)
        axes = self.get_axes()
        line = self._is_line_in_axes(label)

        if len(x) == 0:
            print(self.__class__.__name__, "Length of x array is 0.")
            return

        if len(x) > self.mask_length:
            x = self._mask_data(x)
            y = self._mask_data(y)

        nodraw = False
        if 'nodraw' in kwargs:
            if kwargs.pop('nodraw'):
                nodraw = True

        if type(line) is bool:
            print(self.__class__.__name__, ": Plotting ", label)
            if type(ax) == int:
                line = axes[ax].plot(x, y, fmt, **kwargs)
                self.txt = axes[ax].text(pos[0],
                                         pos[1],
                                         text,
                                         transform=axes[ax].transAxes,
                                         fontsize=12)
            elif hasattr(ax, 'plot'):
                line = ax.plot(x, y, fmt, **kwargs)
                self.txt = axes[ax].text(pos[0],
                                         pos[1],
                                         text,
                                         transform=axes[ax].transAxes,
                                         fontsize=12)
            else:
                raise TypeError('ax can be an int or the axis itself!')

            self.lines_list.append(label)
            self.draw_legend()
        else:
            if line.get_animated():
                self.set_animation_for_lines(
                    False)  # doesn't work otherwise, dunno why.
            line.set_data(x, y)
            self.txt.set_text(text)

        if not nodraw:
            self._normalize_bool_changed()
            self.draw()  # draws with respect to autolim etc.
            # self.start_thread('draw()')  # kind of working ...

        if hasattr(line, "append"):
            return line[0]
        else:
            return line

    def blit(self, x, y, ax=0, **kwargs):
        kwargs['animated'] = True

        self.img_bool = False
        fmt, label = self._test_plot_kwargs(kwargs)
        axes = self.get_axes()
        line = self._is_line_in_axes(label)

        assert len(x) > 0, "BasicFigure: Length of x array is 0"

        if len(x) > self.mask_length:
            x = self._mask_data(x)
            y = self._mask_data(y)

        nodraw = False
        if 'nodraw' in kwargs:
            if kwargs.pop('nodraw'):
                nodraw = True

        if not self._is_line_in_axes(label):
            print(self.__class__.__name__, ": Plotting blitted ", label)
            axes[ax].plot(x, y, fmt, **kwargs)
            self.lines_list.append(label)
            self.draw_legend()
            self.figure.canvas.draw()
            self.background = self.canvas.copy_from_bbox(
                self.axes_selector.bbox)
            self.refresh_lines(ax)
        else:
            l = self._is_line_in_axes(label)
            if not l.get_animated():
                self.set_animation_for_lines(True)
                self.blit(x, y, ax=0, **kwargs)

            self.canvas.restore_region(self.background)
            self._setlinedata(x, y, ax, **kwargs)
            self.refresh_lines(ax)
            self.canvas.blit(self.axes_selector.bbox)

        self._normalize_bool_changed()

    def _setlinedata(self, x, y, ax, **kwargs):
        x = np.array(x)
        y = np.array(y)
        l = self._is_line_in_axes(kwargs['label'])
        l.set_data(x, y)

    def mpl_setup(self):
        print(self.__class__.__name__,
              ": Running mpl_setup - connecting button press events")
        self.canvas = self.figure.canvas  # creates link (same object)
        cid = self.figure.canvas.mpl_connect('button_press_event',
                                             self.__onclick)

    def __onclick(self, event):
        if event is None:
            return None
        self.clickdata = (event.button, event.x, event.y, event.xdata,
                          event.ydata)
        if not self.img_bool:
            self.set_animation_for_lines(False)
        print(self.__class__.__name__, ": %s" % event)

    def set_animation_for_lines(self, TF):
        self.animated = TF
        axes = self.get_axes()
        for ax in axes:
            for l in ax.get_lines():
                l.set_animated(TF)

            ax.relim()

        self.canvas.draw()
        self.background = self.canvas.copy_from_bbox(self.axes_selector.bbox)

    def refresh_lines(self, ax):
        axes = self.get_axes()
        for line in axes[ax].get_lines():
            axes[ax].draw_artist(line)

        self.canvas.update()

    def draw_legend(self, ax=None):
        if self.draw_legend_bool:

            print(self.__class__.__name__, ": Drawing Legend")
            axes = self.figure.get_axes()
            if ax == None:
                for ax in axes:
                    leg = ax.legend(loc=0, fancybox=True)
            else:
                axes[ax].legend(loc=0, fancybox=True)

    def options_group(self):
        g = HGroup(
            UItem('options_btn'),
            UItem('clear_btn'),
            UItem('line_selector', visible_when='not img_bool'),
            UItem('copy_data_btn', visible_when='not img_bool'),
            Item('normalize_bool', label='normalize'),
            Item('log_bool', label='log scale'),
            Item('draw_legend_bool', label='draw legend'),
            Item('cmap_selector', label='cmap', visible_when='img_bool'),
            UItem('image_slider_btn', visible_when='img_bool'),
            UItem('save_fig_btn'),
        )
        return g

    def options_group_axes_sel(self):
        g = HGroup(
            UItem('options_btn'),
            UItem('clear_btn'),
            UItem('line_selector', visible_when='not img_bool'),
            UItem('copy_data_btn', visible_when='not img_bool'),
            Item('axes_selector'),
            Item('normalize_bool', label='normalize'),
            Item('log_bool', label='log scale'),
            Item('draw_legend_bool', label='draw legend'),
            Item('cmap_selector', label='cmap', visible_when='img_bool'),
            UItem('image_slider_btn', visible_when='img_bool'),
            UItem('save_fig_btn'),
        )
        return g

    def traits_view(self):
        trait_view = View(
            UItem('figure', editor=MPLFigureEditor(), style='custom'),
            Include('options_group'),
            handler=MPLInitHandler,
            resizable=True,
        )
        return trait_view

    def traits_scroll_view(self):
        traits_scroll_view = View(
            UItem('figure', editor=ScrollableMPLFigureEditor(),
                  style='custom'),
            Include('options_group'),
            handler=MPLInitHandler,
            resizable=True,
            # scrollable=True,
        )
        return traits_scroll_view

    def traits_multiple_axes_view(self):
        traits_scroll_view = View(
            UItem('figure', editor=MPLFigureEditor(), style='custom'),
            Include('options_group_axes_sel'),
            handler=MPLInitHandler,
            resizable=True,
        )
        return traits_scroll_view

    def image_slider_view(self):
        g = View(
            VGroup(
                Item('vmin',
                     label='min',
                     style='custom',
                     visible_when='img_bool'),
                Item('vmax',
                     label='max',
                     style='custom',
                     visible_when='img_bool'),
                Item('autoscale'),
            ),
            resizable=True,
        )

        return g

    def traits_options_view(self):
        traits_options_view = View(
            Item('axes_selector'),
            Item('title'),
            Item('xlabel'),
            Item('ylabel'),
            Item('zlabel'),
            Item('fontsize'),
            HGroup(
                Item('grid'),
                Item('autoscale'),
                Item('mask_data_bool',
                     label='mask data',
                     visible_when='not img_bool'),
                Item('mask_length', width=-50, visible_when='not img_bool'),
            ),
            Item('clickdata', style='readonly'),
            resizable=True,
        )
        return traits_options_view
Example #15
0
class Object(HasPrivateTraits):
    """Represent a 3d object in a mayavi scene."""

    points = Array(float, shape=(None, 3))
    nn = Array(float, shape=(None, 3))
    trans = Array()
    name = Str

    # projection onto a surface
    project_to_points = Array(float, shape=(None, 3))
    project_to_tris = Array(int, shape=(None, 3))
    project_to_surface = Bool(False,
                              label='Project',
                              desc='project points '
                              'onto the surface')
    orient_to_surface = Bool(False,
                             label='Orient',
                             desc='orient points '
                             'toward the surface')
    scale_by_distance = Bool(False,
                             label='Dist.',
                             desc='scale points by '
                             'distance from the surface')
    mark_inside = Bool(False,
                       label='Mark',
                       desc='mark points inside the '
                       'surface in a different color')

    scene = Instance(MlabSceneModel, ())
    src = Instance(VTKDataSource)

    # This should be Tuple, but it is broken on Anaconda as of 2016/12/16
    color = RGBColor((1., 1., 1.))
    inside_color = RGBColor((0., 0., 0.))
    point_scale = Float(10, label='Point Scale')
    opacity = Range(low=0., high=1., value=1.)
    visible = Bool(True)

    # don't put project_to_tris here, just always set project_to_points second
    @on_trait_change('trans,points,project_to_surface,mark_inside')
    def _update_points(self):
        """Update the location of the plotted points."""
        if not hasattr(self.src, 'data'):
            return

        trans = self.trans
        if np.any(trans):
            if trans.ndim == 0 or trans.shape == (3, ) or trans.shape == (1,
                                                                          3):
                pts = self.points * trans
            elif trans.shape == (3, 3):
                pts = np.dot(self.points, trans.T)
            elif trans.shape == (4, 4):
                pts = apply_trans(trans, self.points)
            else:
                err = ("trans must be a scalar, a length 3 sequence, or an "
                       "array of shape (1,3), (3, 3) or (4, 4). "
                       "Got %s" % str(trans))
                error(None, err, "Display Error")
                raise ValueError(err)
        else:
            pts = self.points

        # Do the projection if required
        if len(self.project_to_points) > 1 and len(pts) > 0:
            surf = dict(rr=np.array(self.project_to_points),
                        tris=np.array(self.project_to_tris))
            method = 'accurate' if len(surf['rr']) <= 20484 else 'nearest'
            proj_pts, proj_nn = _project_onto_surface(pts,
                                                      surf,
                                                      project_rrs=True,
                                                      return_nn=True,
                                                      method=method)[2:4]
            vec = pts - proj_pts  # point to the surface
            if self.project_to_surface:
                pts = proj_pts
                nn = proj_nn
            else:
                nn = vec.copy()
                _normalize_vectors(nn)
            if self.mark_inside and not self.project_to_surface:
                scalars = _points_outside_surface(pts, surf).astype(int)
            else:
                scalars = np.ones(len(pts))
            # With this, a point exactly on the surface is of size point_scale
            dist = np.linalg.norm(vec, axis=-1, keepdims=True)
            self.src.data.point_data.normals = (250 * dist + 1) * nn
            self.src.data.point_data.scalars = scalars
            self.glyph.actor.mapper.scalar_range = [0., 1.]
        self.src.data.point_data.update()
        self.src.data.points = pts
        return True
Example #16
0
class RTraceGraph(RTrace):
    '''
    Collects two response evaluators to make a response graph.

    The supplied strings for var_x and var_y are used to locate the rte in
    the current response manager. The bind method is used to navigate to
    the rte and is stored in here as var_x_eval and var_y_val as Callable
    object.

    The request for new response evaluation is launched by the time loop
    and directed futher by the response manager. This method is used solely
    for collecting the data, not for their visualization in the viewer.

    The timer_tick method is invoked when the visualization of the Graph
    should be synchronized with the actual contents.
    '''

    label = Str('RTraceGraph')
    var_x = Str('')
    var_x_eval = Callable(trantient=True)
    idx_x_arr = Array
    idx_x = Int(-1, enter_set=True, auto_set=False)
    var_y = Str('')
    var_y_eval = Callable(trantient=True)
    idx_y_arr = Array
    idx_y = Int(-1, enter_set=True, auto_set=False)
    transform_x = Str(enter_set=True, auto_set=False)
    transform_y = Str(enter_set=True, auto_set=False)

    trace = Instance(MFnLineArray)

    def _trace_default(self):
        return MFnLineArray()

    print_button = ToolbarButton('Print Values',
                                 style='toolbar',
                                 trantient=True)

    @on_trait_change('print_button')
    def print_values(self, event=None):
        print 'x:\t', self.trace.xdata, '\ny:\t', self.trace.ydata

    view = View(VSplit(
        VGroup(
            HGroup(
                VGroup(
                    HGroup(Spring(), Item('var_x', style='readonly'),
                           Item('idx_x', show_label=False)),
                    Item('transform_x')),
                VGroup(
                    HGroup(Spring(), Item('var_y', style='readonly'),
                           Item('idx_y', show_label=False)),
                    Item('transform_y')), VGroup('record_on', 'clear_on')),
            HGroup(Item('refresh_button', show_label=False),
                   Item('print_button', show_label=False)),
        ),
        Item('trace@',
             editor=MFnMatplotlibEditor(adapter=MFnPlotAdapter(
                 var_x='var_x', var_y='var_y', min_size=(100, 100))),
             show_label=False,
             resizable=True),
    ),
                buttons=[OKButton, CancelButton],
                resizable=True,
                scrollable=True,
                height=0.5,
                width=0.5)

    _xdata = List(Array(float))
    _ydata = List(Array(float))

    def bind(self):
        '''
        Locate the evaluators
        '''
        self.var_x_eval = self.rmgr.rte_dict.get(self.var_x, None)
        if self.var_x_eval == None:
            raise KeyError, 'Variable %s not present in the dictionary:\n%s' % \
                            (self.var_x, self.rmgr.rte_dict.keys())

        self.var_y_eval = self.rmgr.rte_dict.get(self.var_y, None)
        if self.var_y_eval == None:
            raise KeyError, 'Variable %s not present in the dictionary:\n%s' % \
                            (self.var_y, self.rmgr.rte_dict.keys())

    def setup(self):
        self.clear()

    def close(self):
        self.write()

    def write(self):
        '''Generate the file name within the write_dir
        and submit the request for writing to the writer
        '''
        # self.writer.scalars_name = self.name
        file_base_name = 'rtrace_diagramm_%s (%s,%s).dat' % \
            (self.label, self.var_x, self.var_y)
        # full path to the data file
        file_name = os.path.join(self.dir, file_base_name)
        # file_rtrace = open( file_name, 'w' )
        self.refresh()
        np.savetxt(file_name,
                   np.vstack([self.trace.xdata, self.trace.ydata]).T)
        # pickle.dump( self, file_rtrace )
        # file.close()

    def add_current_values(self, sctx, U_k, *args, **kw):
        '''
        Invoke the evaluators in the current context for the specified control vector U_k.
        '''
        x = self.var_x_eval(sctx, U_k, *args, **kw)
        y = self.var_y_eval(sctx, U_k, *args, **kw)

        self.add_pair(x.flatten(), y.flatten())

    def add_pair(self, x, y):
        self._xdata.append(np.copy(x))
        self._ydata.append(np.copy(y))

    @on_trait_change('idx_x,idx_y')
    def redraw(self, e=None):
        if ((self.idx_x < 0 and len(self.idx_x_arr) == 0)
                or (self.idx_y < 0 and len(self.idx_y_arr) == 0)
                or self._xdata == [] or self._ydata == []):
            return
        #
        if len(self.idx_x_arr) > 0:
            print 'x: summation for', self.idx_x_arr
            xarray = np.array(self._xdata)[:, self.idx_x_arr].sum(1)
        else:
            xarray = np.array(self._xdata)[:, self.idx_x]

        if len(self.idx_y_arr) > 0:
            print 'y: summation for', self.idx_y_arr
            yarray = np.array(self._ydata)[:, self.idx_y_arr].sum(1)


#            print 'yarray', yarray
#            yarray_arr = array( self._ydata )[:, self.idx_y_arr]
#            sym_weigth_arr = 2. * ones_like( yarray_arr[1] )
#            sym_weigth_arr[0] = 4.
#            print 'yarray_arr', yarray_arr
#            print 'sym_weigth_arr', sym_weigth_arr
#            yarray = dot( yarray_arr, sym_weigth_arr )
#            print 'yarray', yarray

        else:
            yarray = np.array(self._ydata)[:, self.idx_y]

        if self.transform_x:

            def transform_x_fn(x):
                '''makes a callable function out of the Str-attribute
                "transform_x". The vectorised version of this function is
                then used to transform the values in "xarray". Note that
                the function defined in "transform_x" must be defined in
                terms of a lower case variable "x".
                '''
                return eval(self.transform_x)

            xarray = np.frompyfunc(transform_x_fn, 1, 1)(xarray)

        if self.transform_y:

            def transform_y_fn(y):
                '''makes a callable function out of the Str-attribute
                "transform_y". The vectorised version of this function is
                then used to transform the values in "yarray". Note that
                the function defined in "transform_y" must be defined in
                terms of a lower case variable "y".
                '''
                return eval(self.transform_y)

            yarray = np.frompyfunc(transform_y_fn, 1, 1)(yarray)

        self.trace.xdata = np.array(xarray)
        self.trace.ydata = np.array(yarray)
        self.trace.data_changed = True

    def timer_tick(self, e=None):
        # @todo: unify with redraw
        pass

    def clear(self):
        self._xdata = []
        self._ydata = []
        self.trace.clear()
        self.redraw()
Example #17
0
class SurfaceObject(Object):
    """Represent a solid object in a mayavi scene.

    Notes
    -----
    Doesn't automatically update plot because update requires both
    :attr:`points` and :attr:`tri`. Call :meth:`plot` after updateing both
    attributes.
    """

    rep = Enum("Surface", "Wireframe")
    tri = Array(int, shape=(None, 3))

    surf = Instance(Surface)
    surf_rear = Instance(Surface)

    view = View(
        HGroup(Item('visible', show_label=False),
               Item('color', show_label=False), Item('opacity')))

    def __init__(self, block_behind=False, **kwargs):  # noqa: D102
        self._block_behind = block_behind
        super(SurfaceObject, self).__init__(**kwargs)

    def clear(self):  # noqa: D102
        if hasattr(self.src, 'remove'):
            self.src.remove()
        if hasattr(self.surf, 'remove'):
            self.surf.remove()
        if hasattr(self.surf_rear, 'remove'):
            self.surf_rear.remove()
        self.reset_traits(['src', 'surf'])

    @on_trait_change('scene.activated')
    def plot(self):
        """Add the points to the mayavi pipeline"""
        _scale = self.scene.camera.parallel_scale
        self.clear()

        if not np.any(self.tri):
            return

        fig = self.scene.mayavi_scene
        surf = complete_surface_info(dict(rr=self.points, tris=self.tri),
                                     verbose='error')
        self.src = _create_mesh_surf(surf, fig=fig)
        rep = 'wireframe' if self.rep == 'Wireframe' else 'surface'
        surf = pipeline.surface(self.src,
                                figure=fig,
                                color=self.color,
                                representation=rep,
                                line_width=1)
        surf.actor.property.backface_culling = True
        self.surf = surf
        self.sync_trait('visible', self.surf, 'visible')
        self.sync_trait('color', self.surf.actor.property, mutual=False)
        self.sync_trait('opacity', self.surf.actor.property)
        if self._block_behind:
            surf_rear = pipeline.surface(self.src,
                                         figure=fig,
                                         color=self.color,
                                         representation=rep,
                                         line_width=1)
            surf_rear.actor.property.frontface_culling = True
            self.surf_rear = surf_rear
            self.sync_trait('color',
                            self.surf_rear.actor.property,
                            mutual=False)
            self.sync_trait('visible', self.surf_rear, 'visible')
            self.surf_rear.actor.property.opacity = 1.

        self.scene.camera.parallel_scale = _scale

    @on_trait_change('trans,points')
    def _update_points(self):
        if Object._update_points(self):
            self.src.update()  # necessary for SurfaceObject since Mayavi 4.5.0
class XDomainFE(BMCSTreeNode):

    hidden = Bool(False)
    #=========================================================================
    # Type and shape specification of state variables representing the domain
    #=========================================================================
    U_var_shape = Property(Int)

    def _get_U_var_shape(self):
        return self.mesh.n_dofs

    vtk_expand_operator = Property

    def _get_vtk_expand_operator(self):
        return self.fets.vtk_expand_operator

    K_type = Type(SysMtxArray)

    state_var_shape = Property(Tuple)

    def _get_state_var_shape(self):
        return (self.mesh.n_active_elems, self.fets.n_m,)

    #=========================================================================
    # Methods needed by XDomain to chain the subdomains
    #=========================================================================
    dof_offset = Property

    def _get_dof_offset(self):
        return self.mesh.dof_offset

    n_active_elems = Property

    def _get_n_active_elems(self):
        return self.mesh.n_active_elems

    def set_next(self, next_):
        self.mesh.next_grid = next_.mesh

    def set_prev(self, prev):
        self.mesh.prev_grid = prev.mesh

    #=========================================================================
    # Input parameters
    #=========================================================================
    coord_min = Array(Float, value=[0., 0., 0.], GEO=True)
    '''Grid geometry specification - min corner point
    '''
    coord_max = Array(Float, value=[1., 1., 1.], MESH=True)
    '''Grid geometry specification - max corner point
    '''
    shape = Array(Int, value=[1, 1, 1], MESH=True)
    '''Number of elements in the individual dimensions
    '''
    geo_transform = Callable
    '''Geometry transformation
    '''
    integ_factor = Float(1.0, input=True, CS=True)
    '''Integration factor used to multiply the integral
    '''
    fets = Instance(IFETSEval, input=True, FE=True)
    '''Finite element type
    '''

    dim_u = Int(2)

    Diff1_abcd = Array(np.float_, input=True)
    '''Symmetric operator distributing the first order
    derivatives of the shape functions into the 
    tensor field
    '''

    def _Diff1_abcd_default(self):
        delta = np.identity(self.dim_u)
        # symmetrization operator
        Diff1_abcd = 0.5 * (
            np.einsum('ac,bd->abcd', delta, delta) +
            np.einsum('ad,bc->abcd', delta, delta)
        )
        return Diff1_abcd

    mesh = Instance(IFEUniformDomain)

    X_Id = DelegatesTo('mesh')
    I_Ei = DelegatesTo('mesh')

    x_Eia = Property(depends_on='MESH,GEO,CS,FE')

    def _get_x_Eia(self):
        x_Ia = self.X_Id
        I_Ei = self.I_Ei
        x_Eia = x_Ia[I_Ei, :]
        return x_Eia

    x_Ema = Property(depends_on='MESH,GEO,CS,FE')

    def _get_x_Ema(self):
        return np.einsum(
            'im,Eia->Ema', self.fets.N_im, self.x_Eia
        )

    o_Ia = Property(depends_on='MESH,GEO,CS,FE')

    @cached_property
    def _get_o_Ia(self):
        x_Ia = self.mesh.X_Id
        n_I, _ = x_Ia.shape
        n_a = self.mesh.n_nodal_dofs
        do = self.mesh.dof_offset
        return do + np.arange(n_I * n_a, dtype=np.int_).reshape(-1, n_a)

    o_Eia = Property(depends_on='MESH,GEO,CS,FE')

    @cached_property
    def _get_o_Eia(self):
        I_Ei = self.I_Ei
        return self.o_Ia[I_Ei]

    B1_Einabc = Property(depends_on='MESH,GEO,CS,FE')
    '''Kinematic mapping between displacement and strain in every
    visualization point
    '''

    @cached_property
    def _get_B1_Einabc(self):
        inv_J_Enar = np.linalg.inv(self.J_Enar)
        print('************************************* inv_J', inv_J_Enar)
        return np.einsum(
            'abcd,imr,Eidr->Eimabc',
#            'abcd,imr,Emrd->Eimabc',
            self.Diff1_abcd, self.fets.dN_inr, inv_J_Enar
        )

    I_Ei = Property(depends_on='MESH,GEO,CS,FE')
    '''[element, node] -> global node
    '''

    def _get_I_Ei(self):
        return self.mesh.I_Ei

    det_J_Em = Property(depends_on='MESH,GEO,CS,FE')
    '''Jacobi matrix in integration points
    '''
    @cached_property
    def _get_det_J_Em(self):
        return np.linalg.det(self.J_Emar)

    J_Emar = Property(depends_on='MESH,GEO,CS,FE')
    '''Jacobi matrix in integration points
    '''
    @cached_property
    def _get_J_Emar(self):
        return np.einsum(
            'imr,Eia->Emar', self.fets.dN_imr, self.x_Eia
        )

    J_Enar = Property(depends_on='MESH,GEO,CS,FE')
    '''Jacobi matrix in nodal points
    '''
    @cached_property
    def _get_J_Enar(self):
        return np.einsum(
            'inr,Eia->Enar',
            self.fets.dN_inr, self.x_Eia
        )

    #=========================================================================
    # Conversion between linear algebra objects and field variables
    #=========================================================================
    B1_Eimabc = Property(depends_on='MESH,GEO,CS,FE')
    '''Kinematic mapping between displacements and strains in every
    integration point.
    '''
    @cached_property
    def _get_B1_Eimabc(self):
        inv_J_Emar = np.linalg.inv(self.J_Emar)
        return np.einsum(
            'abcd,inr,Emrd->Einabc',
            self.Diff1_abcd, self.fets.dN_imr, inv_J_Emar
        )

    B_Eimabc = Property(depends_on='MESH,GEO,CS,FE')
    '''Kinematic mapping between displacements and strains in every
    integration point.
    '''
    @cached_property
    def _get_B_Eimabc(self):
        return self.B1_Eimabc

    BB_Emicjdabef = Property(depends_on='MESH,GEO,CS,FE')
    '''Quadratic form of the kinematic mapping.
    '''

    def _get_BB_Emicjdabef(self):
        return np.einsum(
            '...Eimabc,...Ejmefd, Em, m->...Emicjdabef',
            self.B_Eimabc, self.B_Eimabc, self.det_J_Em, self.fets.w_m
        )

    n_dofs = Property

    def _get_n_dofs(self):
        return self.mesh.n_dofs

    def U2u(self, U_Eia):
        return U_Eia

    def f2F(self, f_Eic):
        return f_Eic

    def k2K(self, K_Eicjd):
        return K_Eicjd

    def map_U_to_field(self, U):
        U_Eia = U[self.o_Eia]
        # coordinate transform to local
        U_Eia = self.xU2u(U_Eia)
        eps_Emab = np.einsum(
            'Eimabc,Eic->Emab',
            self.B_Eimabc, U_Eia
        )
        return eps_Emab

    def map_field_to_F(self, sig_Emab):
        f_Eic = self.integ_factor * np.einsum(
            'm,Eimabc,Emab,Em->Eic',
            self.fets.w_m, self.B_Eimabc, sig_Emab, self.det_J_Em
        )
        # coordinate transform to global
        f_Eic = self.xf2F(f_Eic)
        _, n_i, n_c = f_Eic.shape
        f_Ei = f_Eic.reshape(-1, n_i * n_c)
        o_E = self.o_Eia.reshape(-1, n_i * n_c)
        return o_E.flatten(), f_Ei.flatten()

    def map_field_to_K(self, D_Emabef):
        k_Eicjd = self.integ_factor * np.einsum(
            'Emicjdabef,Emabef->Eicjd',
            self.BB_Emicjdabef, D_Emabef
        )
        _, _, n_a, n_b, _,_ = D_Emabef.shape
        # coordinate transform to global
        _K_Eicjd = self.k2K(k_Eicjd)
        K_Eicjd = self.xk2K(k_Eicjd)
        # print(_K_Eicjd - K_Eicjd)
        _, n_i, n_c, n_j, n_d = K_Eicjd.shape
        K_Eij = K_Eicjd.reshape(-1, n_i * n_c, n_j * n_d)
        o_Ei = self.o_Eia.reshape(-1, n_i * n_c)
        #print(K_Eij)
        return SysMtxArray(mtx_arr=K_Eij, dof_map_arr=o_Ei)

    debug_cell_data = Bool(False)
    # @todo - comment this procedure`

    def get_vtk_cell_data(self, position, point_offset, cell_offset):
        if position == 'nodes':
            subcell_offsets, subcell_lengths, subcells, subcell_types = \
                self.fets.vtk_node_cell_data
        elif position == 'int_pnts':
            subcell_offsets, subcell_lengths, subcells, subcell_types = \
                self.fets.vtk_ip_cell_data

        if self.debug_cell_data:
            print('subcell_offsets')
            print(subcell_offsets)
            print('subcell_lengths')
            print(subcell_lengths)
            print('subcells')
            print(subcells)
            print('subcell_types')
            print(subcell_types)

        n_subcells = subcell_types.shape[0]
        n_cell_points = self.n_cell_points
        subcell_size = subcells.shape[0] + n_subcells

        if self.debug_cell_data:
            print('n_cell_points', n_cell_points)
            print('n_cells', self.n_cells)

        vtk_cell_array = np.zeros((self.n_cells, subcell_size), dtype=int)

        idx_cell_pnts = np.repeat(True, subcell_size)

        if self.debug_cell_data:
            print('idx_cell_pnts')
            print(idx_cell_pnts)

        idx_cell_pnts[subcell_offsets] = False

        if self.debug_cell_data:
            print('idx_cell_pnts')
            print(idx_cell_pnts)

        idx_lengths = idx_cell_pnts == False

        if self.debug_cell_data:
            print('idx_lengths')
            print(idx_lengths)

        point_offsets = np.arange(self.n_cells) * n_cell_points
        point_offsets += point_offset

        if self.debug_cell_data:
            print('point_offsets')
            print(point_offsets)

        vtk_cell_array[:, idx_cell_pnts] = point_offsets[
            :, None] + subcells[None, :]
        vtk_cell_array[:, idx_lengths] = subcell_lengths[None, :]

        if self.debug_cell_data:
            print('vtk_cell_array')
            print(vtk_cell_array)

        n_active_cells = self.mesh.n_active_elems

        if self.debug_cell_data:
            print('n active cells')
            print(n_active_cells)

        cell_offsets = np.arange(n_active_cells, dtype=int) * subcell_size
        cell_offsets += cell_offset
        vtk_cell_offsets = cell_offsets[:, None] + subcell_offsets[None, :]

        if self.debug_cell_data:
            print('vtk_cell_offsets')
            print(vtk_cell_offsets)

        vtk_cell_types = np.zeros(
            self.n_cells * n_subcells, dtype=int
        ).reshape(self.n_cells, n_subcells)
        vtk_cell_types += subcell_types[None, :]

        if self.debug_cell_data:
            print('vtk_cell_types')
            print(vtk_cell_types)

        return (vtk_cell_array.flatten(),
                vtk_cell_offsets.flatten(),
                vtk_cell_types.flatten())

    n_cells = Property(Int)

    def _get_n_cells(self):
        '''Return the total number of cells'''
        return self.mesh.n_active_elems

    n_cell_points = Property(Int)

    def _get_n_cell_points(self):
        '''Return the number of points defining one cell'''
        return self.fets.n_vtk_r
Example #19
0
class GPUWarp(HasTraits):
    #settings_id = 'gpuwarp'
    shape = (512, 512)  #(1024, 1024) #
    warp_coords = Array(shape=shape + (2, ), dtype=np.float32)
    warp_scaling = Array(shape=shape, dtype=np.float32)
    _warp_coords_buf = Python  #cl.Buffer
    _img_warped_buf = Python  #cl.Buffer
    kernel = Python  #cl.Kernel

    src_shape = (0, 0)  # Tuple(Int, Int)
    src_offset = (0, 0)  # Tuple(CFloat, CFloat)
    src_bits = 8

    kernel_src = """
    __constant sampler_t sampler_img_src = (CLK_NORMALIZED_COORDS_TRUE
                                            | CLK_ADDRESS_CLAMP 
                                            | CLK_FILTER_LINEAR
                                            );

    __kernel
    void warp(
    __global float2* warp_coords,
    __global float* warp_scaling,
    float2 warp_coords_offset,
    float2 warp_coords_scaling,
    __read_only image2d_t img_src,
    //__global uchar* img_dst
    __global ushort* img_dst
    )
    {
    const unsigned int xid = get_global_id(0);
    const unsigned int yid = get_global_id(1);
    const unsigned int id = xid + yid*get_global_size(0);
    
    const float2 pos_src = (warp_coords[id] - warp_coords_offset + 0.5f)*warp_coords_scaling;
    const float scale = warp_scaling[id];
    const float dst = clamp(scale * read_imagef(img_src, sampler_img_src, pos_src).x,
                        0.f,
                        1.f);
    //img_dst[id] = convert_uchar_sat(dst*255.f);
    img_dst[id] = convert_ushort_sat(dst*65535.f);
    }
    
"""

    def __init__(self, *args, **kwargs):
        super(GPUWarp, self).__init__(*args, **kwargs)
        self._warp_coords_buf = cla.zeros(cl_queue,
                                          self.shape + (2, ),
                                          dtype=np.float32)
        self._warp_scaling_buf = cla.zeros(cl_queue,
                                           self.shape,
                                           dtype=np.float32)
        self._img_warped_buf = cla.zeros(cl_queue, self.shape, dtype=np.uint16)

        #init cl
        program = cl.Program(cl_context, self.kernel_src).build()
        #build_info = program.get_build_info(self.device, cl.program_build_info.LOG)
        #if build_info:
        #    print 'build info GPUWarp:'
        #    print build_info
        self.kernel = program.warp

    def set_source_image(self, src):
        src_shape = (src.roi_h, src.roi_w)  #src.data_roi.shape
        src_offset = src.pos

        if self.src_offset != src_offset:
            self.src_offset = src_offset

        if self.src_shape != src_shape or self.src_bits != src.bits:  # size changed, reinit source image buffer
            self.src_shape = src_shape
            self.src_bits = src.bits
            self._img_src_buf = cl.Image(
                cl_context,
                cl.mem_flags.HOST_WRITE_ONLY | cl.mem_flags.READ_ONLY,
                format=cl.ImageFormat(
                    cl.channel_order.R,
                    {
                        8: cl.channel_type.UNORM_INT8,
                        16: cl.channel_type.UNORM_INT16
                    }[src.bits],
                ),
                shape=(src_shape[1], src_shape[0])
                #is_array=False,
            )

        cl.enqueue_copy(
            cl_queue,
            self._img_src_buf,
            src.data_roi,
            origin=(0, 0),
            region=(src_shape[1], src_shape[0]),
        )

    def _warp_coords_changed(self, value):
        self._warp_coords_buf.set(value)
        print('updated warp coords')

    def _warp_scaling_changed(self, value):
        self._warp_scaling_buf.set(value)

    def warp_image(self):
        self.kernel.set_args(
            self._warp_coords_buf.data,
            self._warp_scaling_buf.data,
            cl.cltypes.make_float2(*self.src_offset),
            cl.cltypes.make_float2(1. / self.src_shape[1],
                                   1. / self.src_shape[0]),
            self._img_src_buf,
            self._img_warped_buf.data,
        )
        event = cl.enqueue_nd_range_kernel(
            cl_queue,
            self.kernel,
            self.shape,
            (8, 8),
        )
        event.wait()
Example #20
0
class ODMR(ManagedJob, GetSetItemsMixin):
    """Provides ODMR measurements."""

    # starting and stopping
    keep_data = Bool(
        False)  # helper variable to decide whether to keep existing data
    resubmit_button = Button(
        label='resubmit',
        desc=
        'Submits the measurement to the job manager. Tries to keep previously acquired data. Behaves like a normal submit if sequence or time bins have changed since previous run.'
    )

    # measurement parameters
    power = Range(low=-100.,
                  high=25.,
                  value=-8,
                  desc='Power [dBm]',
                  label='Power [dBm]',
                  mode='text',
                  auto_set=False,
                  enter_set=True)
    frequency_begin = Range(low=1,
                            high=20e9,
                            value=2.85e9,
                            desc='Start Frequency [Hz]',
                            label='Begin [Hz]',
                            editor=TextEditor(auto_set=False,
                                              enter_set=True,
                                              evaluate=float,
                                              format_str='%.4e'))
    frequency_end = Range(low=1,
                          high=20e9,
                          value=2.90e9,
                          desc='Stop Frequency [Hz]',
                          label='End [Hz]',
                          editor=TextEditor(auto_set=False,
                                            enter_set=True,
                                            evaluate=float,
                                            format_str='%.4e'))
    frequency_delta = Range(low=1e-3,
                            high=20e9,
                            value=1e6,
                            desc='frequency step [Hz]',
                            label='Delta [Hz]',
                            editor=TextEditor(auto_set=False,
                                              enter_set=True,
                                              evaluate=float,
                                              format_str='%.4e'))
    t_pi = Range(low=1.,
                 high=100000.,
                 value=1000.,
                 desc='length of pi pulse [ns]',
                 label='pi [ns]',
                 mode='text',
                 auto_set=False,
                 enter_set=True)
    laser = Range(low=1.,
                  high=10000.,
                  value=300.,
                  desc='laser [ns]',
                  label='laser [ns]',
                  mode='text',
                  auto_set=False,
                  enter_set=True)
    wait = Range(low=1.,
                 high=10000.,
                 value=1000.,
                 desc='wait [ns]',
                 label='wait [ns]',
                 mode='text',
                 auto_set=False,
                 enter_set=True)
    pulsed = Bool(False, label='pulsed')
    power_p = Range(low=-100.,
                    high=25.,
                    value=-20,
                    desc='Power Pmode [dBm]',
                    label='Power[dBm]',
                    mode='text',
                    auto_set=False,
                    enter_set=True)
    frequency_begin_p = Range(low=1,
                              high=20e9,
                              value=2.87e9,
                              desc='Start Frequency Pmode[Hz]',
                              label='Begin[Hz]',
                              editor=TextEditor(auto_set=False,
                                                enter_set=True,
                                                evaluate=float,
                                                format_str='%.4e'))
    frequency_end_p = Range(low=1,
                            high=20e9,
                            value=2.88e9,
                            desc='Stop Frequency Pmode[Hz]',
                            label='End[Hz]',
                            editor=TextEditor(auto_set=False,
                                              enter_set=True,
                                              evaluate=float,
                                              format_str='%.4e'))
    frequency_delta_p = Range(low=1e-3,
                              high=20e9,
                              value=1e5,
                              desc='frequency step Pmode[Hz]',
                              label='Delta[Hz]',
                              editor=TextEditor(auto_set=False,
                                                enter_set=True,
                                                evaluate=float,
                                                format_str='%.4e'))
    seconds_per_point = Range(low=20e-3,
                              high=1,
                              value=20e-3,
                              desc='Seconds per point',
                              label='Seconds per point',
                              mode='text',
                              auto_set=False,
                              enter_set=True)
    stop_time = Range(
        low=1.,
        value=np.inf,
        desc='Time after which the experiment stops by itself [s]',
        label='Stop time [s]',
        mode='text',
        auto_set=False,
        enter_set=True)
    n_lines = Range(low=1,
                    high=10000,
                    value=50,
                    desc='Number of lines in Matrix',
                    label='Matrix lines',
                    mode='text',
                    auto_set=False,
                    enter_set=True)

    # control data fitting
    perform_fit = Bool(False, label='perform fit')
    number_of_resonances = Trait(
        'auto', String('auto', auto_set=False, enter_set=True),
        Int(10000.,
            desc='Number of Lorentzians used in fit',
            label='N',
            auto_set=False,
            enter_set=True))
    threshold = Range(
        low=-99,
        high=99.,
        value=-50.,
        desc=
        'Threshold for detection of resonances [%]. The sign of the threshold specifies whether the resonances are negative or positive.',
        label='threshold [%]',
        mode='text',
        auto_set=False,
        enter_set=True)

    # fit result
    fit_parameters = Array(value=np.array((np.nan, np.nan, np.nan, np.nan)))
    fit_frequencies = Array(value=np.array((np.nan, )), label='frequency [Hz]')
    fit_line_width = Array(value=np.array((np.nan, )), label='line_width [Hz]')
    fit_contrast = Array(value=np.array((np.nan, )), label='contrast [%]')

    # measurement data
    frequency = Array()
    counts = Array()
    counts_matrix = Array()
    run_time = Float(value=0.0, desc='Run time [s]', label='Run time [s]')

    # plotting
    line_label = Instance(PlotLabel)
    line_data = Instance(ArrayPlotData)
    matrix_data = Instance(ArrayPlotData)
    line_plot = Instance(Plot, editor=ComponentEditor())
    matrix_plot = Instance(Plot, editor=ComponentEditor())

    #added by Farida, Magnetic field calculation
    central_dip_position = Float(value=0.0,
                                 desc='Centr freq [GHz]',
                                 label='Centr freq [GHz]',
                                 editor=TextEditor(auto_set=False,
                                                   enter_set=False,
                                                   evaluate=float,
                                                   format_str='%.4e'))
    magnetic_field = Float(value=0.0,
                           desc='Magnetic field [G]',
                           label='Magnetic field [G]',
                           editor=TextEditor(auto_set=False,
                                             enter_set=False,
                                             evaluate=float,
                                             format_str='%.1f'))
    crossing = Bool(False, label='crossing')

    def __init__(self):
        super(ODMR, self).__init__()
        self._create_line_plot()
        self._create_matrix_plot()
        self.on_trait_change(self._update_line_data_index,
                             'frequency',
                             dispatch='ui')
        self.on_trait_change(self._update_line_data_value,
                             'counts',
                             dispatch='ui')
        self.on_trait_change(self._update_line_data_fit,
                             'fit_parameters',
                             dispatch='ui')
        self.on_trait_change(self._update_matrix_data_value,
                             'counts_matrix',
                             dispatch='ui')
        self.on_trait_change(self._update_matrix_data_index,
                             'n_lines,frequency',
                             dispatch='ui')
        self.on_trait_change(
            self._update_fit,
            'counts,perform_fit,number_of_resonances,threshold',
            dispatch='ui')
        self.on_trait_change(self._update_magnetic_field,
                             'counts',
                             dispatch='ui')

    def _counts_matrix_default(self):
        return np.zeros((self.n_lines, len(self.frequency)))

    def _frequency_default(self):
        if self.pulsed:
            return np.arange(self.frequency_begin_p,
                             self.frequency_end_p + self.frequency_delta_p,
                             self.frequency_delta_p)
        else:
            return np.arange(self.frequency_begin,
                             self.frequency_end + self.frequency_delta,
                             self.frequency_delta)

    def _counts_default(self):
        return np.zeros(self.frequency.shape)

    # data acquisition

    def apply_parameters(self):
        """Apply the current parameters and decide whether to keep previous data."""
        if self.pulsed:
            frequency = np.arange(
                self.frequency_begin_p,
                self.frequency_end_p + self.frequency_delta_p,
                self.frequency_delta_p)
        else:
            frequency = np.arange(self.frequency_begin,
                                  self.frequency_end + self.frequency_delta,
                                  self.frequency_delta)

        if not self.keep_data or np.any(frequency != self.frequency):
            self.frequency = frequency
            self.counts = np.zeros(frequency.shape)
            self.run_time = 0.0

        self.keep_data = False  # when job manager stops and starts the job, data should be kept. Only new submission should clear data.

    def _run(self):

        try:
            self.state = 'run'
            self.apply_parameters()

            if self.run_time >= self.stop_time:
                self.state = 'done'
                return

            # if pulsed, turn on sequence
            if self.pulsed:
                ha.PulseGenerator().Sequence([(['laser'], self.laser),
                                              ([], self.wait),
                                              (['mw'], self.t_pi), ([], 15)])
            else:
                ha.PulseGenerator().Open()

            n = len(self.frequency)

            if self.pulsed:
                ha.Microwave().setPower(self.power_p)
                ha.Microwave().initSweep(
                    self.frequency,
                    self.power_p * np.ones(self.frequency.shape))
            else:
                ha.Microwave().setPower(self.power)
                ha.Microwave().initSweep(
                    self.frequency, self.power * np.ones(self.frequency.shape))
            ha.Counter().configure(n, self.seconds_per_point, DutyCycle=0.8)
            time.sleep(0.5)

            while self.run_time < self.stop_time:
                start_time = time.time()
                if threading.currentThread().stop_request.isSet():
                    break
                ha.Microwave().resetListPos()
                counts = ha.Counter().run() / 1e3
                self.run_time += time.time() - start_time
                self.counts += counts
                self.counts_matrix = np.vstack(
                    (counts, self.counts_matrix[:-1, :]))
                self.trait_property_changed('counts', self.counts)
                """
                ha.Microwave().doSweep()
                
                timeout = 3.
                start_time = time.time()
                while not self._count_between_markers.ready():
                    time.sleep(0.1)
                    if time.time() - start_time > timeout:
                        print "count between markers timeout in ODMR"
                        break
                        
                counts = self._count_between_markers.getData(0)
                self._count_between_markers.clean()
                """

            if self.run_time < self.stop_time:
                self.state = 'idle'
            else:
                self.state = 'done'
            if self.pulsed:
                ha.Microwave().setOutput(None, self.frequency_begin_p)
            else:
                ha.Microwave().setOutput(None, self.frequency_begin)
            ha.PulseGenerator().Light()
            ha.Counter().clear()
        except:
            logging.getLogger().exception('Error in odmr.')
            self.state = 'error'
        finally:
            if self.pulsed:
                ha.Microwave().setOutput(None, self.frequency_begin_p)
            else:
                ha.Microwave().setOutput(None, self.frequency_begin)

    # fitting
    def _update_fit(self):
        if self.perform_fit:
            N = self.number_of_resonances
            if N != 'auto':
                N = int(N)
            try:
                p = fitting.fit_multiple_lorentzians(self.frequency,
                                                     self.counts,
                                                     N,
                                                     threshold=self.threshold *
                                                     0.01)
            except Exception:
                logging.getLogger().debug('ODMR fit failed.', exc_info=True)
                p = np.nan * np.empty(4)
        else:
            p = np.nan * np.empty(4)
        self.fit_parameters = p
        self.fit_frequencies = p[1::3]
        self.fit_line_width = p[2::3]
        N = len(p) / 3
        contrast = np.empty(N)
        c = p[0]
        pp = p[1:].reshape((N, 3))
        for i, pn in enumerate(pp):
            a = pn[2]
            g = pn[1]
            A = np.abs(a / (np.pi * g))
            if a > 0:
                contrast[i] = 100 * A / (A + c)
            else:
                contrast[i] = 100 * A / c
        self.fit_contrast = contrast

    # plotting

    def _create_line_plot(self):
        line_data = ArrayPlotData(frequency=np.array((0., 1.)),
                                  counts=np.array((0., 0.)),
                                  fit=np.array((0., 0.)))
        line_plot = Plot(line_data,
                         padding=8,
                         padding_left=64,
                         padding_bottom=32)
        line_plot.plot(('frequency', 'counts'), style='line', color='blue')
        line_plot.index_axis.title = 'Frequency [MHz]'
        line_plot.value_axis.title = 'Fluorescence/k'
        line_label = PlotLabel(text='',
                               hjustify='left',
                               vjustify='bottom',
                               position=[64, 128])
        line_plot.overlays.append(line_label)
        self.line_label = line_label
        self.line_data = line_data
        self.line_plot = line_plot

    def _create_matrix_plot(self):
        matrix_data = ArrayPlotData(image=np.zeros((2, 2)))
        matrix_plot = Plot(matrix_data,
                           padding=8,
                           padding_left=64,
                           padding_bottom=32)
        matrix_plot.index_axis.title = 'Frequency [MHz]'
        matrix_plot.value_axis.title = 'line #'
        matrix_plot.img_plot('image',
                             xbounds=(self.frequency[0], self.frequency[-1]),
                             ybounds=(0, self.n_lines),
                             colormap=Spectral)
        self.matrix_data = matrix_data
        self.matrix_plot = matrix_plot

    def _perform_fit_changed(self, new):
        plot = self.line_plot
        if new:
            plot.plot(('frequency', 'fit'),
                      style='line',
                      color='red',
                      name='fit')
            self.line_label.visible = True
        else:
            plot.delplot('fit')
            self.line_label.visible = False
        plot.request_redraw()

    def _update_line_data_index(self):
        self.line_data.set_data('frequency', self.frequency * 1e-6)
        self.counts_matrix = self._counts_matrix_default()

    def _update_line_data_value(self):
        self.line_data.set_data('counts', self.counts)

    def _update_line_data_fit(self):
        if not np.isnan(self.fit_parameters[0]):
            self.line_data.set_data(
                'fit',
                fitting.NLorentzians(*self.fit_parameters)(self.frequency))
            p = self.fit_parameters
            f = p[1::3]
            w = p[2::3]
            N = len(p) / 3
            contrast = np.empty(N)
            c = p[0]
            pp = p[1:].reshape((N, 3))
            for i, pi in enumerate(pp):
                a = pi[2]
                g = pi[1]
                A = np.abs(a / (np.pi * g))
                if a > 0:
                    contrast[i] = 100 * A / (A + c)
                else:
                    contrast[i] = 100 * A / c
            s = ''
            for i, fi in enumerate(f):
                s += 'f %i: %.6e Hz, HWHM %.3e Hz, contrast %.1f%%\n' % (
                    i + 1, fi, w[i], contrast[i])
            self.line_label.text = s

    def _update_matrix_data_value(self):
        self.matrix_data.set_data('image', self.counts_matrix)

    def _update_matrix_data_index(self):
        if self.n_lines > self.counts_matrix.shape[0]:
            self.counts_matrix = np.vstack(
                (self.counts_matrix,
                 np.zeros((self.n_lines - self.counts_matrix.shape[0],
                           self.counts_matrix.shape[1]))))
        else:
            self.counts_matrix = self.counts_matrix[:self.n_lines]
        self.matrix_plot.components[0].index.set_data(
            (self.frequency[0] * 1e-6, self.frequency[-1] * 1e-6),
            (0.0, float(self.n_lines)))

    # saving data

    def save_line_plot(self, filename):
        self.save_figure(self.line_plot, filename + '_ODMR_Line_Plot.png')

    def save_matrix_plot(self, filename):
        self.save_figure(self.matrix_plot, filename)

    def save_all(self, filename):
        self.save_line_plot(filename)
        self.save_matrix_plot(filename + '_ODMR_Matrix_Plot.png')
        self.save(filename + '_ODMR.pyd')

    # react to GUI events

    def submit(self):
        """Submit the job to the JobManager."""
        self.keep_data = False
        ManagedJob.submit(self)
        print threading.currentThread().getName()

    def resubmit(self):
        """Submit the job to the JobManager."""
        self.keep_data = True
        ManagedJob.submit(self)

    def _resubmit_button_fired(self):
        """React to start button. Submit the Job."""
        self.resubmit()

    #Calculation of the center between 2 lines

    def _update_magnetic_field(self):
        if len(self.fit_frequencies) > 0:
            if self.crossing:
                self.magnetic_field = (2870 +
                                       self.fit_frequencies[0] * 1e-6) / 2.8
            else:
                self.magnetic_field = (2870 -
                                       self.fit_frequencies[0] * 1e-6) / 2.8
            self.central_dip_position = self.fit_frequencies.sum() / len(
                self.fit_frequencies)
        return self.magnetic_field, self.central_dip_position * 1e-9

    traits_view = View(VGroup(
        HGroup(
            Item('submit_button', show_label=False),
            Item('remove_button', show_label=False),
            Item('resubmit_button', show_label=False),
            Item('priority', enabled_when='state != "run"'),
            Item('state', style='readonly'),
            Item('run_time', style='readonly', format_str='%.f'),
            Item('stop_time'),
        ),
        Group(
            HGroup(
                Item('pulsed', enabled_when='state != "run"'),
                Item('power_p', width=-40, enabled_when='state != "run"'),
                Item('frequency_begin_p',
                     width=-80,
                     enabled_when='state != "run"'),
                Item('frequency_end_p',
                     width=-80,
                     enabled_when='state != "run"'),
                Item('frequency_delta_p',
                     width=-80,
                     enabled_when='state != "run"'),
                Item('t_pi', width=-50, enabled_when='state != "run"'),
                label='Pulsed Mode',
            ),
            HGroup(
                Item('power', width=-40, enabled_when='state != "run"'),
                Item('frequency_begin',
                     width=-80,
                     enabled_when='state != "run"'),
                Item('frequency_end', width=-80,
                     enabled_when='state != "run"'),
                Item('frequency_delta',
                     width=-80,
                     enabled_when='state != "run"'),
                label='CW Mode',
            ),
            layout='tabbed',
        ),
        HGroup(
            Item('seconds_per_point', width=-40,
                 enabled_when='state != "run"'),
            Item('laser', width=-50, enabled_when='state != "run"'),
            Item('wait', width=-50, enabled_when='state != "run"'),
        ),
        HGroup(
            Item('perform_fit'),
            Item('number_of_resonances', width=-60),
            Item('threshold', width=-60),
            Item('n_lines', width=-60),
        ),
        HGroup(
            Item('fit_contrast', width=-90, style='readonly'),
            Item('fit_line_width', width=-90, style='readonly'),
            Item('fit_frequencies', width=-90, style='readonly'),
        ),
        HGroup(
            Item('crossing', enabled_when='state != "run"'),
            Item('central_dip_position', style='readonly'),
            Item('magnetic_field', style='readonly'),
        ),
        VSplit(
            Item('line_plot', show_label=False, resizable=True),
            Item('matrix_plot', show_label=False, resizable=True),
        ),
    ),
                       menubar=MenuBar(
                           Menu(Action(action='saveLinePlot',
                                       name='SaveLinePlot (.png)'),
                                Action(action='saveMatrixPlot',
                                       name='SaveMatrixPlot (.png)'),
                                Action(action='save',
                                       name='Save (.pyd or .pys)'),
                                Action(action='saveAll',
                                       name='Save All (.png+.pyd)'),
                                Action(action='export',
                                       name='Export as Ascii (.asc)'),
                                Action(action='export',
                                       name='Export as Ascii (.asc)'),
                                Action(action='load', name='Load'),
                                Action(action='_on_close', name='Quit'),
                                name='File')),
                       title='ODMR',
                       width=900,
                       height=800,
                       buttons=[],
                       resizable=True,
                       handler=ODMRHandler)

    get_set_items = [
        'frequency', 'counts', 'counts_matrix', 'fit_parameters',
        'fit_contrast', 'fit_line_width', 'fit_frequencies', 'perform_fit',
        'run_time', 'power', 'frequency_begin', 'frequency_end',
        'frequency_delta', 'power_p', 'frequency_begin_p', 'frequency_end_p',
        'frequency_delta_p', 'laser', 'wait', 'pulsed', 't_pi',
        'seconds_per_point', 'stop_time', 'n_lines', 'number_of_resonances',
        'threshold', '__doc__'
    ]
Example #21
0
class CellSpec( HasTraits ):
    '''
    '''
    node_coords = Array( float, value = [[-1, -1, 0],
                         [ 1, -1, 0],
                         [ 1, 1, 0],
                         [ 1, 1, 1],
                         [-1, 1, -1],
                         [-1 / 2., 1, 0],
                         [ 0.  , 1, 0],
                         [ 1 / 2., 1, 0]] )


#    xnode_coords = List( [[-1,0,-1],
#                          [ 1,0,-1],
#                          [ 1,0, 1],
#                          [-1,0, 1]] )
#
#    node_coords = List( [[-1,-1,-1],
#                         [ 1, 0, 0],
#                         [ 1, 1, 1],
#                         ] )

    traits_view = View( Item( 'node_coords', style = 'readonly' ),
                       resizable = True,
                       height = 0.5,
                       width = 0.5 )

    _node_array = Property( Array( 'float_' ), depends_on = 'node_coords' )
    @cached_property
    def _get__node_array( self ):
        '''Get the node array as float_
        '''
        # check that the nodes are equidistant
        return array( self.node_coords, 'float_' )

    n_dims = Property( depends_on = 'node_coords' )
    @cached_property
    def _get_n_dims( self ):
        '''Get the number of dimension of the cell
        '''
        return self._node_array.shape[1]

    def get_cell_shape( self ):
        '''Get the shape of the cell grid.
        '''
        cell_shape = ones( 3, dtype = int )
        ndims = self.n_dims
        narray = self._node_array
        cell_shape[0:ndims] = array( [len( unique( narray[:, i] ) )
                                      for i in range( ndims ) ], dtype = int )
        cell_shape = array( [len( unique( narray[:, i] ) )
                                      for i in range( ndims ) ], dtype = int )
        return cell_shape

    def get_cell_slices( self ):
        '''Get slices for the generation of the cell grid.
        '''
        ndims = self.n_dims
        narray = self._node_array
        return tuple( [ slice( 
                              min( narray[:, i] ),
                              max( narray[:, i] ),
                              complex( 0, len( unique( narray[:, i] ) ) ),
                              )
                        for i in range( ndims ) ] )

    #-------------------------------------------------------------------
    # Visualization-related specification
    #-------------------------------------------------------------------
    cell_lines = Array( int, value = [[0, 1], [1, 2], [2, 0]] )
    cell_faces = Array( int, value = [[0, 1, 2]] )
Example #22
0
class SysMtxArray(HasTraits):
    '''Class managing an array of equally sized matrices with 
    enumerated equations.
    
    It is used as an intermediate result of finite element integration
    of the system matrix. From this format, it is converted to a sparse
    matrix representation. For example, the coord format, the conversion 
    is possible by flattening the value array using the ravel.
    (see e.g. coo_mtx) 
    '''
    dof_map_arr = Array('int_')
    mtx_arr = Array('float_')

    n_dofs = Property(depends_on='dof_map_arr')

    @cached_property
    def _get_n_dofs(self):
        if self.dof_map_arr.shape[0] == 0:
            return 0
        return self.dof_map_arr.max() + 1

    def get_dof_ix_array(self, dof):
        '''Return the element number and index of the dof within the element
        '''
        return where(self.dof_map_arr == dof)

    def _zero_rc(self, dof_ix_array):
        '''Set row column values associated with dof a
        to zero.  
        '''
        el_arr, row_arr = dof_ix_array
        for el, i_dof in zip(el_arr, row_arr):
            k_diag = self.mtx_arr[el, i_dof, i_dof]
            self.mtx_arr[el, i_dof, :] = 0.0
            self.mtx_arr[el, :, i_dof] = 0.0

    def _add_col_to_vector(self, dof_ix_array, F, factor):
        '''Get the slice of the a-th column.
        (used for the implementation of the essential boundary conditions)
        '''
        el_arr, row_arr = dof_ix_array
        for el, i_dof in zip(el_arr, row_arr):
            rows = self.dof_map_arr[el]
            F[rows] += factor * self.mtx_arr[el, :, i_dof]

    def _get_diag_elem(self, dof_ix_array):
        '''Get the value of diagonal element at a-ths dof. 
        '''
        K_aa = 0.
        el_arr, row_arr = dof_ix_array
        for el, i_dof in zip(el_arr, row_arr):
            K_aa += self.mtx_arr[el, i_dof, i_dof]
        return K_aa

    def _add_diag_elem(self, dof_ix_array, K_aa):
        '''Get the value of diagonal element at a-ths dof. 
        '''
        el_arr, row_arr = dof_ix_array
        el = el_arr[0]
        i_dof = row_arr[0]
        self.mtx_arr[el, i_dof, i_dof] = K_aa

    def _get_col_subvector(self, dof_ix_array):
        idx_arr = array([], dtype=int)
        val_arr = array([], dtype=float)
        el_arr, row_arr = dof_ix_array
        for el, i_dof in zip(el_arr, row_arr):
            r_dofs = self.dof_map_arr[el]
            idx_arr = append(idx_arr, r_dofs)
            val_arr = append(val_arr, self.mtx_arr[el, :, i_dof])

        return idx_arr, val_arr
Example #23
0
class VolumeSlicer(HasTraits):
    # The data to plot
    data = Array()

    # The 4 views displayed
    scene3d = Instance(MlabSceneModel, ())
    scene_x = Instance(MlabSceneModel, ())
    scene_y = Instance(MlabSceneModel, ())
    scene_z = Instance(MlabSceneModel, ())

    # The data source
    data_src3d = Instance(Source)

    # The image plane widgets of the 3D scene
    ipw_3d_x = Instance(PipelineBase)
    ipw_3d_y = Instance(PipelineBase)
    ipw_3d_z = Instance(PipelineBase)

    _axis_names = dict(x=0, y=1, z=2)

    #---------------------------------------------------------------------------
    def __init__(self, **traits):
        super(VolumeSlicer, self).__init__(**traits)
        # Force the creation of the image_plane_widgets:
        self.ipw_3d_x
        self.ipw_3d_y
        self.ipw_3d_z

    #---------------------------------------------------------------------------
    # Default values
    #---------------------------------------------------------------------------
    def _data_src3d_default(self):
        return mlab.pipeline.scalar_field(self.data,
                                          figure=self.scene3d.mayavi_scene)

    def make_ipw_3d(self, axis_name):
        ipw = mlab.pipeline.image_plane_widget(
            self.data_src3d,
            figure=self.scene3d.mayavi_scene,
            plane_orientation='%s_axes' % axis_name)
        return ipw

    def _ipw_3d_x_default(self):
        return self.make_ipw_3d('x')

    def _ipw_3d_y_default(self):
        return self.make_ipw_3d('y')

    def _ipw_3d_z_default(self):
        return self.make_ipw_3d('z')

    #---------------------------------------------------------------------------
    # Scene activation callbaks
    #---------------------------------------------------------------------------
    @on_trait_change('scene3d.activated')
    def display_scene3d(self):
        outline = mlab.pipeline.outline(
            self.data_src3d,
            figure=self.scene3d.mayavi_scene,
        )
        self.scene3d.mlab.view(40, 50)
        # Interaction properties can only be changed after the scene
        # has been created, and thus the interactor exists
        for ipw in (self.ipw_3d_x, self.ipw_3d_y, self.ipw_3d_z):
            # Turn the interaction off
            ipw.ipw.interaction = 0
        self.scene3d.scene.background = (0, 0, 0)
        # Keep the view always pointing up
        self.scene3d.scene.interactor.interactor_style = \
                                 tvtk.InteractorStyleTerrain()

    def make_side_view(self, axis_name):
        scene = getattr(self, 'scene_%s' % axis_name)

        # To avoid copying the data, we take a reference to the
        # raw VTK dataset, and pass it on to mlab. Mlab will create
        # a Mayavi source from the VTK without copying it.
        # We have to specify the figure so that the data gets
        # added on the figure we are interested in.
        outline = mlab.pipeline.outline(
            self.data_src3d.mlab_source.dataset,
            figure=scene.mayavi_scene,
        )
        ipw = mlab.pipeline.image_plane_widget(outline,
                                               plane_orientation='%s_axes' %
                                               axis_name)
        setattr(self, 'ipw_%s' % axis_name, ipw)

        # Synchronize positions between the corresponding image plane
        # widgets on different views.
        ipw.ipw.sync_trait('slice_position',
                           getattr(self, 'ipw_3d_%s' % axis_name).ipw)

        # Make left-clicking create a crosshair
        ipw.ipw.left_button_action = 0

        # Add a callback on the image plane widget interaction to
        # move the others
        def move_view(obj, evt):
            position = obj.GetCurrentCursorPosition()
            for other_axis, axis_number in self._axis_names.items():
                if other_axis == axis_name:
                    continue
                ipw3d = getattr(self, 'ipw_3d_%s' % other_axis)
                ipw3d.ipw.slice_position = position[axis_number]

        ipw.ipw.add_observer('InteractionEvent', move_view)
        ipw.ipw.add_observer('StartInteractionEvent', move_view)

        # Center the image plane widget
        ipw.ipw.slice_position = 0.5 * self.data.shape[
            self._axis_names[axis_name]]

        # Position the view for the scene
        views = dict(
            x=(0, 90),
            y=(90, 90),
            z=(0, 0),
        )
        scene.mlab.view(*views[axis_name])
        # 2D interaction: only pan and zoom
        scene.scene.interactor.interactor_style = \
                                 tvtk.InteractorStyleImage()
        scene.scene.background = (0, 0, 0)

    @on_trait_change('scene_x.activated')
    def display_scene_x(self):
        return self.make_side_view('x')

    @on_trait_change('scene_y.activated')
    def display_scene_y(self):
        return self.make_side_view('y')

    @on_trait_change('scene_z.activated')
    def display_scene_z(self):
        return self.make_side_view('z')

    #---------------------------------------------------------------------------
    # The layout of the dialog created
    #---------------------------------------------------------------------------
    view = View(
        HGroup(
            Group(
                Item('scene_y',
                     editor=SceneEditor(scene_class=Scene),
                     height=250,
                     width=300),
                Item('scene_z',
                     editor=SceneEditor(scene_class=Scene),
                     height=250,
                     width=300),
                show_labels=False,
            ),
            Group(
                Item('scene_x',
                     editor=SceneEditor(scene_class=Scene),
                     height=250,
                     width=300),
                Item('scene3d',
                     editor=SceneEditor(scene_class=MayaviScene),
                     height=250,
                     width=300),
                show_labels=False,
            ),
        ),
        resizable=True,
        title='Volume Slicer',
    )
Example #24
0
class MGridCell(SDomain):
    '''
    A single mgrid cell for geometrical representation of the domain.
    
    Based on the grid_cell_spec attribute, 
    the node distribution is determined.
    
    '''
    # Everything depends on the grid_cell_specification
    #
    grid_cell_spec = Instance(MGridCellSpec)

    def _grid_cell_spec_default(self):
        return MGridCellSpec()

    # Generated grid cell coordinates as they come from mgrid.
    # The dimensionality of the mgrid comes from the
    # grid_cell_spec_attribute
    #
    grid_cell_coords = Property(depends_on='grid_cell_spec')

    @cached_property
    def _get_grid_cell_coords(self):
        grid_cell = mgrid[self.grid_cell_spec.get_cell_slices()]
        return c_[tuple([x.flatten() for x in grid_cell])]

    # Node map lists the active nodes within the grid cell
    # in the specified order
    #
    node_map = Property(Array(int), depends_on='grid_cell_spec')

    @cached_property
    def _get_node_map(self):
        n_map = []
        for node in self.grid_cell_spec._node_array:
            for idx, grid_cell_node in enumerate(self.grid_cell_coords):
                if alltrue(node == grid_cell_node):
                    n_map.append(idx)
                    continue
        return array(n_map, int)

    #-----------------------------------------------------------------
    # Visualization related methods
    #-----------------------------------------------------------------
    mvp_mgrid_ngeo_labels = Trait(MVPointLabels)

    def _mvp_mgrid_ngeo_labels_default(self):
        return MVPointLabels(name='Geo node numbers',
                             points=self._get_points,
                             scalars=self._get_node_distribution,
                             color=(153, 204, 0))

    refresh_button = Button('Draw')

    @on_trait_change('refresh_button')
    def redraw(selfMGridCell):
        '''
        '''
        self.mvp_mgrid_ngeo_labels.redraw('label_scalars')

    def _get_points(self):
        #points = self.grid_cell_coords[ ix_(self.node_map) ]

        points = self.grid_cell_coords
        print(points)
        shape = points.shape
        if shape[1] < 3:
            _points = zeros((shape[0], 3), dtype=float)
            _points[:, 0:shape[1]] = points
            return _points
        else:
            return points

    def _get_node_distribution(self):
        #return arange(len(self.node_map))
        n_points = self.grid_cell_coords.shape[0]
        full_node_map = ones(n_points, dtype=float) * -1.
        full_node_map[ix_(self.node_map)] = arange(len(self.node_map))
        print(full_node_map)
        return full_node_map

    #------------------------------------------------------------------
    # UI - related methods
    #------------------------------------------------------------------
    traits_view = View(Item('grid_cell_spec'),
                       Item('refresh_button'),
                       Item('node_map'),
                       resizable=True,
                       height=0.5,
                       width=0.5)
Example #25
0
class EMFields(HasTraits):
    """ A collection of EM fields output from Sim4Life EM simulations
    """

    # pylint: disable=too-many-instance-attributes

    #: Configuration parser.
    configuration = Instance(ConfigParser)

    #: Current participant ID.
    participant_id = Str()

    #: Path to field data file.
    data_path = File()

    #: Dictionary of data in `data_path`.
    data_dict = Dict()

    #: List of field keys that can be displayed.
    field_keys = ListStr()

    #: The currently selected field key.
    selected_field_key = Str()

    #: X values of grid in data file.
    x_vals = Array()

    #: Y values of grid in data file.
    y_vals = Array()

    #: Z values of grid in data file.
    z_vals = Array()

    #: Raw data of currently selected field from data file.
    data_arr = Array()

    #: X values of regular grid for current field.
    masked_gr_x = Array()

    #: Y values of regular grid for current field.
    masked_gr_y = Array()

    #: Z values of regular grid for current field.
    masked_gr_z = Array()

    #: Data on regular grid for current field.
    masked_grid_data = Array()

    @observe('data_path')
    def _update_data_path(self, event):
        self.data_dict = loadmat(event.new)

        g_z = self.data_dict['Axis0'][0] * 1000
        g_y = self.data_dict['Axis1'][0] * 1000
        g_x = self.data_dict['Axis2'][0] * 1000

        self.x_vals = np.array([(g_x[i] + g_x[i + 1]) / 2
                                for i in range(g_x.size - 1)])
        self.y_vals = np.array([(g_y[i] + g_y[i + 1]) / 2
                                for i in range(g_y.size - 1)])
        self.z_vals = np.array([(g_z[i] + g_z[i + 1]) / 2
                                for i in range(g_z.size - 1)])

        tmp = self.x_vals
        self.x_vals = self.z_vals
        self.z_vals = tmp

        self.field_keys = [
            key for key in self.data_dict.keys()
            if 'Snapshot' in key and not any(field in key
                                             for field in scalar_fields)
        ]

        if self.selected_field_key not in self.field_keys:
            self.selected_field_key = self.field_keys[0]
            for key in self.field_keys:
                if key.lower().startswith(
                        self._get_default_value('initial_field').lower()):
                    self.selected_field_key = key
                    break

        self.calculate_field()

    @observe('selected_field_key', post_init=True)
    def calculate_field(self, event=None):  # pylint: disable=unused-argument, too-many-locals
        """
        Calculate the current selected field values and set the grid locations and values

        Parameters
        ----------
        event : A :py:class:`traits.observation.events.TraitChangeEvent` instance
            The trait change event for selected_field_key
        """
        data_x, data_y, data_z = abs(self.data_dict[self.selected_field_key]).T

        self.data_arr = np.sqrt(data_x**2 + data_y**2 + data_z**2).reshape(
            self.z_vals.size, self.y_vals.size, self.x_vals.size)
        self.data_arr[self.data_arr == 0] = np.nan

        self.data_arr = np.swapaxes(self.data_arr, 0, 2)

        x_min = int(np.ceil(self.x_vals.min()))
        x_max = int(np.floor(self.x_vals.max()))
        y_min = int(np.ceil(self.y_vals.min()))
        y_max = int(np.floor(self.y_vals.max()))
        z_min = int(np.ceil(self.z_vals.min()))
        z_max = int(np.floor(self.z_vals.max()))

        gr_x, gr_y, gr_z = np.mgrid[x_min:x_max:len(self.x_vals) * 1j,
                                    y_min:y_max:len(self.y_vals) * 1j,
                                    z_min:z_max:len(self.z_vals) * 1j]

        points = np.array([[gr_x[i, j, k], gr_y[i, j, k], gr_z[i, j, k]]
                           for i in range(gr_x.shape[0])
                           for j in range(gr_x.shape[1])
                           for k in range(gr_x.shape[2])])

        interp_func = RegularGridInterpolator(
            (self.x_vals, self.y_vals, self.z_vals), self.data_arr)
        grid_data = interp_func(points).reshape(self.data_arr.shape)

        mask = np.all(np.isnan(grid_data), axis=(0, 1))
        masked_grid_data = grid_data[:, :, ~mask]
        masked_gr_x = gr_x[:, :, ~mask]
        masked_gr_y = gr_y[:, :, ~mask]
        masked_gr_z = gr_z[:, :, ~mask]

        maskx = np.all(np.isnan(masked_grid_data), axis=(1, 2))
        masked_grid_data = masked_grid_data[~maskx, :, :]
        masked_gr_x = masked_gr_x[~maskx, :, :]
        masked_gr_y = masked_gr_y[~maskx, :, :]
        masked_gr_z = masked_gr_z[~maskx, :, :]

        masky = np.all(np.isnan(masked_grid_data), axis=(0, 2))
        self.masked_grid_data = masked_grid_data[:, ~masky, :]
        self.masked_gr_x = masked_gr_x[:, ~masky, :]
        self.masked_gr_y = masked_gr_y[:, ~masky, :]
        self.masked_gr_z = masked_gr_z[:, ~masky, :]

    def _get_default_value(self, option):
        if self.participant_id is not None:
            if self.participant_id not in self.configuration:
                self.configuration[self.participant_id] = {}
            val = self.configuration[self.participant_id][option]
        else:
            val = self.configuration[self.participant_id][option]
        return val
Example #26
0
class MGridCellSpec(HasTraits):
    '''
    '''

    geo_type = Enum("Triangle", "Diamond")

    node_coords = List([[-1, -1, 0], [1, -1, 0], [1, 1, 0], [1, 1, 1],
                        [-1, 1, -1], [-1 / 2., 1, 0], [0., 1, 0],
                        [1 / 2., 1, 0]])

    xnode_coords = List([[-1, 0, -1], [1, 0, -1], [1, 0, 1], [-1, 0, 1]])
    #-------------------------------------------------------------------
    # Visualization-related specification
    #-------------------------------------------------------------------
    cell_lines = Array(int, value=[[0, 1], [1, 2], [2, 0]])

    @on_trait_change('geo_type')
    def _reset_node_coords(self):
        if self.geo_type == "Triangle":
            # three points selected like triangle
            self.node_coords = [
                [-1, -1, -1],
                [1, 0, 0],
                [1, 1, 1],
            ]
            # Triangle:
            self.cell_lines = [[0, 1], [1, 2], [2, 0]]
            cell_faces = [[0, 1, 2]]
        else:
            #Diamond
            self.node_coords = [
                [0, -1, -1],
                [1, 0, -1],
                [0, 1, -1],
                [-1, 0, -1],
                [0, -1, 1],
                [1, 0, 1],
                [0, 1, 1],
                [-1, 0, 1],
                [1, -1, 0],
                [-1, -1, 0],
                [1, 1, 0],
                [-1, 1, 0],
            ]
            # Diamand
            self.cell_lines = [[0, 1], [1, 2], [2, 3], [3, 0], [4, 5], [5, 6],
                               [6, 7], [7, 4], [0, 8], [8, 4], [4, 9], [9, 0],
                               [2, 10], [10, 6], [6, 11], [11, 2], [9, 3],
                               [3, 11], [11, 7], [7, 9], [8, 1], [1, 10],
                               [10, 5], [5, 8]]

            cell_faces = [[0, 3, 2, 1], [4, 5, 6, 7], [0, 8, 4, 9],
                          [2, 10, 6, 11], [9, 3, 11, 7], [8, 1, 10, 5]]

    traits_view = View(Item('geo_type'),
                       Item('node_coords', style='readonly'),
                       Item('cell_lines', style='readonly'),
                       resizable=True,
                       height=0.5,
                       width=0.5)

    _node_array = Property(Array('float_'), depends_on='node_coords')

    @cached_property
    def _get__node_array(self):
        '''Get the node array as float_
        '''
        # check that the nodes are equidistant
        return array(self.node_coords, 'float_')

    def get_n_dims(self):
        '''Get the number of dimension of the cell
        '''
        return self._node_array.shape[1]

    def get_cell_shape(self):
        '''Get the shape of the cell grid.
        '''
        npoints = self._node_array.shape[1]
        narray = self._node_array

        return array([len(unique(narray[:, i])) for i in range(npoints)],
                     dtype=int)

    def get_cell_slices(self):
        '''Get slices for the generation of the cell grid.
        '''
        ndims = self.get_n_dims()
        narray = self._node_array
        return tuple([
            slice(
                min(narray[:, i]),
                max(narray[:, i]),
                complex(0, len(unique(narray[:, i]))),
            ) for i in range(ndims)
        ])
Example #27
0
class Saturation(ManagedJob, GetSetItemsMixin):
    """
    Measures saturation curves.
    
    written by: [email protected]
    last modified: 2012-08-17
    """

    v_begin = Range(low=0.,
                    high=5.,
                    value=0.,
                    desc='begin [V]',
                    label='begin [V]',
                    mode='text',
                    auto_set=False,
                    enter_set=True)
    v_end = Range(low=0.,
                  high=5.,
                  value=5.,
                  desc='end [V]',
                  label='end [V]',
                  mode='text',
                  auto_set=False,
                  enter_set=True)
    v_delta = Range(low=0.,
                    high=5.,
                    value=.1,
                    desc='delta [V]',
                    label='delta [V]',
                    mode='text',
                    auto_set=False,
                    enter_set=True)

    seconds_per_point = Range(low=1e-3,
                              high=1000.,
                              value=1.,
                              desc='Seconds per point',
                              label='Seconds per point',
                              mode='text',
                              auto_set=False,
                              enter_set=True)

    voltage = Array()
    power = Array()
    rate = Array()

    plot_data = Instance(ArrayPlotData)
    plot = Instance(Plot)

    get_set_items = [
        '__doc__', 'v_begin', 'v_end', 'v_delta', 'seconds_per_point',
        'voltage', 'power', 'rate'
    ]

    traits_view = View(VGroup(
        HGroup(
            Item('submit_button', show_label=False),
            Item('remove_button', show_label=False),
            Item('priority'),
            Item('state', style='readonly'),
        ),
        HGroup(
            Item('v_begin'),
            Item('v_end'),
            Item('v_delta'),
        ),
        HGroup(Item('seconds_per_point'), ),
        Item('plot',
             editor=ComponentEditor(),
             show_label=False,
             resizable=True),
    ),
                       menubar=MenuBar(
                           Menu(Action(action='savePlot',
                                       name='Save Plot (.png)'),
                                Action(action='save',
                                       name='Save (.pyd or .pys)'),
                                Action(action='load', name='Load'),
                                Action(action='_on_close', name='Quit'),
                                name='File')),
                       title='Saturation',
                       buttons=[],
                       resizable=True,
                       handler=CustomHandler)

    def __init__(self):
        super(Saturation, self).__init__()
        self._create_plot()
        self.on_trait_change(self._update_index, 'power', dispatch='ui')
        self.on_trait_change(self._update_value, 'rate', dispatch='ui')

    def _run(self):

        try:
            self.state = 'run'
            voltage = np.arange(self.v_begin, self.v_end, self.v_delta)

            power = np.zeros_like(voltage)
            rate = np.zeros_like(voltage)

            counter_0 = TimeTagger.Countrate(0)
            counter_1 = TimeTagger.Countrate(1)

            for i, v in enumerate(voltage):
                Laser().voltage = v
                power[i] = PowerMeter().getPower()
                counter_0.clear()
                counter_1.clear()
                self.thread.stop_request.wait(self.seconds_per_point)
                if self.thread.stop_request.isSet():
                    logging.getLogger().debug('Caught stop signal. Exiting.')
                    self.state = 'idle'
                    break
                rate[i] = counter_0.getData() + counter_1.getData()
                power[i] = PowerMeter().getPower()
            else:
                self.state = 'done'

            del counter_0
            del counter_1

            self.voltage = voltage
            self.power = power
            self.rate = rate

        finally:
            self.state = 'idle'

    def _create_plot(self):
        plot_data = ArrayPlotData(
            power=np.array(()),
            rate=np.array(()),
        )
        plot = Plot(plot_data, padding=8, padding_left=64, padding_bottom=64)
        plot.plot(('power', 'rate'), color='blue')
        plot.index_axis.title = 'Power [mW]'
        plot.value_axis.title = 'rate [kcounts/s]'
        self.plot_data = plot_data
        self.plot = plot

    def _update_index(self, new):
        self.plot_data.set_data('power', new * 1e3)

    def _update_value(self, new):
        self.plot_data.set_data('rate', new * 1e-3)

    def save_plot(self, filename):
        self.save_figure(self.plot, filename)
class PlotExample(HasTraits):

    # 1D arrays of coordinates where a line should be drawn.
    x1 = Array(dtype=np.float64)
    y1 = Array(dtype=np.float64)
    x2 = Array(dtype=np.float64)
    y2 = Array(dtype=np.float64)

    time_plot = Instance(Plot)

    time_plot_data = Instance(ArrayPlotData)

    # These is used in `time_plot` to mark special times.
    line_overlay1 = Instance(CoordinateLineOverlay)
    line_overlay2 = Instance(CoordinateLineOverlay)

    traits_view = \
        View(
            HGroup(
                UItem('time_plot', editor=ComponentEditor(height=20)),
            ),
            title="Demo",
            width=1000, height=640, resizable=True,
        )

    #------------------------------------------------------------------------
    # Trait change handlers
    #------------------------------------------------------------------------

    def _x1_changed(self):
        self.line_overlay1.index_data = self.x1

    def _y1_changed(self):
        self.line_overlay1.value_data = self.y1

    def _x2_changed(self):
        self.line_overlay2.index_data = self.x2

    def _y2_changed(self):
        self.line_overlay2.value_data = self.y2

    #------------------------------------------------------------------------
    # Trait defaults
    #------------------------------------------------------------------------

    def _x1_default(self):
        return np.array([1.8, 8.4])

    def _x2_default(self):
        return np.array([3.25])

    def _y2_default(self):
        return np.array([5.25])

    def _time_plot_data_default(self):
        t = np.linspace(0, 10, 201)
        y = (0.5 * t + 0.1 * np.sin(0.4 * 2 * np.pi * t) + 0.3 * (t + 2) *
             (8 - t) * np.cos(0.33 * 2 * np.pi * t))
        data = ArrayPlotData(t=t, y=y)
        return data

    def _time_plot_default(self):
        time_plot = Plot(self.time_plot_data)

        time_plot.plot(('t', 'y'))

        time_plot.index_axis.title = "Time"

        time_plot.tools.append(PanTool(time_plot))

        zoomtool = ZoomTool(time_plot, drag_button='right', always_on=True)
        time_plot.overlays.append(zoomtool)

        lines1 = CoordinateLineOverlay(component=time_plot,
                                       index_data=self.x1,
                                       value_data=self.y1,
                                       color=(0.75, 0.25, 0.25, 0.75),
                                       line_style='dash',
                                       line_width=1)
        time_plot.underlays.append(lines1)
        self.line_overlay1 = lines1

        lines2 = CoordinateLineOverlay(component=time_plot,
                                       index_data=self.x2,
                                       value_data=self.y2,
                                       color=(0.2, 0.5, 1.0, 0.75),
                                       line_width=3)
        time_plot.underlays.append(lines2)
        self.line_overlay2 = lines2

        return time_plot
Example #29
0
class FETSLSEval(FETSEval):

    x_slice = slice(0, 0)
    parent_fets = Instance(FETSEval)

    nip_disc = Int(0)  #number of integration points on the discontinuity

    def setup(self, sctx, n_ip):
        '''
        overloading the default method
        mats state array has to account for different number of ip in elements
        Perform the setup in the all integration points.
        TODO: original setup can be used after adaptation the ip_coords param
        '''
        #        print 'n_ip ', n_ip
        #        print 'self.m_arr_size ',self.m_arr_size
        #        print 'shape ',sctx.elem_state_array.shape
        for i in range(n_ip):
            sctx.mats_state_array = sctx.elem_state_array[(
                i * self.m_arr_size):((i + 1) * self.m_arr_size)]
            self.mats_eval.setup(sctx)

    n_nodes = Property  #TODO: define dependencies

    @cached_property
    def _get_n_nodes(self):
        return self.parent_fets.n_e_dofs / self.parent_fets.n_nodal_dofs

    #dots_class = DOTSUnstructuredEval
    dots_class = Class(DOTSEval)

    int_order = Int(1)

    mats_eval = Delegate('parent_fets')
    mats_eval_pos = Trait(None, Instance(IMATSEval))
    mats_eval_neg = Trait(None, Instance(IMATSEval))
    mats_eval_disc = Trait(None, Instance(IMATSEval))
    dim_slice = Delegate('parent_fets')

    dof_r = Delegate('parent_fets')
    geo_r = Delegate('parent_fets')
    n_nodal_dofs = Delegate('parent_fets')
    n_e_dofs = Delegate('parent_fets')

    get_dNr_mtx = Delegate('parent_fets')
    get_dNr_geo_mtx = Delegate('parent_fets')

    get_N_geo_mtx = Delegate('parent_fets')

    def get_B_mtx(self, r_pnt, X_mtx, node_ls_values, r_ls_value):
        B_mtx = self.parent_fets.get_B_mtx(r_pnt, X_mtx)
        return B_mtx

    def get_u(self, sctx, u):
        N_mtx = self.parent_fets.get_N_mtx(sctx.loc)
        return dot(N_mtx, u)

    def get_eps_eng(self, sctx, u):
        B_mtx = self.parent_fets.get_B_mtx(sctx.loc, sctx.X)
        return dot(B_mtx, u)

    dof_r = Delegate('parent_fets')
    geo_r = Delegate('parent_fets')

    node_ls_values = Array(float)

    tri_subdivision = Int(0)

    def get_triangulation(self, point_set):
        dim = point_set[0].shape[1]
        n_add = 3 - dim
        if dim == 1:  #sideway for 1D
            structure = [
                array([
                    min(point_set[0]),
                    max(point_set[0]),
                    min(point_set[1]),
                    max(point_set[1])
                ],
                      dtype=float),
                array([[0, 1], [2, 3]], dtype=int)
            ]
            return structure
        points_list = []
        triangles_list = []
        point_offset = 0
        for pts in point_set:
            if self.tri_subdivision == 1:
                new_pt = average(pts, 0)
                pts = vstack((pts, new_pt))
            if n_add > 0:
                points = hstack(
                    [pts, zeros([pts.shape[0], n_add], dtype='float_')])
            # Create a polydata with the points we just created.
            profile = tvtk.PolyData(points=points)

            # Perform a 2D Delaunay triangulation on them.
            delny = tvtk.Delaunay2D(input=profile, offset=1.e1)
            tri = delny.output
            tri.update()  #initiate triangulation
            triangles = array(tri.polys.data, dtype=int_)
            pt = tri.points.data
            tri = (triangles.reshape((triangles.shape[0] / 4), 4))[:, 1:]
            points_list += list(pt)
            triangles_list += list(tri + point_offset)
            point_offset += len(unique(tri))  #Triangulation
        points = array(points_list)
        triangles = array(triangles_list)
        return [points, triangles]

    vtk_point_ip_map = Property(Array(Int))

    def _get_vtk_point_ip_map(self):
        '''
        mapping of the visualization point to the integration points
        according to mutual proximity in the local coordinates
        '''
        vtk_pt_arr = zeros((1, 3), dtype='float_')
        ip_map = zeros(self.vtk_r.shape[0], dtype='int_')
        for i, vtk_pt in enumerate(self.vtk_r):
            vtk_pt_arr[0, self.dim_slice] = vtk_pt
            # get the nearest ip_coord
            ip_map[i] = argmin(cdist(vtk_pt_arr, self.ip_coords))
        return array(ip_map)

    def get_ip_coords(self, int_triangles, int_order):
        '''Get the array of integration points'''
        gps = []
        points, triangles = int_triangles
        if triangles.shape[1] == 1:  #0D - points
            if int_order == 1:
                gps.append(points[0])
            else:
                raise TraitError, 'does not make sense'
        elif triangles.shape[1] == 2:  #1D - lines
            if int_order == 1:
                for id in triangles:
                    gp = average(points[ix_(id)], 0)
                    gps.append(gp)
            elif int_order == 2:
                weigths = array([[0.21132486540518713, 0.78867513459481287],
                                 [0.78867513459481287, 0.21132486540518713]])
                for id in triangles:
                    gps += average( points[ix_( id )], 0, weigths[0] ), \
                            average( points[ix_( id )], 0, weigths[1] )
            else:
                raise NotImplementedError
        elif triangles.shape[1] == 3:  #2D - triangles
            if int_order == 1:
                for id in triangles:
                    gp = average(points[ix_(id)], 0)
                    #print "gp ",gp
                    gps.append(gp)
            elif int_order == 2:
                raise NotImplementedError
            elif int_order == 3:
                weigths = array([[0.6, 0.2, 0.2], [0.2, 0.6, 0.2],
                                 [0.2, 0.2, 0.6]])
                for id in triangles:
                    gps += average( points[ix_( id )], 0 ), \
                        average( points[ix_( id )], 0, weigths[0] ), \
                        average( points[ix_( id )], 0, weigths[1] ), \
                        average( points[ix_( id )], 0, weigths[2] )

            elif int_order == 4:
                raise NotImplementedError
            elif int_order == 5:
                weigths = array( [[0.0597158717, 0.4701420641, 0.4701420641], \
                                 [0.4701420641, 0.0597158717, 0.4701420641], \
                                 [0.4701420641, 0.4701420641, 0.0597158717], \
                                 [0.7974269853, 0.1012865073, 0.1012865073], \
                                 [0.1012865073, 0.7974269853, 0.1012865073], \
                                 [0.1012865073, 0.1012865073, 0.7974269853]] )
                for id in triangles:
                    weigts_sum = False  #for debug
                    gps += average( points[ix_( id )], 0 ), \
                         average( points[ix_( id )], 0, weigths[0], weigts_sum ), \
                         average( points[ix_( id )], 0, weigths[1], weigts_sum ), \
                         average( points[ix_( id )], 0, weigths[2], weigts_sum ), \
                         average( points[ix_( id )], 0, weigths[3], weigts_sum ), \
                         average( points[ix_( id )], 0, weigths[4], weigts_sum ), \
                         average( points[ix_( id )], 0, weigths[5], weigts_sum )
            else:
                raise NotImplementedError
        elif triangles.shape[1] == 4:  #3D - tetrahedrons
            raise NotImplementedError
        else:
            raise TraitError, 'unsupported geometric form with %s nodes ' % triangles.shape[
                1]
        return array(gps, dtype='float_')

    def get_ip_weights(self, int_triangles, int_order):
        '''Get the array of integration points'''
        gps = []
        points, triangles = int_triangles
        if triangles.shape[1] == 1:  #0D - points
            if int_order == 1:
                gps.append(1.)
            else:
                raise TraitError, 'does not make sense'
        elif triangles.shape[1] == 2:  #1D - lines
            if int_order == 1:
                for id in triangles:
                    r_pnt = points[ix_(id)]
                    J_det_ip = norm(r_pnt[1] - r_pnt[0]) * 0.5
                    gp = 2. * J_det_ip
                    gps.append(gp)
            elif int_order == 2:
                for id in triangles:
                    r_pnt = points[ix_(id)]
                    J_det_ip = norm(r_pnt[1] - r_pnt[0]) * 0.5
                    gps += J_det_ip, J_det_ip
            else:
                raise NotImplementedError
        elif triangles.shape[1] == 3:  #2D - triangles
            if int_order == 1:
                for id in triangles:
                    r_pnt = points[ix_(id)]
                    J_det_ip = self._get_J_det_ip(r_pnt)
                    gp = 1. * J_det_ip
                    #print "gp ",gp
                    gps.append(gp)
            elif int_order == 2:
                raise NotImplementedError
            elif int_order == 3:
                for id in triangles:
                    r_pnt = points[ix_(id)]
                    J_det_ip = self._get_J_det_ip(r_pnt)
                    gps += -0.5625 * J_det_ip, \
                            0.52083333333333337 * J_det_ip, \
                            0.52083333333333337 * J_det_ip, \
                            0.52083333333333337 * J_det_ip
            elif int_order == 4:
                raise NotImplementedError
            elif int_order == 5:
                for id in triangles:
                    r_pnt = points[ix_(id)]
                    J_det_ip = self._get_J_det_ip(r_pnt)
                    gps += 0.225 * J_det_ip, 0.1323941527 * J_det_ip, \
                            0.1323941527 * J_det_ip, 0.1323941527 * J_det_ip, \
                            0.1259391805 * J_det_ip, 0.1259391805 * J_det_ip, \
                            0.1259391805 * J_det_ip
            else:
                raise NotImplementedError
        elif triangles.shape[1] == 4:  #3D - tetrahedrons
            raise NotImplementedError
        else:
            raise TraitError, 'unsupported geometric form with %s nodes ' % triangles.shape[
                1]
        return array(gps, dtype='float_')

    def _get_J_det_ip(self, r_pnt):
        '''
        Helper function 
        just for 2D
        #todo:3D
        @param r_pnt:
        '''
        dNr_geo = self.dNr_geo_triangle
        return det(dot(dNr_geo,
                       r_pnt[:, :2])) / 2.  #factor 2 due to triangular form

    dNr_geo_triangle = Property(Array(float))

    @cached_property
    def _get_dNr_geo_triangle(self):
        dN_geo = array([[-1., 1., 0.], [-1., 0., 1.]], dtype='float_')
        return dN_geo

    def get_corr_pred(self,
                      sctx,
                      u,
                      du,
                      tn,
                      tn1,
                      u_avg=None,
                      B_mtx_grid=None,
                      J_det_grid=None,
                      ip_coords=None,
                      ip_weights=None):
        '''
        Corrector and predictor evaluation.

        @param u current element displacement vector
        '''
        if J_det_grid == None or B_mtx_grid == None:
            X_mtx = sctx.X

        show_comparison = True
        if ip_coords == None:
            ip_coords = self.ip_coords
            show_comparison = False
        if ip_weights == None:
            ip_weights = self.ip_weights

        ### Use for Jacobi Transformation

        n_e_dofs = self.n_e_dofs
        K = zeros((n_e_dofs, n_e_dofs))
        F = zeros(n_e_dofs)
        sctx.fets_eval = self
        ip = 0

        for r_pnt, wt in zip(ip_coords, ip_weights):
            #r_pnt = gp[0]
            sctx.r_pnt = r_pnt
            #caching cannot be switched off in the moment
            #            if J_det_grid == None:
            #                J_det = self._get_J_det( r_pnt, X_mtx )
            #            else:
            #                J_det = J_det_grid[ip, ... ]
            #            if B_mtx_grid == None:
            #                B_mtx = self.get_B_mtx( r_pnt, X_mtx )
            #            else:
            #                B_mtx = B_mtx_grid[ip, ... ]
            J_det = J_det_grid[ip, ...]
            B_mtx = B_mtx_grid[ip, ...]

            eps_mtx = dot(B_mtx, u)
            d_eps_mtx = dot(B_mtx, du)
            sctx.mats_state_array = sctx.elem_state_array[ip *
                                                          self.m_arr_size:(ip +
                                                                           1) *
                                                          self.m_arr_size]
            #print 'elem state ', sctx.elem_state_array
            #print 'mats state ', sctx.mats_state_array
            sctx.r_ls = sctx.ls_val[ip]
            sig_mtx, D_mtx = self.get_mtrl_corr_pred(sctx, eps_mtx, d_eps_mtx,
                                                     tn, tn1)
            k = dot(B_mtx.T, dot(D_mtx, B_mtx))
            k *= (wt * J_det)
            K += k
            f = dot(B_mtx.T, sig_mtx)
            f *= (wt * J_det)
            F += f
            ip += 1

        return F, K

    def get_J_det(self, r_pnt, X_mtx, ls_nodes,
                  ls_r):  #unified interface for caching
        return array(self._get_J_det(r_pnt, X_mtx), dtype='float_')

    def get_mtrl_corr_pred(self, sctx, eps_mtx, d_eps, tn, tn1):
        ls = sctx.r_ls
        if ls == 0. and self.mats_eval_disc:
            sig_mtx, D_mtx = self.mats_eval_disc.get_corr_pred(
                sctx,
                eps_mtx,
                d_eps,
                tn,
                tn1,
            )
        elif ls > 0. and self.mats_eval_pos:
            sig_mtx, D_mtx = self.mats_eval_pos.get_corr_pred(
                sctx,
                eps_mtx,
                d_eps,
                tn,
                tn1,
            )
        elif ls < 0. and self.mats_eval_neg:
            sig_mtx, D_mtx = self.mats_eval_neg.get_corr_pred(
                sctx,
                eps_mtx,
                d_eps,
                tn,
                tn1,
            )
        else:
            sig_mtx, D_mtx = self.mats_eval.get_corr_pred(
                sctx,
                eps_mtx,
                d_eps,
                tn,
                tn1,
            )
        return sig_mtx, D_mtx
Example #30
0
class Image(Component):
    """ Component that displays a static image

    This is extremely simple right now.  By default it will draw the array into
    the entire region occupied by the component, stretching or shrinking as
    needed.  By default the bounds are set to the width and height of the data
    array, and we provide the same information to constraints-based layout
    with the layout_size_hint trait.

    """

    #: the image data as an array
    data = Array(shape=(None, None, (3, 4)), dtype='uint8')

    #: the format of the image data (eg. RGB vs. RGBA)
    format = Property(Enum('rgb24', 'rgba32'), depends_on='data')

    #: the size-hint for constraints-based layout
    layout_size_hint = Property(data, depends_on='data')

    #: the image as an Image GC
    _image = Property(Instance(GraphicsContext), depends_on='data')

    @classmethod
    def from_file(cls, filename, **traits):
        from PIL import Image
        from numpy import asarray
        data = asarray(Image.open(filename))
        return cls(data=data, **traits)

    def __init__(self, data, **traits):
        # the default bounds are the size of the image
        traits.setdefault('bounds', data.shape[1::-1])
        super(Image, self).__init__(data=data, **traits)

    def _draw_mainlayer(self, gc, view_bounds=None, mode="normal"):
        """ Draws the image. """
        with gc:
            gc.draw_image(self._image,
                          (self.x, self.y, self.width, self.height))

    @cached_property
    def _get_format(self):
        if self.data.shape[-1] == 3:
            return 'rgb24'
        elif self.data.shape[-1] == 4:
            return 'rgba32'
        else:
            raise ValueError('Data array not correct shape')

    @cached_property
    def _get_layout_size_hint(self):
        return self.data.shape[1::-1]

    @cached_property
    def _get__image(self):
        if not self.data.flags['C_CONTIGUOUS']:
            data = self.data.copy()
        else:
            data = self.data
        image_gc = GraphicsContext(data, pix_format=self.format)
        return image_gc