示例#1
0
    def setUp(self):
        # Ensure we're not getting any DWIM behavior out of the CLI
        # session:
        os.environ['RECORDS_MOVER_SESSION_TYPE'] = 'itest'

        self.resources_dir = os.path.dirname(
            os.path.abspath(__file__)) + '/../../resources'
        self.session = Session(session_type='env',
                               default_db_creds_name=None,
                               default_aws_creds_name=None)
        self.engine = self.session.get_default_db_engine()
        self.driver = self.session.db_driver(self.engine)
        if self.engine.name == 'bigquery':
            self.schema_name = 'bq_itest'
            # avoid per-table rate limits
        elif self.engine.name == 'mysql':
            self.schema_name = 'mysqlitest'
        else:
            self.schema_name = 'public'
        table_name_prefix = "itest_"
        build_num = os.environ.get("CIRCLE_BUILD_NUM", "local")
        current_epoch = int(time.time())
        self.table_name = f"{table_name_prefix}{build_num}_{current_epoch}"
        self.fixture = RecordsDatabaseFixture(self.engine,
                                              schema_name=self.schema_name,
                                              table_name=self.table_name)
        self.fixture.tear_down()
        purge_old_tables(self.schema_name, table_name_prefix)

        logger.debug("Initialized class!")

        self.meta = MetaData()
        self.records = self.session.records
示例#2
0
 def test_get_default_db_engine_from_name(self, mock_engine_from_db_facts,
                                          mock_CredsViaLastPass, mock_os):
     session = Session(session_type='lpass',
                       default_db_creds_name='foo',
                       default_aws_creds_name=None)
     mock_creds = mock_CredsViaLastPass.return_value
     out = session.get_default_db_engine()
     self.assertEqual(out, mock_engine_from_db_facts.return_value)
     mock_engine_from_db_facts.assert_called_with(
         mock_creds.default_db_facts.return_value)
示例#3
0
 def test_get_default_db_engine_no_default(
         self, mock_engine_from_db_facts, mock_db_facts_from_env,
         mock_creds_via_env_os, mock_os, mock_get_config,
         mock_google_auth_default, mock_google_cloud_storage_Client):
     session = Session()
     self.assertEqual(session.get_default_db_engine(),
                      mock_engine_from_db_facts.return_value)
     mock_db_facts_from_env.assert_called_with()
     mock_db_facts = mock_db_facts_from_env.return_value
     mock_engine_from_db_facts.assert_called_with(mock_db_facts)
示例#4
0
 def test_file_url(self, mock_UrlResolver, mock_creds_via_env_os, mock_os,
                   mock_get_config, mock_google_auth_default,
                   mock_google_cloud_storage_Client):
     mock_credentials = Mock(name='credentials')
     mock_project = Mock(name='project')
     mock_google_auth_default.return_value = (mock_credentials,
                                              mock_project)
     session = Session(scratch_s3_url='s3://bar/baz')
     self.assertEqual(session.file_url('s3://bar/baz'),
                      mock_UrlResolver.return_value.file_url.return_value)
 def test_get_db_engine_use_sesssion_creds(
         self, mock_engine_from_db_facts, mock_os, mock_get_config,
         mock_google_auth_default, mock_google_cloud_storage_Client):
     mock_db_creds_name = Mock(name='db_creds_name')
     mock_creds = Mock(name='creds')
     session = Session(creds=mock_creds)
     out = session.get_db_engine(mock_db_creds_name)
     mock_creds.db_facts.assert_called_with(mock_db_creds_name)
     mock_engine_from_db_facts.assert_called_with(
         mock_creds.db_facts.return_value)
     self.assertEqual(mock_engine_from_db_facts.return_value, out)
 def test_env_type_uses_creds_via_env(self, mock_CredsViaEnv, mock_os,
                                      mock_get_config,
                                      mock_google_auth_default,
                                      mock_google_cloud_storage_Client):
     mock_creds = mock_CredsViaEnv.return_value
     session = Session(session_type='env')
     self.assertEqual(session.creds, mock_creds)
 def test_db_driver(self, mock_db_driver, mock_UrlResolver, mock_os,
                    mock_get_config, mock_google_auth_default,
                    mock_google_cloud_storage_Client):
     mock_creds = Mock(name='creds')
     mock_db = Mock(name='db')
     mock_url_resolver = mock_UrlResolver.return_value
     mock_scratch_s3_url = mock_creds.default_scratch_s3_url.return_value
     mock_s3_temp_base_loc = mock_url_resolver.directory_url.return_value
     session = Session(creds=mock_creds)
     out = session.db_driver(mock_db)
     self.assertEqual(out, mock_db_driver.return_value)
     mock_url_resolver.directory_url.assert_called_with(mock_scratch_s3_url)
     mock_db_driver.assert_called_with(
         db=mock_db,
         url_resolver=mock_url_resolver,
         s3_temp_base_loc=mock_s3_temp_base_loc)
