示例#1
0
文件: basic.py 项目: rik0/pynetsym
class AbstractConfigurator(core.Agent):
    implements(IConfigurator)

    name = 'configurator'

    options = {"full_parameters"}
    """
    Here we specify the names of the options for the configurator.
    Options are accumulated along the inheritance path
    """

    node_identifiers = Any
    full_parameters = Dict(key_trait=Str)

    def _start(self):
        self.create_nodes()
        self.create_edges()
        self.initialize_nodes()

    def create_nodes(self):
        raise NotImplementedError()

    def create_edges(self):
        raise NotImplementedError()

    def do_initialize(self):
        for identifier in self.node_identifiers:
            self.send(identifier, 'initialize')

    def do_not_initialize(self):
        pass

    initialize_nodes = do_not_initialize
示例#2
0
        class Bar(HasTraits):
            foos = Dict(Str, Foo)
            modified = Event

            @on_trait_change("foos_items,foos.modified")
            def _fire_modified_event(self, obj, trait_name, old, new):
                self.modified = True
示例#3
0
class ConfoundsInputSpec(BaseInterfaceInputSpec):
    pipeline = Dict(mandatory=True, desc="Denoising pipeline")
    conf_raw = File(exist=True, mandatory=True, desc="Confounds table")
    conf_json = File(exist=True,
                     mandatory=True,
                     desc="Confounds description (aCompCor)")
    output_dir = Directory(exists=True, mandatory=True, desc="Output path")
示例#4
0
class NestedContainerClass(HasTraits):
    # Used in regression test for changes to nested containers

    # Nested list
    list_of_list = List(List)

    # enthought/traits#281
    dict_of_list = Dict(Str, List(Str))

    # Similar to enthought/traits#281
    dict_of_union_none_or_list = Dict(Str, Union(List(), None))

    # Nested dict
    # enthought/traits#25
    list_of_dict = List(Dict)

    dict_of_dict = Dict(Str, Dict)

    dict_of_union_none_or_dict = Dict(Str, Union(Dict(), None))

    # Nested set
    list_of_set = List(Set)

    dict_of_set = Dict(Str, Set)

    dict_of_union_none_or_set = Dict(Str, Union(Set(), None))
class A(HasTraits):
    alist = List(Int, list(range(5)))
    adict = Dict(Str, Int, dict(a=1, b=2))
    aset = Set(Int, set(range(5)))

    events = List()

    @on_trait_change("alist_items,adict_items,aset_items")
    def _receive_events(self, object, name, old, new):
        self.events.append((name, new))
示例#6
0
class BIDSDataSinkInputSpec(BaseInterfaceInputSpec):
    base_entities = Dict(
        key_trait=Str,
        value_trait=Str,
        value=dict(),  # default value
        mandatory=False,
        desc="Optional base entities that will overwrite values from incoming file"
    )
    in_file = File(
        desc="File from tmp to save in BIDS directory")
示例#7
0
    def test_validate(self):
        """ Check the validation method.

        """
        foo = Dict()

        # invalid value
        with self.assertRaises(TraitError):
            foo.validate(object=HasTraits(), name="bar", value=None)

        # valid value
        result = foo.validate(object=HasTraits(), name="bar", value={})
        self.assertIsInstance(result, TraitDictObject)

        # object is None (check for issue #71)
        result = foo.validate(object=None, name="bar", value={})
        self.assertEqual(result, {})
示例#8
0
class ReportCreatorInputSpec(BaseInterfaceInputSpec):
    pipelines = List(Dict(), mandatory=True)

    pipelines_names = List(Str(), mandatory=True)

    group_data_dir = Directory(exists=True)

    excluded_subjects = List(Str(), value=())

    plot_pipeline_edges_density = File(
        exists=True, desc="Density of edge weights (all subjects)")

    plot_pipelines_edges_density_no_high_motion = File(
        exist=True, desc="Density of edge weights (no high motion)")

    plot_pipelines_fc_fd_pearson = File(exist=True)

    plot_pipelines_fc_fd_uncorr = File(exist=True)

    plot_pipelines_distance_dependence = File(exist=True)
示例#9
0
 class TestClass(HasTraits):
     dict_1 = Dict(Str)
     dict_2 = Dict(Int, Str)
