コード例 #1
0
def test_table_model(session: Session) -> None:
    """
    Test basic attributes of a ``Table``.
    """
    from superset.columns.models import Column
    from superset.models.core import Database
    from superset.tables.models import Table

    engine = session.get_bind()
    Table.metadata.create_all(engine)  # pylint: disable=no-member

    table = Table(
        name="my_table",
        schema="my_schema",
        catalog="my_catalog",
        database=Database(database_name="my_database", sqlalchemy_uri="test://"),
        columns=[
            Column(
                name="ds",
                type="TIMESTAMP",
                expression="ds",
            )
        ],
    )
    session.add(table)
    session.flush()

    assert table.id == 1
    assert table.uuid is not None
    assert table.database_id == 1
    assert table.catalog == "my_catalog"
    assert table.schema == "my_schema"
    assert table.name == "my_table"
    assert [column.name for column in table.columns] == ["ds"]
コード例 #2
0
def migrate_posted_tweets(session: Session, cur, table_posted: str):
    t0 = time.time()
    logger.info('Migrating %s', table_posted)
    n_in = 0
    n_out = 0
    cur.execute('SELECT url, line, parsed, status, edited, tweet, inserted '
                f'FROM {table_posted}')
    for url, line, parsed, status_str, edited, tweet, inserted in cur:
        n_in += 1
        logger.debug('%d %s', n_in, url)
        if not count(
                session.query(PostedTweet).filter(PostedTweet.text == tweet)):
            posted_tweet = PostedTweet(
                url=url,
                line=line,
                parsed=parsed,
                status=convert_status(status_str),
                edited=edited,
                text=tweet,
                inserted=inserted,
            )
            session.add(posted_tweet)
            if n_in % 10000 == 0:
                logger.info('%d flush', n_in)
                session.flush()
            n_out += 1
    logger.info('commit')
    session.commit()
    logger.info(
        'Migrated %s: %d -> %d in %ds',
        table_posted,
        n_in,
        n_out,
        time.time() - t0,
    )
コード例 #3
0
def migrate_pages(session: Session, cur, table_pages: str):
    t0 = time.time()
    logger.info('Migrating %s', table_pages)
    n_in = 0
    n_out = 0
    cur.execute(f'SELECT url, text, inserted FROM {table_pages}')
    for url, text, inserted in cur:
        n_in += 1
        logger.debug('%d %s', n_in, url)
        if not count(session.query(Page).filter(Page.url == url)):
            page = Page(
                url=url,
                text=text,
                inserted=inserted,
            )
            session.add(page)
            if n_in % 10000 == 0:
                logger.info('%d flush', n_in)
                session.flush()
            n_out += 1
    logger.info('commit')
    session.commit()
    logger.info(
        'Migrated %s: %d -> %d in %ds',
        table_pages,
        n_in,
        n_out,
        time.time() - t0,
    )
コード例 #4
0
class TestSessions(TestCase):
    plugins = []

    def test_multiple_connections(self):
        self.session2 = Session(bind=self.engine.connect())
        article = self.Article(name=u'Session1 article')
        article2 = self.Article(name=u'Session2 article')
        self.session.add(article)
        self.session2.add(article2)
        self.session.flush()
        self.session2.flush()

        self.session.commit()
        self.session2.commit()
        assert article.versions[-1].transaction_id == 1
        assert article2.versions[-1].transaction_id == 2

    def test_manual_transaction_creation(self):
        uow = versioning_manager.unit_of_work(self.session)
        transaction = uow.create_transaction(self.session)
        self.session.flush()
        assert transaction.id == 1
        article = self.Article(name=u'Session1 article')
        self.session.add(article)
        self.session.flush()
        assert uow.current_transaction.id == 1

        self.session.commit()
        assert article.versions[-1].transaction_id == 1

    def test_commit_without_objects(self):
        self.session.commit()
