Exemplo n.º 1
0
def async_setup(hass: HomeAssistant) -> None:
    """Set up the history hooks."""
    hass.data[STATISTICS_BAKERY] = baked.bakery()
    hass.data[STATISTICS_META_BAKERY] = baked.bakery()
    hass.data[STATISTICS_SHORT_TERM_BAKERY] = baked.bakery()

    def entity_id_changed(event: Event) -> None:
        """Handle entity_id changed."""
        old_entity_id = event.data["old_entity_id"]
        entity_id = event.data["entity_id"]
        with session_scope(hass=hass) as session:
            session.query(StatisticsMeta).filter(
                (StatisticsMeta.statistic_id == old_entity_id)
                & (StatisticsMeta.source == DOMAIN)).update(
                    {StatisticsMeta.statistic_id: entity_id})

    @callback
    def entity_registry_changed_filter(event: Event) -> bool:
        """Handle entity_id changed filter."""
        if event.data[
                "action"] != "update" or "old_entity_id" not in event.data:
            return False

        return True

    if hass.is_running:
        hass.bus.async_listen(
            entity_registry.EVENT_ENTITY_REGISTRY_UPDATED,
            entity_id_changed,
            event_filter=entity_registry_changed_filter,
        )
Exemplo n.º 2
0
 def __init__(self, database, logLevel='INFO'):
     self.db = database
     self.logger = Logger(__name__, level=logLevel)
     super(BaseListener, self).__init__()
     self.session = database.session
     self.bakery = baked.bakery()
     self.bake_common_queries()
Exemplo n.º 3
0
    def test_subqueryload_post_context(self):
        User = self.classes.User
        Address = self.classes.Address

        assert_result = [
            User(
                id=7, addresses=[Address(id=1, email_address="*****@*****.**")]
            )
        ]

        self.bakery = baked.bakery(size=3)

        bq = self.bakery(lambda s: s.query(User))

        bq += lambda q: q.options(subqueryload(User.addresses))
        bq += lambda q: q.order_by(User.id)
        bq += lambda q: q.filter(User.name == bindparam("name"))
        sess = Session()

        def set_params(q):
            return q.params(name="jack")

        # test that the changes we make using with_post_criteria()
        # are also applied to the subqueryload query.
        def go():
            result = bq(sess).with_post_criteria(set_params).all()
            eq_(assert_result, result)

        self.assert_sql_count(testing.db, go, 2)
Exemplo n.º 4
0
 def get_tender_by_id(self, tender_id: int):
     with self.session_scope() as session:
         bakery = baked.bakery()
         bq = bakery(lambda s: session.query(Tender))
         bq += lambda q: q.filter(Tender.id == tender_id)
         result = bq(session).params(tender_id=tender_id).all()
     return result
Exemplo n.º 5
0
 def get_tenders(self):
     with self.session_scope() as session:
         bakery = baked.bakery()
         bq = bakery(lambda s: session.query(Tender))
         bq += lambda q: q.limit(10)
         result = bq(session).params().all()
     return result
Exemplo n.º 6
0
def cats(data, session):

    if data:
        try:
            attr, order, offset, limit = validate_cats_params(data)
            if order and not attr:
                raise ValueError
        except ValueError:
            return (
                400,
                {"status": "Bad request.", "exception": ""},
                {"Content-type": "application/json"},
            )

        bakery = baked.bakery()
        cat = bakery(lambda session: session.query(Cats))

        if attr:
            sort_func = desc if order == "desc" else asc
            cat += lambda c: c.order_by(sort_func(getattr(Cats, attr)))

        if offset:
            cat += lambda c: c.offset(offset)
        if limit:
            cat += lambda c: c.limit(limit)

        cats = cat(session).all()
    else:
        cats = session.query(Cats)

    cats_list = []
    for i in cats:
        cats_list.extend([i.to_dict()])

    return (200, cats_list, {"Content-type": "application/json"})
Exemplo n.º 7
0
def test_advanced_examples():
    # Specify exactly what to return by accessing the underlying query.
    print_query(
        "session(User)['jack'].addresses._query.add_columns(User.id, Address.id)",
        [(1, 1), (1, 2)], globals())

    # If `QueryMakerSession` isn't used, the session can be provided at the end of the query. However, this means the ``.q`` property won't be useful (since it has no assigned session).
    print_query("User['jack'].to_query(session)", [jack], globals())

    # If the `QueryMakerDeclarativeMeta` metaclass wasn't used, this performs the equivalent of ``User['jack']`` manually.
    print_query("QueryMaker(User)['jack'].to_query(session)", [jack],
                globals())

    # Add to an existing query: first, find the User named jack.
    q = session.query().select_from(User).filter(User.name == 'jack')
    # Then ask for the Address for [email protected].
    print_query("q.query_maker().addresses['*****@*****.**']",
                [jack.addresses[0]], globals(), locals())
    # Do the same manually (without relying on the `QueryMakerQuery` ``query_maker`` method).
    print_query("QueryMaker(query=q).addresses['*****@*****.**']",
                [jack.addresses[0]], globals(), locals())

    # `Baked queries <http://docs.sqlalchemy.org/en/latest/orm/extensions/baked.html>`_ are supported.
    bakery = baked.bakery()
    baked_query = bakery(lambda session: session(User))
    baked_query += lambda query: query[User.name == bindparam('username')]
    # The last item in the query must end with a ``.q``. Note that this doesn't print nicely. Using ``.to_query()`` instead fixes this.
    baked_query += lambda query: query.q.order_by(User.id).q
    print_query(
        "baked_query(session).params(username='******', email='*****@*****.**')",
        [jack], globals(), locals())
Exemplo n.º 8
0
    def test_subqueryload_post_context_w_cancelling_event(
        self, before_compile_nobake_fixture
    ):
        User = self.classes.User
        Address = self.classes.Address

        assert_result = [
            User(
                id=7, addresses=[Address(id=1, email_address="*****@*****.**")]
            )
        ]

        self.bakery = baked.bakery(size=3)

        bq = self.bakery(lambda s: s.query(User))

        bq += lambda q: q.options(subqueryload(User.addresses))
        bq += lambda q: q.order_by(User.id)
        bq += lambda q: q.filter(User.name == bindparam("name"))
        sess = fixture_session()

        def set_params(q):
            return q.params(name="jack")

        # test that the changes we make using with_post_criteria()
        # are also applied to the subqueryload query.
        def go():
            result = bq(sess).with_post_criteria(set_params).all()
            eq_(assert_result, result)

        self.assert_sql_count(testing.db, go, 2)
