class AbstractConfigurator(core.Agent): implements(IConfigurator) name = 'configurator' options = {"full_parameters"} """ Here we specify the names of the options for the configurator. Options are accumulated along the inheritance path """ node_identifiers = Any full_parameters = Dict(key_trait=Str) def _start(self): self.create_nodes() self.create_edges() self.initialize_nodes() def create_nodes(self): raise NotImplementedError() def create_edges(self): raise NotImplementedError() def do_initialize(self): for identifier in self.node_identifiers: self.send(identifier, 'initialize') def do_not_initialize(self): pass initialize_nodes = do_not_initialize
class Bar(HasTraits): foos = Dict(Str, Foo) modified = Event @on_trait_change("foos_items,foos.modified") def _fire_modified_event(self, obj, trait_name, old, new): self.modified = True
class ConfoundsInputSpec(BaseInterfaceInputSpec): pipeline = Dict(mandatory=True, desc="Denoising pipeline") conf_raw = File(exist=True, mandatory=True, desc="Confounds table") conf_json = File(exist=True, mandatory=True, desc="Confounds description (aCompCor)") output_dir = Directory(exists=True, mandatory=True, desc="Output path")
class NestedContainerClass(HasTraits): # Used in regression test for changes to nested containers # Nested list list_of_list = List(List) # enthought/traits#281 dict_of_list = Dict(Str, List(Str)) # Similar to enthought/traits#281 dict_of_union_none_or_list = Dict(Str, Union(List(), None)) # Nested dict # enthought/traits#25 list_of_dict = List(Dict) dict_of_dict = Dict(Str, Dict) dict_of_union_none_or_dict = Dict(Str, Union(Dict(), None)) # Nested set list_of_set = List(Set) dict_of_set = Dict(Str, Set) dict_of_union_none_or_set = Dict(Str, Union(Set(), None))
class A(HasTraits): alist = List(Int, list(range(5))) adict = Dict(Str, Int, dict(a=1, b=2)) aset = Set(Int, set(range(5))) events = List() @on_trait_change("alist_items,adict_items,aset_items") def _receive_events(self, object, name, old, new): self.events.append((name, new))
class BIDSDataSinkInputSpec(BaseInterfaceInputSpec): base_entities = Dict( key_trait=Str, value_trait=Str, value=dict(), # default value mandatory=False, desc="Optional base entities that will overwrite values from incoming file" ) in_file = File( desc="File from tmp to save in BIDS directory")
def test_validate(self): """ Check the validation method. """ foo = Dict() # invalid value with self.assertRaises(TraitError): foo.validate(object=HasTraits(), name="bar", value=None) # valid value result = foo.validate(object=HasTraits(), name="bar", value={}) self.assertIsInstance(result, TraitDictObject) # object is None (check for issue #71) result = foo.validate(object=None, name="bar", value={}) self.assertEqual(result, {})
class ReportCreatorInputSpec(BaseInterfaceInputSpec): pipelines = List(Dict(), mandatory=True) pipelines_names = List(Str(), mandatory=True) group_data_dir = Directory(exists=True) excluded_subjects = List(Str(), value=()) plot_pipeline_edges_density = File( exists=True, desc="Density of edge weights (all subjects)") plot_pipelines_edges_density_no_high_motion = File( exist=True, desc="Density of edge weights (no high motion)") plot_pipelines_fc_fd_pearson = File(exist=True) plot_pipelines_fc_fd_uncorr = File(exist=True) plot_pipelines_distance_dependence = File(exist=True)
class TestClass(HasTraits): dict_1 = Dict(Str) dict_2 = Dict(Int, Str)
class PipelineSelectorOutPutSpecification(TraitedSpec): pipeline = Dict(items=True) pipeline_name = Str(desc="Name of denoising strategy")
class ReportCreatorInputSpec(BaseInterfaceInputSpec): pipelines = List( Dict("Dictionary pipeline"), mandatory=True ) tasks = List( Str(), mandatory=True) output_dir = Directory(exists=True) sessions = List( Str(), mandatory=False) runs = List( Int(), mandatory=False) runtime_info = Instance(RuntimeInfo, mandatory=True) excluded_subjects = List( trait=Instance(ExcludedSubjects), value=[], usedefault=True ) warnings = List( trait=Instance(ErrorData), desc="ErrorData objects with all relevant entities error source and error message", value=[], usedefault=True ) # Aggregated over pipelines plots_all_pipelines_edges_density = List( Optional( File( exists=True, desc="Density of edge weights (all pipelines) for all subjects" ))) plots_all_pipelines_edges_density_no_high_motion = List( Optional( File( exists=True, desc="Density of edge weights (all pipelines) without high motion subjects" ))) plots_all_pipelines_fc_fd_pearson_info = List( Optional(File( exists=True, desc="Barplot and violinplot showing percent of significant fc-fd correlations and distribution of Pearson's r values for all subjects" ))) plots_all_pipelines_fc_fd_pearson_info_no_high_motion = List( Optional( File( exists=True, desc="Barplot and violinplot showing percent of significant fc-fd correlations and distribution of Pearson's r values without high motion subjects" ))) plots_all_pipelines_distance_dependence = List( Optional( File( exists=True, desc="Barplot showing mean Spearman's rho between fd-fc correlation and Euclidean distance between ROIs for all subject" ))) plots_all_pipelines_distance_dependence_no_high_motion = List( Optional( File( exists=True, desc="Barplot showing mean Spearman's rho between fd-fc correlation and Euclidean distance between ROIs without high motion subjects" ))) plots_all_pipelines_tdof_loss = List( Optional( File( exists=True, desc="Barplot showing degree of freedom loss (number of regressors included in each pipeline." ))) # For single pipeline plots_pipeline_fc_fd_pearson_matrix = List( Optional( File( exists=True, desc="Matrix showing correlation between connection strength and motion for all subjects" ))) plots_pipeline_fc_fd_pearson_matrix_no_high_motion = List( Optional( File( exists=True, desc="Matrix showing correlation between connection strength and motion without high motion subjects" )))
class Foo(HasTraits): mapping = Dict(items=False)
class Foo(HasTraits): mapping = Dict(Any, Str)
class SystemObject(HasStrictTraits): """ Baseclass for Programs, Sensor, Actuators """ #: Names of attributes that accept Callables. If there are custom callables being used, they must be added here. #: The purpose of this list is that these Callables will be initialized properly. #: :class:`~automate.program.ProgrammableSystemObject` introduces 5 basic callables #: (see also :ref:`automate-programs`). callables = [] def get_default_callables(self): """ Get a dictionary of default callables, in form {name:callable}. Re-defined in subclasses.""" return {} #: Reference to System object system = Instance(SystemBase, transient=True) #: Description of the object (shown in WEB interface) description = CUnicode #: Python Logger instance for this object. System creates each object its own logger instance. logger = Instance(logging.Logger, transient=True) #: Tags are used for (for example) grouping objects. See :ref:`groups`. tags = TagSet(trait=CUnicode) #: Name property is determined by System namespace. Can be read/written. name = Property(trait=Unicode, depends_on='name_changed_event') log_level = Int(logging.INFO) def _log_level_changed(self, new_value): if self.logger: self.logger.setLevel(new_value) @cached_property def _get_name(self): try: return self.system.reverse[self] except (KeyError, AttributeError): return 'System not initialized!' def _set_name(self, new_name): if not is_valid_variable_name(new_name): raise NameError('Illegal name %s' % new_name) try: if self in list(self.system.namespace.values()): del self.system.namespace[self.name] except NameError: pass self.system.namespace[new_name] = self self.logger = self.system.logger.getChild('%s.%s' % (self.__class__.__name__, new_name)) #: If set to *True*, current SystemObject is hidden in the UML diagram of WEB interface. hide_in_uml = CBool(False) _order = Int _count = 0 #: Attributes that can be edited by user in WEB interface view = ['hide_in_uml'] #: The data type name (as string) of the object. This is written in the initialization, and is used by WEB #: interface Django templates. data_type = '' #: If editable=True, a quick edit widget will appear in the web interface. Define in subclasses. editable = False # Namespace triggers this event when object name name is changed name_changed_event = Event _passed_arguments = Tuple(transient=True) _postponed_callables = Dict(transient=True) @property def class_name(self): # For Django templates return self.__class__.__name__ @property def object_type(self): """ A read-only property that gives the object type as string; sensor, actuator, program, other. Used by WEB interface templates. """ from .statusobject import AbstractSensor, AbstractActuator from .program import Program if isinstance(self, AbstractSensor): return 'sensor' elif isinstance(self, AbstractActuator): return 'actuator' elif isinstance(self, Program): return 'program' else: return 'other' def __init__(self, name='', **traits): # Postpone traits initialization to be launched by System self.logger = logging.getLogger('automate.%s' % self.__class__.__name__) self._order = SystemObject._count SystemObject._count += 1 self._passed_arguments = name, traits if 'system' in traits: self.setup_system(traits.pop('system')) self.setup_callables() def __setstate__(self, state, trait_change_notify=True): self.logger = logging.getLogger('automate.%s' % self.__class__.__name__) self._order = state.pop('_order') self._passed_arguments = None, state def get_status_display(self, **kwargs): """ Redefine this in subclasses if status can be represented in human-readable way (units etc.) """ if 'value' in kwargs: return str(kwargs['value']) return self.class_name def get_as_datadict(self): """ Get information about this object as a dictionary. Used by WebSocket interface to pass some relevant information to client applications. """ return dict(type=self.__class__.__name__, tags=list(self.tags)) def setup(self, *args, **kwargs): """ Initialize necessary services etc. here. Define this in subclasses. """ pass def setup_system(self, system, name_from_system='', **kwargs): """ Set system attribute and do some initialization. Used by System. """ if not self.system: self.system = system name, traits = self._passed_arguments new_name = self.system.get_unique_name(self, name, name_from_system) if not self in self.system.reverse: self.name = new_name self.logger = self.system.logger.getChild('%s.%s' % (self.__class__.__name__, self.name)) self.logger.setLevel(self.log_level) if name is None and 'name' in traits: # Only __setstate__ sets name to None. Default is ''. del traits['name'] for cname in self.callables: if cname in traits: c = self._postponed_callables[cname] = traits.pop(cname) c.setup_callable_system(self.system) getattr(self, cname).setup_callable_system(self.system) if not self.traits_inited(): super().__init__(**traits) self.name_changed_event = True self.setup() def setup_callables(self): """ Setup Callable attributes that belong to this object. """ defaults = self.get_default_callables() for key, value in list(defaults.items()): self._postponed_callables.setdefault(key, value) for key in self.callables: value = self._postponed_callables.pop(key) value.setup_callable_system(self.system, init=True) setattr(self, key, value) def cleanup(self): """ Write here whatever cleanup actions are needed when object is no longer used. """ def __str__(self): return self.name def __repr__(self): return u"'%s'" % self.name
class ExecutedAutomatedRunSpecAdapter(TabularAdapter, ConfigurableMixin): all_columns = [('Idx', 'idx'), ('-', 'result'), ('Labnumber', 'labnumber'), ('Aliquot', 'aliquot'), ('Sample', 'sample'), ('Project', 'project'), ('Material', 'material'), ('RepositoryID', 'repository_identifier'), ('Position', 'position'), ('Extract', 'extract_value'), ('Units', 'extract_units'), ('Ramp (s)', 'ramp_duration'), ('Duration (s)', 'duration'), ('Cleanup (s)', 'cleanup'), ('Overlap (s)', 'overlap'), ('Beam (mm)', 'beam_diameter'), ('Pattern', 'pattern'), ('Extraction', 'extraction_script'), ('T_o Offset', 'collection_time_zero_offset'), ('Measurement', 'measurement_script'), ('Conditionals', 'conditionals'), ('SynExtraction', 'syn_extraction'), ('CDDWarm', 'use_cdd_warming'), ('Post Eq.', 'post_equilibration_script'), ('Post Meas.', 'post_measurement_script'), ('Options', 'script_options'), ('Comment', 'comment'), ('Delay After', 'delay_after')] columns = [('Idx', 'idx'), ('-', 'result'), ('Labnumber', 'labnumber'), ('Aliquot', 'aliquot'), ('Sample', 'sample'), ('Project', 'project'), ('Material', 'material'), ('RepositoryID', 'repository_identifier'), ('Position', 'position'), ('Extract', 'extract_value'), ('Units', 'extract_units'), ('Ramp (s)', 'ramp_duration'), ('Duration (s)', 'duration'), ('Cleanup (s)', 'cleanup'), ('Overlap (s)', 'overlap'), ('Beam (mm)', 'beam_diameter'), ('Pattern', 'pattern'), ('Extraction', 'extraction_script'), ('T_o Offset', 'collection_time_zero_offset'), ('Measurement', 'measurement_script'), ('Conditionals', 'conditionals'), ('SynExtraction', 'syn_extraction'), ('CDDWarm', 'use_cdd_warming'), ('Post Eq.', 'post_equilibration_script'), ('Post Meas.', 'post_measurement_script'), ('Options', 'script_options'), ('Comment', 'comment'), ('Delay After', 'delay_after')] font = 'arial 10' # all_columns = List # all_columns_dict = Dict # =========================================================================== # widths # =========================================================================== result_width = Int(25) repository_identifier_width = Int(90) labnumber_width = Int(80) aliquot_width = Int(60) sample_width = Int(50) position_width = Int(50) extract_value_width = Int(50) extract_units_width = Int(40) duration_width = Int(70) ramp_duration_width = Int(50) cleanup_width = Int(70) pattern_width = Int(80) beam_diameter_width = Int(65) overlap_width = Int(50) # autocenter_width = Int(70) # extract_device_width = Int(125) extraction_script_width = Int(80) measurement_script_width = Int(90) conditionals_width = Int(80) syn_extraction_width = Int(80) use_cdd_warming_width = Int(80) post_measurement_script_width = Int(90) post_equilibration_script_width = Int(90) position_text = Property comment_width = Int(125) # =========================================================================== # number values # =========================================================================== ramp_duration_text = Property extract_value_text = Property beam_diameter_text = Property duration_text = Property cleanup_text = Property aliquot_text = Property overlap_text = Property # =========================================================================== # non cell editable # =========================================================================== labnumber_text = Property result_text = Property extraction_script_text = Property measurement_script_text = Property post_measurement_script_text = Property post_equilibration_script_text = Property sample_text = Property use_cdd_warming_text = Property colors = Dict(COLORS) image = Property menu = Property tooltip = Property def _get_tooltip(self): name = self.column_id item = self.item if name == 'result': if item.state in ('success', 'truncated'): return item.result.summary else: return '{}= {}\nstate= {}'.format(name, getattr(item, name), item.state) # def get_tooltip(self, obj, trait, row, column): # name = self.column_map[column] # item = getattr(obj, trait)[row] # if name == 'result': # if item.state in ('success', 'truncated'): # return item.result.summary # else: # return '{}= {}\nstate= {}'.format(name, getattr(item, name), item.state) # def get_row_label(self, section, obj=None): # return section + 1 def get_bg_color(self, obj, trait, row, column=0): # item = self.item item = getattr(obj, trait)[row] # print item.identifier, item.state, item.executable if not item.executable: color = 'red' else: if item.skip: color = 'blue' # '#33CCFF' # light blue elif item.state in self.colors: color = self.colors[item.state] elif item.end_after: color = 'grey' else: if row % 2 == 0: # color = 'white' # color = self.even_bg_color color = self.even_bg_color else: color = self.odd_bg_color # '#E6F2FF' # light gray blue # print row, color, self.odd_bg_color, self.even_bg_color return color def _get_image(self): if self.column_id == 'result': if self.item.state == 'success': return GREEN_BALL elif self.item.state == 'truncated': return ORANGE_BALL # def get_image(self, obj, trait, row, column): # name = self.column_map[column] # if name == 'result': # item = getattr(obj, trait)[row] # if item.state == 'success': # return GREEN_BALL # elif item.state == 'truncated': # return ORANGE_BALL def _get_menu(self): item = self.item if item.state in ('success', 'truncated'): evo_actions = [ Action(name='Show All', action='show_evolutions'), Action(name='Show All w/Equilibration', action='show_evolutions_w_eq'), Action(name='Show All w/Equilibration+Baseline', action='show_evolutions_w_eq_bs'), Action(name='Show All w/Baseline', action='show_evolutions_w_bs') ] for iso in item.result.isotope_group.iter_isotopes(): actions = [ Action(name='Signal', action='show_evolution_{}'.format(iso.name)), Action(name='Equilibration/Signal', action='show_evolution_eq_{}'.format(iso.name)), Action(name='Equilibration/Signal/Baseline', action='show_evolution_eq_bs_{}'.format(iso.name)), Action(name='Signal/Baseline', action='show_evolution_bs_{}'.format(iso.name)) ] m = MenuManager(*actions, name=iso.name) evo_actions.append(m) evo = MenuManager(*evo_actions, name='Evolutions') success = MenuManager( Action(name='Summary', action='show_summary'), evo) return success # def get_menu(self, obj, trait, row, column): # item = getattr(obj, trait)[row] # if item.state in ('success', 'truncated'): # # evo_actions = [Action(name='Show All', action='show_evolutions'), # Action(name='Show All w/Equilibration', action='show_evolutions_w_eq'), # Action(name='Show All w/Equilibration+Baseline', action='show_evolutions_w_eq_bs'), # Action(name='Show All w/Baseline', action='show_evolutions_w_bs')] # for iso in item.result.isotope_group.iter_isotopes(): # actions = [Action(name='Signal', action='show_evolution_{}'.format(iso.name)), # Action(name='Equilibration/Signal', action='show_evolution_eq_{}'.format(iso.name)), # Action(name='Equilibration/Signal/Baseline', # action='show_evolution_eq_bs_{}'.format(iso.name)), # Action(name='Signal/Baseline', action='show_evolution_bs_{}'.format(iso.name))] # m = MenuManager(*actions, name=iso.name) # evo_actions.append(m) # # evo = MenuManager(*evo_actions, name='Evolutions') # # success = MenuManager(Action(name='Summary', action='show_summary'), # evo) # return success # ============ non cell editable ============ def _get_result_text(self): return '' def _set_result_text(self, v): pass def _get_position_text(self): at = self.item.analysis_type p = self.item.position if at not in ('unknown', 'degas'): if at == 'blank_unknown': if ',' not in p: p = '' else: p = '' return p def _get_labnumber_text(self): return self.item.labnumber def _set_labnumber_text(self, v): pass def _set_sample_text(self, v): pass def _get_sample_text(self): return self.item.sample def _get_extraction_script_text(self): return self.item.extraction_script def _get_measurement_script_text(self): return self.item.measurement_script def _get_post_measurement_script_text(self): return self.item.post_measurement_script def _get_post_equilibration_script_text(self): return self.item.post_equilibration_script def _set_extraction_script_text(self, v): pass def _set_measurement_script_text(self, v): pass def _set_post_measurement_script_text(self, v): pass def _set_post_equilibration_script_text(self, v): pass def _set_position_text(self, v): pass # ============================================ def _get_overlap_text(self): o, m = self.item.overlap if m: return '{},{}'.format(o, m) else: if int(o): return '{}'.format(o) return '' def _get_aliquot_text(self): al = '' it = self.item if it.aliquot != 0: al = make_aliquot_step(it.aliquot, it.step) return al def _get_ramp_duration_text(self): return self._get_number('ramp_duration', fmt='{:n}') def _get_beam_diameter_text(self): return self._get_number('beam_diameter') def _get_extract_value_text(self): return self._get_number('extract_value') def _get_duration_text(self): return self._get_number('duration') def _get_cleanup_text(self): return self._get_number('cleanup') def _get_use_cdd_warming_text(self): return 'Yes' if self.item.use_cdd_warming else 'No' # ===============set================ def _set_ramp_duration_text(self, v): self._set_number(v, 'ramp_duration') def _set_beam_diameter_text(self, v): self._set_number(v, 'beam_diameter') def _set_extract_value_text(self, v): self._set_number(v, 'extract_value') def _set_duration_text(self, v): self._set_number(v, 'duration') def _set_cleanup_text(self, v): self._set_number(v, 'cleanup') def _set_use_cdd_warming_text(self, v): self.item.use_cdd_warming = to_bool(v) def _set_aliquot_text(self, v): self.item.user_defined_aliquot = int(v) # ==============validate================ def _validate_aliquot_text(self, v): return self._validate_number(v, 'aliquot', kind=int) def _validate_extract_value_text(self, v): return self._validate_number(v, 'extract_value') def _validate_ramp_duration_text(self, v): return self._validate_number(v, 'ramp_duration') def _validate_beam_diameter_text(self, v): return self._validate_number(v, 'beam_diameter') def _validate_extract_value_text(self, v): return self._validate_number(v, 'extract_value') def _validate_duration_text(self, v): return self._validate_number(v, 'duration') def _validate_cleanup_text(self, v): return self._validate_number(v, 'cleanup') # ==========helpers============== def _set_number(self, v, attr): setattr(self.item, attr, v) def _validate_number(self, v, attr, kind=float): try: return kind(v) except ValueError: return getattr(self.item, attr) def _get_number(self, attr, fmt='{:0.2f}'): """ dont display 0.0's """ v = getattr(self.item, attr) if v: if isinstance(v, str): v = float(v) return fmt.format(v) else: return ''
class B(HasTraits): dict = Dict(Str, Instance(A))
class ClassWithDict(HasTraits): values = Dict() dict_of_dict = Dict(Str, Dict)
class AdaptationManager(HasTraits): """ Manages all registered adaptations. """ #### 'AdaptationManager' class protocol ################################### @staticmethod def mro_distance_to_protocol(from_type, to_protocol): """ Return the distance in the MRO from 'from_type' to 'to_protocol'. If `from_type` provides `to_protocol`, returns the distance between `from_type` and the super-most class in the MRO hierarchy providing `to_protocol` (that's where the protocol was provided in the first place). If `from_type` does not provide `to_protocol`, return None. """ if not AdaptationManager.provides_protocol(from_type, to_protocol): return None # We walk up the MRO hierarchy until the point where the `to_protocol` # is *no longer* provided. When we reach that point we know that the # previous class in the MRO is the one that provided the protocol in # the first place (e.g., the first super-class implementing an # interface). supertypes = inspect.getmro(from_type)[1:] distance = 0 for t in supertypes: if AdaptationManager.provides_protocol(t, to_protocol): distance += 1 # We have reached the point in the MRO where the protocol is no # longer provided. else: break return distance @staticmethod def provides_protocol(type_, protocol): """ Does the given type provide (i.e implement) a given protocol? 'type_' is a Python 'type'. 'protocol' is either a regular Python class or a traits Interface. Return True if the object provides the protocol, otherwise False. """ # We do the 'is' check first as a performance improvement to save us # a call to 'issubclass'. return type_ is protocol or issubclass(type_, protocol) #### 'AdaptationManager' protocol ########################################## def adapt(self, adaptee, to_protocol, default=AdaptationError): """ Attempt to adapt an object to a given protocol. `adaptee` is the object that we want to adapt. `to_protocol` is the protocol that the want to adapt the object to. If `adaptee` already provides (i.e. implements) the given protocol then it is simply returned unchanged. Otherwise, we try to build a chain of adapters that adapt `adaptee` to `to_protocol`. If no such adaptation is possible then either an AdaptationError is raised (if default=Adaptation error), or `default` is returned (as in the default value passed to 'getattr' etc). """ # If the object already provides the given protocol then it is # simply returned. # We use adaptee.__class__ instead of type(adaptee) as a courtesy to # old-style classes. if self.provides_protocol(adaptee.__class__, to_protocol): result = adaptee # Otherwise, try adapting the object. else: result = self._adapt(adaptee, to_protocol) if result is None: if default is AdaptationError: raise AdaptationError( 'Could not adapt %r to %r' % (adaptee, to_protocol)) else: result = default return result def register_offer(self, offer): """ Register an offer to adapt from one protocol to another. """ offers = self._adaptation_offers.setdefault( offer.from_protocol_name, [] ) offers.append(offer) return def register_factory(self, factory, from_protocol, to_protocol): """ Register an adapter factory. This is a simply a convenience method that creates and registers an 'AdaptationOffer' from the given arguments. """ from traits.adaptation.adaptation_offer import AdaptationOffer self.register_offer( AdaptationOffer( factory = factory, from_protocol = from_protocol, to_protocol = to_protocol ) ) return def register_provides(self, provider_protocol, protocol): """ Register that a protocol provides another. """ self.register_factory(no_adapter_necessary, provider_protocol, protocol) return def supports_protocol(self, obj, protocol): """ Does the object support a given protocol? An object "supports" a protocol if either it "provides" it directly, or it can be adapted to it. """ return self.adapt(obj, protocol, None) is not None #### Private protocol ##################################################### #: All registered adaptation offers. #: Keys are the type name of the offer's from_protocol; values are a #: list of adaptation offers. _adaptation_offers = Dict(Str, List) def _adapt(self, adaptee, to_protocol): """ Returns an adapter that adapts an object to the target class. Returns None if no such adapter exists. """ # The algorithm for finding a sequence of adapters adapting 'adaptee' # to 'to_protocol' is based on a weighted graph. # Nodes on the graphs are protocols (types or interfaces). # Edges are adaptation offers that connect a offer.from_protocol to a # offer.to_protocol. # Edges connect protocol A to protocol B and are weighted by two # numbers in this priority: # 1) a unit weight (1) representing the fact that we use 1 adaptation # offer to go from A to B # 2) the number of steps up the type hierarchy that we need to take # to go from A to offer.from_protocol, so that more specific # adapters are always preferred # The algorithm finds the shortest weighted path between 'adaptee' # and 'to_protocol'. Once a candidate path is found, it tries to # create the adapters using the factories in the adaptation offers # that compose the path. If this fails because of conditional # adaptation (i.e., an adapter factory returns None), the path # is discarded and the algorithm looks for the next shortest path. # Cycles in adaptation are avoided by only considering path were # every adaptation offer is used at most once. # The implementation of the algorithm is based on a priority queue, # 'offer_queue'. # # Each value in the queue has got two parts, # one is the adaptation path, i.e., the sequence of adaptation offers # followed so far; the second value is the protocol of the last # visited node. # # The priority in the queue is the sum of all the weights for the # edges traversed in the path. # Unique sequence counter to make the priority list stable # w.r.t the sequence of insertion. counter = itertools.count() # The priority queue containing entries of the form # (cumulative weight, path, current protocol) describing an # adaptation path starting at `adaptee`, following a sequence # of adaptation offers, `path`, and having weight `cumulative_weight`. # # 'cumulative weight' is a tuple of the form # (number of traversed adapters, # number of steps up protocol hierarchies, # counter) # # The counter is an increasing number, and is used to make the # priority queue stable w.r.t insertion time # (see http://bit.ly/13VxILn). offer_queue = [((0, 0, next(counter)), [], type(adaptee))] while len(offer_queue) > 0: # Get the most specific candidate path for adaptation. weight, path, current_protocol = heappop(offer_queue) edges = self._get_applicable_offers(current_protocol, path) # Sort by weight first, then by from_protocol type. if sys.version_info[0] < 3: edges.sort(cmp=_by_weight_then_from_protocol_specificity) else: # functools.cmp_to_key is available from 2.7 and 3.2 edges.sort(key=functools.cmp_to_key(_by_weight_then_from_protocol_specificity)) # At this point, the first edges are the shortest ones. Within # edges with the same distance, interfaces which are subclasses # of other interfaces in that group come first. The rest of # the order is unspecified. for mro_distance, offer in edges: new_path = path + [offer] # Check if we arrived at the target protocol. if self.provides_protocol(offer.to_protocol, to_protocol): # Walk path and create adapters adapter = adaptee for offer in new_path: adapter = offer.factory(adapter) if adapter is None: # This adaptation attempt failed (e.g. because of # conditional adaptation). # Discard this path and continue. break else: # We're done! return adapter else: # Push the new path on the priority queue. adapter_weight, mro_weight, _ = weight new_weight = (adapter_weight + 1, mro_weight + mro_distance, next(counter)) heappush( offer_queue, (new_weight, new_path, offer.to_protocol) ) return None def _get_applicable_offers(self, current_protocol, path): """ Find all adaptation offers that can be applied to a protocol. Return all the applicable offers together with the number of steps up the MRO hierarchy that need to be taken from the protocol to the offer's from_protocol. The returned object is a list of tuples (mro_distance, offer) . In terms of our graph algorithm, we're looking for all outgoing edges from the current node. """ edges = [] for from_protocol_name, offers in self._adaptation_offers.items(): from_protocol = offers[0].from_protocol mro_distance = self.mro_distance_to_protocol( current_protocol, from_protocol ) if mro_distance is not None: for offer in offers: # Avoid cycles by checking that we did not consider this # offer in this path. if offer not in path: edges.append((mro_distance, offer)) return edges
class PipelineSelectorOutPutSpecification(TraitedSpec): pipeline = Dict(items=True)