コード例 #5
0
ファイル: processor.py プロジェクト: abhinavkumar195/airflow
    def execute_callbacks(self,
                          dagbag: DagBag,
                          callback_requests: List[CallbackRequest],
                          session: Session = NEW_SESSION) -> None:
        """
        Execute on failure callbacks. These objects can come from SchedulerJob or from
        DagFileProcessorManager.

        :param dagbag: Dag Bag of dags
        :param callback_requests: failure callbacks to execute
        :param session: DB session.
        """
        for request in callback_requests:
            self.log.debug("Processing Callback Request: %s", request)
            try:
                if isinstance(request, TaskCallbackRequest):
                    self._execute_task_callbacks(dagbag,
                                                 request,
                                                 session=session)
                elif isinstance(request, SlaCallbackRequest):
                    self.manage_slas(dagbag.get_dag(request.dag_id),
                                     session=session)
                elif isinstance(request, DagCallbackRequest):
                    self._execute_dag_callbacks(dagbag, request, session)
            except Exception:
                self.log.exception(
                    "Error executing %s callback for file: %s",
                    request.__class__.__name__,
                    request.full_filepath,
                )

        session.flush()
コード例 #6
0
def test_auto_slug_property(session: Session) -> None:
    obj = DummyContact(name="a b c")
    session.add(obj)
    session.flush()
    assert obj.auto_slug == "a-b-c"

    # pyre-fixme[8]: Attribute has type `Column`; used as `str`.
    obj.name = "C'est l'été !"
    assert obj.auto_slug == "c-est-l-ete"

    # with a special space character
    # pyre-fixme[8]: Attribute has type `Column`; used as `str`.
    obj.name = "a_b\u205fc"  # U+205F: MEDIUM MATHEMATICAL SPACE
    assert obj.auto_slug == "a-b-c"

    # with non-ascii translatable chars, like EN DASH U+2013 (–) and EM DASH
    # U+2014 (—). Standard separator is \u002d (\x2d) "-" HYPHEN-MINUS.
    # this test may fails depending on how  Unicode normalization + char
    # substitution is done (order matters).
    # pyre-fixme[8]: Attribute has type `Column`; used as `str`.
    obj.name = "a\u2013b\u2014c"  # u'a–b—c'
    slug = obj.auto_slug
    assert slug == "a-b-c"
    assert "\u2013" not in slug
    assert "\u002d" in slug
コード例 #7
0
def find_user(session: Session,
              platform: int,
              user_id: int,
              nickname: str = '') -> User:
    """平台查找用户, 暂不支持o-what
    ### Args:
    ``session``: 用于连接数据库的SQLAlchemy线程.\n
    ``platform``: 待查找的平台.\n
    ``user_id``: 平台上的用户id.\n
    ``nickname``: 需要刷新的昵称, 不填则为不需要刷新.\n
    ### Result:
    ``user``: 查找到的用户.\n
    """
    try:
        if platform == 1:  # 摩点
            result = session.query(User).\
                             filter(User.modian_id == user_id).one()
        if platform == 2:  # 桃叭
            result = session.query(User).\
                             filter(User.taoba_id == user_id).one()
        if result.qq_id is None:
            if nickname and result.nickname != nickname:
                logger.debug('用户%s的昵称变为%s', result.nickname, nickname)
                result.nickname = nickname
                session.flush()
        return result
    except NoResultFound:
        if platform == 1:  # 摩点
            result = User(nickname=nickname, modian_id=user_id)
        if platform == 2:  # 桃叭
            result = User(nickname=nickname, taoba_id=user_id)
        session.add(result)
        session.flush()
        logger.debug('添加用户%s', str(result))
        return result
コード例 #8
0
class TestSessions(TestCase):
    plugins = []

    def test_multiple_connections(self):
        self.session2 = Session(bind=self.engine.connect())
        article = self.Article(name=u'Session1 article')
        article2 = self.Article(name=u'Session2 article')
        self.session.add(article)
        self.session2.add(article2)
        self.session.flush()
        self.session2.flush()

        self.session.commit()
        self.session2.commit()
        assert article.versions[-1].transaction_id
        assert (article2.versions[-1].transaction_id >
                article.versions[-1].transaction_id)

    def test_manual_transaction_creation(self):
        uow = versioning_manager.unit_of_work(self.session)
        transaction = uow.create_transaction(self.session)
        self.session.flush()
        assert transaction.id
        article = self.Article(name=u'Session1 article')
        self.session.add(article)
        self.session.flush()
        assert uow.current_transaction.id

        self.session.commit()
        assert article.versions[-1].transaction_id

    def test_commit_without_objects(self):
        self.session.commit()