示例#8
0
def run_records_mover_job(source_method_name: str,
                          target_method_name: str,
                          job_name: str,
                          config: JobConfig) -> MoveResult:
    session = Session()
    try:
        source_method = getattr(session.records.sources, source_method_name)
        target_method = getattr(session.records.targets, target_method_name)
        logger.info('Starting...')

        source_kwargs = config_to_args(config=config['source'],
                                       method=source_method,
                                       session=session)
        target_kwargs = config_to_args(config=config['target'],
                                       method=target_method,
                                       session=session)
        pi_config = {k: config[k] for k in config if k not in ['source', 'target', 'func']}
        pi_kwargs = config_to_args(pi_config,
                                   method=ProcessingInstructions,
                                   session=session)
        processing_instructions = ProcessingInstructions(**pi_kwargs)

        records = session.records
        source = source_method(**source_kwargs)
        target = target_method(**target_kwargs)
        return records.move(source, target, processing_instructions)
    except Exception:
        logger.error('', exc_info=True)
        raise
 def test_session_boto3_session_via_url_resolver_default(
         self, mock_boto3_session, mock_os, mock_get_config,
         mock_google_auth_default, mock_google_cloud_storage_Client):
     session = Session()
     boto3_session = session.url_resolver.boto3_session_getter()
     self.assertEqual(boto3_session,
                      mock_boto3_session.Session.return_value)
 def test_set_stream_logging(self, mock_set_stream_logging, mock_os,
                             mock_get_config, mock_google_auth_default,
                             mock_google_cloud_storage_Client):
     session = Session()
     mock_name = Mock(name='name')
     mock_level = Mock(name='level')
     mock_stream = Mock(name='stream')
     mock_fmt = Mock(name='fmt')
     mock_datefmt = Mock(name='datefmt')
     session.set_stream_logging(name=mock_name,
                                level=mock_level,
                                stream=mock_stream,
                                fmt=mock_fmt,
                                datefmt=mock_datefmt)
     mock_set_stream_logging.assert_called_with(name=mock_name,
                                                level=mock_level,
                                                stream=mock_stream,
                                                fmt=mock_fmt,
                                                datefmt=mock_datefmt)
示例#11
0
 def test_session_gcp_creds_via_url_resolver_default(
         self, mock_creds_via_env_os, mock_os, mock_get_config,
         mock_google_auth_default, mock_google_cloud_storage_Client):
     mock_credentials = Mock(name='credentials')
     mock_project = Mock(name='project')
     mock_google_auth_default.return_value = (mock_credentials,
                                              mock_project)
     session = Session()
     gcp_credentials = session.url_resolver.gcp_credentials_getter()
     self.assertEqual(gcp_credentials, mock_credentials)
示例#12
0
 def test_session_boto3_session_via_url_resolver_cached(
         self, mock_boto3_session, mock_creds_via_env_os, mock_os,
         mock_get_config, mock_google_auth_default,
         mock_google_cloud_storage_Client):
     session = Session()
     boto3_session = session.url_resolver.boto3_session_getter()
     self.assertEqual(boto3_session,
                      mock_boto3_session.Session.return_value)
     second_boto3_session = session.url_resolver.boto3_session_getter()
     self.assertEqual(second_boto3_session,
                      mock_boto3_session.Session.return_value)
     mock_boto3_session.Session.assert_called_once_with()
