Esempio n. 1
0
 def upgrade(self, connection: Connection):
     connection.execute("""create table migrations(
         id int not null auto_increment,
         version int not null,
         created timestamp not null default current_timestamp,            
         primary key (id)
     );""")
Esempio n. 2
0
def _get_net_income_quater(connection: engine.Connection, meta: MetaData,
                           name: str, b_date: date) -> float:
    """
    Аргументы: соединение с бд, метаданные бд, имя банка и дату начала квартала.
    Расчитывает значение балансовой прибыли, нарастающим итогом
    Возвращает это значение на начало квартала.
    """
    # Ассоциация с таблицей f102
    bank = meta.tables['f102']
    # Поиск доходов
    select_statement = select(
        bank.c.SIM_ITOGO).filter(
        and_(
            bank.c.NAME_B == name,
            bank.c.DT == b_date,
            bank.c.CODE == 19999))
    result = connection.execute(select_statement)
    incomes = result.fetchone()[0]
    # Поиск расходов
    select_statement = select(
        bank.c.SIM_ITOGO).filter(
        and_(
            bank.c.NAME_B == name,
            bank.c.DT == b_date,
            bank.c.CODE == 29999))
    result = connection.execute(select_statement)
    consumptions = result.fetchone()[0]
    return incomes - consumptions
def insert_multiple(application_id: str, account_name: str,
                    followers: Iterable, source: str, cursor: str,
                    conn: Connection) -> None:
    """Add followers for one account into the account relationships table"""
    trans = conn.begin()
    try:
        account_id = select_one_id(application_id, account_name, source, conn)
        for follower in followers:
            follower_id = select_one_id(application_id, follower, source, conn)

            if not follower_id:
                follower_id = insert_one(application_id, follower, False,
                                         source, conn)

            stmt_relationship = insert(account_relationship,
                                       values={
                                           'account': account_id,
                                           'follower': follower_id
                                       })
            conn.execute(stmt_relationship)

        update_one_cursor(account_id, cursor, conn)
        trans.commit()
    except Exception as e:
        trans.rollback()
    def _add_column(
        self,
        connection: Connection,
        table: SqlaTable,
        column_name: str,
        column_type: str,
    ) -> None:
        """
        Add new column to table
        :param connection: The connection to work with
        :param table: The SqlaTable
        :param column_name: The name of the column
        :param column_type: The type of the column
        """
        column = Column(column_name, column_type)
        name = column.compile(column_name,
                              dialect=table.database.get_sqla_engine().dialect)
        col_type = column.type.compile(
            table.database.get_sqla_engine().dialect)
        sql = text(
            f"ALTER TABLE {self._get_from_clause(table)} ADD {name} {col_type}"
        )
        connection.execute(sql)

        table.columns.append(
            TableColumn(column_name=column_name, type=col_type))
Esempio n. 5
0
def _insert_bank_list(connection: engine.Connection, meta: MetaData,
                      ratios: dict, b_date: date):
    """
    Аргументы: соединение с бд, метаданные бд, дата и 
    словарь, ключи - именования банков, значения - 
    списки, рассчитанных коэффициентов для банков, за
    данную дату.
    Вставляет данные в витрину данных
    """
    # Ассоциация с таблицой витрины данных
    table_mart = meta.tables[config.mart_name]
    # Ассоциация с таблицей req
    table_req = meta.tables['req']
    for name in ratios:
        # Нахождение индекса банка в таблице req
        select_statement = select(
            table_req.c.index).filter(table_req.c.NAME_B == name)
        result = connection.execute(select_statement)
        bank_index = result.fetchone()[0]
        # Вставка данных в витрину данных
        insert_statement = table_mart.\
            insert().values(bank_index=bank_index, date=b_date,
                            CA=ratios[name][0], LR=ratios[name][1],
                            CR=ratios[name][2], LTR=ratios[name][3],
                            CP=ratios[name][4], ROA=ratios[name][5])
        connection.execute(insert_statement)
Esempio n. 6
0
def _insert_bank(connection: engine.Connection, meta: MetaData, ratios: dict,
                 name: str):
    """
    Аргументы: соединение с бд, метаданные бд, имя банка и cловарь,
    ключи - даты, значения - списки, расcчитанных коэффициентов
    за эту дату для данного банка.
    Вставляет данные в витрину данных
    """
    # Ассоциация с таблицой витрины данных
    table_mart = meta.tables[config.mart_name]
    # Ассоциация с таблицей req
    table_req = meta.tables['req']
    # Нахождение индекса банка в таблице req
    select_statement = select(
        table_req.c.index).filter(table_req.c.NAME_B == name)
    result = connection.execute(select_statement)
    bank_index = result.fetchone()[0]
    # Вставка данных в витрину данных
    for b_date in ratios:
        insert_statement = table_mart.\
            insert().values(bank_index=bank_index, date=b_date,
                            CA=ratios[b_date][0], LR=ratios[b_date][1],
                            CR=ratios[b_date][2], LTR=ratios[b_date][3],
                            CP=ratios[b_date][4], ROA=ratios[b_date][5])
        connection.execute(insert_statement)
Esempio n. 7
0
    def get_columns(
        self,
        connection: Connection,
        table_name: str,
        schema: Optional[str] = None,
        **kwargs: Any,
    ) -> List[Dict[str, Any]]:
        query = f'SHOW COLUMNS FROM "{table_name}"'
        # Custom SQL
        array_columns_ = connection.execute(
            f"SHOW ARRAY_COLUMNS FROM {table_name}").fetchall()
        # convert cursor rows: List[Tuple[str]] to List[str]
        if not array_columns_:
            array_columns = []
        else:
            array_columns = [col_name.name for col_name in array_columns_]

        all_columns = connection.execute(query)
        return [{
            "name": row.column,
            "type": basesqlalchemy.get_type(row.mapping),
            "nullable": True,
            "default": None,
        } for row in all_columns
                if row.mapping not in self._not_supported_column_types
                and row.column not in array_columns]