コード例 #9
0
def test_quote_expressions(app_context: None, session: Session) -> None:
    """
    Test that expressions are quoted appropriately in columns and datasets.
    """
    from superset.columns.models import Column
    from superset.connectors.sqla.models import SqlaTable, TableColumn
    from superset.datasets.models import Dataset
    from superset.models.core import Database
    from superset.tables.models import Table

    engine = session.get_bind()
    Dataset.metadata.create_all(engine)  # pylint: disable=no-member

    columns = [
        TableColumn(column_name="has space", type="INTEGER"),
        TableColumn(column_name="no_need", type="INTEGER"),
    ]

    sqla_table = SqlaTable(
        table_name="old dataset",
        columns=columns,
        metrics=[],
        database=Database(database_name="my_database",
                          sqlalchemy_uri="sqlite://"),
    )
    session.add(sqla_table)
    session.flush()

    dataset = session.query(Dataset).one()
    assert dataset.expression == '"old dataset"'
    assert dataset.columns[0].expression == '"has space"'
    assert dataset.columns[1].expression == "no_need"
コード例 #10
0
ファイル: test_models.py プロジェクト: dodopizza/superset
def test_column_model(app_context: None, session: Session) -> None:
    """
    Test basic attributes of a ``Column``.
    """
    from superset.columns.models import Column

    engine = session.get_bind()
    Column.metadata.create_all(engine)  # pylint: disable=no-member

    column = Column(
        name="ds",
        type="TIMESTAMP",
        expression="ds",
    )

    session.add(column)
    session.flush()

    assert column.id == 1
    assert column.uuid is not None

    assert column.name == "ds"
    assert column.type == "TIMESTAMP"
    assert column.expression == "ds"

    # test that default values are set correctly
    assert column.description is None
    assert column.warning_text is None
    assert column.unit is None
    assert column.is_temporal is False
    assert column.is_spatial is False
    assert column.is_partition is False
    assert column.is_aggregation is False
    assert column.is_additive is False
    assert column.is_increase_desired is True
コード例 #11
0
    async def update_item_list(self, db: Session, *, order_db: Order,
                               order_update: OrderUpdate) -> Order:

        db.query(ItemOrder).filter(ItemOrder.order_id == order_db.id).delete()

        item_orders: List[ItemOrder] = []

        for item in order_update.items_in_order:

            item_order = ItemOrder(
                tot_price=item.tot_price,
                quantity=item.quantity,
                order_id=order_db.id,
                item_id=item.item_id,
            )

            item_orders.append(item_order)

        order_db.items = item_orders

        db.add(order_db)
        db.commit()
        db.flush(order_db)

        return order_db
コード例 #12
0
ファイル: users.py プロジェクト: PeterChain/fastapi_template
def create_user(
    user_to_create: schemas.UserIn, 
    user: schemas.UserOut = Depends(get_logged_user),
    db: Session = Depends(get_db)
):
    # Only a admin user can create other users
    admin_user = db.query(User).get(user.username)
    if not admin_user.admin:
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN,
            detail="Only admin users can create users"
        )
    
    # Username must be unique
    if db.query(User).get(user_to_create.username):
        raise HTTPException(
            status_code=status.HTTP_409_CONFLICT,
            detail="Username already exists"
        )

    new_user = User(**user_to_create.dict())
    new_user.hashed_password = hash_password(user_to_create.password)
    db.add(new_user)
    db.flush()
    db.commit()
    return schemas.UserOut(**new_user.dict())
コード例 #13
0
def test_posts_are_listed_in_publication_order(web_server: str,
                                               browser: DriverAPI,
                                               dbsession: Session,
                                               fakefactory):
    """Post are listed in publication order on tag-roll."""

    dates_span = arrow.Arrow.range("hour", arrow.get(2013, 5, 5, 0, 0),
                                   arrow.get(2013, 5, 5, 19, 0))[::-1]
    with transaction.manager:
        tag = fakefactory.TagFactory()
        posts = fakefactory.PostFactory.create_batch(len(dates_span),
                                                     public=True,
                                                     tags=[tag])
        random.shuffle(
            posts
        )  # make sure that creation order is not the same as publication order
        for post, date in zip(posts, dates_span):
            post.published_at = date.datetime
        dbsession.flush()
        dbsession.expunge_all()

    expected_posts_titles = [i.title for i in posts]
    browser.visit(web_server + "/blog/tag/{}".format(tag))

    rendered_posts_titles = [i.text for i in browser.find_by_css(".post h2")]

    assert expected_posts_titles == rendered_posts_titles