示例#13
0
 def test_records(self, mock_Records, mock_google_auth_default,
                  mock_google_cloud_storage_Client):
     mock_credentials = Mock(name='credentials')
     mock_project = Mock(name='project')
     mock_google_auth_default.return_value = (mock_credentials,
                                              mock_project)
     session = Session(session_type='cli',
                       default_db_creds_name=None,
                       default_aws_creds_name=None,
                       default_gcp_creds_name=None)
     self.assertEqual(mock_Records.return_value, session.records)
     mock_Records.assert_called_with(db_driver=ANY, url_resolver=ANY)
 def test_s3_url_from_get_config(self, mock_set_stream_logging, mock_os,
                                 mock_get_config, mock_google_auth_default,
                                 mock_google_cloud_storage_Client):
     mock_os.environ = {}
     mock_config_result = mock_get_config.return_value
     mock_config_result.config = {
         'aws': {
             's3_scratch_url': 's3://foundit/'
         }
     }
     session = Session()
     self.assertEqual(session.creds.default_scratch_s3_url(),
                      's3://foundit/')
    def move_and_verify(self, source_dbname: str, target_dbname: str) -> None:
        session = Session()
        records = session.records
        targets = records.targets
        sources = records.sources
        source_engine = session.get_db_engine(source_dbname)
        target_engine = session.get_db_engine(target_dbname)
        source_schema_name = schema_name(source_dbname)
        target_schema_name = schema_name(target_dbname)
        source_table_name = f'itest_source_{BUILD_NUM}_{CURRENT_EPOCH}'
        records_database_fixture = RecordsDatabaseFixture(
            source_engine, source_schema_name, source_table_name)
        records_database_fixture.tear_down()
        records_database_fixture.bring_up()

        existing = ExistingTableHandling.DROP_AND_RECREATE
        source = sources.table(schema_name=source_schema_name,
                               table_name=source_table_name,
                               db_engine=source_engine)
        target = targets.table(schema_name=target_schema_name,
                               table_name=TARGET_TABLE_NAME,
                               db_engine=target_engine,
                               existing_table_handling=existing)
        out = records.move(source, target)
        # redshift doesn't give reliable info on load results, so this
        # will be None or 1
        self.assertNotEqual(0, out.move_count)
        validator = RecordsTableValidator(target_engine,
                                          source_db_engine=source_engine)
        validator.validate(schema_name=target_schema_name,
                           table_name=TARGET_TABLE_NAME)

        quoted_target = quote_schema_and_table(target_engine,
                                               target_schema_name,
                                               TARGET_TABLE_NAME)
        sql = f"DROP TABLE {quoted_target}"
        target_engine.execute(sql)

        records_database_fixture.tear_down()
示例#16
0
def purge_old_tables(schema_name: str,
                     table_name_prefix: str,
                     db_name: Optional[str] = None) -> None:
    session = Session()
    if db_name is None:
        db_engine = session.get_default_db_engine()
    else:
        db_engine = session.get_db_engine(db_name)

    inspector = inspect(db_engine)
    table_names = inspector.get_table_names(schema=schema_name)
    print(f"All tables name in {schema_name}: {table_names}")
    purgable_table_names = [
        table_name for table_name in table_names if
        table_name.startswith(f"{table_name_prefix}_") and is_old(table_name)
    ]
    print(
        f"Tables to purge matching {schema_name}.{table_name_prefix}_: {purgable_table_names}"
    )
    for table_name in purgable_table_names:
        sql = f"DROP TABLE {quote_schema_and_table(db_engine, schema_name, table_name)}"
        print(sql)
        db_engine.execute(sql)
    def test_session_gcs_client_via_url_resolver_cached(
            self, mock_os, mock_get_config, mock_google_auth_default,
            mock_google_cloud_storage_Client):
        mock_credentials = Mock(name='credentials')
        mock_project = Mock(name='project')
        mock_google_auth_default.return_value = (mock_credentials,
                                                 mock_project)
        session = Session()
        gcs_client = session.url_resolver.gcs_client_getter()
        self.assertEqual(gcs_client,
                         mock_google_cloud_storage_Client.return_value)

        second_gcs_client = session.url_resolver.gcs_client_getter()
        self.assertEqual(second_gcs_client,
                         mock_google_cloud_storage_Client.return_value)

        mock_google_auth_default.assert_called_once_with()
