示例#1
0
 def _aliasize_orderby(self, orderby, copy=True):
     if copy:
         return self.aliasizer.copy_and_process(util.to_list(orderby))
     else:
         orderby = util.to_list(orderby)
         self.aliasizer.process_list(orderby)
         return orderby
示例#2
0
    def group_by(self, criterion):
        """apply one or more GROUP BY criterion to the query and return the newly resulting ``Query``"""

        q = self._clone()
        if q._group_by is False:    
            q._group_by = util.to_list(criterion)
        else:
            q._group_by.extend(util.to_list(criterion))
        return q
示例#3
0
    def __get_paths(self, query, raiseerr):
        path = None
        entity = None
        l = []

        # _current_path implies we're in a secondary load
        # with an existing path
        current_path = list(query._current_path)
            
        if self.mapper:
            entity = self.__find_entity(query, self.mapper, raiseerr)
            mapper = entity.mapper
            path_element = entity.path_entity

        for key in util.to_list(self.key):
            if isinstance(key, basestring):
                tokens = key.split('.')
            else:
                tokens = [key]
            for token in tokens:
                if isinstance(token, basestring):
                    if not entity:
                        entity = query._entity_zero()
                        path_element = entity.path_entity
                        mapper = entity.mapper
                    prop = mapper.get_property(token, resolve_synonyms=True, raiseerr=raiseerr)
                    key = token
                elif isinstance(token, PropComparator):
                    prop = token.property
                    if not entity:
                        entity = self.__find_entity(query, token.parententity, raiseerr)
                        if not entity:
                            return []
                        path_element = entity.path_entity
                    key = prop.key
                else:
                    raise sa_exc.ArgumentError("mapper option expects string key or list of attributes")

                if current_path and key == current_path[1]:
                    current_path = current_path[2:]
                    continue
                    
                if prop is None:
                    return []

                path = build_path(path_element, prop.key, path)
                l.append(path)
                if getattr(token, '_of_type', None):
                    path_element = mapper = token._of_type
                else:
                    path_element = mapper = getattr(prop, 'mapper', None)
                if path_element:
                    path_element = path_element.base_mapper
        
        # if current_path tokens remain, then
        # we didn't have an exact path match.
        if current_path:
            return []
            
        return l
示例#4
0
def driver(drivername):
    """Return `True` or `False` if the drivername is matching the current
    configuration.  `drivername` can be a list.
    """
    engine = get_engine()
    drivers = to_list(drivername)
    return engine.url.drivername in drivers
    def __init__(self, app=None, use_native_unicode=True,
                 session_extensions=None, session_options=None):
        self.use_native_unicode = use_native_unicode
        self.session_extensions = to_list(session_extensions, []) + \
                                  [_SignallingSessionExtension()]

        if session_options is None:
            session_options = {}

        session_options.setdefault(
            'scopefunc', connection_stack.__ident_func__
        )

        self.session = self.create_scoped_session(session_options)
        self.Model = self.make_declarative_base()
        self._engine_lock = Lock()

        if app is not None:
            self.app = app
            self.init_app(app)
        else:
            self.app = None

        _include_sqlalchemy(self)
        self.Query = BaseQuery
示例#6
0
def column_mapped_collection(mapping_spec):
    """A dictionary-based collection type with column-based keying.

    Returns a MappedCollection factory with a keying function generated
    from mapping_spec, which may be a Column or a sequence of Columns.

    The key value must be immutable for the lifetime of the object.  You
    can not, for example, map on foreign key values if those key values will
    change during the session, i.e. from None to a database-assigned integer
    after a session flush.

    """
    from sqlalchemy.orm.util import _state_mapper
    from sqlalchemy.orm.attributes import instance_state

    cols = [expression._no_literals(q) for q in util.to_list(mapping_spec)]
    if len(cols) == 1:
        def keyfunc(value):
            state = instance_state(value)
            m = _state_mapper(state)
            return m._get_state_attr_by_column(state, cols[0])
    else:
        mapping_spec = tuple(cols)
        def keyfunc(value):
            state = instance_state(value)
            m = _state_mapper(state)
            return tuple(m._get_state_attr_by_column(state, c)
                         for c in mapping_spec)
    return lambda: MappedCollection(keyfunc)
示例#7
0
    def _register_attribute(self, compare_function, copy_function, mutable_scalars, 
            comparator_factory, callable_=None, proxy_property=None, active_history=False):
        self.logger.info("%s register managed attribute" % self)

        attribute_ext = util.to_list(self.parent_property.extension) or []
        if self.key in self.parent._validators:
            attribute_ext.append(mapperutil.Validator(self.key, self.parent._validators[self.key]))

        for mapper in self.parent.polymorphic_iterator():
            if (mapper is self.parent or not mapper.concrete) and mapper.has_property(self.key):
                sessionlib.register_attribute(
                    mapper.class_, 
                    self.key, 
                    uselist=False, 
                    useobject=False, 
                    copy_function=copy_function, 
                    compare_function=compare_function, 
                    mutable_scalars=mutable_scalars, 
                    comparator=comparator_factory(self.parent_property, mapper), 
                    parententity=mapper,
                    callable_=callable_,
                    extension=attribute_ext,
                    proxy_property=proxy_property,
                    active_history=active_history
                    )
示例#8
0
    def _filter_or_exclude(self, negate, kwargs):
        q = self
        negate_if = lambda expr: expr if not negate else ~expr
        column = None

        for arg, value in kwargs.iteritems():
            for token in arg.split('__'):
                if column is None:
                    column = _entity_descriptor(q._joinpoint_zero(), token)
                    if column.impl.uses_objects:
                        q = q.join(column)
                        column = None
                elif token in self.OPERATORS:
                    op = self.OPERATORS[token]
                    if isinstance(value, (list, tuple)):
                        value = [value]
                    q = q.filter(negate_if(op(column, *to_list(value))))
                    column = None
                else:
                    raise ValueError('No idea what to do with %r' % token)
            if column is not None:
                q = q.filter(negate_if(column == value))
                column = None
            q = q.reset_joinpoint()
        return q
    def _do_skips(self, cls):
        if hasattr(cls, '__requires__'):
            def test_suite(): return 'ok'
            test_suite.__name__ = cls.__name__
            for requirement in cls.__requires__:
                check = getattr(requires, requirement)
                check(test_suite)()

        if cls.__unsupported_on__:
            spec = testing.db_spec(*cls.__unsupported_on__)
            if spec(testing.db):
                raise SkipTest(
                    "'%s' unsupported on DB implementation '%s'" % (
                     cls.__name__, testing.db.name)
                    )

        if getattr(cls, '__only_on__', None):
            spec = testing.db_spec(*util.to_list(cls.__only_on__))
            if not spec(testing.db):
                raise SkipTest(
                    "'%s' unsupported on DB implementation '%s'" % (
                     cls.__name__, testing.db.name)
                    )

        if getattr(cls, '__skip_if__', False):
            for c in getattr(cls, '__skip_if__'):
                if c():
                    raise SkipTest("'%s' skipped by %s" % (
                        cls.__name__, c.__name__)
                    )

        for db, op, spec in getattr(cls, '__excluded_on__', ()):
            testing.exclude(db, op, spec, "'%s' unsupported on DB %s version %s" % (
                    cls.__name__, testing.db.name,
                    testing._server_version()))