コード例 #14
0
ファイル: dagrun.py プロジェクト: arybin93/airflow
    def task_instance_scheduling_decisions(self, session: Session = None) -> TISchedulingDecision:

        schedulable_tis: List[TI] = []
        changed_tis = False

        tis = list(self.get_task_instances(session=session, state=State.task_states))
        self.log.debug("number of tis tasks for %s: %s task(s)", self, len(tis))
        for ti in tis:
            try:
                ti.task = self.get_dag().get_task(ti.task_id)
            except TaskNotFound:
                self.log.warning(
                    "Failed to get task '%s' for dag '%s'. Marking it as removed.", ti, ti.dag_id
                )
                ti.state = State.REMOVED
                session.flush()

        unfinished_tasks = [t for t in tis if t.state in State.unfinished]
        finished_tasks = [t for t in tis if t.state in State.finished]
        if unfinished_tasks:
            scheduleable_tasks = [ut for ut in unfinished_tasks if ut.state in SCHEDULEABLE_STATES]
            self.log.debug("number of scheduleable tasks for %s: %s task(s)", self, len(scheduleable_tasks))
            schedulable_tis, changed_tis = self._get_ready_tis(scheduleable_tasks, finished_tasks, session)

        return TISchedulingDecision(
            tis=tis,
            schedulable_tis=schedulable_tis,
            changed_tis=changed_tis,
            unfinished_tasks=unfinished_tasks,
            finished_tasks=finished_tasks,
        )
コード例 #15
0
def test_delete_sqlatable(app_context: None, session: Session) -> None:
    """
    Test that deleting a ``SqlaTable`` also deletes the corresponding ``Dataset``.
    """
    from superset.columns.models import Column
    from superset.connectors.sqla.models import SqlaTable, TableColumn
    from superset.datasets.models import Dataset
    from superset.models.core import Database
    from superset.tables.models import Table

    engine = session.get_bind()
    Dataset.metadata.create_all(engine)  # pylint: disable=no-member

    columns = [
        TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"),
    ]
    sqla_table = SqlaTable(
        table_name="old_dataset",
        columns=columns,
        metrics=[],
        database=Database(database_name="my_database",
                          sqlalchemy_uri="sqlite://"),
    )
    session.add(sqla_table)
    session.flush()

    datasets = session.query(Dataset).all()
    assert len(datasets) == 1

    session.delete(sqla_table)
    session.flush()

    # test that dataset was also deleted
    datasets = session.query(Dataset).all()
    assert len(datasets) == 0
コード例 #16
0
def pfs_visit_id(db: Session):
    db.add(
        models.pfs_visit(pfs_visit_id=-1,
                         pfs_visit_description='',
                         pfs_design_id=-1,
                         issued_at=datetime.datetime.now()))
    db.flush()
コード例 #17
0
ファイル: db.py プロジェクト: odracci/airflow
def synchronize_log_template(*, session: Session = NEW_SESSION) -> None:
    """Synchronize log template configs with table.

    This checks if the last row fully matches the current config values, and
    insert a new row if not.
    """
    def check_templates(filename, elasticsearch_id):
        stored = session.query(LogTemplate).order_by(
            LogTemplate.id.desc()).first()

        if not stored or stored.filename != filename or stored.elasticsearch_id != elasticsearch_id:
            session.add(
                LogTemplate(filename=filename,
                            elasticsearch_id=elasticsearch_id))

    filename = conf.get("logging", "log_filename_template")
    elasticsearch_id = conf.get("elasticsearch", "log_id_template")

    # Before checking if the _current_ value exists, we need to check if the old config value we upgraded in
    # place exists!
    pre_upgrade_filename = conf.upgraded_values.get(
        ('logging', 'log_filename_template'), None)
    if pre_upgrade_filename is not None:
        check_templates(pre_upgrade_filename, elasticsearch_id)
        session.flush()

    check_templates(filename, elasticsearch_id)
コード例 #18
0
def add_users(db: Session) -> None:
    """Add demo users to db."""

    one = User(id=USER_ONE_ID,
               email="*****@*****.**",
               username="******",
               password_hash=SECRET)
    db.add(one)
    logger.info("User added", username=one.username)

    two = User(id=USER_TWO_ID,
               email="*****@*****.**",
               username="******",
               password_hash=SECRET)
    db.add(two)
    two.follows.append(one)
    logger.info("User added", username=two.username)

    # Postman tests expect this user to be present
    johnjacob = User(
        id=USER_JOHNJACOB_ID,
        email="*****@*****.**",
        username="******",
        password_hash=SECRET,
    )
    db.add(johnjacob)
    johnjacob.follows.append(one)
    logger.info("User added", username=johnjacob.username)

    db.flush()