Esempio n. 8
0
def get_free_title(connection: Connection,
                   title: str,
                   auth_user_id: str) -> str:
    """
    Get a good version of the title to be inserted into the survey table. If
    the title as given already exists, this function will append a number.
    For example, when the title is "survey":
    1. "survey" not in table -> "survey"
    2. "survey" in table     -> "survey(1)"
    3. "survey(1)" in table  -> "survey(2)"

    :param connection: a SQLAlchemy Connection
    :param title: the survey title
    :param auth_user_id: the user's UUID
    :return: a title that can be inserted safely
    """

    (does_exist, ), = connection.execute(
        select((exists().where(
            survey_table.c.survey_title == title
        ).where(survey_table.c.auth_user_id == auth_user_id),)))
    if not does_exist:
        return title
    similar_surveys = connection.execute(
        select([survey_table]).where(
            survey_table.c.survey_title.like(title + '%')
        ).where(
            survey_table.c.auth_user_id == auth_user_id
        )
    ).fetchall()
    conflicts = list(_conflicting(title, similar_surveys))
    free_number = max(conflicts) + 1 if len(conflicts) else 1
    return title + '({})'.format(free_number)
Esempio n. 9
0
def test_saves_auction_changes(
    connection: Connection,
    another_bidder_id: int,
    bid_model: RowProxy,
    auction_model_with_a_bid: RowProxy,
    ends_at: datetime,
    event_bus_mock: Mock,
) -> None:
    new_bid_price = get_dollars(bid_model.amount * 2)
    auction = Auction(
        id=auction_model_with_a_bid.id,
        title=auction_model_with_a_bid.title,
        starting_price=get_dollars(auction_model_with_a_bid.starting_price),
        ends_at=ends_at,
        bids=[
            Bid(bid_model.id, bid_model.bidder_id,
                get_dollars(bid_model.amount)),
            Bid(None, another_bidder_id, new_bid_price),
        ],
        ended=True,
    )

    SqlAlchemyAuctionsRepo(connection, event_bus_mock).save(auction)

    assert connection.execute(select([func.count()
                                      ]).select_from(bids)).scalar() == 2
    proxy = connection.execute(
        select([
            auctions
        ]).where(auctions.c.id == auction_model_with_a_bid.id)).first()
    assert proxy.current_price == new_bid_price.amount
    assert proxy.ended
Esempio n. 10
0
 def create_table(self, schema_sql: str, conn: Connection,
                  driver: DBDriver) -> None:
     logger.info('Creating table...')
     conn.execute(schema_sql)
     logger.info(f"Just ran {schema_sql}")
     self.add_permissions(conn, driver)
     logger.info("Table prepped")
Esempio n. 11
0
def diff_tables(
    connection: Connection, master: Table, copy: Table,
    result_columns: Iterable[Column]
) -> Tuple[List[Tuple], List[Tuple], List[Tuple]]:
    """
    Compute the differences in the contents of two tables with identical
    columns.

    The master table must have at least one PrimaryKeyConstraint or
    UniqueConstraint with only non-null columns defined.

    If there are multiple constraints defined the constraints that contains the
    least number of columns are used.
    :param connection: DB connection
    :param master: Master table
    :param copy: Copy of master table
    :param result_columns: columns to return
    :return: True, if the contents differ, otherwise False
    """
    logger.debug('Calculating diff between "%s" and "%s"', master.name,
                 copy.name)
    result_columns = tuple(result_columns)
    unique_columns = min(
        (constraint.columns for constraint in master.constraints
         if isinstance(constraint, (UniqueConstraint, PrimaryKeyConstraint))
         and constraint.columns and
         not any(map(operator.attrgetter('nullable'), constraint.columns))),
        key=len,
        default=[])
    if not unique_columns:
        raise AssertionError("To diff table {} it must have at least one "
                             "PrimaryKeyConstraint/UniqueConstraint with only "
                             "NOT NULL columns defined on it.".format(
                                 master.name))
    unique_column_names = tuple(col.name for col in unique_columns)
    other_column_names = tuple(col.name for col in master.c
                               if col.name not in unique_column_names)
    on_clause = and_(
        *(getattr(master.c, column_name) == getattr(copy.c, column_name)
          for column_name in unique_column_names))
    added = connection.execute(
        select(result_columns).select_from(master.outerjoin(
            copy, on_clause)).where(
                or_(*(getattr(copy.c, column_name).is_(null())
                      for column_name in unique_column_names)))).fetchall()
    deleted = connection.execute(
        select(result_columns).select_from(copy.outerjoin(
            master, on_clause)).where(
                or_(*(getattr(master.c, column_name).is_(null())
                      for column_name in unique_column_names)))).fetchall()
    modified = connection.execute(
        select(result_columns).select_from(master.join(copy, on_clause)).where(
            or_(*(
                getattr(master.c, column_name) != getattr(copy.c, column_name)
                for column_name in other_column_names
            )))).fetchall() if other_column_names else []
    logger.debug('Diff found %d added, %d deleted, and %d modified records',
                 len(added), len(deleted), len(modified))
    return added, deleted, modified
Esempio n. 12
0
def bulk_update(conn: Connection, table: Table, col_name: str,
                ids_values: List[Tuple]):
    stmt = table.update().where(table.c.id == bindparam('obj_id')).values(
        **{col_name: bindparam('val')})
    conn.execute(stmt, [{
        'obj_id': eid,
        'val': val
    } for eid, val in ids_values])
Esempio n. 13
0
def update_cluster(account_id: str, topic_iteration_id: str, cluster: int,
                   x: float, y: float, conn: Connection) -> None:
    """Update the cluster and the x and y values of the topic"""
    stmt = update(topic).where(
        and_(topic.c.account == account_id,
             topic.c.topic_iteration == topic_iteration_id)).values(
                 cluster=cluster, x=x, y=y)
    conn.execute(stmt)
Esempio n. 14
0
def import_nivo(con: Connection, csv_file: ANivoCsv) -> None:
    csv_file.normalize()
    csv_file.find_and_replace_foreign_key_value()
    with con.begin():
        ins = insert(NivoRecordTable).values(
            csv_file.cleaned_csv
        )  # .on_conflict_do_nothing(index_elements=['nss_name'])
        con.execute(ins)
Esempio n. 15
0
def _execute_sql_stream(conn: Connection, sql: str) -> None:
    """Run the SQL statements in a stream against a database."""
    for query in split_sql(sql):
        if _should_escape_percents(conn):
            escaped_query = query.replace("%", "%%")
        else:
            escaped_query = query
        conn.execute(escaped_query)
