コード例 #1
0
ファイル: test_db_model.py プロジェクト: mlatief/haukka
class TrialModelTestCase(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        global transaction, connection, engine

        # Connect to the database and create the schema within a transaction
        engine = create_engine(TEST_DATABASE_URI)
        connection = engine.connect()
        transaction = connection.begin()
        Trial.metadata.create_all(connection)

        # Load test trials fixtures from xml files
        nct_ids = ['NCT02034110', 'NCT00001160', 'NCT00001163']
        cls.trials = load_sample_trials(nct_ids)

    @classmethod
    def tearDownClass(cls):
        # Roll back the top level transaction and disconnect from the database
        transaction.rollback()
        connection.close()
        engine.dispose()

    def setUp(self):
        self.__transaction = connection.begin_nested()
        self.session = Session(connection)

    def tearDown(self):
        self.session.close()
        self.__transaction.rollback()

    def test_add(self):
        trial = Trial(ct_dict=self.trials[0])
        self.session.add(trial)
コード例 #2
0
class TestSessions(TestCase):
    plugins = []

    def test_multiple_connections(self):
        self.session2 = Session(bind=self.engine.connect())
        article = self.Article(name=u'Session1 article')
        article2 = self.Article(name=u'Session2 article')
        self.session.add(article)
        self.session2.add(article2)
        self.session.flush()
        self.session2.flush()

        self.session.commit()
        self.session2.commit()
        assert article.versions[-1].transaction_id == 1
        assert article2.versions[-1].transaction_id == 2

    def test_manual_transaction_creation(self):
        uow = versioning_manager.unit_of_work(self.session)
        transaction = uow.create_transaction(self.session)
        self.session.flush()
        assert transaction.id == 1
        article = self.Article(name=u'Session1 article')
        self.session.add(article)
        self.session.flush()
        assert uow.current_transaction.id == 1

        self.session.commit()
        assert article.versions[-1].transaction_id == 1

    def test_commit_without_objects(self):
        self.session.commit()
コード例 #3
0
    def test_session_with_external_transaction(self):
        conn = self.engine.connect()
        t = conn.begin()
        session = Session(bind=conn)

        article = self.Article(name=u'My Session Article')
        session.add(article)
        session.flush()

        session.close()
        t.rollback()
        conn.close()
コード例 #4
0
ファイル: manager.py プロジェクト: BtbN/esaupload
def set_new_setting(session: Session, name: str, value):
    if isinstance(value, str) or isinstance(value, int) or isinstance(value, float) or isinstance(value, bool):
        value_class = str(value.__class__.__name__)
        value = str(value)
    else:
        value_class = "pickle"
        value = pickle.dumps(value)
    sett = session.query(Setting).filter_by(name=name).first()
    if not sett:
            sett = Setting(name=name, value=value, value_class=value_class)
            session.add(sett)
    else:
        sett.value = value
        sett.value_class = value_class
    return True
コード例 #5
0
ファイル: usermixin.py プロジェクト: websauna/websauna
    def init_empty_site(self, dbsession: Session, user: UserMixin):
        """When the first user signs up build the admin groups and make the user member of it.

        Make the first member of the site to be admin and superuser.
        """

        # Try to reflect related group class based on User model
        i = inspection.inspect(user.__class__)
        Group = i.relationships["groups"].mapper.entity

        # Do we already have any groups... if we do we probably don'¨t want to init again
        if dbsession.query(Group).count() > 0:
            return

        g = Group(name=Group.DEFAULT_ADMIN_GROUP_NAME)
        dbsession.add(g)

        g.users.append(user)
コード例 #6
0
def upgrade(migrate_engine):
    """
    Upgrade operations go here.
    Don't create your own engine; bind migrate_engine to your metadata
    """
    #==========================================================================
    # USER LOGS
    #==========================================================================
    from rhodecode.lib.dbmigrate.schema.db_1_5_0 import UserLog
    tbl = UserLog.__table__
    username = Column("username", String(255, convert_unicode=False,
                                         assert_unicode=None), nullable=True,
                      unique=None, default=None)
    # create username column
    username.create(table=tbl)

    _Session = Session()
    ## after adding that column fix all usernames
    users_log = _Session.query(UserLog)\
            .options(joinedload(UserLog.user))\
            .options(joinedload(UserLog.repository)).all()

    for entry in users_log:
        entry.username = entry.user.username
        _Session.add(entry)
    _Session.commit()

    #alter username to not null
    from rhodecode.lib.dbmigrate.schema.db_1_5_0 import UserLog
    tbl_name = UserLog.__tablename__
    tbl = Table(tbl_name,
                MetaData(bind=migrate_engine), autoload=True,
                autoload_with=migrate_engine)
    col = tbl.columns.username

    # remove nullability from revision field
    col.alter(nullable=False)
コード例 #7
0
    def test_create_object(self):
        """
            Test the method def _create_object(self, obj_type, sync_key,
            local_id=None, pvc_id=None)
        """

        obj_type = "Network"
        sync_key = utils.gen_network_sync_key(
            self.fakePowerVCNetwork.powerNetInstance)
        local_id = self.fakeOSNetwork.fakeOSNetworkInstance['id']
        pvc_id = self.fakePowerVCNetwork.powerNetInstance['id']

        inputPowerVCMObj = model.PowerVCMapping(obj_type, sync_key)

        self.aMox.StubOutWithMock(session, 'begin')
        session.begin(subtransactions=True).AndReturn(transaction(None, None))

        self.aMox.StubOutWithMock(model, 'PowerVCMapping')
        model.PowerVCMapping(obj_type, sync_key).AndReturn(inputPowerVCMObj)

        self.aMox.StubOutWithMock(session, 'add')
        session.add(inputPowerVCMObj).AndReturn("")

        self.aMox.ReplayAll()

        self.powervcagentdb._create_object(
            obj_type, sync_key, update_data=None,
            local_id=local_id, pvc_id=pvc_id)

        self.aMox.VerifyAll()

        self.assertEqual(
            self.powerVCMapping.local_id, inputPowerVCMObj.local_id)
        self.assertEqual(self.powerVCMapping.pvc_id, inputPowerVCMObj.pvc_id)
        self.assertEqual(self.powerVCMapping.status, inputPowerVCMObj.status)
        self.aMox.UnsetStubs()
コード例 #8
0
class CrawlProcessor(object):

    __VERSION__ = "CrawlProcessor-0.2.1"

    def __init__(self, engine, redis_server, stop_list="keyword_filter.txt"):

        if type(engine) == types.StringType:
            logging.info("Using connection string '%s'" % (engine,))
            new_engine = create_engine(engine, encoding='utf-8', isolation_level="READ COMMITTED")
            if "sqlite:" in engine:
                logging.debug("Setting text factory for unicode compat.")
                new_engine.raw_connection().connection.text_factory = str 
            self._engine = new_engine
        else:
            logging.info("Using existing engine...")
            self._engine = engine
        logging.info("Binding session...")
        self._session = Session(bind=self._engine, autocommit = False)

        if type(stop_list) == types.StringType:
            stop_list_fp = open(stop_list)
        else:
            stop_list_fp = stop_list 

        self.stop_list = set([])
        for line in stop_list_fp:
            self.stop_list.add(line.strip())

        self.cls = DocumentClassifier()
        self.dc  = DomainController(self._engine, self._session)
        self.ac  = ArticleController(self._engine, self._session)
        self.ex  = extract.TermExtractor()
        self.kwc = KeywordController(self._engine, self._session)
        self.swc = SoftwareVersionsController(self._engine, self._session)
        self.redis_kw = redis.Redis(host=redis_server, port=6379, db=1)
        self.redis_dm = redis.Redis(host=redis_server, port=6379, db=2)
        dm_session = Session(bind=self._engine, autocommit = False)
        self.drw = DomainResolutionWorker(dm_session, self.redis_dm)


    def _check_processed(self, item):
        crawl_id, record = item 
        headers, content, url, date_crawled, content_type = record

        path   = self.ac.get_path_fromurl(url)
        domain_identifier = None 
        logging.info("_check_processed: retrieving domain...")
        domain_key = self.dc.get_Domain_key(url)
        while domain_identifier == None:
            domain_identifier = self.drw.get_domain(domain_key)

        it = self._session.query(Article).filter_by(crawl_id = crawl_id).filter_by(domain_id = domain_identifier).filter_by(path = path)
        try:
            it = it.one()
            logging.error("%s: already processed", url)
            return False 
        except sqlalchemy.orm.exc.MultipleResultsFound:
            logging.error("%s: appears to have been already processed multiple times", url)
            return False 
        except sqlalchemy.orm.exc.NoResultFound:
            logging.info("%s: hasn't been processed yet", url)
            return True 

    def process_record(self, item):
        if len(item) != 2:
            raise ValueError(item)
        if not self._check_processed(item):
            return None
        ret, retries = None, 2
        while ret == None and retries > 0:
            try:
                retries -= 1
                ret = self._process_record(item)
            except Exception as ex:
                import traceback
                print >> sys.stderr, ex
                traceback.print_exc()
                raise ex 
        if ret == False:
            return None 
        return ret 


    def _process_record(self, item_arg):

        crawl_id, record = item_arg
        headers, content, url, date_crawled, content_type = record

        assert headers is not None
        assert content is not None 
        assert url is not None 
        assert date_crawled is not None 
        assert content_type is not None 

        status = "Processed"

        # Fix for a seg-fault
        if "nasa.gov" in url:
            return False

        # Sort out the domain
        domain_identifier = None 
        logging.info("Retrieving domain...")
        domain_key = self.dc.get_Domain_key(url)
        while domain_identifier == None:
            domain_identifier = self.drw.get_domain(domain_key)

        domain = self._session.query(Domain).get(domain_identifier)
        assert domain is not None

        # Build database objects 
        path   = self.ac.get_path_fromurl(url)
        article = Article(path, date_crawled, crawl_id, domain, status)
        self._session.add(article)
        classified_by = self.swc.get_SoftwareVersion_fromstr(pysen.__VERSION__)
        assert classified_by is not None

        if content_type != 'text/html':
            logging.error("Unsupported content type: %s", str(content_type))
            article.status = "UnsupportedType"
            return False

        # Start the async transaction to get the plain text
        worker_req_thread = BoilerPipeWorker(content)
        worker_req_thread.start()

        # Whilst that's executing, parse the document 
        logging.info("Parsing HTML...")
        html = BeautifulSoup(content)

        if html is None or html.body is None:
            article.status = "NoContent"
            return False

        # Extract the dates 
        date_dict = pydate.get_dates(html)

        if len(date_dict) == 0:
            status = "NoDates"

        # Detect the language
        lang, lang_certainty = langid.classify(content)

        # Wait for the BoilerPipe thread to complete
        worker_req_thread.join()
        logging.debug(worker_req_thread.result)
        logging.debug(worker_req_thread.version)

        if worker_req_thread.result == None:
            article.status = "NoContent"
            return False

        # If the language isn't English, skip it
        if lang != "en":
            logging.info("language: %s with certainty %.2f - skipping...", lang, lang_certainty)
            article.status = "LanguageError" # Replace with something appropriate
            return False

        content = worker_req_thread.result.encode('ascii', 'ignore')

        # Headline extraction 
        h_counter = 6
        headline = None
        while h_counter > 0:
            tag = "h%d" % (h_counter,)
            found = False 
            for node in html.findAll(tag):
                if node.text in content:
                    headline = node.text 
                    found = True 
                    break 
            if found:
                break
            h_counter -= 1

        # Run keyword extraction 
        keywords = self.ex(content)
        kset     = KeywordSet(self.stop_list)
        nnp_sets_scored = set([])

        for word, freq, amnt in sorted(keywords):
            try:
                nnp_sets_scored.add((word, freq))
            except ValueError:
                break 

        nnp_adj = set([])
        nnp_set = set([])
        nnp_vector = []
        for sentence in sent_tokenize(content):
            text = nltk.word_tokenize(sentence)
            pos  = nltk.pos_tag(text)
            pos_groups = itertools.groupby(pos, lambda x: x[1])
            for k, g in pos_groups:
                if k != 'NNP':
                    continue
                nnp_list = [word for word, speech in g]
                nnp_buf = []
                for item in nnp_list:
                    nnp_set.add(item)
                    nnp_buf.append(item)
                    nnp_vector.append(item)
                for i, j in zip(nnp_buf[0:-1], nnp_buf[1:]):
                    nnp_adj.add((i, j))

        nnp_vector = filter(lambda x: x.lower() not in self.stop_list, nnp_vector)
        nnp_counter = Counter(nnp_vector)
        for word in nnp_set:
            score = nnp_counter[word]
            nnp_sets_scored.add((item, score))

        for item, score in sorted(nnp_sets_scored, key=lambda x: x[1], reverse=True):
            try: 
                if type(item) == types.ListType or type(item) == types.TupleType:
                    kset.add(' '.join(item))
                else:
                    kset.add(item)
            except ValueError:
                break 

        scored_nnp_adj = []
        for item1, item2 in nnp_adj:
            score = nnp_counter[item1] + nnp_counter[item2]
            scored_nnp_adj.append((item1, item2, score))

        nnp_adj = []
        for item1, item2, score in sorted(scored_nnp_adj, key=lambda x: x[1], reverse=True):
            if len(nnp_adj) < KEYWORD_LIMIT:
                nnp_adj.append((item1, item2))
            else:
                break

        # Generate list of all keywords
        keywords = set([])
        for keyword in kset:
            try:
                k = Keyword(keyword)
                keywords.add(k)
            except ValueError as ex:
                logging.error(ex)
                continue
        for item1, item2 in nnp_adj:
            try:
                k = Keyword(item1)
                keywords.add(k)
            except ValueError as ex:
                logging.error(ex)
            try:
                k = Keyword(item2)
                keywords.add(k)
            except ValueError as ex:
                logging.error(ex)

        # Resolve keyword identifiers
        keyword_resolution_worker = KeywordResolutionWorker(set([k.word for k in keywords]), self.redis_kw)
        keyword_resolution_worker.start()
            
        # Run sentiment analysis
        trace = []
        features = self.cls.classify(worker_req_thread.result, trace) 
        label, length, classified, pos_sentences, neg_sentences,\
        pos_phrases, neg_phrases  = features[0:7]        

        # Convert Pysen's model into database models
        try:
            doc = Document(article.id, label, length, pos_sentences, neg_sentences, pos_phrases, neg_phrases, headline)
        except ValueError as ex:
            logging.error(ex)
            logging.error("Skipping this document...")
            article.status = "ClassificationError"
            return False

        self._session.add(doc)
        extracted_phrases = set([])
        for sentence, score, phrase_trace in trace:
            sentence_type = "Unknown"
            for node in html.findAll(text=True):
                if sentence.text in node.strip():
                    sentence_type = node.parent.name.upper()
                    break

            if sentence_type not in ["H1", "H2", "H3", "H4", "H5", "H6", "P", "Unknown"]:
                sentence_type = "Other"

            label, average, prob, pos, neg, probs, _scores = score 

            s = Sentence(doc, label, average, prob, sentence_type)
            self._session.add(s)
            for phrase, prob, score, label in phrase_trace:
                p = Phrase(s, score, prob, label)
                self._session.add(p)
                extracted_phrases.add((phrase, p))

        # Wait for keyword resolution to finish
        keyword_resolution_worker.join()
        keyword_mapping = keyword_resolution_worker.out_keywords

        # Associate extracted keywords with phrases
        keyword_objects, short_keywords = kset.convert(keyword_mapping, self.kwc)
        for k in keyword_objects:
            self._session.merge(k)
        for p, p_obj in extracted_phrases:
            for k in keyword_objects:
                if k.word in p.get_text():
                    nk = KeywordIncidence(k, p_obj)

        # Save the keyword adjacency list
        for i, j in kset.convert_adj_tuples(nnp_adj, keyword_mapping, self.kwc):
            self._session.merge(i)
            self._session.merge(j)
            kwa = KeywordAdjacency(i, j, doc)
            self._session.add(kwa)

        # Build date objects
        for key in date_dict:
            rec  = date_dict[key]
            if "dates" not in rec:
                logging.error("OK: 'dates' is not in a pydate result record.")
                continue
            dlen = len(rec["dates"])
            if rec["text"] not in content:
                logging.debug("'%s' is not in %s", rec["text"], content)
                continue
            if dlen > 1:
                for date, day_first, year_first in rec["dates"]:
                    try:
                        dobj = AmbiguousDate(date, doc, day_first, year_first, rec["prep"], key)
                    except ValueError as ex:
                        logging.error(ex)
                        continue
                    self._session.add(dobj)
            elif dlen == 1:
                for date, day_first, year_first in rec["dates"]:
                    dobj = CertainDate(date, doc, key)
                    self._session.add(dobj)
            else:
                logging.error("'dates' in a pydate result set contains no records.")

        # Process links
        for link in html.findAll('a'):
            if not link.has_attr("href"):
                logging.debug("skipping %s: no href", link)
                continue

            process = True 
            for node in link.findAll(text=True):
                if node not in worker_req_thread.result:
                    process = False 
                    break 
            
            if not process:
                logging.debug("skipping %s because it's not in the body text", link)
                break

            href, junk, junk = link["href"].partition("#")
            if "http://" in href:
                try:

                    domain_id = None 
                    domain_key = self.dc.get_Domain_key(href)
                    while domain_id is None:
                        domain_id = self.drw.get_domain(domain_key)

                    assert domain_id is not None
                    href_domain = self._session.query(Domain).get(domain_id)
                except ValueError as ex:
                    logging.error(ex)
                    logging.error("Skipping this link")
                    continue
                href_path   = self.ac.get_path_fromurl(href)
                lnk = AbsoluteLink(doc, href_domain, href_path)
                self._session.add(lnk)
                logging.debug("Adding: %s", lnk)
            else:
                href_path  = href 
                try:
                    lnk = RelativeLink(doc, href_path)
                except ValueError as ex:
                    logging.error(ex)
                    logging.error("Skipping link")
                    continue
                self._session.add(lnk)
                logging.debug("Adding: %s", lnk)

        # Construct software involvment records
        self_sir = SoftwareInvolvementRecord(self.swc.get_SoftwareVersion_fromstr(self.__VERSION__), "Processed", doc)
        date_sir = SoftwareInvolvementRecord(self.swc.get_SoftwareVersion_fromstr(pydate.__VERSION__), "Dated", doc)
        clas_sir = SoftwareInvolvementRecord(self.swc.get_SoftwareVersion_fromstr(pysen.__VERSION__), "Classified", doc)
        extr_sir = SoftwareInvolvementRecord(self.swc.get_SoftwareVersion_fromstr(worker_req_thread.version), "Extracted", doc)

        for sw in [self_sir, date_sir, clas_sir, extr_sir]:
            self._session.merge(sw, load=True)

        logging.debug("Domain: %s", domain)
        logging.debug("Path: %s", path)
        article.status = status

        # Commit to database, return True on success
        try:
            self._session.commit()
        except OperationalError as ex:
            logging.error(ex)
            self._session.rollback()
            return None

        return article.id

    def finalize(self):
        self._session.commit()
コード例 #9
0
def test_export(session: Session) -> None:
    """
    Test exporting a dataset.
    """
    from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
    from superset.datasets.commands.export import ExportDatasetsCommand
    from superset.models.core import Database

    engine = session.get_bind()
    SqlaTable.metadata.create_all(engine)  # pylint: disable=no-member

    database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
    session.add(database)
    session.flush()

    columns = [
        TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"),
        TableColumn(column_name="user_id", type="INTEGER"),
        TableColumn(column_name="revenue", type="INTEGER"),
        TableColumn(column_name="expenses", type="INTEGER"),
        TableColumn(
            column_name="profit",
            type="INTEGER",
            expression="revenue-expenses",
            extra=json.dumps({"certified_by": "User"}),
        ),
    ]
    metrics = [
        SqlMetric(
            metric_name="cnt",
            expression="COUNT(*)",
            extra=json.dumps({"warning_markdown": None}),
        ),
    ]

    sqla_table = SqlaTable(
        table_name="my_table",
        columns=columns,
        metrics=metrics,
        main_dttm_col="ds",
        database=database,
        offset=-8,
        description="This is the description",
        is_featured=1,
        cache_timeout=3600,
        schema="my_schema",
        sql=None,
        params=json.dumps(
            {
                "remote_id": 64,
                "database_name": "examples",
                "import_time": 1606677834,
            }
        ),
        perm=None,
        filter_select_enabled=1,
        fetch_values_predicate="foo IN (1, 2)",
        is_sqllab_view=0,  # no longer used?
        template_params=json.dumps({"answer": "42"}),
        schema_perm=None,
        extra=json.dumps({"warning_markdown": "*WARNING*"}),
    )

    export = list(
        ExportDatasetsCommand._export(sqla_table)  # pylint: disable=protected-access
    )
    assert export == [
        (
            "datasets/my_database/my_table.yaml",
            f"""table_name: my_table
main_dttm_col: ds
description: This is the description
default_endpoint: null
offset: -8
cache_timeout: 3600
schema: my_schema
sql: null
params:
  remote_id: 64
  database_name: examples
  import_time: 1606677834
template_params:
  answer: '42'
filter_select_enabled: 1
fetch_values_predicate: foo IN (1, 2)
extra:
  warning_markdown: '*WARNING*'
uuid: null
metrics:
- metric_name: cnt
  verbose_name: null
  metric_type: null
  expression: COUNT(*)
  description: null
  d3format: null
  extra:
    warning_markdown: null
  warning_text: null
columns:
- column_name: profit
  verbose_name: null
  is_dttm: null
  is_active: null
  type: INTEGER
  advanced_data_type: null
  groupby: null
  filterable: null
  expression: revenue-expenses
  description: null
  python_date_format: null
  extra:
    certified_by: User
- column_name: ds
  verbose_name: null
  is_dttm: 1
  is_active: null
  type: TIMESTAMP
  advanced_data_type: null
  groupby: null
  filterable: null
  expression: null
  description: null
  python_date_format: null
  extra: null
- column_name: user_id
  verbose_name: null
  is_dttm: null
  is_active: null
  type: INTEGER
  advanced_data_type: null
  groupby: null
  filterable: null
  expression: null
  description: null
  python_date_format: null
  extra: null
- column_name: expenses
  verbose_name: null
  is_dttm: null
  is_active: null
  type: INTEGER
  advanced_data_type: null
  groupby: null
  filterable: null
  expression: null
  description: null
  python_date_format: null
  extra: null
- column_name: revenue
  verbose_name: null
  is_dttm: null
  is_active: null
  type: INTEGER
  advanced_data_type: null
  groupby: null
  filterable: null
  expression: null
  description: null
  python_date_format: null
  extra: null
version: 1.0.0
database_uuid: {database.uuid}
""",
        ),
        (
            "databases/my_database.yaml",
            f"""database_name: my_database
sqlalchemy_uri: sqlite://
cache_timeout: null
expose_in_sqllab: true
allow_run_async: false
allow_ctas: false
allow_cvas: false
allow_file_upload: false
extra:
  metadata_params: {{}}
  engine_params: {{}}
  metadata_cache_timeout: {{}}
  schemas_allowed_for_file_upload: []
uuid: {database.uuid}
version: 1.0.0
""",
        ),
    ]
コード例 #10
0
ファイル: db_helper.py プロジェクト: Fathui/fsqlfly
 def _clean(cls, obj: DBT, session: Session, base: Type[DBT]):
     back = obj.as_dict()
     session.delete(obj)
     session.commit()
     session.add(base(**back))
     session.commit()
コード例 #11
0
 def add(self, instance, **kwargs):
     SessionBase.add(self, instance, **kwargs)
     return instance
コード例 #12
0
def test_import_dataset(app_context: None, session: Session) -> None:
    """
    Test importing a dataset.
    """
    from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
    from superset.datasets.commands.importers.v1.utils import import_dataset
    from superset.datasets.schemas import ImportV1DatasetSchema
    from superset.models.core import Database

    engine = session.get_bind()
    SqlaTable.metadata.create_all(engine)  # pylint: disable=no-member

    database = Database(database_name="my_database",
                        sqlalchemy_uri="sqlite://")
    session.add(database)
    session.flush()

    dataset_uuid = uuid.uuid4()
    config = {
        "table_name":
        "my_table",
        "main_dttm_col":
        "ds",
        "description":
        "This is the description",
        "default_endpoint":
        None,
        "offset":
        -8,
        "cache_timeout":
        3600,
        "schema":
        "my_schema",
        "sql":
        None,
        "params": {
            "remote_id": 64,
            "database_name": "examples",
            "import_time": 1606677834,
        },
        "template_params": {
            "answer": "42",
        },
        "filter_select_enabled":
        True,
        "fetch_values_predicate":
        "foo IN (1, 2)",
        "extra": {
            "warning_markdown": "*WARNING*"
        },
        "uuid":
        dataset_uuid,
        "metrics": [{
            "metric_name": "cnt",
            "verbose_name": None,
            "metric_type": None,
            "expression": "COUNT(*)",
            "description": None,
            "d3format": None,
            "extra": {
                "warning_markdown": None
            },
            "warning_text": None,
        }],
        "columns": [{
            "column_name": "profit",
            "verbose_name": None,
            "is_dttm": None,
            "is_active": None,
            "type": "INTEGER",
            "groupby": None,
            "filterable": None,
            "expression": "revenue-expenses",
            "description": None,
            "python_date_format": None,
            "extra": {
                "certified_by": "User",
            },
        }],
        "database_uuid":
        database.uuid,
        "database_id":
        database.id,
    }

    sqla_table = import_dataset(session, config)
    assert sqla_table.table_name == "my_table"
    assert sqla_table.main_dttm_col == "ds"
    assert sqla_table.description == "This is the description"
    assert sqla_table.default_endpoint is None
    assert sqla_table.offset == -8
    assert sqla_table.cache_timeout == 3600
    assert sqla_table.schema == "my_schema"
    assert sqla_table.sql is None
    assert sqla_table.params == json.dumps({
        "remote_id": 64,
        "database_name": "examples",
        "import_time": 1606677834
    })
    assert sqla_table.template_params == json.dumps({"answer": "42"})
    assert sqla_table.filter_select_enabled is True
    assert sqla_table.fetch_values_predicate == "foo IN (1, 2)"
    assert sqla_table.extra == '{"warning_markdown": "*WARNING*"}'
    assert sqla_table.uuid == dataset_uuid
    assert len(sqla_table.metrics) == 1
    assert sqla_table.metrics[0].metric_name == "cnt"
    assert sqla_table.metrics[0].verbose_name is None
    assert sqla_table.metrics[0].metric_type is None
    assert sqla_table.metrics[0].expression == "COUNT(*)"
    assert sqla_table.metrics[0].description is None
    assert sqla_table.metrics[0].d3format is None
    assert sqla_table.metrics[0].extra == '{"warning_markdown": null}'
    assert sqla_table.metrics[0].warning_text is None
    assert len(sqla_table.columns) == 1
    assert sqla_table.columns[0].column_name == "profit"
    assert sqla_table.columns[0].verbose_name is None
    assert sqla_table.columns[0].is_dttm is None
    assert sqla_table.columns[0].is_active is None
    assert sqla_table.columns[0].type == "INTEGER"
    assert sqla_table.columns[0].groupby is None
    assert sqla_table.columns[0].filterable is None
    assert sqla_table.columns[0].expression == "revenue-expenses"
    assert sqla_table.columns[0].description is None
    assert sqla_table.columns[0].python_date_format is None
    assert sqla_table.columns[0].extra == '{"certified_by": "User"}'
    assert sqla_table.database.uuid == database.uuid
    assert sqla_table.database.id == database.id
コード例 #13
0
###############################################################################
#                              scoped_session
# The scoped_session object by default uses threading.local() as storage
###############################################################################

# initial
session_factory = sessionmaker(bind=engine)
Session = scoped_session(session_factory)

# only keep one session instance for each thread
session_1 = Session()
session_2 = Session()
assert session_1 is session_2    # True

# release and recreate session instance
Session.remove()
session_3 = Session()
assert session_3 is session_1    # False

# as a proxy
Session.query()
Session.add()
Session.commit()


# bind to a request scope
def get_current_request():
    # return an scope, for example, coroutine, thread or request
    pass
Session = scoped_session(session_factory, scopefunc=get_current_request)
コード例 #14
0
ファイル: models.py プロジェクト: Scotti18/Recipes
def store_new_user(username, firstname, surname, pw_hash, email):
    session = Session()
    new_user = User(username, firstname, surname, pw_hash, email)
    session.add(new_user)
    session.commit()
    session.close()
コード例 #15
0
# features a generative interface whereby successive calls return a new Query 
# object, a copy of the former with additional criteria and options associated
# with it.
#     Query objects are normally initially generated using the query() method
# of Session. 
#     query() takes a variable number of arguments, any combination of mapped  
# class, a Mapper object, an orm-enabled descriptor, or an AliasedClass object.
###############################################################################

# basic query #############################################

# select
query = Query(User)

# insert
session.add(User(id=1))

# aggregation
(Query(User.department.
       func.sum(User.salary).albel('salary'),
       func.count('*').label('total_number'))
 .group_by(User.department))

# case
Query(User.id,
      case([(User.salary < 1000, 1),
            (User.salary.between(1000,2000), 2)],
           else_=0).label('salary'))


# join #####################################################
コード例 #16
0
def export_sqlite(database_path: str, api_type, datas):
    metadata_directory = os.path.dirname(database_path)
    os.makedirs(metadata_directory, exist_ok=True)
    database_name = os.path.basename(database_path).replace(".db", "")
    cwd = os.getcwd()
    alembic_location = os.path.join(cwd, "database", "databases",
                                    database_name.lower())
    db_helper.run_migrations(alembic_location, database_path)
    Session, engine = db_helper.create_database_session(database_path)
    db_collection = db_helper.database_collection()
    database = db_collection.database_picker(database_name)
    if not database:
        return
    database_session = Session()
    api_table = database.table_picker(api_type)
    if not api_table:
        return
    for post in datas:
        post_id = post["post_id"]
        postedAt = post["postedAt"]
        date_object = None
        if postedAt:
            if not isinstance(postedAt, datetime):
                date_object = datetime.strptime(postedAt, "%d-%m-%Y %H:%M:%S")
            else:
                date_object = postedAt
        result = database_session.query(api_table)
        post_db = result.filter_by(post_id=post_id).first()
        if not post_db:
            post_db = api_table()
        if api_type == "Messages":
            post_db.user_id = post["user_id"]
        post_db.post_id = post_id
        post_db.text = post["text"]
        if post["price"] is None:
            post["price"] = 0
        post_db.price = post["price"]
        post_db.paid = post["paid"]
        post_db.archived = post["archived"]
        if date_object:
            post_db.created_at = date_object
        database_session.add(post_db)
        for media in post["medias"]:
            if media["media_type"] == "Texts":
                continue
            created_at = media["created_at"]
            if not isinstance(created_at, datetime):
                date_object = datetime.strptime(created_at,
                                                "%d-%m-%Y %H:%M:%S")
            else:
                date_object = postedAt
            media_id = media.get("media_id", None)
            result = database_session.query(database.media_table)
            media_db = result.filter_by(media_id=media_id).first()
            if not media_db:
                media_db = result.filter_by(filename=media["filename"],
                                            created_at=date_object).first()
                if not media_db:
                    media_db = database.media_table()
            media_db.media_id = media_id
            media_db.post_id = post_id
            if "_sa_instance_state" in post:
                media_db.size = media["size"]
                media_db.downloaded = media["downloaded"]
            media_db.link = media["links"][0]
            media_db.preview = media.get("preview", False)
            media_db.directory = media["directory"]
            media_db.filename = media["filename"]
            media_db.api_type = api_type
            media_db.media_type = media["media_type"]
            media_db.linked = media.get("linked", None)
            if date_object:
                media_db.created_at = date_object
            database_session.add(media_db)
            print
        print
    print
    database_session.commit()
    database_session.close()
    return Session, api_type, database
コード例 #17
0
 def _add_index_entry(self, digest: Digest, session: SessionType) -> None:
     """ Helper method to add an index entry """
     session.add(
         IndexEntry(digest_hash=digest.hash,
                    digest_size_bytes=digest.size_bytes,
                    accessed_timestamp=datetime.utcnow()))
コード例 #18
0
def export_sqlite2(archive_path, datas, parent_type, legacy_fixer=False):
    metadata_directory = os.path.dirname(archive_path)
    os.makedirs(metadata_directory, exist_ok=True)
    cwd = os.getcwd()
    api_type: str = os.path.basename(archive_path).removesuffix(".db")
    database_path = archive_path
    database_name = parent_type if parent_type else api_type
    database_name = database_name.lower()
    db_collection = db_helper.database_collection()
    database = db_collection.database_picker(database_name)
    if not database:
        return
    alembic_location = os.path.join(cwd, "database", "databases",
                                    database_name)
    database_exists = os.path.exists(database_path)
    if database_exists:
        if os.path.getsize(database_path) == 0:
            os.remove(database_path)
            database_exists = False
    if not legacy_fixer:
        legacy_database_fixer(database_path, database, database_name,
                              database_exists)
    db_helper.run_migrations(alembic_location, database_path)
    print
    Session, engine = db_helper.create_database_session(database_path)
    database_session = Session()
    api_table = database.api_table
    media_table = database.media_table

    for post in datas:
        post_id = post["post_id"]
        postedAt = post["postedAt"]
        date_object = None
        if postedAt:
            if not isinstance(postedAt, datetime):
                date_object = datetime.strptime(postedAt, "%d-%m-%Y %H:%M:%S")
            else:
                date_object = postedAt
        result = database_session.query(api_table)
        post_db = result.filter_by(post_id=post_id).first()
        if not post_db:
            post_db = api_table()
        post_db.post_id = post_id
        post_db.text = post["text"]
        if post["price"] is None:
            post["price"] = 0
        post_db.price = post["price"]
        post_db.paid = post["paid"]
        post_db.archived = post["archived"]
        if date_object:
            post_db.created_at = date_object
        database_session.add(post_db)
        for media in post["medias"]:
            if media["media_type"] == "Texts":
                continue
            media_id = media.get("media_id", None)
            result = database_session.query(media_table)
            media_db = result.filter_by(media_id=media_id).first()
            if not media_db:
                media_db = result.filter_by(filename=media["filename"],
                                            created_at=date_object).first()
                if not media_db:
                    media_db = media_table()
            if legacy_fixer:
                media_db.size = media["size"]
                media_db.downloaded = media["downloaded"]
            media_db.media_id = media_id
            media_db.post_id = post_id
            media_db.link = media["links"][0]
            media_db.preview = media.get("preview", False)
            media_db.directory = media["directory"]
            media_db.filename = media["filename"]
            media_db.api_type = api_type
            media_db.media_type = media["media_type"]
            media_db.linked = media.get("linked", None)
            if date_object:
                media_db.created_at = date_object
            database_session.add(media_db)
            print
        print
    print

    database_session.commit()
    database_session.close()
    return Session, api_type, database
コード例 #19
0
def compile_hourly_statistics(
    instance: Recorder, session: Session, start: datetime
) -> None:
    """Compile hourly statistics.

    This will summarize 5-minute statistics for one hour:
    - average, min max is computed by a database query
    - sum is taken from the last 5-minute entry during the hour
    """
    start_time = start.replace(minute=0)
    end_time = start_time + timedelta(hours=1)

    # Compute last hour's average, min, max
    summary: dict[str, StatisticData] = {}
    baked_query = instance.hass.data[STATISTICS_BAKERY](
        lambda session: session.query(*QUERY_STATISTICS_SUMMARY_MEAN)
    )

    baked_query += lambda q: q.filter(
        StatisticsShortTerm.start >= bindparam("start_time")
    )
    baked_query += lambda q: q.filter(StatisticsShortTerm.start < bindparam("end_time"))
    baked_query += lambda q: q.group_by(StatisticsShortTerm.metadata_id)
    baked_query += lambda q: q.order_by(StatisticsShortTerm.metadata_id)

    stats = execute(
        baked_query(session).params(start_time=start_time, end_time=end_time)
    )

    if stats:
        for stat in stats:
            metadata_id, _mean, _min, _max = stat
            summary[metadata_id] = {
                "start": start_time,
                "mean": _mean,
                "min": _min,
                "max": _max,
            }

    # Get last hour's last sum
    if instance._db_supports_row_number:  # pylint: disable=[protected-access]
        subquery = (
            session.query(*QUERY_STATISTICS_SUMMARY_SUM)
            .filter(StatisticsShortTerm.start >= bindparam("start_time"))
            .filter(StatisticsShortTerm.start < bindparam("end_time"))
            .subquery()
        )
        query = (
            session.query(subquery)
            .filter(subquery.c.rownum == 1)
            .order_by(subquery.c.metadata_id)
        )
        stats = execute(query.params(start_time=start_time, end_time=end_time))

        if stats:
            for stat in stats:
                metadata_id, start, last_reset, state, _sum, _ = stat
                if metadata_id in summary:
                    summary[metadata_id].update(
                        {
                            "last_reset": process_timestamp(last_reset),
                            "state": state,
                            "sum": _sum,
                        }
                    )
                else:
                    summary[metadata_id] = {
                        "start": start_time,
                        "last_reset": process_timestamp(last_reset),
                        "state": state,
                        "sum": _sum,
                    }
    else:
        baked_query = instance.hass.data[STATISTICS_BAKERY](
            lambda session: session.query(*QUERY_STATISTICS_SUMMARY_SUM_LEGACY)
        )

        baked_query += lambda q: q.filter(
            StatisticsShortTerm.start >= bindparam("start_time")
        )
        baked_query += lambda q: q.filter(
            StatisticsShortTerm.start < bindparam("end_time")
        )
        baked_query += lambda q: q.order_by(
            StatisticsShortTerm.metadata_id, StatisticsShortTerm.start.desc()
        )

        stats = execute(
            baked_query(session).params(start_time=start_time, end_time=end_time)
        )

        if stats:
            for metadata_id, group in groupby(stats, lambda stat: stat["metadata_id"]):  # type: ignore[no-any-return]
                (
                    metadata_id,
                    last_reset,
                    state,
                    _sum,
                ) = next(group)
                if metadata_id in summary:
                    summary[metadata_id].update(
                        {
                            "start": start_time,
                            "last_reset": process_timestamp(last_reset),
                            "state": state,
                            "sum": _sum,
                        }
                    )
                else:
                    summary[metadata_id] = {
                        "start": start_time,
                        "last_reset": process_timestamp(last_reset),
                        "state": state,
                        "sum": _sum,
                    }

    # Insert compiled hourly statistics in the database
    for metadata_id, stat in summary.items():
        session.add(Statistics.from_stats(metadata_id, stat))
コード例 #20
0
ファイル: types.py プロジェクト: SmartTeleMax/iktomi
class TypeDecoratorsTest(unittest.TestCase):

    def setUp(self):
        self.engine = create_engine("sqlite://")
        Base.metadata.create_all(self.engine)
        self.db = Session(bind=self.engine)

    def tearDown(self):
        self.db.query(TypesObject).delete()
        self.db.commit()
        self.db.close()

    def test_string_list(self):
        words = ['one', 'two', 'three', 'four', 'five']
        obj = TypesObject()
        obj.words = words
        self.db.add(obj)
        self.db.commit()

        self.db.close()
        self.db = Session(bind=self.engine)
        obj = self.db.query(TypesObject).first()
        self.assertEqual(words, obj.words)

    def test_integer_list(self):
        numbers = [1, 5, 10, 15, 20]
        obj = TypesObject()
        obj.numbers = numbers
        self.db.add(obj)
        self.db.commit()

        self.db.close()
        self.db = Session(bind=self.engine)
        obj = self.db.query(TypesObject).first()
        self.assertEqual(numbers, obj.numbers)

    def test_string_wrapped_in_html(self):
        obj = TypesObject()
        obj.html_string1 = Markupable('<html>value</html>')
        self.db.add(obj)
        self.db.commit()
        self.db.close()

        self.db = Session(bind=self.engine)
        obj = self.db.query(TypesObject).first()
        self.assertIsInstance(obj.html_string1, Markup)
        self.assertEqual('<html>value</html>', obj.html_string1)

    def test_html_string(self):
        obj = TypesObject()
        obj.html_string2 = Markupable('<html>value</html>')
        self.db.add(obj)
        self.db.commit()
        self.db.close()

        self.db = Session(bind=self.engine)
        obj = self.db.query(TypesObject).first()
        self.assertIsInstance(obj.html_string2, Markup)
        self.assertEqual('<html>value</html>', obj.html_string2)

    def test_html_text(self):
        obj = TypesObject()
        text = "<html>" + "the sample_text " * 100 + "</html>"
        obj.html_text = Markupable(text)
        self.db.add(obj)
        self.db.commit()
        self.db.close()

        self.db = Session(bind=self.engine)
        obj = self.db.query(TypesObject).first()
        self.assertIsInstance(obj.html_text, Markup)
        self.assertEqual(text, obj.html_text)

    def test_html_custom_markup(self):
        obj = TypesObject()
        obj.html_custom = Markupable('<html>   value   </html>')
        self.db.add(obj)
        self.db.commit()
        self.db.close()

        self.db = Session(bind=self.engine)
        obj = self.db.query(TypesObject).first()
        self.assertIsInstance(obj.html_custom, CustomMarkup)
        self.assertEqual('<html> value </html>', obj.html_custom)
コード例 #21
0
def user_library_state_update(
    self,
    update_task: DatabaseTask,
    session: Session,
    user_library_factory_txs,
    block_number,
    block_timestamp,
    block_hash,
    _ipfs_metadata,  # prefix unused args with underscore to prevent pylint
    _blacklisted_cids,
) -> Tuple[int, Set]:
    """Return Tuple containing int representing number of User Library model state changes found in transaction and empty Set (to align with fn signature of other _state_update functions."""
    empty_set: Set[int] = set()
    num_total_changes = 0
    if not user_library_factory_txs:
        return num_total_changes, empty_set

    challenge_bus = update_task.challenge_event_bus
    block_datetime = datetime.utcfromtimestamp(block_timestamp)

    track_save_state_changes: Dict[int, Dict[int, Save]] = {}
    playlist_save_state_changes: Dict[int, Dict[int, Save]] = {}

    for tx_receipt in user_library_factory_txs:
        try:
            add_track_save(
                self,
                update_task.user_library_contract,
                update_task,
                session,
                tx_receipt,
                block_number,
                block_datetime,
                track_save_state_changes,
            )

            add_playlist_save(
                self,
                update_task.user_library_contract,
                update_task,
                session,
                tx_receipt,
                block_number,
                block_datetime,
                playlist_save_state_changes,
            )

            delete_track_save(
                self,
                update_task.user_library_contract,
                update_task,
                session,
                tx_receipt,
                block_number,
                block_datetime,
                track_save_state_changes,
            )

            delete_playlist_save(
                self,
                update_task.user_library_contract,
                update_task,
                session,
                tx_receipt,
                block_number,
                block_datetime,
                playlist_save_state_changes,
            )
        except Exception as e:
            logger.info("Error in user library transaction")
            txhash = update_task.web3.toHex(tx_receipt.transactionHash)
            blockhash = update_task.web3.toHex(block_hash)
            raise IndexingError(
                "user_library", block_number, blockhash, txhash, str(e)
            ) from e

    for user_id, track_ids in track_save_state_changes.items():
        for track_id in track_ids:
            invalidate_old_save(session, user_id, track_id, SaveType.track)
            save = track_ids[track_id]
            session.add(save)
            dispatch_favorite(challenge_bus, save, block_number)
        num_total_changes += len(track_ids)

    for user_id, playlist_ids in playlist_save_state_changes.items():
        for playlist_id in playlist_ids:
            invalidate_old_save(
                session,
                user_id,
                playlist_id,
                playlist_ids[playlist_id].save_type,
            )
            save = playlist_ids[playlist_id]
            session.add(save)
            dispatch_favorite(challenge_bus, save, block_number)
        num_total_changes += len(playlist_ids)

    return num_total_changes, empty_set
コード例 #22
0
 def test_create_zero_screamer_chance(self):
     session = Session(connection)
     webm = WEBM(id_="q", screamer_chance=0)
     session.add(webm)
     session.commit()
コード例 #23
0
ファイル: be.py プロジェクト: jjhoo/putiikki
class Catalog(object):
    def __init__(self, engine):
        self.engine = engine
        self.session = Session(bind=self.engine)

    @subtransaction
    def add_item(self, code, description, long_description=None):
        citem = models.Item(code=code,
                            description=description,
                            long_description=long_description)
        self.session.add(citem)
        return citem

    @subtransaction
    def add_items(self, items):
        for item in items:
            # KeyErrors not caught if missing required field

            try:
                long_desc = item['long description']
            except KeyError:
                long_desc = None

            citem = models.Item(code=item['code'],
                                description=item['description'],
                                description_lower=item['description'].lower(),
                                long_description=long_desc)

            primary=True
            for cat in item['categories']:
                q = self.session.query(models.Category).\
                  filter(models.Category.name == cat)
                cater = q.first()

                icater = models.ItemCategory(item=citem, category=cater,
                                             primary=primary)
                citem.categories.append(icater)
                primary=False
            self.session.add(citem)

    def add_category(self, name):
        self.add_categories([name])

    @subtransaction
    def add_categories(self, names):
        for name in names:
            cater = models.Category(name=name)
            self.session.add(cater)

    # def remove_category: should make sure that zero use
    # def rename_category

    @subtransaction
    def add_item_category(self, item_id, category_id, primary=False):
        catitem = models.ItemCategory(item_id=item_id, category_id=category_id,
                                      primary=primary)
        self.session.add(catitem)
        return catitem

    # def remove_item_category

    def get_item(self, code, as_object=False):
        q = self.session.query(models.Item).filter(models.Item.code == code)
        item = q.first()
        if item is None:
            return None
        if as_object is True:
            return item
        return (item.id, item.code, item.description, item.long_description)

    def remove_item(self, code):
        self.session.begin(subtransactions=True)
        q = self.session.query(models.Item).filter(models.Item.code == code).\
          delete()
        self.session.commit()

    @subtransaction
    def update_item(self, code, description=None, long_description=None,
                    new_code=None):
        q = self.session.query(models.Item).filter(models.Item.code == code)
        item = q.first()

        if new_code is not None:
            item.code = new_code

        if description is not None:
            item.description = description
            item.description_lower = description.lower()

        if long_description is not None:
            item.long_description = long_description

    def get_stock(self, code, as_object=False):
        q = self.session.query(models.Item, models.StockItem).\
          with_entities(models.StockItem).\
          filter(models.Item.code == code).\
          filter(models.Item.id == models.StockItem.item_id)
        res = q.first()
        if res is None:
            return None

        if as_object is True:
            return res
        return (res.id, res.price, res.count)

    @subtransaction
    def add_stock(self, items):

        for item in items:
            q = self.session.query(models.Item).\
              filter(models.Item.code == item['code'])

            citem = q.first()
            stock = models.StockItem(item=citem, count=item['count'],
                                    price=item['price'])
            self.session.add(stock)

    @subtransaction
    def update_stock(self, code, count, price):
        q = self.session.query(models.Item, models.StockItem).\
              with_entities(models.StockItem).\
              join(models.StockItem.item).\
              filter(models.Item.code == code)

        stock = q.first()
        if stock is None:
            q = self.session.query(models.Item).\
              filter(models.Item.code == code)
            item = q.first()
            if item is not None:
                stock = models.StockItem(item=item, count=count, price=price)
                self.session.add(stock)
            else:
                raise KeyError('Unknown item code')
        else:
            stock.count += count
            stock.price = price

    @subtransaction
    def add_items_with_stock(self, items):
        categories = set()
        for item in items:
            for x in item['categories']:
                categories.add(x)
        self.add_categories(categories)

        self.add_items(items)
        self.add_stock(items)

    def list_items(self, sort_key='description',
                   ascending=True, page=1, page_size=10):
        sq = self.session.query(models.Reservation.stock_item_id,
                                func.sum(models.Reservation.count).\
                                label('reserved')).\
                                group_by(models.Reservation.stock_item_id).\
                                subquery()
        q = self.session.query(models.Item, models.Category,
                               models.ItemCategory, models.StockItem,
                               sq.c.reserved).\
          with_entities(models.Item.code, models.Item.description,
                        models.Category.name, models.StockItem.price,
                        models.StockItem.count, sq.c.reserved).\
          join(models.StockItem).\
          join(models.ItemCategory,
               models.Item.id == models.ItemCategory.item_id).\
          join(models.Category,
               models.Category.id == models.ItemCategory.category_id).\
          filter(models.ItemCategory.primary == True).\
          outerjoin(sq, models.StockItem.id == sq.c.stock_item_id)

        q = ordering(q, ascending, sort_key)
        q = pagination(q, page, page_size)

        res = [item_to_json(*x) for x in q]
        return res

    def search_items(self, prefix, price_range, sort_key='description',
                     ascending=True, page=1, page_size=10):
        sq = self.session.query(models.Reservation.stock_item_id,
                                func.sum(models.Reservation.count).\
                                label('reserved')).\
                                group_by(models.Reservation.stock_item_id).\
                                subquery()

        q = self.session.query(models.Item, models.Category,
                               models.ItemCategory, models.StockItem,
                               sq.c.reserved).\
          with_entities(models.Item.code, models.Item.description,
                        models.Category.name, models.StockItem.price,
                        models.StockItem.count,
                        sq.c.reserved).\
          join(models.StockItem).\
          join(models.ItemCategory,
               models.Item.id == models.ItemCategory.item_id).\
          join(models.Category,
               models.Category.id == models.ItemCategory.category_id).\
          filter(models.StockItem.price.between(*price_range),
                 models.Item.description_lower.like('{:s}%'.format(prefix.lower())),
                 models.ItemCategory.primary == True).\
          outerjoin(sq, models.StockItem.id == sq.c.stock_item_id)

        q = ordering(q, ascending, sort_key)
        q = pagination(q, page, page_size)

        res = [item_to_json(*x) for x in q]
        return res

    def list_items_by_prices(self, prices, sort_key='price', prefix=None,
                             ascending=True, page=1, page_size=10):
        pgs = pg_cases(prices)
        pg_case = sqla.case(pgs, else_ = -1).label('price_group')

        sq = self.session.query(models.Reservation.stock_item_id,
                                func.sum(models.Reservation.count).\
                                label('reserved')).\
                                group_by(models.Reservation.stock_item_id).\
                                subquery()

        q = self.session.query(models.Item, models.Category,
                               models.ItemCategory, models.StockItem,
                               sq.c.reserved).\
                               with_entities(pg_case, models.Item.code,
                                             models.Item.description,
                                             models.Category.name,
                                             models.StockItem.price,
                                             models.StockItem.count,
                                             sq.c.reserved)
        if prefix is not None:
            q = q.filter(models.Item.description_lower.like('{:s}%'.format(prefix.lower())))

        q = q.join(models.StockItem.item).\
          outerjoin(sq, models.StockItem.id == sq.c.stock_item_id).\
          filter(models.ItemCategory.item_id == models.Item.id,
                 models.ItemCategory.category_id == models.Category.id,
                 models.ItemCategory.primary == True,
                 pg_case >= 0)

        q = pg_ordering(q, ascending)
        q = pagination(q, page, page_size)
        def to_dict(x):
            tmp = item_to_json(*x[1:])
            tmp.update({'price_group': x[0]})
            return tmp
        res = [to_dict(x) for x in q]
        return res

    def _get_reservations(self, stock_id):
        q = self.session.query(func.sum(models.Reservation.count)).\
          filter(models.StockItem.id == stock_id).\
          filter(models.StockItem.id == models.Reservation.stock_item_id)

        res = q.first()
        if res is None or res[0] is None:
            return 0

        return res[0]

    def _get_reservation(self, basket_item_id):
        q = self.session.query(models.Basket, models.BasketItem,
                               models.Reservation).\
          with_entities(models.Reservation).\
          join(models.Reservation.basket_item).\
          filter(models.BasketItem.id == basket_item_id)
        res = q.first()
        return res

    @subtransaction
    def _update_reservation(self, stock, basket_item):
        reservations = self._get_reservations(basket_item.stock_item_id)
        # can reserve (scount - reservations)
        reservation = self._get_reservation(basket_item.id)

        if reservation is not None:
            rcount = min(basket_item.count,
                         stock.count - reservations + reservation.count)
            reservation.count = rcount
        else:
            rcount = min(basket_item.count, stock.count - reservations)
            reservation = models.Reservation(stock_item=stock,
                                             basket_item=basket_item,
                                             count=rcount)
            self.session.add(reservation)
コード例 #24
0
def test_sql_lab_insert_rls(
    mocker: MockerFixture,
    session: Session,
    app_context: None,
) -> None:
    """
    Integration test for `insert_rls`.
    """
    from flask_appbuilder.security.sqla.models import Role, User

    from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable
    from superset.models.core import Database
    from superset.models.sql_lab import Query
    from superset.security.manager import SupersetSecurityManager
    from superset.sql_lab import execute_sql_statement
    from superset.utils.core import RowLevelSecurityFilterType

    engine = session.connection().engine
    Query.metadata.create_all(engine)  # pylint: disable=no-member

    connection = engine.raw_connection()
    connection.execute("CREATE TABLE t (c INTEGER)")
    for i in range(10):
        connection.execute("INSERT INTO t VALUES (?)", (i, ))

    cursor = connection.cursor()

    query = Query(
        sql="SELECT c FROM t",
        client_id="abcde",
        database=Database(database_name="test_db", sqlalchemy_uri="sqlite://"),
        schema=None,
        limit=5,
        select_as_cta_used=False,
    )
    session.add(query)
    session.commit()

    admin = User(
        first_name="Alice",
        last_name="Doe",
        email="*****@*****.**",
        username="******",
        roles=[Role(name="Admin")],
    )

    # first without RLS
    with override_user(admin):
        superset_result_set = execute_sql_statement(
            sql_statement=query.sql,
            query=query,
            session=session,
            cursor=cursor,
            log_params=None,
            apply_ctas=False,
        )
    assert (superset_result_set.to_pandas_df().to_markdown() == """
|    |   c |
|---:|----:|
|  0 |   0 |
|  1 |   1 |
|  2 |   2 |
|  3 |   3 |
|  4 |   4 |""".strip())
    assert query.executed_sql == "SELECT c FROM t\nLIMIT 6"

    # now with RLS
    rls = RowLevelSecurityFilter(
        name="sqllab_rls1",
        filter_type=RowLevelSecurityFilterType.REGULAR,
        tables=[SqlaTable(database_id=1, schema=None, table_name="t")],
        roles=[admin.roles[0]],
        group_key=None,
        clause="c > 5",
    )
    session.add(rls)
    session.flush()
    mocker.patch.object(SupersetSecurityManager,
                        "find_user",
                        return_value=admin)
    mocker.patch("superset.sql_lab.is_feature_enabled", return_value=True)

    with override_user(admin):
        superset_result_set = execute_sql_statement(
            sql_statement=query.sql,
            query=query,
            session=session,
            cursor=cursor,
            log_params=None,
            apply_ctas=False,
        )
    assert (superset_result_set.to_pandas_df().to_markdown() == """
|    |   c |
|---:|----:|
|  0 |   6 |
|  1 |   7 |
|  2 |   8 |
|  3 |   9 |""".strip())
    assert query.executed_sql == "SELECT c FROM t WHERE (t.c > 5)\nLIMIT 6"
コード例 #25
0
class DatabaseTest(object):

    engine = None
    connection = None

    @classmethod
    def get_database_connection(cls):
        url = Configuration.database_url(test=True)
        engine, connection = SessionManager.initialize(url)

        return engine, connection

    @classmethod
    def setup_class(cls):
        # Initialize a temporary data directory.
        cls.engine, cls.connection = cls.get_database_connection()
        cls.old_data_dir = Configuration.data_directory
        cls.tmp_data_dir = tempfile.mkdtemp(dir="/tmp")
        Configuration.instance[Configuration.DATA_DIRECTORY] = cls.tmp_data_dir

        # Avoid CannotLoadConfiguration errors related to CDN integrations.
        Configuration.instance[Configuration.INTEGRATIONS] = Configuration.instance.get(
            Configuration.INTEGRATIONS, {}
        )
        Configuration.instance[Configuration.INTEGRATIONS][ExternalIntegration.CDN] = {}

        os.environ['TESTING'] = 'true'
        
    @classmethod
    def teardown_class(cls):
        # Destroy the database connection and engine.
        cls.connection.close()
        cls.engine.dispose()

        if cls.tmp_data_dir.startswith("/tmp"):
            logging.debug("Removing temporary directory %s" % cls.tmp_data_dir)
            shutil.rmtree(cls.tmp_data_dir)

        else:
            logging.warn("Cowardly refusing to remove 'temporary' directory %s" % cls.tmp_data_dir)

        Configuration.instance[Configuration.DATA_DIRECTORY] = cls.old_data_dir
        if 'TESTING' in os.environ:
            del os.environ['TESTING']

    def setup(self, mock_search=True):
        # Create a new connection to the database.
        self._db = Session(self.connection)
        self.transaction = self.connection.begin_nested()
        
        # Start with a high number so it won't interfere with tests that search for an age or grade
        self.counter = 2000

        self.time_counter = datetime(2014, 1, 1)
        self.isbns = [
            "9780674368279", "0636920028468", "9781936460236", "9780316075978"
        ]
        if mock_search:
            self.search_mock = mock.patch(external_search.__name__ + ".ExternalSearchIndex", DummyExternalSearchIndex)
            self.search_mock.start()
        else:
            self.search_mock = None

        # TODO:  keeping this for now, but need to fix it bc it hits _isbn, 
        # which pops an isbn off the list and messes tests up.  so exclude 
        # _ functions from participating.
        # also attempt to stop nosetest showing docstrings instead of function names.
        #for name, obj in inspect.getmembers(self):
        #    if inspect.isfunction(obj) and obj.__name__.startswith('test_'):
        #        obj.__doc__ = None


    def teardown(self):
        # Close the session.
        self._db.close()

        # Roll back all database changes that happened during this
        # test, whether in the session that was just closed or some
        # other session.
        self.transaction.rollback()

        # Remove any database objects cached in the model classes but
        # associated with the now-rolled-back session.
        Collection.reset_cache()
        ConfigurationSetting.reset_cache()
        DataSource.reset_cache()
        DeliveryMechanism.reset_cache()
        ExternalIntegration.reset_cache()
        Genre.reset_cache()
        Library.reset_cache()
        
        # Also roll back any record of those changes in the
        # Configuration instance.
        for key in [
                Configuration.SITE_CONFIGURATION_LAST_UPDATE,
                Configuration.LAST_CHECKED_FOR_SITE_CONFIGURATION_UPDATE
        ]:
            if key in Configuration.instance:
                del(Configuration.instance[key])

        if self.search_mock:
            self.search_mock.stop()

    def shortDescription(self):
        return None # Stop nosetests displaying docstrings instead of class names when verbosity level >= 2.

    @property
    def _id(self):
        self.counter += 1
        return self.counter

    @property
    def _str(self):
        return unicode(self._id)

    @property
    def _time(self):
        v = self.time_counter 
        self.time_counter = self.time_counter + timedelta(days=1)
        return v

    @property
    def _isbn(self):
        return self.isbns.pop()

    @property
    def _url(self):
        return "http://foo.com/" + self._str

    def _patron(self, external_identifier=None, library=None):
        external_identifier = external_identifier or self._str
        library = library or self._default_library
        return get_one_or_create(
            self._db, Patron, external_identifier=external_identifier,
            library=library
        )[0]

    def _contributor(self, sort_name=None, name=None, **kw_args):
        name = sort_name or name or self._str
        return get_one_or_create(self._db, Contributor, sort_name=unicode(name), **kw_args)

    def _identifier(self, identifier_type=Identifier.GUTENBERG_ID, foreign_id=None):
        if foreign_id:
            id = foreign_id
        else:
            id = self._str
        return Identifier.for_foreign_id(self._db, identifier_type, id)[0]

    def _edition(self, data_source_name=DataSource.GUTENBERG,
                 identifier_type=Identifier.GUTENBERG_ID,
                 with_license_pool=False, with_open_access_download=False,
                 title=None, language="eng", authors=None, identifier_id=None,
                 series=None, collection=None
    ):
        id = identifier_id or self._str
        source = DataSource.lookup(self._db, data_source_name)
        wr = Edition.for_foreign_id(
            self._db, source, identifier_type, id)[0]
        if not title:
            title = self._str
        wr.title = unicode(title)
        if series:
            wr.series = series
        if language:
            wr.language = language
        if authors is None:
            authors = self._str
        if isinstance(authors, basestring):
            authors = [authors]
        if authors != []:
            wr.add_contributor(unicode(authors[0]), Contributor.PRIMARY_AUTHOR_ROLE)
            wr.author = unicode(authors[0])
        for author in authors[1:]:
            wr.add_contributor(unicode(author), Contributor.AUTHOR_ROLE)

        if with_license_pool or with_open_access_download:
            pool = self._licensepool(
                wr, data_source_name=data_source_name,
                with_open_access_download=with_open_access_download,
                collection=collection
            )

            pool.set_presentation_edition()
            return wr, pool
        return wr

    def _work(self, title=None, authors=None, genre=None, language=None,
              audience=None, fiction=True, with_license_pool=False, 
              with_open_access_download=False, quality=0.5, series=None,
              presentation_edition=None, collection=None, data_source_name=None):
        """Create a Work.

        For performance reasons, this method does not generate OPDS
        entries or calculate a presentation edition for the new
        Work. Tests that rely on this information being present
        should call _slow_work() instead, which takes more care to present
        the sort of Work that would be created in a real environment.
        """
        pools = []
        if with_open_access_download:
            with_license_pool = True
        language = language or "eng"
        title = unicode(title or self._str)
        audience = audience or Classifier.AUDIENCE_ADULT
        if audience == Classifier.AUDIENCE_CHILDREN and not data_source_name:
            # TODO: This is necessary because Gutenberg's childrens books
            # get filtered out at the moment.
            data_source_name = DataSource.OVERDRIVE
        elif not data_source_name:
            data_source_name = DataSource.GUTENBERG
        if fiction is None:
            fiction = True
        new_edition = False
        if not presentation_edition:
            new_edition = True
            presentation_edition = self._edition(
                title=title, language=language,
                authors=authors,
                with_license_pool=with_license_pool,
                with_open_access_download=with_open_access_download,
                data_source_name=data_source_name,
                series=series,
                collection=collection,
            )
            if with_license_pool:
                presentation_edition, pool = presentation_edition
                if with_open_access_download:
                    pool.open_access = True
                pools = [pool]
        else:
            pools = presentation_edition.license_pools
        work, ignore = get_one_or_create(
            self._db, Work, create_method_kwargs=dict(
                audience=audience,
                fiction=fiction,
                quality=quality), id=self._id)
        if genre:
            if not isinstance(genre, Genre):
                genre, ignore = Genre.lookup(self._db, genre, autocreate=True)
            work.genres = [genre]
        work.random = 0.5
        work.set_presentation_edition(presentation_edition)

        if pools:
            # make sure the pool's presentation_edition is set, 
            # bc loan tests assume that.
            if not work.license_pools:
                for pool in pools:
                    work.license_pools.append(pool)

            for pool in pools:
                pool.set_presentation_edition()

            # This is probably going to be used in an OPDS feed, so
            # fake that the work is presentation ready.
            work.presentation_ready = True
            work.calculate_opds_entries(verbose=False)

        return work

    def add_to_materialized_view(self, works, true_opds=False):
        """Make sure all the works in `works` show up in the materialized view.

        :param true_opds: Generate real OPDS entries for each each work,
        rather than faking it.
        """
        if not isinstance(works, list):
            works = [works]
        for work in works:
            if true_opds:
                work.calculate_opds_entries(verbose=False)
            else:
                work.presentation_ready = True
                work.simple_opds_entry = "<entry>an entry</entry>"
        self._db.commit()
        SessionManager.refresh_materialized_views(self._db)

    def _lane(self, display_name=None, library=None, 
              parent=None, genres=None, languages=None,
              fiction=None
    ):
        display_name = display_name or self._str
        library = library or self._default_library
        lane, is_new = get_one_or_create(
            self._db, Lane,
            library=library,
            parent=parent, display_name=display_name,
            create_method_kwargs=dict(fiction=fiction)
        )
        if is_new and parent:
            lane.priority = len(parent.sublanes)-1
        if genres:
            if not isinstance(genres, list):
                genres = [genres]
            for genre in genres:
                if isinstance(genre, basestring):
                    genre, ignore = Genre.lookup(self._db, genre)
                lane.genres.append(genre)
        if languages:
            if not isinstance(languages, list):
                languages = [languages]
            lane.languages = languages
        return lane

    def _slow_work(self, *args, **kwargs):
        """Create a work that closely resembles one that might be found in the
        wild.

        This is significantly slower than _work() but more reliable.
        """
        work = self._work(*args, **kwargs)
        work.calculate_presentation_edition()
        work.calculate_opds_entries(verbose=False)
        return work

    def _add_generic_delivery_mechanism(self, license_pool):
        """Give a license pool a generic non-open-access delivery mechanism."""
        data_source = license_pool.data_source
        identifier = license_pool.identifier
        content_type = Representation.EPUB_MEDIA_TYPE
        drm_scheme = DeliveryMechanism.NO_DRM
        LicensePoolDeliveryMechanism.set(
            data_source, identifier, content_type, drm_scheme,
            RightsStatus.IN_COPYRIGHT
        )

    def _coverage_record(self, edition, coverage_source, operation=None,
        status=CoverageRecord.SUCCESS, collection=None, exception=None,
    ):
        if isinstance(edition, Identifier):
            identifier = edition
        else:
            identifier = edition.primary_identifier
        record, ignore = get_one_or_create(
            self._db, CoverageRecord,
            identifier=identifier,
            data_source=coverage_source,
            operation=operation,
            collection=collection,
            create_method_kwargs = dict(
                timestamp=datetime.utcnow(),
                status=status,
                exception=exception,
            )
        )
        return record

    def _work_coverage_record(self, work, operation=None, 
                              status=CoverageRecord.SUCCESS):
        record, ignore = get_one_or_create(
            self._db, WorkCoverageRecord,
            work=work,
            operation=operation,
            create_method_kwargs = dict(
                timestamp=datetime.utcnow(),
                status=status,
            )
        )
        return record

    def _licensepool(self, edition, open_access=True, 
                     data_source_name=DataSource.GUTENBERG,
                     with_open_access_download=False, 
                     set_edition_as_presentation=False,
                     collection=None):
        source = DataSource.lookup(self._db, data_source_name)
        if not edition:
            edition = self._edition(data_source_name)
        collection = collection or self._default_collection
        pool, ignore = get_one_or_create(
            self._db, LicensePool,
            create_method_kwargs=dict(
                open_access=open_access),
            identifier=edition.primary_identifier,
            data_source=source,
            collection=collection,
            availability_time=datetime.utcnow()
        )

        if set_edition_as_presentation:
            pool.presentation_edition = edition

        if with_open_access_download:
            pool.open_access = True
            url = "http://foo.com/" + self._str
            media_type = Representation.EPUB_MEDIA_TYPE
            link, new = pool.identifier.add_link(
                Hyperlink.OPEN_ACCESS_DOWNLOAD, url,
                source, media_type
            )

            # Add a DeliveryMechanism for this download
            pool.set_delivery_mechanism(
                media_type,
                DeliveryMechanism.NO_DRM,
                RightsStatus.GENERIC_OPEN_ACCESS,
                link.resource,
            )

            representation, is_new = self._representation(
                url, media_type, "Dummy content", mirrored=True)
            link.resource.representation = representation
        else:

            # Add a DeliveryMechanism for this licensepool
            pool.set_delivery_mechanism(
                Representation.EPUB_MEDIA_TYPE,
                DeliveryMechanism.ADOBE_DRM,
                RightsStatus.UNKNOWN,
                None
            )
            pool.licenses_owned = pool.licenses_available = 1

        return pool

    def _representation(self, url=None, media_type=None, content=None,
                        mirrored=False):
        url = url or "http://foo.com/" + self._str
        repr, is_new = get_one_or_create(
            self._db, Representation, url=url)
        repr.media_type = media_type
        if media_type and content:
            repr.content = content
            repr.fetched_at = datetime.utcnow()
            if mirrored:
                repr.mirror_url = "http://foo.com/" + self._str
                repr.mirrored_at = datetime.utcnow()            
        return repr, is_new

    def _customlist(self, foreign_identifier=None, 
                    name=None,
                    data_source_name=DataSource.NYT, num_entries=1,
                    entries_exist_as_works=True
    ):
        data_source = DataSource.lookup(self._db, data_source_name)
        foreign_identifier = foreign_identifier or self._str
        now = datetime.utcnow()
        customlist, ignore = get_one_or_create(
            self._db, CustomList,
            create_method_kwargs=dict(
                created=now,
                updated=now,
                name=name or self._str,
                description=self._str,
                ),
            data_source=data_source,
            foreign_identifier=foreign_identifier
        )

        editions = []
        for i in range(num_entries):
            if entries_exist_as_works:
                work = self._work(with_open_access_download=True)
                edition = work.presentation_edition
            else:
                edition = self._edition(
                    data_source_name, title="Item %s" % i)
                edition.permanent_work_id="Permanent work ID %s" % self._str
            customlist.add_entry(
                edition, "Annotation %s" % i, first_appearance=now)
            editions.append(edition)
        return customlist, editions

    def _complaint(self, license_pool, type, source, detail, resolved=None):
        complaint, is_new = Complaint.register(
            license_pool,
            type,
            source,
            detail,
            resolved
        )
        return complaint

    def _credential(self, data_source_name=DataSource.GUTENBERG,
                    type=None, patron=None):
        data_source = DataSource.lookup(self._db, data_source_name)
        type = type or self._str
        patron = patron or self._patron()
        credential, is_new = Credential.persistent_token_create(
            self._db, data_source, type, patron
        )
        return credential
    
    def _external_integration(self, protocol, goal=None, settings=None,
                              libraries=None, **kwargs
    ):
        integration = None
        if not libraries:
            integration, ignore = get_one_or_create(
                self._db, ExternalIntegration, protocol=protocol, goal=goal
            )
        else:
            if not isinstance(libraries, list):
                libraries = [libraries]

            # Try to find an existing integration for one of the given
            # libraries.
            for library in libraries:
                integration = ExternalIntegration.lookup(
                    self._db, protocol, goal, library=libraries[0]
                )
                if integration:
                    break

            if not integration:
                # Otherwise, create a brand new integration specifically
                # for the library.
                integration = ExternalIntegration(
                    protocol=protocol, goal=goal,
                )
                integration.libraries.extend(libraries)
                self._db.add(integration)

        for attr, value in kwargs.items():
            setattr(integration, attr, value)

        settings = settings or dict()
        for key, value in settings.items():
            integration.set_setting(key, value)

        return integration

    def _delegated_patron_identifier(
            self, library_uri=None, patron_identifier=None,
            identifier_type=DelegatedPatronIdentifier.ADOBE_ACCOUNT_ID,
            identifier=None
    ):
        """Create a sample DelegatedPatronIdentifier"""
        library_uri = library_uri or self._url
        patron_identifier = patron_identifier or self._str
        if callable(identifier):
            make_id = identifier
        else:
            if not identifier:
                identifier = self._str
            def make_id():
                return identifier
        patron, is_new = DelegatedPatronIdentifier.get_one_or_create(
            self._db, library_uri, patron_identifier, identifier_type,
            make_id
        )
        return patron

    def _sample_ecosystem(self):
        """ Creates an ecosystem of some sample work, pool, edition, and author 
        objects that all know each other. 
        """
        # make some authors
        [bob], ignore = Contributor.lookup(self._db, u"Bitshifter, Bob")
        bob.family_name, bob.display_name = bob.default_names()
        [alice], ignore = Contributor.lookup(self._db, u"Adder, Alice")
        alice.family_name, alice.display_name = alice.default_names()

        edition_std_ebooks, pool_std_ebooks = self._edition(DataSource.STANDARD_EBOOKS, Identifier.URI, 
            with_license_pool=True, with_open_access_download=True, authors=[])
        edition_std_ebooks.title = u"The Standard Ebooks Title"
        edition_std_ebooks.subtitle = u"The Standard Ebooks Subtitle"
        edition_std_ebooks.add_contributor(alice, Contributor.AUTHOR_ROLE)

        edition_git, pool_git = self._edition(DataSource.PROJECT_GITENBERG, Identifier.GUTENBERG_ID, 
            with_license_pool=True, with_open_access_download=True, authors=[])
        edition_git.title = u"The GItenberg Title"
        edition_git.subtitle = u"The GItenberg Subtitle"
        edition_git.add_contributor(bob, Contributor.AUTHOR_ROLE)
        edition_git.add_contributor(alice, Contributor.AUTHOR_ROLE)

        edition_gut, pool_gut = self._edition(DataSource.GUTENBERG, Identifier.GUTENBERG_ID, 
            with_license_pool=True, with_open_access_download=True, authors=[])
        edition_gut.title = u"The GUtenberg Title"
        edition_gut.subtitle = u"The GUtenberg Subtitle"
        edition_gut.add_contributor(bob, Contributor.AUTHOR_ROLE)

        work = self._work(presentation_edition=edition_git)

        for p in pool_gut, pool_std_ebooks:
            work.license_pools.append(p)

        work.calculate_presentation()

        return (work, pool_std_ebooks, pool_git, pool_gut, 
            edition_std_ebooks, edition_git, edition_gut, alice, bob)


    def print_database_instance(self):
        """
        Calls the class method that examines the current state of the database model 
        (whether it's been committed or not).

        NOTE:  If you set_trace, and hit "continue", you'll start seeing console output right 
        away, without waiting for the whole test to run and the standard output section to display.
        You can also use nosetest --nocapture.
        I use:
        def test_name(self):
            [code...]
            set_trace()
            self.print_database_instance()  # TODO: remove before prod
            [code...]
        """
        if not 'TESTING' in os.environ:
            # we are on production, abort, abort!
            logging.warn("Forgot to remove call to testing.py:DatabaseTest.print_database_instance() before pushing to production.")
            return

        DatabaseTest.print_database_class(self._db)
        return


    @classmethod
    def print_database_class(cls, db_connection):
        """
        Prints to the console the entire contents of the database, as the unit test sees it. 
        Exists because unit tests don't persist db information, they create a memory 
        representation of the db state, and then roll the unit test-derived transactions back.
        So we cannot see what's going on by going into postgres and running selects.
        This is the in-test alternative to going into postgres.

        Can be called from model and metadata classes as well as tests.

        NOTE: The purpose of this method is for debugging.  
        Be careful of leaving it in code and potentially outputting 
        vast tracts of data into your output stream on production.

        Call like this:
        set_trace()
        from testing import (
            DatabaseTest, 
        )
        _db = Session.object_session(self)
        DatabaseTest.print_database_class(_db)  # TODO: remove before prod
        """
        if not 'TESTING' in os.environ:
            # we are on production, abort, abort!
            logging.warn("Forgot to remove call to testing.py:DatabaseTest.print_database_class() before pushing to production.")
            return

        works = db_connection.query(Work).all()
        identifiers = db_connection.query(Identifier).all()
        license_pools = db_connection.query(LicensePool).all()
        editions = db_connection.query(Edition).all()
        data_sources = db_connection.query(DataSource).all()
        representations = db_connection.query(Representation).all()

        if (not works):
            print "NO Work found"
        for wCount, work in enumerate(works):
            # pipe character at end of line helps see whitespace issues
            print "Work[%s]=%s|" % (wCount, work)

            if (not work.license_pools):
                print "    NO Work.LicensePool found"
            for lpCount, license_pool in enumerate(work.license_pools):
                print "    Work.LicensePool[%s]=%s|" % (lpCount, license_pool)

            print "    Work.presentation_edition=%s|" % work.presentation_edition

        print "__________________________________________________________________\n"
        if (not identifiers):
            print "NO Identifier found"
        for iCount, identifier in enumerate(identifiers):
            print "Identifier[%s]=%s|" % (iCount, identifier)
            print "    Identifier.licensed_through=%s|" % identifier.licensed_through           

        print "__________________________________________________________________\n"
        if (not license_pools):
            print "NO LicensePool found"
        for index, license_pool in enumerate(license_pools):
            print "LicensePool[%s]=%s|" % (index, license_pool)
            print "    LicensePool.work_id=%s|" % license_pool.work_id
            print "    LicensePool.data_source_id=%s|" % license_pool.data_source_id
            print "    LicensePool.identifier_id=%s|" % license_pool.identifier_id
            print "    LicensePool.presentation_edition_id=%s|" % license_pool.presentation_edition_id            
            print "    LicensePool.superceded=%s|" % license_pool.superceded
            print "    LicensePool.suppressed=%s|" % license_pool.suppressed

        print "__________________________________________________________________\n"
        if (not editions):
            print "NO Edition found"
        for index, edition in enumerate(editions):
            # pipe character at end of line helps see whitespace issues
            print "Edition[%s]=%s|" % (index, edition)
            print "    Edition.primary_identifier_id=%s|" % edition.primary_identifier_id
            print "    Edition.permanent_work_id=%s|" % edition.permanent_work_id
            if (edition.data_source):
                print "    Edition.data_source.id=%s|" % edition.data_source.id
                print "    Edition.data_source.name=%s|" % edition.data_source.name
            else:
                print "    No Edition.data_source."
            if (edition.license_pool):
                print "    Edition.license_pool.id=%s|" % edition.license_pool.id
            else:
                print "    No Edition.license_pool."

            print "    Edition.title=%s|" % edition.title
            print "    Edition.author=%s|" % edition.author
            if (not edition.author_contributors):
                print "    NO Edition.author_contributor found"
            for acCount, author_contributor in enumerate(edition.author_contributors):
                print "    Edition.author_contributor[%s]=%s|" % (acCount, author_contributor)

        print "__________________________________________________________________\n"
        if (not data_sources):
            print "NO DataSource found"
        for index, data_source in enumerate(data_sources):
            print "DataSource[%s]=%s|" % (index, data_source)
            print "    DataSource.id=%s|" % data_source.id
            print "    DataSource.name=%s|" % data_source.name
            print "    DataSource.offers_licenses=%s|" % data_source.offers_licenses
            print "    DataSource.editions=%s|" % data_source.editions            
            print "    DataSource.license_pools=%s|" % data_source.license_pools
            print "    DataSource.links=%s|" % data_source.links

        print "__________________________________________________________________\n"
        if (not representations):
            print "NO Representation found"
        for index, representation in enumerate(representations):
            print "Representation[%s]=%s|" % (index, representation)
            print "    Representation.id=%s|" % representation.id
            print "    Representation.url=%s|" % representation.url
            print "    Representation.mirror_url=%s|" % representation.mirror_url
            print "    Representation.fetch_exception=%s|" % representation.fetch_exception   
            print "    Representation.mirror_exception=%s|" % representation.mirror_exception

        return


    def _library(self, name=None, short_name=None):
        name=name or self._str
        short_name = short_name or self._str
        library, ignore = get_one_or_create(
            self._db, Library, name=name, short_name=short_name,
            create_method_kwargs=dict(uuid=str(uuid.uuid4())),
        )
        return library
    
    def _collection(self, name=None, protocol=ExternalIntegration.OPDS_IMPORT,
                    external_account_id=None, url=None, username=None,
                    password=None, data_source_name=None):
        name = name or self._str
        collection, ignore = get_one_or_create(
            self._db, Collection, name=name
        )
        collection.external_account_id = external_account_id
        integration = collection.create_external_integration(protocol)
        integration.goal = ExternalIntegration.LICENSE_GOAL
        integration.url = url
        integration.username = username
        integration.password = password

        if data_source_name:
            collection.data_source = data_source_name
        return collection
        
    @property
    def _default_library(self):
        """A Library that will only be created once throughout a given test.

        By default, the `_default_collection` will be associated with
        the default library.
        """
        if not hasattr(self, '_default__library'):
            self._default__library = self.make_default_library(self._db)
        return self._default__library
        
    @property
    def _default_collection(self):
        """A Collection that will only be created once throughout
        a given test.

        For most tests there's no need to create a different
        Collection for every LicensePool. Using
        self._default_collection instead of calling self.collection()
        saves time.
        """
        if not hasattr(self, '_default__collection'):
            self._default__collection = self._default_library.collections[0]
        return self._default__collection

    @classmethod
    def make_default_library(cls, _db):
        """Ensure that the default library exists in the given database.

        This can be called by code intended for use in testing but not actually
        within a DatabaseTest subclass.
        """
        library, ignore = get_one_or_create(
            _db, Library, create_method_kwargs=dict(
                uuid=unicode(uuid.uuid4()),
                name="default",
            ), short_name="default"
        )
        collection, ignore = get_one_or_create(
            _db, Collection, name="Default Collection"
        )
        integration = collection.create_external_integration(
            ExternalIntegration.OPDS_IMPORT
        )
        integration.goal = ExternalIntegration.LICENSE_GOAL
        if collection not in library.collections:
            library.collections.append(collection)
        return library
            
    def _catalog(self, name=u"Faketown Public Library"):
        source, ignore = get_one_or_create(self._db, DataSource, name=name)
        
    def _integration_client(self, url=None, shared_secret=None):
        url = url or self._url
        secret = shared_secret or u"secret"
        return get_one_or_create(
            self._db, IntegrationClient, shared_secret=secret,
            create_method_kwargs=dict(url=url)
        )[0]

    def _subject(self, type, identifier):
        return get_one_or_create(
            self._db, Subject, type=type, identifier=identifier
        )[0]

    def _classification(self, identifier, subject, data_source, weight=1):
        return get_one_or_create(
            self._db, Classification, identifier=identifier, subject=subject, 
            data_source=data_source, weight=weight
        )[0]

    def sample_cover_path(self, name):
        """The path to the sample cover with the given filename."""
        base_path = os.path.split(__file__)[0]
        resource_path = os.path.join(base_path, "tests", "files", "covers")
        sample_cover_path = os.path.join(resource_path, name)
        return sample_cover_path

    def sample_cover_representation(self, name):
        """A Representation of the sample cover with the given filename."""
        sample_cover_path = self.sample_cover_path(name)
        return self._representation(
            media_type="image/png", content=open(sample_cover_path).read())[0]
コード例 #26
0
def test_import_column_extra_is_string(app_context: None,
                                       session: Session) -> None:
    """
    Test importing a dataset when the column extra is a string.
    """
    from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
    from superset.datasets.commands.importers.v1.utils import import_dataset
    from superset.datasets.schemas import ImportV1DatasetSchema
    from superset.models.core import Database

    engine = session.get_bind()
    SqlaTable.metadata.create_all(engine)  # pylint: disable=no-member

    database = Database(database_name="my_database",
                        sqlalchemy_uri="sqlite://")
    session.add(database)
    session.flush()

    dataset_uuid = uuid.uuid4()
    yaml_config: Dict[str, Any] = {
        "version":
        "1.0.0",
        "table_name":
        "my_table",
        "main_dttm_col":
        "ds",
        "description":
        "This is the description",
        "default_endpoint":
        None,
        "offset":
        -8,
        "cache_timeout":
        3600,
        "schema":
        "my_schema",
        "sql":
        None,
        "params": {
            "remote_id": 64,
            "database_name": "examples",
            "import_time": 1606677834,
        },
        "template_params": {
            "answer": "42",
        },
        "filter_select_enabled":
        True,
        "fetch_values_predicate":
        "foo IN (1, 2)",
        "extra":
        '{"warning_markdown": "*WARNING*"}',
        "uuid":
        dataset_uuid,
        "metrics": [{
            "metric_name": "cnt",
            "verbose_name": None,
            "metric_type": None,
            "expression": "COUNT(*)",
            "description": None,
            "d3format": None,
            "extra": '{"warning_markdown": null}',
            "warning_text": None,
        }],
        "columns": [{
            "column_name": "profit",
            "verbose_name": None,
            "is_dttm": False,
            "is_active": True,
            "type": "INTEGER",
            "groupby": False,
            "filterable": False,
            "expression": "revenue-expenses",
            "description": None,
            "python_date_format": None,
            "extra": '{"certified_by": "User"}',
        }],
        "database_uuid":
        database.uuid,
    }

    schema = ImportV1DatasetSchema()
    dataset_config = schema.load(yaml_config)
    dataset_config["database_id"] = database.id
    sqla_table = import_dataset(session, dataset_config)

    assert sqla_table.metrics[0].extra == '{"warning_markdown": null}'
    assert sqla_table.columns[0].extra == '{"certified_by": "User"}'
    assert sqla_table.extra == '{"warning_markdown": "*WARNING*"}'
コード例 #27
0
ファイル: session.py プロジェクト: helixyte/everest
 def add(self, entity_class, data): # different signature pylint: disable=W0222
     if not IEntity.providedBy(data): # pylint: disable=E1101
         self.__run_traversal(entity_class, data, None,
                              RELATION_OPERATIONS.ADD)
     else:
         SaSession.add(self, data)
コード例 #28
0
ファイル: order.py プロジェクト: vnzinki/vin_test
 def add_fruit(self, db: Session, order: Order, item: Fruit, quantity: int):
     order.order_items.append(OrderItem(item=item, quantity=quantity))
     db.add(order)
     db.commit()
     db.refresh(order)
     return order
コード例 #29
0
 def test_create_none_screamer_chance_default(self):
     session = Session(connection)
     webm = WEBM(id_="q")  # TODO: add 32 length constrait
     session.add(webm)
     session.commit()
コード例 #30
0
ファイル: order.py プロジェクト: vnzinki/vin_test
 def set_total_price(self, db: Session, order: Order, total_price: float):
     order.total_price = total_price
     db.add(order)
     db.commit()
     db.refresh(order)
     return order
コード例 #31
0
 def test_create_normal_foreign_key(self):
     session = Session(connection)
     dirty_webm = DirtyWEBM(md5="a4", webm_id="q")
     session.add(dirty_webm)
     session.commit()
コード例 #32
0
ファイル: main.py プロジェクト: lhm-limux/gosa
session = Session()

# Load schema loader/object factory
factory = SchemaLoader()
factory.loadSchema('schema/Base-schema.xml')
factory.loadSchema('schema/Person-schema.xml')

# Populate metaclasses, maybe there is a nicer way to do this
for (cName, cClass) in factory.getClasses().items():
    globals()[cName] = cClass 

# Create a new object called father
father = Employee('cn=Father of Horst, ...')
father.givenName = 'The father!'
father.age = 1234
session.add(father.getObject())
session.commit()

# Create another object which references to the father object
tim = Person('cn=Horst Hackpeter, ist voll toll')
tim.givenName = 'Horst'
tim.sn = 'Hackepeter'
tim.age = 2
tim.parent = father.getObject()
tim.notes = ['tester', '44', 55, father.getObject()]
session.add(tim.getObject())
session.commit()

# Walk through database that match 'Horst'
for entry in session.query(GOsaDBObject).filter(GOsaDBObject.properties.any(GOsaDBProperty.value == u'Horst')).all():
    entry = factory.toObject(entry)
コード例 #33
0
 def add_profile_pic(self, db: Session, db_obj: User, image: Image):
     db_obj.profile_pic = image
     db.add(db_obj)
     db.commit()
     db.refresh(db_obj)
     return db_obj.profile_pic
#!/usr/bin/python3
""" This script that changes the name of a State object
from the database hbtn_0e_6_usa.

"""

from sqlalchemy import (create_engine)
from sqlalchemy.orm.session import Session
from model_state import Base, State
import sys

if __name__ == '__main__':
    engine = create_engine('mysql+mysqldb://{}:{}@localhost/{}'.format(
        sys.argv[1], sys.argv[2], sys.argv[3]),
                           pool_pre_ping=True)
    Base.metadata.create_all(engine)
    session = Session(bind=engine)
    a = session.query(State).get(2)
    # Update the name table with the id = 2
    a.name = 'New Mexico'
    session.add(a)
    session.commit()
    session.close()
コード例 #35
0
ファイル: db_helper.py プロジェクト: Fathui/fsqlfly
 def create(cls, model: str, obj: dict, *args, session: Session,
            base: Type[DBT], **kwargs) -> DBRes:
     db_obj = base(**obj)
     session.add(db_obj)
     session.commit()
     return DBRes(data=db_obj.as_dict())
コード例 #36
0
    engine = core.get_database_engine_string()
    logging.info("Using connection string '%s'" % (engine,))
    engine = create_engine(engine, encoding='utf-8', isolation_level="READ COMMITTED")
    logging.info("Binding session...")
    session = Session(bind=engine, autocommit = False)

    # Select the old raw_results 
    sql = "SELECT crawl_id, date_crawled, url, content_type, raw_article_results.status, raw_article_conversions.inserted_id FROM raw_articles JOIN raw_article_results ON raw_article_results.raw_article_id = raw_articles.id JOIN raw_article_conversions ON raw_article_conversions.raw_article_id = raw_articles.id"
    it = session.execute(sql)

    for crawl_id, date_crawled, url, content_type, status, inserted_id in it:

        # Decide if any of these have been comitted
        sub = session.query(RawArticle).filter_by(crawl_id = crawl_id, url = url, content_type = content_type, date_crawled = date_crawled)
        try:
            i = sub.one()
            logging.info("RawArticle %s has already been processed.", i)
            continue 
        except NoResultException:
            pass


        rbase = RawArticle((crawl_id, (None, None, url, date_crawled, content_type)))
        rstat = RawArticleResult(None, status)
        rstat.parent = rbase 
        rrslt = RawArticleResultLink(None, inserted_id)
        rrslt.parent = rbase 
        session.add(rbase)
        session.add(rstat)
        session.add(rrslt)
        session.commit()
コード例 #37
0
def upgrade():
    session = Session(bind=op.get_bind(), expire_on_commit=False)

    # All IRONMAN orgs need a 3 digit version
    IRONMAN_system = 'http://pcctc.org/'

    ironman_org_ids = [(id.id, id._value) for id in Identifier.query.filter(
        Identifier.system == IRONMAN_system).with_entities(
        Identifier.id, Identifier._value)]
    existing_values = [id[1] for id in ironman_org_ids]

    replacements = {}
    for io_id, io_value in ironman_org_ids:
        found = org_pattern.match(io_value)
        if found:
            # avoid probs if run again - don't add if already present
            needed = '146-0{}'.format(found.group(1))
            replacements[found.group(1)] = '0{}'.format(found.group(1))
            if needed not in existing_values:
                needed_i = Identifier(
                    use='secondary', system=IRONMAN_system, _value=needed)
            else:
                needed_i = Identifier.query.filter(
                    Identifier.system == IRONMAN_system).filter(
                    Identifier._value == needed).one()

            # add a 3 digit identifier and link with same org
            oi = OrganizationIdentifier.query.filter(
                OrganizationIdentifier.identifier_id == io_id).one()
            needed_oi = OrganizationIdentifier.query.filter(
                OrganizationIdentifier.organization_id ==
                oi.organization_id).filter(
                OrganizationIdentifier.identifier == needed_i).first()
            if not needed_oi:
                needed_i = session.merge(needed_i)
                needed_oi = OrganizationIdentifier(
                    organization_id=oi.organization_id,
                    identifier=needed_i)
                session.add(needed_oi)

    # All IRONMAN users with a 2 digit ID referencing one of the replaced
    # values needs a 3 digit version

    ironman_study_ids = Identifier.query.filter(
        Identifier.system == TRUENTH_EXTERNAL_STUDY_SYSTEM).filter(
        Identifier._value.like('170-%')).with_entities(
        Identifier.id, Identifier._value)

    for iid, ival in ironman_study_ids:
        found = study_pattern.match(ival)
        if found:
            org_segment = found.group(1)
            patient_segment = found.group(2)

            # only add if also one of the new org ids
            if org_segment not in replacements:
                continue

            needed = '170-{}-{}'.format(
                replacements[org_segment], patient_segment)

            # add a 3 digit identifier and link with same user(s),
            # if not already present
            uis = UserIdentifier.query.filter(
                UserIdentifier.identifier_id == iid)
            needed_i = Identifier.query.filter(
                Identifier.system == TRUENTH_EXTERNAL_STUDY_SYSTEM).filter(
                Identifier._value == needed).first()
            if not needed_i:
                needed_i = Identifier(
                    use='secondary', system=TRUENTH_EXTERNAL_STUDY_SYSTEM,
                    _value=needed)
            for ui in uis:
                needed_ui = UserIdentifier.query.filter(
                    UserIdentifier.user_id == ui.user_id).filter(
                    UserIdentifier.identifier == needed_i).first()
                if not needed_ui:
                    needed_ui = UserIdentifier(
                        user_id=ui.user_id, identifier=needed_i)
                    session.add(needed_ui)

    session.commit()
コード例 #38
0
ファイル: storage.py プロジェクト: rmol/securedrop-client
def update_replies(
    remote_replies: List[SDKReply], local_replies: List[Reply], session: Session, data_dir: str
) -> None:
    """
    * Existing replies are updated in the local database.
    * New replies have an entry created in the local database.
    * Local replies not returned in the remote replies are deleted from the
      local database unless they are pending or failed.

    If a reply references a new journalist username, add them to the database
    as a new user.
    """
    local_replies_by_uuid = {r.uuid: r for r in local_replies}
    users: Dict[str, User] = {}
    source_cache = SourceCache(session)
    for reply in remote_replies:
        user = users.get(reply.journalist_uuid)
        if not user:
            user = find_or_create_user(reply.journalist_uuid, reply.journalist_username, session)
            users[reply.journalist_uuid] = user

        local_reply = local_replies_by_uuid.get(reply.uuid)
        if local_reply:
            lazy_setattr(local_reply, "journalist_id", user.id)
            lazy_setattr(local_reply, "size", reply.size)
            lazy_setattr(local_reply, "filename", reply.filename)

            del local_replies_by_uuid[reply.uuid]
            logger.debug("Updated reply {}".format(reply.uuid))
        else:
            # A new reply to be added to the database.
            source = source_cache.get(reply.source_uuid)
            if not source:
                logger.error(f"No source found for reply {reply.uuid}")
                continue

            nr = Reply(
                uuid=reply.uuid,
                journalist_id=user.id,
                source_id=source.id,
                filename=reply.filename,
                size=reply.size,
            )
            session.add(nr)

            # All replies fetched from the server have succeeded in being sent,
            # so we should delete the corresponding draft locally if it exists.
            try:
                draft_reply_db_object = session.query(DraftReply).filter_by(uuid=reply.uuid).one()

                update_draft_replies(
                    session,
                    draft_reply_db_object.source.id,
                    draft_reply_db_object.timestamp,
                    draft_reply_db_object.file_counter,
                    nr.file_counter,
                    commit=False,
                )
                session.delete(draft_reply_db_object)

            except NoResultFound:
                pass  # No draft locally stored corresponding to this reply.

            logger.debug("Added new reply {}".format(reply.uuid))

    # The uuids remaining in local_uuids do not exist on the remote server, so
    # delete the related records.
    for deleted_reply in local_replies_by_uuid.values():
        delete_single_submission_or_reply_on_disk(deleted_reply, data_dir)
        session.delete(deleted_reply)
        logger.debug("Deleted reply {}".format(deleted_reply.uuid))

    session.commit()
#!/usr/bin/python3
""" This script adds the State object “Louisiana” to the
database hbtn_0e_6_usa.

"""

from sqlalchemy import (create_engine)
from sqlalchemy.orm.session import Session
from model_state import Base, State
import sys

if __name__ == '__main__':
    engine = create_engine('mysql+mysqldb://{}:{}@localhost/{}'.format(
        sys.argv[1], sys.argv[2], sys.argv[3]),
                           pool_pre_ping=True)
    Base.metadata.create_all(engine)
    session = Session(bind=engine)
    # The object was created
    obj1 = State(name='Louisiana')
    # The object created is add
    session.add(obj1)
    # The object was save in the data base
    session.commit()
    print(obj1.id)
    session.close()
コード例 #40
0
ファイル: cli.py プロジェクト: davidbrochart/quetz
def _fill_test_database(db: Session) -> NoReturn:
    """Create dummy users and channels to allow further testing in dev mode."""

    testUsers = []
    try:
        for index, username in enumerate(['alice', 'bob', 'carol', 'dave']):
            user = User(id=uuid.uuid4().bytes, username=username)

            identity = Identity(
                provider='dummy',
                identity_id=str(index),
                username=username,
            )

            profile = Profile(name=username.capitalize(), avatar_url='/avatar.jpg')

            user.identities.append(identity)  # type: ignore
            user.profile = profile
            db.add(user)
            testUsers.append(user)

        for channel_index in range(3):
            channel = Channel(
                name=f'channel{channel_index}',
                description=f'Description of channel{channel_index}',
                private=False,
            )

            for package_index in range(random.randint(5, 10)):
                package = Package(
                    name=f'package{package_index}',
                    summary=f'package {package_index} summary text',
                    description=f'Description of package{package_index}',
                )
                channel.packages.append(package)  # type: ignore

                test_user = testUsers[random.randint(0, len(testUsers) - 1)]
                package_member = PackageMember(
                    package=package, channel=channel, user=test_user, role='owner'
                )

                db.add(package_member)

            if channel_index == 0:
                package = Package(name='xtensor', description='Description of xtensor')
                channel.packages.append(package)  # type: ignore

                test_user = testUsers[random.randint(0, len(testUsers) - 1)]
                package_member = PackageMember(
                    package=package, channel=channel, user=test_user, role='owner'
                )

                db.add(package_member)

                # create API key
                key = uuid.uuid4().hex

                key_user = User(id=uuid.uuid4().bytes)
                api_key = ApiKey(
                    key=key, description='test API key', user=test_user, owner=test_user
                )
                db.add(api_key)
                print(f'Test API key created for user "{test_user.username}": {key}')

                key_package_member = PackageMember(
                    user=key_user,
                    channel_name=channel.name,
                    package_name=package.name,
                    role='maintainer',
                )
                db.add(key_package_member)

            db.add(channel)

            channel_member = ChannelMember(
                channel=channel,
                user=test_user,
                role='owner',
            )

            db.add(channel_member)
        db.commit()
    finally:
        db.close()
コード例 #41
0
class Address(Base):
    __tablename__ = 'address'
    id = Column(Integer, primary_key=True)
    user_id = Column(Integer, ForeignKey('user.id'))
    email_address = Column(String)
    user = relationship(User, backref='addresses')

################################################################################
#                              identity set
#     The session itself acts somewhat like a set-like collection. All items
# present may be accessed using the iterator interface.
################################################################################

user = User()
session.add(user)

for obj in session:
    print(obj)
    
assert user in session
    
# inspect the current state of an object
state = inspect(user)
assert state.persistent

################################################################################
#                              expunge
# Expunging removes an object from the session.
#     Pending instances are sent to the transient state.
#     Persistent instances from flush() are sent to detached state. Qualify
コード例 #42
0
def session_with_data(session: Session) -> Iterator[Session]:
    from superset.columns.models import Column
    from superset.connectors.sqla.models import SqlaTable, TableColumn
    from superset.datasets.models import Dataset
    from superset.models.core import Database
    from superset.models.sql_lab import Query, SavedQuery
    from superset.tables.models import Table

    engine = session.get_bind()
    SqlaTable.metadata.create_all(engine)  # pylint: disable=no-member

    db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")

    columns = [
        TableColumn(column_name="a", type="INTEGER"),
    ]

    sqla_table = SqlaTable(
        table_name="my_sqla_table",
        columns=columns,
        metrics=[],
        database=db,
    )

    query_obj = Query(
        client_id="foo",
        database=db,
        tab_name="test_tab",
        sql_editor_id="test_editor_id",
        sql="select * from bar",
        select_sql="select * from bar",
        executed_sql="select * from bar",
        limit=100,
        select_as_cta=False,
        rows=100,
        error_message="none",
        results_key="abc",
    )

    saved_query = SavedQuery(database=db, sql="select * from foo")

    table = Table(
        name="my_table",
        schema="my_schema",
        catalog="my_catalog",
        database=db,
        columns=[],
    )

    dataset = Dataset(
        database=table.database,
        name="positions",
        expression="""
SELECT array_agg(array[longitude,latitude]) AS position
FROM my_catalog.my_schema.my_table
""",
        tables=[table],
        columns=[
            Column(
                name="position",
                expression="array_agg(array[longitude,latitude])",
            ),
        ],
    )

    session.add(dataset)
    session.add(table)
    session.add(saved_query)
    session.add(query_obj)
    session.add(db)
    session.add(sqla_table)
    session.flush()
    yield session
コード例 #43
0
ファイル: books_example.py プロジェクト: smartkiwi/pugip-demo
    "publication_date": datetime.datetime.strptime("2010-05-05", "%Y-%m-%d"),
    "pages_count": 240,
}
engine.connect().execute(insert_stmt, data)