コード例 #19
0
def test_import_dataset_managed_externally(app_context: None,
                                           session: Session) -> None:
    """
    Test importing a dataset that is managed externally.
    """
    from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
    from superset.datasets.commands.importers.v1.utils import import_dataset
    from superset.datasets.schemas import ImportV1DatasetSchema
    from superset.models.core import Database
    from tests.integration_tests.fixtures.importexport import dataset_config

    engine = session.get_bind()
    SqlaTable.metadata.create_all(engine)  # pylint: disable=no-member

    database = Database(database_name="my_database",
                        sqlalchemy_uri="sqlite://")
    session.add(database)
    session.flush()

    dataset_uuid = uuid.uuid4()
    config = copy.deepcopy(dataset_config)
    config["is_managed_externally"] = True
    config["external_url"] = "https://example.org/my_table"
    config["database_id"] = database.id

    sqla_table = import_dataset(session, config)
    assert sqla_table.is_managed_externally is True
    assert sqla_table.external_url == "https://example.org/my_table"
コード例 #20
0
def execute_sql(session: Session, sql: str) -> ResultProxy:
    """
    execute an sql statement on the database
    
    Parameters
    ----------
    session
        sqlalchemy.orm.session.Session
        session object to be used
    sql
        sql
        the sql statement
 
    Return
    ------
    sqlalchemy.engine.ResultProxy
        DB-API cursor wrapper for results of the query
    """
    if not isinstance(sql, str):
        raise TypeError('sql must be of type str')
    try:
        results = session.execute(sql)
        session.commit()
        session.flush()
        return results
    except Exception as e:
        session.rollback()
        session.flush()
        raise e
コード例 #21
0
 def test_check_auth_redis_miss(self):
     db_handler = self.auth_handler.db_handler
     db_conn = db_handler.getEngine().connect()
     db_txn = db_conn.begin()
     try:
         db_session = Session(bind=db_conn)
         try:
             account = Account(auth_id='some_auth_id', username='******')
             db_session.add(account)
             db_session.flush()
             phonenumber = PhoneNumber(number='9740171794',
                                       account_id=account.id)
             db_session.add(phonenumber)
             db_session.commit()
             self.auth_handler.redis_client.hget.return_value = None
             self.auth_handler.redis_client.pipeline.return_value = redis_pipeline_mock = MagicMock(
             )
             status, phonenums = self.auth_handler._check_auth(
                 'some_user', 'some_auth_id', db_session)
             self.assertTrue(status)
             self.assertEquals(set(['9740171794']), phonenums)
             self.assertEquals(redis_pipeline_mock.hset.call_count, 1)
             redis_pipeline_mock.hset.assert_called_with(
                 self.auth_handler._REDIS_AUTH_HASH, 'some_user',
                 'some_auth_id')
             self.assertEquals(redis_pipeline_mock.sadd.call_count, 1)
             redis_pipeline_mock.sadd.assert_called_with(
                 'some_user', '9740171794')
             self.assertEquals(redis_pipeline_mock.execute.call_count, 1)
         finally:
             db_session.close()
     finally:
         db_txn.rollback()
         db_conn.close()
コード例 #22
0
def test_cascade_delete_table(app_context: None, session: Session) -> None:
    """
    Test that deleting ``Table`` also deletes its columns.
    """
    from superset.columns.models import Column
    from superset.models.core import Database
    from superset.tables.models import Table

    engine = session.get_bind()
    Table.metadata.create_all(engine)  # pylint: disable=no-member

    table = Table(
        name="my_table",
        schema="my_schema",
        catalog="my_catalog",
        database=Database(database_name="my_database",
                          sqlalchemy_uri="sqlite://"),
        columns=[
            Column(name="longitude", expression="longitude"),
            Column(name="latitude", expression="latitude"),
        ],
    )
    session.add(table)
    session.flush()

    columns = session.query(Column).all()
    assert len(columns) == 2

    session.delete(table)
    session.flush()

    # test that columns were deleted
    columns = session.query(Column).all()
    assert len(columns) == 0