示例#10
0
    def _get(self, key, ident=None, reload=False, lockmode=None):
        lockmode = lockmode or self.lockmode
        if not reload and not self.always_refresh and lockmode is None:
            try:
                return self.session._get(key)
            except KeyError:
                pass

        if ident is None:
            ident = key[1]
        else:
            ident = util.to_list(ident)
        i = 0
        params = {}
        for primary_key in self.primary_key_columns:
            params[primary_key._label] = ident[i]
            # if there are not enough elements in the given identifier, then
            # use the previous identifier repeatedly.  this is a workaround for the issue
            # in [ticket:185], where a mapper that uses joined table inheritance needs to specify
            # all primary keys of the joined relationship, which includes even if the join is joining
            # two primary key (and therefore synonymous) columns together, the usual case for joined table inheritance.
            if len(ident) > i + 1:
                i += 1
        try:
            statement = self.compile(self._get_clause, lockmode=lockmode)
            return self._select_statement(statement, params=params, populate_existing=reload, version_check=(lockmode is not None))[0]
        except IndexError:
            return None
示例#11
0
 def _register_attribute(self, class_, callable_=None, impl_class=None, **kwargs):
     self.logger.info("%s register managed %s attribute" % (self, (self.uselist and "collection" or "scalar")))
     
     attribute_ext = util.to_list(self.parent_property.extension) or []
     
     if self.parent_property.backref:
         attribute_ext.append(self.parent_property.backref.extension)
     
     if self.key in self.parent._validators:
         attribute_ext.append(mapperutil.Validator(self.key, self.parent._validators[self.key]))
         
     sessionlib.register_attribute(
         class_, 
         self.key, 
         uselist=self.uselist, 
         useobject=True, 
         extension=attribute_ext, 
         trackparent=True, 
         typecallable=self.parent_property.collection_class, 
         callable_=callable_, 
         comparator=self.parent_property.comparator, 
         parententity=self.parent,
         impl_class=impl_class,
         **kwargs
         )
示例#12
0
    def __should_skip_for(self, cls):
        if hasattr(cls, '__requires__'):
            def test_suite(): return 'ok'
            for requirement in cls.__requires__:
                check = getattr(requires, requirement)
                if check(test_suite)() != 'ok':
                    # The requirement will perform messaging.
                    return True

        if cls.__unsupported_on__:
            spec = testing.db_spec(*cls.__unsupported_on__)
            if spec(testing.db):
                print "'%s' unsupported on DB implementation '%s'" % (
                     cls.__class__.__name__, testing.db.name)
                return True
        if getattr(cls, '__only_on__', None):
            spec = testing.db_spec(*util.to_list(cls.__only_on__))
            if not spec(testing.db):
                print "'%s' unsupported on DB implementation '%s'" % (
                     cls.__class__.__name__, testing.db.name)
                return True                    

        if (getattr(cls, '__skip_if__', False)):
            for c in getattr(cls, '__skip_if__'):
                if c():
                    print "'%s' skipped by %s" % (
                        cls.__class__.__name__, c.__name__)
                    return True
        for rule in getattr(cls, '__excluded_on__', ()):
            if testing._is_excluded(*rule):
                print "'%s' unsupported on DB %s version %s" % (
                    cls.__class__.__name__, testing.db.name,
                    _server_version())
                return True
        return False
示例#13
0
    def _get_paths(self, query, raiseerr):
        path = None
        entity = None
        l = []
        mappers = []

        # _current_path implies we're in a secondary load with an
        # existing path

        current_path = list(query._current_path)
        tokens = []
        for key in util.to_list(self.key):
            if isinstance(key, basestring):
                tokens += key.split(".")
            else:
                tokens += [key]
        for token in tokens:
            if isinstance(token, basestring):
                if not entity:
                    if current_path:
                        if current_path[1] == token:
                            current_path = current_path[2:]
                            continue
                    entity = query._entity_zero()
                    path_element = entity.path_entity
                    mapper = entity.mapper
                mappers.append(mapper)
                prop = mapper.get_property(token, resolve_synonyms=True, raiseerr=raiseerr)
                key = token
            elif isinstance(token, PropComparator):
                prop = token.property
                if not entity:
                    if current_path:
                        if current_path[0:2] == [token.parententity, prop.key]:
                            current_path = current_path[2:]
                            continue
                    entity = self._find_entity(query, token.parententity, raiseerr)
                    if not entity:
                        return [], []
                    path_element = entity.path_entity
                    mapper = entity.mapper
                mappers.append(prop.parent)
                key = prop.key
            else:
                raise sa_exc.ArgumentError("mapper option expects " "string key or list of attributes")
            if prop is None:
                return [], []
            path = build_path(path_element, prop.key, path)
            l.append(path)
            if getattr(token, "_of_type", None):
                path_element = mapper = token._of_type
            else:
                path_element = mapper = getattr(prop, "mapper", None)
            if path_element:
                path_element = path_element

        if current_path:
            return [], []
        return l, mappers
示例#14
0
def mapper(*args, **kwargs):
    """
    Add our own database mapper, not the new sqlalchemy 0.4
    session aware mapper.
    """
    kwargs['extension'] = extensions = to_list(kwargs.get('extension', []))
    extensions.append(ManagerExtension())
    return orm.mapper(*args, **kwargs)
示例#15
0
    def __init__(self, stmt):
        if isinstance(stmt, expression.ScalarSelect):
            stmt = stmt.element
        elif not isinstance(stmt, expression.SelectBase):
            stmt = expression.select(util.to_list(stmt))

        super(nested, self).__init__(stmt)
        self.type = NestedResult()
示例#16
0
    def setup_query(self, context, eagertable=None, parentclauses=None, parentmapper=None, **kwargs):
        """Add a left outer join to the statement thats being constructed."""
        
        if parentmapper is None:
            localparent = context.mapper
        else:
            localparent = parentmapper
        
        if self.mapper in context.recursion_stack:
            return
        else:
            context.recursion_stack.add(self.parent)

        statement = context.statement
        
        if hasattr(statement, '_outerjoin'):
            towrap = statement._outerjoin
        elif isinstance(localparent.mapped_table, schema.Table):
            # if the mapper is against a plain Table, look in the from_obj of the select statement
            # to join against whats already there.
            for (fromclause, finder) in [(x, sql_util.TableFinder(x)) for x in statement.froms]:
                # dont join against an Alias'ed Select.  we are really looking either for the 
                # table itself or a Join that contains the table.  this logic still might need
                # adjustments for scenarios not thought of yet.
                if not isinstance(fromclause, sql.Alias) and localparent.mapped_table in finder:
                    towrap = fromclause
                    break
            else:
                raise exceptions.InvalidRequestError("EagerLoader cannot locate a clause with which to outer join to, in query '%s' %s" % (str(statement), self.localparent.mapped_table))
        else:
            # if the mapper is against a select statement or something, we cant handle that at the
            # same time as a custom FROM clause right now.
            towrap = localparent.mapped_table
        
        try:
            clauses = self.clauses[parentclauses]
        except KeyError:
            clauses = EagerLoader.AliasedClauses(self, parentclauses)
            self.clauses[parentclauses] = clauses
            
        if context.mapper not in self.clauses_by_lead_mapper:
            self.clauses_by_lead_mapper[context.mapper] = clauses

        if self.secondaryjoin is not None:
            statement._outerjoin = sql.outerjoin(towrap, clauses.eagersecondary, clauses.eagerprimary).outerjoin(clauses.eagertarget, clauses.eagersecondaryjoin)
            if self.order_by is False and self.secondary.default_order_by() is not None:
                statement.order_by(*clauses.eagersecondary.default_order_by())
        else:
            statement._outerjoin = towrap.outerjoin(clauses.eagertarget, clauses.eagerprimary)
            if self.order_by is False and clauses.eagertarget.default_order_by() is not None:
                statement.order_by(*clauses.eagertarget.default_order_by())

        if clauses.eager_order_by:
            statement.order_by(*util.to_list(clauses.eager_order_by))
                
        statement.append_from(statement._outerjoin)
        for value in self.select_mapper.props.values():
            value.setup(context, eagertable=clauses.eagertarget, parentclauses=clauses, parentmapper=self.select_mapper)
