コード例 #1
0
ファイル: dao.py プロジェクト: Leo-Ryu/pulse-data
def get_file_metadata_row_for_path(
    session: Session, region_code: str, path: GcsfsFilePath
) -> Union[schema.DirectIngestRawFileMetadata,
           schema.DirectIngestIngestFileMetadata]:
    """Returns metadata information for the provided path. If the file has not yet been registered in the
    appropriate metadata table, this function will generate a file_id to return with the metadata.
    """

    parts = filename_parts_from_path(path)

    if parts.file_type == GcsfsDirectIngestFileType.INGEST_VIEW:
        results = session.query(
            schema.DirectIngestIngestFileMetadata).filter_by(
                region_code=region_code,
                is_invalidated=False,
                normalized_file_name=path.file_name).all()
    elif parts.file_type == GcsfsDirectIngestFileType.RAW_DATA:
        results = session.query(schema.DirectIngestRawFileMetadata).filter_by(
            region_code=region_code,
            normalized_file_name=path.file_name).all()
    else:
        raise ValueError(f'Unexpected path type: {parts.file_type}')

    if len(results) != 1:
        raise ValueError(
            f'Unexpected number of metadata results for path {path.abs_path()}: [{len(results)}]'
        )

    return one(results)
コード例 #2
0
def check_people_do_not_have_multiple_ids_same_type(
        session: Session, region_code: str,
        output_people: List[schema.StatePerson]) -> bool:
    """Validates that person has two ids of the same type (for states configured to enforce this invariant)."""

    check_not_dirty(session)

    logging.info(
        "[Invariant validation] Checking that no person has multiple external ids of the same type."
    )
    if state_allows_multiple_ids_same_type(region_code):
        logging.info(
            "[Invariant validation] Multiple external ids of the same type allowed for [%s] - skipping.",
            region_code)
        return True

    person_ids = {p.person_id for p in output_people}
    if not person_ids:
        logging.warning(
            "[Invariant validation] No StatePersonExternalIds in the output set - skipping validations."
        )
        return True

    counts_subquery = session.query(
        schema.StatePersonExternalId.state_code.label('state_code'),
        schema.StatePersonExternalId.person_id.label('person_id'),
        schema.StatePersonExternalId.id_type.label('id_type'),
        func.count().label('cnt')
    ).filter(schema.StatePersonExternalId.state_code == region_code.upper(
    )).filter(
        # Ideally we would not filter by person_ids, but that query takes ~10s on a test of US_PA external ids. Since
        # this will be run for every file, that sort of performance is prohibitive. We instead filter by just the person
        # ids we think we have touched this session.
        schema.StatePersonExternalId.person_id.in_(person_ids)).group_by(
            schema.StatePersonExternalId.state_code,
            schema.StatePersonExternalId.person_id,
            schema.StatePersonExternalId.id_type,
        ).subquery()

    query = session.query(
        counts_subquery.c.state_code,
        counts_subquery.c.person_id,
        counts_subquery.c.id_type,
        counts_subquery.c.cnt,
    ).filter(counts_subquery.c.cnt > 1).limit(1)

    results = query.all()

    if results:
        _state_code, person_id, id_type, count = one(results)
        logging.error(
            '[Invariant validation] Found people with multiple ids of the same type. First example: '
            'person_id=[%s], id_type=[%s] is used [%s] times.', person_id,
            id_type, count)
        return False

    logging.info(
        "[Invariant validation] Found no people with multiple external ids of the same type."
    )
    return True
コード例 #3
0
    def _write_snapshots_for_existing_entities(
        self,
        session: Session,
        context: "_SnapshotContext",
        snapshot_time: datetime,
        schema: ModuleType,
    ) -> None:
        """Writes snapshot updates for entities that already have snapshots
        present in the database
        """
        historical_class = _get_historical_class(type(context.schema_object),
                                                 schema)
        new_historical_snapshot = historical_class()
        self._copy_entity_fields_to_historical_snapshot(
            context.schema_object, new_historical_snapshot)
        new_historical_snapshot.valid_from = snapshot_time

        # Snapshot must be merged separately from record tree, as they are not
        # included in the ORM model relationships (to avoid needing to load
        # the entire snapshot chain at once)
        session.merge(new_historical_snapshot)

        # Close last snapshot if one is present
        if context.most_recent_snapshot is not None:
            if not isinstance(context.most_recent_snapshot,
                              HistoryTableSharedColumns):
                raise ValueError(
                    f"Snapshot class [{type(context.most_recent_snapshot)}] "
                    f"must be a subclass of "
                    f"[{HistoryTableSharedColumns.__name__}]")

            context.most_recent_snapshot.valid_to = snapshot_time
            session.merge(context.most_recent_snapshot)