Esempio n. 16
0
def add_columns(table: str, df: pandas.DataFrame, con: Connection):
	"""Add columns eventually missing in the table, if the table exists"""
	if table_exists(table, con):
		cols = set(con.execute(f'SELECT * FROM {table} LIMIT 1').keys())
		new_cols = set(df.columns) - cols
		for col in new_cols:
			alter = f'ALTER TABLE {table} ADD COLUMN "{col}" FLOAT'
			con.execute(alter)
Esempio n. 17
0
 def _add_or_update_records(
         cls, conn: Connection, table: Table,
         records: List["I2B2CoreWithUploadId"]) -> Tuple[int, int]:
     """
     Add or update the supplied table as needed to reflect the contents of records
     :param table: i2b2 sql connection
     :param records: records to apply
     :return: number of records added / modified
     """
     num_updates = 0
     num_inserts = 0
     inserts = []
     # Iterate over the records doing updates
     # Note: This is slow as molasses - definitely not optimal for batch work, but hopefully we'll be dealing with
     #    thousands to tens of thousands of records.  May want to move to ORM model if this gets to be an issue
     for record in records:
         keys = [(table.c[k] == getattr(record, k)) for k in cls.key_fields]
         key_filter = I2B2CoreWithUploadId._nested_fcn(and_, keys)
         rec_exists = conn.execute(
             select([table.c.upload_id]).where(key_filter)).rowcount
         if rec_exists:
             known_values = {
                 k: v
                 for k, v in record._freeze().items()
                 if v is not None and k not in cls.no_update_fields
                 and k not in cls.key_fields
             }
             vals = [table.c[k] != v for k, v in known_values.items()]
             val_filter = I2B2CoreWithUploadId._nested_fcn(or_, vals)
             known_values['update_date'] = record.update_date
             upd = update(table).where(and_(
                 key_filter, val_filter)).values(known_values)
             num_updates += conn.execute(upd).rowcount
         else:
             inserts.append(record._freeze())
     if inserts:
         if cls._check_dups:
             dups = cls._check_for_dups(inserts)
             nprints = 0
             # TODO: Figure out why duplicates are occuring -- they are very specific
             if dups:
                 print("{} duplicate records encountered".format(len(dups)))
                 for k, vals in dups.items():
                     if len(vals) == 2 and vals[0] == vals[1]:
                         inserts.remove(vals[1])
                     else:
                         if nprints < 20:
                             print("Key: {} has a non-identical dup".format(
                                 k))
                         elif nprints == 20:
                             print(".... more ...")
                         nprints += 1
                         for v in vals[1:]:
                             inserts.remove(v)
         # TODO: refactor this to load on a per-resource basis.  Temporary fix
         for insert in ListChunker(inserts, 500):
             num_inserts += conn.execute(table.insert(), insert).rowcount
     return num_inserts, num_updates
Esempio n. 18
0
def bulk_update_column(
    conn: Connection, table: Table, col_name: str, ids_values: List[Tuple], key_col="id"
):
    stmt = (
        table.update()
        .where(getattr(table.c, key_col) == bindparam("obj_id"))
        .values(**{col_name: bindparam("val")})
    )
    conn.execute(stmt, [{"obj_id": eid, "val": val} for eid, val in ids_values])
Esempio n. 19
0
def insert_one(application_id: str, source: str, topics: Iterable,
               conn: Connection) -> None:
    stmt = insert(topic_model,
                  values={
                      'application': application_id,
                      'source': source,
                      'topics': topics
                  })
    conn.execute(stmt)
Esempio n. 20
0
def insert_one(account_id: str, weights: Iterable, topic_iteration: str,
               conn: Connection) -> None:
    stmt = insert(topic,
                  values={
                      'account': account_id,
                      'weights': weights,
                      'topic_iteration': topic_iteration
                  })
    conn.execute(stmt)
Esempio n. 21
0
def delete(connection: Connection, survey_id: str):
    """
    Delete the survey specified by the given survey_id

    :param connection: a SQLAlchemy connection
    :param survey_id: the UUID of the survey
    """
    with connection.begin():
        connection.execute(delete_record(survey_table, 'survey_id', survey_id))
    return json_response('Survey deleted')
Esempio n. 22
0
def nabel_update_recent(con: Connection, from_time: datetime):
    year = datetime.now(tz_local).year
    print(f'Retrieving data for the city of Zurich for {year}')
    if table_exists(NABEL_TABLE, con):
        con.execute(
            f"DELETE FROM {NABEL_TABLE} WHERE date >= '{from_time.isoformat()}'"
        )

    nabel_download(con, from_time, datetime.now(tz_local))
    print('Done!')
Esempio n. 23
0
def persist_department(con: Connection, name: str, number: str,
                       zone: str) -> UUID:
    zone_id = persist_zone(con, zone)
    ins = insert(DepartmentTable).values(d_name=name,
                                         d_number=number,
                                         d_zone=zone_id)
    ins = ins.on_conflict_do_nothing(index_elements=["d_name"])
    con.execute(ins)
    res = select([DepartmentTable.c.d_id
                  ]).where(DepartmentTable.c.d_name == name)
    return con.execute(res).first().d_id
Esempio n. 24
0
def start_new_payment(payment_uuid: UUID, customer_id: int, amount: Money,
                      description: str, conn: Connection) -> None:
    conn.execute(
        payments.insert({
            "uuid": str(payment_uuid),
            "customer_id": customer_id,
            "amount": int(amount.amount) * 100,
            "currency": amount.currency.iso_code,
            "description": description,
            "status": PaymentStatus.NEW.value,
        }))
Esempio n. 25
0
def persist_flowcapt_station(con: Connection, station: Feature) -> None:
    geom = _get_geom(station)
    ins = insert(FlowCaptStationTable).values(
        fcs_id=station.properties["id"],
        fcs_site=station.properties["site"],
        fcs_country=station.properties["country"],
        fcs_altitude=station.properties["altitude"],
        the_geom=geom,
    )
    ins = ins.on_conflict_do_nothing(index_elements=["fcs_id"])
    con.execute(ins)
