def setUp(self) -> None: self.database_key = SQLAlchemyDatabaseKey.canonical_for_schema( SchemaType.STATE) local_postgres_helpers.use_on_disk_postgresql_database( SQLAlchemyDatabaseKey.canonical_for_schema(SchemaType.STATE)) self.state_code = "US_XX"
def test_canonical_for_schema_local_only(self) -> None: _ = SQLAlchemyDatabaseKey.canonical_for_schema( schema_type=SchemaType.STATE) with patch( "recidiviz.utils.environment.get_gcp_environment", Mock(return_value="production"), ): with self.assertRaises(RuntimeError): _ = SQLAlchemyDatabaseKey.canonical_for_schema( schema_type=SchemaType.STATE) _ = SQLAlchemyDatabaseKey.canonical_for_schema( schema_type=SchemaType.STATE)
def _get_all_schema_objects_in_db( self, schema_person_type: SchemaPersonType, schema: ModuleType, schema_object_type_names_to_ignore: List[str], ) -> List[DatabaseEntity]: """Generates a list of all schema objects stored in the database that can be reached from an object with the provided type. Args: schema_person_type: Class type of the root of the schema object graph (e.g. StatePerson). schema: The schema module that root_object_type is defined in. schema_object_type_names_to_ignore: type names for objects defined in the schema that we shouldn't assert are included in the object graph. Returns: A list of all schema objects that can be reached from the object graph rooted at the singular object of type |schema_person_type|. Throws: If more than one object of type |schema_person_type| exists in the DB. """ with SessionFactory.using_database( SQLAlchemyDatabaseKey.canonical_for_schema( schema_type_for_schema_module(schema)), autocommit=False, ) as session: person = one(session.query(schema_person_type).all()) schema_objects: Set[DatabaseEntity] = {person} unprocessed = list([person]) while unprocessed: schema_object = unprocessed.pop() related_entities = [] for (relationship_name ) in schema_object.get_relationship_property_names(): related = getattr(schema_object, relationship_name) # Relationship can return either a list or a single item if isinstance(related, DatabaseEntity): related_entities.append(related) if isinstance(related, list): related_entities.extend(related) for obj in related_entities: if obj not in schema_objects: schema_objects.add(obj) unprocessed.append(obj) self._check_all_non_history_schema_object_types_in_list( list(schema_objects), schema, schema_object_type_names_to_ignore) return list(schema_objects)
def setUp(self) -> None: self.db_dir = local_postgres_helpers.start_on_disk_postgresql_database( ) self.database_key = SQLAlchemyDatabaseKey.canonical_for_schema( self.schema_type) self.overridden_env_vars = ( local_postgres_helpers.update_local_sqlalchemy_postgres_env_vars()) self.engine = create_engine( local_postgres_helpers.postgres_db_url_from_env_vars())
def test_state_legacy_db(self) -> None: db_key_1 = SQLAlchemyDatabaseKey(schema_type=SchemaType.STATE) db_key_1_dup = SQLAlchemyDatabaseKey.canonical_for_schema( schema_type=SchemaType.STATE) self.assertEqual(db_key_1, db_key_1_dup) # TODO(#7984): Once we have cut over all traffic to non-legacy state DBs and # removed the LEGACY database version, remove this part of the test. db_key_legacy = SQLAlchemyDatabaseKey.for_state_code( StateCode.US_AK, SQLAlchemyStateDatabaseVersion.LEGACY) self.assertEqual(db_key_1, db_key_legacy)
def for_state( cls, region: str, enum_overrides: Optional[EnumOverrides] = None, ) -> IngestMetadata: return IngestMetadata( region=region, jurisdiction_id="", ingest_time=datetime.datetime(2020, 4, 14, 12, 31, 00), enum_overrides=enum_overrides or EnumOverrides.empty(), system_level=SystemLevel.STATE, database_key=SQLAlchemyDatabaseKey.canonical_for_schema(SchemaType.STATE), )
def setUp(self) -> None: self.database_key = SQLAlchemyDatabaseKey.canonical_for_schema(SchemaType.STATE) local_postgres_helpers.use_on_disk_postgresql_database(self.database_key) # State persistence ends up having to instantiate the us_nd_controller to # get enum overrides, and the controller goes on to create bigquery, # storage, and tasks clients. self.bq_client_patcher = patch("google.cloud.bigquery.Client") self.storage_client_patcher = patch("google.cloud.storage.Client") self.task_client_patcher = patch("google.cloud.tasks_v2.CloudTasksClient") self.bq_client_patcher.start() self.storage_client_patcher.start() self.task_client_patcher.start()
def main(schema_type: SchemaType, message: str, use_local_db: bool) -> None: """Runs the script to autogenerate migrations.""" database_key = SQLAlchemyDatabaseKey.canonical_for_schema(schema_type) if use_local_db: # TODO(#4619): We should eventually move this from a local postgres instance to running # postgres from a docker container. if not local_postgres_helpers.can_start_on_disk_postgresql_database(): logging.error( "pg_ctl is not installed, so the script cannot be run locally. " "--project-id must be specified to run against staging or production." ) logging.error("Exiting...") sys.exit(1) logging.info("Starting local postgres database for autogeneration...") tmp_db_dir = local_postgres_helpers.start_on_disk_postgresql_database() original_env_vars = ( local_postgres_helpers.update_local_sqlalchemy_postgres_env_vars()) else: # TODO(Recidiviz/zenhub-tasks#134): This code path will throw when pointed at staging # because we havne't created valid read-only users there just yet. try: original_env_vars = SQLAlchemyEngineManager.update_sqlalchemy_env_vars( database_key=database_key, readonly_user=True) except ValueError as e: logging.warning("Error fetching SQLAlchemy credentials: %s", e) logging.warning( "Until readonly users are created, we cannot autogenerate migrations against staging." ) logging.warning( "See https://github.com/Recidiviz/zenhub-tasks/issues/134") sys.exit(1) try: config = alembic.config.Config(database_key.alembic_file) if use_local_db: upgrade(config, "head") revision(config, autogenerate=True, message=message) except Exception as e: logging.error("Automigration generation failed: %s", e) local_postgres_helpers.restore_local_env_vars(original_env_vars) if use_local_db: logging.info("Stopping local postgres database...") local_postgres_helpers.stop_and_clear_on_disk_postgresql_database( tmp_db_dir)
def _commit_person( person: SchemaPersonType, system_level: SystemLevel, ingest_time: datetime.datetime, ): db_key = SQLAlchemyDatabaseKey.canonical_for_schema( system_level.schema_type()) with SessionFactory.using_database(db_key) as act_session: merged_person = act_session.merge(person) metadata = IngestMetadata( region="somewhere", jurisdiction_id="12345", ingest_time=ingest_time, system_level=system_level, database_key=db_key, ) update_historical_snapshots(act_session, [merged_person], [], metadata)
def _get_old_enum_values(schema_type: SchemaType, enum_name: str) -> List[str]: """Fetches the current enum values for the given schema and enum name.""" # Setup temp pg database db_dir = local_postgres_helpers.start_on_disk_postgresql_database() database_key = SQLAlchemyDatabaseKey.canonical_for_schema(schema_type) overridden_env_vars = ( local_postgres_helpers.update_local_sqlalchemy_postgres_env_vars()) engine = create_engine( local_postgres_helpers.postgres_db_url_from_env_vars()) try: # Fetch enums default_config = { "file": database_key.alembic_file, "script_location": database_key.migrations_location, } with runner(default_config, engine) as r: r.migrate_up_to("head") conn = engine.connect() rows = conn.execute(f""" SELECT e.enumlabel as enum_value FROM pg_type t JOIN pg_enum e ON t.oid = e.enumtypid JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace WHERE n.nspname = 'public' AND t.typname = '{enum_name}'; """) enums = [row[0] for row in rows] finally: # Teardown temp pg database local_postgres_helpers.restore_local_env_vars(overridden_env_vars) local_postgres_helpers.stop_and_clear_on_disk_postgresql_database( db_dir) return enums
def tearDown(self) -> None: local_postgres_helpers.teardown_on_disk_postgresql_database( SQLAlchemyDatabaseKey.canonical_for_schema(SchemaType.STATE))
def setUp(self) -> None: local_postgres_helpers.use_on_disk_postgresql_database( SQLAlchemyDatabaseKey.canonical_for_schema(SchemaType.STATE))
def _assert_expected_snapshots_for_schema_object( self, expected_schema_object: DatabaseEntity, ingest_times: List[datetime.date]) -> None: """ Assert that we have expected history snapshots for the given schema object that has been ingested at the provided |ingest_times|. """ history_table_class = historical_table_class_from_obj( expected_schema_object) self.assertIsNotNone(history_table_class) schema_obj_primary_key_col_name = primary_key_name_from_obj( expected_schema_object) schema_obj_primary_key_value = primary_key_value_from_obj( expected_schema_object) self.assertIsNotNone(schema_obj_primary_key_value) self.assertEqual(type(schema_obj_primary_key_value), int) schema_obj_foreign_key_column_in_history_table = getattr( history_table_class, schema_obj_primary_key_col_name, None) self.assertIsNotNone(schema_obj_foreign_key_column_in_history_table) with SessionFactory.using_database( SQLAlchemyDatabaseKey.canonical_for_schema( schema_type_for_object(expected_schema_object)), autocommit=False, ) as assert_session: history_snapshots: List[DatabaseEntity] = ( assert_session.query(history_table_class).filter( schema_obj_foreign_key_column_in_history_table == schema_obj_primary_key_value).all()) self.assertEqual( len(history_snapshots), len(ingest_times), f"History snapshots do not correspond to ingest times " f"for object of type " f"[{expected_schema_object.__class__}]", ) self.assertTrue( all( isinstance(s, HistoryTableSharedColumns) for s in history_snapshots)) def as_history_cols( snapshot: DatabaseEntity) -> HistoryTableSharedColumns: if not isinstance(snapshot, HistoryTableSharedColumns): self.fail( f"Snapshot class [{type(snapshot)}] must be a " f"subclass of [{HistoryTableSharedColumns.__name__}]") return snapshot history_snapshots.sort( key=lambda snapshot: as_history_cols(snapshot).valid_from) for i, history_snapshot in enumerate(history_snapshots): expected_valid_from = ingest_times[i] expected_valid_to = (ingest_times[i + 1] if i < len(ingest_times) - 1 else None) self.assertEqual(expected_valid_from, as_history_cols(history_snapshot).valid_from) self.assertEqual(expected_valid_to, as_history_cols(history_snapshot).valid_to) last_history_snapshot = history_snapshots[-1] assert last_history_snapshot is not None self._assert_schema_object_and_historical_snapshot_match( expected_schema_object, last_history_snapshot)
def main( schema_type: SchemaType, repo_root: str, ssl_cert_path: str, dry_run: bool, skip_db_name_check: bool, confirm_hash: Optional[str], ) -> None: """ Invokes the main code path for running migrations. This checks for user validations that the database and branches are correct and then runs existing pending migrations. """ if dry_run: if not local_postgres_helpers.can_start_on_disk_postgresql_database(): logging.error("pg_ctl is not installed. Cannot perform a dry-run.") logging.error("Exiting...") sys.exit(1) logging.info("Creating a dry-run...\n") else: if not ssl_cert_path: logging.error( "SSL certificates are required when running against live databases" ) logging.error("Exiting...") sys.exit(1) logging.info("Using SSL certificate path: %s", ssl_cert_path) is_prod = metadata.project_id() == GCP_PROJECT_PRODUCTION if is_prod: logging.info("RUNNING AGAINST PRODUCTION\n") if not skip_db_name_check: confirm_correct_db_instance(schema_type) confirm_correct_git_branch(repo_root, confirm_hash=confirm_hash) if dry_run: db_keys = [SQLAlchemyDatabaseKey.canonical_for_schema(schema_type)] else: db_keys = [ key for key in SQLAlchemyDatabaseKey.all() if key.schema_type == schema_type ] # Run migrations for key in db_keys: if dry_run: overriden_env_vars = (local_postgres_helpers. update_local_sqlalchemy_postgres_env_vars()) else: overriden_env_vars = SQLAlchemyEngineManager.update_sqlalchemy_env_vars( database_key=key, ssl_cert_path=ssl_cert_path, migration_user=True, ) try: logging.info( "*** Starting postgres migrations for schema [%s], db_name [%s] ***", key.schema_type, key.db_name, ) if dry_run: db_dir = local_postgres_helpers.start_on_disk_postgresql_database( ) config = alembic.config.Config(key.alembic_file) alembic.command.upgrade(config, "head") except Exception as e: logging.error("Migrations failed to run: %s", e) sys.exit(1) finally: local_postgres_helpers.restore_local_env_vars(overriden_env_vars) if dry_run: try: logging.info("Stopping local postgres database") local_postgres_helpers.stop_and_clear_on_disk_postgresql_database( db_dir) except Exception as e2: logging.error("Error cleaning up postgres: %s", e2)
OVERALL_THRESHOLD, ) from recidiviz.tools.postgres import local_postgres_helpers EXTERNAL_ID = "EXTERNAL_ID" EXTERNAL_ID_2 = "EXTERNAL_ID_2" FULL_NAME_1 = "TEST_FULL_NAME_1" STATE_CODE = "US_ND" COUNTY_CODE = "COUNTY" DEFAULT_METADATA = IngestMetadata( region="us_nd", jurisdiction_id="12345678", system_level=SystemLevel.STATE, ingest_time=datetime(year=1000, month=1, day=1), database_key=SQLAlchemyDatabaseKey.canonical_for_schema( schema_type=SchemaType.STATE ), ) ID_TYPE = "ID_TYPE" ID = 1 ID_2 = 2 ID_3 = 3 ID_4 = 4 SENTENCE_GROUP_ID = "SG1" SENTENCE_GROUP_ID_2 = "SG2" SENTENCE_GROUP_ID_3 = "SG3" SENTENCE_GROUP_ID_4 = "SG4" STATE_ERROR_THRESHOLDS_WITH_FORTY_PERCENT_RATIOS = { SystemLevel.STATE: { OVERALL_THRESHOLD: 0.4,