示例#10
0
class PipelineSelectorOutPutSpecification(TraitedSpec):
    pipeline = Dict(items=True)
    pipeline_name = Str(desc="Name of denoising strategy")
示例#11
0
class ReportCreatorInputSpec(BaseInterfaceInputSpec):
    pipelines = List(
        Dict("Dictionary pipeline"),
        mandatory=True
    )
    tasks = List(
        Str(), mandatory=True)
    output_dir = Directory(exists=True)
    sessions = List(
        Str(),
        mandatory=False)
    runs = List(
        Int(), mandatory=False)
    runtime_info = Instance(RuntimeInfo, mandatory=True)
    excluded_subjects = List(
        trait=Instance(ExcludedSubjects),
        value=[],
        usedefault=True
    )

    warnings = List(
        trait=Instance(ErrorData),
        desc="ErrorData objects with all relevant entities error source and error message",
        value=[],
        usedefault=True
    )

    # Aggregated over pipelines
    plots_all_pipelines_edges_density = List(
        Optional(
            File(
                exists=True,
                desc="Density of edge weights (all pipelines) for all subjects"
            )))

    plots_all_pipelines_edges_density_no_high_motion = List(
        Optional(
            File(
                exists=True,
                desc="Density of edge weights (all pipelines) without high motion subjects"
            )))

    plots_all_pipelines_fc_fd_pearson_info = List(
        Optional(File(
            exists=True,
            desc="Barplot and violinplot showing percent of significant fc-fd correlations and distribution of Pearson's r values for all subjects"
        )))

    plots_all_pipelines_fc_fd_pearson_info_no_high_motion = List(
        Optional(
            File(
                exists=True,
                desc="Barplot and violinplot showing percent of significant fc-fd correlations and distribution of Pearson's r values without high motion subjects"
            )))

    plots_all_pipelines_distance_dependence = List(
        Optional(
            File(
                exists=True,
                desc="Barplot showing mean Spearman's rho between fd-fc correlation and Euclidean distance between ROIs for all subject"
            )))

    plots_all_pipelines_distance_dependence_no_high_motion = List(
        Optional(
            File(
                exists=True,
                desc="Barplot showing mean Spearman's rho between fd-fc correlation and Euclidean distance between ROIs without high motion subjects"
            )))

    plots_all_pipelines_tdof_loss = List(
        Optional(
            File(
                exists=True,
                desc="Barplot showing degree of freedom loss (number of regressors included in each pipeline."
            )))

    # For single pipeline
    plots_pipeline_fc_fd_pearson_matrix = List(
        Optional(
            File(
                exists=True,
                desc="Matrix showing correlation between connection strength and motion for all subjects"
            )))

    plots_pipeline_fc_fd_pearson_matrix_no_high_motion = List(
        Optional(
            File(
                exists=True,
                desc="Matrix showing correlation between connection strength and motion without high motion subjects"
            )))
示例#12
0
        class Foo(HasTraits):

            mapping = Dict(items=False)
示例#13
0
        class Foo(HasTraits):

            mapping = Dict(Any, Str)