Esempio n. 26
0
    def _re_encrypt_row(
        self,
        conn: Connection,
        row: RowProxy,
        table_name: str,
        columns: Dict[str, EncryptedType],
    ) -> None:
        """
        Re encrypts all columns in a Row
        :param row: Current row to reencrypt
        :param columns: Meta info from columns
        """
        re_encrypted_columns = {}

        for column_name, encrypted_type in columns.items():
            previous_encrypted_type = EncryptedType(
                type_in=encrypted_type.underlying_type,
                key=self._previous_secret_key)
            try:
                unencrypted_value = previous_encrypted_type.process_result_value(
                    self._read_bytes(column_name, row[column_name]),
                    self._dialect)
            except ValueError as exc:
                # Failed to unencrypt
                try:
                    encrypted_type.process_result_value(
                        self._read_bytes(column_name, row[column_name]),
                        self._dialect)
                    logger.info(
                        "Current secret is able to decrypt value on column [%s.%s],"
                        " nothing to do",
                        table_name,
                        column_name,
                    )
                    return
                except Exception:
                    raise Exception from exc

            re_encrypted_columns[
                column_name] = encrypted_type.process_bind_param(
                    unencrypted_value,
                    self._dialect,
                )

        set_cols = ",".join([
            f"{name} = :{name}" for name in list(re_encrypted_columns.keys())
        ])
        logger.info("Processing table: %s", table_name)
        conn.execute(
            text(f"UPDATE {table_name} SET {set_cols} WHERE id = :id"),
            id=row["id"],
            **re_encrypted_columns,
        )
Esempio n. 27
0
def zurich_download_all(con: Connection):
    print(
        f'Retrieving data for the city of Zurich from {START_YEAR} to {END_YEAR}'
    )
    con.execute(f'DROP TABLE IF EXISTS {ZURICH_TABLE}')

    for year in tqdm(range(START_YEAR, END_YEAR + 1)):
        df = zurich_hourly_data(year)
        add_columns(ZURICH_TABLE, df, con)
        df.to_sql(ZURICH_TABLE, if_exists='append', con=con)

    print('Done!')
Esempio n. 28
0
def nabel_download_all(con: Connection):
    print(f'Retrieving data for the city of Zurich starting in {START_YEAR}')
    con.execute(f'DROP TABLE IF EXISTS {NABEL_TABLE}')
    current_year = datetime.now(tz_local).year

    for year in tqdm(range(START_YEAR, current_year), desc='Past years'):
        nabel_download(con, datetime(year, 1, 1, 0, 0, 0, 0),
                       datetime(year, 12, 31, 0, 0, 0, 0))

    nabel_download(con, datetime(current_year, 1, 1, 0, 0, 0, 0),
                   datetime.now(tz_local))
    print('Done!')
Esempio n. 29
0
def example_auction(connection: Connection, another_user: str) -> Generator[int, None, None]:
    ends_at = datetime.now() + timedelta(days=3)
    result_proxy = connection.execute(
        auctions.insert(
            {"title": "Super aukcja", "starting_price": "0.99", "current_price": "1.00", "ends_at": ends_at}
        )
    )
    auction_id = result_proxy.lastrowid
    connection.execute(bids.insert({"auction_id": auction_id, "amount": "1.00", "bidder_id": another_user}))
    yield int(auction_id)
    connection.execute(bids.delete(bids.c.auction_id == auction_id))
    connection.execute(auctions.delete(auctions.c.id == auction_id))
Esempio n. 30
0
def _create_choices(connection: Connection,
                    values: dict,
                    question_id: str,
                    submission_map: dict,
                    existing_question_id: str=None) -> Iterator:
    """
    Create the choices of a survey question. If this is an update to an
    existing survey, it will also copy over answers to the questions.

    :param connection: the SQLAlchemy Connection object for the transaction
    :param values: the dictionary of values associated with the question
    :param question_id: the UUID of the question
    :param submission_map: a dictionary mapping old submission_id to new
    :param existing_question_id: the UUID of the existing question (if this is
                                 an update)
    :return: an iterable of the resultant choice fields
    """
    choices = values['choices']
    new_choices, updates = _determine_choices(connection, existing_question_id,
                                              choices)

    for number, choice in enumerate(new_choices):
        choice_dict = {
            'question_id': question_id,
            'survey_id': values['survey_id'],
            'choice': choice,
            'choice_number': number,
            'type_constraint_name': values['type_constraint_name'],
            'question_sequence_number': values['sequence_number'],
            'allow_multiple': values['allow_multiple']}
        executable = question_choice_insert(**choice_dict)
        exc = [('unique_choice_names', RepeatedChoiceError(choice))]
        result = execute_with_exceptions(connection, executable, exc)
        result_ipk = result.inserted_primary_key
        question_choice_id = result_ipk[0]

        if choice in updates:
            question_fields = {'question_id': question_id,
                               'type_constraint_name': result_ipk[2],
                               'sequence_number': result_ipk[3],
                               'allow_multiple': result_ipk[4],
                               'survey_id': values['survey_id']}
            for answer in get_answer_choices_for_choice_id(connection,
                                                           updates[choice]):
                answer_values = question_fields.copy()
                new_submission_id = submission_map[answer.submission_id]
                answer_values['question_choice_id'] = question_choice_id
                answer_values['submission_id'] = new_submission_id
                answer_metadata = answer.answer_choice_metadata
                answer_values['answer_choice_metadata'] = answer_metadata
                connection.execute(answer_choice_insert(**answer_values))

        yield question_choice_id
Esempio n. 31
0
def lock_table(connection: Connection, target_table: Table):
    """
    Lock a table using a PostgreSQL advisory lock

    The OID of the table in the pg_class relation is used as lock id.
    :param connection: DB connection
    :param target_table: Table object
    """
    logger.debug('Locking table "%s"', target_table.name)
    oid = connection.execute(
        select([column("oid")]).select_from(table("pg_class")).where(
            (column("relname") == target_table.name))).scalar()
    connection.execute(select([func.pg_advisory_xact_lock(oid)])).scalar()
Esempio n. 32
0
 def upgrade(self, connection: Connection):
     tx = connection.begin()
     connection.execute("""create table villain_templates(
             id int not null auto_increment,
             name text not null,
             face_image_url text not null,
             primary key (id)
         );""")
     connection.execute("""
         alter table villain_templates
         add unique index `villain_template_name_idx` (`name`);
     """)
     tx.commit()