示例#18
0
    def test_session_gcp_creds_via_url_resolver_cached(
            self, mock_creds_via_env_os, mock_os, mock_get_config,
            mock_google_auth_default, mock_google_cloud_storage_Client):
        mock_credentials = Mock(name='credentials')
        mock_project = Mock(name='project')
        mock_google_auth_default.return_value = (mock_credentials,
                                                 mock_project)
        session = Session()
        gcp_credentials = session.url_resolver.gcp_credentials_getter()
        self.assertEqual(gcp_credentials, mock_credentials)
        second_gcp_credentials = session.url_resolver.gcp_credentials_getter()
        self.assertEqual(second_gcp_credentials, mock_credentials)

        expected_scopes = (
            'https://www.googleapis.com/auth/devstorage.full_control',
            'https://www.googleapis.com/auth/devstorage.read_only',
            'https://www.googleapis.com/auth/devstorage.read_write')
        mock_google_auth_default.assert_called_once_with(
            scopes=expected_scopes)
示例#19
0
def purge_old_test_sheets(cred_name: str, spreadsheet_id: str) -> None:
    session = Session()

    creds = session.creds

    gsheet_creds = creds.google_sheets(cred_name)

    http_authorized = authorize(gsheet_creds)

    service = build('sheets', 'v4', http=http_authorized)

    sheet_ids = find_old_test_sheets(service, spreadsheet_id)

    for sheet_id in sheet_ids:
        try:
            delete_sheet_by_id(service, spreadsheet_id, sheet_id)
        except googleapiclient.errors.HttpError:
            print(
                f"Could not delete sheet {sheet_id} (another purger running at the same time?)"
            )
 def test_session_boto3_session_via_url_resolver_specified(
         self, mock_boto3_session, mock_CredsViaEnv, mock_os,
         mock_get_config, mock_google_auth_default,
         mock_google_cloud_storage_Client):
     mock_default_aws_creds_name = Mock(name='default_aws_creds_name')
     session = Session(default_aws_creds_name=mock_default_aws_creds_name,
                       session_type='env')
     boto3_session = session.url_resolver.boto3_session_getter()
     self.assertEqual(
         boto3_session,
         mock_CredsViaEnv.return_value.default_boto3_session.return_value)
     mock_CredsViaEnv.assert_called_with(
         default_db_creds_name=None,
         default_aws_creds_name=mock_default_aws_creds_name,
         default_gcp_creds_name=None,
         default_db_facts=PleaseInfer.token,
         default_boto3_session=PleaseInfer.token,
         default_gcp_creds=PleaseInfer.token,
         default_gcs_client=PleaseInfer.token,
         scratch_s3_url=PleaseInfer.token)
     mock_CredsViaEnv.return_value.default_boto3_session.assert_called()
示例#21
0
    def __init__(
            self,
            db_driver: Union[Callable[[Union['Engine', 'Connection']],
                                      'DBDriver'],
                             PleaseInfer] = PleaseInfer.token,
            url_resolver: Union[UrlResolver, PleaseInfer] = PleaseInfer.token,
            session: Union['Session',
                           PleaseInfer] = PleaseInfer.token) -> None:
        if db_driver is PleaseInfer.token or url_resolver is PleaseInfer.token:
            if session is PleaseInfer.token:
                from records_mover import Session  # noqa

                session = Session()
            if db_driver is PleaseInfer.token:
                db_driver = session.db_driver
            if url_resolver is PleaseInfer.token:
                url_resolver = session.url_resolver
        self.move = move  # type: ignore
        self.sources = RecordsSources(db_driver=db_driver,
                                      url_resolver=url_resolver)
        self.targets = RecordsTargets(url_resolver=url_resolver,
                                      db_driver=db_driver)
示例#22
0
 def test_records_with_overridden_scratch_bucket(self, mock_Records):
     session = Session(session_type='cli',
                       default_db_creds_name=None,
                       default_aws_creds_name=None)
     self.assertEqual(session.creds.default_scratch_s3_url(),
                      's3://different-scratch-bucket/')
