コード例 #1
0
    def setUp(self):
        super().setUp()
        self._home = self.tmp_dir
        self.enter_context(test_case_utils.override_env_var(
            'HOME', self._home))
        self.enter_context(
            test_case_utils.override_env_var(
                'AIRFLOW_HOME', os.path.join(os.environ['HOME'], 'airflow')))

        # Flags for handler.
        self.engine = 'airflow'
        self.pipeline_path = os.path.join(_testdata_dir,
                                          'test_pipeline_airflow_1.py')
        self.pipeline_root = os.path.join(self._home, 'tfx', 'pipelines')
        self.pipeline_name = 'chicago_taxi_simple'
        self.run_id = 'manual__2019-07-19T19:56:02+00:00'
        self.runtime_parameter = {'a': '1', 'b': '2'}
        self.runtime_parameter_json = json.dumps(self.runtime_parameter)

        self.enter_context(test_case_utils.change_working_dir(self.tmp_dir))

        # Pipeline args for mocking subprocess
        self.pipeline_args = {labels.PIPELINE_NAME: self.pipeline_name}
        self._mock_get_airflow_version = self.enter_context(
            mock.patch.object(airflow_handler.AirflowHandler,
                              '_get_airflow_version',
                              return_value='2.0.1'))
コード例 #2
0
    def setUp(self):
        super().setUp()
        self.chicago_taxi_pipeline_dir = os.path.join(
            os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
            'testdata')
        self._home = self.tmp_dir
        self.enter_context(test_case_utils.change_working_dir(self.tmp_dir))
        self.enter_context(test_case_utils.override_env_var(
            'HOME', self._home))
        self._beam_home = os.path.join(os.environ['HOME'], 'beam')
        self.enter_context(
            test_case_utils.override_env_var('BEAM_HOME', self._beam_home))

        # Flags for handler.
        self.engine = 'beam'
        self.pipeline_path = os.path.join(self.chicago_taxi_pipeline_dir,
                                          'test_pipeline_beam_1.py')
        self.pipeline_name = 'chicago_taxi_beam'
        self.pipeline_root = os.path.join(self._home, 'tfx', 'pipelines',
                                          self.pipeline_name)
        self.run_id = 'dummyID'

        self.pipeline_args = {
            labels.PIPELINE_NAME: self.pipeline_name,
            labels.PIPELINE_DSL_PATH: self.pipeline_path,
        }
コード例 #3
0
ファイル: kubeflow_handler_test.py プロジェクト: joy-jj/tfx
    def setUp(self):
        super(KubeflowHandlerTest, self).setUp()
        self._home = self.tmp_dir

        self.enter_context(test_case_utils.override_env_var(
            'HOME', self._home))
        self.enter_context(
            test_case_utils.override_env_var(
                'KUBEFLOW_HOME', os.path.join(self._home, 'kubeflow')))

        # Flags for handler.
        self.engine = 'kubeflow'
        self.chicago_taxi_pipeline_dir = os.path.join(
            os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
            'testdata')
        self.pipeline_path = os.path.join(self.chicago_taxi_pipeline_dir,
                                          'test_pipeline_kubeflow_1.py')
        self.pipeline_name = 'chicago_taxi_pipeline_kubeflow'
        self.pipeline_package_path = os.path.abspath(
            'chicago_taxi_pipeline_kubeflow.tar.gz')
        self.pipeline_root = os.path.join(self._home, 'tfx', 'pipelines')

        # Kubeflow client params.
        self.endpoint = 'dummyEndpoint'
        self.namespace = 'kubeflow'
        self.iap_client_id = 'dummyID'

        # Pipeline args for mocking subprocess.
        self.pipeline_args = {
            'pipeline_name': 'chicago_taxi_pipeline_kubeflow'
        }
