Ejemplo n.º 1
0
 def setUp(self):
     self._temp_dir = os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR',
                                     self.get_temp_dir())
     dummy_dag = models.DAG(dag_id='my_component',
                            start_date=datetime.datetime(2019, 1, 1))
     self.checkcache_op = dummy_operator.DummyOperator(
         task_id='my_component.checkcache', dag=dummy_dag)
     self.tfx_python_op = dummy_operator.DummyOperator(
         task_id='my_component.pythonexec', dag=dummy_dag)
     self.noop_sink_op = dummy_operator.DummyOperator(
         task_id='my_component.noop_sink', dag=dummy_dag)
     self.publishexec_op = dummy_operator.DummyOperator(
         task_id='my_component.publishexec', dag=dummy_dag)
     self._logger_config = logging_utils.LoggerConfig()
     self.parent_dag = airflow_pipeline.AirflowPipeline(
         pipeline_name='pipeline_name',
         start_date=datetime.datetime(2018, 1, 1),
         schedule_interval=None,
         pipeline_root='pipeline_root',
         metadata_db_root=self._temp_dir,
         metadata_connection_config=None,
         additional_pipeline_args=None,
         enable_cache=True)
     self.input_dict = {'i': [TfxArtifact('i')]}
     self.output_dict = {'o': [TfxArtifact('o')]}
     self.exec_properties = {'e': 'e'}
     self.driver_options = {'d': 'd'}
Ejemplo n.º 2
0
 def setUp(self):
     self._mock_metadata = tf.test.mock.Mock()
     self._input_dict = {
         'input_data': [types.TfxArtifact(type_name='InputType')],
     }
     input_dir = os.path.join(
         os.environ.get('TEST_TMP_DIR', self.get_temp_dir()),
         self._testMethodName, 'input_dir')
     # valid input artifacts must have a uri pointing to an existing directory.
     for key, input_list in self._input_dict.items():
         for index, artifact in enumerate(input_list):
             artifact.id = index + 1
             uri = os.path.join(input_dir, key, str(artifact.id), '')
             artifact.uri = uri
             tf.gfile.MakeDirs(uri)
     self._output_dict = {
         'output_data': [types.TfxArtifact(type_name='OutputType')],
     }
     self._exec_properties = {
         'key': 'value',
     }
     self._base_output_dir = os.path.join(
         os.environ.get('TEST_TMP_DIR', self.get_temp_dir()),
         self._testMethodName, 'base_output_dir')
     self._driver_options = base_driver.DriverOptions(
         worker_name='worker_name',
         base_output_dir=self._base_output_dir,
         enable_cache=True)
     self._execution_id = 100
     log_root = os.path.join(self._base_output_dir, 'log_dir')
     logger_config = logging_utils.LoggerConfig(log_root=log_root)
     self._logger = logging_utils.get_logger(logger_config)
Ejemplo n.º 3
0
 def testDefaultSettings(self):
     """Ensure log defaults are set correctly."""
     config = logging_utils.LoggerConfig()
     self.assertEqual(config.log_root, '/var/tmp/tfx/logs')
     self.assertEqual(config.log_level, logging.INFO)
     self.assertEqual(config.pipeline_name, '')
     self.assertEqual(config.worker_name, '')
Ejemplo n.º 4
0
  def __init__(self,
               pipeline_name,
               start_date,
               schedule_interval,
               pipeline_root,
               metadata_db_root,
               metadata_connection_config=None,
               additional_pipeline_args=None,
               docker_operator_cfg=None,
               enable_cache=False):
    super(AirflowPipeline, self).__init__(
        dag_id=pipeline_name,
        schedule_interval=schedule_interval,
        start_date=start_date)
    self.project_path = os.path.join(pipeline_root, pipeline_name)
    self.additional_pipeline_args = additional_pipeline_args
    self.docker_operator_cfg = docker_operator_cfg
    self.enable_cache = enable_cache

    if additional_pipeline_args is None:
      additional_pipeline_args = {}

    # Configure logging
    self.logger_config = logging_utils.LoggerConfig(pipeline_name=pipeline_name)
    if 'logger_args' in additional_pipeline_args:
      self.logger_config.update(additional_pipeline_args.get('logger_args'))

    self._logger = logging_utils.get_logger(self.logger_config)
    self.metadata_connection_config = metadata_connection_config or _get_default_metadata_connection_config(
        metadata_db_root, pipeline_name)
    self._producer_map = {}
    self._consumer_map = {}
    self._upstreams_map = collections.defaultdict(set)