示例#17
0
 def __getstate__(self):
     d = self.__dict__.copy()
     d['key'] = ret = []
     for token in util.to_list(self.key):
         if isinstance(token, PropComparator):
             ret.append((token.mapper.class_, token.key))
         else:
             ret.append(token)
     return d
示例#18
0
    def select_from(self, from_obj):
        """Set the `from_obj` parameter of the query.

        `from_obj` is a list of one or more tables.
        """

        new = self._clone()
        new._from_obj = list(new._from_obj) + util.to_list(from_obj)
        return new
示例#19
0
def identity_key(*args, **kwargs):
    """Get an identity key.

    Valid call signatures:

    * ``identity_key(class, ident)``

      class
          mapped class (must be a positional argument)

      ident
          primary key, if the key is composite this is a tuple


    * ``identity_key(instance=instance)``

      instance
          object instance (must be given as a keyword arg)

    * ``identity_key(class, row=row)``

      class
          mapped class (must be a positional argument)

      row
          result proxy row (must be given as a keyword arg)

    """
    if args:
        if len(args) == 1:
            class_ = args[0]
            try:
                row = kwargs.pop("row")
            except KeyError:
                ident = kwargs.pop("ident")
        elif len(args) == 2:
            class_, ident = args
        elif len(args) == 3:
            class_, ident = args
        else:
            raise sa_exc.ArgumentError("expected up to three "
                "positional arguments, got %s" % len(args))
        if kwargs:
            raise sa_exc.ArgumentError("unknown keyword arguments: %s"
                % ", ".join(kwargs.keys()))
        mapper = class_mapper(class_)
        if "ident" in locals():
            return mapper.identity_key_from_primary_key(util.to_list(ident))
        return mapper.identity_key_from_row(row)
    instance = kwargs.pop("instance")
    if kwargs:
        raise sa_exc.ArgumentError("unknown keyword arguments: %s"
            % ", ".join(kwargs.keys()))
    mapper = object_mapper(instance)
    return mapper.identity_key_from_instance(instance)
示例#20
0
    def __call__(self):
        state = self.state
        if not mapper._state_has_identity(state):
            return None

        instance_mapper = mapper._state_mapper(state)
        prop = instance_mapper.get_property(self.key)
        strategy = prop._get_strategy(LazyLoader)
        
        if strategy._should_log_debug:
            strategy.logger.debug("loading %s" % mapperutil.state_attribute_str(state, self.key))

        session = sessionlib._state_session(state)
        if session is None:
            raise sa_exc.UnboundExecutionError(
                "Parent instance %s is not bound to a Session; "
                "lazy load operation of attribute '%s' cannot proceed" % 
                (mapperutil.state_str(state), self.key)
            )
        
        q = session.query(prop.mapper)._adapt_all_clauses()
        
        if self.path:
            q = q._with_current_path(self.path)
            
        # if we have a simple primary key load, use mapper.get()
        # to possibly save a DB round trip
        if strategy.use_get:
            ident = []
            allnulls = True
            for primary_key in prop.mapper.primary_key: 
                val = instance_mapper._get_committed_state_attr_by_column(state, strategy._equated_columns[primary_key])
                allnulls = allnulls and val is None
                ident.append(val)
            if allnulls:
                return None
            if self.options:
                q = q._conditional_options(*self.options)
            return q.get(ident)

        if prop.order_by:
            q = q.order_by(*util.to_list(prop.order_by))

        if self.options:
            q = q._conditional_options(*self.options)
        q = q.filter(strategy.lazy_clause(state))

        result = q.all()
        if strategy.uselist:
            return result
        else:
            if result:
                return result[0]
            else:
                return None
示例#21
0
def mapper(cls, *arg, **options):
    """A mapper that hooks in our standard extensions."""

    extensions = to_list(options.pop('extension', None), [])
    extensions.append(AutoAddExt())
    options['extension'] = extensions

    if not hasattr(cls, 'query'):
        cls.query = session.query_property()

    return orm.mapper(cls, *arg, **options)
示例#22
0
def mapper(model, table, **options):
    """A mapper that hooks in standard extensions."""
    extensions = to_list(options.pop('extension', None), [])
    options['extension'] = extensions
    # automatically register the model to the session
    old_init = getattr(model, '__init__', lambda s: None)
    def register_init(self, *args, **kwargs):
        old_init(self, *args, **kwargs)
        db.session.add(self)
    model.__init__ = register_init
    return orm.mapper(model, table, **options)
 def get(self, ident, **kwargs):
     if self._shard_id is not None:
         return super(ShardedQuery, self).get(ident)
     else:
         ident = util.to_list(ident)
         for shard_id in self.id_chooser(self, ident):
             o = self.set_shard(shard_id).get(ident, **kwargs)
             if o is not None:
                 return o
         else:
             return None
示例#24
0
    def __init__(self, prop, primaryjoin, secondaryjoin, parentclauses=None, alias=None):
        self.prop = prop
        self.mapper = self.prop.mapper
        self.table = self.prop.table
        self.parentclauses = parentclauses

        if not alias:
            from_obj = self.mapper._with_polymorphic_selectable()
            alias = from_obj.alias()

        super(PropertyAliasedClauses, self).__init__(alias, equivalents=self.mapper._equivalent_columns, chain_to=parentclauses)
        
        if prop.secondary:
            self.secondary = prop.secondary.alias()
            primary_aliasizer = sql_util.ClauseAdapter(self.secondary)
            secondary_aliasizer = sql_util.ClauseAdapter(self.alias, equivalents=self.equivalents).chain(sql_util.ClauseAdapter(self.secondary))

            if parentclauses is not None:
                primary_aliasizer.chain(sql_util.ClauseAdapter(parentclauses.alias, equivalents=parentclauses.equivalents))

            self.secondaryjoin = secondary_aliasizer.traverse(secondaryjoin, clone=True)
            self.primaryjoin = primary_aliasizer.traverse(primaryjoin, clone=True)
        else:
            primary_aliasizer = sql_util.ClauseAdapter(self.alias, exclude=prop.local_side, equivalents=self.equivalents)
            if parentclauses is not None: 
                primary_aliasizer.chain(sql_util.ClauseAdapter(parentclauses.alias, exclude=prop.remote_side, equivalents=parentclauses.equivalents))
            
            self.primaryjoin = primary_aliasizer.traverse(primaryjoin, clone=True)
            self.secondary = None
            self.secondaryjoin = None
        
        if prop.order_by:
            if prop.secondary:
                # usually this is not used but occasionally someone has a sort key in their secondary
                # table, even tho SA does not support writing this column directly
                self.order_by = secondary_aliasizer.copy_and_process(util.to_list(prop.order_by))
            else:
                self.order_by = primary_aliasizer.copy_and_process(util.to_list(prop.order_by))
                
        else:
            self.order_by = None
 def get(self, ident, **kwargs):
     if self._shard_id is not None or not\
         self.session.table_is_sharded(self.session.get_table(query=self)):
         return super(ShardingQueryMixin, self).get(ident)
     else:
         ident = util.to_list(ident)
         for shard_id in self.id_chooser(self, ident):
             o = self.set_shard(shard_id).get(ident, **kwargs)
             if o is not None:
                 return o
         else:
             return None