コード例 #4
0
ファイル: local_handler_test.py プロジェクト: majiang/tfx
    def setUp(self):
        super(LocalHandlerTest, self).setUp()
        self.chicago_taxi_pipeline_dir = os.path.join(
            os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
            'testdata')
        self._home = self.tmp_dir
        self.enter_context(test_case_utils.override_env_var(
            'HOME', self._home))
        self._local_home = os.path.join(os.environ['HOME'], 'local')
        self.enter_context(
            test_case_utils.override_env_var('LOCAL_HOME', self._local_home))

        # Flags for handler.
        self.engine = 'local'
        self.pipeline_path = os.path.join(self.chicago_taxi_pipeline_dir,
                                          'test_pipeline_local_1.py')
        self.pipeline_name = 'chicago_taxi_local'
        self.pipeline_root = os.path.join(self._home, 'tfx', 'pipelines',
                                          self.pipeline_name)
        self.run_id = 'dummyID'

        # Pipeline args for mocking subprocess
        self.pipeline_args = {
            'pipeline_name': 'chicago_taxi_local',
            'pipeline_dsl_path': self.pipeline_path
        }
コード例 #5
0
ファイル: vertex_handler_test.py プロジェクト: jay90099/tfx
    def setUp(self):
        super().setUp()
        self.chicago_taxi_pipeline_dir = os.path.join(
            os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
            'testdata')

        self._home = self.tmp_dir
        self.enter_context(test_case_utils.change_working_dir(self.tmp_dir))
        self.enter_context(test_case_utils.override_env_var(
            'HOME', self._home))
        self._vertex_home = os.path.join(self._home, 'vertex')
        self.enter_context(
            test_case_utils.override_env_var('VERTEX_HOME', self._vertex_home))

        # Flags for handler.
        self.engine = 'vertex'
        self.pipeline_path = os.path.join(self.chicago_taxi_pipeline_dir,
                                          'test_pipeline_kubeflow_v2_1.py')
        self.pipeline_name = _TEST_PIPELINE_NAME
        self.pipeline_root = os.path.join(self._home, 'tfx', 'pipelines',
                                          self.pipeline_name)
        self.run_id = 'dummyID'
        self.project = 'gcp_project_1'
        self.region = 'us-central1'

        self.runtime_parameter = {'a': '1', 'b': '2'}

        # Setting up Mock for API client, so that this Python test is hermetic.
        # subprocess Mock will be setup per-test.
        self.addCleanup(mock.patch.stopall)
コード例 #6
0
  def setUp(self):
    super(KubeflowV2HandlerTest, self).setUp()
    self.chicago_taxi_pipeline_dir = os.path.join(
        os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'testdata')

    self._home = self.tmp_dir
    self.enter_context(test_case_utils.override_env_var('HOME', self._home))
    self._kubeflow_v2_home = os.path.join(self._home, 'kubeflow_v2')
    self.enter_context(
        test_case_utils.override_env_var('KUBEFLOW_V2_HOME',
                                         self._kubeflow_v2_home))

    # Flags for handler.
    self.engine = 'kubeflow_v2'
    self.pipeline_path = os.path.join(self.chicago_taxi_pipeline_dir,
                                      'test_pipeline_1.py')
    self.bad_pipeline_path = os.path.join(self.chicago_taxi_pipeline_dir,
                                          'test_pipeline_bad.py')
    self.pipeline_name = _TEST_PIPELINE_NAME
    self.pipeline_root = os.path.join(self._home, 'tfx', 'pipelines',
                                      self.pipeline_name)
    self.run_id = 'dummyID'

    # Pipeline args for mocking subprocess
    self.pipeline_args = {
        'pipeline_name': _TEST_PIPELINE_NAME,
        'pipeline_dsl_path': self.pipeline_path
    }

    # Setting up Mock for API client, so that this Python test is hermatic.
    # subprocess Mock will be setup per-test.
    self.addCleanup(mock.patch.stopall)