session = Session(bind=engine)
q = session.query(Book).filter(Book.title == "Essential SQLAlchemy")
print q
book = q.one()
print (book.id, book.title)


author = Author(name="Rick Copeland")
author.books.append(book)
session.add(book)
session.flush()

####
# select CASE WHEN (BOOK.pages_count > 200) THEN 1 ELSE 0 END is_novel, count(*)
# from BOOK
# group by CASE WHEN (BOOK.pages_count > 200) THEN 1 ELSE 0 END
# order by CASE WHEN (BOOK.pages_count > 200) THEN 1 ELSE 0 END
#
is_novel_column = case([(Book.pages_count > 200, 1)], else_=0)
novel_query = (
    session.query(is_novel_column.label("is_alias"), count()).group_by(is_novel_column).order_by(is_novel_column)
)

print novel_query
print novel_query.all()
def upgrade():
    session = Session(bind=op.get_bind())
    session.add(Exchange(name=BxInThExchange.name, is_active=True, weight=15))
    session.flush()
コード例 #45
0
ファイル: dbhelper.py プロジェクト: ericbean/RecordSheet
def setup_module():
    global transaction, connection, engine

    engine = create_engine('postgresql:///recordsheet_test')
    connection = engine.connect()
    transaction = connection.begin()
    Base.metadata.create_all(connection)

    #insert some data
    inner_tr = connection.begin_nested()
    ses = Session(connection)