示例#26
0
def only_on(dbs, reason):
    carp = _should_carp_about_exclusion(reason)
    spec = db_spec(*util.to_list(dbs))
    @decorator
    def decorate(fn, *args, **kw):
        if spec(config.db):
            return fn(*args, **kw)
        else:
            msg = "'%s' unsupported on DB implementation '%s+%s': %s" % (
                fn.__name__, config.db.name, config.db.driver, reason)
            raise SkipTest(msg)
    return decorate
示例#27
0
    def __init__(self, engine_url, echo=False, pool_recycle=7200, pool_size=10,
                 session_extensions=None, session_options=None):
        # create signals sender
        self.sender = str(uuid.uuid4())

        self.session_extensions = to_list(session_extensions, []) + \
                                  [_SignallingSessionExtension()]
        self.session = self.create_scoped_session(session_options)
        self.Model = self.make_declarative_base()
        
        self.engine = sqlalchemy.create_engine(engine_url, echo=echo, pool_recycle=pool_recycle, pool_size=pool_size)

        _include_sqlalchemy(self)
示例#28
0
    def mapper(self, *args, **kwargs):
        """return a mapper() function which associates this ScopedSession with the Mapper."""

        from sqlalchemy.orm import mapper

        extension_args = dict([(arg, kwargs.pop(arg)) for arg in get_cls_kwargs(_ScopedExt) if arg in kwargs])

        kwargs["extension"] = extension = to_list(kwargs.get("extension", []))
        if extension_args:
            extension.append(self.extension.configure(**extension_args))
        else:
            extension.append(self.extension)
        return mapper(*args, **kwargs)
示例#29
0
def _possible_configs_for_cls(cls):
    all_configs = set(config.Config.all_configs())
    if cls.__unsupported_on__:
        spec = exclusions.db_spec(*cls.__unsupported_on__)
        for config_obj in list(all_configs):
            if spec(config_obj):
                all_configs.remove(config_obj)
    if getattr(cls, '__only_on__', None):
        spec = exclusions.db_spec(*util.to_list(cls.__only_on__))
        for config_obj in list(all_configs):
            if not spec(config_obj):
                all_configs.remove(config_obj)
    return all_configs
示例#30
0
    def _do_skips(self, cls):
        from sqlalchemy.testing import config
        if hasattr(cls, '__requires__'):
            def test_suite():
                return 'ok'
            test_suite.__name__ = cls.__name__
            for requirement in cls.__requires__:
                check = getattr(config.requirements, requirement)

                if not check.enabled:
                    raise SkipTest(
                        check.reason if check.reason
                        else
                        (
                            "'%s' unsupported on DB implementation '%s' == %s" % (
                                cls.__name__, config.db.name,
                                config.db.dialect.server_version_info
                            )
                        )
                    )

        if cls.__unsupported_on__:
            spec = exclusions.db_spec(*cls.__unsupported_on__)
            if spec(config.db):
                raise SkipTest(
                    "'%s' unsupported on DB implementation '%s' == %s" % (
                     cls.__name__, config.db.name,
                        config.db.dialect.server_version_info)
                    )

        if getattr(cls, '__only_on__', None):
            spec = exclusions.db_spec(*util.to_list(cls.__only_on__))
            if not spec(config.db):
                raise SkipTest(
                    "'%s' unsupported on DB implementation '%s' == %s" % (
                     cls.__name__, config.db.name,
                        config.db.dialect.server_version_info)
                    )

        if getattr(cls, '__skip_if__', False):
            for c in getattr(cls, '__skip_if__'):
                if c():
                    raise SkipTest("'%s' skipped by %s" % (
                        cls.__name__, c.__name__)
                    )

        for db, op, spec in getattr(cls, '__excluded_on__', ()):
            exclusions.exclude(db, op, spec,
                    "'%s' unsupported on DB %s version %s" % (
                    cls.__name__, config.db.name,
                    exclusions._server_version(config.db)))
示例#31
0
    def _get_paths(self, query, raiseerr):
        path = None
        l = []
        current_path = list(query._current_path)

        if self.mapper:
            global class_mapper
            if class_mapper is None:
                from sqlalchemy.orm import class_mapper
            mapper = self.mapper
            if isinstance(self.mapper, type):
                mapper = class_mapper(mapper)
            if mapper is not query.mapper and mapper not in [q.mapper for q in query._entities]:
                raise exceptions.ArgumentError("Can't find entity %s in Query.  Current list: %r" % (str(mapper), [str(m) for m in query._entities]))
        else:
            mapper = query.mapper
        if isinstance(self.key, basestring):
            tokens = self.key.split('.')
        else:
            tokens = util.to_list(self.key)
            
        for token in tokens:
            if isinstance(token, basestring):
                prop = mapper.get_property(token, resolve_synonyms=True, raiseerr=raiseerr)
            elif isinstance(token, PropComparator):
                prop = token.property
                token = prop.key
                    
            else:
                raise exceptions.ArgumentError("mapper option expects string key or list of attributes")
                
            if current_path and token == current_path[1]:
                current_path = current_path[2:]
                continue
                
            if prop is None:
                return []
            path = build_path(mapper, prop.key, path)
            l.append(path)
            if getattr(token, '_of_type', None):
                mapper = token._of_type
            else:
                mapper = getattr(prop, 'mapper', None)
        return l
示例#32
0
    def batch_get(self, *idents):
        mapper = self._only_full_mapper_zero('batch_get')
        lazyload_idents = {}
        or_list = []
        return_list = [None] * len(idents)
        for idx, ident in enumerate(idents):
            if hasattr(ident, '__composite_values__'):
                ident = ident.__composite_values__()
            ident = to_list(ident)
            if len(ident) != len(mapper.primary_key):
                raise exc.InvalidRequestError(
                    "Incorrect number of values in identifier to formulate "
                    "primary key for query.batch_get(); "
                    "primary key columns are %s" %
                    ','.join("'%s'" % c for c in mapper.primary_key))

            key = mapper.identity_key_from_primary_key(ident)
            if not self._populate_existing and \
                    not mapper.always_refresh and \
                    self.with_lockmode is None:

                instance = loading.get_from_identity(
                    self.session, key, attributes.PASSIVE_OFF)
                if instance is not None:
                    # reject calls for id in indentity map but class
                    # mismatch.
                    if not issubclass(instance.__class__, mapper.class_):
                        instance = None
                    return_list[idx] = instance
                    continue

            lazyload_idents.setdefault(key[1], []).append(idx)
            and_list = [col == ide for col, ide in
                        zip(mapper.primary_key, ident)]
            or_list.append(sql.and_(*and_list))

        if or_list:
            # 加载未缓存对象到 return_list 中
            for instance in self.filter(sql.or_(*or_list)):
                ident = mapper.primary_key_from_instance(instance)
                for idx in lazyload_idents[tuple(ident)]:
                    return_list[idx] = instance

        return return_list