コード例 #7
0
ファイル: cli_airflow_e2e_test.py プロジェクト: jay90099/tfx
  def setUp(self):
    super().setUp()

    # List of packages installed.
    self._pip_list = pip_utils.get_package_names()

    # Check if Apache Airflow is installed before running E2E tests.
    if labels.AIRFLOW_PACKAGE_NAME not in self._pip_list:
      sys.exit('Apache Airflow not installed.')

    # Change the encoding for Click since Python 3 is configured to use ASCII as
    # encoding for the environment.
    if codecs.lookup(locale.getpreferredencoding()).name == 'ascii':
      os.environ['LANG'] = 'en_US.utf-8'

    # Setup airflow_home in a temp directory
    self._airflow_home = os.path.join(self.tmp_dir, 'airflow')
    self.enter_context(
        test_case_utils.override_env_var('AIRFLOW_HOME', self._airflow_home))
    self.enter_context(
        test_case_utils.override_env_var('HOME', self._airflow_home))

    absl.logging.info('Using %s as AIRFLOW_HOME and HOME in this e2e test',
                      self._airflow_home)

    # Testdata path.
    self._testdata_dir = os.path.join(
        os.path.dirname(os.path.dirname(__file__)), 'testdata')

    self._pipeline_name = 'chicago_taxi_simple'
    self._pipeline_path = os.path.join(self._testdata_dir,
                                       'test_pipeline_airflow_1.py')

    # Copy data.
    chicago_taxi_pipeline_dir = os.path.join(
        os.path.dirname(
            os.path.dirname(
                os.path.dirname(os.path.dirname(os.path.abspath(__file__))))),
        'examples', 'chicago_taxi_pipeline')
    data_dir = os.path.join(chicago_taxi_pipeline_dir, 'data', 'simple')
    content = fileio.listdir(data_dir)
    assert content, 'content in {} is empty'.format(data_dir)
    target_data_dir = os.path.join(self._airflow_home, 'taxi', 'data', 'simple')
    io_utils.copy_dir(data_dir, target_data_dir)
    assert fileio.isdir(target_data_dir)
    content = fileio.listdir(target_data_dir)
    assert content, 'content in {} is {}'.format(target_data_dir, content)
    io_utils.copy_file(
        os.path.join(chicago_taxi_pipeline_dir, 'taxi_utils.py'),
        os.path.join(self._airflow_home, 'taxi', 'taxi_utils.py'))

    # Initialize CLI runner.
    self.runner = click_testing.CliRunner()
コード例 #8
0
ファイル: cli_airflow_e2e_test.py プロジェクト: jay90099/tfx
  def _prepare_airflow_with_mysql(self):
    self._mysql_container_name = 'airflow_' + test_utils.generate_random_id()
    db_port = airflow_test_utils.create_mysql_container(
        self._mysql_container_name)
    self.addCleanup(self._cleanup_mysql_container)
    self.enter_context(
        test_case_utils.override_env_var(
            'AIRFLOW__CORE__SQL_ALCHEMY_CONN',
            'mysql://[email protected]:%d/airflow' % db_port))
    # Do not load examples to make this a bit faster.
    self.enter_context(
        test_case_utils.override_env_var('AIRFLOW__CORE__LOAD_EXAMPLES',
                                         'False'))

    self._airflow_initdb()
コード例 #9
0
 def setUp(self):
     super().setUp()
     self._namespace = 'kubeflow'
     self._endpoint = self._get_endpoint(self._namespace)
     self._kfp_client = kfp.Client(host=self._endpoint)
     logging.info('ENDPOINT: %s', self._endpoint)
     self.enter_context(
         test_case_utils.override_env_var(
             'KUBEFLOW_HOME', os.path.join(self._temp_dir, 'kubeflow')))
コード例 #10
0
    def setUp(self):
        super(AirflowHandlerTest, self).setUp()
        self._home = self.tmp_dir
        self.enter_context(test_case_utils.override_env_var(
            'HOME', self._home))
        self.enter_context(
            test_case_utils.override_env_var(
                'AIRFLOW_HOME', os.path.join(os.environ['HOME'], 'airflow')))

        # Flags for handler.
        self.engine = 'airflow'
        self.pipeline_path = os.path.join(_testdata_dir,
                                          'test_pipeline_airflow_1.py')
        self.pipeline_root = os.path.join(self._home, 'tfx', 'pipelines')
        self.pipeline_name = 'chicago_taxi_simple'
        self.run_id = 'manual__2019-07-19T19:56:02+00:00'

        # Pipeline args for mocking subprocess
        self.pipeline_args = {'pipeline_name': 'chicago_taxi_simple'}