#    ses.begin_nested()
    ses.add(dbmodel.Account(name='TEST01', desc='test account 01'))
    ses.add(dbmodel.Account(name='TEST02', desc='test account 02'))
    user = dbmodel.User(username='******', name='Test T. User',
                       password=dbapi.new_pw_hash('passtestword'),
                       locked=False)
    ses.add(user)
    lockeduser = dbmodel.User(username='******', name='Test T. User',
                       password=dbapi.new_pw_hash('passtestword'),
                       locked=True)
    ses.add(lockeduser)

    batch = dbmodel.Batch(user=user)
    ses.add(batch)
    jrnl = dbmodel.Journal(memo='test', batch=batch,
                datetime='2016-06-05 14:09:00-05')
    ses.add(jrnl)
    ses.add(dbmodel.Posting(memo="test", amount=100, account_id=1,
                        journal=jrnl))
    ses.add(dbmodel.Posting(memo="test", amount=-100, account_id=2,
                        journal=jrnl))
    ses.commit()
    # mock a sessionmaker so all querys are in this transaction
    dbapi._session = lambda: ses
    ses.begin_nested()

    @event.listens_for(ses, "after_transaction_end")
    def restart_savepoint(session, transaction):
        if transaction.nested and not transaction._parent.nested:
            ses.begin_nested()
