def get_file_metadata_row_for_path( session: Session, region_code: str, path: GcsfsFilePath ) -> Union[schema.DirectIngestRawFileMetadata, schema.DirectIngestIngestFileMetadata]: """Returns metadata information for the provided path. If the file has not yet been registered in the appropriate metadata table, this function will generate a file_id to return with the metadata. """ parts = filename_parts_from_path(path) if parts.file_type == GcsfsDirectIngestFileType.INGEST_VIEW: results = session.query( schema.DirectIngestIngestFileMetadata).filter_by( region_code=region_code, is_invalidated=False, normalized_file_name=path.file_name).all() elif parts.file_type == GcsfsDirectIngestFileType.RAW_DATA: results = session.query(schema.DirectIngestRawFileMetadata).filter_by( region_code=region_code, normalized_file_name=path.file_name).all() else: raise ValueError(f'Unexpected path type: {parts.file_type}') if len(results) != 1: raise ValueError( f'Unexpected number of metadata results for path {path.abs_path()}: [{len(results)}]' ) return one(results)
def check_people_do_not_have_multiple_ids_same_type( session: Session, region_code: str, output_people: List[schema.StatePerson]) -> bool: """Validates that person has two ids of the same type (for states configured to enforce this invariant).""" check_not_dirty(session) logging.info( "[Invariant validation] Checking that no person has multiple external ids of the same type." ) if state_allows_multiple_ids_same_type(region_code): logging.info( "[Invariant validation] Multiple external ids of the same type allowed for [%s] - skipping.", region_code) return True person_ids = {p.person_id for p in output_people} if not person_ids: logging.warning( "[Invariant validation] No StatePersonExternalIds in the output set - skipping validations." ) return True counts_subquery = session.query( schema.StatePersonExternalId.state_code.label('state_code'), schema.StatePersonExternalId.person_id.label('person_id'), schema.StatePersonExternalId.id_type.label('id_type'), func.count().label('cnt') ).filter(schema.StatePersonExternalId.state_code == region_code.upper( )).filter( # Ideally we would not filter by person_ids, but that query takes ~10s on a test of US_PA external ids. Since # this will be run for every file, that sort of performance is prohibitive. We instead filter by just the person # ids we think we have touched this session. schema.StatePersonExternalId.person_id.in_(person_ids)).group_by( schema.StatePersonExternalId.state_code, schema.StatePersonExternalId.person_id, schema.StatePersonExternalId.id_type, ).subquery() query = session.query( counts_subquery.c.state_code, counts_subquery.c.person_id, counts_subquery.c.id_type, counts_subquery.c.cnt, ).filter(counts_subquery.c.cnt > 1).limit(1) results = query.all() if results: _state_code, person_id, id_type, count = one(results) logging.error( '[Invariant validation] Found people with multiple ids of the same type. First example: ' 'person_id=[%s], id_type=[%s] is used [%s] times.', person_id, id_type, count) return False logging.info( "[Invariant validation] Found no people with multiple external ids of the same type." ) return True
def _write_snapshots_for_existing_entities( self, session: Session, context: "_SnapshotContext", snapshot_time: datetime, schema: ModuleType, ) -> None: """Writes snapshot updates for entities that already have snapshots present in the database """ historical_class = _get_historical_class(type(context.schema_object), schema) new_historical_snapshot = historical_class() self._copy_entity_fields_to_historical_snapshot( context.schema_object, new_historical_snapshot) new_historical_snapshot.valid_from = snapshot_time # Snapshot must be merged separately from record tree, as they are not # included in the ORM model relationships (to avoid needing to load # the entire snapshot chain at once) session.merge(new_historical_snapshot) # Close last snapshot if one is present if context.most_recent_snapshot is not None: if not isinstance(context.most_recent_snapshot, HistoryTableSharedColumns): raise ValueError( f"Snapshot class [{type(context.most_recent_snapshot)}] " f"must be a subclass of " f"[{HistoryTableSharedColumns.__name__}]") context.most_recent_snapshot.valid_to = snapshot_time session.merge(context.most_recent_snapshot)
def _write_snapshots_for_new_entities( self, session: Session, context: "_SnapshotContext", snapshot_time: datetime, schema, ) -> None: """Writes snapshots for any new entities, including any required manual adjustments based on provided start and end times """ historical_class = _get_historical_class(type(context.schema_object), schema) new_historical_snapshot = historical_class() self._copy_entity_fields_to_historical_snapshot( context.schema_object, new_historical_snapshot) provided_start_time = None provided_end_time = None # Validate provided start and end times if (context.provided_start_time and context.provided_start_time.date() < snapshot_time.date()): provided_start_time = context.provided_start_time if (context.provided_end_time and context.provided_end_time.date() < snapshot_time.date()): provided_end_time = context.provided_end_time if (provided_start_time and provided_end_time and provided_start_time >= provided_end_time): provided_start_time = None if provided_start_time is not None and provided_end_time is None: new_historical_snapshot.valid_from = provided_start_time elif provided_end_time is not None: new_historical_snapshot.valid_from = provided_end_time else: new_historical_snapshot.valid_from = snapshot_time # Snapshot must be merged separately from record tree, as they are not # included in the ORM model relationships (to avoid needing to load # the entire snapshot chain at once) session.merge(new_historical_snapshot) # If both start and end time were provided, an earlier snapshot needs to # be created, reflecting the state of the entity before its current # completed state if provided_start_time and provided_end_time: initial_snapshot = historical_class() self._copy_entity_fields_to_historical_snapshot( context.schema_object, initial_snapshot) initial_snapshot.valid_from = provided_start_time initial_snapshot.valid_to = provided_end_time self.post_process_initial_snapshot(context, initial_snapshot) session.merge(initial_snapshot)
def _execute_statement(self, statement: str) -> None: session = Session(bind=self.postgres_engine) try: session.execute(statement) session.commit() except Exception as e: logging.warning("Failed to cleanup: %s", e) session.rollback() finally: session.close()
def for_prod_data_client( cls, database_key: SQLAlchemyDatabaseKey, ssl_cert_path: str, *, autocommit: bool = True, ) -> Iterator[Session]: """Implements a context manager for db sessions for use in prod-data-client.""" engine = SQLAlchemyEngineManager.get_engine_for_database_with_ssl_certs( database_key=database_key, ssl_cert_path=ssl_cert_path) if engine is None: raise ValueError(f"No engine set for key [{database_key}]") try: session = Session(bind=engine) cls._alter_session_variables(session) cls._apply_session_listener_for_schema_base( database_key.declarative_meta, session) yield session if autocommit: try: session.commit() except Exception as e: session.rollback() raise e finally: session.close()
def _alter_session_variables(cls, session: Session) -> None: # Postgres uses a query cost analysis heuristic to decide what type of read to use for a particular query. It # sometimes chooses to use a sequential read because for hard disk drives (HDDs, as opposed to solid state # drives, SSDs) that may be faster than jumping around to random pages of an index. This is especially likely # when running over small sets of data. Setting this option changes the heuristic to almost always prefer index # reads. # # Our postgres instances run on SSDs, so this should increase performance for us. This is also important # because sequential reads lock an entire table, whereas index reads only lock the particular predicate from a # query. See https://www.postgresql.org/docs/12/transaction-iso.html and # https://stackoverflow.com/questions/42288808/why-does-postgresql-serializable-transaction-think-this-as-conflict. # # TODO(#3928): Once defined in code, set this on the SQL instance itself instead of per session. if session.bind.dialect.name == "postgresql": session.execute("SET random_page_cost=1;")
def get_file_metadata_row( session: Session, file_type: GcsfsDirectIngestFileType, file_id: int ) -> Union[schema.DirectIngestRawFileMetadata, schema.DirectIngestIngestFileMetadata]: """Queries for the file metadata row by the metadata row primary key.""" if file_type == GcsfsDirectIngestFileType.INGEST_VIEW: results = session.query(schema.DirectIngestIngestFileMetadata).filter_by(file_id=file_id).all() elif file_type == GcsfsDirectIngestFileType.RAW_DATA: results = session.query(schema.DirectIngestRawFileMetadata).filter_by(file_id=file_id).all() else: raise ValueError(f'Unexpected path type: {file_type}') return one(results)
def _pending_to_persistent(session: Session, instance: Any) -> None: """Called when a SQLAlchemy object transitions to a persistent object. If this function throws, the session will be rolled back and that object will not be committed.""" if not isinstance(instance, DirectIngestIngestFileMetadata): return results = (session.query(DirectIngestIngestFileMetadata).filter_by( is_invalidated=False, is_file_split=False, region_code=instance.region_code, file_tag=instance.file_tag, ingest_database_name=instance.ingest_database_name, datetimes_contained_lower_bound_exclusive=instance. datetimes_contained_lower_bound_exclusive, datetimes_contained_upper_bound_inclusive=instance. datetimes_contained_upper_bound_inclusive, ).all()) if len(results) > 1: raise IntegrityError( f"Attempting to commit repeated non-file split DirectIngestIngestFileMetadata row for " f"region_code={instance.region_code}, file_tag={instance.file_tag}, " f"ingest_database_name={instance.ingest_database_name}", f"datetimes_contained_lower_bound_exclusive={instance.datetimes_contained_lower_bound_exclusive}, " f"datetimes_contained_upper_bound_inclusive={instance.datetimes_contained_upper_bound_inclusive}", )
def get_ingest_file_metadata_row_for_path( session: Session, region_code: str, path: GcsfsFilePath, ingest_database_name: str) -> schema.DirectIngestIngestFileMetadata: """Returns metadata information for the provided path. If the file has not yet been registered in the appropriate metadata table, this function will generate a file_id to return with the metadata. """ parts = filename_parts_from_path(path) if parts.file_type != GcsfsDirectIngestFileType.INGEST_VIEW: raise ValueError(f"Unexpected file type [{parts.file_type}]") results = (session.query(schema.DirectIngestIngestFileMetadata).filter_by( region_code=region_code.upper(), is_invalidated=False, normalized_file_name=path.file_name, ingest_database_name=ingest_database_name, ).all()) if len(results) != 1: raise ValueError( f"Unexpected number of metadata results for path {path.abs_path()}: [{len(results)}]" ) return one(results)
def get_ingest_view_metadata_for_export_job( session: Session, region_code: str, file_tag: str, datetimes_contained_lower_bound_exclusive: Optional[datetime.datetime], datetimes_contained_upper_bound_inclusive: datetime.datetime, ingest_database_name: str, ) -> Optional[schema.DirectIngestIngestFileMetadata]: """Returns the ingest file metadata row corresponding to the export job with the provided args. Throws if such a row does not exist. """ results = (session.query(schema.DirectIngestIngestFileMetadata).filter_by( region_code=region_code.upper(), file_tag=file_tag, is_invalidated=False, is_file_split=False, datetimes_contained_lower_bound_exclusive= datetimes_contained_lower_bound_exclusive, datetimes_contained_upper_bound_inclusive= datetimes_contained_upper_bound_inclusive, ingest_database_name=ingest_database_name, ).all()) if not results: return None return one(results)
def retry_transaction( session: Session, measurements: MeasurementMap, txn_body: Callable[[Session], bool], max_retries: Optional[int], ) -> bool: """Retries the transaction if a serialization failure occurs. Handles management of committing, rolling back, and closing the `session`. `txn_body` can return False to force the transaction to be aborted, otherwise return True. Returns: True, if the transaction succeeded. False, if the transaction was aborted by `txn_body`. """ num_retries = 0 try: while True: try: should_continue = txn_body(session) if not should_continue: session.rollback() return should_continue session.commit() return True except sqlalchemy.exc.DBAPIError as e: session.rollback() if max_retries and num_retries >= max_retries: raise if (isinstance(e.orig, psycopg2.OperationalError) and e.orig.pgcode == SERIALIZATION_FAILURE): logging.info( "Retrying transaction due to serialization failure: %s", e) num_retries += 1 continue raise except Exception: session.rollback() raise finally: measurements.measure_int_put(m_retries, num_retries) session.close()
def get_date_sorted_unprocessed_raw_files_for_region( session: Session, region_code: str, ) -> List[schema.DirectIngestRawFileMetadata]: """Returns metadata for all raw files that do not have a processed_time from earliest to latest""" return (session.query(schema.DirectIngestRawFileMetadata).filter_by( region_code=region_code.upper(), processed_time=None).order_by( schema.DirectIngestRawFileMetadata. datetimes_contained_upper_bound_inclusive.asc()).all())
def get_ingest_view_metadata_pending_export( session: Session, region_code: str) -> List[schema.DirectIngestIngestFileMetadata]: """Returns metadata for all ingest files have not yet been exported.""" return session.query(schema.DirectIngestIngestFileMetadata).filter_by( region_code=region_code, is_invalidated=False, is_file_split=False, export_time=None).all()
def for_schema_base(cls, schema_base: DeclarativeMeta) -> Session: engine = SQLAlchemyEngineManager.get_engine_for_schema_base( schema_base) if engine is None: raise ValueError( f"No engine set for base [{schema_base.__name__}]") session = Session(bind=engine) cls._apply_session_listener_for_schema_base(schema_base, session) return session
def get_ingest_file_metadata_row( session: Session, file_id: int, ingest_database_name: str, ) -> schema.DirectIngestIngestFileMetadata: """Queries for the ingest file metadata row by the metadata row primary key.""" results = (session.query(schema.DirectIngestIngestFileMetadata).filter_by( file_id=file_id, ingest_database_name=ingest_database_name).all()) return one(results)
def get_metadata_for_raw_files_discovered_after_datetime( session: Session, region_code: str, raw_file_tag: str, discovery_time_lower_bound_exclusive: Optional[datetime.datetime] ) -> List[schema.DirectIngestRawFileMetadata]: """Returns metadata for all raw files with a given tag that have been updated after the provided date.""" query = session.query(schema.DirectIngestRawFileMetadata).filter_by( region_code=region_code, file_tag=raw_file_tag) if discovery_time_lower_bound_exclusive: query = query.filter(schema.DirectIngestRawFileMetadata.discovery_time > discovery_time_lower_bound_exclusive) return query.all()
def to_query(self, session: Session) -> Query: """Create a query to SELECT each column based on the Mapping's name.""" select_statements = [] for new_view_column_name, source_column in attr.asdict(self).items(): # Default unmapped columns to NULL to ensure we SELECT all columns select_statement = null().label(new_view_column_name) if source_column is not None: select_statement = source_column.label(new_view_column_name) select_statements.append(select_statement) return session.query(*select_statements)
def _for_database(cls, database_key: SQLAlchemyDatabaseKey) -> Session: # TODO(#8046): When the above method is deleted, move this into `using_database` # directly. engine = SQLAlchemyEngineManager.get_engine_for_database( database_key=database_key) if engine is None: raise ValueError(f"No engine set for key [{database_key}]") session = Session(bind=engine) cls._alter_session_variables(session) cls._apply_session_listener_for_schema_base( database_key.declarative_meta, session) return session
def get_date_sorted_unprocessed_ingest_view_files_for_region( session: Session, region_code: str, ingest_database_name: str, ) -> List[schema.DirectIngestIngestFileMetadata]: """Returns metadata for all ingest files that do not have a processed_time from earliest to latest""" return (session.query(schema.DirectIngestIngestFileMetadata).filter_by( region_code=region_code.upper(), processed_time=None, ingest_database_name=ingest_database_name, is_invalidated=False, ).order_by(schema.DirectIngestIngestFileMetadata. datetimes_contained_upper_bound_inclusive.asc()).all())
def get_ingest_view_metadata_for_most_recent_valid_job( session: Session, region_code: str, file_tag: str) -> Optional[schema.DirectIngestIngestFileMetadata]: """Returns most recently created export metadata row where is_invalidated is False, or None if there are no metadata rows for this file tag for this manager's region.""" results = session.query(schema.DirectIngestIngestFileMetadata).filter_by( region_code=region_code, is_invalidated=False, is_file_split=False, file_tag=file_tag).order_by(schema.DirectIngestIngestFileMetadata. job_creation_time.desc()).limit(1).all() if not results: return None return one(results)
def assert_no_unexpected_entities_in_db( expected_entities: Sequence[DatabaseEntity], session: Session ): """Counts all of the entities present in the |expected_entities| graph by type and ensures that the same number of entities exists in the DB for each type. """ entity_counter: Dict[Type, List[DatabaseEntity]] = defaultdict(list) get_entities_by_type(expected_entities, entity_counter) for cls, entities_of_cls in entity_counter.items(): # Standalone classes do not need to be attached to a person by design, # so it is valid if some standalone entities are not reachable from the # provided |expected_entities| if is_standalone_class(cls): continue expected_ids = set() for entity in entities_of_cls: expected_ids.add(entity.get_id()) db_entities = session.query(cls).all() db_ids = set() for entity in db_entities: db_ids.add(entity.get_id()) if expected_ids != db_ids: print("\n********** Entities from |found_persons| **********\n") for entity in sorted(entities_of_cls, key=lambda x: x.get_id()): print_entity_tree(entity) print("\n********** Entities from db **********\n") for entity in sorted(db_entities, key=lambda x: x.get_id()): print_entity_tree(entity) raise ValueError( f"For cls {cls.__name__}, found difference in primary keys from" f"expected entities and those of entities read from db.\n" f"Expected ids not present in db: " f"{str(expected_ids - db_ids)}\n" f"Db ids not present in expected entities: " f"{str(db_ids - expected_ids)}\n" )
def _get_status_using_session( self, session: Session) -> DirectIngestInstanceStatus: return (session.query(DirectIngestInstanceStatus).filter( DirectIngestInstanceStatus.region_code == self.region_code, DirectIngestInstanceStatus.instance == self.ingest_instance.value, ).one())
def update_historical_snapshots( self, session: Session, root_people: List[SchemaPersonType], orphaned_entities: List[DatabaseEntity], ingest_metadata: IngestMetadata, ) -> None: """For all entities in all record trees rooted at |root_people| and all entities in |orphaned_entities|, performs any required historical snapshot updates. If any entity has no existing historical snapshots, an initial snapshot will be created for it. If any column of an entity differs from its current snapshot, the current snapshot will be closed with period end time of |snapshot_time| and a new snapshot will be opened corresponding to the updated entity with period start time of |snapshot_time|. If neither of these cases applies, no action will be taken on the entity. """ logging.info( "Beginning historical snapshot updates for %s record tree(s) and %s" " orphaned entities", len(root_people), len(orphaned_entities), ) schema: ModuleType = self.get_schema_module() root_entities: List[DatabaseEntity] = list() root_entities.extend(root_people) root_entities.extend(orphaned_entities) self._assert_all_root_entities_unique(root_entities) context_registry = _SnapshotContextRegistry() self._execute_action_for_all_entities(root_entities, context_registry.register_entity) logging.info( "%s master entities registered for snapshot check", len(context_registry.all_contexts()), ) most_recent_snapshots = self._fetch_most_recent_snapshots_for_all_entities( session, root_entities, schema) for snapshot in most_recent_snapshots: context_registry.add_snapshot(snapshot, schema) logging.info("%s registered entities with existing snapshots", len(most_recent_snapshots)) # Provided start and end times only need to be set for root_people, not # orphaned entities. Provided start and end times are only relevant for # new entities with no existing snapshots, and orphaned entities by # definition are already present in the database and therefore already # have existing snapshots. self.set_provided_start_and_end_times(root_people, context_registry) logging.info( "Provided start and end times set for registered entities") for snapshot_context in context_registry.all_contexts(): self._write_snapshots(session, snapshot_context, ingest_metadata.ingest_time, schema) logging.info("Flushing snapshots") session.flush() logging.info("All historical snapshots written")
def _fetch_most_recent_snapshots_for_entity_type( self, session: Session, master_class: Type, entity_ids: Set[int], schema: ModuleType, ) -> List[DatabaseEntity]: """Returns a list containing the most recent snapshot for each ID in |entity_ids| with type |master_class| """ # Get name of historical table in database (as distinct from name of ORM # class representing historical table in code) history_table_class = _get_historical_class(master_class, schema) history_table_name = history_table_class.__table__.name history_table_primary_key_col_name = ( history_table_class.get_primary_key_column_name()) # See module assumption #2 master_table_primary_key_col_name = master_class.get_primary_key_column_name( ) ids_list = ", ".join([str(id) for id in entity_ids]) # Get snapshot IDs in a separate query. The subquery logic here is ugly # and easier to do as a raw string query than through the ORM query, but # the return type of a raw string query is just a collection of values # rather than an ORM model. Doing this step as a separate query enables # passing just the IDs to the second request, which allows proper ORM # models to be returned as a result. snapshot_ids_query = f""" SELECT history.{history_table_primary_key_col_name}, history.{master_table_primary_key_col_name}, history.valid_to FROM {history_table_name} history JOIN ( SELECT {master_table_primary_key_col_name}, MAX(valid_from) AS valid_from FROM {history_table_name} WHERE {master_table_primary_key_col_name} IN ({ids_list}) GROUP BY {master_table_primary_key_col_name} ) AS most_recent_valid_from ON history.{master_table_primary_key_col_name} = most_recent_valid_from.{master_table_primary_key_col_name} WHERE history.valid_from = most_recent_valid_from.valid_from; """ results = session.execute(text(snapshot_ids_query)).fetchall() # Use only results where valid_to is None to exclude any overlapping # non-open snapshots snapshot_ids = [ snapshot_id for snapshot_id, master_id, valid_to in results if valid_to is None ] # Removing the below early return will pass in tests but fail in # production, because SQLite allows "IN ()" but Postgres does not if not snapshot_ids: return [] filter_statement = ( "{historical_table}.{primary_key_column} IN ({ids_list})".format( historical_table=history_table_name, primary_key_column=history_table_class. get_primary_key_column_name(), ids_list=", ".join([str(id) for id in snapshot_ids]), )) return session.query(history_table_class).filter( text(filter_statement)).all()