def upgrade(self, connection: Connection): connection.execute("""create table migrations( id int not null auto_increment, version int not null, created timestamp not null default current_timestamp, primary key (id) );""")
def _get_net_income_quater(connection: engine.Connection, meta: MetaData, name: str, b_date: date) -> float: """ Аргументы: соединение с бд, метаданные бд, имя банка и дату начала квартала. Расчитывает значение балансовой прибыли, нарастающим итогом Возвращает это значение на начало квартала. """ # Ассоциация с таблицей f102 bank = meta.tables['f102'] # Поиск доходов select_statement = select( bank.c.SIM_ITOGO).filter( and_( bank.c.NAME_B == name, bank.c.DT == b_date, bank.c.CODE == 19999)) result = connection.execute(select_statement) incomes = result.fetchone()[0] # Поиск расходов select_statement = select( bank.c.SIM_ITOGO).filter( and_( bank.c.NAME_B == name, bank.c.DT == b_date, bank.c.CODE == 29999)) result = connection.execute(select_statement) consumptions = result.fetchone()[0] return incomes - consumptions
def insert_multiple(application_id: str, account_name: str, followers: Iterable, source: str, cursor: str, conn: Connection) -> None: """Add followers for one account into the account relationships table""" trans = conn.begin() try: account_id = select_one_id(application_id, account_name, source, conn) for follower in followers: follower_id = select_one_id(application_id, follower, source, conn) if not follower_id: follower_id = insert_one(application_id, follower, False, source, conn) stmt_relationship = insert(account_relationship, values={ 'account': account_id, 'follower': follower_id }) conn.execute(stmt_relationship) update_one_cursor(account_id, cursor, conn) trans.commit() except Exception as e: trans.rollback()
def _add_column( self, connection: Connection, table: SqlaTable, column_name: str, column_type: str, ) -> None: """ Add new column to table :param connection: The connection to work with :param table: The SqlaTable :param column_name: The name of the column :param column_type: The type of the column """ column = Column(column_name, column_type) name = column.compile(column_name, dialect=table.database.get_sqla_engine().dialect) col_type = column.type.compile( table.database.get_sqla_engine().dialect) sql = text( f"ALTER TABLE {self._get_from_clause(table)} ADD {name} {col_type}" ) connection.execute(sql) table.columns.append( TableColumn(column_name=column_name, type=col_type))
def _insert_bank_list(connection: engine.Connection, meta: MetaData, ratios: dict, b_date: date): """ Аргументы: соединение с бд, метаданные бд, дата и словарь, ключи - именования банков, значения - списки, рассчитанных коэффициентов для банков, за данную дату. Вставляет данные в витрину данных """ # Ассоциация с таблицой витрины данных table_mart = meta.tables[config.mart_name] # Ассоциация с таблицей req table_req = meta.tables['req'] for name in ratios: # Нахождение индекса банка в таблице req select_statement = select( table_req.c.index).filter(table_req.c.NAME_B == name) result = connection.execute(select_statement) bank_index = result.fetchone()[0] # Вставка данных в витрину данных insert_statement = table_mart.\ insert().values(bank_index=bank_index, date=b_date, CA=ratios[name][0], LR=ratios[name][1], CR=ratios[name][2], LTR=ratios[name][3], CP=ratios[name][4], ROA=ratios[name][5]) connection.execute(insert_statement)
def _insert_bank(connection: engine.Connection, meta: MetaData, ratios: dict, name: str): """ Аргументы: соединение с бд, метаданные бд, имя банка и cловарь, ключи - даты, значения - списки, расcчитанных коэффициентов за эту дату для данного банка. Вставляет данные в витрину данных """ # Ассоциация с таблицой витрины данных table_mart = meta.tables[config.mart_name] # Ассоциация с таблицей req table_req = meta.tables['req'] # Нахождение индекса банка в таблице req select_statement = select( table_req.c.index).filter(table_req.c.NAME_B == name) result = connection.execute(select_statement) bank_index = result.fetchone()[0] # Вставка данных в витрину данных for b_date in ratios: insert_statement = table_mart.\ insert().values(bank_index=bank_index, date=b_date, CA=ratios[b_date][0], LR=ratios[b_date][1], CR=ratios[b_date][2], LTR=ratios[b_date][3], CP=ratios[b_date][4], ROA=ratios[b_date][5]) connection.execute(insert_statement)
def get_columns( self, connection: Connection, table_name: str, schema: Optional[str] = None, **kwargs: Any, ) -> List[Dict[str, Any]]: query = f'SHOW COLUMNS FROM "{table_name}"' # Custom SQL array_columns_ = connection.execute( f"SHOW ARRAY_COLUMNS FROM {table_name}").fetchall() # convert cursor rows: List[Tuple[str]] to List[str] if not array_columns_: array_columns = [] else: array_columns = [col_name.name for col_name in array_columns_] all_columns = connection.execute(query) return [{ "name": row.column, "type": basesqlalchemy.get_type(row.mapping), "nullable": True, "default": None, } for row in all_columns if row.mapping not in self._not_supported_column_types and row.column not in array_columns]
def get_free_title(connection: Connection, title: str, auth_user_id: str) -> str: """ Get a good version of the title to be inserted into the survey table. If the title as given already exists, this function will append a number. For example, when the title is "survey": 1. "survey" not in table -> "survey" 2. "survey" in table -> "survey(1)" 3. "survey(1)" in table -> "survey(2)" :param connection: a SQLAlchemy Connection :param title: the survey title :param auth_user_id: the user's UUID :return: a title that can be inserted safely """ (does_exist, ), = connection.execute( select((exists().where( survey_table.c.survey_title == title ).where(survey_table.c.auth_user_id == auth_user_id),))) if not does_exist: return title similar_surveys = connection.execute( select([survey_table]).where( survey_table.c.survey_title.like(title + '%') ).where( survey_table.c.auth_user_id == auth_user_id ) ).fetchall() conflicts = list(_conflicting(title, similar_surveys)) free_number = max(conflicts) + 1 if len(conflicts) else 1 return title + '({})'.format(free_number)
def test_saves_auction_changes( connection: Connection, another_bidder_id: int, bid_model: RowProxy, auction_model_with_a_bid: RowProxy, ends_at: datetime, event_bus_mock: Mock, ) -> None: new_bid_price = get_dollars(bid_model.amount * 2) auction = Auction( id=auction_model_with_a_bid.id, title=auction_model_with_a_bid.title, starting_price=get_dollars(auction_model_with_a_bid.starting_price), ends_at=ends_at, bids=[ Bid(bid_model.id, bid_model.bidder_id, get_dollars(bid_model.amount)), Bid(None, another_bidder_id, new_bid_price), ], ended=True, ) SqlAlchemyAuctionsRepo(connection, event_bus_mock).save(auction) assert connection.execute(select([func.count() ]).select_from(bids)).scalar() == 2 proxy = connection.execute( select([ auctions ]).where(auctions.c.id == auction_model_with_a_bid.id)).first() assert proxy.current_price == new_bid_price.amount assert proxy.ended
def create_table(self, schema_sql: str, conn: Connection, driver: DBDriver) -> None: logger.info('Creating table...') conn.execute(schema_sql) logger.info(f"Just ran {schema_sql}") self.add_permissions(conn, driver) logger.info("Table prepped")
def diff_tables( connection: Connection, master: Table, copy: Table, result_columns: Iterable[Column] ) -> Tuple[List[Tuple], List[Tuple], List[Tuple]]: """ Compute the differences in the contents of two tables with identical columns. The master table must have at least one PrimaryKeyConstraint or UniqueConstraint with only non-null columns defined. If there are multiple constraints defined the constraints that contains the least number of columns are used. :param connection: DB connection :param master: Master table :param copy: Copy of master table :param result_columns: columns to return :return: True, if the contents differ, otherwise False """ logger.debug('Calculating diff between "%s" and "%s"', master.name, copy.name) result_columns = tuple(result_columns) unique_columns = min( (constraint.columns for constraint in master.constraints if isinstance(constraint, (UniqueConstraint, PrimaryKeyConstraint)) and constraint.columns and not any(map(operator.attrgetter('nullable'), constraint.columns))), key=len, default=[]) if not unique_columns: raise AssertionError("To diff table {} it must have at least one " "PrimaryKeyConstraint/UniqueConstraint with only " "NOT NULL columns defined on it.".format( master.name)) unique_column_names = tuple(col.name for col in unique_columns) other_column_names = tuple(col.name for col in master.c if col.name not in unique_column_names) on_clause = and_( *(getattr(master.c, column_name) == getattr(copy.c, column_name) for column_name in unique_column_names)) added = connection.execute( select(result_columns).select_from(master.outerjoin( copy, on_clause)).where( or_(*(getattr(copy.c, column_name).is_(null()) for column_name in unique_column_names)))).fetchall() deleted = connection.execute( select(result_columns).select_from(copy.outerjoin( master, on_clause)).where( or_(*(getattr(master.c, column_name).is_(null()) for column_name in unique_column_names)))).fetchall() modified = connection.execute( select(result_columns).select_from(master.join(copy, on_clause)).where( or_(*( getattr(master.c, column_name) != getattr(copy.c, column_name) for column_name in other_column_names )))).fetchall() if other_column_names else [] logger.debug('Diff found %d added, %d deleted, and %d modified records', len(added), len(deleted), len(modified)) return added, deleted, modified
def bulk_update(conn: Connection, table: Table, col_name: str, ids_values: List[Tuple]): stmt = table.update().where(table.c.id == bindparam('obj_id')).values( **{col_name: bindparam('val')}) conn.execute(stmt, [{ 'obj_id': eid, 'val': val } for eid, val in ids_values])
def update_cluster(account_id: str, topic_iteration_id: str, cluster: int, x: float, y: float, conn: Connection) -> None: """Update the cluster and the x and y values of the topic""" stmt = update(topic).where( and_(topic.c.account == account_id, topic.c.topic_iteration == topic_iteration_id)).values( cluster=cluster, x=x, y=y) conn.execute(stmt)
def import_nivo(con: Connection, csv_file: ANivoCsv) -> None: csv_file.normalize() csv_file.find_and_replace_foreign_key_value() with con.begin(): ins = insert(NivoRecordTable).values( csv_file.cleaned_csv ) # .on_conflict_do_nothing(index_elements=['nss_name']) con.execute(ins)
def _execute_sql_stream(conn: Connection, sql: str) -> None: """Run the SQL statements in a stream against a database.""" for query in split_sql(sql): if _should_escape_percents(conn): escaped_query = query.replace("%", "%%") else: escaped_query = query conn.execute(escaped_query)
def add_columns(table: str, df: pandas.DataFrame, con: Connection): """Add columns eventually missing in the table, if the table exists""" if table_exists(table, con): cols = set(con.execute(f'SELECT * FROM {table} LIMIT 1').keys()) new_cols = set(df.columns) - cols for col in new_cols: alter = f'ALTER TABLE {table} ADD COLUMN "{col}" FLOAT' con.execute(alter)
def _add_or_update_records( cls, conn: Connection, table: Table, records: List["I2B2CoreWithUploadId"]) -> Tuple[int, int]: """ Add or update the supplied table as needed to reflect the contents of records :param table: i2b2 sql connection :param records: records to apply :return: number of records added / modified """ num_updates = 0 num_inserts = 0 inserts = [] # Iterate over the records doing updates # Note: This is slow as molasses - definitely not optimal for batch work, but hopefully we'll be dealing with # thousands to tens of thousands of records. May want to move to ORM model if this gets to be an issue for record in records: keys = [(table.c[k] == getattr(record, k)) for k in cls.key_fields] key_filter = I2B2CoreWithUploadId._nested_fcn(and_, keys) rec_exists = conn.execute( select([table.c.upload_id]).where(key_filter)).rowcount if rec_exists: known_values = { k: v for k, v in record._freeze().items() if v is not None and k not in cls.no_update_fields and k not in cls.key_fields } vals = [table.c[k] != v for k, v in known_values.items()] val_filter = I2B2CoreWithUploadId._nested_fcn(or_, vals) known_values['update_date'] = record.update_date upd = update(table).where(and_( key_filter, val_filter)).values(known_values) num_updates += conn.execute(upd).rowcount else: inserts.append(record._freeze()) if inserts: if cls._check_dups: dups = cls._check_for_dups(inserts) nprints = 0 # TODO: Figure out why duplicates are occuring -- they are very specific if dups: print("{} duplicate records encountered".format(len(dups))) for k, vals in dups.items(): if len(vals) == 2 and vals[0] == vals[1]: inserts.remove(vals[1]) else: if nprints < 20: print("Key: {} has a non-identical dup".format( k)) elif nprints == 20: print(".... more ...") nprints += 1 for v in vals[1:]: inserts.remove(v) # TODO: refactor this to load on a per-resource basis. Temporary fix for insert in ListChunker(inserts, 500): num_inserts += conn.execute(table.insert(), insert).rowcount return num_inserts, num_updates
def bulk_update_column( conn: Connection, table: Table, col_name: str, ids_values: List[Tuple], key_col="id" ): stmt = ( table.update() .where(getattr(table.c, key_col) == bindparam("obj_id")) .values(**{col_name: bindparam("val")}) ) conn.execute(stmt, [{"obj_id": eid, "val": val} for eid, val in ids_values])
def insert_one(application_id: str, source: str, topics: Iterable, conn: Connection) -> None: stmt = insert(topic_model, values={ 'application': application_id, 'source': source, 'topics': topics }) conn.execute(stmt)
def insert_one(account_id: str, weights: Iterable, topic_iteration: str, conn: Connection) -> None: stmt = insert(topic, values={ 'account': account_id, 'weights': weights, 'topic_iteration': topic_iteration }) conn.execute(stmt)
def delete(connection: Connection, survey_id: str): """ Delete the survey specified by the given survey_id :param connection: a SQLAlchemy connection :param survey_id: the UUID of the survey """ with connection.begin(): connection.execute(delete_record(survey_table, 'survey_id', survey_id)) return json_response('Survey deleted')
def nabel_update_recent(con: Connection, from_time: datetime): year = datetime.now(tz_local).year print(f'Retrieving data for the city of Zurich for {year}') if table_exists(NABEL_TABLE, con): con.execute( f"DELETE FROM {NABEL_TABLE} WHERE date >= '{from_time.isoformat()}'" ) nabel_download(con, from_time, datetime.now(tz_local)) print('Done!')
def persist_department(con: Connection, name: str, number: str, zone: str) -> UUID: zone_id = persist_zone(con, zone) ins = insert(DepartmentTable).values(d_name=name, d_number=number, d_zone=zone_id) ins = ins.on_conflict_do_nothing(index_elements=["d_name"]) con.execute(ins) res = select([DepartmentTable.c.d_id ]).where(DepartmentTable.c.d_name == name) return con.execute(res).first().d_id
def start_new_payment(payment_uuid: UUID, customer_id: int, amount: Money, description: str, conn: Connection) -> None: conn.execute( payments.insert({ "uuid": str(payment_uuid), "customer_id": customer_id, "amount": int(amount.amount) * 100, "currency": amount.currency.iso_code, "description": description, "status": PaymentStatus.NEW.value, }))
def persist_flowcapt_station(con: Connection, station: Feature) -> None: geom = _get_geom(station) ins = insert(FlowCaptStationTable).values( fcs_id=station.properties["id"], fcs_site=station.properties["site"], fcs_country=station.properties["country"], fcs_altitude=station.properties["altitude"], the_geom=geom, ) ins = ins.on_conflict_do_nothing(index_elements=["fcs_id"]) con.execute(ins)
def _re_encrypt_row( self, conn: Connection, row: RowProxy, table_name: str, columns: Dict[str, EncryptedType], ) -> None: """ Re encrypts all columns in a Row :param row: Current row to reencrypt :param columns: Meta info from columns """ re_encrypted_columns = {} for column_name, encrypted_type in columns.items(): previous_encrypted_type = EncryptedType( type_in=encrypted_type.underlying_type, key=self._previous_secret_key) try: unencrypted_value = previous_encrypted_type.process_result_value( self._read_bytes(column_name, row[column_name]), self._dialect) except ValueError as exc: # Failed to unencrypt try: encrypted_type.process_result_value( self._read_bytes(column_name, row[column_name]), self._dialect) logger.info( "Current secret is able to decrypt value on column [%s.%s]," " nothing to do", table_name, column_name, ) return except Exception: raise Exception from exc re_encrypted_columns[ column_name] = encrypted_type.process_bind_param( unencrypted_value, self._dialect, ) set_cols = ",".join([ f"{name} = :{name}" for name in list(re_encrypted_columns.keys()) ]) logger.info("Processing table: %s", table_name) conn.execute( text(f"UPDATE {table_name} SET {set_cols} WHERE id = :id"), id=row["id"], **re_encrypted_columns, )
def zurich_download_all(con: Connection): print( f'Retrieving data for the city of Zurich from {START_YEAR} to {END_YEAR}' ) con.execute(f'DROP TABLE IF EXISTS {ZURICH_TABLE}') for year in tqdm(range(START_YEAR, END_YEAR + 1)): df = zurich_hourly_data(year) add_columns(ZURICH_TABLE, df, con) df.to_sql(ZURICH_TABLE, if_exists='append', con=con) print('Done!')
def nabel_download_all(con: Connection): print(f'Retrieving data for the city of Zurich starting in {START_YEAR}') con.execute(f'DROP TABLE IF EXISTS {NABEL_TABLE}') current_year = datetime.now(tz_local).year for year in tqdm(range(START_YEAR, current_year), desc='Past years'): nabel_download(con, datetime(year, 1, 1, 0, 0, 0, 0), datetime(year, 12, 31, 0, 0, 0, 0)) nabel_download(con, datetime(current_year, 1, 1, 0, 0, 0, 0), datetime.now(tz_local)) print('Done!')
def example_auction(connection: Connection, another_user: str) -> Generator[int, None, None]: ends_at = datetime.now() + timedelta(days=3) result_proxy = connection.execute( auctions.insert( {"title": "Super aukcja", "starting_price": "0.99", "current_price": "1.00", "ends_at": ends_at} ) ) auction_id = result_proxy.lastrowid connection.execute(bids.insert({"auction_id": auction_id, "amount": "1.00", "bidder_id": another_user})) yield int(auction_id) connection.execute(bids.delete(bids.c.auction_id == auction_id)) connection.execute(auctions.delete(auctions.c.id == auction_id))
def _create_choices(connection: Connection, values: dict, question_id: str, submission_map: dict, existing_question_id: str=None) -> Iterator: """ Create the choices of a survey question. If this is an update to an existing survey, it will also copy over answers to the questions. :param connection: the SQLAlchemy Connection object for the transaction :param values: the dictionary of values associated with the question :param question_id: the UUID of the question :param submission_map: a dictionary mapping old submission_id to new :param existing_question_id: the UUID of the existing question (if this is an update) :return: an iterable of the resultant choice fields """ choices = values['choices'] new_choices, updates = _determine_choices(connection, existing_question_id, choices) for number, choice in enumerate(new_choices): choice_dict = { 'question_id': question_id, 'survey_id': values['survey_id'], 'choice': choice, 'choice_number': number, 'type_constraint_name': values['type_constraint_name'], 'question_sequence_number': values['sequence_number'], 'allow_multiple': values['allow_multiple']} executable = question_choice_insert(**choice_dict) exc = [('unique_choice_names', RepeatedChoiceError(choice))] result = execute_with_exceptions(connection, executable, exc) result_ipk = result.inserted_primary_key question_choice_id = result_ipk[0] if choice in updates: question_fields = {'question_id': question_id, 'type_constraint_name': result_ipk[2], 'sequence_number': result_ipk[3], 'allow_multiple': result_ipk[4], 'survey_id': values['survey_id']} for answer in get_answer_choices_for_choice_id(connection, updates[choice]): answer_values = question_fields.copy() new_submission_id = submission_map[answer.submission_id] answer_values['question_choice_id'] = question_choice_id answer_values['submission_id'] = new_submission_id answer_metadata = answer.answer_choice_metadata answer_values['answer_choice_metadata'] = answer_metadata connection.execute(answer_choice_insert(**answer_values)) yield question_choice_id
def lock_table(connection: Connection, target_table: Table): """ Lock a table using a PostgreSQL advisory lock The OID of the table in the pg_class relation is used as lock id. :param connection: DB connection :param target_table: Table object """ logger.debug('Locking table "%s"', target_table.name) oid = connection.execute( select([column("oid")]).select_from(table("pg_class")).where( (column("relname") == target_table.name))).scalar() connection.execute(select([func.pg_advisory_xact_lock(oid)])).scalar()
def upgrade(self, connection: Connection): tx = connection.begin() connection.execute("""create table villain_templates( id int not null auto_increment, name text not null, face_image_url text not null, primary key (id) );""") connection.execute(""" alter table villain_templates add unique index `villain_template_name_idx` (`name`); """) tx.commit()
def insert_profile(conn: Connection, insert: str, p: Profile): u, _ = unify_profile_name(p.first_name, p.last_name) b64u = generate_id(u) conn.execute( insert, ( sanitize_text(p.identifier), b64u, sanitize_text(p.first_name), sanitize_text(p.last_name), sanitize_text(p.display_name), sanitize_text(p.link), ), )
def init_db(connection: Connection, alembic_ini: str=None, force: bool=False, test: bool=False) -> None: import c2cgeoportal_commons.models.main # noqa: F401 import c2cgeoportal_commons.models.static # noqa: F401 schema = config['schema'] schema_static = config['schema_static'] if force: if schema_exists(connection, schema): connection.execute('DROP SCHEMA {} CASCADE;'.format(schema)) if schema_exists(connection, schema_static): connection.execute('DROP SCHEMA {} CASCADE;'.format(schema_static)) if not schema_exists(connection, schema): connection.execute('CREATE SCHEMA "{}";'.format(schema)) if not schema_exists(connection, schema_static): connection.execute('CREATE SCHEMA "{}";'.format(schema_static)) if alembic_ini is None: Base.metadata.create_all(connection) else: def upgrade(schema: str) -> None: cfg = alembic_config.Config(alembic_ini, ini_section=schema) cfg.attributes['connection'] = connection # pylint: disable=unsupported-assignment-operation command.upgrade(cfg, 'head') upgrade('main') upgrade('static') session_factory = get_session_factory(connection) with transaction.manager: dbsession = get_tm_session(session_factory, transaction.manager) if test: setup_test_data(dbsession)
def init_db(connection: Connection, force: bool=False, test: bool=False) -> None: import c2cgeoportal_commons.models.main # noqa: F401 import c2cgeoportal_commons.models.static # noqa: F401 from c2cgeoportal_commons.models import schema schema_static = '{}_static'.format(schema) assert schema is not None if force: if schema_exists(connection, schema): connection.execute('DROP SCHEMA {} CASCADE;'.format(schema)) if schema_exists(connection, schema_static): connection.execute('DROP SCHEMA {} CASCADE;'.format(schema_static)) if not schema_exists(connection, schema): connection.execute('CREATE SCHEMA "{}";'.format(schema)) if not schema_exists(connection, schema_static): connection.execute('CREATE SCHEMA "{}";'.format(schema_static)) Base.metadata.create_all(connection) session_factory = get_session_factory(connection) with transaction.manager: dbsession = get_tm_session(session_factory, transaction.manager) if test: setup_test_data(dbsession)
def lock_table(connection: Connection, target_table: Table): """ Lock a table using a PostgreSQL advisory lock The OID of the table in the pg_class relation is used as lock id. :param connection: DB connection :param target_table: Table object """ logger.debug('Locking table "%s"', target_table.name) oid = connection.execute(select([column("oid")]) .select_from(table("pg_class")) .where((column("relname") == target_table.name)) ).scalar() connection.execute(select([func.pg_advisory_xact_lock(oid)])).scalar()
def _copy_submission_entries(connection: Connection, existing_survey_id: str, new_survey_id: str, email: str) -> tuple: """ Copy submissions from an existing survey to its updated copy. :param connection: the SQLAlchemy connection used for the transaction :param existing_survey_id: the UUID of the existing survey :param new_survey_id: the UUID of the survey's updated copy :param email: the user's e-mail address :return: a tuple containing the old and new submission IDs """ submissions = get_submissions_by_email( connection, email, survey_id=existing_survey_id ) for sub in submissions: values = {'submitter': sub.submitter, 'submitter_email': sub.submitter_email, 'submission_time': sub.submission_time, 'save_time': sub.save_time, 'survey_id': new_survey_id} result = connection.execute(submission_insert(**values)) yield sub.submission_id, result.inserted_primary_key[0]
def execute_with_exceptions(connection: Connection, executable: [Insert, Update], exceptions: Iterator) -> ResultProxy: """ Execute the given executable (a SQLAlchemy Insert or Update) within a transaction (provided by the Connection object), and raise meaningful exceptions. Normally connection.execute() will raise a generic Integrity error, so use the exceptions parameter to specify which exceptions to raise instead. :param connection: the SQLAlchemy connection (for transaction purposes) :param executable: the object to pass to connection.execute() :param exceptions: an iterable of (name: str, exception: Exception) tuples. name is the string to look for in the IntegrityError, and exception is the Exception to raise instead of IntegrityError :return: a SQLAlchemy ResultProxy """ try: return connection.execute(executable) except IntegrityError as exc: error = str(exc.orig) for name, exception in exceptions: if name in error: raise exception raise
def get_questions(connection: Connection, survey_id: str, auth_user_id: [str, None]=None, email: [str, None]=None) -> ResultProxy: """ Get all the questions for a survey identified by survey_id ordered by sequence number restricted by auth_user. :param connection: a SQLAlchemy Connection :param survey_id: the UUID of the survey :param auth_user_id: the UUID of the user :param email: the user's e-mail address :return: an iterable of the questions (RowProxy) """ table = question_table.join(survey_table) conds = [question_table.c.survey_id == survey_id] if auth_user_id is not None: if email is not None: raise TypeError('You cannot specify both auth_user_id and email') conds.append(survey_table.c.auth_user_id == auth_user_id) elif email is not None: table = table.join(auth_user_table) conds.append(auth_user_table.c.email == email) else: raise TypeError('You must specify either auth_user_id or email') questions = connection.execute( select([question_table]).select_from(table).where( and_(*conds)).order_by('sequence_number asc')) return questions
def get_sessions_of_mac(connection: Connection, mac: netaddr.EUI, when: Optional[DatetimeRange]=None, limit: Optional[int]=None) -> Iterable[ Tuple[netaddr.IPAddress, str, datetime, datetime]]: """ Return accounting sessions of a particular MAC address ordered by Session-Start-Time descending. :param connection: A SQLAlchemy connection :param str mac: MAC address :param when: Range in which Session-Start-Time must be within :param limit: Maximum number of records :return: An iterable that yields (NAS-IP-Address, NAS-Port-Id, Session-Start-Time, Session-Stop-Time)-tuples ordered by Session-Start-Time descending """ logger.debug('Getting all sessions for MAC "%s"', mac) query = ( select([radacct.c.NASIPAddress, radacct.c.NASPortId, radacct.c.AcctStartTime, radacct.c.AcctStopTime]) .where(and_(radacct.c.UserName == mac)) .order_by(radacct.c.AcctStartTime.desc()) ) if when is not None: query.where(radacct.c.AcctStartTime.op('<@') <= func.tstzrange(*when)) if limit is not None: query = query.limit(limit) return iter(connection.execute(query))
def _jsonify(connection: Connection, answer: object, question_id: str) -> object: """ This function returns a "nice" representation of an answer which can be serialized as JSON. :param connection: a SQLAlchemy Connection :param answer: a submitted value :param type_constraint_name: the UUID of the question :return: the nice representation """ type_constraint_name = question_select(connection, question_id).type_constraint_name if type_constraint_name in {'location', 'facility'}: geo_json = connection.execute(func.ST_AsGeoJSON(answer)).scalar() return json_decode(geo_json)['coordinates'] elif type_constraint_name in {'date', 'time'}: return maybe_isoformat(answer) elif type_constraint_name == 'decimal': return float(answer) elif type_constraint_name == 'multiple_choice': question_choice = question_choice_select(connection, answer) return question_choice.choice else: return answer
def get_stats(connection: Connection, survey_id: str, email: str) -> dict: """ Get statistics about the specified survey: creation time, number of submissions, time of the earliest submission, and time of the latest submission. :param connection: a SQLAlchemy Connection :param survey_id: the UUID of the survey :param email: the e-mail address of the user :return: a JSON representation of the statistics. """ result = connection.execute( select([ survey_table.c.created_on, count(submission_table.c.submission_id), sqlmin(submission_table.c.submission_time), sqlmax(submission_table.c.submission_time) ]).select_from( auth_user_table.join(survey_table).outerjoin(submission_table) ).where( survey_table.c.survey_id == survey_id ).where( auth_user_table.c.email == email ).group_by( survey_table.c.survey_id ) ).first() return json_response({ 'created_on': maybe_isoformat(result[0]), 'num_submissions': result[1], 'earliest_submission_time': maybe_isoformat(result[2]), 'latest_submission_time': maybe_isoformat(result[3]) })
def get_auth_attempts_at_port(connection: Connection, nas_ip_address: netaddr.IPAddress, nas_port_id: str, when: Optional[DatetimeRange]=None, limit: Optional[int]=None)-> Iterable[ Tuple[str, str, Groups, Attributes, datetime]]: """ Return auth attempts at a particular port of an NAS ordered by Auth-Date descending. :param connection: A SQLAlchemy connection :param nas_ip_address: NAS IP address :param nas_port_id: NAS Port ID :param when: Range in which Auth-Date must be within :param limit: Maximum number of records :return: An iterable that yields (User-Name, Packet-Type, Groups, Reply, Auth-Date)-tuples ordered by Auth-Date descending """ logger.debug('Getting all auth attempts at port %2$s of %1$s', nas_ip_address, nas_port_id) query = ( select([radpostauth.c.UserName, radpostauth.c.PacketType, radpostauth.c.Groups, radpostauth.c.Reply, radpostauth.c.AuthDate]) .where(and_(radpostauth.c.NASIPAddress == nas_ip_address, radpostauth.c.NASPortId == nas_port_id)) .order_by(radpostauth.c.AuthDate.desc()) ) if when is not None: query.where(radpostauth.c.AuthDate.op('<@') <= func.tstzrange(*when)) if limit is not None: query = query.limit(limit) return iter(connection.execute(query))
def get_auth_attempts_of_mac(connection: Connection, mac: netaddr.EUI, when: Optional[DatetimeRange]=None, limit: Optional[int]=None) -> Iterable[ Tuple[netaddr.IPAddress, str, str, Groups, Attributes, datetime]]: """ Return auth attempts of a particular MAC address order by Auth-Date descending. :param connection: A SQLAlchemy connection :param mac: MAC address :param when: Range in which Auth-Date must be within :param limit: Maximum number of records :return: An iterable that yields (NAS-IP-Address, NAS-Port-Id, Packet-Type, Groups, Reply, Auth-Date)-tuples ordered by Auth-Date descending """ logger.debug('Getting all auth attempts of MAC %s', mac) query = ( select([radpostauth.c.NASIPAddress, radpostauth.c.NASPortId, radpostauth.c.PacketType, radpostauth.c.Groups, radpostauth.c.Reply, radpostauth.c.AuthDate]) .where(and_(radpostauth.c.UserName == mac)) .order_by(radpostauth.c.AuthDate.desc()) ) if when is not None: query.where(radpostauth.c.AuthDate.op('<@') <= func.tstzrange(*when)) if limit is not None: query = query.limit(limit) return iter(connection.execute(query))
def survey_select(connection: Connection, survey_id: str, auth_user_id: str=None, email: str=None) -> RowProxy: """ Get a record from the survey table. You must supply either the auth_user_id or the email. :param connection: a SQLAlchemy Connection :param survey_id: the UUID of the survey :param auth_user_id: the UUID of the user :param email: the user's e-mail address :return: the corresponding record :raise SurveyDoesNotExistError: if the UUID is not in the table """ table = survey_table conds = [survey_table.c.survey_id == survey_id] if auth_user_id is not None: if email is not None: raise TypeError('You cannot specify both auth_user_id and email') conds.append(survey_table.c.auth_user_id == auth_user_id) elif email is not None: table = table.join(auth_user_table) conds.append(auth_user_table.c.email == email) else: raise TypeError('You must specify either auth_user_id or email') survey = connection.execute(select([survey_table]).select_from( table).where(and_(*conds))).first() if survey is None: raise SurveyDoesNotExistError(survey_id) return survey
def _return_sql(connection: Connection, result: object, survey_id: str, auth_user_id: str, question_id: str) -> object: """ Get the result for a _scalar-y function. :param connection: a SQLAlchemy Connection :param result: the result of the SQL function :param survey_id: the UUID of the survey :param auth_user_id: the UUID of the user :param question_id: the UUID of the question :return: the result of the SQL function :raise NoSubmissionsToQuestionError: if there are no submissions :raise QuestionDoesNotExistError: if the user is not authorized """ if result is None or result == []: condition = survey_table.c.survey_id == survey_id stmt = select([survey_table]).where(condition) proper_id = connection.execute(stmt).first().auth_user_id if auth_user_id == proper_id: raise NoSubmissionsToQuestionError(question_id) raise QuestionDoesNotExistError(question_id) return result
def update(connection: Connection, data: dict): """ Update a survey (title, questions). You can also add or modify questions here. Note that this creates a new survey (with new submissions, etc), copying everything from the old survey. The old survey's title will be changed to end with "(new version created on <time>)". :param connection: a SQLAlchemy Connection :param data: JSON containing the UUID of the survey and fields to update. """ survey_id = data['survey_id'] email = data['email'] existing_survey = survey_select(connection, survey_id, email=email) if 'survey_metadata' not in data: data['survey_metadata'] = existing_survey.survey_metadata update_time = datetime.datetime.now() with connection.begin(): new_title = '{} (new version created on {})'.format( existing_survey.survey_title, update_time.isoformat()) executable = update_record(survey_table, 'survey_id', survey_id, survey_title=new_title) exc = [('survey_title_survey_owner_key', SurveyAlreadyExistsError(new_title))] execute_with_exceptions(connection, executable, exc) new_survey_id = _create_survey(connection, data) return get_one(connection, new_survey_id, email=email)
def create_user(connection: Connection, data: dict) -> dict: """ Registers a new user account. :param connection: a SQLAlchemy Connection :param data: the user's e-mail :return: a response containing the e-mail and whether it was created or already exists in the database """ email = data['email'] try: get_auth_user_by_email(connection, email) except UserDoesNotExistError: with connection.begin(): connection.execute(create_auth_user(email=email)) return json_response({'email': email, 'response': 'Created'}) return json_response({'email': email, 'response': 'Already exists'})
def schema_exists(connection: Connection, schema_name: str) -> bool: sql = """ SELECT count(*) AS count FROM information_schema.schemata WHERE schema_name = '{}'; """.format(schema_name) result = connection.execute(sql) row = result.first() return row[0] == 1
def bar_graph(connection: Connection, question_id: str, auth_user_id: str=None, email: str=None, limit: [int, None]=None, count_order: bool=False) -> dict: """ Get a list of the number of times each submission value appears. You must provide either an auth_user_id or e-mail address. :param connection: a SQLAlchemy Connection :param question_id: the UUID of the question :param auth_user_id: the UUID of the user :param email: the e-mail address of the user. :param limit: a limit on the number of results :param count_order: whether to order from largest count to smallest :return: a JSON dict containing the result [[values], [counts]] """ user_id = _get_user_id(connection, auth_user_id, email) allowable_types = {'text', 'integer', 'decimal', 'multiple_choice', 'date', 'time', 'location', 'facility'} question = question_select(connection, question_id) tcn = _get_type_constraint_name(allowable_types, question) # Assume that you only want to consider the non-other answers original_table, column_name = _table_and_column(tcn) table = original_table.join( question_table, original_table.c.question_id == question_table.c.question_id ).join(survey_table) conds = [question_table.c.question_id == question_id, survey_table.c.auth_user_id == user_id] column = get_column(original_table, column_name) column_query = select( [column, sqlcount(column)] ).select_from(table).group_by(column) ordering = desc(sqlcount(column)) if count_order else column ordered_query = column_query.order_by(ordering) result = connection.execute( ordered_query.where(and_(*conds)).limit(limit) ) result = _return_sql(connection, result, question.survey_id, user_id, question_id) bar_graph_result = [[_jsonify(connection, r[0], question_id), r[1]] for r in result] response = json_response( _return_sql(connection, bar_graph_result, question.survey_id, user_id, question_id)) response['query'] = 'bar_graph' return response
def refresh_and_diff_materialized_view( connection: Connection, view: Table, copy: Table, result_columns: Iterable[Column]) -> Tuple[ List[Tuple], List[Tuple], List[Tuple]]: with connection.begin(): lock_table(connection, view) create_temp_copy(connection, view, copy) refresh_materialized_view(connection, view) return diff_tables(connection, view, copy, result_columns)
def get_number_of_submissions(connection: Connection, survey_id: str) -> int: """ Return the number of submissions for a given survey :param connection: a SQLAlchemy Connection :param survey_id: the UUID of the survey :return: the corresponding number of submissions """ return connection.execute(select([count()]).where( submission_table.c.survey_id == survey_id)).scalar()
def time_series(connection: Connection, question_id: str, auth_user_id: str=None, email: str=None) -> dict: """ Get a list of submissions to the specified question over time. You must provide either an auth_user_id or e-mail address. :param connection: a SQLAlchemy Connection :param question_id: the UUID of the question :param auth_user_id: the UUID of the user :param email: the e-mail address of the user. :return: a JSON dict containing the result [[times], [values]] """ user_id = _get_user_id(connection, auth_user_id, email) allowable_types = {'text', 'integer', 'decimal', 'multiple_choice', 'date', 'time', 'location'} question = question_select(connection, question_id) tcn = _get_type_constraint_name(allowable_types, question) # Assume that you only want to consider the non-other answers original_table, column_name = _table_and_column(tcn) table = original_table.join( survey_table, original_table.c.survey_id == survey_table.c.survey_id ).join( submission_table, original_table.c.submission_id == submission_table.c.submission_id ) column = get_column(original_table, column_name) where_stmt = select( [column, submission_table.c.submission_time] ).select_from(table).where( original_table.c.question_id == question_id ).where( survey_table.c.auth_user_id == user_id ) result = _return_sql( connection, connection.execute(where_stmt.order_by('submission_time asc')), question.survey_id, auth_user_id, question_id) tsr = [[r.submission_time.isoformat(), _jsonify(connection, r[column_name], question_id)] for r in result] time_series_result = tsr response = json_response( _return_sql(connection, time_series_result, question.survey_id, user_id, question_id)) response['query'] = 'time_series' return response
def get_all_alternative_dns_ips(connection: Connection) -> Iterable[ netaddr.IPAddress]: """ Return all IPs for alternative DNS configuration. :param connection: A SQLAlchemy connection :return: An iterable that yields ip addresses """ logger.debug("Getting all alternative DNS clients") result = connection.execute(select([alternative_dns.c.IPAddress])) return map(operator.itemgetter(0), result)
def get_answers_for_question(connection: Connection, question_id: str) -> ResultProxy: """ Get all the records from the answer table identified by question_id. :param connection: a SQLAlchemy Connection :param question_id: foreign key :return: an iterable of the answers (RowProxy) """ select_stmt = select([answer_table]) where_stmt = select_stmt.where(answer_table.c.question_id == question_id) return connection.execute(where_stmt)
def get_branches(connection: Connection, question_id: str) -> ResultProxy: """ Get all the branches for a question identified by question_id. :param connection: a SQLAlchemy Connection :param question_id: foreign key :return: an iterable of the branches (RowProxy) """ select_stmt = select([question_branch_table]) where_stmt = select_stmt.where(question_branch_table.c.from_question_id == question_id) return connection.execute(where_stmt)
def get_all_dhcp_hosts(connection: Connection) -> Iterable[ Tuple[netaddr.EUI, netaddr.IPAddress]]: """ Return all DHCP host configurations. :param connection: A SQLAlchemy connection :return: An iterable that yields (mac, ip)-tuples """ logger.debug("Getting all DHCP hosts") result = connection.execute(select([dhcphost.c.MAC, dhcphost.c.IPAddress])) return iter(result)
def create(connection: Connection, data: dict) -> dict: """ Create a survey with questions. :param connection: a SQLAlchemy Connection :param data: a JSON representation of the survey to be created :return: a JSON representation of the created survey """ with connection.begin(): survey_id = _create_survey(connection, data) return get_one(connection, survey_id, email=data['email'])
def get_answers(connection: Connection, submission_id: str) -> ResultProxy: """ Get all the records from the answer table identified by submission_id ordered by sequence number. :param connection: a SQLAlchemy Connection :param submission_id: foreign key :return: an iterable of the answers (RowProxy) """ select_stmt = select([answer_table]) where_stmt = select_stmt.where(answer_table.c.submission_id == submission_id) return connection.execute(where_stmt.order_by("sequence_number asc"))
def create_temp_copy(connection: Connection, source: Table, destination: Table): """ Create a temporary table as a copy of a source table that will be dropped at the end of the running transaction. :param connection: DB connection :param source: Source table :param destination: Destination table """ logger.debug('Creating temporary table "%s" as copy of "%s"', destination.name, source.name) if not connection.in_transaction(): raise RuntimeError("must be executed in a transaction to have any " "effect") preparer = connection.dialect.identifier_preparer connection.execute( 'CREATE TEMPORARY TABLE {destination} ON COMMIT DROP AS ' 'SELECT * FROM {source}'.format( source=preparer.format_table(source), destination=preparer.format_table(destination), ) )