Exemplo n.º 9
0
def test_baked_query(n):
    """test a baked query of the full entity."""
    bakery = baked.bakery()
    s = Session(bind=engine)
    for id_ in random.sample(ids, n):
        q = bakery(lambda s: s.query(Customer))
        q += lambda q: q.filter(Customer.id == bindparam('id'))
        q(s).params(id=id_).one()
Exemplo n.º 10
0
 def get_last_modified_date(self):
     with self.session_scope() as session:
         bakery = baked.bakery()
         max_date_modified = func.max(Tender.date_modified).label('tesr')
         bq = bakery(lambda s: session.query(max_date_modified))
         print(bq)
         result = bq(session).params().one()
     return result
Exemplo n.º 11
0
def test_baked_query(n):
    """test a baked query of the full entity."""
    bakery = baked.bakery()
    s = Session(bind=engine)
    for id_ in random.sample(ids, n):
        q = bakery(lambda s: s.query(Customer))
        q += lambda q: q.filter(Customer.id == bindparam("id"))
        q(s).params(id=id_).one()
Exemplo n.º 12
0
def test_baked_query_cols_only(n):
    """test a baked query of only the entity columns."""
    bakery = baked.bakery()
    s = Session(bind=engine)
    for id_ in random.sample(ids, n):
        q = bakery(lambda s: s.query(Customer.id, Customer.name, Customer.
                                     description))
        q += lambda q: q.filter(Customer.id == bindparam("id"))
        q(s).params(id=id_).one()
Exemplo n.º 13
0
def test_baked_query_cols_only(n):
    """test a baked query of only the entity columns."""
    bakery = baked.bakery()
    s = Session(bind=engine)
    for id_ in random.sample(ids, n):
        q = bakery(
            lambda s: s.query(
                Customer.id, Customer.name, Customer.description))
        q += lambda q: q.filter(Customer.id == bindparam('id'))
        q(s).params(id=id_).one()
Exemplo n.º 14
0
 def __init__(self):
     self.session = create_session()
     self.clear_tables()
     self.os_names_db_objs = list()
     self.add_default_values()
     self.add_triggers()
     self.add_views()
     # inspector = reflection.Inspector.from_engine(get_engine())
     # print("Tables:", inspector.get_table_names())
     # print("Views:", inspector.get_view_names())
     self.baked_queries_map = self.bake_baked_queries()
     self.bakery = baked.bakery()
Exemplo n.º 15
0
    def __init__(self,
                 db_endpoint: str,
                 fs_root: str,
                 max_cache_mem: int,
                 ttl: int,
                 engine_kwargs: dict = None):
        """
        Initialize a new instance of SQLAlchemyModelRepository.

        :param db_endpoint: SQLAlchemy connection string.
        :param fs_root: Root directory where to store the models.
        :param max_cache_mem: Maximum memory size to use for model cache (in bytes).
        :param ttl: Time-to-live for each model in the cache (in seconds).
        :param engine_kwargs: Passed directly to SQLAlchemy's `create_engine()`.
        """
        self.fs_root = fs_root
        # A version of db_endpoint that never contain password is needed for logging
        db_endpoint_components = urlparse(db_endpoint)
        if db_endpoint_components.password is not None:
            password, netloc = db_endpoint_components.password, db_endpoint_components.netloc
            password_index = netloc.rindex(password)
            safe_netloc = "%s%s%s" % (
                db_endpoint_components.netloc[:password_index], "<PASSWORD>",
                db_endpoint_components.netloc[password_index + len(password):])
            safe_db_endpoint_components = list(db_endpoint_components)
            safe_db_endpoint_components[1] = safe_netloc
            self._safe_db_endpoint = urlunparse(safe_db_endpoint_components)
        else:
            self._safe_db_endpoint = db_endpoint
        must_initialize = not database_exists(db_endpoint)
        if must_initialize:
            self._log.debug("%s does not exist, creating",
                            self._safe_db_endpoint)
            create_database(db_endpoint)
            self._log.warning("created a new database at %s",
                              self._safe_db_endpoint)
        self._engine = create_engine(
            db_endpoint,
            **(engine_kwargs if engine_kwargs is not None else {}))
        must_initialize |= not self._engine.has_table(Model.__tablename__)
        if must_initialize:
            Model.metadata.create_all(self._engine)
        self._sessionmaker = ContextSessionMaker(
            sessionmaker(bind=self._engine))
        bakery = baked.bakery()
        self._get_query = bakery(lambda session: session.query(Model))
        self._get_query += lambda query: query.filter(
            and_(Model.analyzer == bindparam("analyzer"), Model.repository ==
                 bindparam("repository")))
        self._cache = cachetools.TTLCache(maxsize=max_cache_mem,
                                          ttl=ttl,
                                          getsizeof=asizeof)
        self._cache_lock = threading.Lock()
Exemplo n.º 16
0
    def __init__(self):
        self.session = create_session()
        self.read_func_by_format = {"info": self.read_from_svn_info,
                                    "text": self.read_from_text,
                                    "props": self.read_props,
                                    "file-sizes": self.read_file_sizes
                                    }

        self.write_func_by_format = {"text": self.write_as_text,}
        self.files_read_list = list()
        self.files_written_list = list()
        self.comments = list()
        self.baked_queries_map = self.bake_baked_queries()
        self.bakery = baked.bakery()
Exemplo n.º 17
0
async def async_setup(hass, config):
    """Set up the history hooks."""
    conf = config.get(DOMAIN, {})

    filters = sqlalchemy_filter_from_include_exclude_conf(conf)

    hass.data[HISTORY_BAKERY] = baked.bakery()

    use_include_order = conf.get(CONF_ORDER)

    hass.http.register_view(HistoryPeriodView(filters, use_include_order))
    hass.components.frontend.async_register_built_in_panel(
        "history", "history", "hass:poll-box")

    return True
Exemplo n.º 18
0
def _default_recorder(hass):
    """Return a recorder with reasonable defaults."""
    return Recorder(
        hass,
        auto_purge=True,
        auto_repack=True,
        keep_days=7,
        commit_interval=1,
        uri="sqlite://",
        db_max_retries=10,
        db_retry_wait=3,
        entity_filter=CONFIG_SCHEMA({DOMAIN: {}}),
        exclude_t=[],
        exclude_attributes_by_domain={},
        bakery=baked.bakery(),
    )