コード例 #46
0
ファイル: dagrun.py プロジェクト: weiplanet/airflow
    def verify_integrity(self, session: Session = None):
        """
        Verifies the DagRun by checking for removed tasks or tasks that are not in the
        database yet. It will set state to removed or add the task if required.

        :param session: Sqlalchemy ORM Session
        :type session: Session
        """
        from airflow.settings import task_instance_mutation_hook

        dag = self.get_dag()
        tis = self.get_task_instances(session=session)

        # check for removed or restored tasks
        task_ids = set()
        for ti in tis:
            task_instance_mutation_hook(ti)
            task_ids.add(ti.task_id)
            task = None
            try:
                task = dag.get_task(ti.task_id)
            except AirflowException:
                if ti.state == State.REMOVED:
                    pass  # ti has already been removed, just ignore it
                elif self.state != State.RUNNING and not dag.partial:
                    self.log.warning(
                        "Failed to get task '%s' for dag '%s'. Marking it as removed.",
                        ti, dag)
                    Stats.incr(f"task_removed_from_dag.{dag.dag_id}", 1, 1)
                    ti.state = State.REMOVED

            should_restore_task = (task
                                   is not None) and ti.state == State.REMOVED
            if should_restore_task:
                self.log.info(
                    "Restoring task '%s' which was previously removed from DAG '%s'",
                    ti, dag)
                Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 1, 1)
                ti.state = State.NONE
            session.merge(ti)

        # check for missing tasks
        for task in dag.task_dict.values():
            if task.start_date > self.execution_date and not self.is_backfill:
                continue

            if task.task_id not in task_ids:
                Stats.incr(f"task_instance_created-{task.task_type}", 1, 1)
                ti = TI(task, self.execution_date)
                task_instance_mutation_hook(ti)
                session.add(ti)

        try:
            session.flush()
        except IntegrityError as err:
            self.log.info(str(err))
            self.log.info('Hit IntegrityError while creating the TIs for '
                          f'{dag.dag_id} - {self.execution_date}.')
            self.log.info('Doing session rollback.')
            # TODO[HA]: We probably need to savepoint this so we can keep the transaction alive.
            session.rollback()