示例#14
0
class SystemObject(HasStrictTraits):

    """
        Baseclass for Programs, Sensor, Actuators
    """

    #: Names of attributes that accept Callables. If there are custom callables being used, they must be added here.
    #: The purpose of this list is that these Callables will be initialized properly.
    #: :class:`~automate.program.ProgrammableSystemObject` introduces 5 basic callables
    #: (see also :ref:`automate-programs`).
    callables = []

    def get_default_callables(self):
        """ Get a dictionary of default callables, in form {name:callable}. Re-defined in subclasses."""
        return {}

    #: Reference to System object
    system = Instance(SystemBase, transient=True)

    #: Description of the object (shown in WEB interface)
    description = CUnicode

    #: Python Logger instance for this object. System creates each object its own logger instance.
    logger = Instance(logging.Logger, transient=True)

    #: Tags are used for (for example) grouping objects. See :ref:`groups`.
    tags = TagSet(trait=CUnicode)

    #: Name property is determined by System namespace. Can be read/written.
    name = Property(trait=Unicode, depends_on='name_changed_event')

    log_level = Int(logging.INFO)

    def _log_level_changed(self, new_value):
        if self.logger:
            self.logger.setLevel(new_value)

    @cached_property
    def _get_name(self):
        try:
            return self.system.reverse[self]
        except (KeyError, AttributeError):
            return 'System not initialized!'

    def _set_name(self, new_name):
        if not is_valid_variable_name(new_name):
            raise NameError('Illegal name %s' % new_name)
        try:
            if self in list(self.system.namespace.values()):
                del self.system.namespace[self.name]
        except NameError:

            pass
        self.system.namespace[new_name] = self
        self.logger = self.system.logger.getChild('%s.%s' % (self.__class__.__name__, new_name))

    #: If set to *True*, current SystemObject is hidden in the UML diagram of WEB interface.
    hide_in_uml = CBool(False)

    _order = Int
    _count = 0

    #: Attributes that can be edited by user in WEB interface
    view = ['hide_in_uml']

    #: The data type name (as string) of the object. This is written in the initialization, and is used by WEB
    #: interface Django templates.
    data_type = ''

    #: If editable=True, a quick edit widget will appear in the web interface. Define in subclasses.
    editable = False

    # Namespace triggers this event when object name name is changed
    name_changed_event = Event

    _passed_arguments = Tuple(transient=True)
    _postponed_callables = Dict(transient=True)

    @property
    def class_name(self):
        # For Django templates
        return self.__class__.__name__

    @property
    def object_type(self):
        """
            A read-only property that gives the object type as string; sensor, actuator, program, other.
            Used by WEB interface templates.
        """

        from .statusobject import AbstractSensor, AbstractActuator
        from .program import Program
        if isinstance(self, AbstractSensor):
            return 'sensor'
        elif isinstance(self, AbstractActuator):
            return 'actuator'
        elif isinstance(self, Program):
            return 'program'
        else:
            return 'other'

    def __init__(self, name='', **traits):
        # Postpone traits initialization to be launched by System
        self.logger = logging.getLogger('automate.%s' % self.__class__.__name__)
        self._order = SystemObject._count
        SystemObject._count += 1

        self._passed_arguments = name, traits
        if 'system' in traits:
            self.setup_system(traits.pop('system'))
            self.setup_callables()

    def __setstate__(self, state, trait_change_notify=True):
        self.logger = logging.getLogger('automate.%s' % self.__class__.__name__)
        self._order = state.pop('_order')
        self._passed_arguments = None, state

    def get_status_display(self, **kwargs):
        """
            Redefine this in subclasses if status can be represented in human-readable way (units etc.)
        """
        if 'value' in kwargs:
            return str(kwargs['value'])
        return self.class_name

    def get_as_datadict(self):
        """
            Get information about this object as a dictionary.  Used by WebSocket interface to pass some
            relevant information to client applications.
        """
        return dict(type=self.__class__.__name__, tags=list(self.tags))

    def setup(self, *args, **kwargs):
        """
            Initialize necessary services etc. here. Define this in subclasses.
        """
        pass

    def setup_system(self, system, name_from_system='', **kwargs):
        """
            Set system attribute and do some initialization. Used by System.
        """

        if not self.system:
            self.system = system
        name, traits = self._passed_arguments
        new_name = self.system.get_unique_name(self, name, name_from_system)
        if not self in self.system.reverse:
            self.name = new_name
        self.logger = self.system.logger.getChild('%s.%s' % (self.__class__.__name__, self.name))
        self.logger.setLevel(self.log_level)
        if name is None and 'name' in traits:  # Only __setstate__ sets name to None. Default is ''.
            del traits['name']

        for cname in self.callables:
            if cname in traits:
                c = self._postponed_callables[cname] = traits.pop(cname)
                c.setup_callable_system(self.system)
            getattr(self, cname).setup_callable_system(self.system)

        if not self.traits_inited():
            super().__init__(**traits)
        self.name_changed_event = True
        self.setup()

    def setup_callables(self):
        """
            Setup Callable attributes that belong to this object.
        """
        defaults = self.get_default_callables()
        for key, value in list(defaults.items()):
            self._postponed_callables.setdefault(key, value)
        for key in self.callables:
            value = self._postponed_callables.pop(key)
            value.setup_callable_system(self.system, init=True)
            setattr(self, key, value)

    def cleanup(self):
        """
            Write here whatever cleanup actions are needed when object is no longer used.
        """

    def __str__(self):
        return self.name

    def __repr__(self):
        return u"'%s'" % self.name