Exemplo n.º 19
0
    def test_subquery_eagerloading(self):
        User = self.classes.User
        Address = self.classes.Address
        Order = self.classes.Order

        # Override the default bakery for one with a smaller size. This used to
        # trigger a bug when unbaking subqueries.
        self.bakery = baked.bakery(size=3)
        base_bq = self.bakery(lambda s: s.query(User))

        base_bq += lambda q: q.options(subqueryload(User.addresses),
                                       subqueryload(User.orders))
        base_bq += lambda q: q.order_by(User.id)

        assert_result = [
            User(id=7,
                 addresses=[Address(id=1, email_address='*****@*****.**')],
                 orders=[Order(id=1), Order(id=3),
                         Order(id=5)]),
            User(id=8,
                 addresses=[
                     Address(id=2, email_address='*****@*****.**'),
                     Address(id=3, email_address='*****@*****.**'),
                     Address(id=4, email_address='*****@*****.**'),
                 ]),
            User(id=9,
                 addresses=[Address(id=5)],
                 orders=[Order(id=2), Order(id=4)]),
            User(id=10, addresses=[])
        ]

        for i in range(4):
            for cond1, cond2 in itertools.product(*[(False, True)
                                                    for j in range(2)]):
                bq = base_bq._clone()

                sess = Session()

                if cond1:
                    bq += lambda q: q.filter(User.name == 'jack')
                else:
                    bq += lambda q: q.filter(User.name.like('%ed%'))

                if cond2:
                    ct = func.count(Address.id).label('count')
                    subq = sess.query(
                        ct,
                        Address.user_id).group_by(Address.user_id).\
                        having(ct > 2).subquery()

                    bq += lambda q: q.join(subq)

                if cond2:
                    if cond1:

                        def go():
                            result = bq(sess).all()
                            eq_([], result)

                        self.assert_sql_count(testing.db, go, 1)
                    else:

                        def go():
                            result = bq(sess).all()
                            eq_(assert_result[1:2], result)

                        self.assert_sql_count(testing.db, go, 3)
                else:
                    if cond1:

                        def go():
                            result = bq(sess).all()
                            eq_(assert_result[0:1], result)

                        self.assert_sql_count(testing.db, go, 3)
                    else:

                        def go():
                            result = bq(sess).all()
                            eq_(assert_result[1:3], result)

                        self.assert_sql_count(testing.db, go, 3)

                sess.close()
 def getbakery(cls):
     if cls._bakery is None:
         cls._bakery = staticmethod(bakery())
     return cls._bakery
from sqlalchemy import func, bindparam, select
from sqlalchemy.orm import make_transient


from glycresoft_sqlalchemy.data_model import (
    Decon2LSPeak, Decon2LSPeakGroup, Decon2LSPeakToPeakGroupMap,
    PipelineModule, MSScan, ScanBase)

from glycresoft_sqlalchemy.utils.common_math import ppm_error

from .common import (
    expanding_window, expected_a_peak_regression, centroid_scan_error_regression)

TDecon2LSPeakGroup = Decon2LSPeakGroup.__table__

query_oven = bakery()
get_group_id_by_mass_window = query_oven(lambda session: session.query(Decon2LSPeakGroup.id))
get_group_id_by_mass_window += lambda q: q.filter(Decon2LSPeakGroup.weighted_monoisotopic_mass.between(
            bindparam("lower"), bindparam("upper")))
get_group_id_by_mass_window += lambda q: q.filter(Decon2LSPeakGroup.sample_run_id == bindparam("sample_run_id"))