コード例 #47
0
class DatabaseTest(object):

    engine = None
    connection = None

    @classmethod
    def get_database_connection(cls):
        url = Configuration.database_url()
        engine, connection = SessionManager.initialize(url)

        return engine, connection

    @classmethod
    def setup_class(cls):
        # Initialize a temporary data directory.
        cls.engine, cls.connection = cls.get_database_connection()
        cls.old_data_dir = Configuration.data_directory
        cls.tmp_data_dir = tempfile.mkdtemp(dir="/tmp")
        Configuration.instance[Configuration.DATA_DIRECTORY] = cls.tmp_data_dir

        # Avoid CannotLoadConfiguration errors related to CDN integrations.
        Configuration.instance[Configuration.INTEGRATIONS] = Configuration.instance.get(
            Configuration.INTEGRATIONS, {}
        )
        Configuration.instance[Configuration.INTEGRATIONS][ExternalIntegration.CDN] = {}

    @classmethod
    def teardown_class(cls):
        # Destroy the database connection and engine.
        cls.connection.close()
        cls.engine.dispose()

        if cls.tmp_data_dir.startswith("/tmp"):
            logging.debug("Removing temporary directory %s" % cls.tmp_data_dir)
            shutil.rmtree(cls.tmp_data_dir)

        else:
            logging.warn("Cowardly refusing to remove 'temporary' directory %s" % cls.tmp_data_dir)

        Configuration.instance[Configuration.DATA_DIRECTORY] = cls.old_data_dir

    def setup(self, mock_search=True):
        # Create a new connection to the database.
        self._db = Session(self.connection)
        self.transaction = self.connection.begin_nested()

        # Start with a high number so it won't interfere with tests that search for an age or grade
        self.counter = 2000

        self.time_counter = datetime(2014, 1, 1)
        self.isbns = [
            "9780674368279", "0636920028468", "9781936460236", "9780316075978"
        ]
        if mock_search:
            self.search_mock = mock.patch(external_search.__name__ + ".ExternalSearchIndex", MockExternalSearchIndex)
            self.search_mock.start()
        else:
            self.search_mock = None

        # TODO:  keeping this for now, but need to fix it bc it hits _isbn,
        # which pops an isbn off the list and messes tests up.  so exclude
        # _ functions from participating.
        # also attempt to stop nosetest showing docstrings instead of function names.
        #for name, obj in inspect.getmembers(self):
        #    if inspect.isfunction(obj) and obj.__name__.startswith('test_'):
        #        obj.__doc__ = None


    def teardown(self):
        # Close the session.
        self._db.close()

        # Roll back all database changes that happened during this
        # test, whether in the session that was just closed or some
        # other session.
        self.transaction.rollback()

        # Remove any database objects cached in the model classes but
        # associated with the now-rolled-back session.
        Collection.reset_cache()
        ConfigurationSetting.reset_cache()
        DataSource.reset_cache()
        DeliveryMechanism.reset_cache()
        ExternalIntegration.reset_cache()
        Genre.reset_cache()
        Library.reset_cache()

        # Also roll back any record of those changes in the
        # Configuration instance.
        for key in [
                Configuration.SITE_CONFIGURATION_LAST_UPDATE,
                Configuration.LAST_CHECKED_FOR_SITE_CONFIGURATION_UPDATE
        ]:
            if key in Configuration.instance:
                del(Configuration.instance[key])

        if self.search_mock:
            self.search_mock.stop()

    def time_eq(self, a, b):
        "Assert that two times are *approximately* the same -- within 2 seconds."
        if a < b:
            delta = b-a
        else:
            delta = a-b
        total_seconds = delta.total_seconds()
        assert (total_seconds < 2), ("Delta was too large: %.2f seconds." % total_seconds)

    def shortDescription(self):
        return None # Stop nosetests displaying docstrings instead of class names when verbosity level >= 2.

    @property
    def _id(self):
        self.counter += 1
        return self.counter

    @property
    def _str(self):
        return unicode(self._id)

    @property
    def _time(self):
        v = self.time_counter
        self.time_counter = self.time_counter + timedelta(days=1)
        return v

    @property
    def _isbn(self):
        return self.isbns.pop()

    @property
    def _url(self):
        return "http://foo.com/" + self._str

    def _patron(self, external_identifier=None, library=None):
        external_identifier = external_identifier or self._str
        library = library or self._default_library
        return get_one_or_create(
            self._db, Patron, external_identifier=external_identifier,
            library=library
        )[0]

    def _contributor(self, sort_name=None, name=None, **kw_args):
        name = sort_name or name or self._str
        return get_one_or_create(self._db, Contributor, sort_name=unicode(name), **kw_args)

    def _identifier(self, identifier_type=Identifier.GUTENBERG_ID, foreign_id=None):
        if foreign_id:
            id = foreign_id
        else:
            id = self._str
        return Identifier.for_foreign_id(self._db, identifier_type, id)[0]

    def _edition(self, data_source_name=DataSource.GUTENBERG,
                 identifier_type=Identifier.GUTENBERG_ID,
                 with_license_pool=False, with_open_access_download=False,
                 title=None, language="eng", authors=None, identifier_id=None,
                 series=None, collection=None, publicationDate=None
    ):
        id = identifier_id or self._str
        source = DataSource.lookup(self._db, data_source_name)
        wr = Edition.for_foreign_id(
            self._db, source, identifier_type, id)[0]
        if not title:
            title = self._str
        wr.title = unicode(title)
        wr.medium = Edition.BOOK_MEDIUM
        if series:
            wr.series = series
        if language:
            wr.language = language
        if authors is None:
            authors = self._str
        if isinstance(authors, basestring):
            authors = [authors]
        if authors != []:
            wr.add_contributor(unicode(authors[0]), Contributor.PRIMARY_AUTHOR_ROLE)
            wr.author = unicode(authors[0])
        for author in authors[1:]:
            wr.add_contributor(unicode(author), Contributor.AUTHOR_ROLE)
        if publicationDate:
            wr.published = publicationDate

        if with_license_pool or with_open_access_download:
            pool = self._licensepool(
                wr, data_source_name=data_source_name,
                with_open_access_download=with_open_access_download,
                collection=collection
            )

            pool.set_presentation_edition()
            return wr, pool
        return wr

    def _work(self, title=None, authors=None, genre=None, language=None,
              audience=None, fiction=True, with_license_pool=False,
              with_open_access_download=False, quality=0.5, series=None,
              presentation_edition=None, collection=None, data_source_name=None):
        """Create a Work.

        For performance reasons, this method does not generate OPDS
        entries or calculate a presentation edition for the new
        Work. Tests that rely on this information being present
        should call _slow_work() instead, which takes more care to present
        the sort of Work that would be created in a real environment.
        """
        pools = []
        if with_open_access_download:
            with_license_pool = True
        language = language or "eng"
        title = unicode(title or self._str)
        audience = audience or Classifier.AUDIENCE_ADULT
        if audience == Classifier.AUDIENCE_CHILDREN and not data_source_name:
            # TODO: This is necessary because Gutenberg's childrens books
            # get filtered out at the moment.
            data_source_name = DataSource.OVERDRIVE
        elif not data_source_name:
            data_source_name = DataSource.GUTENBERG
        if fiction is None:
            fiction = True
        new_edition = False
        if not presentation_edition:
            new_edition = True
            presentation_edition = self._edition(
                title=title, language=language,
                authors=authors,
                with_license_pool=with_license_pool,
                with_open_access_download=with_open_access_download,
                data_source_name=data_source_name,
                series=series,
                collection=collection,
            )
            if with_license_pool:
                presentation_edition, pool = presentation_edition
                if with_open_access_download:
                    pool.open_access = True
                pools = [pool]
        else:
            pools = presentation_edition.license_pools
        work, ignore = get_one_or_create(
            self._db, Work, create_method_kwargs=dict(
                audience=audience,
                fiction=fiction,
                quality=quality), id=self._id)
        if genre:
            if not isinstance(genre, Genre):
                genre, ignore = Genre.lookup(self._db, genre, autocreate=True)
            work.genres = [genre]
        work.random = 0.5
        work.set_presentation_edition(presentation_edition)

        if pools:
            # make sure the pool's presentation_edition is set,
            # bc loan tests assume that.
            if not work.license_pools:
                for pool in pools:
                    work.license_pools.append(pool)

            for pool in pools:
                pool.set_presentation_edition()

            # This is probably going to be used in an OPDS feed, so
            # fake that the work is presentation ready.
            work.presentation_ready = True
            work.calculate_opds_entries(verbose=False)

        return work

    def add_to_materialized_view(self, works, true_opds=False):
        """Make sure all the works in `works` show up in the materialized view.

        :param true_opds: Generate real OPDS entries for each each work,
        rather than faking it.
        """
        if not isinstance(works, list):
            works = [works]
        for work in works:
            if true_opds:
                work.calculate_opds_entries(verbose=False)
            else:
                work.presentation_ready = True
                work.simple_opds_entry = "<entry>an entry</entry>"
        self._db.commit()
        SessionManager.refresh_materialized_views(self._db)

    def _lane(self, display_name=None, library=None,
              parent=None, genres=None, languages=None,
              fiction=None
    ):
        display_name = display_name or self._str
        library = library or self._default_library
        lane, is_new = create(
            self._db, Lane,
            library=library,
            parent=parent, display_name=display_name,
            fiction=fiction
        )
        if is_new and parent:
            lane.priority = len(parent.sublanes)-1
        if genres:
            if not isinstance(genres, list):
                genres = [genres]
            for genre in genres:
                if isinstance(genre, basestring):
                    genre, ignore = Genre.lookup(self._db, genre)
                lane.genres.append(genre)
        if languages:
            if not isinstance(languages, list):
                languages = [languages]
            lane.languages = languages
        return lane

    def _slow_work(self, *args, **kwargs):
        """Create a work that closely resembles one that might be found in the
        wild.

        This is significantly slower than _work() but more reliable.
        """
        work = self._work(*args, **kwargs)
        work.calculate_presentation_edition()
        work.calculate_opds_entries(verbose=False)
        return work

    def _add_generic_delivery_mechanism(self, license_pool):
        """Give a license pool a generic non-open-access delivery mechanism."""
        data_source = license_pool.data_source
        identifier = license_pool.identifier
        content_type = Representation.EPUB_MEDIA_TYPE
        drm_scheme = DeliveryMechanism.NO_DRM
        LicensePoolDeliveryMechanism.set(
            data_source, identifier, content_type, drm_scheme,
            RightsStatus.IN_COPYRIGHT
        )

    def _coverage_record(self, edition, coverage_source, operation=None,
        status=CoverageRecord.SUCCESS, collection=None, exception=None,
    ):
        if isinstance(edition, Identifier):
            identifier = edition
        else:
            identifier = edition.primary_identifier
        record, ignore = get_one_or_create(
            self._db, CoverageRecord,
            identifier=identifier,
            data_source=coverage_source,
            operation=operation,
            collection=collection,
            create_method_kwargs = dict(
                timestamp=datetime.utcnow(),
                status=status,
                exception=exception,
            )
        )
        return record

    def _work_coverage_record(self, work, operation=None,
                              status=CoverageRecord.SUCCESS):
        record, ignore = get_one_or_create(
            self._db, WorkCoverageRecord,
            work=work,
            operation=operation,
            create_method_kwargs = dict(
                timestamp=datetime.utcnow(),
                status=status,
            )
        )
        return record

    def _licensepool(self, edition, open_access=True,
                     data_source_name=DataSource.GUTENBERG,
                     with_open_access_download=False,
                     set_edition_as_presentation=False,
                     collection=None):
        source = DataSource.lookup(self._db, data_source_name)
        if not edition:
            edition = self._edition(data_source_name)
        collection = collection or self._default_collection
        pool, ignore = get_one_or_create(
            self._db, LicensePool,
            create_method_kwargs=dict(
                open_access=open_access),
            identifier=edition.primary_identifier,
            data_source=source,
            collection=collection,
            availability_time=datetime.utcnow()
        )

        if set_edition_as_presentation:
            pool.presentation_edition = edition

        if with_open_access_download:
            pool.open_access = True
            url = "http://foo.com/" + self._str
            media_type = MediaTypes.EPUB_MEDIA_TYPE
            link, new = pool.identifier.add_link(
                Hyperlink.OPEN_ACCESS_DOWNLOAD, url,
                source, media_type
            )

            # Add a DeliveryMechanism for this download
            pool.set_delivery_mechanism(
                media_type,
                DeliveryMechanism.NO_DRM,
                RightsStatus.GENERIC_OPEN_ACCESS,
                link.resource,
            )

            representation, is_new = self._representation(
                url, media_type, "Dummy content", mirrored=True)
            link.resource.representation = representation
        else:

            # Add a DeliveryMechanism for this licensepool
            pool.set_delivery_mechanism(
                MediaTypes.EPUB_MEDIA_TYPE,
                DeliveryMechanism.ADOBE_DRM,
                RightsStatus.UNKNOWN,
                None
            )
            pool.licenses_owned = pool.licenses_available = 1

        return pool

    def _license(self, pool, identifier=None, checkout_url=None, status_url=None,
                 expires=None, remaining_checkouts=None, concurrent_checkouts=None):
        identifier = identifier or self._str
        checkout_url = checkout_url or self._str
        status_url = status_url or self._str
        license, ignore = get_one_or_create(
            self._db, License, identifier=identifier, license_pool=pool,
            checkout_url=checkout_url,
            status_url=status_url, expires=expires,
            remaining_checkouts=remaining_checkouts,
            concurrent_checkouts=concurrent_checkouts,
        )
        return license

    def _representation(self, url=None, media_type=None, content=None,
                        mirrored=False):
        url = url or "http://foo.com/" + self._str
        repr, is_new = get_one_or_create(
            self._db, Representation, url=url)
        repr.media_type = media_type
        if media_type and content:
            repr.content = content
            repr.fetched_at = datetime.utcnow()
            if mirrored:
                repr.mirror_url = "http://foo.com/" + self._str
                repr.mirrored_at = datetime.utcnow()
        return repr, is_new

    def _customlist(self, foreign_identifier=None,
                    name=None,
                    data_source_name=DataSource.NYT, num_entries=1,
                    entries_exist_as_works=True
    ):
        data_source = DataSource.lookup(self._db, data_source_name)
        foreign_identifier = foreign_identifier or self._str
        now = datetime.utcnow()
        customlist, ignore = get_one_or_create(
            self._db, CustomList,
            create_method_kwargs=dict(
                created=now,
                updated=now,
                name=name or self._str,
                description=self._str,
                ),
            data_source=data_source,
            foreign_identifier=foreign_identifier
        )

        editions = []
        for i in range(num_entries):
            if entries_exist_as_works:
                work = self._work(with_open_access_download=True)
                edition = work.presentation_edition
            else:
                edition = self._edition(
                    data_source_name, title="Item %s" % i)
                edition.permanent_work_id="Permanent work ID %s" % self._str
            customlist.add_entry(
                edition, "Annotation %s" % i, first_appearance=now)
            editions.append(edition)
        return customlist, editions

    def _complaint(self, license_pool, type, source, detail, resolved=None):
        complaint, is_new = Complaint.register(
            license_pool,
            type,
            source,
            detail,
            resolved
        )
        return complaint

    def _credential(self, data_source_name=DataSource.GUTENBERG,
                    type=None, patron=None):
        data_source = DataSource.lookup(self._db, data_source_name)
        type = type or self._str
        patron = patron or self._patron()
        credential, is_new = Credential.persistent_token_create(
            self._db, data_source, type, patron
        )
        return credential

    def _external_integration(self, protocol, goal=None, settings=None,
                              libraries=None, **kwargs
    ):
        integration = None
        if not libraries:
            integration, ignore = get_one_or_create(
                self._db, ExternalIntegration, protocol=protocol, goal=goal
            )
        else:
            if not isinstance(libraries, list):
                libraries = [libraries]

            # Try to find an existing integration for one of the given
            # libraries.
            for library in libraries:
                integration = ExternalIntegration.lookup(
                    self._db, protocol, goal, library=libraries[0]
                )
                if integration:
                    break

            if not integration:
                # Otherwise, create a brand new integration specifically
                # for the library.
                integration = ExternalIntegration(
                    protocol=protocol, goal=goal,
                )
                integration.libraries.extend(libraries)
                self._db.add(integration)

        for attr, value in kwargs.items():
            setattr(integration, attr, value)

        settings = settings or dict()
        for key, value in settings.items():
            integration.set_setting(key, value)

        return integration

    def _delegated_patron_identifier(
            self, library_uri=None, patron_identifier=None,
            identifier_type=DelegatedPatronIdentifier.ADOBE_ACCOUNT_ID,
            identifier=None
    ):
        """Create a sample DelegatedPatronIdentifier"""
        library_uri = library_uri or self._url
        patron_identifier = patron_identifier or self._str
        if callable(identifier):
            make_id = identifier
        else:
            if not identifier:
                identifier = self._str
            def make_id():
                return identifier
        patron, is_new = DelegatedPatronIdentifier.get_one_or_create(
            self._db, library_uri, patron_identifier, identifier_type,
            make_id
        )
        return patron

    def _sample_ecosystem(self):
        """ Creates an ecosystem of some sample work, pool, edition, and author
        objects that all know each other.
        """
        # make some authors
        [bob], ignore = Contributor.lookup(self._db, u"Bitshifter, Bob")
        bob.family_name, bob.display_name = bob.default_names()
        [alice], ignore = Contributor.lookup(self._db, u"Adder, Alice")
        alice.family_name, alice.display_name = alice.default_names()

        edition_std_ebooks, pool_std_ebooks = self._edition(DataSource.STANDARD_EBOOKS, Identifier.URI,
            with_license_pool=True, with_open_access_download=True, authors=[])
        edition_std_ebooks.title = u"The Standard Ebooks Title"
        edition_std_ebooks.subtitle = u"The Standard Ebooks Subtitle"
        edition_std_ebooks.add_contributor(alice, Contributor.AUTHOR_ROLE)

        edition_git, pool_git = self._edition(DataSource.PROJECT_GITENBERG, Identifier.GUTENBERG_ID,
            with_license_pool=True, with_open_access_download=True, authors=[])
        edition_git.title = u"The GItenberg Title"
        edition_git.subtitle = u"The GItenberg Subtitle"
        edition_git.add_contributor(bob, Contributor.AUTHOR_ROLE)
        edition_git.add_contributor(alice, Contributor.AUTHOR_ROLE)

        edition_gut, pool_gut = self._edition(DataSource.GUTENBERG, Identifier.GUTENBERG_ID,
            with_license_pool=True, with_open_access_download=True, authors=[])
        edition_gut.title = u"The GUtenberg Title"
        edition_gut.subtitle = u"The GUtenberg Subtitle"
        edition_gut.add_contributor(bob, Contributor.AUTHOR_ROLE)

        work = self._work(presentation_edition=edition_git)

        for p in pool_gut, pool_std_ebooks:
            work.license_pools.append(p)

        work.calculate_presentation()

        return (work, pool_std_ebooks, pool_git, pool_gut,
            edition_std_ebooks, edition_git, edition_gut, alice, bob)


    def print_database_instance(self):
        """
        Calls the class method that examines the current state of the database model
        (whether it's been committed or not).

        NOTE:  If you set_trace, and hit "continue", you'll start seeing console output right
        away, without waiting for the whole test to run and the standard output section to display.
        You can also use nosetest --nocapture.
        I use:
        def test_name(self):
            [code...]
            set_trace()
            self.print_database_instance()  # TODO: remove before prod
            [code...]
        """
        if not 'TESTING' in os.environ:
            # we are on production, abort, abort!
            logging.warn("Forgot to remove call to testing.py:DatabaseTest.print_database_instance() before pushing to production.")
            return

        DatabaseTest.print_database_class(self._db)
        return


    @classmethod
    def print_database_class(cls, db_connection):
        """
        Prints to the console the entire contents of the database, as the unit test sees it.
        Exists because unit tests don't persist db information, they create a memory
        representation of the db state, and then roll the unit test-derived transactions back.
        So we cannot see what's going on by going into postgres and running selects.
        This is the in-test alternative to going into postgres.

        Can be called from model and metadata classes as well as tests.

        NOTE: The purpose of this method is for debugging.
        Be careful of leaving it in code and potentially outputting
        vast tracts of data into your output stream on production.

        Call like this:
        set_trace()
        from testing import (
            DatabaseTest,
        )
        _db = Session.object_session(self)
        DatabaseTest.print_database_class(_db)  # TODO: remove before prod
        """
        if not 'TESTING' in os.environ:
            # we are on production, abort, abort!
            logging.warn("Forgot to remove call to testing.py:DatabaseTest.print_database_class() before pushing to production.")
            return

        works = db_connection.query(Work).all()
        identifiers = db_connection.query(Identifier).all()
        license_pools = db_connection.query(LicensePool).all()
        editions = db_connection.query(Edition).all()
        data_sources = db_connection.query(DataSource).all()
        representations = db_connection.query(Representation).all()

        if (not works):
            print "NO Work found"
        for wCount, work in enumerate(works):
            # pipe character at end of line helps see whitespace issues
            print "Work[%s]=%s|" % (wCount, work)

            if (not work.license_pools):
                print "    NO Work.LicensePool found"
            for lpCount, license_pool in enumerate(work.license_pools):
                print "    Work.LicensePool[%s]=%s|" % (lpCount, license_pool)

            print "    Work.presentation_edition=%s|" % work.presentation_edition

        print "__________________________________________________________________\n"
        if (not identifiers):
            print "NO Identifier found"
        for iCount, identifier in enumerate(identifiers):
            print "Identifier[%s]=%s|" % (iCount, identifier)
            print "    Identifier.licensed_through=%s|" % identifier.licensed_through

        print "__________________________________________________________________\n"
        if (not license_pools):
            print "NO LicensePool found"
        for index, license_pool in enumerate(license_pools):
            print "LicensePool[%s]=%s|" % (index, license_pool)
            print "    LicensePool.work_id=%s|" % license_pool.work_id
            print "    LicensePool.data_source_id=%s|" % license_pool.data_source_id
            print "    LicensePool.identifier_id=%s|" % license_pool.identifier_id
            print "    LicensePool.presentation_edition_id=%s|" % license_pool.presentation_edition_id
            print "    LicensePool.superceded=%s|" % license_pool.superceded
            print "    LicensePool.suppressed=%s|" % license_pool.suppressed

        print "__________________________________________________________________\n"
        if (not editions):
            print "NO Edition found"
        for index, edition in enumerate(editions):
            # pipe character at end of line helps see whitespace issues
            print "Edition[%s]=%s|" % (index, edition)
            print "    Edition.primary_identifier_id=%s|" % edition.primary_identifier_id
            print "    Edition.permanent_work_id=%s|" % edition.permanent_work_id
            if (edition.data_source):
                print "    Edition.data_source.id=%s|" % edition.data_source.id
                print "    Edition.data_source.name=%s|" % edition.data_source.name
            else:
                print "    No Edition.data_source."
            if (edition.license_pool):
                print "    Edition.license_pool.id=%s|" % edition.license_pool.id
            else:
                print "    No Edition.license_pool."

            print "    Edition.title=%s|" % edition.title
            print "    Edition.author=%s|" % edition.author
            if (not edition.author_contributors):
                print "    NO Edition.author_contributor found"
            for acCount, author_contributor in enumerate(edition.author_contributors):
                print "    Edition.author_contributor[%s]=%s|" % (acCount, author_contributor)

        print "__________________________________________________________________\n"
        if (not data_sources):
            print "NO DataSource found"
        for index, data_source in enumerate(data_sources):
            print "DataSource[%s]=%s|" % (index, data_source)
            print "    DataSource.id=%s|" % data_source.id
            print "    DataSource.name=%s|" % data_source.name
            print "    DataSource.offers_licenses=%s|" % data_source.offers_licenses
            print "    DataSource.editions=%s|" % data_source.editions
            print "    DataSource.license_pools=%s|" % data_source.license_pools
            print "    DataSource.links=%s|" % data_source.links

        print "__________________________________________________________________\n"
        if (not representations):
            print "NO Representation found"
        for index, representation in enumerate(representations):
            print "Representation[%s]=%s|" % (index, representation)
            print "    Representation.id=%s|" % representation.id
            print "    Representation.url=%s|" % representation.url
            print "    Representation.mirror_url=%s|" % representation.mirror_url
            print "    Representation.fetch_exception=%s|" % representation.fetch_exception
            print "    Representation.mirror_exception=%s|" % representation.mirror_exception

        return


    def _library(self, name=None, short_name=None):
        name=name or self._str
        short_name = short_name or self._str
        library, ignore = get_one_or_create(
            self._db, Library, name=name, short_name=short_name,
            create_method_kwargs=dict(uuid=str(uuid.uuid4())),
        )
        return library

    def _collection(self, name=None, protocol=ExternalIntegration.OPDS_IMPORT,
                    external_account_id=None, url=None, username=None,
                    password=None, data_source_name=None):
        name = name or self._str
        collection, ignore = get_one_or_create(
            self._db, Collection, name=name
        )
        collection.external_account_id = external_account_id
        integration = collection.create_external_integration(protocol)
        integration.goal = ExternalIntegration.LICENSE_GOAL
        integration.url = url
        integration.username = username
        integration.password = password

        if data_source_name:
            collection.data_source = data_source_name
        return collection

    @property
    def _default_library(self):
        """A Library that will only be created once throughout a given test.

        By default, the `_default_collection` will be associated with
        the default library.
        """
        if not hasattr(self, '_default__library'):
            self._default__library = self.make_default_library(self._db)
        return self._default__library

    @property
    def _default_collection(self):
        """A Collection that will only be created once throughout
        a given test.

        For most tests there's no need to create a different
        Collection for every LicensePool. Using
        self._default_collection instead of calling self.collection()
        saves time.
        """
        if not hasattr(self, '_default__collection'):
            self._default__collection = self._default_library.collections[0]
        return self._default__collection

    @classmethod
    def make_default_library(cls, _db):
        """Ensure that the default library exists in the given database.

        This can be called by code intended for use in testing but not actually
        within a DatabaseTest subclass.
        """
        library, ignore = get_one_or_create(
            _db, Library, create_method_kwargs=dict(
                uuid=unicode(uuid.uuid4()),
                name="default",
            ), short_name="default"
        )
        collection, ignore = get_one_or_create(
            _db, Collection, name="Default Collection"
        )
        integration = collection.create_external_integration(
            ExternalIntegration.OPDS_IMPORT
        )
        integration.goal = ExternalIntegration.LICENSE_GOAL
        if collection not in library.collections:
            library.collections.append(collection)
        return library

    def _catalog(self, name=u"Faketown Public Library"):
        source, ignore = get_one_or_create(self._db, DataSource, name=name)

    def _integration_client(self, url=None, shared_secret=None):
        url = url or self._url
        secret = shared_secret or u"secret"
        return get_one_or_create(
            self._db, IntegrationClient, shared_secret=secret,
            create_method_kwargs=dict(url=url)
        )[0]

    def _subject(self, type, identifier):
        return get_one_or_create(
            self._db, Subject, type=type, identifier=identifier
        )[0]

    def _classification(self, identifier, subject, data_source, weight=1):
        return get_one_or_create(
            self._db, Classification, identifier=identifier, subject=subject,
            data_source=data_source, weight=weight
        )[0]

    def sample_cover_path(self, name):
        """The path to the sample cover with the given filename."""
        base_path = os.path.split(__file__)[0]
        resource_path = os.path.join(base_path, "tests", "files", "covers")
        sample_cover_path = os.path.join(resource_path, name)
        return sample_cover_path

    def sample_cover_representation(self, name):
        """A Representation of the sample cover with the given filename."""
        sample_cover_path = self.sample_cover_path(name)
        return self._representation(
            media_type="image/png", content=open(sample_cover_path).read())[0]