コード例 #23
0
def queue_reward(
    *,
    deposit: Deposit,
    dbsession: Session,
    web3: Web3,
    reward_amount_rbtc: Decimal,
    deposit_thresholds: RewardThresholdMap,
):
    threshold = deposit_thresholds.get(deposit.side_token_symbol)
    if not threshold:
        # TODO: maybe these should be added somewhere for post processing?
        logger.warning('Threshold not found for deposit %s -- cannot process',
                       deposit)
        return
    if deposit.amount_decimal < threshold:
        logger.info('Threshold %s not met for deposit %s -- not rewarding',
                    threshold, deposit)
        return

    existing_reward = dbsession.query(Reward).filter(
        func.lower(Reward.user_address) ==
        deposit.user_address.lower()).first()
    if existing_reward:
        logger.info('User %s has already been rewarded.', deposit.user_address)
        return

    [balance, transaction_count] = _get_user_balance_and_transaction_count(
        web3=web3,
        user_address=deposit.user_address.lower(),
    )
    if balance > 0:
        logger.info(
            'User %s has an existing balance of %s RBTC - not rewarding',
            deposit.user_address, from_wei(balance, 'ether'))
        return
    if transaction_count > 0:
        logger.info(
            'User %s already has %s transactions in RSK - not rewarding',
            deposit.user_address, transaction_count)
        return

    logger.info('Rewarding user %s with %s RBTC', deposit.user_address,
                str(reward_amount_rbtc))

    reward = Reward(
        status=RewardStatus.queued,
        reward_rbtc_wei=int(reward_amount_rbtc * 10**18),
        user_address=deposit.user_address,
        deposit_side_token_address=deposit.side_token_address,
        deposit_side_token_symbol=deposit.side_token_symbol,
        deposit_main_token_address=deposit.main_token_address,
        deposit_amount_minus_fees_wei=deposit.amount_minus_fees_wei,
        deposit_log_index=deposit.log_index,
        deposit_block_hash=deposit.block_hash,
        deposit_transaction_hash=deposit.transaction_hash,
        deposit_contract_address=deposit.contract_address,
    )
    dbsession.add(reward)
    dbsession.flush()
    return reward
コード例 #24
0
ファイル: dagrun.py プロジェクト: bgeng777/flink-ai-extended
    def verify_integrity(self, session: Session = None):
        """
        Verifies the DagRun by checking for removed tasks or tasks that are not in the
        database yet. It will set state to removed or add the task if required.

        :param session: Sqlalchemy ORM Session
        :type session: Session
        """
        dag = self.get_dag()
        tis = self.get_task_instances(session=session)

        # check for removed or restored tasks
        task_ids = set()
        for ti in tis:
            task_instance_mutation_hook(ti)
            task_ids.add(ti.task_id)
            task = None
            try:
                task = dag.get_task(ti.task_id)
            except AirflowException:
                if ti.state == State.REMOVED:
                    pass  # ti has already been removed, just ignore it
                elif self.state is not State.RUNNING and not dag.partial:
                    self.log.warning(
                        "Failed to get task '%s' for dag '%s'. Marking it as removed.",
                        ti, dag)
                    Stats.incr(f"task_removed_from_dag.{dag.dag_id}", 1, 1)
                    ti.state = State.REMOVED

            should_restore_task = (task
                                   is not None) and ti.state == State.REMOVED
            if should_restore_task:
                self.log.info(
                    "Restoring task '%s' which was previously removed from DAG '%s'",
                    ti, dag)
                Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 1, 1)
                ti.state = State.NONE
            session.merge(ti)

        # check for missing tasks
        for task in dag.task_dict.values():
            if task.start_date > self.execution_date and not self.is_backfill and not self.is_manual:
                continue

            if task.task_id not in task_ids:
                Stats.incr(f"task_instance_created-{task.task_type}", 1, 1)
                ti = TI(task, self.execution_date)
                task_instance_mutation_hook(ti)
                session.add(ti)

        try:
            session.flush()
        except IntegrityError as err:
            self.log.info(str(err))
            self.log.info('Hit IntegrityError while creating the TIs for '
                          f'{dag.dag_id} - {self.execution_date}.')
            self.log.info('Doing session rollback.')
            # TODO[HA]: We probably need to savepoint this so we can keep the transaction alive.
            session.rollback()
コード例 #25
0
ファイル: query.py プロジェクト: badock/rome
 def delete(self, synchronize_session='evaluate'):
     from rome.core.session.session import Session
     temporary_session = Session()
     objects = self.matching_objects(filter_deleted=False)
     for obj in objects:
         temporary_session.delete(obj)
     temporary_session.flush()
     return len(objects)