Ejemplo n.º 5
0
    def test_fetch_last_blessed_model(self):
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)
        self._logger_config = logging_utils.LoggerConfig(
            log_root=os.path.join(output_data_dir, 'log_dir'))

        # Mock metadata.
        mock_metadata = tf.test.mock.Mock()
        model_validator_driver = driver.Driver(self._logger_config,
                                               mock_metadata)

        # No blessed model.
        mock_metadata.get_all_artifacts.return_value = []
        self.assertEqual((None, None),
                         model_validator_driver._fetch_last_blessed_model())

        # Mock blessing artifacts.
        artifacts = []
        for span in [4, 3, 2, 1]:
            model_blessing = types.TfxArtifact(type_name='ModelBlessingPath')
            model_blessing.span = span
            model_blessing.set_string_custom_property('current_model',
                                                      'uri-%d' % span)
            model_blessing.set_int_custom_property('current_model_id', span)
            # Only odd spans are "blessed"
            model_blessing.set_int_custom_property('blessed', span % 2)
            artifacts.append(model_blessing.artifact)
        mock_metadata.get_all_artifacts.return_value = artifacts
        self.assertEqual(('uri-3', 3),
                         model_validator_driver._fetch_last_blessed_model())
Ejemplo n.º 6
0
 def setUp(self):
     dummy_dag = models.DAG(dag_id='my_component',
                            start_date=datetime.datetime(2019, 1, 1))
     self.checkcache_op = dummy_operator.DummyOperator(
         task_id='my_component.checkcache', dag=dummy_dag)
     self.tfx_python_op = dummy_operator.DummyOperator(
         task_id='my_component.pythonexec', dag=dummy_dag)
     self.tfx_docker_op = dummy_operator.DummyOperator(
         task_id='my_component.dockerexec', dag=dummy_dag)
     self.publishcache_op = dummy_operator.DummyOperator(
         task_id='my_component.publishcache', dag=dummy_dag)
     self.publishexec_op = dummy_operator.DummyOperator(
         task_id='my_component.publishexec', dag=dummy_dag)
     self._logger_config = logging_utils.LoggerConfig()
     self.parent_dag = airflow_pipeline.AirflowPipeline(
         pipeline_name='pipeline_name',
         start_date=datetime.datetime(2018, 1, 1),
         schedule_interval=None,
         pipeline_root='pipeline_root',
         metadata_db_root='metadata_db_root',
         metadata_connection_config=None,
         additional_pipeline_args=None,
         docker_operator_cfg=None,
         enable_cache=True)
     self.input_dict = {'i': [TfxType('i')]}
     self.output_dict = {'o': [TfxType('o')]}
     self.exec_properties = {'e': 'e'}
     self.driver_options = {'d': 'd'}
Ejemplo n.º 7
0
 def test_override_settings(self):
   """Ensure log overrides are set correctly."""
   config = logging_utils.LoggerConfig(log_root='path', log_level=logging.WARN,
                                       pipeline_name='pipe', worker_name='wrk')
   self.assertEqual(config.log_root, 'path')
   self.assertEqual(config.log_level, logging.WARN)
   self.assertEqual(config.pipeline_name, 'pipe')
   self.assertEqual(config.worker_name, 'wrk')
Ejemplo n.º 8
0
 def setUp(self):
     self.input_one = types.TfxArtifact('INPUT_ONE')
     self.input_one.source = airflow_component._OrchestrationSource(
         'input_one_key', 'input_one_component_id')
     self.output_one = types.TfxArtifact('OUTPUT_ONE')
     self.output_one.source = airflow_component._OrchestrationSource(
         'output_one_key', 'output_one_component_id')
     self.input_one_json = json.dumps([self.input_one.json_dict()])
     self.output_one_json = json.dumps([self.output_one.json_dict()])
     self._logger_config = logging_utils.LoggerConfig()