コード例 #48
0
ファイル: db.py プロジェクト: kushsharma/airflow
def merge_conn(conn, session: Session = NEW_SESSION):
    """Add new Connection."""
    if not session.query(Connection).filter(
            Connection.conn_id == conn.conn_id).first():
        session.add(conn)
        session.commit()
コード例 #49
0
class testTradingCenter(unittest.TestCase):

    def setUp(self):
        self.trans = connection.begin()
        self.session = Session(connection)

        currency = Currency(name='Pesos', code='ARG')
        self.exchange = Exchange(name='Merval', code='MERV', currency=currency)
        self.owner = Owner(name='poor owner')
        self.broker = Broker(name='broker1')
        self.account = Account(owner=self.owner, broker=self.broker)
        self.account.deposit(Money(amount=10000, currency=currency))

    def tearDown(self):
        self.trans.rollback()
        self.session.close()

    def test_open_orders_by_order_id(self):
        stock=Stock(symbol='symbol', description='a stock', ISIN='US123456789', exchange=self.exchange)
        order1=BuyOrder(account=self.account, security=stock, price=13.2, share=10)
        order2=BuyOrder(account=self.account, security=stock, price=13.25, share=10)
        self.session.add(order1)
        self.session.add(order2)
        self.session.commit()

        tc=TradingCenter(self.session)
        order=tc.open_order_by_id(order1.id)
        self.assertEquals(order1, order)

        order=tc.open_order_by_id(100)
        self.assertEquals(None, order)

    def testGetOpenOrdersBySymbol(self):

        stock=Stock(symbol='symbol', description='a stock', ISIN='US123456789', exchange=self.exchange)
        order1=BuyOrder(account=self.account, security=stock, price=13.2, share=10)
        order2=BuyOrder(account=self.account, security=stock, price=13.25, share=10)
        self.session.add(order1)
        self.session.add(order2)
        self.session.commit()

        tc=TradingCenter(self.session)
        orders=tc.open_orders_by_symbol('symbol')
        self.assertEquals([order1, order2], list(orders))

    def testCancelOrder(self):

        stock=Stock(symbol='symbol', description='a stock', ISIN='US123456789', exchange=self.exchange)
        order1=BuyOrder(account=self.account, security=stock, price=13.2, share=10)
        order2=BuyOrder(account=self.account, security=stock, price=13.25, share=10)

        self.session.add(order1)
        self.session.add(order2)
        self.session.commit()

        tc=TradingCenter(self.session)

        order1.cancel()
        self.assertEquals([order2], tc.open_orders)
        self.assertEquals([order1], tc.cancel_orders)
        self.assertEquals(CancelOrderStage, type(order1.current_stage))

        order2.cancel()
        self.assertEquals([], tc.open_orders)
        self.assertEquals([order1, order2], tc.cancel_orders)

    def testCancelAllOpenOrders(self):
        security=Stock(symbol='symbol', description='a stock', ISIN='US123456789', exchange=self.exchange)
        order1=BuyOrder(account=self.account, security=security, price=13.2, share=10)
        order2=BuyOrder(account=self.account, security=security, price=13.25, share=10)

        self.session.add(order1)
        self.session.add(order2)
        self.session.commit()

        tc=TradingCenter(self.session)

        tc.cancel_all_open_orders()

        self.assertEquals([], tc.open_orders)

    def testConsume(self):
        pass

    def testPostConsume(self):
        pass

    def testCreateAccountWithMetrix(self):
        pass