コード例 #11
0
    def setUp(self):
        super().setUp()

        # Change the encoding for Click since Python 3 is configured to use ASCII as
        # encoding for the environment.
        if codecs.lookup(locale.getpreferredencoding()).name == 'ascii':
            os.environ['LANG'] = 'en_US.utf-8'

        # Setup beam_home in a temp directory
        self._home = self.tmp_dir
        self._beam_home = os.path.join(self._home, 'beam')
        self.enter_context(
            test_case_utils.override_env_var('BEAM_HOME', self._beam_home))
        self.enter_context(test_case_utils.override_env_var(
            'HOME', self._home))

        # Testdata path.
        self._testdata_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'testdata')

        # Copy data.
        chicago_taxi_pipeline_dir = os.path.join(
            os.path.dirname(
                os.path.dirname(
                    os.path.dirname(os.path.dirname(
                        os.path.abspath(__file__))))), 'examples',
            'chicago_taxi_pipeline', '')
        data_dir = os.path.join(chicago_taxi_pipeline_dir, 'data', 'simple')
        content = fileio.listdir(data_dir)
        assert content, 'content in {} is empty'.format(data_dir)
        target_data_dir = os.path.join(self._home, 'taxi', 'data', 'simple')
        io_utils.copy_dir(data_dir, target_data_dir)
        assert fileio.isdir(target_data_dir)
        content = fileio.listdir(target_data_dir)
        assert content, 'content in {} is {}'.format(target_data_dir, content)
        io_utils.copy_file(
            os.path.join(chicago_taxi_pipeline_dir, 'taxi_utils.py'),
            os.path.join(self._home, 'taxi', 'taxi_utils.py'))

        # Initialize CLI runner.
        self.runner = click_testing.CliRunner()
コード例 #12
0
 def setUp(self):
     super().setUp()
     self.enter_context(
         test_case_utils.override_env_var(
             'VERTEX_HOME', os.path.join(self._temp_dir, 'vertex')))
コード例 #13
0
    def setUp(self):
        super().setUp()

        # Flags for handler.
        self.engine = 'kubeflow'
        self.chicago_taxi_pipeline_dir = os.path.join(
            os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
            'testdata')

        self.enter_context(test_case_utils.change_working_dir(self.tmp_dir))
        self.enter_context(
            test_case_utils.override_env_var('KFP_E2E_BASE_CONTAINER_IMAGE',
                                             'dummy-image'))
        self.enter_context(
            test_case_utils.override_env_var('KFP_E2E_BUCKET_NAME',
                                             'dummy-bucket'))
        self.enter_context(
            test_case_utils.override_env_var('KFP_E2E_TEST_DATA_ROOT',
                                             'dummy-root'))

        self.pipeline_path = os.path.join(self.chicago_taxi_pipeline_dir,
                                          'test_pipeline_kubeflow_1.py')
        self.pipeline_name = 'chicago_taxi_pipeline_kubeflow'

        # Kubeflow client params.
        self.endpoint = 'dummyEndpoint'
        self.namespace = 'kubeflow'
        self.iap_client_id = 'dummyID'

        self.runtime_parameter = {'a': '1', 'b': '2'}

        default_flags = {
            labels.ENGINE_FLAG: self.engine,
            labels.ENDPOINT: self.endpoint,
            labels.IAP_CLIENT_ID: self.iap_client_id,
            labels.NAMESPACE: self.namespace,
        }

        self.flags_with_name = {
            **default_flags,
            labels.PIPELINE_NAME: self.pipeline_name,
        }

        self.flags_with_runtime_param = {
            **default_flags,
            labels.PIPELINE_NAME: self.pipeline_name,
            labels.RUNTIME_PARAMETER: self.runtime_parameter,
        }

        self.flags_with_dsl_path = {
            **default_flags,
            labels.PIPELINE_DSL_PATH: self.pipeline_path,
        }

        # Pipeline args for mocking subprocess.
        self.pipeline_args = {
            'pipeline_name': 'chicago_taxi_pipeline_kubeflow'
        }
        self.pipeline_id = 'the_pipeline_id'
        self.experiment_id = 'the_experiment_id'
        self.pipeline_version_id = 'the_pipeline_version_id'

        mock_client_cls = self.enter_context(
            mock.patch.object(kfp, 'Client', autospec=True))
        self.mock_client = mock_client_cls.return_value
        # Required to access generated apis.
        self.mock_client._experiment_api = mock.MagicMock()

        self.mock_client.get_pipeline_id.return_value = self.pipeline_id
        self.mock_client.get_experiment.return_value.id = self.experiment_id
        versions = [mock.MagicMock()]
        versions[0].id = self.pipeline_version_id
        self.mock_client.list_pipeline_versions.return_value.versions = versions