コード例 #26
0
def downgrade():
    bind = op.get_bind()
    session = Session(bind=bind)
    session.query(XmlNodeModel).delete()
    session.query(XmlNodeAttrModel).delete()
    session.query(XmlNodeAttrWidgetModel).delete()
    session.flush()
    session.commit()
コード例 #27
0
def test_dataset_model(app_context: None, session: Session) -> None:
    """
    Test basic attributes of a ``Dataset``.
    """
    from superset.columns.models import Column
    from superset.datasets.models import Dataset
    from superset.models.core import Database
    from superset.tables.models import Table

    engine = session.get_bind()
    Dataset.metadata.create_all(engine)  # pylint: disable=no-member

    table = Table(
        name="my_table",
        schema="my_schema",
        catalog="my_catalog",
        database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"),
        columns=[
            Column(name="longitude", expression="longitude"),
            Column(name="latitude", expression="latitude"),
        ],
    )
    session.add(table)
    session.flush()

    dataset = Dataset(
        database=table.database,
        name="positions",
        expression="""
SELECT array_agg(array[longitude,latitude]) AS position
FROM my_catalog.my_schema.my_table
""",
        tables=[table],
        columns=[
            Column(
                name="position",
                expression="array_agg(array[longitude,latitude])",
            ),
        ],
    )
    session.add(dataset)
    session.flush()

    assert dataset.id == 1
    assert dataset.uuid is not None

    assert dataset.name == "positions"
    assert (
        dataset.expression
        == """
SELECT array_agg(array[longitude,latitude]) AS position
FROM my_catalog.my_schema.my_table
"""
    )

    assert [table.name for table in dataset.tables] == ["my_table"]
    assert [column.name for column in dataset.columns] == ["position"]
コード例 #28
0
def update_rows(session: Session,
                model: DeclarativeMeta,
                updates: dict,
                filters: dict = None):
    """
    bulk update rows from model where the criteria in the filters is met by the values in the updates dict
    Parameters
    ----------
    session
        sqlalchemy.orm.session.Session
        session object to be used
    model 
        sqlalchemy.ext.declarative.api.DeclativeMeta
        the sqlalchemy model to use
    updates
        dict
        the fields to update as the keys and their respective values and the dictionary values
    filters
        dict
        filters dict must be in the following structure
        [  {
                'column': {
                    'comparitor': '>=' OR '==' OR '<=' OR '>' OR '<' OR !=
                    'data': str OR int OR float  
                },
                join = "and" OR "or"
            },
            ...Other Columns
        ]
    """
    if not isinstance(updates, dict):
        raise TypeError('updates must be of type dict')

    results = read_rows(session, model, filters)

    check_res = results.first()

    if check_res == None:
        raise NoResultFound(
            f"no rows can be updated because no rows can be found with the following filters: {json.dumps(filters)}"
        )

    matched = results.update(updates)

    if matched == 0:
        raise ValueError(
            f"bad update request, no columns could be matched updates requested: {json.dumps(updates)}"
        )

    try:
        session.commit()
        session.flush()
    except Exception as e:
        # TODO Logging.log.exception()
        session.rollback()
        session.flush()
        raise e
コード例 #29
0
def upgrade():
    session = Session(bind=op.get_bind())
    session.add(Exchange(name=BitkubExchange.name, is_active=True, weight=15))

    bx_in_th_name = "[bx.in.th](https://bx.in.th/ref/s9c3HU/)"
    bx_in_th = session.query(Exchange).filter_by(name=bx_in_th_name).one()
    session.query(Rate).filter_by(exchange_id=bx_in_th.id).delete()
    session.delete(bx_in_th)
    session.flush()
コード例 #30
0
def create_price(price: Price, sess: Session = Depends(create_session)):
    if price.page_id is None:
        raise HTTPException(400, "page id is not given")
    page = sess.query(PageORM).filter(PageORM.id == price.page_id).scalar()
    if not page:
        raise HTTPException(400, f"page id ({price.page_id}) does not exist")
    priceOrm = PriceORM(**price.dict(exclude_none=True, exclude_unset=True))
    sess.add(priceOrm)
    sess.flush([priceOrm])
    return Price.from_orm(priceOrm)