Esempio n. 33
0
 def insert_profile(conn: Connection, insert: str, p: Profile):
     u, _ = unify_profile_name(p.first_name, p.last_name)
     b64u = generate_id(u)
     conn.execute(
         insert,
         (
             sanitize_text(p.identifier),
             b64u,
             sanitize_text(p.first_name),
             sanitize_text(p.last_name),
             sanitize_text(p.display_name),
             sanitize_text(p.link),
         ),
     )
Esempio n. 34
0
def init_db(connection: Connection, alembic_ini: str=None, force: bool=False, test: bool=False) -> None:
    import c2cgeoportal_commons.models.main  # noqa: F401
    import c2cgeoportal_commons.models.static  # noqa: F401

    schema = config['schema']
    schema_static = config['schema_static']
    if force:
        if schema_exists(connection, schema):
            connection.execute('DROP SCHEMA {} CASCADE;'.format(schema))
        if schema_exists(connection, schema_static):
            connection.execute('DROP SCHEMA {} CASCADE;'.format(schema_static))

    if not schema_exists(connection, schema):
        connection.execute('CREATE SCHEMA "{}";'.format(schema))

    if not schema_exists(connection, schema_static):
        connection.execute('CREATE SCHEMA "{}";'.format(schema_static))

    if alembic_ini is None:
        Base.metadata.create_all(connection)
    else:
        def upgrade(schema: str) -> None:
            cfg = alembic_config.Config(alembic_ini, ini_section=schema)
            cfg.attributes['connection'] = connection  # pylint: disable=unsupported-assignment-operation
            command.upgrade(cfg, 'head')
        upgrade('main')
        upgrade('static')

    session_factory = get_session_factory(connection)

    with transaction.manager:
        dbsession = get_tm_session(session_factory, transaction.manager)
        if test:
            setup_test_data(dbsession)
Esempio n. 35
0
def init_db(connection: Connection, force: bool=False, test: bool=False) -> None:
    import c2cgeoportal_commons.models.main  # noqa: F401
    import c2cgeoportal_commons.models.static  # noqa: F401
    from c2cgeoportal_commons.models import schema

    schema_static = '{}_static'.format(schema)

    assert schema is not None
    if force:
        if schema_exists(connection, schema):
            connection.execute('DROP SCHEMA {} CASCADE;'.format(schema))
        if schema_exists(connection, schema_static):
            connection.execute('DROP SCHEMA {} CASCADE;'.format(schema_static))

    if not schema_exists(connection, schema):
        connection.execute('CREATE SCHEMA "{}";'.format(schema))

    if not schema_exists(connection, schema_static):
        connection.execute('CREATE SCHEMA "{}";'.format(schema_static))

    Base.metadata.create_all(connection)

    session_factory = get_session_factory(connection)

    with transaction.manager:
        dbsession = get_tm_session(session_factory, transaction.manager)
        if test:
            setup_test_data(dbsession)
Esempio n. 36
0
File: db.py Progetto: agdsn/hades
def lock_table(connection: Connection, target_table: Table):
    """
    Lock a table using a PostgreSQL advisory lock

    The OID of the table in the pg_class relation is used as lock id.
    :param connection: DB connection
    :param target_table: Table object
    """
    logger.debug('Locking table "%s"', target_table.name)
    oid = connection.execute(select([column("oid")])
                             .select_from(table("pg_class"))
                             .where((column("relname") == target_table.name))
                             ).scalar()
    connection.execute(select([func.pg_advisory_xact_lock(oid)])).scalar()
Esempio n. 37
0
def _copy_submission_entries(connection: Connection,
                             existing_survey_id: str,
                             new_survey_id: str,
                             email: str) -> tuple:
    """
    Copy submissions from an existing survey to its updated copy.

    :param connection: the SQLAlchemy connection used for the transaction
    :param existing_survey_id: the UUID of the existing survey
    :param new_survey_id: the UUID of the survey's updated copy
    :param email: the user's e-mail address
    :return: a tuple containing the old and new submission IDs
    """
    submissions = get_submissions_by_email(
        connection, email,
        survey_id=existing_survey_id
    )
    for sub in submissions:
        values = {'submitter': sub.submitter,
                  'submitter_email': sub.submitter_email,
                  'submission_time': sub.submission_time,
                  'save_time': sub.save_time,
                  'survey_id': new_survey_id}
        result = connection.execute(submission_insert(**values))
        yield sub.submission_id, result.inserted_primary_key[0]
Esempio n. 38
0
def execute_with_exceptions(connection: Connection,
                            executable: [Insert, Update],
                            exceptions: Iterator) -> ResultProxy:
    """
    Execute the given executable (a SQLAlchemy Insert or Update) within a
    transaction (provided by the Connection object), and raise meaningful
    exceptions. Normally connection.execute() will raise a generic Integrity
    error, so use the exceptions parameter to specify which exceptions to
    raise instead.

    :param connection: the SQLAlchemy connection (for transaction purposes)
    :param executable: the object to pass to connection.execute()
    :param exceptions: an iterable of (name: str, exception: Exception) tuples.
                       name is the string to look for in the IntegrityError,
                       and exception is the Exception to raise instead of
                       IntegrityError
    :return: a SQLAlchemy ResultProxy
    """
    try:
        return connection.execute(executable)
    except IntegrityError as exc:
        error = str(exc.orig)
        for name, exception in exceptions:
            if name in error:
                raise exception
        raise
Esempio n. 39
0
def get_questions(connection: Connection,
                  survey_id: str,
                  auth_user_id: [str, None]=None,
                  email: [str, None]=None) -> ResultProxy:
    """
    Get all the questions for a survey identified by survey_id ordered by
    sequence number restricted by auth_user.

    :param connection: a SQLAlchemy Connection
    :param survey_id: the UUID of the survey
    :param auth_user_id: the UUID of the user
    :param email: the user's e-mail address
    :return: an iterable of the questions (RowProxy)
    """

    table = question_table.join(survey_table)
    conds = [question_table.c.survey_id == survey_id]

    if auth_user_id is not None:
        if email is not None:
            raise TypeError('You cannot specify both auth_user_id and email')
        conds.append(survey_table.c.auth_user_id == auth_user_id)
    elif email is not None:
        table = table.join(auth_user_table)
        conds.append(auth_user_table.c.email == email)
    else:
        raise TypeError('You must specify either auth_user_id or email')

    questions = connection.execute(
        select([question_table]).select_from(table).where(
            and_(*conds)).order_by('sequence_number asc'))
    return questions