コード例 #50
0
ファイル: session.py プロジェクト: BigData-Tools/everest
 def add(self, entity):
     self.begin()
     SaSession.add(self, entity)
     self.commit()
コード例 #51
0
    # Database connection
    engine = core.get_database_engine_string()
    logging.info("Using connection string '%s'" % (engine,))
    engine = create_engine(engine, encoding='utf-8', isolation_level = 'READ COMMITTED')
    logging.info("Binding session...")
    session = Session(bind=engine, autocommit = False)

    #
    # Initial query object

    it = session.query(UserQuery).filter_by(text = query_text)
    try:
        q = it.one()
    except NoResultFound:
        q = UserQuery(query_text)
        session.add(q)
        session.commit()
    assert q.id is not None 
    logging.info("Resolving query id %d with text '%s'", q.id, q.text)

    #
    # Query keyword resolution
    keywords = set([])
    using_keywords = False 
    raw_keyword_count = 0
    for k in q.get_keywords(q.text):
        logging.debug("keyword: %s", k)
        using_keywords = True 
        raw_keyword_count += 1
        it = session.query(Keyword).filter(Keyword.word.like("%s" % (k,)))
        keywords.update(it)
コード例 #52
0
ファイル: service.py プロジェクト: LostFan123/alcor
def draw(*, group_id: uuid.UUID, filtration_method: str,
         nullify_radial_velocity: bool, with_luminosity_function: bool,
         with_velocities_vs_magnitude: bool, with_velocity_clouds: bool,
         lepine_criterion: bool, heatmaps_axes: str, with_toomre_diagram: bool,
         with_ugriz_diagrams: bool, desired_stars_count: int,
         session: Session) -> None:
    entities = star_query_entities(
        filtration_method=filtration_method,
        nullify_radial_velocity=nullify_radial_velocity,
        lepine_criterion=lepine_criterion,
        with_luminosity_function=with_luminosity_function,
        with_velocities_vs_magnitude=with_velocities_vs_magnitude,
        with_velocity_clouds=with_velocity_clouds,
        heatmaps_axes=heatmaps_axes,
        with_toomre_diagram=with_toomre_diagram,
        with_ugriz_diagrams=with_ugriz_diagrams)

    if not entities:
        raise ValueError('No plotting options were chosen')

    query = (session.query(Star).filter(Star.group_id == group_id))

    if desired_stars_count:
        query = (query.order_by(func.random()).limit(desired_stars_count))

    query = query.with_entities(*entities)

    statement = query.statement

    stars = pd.read_sql_query(sql=statement,
                              con=session.get_bind(),
                              index_col='id')

    filtration_functions = stars_filtration_functions(method=filtration_method)
    eliminations_counter = stars_eliminations_counter(
        stars, filtration_functions=filtration_functions, group_id=group_id)
    session.add(eliminations_counter)
    session.commit()

    stars = filtered_stars(stars, filtration_functions=filtration_functions)

    if nullify_radial_velocity:
        set_radial_velocity_to_zero(stars)

    if with_luminosity_function:
        luminosity_function.plot(stars=stars)

    if with_velocities_vs_magnitude:
        if lepine_criterion:
            velocities_vs_magnitude.plot_lepine_case(stars=stars)
        else:
            velocities_vs_magnitude.plot(stars=stars)

    if with_velocity_clouds:
        if lepine_criterion:
            velocity_clouds.plot_lepine_case(stars=stars)
        else:
            velocity_clouds.plot(stars=stars)

    if heatmaps_axes:
        heatmaps.plot(stars=stars, axes=heatmaps_axes)

    if with_toomre_diagram:
        toomre_diagram.plot(stars=stars)

    if with_ugriz_diagrams:
        ugriz_diagrams.plot(stars=stars)