示例#15
0
class ExecutedAutomatedRunSpecAdapter(TabularAdapter, ConfigurableMixin):
    all_columns = [('Idx', 'idx'), ('-', 'result'), ('Labnumber', 'labnumber'),
                   ('Aliquot', 'aliquot'), ('Sample', 'sample'),
                   ('Project', 'project'), ('Material', 'material'),
                   ('RepositoryID', 'repository_identifier'),
                   ('Position', 'position'), ('Extract', 'extract_value'),
                   ('Units', 'extract_units'), ('Ramp (s)', 'ramp_duration'),
                   ('Duration (s)', 'duration'), ('Cleanup (s)', 'cleanup'),
                   ('Overlap (s)', 'overlap'), ('Beam (mm)', 'beam_diameter'),
                   ('Pattern', 'pattern'), ('Extraction', 'extraction_script'),
                   ('T_o Offset', 'collection_time_zero_offset'),
                   ('Measurement', 'measurement_script'),
                   ('Conditionals', 'conditionals'),
                   ('SynExtraction', 'syn_extraction'),
                   ('CDDWarm', 'use_cdd_warming'),
                   ('Post Eq.', 'post_equilibration_script'),
                   ('Post Meas.', 'post_measurement_script'),
                   ('Options', 'script_options'), ('Comment', 'comment'),
                   ('Delay After', 'delay_after')]

    columns = [('Idx', 'idx'), ('-', 'result'), ('Labnumber', 'labnumber'),
               ('Aliquot', 'aliquot'), ('Sample', 'sample'),
               ('Project', 'project'), ('Material', 'material'),
               ('RepositoryID', 'repository_identifier'),
               ('Position', 'position'), ('Extract', 'extract_value'),
               ('Units', 'extract_units'), ('Ramp (s)', 'ramp_duration'),
               ('Duration (s)', 'duration'), ('Cleanup (s)', 'cleanup'),
               ('Overlap (s)', 'overlap'), ('Beam (mm)', 'beam_diameter'),
               ('Pattern', 'pattern'), ('Extraction', 'extraction_script'),
               ('T_o Offset', 'collection_time_zero_offset'),
               ('Measurement', 'measurement_script'),
               ('Conditionals', 'conditionals'),
               ('SynExtraction', 'syn_extraction'),
               ('CDDWarm', 'use_cdd_warming'),
               ('Post Eq.', 'post_equilibration_script'),
               ('Post Meas.', 'post_measurement_script'),
               ('Options', 'script_options'), ('Comment', 'comment'),
               ('Delay After', 'delay_after')]
    font = 'arial 10'
    # all_columns = List
    # all_columns_dict = Dict
    # ===========================================================================
    # widths
    # ===========================================================================
    result_width = Int(25)
    repository_identifier_width = Int(90)
    labnumber_width = Int(80)
    aliquot_width = Int(60)
    sample_width = Int(50)
    position_width = Int(50)
    extract_value_width = Int(50)
    extract_units_width = Int(40)
    duration_width = Int(70)
    ramp_duration_width = Int(50)
    cleanup_width = Int(70)
    pattern_width = Int(80)
    beam_diameter_width = Int(65)

    overlap_width = Int(50)
    # autocenter_width = Int(70)
    #    extract_device_width = Int(125)
    extraction_script_width = Int(80)
    measurement_script_width = Int(90)
    conditionals_width = Int(80)
    syn_extraction_width = Int(80)
    use_cdd_warming_width = Int(80)
    post_measurement_script_width = Int(90)
    post_equilibration_script_width = Int(90)

    position_text = Property
    comment_width = Int(125)
    # ===========================================================================
    # number values
    # ===========================================================================
    ramp_duration_text = Property
    extract_value_text = Property
    beam_diameter_text = Property
    duration_text = Property
    cleanup_text = Property

    aliquot_text = Property
    overlap_text = Property

    # ===========================================================================
    # non cell editable
    # ===========================================================================
    labnumber_text = Property
    result_text = Property
    extraction_script_text = Property
    measurement_script_text = Property
    post_measurement_script_text = Property
    post_equilibration_script_text = Property
    sample_text = Property
    use_cdd_warming_text = Property
    colors = Dict(COLORS)

    image = Property
    menu = Property
    tooltip = Property

    def _get_tooltip(self):
        name = self.column_id
        item = self.item
        if name == 'result':
            if item.state in ('success', 'truncated'):
                return item.result.summary
        else:
            return '{}= {}\nstate= {}'.format(name, getattr(item, name),
                                              item.state)

    # def get_tooltip(self, obj, trait, row, column):
    #     name = self.column_map[column]
    #     item = getattr(obj, trait)[row]
    #     if name == 'result':
    #         if item.state in ('success', 'truncated'):
    #             return item.result.summary
    #     else:
    #         return '{}= {}\nstate= {}'.format(name, getattr(item, name), item.state)

    # def get_row_label(self, section, obj=None):
    #     return section + 1

    def get_bg_color(self, obj, trait, row, column=0):
        # item = self.item
        item = getattr(obj, trait)[row]
        # print item.identifier, item.state, item.executable
        if not item.executable:
            color = 'red'
        else:
            if item.skip:
                color = 'blue'  # '#33CCFF'  # light blue
            elif item.state in self.colors:
                color = self.colors[item.state]
            elif item.end_after:
                color = 'grey'
            else:
                if row % 2 == 0:
                    # color = 'white'
                    # color = self.even_bg_color
                    color = self.even_bg_color
                else:
                    color = self.odd_bg_color  # '#E6F2FF'  # light gray blue
                    # print row, color, self.odd_bg_color, self.even_bg_color

        return color

    def _get_image(self):
        if self.column_id == 'result':
            if self.item.state == 'success':
                return GREEN_BALL
            elif self.item.state == 'truncated':
                return ORANGE_BALL

    # def get_image(self, obj, trait, row, column):
    #     name = self.column_map[column]
    #     if name == 'result':
    #         item = getattr(obj, trait)[row]
    #         if item.state == 'success':
    #             return GREEN_BALL
    #         elif item.state == 'truncated':
    #             return ORANGE_BALL
    def _get_menu(self):
        item = self.item
        if item.state in ('success', 'truncated'):

            evo_actions = [
                Action(name='Show All', action='show_evolutions'),
                Action(name='Show All w/Equilibration',
                       action='show_evolutions_w_eq'),
                Action(name='Show All w/Equilibration+Baseline',
                       action='show_evolutions_w_eq_bs'),
                Action(name='Show All w/Baseline',
                       action='show_evolutions_w_bs')
            ]
            for iso in item.result.isotope_group.iter_isotopes():
                actions = [
                    Action(name='Signal',
                           action='show_evolution_{}'.format(iso.name)),
                    Action(name='Equilibration/Signal',
                           action='show_evolution_eq_{}'.format(iso.name)),
                    Action(name='Equilibration/Signal/Baseline',
                           action='show_evolution_eq_bs_{}'.format(iso.name)),
                    Action(name='Signal/Baseline',
                           action='show_evolution_bs_{}'.format(iso.name))
                ]
                m = MenuManager(*actions, name=iso.name)
                evo_actions.append(m)

            evo = MenuManager(*evo_actions, name='Evolutions')

            success = MenuManager(
                Action(name='Summary', action='show_summary'), evo)
            return success

    # def get_menu(self, obj, trait, row, column):
    #     item = getattr(obj, trait)[row]
    #     if item.state in ('success', 'truncated'):
    #
    #         evo_actions = [Action(name='Show All', action='show_evolutions'),
    #                        Action(name='Show All w/Equilibration', action='show_evolutions_w_eq'),
    #                        Action(name='Show All w/Equilibration+Baseline', action='show_evolutions_w_eq_bs'),
    #                        Action(name='Show All w/Baseline', action='show_evolutions_w_bs')]
    #         for iso in item.result.isotope_group.iter_isotopes():
    #             actions = [Action(name='Signal', action='show_evolution_{}'.format(iso.name)),
    #                        Action(name='Equilibration/Signal', action='show_evolution_eq_{}'.format(iso.name)),
    #                        Action(name='Equilibration/Signal/Baseline',
    #                               action='show_evolution_eq_bs_{}'.format(iso.name)),
    #                        Action(name='Signal/Baseline', action='show_evolution_bs_{}'.format(iso.name))]
    #             m = MenuManager(*actions, name=iso.name)
    #             evo_actions.append(m)
    #
    #         evo = MenuManager(*evo_actions, name='Evolutions')
    #
    #         success = MenuManager(Action(name='Summary', action='show_summary'),
    #                               evo)
    #         return success

    # ============ non cell editable ============
    def _get_result_text(self):
        return ''

    def _set_result_text(self, v):
        pass

    def _get_position_text(self):
        at = self.item.analysis_type
        p = self.item.position
        if at not in ('unknown', 'degas'):
            if at == 'blank_unknown':
                if ',' not in p:
                    p = ''
            else:
                p = ''
        return p

    def _get_labnumber_text(self):
        return self.item.labnumber

    def _set_labnumber_text(self, v):
        pass

    def _set_sample_text(self, v):
        pass

    def _get_sample_text(self):
        return self.item.sample

    def _get_extraction_script_text(self):
        return self.item.extraction_script

    def _get_measurement_script_text(self):
        return self.item.measurement_script

    def _get_post_measurement_script_text(self):
        return self.item.post_measurement_script

    def _get_post_equilibration_script_text(self):
        return self.item.post_equilibration_script

    def _set_extraction_script_text(self, v):
        pass

    def _set_measurement_script_text(self, v):
        pass

    def _set_post_measurement_script_text(self, v):
        pass

    def _set_post_equilibration_script_text(self, v):
        pass

    def _set_position_text(self, v):
        pass

    # ============================================

    def _get_overlap_text(self):
        o, m = self.item.overlap
        if m:
            return '{},{}'.format(o, m)
        else:
            if int(o):
                return '{}'.format(o)
        return ''

    def _get_aliquot_text(self):
        al = ''
        it = self.item
        if it.aliquot != 0:
            al = make_aliquot_step(it.aliquot, it.step)

        return al

    def _get_ramp_duration_text(self):
        return self._get_number('ramp_duration', fmt='{:n}')

    def _get_beam_diameter_text(self):
        return self._get_number('beam_diameter')

    def _get_extract_value_text(self):
        return self._get_number('extract_value')

    def _get_duration_text(self):
        return self._get_number('duration')

    def _get_cleanup_text(self):
        return self._get_number('cleanup')

    def _get_use_cdd_warming_text(self):
        return 'Yes' if self.item.use_cdd_warming else 'No'

    # ===============set================
    def _set_ramp_duration_text(self, v):
        self._set_number(v, 'ramp_duration')

    def _set_beam_diameter_text(self, v):
        self._set_number(v, 'beam_diameter')

    def _set_extract_value_text(self, v):
        self._set_number(v, 'extract_value')

    def _set_duration_text(self, v):
        self._set_number(v, 'duration')

    def _set_cleanup_text(self, v):
        self._set_number(v, 'cleanup')

    def _set_use_cdd_warming_text(self, v):
        self.item.use_cdd_warming = to_bool(v)

    def _set_aliquot_text(self, v):
        self.item.user_defined_aliquot = int(v)

    # ==============validate================
    def _validate_aliquot_text(self, v):
        return self._validate_number(v, 'aliquot', kind=int)

    def _validate_extract_value_text(self, v):
        return self._validate_number(v, 'extract_value')

    def _validate_ramp_duration_text(self, v):
        return self._validate_number(v, 'ramp_duration')

    def _validate_beam_diameter_text(self, v):
        return self._validate_number(v, 'beam_diameter')

    def _validate_extract_value_text(self, v):
        return self._validate_number(v, 'extract_value')

    def _validate_duration_text(self, v):
        return self._validate_number(v, 'duration')

    def _validate_cleanup_text(self, v):
        return self._validate_number(v, 'cleanup')

    # ==========helpers==============
    def _set_number(self, v, attr):
        setattr(self.item, attr, v)

    def _validate_number(self, v, attr, kind=float):
        try:
            return kind(v)
        except ValueError:
            return getattr(self.item, attr)

    def _get_number(self, attr, fmt='{:0.2f}'):
        """
            dont display 0.0's
        """
        v = getattr(self.item, attr)
        if v:
            if isinstance(v, str):
                v = float(v)

            return fmt.format(v)
        else:
            return ''