Esempio n. 40
0
File: db.py Progetto: agdsn/hades
def get_sessions_of_mac(connection: Connection, mac: netaddr.EUI,
                        when: Optional[DatetimeRange]=None,
                        limit: Optional[int]=None) -> Iterable[
        Tuple[netaddr.IPAddress, str, datetime, datetime]]:
    """
    Return accounting sessions of a particular MAC address ordered by
    Session-Start-Time descending.

    :param connection: A SQLAlchemy connection
    :param str mac: MAC address
    :param when: Range in which Session-Start-Time must be within
    :param limit: Maximum number of records
    :return: An iterable that yields (NAS-IP-Address, NAS-Port-Id,
    Session-Start-Time, Session-Stop-Time)-tuples ordered by Session-Start-Time
    descending
    """
    logger.debug('Getting all sessions for MAC "%s"', mac)
    query = (
        select([radacct.c.NASIPAddress, radacct.c.NASPortId,
                radacct.c.AcctStartTime,
                radacct.c.AcctStopTime])
        .where(and_(radacct.c.UserName == mac))
        .order_by(radacct.c.AcctStartTime.desc())
    )
    if when is not None:
        query.where(radacct.c.AcctStartTime.op('<@') <= func.tstzrange(*when))
    if limit is not None:
        query = query.limit(limit)
    return iter(connection.execute(query))
Esempio n. 41
0
def _jsonify(connection: Connection,
             answer: object,
             question_id: str) -> object:
    """
    This function returns a "nice" representation of an answer which can be
    serialized as JSON.

    :param connection: a SQLAlchemy Connection
    :param answer: a submitted value
    :param type_constraint_name: the UUID of the question
    :return: the nice representation
    """
    type_constraint_name = question_select(connection,
                                           question_id).type_constraint_name
    if type_constraint_name in {'location', 'facility'}:
        geo_json = connection.execute(func.ST_AsGeoJSON(answer)).scalar()
        return json_decode(geo_json)['coordinates']
    elif type_constraint_name in {'date', 'time'}:
        return maybe_isoformat(answer)
    elif type_constraint_name == 'decimal':
        return float(answer)
    elif type_constraint_name == 'multiple_choice':
        question_choice = question_choice_select(connection, answer)
        return question_choice.choice
    else:
        return answer
Esempio n. 42
0
def get_stats(connection: Connection,
              survey_id: str,
              email: str) -> dict:
    """
    Get statistics about the specified survey: creation time, number of
    submissions, time of the earliest submission, and time of the latest
    submission.

    :param connection: a SQLAlchemy Connection
    :param survey_id: the UUID of the survey
    :param email: the e-mail address of the user
    :return: a JSON representation of the statistics.
    """
    result = connection.execute(
        select([
            survey_table.c.created_on,
            count(submission_table.c.submission_id),
            sqlmin(submission_table.c.submission_time),
            sqlmax(submission_table.c.submission_time)
        ]).select_from(
            auth_user_table.join(survey_table).outerjoin(submission_table)
        ).where(
            survey_table.c.survey_id == survey_id
        ).where(
            auth_user_table.c.email == email
        ).group_by(
            survey_table.c.survey_id
        )
    ).first()
    return json_response({
        'created_on': maybe_isoformat(result[0]),
        'num_submissions': result[1],
        'earliest_submission_time': maybe_isoformat(result[2]),
        'latest_submission_time': maybe_isoformat(result[3])
    })
Esempio n. 43
0
File: db.py Progetto: agdsn/hades
def get_auth_attempts_at_port(connection: Connection,
                              nas_ip_address: netaddr.IPAddress,
                              nas_port_id: str,
                              when: Optional[DatetimeRange]=None,
                              limit: Optional[int]=None)-> Iterable[
        Tuple[str, str, Groups, Attributes, datetime]]:
    """
    Return auth attempts at a particular port of an NAS ordered by Auth-Date
    descending.

    :param connection: A SQLAlchemy connection
    :param nas_ip_address: NAS IP address
    :param nas_port_id: NAS Port ID
    :param when: Range in which Auth-Date must be within
    :param limit: Maximum number of records
    :return: An iterable that yields (User-Name, Packet-Type, Groups, Reply,
             Auth-Date)-tuples ordered by Auth-Date descending
    """
    logger.debug('Getting all auth attempts at port %2$s of %1$s',
                 nas_ip_address, nas_port_id)
    query = (
        select([radpostauth.c.UserName, radpostauth.c.PacketType,
                radpostauth.c.Groups, radpostauth.c.Reply,
                radpostauth.c.AuthDate])
        .where(and_(radpostauth.c.NASIPAddress == nas_ip_address,
                    radpostauth.c.NASPortId == nas_port_id))
        .order_by(radpostauth.c.AuthDate.desc())
    )
    if when is not None:
        query.where(radpostauth.c.AuthDate.op('<@') <= func.tstzrange(*when))
    if limit is not None:
        query = query.limit(limit)
    return iter(connection.execute(query))
Esempio n. 44
0
File: db.py Progetto: agdsn/hades
def get_auth_attempts_of_mac(connection: Connection, mac: netaddr.EUI,
                             when: Optional[DatetimeRange]=None,
                             limit: Optional[int]=None) -> Iterable[
        Tuple[netaddr.IPAddress, str, str, Groups, Attributes, datetime]]:
    """
    Return auth attempts of a particular MAC address order by Auth-Date
    descending.

    :param connection: A SQLAlchemy connection
    :param mac: MAC address
    :param when: Range in which Auth-Date must be within
    :param limit: Maximum number of records
    :return: An iterable that yields (NAS-IP-Address, NAS-Port-Id, Packet-Type,
    Groups, Reply, Auth-Date)-tuples ordered by Auth-Date descending
    """
    logger.debug('Getting all auth attempts of MAC %s', mac)
    query = (
        select([radpostauth.c.NASIPAddress, radpostauth.c.NASPortId,
                radpostauth.c.PacketType, radpostauth.c.Groups,
                radpostauth.c.Reply, radpostauth.c.AuthDate])
        .where(and_(radpostauth.c.UserName == mac))
        .order_by(radpostauth.c.AuthDate.desc())
    )
    if when is not None:
        query.where(radpostauth.c.AuthDate.op('<@') <= func.tstzrange(*when))
    if limit is not None:
        query = query.limit(limit)
    return iter(connection.execute(query))