Ejemplo n.º 9
0
  def __init__(self,
               parent_dag,
               component_name,
               unique_name,
               driver,
               executor,
               input_dict,
               output_dict,
               exec_properties):
    # Prepare parameters to create TFX worker.
    if unique_name:
      worker_name = component_name + '.' + unique_name
    else:
      worker_name = component_name
    task_id = parent_dag.dag_id + '.' + worker_name

    # Create output object of appropriate type
    output_dir = self._get_working_dir(parent_dag.project_path, component_name,
                                       unique_name or '')

    # Update the output dict before providing to downstream componentsget_
    for k, output_list in output_dict.items():
      for single_output in output_list:
        single_output.source = _OrchestrationSource(key=k, component_id=task_id)

    my_logger_config = logging_utils.LoggerConfig(
        log_root=parent_dag.logger_config.log_root,
        log_level=parent_dag.logger_config.log_level,
        pipeline_name=parent_dag.logger_config.pipeline_name,
        worker_name=worker_name)
    driver_options = base_driver.DriverOptions(
        worker_name=worker_name,
        base_output_dir=output_dir,
        enable_cache=parent_dag.enable_cache)

    worker = _TfxWorker(
        component_name=component_name,
        task_id=task_id,
        parent_dag=parent_dag,
        input_dict=input_dict,
        output_dict=output_dict,
        exec_properties=exec_properties,
        driver_options=driver_options,
        driver_class=driver,
        executor_class=executor,
        additional_pipeline_args=parent_dag.additional_pipeline_args,
        metadata_connection_config=parent_dag.metadata_connection_config,
        logger_config=my_logger_config)
    subdag = subdag_operator.SubDagOperator(
        subdag=worker, task_id=worker_name, dag=parent_dag)

    parent_dag.add_node_to_graph(
        node=subdag,
        consumes=input_dict.values(),
        produces=output_dict.values())
Ejemplo n.º 10
0
  def test_prepare_input_for_processing(self):
    output_data_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self._testMethodName)
    self._logger_config = logging_utils.LoggerConfig(
        log_root=os.path.join(output_data_dir, 'log_dir'))

    # Mock metadata.
    mock_metadata = tf.test.mock.Mock()
    csv_example_gen = driver.Driver(self._logger_config, mock_metadata)

    # Mock artifact.
    artifacts = []
    for i in [4, 3, 2, 1]:
      artifact = metadata_store_pb2.Artifact()
      artifact.id = i
      # Only odd ids will be matched to input_base.uri.
      artifact.uri = 'path-{}'.format(i % 2)
      artifacts.append(artifact)

    # Create input dict.
    input_base = types.TfxType(type_name='ExternalPath')
    input_base.uri = 'path-1'
    input_dict = {'input-base': [input_base]}

    # Cache not hit.
    mock_metadata.get_all_artifacts.return_value = []
    mock_metadata.publish_artifacts.return_value = [artifacts[3]]
    updated_input_dict = csv_example_gen._prepare_input_for_processing(
        copy.deepcopy(input_dict))
    self.assertEqual(1, len(updated_input_dict))
    self.assertEqual(1, len(updated_input_dict['input-base']))
    updated_input_base = updated_input_dict['input-base'][0]
    self.assertEqual(1, updated_input_base.id)
    self.assertEqual('path-1', updated_input_base.uri)

    # Cache hit.
    mock_metadata.get_all_artifacts.return_value = artifacts
    mock_metadata.publish_artifacts.return_value = []
    updated_input_dict = csv_example_gen._prepare_input_for_processing(
        copy.deepcopy(input_dict))
    self.assertEqual(1, len(updated_input_dict))
    self.assertEqual(1, len(updated_input_dict['input-base']))
    updated_input_base = updated_input_dict['input-base'][0]
    self.assertEqual(3, updated_input_base.id)
    self.assertEqual('path-1', updated_input_base.uri)
Ejemplo n.º 11
0
 def test_fetch_warm_starting_model(self):
     mock_metadata = tf.test.mock.Mock()
     artifacts = []
     for span in [3, 2, 1]:
         model = types.TfxArtifact(type_name='ModelExportPath')
         model.span = span
         model.uri = 'uri-%d' % span
         artifacts.append(model.artifact)
     mock_metadata.get_all_artifacts.return_value = artifacts
     output_data_dir = os.path.join(
         os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
         self._testMethodName)
     log_root = os.path.join(output_data_dir, 'log_dir')
     logger_config = logging_utils.LoggerConfig(log_root=log_root)
     logger = logging_utils.get_logger(logger_config)
     trainer_driver = driver.Driver(logger, mock_metadata)
     result = trainer_driver._fetch_latest_model()
     self.assertEqual('uri-3', result)
Ejemplo n.º 12
0
 def setUp(self):
     super(LoggingUtilsTest, self).setUp()
     self._log_root = os.path.join(self.get_temp_dir(), 'log_dir')
     self._logger_config = logging_utils.LoggerConfig(
         log_root=self._log_root)
Ejemplo n.º 13
0
 def setUp(self):
     self._connection_config = metadata_store_pb2.ConnectionConfig()
     self._connection_config.sqlite.SetInParent()
     log_root = os.path.join(self.get_temp_dir(), 'log_dir')
     logger_config = logging_utils.LoggerConfig(log_root=log_root)
     self._logger = logging_utils.get_logger(logger_config)