コード例 #4
0
    def _write_snapshots_for_new_entities(
        self,
        session: Session,
        context: "_SnapshotContext",
        snapshot_time: datetime,
        schema,
    ) -> None:
        """Writes snapshots for any new entities, including any required manual
        adjustments based on provided start and end times
        """
        historical_class = _get_historical_class(type(context.schema_object),
                                                 schema)
        new_historical_snapshot = historical_class()
        self._copy_entity_fields_to_historical_snapshot(
            context.schema_object, new_historical_snapshot)

        provided_start_time = None
        provided_end_time = None

        # Validate provided start and end times
        if (context.provided_start_time
                and context.provided_start_time.date() < snapshot_time.date()):
            provided_start_time = context.provided_start_time

        if (context.provided_end_time
                and context.provided_end_time.date() < snapshot_time.date()):
            provided_end_time = context.provided_end_time

        if (provided_start_time and provided_end_time
                and provided_start_time >= provided_end_time):
            provided_start_time = None

        if provided_start_time is not None and provided_end_time is None:
            new_historical_snapshot.valid_from = provided_start_time
        elif provided_end_time is not None:
            new_historical_snapshot.valid_from = provided_end_time
        else:
            new_historical_snapshot.valid_from = snapshot_time

        # Snapshot must be merged separately from record tree, as they are not
        # included in the ORM model relationships (to avoid needing to load
        # the entire snapshot chain at once)
        session.merge(new_historical_snapshot)

        # If both start and end time were provided, an earlier snapshot needs to
        # be created, reflecting the state of the entity before its current
        # completed state
        if provided_start_time and provided_end_time:
            initial_snapshot = historical_class()
            self._copy_entity_fields_to_historical_snapshot(
                context.schema_object, initial_snapshot)
            initial_snapshot.valid_from = provided_start_time
            initial_snapshot.valid_to = provided_end_time

            self.post_process_initial_snapshot(context, initial_snapshot)

            session.merge(initial_snapshot)
コード例 #5
0
 def _execute_statement(self, statement: str) -> None:
     session = Session(bind=self.postgres_engine)
     try:
         session.execute(statement)
         session.commit()
     except Exception as e:
         logging.warning("Failed to cleanup: %s", e)
         session.rollback()
     finally:
         session.close()
コード例 #6
0
    def for_prod_data_client(
        cls,
        database_key: SQLAlchemyDatabaseKey,
        ssl_cert_path: str,
        *,
        autocommit: bool = True,
    ) -> Iterator[Session]:
        """Implements a context manager for db sessions for use in prod-data-client."""
        engine = SQLAlchemyEngineManager.get_engine_for_database_with_ssl_certs(
            database_key=database_key, ssl_cert_path=ssl_cert_path)
        if engine is None:
            raise ValueError(f"No engine set for key [{database_key}]")

        try:
            session = Session(bind=engine)
            cls._alter_session_variables(session)
            cls._apply_session_listener_for_schema_base(
                database_key.declarative_meta, session)
            yield session
            if autocommit:
                try:
                    session.commit()
                except Exception as e:
                    session.rollback()
                    raise e
        finally:
            session.close()
コード例 #7
0
 def _alter_session_variables(cls, session: Session) -> None:
     # Postgres uses a query cost analysis heuristic to decide what type of read to use for a particular query. It
     # sometimes chooses to use a sequential read because for hard disk drives (HDDs, as opposed to solid state
     # drives, SSDs) that may be faster than jumping around to random pages of an index. This is especially likely
     # when running over small sets of data. Setting this option changes the heuristic to almost always prefer index
     # reads.
     #
     # Our postgres instances run on SSDs, so this should increase performance for us. This is also important
     # because sequential reads lock an entire table, whereas index reads only lock the particular predicate from a
     # query. See https://www.postgresql.org/docs/12/transaction-iso.html and
     # https://stackoverflow.com/questions/42288808/why-does-postgresql-serializable-transaction-think-this-as-conflict.
     #
     # TODO(#3928): Once defined in code, set this on the SQL instance itself instead of per session.
     if session.bind.dialect.name == "postgresql":
         session.execute("SET random_page_cost=1;")