コード例 #53
0
ファイル: session.py プロジェクト: GEverding/inbox
class InboxSession(object):
    """ Inbox custom ORM (with SQLAlchemy compatible API).

    Parameters
    ----------
    engine : <sqlalchemy.engine.Engine>
        A configured database engine to use for this session
    versioned : bool
        Do you want to enable the transaction log?
    ignore_soft_deletes : bool
        Whether or not to ignore soft-deleted objects in query results.
    namespace_id : int
        Namespace to limit query results with.
    """
    def __init__(self, engine, versioned=True, ignore_soft_deletes=True,
                 namespace_id=None):
        # TODO: support limiting on namespaces
        assert engine, "Must set the database engine"

        args = dict(bind=engine, autoflush=True, autocommit=False)
        self.ignore_soft_deletes = ignore_soft_deletes
        if ignore_soft_deletes:
            args['query_cls'] = InboxQuery
        self._session = Session(**args)

        if versioned:
            from inbox.models.transaction import create_revisions

            @event.listens_for(self._session, 'after_flush')
            def after_flush(session, flush_context):
                """
                Hook to log revision snapshots. Must be post-flush in order to
                grab object IDs on new objects.
                """
                create_revisions(session)

    def query(self, *args, **kwargs):
        q = self._session.query(*args, **kwargs)
        if self.ignore_soft_deletes:
            return q.options(IgnoreSoftDeletesOption())
        else:
            return q

    def add(self, instance):
        if not self.ignore_soft_deletes or not instance.is_deleted:
            self._session.add(instance)
        else:
            raise Exception("Why are you adding a deleted object?")

    def add_all(self, instances):
        if True not in [i.is_deleted for i in instances] or \
                not self.ignore_soft_deletes:
            self._session.add_all(instances)
        else:
            raise Exception("Why are you adding a deleted object?")

    def delete(self, instance):
        if self.ignore_soft_deletes:
            instance.mark_deleted()
            # just to make sure
            self._session.add(instance)
        else:
            self._session.delete(instance)

    def begin(self):
        self._session.begin()

    def commit(self):
        self._session.commit()

    def rollback(self):
        self._session.rollback()

    def flush(self):
        self._session.flush()

    def close(self):
        self._session.close()

    def expunge(self, obj):
        self._session.expunge(obj)

    def merge(self, obj):
        return self._session.merge(obj)

    @property
    def no_autoflush(self):
        return self._session.no_autoflush
コード例 #54
0
    def call_api(self, api_client: API, session: Session) -> str:
        '''
        Override ApiJob.

        Encrypt the reply and send it to the server. If the call is successful, add it to the local
        database and return the reply uuid string. Otherwise raise a SendReplyJobException so that
        we can return the reply uuid.
        '''

        try:
            # If the reply has already made it to the server but we didn't get a 201 response back,
            # then a reply with self.reply_uuid will exist in the replies table.
            reply_db_object = session.query(Reply).filter_by(
                uuid=self.reply_uuid).one_or_none()
            if reply_db_object:
                logger.debug(
                    'Reply {} has already been sent successfully'.format(
                        self.reply_uuid))
                return reply_db_object.uuid

            # If the draft does not exist because it was deleted locally then do not send the
            # message to the source.
            draft_reply_db_object = session.query(DraftReply).filter_by(
                uuid=self.reply_uuid).one_or_none()
            if not draft_reply_db_object:
                raise Exception('Draft reply {} does not exist'.format(
                    self.reply_uuid))

            # If the source was deleted locally then do not send the message and delete the draft.
            source = session.query(Source).filter_by(
                uuid=self.source_uuid).one_or_none()
            if not source:
                session.delete(draft_reply_db_object)
                session.commit()
                raise Exception('Source {} does not exists'.format(
                    self.source_uuid))

            # Send the draft reply to the source
            encrypted_reply = self.gpg.encrypt_to_source(
                self.source_uuid, self.message)
            sdk_reply = self._make_call(encrypted_reply, api_client)

            # Create a new reply object with an updated filename and file counter
            interaction_count = source.interaction_count + 1
            filename = '{}-{}-reply.gpg'.format(interaction_count,
                                                source.journalist_designation)
            reply_db_object = Reply(
                uuid=self.reply_uuid,
                source_id=source.id,
                filename=filename,
                journalist_id=api_client.token_journalist_uuid,
                content=self.message,
                is_downloaded=True,
                is_decrypted=True)
            new_file_counter = int(sdk_reply.filename.split('-')[0])
            reply_db_object.file_counter = new_file_counter
            reply_db_object.filename = sdk_reply.filename

            # Update following draft replies for the same source to reflect the new reply count
            draft_file_counter = draft_reply_db_object.file_counter
            draft_timestamp = draft_reply_db_object.timestamp
            update_draft_replies(session, source.id, draft_timestamp,
                                 draft_file_counter, new_file_counter)

            # Add reply to replies table and increase the source interaction count by 1 and delete
            # the draft reply.
            session.add(reply_db_object)
            source.interaction_count += 1
            session.add(source)
            session.delete(draft_reply_db_object)
            session.commit()

            return reply_db_object.uuid
        except (RequestTimeoutError, ServerConnectionError) as e:
            message = "Failed to send reply for source {id} due to Exception: {error}".format(
                id=self.source_uuid, error=e)
            raise SendReplyJobTimeoutError(message, self.reply_uuid)
        except Exception as e:
            # Continue to store the draft reply
            message = '''
                Failed to send reply {uuid} for source {id} due to Exception: {error}
            '''.format(uuid=self.reply_uuid, id=self.source_uuid, error=e)
            self._set_status_to_failed(session)
            raise SendReplyJobError(message, self.reply_uuid)
コード例 #55
0
ファイル: session.py プロジェクト: helixyte/everest
 def __add(self, entity, path): # pylint: disable=W0613
     if len(path) == 0:
         SaSession.add(self, entity)
コード例 #56
0
ファイル: manager.py プロジェクト: BtbN/esaupload
def add_video(session: Session, args):
    vid = session.query(Video).filter(Video.title == args.title).first()
    if vid:
        print("Video with same title already exists, checking status...")
        if not args.force and (vid.youtube_status or vid.twitch_status or vid.processing_status):
            print("Video processing has already started, aborting update.")
            return -1
        else:
            print("Updating existing video")
    else:
        vid = Video()
        session.add(vid)

    vid.date_added = datetime.datetime.now(datetime.timezone.utc)

    vid.title = args.title
    vid.description = args.desc

    vid.start_segment = int(args.start)
    vid.end_segment = int(args.end)

    diff = vid.end_segment - vid.start_segment
    if (diff < 0 or diff > 24 * 60 * 60) and not args.force:
        print("Video would be over 24 hours long, aborting!")
        return -1

    kws = str(get_setting(session, "common_keywords")).split(",")
    kws.extend(str(args.keywords).split(","))
    vid.keywords = ",".join(kws)

    if get_setting(session, "upload_to_twitch"):
        vid.do_twitch = True
    else:
        vid.do_twitch = False

    if get_setting(session, "upload_to_youtube"):
        vid.do_youtube = True
    else:
        vid.do_youtube = False

    if get_setting(session, "youtube_public"):
        vid.youtube_public = True
    else:
        vid.youtube_public = False

    # TODO: Potentially kill any running upload tasks
    vid.youtube_status = ""
    vid.twitch_status = ""
    vid.processing_status = str(get_setting(session, "init_proc_status"))
    vid.processing_done = False
    vid.done_twitch = False
    vid.done_youtube = False

    if args.game:
        vid.game = args.game
    else:
        vid.game = None

    if vid.do_youtube and not vid.youtube_pubdate and get_setting(session, "youtube_schedule"):
        last_pubdate = get_last_publish_date(session)
        offset = get_setting(session, "schedule_offset_hours")
        vid.youtube_pubdate = last_pubdate + datetime.timedelta(hours=offset)

    session.commit()

    print("Added as Video %s" % vid.id)

    return 0