コード例 #14
0
    def setUp(self):
        super(AirflowEndToEndTest, self).setUp()
        # setup airflow_home in a temp directory, config and init db.
        self._airflow_home = self.tmp_dir
        self.enter_context(
            test_case_utils.override_env_var('AIRFLOW_HOME',
                                             self._airflow_home))
        self.enter_context(
            test_case_utils.override_env_var('HOME', self._airflow_home))
        absl.logging.info('Using %s as AIRFLOW_HOME and HOME in this e2e test',
                          self._airflow_home)

        self._mysql_container_name = 'airflow_' + test_utils.generate_random_id(
        )
        db_port = airflow_test_utils.create_mysql_container(
            self._mysql_container_name)
        self.addCleanup(airflow_test_utils.delete_mysql_container,
                        self._mysql_container_name)
        os.environ['AIRFLOW__CORE__SQL_ALCHEMY_CONN'] = (
            'mysql://[email protected]:%d/airflow' % db_port)

        # Set a couple of important environment variables. See
        # https://airflow.apache.org/howto/set-config.html for details.
        os.environ['AIRFLOW__CORE__DAGS_FOLDER'] = os.path.join(
            self._airflow_home, 'dags')
        os.environ['AIRFLOW__CORE__BASE_LOG_FOLDER'] = os.path.join(
            self._airflow_home, 'logs')
        os.environ['AIRFLOW__CORE__DAGBAG_IMPORT_TIMEOUT'] = '300'
        # Do not load examples to make this a bit faster.
        os.environ['AIRFLOW__CORE__LOAD_EXAMPLES'] = 'False'
        # Following environment variables make scheduler process dags faster.
        os.environ['AIRFLOW__SCHEDULER__JOB_HEARTBEAT_SEC'] = '1'
        os.environ['AIRFLOW__SCHEDULER__SCHEDULER_HEARTBEAT_SEC'] = '1'
        os.environ['AIRFLOW__SCHEDULER__RUN_DURATION'] = '-1'
        os.environ['AIRFLOW__SCHEDULER__MIN_FILE_PROCESS_INTERVAL'] = '1'
        os.environ['AIRFLOW__SCHEDULER__PRINT_STATS_INTERVAL'] = '30'

        # Following fields are specific to the chicago_taxi_simple example.
        self._dag_id = 'chicago_taxi_simple'
        self._run_id = 'manual_run_id_1'
        # This execution date must be after the start_date in chicago_taxi_simple
        # but before current execution date.
        self._execution_date = '2019-02-01T01:01:01'
        self._all_tasks = [
            'CsvExampleGen',
            'Evaluator',
            'ExampleValidator',
            'Pusher',
            'SchemaGen',
            'StatisticsGen',
            'Trainer',
            'Transform',
        ]
        # Copy dag file and data.
        chicago_taxi_pipeline_dir = os.path.dirname(__file__)
        simple_pipeline_file = os.path.join(chicago_taxi_pipeline_dir,
                                            'taxi_pipeline_simple.py')

        io_utils.copy_file(
            simple_pipeline_file,
            os.path.join(self._airflow_home, 'dags',
                         'taxi_pipeline_simple.py'))

        data_dir = os.path.join(chicago_taxi_pipeline_dir, 'data', 'simple')
        content = fileio.listdir(data_dir)
        assert content, 'content in {} is empty'.format(data_dir)
        target_data_dir = os.path.join(self._airflow_home, 'taxi', 'data',
                                       'simple')
        io_utils.copy_dir(data_dir, target_data_dir)
        assert fileio.isdir(target_data_dir)
        content = fileio.listdir(target_data_dir)
        assert content, 'content in {} is {}'.format(target_data_dir, content)
        io_utils.copy_file(
            os.path.join(chicago_taxi_pipeline_dir, 'taxi_utils.py'),
            os.path.join(self._airflow_home, 'taxi', 'taxi_utils.py'))

        # Initialize database.
        subprocess.run(['airflow', 'db', 'init'], check=True)
        subprocess.run(['airflow', 'dags', 'unpause', self._dag_id],
                       check=True)