コード例 #8
0
ファイル: dao.py プロジェクト: pnchbck/pulse-data
def get_file_metadata_row(
        session: Session,
        file_type: GcsfsDirectIngestFileType,
        file_id: int
) -> Union[schema.DirectIngestRawFileMetadata, schema.DirectIngestIngestFileMetadata]:
    """Queries for the file metadata row by the metadata row primary key."""

    if file_type == GcsfsDirectIngestFileType.INGEST_VIEW:
        results = session.query(schema.DirectIngestIngestFileMetadata).filter_by(file_id=file_id).all()
    elif file_type == GcsfsDirectIngestFileType.RAW_DATA:
        results = session.query(schema.DirectIngestRawFileMetadata).filter_by(file_id=file_id).all()
    else:
        raise ValueError(f'Unexpected path type: {file_type}')

    return one(results)
コード例 #9
0
    def _pending_to_persistent(session: Session, instance: Any) -> None:
        """Called when a SQLAlchemy object transitions to a persistent object. If this function throws, the session
        will be rolled back and that object will not be committed."""
        if not isinstance(instance, DirectIngestIngestFileMetadata):
            return

        results = (session.query(DirectIngestIngestFileMetadata).filter_by(
            is_invalidated=False,
            is_file_split=False,
            region_code=instance.region_code,
            file_tag=instance.file_tag,
            ingest_database_name=instance.ingest_database_name,
            datetimes_contained_lower_bound_exclusive=instance.
            datetimes_contained_lower_bound_exclusive,
            datetimes_contained_upper_bound_inclusive=instance.
            datetimes_contained_upper_bound_inclusive,
        ).all())

        if len(results) > 1:
            raise IntegrityError(
                f"Attempting to commit repeated non-file split DirectIngestIngestFileMetadata row for "
                f"region_code={instance.region_code}, file_tag={instance.file_tag}, "
                f"ingest_database_name={instance.ingest_database_name}",
                f"datetimes_contained_lower_bound_exclusive={instance.datetimes_contained_lower_bound_exclusive}, "
                f"datetimes_contained_upper_bound_inclusive={instance.datetimes_contained_upper_bound_inclusive}",
            )
コード例 #10
0
def get_ingest_file_metadata_row_for_path(
        session: Session, region_code: str, path: GcsfsFilePath,
        ingest_database_name: str) -> schema.DirectIngestIngestFileMetadata:
    """Returns metadata information for the provided path. If the file has not yet been registered in the
    appropriate metadata table, this function will generate a file_id to return with the metadata.
    """

    parts = filename_parts_from_path(path)

    if parts.file_type != GcsfsDirectIngestFileType.INGEST_VIEW:
        raise ValueError(f"Unexpected file type [{parts.file_type}]")

    results = (session.query(schema.DirectIngestIngestFileMetadata).filter_by(
        region_code=region_code.upper(),
        is_invalidated=False,
        normalized_file_name=path.file_name,
        ingest_database_name=ingest_database_name,
    ).all())

    if len(results) != 1:
        raise ValueError(
            f"Unexpected number of metadata results for path {path.abs_path()}: [{len(results)}]"
        )

    return one(results)
コード例 #11
0
def get_ingest_view_metadata_for_export_job(
    session: Session,
    region_code: str,
    file_tag: str,
    datetimes_contained_lower_bound_exclusive: Optional[datetime.datetime],
    datetimes_contained_upper_bound_inclusive: datetime.datetime,
    ingest_database_name: str,
) -> Optional[schema.DirectIngestIngestFileMetadata]:
    """Returns the ingest file metadata row corresponding to the export job with the provided args. Throws if such a
    row does not exist.
    """
    results = (session.query(schema.DirectIngestIngestFileMetadata).filter_by(
        region_code=region_code.upper(),
        file_tag=file_tag,
        is_invalidated=False,
        is_file_split=False,
        datetimes_contained_lower_bound_exclusive=
        datetimes_contained_lower_bound_exclusive,
        datetimes_contained_upper_bound_inclusive=
        datetimes_contained_upper_bound_inclusive,
        ingest_database_name=ingest_database_name,
    ).all())

    if not results:
        return None

    return one(results)