Esempio n. 45
0
def survey_select(connection: Connection,
                  survey_id: str,
                  auth_user_id: str=None,
                  email: str=None) -> RowProxy:
    """
    Get a record from the survey table. You must supply either the
    auth_user_id or the email.

    :param connection: a SQLAlchemy Connection
    :param survey_id: the UUID of the survey
    :param auth_user_id: the UUID of the user
    :param email: the user's e-mail address
    :return: the corresponding record
    :raise SurveyDoesNotExistError: if the UUID is not in the table
    """
    table = survey_table
    conds = [survey_table.c.survey_id == survey_id]

    if auth_user_id is not None:
        if email is not None:
            raise TypeError('You cannot specify both auth_user_id and email')
        conds.append(survey_table.c.auth_user_id == auth_user_id)
    elif email is not None:
        table = table.join(auth_user_table)
        conds.append(auth_user_table.c.email == email)
    else:
        raise TypeError('You must specify either auth_user_id or email')

    survey = connection.execute(select([survey_table]).select_from(
        table).where(and_(*conds))).first()
    if survey is None:
        raise SurveyDoesNotExistError(survey_id)
    return survey
Esempio n. 46
0
def _return_sql(connection: Connection,
                result: object,
                survey_id: str,
                auth_user_id: str,
                question_id: str) -> object:
    """
    Get the result for a _scalar-y function.

    :param connection: a SQLAlchemy Connection
    :param result: the result of the SQL function
    :param survey_id: the UUID of the survey
    :param auth_user_id: the UUID of the user
    :param question_id: the UUID of the question
    :return: the result of the SQL function
    :raise NoSubmissionsToQuestionError: if there are no submissions
    :raise QuestionDoesNotExistError: if the user is not authorized
    """
    if result is None or result == []:
        condition = survey_table.c.survey_id == survey_id
        stmt = select([survey_table]).where(condition)
        proper_id = connection.execute(stmt).first().auth_user_id
        if auth_user_id == proper_id:
            raise NoSubmissionsToQuestionError(question_id)
        raise QuestionDoesNotExistError(question_id)
    return result
Esempio n. 47
0
def update(connection: Connection, data: dict):
    """
    Update a survey (title, questions). You can also add or modify questions
    here. Note that this creates a new survey (with new submissions, etc),
    copying everything from the old survey. The old survey's title will be
    changed to end with "(new version created on <time>)".

    :param connection: a SQLAlchemy Connection
    :param data: JSON containing the UUID of the survey and fields to update.
    """
    survey_id = data['survey_id']
    email = data['email']
    existing_survey = survey_select(connection, survey_id, email=email)
    if 'survey_metadata' not in data:
        data['survey_metadata'] = existing_survey.survey_metadata
    update_time = datetime.datetime.now()

    with connection.begin():
        new_title = '{} (new version created on {})'.format(
            existing_survey.survey_title, update_time.isoformat())
        executable = update_record(survey_table, 'survey_id', survey_id,
                                   survey_title=new_title)
        exc = [('survey_title_survey_owner_key',
                SurveyAlreadyExistsError(new_title))]
        execute_with_exceptions(connection, executable, exc)

        new_survey_id = _create_survey(connection, data)

    return get_one(connection, new_survey_id, email=email)
Esempio n. 48
0
def create_user(connection: Connection, data: dict) -> dict:
    """
    Registers a new user account.

    :param connection: a SQLAlchemy Connection
    :param data: the user's e-mail
    :return: a response containing the e-mail and whether it was created or
    already exists in the database
    """
    email = data['email']
    try:
        get_auth_user_by_email(connection, email)
    except UserDoesNotExistError:
        with connection.begin():
            connection.execute(create_auth_user(email=email))
        return json_response({'email': email, 'response': 'Created'})
    return json_response({'email': email, 'response': 'Already exists'})
Esempio n. 49
0
def schema_exists(connection: Connection, schema_name: str) -> bool:
    sql = """
SELECT count(*) AS count
FROM information_schema.schemata
WHERE schema_name = '{}';
""".format(schema_name)
    result = connection.execute(sql)
    row = result.first()
    return row[0] == 1
Esempio n. 50
0
def bar_graph(connection: Connection,
              question_id: str,
              auth_user_id: str=None,
              email: str=None,
              limit: [int, None]=None,
              count_order: bool=False) -> dict:
    """
    Get a list of the number of times each submission value appears. You must
    provide either an auth_user_id or e-mail address.

    :param connection: a SQLAlchemy Connection
    :param question_id: the UUID of the question
    :param auth_user_id: the UUID of the user
    :param email: the e-mail address of the user.
    :param limit: a limit on the number of results
    :param count_order: whether to order from largest count to smallest
    :return: a JSON dict containing the result [[values], [counts]]
    """
    user_id = _get_user_id(connection, auth_user_id, email)

    allowable_types = {'text', 'integer', 'decimal', 'multiple_choice', 'date',
                       'time', 'location', 'facility'}

    question = question_select(connection, question_id)

    tcn = _get_type_constraint_name(allowable_types, question)

    # Assume that you only want to consider the non-other answers
    original_table, column_name = _table_and_column(tcn)
    table = original_table.join(
        question_table,
        original_table.c.question_id == question_table.c.question_id
    ).join(survey_table)

    conds = [question_table.c.question_id == question_id,
             survey_table.c.auth_user_id == user_id]
    column = get_column(original_table, column_name)

    column_query = select(
        [column, sqlcount(column)]
    ).select_from(table).group_by(column)
    ordering = desc(sqlcount(column)) if count_order else column
    ordered_query = column_query.order_by(ordering)

    result = connection.execute(
        ordered_query.where(and_(*conds)).limit(limit)
    )

    result = _return_sql(connection, result, question.survey_id, user_id,
                         question_id)
    bar_graph_result = [[_jsonify(connection, r[0], question_id), r[1]] for r
                        in result]
    response = json_response(
        _return_sql(connection, bar_graph_result, question.survey_id,
                    user_id, question_id))
    response['query'] = 'bar_graph'
    return response