コード例 #31
0
ファイル: test_entities.py プロジェクト: tmarwen/abilian-core
def test_polymorphic_update_timestamp(session: Session) -> None:
    contact = DummyContact(name="Pacôme Hégésippe Adélard Ladislas")
    session.add(contact)
    session.flush()

    updated_at = contact.updated_at
    assert updated_at
    contact.email = "*****@*****.**"
    session.flush()
    assert contact.updated_at > updated_at
コード例 #32
0
def dispatch_play(offset: int, session: Session, bus: ChallengeEventBus):
    play = create_play(offset)
    session.add(play)
    session.flush()
    bus.dispatch(
        ChallengeEvent.track_listen,
        BLOCK_NUMBER,
        1,
        {"created_at": play.created_at.timestamp()},
    )
コード例 #33
0
    def test_session_with_external_transaction(self):
        conn = self.engine.connect()
        t = conn.begin()
        session = Session(bind=conn)

        article = self.Article(name=u'My Session Article')
        session.add(article)
        session.flush()

        session.close()
        t.rollback()
        conn.close()
コード例 #34
0
ファイル: test_sqlalchemy.py プロジェクト: polyactis/test
#2008-04-28 these lines are not necessary
#session.save(k_tutorial)
#session.save(k_cool)
#session.save(k_unfinished)
#session.flush()

a1.keywords.append(k_unfinished)
k_cool.articles.append(a1)
k_cool.articles.append(a2)
# Or:
k_cool.articles = [a1, a2]  # This works as well!
a2.keywords.append(k_tutorial)
#one flush will save all relevant unsaved objects into database

session.flush()
print "check stuff in db after flush:"
i = 0
row = session.query(Article).offset(i).limit(1).list()
while row:
	row = row[0]
	print row.headline
	print row.keywords
	i += 1
	row = session.query(Article).offset(i).limit(1).list()	#all() = list() returns a list of objects. first() returns the 1st object. one() woud raise error because 'Multiple rows returned for one()'

#"""
#2008-05-07
s = sqlalchemy.sql.select([Article.c.body, Article.c.headline], Article.c.headline=='Python is cool!')
#connection = eng.connect()
result = connection.execute(s)
コード例 #35
0
ファイル: session.py プロジェクト: GEverding/inbox
class InboxSession(object):
    """ Inbox custom ORM (with SQLAlchemy compatible API).

    Parameters
    ----------
    engine : <sqlalchemy.engine.Engine>
        A configured database engine to use for this session
    versioned : bool
        Do you want to enable the transaction log?
    ignore_soft_deletes : bool
        Whether or not to ignore soft-deleted objects in query results.
    namespace_id : int
        Namespace to limit query results with.
    """
    def __init__(self, engine, versioned=True, ignore_soft_deletes=True,
                 namespace_id=None):
        # TODO: support limiting on namespaces
        assert engine, "Must set the database engine"

        args = dict(bind=engine, autoflush=True, autocommit=False)
        self.ignore_soft_deletes = ignore_soft_deletes
        if ignore_soft_deletes:
            args['query_cls'] = InboxQuery
        self._session = Session(**args)

        if versioned:
            from inbox.models.transaction import create_revisions

            @event.listens_for(self._session, 'after_flush')
            def after_flush(session, flush_context):
                """
                Hook to log revision snapshots. Must be post-flush in order to
                grab object IDs on new objects.
                """
                create_revisions(session)

    def query(self, *args, **kwargs):
        q = self._session.query(*args, **kwargs)
        if self.ignore_soft_deletes:
            return q.options(IgnoreSoftDeletesOption())
        else:
            return q

    def add(self, instance):
        if not self.ignore_soft_deletes or not instance.is_deleted:
            self._session.add(instance)
        else:
            raise Exception("Why are you adding a deleted object?")

    def add_all(self, instances):
        if True not in [i.is_deleted for i in instances] or \
                not self.ignore_soft_deletes:
            self._session.add_all(instances)
        else:
            raise Exception("Why are you adding a deleted object?")

    def delete(self, instance):
        if self.ignore_soft_deletes:
            instance.mark_deleted()
            # just to make sure
            self._session.add(instance)
        else:
            self._session.delete(instance)

    def begin(self):
        self._session.begin()

    def commit(self):
        self._session.commit()

    def rollback(self):
        self._session.rollback()

    def flush(self):
        self._session.flush()

    def close(self):
        self._session.close()

    def expunge(self, obj):
        self._session.expunge(obj)

    def merge(self, obj):
        return self._session.merge(obj)

    @property
    def no_autoflush(self):
        return self._session.no_autoflush