コード例 #12
0
ファイル: persistence.py プロジェクト: jazzPouls/pulse-data
def retry_transaction(
    session: Session,
    measurements: MeasurementMap,
    txn_body: Callable[[Session], bool],
    max_retries: Optional[int],
) -> bool:
    """Retries the transaction if a serialization failure occurs.

    Handles management of committing, rolling back, and closing the `session`. `txn_body` can return False to force the
    transaction to be aborted, otherwise return True.

    Returns:
        True, if the transaction succeeded.
        False, if the transaction was aborted by `txn_body`.
    """
    num_retries = 0
    try:
        while True:
            try:
                should_continue = txn_body(session)

                if not should_continue:
                    session.rollback()
                    return should_continue

                session.commit()
                return True
            except sqlalchemy.exc.DBAPIError as e:
                session.rollback()
                if max_retries and num_retries >= max_retries:
                    raise
                if (isinstance(e.orig, psycopg2.OperationalError)
                        and e.orig.pgcode == SERIALIZATION_FAILURE):
                    logging.info(
                        "Retrying transaction due to serialization failure: %s",
                        e)
                    num_retries += 1
                    continue
                raise
            except Exception:
                session.rollback()
                raise
    finally:
        measurements.measure_int_put(m_retries, num_retries)
        session.close()
コード例 #13
0
def get_date_sorted_unprocessed_raw_files_for_region(
    session: Session,
    region_code: str,
) -> List[schema.DirectIngestRawFileMetadata]:
    """Returns metadata for all raw files that do not have a processed_time from earliest to latest"""
    return (session.query(schema.DirectIngestRawFileMetadata).filter_by(
        region_code=region_code.upper(), processed_time=None).order_by(
            schema.DirectIngestRawFileMetadata.
            datetimes_contained_upper_bound_inclusive.asc()).all())
コード例 #14
0
ファイル: dao.py プロジェクト: Leo-Ryu/pulse-data
def get_ingest_view_metadata_pending_export(
        session: Session,
        region_code: str) -> List[schema.DirectIngestIngestFileMetadata]:
    """Returns metadata for all ingest files have not yet been exported."""

    return session.query(schema.DirectIngestIngestFileMetadata).filter_by(
        region_code=region_code,
        is_invalidated=False,
        is_file_split=False,
        export_time=None).all()
コード例 #15
0
    def for_schema_base(cls, schema_base: DeclarativeMeta) -> Session:
        engine = SQLAlchemyEngineManager.get_engine_for_schema_base(
            schema_base)
        if engine is None:
            raise ValueError(
                f"No engine set for base [{schema_base.__name__}]")

        session = Session(bind=engine)
        cls._apply_session_listener_for_schema_base(schema_base, session)
        return session
コード例 #16
0
def get_ingest_file_metadata_row(
    session: Session,
    file_id: int,
    ingest_database_name: str,
) -> schema.DirectIngestIngestFileMetadata:
    """Queries for the ingest file metadata row by the metadata row primary key."""

    results = (session.query(schema.DirectIngestIngestFileMetadata).filter_by(
        file_id=file_id, ingest_database_name=ingest_database_name).all())
    return one(results)
コード例 #17
0
ファイル: dao.py プロジェクト: Leo-Ryu/pulse-data
def get_metadata_for_raw_files_discovered_after_datetime(
    session: Session, region_code: str, raw_file_tag: str,
    discovery_time_lower_bound_exclusive: Optional[datetime.datetime]
) -> List[schema.DirectIngestRawFileMetadata]:
    """Returns metadata for all raw files with a given tag that have been updated after the provided date."""

    query = session.query(schema.DirectIngestRawFileMetadata).filter_by(
        region_code=region_code, file_tag=raw_file_tag)
    if discovery_time_lower_bound_exclusive:
        query = query.filter(schema.DirectIngestRawFileMetadata.discovery_time
                             > discovery_time_lower_bound_exclusive)

    return query.all()
コード例 #18
0
ファイル: mappings.py プロジェクト: Recidiviz/pulse-data
    def to_query(self, session: Session) -> Query:
        """Create a query to SELECT each column based on the Mapping's name."""
        select_statements = []
        for new_view_column_name, source_column in attr.asdict(self).items():
            # Default unmapped columns to NULL to ensure we SELECT all columns
            select_statement = null().label(new_view_column_name)

            if source_column is not None:
                select_statement = source_column.label(new_view_column_name)

            select_statements.append(select_statement)

        return session.query(*select_statements)