示例#33
0
def assign_mapper(ctx, class_, *args, **kwargs):
    extension = kwargs.pop('extension', None)
    if extension is not None:
        extension = util.to_list(extension)
        extension.append(ctx.mapper_extension)
    else:
        extension = ctx.mapper_extension

    validate = kwargs.pop('validate', False)

    if not isinstance(getattr(class_, '__init__'), types.MethodType):

        def __init__(self, **kwargs):
            for key, value in kwargs.items():
                if validate:
                    if not self.mapper.get_property(
                            key, resolve_synonyms=False, raiseerr=False):
                        raise exceptions.ArgumentError(
                            "Invalid __init__ argument: '%s'" % key)
                setattr(self, key, value)

        class_.__init__ = __init__

    class query(object):
        def __getattr__(self, key):
            return getattr(ctx.current.query(class_), key)

        def __call__(self):
            return ctx.current.query(class_)

    if not hasattr(class_, 'query'):
        class_.query = query()

    for name in ('get', 'filter', 'filter_by', 'select', 'select_by',
                 'selectfirst', 'selectfirst_by', 'selectone', 'selectone_by',
                 'get_by', 'join_to', 'join_via', 'count', 'count_by',
                 'options', 'instances'):
        _monkeypatch_query_method(name, ctx, class_)
    for name in ('refresh', 'expire', 'delete', 'expunge', 'update'):
        _monkeypatch_session_method(name, ctx, class_)

    m = mapper(class_, extension=extension, *args, **kwargs)
    class_.mapper = m
    return m
示例#34
0
    def _do_skips(self, cls):
        from sqlalchemy.testing import config
        if hasattr(cls, '__requires__'):

            def test_suite():
                return 'ok'

            test_suite.__name__ = cls.__name__
            for requirement in cls.__requires__:
                check = getattr(config.requirements, requirement)

                if not check.enabled:
                    raise SkipTest(check.reason if check.reason else (
                        "'%s' unsupported on DB implementation '%s' == %s" %
                        (cls.__name__, config.db.name,
                         config.db.dialect.server_version_info)))

        if cls.__unsupported_on__:
            spec = exclusions.db_spec(*cls.__unsupported_on__)
            if spec(config.db):
                raise SkipTest(
                    "'%s' unsupported on DB implementation '%s' == %s" %
                    (cls.__name__, config.db.name,
                     config.db.dialect.server_version_info))

        if getattr(cls, '__only_on__', None):
            spec = exclusions.db_spec(*util.to_list(cls.__only_on__))
            if not spec(config.db):
                raise SkipTest(
                    "'%s' unsupported on DB implementation '%s' == %s" %
                    (cls.__name__, config.db.name,
                     config.db.dialect.server_version_info))

        if getattr(cls, '__skip_if__', False):
            for c in getattr(cls, '__skip_if__'):
                if c():
                    raise SkipTest("'%s' skipped by %s" %
                                   (cls.__name__, c.__name__))

        for db, op, spec in getattr(cls, '__excluded_on__', ()):
            exclusions.exclude(
                db, op, spec, "'%s' unsupported on DB %s version %s" %
                (cls.__name__, config.db.name,
                 exclusions._server_version(config.db)))
示例#35
0
def _register_attribute(strategy,
                        useobject,
                        compare_function=None,
                        typecallable=None,
                        copy_function=None,
                        mutable_scalars=False,
                        uselist=False,
                        callable_=None,
                        proxy_property=None,
                        active_history=False,
                        impl_class=None,
                        **kw):

    prop = strategy.parent_property
    attribute_ext = util.to_list(prop.extension) or []
    if getattr(prop, 'backref', None):
        attribute_ext.append(prop.backref.extension)

    if prop.key in prop.parent._validators:
        attribute_ext.append(
            mapperutil.Validator(prop.key, prop.parent._validators[prop.key]))

    if useobject:
        attribute_ext.append(sessionlib.UOWEventHandler(prop.key))

    for mapper in prop.parent.polymorphic_iterator():
        if (mapper is prop.parent
                or not mapper.concrete) and mapper.has_property(prop.key):
            attributes.register_attribute_impl(
                mapper.class_,
                prop.key,
                parent_token=prop,
                mutable_scalars=mutable_scalars,
                uselist=uselist,
                copy_function=copy_function,
                compare_function=compare_function,
                useobject=useobject,
                extension=attribute_ext,
                trackparent=useobject,
                typecallable=typecallable,
                callable_=callable_,
                active_history=active_history,
                impl_class=impl_class,
                **kw)
示例#36
0
def _possible_configs_for_cls(cls, reasons=None):
    all_configs = set(config.Config.all_configs())

    if cls.__unsupported_on__:
        spec = exclusions.db_spec(*cls.__unsupported_on__)
        for config_obj in list(all_configs):
            if spec(config_obj):
                all_configs.remove(config_obj)

    if getattr(cls, '__only_on__', None):
        spec = exclusions.db_spec(*util.to_list(cls.__only_on__))
        for config_obj in list(all_configs):
            if not spec(config_obj):
                all_configs.remove(config_obj)

    if getattr(cls, '__only_on_config__', None):
        all_configs.intersection_update([cls.__only_on_config__])

    if hasattr(cls, '__requires__'):
        requirements = config.requirements
        for config_obj in list(all_configs):
            for requirement in cls.__requires__:
                check = getattr(requirements, requirement)

                skip_reasons = check.matching_config_reasons(config_obj)
                if skip_reasons:
                    all_configs.remove(config_obj)
                    if reasons is not None:
                        reasons.extend(skip_reasons)
                    break

    if hasattr(cls, '__prefer_requires__'):
        non_preferred = set()
        requirements = config.requirements
        for config_obj in list(all_configs):
            for requirement in cls.__prefer_requires__:
                check = getattr(requirements, requirement)

                if not check.enabled_for_config(config_obj):
                    non_preferred.add(config_obj)
        if all_configs.difference(non_preferred):
            all_configs.difference_update(non_preferred)

    return all_configs
示例#37
0
文件: database.py 项目: zofuthan/july
    def __init__(self, engine_url, echo=False, pool_recycle=3600,
                 pool_size=10, session_extensions=None, session_options=None):
        # create signals sender
        self.sender = str(uuid.uuid4())

        self.session_extensions = to_list(session_extensions, []) + \
                                  [_SignallingSessionExtension()]
        self.session = self.create_scoped_session(session_options)
        self.Model = self.make_declarative_base()

        if engine_url.startswith('sqlite'):
            self.engine = sqlalchemy.create_engine(engine_url, echo=echo)
        else:
            self.engine = sqlalchemy.create_engine(
                engine_url, echo=echo, pool_recycle=pool_recycle,
                pool_size=pool_size
            )

        _include_sqlalchemy(self)
示例#38
0
文件: scoping.py 项目: pguenth/xsbs
    def mapper(self, *args, **kwargs):
        """return a mapper() function which associates this ScopedSession with the Mapper.

        DEPRECATED.

        """

        from sqlalchemy.orm import mapper

        extension_args = dict((arg, kwargs.pop(arg))
                              for arg in get_cls_kwargs(_ScopedExt)
                              if arg in kwargs)

        kwargs['extension'] = extension = to_list(kwargs.get('extension', []))
        if extension_args:
            extension.append(self.extension.configure(**extension_args))
        else:
            extension.append(self.extension)
        return mapper(*args, **kwargs)