class B(HasTraits):
    dict = Dict(Str, Instance(A))
示例#17
0
class ClassWithDict(HasTraits):
    values = Dict()

    dict_of_dict = Dict(Str, Dict)
示例#18
0
class AdaptationManager(HasTraits):
    """ Manages all registered adaptations. """

    #### 'AdaptationManager' class protocol ###################################

    @staticmethod
    def mro_distance_to_protocol(from_type, to_protocol):
        """ Return the distance in the MRO from 'from_type' to 'to_protocol'.

        If `from_type` provides `to_protocol`, returns the distance between
        `from_type` and the super-most class in the MRO hierarchy providing
        `to_protocol` (that's where the protocol was provided in the first
        place).

        If `from_type` does not provide `to_protocol`, return None.

        """

        if not AdaptationManager.provides_protocol(from_type, to_protocol):
            return None

        # We walk up the MRO hierarchy until the point where the `to_protocol`
        # is *no longer* provided. When we reach that point we know that the
        # previous class in the MRO is the one that provided the protocol in
        # the first place (e.g., the first super-class implementing an
        # interface).
        supertypes = inspect.getmro(from_type)[1:]

        distance = 0
        for t in supertypes:
            if AdaptationManager.provides_protocol(t, to_protocol):
                distance += 1

            # We have reached the point in the MRO where the protocol is no
            # longer provided.
            else:
                break

        return distance

    @staticmethod
    def provides_protocol(type_, protocol):
        """ Does the given type provide (i.e implement) a given protocol?

        'type_'    is a Python 'type'.
        'protocol' is either a regular Python class or a traits Interface.

        Return True if the object provides the protocol, otherwise False.

        """

        # We do the 'is' check first as a performance improvement to save us
        # a call to 'issubclass'.
        return type_ is protocol or issubclass(type_, protocol)

    #### 'AdaptationManager' protocol ##########################################

    def adapt(self, adaptee, to_protocol, default=AdaptationError):
        """ Attempt to adapt an object to a given protocol.

        `adaptee`     is the object that we want to adapt.
        `to_protocol` is the protocol that the want to adapt the object to.

        If `adaptee` already provides (i.e. implements) the given protocol
        then it is simply returned unchanged.

        Otherwise, we try to build a chain of adapters that adapt `adaptee`
        to `to_protocol`.

        If no such adaptation is possible then either an AdaptationError is
        raised (if default=Adaptation error), or `default` is returned (as
        in the default value passed to 'getattr' etc).

        """

        # If the object already provides the given protocol then it is
        # simply returned.
        # We use adaptee.__class__ instead of type(adaptee) as a courtesy to
        # old-style classes.
        if self.provides_protocol(adaptee.__class__, to_protocol):
            result = adaptee

        # Otherwise, try adapting the object.
        else:
            result = self._adapt(adaptee, to_protocol)

        if result is None:
            if default is AdaptationError:
                raise AdaptationError(
                    'Could not adapt %r to %r' % (adaptee, to_protocol))
            else:
                result = default

        return result

    def register_offer(self, offer):
        """ Register an offer to adapt from one protocol to another. """

        offers = self._adaptation_offers.setdefault(
            offer.from_protocol_name, []
        )
        offers.append(offer)

        return

    def register_factory(self, factory, from_protocol, to_protocol):
        """ Register an adapter factory.

        This is a simply a convenience method that creates and registers an
        'AdaptationOffer' from the given arguments.

        """

        from traits.adaptation.adaptation_offer import AdaptationOffer

        self.register_offer(
            AdaptationOffer(
                factory       = factory,
                from_protocol = from_protocol,
                to_protocol   = to_protocol
            )
        )

        return

    def register_provides(self, provider_protocol, protocol):
        """ Register that a protocol provides another. """

        self.register_factory(no_adapter_necessary, provider_protocol, protocol)

        return

    def supports_protocol(self, obj, protocol):
        """ Does the object support a given protocol?

        An object "supports" a protocol if either it "provides" it directly,
        or it can be adapted to it.

        """

        return self.adapt(obj, protocol, None) is not None

    #### Private protocol #####################################################

    #: All registered adaptation offers.
    #: Keys are the type name of the offer's from_protocol; values are a
    #: list of adaptation offers.
    _adaptation_offers = Dict(Str, List)

    def _adapt(self, adaptee, to_protocol):
        """ Returns an adapter that adapts an object to the target class.

        Returns None if no such adapter exists.

        """

        # The algorithm for finding a sequence of adapters adapting 'adaptee'
        # to 'to_protocol' is based on a weighted graph.

        # Nodes on the graphs are protocols (types or interfaces).
        # Edges are adaptation offers that connect a offer.from_protocol to a
        # offer.to_protocol.
        # Edges connect protocol A to protocol B and are weighted by two
        # numbers in this priority:
        # 1) a unit weight (1) representing the fact that we use 1 adaptation
        #    offer to go from A to B
        # 2) the number of steps up the type hierarchy that we need to take
        #    to go from A to offer.from_protocol, so that more specific
        #    adapters are always preferred

        # The algorithm finds the shortest weighted path between 'adaptee'
        # and 'to_protocol'. Once a candidate path is found, it tries to
        # create the adapters using the factories in the adaptation offers
        # that compose the path. If this fails because of conditional
        # adaptation (i.e., an adapter factory returns None), the path
        # is discarded and the algorithm looks for the next shortest path.

        # Cycles in adaptation are avoided by only considering path were
        # every adaptation offer is used at most once.

        # The implementation of the algorithm is based on a priority queue,
        # 'offer_queue'.
        #
        # Each value in the queue has got two parts,
        # one is the adaptation path, i.e., the sequence of adaptation offers
        # followed so far; the second value is the protocol of the last
        # visited node.
        #
        # The priority in the queue is the sum of all the weights for the
        # edges traversed in the path.

        # Unique sequence counter to make the priority list stable
        # w.r.t the sequence of insertion.
        counter = itertools.count()

        # The priority queue containing entries of the form
        # (cumulative weight, path, current protocol) describing an
        # adaptation path starting at `adaptee`, following a sequence
        # of adaptation offers, `path`, and having weight `cumulative_weight`.
        #
        # 'cumulative weight' is a tuple of the form
        # (number of traversed adapters,
        #  number of steps up protocol hierarchies,
        #  counter)
        #
        # The counter is an increasing number, and is used to make the
        # priority queue stable w.r.t insertion time
        # (see http://bit.ly/13VxILn).
        offer_queue = [((0, 0, next(counter)), [], type(adaptee))]

        while len(offer_queue) > 0:
            # Get the most specific candidate path for adaptation.
            weight, path, current_protocol = heappop(offer_queue)

            edges = self._get_applicable_offers(current_protocol, path)

            # Sort by weight first, then by from_protocol type.
            if sys.version_info[0] < 3:
                edges.sort(cmp=_by_weight_then_from_protocol_specificity)
            else:
                # functools.cmp_to_key is available from 2.7 and 3.2
                edges.sort(key=functools.cmp_to_key(_by_weight_then_from_protocol_specificity))
                

            # At this point, the first edges are the shortest ones. Within
            # edges with the same distance, interfaces which are subclasses
            # of other interfaces in that group come first. The rest of
            # the order is unspecified.

            for mro_distance, offer in edges:
                new_path = path + [offer]

                # Check if we arrived at the target protocol.
                if self.provides_protocol(offer.to_protocol, to_protocol):
                    # Walk path and create adapters
                    adapter = adaptee
                    for offer in new_path:
                        adapter = offer.factory(adapter)
                        if adapter is None:
                            # This adaptation attempt failed (e.g. because of
                            # conditional adaptation).
                            # Discard this path and continue.
                            break

                    else:
                        # We're done!
                        return adapter

                else:
                    # Push the new path on the priority queue.
                    adapter_weight, mro_weight, _ = weight
                    new_weight = (adapter_weight + 1,
                                  mro_weight + mro_distance,
                                  next(counter))
                    heappush(
                        offer_queue,
                        (new_weight, new_path, offer.to_protocol)
                    )

        return None

    def _get_applicable_offers(self, current_protocol, path):
        """ Find all adaptation offers that can be applied to a protocol.

        Return all the applicable offers together with the number of steps
        up the MRO hierarchy that need to be taken from the protocol
        to the offer's from_protocol.
        The returned object is a list of tuples (mro_distance, offer) .

        In terms of our graph algorithm, we're looking for all outgoing edges
        from the current node.
        """

        edges = []

        for from_protocol_name, offers in self._adaptation_offers.items():
            from_protocol = offers[0].from_protocol
            mro_distance = self.mro_distance_to_protocol(
                current_protocol, from_protocol
            )

            if mro_distance is not None:

                for offer in offers:
                    # Avoid cycles by checking that we did not consider this
                    # offer in this path.
                    if offer not in path:
                        edges.append((mro_distance, offer))

        return edges
示例#19
0
class PipelineSelectorOutPutSpecification(TraitedSpec):
    pipeline = Dict(items=True)