コード例 #19
0
    def _for_database(cls, database_key: SQLAlchemyDatabaseKey) -> Session:
        # TODO(#8046): When the above method is deleted, move this into `using_database`
        # directly.
        engine = SQLAlchemyEngineManager.get_engine_for_database(
            database_key=database_key)
        if engine is None:
            raise ValueError(f"No engine set for key [{database_key}]")

        session = Session(bind=engine)
        cls._alter_session_variables(session)
        cls._apply_session_listener_for_schema_base(
            database_key.declarative_meta, session)
        return session
コード例 #20
0
def get_date_sorted_unprocessed_ingest_view_files_for_region(
    session: Session,
    region_code: str,
    ingest_database_name: str,
) -> List[schema.DirectIngestIngestFileMetadata]:
    """Returns metadata for all ingest files that do not have a processed_time from earliest to latest"""
    return (session.query(schema.DirectIngestIngestFileMetadata).filter_by(
        region_code=region_code.upper(),
        processed_time=None,
        ingest_database_name=ingest_database_name,
        is_invalidated=False,
    ).order_by(schema.DirectIngestIngestFileMetadata.
               datetimes_contained_upper_bound_inclusive.asc()).all())
コード例 #21
0
ファイル: dao.py プロジェクト: Leo-Ryu/pulse-data
def get_ingest_view_metadata_for_most_recent_valid_job(
        session: Session, region_code: str,
        file_tag: str) -> Optional[schema.DirectIngestIngestFileMetadata]:
    """Returns most recently created export metadata row where is_invalidated is False, or None if there are no
    metadata rows for this file tag for this manager's region."""

    results = session.query(schema.DirectIngestIngestFileMetadata).filter_by(
        region_code=region_code,
        is_invalidated=False,
        is_file_split=False,
        file_tag=file_tag).order_by(schema.DirectIngestIngestFileMetadata.
                                    job_creation_time.desc()).limit(1).all()

    if not results:
        return None

    return one(results)
コード例 #22
0
def assert_no_unexpected_entities_in_db(
    expected_entities: Sequence[DatabaseEntity], session: Session
):
    """Counts all of the entities present in the |expected_entities| graph by
    type and ensures that the same number of entities exists in the DB for each
    type.
    """
    entity_counter: Dict[Type, List[DatabaseEntity]] = defaultdict(list)
    get_entities_by_type(expected_entities, entity_counter)
    for cls, entities_of_cls in entity_counter.items():
        # Standalone classes do not need to be attached to a person by design,
        # so it is valid if some standalone entities are not reachable from the
        # provided |expected_entities|
        if is_standalone_class(cls):
            continue

        expected_ids = set()
        for entity in entities_of_cls:
            expected_ids.add(entity.get_id())
        db_entities = session.query(cls).all()
        db_ids = set()
        for entity in db_entities:
            db_ids.add(entity.get_id())

        if expected_ids != db_ids:
            print("\n********** Entities from |found_persons| **********\n")
            for entity in sorted(entities_of_cls, key=lambda x: x.get_id()):
                print_entity_tree(entity)
            print("\n********** Entities from db **********\n")
            for entity in sorted(db_entities, key=lambda x: x.get_id()):
                print_entity_tree(entity)
            raise ValueError(
                f"For cls {cls.__name__}, found difference in primary keys from"
                f"expected entities and those of entities read from db.\n"
                f"Expected ids not present in db: "
                f"{str(expected_ids - db_ids)}\n"
                f"Db ids not present in expected entities: "
                f"{str(db_ids - expected_ids)}\n"
            )
コード例 #23
0
 def _get_status_using_session(
         self, session: Session) -> DirectIngestInstanceStatus:
     return (session.query(DirectIngestInstanceStatus).filter(
         DirectIngestInstanceStatus.region_code == self.region_code,
         DirectIngestInstanceStatus.instance == self.ingest_instance.value,
     ).one())