示例#23
0
 def test_creds(self, mock_CredsViaLastPass, mock_os):
     session = Session(session_type='lpass',
                       default_db_creds_name=None,
                       default_aws_creds_name=None)
     self.assertEqual(mock_CredsViaLastPass.return_value, session.creds)
示例#24
0
class BaseRecordsIntegrationTest(unittest.TestCase):
    def setUp(self):
        # Ensure we're not getting any DWIM behavior out of the CLI
        # session:
        os.environ['RECORDS_MOVER_SESSION_TYPE'] = 'itest'

        self.resources_dir = os.path.dirname(
            os.path.abspath(__file__)) + '/../../resources'
        self.session = Session(session_type='env',
                               default_db_creds_name=None,
                               default_aws_creds_name=None)
        self.engine = self.session.get_default_db_engine()
        self.driver = self.session.db_driver(self.engine)
        if self.engine.name == 'bigquery':
            self.schema_name = 'bq_itest'
            # avoid per-table rate limits
        elif self.engine.name == 'mysql':
            self.schema_name = 'mysqlitest'
        else:
            self.schema_name = 'public'
        table_name_prefix = "itest_"
        build_num = os.environ.get("CIRCLE_BUILD_NUM", "local")
        current_epoch = int(time.time())
        self.table_name = f"{table_name_prefix}{build_num}_{current_epoch}"
        self.fixture = RecordsDatabaseFixture(self.engine,
                                              schema_name=self.schema_name,
                                              table_name=self.table_name)
        self.fixture.tear_down()
        purge_old_tables(self.schema_name, table_name_prefix)

        logger.debug("Initialized class!")

        self.meta = MetaData()
        self.records = self.session.records

    def tearDown(self):
        self.session = None
        self.fixture.tear_down()

    def table(self, schema, table):
        return Table(table,
                     self.meta,
                     schema=schema,
                     autoload=True,
                     autoload_with=self.engine)

    def variant_has_header(self, variant):
        return variant in ['csv', 'bigquery']

    def resource_name(self, format_type, variant, hints):
        if hints.get('header-row', self.variant_has_header(variant)):
            return f"{format_type}-{variant}-with-header"
        else:
            return f"{format_type}-{variant}-no-header"

    def has_scratch_s3_bucket(self):
        return os.environ.get('SCRATCH_S3_URL') is not None

    def has_scratch_gcs_bucket(self):
        return os.environ.get('SCRATCH_GCS_URL') is not None

    def has_pandas(self):
        try:
            import pandas  # noqa
            logger.info("Just imported pandas")
            return True
        except ModuleNotFoundError:
            logger.info("Could not find pandas")
            return False

    def unload_column_to_string(self, column_name: str,
                                records_format: BaseRecordsFormat) -> str:
        targets = self.records.targets
        sources = self.records.sources
        with tempfile.TemporaryDirectory() as directory_name:
            source = sources.table(schema_name=self.schema_name,
                                   table_name=self.table_name,
                                   db_engine=self.engine)
            directory_url = pathlib.Path(directory_name).as_uri() + '/'
            target = targets.directory_from_url(output_url=directory_url,
                                                records_format=records_format)
            self.records.move(source, target)
            directory_loc = self.session.directory_url(directory_url)
            records_dir = RecordsDirectory(records_loc=directory_loc)
            with tempfile.NamedTemporaryFile() as t:
                output_url = pathlib.Path(t.name).as_uri()
                output_loc = self.session.file_url(output_url)
                records_dir.save_to_url(output_loc)
                return output_loc.string_contents()