class Decon2LSPeakGrouper(PipelineModule):
    '''
    Pipeline Step to post-process Decon2LSPeaks, clustering them by mass
    and calculating trends across groups.
    '''
    def __init__(self, database_path, sample_run_id=1, grouping_error_tolerance=8e-5,
                 minimum_scan_count=1, max_charge_state=8,
                 minimum_abundance_ratio=0.01, minimum_mass=1200.,
                 maximum_mass=15000., minimum_signal_to_noise=1.,
Exemplo n.º 22
0
def async_setup(opp):
    """Set up the history hooks."""
    opp.data[STATISTICS_BAKERY] = baked.bakery()
Exemplo n.º 23
0
from neutron_lib.callbacks import registry
from neutron_lib.callbacks import resources
from neutron_lib.plugins import directory
from oslo_log import log as logging
import sqlalchemy as sa
from sqlalchemy.ext import baked

from gbpservice.neutron.plugins.ml2plus.drivers.apic_aim import constants
from gbpservice.neutron.services.grouppolicy.common import exceptions as exc
from gbpservice.neutron.services.sfc.aim import constants as sfc_cts
from gbpservice.neutron.services.sfc.aim import exceptions as sfc_exc

LOG = logging.getLogger(__name__)
flowclassifier.SUPPORTED_L7_PARAMETERS.update(sfc_cts.AIM_FLC_L7_PARAMS)

BAKERY = baked.bakery(_size_alert=lambda c: LOG.warning(
    "sqlalchemy baked query cache size exceeded in %s" % __name__))


class FlowclassifierAIMDriverBase(base.FlowClassifierDriverBase):
    def create_flow_classifier_precommit(self, context):
        pass

    def create_flow_classifier(self, context):
        pass

    def update_flow_classifier(self, context):
        pass

    def delete_flow_classifier(self, context):
        pass
Exemplo n.º 24
0
    def test_subquery_eagerloading(self):
        User = self.classes.User
        Address = self.classes.Address
        Order = self.classes.Order

        self.bakery = baked.bakery()
        base_bq = self.bakery(lambda s: s.query(User))

        base_bq += lambda q: q.options(
            subqueryload(User.addresses), subqueryload(User.orders)
        )
        base_bq += lambda q: q.order_by(User.id)

        assert_result = [
            User(
                id=7,
                addresses=[Address(id=1, email_address="*****@*****.**")],
                orders=[Order(id=1), Order(id=3), Order(id=5)],
            ),
            User(
                id=8,
                addresses=[
                    Address(id=2, email_address="*****@*****.**"),
                    Address(id=3, email_address="*****@*****.**"),
                    Address(id=4, email_address="*****@*****.**"),
                ],
            ),
            User(
                id=9,
                addresses=[Address(id=5)],
                orders=[Order(id=2), Order(id=4)],
            ),
            User(id=10, addresses=[]),
        ]

        for i in range(4):
            for cond1, cond2 in itertools.product(
                *[(False, True) for j in range(2)]
            ):
                bq = base_bq._clone()

                sess = fixture_session()

                if cond1:
                    bq += lambda q: q.filter(User.name == "jack")
                else:
                    bq += lambda q: q.filter(User.name.like("%ed%"))

                if cond2:
                    ct = func.count(Address.id).label("count")
                    subq = (
                        sess.query(ct, Address.user_id)
                        .group_by(Address.user_id)
                        .having(ct > 2)
                        .subquery()
                    )

                    bq += lambda q: q.join(subq)

                if cond2:
                    if cond1:

                        def go():
                            result = bq(sess).all()
                            eq_([], result)

                        self.assert_sql_count(testing.db, go, 1)
                    else:

                        def go():
                            result = bq(sess).all()
                            eq_(assert_result[1:2], result)

                        self.assert_sql_count(testing.db, go, 3)
                else:
                    if cond1:

                        def go():
                            result = bq(sess).all()
                            eq_(assert_result[0:1], result)

                        self.assert_sql_count(testing.db, go, 3)
                    else:

                        def go():
                            result = bq(sess).all()
                            eq_(assert_result[1:3], result)

                        self.assert_sql_count(testing.db, go, 3)

                sess.close()
Exemplo n.º 25
0
# coding: utf-8

from enum import IntEnum, auto
from dataclasses import dataclass, astuple
import sqlalchemy as sa
from sqlalchemy.ext import baked
from bookworm import config
from bookworm.logger import logger
from bookworm.database.models import Book
from .annotation_models import Bookmark, Note, Quote


log = logger.getChild(__name__)
# The bakery caches query objects to avoid recompiling them into strings in every call
BAKERY = baked.bakery()


@dataclass
class AnnotationFilterCriteria:
    book_id: int = 0
    tag: str = ""
    section_title: str = ""
    content_snip: str = ""

    def any(self):
        return any(astuple(self))

    def filter_query(self, model, query):
        if not self.any():
            return query
        clauses = []
from sqlalchemy import (PickleType, Numeric, Unicode, Table, bindparam,
                        Column, Integer, ForeignKey, UnicodeText, Boolean)

import numpy as np

from .base import Base
from .connection import DatabaseManager
from .generic import MutableDict, ParameterStore

from ..structure.composition import Composition
from ..utils.common_math import DPeak
from ..utils.collectiontools import groupby

PROTON = Composition("H+").mass

observed_ions_bakery = bakery()


class HasPeakChromatogramData(object):

    peak_data = Column(MutableDict.as_mutable(PickleType))

    def get_chromatogram(self):
        peak_data = self.peak_data
        scans = peak_data['scan_times']
        intensity = peak_data['intensities']

        scan_groups = groupby(
            zip(scans, intensity), key_fn=operator.itemgetter(0),
            transform_fn=operator.itemgetter(1))
        scans = []
Exemplo n.º 27
0
def iterate_db_families(session, tax_db, families_query):
    """Returns an iterator over families in the Dfam MySQL database."""
    class_db = load_classification(session)

    # A "bakery" caches queries. The performance gains are worth it here, where
    # the queries are done many times with only the id changing. Another
    # approach that could be used is to make each of these queries once instead
    # of in a loop, but that would require a more significant restructuring.
    bakery = baked.bakery()

    clade_query = bakery(lambda s: s.query(dfam.t_family_clade.c.dfam_taxdb_tax_id))
    clade_query += lambda q: q.filter(dfam.t_family_clade.c.family_id == bindparam("id"))

    search_stage_query = bakery(lambda s: s.query(dfam.t_family_has_search_stage.c.repeatmasker_stage_id))
    search_stage_query += lambda q: q.filter(dfam.t_family_has_search_stage.c.family_id == bindparam("id"))

    buffer_stage_query = bakery(lambda s: s.query(
        dfam.FamilyHasBufferStage.repeatmasker_stage_id,
        dfam.FamilyHasBufferStage.start_pos,
        dfam.FamilyHasBufferStage.end_pos,
    ))
    buffer_stage_query += lambda q: q.filter(dfam.FamilyHasBufferStage.family_id == bindparam("id"))

    assembly_data_query = bakery(lambda s: s.query(
        dfam.Assembly.dfam_taxdb_tax_id,
        dfam.FamilyAssemblyDatum.hmm_hit_GA,
        dfam.FamilyAssemblyDatum.hmm_hit_TC,
        dfam.FamilyAssemblyDatum.hmm_hit_NC,
        dfam.FamilyAssemblyDatum.hmm_fdr,
    ))
    assembly_data_query += lambda q: q.filter(dfam.FamilyAssemblyDatum.family_id == bindparam("id"))
    assembly_data_query += lambda q: q.filter(dfam.Assembly.id == dfam.FamilyAssemblyDatum.assembly_id)

    feature_query = bakery(lambda s: s.query(dfam.FamilyFeature))
    feature_query += lambda q: q.filter(dfam.FamilyFeature.family_id == bindparam("id"))

    feature_attr_query = bakery(lambda s: s.query(dfam.FeatureAttribute))
    feature_attr_query += lambda q: q.filter(dfam.FeatureAttribute.family_feature_id == bindparam("id"))

    cds_query = bakery(lambda s: s.query(dfam.CodingSequence))
    cds_query += lambda q: q.filter(dfam.CodingSequence.family_id == bindparam("id"))

    alias_query = bakery(lambda s: s.query(dfam.FamilyDatabaseAlia))
    alias_query += lambda q: q.filter(dfam.FamilyDatabaseAlia.family_id == bindparam("id"))

    citation_query = bakery(lambda s: s.query(
        dfam.Citation.title,
        dfam.Citation.authors,
        dfam.Citation.journal,
        dfam.FamilyHasCitation.order_added,
    ))
    citation_query += lambda q: q.filter(dfam.Citation.pmid == dfam.FamilyHasCitation.citation_pmid)
    citation_query += lambda q: q.filter(dfam.FamilyHasCitation.family_id == bindparam("id"))

    hmm_query = bakery(lambda s: s.query(dfam.HmmModelDatum.hmm))
    hmm_query += lambda q: q.filter(dfam.HmmModelDatum.family_id == bindparam("id"))

    for record in families_query:
        family = famdb.Family()

        # REQUIRED FIELDS
        family.name = record.name
        family.accession = record.accession
        family.title = record.title
        family.version = record.version
        family.consensus = record.consensus
        family.length = record.length

        # RECOMMENDED FIELDS
        family.description = record.description
        family.author = record.author
        family.date_created = record.date_created
        family.date_modified = record.date_modified
        family.refineable = record.refineable
        family.target_site_cons = record.target_site_cons
        family.general_cutoff = record.hmm_general_threshold

        if record.classification_id in class_db:
            cls = class_db[record.classification_id]
            family.classification = cls.full_name()
            family.repeat_type = cls.type_name
            family.repeat_subtype = cls.subtype_name

        # clades and taxonomy links
        family.clades = []
        for (clade_id,) in clade_query(session).params(id=record.id).all():
            family.clades += [clade_id]

        # "SearchStages: A,B,C,..."
        ss_values = []
        for (stage_id,) in search_stage_query(session).params(id=record.id).all():
            ss_values += [str(stage_id)]

        if ss_values:
            family.search_stages = ",".join(ss_values)

        # "BufferStages:A,B,C[D-E],..."
        bs_values = []
        for (stage_id, start_pos, end_pos) in buffer_stage_query(session).params(id=record.id).all():
            if start_pos == 0 and end_pos == 0:
                bs_values += [str(stage_id)]
            else:
                bs_values += ["{}[{}-{}]".format(stage_id, start_pos, end_pos)]

        if bs_values:
            family.buffer_stages = ",".join(bs_values)

        # Taxa-specific thresholds. "ID, GA, TC, NC, fdr"
        th_values = []

        for (tax_id, spec_ga, spec_tc, spec_nc, spec_fdr) in assembly_data_query(session).params(id=record.id).all():
            if None in (spec_ga, spec_tc, spec_nc, spec_fdr):
                raise Exception("Found value of None for a threshold value for " +
                    record.accession + " in tax_id" + str(tax_id))
            th_values += ["{}, {}, {}, {}, {}".format(tax_id, spec_ga, spec_tc, spec_nc, spec_fdr)]
            tax_db[tax_id].mark_ancestry_used()

        if th_values:
            family.taxa_thresholds = "\n".join(th_values)

        feature_values = []
        for feature in feature_query(session).params(id=record.id).all():
            obj = {
                "type": feature.feature_type,
                "description": feature.description,
                "model_start_pos": feature.model_start_pos,
                "model_end_pos": feature.model_end_pos,
                "label": feature.label,
                "attributes": [],
            }

            for attribute in feature_attr_query(session).params(id=feature.id).all():
                obj["attributes"] += [{"attribute": attribute.attribute, "value": attribute.value}]
            feature_values += [obj]

        if feature_values:
            family.features = json.dumps(feature_values)

        cds_values = []
        for cds in cds_query(session).params(id=record.id).all():
            obj = {
                "product": cds.product,
                "translation": cds.translation,
                "cds_start": cds.cds_start,
                "cds_end": cds.cds_end,
                "exon_count": cds.exon_count,
                "exon_starts": str(cds.exon_starts),
                "exon_ends": str(cds.exon_ends),
                "external_reference": cds.external_reference,
                "reverse": (cds.reverse == 1),
                "stop_codons": cds.stop_codons,
                "frameshifts": cds.frameshifts,
                "gaps": cds.gaps,
                "percent_identity": cds.percent_identity,
                "left_unaligned": cds.left_unaligned,
                "right_unaligned": cds.right_unaligned,
                "description": cds.description,
                "protein_type": cds.protein_type,
            }

            cds_values += [obj]

        if cds_values:
            family.coding_sequences = json.dumps(cds_values)

        # External aliases

        alias_values = []
        for alias in alias_query(session).params(id=record.id).all():
            alias_values += ["%s: %s" % (alias.db_id, alias.db_link)]

        if alias_values:
            family.aliases = "\n".join(alias_values)

        citation_values = []
        for citation in citation_query(session).params(id=record.id).all():
            obj = {
                "title": citation.title,
                "authors": citation.authors,
                "journal": citation.journal,
                "order_added": citation.order_added,
            }
            citation_values += [obj]

        if citation_values:
            family.citations = json.dumps(citation_values)

        # MODEL DATA + METADATA

        hmm = hmm_query(session).params(id=record.id).one_or_none()
        if hmm:
            family.model = gzip.decompress(hmm[0]).decode()

        if record.hmm_maxl:
            family.max_length = record.hmm_maxl
        family.is_model_masked = record.model_mask

        seed_count = session.execute("SELECT COUNT(*) from seed_region where family_id=:id", {"id": record.id}).fetchone()[0]
        family.seed_count = seed_count

        yield family
Exemplo n.º 28
0
class Transcript(TranscriptBase):
    # Query baking to minimize overhead
    bakery = baked.bakery()
    blast_baked = bakery(lambda session: session.query(Hit))
    blast_baked += lambda q: q.filter(
        and_(Hit.query == bindparam("query"), Hit.evalue <= bindparam("evalue")
             ), )

    blast_baked += lambda q: q.order_by(asc(Hit.evalue))

    # blast_baked += lambda q: q.limit(bindparam("max_target_seqs"))

    def __init__(self, *args, configuration=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.__configuration = None
        self.configuration = configuration

    @property
    def configuration(self):
        """
        Configuration dictionary. It can be None.
        :return:
        """
        if self.__configuration is None:
            self.__configuration = default_config.copy()

        return self.__configuration

    @configuration.setter
    def configuration(self, configuration):
        """
        Setter for the configuration dictionary.
        :param configuration: None or a dictionary
        :type configuration: (None | MikadoConfiguration | DaijinConfiguration)
        :return:
        """

        if configuration is None:
            configuration = default_config.copy()

        assert isinstance(configuration,
                          (MikadoConfiguration, DaijinConfiguration))
        self.__configuration = configuration

    def __getstate__(self):
        state = super().__getstate__()
        if hasattr(self, "configuration") and self.configuration is not None:
            state["configuration"] = self.configuration.copy()
            assert isinstance(state["configuration"],
                              (MikadoConfiguration,
                               DaijinConfiguration)), type(self.configuration)
            if isinstance(state["configuration"].reference.genome,
                          pysam.FastaFile):
                state["configuration"]["reference"]["genome"] = state[
                    "configuration"].reference.genome.filename
        return state

    def __setstate__(self, state):
        self.configuration = state.pop("configuration", None)
        self.__dict__.update(state)
        self._calculate_cds_tree()
        self._calculate_segment_tree()
        self.logger = None

    def split_by_cds(self) -> List:
        """This method is used for transcripts that have multiple ORFs.
        It will split them according to the CDS information into multiple transcripts.
        UTR information will be retained only if no ORF is down/upstream.
        """

        for new_transcript in splitting.split_by_cds(self):
            yield new_transcript

        return

    def load_information_from_db(self,
                                 configuration,
                                 introns=None,
                                 data_dict=None):
        """This method will load information regarding the transcript from the provided database.

        :param configuration: Necessary configuration file
        :type configuration: (MikadoConfiguration|DaijinConfiguration)

        :param introns: the verified introns in the Locus
        :type introns: None,set

        :param data_dict: a dictionary containing the information directly
        :type data_dict: dict

        Verified introns can be provided from outside using the keyword.
        Otherwise, they will be extracted from the database directly.
        """

        retrieval.load_information_from_db(self,
                                           configuration,
                                           introns=introns,
                                           data_dict=data_dict)

    def load_orfs(self, candidate_orfs):
        """
        Thin layer over the load_orfs method from the retrieval module.
        :param candidate_orfs: list of candidate ORFs in BED12 format.
        :return:
        """

        retrieval.load_orfs(self, candidate_orfs)

    def find_overlapping_cds(self, candidate_orfs):
        """
        Thin wrapper for the homonym function in retrieval
        :param candidate_orfs: List of candidate ORFs
        :return:
        """

        return retrieval.find_overlapping_cds(self, candidate_orfs)

    # We need to overload this because otherwise we won't get the metrics from the base class.
    @classmethod
    @functools.lru_cache(maxsize=None, typed=True)
    def get_available_metrics(cls) -> list:
        """This function retrieves all metrics available for the class."""

        metrics = TranscriptBase.get_available_metrics()
        for member in inspect.getmembers(cls):
            if not member[0].startswith(
                    "__") and member[0] in cls.__dict__ and isinstance(
                        cls.__dict__[member[0]], Metric):
                metrics.append(member[0])

        _metrics = sorted(set([metric for metric in metrics]))
        final_metrics = ["tid", "alias", "parent", "original_source", "score"
                         ] + _metrics
        return final_metrics

    # We need to overload this because otherwise we won't get the metrics from the base class.
    @classmethod
    @functools.lru_cache(maxsize=None, typed=True)
    def get_modifiable_metrics(cls) -> set:

        metrics = TranscriptBase.get_modifiable_metrics()
        for member in inspect.getmembers(cls):
            not_private = (not member[0].startswith("_" + cls.__name__ + "__")
                           and not member[0].startswith("__"))
            in_dict = (member[0] in cls.__dict__)
            if in_dict:
                is_metric = isinstance(cls.__dict__[member[0]], Metric)
                has_fset = (getattr(cls.__dict__[member[0]], "fset", None)
                            is not None)
            else:
                is_metric = None
                has_fset = None
            if all([not_private, in_dict, is_metric, has_fset]):
                metrics.append(member[0])
        return set(metrics)
Exemplo n.º 29
0
    registry[STORAGE] = RDBStorage(registry[DBSESSION])
    global _DBSESSION
    _DBSESSION = registry[DBSESSION]
    if registry.settings.get('blob_bucket'):
        registry[BLOBS] = S3BlobStorage(
            registry.settings['blob_bucket'],
            read_profile_name=registry.settings.get('blob_read_profile_name'),
            store_profile_name=registry.settings.get('blob_store_profile_name')
        )
    else:
        registry[BLOBS] = RDBBlobStorage(registry[DBSESSION])


Base = declarative_base()

bakery = baked.bakery()
baked_query_resource = bakery(lambda session: session.query(Resource))
baked_query_unique_key = bakery(
    lambda session: session.query(Key).options(
        orm.joinedload_all(
            Key.resource,
            Resource.data,
            CurrentPropertySheet.propsheet,
            innerjoin=True,
        ),
    ).filter(Key.name == bindparam('name'), Key.value == bindparam('value'))
)


class RDBStorage(object):
    batchsize = 1000
Exemplo n.º 30
0
import logging
import pprint

import tornado.web
import tornado.httpserver
from tornado.options import define, options, parse_command_line
from sqlalchemy import create_engine, Column, Integer, String
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext import baked
from sqlalchemy.orm import Session


# models

BAKERY = baked.bakery()


Base = declarative_base()


ENGINE = create_engine('sqlite:///:memory:', echo=False)


class User(Base):
    __tablename__ = 'users'

    id = Column(Integer, primary_key=True)
    name = Column(String)
    fullname = Column(String)
    password = Column(String)
Exemplo n.º 31
0
class ModelViewWithBakedQueryMixin:

    bakery = baked.bakery()

    def _apply_path_joins(self, query, joins, path, inner_join=True):
        """
            Apply join path to the query.

            :param query:
                Query to add joins to
            :param joins:
                List of current joins. Used to avoid joining on same relationship more than once
            :param path:
                Path to be joined
            :param fn:
                Join function
        """
        last = None

        if path:
            for item in path:
                key = (inner_join, item)
                alias = joins.get(key)

                if key not in joins:
                    if not isinstance(item, Table):
                        alias = aliased(item.property.mapper.class_)

                    fn = query.join if inner_join else query.outerjoin

                    if last is None:
                        query = fn(item) if alias is None else fn(alias, item)
                    else:
                        prop = getattr(last, item.key)
                        query = fn(prop) if alias is None else fn(alias, prop)

                    joins[key] = alias

                last = alias

        return query, joins, last

    def _order_by(self, query, joins, sort_joins, sort_field, sort_desc):
        """
            Apply order_by to the query

            :param query:
                Query
            :pram joins:
                Current joins
            :param sort_joins:
                Sort joins (properties or tables)
            :param sort_field:
                Sort field
            :param sort_desc:
                Ascending or descending
        """
        if sort_field is not None:
            # Handle joins
            query, joins, alias = self._apply_path_joins(
                query, joins, sort_joins, inner_join=False)

            column = sort_field if alias is None else getattr(
                alias, sort_field.key)

            # if sort_desc:
            #     query += lambda q: q.order_by(desc(column))
            # else:
            #     query += lambda q: q.order_by(column)

            if sort_desc:
                query += lambda q: q.order_by(desc(bindparam('order_by')))
            else:
                query += lambda q: q.order_by(bindparam('order_by'))
            self.bakery_query_params['order_by'] = column.name

        return query, joins

    def _apply_filters(self, query, count_query, joins, count_joins, filters):
        for idx, flt_name, value in filters:
            flt = self._filters[idx]

            alias = None
            count_alias = None

            # Figure out joins
            if isinstance(flt, sqla_filters.BaseSQLAFilter):
                # If no key_name is specified, use filter column as filter key
                filter_key = flt.key_name or flt.column
                path = self._filter_joins.get(filter_key, [])

                query, joins, alias = self._apply_path_joins(
                    query, joins, path, inner_join=False)

                if count_query is not None:
                    count_query, count_joins, count_alias = self._apply_path_joins(
                        count_query, count_joins, path, inner_join=False)

            # Clean value .clean() and apply the filter
            clean_value = flt.clean(value)

            try:
                query = flt.apply(query, clean_value, alias)
            except TypeError:
                spec = inspect.getargspec(flt.apply)

                if len(spec.args) == 3:
                    warnings.warn('Please update your custom filter %s to '
                                  'include additional `alias` parameter.' % repr(flt))
                else:
                    raise

                query = flt.apply(query, clean_value)

            if count_query is not None:
                try:
                    count_query = flt.apply(
                        count_query, clean_value, count_alias)
                except TypeError:
                    count_query = flt.apply(count_query, clean_value)

        return query, count_query, joins, count_joins

    def _apply_search(self, query, count_query, joins, count_joins, search):
        """
            Apply search to a query.
        """
        terms = search.split(' ')

        for term in terms:
            if not term:
                continue

            stmt = tools.parse_like_term(term)

            filter_stmt = []
            count_filter_stmt = []

            for field, path in self._search_fields:
                query, joins, alias = self._apply_path_joins(
                    query, joins, path, inner_join=False)

                count_alias = None

                if count_query is not None:
                    count_query, count_joins, count_alias = self._apply_path_joins(
                        count_query, count_joins, path, inner_join=False)

                column = field if alias is None else getattr(alias, field.key)
                filter_stmt.append(
                    cast(column, Unicode).ilike(stmt))  # 使用ilike进行检索

                if count_filter_stmt is not None:
                    column = field if count_alias is None else getattr(
                        count_alias, field.key)
                    count_filter_stmt.append(cast(column, Unicode).ilike(stmt))

            query += lambda q: q.filter(or_(*filter_stmt))  # “或”查询

            if count_query is not None:
                count_query += lambda q: q.filter(or_(*count_filter_stmt))

        return query, count_query, joins, count_joins

    def _apply_pagination(self, query, page, page_size):
        if page_size is None:
            page_size = self.page_size

        if page_size:
            query += lambda q: q.limit(bindparam('page_size'))
            self.bakery_query_params['page_size'] = page_size

        if page and page_size:
            # query += lambda q: q.offset(page * page_size)
            query += lambda q: q.offset(bindparam('offset'))
            self.bakery_query_params['offset'] = page * page_size

        return query

    def get_list(self, page, sort_column, sort_desc, search, filters,
                 execute=True, page_size=None):
        """
            Return records from the database.

            :param page:
                Page number
            :param sort_column:
                Sort column name
            :param sort_desc:
                Descending or ascending sort
            :param search:
                Search query
            :param execute:
                Execute query immediately? Default is `True`
            :param filters:
                List of filter tuples
            :param page_size:
                Number of results. Defaults to ModelView's page_size. Can be
                overriden to change the page_size limit. Removing the page_size
                limit requires setting page_size to 0 or False.
        """

        # Will contain join paths with optional aliased object
        joins = {}
        count_joins = {}

        query = self.get_query()
        count_query = self.get_count_query() if not self.simple_list_pager else None

        # Ignore eager-loaded relations (prevent unnecessary joins)
        # TODO: Separate join detection for query and count query?
        # eager-loading: https://docs.sqlalchemy.org/en/latest/orm/tutorial.html#eager-loading
        if hasattr(query, '_join_entities'):
            for entity in query._join_entities:
                for table in entity.tables:
                    joins[table] = None

        # Apply search criteria
        if self._search_supported and search:
            query, count_query, joins, count_joins = self._apply_search(
                query, count_query, joins, count_joins, search)

        # Apply filters
        if filters and self._filters:
            query, count_query, joins, count_joins = self._apply_filters(
                query, count_query, joins, count_joins, filters)

        # Calculate number of rows if necessary
        count = count_query(self.session()).scalar() if count_query else None

        # Auto join
        for j in self._auto_joins:
            # joinedload: https://docs.sqlalchemy.org/en/13/orm/loading_relationships.html#sqlalchemy.orm.joinedload
            query += lambda q: q.options(joinedload(j))

        # Sorting
        query, joins = self._apply_sorting(
            query, joins, sort_column, sort_desc)

        # Pagination
        query = self._apply_pagination(query, page, page_size)

        # Execute if needed
        if execute:
            query = query(self.session()).params(
                **self.bakery_query_params
            ).all()

        return count, query

    def get_one(self, id):
        """override"""
        return super().get_one(id)
        # TODO:
        # q = self.bakery(lambda s: s.query(self.model))
        # q += lambda q: q.filter(self.model.id == bindparam("id"))
        # return q(self.session()).params(id=id).one()

    def get_query(self):
        """override"""
        # return super().get_query()
        q = self.bakery(lambda s: s.query(self.model))
        self.bakery_query_params = {}
        return q

    def get_count_query(self):
        """override"""
        # return super().get_count_query()
        q = self.bakery(lambda s: s.query(
            func.count('*')).select_from(self.model))
        return q
Exemplo n.º 32
0
def async_setup(hass):
    """Set up the history hooks."""
    hass.data[HISTORY_BAKERY] = baked.bakery()
Exemplo n.º 33
0
def async_setup(opp):
    """Set up the history hooks."""
    opp.data[HISTORY_BAKERY] = baked.bakery()
Exemplo n.º 34
0
 def setup(self):
     self.bakery = baked.bakery()
Exemplo n.º 35
0
def async_setup(hass):
    """Set up the history hooks."""
    hass.data[STATISTICS_BAKERY] = baked.bakery()
Exemplo n.º 36
0
 def setup_test(self):
     self.bakery = baked.bakery()
from .base import Base, Namespace
from .generic import (
    MutableDict, HasTaxonomy,
    HasReferenceAccessionNumber,
    HasClassBakedQueries)

from ..utils.database_utils import get_or_create
from ..utils.memoize import memoclone

FrozenGlycanComposition = glycan_composition.FrozenGlycanComposition
crossring_pattern = re.compile(r"\d,\d")
glycoct_parser = memoclone(100)(glycoct.loads)


glycan_bakery = bakery()


class MassShift(Base):
    __tablename__ = "MassShift"

    id = Column(Integer, primary_key=True, autoincrement=True)
    name = Column(Unicode(128), index=True)
    mass = Column(Numeric(10, 6, asdecimal=False))

    def __hash__(self):
        return hash(self.name)

    def __eq__(self, other):
        return (self.name == other.name) and (self.mass == other.mass)
from sqlalchemy.orm import relationship, backref
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy import Numeric, Unicode, Column, Integer, ForeignKey, Table, UnicodeText, bindparam

from ..base import Base
from sqlalchemy.ext.baked import bakery

from glycresoft_sqlalchemy.structure.sequence_composition import SequenceComposition, Composition


sequencing_bakery = bakery()


class SequenceBuildingBlock(Base):
    __tablename__ = "SequenceBuildingBlock"

    id = Column(Integer, primary_key=True)
    name = Column(Unicode(120), index=True)
    mass = Column(Numeric(10, 6), index=True)
    hypothesis_id = Column(Integer, ForeignKey("Hypothesis.id"), index=True)


class SequenceSegment(Base):
    __tablename__ = "SequenceSegment"
    id = Column(Integer, primary_key=True)
    sequence = Column(Unicode(128))
    mass = Column(Numeric(12, 6), index=True)
    count_n_glycosylation = Column(Integer)


class AminoAcidComposition(Base):
from sqlalchemy.orm import relationship, backref, make_transient, Query, validates
from sqlalchemy.ext.hybrid import hybrid_method, hybrid_property
from sqlalchemy import (PickleType, Numeric, Unicode, Table, bindparam,
                        Column, Integer, ForeignKey, UnicodeText, Boolean,
                        event)
from sqlalchemy.orm.exc import DetachedInstanceError
from ..generic import MutableDict, MutableList, HasClassBakedQueries
from ..base import Base
from ..hypothesis import Hypothesis
from ..glycomics import (
    with_glycan_composition, TheoreticalGlycanCombination, has_glycan_composition_listener)

from ...structure import sequence, residue


peptide_bakery = bakery()


class Protein(Base):
    __tablename__ = "Protein"

    id = Column(Integer, primary_key=True, autoincrement=True)
    protein_sequence = Column(UnicodeText, default=u"")
    name = Column(Unicode(128), default=u"", index=True)
    other = Column(MutableDict.as_mutable(PickleType))
    hypothesis_id = Column(Integer, ForeignKey("Hypothesis.id", ondelete="CASCADE"))

    _n_glycan_sequon_sites = None

    @property
    def n_glycan_sequon_sites(self):
Exemplo n.º 40
0
    def test_subquery_eagerloading(self):
        User = self.classes.User
        Address = self.classes.Address
        Order = self.classes.Order

        # Override the default bakery for one with a smaller size. This used to
        # trigger a bug when unbaking subqueries.
        self.bakery = baked.bakery(size=3)
        base_bq = self.bakery(lambda s: s.query(User))

        base_bq += lambda q: q.options(subqueryload(User.addresses),
                                       subqueryload(User.orders))
        base_bq += lambda q: q.order_by(User.id)

        assert_result = [
            User(id=7,
                addresses=[Address(id=1, email_address='*****@*****.**')],
                orders=[Order(id=1), Order(id=3), Order(id=5)]),
            User(id=8, addresses=[
                Address(id=2, email_address='*****@*****.**'),
                Address(id=3, email_address='*****@*****.**'),
                Address(id=4, email_address='*****@*****.**'),
            ]),
            User(id=9,
                addresses=[Address(id=5)],
                orders=[Order(id=2), Order(id=4)]),
            User(id=10, addresses=[])
        ]

        for i in range(4):
            for cond1, cond2 in itertools.product(
                    *[(False, True) for j in range(2)]):
                bq = base_bq._clone()

                sess = Session()

                if cond1:
                    bq += lambda q: q.filter(User.name == 'jack')
                else:
                    bq += lambda q: q.filter(User.name.like('%ed%'))

                if cond2:
                    ct = func.count(Address.id).label('count')
                    subq = sess.query(
                        ct,
                        Address.user_id).group_by(Address.user_id).\
                        having(ct > 2).subquery()

                    bq += lambda q: q.join(subq)

                if cond2:
                    if cond1:
                        def go():
                            result = bq(sess).all()
                            eq_([], result)
                        self.assert_sql_count(testing.db, go, 1)
                    else:
                        def go():
                            result = bq(sess).all()
                            eq_(assert_result[1:2], result)
                        self.assert_sql_count(testing.db, go, 3)
                else:
                    if cond1:
                        def go():
                            result = bq(sess).all()
                            eq_(assert_result[0:1], result)
                        self.assert_sql_count(testing.db, go, 3)
                    else:
                        def go():
                            result = bq(sess).all()
                            eq_(assert_result[1:3], result)
                        self.assert_sql_count(testing.db, go, 3)

                sess.close()
Exemplo n.º 41
0
           'ModelAPIMixin']


class BaseCls(object):
    created_time = Column(DateTime, default=datetime.datetime.now)
    modified_time = Column(DateTime,
                           onupdate=datetime.datetime.now,
                           default=datetime.datetime.now)


# 各种要多次使用的客户端,基类
Base = declarative_base(cls=BaseCls)        # 继承了2个时间字段的Base
NativeBase = declarative_base()             # 原生Base
metadata = Base.metadata
Session = sessionmaker()
sql_bakery = bakery()
redis_cli = redis.StrictRedis(
    host=CommonConfig.REDIS_HOST,
    port=CommonConfig.REDIS_PORT,
    password=CommonConfig.REDIS_PASSWORD,
    decode_responses=True
)


@contextlib.contextmanager
def session_context(uri=CommonConfig.SQLALCHEMY_URI):
    engine = create_engine(uri)
    session = Session(bind=engine)
    yield session
    try:
        session.commit()
Exemplo n.º 42
0
from neutron_lib.api.definitions import portbindings
from opflexagent import rpc as o_rpc
from oslo_log import log

from gbpservice.neutron.db.grouppolicy.extensions import (
    apic_auto_ptg_db as auto_ptg_db)
from gbpservice.neutron.db.grouppolicy.extensions import (
    apic_segmentation_label_db as seg_label_db)
from gbpservice.neutron.db.grouppolicy import group_policy_mapping_db as gpmdb
from gbpservice.neutron.plugins.ml2plus.drivers.apic_aim import (
    constants as md_const)


LOG = log.getLogger(__name__)

BAKERY = baked.bakery(_size_alert=lambda c: LOG.warning(
    "sqlalchemy baked query cache size exceeded in %s" % __name__))

EndpointPtInfo = namedtuple(
    'EndpointPtInfo',
    ['pt_id',
     'ptg_id',
     'apg_id',
     'inject_default_route',
     'l3p_project_id',
     'is_auto_ptg'])


class AIMMappingRPCMixin(object):
    """RPC mixin for AIM mapping.

    Collection of all the RPC methods consumed by the AIM mapping.
Exemplo n.º 43
0
def async_setup(hass: HomeAssistant) -> None:
    """Set up the history hooks."""
    hass.data[HISTORY_BAKERY] = baked.bakery()
Exemplo n.º 44
0
 def setup(self):
     self.bakery = baked.bakery()
Exemplo n.º 45
0
)
from sqlalchemy.ext import baked
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.session import Session
from sqlalchemy.pool import NullPool, Pool, QueuePool
from sqlalchemy.sql import func, select
from sqlalchemy.sql.expression import Insert

from ichnaea.config import (
    DB_DDL_URI,
    DB_RW_URI,
    DB_RO_URI,
)

BAKERY = baked.bakery(size=500)

DB_TYPE = {
    'ddl': DB_DDL_URI,
    'ro': DB_RO_URI,
    'rw': DB_RW_URI,
}


@compiles(Insert, 'mysql')
def on_duplicate(insert, compiler, **kw):
    """Custom MySQL insert on_duplicate support."""
    stmt = compiler.visit_insert(insert, **kw)
    my_var = insert.dialect_kwargs.get('mysql_on_duplicate', None)
    if my_var:
        stmt += ' ON DUPLICATE KEY UPDATE %s' % my_var