コード例 #24
0
    def update_historical_snapshots(
        self,
        session: Session,
        root_people: List[SchemaPersonType],
        orphaned_entities: List[DatabaseEntity],
        ingest_metadata: IngestMetadata,
    ) -> None:
        """For all entities in all record trees rooted at |root_people| and all
        entities in |orphaned_entities|, performs any required historical
        snapshot updates.

        If any entity has no existing historical snapshots, an initial snapshot
        will be created for it.

        If any column of an entity differs from its current snapshot, the
        current snapshot will be closed with period end time of |snapshot_time|
        and a new snapshot will be opened corresponding to the updated entity
        with period start time of |snapshot_time|.

        If neither of these cases applies, no action will be taken on the
        entity.
        """
        logging.info(
            "Beginning historical snapshot updates for %s record tree(s) and %s"
            " orphaned entities",
            len(root_people),
            len(orphaned_entities),
        )

        schema: ModuleType = self.get_schema_module()

        root_entities: List[DatabaseEntity] = list()
        root_entities.extend(root_people)
        root_entities.extend(orphaned_entities)

        self._assert_all_root_entities_unique(root_entities)

        context_registry = _SnapshotContextRegistry()

        self._execute_action_for_all_entities(root_entities,
                                              context_registry.register_entity)

        logging.info(
            "%s master entities registered for snapshot check",
            len(context_registry.all_contexts()),
        )

        most_recent_snapshots = self._fetch_most_recent_snapshots_for_all_entities(
            session, root_entities, schema)
        for snapshot in most_recent_snapshots:
            context_registry.add_snapshot(snapshot, schema)

        logging.info("%s registered entities with existing snapshots",
                     len(most_recent_snapshots))

        # Provided start and end times only need to be set for root_people, not
        # orphaned entities. Provided start and end times are only relevant for
        # new entities with no existing snapshots, and orphaned entities by
        # definition are already present in the database and therefore already
        # have existing snapshots.
        self.set_provided_start_and_end_times(root_people, context_registry)

        logging.info(
            "Provided start and end times set for registered entities")

        for snapshot_context in context_registry.all_contexts():
            self._write_snapshots(session, snapshot_context,
                                  ingest_metadata.ingest_time, schema)

        logging.info("Flushing snapshots")
        session.flush()
        logging.info("All historical snapshots written")
コード例 #25
0
    def _fetch_most_recent_snapshots_for_entity_type(
        self,
        session: Session,
        master_class: Type,
        entity_ids: Set[int],
        schema: ModuleType,
    ) -> List[DatabaseEntity]:
        """Returns a list containing the most recent snapshot for each ID in
        |entity_ids| with type |master_class|
        """

        # Get name of historical table in database (as distinct from name of ORM
        # class representing historical table in code)
        history_table_class = _get_historical_class(master_class, schema)
        history_table_name = history_table_class.__table__.name
        history_table_primary_key_col_name = (
            history_table_class.get_primary_key_column_name())
        # See module assumption #2
        master_table_primary_key_col_name = master_class.get_primary_key_column_name(
        )
        ids_list = ", ".join([str(id) for id in entity_ids])

        # Get snapshot IDs in a separate query. The subquery logic here is ugly
        # and easier to do as a raw string query than through the ORM query, but
        # the return type of a raw string query is just a collection of values
        # rather than an ORM model. Doing this step as a separate query enables
        # passing just the IDs to the second request, which allows proper ORM
        # models to be returned as a result.
        snapshot_ids_query = f"""
        SELECT
          history.{history_table_primary_key_col_name},
          history.{master_table_primary_key_col_name},
          history.valid_to
        FROM {history_table_name} history
        JOIN (
          SELECT 
            {master_table_primary_key_col_name}, 
            MAX(valid_from) AS valid_from
          FROM {history_table_name}
          WHERE {master_table_primary_key_col_name} IN ({ids_list})
          GROUP BY {master_table_primary_key_col_name}
        ) AS most_recent_valid_from
        ON history.{master_table_primary_key_col_name} = 
            most_recent_valid_from.{master_table_primary_key_col_name}
        WHERE history.valid_from = most_recent_valid_from.valid_from;
        """

        results = session.execute(text(snapshot_ids_query)).fetchall()

        # Use only results where valid_to is None to exclude any overlapping
        # non-open snapshots
        snapshot_ids = [
            snapshot_id for snapshot_id, master_id, valid_to in results
            if valid_to is None
        ]

        # Removing the below early return will pass in tests but fail in
        # production, because SQLite allows "IN ()" but Postgres does not
        if not snapshot_ids:
            return []

        filter_statement = (
            "{historical_table}.{primary_key_column} IN ({ids_list})".format(
                historical_table=history_table_name,
                primary_key_column=history_table_class.
                get_primary_key_column_name(),
                ids_list=", ".join([str(id) for id in snapshot_ids]),
            ))

        return session.query(history_table_class).filter(
            text(filter_statement)).all()