示例#39
0
def _do_skips(cls):
    reasons = []
    all_configs = _possible_configs_for_cls(cls, reasons)

    if getattr(cls, '__skip_if__', False):
        for c in getattr(cls, '__skip_if__'):
            if c():
                config.skip_test("'%s' skipped by %s" % (
                    cls.__name__, c.__name__)
                                 )

    if not all_configs:
        if getattr(cls, '__backend__', False):
            msg = "'%s' unsupported for implementation '%s'" % (
                cls.__name__, cls.__only_on__)
        else:
            msg = "'%s' unsupported on any DB implementation %s%s" % (
                cls.__name__,
                ", ".join(
                    "'%s(%s)+%s'" % (
                        config_obj.db.name,
                        ".".join(
                            str(dig) for dig in
                            config_obj.db.dialect.server_version_info),
                        config_obj.db.driver
                    )
                    for config_obj in config.Config.all_configs()
                ),
                ", ".join(reasons)
            )
        config.skip_test(msg)
    elif hasattr(cls, '__prefer_backends__'):
        non_preferred = set()
        spec = exclusions.db_spec(*util.to_list(cls.__prefer_backends__))
        for config_obj in all_configs:
            if not spec(config_obj):
                non_preferred.add(config_obj)
        if all_configs.difference(non_preferred):
            all_configs.difference_update(non_preferred)

    if config._current not in all_configs:
        _setup_config(all_configs.pop(), cls)
示例#40
0
    def test_copy_internals(self):
        for fixture in self.fixtures:
            case_a = fixture()
            case_b = fixture()

            assert case_a[0].compare(case_b[0])

            clone = case_a[0]._clone()
            clone._copy_internals()

            assert clone.compare(case_b[0])

            stack = [clone]
            seen = {clone}
            found_elements = False
            while stack:
                obj = stack.pop(0)

                items = [
                    subelem
                    for key, elem in clone.__dict__.items()
                    if key != "_is_clone_of" and elem is not None
                    for subelem in util.to_list(elem)
                    if (
                        isinstance(subelem, (ColumnElement, ClauseList))
                        and subelem not in seen
                        and not isinstance(subelem, Immutable)
                        and subelem is not case_a[0]
                    )
                ]
                stack.extend(items)
                seen.update(items)

                if obj is not clone:
                    found_elements = True
                    # ensure the element will not compare as true
                    obj.compare = lambda other, **kw: False
                    obj.__visit_name__ = "dont_match"

            if found_elements:
                assert not clone.compare(case_b[0])
            assert case_a[0].compare(case_b[0])
示例#41
0
    def __init__(self, app=None, use_native_unicode=True, session_extensions=None, session_options=None):
        # create signals sender
        self.sender = str(uuid.uuid4())
        self.use_native_unicode = use_native_unicode

        self.session_extensions = to_list(session_extensions, []) + \
            [_SignallingSessionExtension()]
        self.session = self.create_scoped_session(session_options)
        self.Model = self.make_declarative_base()
        self._engine_lock = Lock()
        #self.engine = sqlalchemy.create_engine(engine_url, echo=echo, pool_recycle=pool_recycle, pool_size=pool_size)

        if app is not None:
            self.app = app
            self.init_app(app)
        else:
            self.app = None

        _include_sqlalchemy(self)
        self.Query = BaseQuery
示例#42
0
 def __init__(self,
              manager,
              class_,
              key,
              uselist,
              callable_,
              typecallable,
              cascade=None,
              extension=None,
              **kwargs):
     extension = util.to_list(extension or [])
     extension.insert(0, UOWEventHandler(key, class_, cascade=cascade))
     super(UOWProperty, self).__init__(manager,
                                       key,
                                       uselist,
                                       callable_,
                                       typecallable,
                                       extension=extension,
                                       **kwargs)
     self.class_ = class_
示例#43
0
    def __init__(self,
                 config=None,
                 use_native_unicode=True,
                 session_extensions=None,
                 session_options=None,
                 base_query_class=BaseQuery):
        self.config = config
        self.use_native_unicode = use_native_unicode
        self.session_extensions = to_list(session_extensions, []) + \
            [_SignallingSessionExtension()]

        if session_options is None:
            session_options = {}

        self.session = self.create_scoped_session(session_options)
        self.Model = self.make_declarative_base()
        self._engine_lock = Lock()

        _include_sqlalchemy(self)
        self.Query = base_query_class
示例#44
0
    def _create_prop(self,
                     class_,
                     key,
                     uselist,
                     callable_,
                     typecallable,
                     cascade=None,
                     extension=None,
                     **kwargs):
        extension = util.to_list(extension or [])
        extension.insert(0, UOWEventHandler(key, class_, cascade=cascade))

        return super(UOWAttributeManager,
                     self)._create_prop(class_,
                                        key,
                                        uselist,
                                        callable_,
                                        typecallable,
                                        extension=extension,
                                        **kwargs)
示例#45
0
    def __init__(self,
                 app=None,
                 use_native_unicode=True,
                 session_extensions=None,
                 session_options=None):
        self.use_native_unicode = use_native_unicode
        self.session_extensions = to_list(session_extensions, []) + \
                                  [_SignallingSessionExtension()]
        self.session = self.create_scoped_session(session_options)
        self.Model = self.make_declarative_base()
        self._engine_lock = Lock()

        if app is not None:
            self.app = app
            self.init_app(app)
        else:
            self.app = None

        _include_sqlalchemy(self)
        self.Query = BaseQuery
示例#46
0
    def visit_merge_tree(self, engine):
        param = engine.get_parameters()
        if param is None:
            param = '()'
        else:
            param = self._compile_param(to_list(param))

        text = '{0}{1}\n'.format(engine.name, param)
        if engine.partition_by:
            text += ' PARTITION BY {0}\n'.format(
                self._compile_param(
                    engine.partition_by.get_expressions_or_columns()[0]
                )
            )
        if engine.order_by:
            text += ' ORDER BY {0}\n'.format(
                self._compile_param(
                    engine.order_by.get_expressions_or_columns()
                )
            )
        if engine.primary_key:
            text += ' PRIMARY KEY {0}\n'.format(
                self._compile_param(
                    engine.primary_key.get_expressions_or_columns()
                )
            )
        if engine.sample_by:
            text += ' SAMPLE BY {0}\n'.format(
                self._compile_param(
                    engine.sample_by.get_expressions_or_columns()[0]
                )
            )
        if engine.settings:
            text += ' SETTINGS ' + ', '.join(
                '{key}={value}'.format(
                    key=key,
                    value=value
                )
                for key, value in sorted(engine.settings.items())
            )
        return text
示例#47
0
def only_on(dbs, reason):
    carp = _should_carp_about_exclusion(reason)
    spec = db_spec(*util.to_list(dbs))

    def decorate(fn):
        fn_name = fn.__name__

        def maybe(*args, **kw):
            if spec(config.db):
                return fn(*args, **kw)
            else:
                msg = "'%s' unsupported on DB implementation '%s+%s': %s" % (
                    fn_name, config.db.name, config.db.driver, reason)
                print msg
                if carp:
                    print >> sys.stderr, msg
                return True

        return function_named(maybe, fn_name)

    return decorate