class BaseRecordsIntegrationTest(unittest.TestCase):
    def setUp(self):
        # Ensure we're not getting any DWIM behavior out of the CLI
        # session:
        os.environ['RECORDS_MOVER_SESSION_TYPE'] = 'itest'

        self.resources_dir = os.path.dirname(
            os.path.abspath(__file__)) + '/../../resources'
        self.session = Session(session_type='env',
                               default_db_creds_name=None,
                               default_aws_creds_name=None)
        self.engine = self.session.get_default_db_engine()
        self.driver = self.session.db_driver(self.engine)
        if self.engine.name == 'bigquery':
            self.schema_name = 'bq_itest'
            # avoid per-table rate limits
        elif self.engine.name == 'mysql':
            self.schema_name = 'mysqlitest'
        else:
            self.schema_name = 'public'
        table_name_prefix = "itest_"
        build_num = os.environ.get("CIRCLE_BUILD_NUM", "local")
        current_epoch = int(time.time())
        self.table_name = f"{table_name_prefix}{build_num}_{current_epoch}"
        self.fixture = RecordsDatabaseFixture(self.engine,
                                              schema_name=self.schema_name,
                                              table_name=self.table_name)
        self.fixture.tear_down()
        purge_old_tables(self.schema_name, table_name_prefix)

        logger.debug("Initialized class!")

        self.meta = MetaData()
        self.records = self.session.records

    def tearDown(self):
        self.session = None
        self.fixture.tear_down()

    def table(self, schema, table):
        return Table(table,
                     self.meta,
                     schema=schema,
                     autoload=True,
                     autoload_with=self.engine)

    def variant_has_header(self, variant):
        return variant in ['csv', 'bigquery']

    def resource_name(self, format_type, variant, hints):
        if hints.get('header-row', self.variant_has_header(variant)):
            return f"{format_type}-{variant}-with-header"
        else:
            return f"{format_type}-{variant}-no-header"

    def has_scratch_bucket(self):
        return os.environ.get('SCRATCH_S3_URL') is not None

    def has_pandas(self):
        try:
            import pandas  # noqa
            logger.info("Just imported pandas")
            return True
        except ModuleNotFoundError:
            logger.info("Could not find pandas")
            return False
示例#26
0
def main() -> None:
    # https://github.com/googleapis/google-auth-library-python/issues/271
    import warnings
    warnings.filterwarnings("ignore",
                            "Your application has authenticated using end user credentials")

    # skip in-memory sources/targets like dataframes that don't make
    # sense from the command-line
    source_method_name_by_cli_name = {
        'table': 'table',
        'gsheet': 'google_sheet',
        'recordsdir': 'directory_from_url',
        'url': 'data_url',
        'file': 'local_file'
    }
    target_method_name_by_cli_name = {
        'gsheet': 'google_sheet',
        'table': 'table',
        'recordsdir': 'directory_from_url',
        'url': 'data_url',
        'file': 'local_file',
        'spectrum': 'spectrum',
    }
    sources = source_method_name_by_cli_name.keys()
    targets = target_method_name_by_cli_name.keys()

    description = 'Move tabular data ("records") from one place to another'
    parser = argparse.ArgumentParser(description=description,
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    pi_config_schema =\
        method_signature_to_json_schema(ProcessingInstructions.__init__,
                                        special_handling={},
                                        parameters_to_ignore=['self'])
    JobConfigSchemaAsArgsParser(config_json_schema=pi_config_schema,
                                argument_parser=parser).configure_arg_parser()

    # https://stackoverflow.com/questions/15405636/pythons-argparse-to-show-programs-version-with-prog-and-version-string-formatt
    parser.add_argument('-V', '--version', action='version', version="%(prog)s ("+__version__+")")
    subparsers = parser.add_subparsers(help='subcommand_help')
    from records_mover import Session
    bootstrap_session = Session()

    for source in sources:
        for target in targets:
            name = f"{source}2{target}"
            sub_parser = subparsers.add_parser(name, help=f"Copy from {source} to {target}")
            source_method_name = source_method_name_by_cli_name[source]
            target_method_name = target_method_name_by_cli_name[target]
            job_config_schema = \
                populate_subparser(bootstrap_session,
                                   sub_parser, source_method_name, target_method_name,
                                   subjob_name=name)
            sub_parser.set_defaults(func=make_job_fn(source_method_name=source_method_name,
                                                     target_method_name=target_method_name,
                                                     name=name,
                                                     job_config_schema=job_config_schema))
    args = parser.parse_args()
    raw_config = vars(args)
    func = getattr(args, 'func', None)
    if func is None:
        parser.print_help()
    else:
        set_stream_logging()
        try:
            func(raw_config)
        except Exception:
            # This is logged above using a redacting logger
            sys.exit(1)