コード例 #15
0
ファイル: cli_kubeflow_e2e_test.py プロジェクト: joy-jj/tfx
    def setUp(self):
        super(CliKubeflowEndToEndTest, self).setUp()
        random.seed(datetime.datetime.now())

        # List of packages installed.
        self._pip_list = pip_utils.get_package_names()

        # Check if Kubeflow is installed before running E2E tests.
        if labels.KUBEFLOW_PACKAGE_NAME not in self._pip_list:
            sys.exit('Kubeflow not installed.')

        # Change the encoding for Click since Python 3 is configured to use ASCII as
        # encoding for the environment.
        if codecs.lookup(locale.getpreferredencoding()).name == 'ascii':
            os.environ['LANG'] = 'en_US.utf-8'

        # Initialize CLI runner.
        self.runner = click_testing.CliRunner()

        # Testdata path.
        self._testdata_dir = os.path.join(
            os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
            'testdata')
        self._testdata_dir_updated = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)
        fileio.makedirs(self._testdata_dir_updated)

        self._pipeline_name = ('cli-kubeflow-e2e-test-' +
                               test_utils.generate_random_id())
        absl.logging.info('Pipeline name is %s' % self._pipeline_name)
        self._pipeline_name_v2 = self._pipeline_name + '_v2'

        orig_pipeline_path = os.path.join(self._testdata_dir,
                                          'test_pipeline_kubeflow_1.py')
        self._pipeline_path = os.path.join(self._testdata_dir_updated,
                                           'test_pipeline_kubeflow_1.py')
        self._pipeline_path_v2 = os.path.join(self._testdata_dir_updated,
                                              'test_pipeline_kubeflow_2.py')

        test_utils.copy_and_change_pipeline_name(
            orig_pipeline_path, self._pipeline_path,
            'chicago_taxi_pipeline_kubeflow', self._pipeline_name)
        self.assertTrue(fileio.exists(self._pipeline_path))
        test_utils.copy_and_change_pipeline_name(
            orig_pipeline_path, self._pipeline_path_v2,
            'chicago_taxi_pipeline_kubeflow', self._pipeline_name_v2)
        self.assertTrue(fileio.exists(self._pipeline_path_v2))

        # Endpoint URL
        self._endpoint = self._get_endpoint(
            subprocess.check_output(
                'kubectl describe configmap inverse-proxy-config -n kubeflow'.
                split()))
        absl.logging.info('ENDPOINT: ' + self._endpoint)

        # Change home directories
        self._kubeflow_home = self.tmp_dir
        self.enter_context(
            test_case_utils.override_env_var('KUBEFLOW_HOME',
                                             self._kubeflow_home))

        self._handler_pipeline_path = os.path.join(self._kubeflow_home,
                                                   self._pipeline_name)
        self._handler_pipeline_args_path = os.path.join(
            self._handler_pipeline_path, 'pipeline_args.json')
        self._pipeline_package_path = '{}.tar.gz'.format(self._pipeline_name)

        try:
            # Create a kfp client for cleanup after running commands.
            self._client = kfp.Client(host=self._endpoint)
        except kfp_server_api.rest.ApiException as err:
            absl.logging.info(err)