示例#48
0
    def __should_skip_for(self, cls):
        if hasattr(cls, '__requires__'):

            def test_suite():
                return 'ok'

            test_suite.__name__ = cls.__name__
            for requirement in cls.__requires__:
                check = getattr(requires, requirement)
                if check(test_suite)() != 'ok':
                    # The requirement will perform messaging.
                    return True

        if cls.__unsupported_on__:
            spec = testing.db_spec(*cls.__unsupported_on__)
            if spec(testing.db):
                print "'%s' unsupported on DB implementation '%s'" % (
                    cls.__class__.__name__, testing.db.name)
                return True

        if getattr(cls, '__only_on__', None):
            spec = testing.db_spec(*util.to_list(cls.__only_on__))
            if not spec(testing.db):
                print "'%s' unsupported on DB implementation '%s'" % (
                    cls.__class__.__name__, testing.db.name)
                return True

        if getattr(cls, '__skip_if__', False):
            for c in getattr(cls, '__skip_if__'):
                if c():
                    print "'%s' skipped by %s" % (cls.__class__.__name__,
                                                  c.__name__)
                    return True

        for rule in getattr(cls, '__excluded_on__', ()):
            if testing._is_excluded(*rule):
                print "'%s' unsupported on DB %s version %s" % (
                    cls.__class__.__name__, testing.db.name, _server_version())
                return True
        return False
示例#49
0
    def _process_dependent_arguments(self):

        # accept callables for other attributes which may require deferred initialization
        for attr in ('order_by', 'primaryjoin', 'secondaryjoin', 'secondary', '_foreign_keys', 'remote_side'):
            if util.callable(getattr(self, attr)):
                setattr(self, attr, getattr(self, attr)())

        # in the case that InstrumentedAttributes were used to construct
        # primaryjoin or secondaryjoin, remove the "_orm_adapt" annotation so these
        # interact with Query in the same way as the original Table-bound Column objects
        for attr in ('primaryjoin', 'secondaryjoin'):
            val = getattr(self, attr)
            if val is not None:
                util.assert_arg_type(val, sql.ClauseElement, attr)
                setattr(self, attr, _orm_deannotate(val))
        
        if self.order_by:
            self.order_by = [expression._literal_as_column(x) for x in util.to_list(self.order_by)]
        
        self._foreign_keys = util.column_set(expression._literal_as_column(x) for x in util.to_column_set(self._foreign_keys))
        self.remote_side = util.column_set(expression._literal_as_column(x) for x in util.to_column_set(self.remote_side))

        if not self.parent.concrete:
            for inheriting in self.parent.iterate_to_root():
                if inheriting is not self.parent and inheriting._get_property(self.key, raiseerr=False):
                    util.warn(
                        ("Warning: relation '%s' on mapper '%s' supercedes "
                         "the same relation on inherited mapper '%s'; this "
                         "can cause dependency issues during flush") %
                        (self.key, self.parent, inheriting))

        # TODO: remove 'self.table'
        self.target = self.table = self.mapper.mapped_table

        if self.cascade.delete_orphan:
            if self.parent.class_ is self.mapper.class_:
                raise sa_exc.ArgumentError("In relationship '%s', can't establish 'delete-orphan' cascade "
                            "rule on a self-referential relationship.  "
                            "You probably want cascade='all', which includes delete cascading but not orphan detection." %(str(self)))
            self.mapper.primary_mapper().delete_orphans.append((self.key, self.parent.class_))
示例#50
0
    def __init__(self,
                 engine_url,
                 echo=False,
                 pool_recycle=7200,
                 pool_size=10,
                 session_extensions=None,
                 session_options=None,
                 poolclass=QueuePool):
        # create signals sender
        self.sender = str(uuid.uuid4())

        self.session_extensions = to_list(session_extensions, []) + \
            [_SignallingSessionExtension()]
        self.session = self.create_scoped_session(session_options)
        self.Model = self.make_declarative_base()

        self.engine = sqlalchemy.create_engine(engine_url,
                                               echo=echo,
                                               pool_recycle=pool_recycle,
                                               poolclass=poolclass)

        _include_sqlalchemy(self)
示例#51
0
    def __init__(self,
                 app=None,
                 use_native_unicode=True,
                 session_extensions=None,
                 session_options=None):
        self.use_native_unicode = use_native_unicode
        self.session_extensions = to_list(session_extensions, []) + \
                                  [_SignallingSessionExtension()]

        self.session = _create_scoped_session(self, session_options)

        self.Model = declarative_base(cls=Model, name='Model')
        self.Model.query = _QueryProperty(self)

        self._engine_lock = Lock()

        if app is not None:
            self.app = app
            self.init_app(app)
        else:
            self.app = None

        _include_sqlalchemy(self)
    def _filter_or_exclude(self, negate, kwargs):
        q = self
        negate_if = lambda expr: expr if not negate else ~expr
        column = None

        for arg, value in kwargs.iteritems():
            for token in arg.split('__'):
                if column is None:
                    column = _entity_descriptor(q._joinpoint_zero(), token)
                    if column.impl.uses_objects:
                        q = q.join(column)
                        column = None
                elif token in self._underscore_operators:
                    op = self._underscore_operators[token]
                    q = q.filter(negate_if(op(column, *to_list(value))))
                    column = None
                else:
                    raise ValueError('No idea what to do with %r' % token)
            if column is not None:
                q = q.filter(negate_if(column == value))
                column = None
            q = q.reset_joinpoint()
        return q
    def visit_merge_tree(self, engine):
        param = engine.get_parameters()
        if param is None:
            param = '()'
        else:
            param = self._compile_param(to_list(param))

        text = '{0}{1}\n'.format(engine.name, param)
        if engine.partition_by:
            text += ' PARTITION BY {0}\n'.format(
                self._compile_param(
                    engine.partition_by.get_expressions_or_columns(),
                    opt_list=True))
        if engine.order_by:
            text += ' ORDER BY {0}\n'.format(
                self._compile_param(
                    engine.order_by.get_expressions_or_columns(),
                    opt_list=True))
        if engine.primary_key:
            text += ' PRIMARY KEY {0}\n'.format(
                self._compile_param(
                    engine.primary_key.get_expressions_or_columns(),
                    opt_list=True))
        if engine.sample_by:
            text += ' SAMPLE BY {0}\n'.format(
                self._compile_param(
                    engine.sample_by.get_expressions_or_columns()[0]))
        if engine.ttl:
            compile = self.sql_compiler.process
            text += ' TTL {0}\n'.format(',\n     '.join(
                compile(i, include_table=False, literal_binds=True)
                for i in engine.ttl.get_expressions_or_columns()))
        if engine.settings:
            text += ' SETTINGS ' + ', '.join(
                '{key}={value}'.format(key=key, value=value)
                for key, value in sorted(engine.settings.items()))
        return text
示例#54
0
    def setup(self, key, statement, eagertable=None, **options):
        """add a left outer join to the statement thats being constructed"""

        # initialize the "eager" chain of EagerLoader objects
        # this can't quite be done in the do_init_mapper() step
        self._create_eager_chain()

        if hasattr(statement, '_outerjoin'):
            towrap = statement._outerjoin
        else:
            towrap = self.localparent.mapped_table

#       print "hello, towrap", str(towrap)
        if self.secondaryjoin is not None:
            statement._outerjoin = sql.outerjoin(towrap, self.eagersecondary,
                                                 self.eagerprimary).outerjoin(
                                                     self.eagertarget,
                                                     self.eagersecondaryjoin)
            if self.order_by is False and self.secondary.default_order_by(
            ) is not None:
                statement.order_by(*self.eagersecondary.default_order_by())
        else:
            statement._outerjoin = towrap.outerjoin(self.eagertarget,
                                                    self.eagerprimary)
            if self.order_by is False and self.eagertarget.default_order_by(
            ) is not None:
                statement.order_by(*self.eagertarget.default_order_by())

        if self.eager_order_by:
            statement.order_by(*util.to_list(self.eager_order_by))
        elif getattr(statement, 'order_by_clause', None):
            self._aliasize_orderby(statement.order_by_clause, False)

        statement.append_from(statement._outerjoin)
        for key, value in self.mapper.props.iteritems():
            value.setup(key, statement, eagertable=self.eagertarget)
