class FitVariable(traits.HasTraits): """represents a variable in a fit. E.g. the standard deviation in a gaussian fit""" name = traits.Str initialValue = traits.Float calculatedValue = traits.Float stdevError = traits.Float traits_view = traitsui.View( traitsui.HGroup( traitsui.Item("name", show_label=False, style="readonly", width=0.2, resizable=True), traitsui.Item("initialValue", label="initial", show_label=True, resizable=True), traitsui.Item("calculatedValue", label="calculated", show_label=True, format_str="%G", style="readonly", width=0.2, resizable=True), traitsui.Item("stdevError", show_label=False, format_str=u"\u00B1%G", style="readonly", resizable=True)))
def make_view(self): return ui.View( ui.HGroup( ui.Item(name='v_list', editor=ui.EnumEditor(values=[str(v) for v in self.v_list])), ui.Item(name='plot', label="", editor=enable.ComponentEditor(), show_label=False) ) )
class CalculatedParameter(traits.HasTraits): """represents a number calculated from a fit. e.g. atom number """ name = traits.Str value = traits.Float traits_view = traitsui.View( traitsui.HGroup( traitsui.Item("name", show_label=False, style="readonly"), traitsui.Item("value", show_label=False, format_str="%G", style="readonly")))
class IntRangeFeature(traits.HasTraits): """ Defines a feature that is settable by slider """ value = traits.Range('low','high','value_') value_ = traits.CInt(0.) low = traits.CInt(-10000.) high = traits.CInt(10000.) is_settable = traits.Bool(False) id = traits.Property(depends_on = 'name') index = traits.Int(0) name = 'gain' view = ui.View(ui.Item('value', show_label = False, style = 'custom')) def _get_id(self): return _SINGLE_VALUED_FEATURES.get(self.name)
class ROI(traits.HasTraits): top = create_int_multirange_feature('top',0) left = create_int_multirange_feature('left',1) width = create_int_multirange_feature('width',2) height = create_int_multirange_feature('height',3) values = traits.Property(traits.Int, depends_on = 'top.value,left.value,width.value,height.value') def _get_values(self): return self.top.value, self.left.value, self.width.value, self.height.value view = ui.View( ui.Item('top', style = 'custom'), ui.Item('left', style = 'custom'), ui.Item('width', style = 'custom'), ui.Item('height', style = 'custom'), )
class ExampleScene(TracerScene): source_y = t_api.Range(0., 5., 2.) source_z = t_api.Range(0., 5., 1.) def __init__(self): # The energy bundle we'll use for now: nrm = 1 / (N.sqrt(2)) direct = N.c_[[0, -nrm, nrm], [0, 0, -1]] position = N.tile(N.c_[[0, self.source_y, self.source_z]], (1, 2)) self.bund = RayBundle(vertices=position, directions=direct, energy=N.r_[1, 1]) # The assembly for ray tracing: rot1 = N.dot(G.rotx(N.pi / 4)[:3, :3], G.roty(N.pi)[:3, :3]) surf1 = rect_one_sided_mirror(width=10, height=10) surf1.set_rotation(rot1) surf2 = rect_one_sided_mirror(width=10, height=10) self.assembly = Assembly(objects=[surf1, surf2]) TracerScene.__init__(self, self.assembly, self.bund) @t_api.on_trait_change('_scene.activated') def initialize_camere(self): self._scene.mlab.view(0, -90) self._scene.mlab.roll(0) @t_api.on_trait_change('source_y, source_z') def bundle_move(self): position = N.tile(N.c_[[0, self.source_y, self.source_z]], (1, 2)) self.bund.set_vertices(position) self.plot_ray_trace() view = tui.View( tui.Item('_scene', editor=SceneEditor(scene_class=MayaviScene), height=400, width=300, show_label=False), tui.HGroup('-', 'source_y', 'source_z'))
class PrtEvent(traits.HasTraits): """A baseclass for events. Named PrtEvent to easily distinguish from other sorts of events, such as SimPy's SimEvents.""" time = traits.Float # Time that event starts or arrives ID = traits.Int label = traits.Str # default view traits_view = ui.View(ui.Item(name='label')) def __init__(self, time, ID, label='', **tr): traits.HasTraits.__init__(self, **tr) self.time = time self.ID = ID self.label = (label if label else self.__class__.__name__ + str(ID)) def __eq__(self, other): if isinstance(other, PrtEvent): return self.ID == other.ID else: return False def __cmp__(self, other): """Compare based on time (and ID for equality)""" if isinstance(other, PrtEvent): if self.ID == other.ID: return 0 else: return cmp(self.time, other.time) elif isinstance(other, basestring): # compared to a string return cmp(self.label, other) def __str__(self): """Use custom label, or use class's name + ID. Will work for subclasses.""" return self.label def __hash__(self): return hash(self.ID)
def chooseVariables(self): """Opens a dialog asking user to select columns from a data File that has been selected. THese are then returned as a string suitable for Y cols input""" columns = self.physics.variables.keys() columns.sort() values = zip(range(0, len(columns)), columns) checklist_group = traitsui.Group( '10', # insert vertical space traitsui.Label('Select the additional variables you wish to log'), traitsui.UItem('columns', style='custom', editor=traitsui.CheckListEditor(values=values, cols=6)), traitsui.UItem('selectAllButton')) traits_view = traitsui.View(checklist_group, title='CheckListEditor', buttons=['OK'], resizable=True, kind='livemodal') col = ColumnEditor(numberOfColumns=len(columns)) try: col.columns = [ columns.index(varName) for varName in self.xmlLogVariables ] except Exception as e: logger.error( "couldn't selected correct variable names. Returning empty selection" ) logger.error("%s " % e.message) col.columns = [] col.edit_traits(view=traits_view) logger.debug("value of columns selected = %s ", col.columns) logger.debug("value of columns selected = %s ", [columns[i] for i in col.columns]) return [columns[i] for i in col.columns]
class Passenger(PrtEvent): """A passenger.""" mass = traits.Int() _loc = traits.Either(traits.Instance('pyprt.sim.station.Station'), traits.Instance('pyprt.sim.vehicle.BaseVehicle'), None) traits_view = ui.View( ui.Item(name='label'), ui.Item(name='ID'), ui.Item(name='loc'), ui.Item(name='mass'), # in kg ui.Item(name='trip_success'), ui.Item(name='wait_time', format_func=sec_to_hms), ui.Item(name='walk_time', format_func=sec_to_hms), ui.Item(name='ride_time', format_func=sec_to_hms), ui.Item(name='will_share'), ui.Item(name='src_station'), ui.Item(name='dest_station'), ui.Item(name='load_delay', format_func=sec_to_hms), ui.Item(name='unload_delay', format_func=sec_to_hms), style='readonly', handler=NoWritebackOnCloseHandler()) ## # Subset of passenger data in table format. ## table_editor = ui.TableEditor( ## columns = [ui_tc.ObjectColumn(name='ID', label='ID'), ## ui_tc.ObjectColumn(name='src_station', label='Origin'), ## ui_tc.ObjectColumn(name='dest_station', label='Destination'), ## ui_tc.ExpressionColumn(label='Waiting', ## expression='sec_to_hms(object.wait_time)', ## globals={'sec_to_hms':sec_to_hms}, ## tooltip='Time spent waiting'), ## ui_tc.ExpressionColumn(label='Riding', ## expression='sec_to_hms(object.ride_time)', ## globals={'sec_to_hms':sec_to_hms}, ## tooltip='Time spent riding'), ## ui_tc.ExpressionColumn(label='Walking', ## expression='sec_to_hms(object.walk_time)', ## globals={'sec_to_hms':sec_to_hms}, ## tooltip='Time spent walking'), ## ui_tc.ExpressionColumn(label='Total', ## expression='sec_to_hms(object.total_time)', ## globals={'sec_to_hms':sec_to_hms}, ## tooltip='Total time spent on trip'), ## ui_tc.ObjectColumn(name='trip_success', label='Success', ## tooltip='Sucessfully reached destination'), ## ui_tc.ObjectColumn(name='loc', label='Current Location') ## ], ## other_columns = [ui_tc.ObjectColumn(name='label', label='Label'), ## ui_tc.ObjectColumn(name='will_share', label='Will Share', ## tooltip='Willing to share vehicle when destinations match'), ## ui_tc.ObjectColumn(name='load_delay', label='Load Delay', ## tooltip='Time that passenger takes to embark'), ## ui_tc.ObjectColumn(name='unload_delay', label='Unload Delay', ## tooltip='Time that passenger takes to disembark'), ## ui_tc.ObjectColumn(name='mass', label='Mass', ## tooltip='Includes luggage (kg)') ## ], ## # more... ## deletable = False, ## editable=False, ## sortable = True, ## sort_model = False, ## auto_size = True, ## orientation = 'vertical', ## show_toolbar = True, ## reorderable = False, ## rows = 15, ## row_factory = traits.This) def __init__(self, time, ID, src_station, dest_station, load_delay, unload_delay, will_share, mass): super(Passenger, self).__init__(time, ID) self.src_station = src_station self.dest_station = dest_station self.load_delay = load_delay self.unload_delay = unload_delay self.will_share = will_share # Willing to share pod (if same dest) self.mass = mass self.trip_success = False self._loc = src_station # For the following, where start and end are times in seconds, with 0 being the start of the sim. self._wait_times = [[time, None, self._loc] ] # contains triples: [[start, end, loc], ...] self._walk_times = [ ] # containing pairs: [[start, end], [start, end], ...] self._ride_times = [ ] # contains triples: [[start, end, vehicle], [start, end, vehicle], ...] self._start_time = time self._end_time = None @property def wait_time(self): # in seconds total = 0 for start, end, loc in self._wait_times: if end is None: total += Sim.now() - start else: total += end - start return total @property def ride_time(self): # in seconds total = 0 for start, end, vehicle in self._ride_times: if end is None: total += Sim.now() - start else: total += end - start return total @property def walk_time(self): # in seconds total = 0 for start, end in self._walk_times: if end is None: total += Sim.now() - start else: total += end - start return total @property def total_time(self): if self._end_time is None: return Sim.now() - self._start_time else: return self._end_time - self._start_time def get_loc(self): return self._loc def set_loc(self, loc): """Changes the loc, and keeps track of how much time is spent in each mode of transit: waiting, riding, or walking.""" ### Track time spent in each mode of transit ### if self._loc is None: # Was walking self._walk_times[-1][1] = Sim.now() elif hasattr(self._loc, 'vehicle_mass'): # was in vehicle self._ride_times[-1][1] = Sim.now() elif hasattr(self._loc, 'platforms'): # was at station self._wait_times[-1][1] = Sim.now() else: raise Exception("Unknown loc type") ### Note if trip is completed. ### if loc is self.dest_station: self._end_time = Sim.now() self.trip_success = True ### More time tracking ### if not self.trip_success: if loc is None: self._walk_times.append([Sim.now(), None]) elif hasattr(loc, 'vehicle_mass'): self._ride_times.append([Sim.now(), None, loc]) # isinstance(loc, BaseVehicle) elif hasattr(loc, 'platforms'): self._wait_times.append([Sim.now(), None, loc]) # isinstance(loc, TrackSegment) else: raise Exception("Unknown loc type") self._loc = loc loc = property( get_loc, set_loc, doc="loc is expected to be a Station, a " "Vehicle, or None (which indicates walking from one station " "to another). Setting the loc has side-effects, see set_loc.") def walk(self, origin_station, dest_station, travel_time, cmd_msg, cmd_id): assert self._loc is origin_station assert travel_time >= 0 assert isinstance(cmd_msg, api.CtrlCmdPassengerWalk) assert isinstance(cmd_id, int) self.loc = None common.AlarmClock(Sim.now() + travel_time, self._post_walk, dest_station, cmd_msg, cmd_id) def _post_walk(self, dest_station, cmd_msg, cmd_id): """Updates stats, changes location, and sends a SimCompletePassengerWalk message. To be called once the walk is complete.""" assert self._loc is None self.loc = dest_station msg = api.SimCompletePassengerWalk() msg.msgID = cmd_id msg.cmd.CopyFrom(cmd_msg) msg.time = Sim.now() common.interface.send(api.SIM_COMPLETE_PASSENGER_WALK, msg) def fill_PassengerStatus(self, ps): ps.pID = self.ID # I'd much rather use isinstance checks, but circular imports are killing me if self._loc is None: ps.loc_type = api.WALKING ps.locID = api.NONE_ID elif hasattr(self._loc, 'vehicle_mass'): # a vehicle ps.loc_type = api.VEHICLE ps.locID = self._loc.ID elif hasattr(self._loc, 'platforms'): # a station ps.loc_type = api.STATION ps.locID = self._loc.ID else: raise Exception, "Unknown passenger location type: %s" % self._loc ps.src_stationID = self.src_station.ID ps.dest_stationID = self.dest_station.ID ps.creation_time = self._start_time ps.mass = self.mass ps.trip_success = self.trip_success
class Fit(traits.HasTraits): name = traits.Str(desc="name of fit") function = traits.Str(desc="function we are fitting with all parameters") variablesList = traits.List(Parameter) calculatedParametersList = traits.List(CalculatedParameter) xs = None # will be a scipy array ys = None # will be a scipy array zs = None # will be a scipy array performFitButton = traits.Button("Perform Fit") getInitialParametersButton = traits.Button("Guess Initial Values") usePreviousFitValuesButton = traits.Button("Use Previous Fit") drawRequestButton = traits.Button("Draw Fit") setSizeButton = traits.Button("Set Initial Size") chooseVariablesButtons = traits.Button("choose logged variables") logLibrarianButton = traits.Button("librarian") logLastFitButton = traits.Button("log current fit") removeLastFitButton = traits.Button("remove last fit") autoFitBool = traits.Bool( False, desc= "Automatically perform this Fit with current settings whenever a new image is loaded" ) autoGuessBool = traits.Bool( False, desc= "Whenever a fit is completed replace the guess values with the calculated values (useful for increasing speed of the next fit)" ) autoDrawBool = traits.Bool( False, desc= "Once a fit is complete update the drawing of the fit or draw the fit for the first time" ) autoSizeBool = traits.Bool( False, desc= "If TOF variable is read from latest XML and is equal to 0.11ms (or time set in Physics) then it will automatically update the physics sizex and sizey with the Sigma x and sigma y from the gaussian fit" ) logBool = traits.Bool( False, desc="Log the calculated and fitted values with a timestamp") logName = traits.String( desc="name of the scan - will be used in the folder name") logDirectory = os.path.join("\\\\ursa", "AQOGroupFolder", "Experiment Humphry", "Data", "eagleLogs") latestSequence = os.path.join("\\\\ursa", "AQOGroupFolder", "Experiment Humphry", "Experiment Control And Software", "currentSequence", "latestSequence.xml") logFile = traits.File(desc="file path of logFile") logAnalyserBool = traits.Bool( False, desc="only use log analyser script when True") logAnalysers = [ ] #list containing full paths to each logAnalyser file to run logAnalyserDisplayString = traits.String( desc= "comma separated read only string that is a list of all logAnalyser python scripts to run. Use button to choose files" ) logAnalyserSelectButton = traits.Button("sel. analyser", image='@icons:function_node', style="toolbar") xmlLogVariables = [] imageInspectorReference = None #will be a reference to the image inspector fitting = traits.Bool(False) #true when performing fit fitted = traits.Bool( False) #true when current data displayed has been fitted fitSubSpace = traits.Bool( False) #true when current data displayed has been fitted startX = traits.Int startY = traits.Int endX = traits.Int endY = traits.Int fittingStatus = traits.Str() fitThread = None fitTimeLimit = traits.Float( 10.0, desc= "Time limit in seconds for fitting function. Only has an effect when fitTimeLimitBool is True" ) fitTimeLimitBool = traits.Bool( True, desc= "If True then fitting functions will be limited to time limit defined by fitTimeLimit " ) physics = traits.Instance( physicsProperties.physicsProperties.PhysicsProperties) #status strings notFittedForCurrentStatus = "Not Fitted for Current Image" fittedForCurrentImageStatus = "Fit Complete for Current Image" currentlyFittingStatus = "Currently Fitting..." failedFitStatus = "Failed to finish fit. See logger" timeExceededStatus = "Fit exceeded user time limit" lmfitModel = traits.Instance( lmfit.Model ) #reference to the lmfit model must be initialised in subclass mostRecentModelResult = None # updated to the most recent ModelResult object from lmfit when a fit thread is performed fitSubSpaceGroup = traitsui.VGroup( traitsui.Item("fitSubSpace", label="Fit Sub Space", resizable=True), traitsui.VGroup(traitsui.HGroup( traitsui.Item("startX", resizable=True), traitsui.Item("startY", resizable=True)), traitsui.HGroup(traitsui.Item("endX", resizable=True), traitsui.Item("endY", resizable=True)), visible_when="fitSubSpace"), label="Fit Sub Space", show_border=True) generalGroup = traitsui.VGroup(traitsui.Item("name", label="Fit Name", style="readonly", resizable=True), traitsui.Item("function", label="Fit Function", style="readonly", resizable=True), fitSubSpaceGroup, label="Fit", show_border=True) variablesGroup = traitsui.VGroup(traitsui.Item( "variablesList", editor=traitsui.ListEditor(style="custom"), show_label=False, resizable=True), show_border=True, label="parameters") derivedGroup = traitsui.VGroup(traitsui.Item( "calculatedParametersList", editor=traitsui.ListEditor(style="custom"), show_label=False, resizable=True), show_border=True, label="derived values") buttons = traitsui.VGroup( traitsui.HGroup( traitsui.Item("autoFitBool", label="Auto fit?", resizable=True), traitsui.Item("performFitButton", show_label=False, resizable=True)), traitsui.HGroup( traitsui.Item("autoGuessBool", label="Auto guess?", resizable=True), traitsui.Item("getInitialParametersButton", show_label=False, resizable=True)), traitsui.HGroup( traitsui.Item("autoDrawBool", label="Auto draw?", resizable=True), traitsui.Item("drawRequestButton", show_label=False, resizable=True)), traitsui.HGroup( traitsui.Item("autoSizeBool", label="Auto size?", resizable=True), traitsui.Item("setSizeButton", show_label=False, resizable=True)), traitsui.HGroup( traitsui.Item("usePreviousFitValuesButton", show_label=False, resizable=True))) logGroup = traitsui.VGroup(traitsui.HGroup( traitsui.Item("logBool", resizable=True), traitsui.Item("chooseVariablesButtons", show_label=False, resizable=True)), traitsui.HGroup( traitsui.Item("logName", resizable=True)), traitsui.HGroup( traitsui.Item("removeLastFitButton", show_label=False, resizable=True), traitsui.Item("logLastFitButton", show_label=False, resizable=True)), traitsui.HGroup( traitsui.Item("logAnalyserBool", label="analyser?", resizable=True), traitsui.Item("logAnalyserDisplayString", show_label=False, style="readonly", resizable=True), traitsui.Item("logAnalyserSelectButton", show_label=False, resizable=True)), label="Logging", show_border=True) actionsGroup = traitsui.VGroup(traitsui.Item("fittingStatus", style="readonly", resizable=True), logGroup, buttons, label="Fit Actions", show_border=True) traits_view = traitsui.View(traitsui.VGroup(generalGroup, variablesGroup, derivedGroup, actionsGroup), kind="subpanel") def __init__(self, **traitsDict): super(Fit, self).__init__(**traitsDict) self.startX = 0 self.startY = 0 self.lmfitModel = lmfit.Model(self.fitFunc) def _set_xs(self, xs): self.xs = xs def _set_ys(self, ys): self.ys = ys def _set_zs(self, zs): self.zs = zs def _fittingStatus_default(self): return self.notFittedForCurrentStatus def _getInitialValues(self): """returns ordered list of initial values from variables List """ return [_.initialValue for _ in self.variablesList] def _getParameters(self): """creates an lmfit parameters object based on the user input in variablesList """ return lmfit.Parameters( {_.name: _.parameter for _ in self.variablesList}) def _getCalculatedValues(self): """returns ordered list of fitted values from variables List """ return [_.calculatedValue for _ in self.variablesList] def _intelligentInitialValues(self): """If possible we can auto set the initial parameters to intelligent guesses user can always overwrite them """ self._setInitialValues(self._getIntelligentInitialValues()) def _get_subSpaceArrays(self): """returns the arrays of the selected sub space. If subspace is not activated then returns the full arrays""" if self.fitSubSpace: xs = self.xs[self.startX:self.endX] ys = self.ys[self.startY:self.endY] logger.info("xs array sliced length %s " % (xs.shape)) logger.info("ys array sliced length %s " % (ys.shape)) zs = self.zs[self.startY:self.endY, self.startX:self.endX] logger.info("zs sub space array %s,%s " % (zs.shape)) return xs, ys, zs else: return self.xs, self.ys, self.zs def _getIntelligentInitialValues(self): """If possible we can auto set the initial parameters to intelligent guesses user can always overwrite them """ logger.debug("Dummy function should not be called directly") return #in python this should be a pass statement. I.e. user has to overwrite this def fitFunc(self, data, *p): """Function that we are trying to fit to. """ logger.error("Dummy function should not be called directly") return #in python this should be a pass statement. I.e. user has to overwrite this def _setCalculatedValues(self, modelFitResult): """updates calculated values with calculated argument """ parametersResult = modelFitResult.params for variable in self.variablesList: variable.calculatedValue = parametersResult[variable.name].value def _setCalculatedValuesErrors(self, modelFitResult): """given the covariance matrix returned by scipy optimize fit convert this into stdeviation errors for parameters list and updated the stdevError attribute of variables""" parametersResult = modelFitResult.params for variable in self.variablesList: variable.stdevError = parametersResult[variable.name].stderr def _setInitialValues(self, guesses): """updates calculated values with calculated argument """ c = 0 for variable in self.variablesList: variable.initialValue = guesses[c] c += 1 def deriveCalculatedParameters(self): """Wrapper for subclass definition of deriving calculated parameters can put more general calls in here""" if self.fitted: self._deriveCalculatedParameters() def _deriveCalculatedParameters(self): """Should be implemented by subclass. should update all variables in calculate parameters list""" logger.error("Should only be called by subclass") return def _fit_routine(self): """This function performs the fit in an appropriate thread and updates necessary values when the fit has been performed""" self.fitting = True if self.fitThread and self.fitThread.isAlive(): logger.warning( "Fitting is already running. You should wait till this fit has timed out before a new thread is started...." ) #logger.warning("I will start a new fitting thread but your previous thread may finish at some undetermined time. you probably had bad starting conditions :( !") return self.fitThread = FitThread() #new fitting thread self.fitThread.fitReference = self self.fitThread.isCurrentFitThread = True # user can create multiple fit threads on a particular fit but only the latest one will have an effect in the GUI self.fitThread.start() self.fittingStatus = self.currentlyFittingStatus def _perform_fit(self): """Perform the fit using scipy optimise curve fit. We must supply x and y as one argument and zs as anothger. in the form xs: 0 1 2 0 1 2 0 ys: 0 0 0 1 1 1 2 zs: 1 5 6 1 9 8 2 Hence the use of repeat and tile in positions and unravel for zs initially xs,ys is a linspace array and zs is a 2d image array """ if self.xs is None or self.ys is None or self.zs is None: logger.warning( "attempted to fit data but had no data inside the Fit object. set xs,ys,zs first" ) return ([], []) params = self._getParameters() if self.fitSubSpace: #fit only the sub space #create xs, ys and zs which are appropriate slices of the arrays xs, ys, zs = self._get_subSpaceArrays() else: #fit the whole array of data (slower) xs, ys, zs = self.xs, self.ys, self.zs positions = scipy.array([ scipy.tile(xs, len(ys)), scipy.repeat(ys, len(xs)) ]) #for creating data necessary for gauss2D function if self.fitTimeLimitBool: modelFitResult = self.lmfitModel.fit(scipy.ravel(zs), positions=positions, params=params, iter_cb=self.getFitCallback( time.time())) else: #no iter callback modelFitResult = self.lmfitModel.fit(scipy.ravel(zs), positions=positions, params=params) return modelFitResult def getFitCallback(self, startTime): """returns the callback function that is called at every iteration of fit to check if it has been running too long""" def fitCallback(params, iter, resid, *args, **kws): """check the time and compare to start time """ if time.time() - startTime > self.fitTimeLimit: raise FitException("Fit time exceeded user limit") return fitCallback def _performFitButton_fired(self): self._fit_routine() def _getInitialParametersButton_fired(self): self._intelligentInitialValues() def _drawRequestButton_fired(self): """tells the imageInspector to try and draw this fit as an overlay contour plot""" self.imageInspectorReference.addFitPlot(self) def _setSizeButton_fired(self): """use the sigmaX and sigmaY from the current fit to overwrite the inTrapSizeX and inTrapSizeY parameters in the Physics Instance""" self.physics.inTrapSizeX = abs(self.sigmax.calculatedValue) self.physics.inTrapSizeY = abs(self.sigmay.calculatedValue) def _getFitFuncData(self): """if data has been fitted, this returns the zs data for the ideal fitted function using the calculated paramters""" positions = [ scipy.tile(self.xs, len(self.ys)), scipy.repeat(self.ys, len(self.xs)) ] #for creating data necessary for gauss2D function zsravelled = self.fitFunc(positions, *self._getCalculatedValues()) return zsravelled.reshape(self.zs.shape) def _logAnalyserSelectButton_fired(self): """open a fast file editor for selecting many files """ fileDialog = FileDialog(action="open files") fileDialog.open() if fileDialog.return_code == pyface.constant.OK: self.logAnalysers = fileDialog.paths logger.info("selected log analysers: %s " % self.logAnalysers) self.logAnalyserDisplayString = str( [os.path.split(path)[1] for path in self.logAnalysers]) def runSingleAnalyser(self, module): """runs the logAnalyser module calling the run function and returns the columnNames and values as a list""" exec("import logAnalysers.%s as currentAnalyser" % module) reload( currentAnalyser ) #in case it has changed..#could make this only when user requests #now the array also contains the raw image as this may be different to zs if you are using a processor if hasattr(self.imageInspectorReference, "rawImage"): rawImage = self.imageInspectorReference.rawImage else: rawImage = None return currentAnalyser.run([self.xs, self.ys, self.zs, rawImage], self.physics.variables, self.variablesList, self.calculatedParametersList) def runAnalyser(self): """ if logAnalyserBool is true we perform runAnalyser at the end of _log_fit runAnalyser checks that logAnalyser exists and is a python script with a valid run()function it then performs the run method and passes to the run function: -the image data as a numpy array -the xml variables dictionary -the fitted paramaters -the derived values""" for logAnalyser in self.logAnalysers: if not os.path.isfile(logAnalyser): logger.error( "attempted to runAnalyser but could not find the logAnalyser File: %s" % logAnalyser) return #these will contain the final column names and values finalColumns = [] finalValues = [] #iterate over each selected logAnalyser get the column names and values and add them to the master lists for logAnalyser in self.logAnalysers: directory, module = os.path.split(logAnalyser) module, ext = os.path.splitext(module) if ext != ".py": logger.error("file was not a python module. %s" % logAnalyser) else: columns, values = self.runSingleAnalyser(module) finalColumns.extend(columns) finalValues.extend(values) return finalColumns, finalValues def mostRecentModelFitReport(self): """returns the lmfit fit report of the most recent lmfit model results object""" if self.mostRecentModelResult is not None: return lmfit.fit_report(self.mostRecentModelResult) + "\n\n" else: return "No fit performed" def getCalculatedParameters(self): """useful for print returns tuple list of calculated parameter name and value """ return [(_.name, _.value) for _ in self.calculatedParametersList] def _log_fit(self): if self.logName == "": logger.warning("no log file defined. Will not log") return #generate folders if they don't exist logFolder = os.path.join(self.logDirectory, self.logName) if not os.path.isdir(logFolder): logger.info("creating a new log folder %s" % logFolder) os.mkdir(logFolder) imagesFolder = os.path.join(logFolder, "images") if not os.path.isdir(imagesFolder): logger.info("creating a new images Folder %s" % imagesFolder) os.mkdir(imagesFolder) commentsFile = os.path.join(logFolder, "comments.txt") if not os.path.exists(commentsFile): logger.info("creating a comments file %s" % commentsFile) open(commentsFile, "a+").close() #create a comments file in every folder! firstSequenceCopy = os.path.join(logFolder, "copyOfInitialSequence.ctr") if not os.path.exists(firstSequenceCopy): logger.info("creating a copy of the first sequence %s -> %s" % (self.latestSequence, firstSequenceCopy)) shutil.copy(self.latestSequence, firstSequenceCopy) if self.imageInspectorReference.model.imageMode == "process raw image": #if we are using a processor, save the details of the processor used to the log folder processorParamtersFile = os.path.join(logFolder, "processorOptions.txt") processorPythonScript = os.path.join(logFolder, "usedProcessor.py") #TODO! if not os.path.exists(processorParamtersFile): with open(processorParamtersFile, "a+") as processorParamsFile: string = str(self.imageInspectorReference.model. chosenProcessor) + "\n" string += str(self.imageInspectorReference.model.processor. optionsDict) processorParamsFile.write(string) logger.debug("finished all checks on log folder") #copy current image try: shutil.copy(self.imageInspectorReference.selectedFile, imagesFolder) except IOError as e: logger.error("Could not copy image. Got IOError: %s " % e.message) except Exception as e: logger.error("Could not copy image. Got %s: %s " % (type(e), e.message)) raise e logger.info("copying current image") self.logFile = os.path.join(logFolder, self.logName + ".csv") #analyser logic if self.logAnalyserBool: #run the analyser script as requested logger.info( "log analyser bool enabled... will attempt to run analyser script" ) analyserResult = self.runAnalyser() logger.info("analyser result = %s " % list(analyserResult)) if analyserResult is None: analyserColumnNames = [] analyserValues = [] #analyser failed. continue as if nothing happened else: analyserColumnNames, analyserValues = analyserResult else: #no analyser enabled analyserColumnNames = [] analyserValues = [] if not os.path.exists(self.logFile): variables = [_.name for _ in self.variablesList] calculated = [_.name for _ in self.calculatedParametersList] times = ["datetime", "epoch seconds"] info = ["img file name"] xmlVariables = self.xmlLogVariables columnNames = times + info + variables + calculated + xmlVariables + analyserColumnNames with open( self.logFile, 'ab+' ) as logFile: # note use of binary file so that windows doesn't write too many /r writer = csv.writer(logFile) writer.writerow(columnNames) #column names already exist so... logger.debug("copying current image") variables = [_.calculatedValue for _ in self.variablesList] calculated = [_.value for _ in self.calculatedParametersList] now = time.time() #epoch seconds timeTuple = time.localtime(now) date = time.strftime("%Y-%m-%dT%H:%M:%S", timeTuple) times = [date, now] info = [self.imageInspectorReference.selectedFile] xmlVariables = [ self.physics.variables[varName] for varName in self.xmlLogVariables ] data = times + info + variables + calculated + xmlVariables + analyserValues with open(self.logFile, 'ab+') as logFile: writer = csv.writer(logFile) writer.writerow(data) def _logLastFitButton_fired(self): """logs the fit. User can use this for non automated logging. i.e. log particular fits""" self._log_fit() def _removeLastFitButton_fired(self): """removes the last line in the log file """ logFolder = os.path.join(self.logDirectory, self.logName) self.logFile = os.path.join(logFolder, self.logName + ".csv") if self.logFile == "": logger.warning("no log file defined. Will not log") return if not os.path.exists(self.logFile): logger.error( "cant remove a line from a log file that doesn't exist") with open(self.logFile, 'r') as logFile: lines = logFile.readlines() with open(self.logFile, 'wb') as logFile: logFile.writelines(lines[:-1]) def saveLastFit(self): """saves result of last fit to a txt/csv file. This can be useful for live analysis or for generating sequences based on result of last fit""" try: with open( self.imageInspectorReference.cameraModel + "-" + self.physics.species + "-" + "lastFit.csv", "wb") as lastFitFile: writer = csv.writer(lastFitFile) writer.writerow(["time", time.time()]) for variable in self.variablesList: writer.writerow([variable.name, variable.calculatedValue]) for variable in self.calculatedParametersList: writer.writerow([variable.name, variable.value]) except Exception as e: logger.error("failed to save last fit to text file. message %s " % e.message) def _chooseVariablesButtons_fired(self): self.xmlLogVariables = self.chooseVariables() def _usePreviousFitValuesButton_fired(self): """update the guess initial values with the value from the last fit """ logger.info( "use previous fit values button fired. loading previous initial values" ) self._setInitialValues(self._getCalculatedValues()) def chooseVariables(self): """Opens a dialog asking user to select columns from a data File that has been selected. THese are then returned as a string suitable for Y cols input""" columns = self.physics.variables.keys() columns.sort() values = zip(range(0, len(columns)), columns) checklist_group = traitsui.Group( '10', # insert vertical space traitsui.Label('Select the additional variables you wish to log'), traitsui.UItem('columns', style='custom', editor=traitsui.CheckListEditor(values=values, cols=6)), traitsui.UItem('selectAllButton')) traits_view = traitsui.View(checklist_group, title='CheckListEditor', buttons=['OK'], resizable=True, kind='livemodal') col = ColumnEditor(numberOfColumns=len(columns)) try: col.columns = [ columns.index(varName) for varName in self.xmlLogVariables ] except Exception as e: logger.error( "couldn't selected correct variable names. Returning empty selection" ) logger.error("%s " % e.message) col.columns = [] col.edit_traits(view=traits_view) logger.debug("value of columns selected = %s ", col.columns) logger.debug("value of columns selected = %s ", [columns[i] for i in col.columns]) return [columns[i] for i in col.columns] def _logLibrarianButton_fired(self): """opens log librarian for current folder in logName box. """ logFolder = os.path.join(self.logDirectory, self.logName) if not os.path.isdir(logFolder): logger.error( "cant open librarian on a log that doesn't exist.... Could not find %s" % logFolder) return librarian = plotObjects.logLibrarian.Librarian(logFolder=logFolder) librarian.edit_traits()
class SegmentPlot(BaseXYPlot): """ A plot consisting of disconnected line segments. """ # The color of the line. color = black_color_trait # The color to use to highlight the line when selected. selected_color = ColorTrait("lightyellow") # The style of the selected line. selected_line_style = LineStyle("solid") # The name of the key in self.metadata that holds the selection mask metadata_name = Str("selections") # The thickness of the line. line_width = Float(1.0) # The line dash style. line_style = LineStyle # Traits UI View for customizing the plot. traits_view = tui.View(tui.Item("color", style="custom"), "line_width", "line_style", buttons=["OK", "Cancel"]) #------------------------------------------------------------------------ # Private traits #------------------------------------------------------------------------ # Cached list of non-NaN arrays of (x,y) data-space points; regardless of # self.orientation, this is always stored as (index_pt, value_pt). This is # different from the default BaseXYPlot definition. _cached_data_pts = List # Cached list of non-NaN arrays of (x,y) screen-space points. _cached_screen_pts = List def hittest(self, screen_pt, threshold=7.0): # NotImplemented return None def get_screen_points(self): self._gather_points() return [self.map_screen(ary) for ary in self._cached_data_pts] #------------------------------------------------------------------------ # Private methods; implements the BaseXYPlot stub methods #------------------------------------------------------------------------ def _gather_points(self): """ Collects the data points that are within the bounds of the plot and caches them. """ if self._cache_valid or not self.index or not self.value: return index = self.index.get_data() value = self.value.get_data() # Check to see if the data is completely outside the view region for ds, rng in ((self.index, self.index_range), (self.value, self.value_range)): low, high = ds.get_bounds() if low > rng.high or high < rng.low: return if len(index) == 0 or len(value) == 0 or len(index) != len(value): self._cached_data_pts = [] self._cache_valid = True size_diff = len(value) - len(index) if size_diff > 0: warnings.warn('len(value) %d - len(index) %d = %d' \ % (len(value), len(index), size_diff)) index_max = len(index) value = value[:index_max] else: index_max = len(value) index = index[:index_max] if index_max % 2: # We need an even number of points. Exclude the final one and # continue. warnings.warn('need an even number of points; got %d' % index_max) index = index[:index_max - 1] value = value[:index_max - 1] # TODO: restore the functionality of rendering highlighted portions # of the line #selection = self.index.metadata.get(self.metadata_name, None) #if selection is not None and type(selection) in (ndarray, list) and \ # len(selection) > 0: # Exclude NaNs and Infs. finite_mask = np.isfinite(value) & np.isfinite(index) # Since the line segment ends are paired, we need to exclude the whole pair if # one is not finite. finite_mask[::2] &= finite_mask[1::2] finite_mask[1::2] &= finite_mask[::2] self._cached_data_pts = [ np.column_stack([index[finite_mask], value[finite_mask]]) ] self._cache_valid = True def _render(self, gc, points, selected_points=None): if len(points) == 0: return gc.save_state() try: gc.set_antialias(True) gc.clip_to_rect(self.x, self.y, self.width, self.height) if selected_points is not None: self._render_segments(gc, selected_points, self.selected_color_, self.line_width + 10.0, self.selected_line_style_) # Render using the normal style self._render_segments(gc, points, self.color_, self.line_width, self.line_style_) finally: gc.restore_state() def _render_segments(self, gc, points, color, line_width, line_style): gc.set_stroke_color(color) gc.set_line_width(line_width) gc.set_line_dash(line_style) gc.begin_path() for ary in points: if len(ary) > 0: gc.line_set(ary[::2], ary[1::2]) gc.stroke_path() @on_trait_change('color,line_style,line_width') def _redraw(self): self.invalidate_draw() self.request_redraw()
class DataAxis(t.HasTraits): name = t.Str() units = t.Str() scale = t.Float() offset = t.Float() size = t.Int() index_in_array = t.Int() low_value = t.Float() high_value = t.Float() value = t.Range('low_value', 'high_value') low_index = t.Int(0) high_index = t.Int() slice = t.Instance(slice) slice_bool = t.Bool(False) index = t.Range('low_index', 'high_index') axis = t.Array() def __init__(self, size, index_in_array, name='', scale=1., offset=0., units='undefined', slice_bool=False): super(DataAxis, self).__init__() self.name = name self.units = units self.scale = scale self.offset = offset self.size = size self.high_index = self.size - 1 self.low_index = 0 self.index = 0 self.index_in_array = index_in_array self.update_axis() self.on_trait_change(self.update_axis, ['scale', 'offset', 'size']) self.on_trait_change(self.update_value, 'index') self.on_trait_change(self.set_index_from_value, 'value') self.on_trait_change(self._update_slice, 'slice_bool') self.on_trait_change(self.update_index_bounds, 'size') self.slice_bool = slice_bool def __repr__(self): if self.name is not None: return self.name + ' index: ' + str(self.index_in_array) def update_index_bounds(self): self.high_index = self.size - 1 def update_axis(self): self.axis = generate_axis(self.offset, self.scale, self.size) self.low_value, self.high_value = self.axis.min(), self.axis.max() # self.update_value() def _update_slice(self, value): if value is True: self.slice = slice(None) else: self.slice = None def get_axis_dictionary(self): adict = { 'name': self.name, 'scale': self.scale, 'offset': self.offset, 'size': self.size, 'units': self.units, 'index_in_array': self.index_in_array, 'slice_bool': self.slice_bool } return adict def update_value(self): self.value = self.axis[self.index] def value2index(self, value): """Return the closest index to the given value if between the limits, otherwise it will return either the upper or lower limits Parameters ---------- value : float Returns ------- int """ if value is None: return None else: index = int(round((value - self.offset) / \ self.scale)) if self.size > index >= 0: return index elif index < 0: messages.warning("The given value is below the axis limits") return 0 else: messages.warning("The given value is above the axis limits") return int(self.size - 1) def index2value(self, index): return self.axis[index] def set_index_from_value(self, value): self.index = self.value2index(value) # If the value is above the limits we must correct the value self.value = self.index2value(self.index) def calibrate(self, value_tuple, index_tuple, modify_calibration=True): scale = (value_tuple[1] - value_tuple[0]) /\ (index_tuple[1] - index_tuple[0]) offset = value_tuple[0] - scale * index_tuple[0] if modify_calibration is True: self.offset = offset self.scale = scale else: return offset, scale traits_view = \ tui.View( tui.Group( tui.Group( tui.Item(name = 'name'), tui.Item(name = 'size', style = 'readonly'), tui.Item(name = 'index_in_array', style = 'readonly'), tui.Item(name = 'index'), tui.Item(name = 'value', style = 'readonly'), tui.Item(name = 'units'), tui.Item(name = 'slice_bool', label = 'slice'), show_border = True,), tui.Group( tui.Item(name = 'scale'), tui.Item(name = 'offset'), label = 'Calibration', show_border = True,), label = "Data Axis properties", show_border = True,), )
class SplineExplorer(traits.HasTraits): """A simple UI to adjust the parameters and view the resulting splines.""" v_min = traits.Float(0) v_max = traits.Float(15) a_min = traits.Float(-5) a_max = traits.Float(5) j_min = traits.Float(-2.5) j_max = traits.Float(2.5) mass = traits.Float(400) q_i = traits.Float v_i = traits.Float a_i = traits.Float t_i = traits.Float q_f = traits.Float(100) v_f = traits.Float(0) a_f = traits.Float(0) t_f = traits.Float(18) plot_names = traits.List( ["Position", "Jerk", "Velocity", "Power", "Acceleration"]) active_plots = traits.List target_type = traits.Enum(('Position', 'Velocity', 'Acceleration', 'Time')) plot_container = traits.Instance(Container) recalculate = menu.Action(name="Recalculate", action="recalc") dump = menu.Action(name="Print", action="dump") save = menu.Action(name="Save", action="save") trait_view = ui.View(ui.HGroup( ui.VGroup( ui.Item(name='target_type', label='Target'), ui.VGroup(ui.Item(name='active_plots', show_label=False, editor=ui.CheckListEditor(cols=3, name='plot_names'), style='custom'), label='Show Plots', show_border=True), ui.VGroup(ui.Item(name='q_i', label='Position'), ui.Item(name='v_i', label='Velocity'), ui.Item(name='a_i', label='Acceleration'), ui.Item(name='t_i', label='Time'), label='Initial Conditions', show_border=True), ui.VGroup(ui.Item( name='q_f', label='Position', enabled_when="target_type not in ('Velocity', 'Acceleration')" ), ui.Item(name='v_f', label='Velocity', enabled_when="target_type != 'Acceleration'"), ui.Item(name='a_f', label='Acceleration'), ui.Item(name='t_f', label='Time', enabled_when="target_type == 'Time'"), label='Final Conditions:', show_border=True), ui.VGroup(ui.Item(name='v_min', label='Min Velocity'), ui.Item(name='v_max', label='Max Velocity'), ui.Item(name='a_min', label='Min Acceleration'), ui.Item(name='a_max', label='Max Acceleration'), ui.Item(name='j_min', label='Min Jerk'), ui.Item(name='j_max', label='Max Jerk'), ui.Item(name='mass', label='Vehicle Mass'), label='Constraints', show_border=True)), ui.Item('plot_container', editor=ComponentEditor(), show_label=False)), title='Cubic Spline Explorer', handler=SEButtonHandler(), buttons=[recalculate, dump, save], resizable=True, width=1000) def __init__(self): super(SplineExplorer, self).__init__() self.active_plots = self.plot_names[:] self.active_plots.remove("Power") self.calc() def calc(self): try: self.solver = TrajectorySolver(self.v_max, self.a_max, self.j_max, self.v_min, self.a_min, self.j_min) self.initial = Knot(self.q_i, self.v_i, self.a_i, self.t_i) self.final = Knot(self.q_f, self.v_f, self.a_f, self.t_f) if self.target_type == 'Position': self.spline = self.solver.target_position( self.initial, self.final) elif self.target_type == 'Velocity': self.spline = self.solver.target_velocity( self.initial, self.final) elif self.target_type == 'Acceleration': self.spline = self.solver.target_acceleration( self.initial, self.final) elif self.target_type == 'Time': self.spline = self.solver.target_time(self.initial, self.final) pos = vel = accel = jerk = power = False if "Position" in self.active_plots: pos = True if "Velocity" in self.active_plots: vel = True if "Acceleration" in self.active_plots: accel = True if "Jerk" in self.active_plots: jerk = True if "Power" in self.active_plots: power = True self.plotter = CSplinePlotter(self.spline, self.v_max, self.a_max, self.j_max, self.v_min, self.a_min, self.j_min, mass=self.mass, plot_pos=pos, plot_vel=vel, plot_accel=accel, plot_jerk=jerk, plot_power=power) self.plot_container = self.plotter.container except: self.initial = None self.final = None self.spline = None self.plot_container = Container() def display(self): self.configure_traits() def get_save_filename(self): """Get a filename from the user via a FileDialog. Returns the filename.""" dialog = FileDialog(action="save as", default_filename="spline_00", wildcard="*.png") dialog.open() if dialog.return_code == OK: return dialog.path def save(self, path): """Save an image of the plot. Does not catch any exceptions.""" # Create a graphics context of the right size win_size = self.plot_container.outer_bounds plot_gc = chaco.PlotGraphicsContext(win_size) #plot_gc.set_fill_color("transparent") # Place the plot component into it plot_gc.render_component(self.plot_container) # Save out to the user supplied filename plot_gc.save(path) def _active_plots_changed(self): self.calc() def _target_type_changed(self): self.calc()
class Berth(Sim.Process, traits.HasTraits): label = traits.Str platform_index = traits.Int station_id = traits.Int vehicle = traits.Instance('pyprt.sim.vehicle.BaseVehicle', None) _busy = traits.Bool _unload = traits.Bool _load = traits.Bool _msg_id = traits.Int _pax = traits.List(traits.Instance('pyprt.sim.events.Passenger')) traits_view = ui.View(ui.HGroup(ui.Item(name='vehicle', editor = ui.TextEditor()), ui.Item('busy'))) def __init__(self, label, station_id, vehicle, **tr): Sim.Process.__init__(self, name=label) traits.HasTraits.__init__(self, **tr) self.label = label self.station_id = station_id self.vehicle = vehicle # Control flags/settings for the run loop self._busy = False self._unload = False self._load = False self._msg_id = api.NONE_ID self._pax = [] def __str__(self): return str( (self.label, str(self.vehicle), str(self._busy)) ) def is_empty(self): """Returns True if the berth is not occupied by a vehicle.""" return False if self.vehicle else True def unload(self, msg_id, passengers): self._unload = True self._msg_id = msg_id self._pax = passengers if self.passive: Sim.reactivate(self, prior=True) def load(self, msg_id, passengers): self._load = True self._msg_id = msg_id self._pax = passengers if self.passive: Sim.reactivate(self, prior=True) def is_busy(self): return self._busy def run(self): station = common.stations[self.station_id] while True: # Unloading if self._unload: for pax in reversed(self.vehicle.passengers): self._unload = False self._busy = True yield Sim.hold, self, pax.unload_delay del self.vehicle.passengers[-1] # pax that left vehicle pax.loc = self.station pax.trip_end = Sim.now() if self.station_id == pax.dest_station.ID: pax.trip_success = True common.delivered_pax.add(pax) logging.info("T=%4.3f %s delivered to %s by %s. Unloaded in berth %s", Sim.now(), pax, self.station_id, self.vehicle, self.label) self._busy = False if station.passive(): Sim.reactivate(station, prior = True) # Loading elif self._load: for pax in self._pax: self._load = False self._busy = True s_notify = api.SimNotifyPassengerLoadStart() s_notify.vID = self.vehicle.ID s_notify.sID = self.station_id s_notify.pID = pax.ID common.interface.send(api.SIM_NOTIFY_PASSENGER_LOAD_START, s_notify) yield Sim.hold, self, pax.load_delay self.vehicle.passengers.append(pax) pax.trip_boarded = Sim.now() pax.loc = self.vehicle logging.info("T=%4.3f %s loaded into %s at station %s", Sim.now(), pax, self.vehicle, self.station_id) e_notify = api.SimNotifyPassengerLoadEnd() e_notify.vID = self.vehicle.ID e_notify.sID = self.station_id e_notify.pID = pax.ID common.interface.send(api.SIM_NOTIFY_PASSENGER_LOAD_END, e_notify) # If using the LOBBY policy, notify that passenger load command # has completed. if self._load_msgID: cmd_notify = api.SimCompletePassengerLoadVehicle() cmd_notify.msgID = self._msgID cmd_notify.vID = self.vehicle.ID cmd_notify.sID = self.station_id cmd_notify.pID = pax.ID common.interface.send(api.SIM_COMPLETE_PASSENGER_LOAD_VEHICLE, cmd_notify) self._load_msgID = None self._busy = False if station.passive(): Sim.reactivate(station, prior = True) else: assert not self._busy yield Sim.passivate, self
class StationReport(Report): s_list = traits.List traits_view = ui.View( ui.Group( ui.Item('s_list', editor=ui.TabularEditor( adapter=StationTabularAdapater(), operations = [], images = [], editable=False, column_clicked='handler.column_clicked'), show_label=False) ), handler=SortHandler(), kind='live' ) def __init__(self): super(StationReport, self).__init__(title='Stations') self._header = ["id", "Label", "Platforms", "Berths", "Unload", "Load", "Load|Unload", "Queue", "Current Pax", "Pax Created", "Pax Arrived", "Pax Departed", "Min Pax Wait", "Mean Pax Wait", "Max Pax Wait", "Vehicles Arrived", "Min Vehicle Dwell", "Mean Vehicle Dwell", "Max Vehicle Dwell"] # TODO: Berth specific stats? self._lines = [] def update(self): if len(self.s_list) != len(common.stations): self.s_list = common.stations.values() self.s_list.sort() lines = [] for s in self.s_list: assert isinstance(s, Station) berths, unload, load, unload_load, queue = 0, 0, 0, 0, 0 for platform in s.platforms: for berth in platform.berths: berths += 1 if berth.unloading and berth.loading: unload_load += 1 elif berth.unloading: unload += 1 elif berth.loading: load += 1 else: # no loading or unloading capability queue += 1 pax_wait_times = s.all_pax_wait_times() if pax_wait_times: min_pax_wait = sec_to_hms(min(pax_wait_times)) mean_pax_wait = sec_to_hms(sum(pax_wait_times)/len(pax_wait_times)) max_pax_wait = sec_to_hms(max(pax_wait_times)) else: min_pax_wait = "N/A" mean_pax_wait = "N/A" max_pax_wait = "N/A" lines.append(["%d" % s.ID, s.label, "%d" % len(s.platforms), "%d" % berths, "%d" % unload, "%d" % load, "%d" % unload_load, "%d" % queue, "%d" % len(s._passengers), "%d" % sum(1 for pax in s._all_passengers if pax.src_station is s), "%d" % s._pax_arrivals_count, "%d" % s._pax_departures_count, min_pax_wait, mean_pax_wait, max_pax_wait, "inc", # TODO: Vehicle-related stats "inc", "inc", "inc"]) self._lines = lines def __str__(self): line_strings = [self.title, self.FIELD_DELIMITER.join(self._header)] for line in self._lines: line_str = self.FIELD_DELIMITER.join(line) line_strings.append(line_str) return self.LINE_DELIMETER.join(line_strings)
class VehicleReport(Report): v_list = traits.List traits_view = ui.View( ui.Group( ui.Item('v_list', editor=ui.TabularEditor( adapter=VehicleTabularAdapater(), operations = [], images = [], editable=False, column_clicked='handler.column_clicked'), show_label=False) ), handler=SortHandler(), kind='live' ) def __init__(self): super(VehicleReport, self).__init__(title='Vehicles') self._units_notice = "All values reported in units of meters and seconds." self._header = ["id", "Label", "LocId", "Position", "Velocity", "Accel", "TotalPassengers", "MaxPassengers", "TimeWeightedAvePax", "DistWeightedAvePax", "DistTravelled", "EmptyDist", "PassengerMeters", "MaxVelocity", "MinVelocity", "MaxAccel", "MinAccel", "MaxJerk", "MinJerk", ] self._lines = [] def update(self): # Check if the locally cached vehicle list has gotten stale. if len(self.v_list) != len(common.vehicles): self.v_list = common.vehicles.values() self.v_list.sort() lines = [] for v in self.v_list: assert isinstance(v, BaseVehicle) v.update_stats() extrema_velocities, extrema_times = v._spline.get_extrema_velocities() max_vel = max(extrema_velocities) min_vel = min(extrema_velocities) max_jerk = max(v._spline.j) min_jerk = min(v._spline.j) lines.append([str(v.ID), v.label, str(v.loc.ID), "%.3f" % v.pos, "%.3f" % v.vel, "%.3f" % v.accel, str(v.total_pax), str(v.max_pax), "%.2f" % v.time_ave_pax, "%.2f" % v.dist_ave_pax, "%d" % v.dist_travelled, "%d" % v.empty_dist, "%d" % v.pax_dist, "%.3f" % max_vel, "%.3f" % min_vel, "%.3f" % v._spline.get_max_acceleration(), "%.3f" % v._spline.get_min_acceleration(), "%.3f" % max_jerk, "%.3f" % min_jerk ]) self._lines = lines def __str__(self): line_strings = [self.title, self._units_notice, self.FIELD_DELIMITER.join(self._header)] for line in self._lines: line_str = self.FIELD_DELIMITER.join(line) line_strings.append(line_str) return self.LINE_DELIMETER.join(line_strings)
class PaxReport(Report): """List of details for all passengers in a gridview""" passengers = traits.List traits_view = ui.View( ui.Group( ui.Item('passengers', editor=ui.TabularEditor( adapter=PassengerTabularAdapter(), operations = [], images = [], editable=False, column_clicked='handler.column_clicked'), show_label=False) ), handler=SortHandler(), kind='live' ) def __init__(self): super(PaxReport, self).__init__(title="Passengers") self._header = ["id", "CreationTime", "SrcStatId", "DestStatId", "CurrentLocId", "CurrentLocType", "TimeWaiting", "TimeWalking", "TimeRiding", "TotalTime", "Success", "Mass", "WillShare", "UnloadDelay", "LoadDelay", "SrcStationLabel", "DestStationLabel", "CurrLocLabel"] self._lines = [] def update(self): # Check if the locally cached vehicle list has gotten stale. if len(self.passengers) != len(common.passengers): self.passengers = common.passengers.values() self.passengers.sort() lines = [] for pax in self.passengers: assert isinstance(pax, Passenger) lines.append([str(pax.ID), "%.3f" % pax.time, str(pax.src_station.ID), str(pax.dest_station.ID), str(pax.loc.ID), self.type_str(pax.loc), sec_to_hms(pax.wait_time), sec_to_hms(pax.walk_time), sec_to_hms(pax.ride_time), sec_to_hms(pax.total_time), str(pax.trip_success), str(pax.mass), str(pax.will_share), str(pax.unload_delay), str(pax.load_delay), pax.src_station.label, pax.dest_station.label, pax.loc.label]) self._lines = lines def __str__(self): line_strings = [self.title, self.FIELD_DELIMITER.join(self._header)] for line in self._lines: line_str = self.FIELD_DELIMITER.join(line) line_strings.append(line_str) return self.LINE_DELIMETER.join(line_strings)
class Station(traits.HasTraits): platforms = traits.List(traits.Instance(Platform)) track_segments = traits.Set(traits.Instance(TrackSegment)) # Passengers waiting at the station. _passengers = traits.List(traits.Instance(Passenger)) traits_view = ui.View( ui.VGroup( ui.Group( ui.Label('Waiting Passengers'), ## ui.Item(name='_passengers', ## show_label = False, ## editor=Passenger.table_editor ## ), show_border=True)), title='Station', # was: self.label # scrollable = True, resizable=True, height=700, width=470, handler=NoWritebackOnCloseHandler()) table_editor = ui.TableEditor( columns=[ ui_tc.ObjectColumn(name='ID', label='ID', tooltip='Station ID'), ui_tc.ObjectColumn(name='label', label='Label', tooltip='Non-unique identifier'), ui_tc.ExpressionColumn( label='Current Pax', format='%d', expression='len(object._passengers)', tooltip='Number of passengers currently at station.') # TODO: The rest... ], deletable=False, editable=False, sortable=True, sort_model=False, auto_size=True, orientation='vertical', show_toolbar=True, reorderable=False, rows=15, row_factory=traits.This) def __init__(self, ID, label, track_segments, storage_entrance_delay, storage_exit_delay, storage_dict): traits.HasTraits.__init__(self) self.ID = ID self.label = label self.platforms = [] self.track_segments = track_segments self.storage_entrance_delay = storage_entrance_delay self.storage_exit_delay = storage_exit_delay # Keyed by the VehicleModel name (string) with FIFO queues as the values. self._storage_dict = storage_dict self._pax_arrivals_count = 0 self._pax_departures_count = 0 self._pax_times = [(Sim.now(), len(self._passengers)) ] # elements are (time, num_pax) self._all_passengers = [] def __str__(self): if self.label: return self.label else: return str(self.ID) def __hash__(self): return hash(self.ID) def __eq__(self, other): if not isinstance(other, Station): return False else: return self.ID == other.ID def __ne__(self, other): if not isinstance(other, Station): return True else: return self.ID != other.ID def __cmp__(self, other): return cmp(self.ID, other.ID) def startup(self): """Activates all the berths""" for platform in self.platforms: for berth in platform.berths: Sim.activate(berth, berth.run()) def add_passenger(self, pax): """Add a passenger to this station.""" assert pax not in self._passengers self._passengers.append(pax) self._all_passengers.append(pax) self._pax_times.append((Sim.now(), len(self._passengers))) def remove_passenger(self, pax): """Remove a passenger from this station, such as when they load into a vehicle, or when they storm off in disgust...""" self._passengers.remove(pax) self._pax_times.append((Sim.now(), len(self._passengers))) def get_num_passengers(self): return len(self._passengers) num_passengers = property(get_num_passengers) def get_stored_vehicle_count(self, vehicle_model): sv_count = 0 store = self._storage_dict[vehicle_model] sv_count += store.get_stored_vehicle_count() return sv_count def all_pax_wait_times(self): """Returns a list of wait times for all passengers, not just the current ones.""" times = [] for pax in self._all_passengers: for start, end, loc in pax._wait_times: if loc is self: if end is None: times.append(Sim.now() - start) else: times.append(end - start) return times def curr_pax_wait_times(self): """Returns a list of wait times for passengers currently waiting in the station.""" times = [] for pax in self._passengers: for start, end, loc in pax._wait_times: if loc is self: if end is None: times.append(Sim.now() - start) else: times.append(end - start) return times def get_min_all_pax_wait(self): try: return min(self.all_pax_wait_times()) except ValueError: # Empty sequence assert len(self.all_pax_wait_times()) == 0 return 0 min_all_pax_wait = property(get_min_all_pax_wait) def get_mean_all_pax_wait(self): try: wait_times = self.all_pax_wait_times() return sum(wait_times) / len(wait_times) except ZeroDivisionError: return 0 mean_all_pax_wait = property(get_mean_all_pax_wait) def get_max_all_pax_wait(self): try: return max(self.all_pax_wait_times()) except ValueError: # Empty sequence return 0 max_all_pax_wait = property(get_max_all_pax_wait) def get_min_curr_pax_wait(self): try: return min(self.curr_pax_wait_times()) except ValueError: # Empty sequence assert len(self.curr_pax_wait_times()) == 0 return 0 min_curr_pax_wait = property(get_min_curr_pax_wait) def get_mean_curr_pax_wait(self): try: wait_times = self.curr_pax_wait_times() return sum(wait_times) / len(wait_times) except ZeroDivisionError: return 0 mean_curr_pax_wait = property(get_mean_curr_pax_wait) def get_max_curr_pax_wait(self): try: return max(self.curr_pax_wait_times()) except ValueError: # Empty sequence return 0 max_curr_pax_wait = property(get_max_curr_pax_wait)
# # You should have received a copy of the GNU General Public License # along with Hyperspy. If not, see <http://www.gnu.org/licenses/>. import enthought.traits.api as t import enthought.traits.ui.api as tui from enthought.traits.ui.menu import OKButton class Message(t.HasTraits): text = t.Str information_view = tui.View(tui.Item( 'text', show_label=False, style='readonly', springy=True, width=300, ), kind='modal', buttons=[ OKButton, ]) def information(text): message = Message() message.text = text message.edit_traits(view=information_view)
class PowerReport(enable.Component): SAMPLE_INTERVAL = 1 # seconds v_list = traits.List plot_data = traits.Instance(chaco.ArrayPlotData) plot_container = traits.Instance(enable.Component) plots = traits.Dict traits_view = ui.View( ui.HGroup( ## ui.Item(name='v_list', editor=ui.EnumEditor(values=[str(v) for v in self.v_list])), ui.Item(name='plot_container', editor=enable.ComponentEditor(), show_label=False) ), kind='live' ) def __init__(self): super(PowerReport, self).__init__(title='Power') def update(self): # Check if the locally cached vehicle list has gotten stale. if len(self.v_list) != len(common.vehicles): self.v_list[:] = common.vehicles.values() self.v_list.sort() self.plot_data = self.make_plot_data(self.v_list) self.plots, self.plot_container = self.make_plots(self.plot_data) def make_plot_data(self, v_list): """Returns a chaco.ArrayPlotData containing the following: v_power -- A 2D array where each row is a vehicle (indexes match self.v_list), and each column is a time point. total_power - A 1D row array giving the network-wide power usage at each time point. Parameters: v_list -- a sequence of Vehicle objects, sorted by ID Does not support negative velocities. """ end_time = min(common.Sim.now(), common.config_manager.get_sim_end_time()) sample_times = numpy.arange(0, end_time+self.SAMPLE_INTERVAL, self.SAMPLE_INTERVAL) power_array = numpy.zeros( (len(v_list), len(sample_times)), dtype=numpy.float32) air_density = common.air_density wind_speed = common.wind_speed wind_angle = common.wind_direction # 0 is blowing FROM the East g = 9.80665 # m/s^2 PI_2 = math.pi/2 PI_3_2 = math.pi * 1.5 for v_idx, v in enumerate(v_list): masses = v.get_total_masses(sample_times) # The sample times may be out of the vehicle spline's valid range, # since the vehicle may not have been created at the beginning of # the simulation. v_start_time = v._spline.t[0] v_end_time = v._spline.t[-1] for idx, t in enumerate(sample_times): if t >= v_start_time: v_start_idx = idx # left index break for idx in xrange(len(sample_times)-1,-1,-1): if sample_times[idx] <= v_end_time: v_end_idx = idx+1 # right index break v_sample_times = sample_times[v_start_idx:v_end_idx] v_knots = v._spline.evaluate_sequence(v_sample_times) knots = [None] * len(sample_times) knots[v_start_idx:v_end_idx] = v_knots CdA = v.frontal_area * v.drag_coefficient path_idx = 0 path_sum = 0 loc = v._path[path_idx] last_elevation = loc.get_elevation(v_knots[0].pos) for sample_idx, (t, mass, knot) in enumerate(itertools.izip(sample_times, masses, knots)): if knot is None: power_array[v_idx, sample_idx] = 0 continue # Track where we are on the vehicle's path pos = knot.pos - path_sum if pos >= loc.length: path_sum += loc.length path_idx += 1 pos = knot.pos - path_sum loc = v._path[path_idx] # Power required to overcome rolling resistance. Ignores effect of # track slope and assumes that rolling resistance is constant # at different velocities. if v.rolling_coefficient: rolling_power = v.rolling_coefficient * g * mass * knot.vel # Force * velocity else: rolling_power = 0 # Rolling resistance not modelled # Power to accelerate / decelerate (change in kinetic energy) accel_power = mass * knot.accel * knot.vel # Power to overcome aero drag if wind_speed and knot.vel != 0: # No power use when stopped travel_angle = loc.get_direction(knot.pos - path_sum) # 0 is travelling TOWARDS the East incidence_angle = wind_angle - travel_angle if PI_2 <= incidence_angle <= PI_3_2: # tail wind vel = knot.vel - math.cos(incidence_angle)*wind_speed else: # head wind vel = knot.vel + math.cos(incidence_angle)*wind_speed else: vel = knot.vel aero_power = 0.5 * air_density * vel*vel*vel * CdA # Power from elevation changes (change in potential energy) elevation = loc.get_elevation(pos) delta_elevation = elevation - last_elevation elevation_power = g * delta_elevation last_elevation = elevation # Adjust power usages by efficiency net_power = accel_power + rolling_power + aero_power + elevation_power if net_power > 0: net_power /= v.powertrain_efficiency # low efficiency increases power required elif net_power < 0: net_power *= v.regenerative_braking_efficiency # low efficiency decreases power recovered power_array[v_idx, sample_idx] = net_power power_array = numpy.divide(power_array, 1000.0) # convert from Watts to KW positive_power = numpy.clip(power_array, 0, numpy.inf) positive_total_power = numpy.sum(positive_power, axis=0) negative_power = numpy.clip(power_array, -numpy.inf, 0) negative_total_power = numpy.sum(negative_power, axis=0) net_total_power = positive_total_power + negative_total_power energy_array = numpy.cumsum(power_array, axis=1) energy_array = numpy.divide(energy_array, 3600/self.SAMPLE_INTERVAL) # convert to KW-hours total_energy_array = numpy.sum(energy_array, axis=0) return chaco.ArrayPlotData( sample_times=chaco.ArrayDataSource(sample_times, sort_order="ascending"), positive_total_power=chaco.ArrayDataSource(positive_total_power), negative_total_power=chaco.ArrayDataSource(negative_total_power), net_total_power=chaco.ArrayDataSource(net_total_power), total_energy=chaco.ArrayDataSource(total_energy_array), v_power=power_array, positive_power=positive_power, negative_power=negative_power, energy_array=energy_array ) def make_plots(self, plot_data): """Create overlapping power and energy plots from the supplied plot_data. Parameters: plot_data -- A chaco.ArrayPlotData object. Expected to be created by self.make_plot_data. Return: A 2-tuple containing: - A dict containing plots, keyed by the plot name. - A chaco.OverlayPlotContainer containing the plots. """ times_mapper = chaco.LinearMapper(range=chaco.DataRange1D(plot_data.get_data('sample_times'), )) graph_colors = {'positive_total_power':'black', 'negative_total_power':'red', 'net_total_power':'purple', 'total_energy':'green'} plots = {} # Dict of all plots # Power graphs power_names = ['positive_total_power', 'negative_total_power', 'net_total_power'] power_data_range = chaco.DataRange1D(*[plot_data.get_data(name) for name in power_names]) power_mapper = chaco.LinearMapper(range=power_data_range) power_plots = {} for plot_name in power_names: plot = chaco.LinePlot(index=plot_data.get_data('sample_times'), value=plot_data.get_data(plot_name), index_mapper=times_mapper, value_mapper=power_mapper, border_visible=False, bg_color='transparent', line_style='solid', color=graph_colors[plot_name], line_width=2) power_plots[plot_name] = plot plots[plot_name] = plot # Energy graphs -- use a different value scale than power energy_plot_names = ['total_energy'] energy_data_range = chaco.DataRange1D(*[plot_data.get_data(name) for name in energy_plot_names]) energy_mapper = chaco.LinearMapper(range=energy_data_range) energy_plots = {} for plot_name in energy_plot_names: plot = chaco.LinePlot(index=plot_data.get_data('sample_times'), value=plot_data.get_data(plot_name), index_mapper=times_mapper, value_mapper=energy_mapper, border_visible=False, bg_color='transarent', line_style='solid', color=graph_colors[plot_name], line_width=2) energy_plots[plot_name] = plot plots[plot_name] = plot # Blank plot -- Holds the grid and axis, and acts as a placeholder when # no other graphs are activated. blank_values = chaco.ArrayDataSource(numpy.zeros( plot_data.get_data('sample_times').get_size() )) blank_plot = chaco.LinePlot(index=plot_data.get_data('sample_times'), value=blank_values, index_mapper=times_mapper, value_mapper=power_mapper, border_visible=True, bg_color='transparent', line_width=0) plots['blank_plot'] = plot times_axis = chaco.PlotAxis(orientation='bottom', title="Time (seconds)", mapper=times_mapper, component=blank_plot) power_axis = chaco.PlotAxis(orientation='left', title="Power (KW)", mapper=power_mapper, component=blank_plot) energy_axis = chaco.PlotAxis(orientation='right', title="Energy (KW-hrs)", mapper=energy_mapper, component=blank_plot) blank_plot.underlays.append(times_axis) blank_plot.underlays.append(power_axis) blank_plot.underlays.append(energy_axis) # Add zoom capability blank_plot.overlays.append(tools.ZoomTool(blank_plot, tool_mode='range', axis='index', always_on=True, drag_button='left')) plot_container = chaco.OverlayPlotContainer() for plot in power_plots.itervalues(): plot_container.add(plot) for plot in energy_plots.itervalues(): plot_container.add(plot) plot_container.add(blank_plot) plot_container.padding_left = 60 plot_container.padding_right = 60 plot_container.padding_top = 20 plot_container.padding_bottom = 50 # Legend legend = chaco.Legend(component=plot_container, padding=20, align="ur") legend.tools.append(tools.LegendTool(legend, drag_button="right")) legend.plots = {} legend.plots.update(power_plots) legend.plots.update(energy_plots) plot_container.overlays.append(legend) return plots, plot_container
class EgertonPanel(t.HasTraits): define_background_window = t.Bool(False) bg_window_size_variation = t.Button() background_substracted_spectrum_name = t.Str('signal') extract_background = t.Button() define_signal_window = t.Bool(False) signal_window_size_variation = t.Button() signal_name = t.Str('signal') extract_signal = t.Button() view = tu.View(tu.Group( tu.Group('define_background_window', tu.Item('bg_window_size_variation', label = 'window size effect', show_label=False), tu.Item('background_substracted_spectrum_name'), tu.Item('extract_background', show_label=False), ), tu.Group('define_signal_window', tu.Item('signal_window_size_variation', label = 'window size effect', show_label=False), tu.Item('signal_name', show_label=True), tu.Item('extract_signal', show_label=False)),)) def __init__(self, SI): self.SI = SI # Background self.bg_span_selector = None self.pl = components.PowerLaw() self.bg_line = None self.bg_cube = None # Signal self.signal_span_selector = None self.signal_line = None self.signal_map = None self.map_ax = None def store_current_spectrum_bg_parameters(self, *args, **kwards): if self.define_background_window is False or \ self.bg_span_selector.range is None: return pars = utils.two_area_powerlaw_estimation( self.SI, *self.bg_span_selector.range,only_current_spectrum = True) self.pl.r.value = pars['r'] self.pl.A.value = pars['A'] if self.define_signal_window is True and \ self.signal_span_selector.range is not None: self.plot_signal_map() def _define_background_window_changed(self, old, new): if new is True: self.bg_span_selector = \ drawing.widgets.ModifiableSpanSelector( self.SI.hse.spectrum_plot.left_ax, onselect = self.store_current_spectrum_bg_parameters, onmove_callback = self.plot_bg_removed_spectrum) elif self.bg_span_selector is not None: if self.bg_line is not None: self.bg_span_selector.ax.lines.remove(self.bg_line) self.bg_line = None if self.signal_line is not None: self.bg_span_selector.ax.lines.remove(self.signal_line) self.signal_line = None self.bg_span_selector.turn_off() self.bg_span_selector = None def _bg_window_size_variation_fired(self): if self.define_background_window is False: return left = self.bg_span_selector.rect.get_x() right = left + self.bg_span_selector.rect.get_width() energy_window_dependency(self.SI, left, right, min_width = 10) def _extract_background_fired(self): if self.pl is None: return signal = self.SI() - self.pl.function(self.SI.energy_axis) i = self.SI.energy2index(self.bg_span_selector.range[1]) signal[:i] = 0. s = Spectrum({'calibration' : {'data_cube' : signal}}) s.get_calibration_from(self.SI) interactive_ns[self.background_substracted_spectrum_name] = s def _define_signal_window_changed(self, old, new): if new is True: self.signal_span_selector = \ drawing.widgets.ModifiableSpanSelector( self.SI.hse.spectrum_plot.left_ax, onselect = self.store_current_spectrum_bg_parameters, onmove_callback = self.plot_signal_map) self.signal_span_selector.rect.set_color('blue') elif self.signal_span_selector is not None: self.signal_span_selector.turn_off() self.signal_span_selector = None def plot_bg_removed_spectrum(self, *args, **kwards): if self.bg_span_selector.range is None: return self.store_current_spectrum_bg_parameters() ileft = self.SI.energy2index(self.bg_span_selector.range[0]) iright = self.SI.energy2index(self.bg_span_selector.range[1]) ea = self.SI.energy_axis[ileft:] if self.bg_line is not None: self.bg_span_selector.ax.lines.remove(self.bg_line) self.bg_span_selector.ax.lines.remove(self.signal_line) self.bg_line, = self.SI.hse.spectrum_plot.left_ax.plot( ea, self.pl.function(ea), color = 'black') self.signal_line, = self.SI.hse.spectrum_plot.left_ax.plot( self.SI.energy_axis[iright:], self.SI()[iright:] - self.pl.function(self.SI.energy_axis[iright:]), color = 'black') self.SI.hse.spectrum_plot.left_ax.figure.canvas.draw() def plot_signal_map(self, *args, **kwargs): if self.define_signal_window is True and \ self.signal_span_selector.range is not None: ileft = self.SI.energy2index(self.signal_span_selector.range[0]) iright = self.SI.energy2index(self.signal_span_selector.range[1]) signal_sp = self.SI.data_cube[ileft:iright,...].squeeze().copy() if self.define_background_window is True: pars = utils.two_area_powerlaw_estimation( self.SI, *self.bg_span_selector.range, only_current_spectrum = False) x = self.SI.energy_axis[ileft:iright, np.newaxis, np.newaxis] A = pars['A'][np.newaxis,...] r = pars['r'][np.newaxis,...] self.bg_sp = (A*x**(-r)).squeeze() signal_sp -= self.bg_sp self.signal_map = signal_sp.sum(0) if self.map_ax is None: f = plt.figure() self.map_ax = f.add_subplot(111) if len(self.signal_map.squeeze().shape) == 2: self.map = self.map_ax.imshow(self.signal_map.T, interpolation = 'nearest') else: self.map, = self.map_ax.plot(self.signal_map.squeeze()) if len(self.signal_map.squeeze().shape) == 2: self.map.set_data(self.signal_map.T) self.map.autoscale() else: self.map.set_ydata(self.signal_map.squeeze()) self.map_ax.figure.canvas.draw() def _extract_signal_fired(self): if self.signal_map is None: return if len(self.signal_map.squeeze().shape) == 2: s = Image( {'calibration' : {'data_cube' : self.signal_map.squeeze()}}) s.xscale = self.SI.xscale s.yscale = self.SI.yscale s.xunits = self.SI.xunits s.yunits = self.SI.yunits interactive_ns[self.signal_name] = s else: s = Spectrum( {'calibration' : {'data_cube' : self.signal_map.squeeze()}}) s.energyscale = self.SI.xscale s.energyunits = self.SI.xunits interactive_ns[self.signal_name] = s
class CameraUI(traits.HasTraits): """Camera settings defines basic camera settings """ camera_control = traits.Instance(Camera, transient = True) cameras = traits.List([_NO_CAMERAS],transient = True) camera = traits.Any(value = _NO_CAMERAS, desc = 'camera serial number', editor = ui.EnumEditor(name = 'cameras')) search = traits.Button(desc = 'camera search action') _is_initialized= traits.Bool(False, transient = True) play = traits.Button(desc = 'display preview action') stop = traits.Button(desc = 'close preview action') on_off = traits.Button('On/Off', desc = 'initiate/Uninitiate camera action') gain = create_range_feature('gain',desc = 'camera gain',transient = True) shutter = create_range_feature('shutter', desc = 'camera exposure time',transient = True) format = create_mapped_feature('format',_FORMAT, desc = 'image format',transient = True) roi = traits.Instance(ROI,transient = True) im_shape = traits.Property(depends_on = 'format.value,roi.values') im_dtype = traits.Property(depends_on = 'format.value') capture = traits.Button() save_button = traits.Button('Save as...') message = traits.Str(transient = True) view = ui.View(ui.Group(ui.HGroup(ui.Item('camera', springy = True), ui.Item('search', show_label = False, springy = True), ui.Item('on_off', show_label = False, springy = True), ui.Item('play', show_label = False, enabled_when = 'is_initialized', springy = True), ui.Item('stop', show_label = False, enabled_when = 'is_initialized', springy = True), ), ui.Group( ui.Item('gain', style = 'custom'), ui.Item('shutter', style = 'custom'), ui.Item('format', style = 'custom'), ui.Item('roi', style = 'custom'), ui.HGroup(ui.Item('capture',show_label = False), ui.Item('save_button',show_label = False)), enabled_when = 'is_initialized', ), ), resizable = True, statusbar = [ ui.StatusItem( name = 'message')], buttons = ['OK']) #default initialization def __init__(self, **kw): super(CameraUI, self).__init__(**kw) self.search_cameras() def _camera_control_default(self): return Camera() def _roi_default(self): return ROI() #@display_cls_error def _get_im_shape(self): top, left, width, height = self.roi.values shape = (height, width) try: colors = _COLORS[self.format.value] if colors > 1: shape += (colors,) except KeyError: raise NotImplementedError('Unsupported format') return shape #@display_cls_error def _get_im_dtype(self): try: return _DTYPE[self.format.value] except KeyError: raise NotImplementedError('Unsupported format') def _search_fired(self): self.search_cameras() #@display_cls_error def search_cameras(self): """ Finds cameras if any and selects first from list """ try: cameras = get_number_cameras() except Exception as e: cameras = [] raise e finally: if len(cameras) == 0: cameras = [_NO_CAMERAS] self.cameras = cameras self.camera = cameras[0] #@display_cls_error def _camera_changed(self): if self._is_initialized: self._is_initialized= False self.camera_control.close() self.message = 'Camera uninitialized' #@display_cls_error def init_camera(self): self._is_initialized= False if self.camera != _NO_CAMERAS: self.camera_control.init(self.camera) self.init_features() self._is_initialized= True self.message = 'Camera initialized' #@display_cls_error def _on_off_fired(self): if self._is_initialized: self._is_initialized= False self.camera_control.close() self.message = 'Camera uninitialized' else: self.init_camera() #@display_cls_error def init_features(self): """ Initializes all features to values given by the camera """ features = self.camera_control.get_camera_features() self._init_single_valued_features(features) self._init_roi(features) #@display_cls_error def _init_single_valued_features(self, features): """ Initializes all single valued features to camera values """ for name, id in list(_SINGLE_VALUED_FEATURES.items()): feature = getattr(self, name) feature.low, feature.high = features[id]['params'][0] feature.value = self.camera_control.get_feature(id)[0] #@display_cls_error def _init_roi(self, features): for i,name in enumerate(('top','left','width','height')): feature = getattr(self.roi, name) low, high = features[FEATURE_ROI]['params'][i] value = self.camera_control.get_feature(FEATURE_ROI)[i] try: feature.value = value finally: feature.low, feature.high = low, high @traits.on_trait_change('format.value') def _on_format_change(self, object, name, value): if self._is_initialized: self.camera_control.set_preview_state(STOP_PREVIEW) self.camera_control.set_stream_state(STOP_STREAM) self.set_feature(FEATURE_PIXEL_FORMAT, [value]) @traits.on_trait_change('gain.value,shutter.value') def _single_valued_feature_changed(self, object, name, value): if self._is_initialized: self.set_feature(object.id, [value]) #@display_cls_error def set_feature(self, id, values, flags = 2): self.camera_control.set_feature(id, values, flags = flags) @traits.on_trait_change('roi.values') def a_roi_feature_changed(self, object, name, value): if self._is_initialized: self.set_feature(FEATURE_ROI, value) try: self._is_initialized= False self.init_features() finally: self._is_initialized= True #@display_cls_error def _play_fired(self): self.camera_control.set_preview_state(STOP_PREVIEW) self.camera_control.set_stream_state(STOP_STREAM) self.camera_control.set_stream_state(START_STREAM) self.camera_control.set_preview_state(START_PREVIEW) #@display_cls_error def _stop_fired(self): self.camera_control.set_preview_state(STOP_PREVIEW) self.camera_control.set_stream_state(STOP_STREAM) self.error = '' #@display_cls_error def _format_changed(self, value): self.camera_control.set_preview_state(STOP_PREVIEW) self.camera_control.set_stream_state(STOP_STREAM) self.camera_control.set_feature(FEATURE_PIXEL_FORMAT, [value],2) #@display_cls_error def _capture_fired(self): self.camera_control.set_stream_state(STOP_STREAM) self.camera_control.set_stream_state(START_STREAM) im = self.capture_image() plt.imshow(im) plt.show() def capture_image(self): im = numpy.empty(shape = self.im_shape, dtype = self.im_dtype) self.camera_control.get_next_frame(im) return im.newbyteorder('>') def save_image(self, fname): """Captures image and saves to format guessed from filename extension""" im = self.capture_image() base, ext = os.path.splitext(fname) if ext == '.npy': numpy.save(fname, im) else: im = toimage(im) im.save(fname) def _save_button_fired(self): f = pyface.FileDialog(action = 'save as') #wildcard = self.filter) if f.open() == pyface.OK: self.save_image(f.path) def capture_HDR(self): pass def __del__(self): try: self.camera_control.set_preview_state(STOP_PREVIEW) self.camera_control.set_stream_state(STOP_STREAM) except: pass
def make_view(self): """Make a traits view (popup window) for this station.""" pax_table_editor = ui.TableEditor( # Only the passenger data relevant when looking at a station. columns = [ui_tc.ObjectColumn(name='label', label='Name'), ui_tc.ObjectColumn(name='_start_time'), ui_tc.ObjectColumn(name='dest_station', label='Destination'), ui_tc.ObjectColumn(name='wait_time', label='Waiting (sec)', format="%.2f"), ui_tc.ObjectColumn(name='will_share', label='Will Share'), ui_tc.ObjectColumn(name='load_delay', label='Time to Board (sec)')], # more... deletable = True, # sort_model = True, auto_size = True, orientation = 'vertical', show_toolbar = True, reorderable = True, # Does this affect the actual boarding order (think no...) rows = 5, row_factory = events.Passenger) groups = ui.VGroup( ui.Group( ui.Label('Waiting Passengers'), ui.Item(name='passengers', show_label = False, editor=pax_table_editor ), show_border = True), # ui.Group( # ui.Label('Load Platform'), # ui.Item(name='load_platform', # show_label = False, # editor=ui.ListEditor(style='custom', # rows=len(self.load_platform)), # style='readonly'), # show_border = True # ), ui.Group( ui.Label('Queue'), ui.Item(name='queue', show_label=False, editor=ui.ListEditor(editor=ui.TextEditor()), style='readonly'), show_border = True # ), # ui.Group( # ui.Label('Unload Platform'), # ui.Item(name='unload_platform', # show_label = False, # editor=ui.ListEditor(style='custom', # rows=len(self.unload_platform)), # style='readonly', # ), # show_border = True, )) view = ui.View(groups, title=self.label, # scrollable = True, resizable = True, height = 700, width = 470 ) return view
class Signal(t.HasTraits, MVA): data = t.Any() axes_manager = t.Instance(AxesManager) original_parameters = t.Instance(Parameters) mapped_parameters = t.Instance(Parameters) physical_property = t.Str() def __init__(self, file_data_dict=None, *args, **kw): """All data interaction is made through this class or its subclasses Parameters: ----------- dictionary : dictionary see load_dictionary for the format """ super(Signal, self).__init__() self.mapped_parameters = Parameters() self.original_parameters = Parameters() if type(file_data_dict).__name__ == "dict": self.load_dictionary(file_data_dict) self._plot = None self.mva_results = MVA_Results() self._shape_before_unfolding = None self._axes_manager_before_unfolding = None def load_dictionary(self, file_data_dict): """Parameters: ----------- file_data_dict : dictionary A dictionary containing at least a 'data' keyword with an array of arbitrary dimensions. Additionally the dictionary can contain the following keys: axes: a dictionary that defines the axes (see the AxesManager class) attributes: a dictionary which keywords are stored as attributes of the signal class mapped_parameters: a dictionary containing a set of parameters that will be stored as attributes of a Parameters class. For some subclasses some particular parameters might be mandatory. original_parameters: a dictionary that will be accesible in the original_parameters attribute of the signal class and that typically contains all the parameters that has been imported from the original data file. """ self.data = file_data_dict['data'] if 'axes' not in file_data_dict: file_data_dict['axes'] = self._get_undefined_axes_list() self.axes_manager = AxesManager(file_data_dict['axes']) if not 'mapped_parameters' in file_data_dict: file_data_dict['mapped_parameters'] = {} if not 'original_parameters' in file_data_dict: file_data_dict['original_parameters'] = {} if 'attributes' in file_data_dict: for key, value in file_data_dict['attributes'].iteritems(): self.__setattr__(key, value) self.original_parameters.load_dictionary( file_data_dict['original_parameters']) self.mapped_parameters.load_dictionary( file_data_dict['mapped_parameters']) def _get_signal_dict(self): dic = {} dic['data'] = self.data.copy() dic['axes'] = self.axes_manager._get_axes_dicts() dic['mapped_parameters'] = \ self.mapped_parameters._get_parameters_dictionary() dic['original_parameters'] = \ self.original_parameters._get_parameters_dictionary() return dic def _get_undefined_axes_list(self): axes = [] for i in xrange(len(self.data.shape)): axes.append({ 'name': 'undefined', 'scale': 1., 'offset': 0., 'size': int(self.data.shape[i]), 'units': 'undefined', 'index_in_array': i, }) return axes def __call__(self, axes_manager=None): if axes_manager is None: axes_manager = self.axes_manager return self.data.__getitem__(axes_manager._getitem_tuple) def _get_hse_1D_explorer(self, *args, **kwargs): islice = self.axes_manager._slicing_axes[0].index_in_array inslice = self.axes_manager._non_slicing_axes[0].index_in_array if islice > inslice: return self.data.squeeze() else: return self.data.squeeze().T def _get_hse_2D_explorer(self, *args, **kwargs): islice = self.axes_manager._slicing_axes[0].index_in_array data = self.data.sum(islice) return data def _get_hie_explorer(self, *args, **kwargs): isslice = [ self.axes_manager._slicing_axes[0].index_in_array, self.axes_manager._slicing_axes[1].index_in_array ] isslice.sort() data = self.data.sum(isslice[1]).sum(isslice[0]) return data def _get_explorer(self, *args, **kwargs): nav_dim = self.axes_manager.navigation_dimension if self.axes_manager.signal_dimension == 1: if nav_dim == 1: return self._get_hse_1D_explorer(*args, **kwargs) elif nav_dim == 2: return self._get_hse_2D_explorer(*args, **kwargs) else: return None if self.axes_manager.signal_dimension == 2: if nav_dim == 1 or nav_dim == 2: return self._get_hie_explorer(*args, **kwargs) else: return None else: return None def plot(self, axes_manager=None): if self._plot is not None: try: self._plot.close() except: # If it was already closed it will raise an exception, # but we want to carry on... pass if axes_manager is None: axes_manager = self.axes_manager if axes_manager.signal_dimension == 1: # Hyperspectrum self._plot = mpl_hse.MPL_HyperSpectrum_Explorer() self._plot.spectrum_data_function = self.__call__ self._plot.spectrum_title = self.mapped_parameters.name self._plot.xlabel = '%s (%s)' % ( self.axes_manager._slicing_axes[0].name, self.axes_manager._slicing_axes[0].units) self._plot.ylabel = 'Intensity' self._plot.axes_manager = axes_manager self._plot.axis = self.axes_manager._slicing_axes[0].axis # Image properties if self.axes_manager._non_slicing_axes: self._plot.image_data_function = self._get_explorer self._plot.image_title = '' self._plot.pixel_size = \ self.axes_manager._non_slicing_axes[0].scale self._plot.pixel_units = \ self.axes_manager._non_slicing_axes[0].units self._plot.plot() elif axes_manager.signal_dimension == 2: # Mike's playground with new plotting toolkits - needs to be a # branch. """ if len(self.data.shape)==2: from drawing.guiqwt_hie import image_plot_2D image_plot_2D(self) import drawing.chaco_hie self._plot = drawing.chaco_hie.Chaco_HyperImage_Explorer(self) self._plot.configure_traits() """ self._plot = mpl_hie.MPL_HyperImage_Explorer() self._plot.image_data_function = self.__call__ self._plot.navigator_data_function = self._get_explorer self._plot.axes_manager = axes_manager self._plot.plot() else: messages.warning_exit('Plotting is not supported for this view') traits_view = tui.View( tui.Item('name'), tui.Item('physical_property'), tui.Item('units'), tui.Item('offset'), tui.Item('scale'), ) def plot_residual(self, axes_manager=None): """Plot the residual between original data and reconstructed data Requires you to have already run PCA or ICA, and to reconstruct data using either the pca_build_SI or ica_build_SI methods. """ if hasattr(self, 'residual'): self.residual.plot(axes_manager) else: print "Object does not have any residual information. Is it a \ reconstruction created using either pca_build_SI or ica_build_SI methods?" def save(self, filename, only_view=False, **kwds): """Saves the signal in the specified format. The function gets the format from the extension. You can use: - hdf5 for HDF5 - nc for NetCDF - msa for EMSA/MSA single spectrum saving. - bin to produce a raw binary file - Many image formats such as png, tiff, jpeg... Please note that not all the formats supports saving datasets of arbitrary dimensions, e.g. msa only suports 1D data. Parameters ---------- filename : str msa_format : {'Y', 'XY'} 'Y' will produce a file without the energy axis. 'XY' will also save another column with the energy axis. For compatibility with Gatan Digital Micrograph 'Y' is the default. only_view : bool If True, only the current view will be saved. Otherwise the full dataset is saved. Please note that not all the formats support this option at the moment. """ io.save(filename, self, **kwds) def _replot(self): if self._plot is not None: if self._plot.is_active() is True: self.plot() def get_dimensions_from_data(self): """Get the dimension parameters from the data_cube. Useful when the data_cube was externally modified, or when the SI was not loaded from a file """ dc = self.data for axis in self.axes_manager.axes: axis.size = int(dc.shape[axis.index_in_array]) print("%s size: %i" % (axis.name, dc.shape[axis.index_in_array])) self._replot() def crop_in_pixels(self, axis, i1=None, i2=None): """Crops the data in a given axis. The range is given in pixels axis : int i1 : int Start index i2 : int End index See also: --------- crop_in_units """ axis = self._get_positive_axis_index_index(axis) if i1 is not None: new_offset = self.axes_manager.axes[axis].axis[i1] # We take a copy to guarantee the continuity of the data self.data = self.data[(slice(None), ) * axis + (slice(i1, i2), Ellipsis)].copy() if i1 is not None: self.axes_manager.axes[axis].offset = new_offset self.get_dimensions_from_data() def crop_in_units(self, axis, x1=None, x2=None): """Crops the data in a given axis. The range is given in the units of the axis axis : int i1 : int Start index i2 : int End index See also: --------- crop_in_pixels """ i1 = self.axes_manager.axes[axis].value2index(x1) i2 = self.axes_manager.axes[axis].value2index(x2) self.crop_in_pixels(axis, i1, i2) def roll_xy(self, n_x, n_y=1): """Roll over the x axis n_x positions and n_y positions the former rows This method has the purpose of "fixing" a bug in the acquisition of the Orsay's microscopes and probably it does not have general interest Parameters ---------- n_x : int n_y : int Note: Useful to correct the SI column storing bug in Marcel's acquisition routines. """ self.data = np.roll(self.data, n_x, 0) self.data[:n_x, ...] = np.roll(self.data[:n_x, ...], n_y, 1) self._replot() # TODO: After using this function the plotting does not work def swap_axis(self, axis1, axis2): """Swaps the axes Parameters ---------- axis1 : positive int axis2 : positive int """ self.data = self.data.swapaxes(axis1, axis2) c1 = self.axes_manager.axes[axis1] c2 = self.axes_manager.axes[axis2] c1.index_in_array, c2.index_in_array = \ c2.index_in_array, c1.index_in_array self.axes_manager.axes[axis1] = c2 self.axes_manager.axes[axis2] = c1 self.axes_manager.set_signal_dimension() self._replot() def rebin(self, new_shape): """ Rebins the data to the new shape Parameters ---------- new_shape: tuple of ints The new shape must be a divisor of the original shape """ factors = np.array(self.data.shape) / np.array(new_shape) self.data = utils.rebin(self.data, new_shape) for axis in self.axes_manager.axes: axis.scale *= factors[axis.index_in_array] self.get_dimensions_from_data() def split_in(self, axis, number_of_parts=None, steps=None): """Splits the data The split can be defined either by the `number_of_parts` or by the `steps` size. Parameters ---------- number_of_parts : int or None Number of parts in which the SI will be splitted steps : int or None Size of the splitted parts axis : int The splitting axis Return ------ tuple with the splitted signals """ axis = self._get_positive_axis_index_index(axis) if number_of_parts is None and steps is None: if not self._splitting_steps: messages.warning_exit( "Please provide either number_of_parts or a steps list") else: steps = self._splitting_steps print "Splitting in ", steps elif number_of_parts is not None and steps is not None: print "Using the given steps list. number_of_parts dimissed" splitted = [] shape = self.data.shape if steps is None: rounded = (shape[axis] - (shape[axis] % number_of_parts)) step = rounded / number_of_parts cut_node = range(0, rounded + step, step) else: cut_node = np.array([0] + steps).cumsum() for i in xrange(len(cut_node) - 1): data = self.data[(slice(None), ) * axis + (slice(cut_node[i], cut_node[i + 1]), Ellipsis)] s = Signal({'data': data}) # TODO: When copying plotting does not work # s.axes = copy.deepcopy(self.axes_manager) s.get_dimensions_from_data() splitted.append(s) return splitted def unfold_if_multidim(self): """Unfold the datacube if it is >2D Returns ------- Boolean. True if the data was unfolded by the function. """ if len(self.axes_manager.axes) > 2: print "Automatically unfolding the data" self.unfold() return True else: return False def _unfold(self, steady_axes, unfolded_axis): """Modify the shape of the data by specifying the axes the axes which dimension do not change and the axis over which the remaining axes will be unfolded Parameters ---------- steady_axes : list The indexes of the axes which dimensions do not change unfolded_axis : int The index of the axis over which all the rest of the axes (except the steady axes) will be unfolded See also -------- fold """ # It doesn't make sense unfolding when dim < 3 if len(self.data.squeeze().shape) < 3: return False # We need to store the original shape and coordinates to be used by # the fold function only if it has not been already stored by a # previous unfold if self._shape_before_unfolding is None: self._shape_before_unfolding = self.data.shape self._axes_manager_before_unfolding = self.axes_manager new_shape = [1] * len(self.data.shape) for index in steady_axes: new_shape[index] = self.data.shape[index] new_shape[unfolded_axis] = -1 self.data = self.data.reshape(new_shape) self.axes_manager = self.axes_manager.deepcopy() i = 0 uname = '' uunits = '' to_remove = [] for axis, dim in zip(self.axes_manager.axes, new_shape): if dim == 1: uname += ',' + axis.name uunits = ',' + axis.units to_remove.append(axis) else: axis.index_in_array = i i += 1 self.axes_manager.axes[unfolded_axis].name += uname self.axes_manager.axes[unfolded_axis].units += uunits self.axes_manager.axes[unfolded_axis].size = \ self.data.shape[unfolded_axis] for axis in to_remove: self.axes_manager.axes.remove(axis) self.data = self.data.squeeze() self._replot() def unfold(self): """Modifies the shape of the data by unfolding the signal and navigation dimensions separaterly """ self.unfold_navigation_space() self.unfold_signal_space() def unfold_navigation_space(self): """Modify the shape of the data to obtain a navigation space of dimension 1 """ if self.axes_manager.navigation_dimension < 2: messages.information('Nothing done, the navigation dimension was ' 'already 1') return False steady_axes = [ axis.index_in_array for axis in self.axes_manager._slicing_axes ] unfolded_axis = self.axes_manager._non_slicing_axes[-1].index_in_array self._unfold(steady_axes, unfolded_axis) def unfold_signal_space(self): """Modify the shape of the data to obtain a signal space of dimension 1 """ if self.axes_manager.signal_dimension < 2: messages.information('Nothing done, the signal dimension was ' 'already 1') return False steady_axes = [ axis.index_in_array for axis in self.axes_manager._non_slicing_axes ] unfolded_axis = self.axes_manager._slicing_axes[-1].index_in_array self._unfold(steady_axes, unfolded_axis) def fold(self): """If the signal was previously unfolded, folds it back""" if self._shape_before_unfolding is not None: self.data = self.data.reshape(self._shape_before_unfolding) self.axes_manager = self._axes_manager_before_unfolding self._shape_before_unfolding = None self._axes_manager_before_unfolding = None self._replot() def _get_positive_axis_index_index(self, axis): if axis < 0: axis = len(self.data.shape) + axis return axis def iterate_axis(self, axis=-1): # We make a copy to guarantee that the data in contiguous, otherwise # it will not return a view of the data self.data = self.data.copy() axis = self._get_positive_axis_index_index(axis) unfolded_axis = axis - 1 new_shape = [1] * len(self.data.shape) new_shape[axis] = self.data.shape[axis] new_shape[unfolded_axis] = -1 # Warning! if the data is not contigous it will make a copy!! data = self.data.reshape(new_shape) for i in xrange(data.shape[unfolded_axis]): getitem = [0] * len(data.shape) getitem[axis] = slice(None) getitem[unfolded_axis] = i yield (data[getitem]) def sum(self, axis, return_signal=False): """Sum the data over the specify axis Parameters ---------- axis : int The axis over which the operation will be performed return_signal : bool If False the operation will be performed on the current object. If True, the current object will not be modified and the operation will be performed in a new signal object that will be returned. Returns ------- Depending on the value of the return_signal keyword, nothing or a signal instance See also -------- sum_in_mask, mean Usage ----- >>> import numpy as np >>> s = Signal({'data' : np.random.random((64,64,1024))}) >>> s.data.shape (64,64,1024) >>> s.sum(-1) >>> s.data.shape (64,64) # If we just want to plot the result of the operation s.sum(-1, True).plot() """ if return_signal is True: s = self.deepcopy() else: s = self s.data = s.data.sum(axis) s.axes_manager.axes.remove(s.axes_manager.axes[axis]) for _axis in s.axes_manager.axes: if _axis.index_in_array > axis: _axis.index_in_array -= 1 s.axes_manager.set_signal_dimension() if return_signal is True: return s def mean(self, axis, return_signal=False): """Average the data over the specify axis Parameters ---------- axis : int The axis over which the operation will be performed return_signal : bool If False the operation will be performed on the current object. If True, the current object will not be modified and the operation will be performed in a new signal object that will be returned. Returns ------- Depending on the value of the return_signal keyword, nothing or a signal instance See also -------- sum_in_mask, mean Usage ----- >>> import numpy as np >>> s = Signal({'data' : np.random.random((64,64,1024))}) >>> s.data.shape (64,64,1024) >>> s.mean(-1) >>> s.data.shape (64,64) # If we just want to plot the result of the operation s.mean(-1, True).plot() """ if return_signal is True: s = self.deepcopy() else: s = self s.data = s.data.mean(axis) s.axes_manager.axes.remove(s.axes_manager.axes[axis]) for _axis in s.axes_manager.axes: if _axis.index_in_array > axis: _axis.index_in_array -= 1 s.axes_manager.set_signal_dimension() if return_signal is True: return s def copy(self): return (copy.copy(self)) def deepcopy(self): return (copy.deepcopy(self)) # def sum_in_mask(self, mask): # """Returns the result of summing all the spectra in the mask. # # Parameters # ---------- # mask : boolean numpy array # # Returns # ------- # Spectrum # """ # dc = self.data_cube.copy() # mask3D = mask.reshape([1,] + list(mask.shape)) * np.ones(dc.shape) # dc = (mask3D*dc).sum(1).sum(1) / mask.sum() # s = Spectrum() # s.data_cube = dc.reshape((-1,1,1)) # s.get_dimensions_from_cube() # utils.copy_energy_calibration(self,s) # return s # # def mean(self, axis): # """Average the SI over the given axis # # Parameters # ---------- # axis : int # """ # dc = self.data_cube # dc = dc.mean(axis) # dc = dc.reshape(list(dc.shape) + [1,]) # self.data_cube = dc # self.get_dimensions_from_cube() # # def roll(self, axis = 2, shift = 1): # """Roll the SI. see numpy.roll # # Parameters # ---------- # axis : int # shift : int # """ # self.data_cube = np.roll(self.data_cube, shift, axis) # self._replot() # # # def get_calibration_from(self, s): # """Copy the calibration from another Spectrum instance # Parameters # ---------- # s : spectrum instance # """ # utils.copy_energy_calibration(s, self) # # def estimate_variance(self, dc = None, gaussian_noise_var = None): # """Variance estimation supposing Poissonian noise # # Parameters # ---------- # dc : None or numpy array # If None the SI is used to estimate its variance. Otherwise, the # provided array will be used. # Note # ---- # The gain_factor and gain_offset from the aquisition parameters are used # """ # print "Variace estimation using the following values:" # print "Gain factor = ", self.acquisition_parameters.gain_factor # print "Gain offset = ", self.acquisition_parameters.gain_offset # if dc is None: # dc = self.data_cube # gain_factor = self.acquisition_parameters.gain_factor # gain_offset = self.acquisition_parameters.gain_offset # self.variance = dc*gain_factor + gain_offset # if self.variance.min() < 0: # if gain_offset == 0 and gaussian_noise_var is None: # print "The variance estimation results in negative values" # print "Maybe the gain_offset is wrong?" # self.variance = None # return # elif gaussian_noise_var is None: # print "Clipping the variance to the gain_offset value" # self.variance = np.clip(self.variance, np.abs(gain_offset), # np.Inf) # else: # print "Clipping the variance to the gaussian_noise_var" # self.variance = np.clip(self.variance, gaussian_noise_var, # np.Inf) # # def calibrate(self, lcE = 642.6, rcE = 849.7, lc = 161.9, rc = 1137.6, # modify_calibration = True): # dispersion = (rcE - lcE) / (rc - lc) # origin = lcE - dispersion * lc # print "Energy step = ", dispersion # print "Energy origin = ", origin # if modify_calibration is True: # self.set_new_calibration(origin, dispersion) # return origin, dispersion # def _correct_navigation_mask_when_unfolded( self, navigation_mask=None, ): #if 'unfolded' in self.history: if navigation_mask is not None: navigation_mask = navigation_mask.reshape((-1, )) return navigation_mask
class AxesManager(t.HasTraits): axes = t.List(DataAxis) _slicing_axes = t.List() _non_slicing_axes = t.List() _step = t.Int(1) def __init__(self, axes_list): super(AxesManager, self).__init__() ncoord = len(axes_list) self.axes = [None] * ncoord for axis_dict in axes_list: self.axes[axis_dict['index_in_array']] = DataAxis(**axis_dict) slices = [i.slice_bool for i in self.axes if hasattr(i, 'slice_bool')] # set_view is called only if there is no current view if not slices or np.all(np.array(slices) == False): self.set_view() self.set_signal_dimension() self.on_trait_change(self.set_signal_dimension, 'axes.slice') self.on_trait_change(self.set_signal_dimension, 'axes.index') def set_signal_dimension(self): getitem_tuple = [] indexes = [] values = [] self._slicing_axes = [] self._non_slicing_axes = [] for axis in self.axes: if axis.slice is None: getitem_tuple.append(axis.index) indexes.append(axis.index) values.append(axis.value) self._non_slicing_axes.append(axis) else: getitem_tuple.append(axis.slice) self._slicing_axes.append(axis) self._getitem_tuple = getitem_tuple self._indexes = np.array(indexes) self._values = np.array(values) self.signal_dimension = len(self._slicing_axes) self.navigation_dimension = len(self._non_slicing_axes) self.navigation_shape = [axis.size for axis in self._non_slicing_axes] def set_not_slicing_indexes(self, nsi): for index, axis in zip(nsi, self.axes): axis.index = index def set_view(self, view='hyperspectrum'): """view : 'hyperspectrum' or 'image' """ tl = [False] * len(self.axes) if view == 'hyperspectrum': # We limit the signal_dimension to 1 to get a spectrum tl[0] = True elif view == 'image': tl[:2] = True, True for axis in self.axes: axis.slice_bool = tl.pop() def set_slicing_axes(self, slicing_axes): '''Easily choose which axes are slicing Parameters ---------- slicing_axes: tuple of ints A list of the axis indexes that we want to slice ''' for axis in self.axes: if axis.index_in_array in slicing_axes: axis.slice_bool = True else: axis.slice_bool = False def connect(self, f): for axis in self.axes: if axis.slice is None: axis.on_trait_change(f, 'index') def disconnect(self, f): for axis in self.axes: if axis.slice is None: axis.on_trait_change(f, 'index', remove=True) def key_navigator(self, event): if len(self._non_slicing_axes) not in (1, 2): return x = self._non_slicing_axes[-1] if event.key == "right" or event.key == "6": x.index += self._step elif event.key == "left" or event.key == "4": x.index -= self._step elif event.key == "pageup": self._step += 1 elif event.key == "pagedown": if self._step > 1: self._step -= 1 if len(self._non_slicing_axes) == 2: y = self._non_slicing_axes[-2] if event.key == "up" or event.key == "8": y.index -= self._step elif event.key == "down" or event.key == "2": y.index += self._step def edit_axes_traits(self): for axis in self.axes: axis.edit_traits() def copy(self): return (copy.copy(self)) def deepcopy(self): return (copy.deepcopy(self)) def __deepcopy__(self, *args): return AxesManager(self._get_axes_dicts()) def _get_axes_dicts(self): axes_dicts = [] for axis in self.axes: axes_dicts.append(axis.get_axis_dictionary()) return axes_dicts def _get_slicing_axes_dicts(self): axes_dicts = [] i = 0 for axis in self._slicing_axes: axes_dicts.append(axis.get_axis_dictionary()) axes_dicts[-1]['index_in_array'] = i i += 1 return axes_dicts def _get_non_slicing_axes_dicts(self): axes_dicts = [] i = 0 for axis in self._non_slicing_axes: axes_dicts.append(axis.get_axis_dictionary()) axes_dicts[-1]['index_in_array'] = i i += 1 return axes_dicts traits_view = tui.View(tui.Item('axes', style='custom'))
class Parameter(traits.HasTraits): """represents a lmfit variable in a fit. E.g. the standard deviation in a gaussian fit""" parameter = traits.Instance(lmfit.Parameter) name = traits.Str initialValue = traits.Float calculatedValue = traits.Float vary = traits.Bool(True) minimumEnable = traits.Bool(False) minimum = traits.Float maximumEnable = traits.Bool(False) maximum = traits.Float stdevError = traits.Float def __init__(self, **traitsDict): super(Parameter, self).__init__(**traitsDict) self.parameter = lmfit.Parameter(name=self.name) def _initialValue_changed(self): self.parameter.set(value=self.initialValue) def _vary_changed(self): self.parameter.set(vary=self.vary) def _minimum_changed(self): if self.minimumEnable: self.parameter.set(min=self.minimum) def _maximum_changed(self): if self.maximumEnabled: self.parameter.set(max=self.maximum) traits_view = traitsui.View(traitsui.VGroup( traitsui.HGroup( traitsui.Item("vary", label="vary?", resizable=True), traitsui.Item("name", show_label=False, style="readonly", width=0.2, resizable=True), traitsui.Item("initialValue", label="initial", show_label=True, resizable=True), traitsui.Item("calculatedValue", label="calculated", show_label=True, format_str="%G", style="readonly", width=0.2, resizable=True), traitsui.Item("stdevError", show_label=False, format_str=u"\u00B1%G", style="readonly", resizable=True)), traitsui.HGroup( traitsui.Item("minimumEnable", label="min?", resizable=True), traitsui.Item("minimum", label="min", resizable=True, visible_when="minimumEnable"), traitsui.Item("maximumEnable", label="max?", resizable=True), traitsui.Item("maximum", label="max", resizable=True, visible_when="maximumEnable"))), kind="subpanel")
class Reports(traits.HasTraits): """A user interface that displays all the reports in a tabbed notebook.""" summary_report = traits.Instance(SummaryReport) vehicle_report = traits.Instance(VehicleReport) pax_report = traits.Instance(PaxReport) station_report = traits.Instance(StationReport) power_report = traits.Instance(PowerReport) refresh = menu.Action(name="Refresh", action="refresh") view = ui.View( ui.Tabbed( ui.Item('summary_report', label='Summary', editor=ui.TextEditor(), style='readonly'), ui.Item('vehicle_report', label='Vehicles', editor=ui.InstanceEditor(), style='custom' ), ui.Item('pax_report', label='Passengers', editor=ui.InstanceEditor(), style='custom' ), ui.Item('station_report', label='Stations', editor=ui.InstanceEditor(), style='custom' ), ui.Item('power_report', label='Power', editor=ui.InstanceEditor(), style='custom'), show_labels=False, ), title = 'Simulation Reports', width=1000, resizable=True, handler=ReportsHandler(), buttons= [], #[refresh], #TODO: Disabling the refresh button until I can get it to refresh all reports properly kind='live') def __init__(self): super(Reports, self).__init__() self.summary_report = SummaryReport() self.pax_report = PaxReport() self.vehicle_report = VehicleReport() self.station_report = StationReport() self.power_report = PowerReport() self._last_update_time = None def update(self): if self._last_update_time == Sim.now(): return self.pax_report.update() self.vehicle_report.update() self.station_report.update() self.power_report.update() self.summary_report.update(self.pax_report, self.vehicle_report, self.station_report, self.power_report) self._last_update_time = Sim.now() def display(self, evt=None): self.update() self.edit_traits() def write(self, report_path, update=True): """Writes the report to the filename specified by report_path. Use '-' to write to stdout.""" if update: self.update() if report_path == '-': out = stdout else: out = open(report_path, 'w') out.write(str(self.summary_report)) out.write('\n\n') out.write(str(self.pax_report)) out.write('\n\n') out.write(str(self.vehicle_report)) out.write('\n\n') out.write(str(self.station_report))
class UiTest(eta.HasTraits): def __init__(self, fname=""): self.data = eta.Dict() self.fname = eta.File() self.plotting = eta.String() self.open_VERT = eta.Button("Open VERT") self.only_I = False self.only_dIdV = False self.pointers = [] self.fname = os.getcwdu() self.plotting = "V" def _load_VERT(self): if self.fname[-5:-1] == ".VER": try: vm = VerticalManipulation(self.fname, self.plotting) vm.open_file() vm.reading_header() vm.load_data() vm.compute_data() vm.plot_data() vm.close_file() except: print("Warning : Fail to open the vertical manipulation.") raise else: print("Warning : " + self.fname + " type not known.") return def _load_VERT_vs_bias(self): self.plotting = "V" self._load_VERT() def _load_VERT_vs_index(self): self.plotting = "i" self._load_VERT() def _load_VERT_vs_duration(self): self.plotting = "t" self._load_VERT() def _load_VERT_vs_height(self): self.plotting = "z" self._load_VERT() open_VERT_vs_bias = etum.Action(name='VERT vs bias', action='_load_VERT_vs_bias') open_VERT_vs_index = etum.Action(name='VERT vs index', action='_load_VERT_vs_index') open_VERT_vs_duration = etum.Action(name='VERT vs duration', action='_load_VERT_vs_duration') open_VERT_vs_height = etum.Action(name='VERT vs height', action='_load_VERT_vs_height') view = etua.View( # etua.Item("data", style="simple"), # etua.Item('fname', editor=etua.FileEditor(filter = ['*.plot'], auto_set = True), style = "custom"), etua.Item('fname', editor=etua.FileEditor(auto_set=True), style="custom"), toolbar=etum.ToolBar(open_VERT_vs_bias, open_VERT_vs_index, open_VERT_vs_duration, open_VERT_vs_height), resizable=True, scrollable=True, title="Vertical Manipulation UI", height=640, width=800)
class Fit(traits.HasTraits): name = traits.Str(desc="name of fit") function = traits.Str(desc="function we are fitting with all parameters") variablesList = traits.List(FitVariable) calculatedParametersList = traits.List(CalculatedParameter) xs = None # will be a scipy array ys = None # will be a scipy array zs = None # will be a scipy array performFitButton = traits.Button("Perform Fit") getInitialParametersButton = traits.Button("Guess Initial Values") drawRequestButton = traits.Button("Draw Fit") autoFitBool = traits.Bool( False, desc= "Automatically perform this Fit with current settings whenever a new image is loaded" ) autoGuessBool = traits.Bool( False, desc= "Whenever a fit is completed replace the guess values with the calculated values (useful for increasing speed of the next fit)" ) autoDrawBool = traits.Bool( False, desc= "Once a fit is complete update the drawing of the fit or draw the fit for the first time" ) logBool = traits.Bool( False, desc="Log the calculated and fitted values with a timestamp") logFile = traits.File(desc="file path of logFile") imageInspectorReference = None #will be a reference to the image inspector fitting = traits.Bool(False) #true when performing fit fitted = traits.Bool( False) #true when current data displayed has been fitted fitSubSpace = traits.Bool( False) #true when current data displayed has been fitted startX = traits.Int startY = traits.Int endX = traits.Int endY = traits.Int fittingStatus = traits.Str() fitThread = None physics = traits.Instance(physicsProperties.PhysicsProperties) #status strings notFittedForCurrentStatus = "Not Fitted for Current Image" fittedForCurrentImageStatus = "Fit Complete for Current Image" currentlyFittingStatus = "Currently Fitting..." failedFitStatus = "Failed to finish fit. See logger" fitSubSpaceGroup = traitsui.VGroup( traitsui.Item("fitSubSpace", label="Fit Sub Space"), traitsui.VGroup(traitsui.HGroup(traitsui.Item("startX"), traitsui.Item("startY")), traitsui.HGroup(traitsui.Item("endX"), traitsui.Item("endY")), visible_when="fitSubSpace"), label="Fit Sub Space", show_border=True) generalGroup = traitsui.VGroup(traitsui.Item("name", label="Fit Name", style="readonly", resizable=True), traitsui.Item("function", label="Fit Function", style="readonly", resizable=True), fitSubSpaceGroup, label="Fit", show_border=True) variablesGroup = traitsui.VGroup(traitsui.Item( "variablesList", editor=traitsui.ListEditor(style="custom"), show_label=False, resizable=True), show_border=True, label="parameters") derivedGroup = traitsui.VGroup(traitsui.Item( "calculatedParametersList", editor=traitsui.ListEditor(style="custom"), show_label=False, resizable=True), show_border=True, label="derived values") buttons = traitsui.VGroup( traitsui.HGroup(traitsui.Item("autoFitBool"), traitsui.Item("performFitButton")), traitsui.HGroup(traitsui.Item("autoGuessBool"), traitsui.Item("getInitialParametersButton")), traitsui.HGroup(traitsui.Item("autoDrawBool"), traitsui.Item("drawRequestButton"))) logGroup = traitsui.HGroup(traitsui.Item("logBool"), traitsui.Item("logFile", visible_when="logBool"), label="Logging", show_border=True) actionsGroup = traitsui.VGroup(traitsui.Item("fittingStatus", style="readonly"), logGroup, buttons, label="Fit Actions", show_border=True) traits_view = traitsui.View( traitsui.VGroup(generalGroup, variablesGroup, derivedGroup, actionsGroup)) def __init__(self, **traitsDict): super(Fit, self).__init__(**traitsDict) self.startX = 0 self.startY = 0 def _set_xs(self, xs): self.xs = xs def _set_ys(self, ys): self.ys = ys def _set_zs(self, zs): self.zs = zs def _fittingStatus_default(self): return self.notFittedForCurrentStatus def _getInitialValues(self): """returns ordered list of initial values from variables List """ return [_.initialValue for _ in self.variablesList] def _getCalculatedValues(self): """returns ordered list of initial values from variables List """ return [_.calculatedValue for _ in self.variablesList] def _log_fit(self): if self.logFile == "": logger.warning("no log file defined. Will not log") return if not os.path.exists(self.logFile): variables = [_.name for _ in self.variablesList] calculated = [_.name for _ in self.calculatedParametersList] times = ["datetime", "epoch seconds"] info = ["img file name"] columnNames = times + info + variables + calculated with open(self.logFile, 'a+') as logFile: writer = csv.writer(logFile) writer.writerow(columnNames) #column names already exist so... variables = [_.calculatedValue for _ in self.variablesList] calculated = [_.value for _ in self.calculatedParametersList] now = time.time() #epoch seconds timeTuple = time.localtime(now) date = time.strftime("%Y-%m-%dT%H:%M:%S", timeTuple) times = [date, now] info = [self.imageInspectorReference.selectedFile] data = times + info + variables + calculated with open(self.logFile, 'a+') as logFile: writer = csv.writer(logFile) writer.writerow(data) def _intelligentInitialValues(self): """If possible we can auto set the initial parameters to intelligent guesses user can always overwrite them """ self._setInitialValues(self._getIntelligentInitialValues()) def _get_subSpaceArrays(self): """returns the arrays of the selected sub space. If subspace is not activated then returns the full arrays""" if self.fitSubSpace: xs = self.xs[self.startX:self.endX] ys = self.ys[self.startY:self.endY] logger.debug("xs array sliced length %s " % (xs.shape)) logger.debug("ys array sliced length %s " % (ys.shape)) zs = self.zs[self.startY:self.endY, self.startX:self.endX] print zs print zs.shape logger.debug("zs sub space array %s,%s " % (zs.shape)) return xs, ys, zs else: return self.xs, self.ys, self.zs def _getIntelligentInitialValues(self): """If possible we can auto set the initial parameters to intelligent guesses user can always overwrite them """ logger.debug("Dummy function should not be called directly") return def fitFunc(self, data, *p): """Function that we are trying to fit to. """ logger.error("Dummy function should not be called directly") return def _setCalculatedValues(self, calculated): """updates calculated values with calculated argument """ c = 0 for variable in self.variablesList: variable.calculatedValue = calculated[c] c += 1 def _setCalculatedValuesErrors(self, covarianceMatrix): """given the covariance matrix returned by scipy optimize fit convert this into stdeviation errors for parameters list and updated the stdevError attribute of variables""" logger.debug("covariance matrix -> %s " % covarianceMatrix) parameterErrors = scipy.sqrt(scipy.diag(covarianceMatrix)) logger.debug("parameterErrors -> %s " % parameterErrors) c = 0 for variable in self.variablesList: variable.stdevError = parameterErrors[c] c += 1 def _setInitialValues(self, guesses): """updates calculated values with calculated argument """ c = 0 for variable in self.variablesList: variable.initialValue = guesses[c] c += 1 def deriveCalculatedParameters(self): """Wrapper for subclass definition of deriving calculated parameters can put more general calls in here""" if self.fitted: self._deriveCalculatedParameters() def _deriveCalculatedParameters(self): """Should be implemented by subclass. should update all variables in calculate parameters list""" logger.error("Should only be called by subclass") return def _fit_routine(self): """This function performs the fit in an appropriate thread and updates necessary values when the fit has been performed""" self.fitting = True if self.fitThread and self.fitThread.isAlive(): logger.warning( "Fitting is already running cannot kick off a new fit until it has finished!" ) return else: self.fitThread = FitThread() self.fitThread.fitReference = self self.fitThread.start() self.fittingStatus = self.currentlyFittingStatus def _perform_fit(self): """Perform the fit using scipy optimise curve fit. We must supply x and y as one argument and zs as anothger. in the form xs: 0 1 2 0 1 2 0 ys: 0 0 0 1 1 1 2 zs: 1 5 6 1 9 8 2 Hence the use of repeat and tile in positions and unravel for zs initially xs,ys is a linspace array and zs is a 2d image array """ if self.xs is None or self.ys is None or self.zs is None: logger.warning( "attempted to fit data but had no data inside the Fit object. set xs,ys,zs first" ) return ([], []) p0 = self._getInitialValues() if self.fitSubSpace: #fit only the sub space #create xs, ys and zs which are appropriate slices of the arrays xs, ys, zs = self._get_subSpaceArrays() positions = [scipy.tile(xs, len(ys)), scipy.repeat(ys, len(xs)) ] #for creating data necessary for gauss2D function params2D, cov2D = scipy.optimize.curve_fit(self.fitFunc, positions, scipy.ravel(zs), p0=p0) chi2 = scipy.sum( (scipy.ravel(zs) - self.fitFunc(positions, *params2D))**2 / self.fitFunc(positions, *params2D)) logger.debug("TEMPORARY ::: CHI^2 = %s " % chi2) else: #fit the whole array of data (slower) positions = [ scipy.tile(self.xs, len(self.ys)), scipy.repeat(self.ys, len(self.xs)) ] #for creating data necessary for gauss2D function #note that it is necessary to ravel zs as curve_fit expects a flattened array params2D, cov2D = scipy.optimize.curve_fit(self.fitFunc, positions, scipy.ravel(self.zs), p0=p0) return params2D, cov2D def _performFitButton_fired(self): self._fit_routine() def _getInitialParametersButton_fired(self): self._intelligentInitialValues() def _drawRequestButton_fired(self): """tells the imageInspector to try and draw this fit as an overlay contour plot""" self.imageInspectorReference.addFitPlot(self) def _getFitFuncData(self): """if data has been fitted, this returns the zs data for the ideal fitted function using the calculated paramters""" positions = [ scipy.tile(self.xs, len(self.ys)), scipy.repeat(self.ys, len(self.xs)) ] #for creating data necessary for gauss2D function zsravelled = self.fitFunc(positions, *self._getCalculatedValues()) return zsravelled.reshape(self.zs.shape)
class CSplinePlotter(traits.HasTraits): """Generates and displays a plot for a cubic spline.""" container = traits.Instance(chaco.OverlayPlotContainer) plotdata = traits.Instance(chaco.ArrayPlotData) traits_view = ui.View(ui.Item('container', editor=ComponentEditor(), show_label=False), width=500, height=500, resizable=True, title='CubicSpline Plot') def __init__(self, cubic_spline, velocity_max=0, acceleration_max=0, jerk_max=0, velocity_min=None, acceleration_min=None, jerk_min=None, title="", start_idx=0, end_idx=-1, mass=None, plot_pos=True, plot_vel=True, plot_accel=True, plot_jerk=True, plot_power=True): """If a 'mass' argument is supplied, then the power will be plotted.""" super(CSplinePlotter, self).__init__() self.cspline = cubic_spline self.v_max = velocity_max self.a_max = acceleration_max self.j_max = jerk_max self.v_min = 0 if velocity_min is None else velocity_min self.a_min = -self.a_max if acceleration_min is None else acceleration_min self.j_min = -self.j_max if jerk_min is None else jerk_min self.title = title self.mass = mass self.plot_pos = plot_pos self.plot_vel = plot_vel self.plot_accel = plot_accel self.plot_jerk = plot_jerk self.plot_power = plot_power and mass is not None self.container = chaco.OverlayPlotContainer(padding=52, fill_padding=True, bgcolor="transparent") self.make_plotdata(start_idx, end_idx) self.make_plots() def make_plotdata(self, start_idx, end_idx): if end_idx < 0: end_idx = len( self.cspline.t) + end_idx # convert to absolute index knot_times = self.cspline.t[start_idx:end_idx + 1] sample_times = numpy.linspace(self.cspline.t[start_idx], self.cspline.t[end_idx], 200) endpoint_times = numpy.array( [self.cspline.t[start_idx], self.cspline.t[end_idx]]) positions = [] velocities = [] powers = [] samples = self.cspline.evaluate_sequence(sample_times) for sample in samples: if self.plot_pos: positions.append(sample.pos) if self.plot_vel: velocities.append(sample.vel) if self.plot_power: powers.append(self.mass * sample.accel * sample.vel / 1000.0) # In KWs if self.plot_accel: accelerations = numpy.array(self.cspline.a[start_idx:end_idx + 1]) else: accelerations = [] if self.plot_jerk and len(self.cspline.j): jerks = numpy.array(self.cspline.j[start_idx:end_idx] + [self.cspline.j[end_idx - 1]]) else: jerks = [] max_vel = numpy.array([self.v_max for t in endpoint_times]) min_vel = numpy.array([self.v_min for t in endpoint_times]) max_accel = numpy.array([self.a_max for t in endpoint_times]) min_accel = numpy.array([self.a_min for t in endpoint_times]) max_jerk = numpy.array([self.j_max for t in endpoint_times]) min_jerk = numpy.array([self.j_min for t in endpoint_times]) self.plotdata = chaco.ArrayPlotData(positions=positions, endpoint_times=endpoint_times, knot_times=knot_times, sample_times=sample_times, velocities=velocities, accelerations=accelerations, jerks=jerks, powers=powers, max_vel=max_vel, min_vel=min_vel, max_accel=max_accel, min_accel=min_accel, max_jerk=max_jerk, min_jerk=min_jerk) def make_plots(self): main_plot = chaco.Plot(self.plotdata, padding=0) colors = { 'pos': 'black', 'vel': 'blue', 'accel': 'red', 'jerk': 'green', 'power': 'purple' } left_y_axis_title_list = [] legend_dict = {} if self.plot_vel: vel_plot = main_plot.plot(("sample_times", "velocities"), type="line", color=colors['vel'], line_width=2) max_vel_plot = main_plot.plot(("endpoint_times", "max_vel"), color=colors['vel'], line_style='dash', line_width=0.60) min_vel_plot = main_plot.plot(("endpoint_times", "min_vel"), color=colors['vel'], line_style='dash', line_width=0.60) left_y_axis_title_list.append("Velocity (m/s)") legend_dict['vel'] = vel_plot if self.plot_accel: accel_plot = main_plot.plot(("knot_times", "accelerations"), type="line", color=colors['accel'], line_width=2) max_accel_plot = main_plot.plot(("endpoint_times", "max_accel"), color=colors['accel'], line_style='dash', line_width=0.55) min_accel_plot = main_plot.plot(("endpoint_times", "min_accel"), color=colors['accel'], line_style='dash', line_width=0.55) left_y_axis_title_list.append("Accel (m/s2)") legend_dict['accel'] = accel_plot if self.plot_jerk: jerk_plot = main_plot.plot(("knot_times", "jerks"), type="line", color=colors['jerk'], line_width=2, render_style="connectedhold") max_jerk_plot = main_plot.plot(("endpoint_times", "max_jerk"), color=colors['jerk'], line_style='dash', line_width=0.45) min_jerk_plot = main_plot.plot(("endpoint_times", "min_jerk"), color=colors['jerk'], line_style='dash', line_width=0.45) left_y_axis_title_list.append("Jerk (m/s3)") legend_dict['jerk'] = jerk_plot if self.plot_power: power_plot = main_plot.plot(("sample_times", "powers"), type="line", color=colors['power'], line_width=2) left_y_axis_title_list.append("Power (KW)") legend_dict['power'] = power_plot main_plot.y_axis.title = ", ".join(left_y_axis_title_list) self.container.add(main_plot) # plot positions (on a separate scale from the others) if self.plot_pos: pos_plot = chaco.create_line_plot([ self.plotdata.arrays["sample_times"], self.plotdata.arrays["positions"] ], color=colors['pos'], width=2) legend_dict['pos'] = pos_plot self.container.add(pos_plot) # add a second y-axis for the positions pos_y_axis = chaco.PlotAxis(pos_plot, orientation="right", title="Position (meters)") self.container.overlays.append(pos_y_axis) # make Legend legend = chaco.Legend(component=self.container, padding=20, align="ul") legend.plots = legend_dict legend.tools.append(tools.LegendTool(legend, drag_button="left")) self.container.overlays.append(legend) # Add title, if any if self.title: main_plot.title = self.title main_plot.title_position = "inside top" def display_plot(self): self.configure_traits()