コード例 #16
0
 def testOverrideEnvVar(self):
     old_home = os.getenv('HOME')
     new_home = self.get_temp_dir()
     with test_case_utils.override_env_var('HOME', new_home):
         self.assertEqual(os.environ['HOME'], new_home)
     self.assertEqual(os.getenv('HOME'), old_home)
コード例 #17
0
 def setUp(self):
     super().setUp()
     self.enter_context(
         test_case_utils.override_env_var('OVERWRITE_ENV', 'bar'))
コード例 #18
0
 def setUp(self):
     super().setUp()
     self.enter_context(test_case_utils.override_env_var('NEW_ENV', 'foo'))
     self.enter_context(
         test_case_utils.override_env_var('OVERWRITE_ENV', 'baz'))
     self.enter_context(test_case_utils.change_working_dir(self.tmp_dir))
コード例 #19
0
 def _set_required_env_vars(self, env_vars):
     for k, v in env_vars.items():
         self.enter_context(test_case_utils.override_env_var(k, v))
コード例 #20
0
    def setUp(self):
        super(CliAirflowEndToEndTest, self).setUp()

        # List of packages installed.
        self._pip_list = pip_utils.get_package_names()

        # Check if Apache Airflow is installed before running E2E tests.
        if labels.AIRFLOW_PACKAGE_NAME not in self._pip_list:
            sys.exit('Apache Airflow not installed.')

        # Change the encoding for Click since Python 3 is configured to use ASCII as
        # encoding for the environment.
        if codecs.lookup(locale.getpreferredencoding()).name == 'ascii':
            os.environ['LANG'] = 'en_US.utf-8'

        # Setup airflow_home in a temp directory
        self._airflow_home = os.path.join(self.tmp_dir, 'airflow')
        self.enter_context(
            test_case_utils.override_env_var('AIRFLOW_HOME',
                                             self._airflow_home))
        self.enter_context(
            test_case_utils.override_env_var('HOME', self._airflow_home))

        absl.logging.info('Using %s as AIRFLOW_HOME and HOME in this e2e test',
                          self._airflow_home)

        # Testdata path.
        self._testdata_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'testdata')

        self._pipeline_name = 'chicago_taxi_simple'
        self._pipeline_path = os.path.join(self._testdata_dir,
                                           'test_pipeline_airflow_1.py')

        # Copy data.
        chicago_taxi_pipeline_dir = os.path.join(
            os.path.dirname(
                os.path.dirname(
                    os.path.dirname(os.path.dirname(
                        os.path.abspath(__file__))))), 'examples',
            'chicago_taxi_pipeline')
        data_dir = os.path.join(chicago_taxi_pipeline_dir, 'data', 'simple')
        content = fileio.listdir(data_dir)
        assert content, 'content in {} is empty'.format(data_dir)
        target_data_dir = os.path.join(self._airflow_home, 'taxi', 'data',
                                       'simple')
        io_utils.copy_dir(data_dir, target_data_dir)
        assert fileio.isdir(target_data_dir)
        content = fileio.listdir(target_data_dir)
        assert content, 'content in {} is {}'.format(target_data_dir, content)
        io_utils.copy_file(
            os.path.join(chicago_taxi_pipeline_dir, 'taxi_utils.py'),
            os.path.join(self._airflow_home, 'taxi', 'taxi_utils.py'))

        self._mysql_container_name = 'airflow_' + test_utils.generate_random_id(
        )
        db_port = airflow_test_utils.create_mysql_container(
            self._mysql_container_name)
        self.addCleanup(self._cleanup_mysql_container)
        self.enter_context(
            test_case_utils.override_env_var(
                'AIRFLOW__CORE__SQL_ALCHEMY_CONN',
                'mysql://[email protected]:%d/airflow' % db_port))
        # Do not load examples to make this a bit faster.
        self.enter_context(
            test_case_utils.override_env_var('AIRFLOW__CORE__LOAD_EXAMPLES',
                                             'False'))

        self._airflow_initdb()

        # Initialize CLI runner.
        self.runner = click_testing.CliRunner()