def setUp(self): # Ensure we're not getting any DWIM behavior out of the CLI # session: os.environ['RECORDS_MOVER_SESSION_TYPE'] = 'itest' self.resources_dir = os.path.dirname( os.path.abspath(__file__)) + '/../../resources' self.session = Session(session_type='env', default_db_creds_name=None, default_aws_creds_name=None) self.engine = self.session.get_default_db_engine() self.driver = self.session.db_driver(self.engine) if self.engine.name == 'bigquery': self.schema_name = 'bq_itest' # avoid per-table rate limits elif self.engine.name == 'mysql': self.schema_name = 'mysqlitest' else: self.schema_name = 'public' table_name_prefix = "itest_" build_num = os.environ.get("CIRCLE_BUILD_NUM", "local") current_epoch = int(time.time()) self.table_name = f"{table_name_prefix}{build_num}_{current_epoch}" self.fixture = RecordsDatabaseFixture(self.engine, schema_name=self.schema_name, table_name=self.table_name) self.fixture.tear_down() purge_old_tables(self.schema_name, table_name_prefix) logger.debug("Initialized class!") self.meta = MetaData() self.records = self.session.records
def test_get_default_db_engine_from_name(self, mock_engine_from_db_facts, mock_CredsViaLastPass, mock_os): session = Session(session_type='lpass', default_db_creds_name='foo', default_aws_creds_name=None) mock_creds = mock_CredsViaLastPass.return_value out = session.get_default_db_engine() self.assertEqual(out, mock_engine_from_db_facts.return_value) mock_engine_from_db_facts.assert_called_with( mock_creds.default_db_facts.return_value)
def test_get_default_db_engine_no_default( self, mock_engine_from_db_facts, mock_db_facts_from_env, mock_creds_via_env_os, mock_os, mock_get_config, mock_google_auth_default, mock_google_cloud_storage_Client): session = Session() self.assertEqual(session.get_default_db_engine(), mock_engine_from_db_facts.return_value) mock_db_facts_from_env.assert_called_with() mock_db_facts = mock_db_facts_from_env.return_value mock_engine_from_db_facts.assert_called_with(mock_db_facts)
def test_file_url(self, mock_UrlResolver, mock_creds_via_env_os, mock_os, mock_get_config, mock_google_auth_default, mock_google_cloud_storage_Client): mock_credentials = Mock(name='credentials') mock_project = Mock(name='project') mock_google_auth_default.return_value = (mock_credentials, mock_project) session = Session(scratch_s3_url='s3://bar/baz') self.assertEqual(session.file_url('s3://bar/baz'), mock_UrlResolver.return_value.file_url.return_value)
def test_get_db_engine_use_sesssion_creds( self, mock_engine_from_db_facts, mock_os, mock_get_config, mock_google_auth_default, mock_google_cloud_storage_Client): mock_db_creds_name = Mock(name='db_creds_name') mock_creds = Mock(name='creds') session = Session(creds=mock_creds) out = session.get_db_engine(mock_db_creds_name) mock_creds.db_facts.assert_called_with(mock_db_creds_name) mock_engine_from_db_facts.assert_called_with( mock_creds.db_facts.return_value) self.assertEqual(mock_engine_from_db_facts.return_value, out)
def test_env_type_uses_creds_via_env(self, mock_CredsViaEnv, mock_os, mock_get_config, mock_google_auth_default, mock_google_cloud_storage_Client): mock_creds = mock_CredsViaEnv.return_value session = Session(session_type='env') self.assertEqual(session.creds, mock_creds)
def test_db_driver(self, mock_db_driver, mock_UrlResolver, mock_os, mock_get_config, mock_google_auth_default, mock_google_cloud_storage_Client): mock_creds = Mock(name='creds') mock_db = Mock(name='db') mock_url_resolver = mock_UrlResolver.return_value mock_scratch_s3_url = mock_creds.default_scratch_s3_url.return_value mock_s3_temp_base_loc = mock_url_resolver.directory_url.return_value session = Session(creds=mock_creds) out = session.db_driver(mock_db) self.assertEqual(out, mock_db_driver.return_value) mock_url_resolver.directory_url.assert_called_with(mock_scratch_s3_url) mock_db_driver.assert_called_with( db=mock_db, url_resolver=mock_url_resolver, s3_temp_base_loc=mock_s3_temp_base_loc)
def run_records_mover_job(source_method_name: str, target_method_name: str, job_name: str, config: JobConfig) -> MoveResult: session = Session() try: source_method = getattr(session.records.sources, source_method_name) target_method = getattr(session.records.targets, target_method_name) logger.info('Starting...') source_kwargs = config_to_args(config=config['source'], method=source_method, session=session) target_kwargs = config_to_args(config=config['target'], method=target_method, session=session) pi_config = {k: config[k] for k in config if k not in ['source', 'target', 'func']} pi_kwargs = config_to_args(pi_config, method=ProcessingInstructions, session=session) processing_instructions = ProcessingInstructions(**pi_kwargs) records = session.records source = source_method(**source_kwargs) target = target_method(**target_kwargs) return records.move(source, target, processing_instructions) except Exception: logger.error('', exc_info=True) raise
def test_session_boto3_session_via_url_resolver_default( self, mock_boto3_session, mock_os, mock_get_config, mock_google_auth_default, mock_google_cloud_storage_Client): session = Session() boto3_session = session.url_resolver.boto3_session_getter() self.assertEqual(boto3_session, mock_boto3_session.Session.return_value)
def test_set_stream_logging(self, mock_set_stream_logging, mock_os, mock_get_config, mock_google_auth_default, mock_google_cloud_storage_Client): session = Session() mock_name = Mock(name='name') mock_level = Mock(name='level') mock_stream = Mock(name='stream') mock_fmt = Mock(name='fmt') mock_datefmt = Mock(name='datefmt') session.set_stream_logging(name=mock_name, level=mock_level, stream=mock_stream, fmt=mock_fmt, datefmt=mock_datefmt) mock_set_stream_logging.assert_called_with(name=mock_name, level=mock_level, stream=mock_stream, fmt=mock_fmt, datefmt=mock_datefmt)
def test_session_gcp_creds_via_url_resolver_default( self, mock_creds_via_env_os, mock_os, mock_get_config, mock_google_auth_default, mock_google_cloud_storage_Client): mock_credentials = Mock(name='credentials') mock_project = Mock(name='project') mock_google_auth_default.return_value = (mock_credentials, mock_project) session = Session() gcp_credentials = session.url_resolver.gcp_credentials_getter() self.assertEqual(gcp_credentials, mock_credentials)
def test_session_boto3_session_via_url_resolver_cached( self, mock_boto3_session, mock_creds_via_env_os, mock_os, mock_get_config, mock_google_auth_default, mock_google_cloud_storage_Client): session = Session() boto3_session = session.url_resolver.boto3_session_getter() self.assertEqual(boto3_session, mock_boto3_session.Session.return_value) second_boto3_session = session.url_resolver.boto3_session_getter() self.assertEqual(second_boto3_session, mock_boto3_session.Session.return_value) mock_boto3_session.Session.assert_called_once_with()
def test_records(self, mock_Records, mock_google_auth_default, mock_google_cloud_storage_Client): mock_credentials = Mock(name='credentials') mock_project = Mock(name='project') mock_google_auth_default.return_value = (mock_credentials, mock_project) session = Session(session_type='cli', default_db_creds_name=None, default_aws_creds_name=None, default_gcp_creds_name=None) self.assertEqual(mock_Records.return_value, session.records) mock_Records.assert_called_with(db_driver=ANY, url_resolver=ANY)
def test_s3_url_from_get_config(self, mock_set_stream_logging, mock_os, mock_get_config, mock_google_auth_default, mock_google_cloud_storage_Client): mock_os.environ = {} mock_config_result = mock_get_config.return_value mock_config_result.config = { 'aws': { 's3_scratch_url': 's3://foundit/' } } session = Session() self.assertEqual(session.creds.default_scratch_s3_url(), 's3://foundit/')
def move_and_verify(self, source_dbname: str, target_dbname: str) -> None: session = Session() records = session.records targets = records.targets sources = records.sources source_engine = session.get_db_engine(source_dbname) target_engine = session.get_db_engine(target_dbname) source_schema_name = schema_name(source_dbname) target_schema_name = schema_name(target_dbname) source_table_name = f'itest_source_{BUILD_NUM}_{CURRENT_EPOCH}' records_database_fixture = RecordsDatabaseFixture( source_engine, source_schema_name, source_table_name) records_database_fixture.tear_down() records_database_fixture.bring_up() existing = ExistingTableHandling.DROP_AND_RECREATE source = sources.table(schema_name=source_schema_name, table_name=source_table_name, db_engine=source_engine) target = targets.table(schema_name=target_schema_name, table_name=TARGET_TABLE_NAME, db_engine=target_engine, existing_table_handling=existing) out = records.move(source, target) # redshift doesn't give reliable info on load results, so this # will be None or 1 self.assertNotEqual(0, out.move_count) validator = RecordsTableValidator(target_engine, source_db_engine=source_engine) validator.validate(schema_name=target_schema_name, table_name=TARGET_TABLE_NAME) quoted_target = quote_schema_and_table(target_engine, target_schema_name, TARGET_TABLE_NAME) sql = f"DROP TABLE {quoted_target}" target_engine.execute(sql) records_database_fixture.tear_down()
def purge_old_tables(schema_name: str, table_name_prefix: str, db_name: Optional[str] = None) -> None: session = Session() if db_name is None: db_engine = session.get_default_db_engine() else: db_engine = session.get_db_engine(db_name) inspector = inspect(db_engine) table_names = inspector.get_table_names(schema=schema_name) print(f"All tables name in {schema_name}: {table_names}") purgable_table_names = [ table_name for table_name in table_names if table_name.startswith(f"{table_name_prefix}_") and is_old(table_name) ] print( f"Tables to purge matching {schema_name}.{table_name_prefix}_: {purgable_table_names}" ) for table_name in purgable_table_names: sql = f"DROP TABLE {quote_schema_and_table(db_engine, schema_name, table_name)}" print(sql) db_engine.execute(sql)
def test_session_gcs_client_via_url_resolver_cached( self, mock_os, mock_get_config, mock_google_auth_default, mock_google_cloud_storage_Client): mock_credentials = Mock(name='credentials') mock_project = Mock(name='project') mock_google_auth_default.return_value = (mock_credentials, mock_project) session = Session() gcs_client = session.url_resolver.gcs_client_getter() self.assertEqual(gcs_client, mock_google_cloud_storage_Client.return_value) second_gcs_client = session.url_resolver.gcs_client_getter() self.assertEqual(second_gcs_client, mock_google_cloud_storage_Client.return_value) mock_google_auth_default.assert_called_once_with()
def test_session_gcp_creds_via_url_resolver_cached( self, mock_creds_via_env_os, mock_os, mock_get_config, mock_google_auth_default, mock_google_cloud_storage_Client): mock_credentials = Mock(name='credentials') mock_project = Mock(name='project') mock_google_auth_default.return_value = (mock_credentials, mock_project) session = Session() gcp_credentials = session.url_resolver.gcp_credentials_getter() self.assertEqual(gcp_credentials, mock_credentials) second_gcp_credentials = session.url_resolver.gcp_credentials_getter() self.assertEqual(second_gcp_credentials, mock_credentials) expected_scopes = ( 'https://www.googleapis.com/auth/devstorage.full_control', 'https://www.googleapis.com/auth/devstorage.read_only', 'https://www.googleapis.com/auth/devstorage.read_write') mock_google_auth_default.assert_called_once_with( scopes=expected_scopes)
def purge_old_test_sheets(cred_name: str, spreadsheet_id: str) -> None: session = Session() creds = session.creds gsheet_creds = creds.google_sheets(cred_name) http_authorized = authorize(gsheet_creds) service = build('sheets', 'v4', http=http_authorized) sheet_ids = find_old_test_sheets(service, spreadsheet_id) for sheet_id in sheet_ids: try: delete_sheet_by_id(service, spreadsheet_id, sheet_id) except googleapiclient.errors.HttpError: print( f"Could not delete sheet {sheet_id} (another purger running at the same time?)" )
def test_session_boto3_session_via_url_resolver_specified( self, mock_boto3_session, mock_CredsViaEnv, mock_os, mock_get_config, mock_google_auth_default, mock_google_cloud_storage_Client): mock_default_aws_creds_name = Mock(name='default_aws_creds_name') session = Session(default_aws_creds_name=mock_default_aws_creds_name, session_type='env') boto3_session = session.url_resolver.boto3_session_getter() self.assertEqual( boto3_session, mock_CredsViaEnv.return_value.default_boto3_session.return_value) mock_CredsViaEnv.assert_called_with( default_db_creds_name=None, default_aws_creds_name=mock_default_aws_creds_name, default_gcp_creds_name=None, default_db_facts=PleaseInfer.token, default_boto3_session=PleaseInfer.token, default_gcp_creds=PleaseInfer.token, default_gcs_client=PleaseInfer.token, scratch_s3_url=PleaseInfer.token) mock_CredsViaEnv.return_value.default_boto3_session.assert_called()
def __init__( self, db_driver: Union[Callable[[Union['Engine', 'Connection']], 'DBDriver'], PleaseInfer] = PleaseInfer.token, url_resolver: Union[UrlResolver, PleaseInfer] = PleaseInfer.token, session: Union['Session', PleaseInfer] = PleaseInfer.token) -> None: if db_driver is PleaseInfer.token or url_resolver is PleaseInfer.token: if session is PleaseInfer.token: from records_mover import Session # noqa session = Session() if db_driver is PleaseInfer.token: db_driver = session.db_driver if url_resolver is PleaseInfer.token: url_resolver = session.url_resolver self.move = move # type: ignore self.sources = RecordsSources(db_driver=db_driver, url_resolver=url_resolver) self.targets = RecordsTargets(url_resolver=url_resolver, db_driver=db_driver)
def test_records_with_overridden_scratch_bucket(self, mock_Records): session = Session(session_type='cli', default_db_creds_name=None, default_aws_creds_name=None) self.assertEqual(session.creds.default_scratch_s3_url(), 's3://different-scratch-bucket/')
def test_creds(self, mock_CredsViaLastPass, mock_os): session = Session(session_type='lpass', default_db_creds_name=None, default_aws_creds_name=None) self.assertEqual(mock_CredsViaLastPass.return_value, session.creds)
class BaseRecordsIntegrationTest(unittest.TestCase): def setUp(self): # Ensure we're not getting any DWIM behavior out of the CLI # session: os.environ['RECORDS_MOVER_SESSION_TYPE'] = 'itest' self.resources_dir = os.path.dirname( os.path.abspath(__file__)) + '/../../resources' self.session = Session(session_type='env', default_db_creds_name=None, default_aws_creds_name=None) self.engine = self.session.get_default_db_engine() self.driver = self.session.db_driver(self.engine) if self.engine.name == 'bigquery': self.schema_name = 'bq_itest' # avoid per-table rate limits elif self.engine.name == 'mysql': self.schema_name = 'mysqlitest' else: self.schema_name = 'public' table_name_prefix = "itest_" build_num = os.environ.get("CIRCLE_BUILD_NUM", "local") current_epoch = int(time.time()) self.table_name = f"{table_name_prefix}{build_num}_{current_epoch}" self.fixture = RecordsDatabaseFixture(self.engine, schema_name=self.schema_name, table_name=self.table_name) self.fixture.tear_down() purge_old_tables(self.schema_name, table_name_prefix) logger.debug("Initialized class!") self.meta = MetaData() self.records = self.session.records def tearDown(self): self.session = None self.fixture.tear_down() def table(self, schema, table): return Table(table, self.meta, schema=schema, autoload=True, autoload_with=self.engine) def variant_has_header(self, variant): return variant in ['csv', 'bigquery'] def resource_name(self, format_type, variant, hints): if hints.get('header-row', self.variant_has_header(variant)): return f"{format_type}-{variant}-with-header" else: return f"{format_type}-{variant}-no-header" def has_scratch_s3_bucket(self): return os.environ.get('SCRATCH_S3_URL') is not None def has_scratch_gcs_bucket(self): return os.environ.get('SCRATCH_GCS_URL') is not None def has_pandas(self): try: import pandas # noqa logger.info("Just imported pandas") return True except ModuleNotFoundError: logger.info("Could not find pandas") return False def unload_column_to_string(self, column_name: str, records_format: BaseRecordsFormat) -> str: targets = self.records.targets sources = self.records.sources with tempfile.TemporaryDirectory() as directory_name: source = sources.table(schema_name=self.schema_name, table_name=self.table_name, db_engine=self.engine) directory_url = pathlib.Path(directory_name).as_uri() + '/' target = targets.directory_from_url(output_url=directory_url, records_format=records_format) self.records.move(source, target) directory_loc = self.session.directory_url(directory_url) records_dir = RecordsDirectory(records_loc=directory_loc) with tempfile.NamedTemporaryFile() as t: output_url = pathlib.Path(t.name).as_uri() output_loc = self.session.file_url(output_url) records_dir.save_to_url(output_loc) return output_loc.string_contents()
class BaseRecordsIntegrationTest(unittest.TestCase): def setUp(self): # Ensure we're not getting any DWIM behavior out of the CLI # session: os.environ['RECORDS_MOVER_SESSION_TYPE'] = 'itest' self.resources_dir = os.path.dirname( os.path.abspath(__file__)) + '/../../resources' self.session = Session(session_type='env', default_db_creds_name=None, default_aws_creds_name=None) self.engine = self.session.get_default_db_engine() self.driver = self.session.db_driver(self.engine) if self.engine.name == 'bigquery': self.schema_name = 'bq_itest' # avoid per-table rate limits elif self.engine.name == 'mysql': self.schema_name = 'mysqlitest' else: self.schema_name = 'public' table_name_prefix = "itest_" build_num = os.environ.get("CIRCLE_BUILD_NUM", "local") current_epoch = int(time.time()) self.table_name = f"{table_name_prefix}{build_num}_{current_epoch}" self.fixture = RecordsDatabaseFixture(self.engine, schema_name=self.schema_name, table_name=self.table_name) self.fixture.tear_down() purge_old_tables(self.schema_name, table_name_prefix) logger.debug("Initialized class!") self.meta = MetaData() self.records = self.session.records def tearDown(self): self.session = None self.fixture.tear_down() def table(self, schema, table): return Table(table, self.meta, schema=schema, autoload=True, autoload_with=self.engine) def variant_has_header(self, variant): return variant in ['csv', 'bigquery'] def resource_name(self, format_type, variant, hints): if hints.get('header-row', self.variant_has_header(variant)): return f"{format_type}-{variant}-with-header" else: return f"{format_type}-{variant}-no-header" def has_scratch_bucket(self): return os.environ.get('SCRATCH_S3_URL') is not None def has_pandas(self): try: import pandas # noqa logger.info("Just imported pandas") return True except ModuleNotFoundError: logger.info("Could not find pandas") return False
def main() -> None: # https://github.com/googleapis/google-auth-library-python/issues/271 import warnings warnings.filterwarnings("ignore", "Your application has authenticated using end user credentials") # skip in-memory sources/targets like dataframes that don't make # sense from the command-line source_method_name_by_cli_name = { 'table': 'table', 'gsheet': 'google_sheet', 'recordsdir': 'directory_from_url', 'url': 'data_url', 'file': 'local_file' } target_method_name_by_cli_name = { 'gsheet': 'google_sheet', 'table': 'table', 'recordsdir': 'directory_from_url', 'url': 'data_url', 'file': 'local_file', 'spectrum': 'spectrum', } sources = source_method_name_by_cli_name.keys() targets = target_method_name_by_cli_name.keys() description = 'Move tabular data ("records") from one place to another' parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter) pi_config_schema =\ method_signature_to_json_schema(ProcessingInstructions.__init__, special_handling={}, parameters_to_ignore=['self']) JobConfigSchemaAsArgsParser(config_json_schema=pi_config_schema, argument_parser=parser).configure_arg_parser() # https://stackoverflow.com/questions/15405636/pythons-argparse-to-show-programs-version-with-prog-and-version-string-formatt parser.add_argument('-V', '--version', action='version', version="%(prog)s ("+__version__+")") subparsers = parser.add_subparsers(help='subcommand_help') from records_mover import Session bootstrap_session = Session() for source in sources: for target in targets: name = f"{source}2{target}" sub_parser = subparsers.add_parser(name, help=f"Copy from {source} to {target}") source_method_name = source_method_name_by_cli_name[source] target_method_name = target_method_name_by_cli_name[target] job_config_schema = \ populate_subparser(bootstrap_session, sub_parser, source_method_name, target_method_name, subjob_name=name) sub_parser.set_defaults(func=make_job_fn(source_method_name=source_method_name, target_method_name=target_method_name, name=name, job_config_schema=job_config_schema)) args = parser.parse_args() raw_config = vars(args) func = getattr(args, 'func', None) if func is None: parser.print_help() else: set_stream_logging() try: func(raw_config) except Exception: # This is logged above using a redacting logger sys.exit(1)