Esempio n. 51
0
File: db.py Progetto: agdsn/hades
def refresh_and_diff_materialized_view(
        connection: Connection, view: Table, copy: Table,
        result_columns: Iterable[Column]) -> Tuple[
        List[Tuple], List[Tuple], List[Tuple]]:
    with connection.begin():
        lock_table(connection, view)
        create_temp_copy(connection, view, copy)
        refresh_materialized_view(connection, view)
        return diff_tables(connection, view, copy, result_columns)
Esempio n. 52
0
def get_number_of_submissions(connection: Connection, survey_id: str) -> int:
    """
    Return the number of submissions for a given survey

    :param connection: a SQLAlchemy Connection
    :param survey_id: the UUID of the survey
    :return: the corresponding number of submissions
    """
    return connection.execute(select([count()]).where(
        submission_table.c.survey_id == survey_id)).scalar()
Esempio n. 53
0
def time_series(connection: Connection,
                question_id: str,
                auth_user_id: str=None,
                email: str=None) -> dict:
    """
    Get a list of submissions to the specified question over time. You must
    provide either an auth_user_id or e-mail address.

    :param connection: a SQLAlchemy Connection
    :param question_id: the UUID of the question
    :param auth_user_id: the UUID of the user
    :param email: the e-mail address of the user.
    :return: a JSON dict containing the result [[times], [values]]
    """
    user_id = _get_user_id(connection, auth_user_id, email)

    allowable_types = {'text', 'integer', 'decimal', 'multiple_choice', 'date',
                       'time', 'location'}

    question = question_select(connection, question_id)

    tcn = _get_type_constraint_name(allowable_types, question)

    # Assume that you only want to consider the non-other answers
    original_table, column_name = _table_and_column(tcn)
    table = original_table.join(
        survey_table,
        original_table.c.survey_id == survey_table.c.survey_id
    ).join(
        submission_table,
        original_table.c.submission_id == submission_table.c.submission_id
    )
    column = get_column(original_table, column_name)

    where_stmt = select(
        [column, submission_table.c.submission_time]
    ).select_from(table).where(
        original_table.c.question_id == question_id
    ).where(
        survey_table.c.auth_user_id == user_id
    )

    result = _return_sql(
        connection,
        connection.execute(where_stmt.order_by('submission_time asc')),
        question.survey_id, auth_user_id, question_id)
    tsr = [[r.submission_time.isoformat(),
            _jsonify(connection, r[column_name], question_id)]
           for r in result]
    time_series_result = tsr
    response = json_response(
        _return_sql(connection, time_series_result, question.survey_id,
                    user_id, question_id))
    response['query'] = 'time_series'
    return response
Esempio n. 54
0
File: db.py Progetto: agdsn/hades
def get_all_alternative_dns_ips(connection: Connection) -> Iterable[
        netaddr.IPAddress]:
    """
    Return all IPs for alternative DNS configuration.

    :param connection: A SQLAlchemy connection
    :return: An iterable that yields ip addresses
    """
    logger.debug("Getting all alternative DNS clients")
    result = connection.execute(select([alternative_dns.c.IPAddress]))
    return map(operator.itemgetter(0), result)
Esempio n. 55
0
def get_answers_for_question(connection: Connection, question_id: str) -> ResultProxy:
    """
    Get all the records from the answer table identified by question_id.

    :param connection: a SQLAlchemy Connection
    :param question_id: foreign key
    :return: an iterable of the answers (RowProxy)
    """
    select_stmt = select([answer_table])
    where_stmt = select_stmt.where(answer_table.c.question_id == question_id)
    return connection.execute(where_stmt)
Esempio n. 56
0
def get_branches(connection: Connection, question_id: str) -> ResultProxy:
    """
    Get all the branches for a question identified by question_id.

    :param connection: a SQLAlchemy Connection
    :param question_id: foreign key
    :return: an iterable of the branches (RowProxy)
    """
    select_stmt = select([question_branch_table])
    where_stmt = select_stmt.where(question_branch_table.c.from_question_id == question_id)
    return connection.execute(where_stmt)
Esempio n. 57
0
File: db.py Progetto: agdsn/hades
def get_all_dhcp_hosts(connection: Connection) -> Iterable[
        Tuple[netaddr.EUI, netaddr.IPAddress]]:
    """
    Return all DHCP host configurations.

    :param connection: A SQLAlchemy connection
    :return: An iterable that yields (mac, ip)-tuples
    """
    logger.debug("Getting all DHCP hosts")
    result = connection.execute(select([dhcphost.c.MAC, dhcphost.c.IPAddress]))
    return iter(result)
Esempio n. 58
0
def create(connection: Connection, data: dict) -> dict:
    """
    Create a survey with questions.

    :param connection: a SQLAlchemy Connection
    :param data: a JSON representation of the survey to be created
    :return: a JSON representation of the created survey
    """
    with connection.begin():
        survey_id = _create_survey(connection, data)

    return get_one(connection, survey_id, email=data['email'])
Esempio n. 59
0
def get_answers(connection: Connection, submission_id: str) -> ResultProxy:
    """
    Get all the records from the answer table identified by submission_id
    ordered by sequence number.

    :param connection: a SQLAlchemy Connection
    :param submission_id: foreign key
    :return: an iterable of the answers (RowProxy)
    """
    select_stmt = select([answer_table])
    where_stmt = select_stmt.where(answer_table.c.submission_id == submission_id)
    return connection.execute(where_stmt.order_by("sequence_number asc"))
Esempio n. 60
0
File: db.py Progetto: agdsn/hades
def create_temp_copy(connection: Connection, source: Table, destination: Table):
    """
    Create a temporary table as a copy of a source table that will be dropped
    at the end of the running transaction.
    :param connection: DB connection
    :param source: Source table
    :param destination: Destination table
    """
    logger.debug('Creating temporary table "%s" as copy of "%s"',
                 destination.name, source.name)
    if not connection.in_transaction():
        raise RuntimeError("must be executed in a transaction to have any "
                           "effect")
    preparer = connection.dialect.identifier_preparer
    connection.execute(
        'CREATE TEMPORARY TABLE {destination} ON COMMIT DROP AS '
        'SELECT * FROM {source}'.format(
            source=preparer.format_table(source),
            destination=preparer.format_table(destination),
        )
    )