def _possible_configs_for_cls(cls, reasons=None, sparse=False):
    all_configs = set(config.Config.all_configs())

    if cls.__unsupported_on__:
        spec = exclusions.db_spec(*cls.__unsupported_on__)
        for config_obj in list(all_configs):
            if spec(config_obj):
                all_configs.remove(config_obj)

    if getattr(cls, "__only_on__", None):
        spec = exclusions.db_spec(*util.to_list(cls.__only_on__))
        for config_obj in list(all_configs):
            if not spec(config_obj):
                all_configs.remove(config_obj)

    if getattr(cls, "__only_on_config__", None):
        all_configs.intersection_update([cls.__only_on_config__])

    if hasattr(cls, "__requires__"):
        requirements = config.requirements
        for config_obj in list(all_configs):
            for requirement in cls.__requires__:
                check = getattr(requirements, requirement)

                skip_reasons = check.matching_config_reasons(config_obj)
                if skip_reasons:
                    all_configs.remove(config_obj)
                    if reasons is not None:
                        reasons.extend(skip_reasons)
                    break

    if hasattr(cls, "__prefer_requires__"):
        non_preferred = set()
        requirements = config.requirements
        for config_obj in list(all_configs):
            for requirement in cls.__prefer_requires__:
                check = getattr(requirements, requirement)

                if not check.enabled_for_config(config_obj):
                    non_preferred.add(config_obj)
        if all_configs.difference(non_preferred):
            all_configs.difference_update(non_preferred)

    if sparse:
        # pick only one config from each base dialect
        # sorted so we get the same backend each time selecting the highest
        # server version info.
        per_dialect = {}
        for cfg in reversed(
                sorted(
                    all_configs,
                    key=lambda cfg: (
                        cfg.db.name,
                        cfg.db.driver,
                        cfg.db.dialect.server_version_info,
                    ),
                )):
            db = cfg.db.name
            if db not in per_dialect:
                per_dialect[db] = cfg
        return per_dialect.values()

    return all_configs
示例#56
0
def _do_skips(cls):
    all_configs = set(config.Config.all_configs())
    reasons = []

    if hasattr(cls, '__requires__'):
        requirements = config.requirements
        for config_obj in list(all_configs):
            for requirement in cls.__requires__:
                check = getattr(requirements, requirement)

                if check.predicate(config_obj):
                    all_configs.remove(config_obj)
                    if check.reason:
                        reasons.append(check.reason)
                    break

    if hasattr(cls, '__prefer_requires__'):
        non_preferred = set()
        requirements = config.requirements
        for config_obj in list(all_configs):
            for requirement in cls.__prefer_requires__:
                check = getattr(requirements, requirement)

                if check.predicate(config_obj):
                    non_preferred.add(config_obj)
        if all_configs.difference(non_preferred):
            all_configs.difference_update(non_preferred)

    if cls.__unsupported_on__:
        spec = exclusions.db_spec(*cls.__unsupported_on__)
        for config_obj in list(all_configs):
            if spec(config_obj):
                all_configs.remove(config_obj)

    if getattr(cls, '__only_on__', None):
        spec = exclusions.db_spec(*util.to_list(cls.__only_on__))
        for config_obj in list(all_configs):
            if not spec(config_obj):
                all_configs.remove(config_obj)


    if getattr(cls, '__skip_if__', False):
        for c in getattr(cls, '__skip_if__'):
            if c():
                raise SkipTest("'%s' skipped by %s" % (
                    cls.__name__, c.__name__)
                )

    for db_spec, op, spec in getattr(cls, '__excluded_on__', ()):
        for config_obj in list(all_configs):
            if exclusions.skip_if(
                    exclusions.SpecPredicate(db_spec, op, spec)
                    ).predicate(config_obj):
                all_configs.remove(config_obj)


    if not all_configs:
        raise SkipTest(
            "'%s' unsupported on DB implementation %s%s" % (
                cls.__name__,
                ", ".join("'%s' = %s" % (
                                config_obj.db.name,
                                config_obj.db.dialect.server_version_info)
                    for config_obj in config.Config.all_configs()
                ),
                ", ".join(reasons)
            )
        )
    elif hasattr(cls, '__prefer_backends__'):
        non_preferred = set()
        spec = exclusions.db_spec(*util.to_list(cls.__prefer_backends__))
        for config_obj in all_configs:
            if not spec(config_obj):
                non_preferred.add(config_obj)
        if all_configs.difference(non_preferred):
            all_configs.difference_update(non_preferred)

    if config._current not in all_configs:
        _setup_config(all_configs.pop(), cls)
示例#57
0
    def _get_paths(self, query, raiseerr):
        path = None
        entity = None
        l = []
        mappers = []

        # _current_path implies we're in a secondary load with an
        # existing path

        current_path = list(query._current_path)
        tokens = []
        for key in util.to_list(self.key):
            if isinstance(key, basestring):
                tokens += key.split('.')
            else:
                tokens += [key]
        for token in tokens:
            if isinstance(token, basestring):
                if not entity:
                    if current_path:
                        if current_path[1] == token:
                            current_path = current_path[2:]
                            continue
                    entity = query._entity_zero()
                    path_element = entity.path_entity
                    mapper = entity.mapper
                mappers.append(mapper)
                prop = mapper.get_property(token,
                        resolve_synonyms=True, raiseerr=raiseerr)
                key = token
            elif isinstance(token, PropComparator):
                prop = token.property
                if not entity:
                    if current_path:
                        if current_path[0:2] == [token.parententity,
                                prop.key]:
                            current_path = current_path[2:]
                            continue
                    entity = self._find_entity(query,
                            token.parententity, raiseerr)
                    if not entity:
                        return [], []
                    path_element = entity.path_entity
                    mapper = entity.mapper
                mappers.append(prop.parent)
                key = prop.key
            else:
                raise sa_exc.ArgumentError('mapper option expects '
                        'string key or list of attributes')
            if prop is None:
                return [], []
            path = build_path(path_element, prop.key, path)
            l.append(path)
            if getattr(token, '_of_type', None):
                path_element = mapper = token._of_type
            else:
                path_element = mapper = getattr(prop, 'mapper', None)
            if path_element:
                path_element = path_element

        if current_path:
            return [], []
        return l, mappers
示例#58
0
def signalling_mapper(*args, **kwargs):
    """Replacement for mapper that injects some extra extensions"""
    extensions = to_list(kwargs.pop('extension', None), [])
    extensions.append(_SignalTrackingMapperExtension())
    kwargs['extension'] = extensions
    return sqlalchemy.orm.mapper(*args, **kwargs)
示例#59
0
 def __init__(self, *args, **kwargs):
     extensions = to_list(kwargs.pop('extension', None), [])
     extensions.append(_SignalTrackingMapperExtension())
     kwargs['extension'] = extensions
     Mapper.__init__(self, *args, **kwargs)
示例#60
0
 def __init__(self, nest_on=None, hash_key=None, **kwargs):
     super(LegacySession, self).__init__(**kwargs)
     self.parent_uow = None
     self.begin_count = 0
     self.nest_on = util.to_list(nest_on)
     self.__pushed_count = 0