示例#1
0
 def _create_pipeline(self, pipeline_name: Text,
                      components: List[BaseComponent]):
     """Creates a pipeline given name and list of components."""
     return tfx_pipeline.Pipeline(
         pipeline_name=pipeline_name,
         pipeline_root=self._pipeline_root(pipeline_name),
         metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
         components=components,
         log_root='/var/tmp/tfx/logs',
         additional_pipeline_args={
             # Use a fixed WORKFLOW_ID (which is used as run id) for testing,
             # for the purpose of making debugging easier.
             'WORKFLOW_ID': pipeline_name,
         },
     )
示例#2
0
 def setUp(self):
     super().setUp()
     self._connection_config = metadata_store_pb2.ConnectionConfig()
     self._connection_config.sqlite.SetInParent()
     self._metadata = self.enter_context(
         metadata.Metadata(connection_config=self._connection_config))
     self._store = self._metadata.store
     self._pipeline_info = data_types.PipelineInfo(
         pipeline_name='my_pipeline',
         pipeline_root='/tmp',
         run_id='my_run_id')
     self._component_info = data_types.ComponentInfo(
         component_type='a.b.c',
         component_id='my_component',
         pipeline_info=self._pipeline_info)
示例#3
0
 def setUp(self):
     super(ResolverDriverTest, self).setUp()
     self.connection_config = metadata_store_pb2.ConnectionConfig()
     self.connection_config.sqlite.SetInParent()
     self.component_info = data_types.ComponentInfo(component_type='c_type',
                                                    component_id='c_id')
     self.pipeline_info = data_types.PipelineInfo(pipeline_name='p_name',
                                                  pipeline_root='p_root',
                                                  run_id='run_id')
     self.driver_args = data_types.DriverArgs(enable_cache=True)
     self.source_channel_key = 'source_channel'
     self.source_channels = {
         self.source_channel_key:
         types.Channel(type=standard_artifacts.Examples)
     }
示例#4
0
def _create_pipeline():
    pipeline_name = _PIPELINE_NAME
    pipeline_root = os.path.join(_get_test_output_dir(), pipeline_name)
    components = test_utils.create_e2e_components(_get_csv_input_location())
    return tfx_pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
        components=components[:
                              2],  # Run two components only to reduce overhead.
        log_root='/var/tmp/tfx/logs',
        additional_pipeline_args={
            'WORKFLOW_ID': pipeline_name,
        },
    )
示例#5
0
    def testRun(self, mock_publisher):
        mock_publisher.return_value.publish_execution.return_value = {}

        test_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        connection_config = metadata_store_pb2.ConnectionConfig()
        connection_config.sqlite.SetInParent()
        metadata_connection = metadata.Metadata(connection_config)

        pipeline_root = os.path.join(test_dir, 'Test')
        input_path = os.path.join(test_dir, 'input')
        fileio.makedirs(os.path.dirname(input_path))
        file_io.write_string_to_file(input_path, 'test')

        input_artifact = test_utils._InputArtifact()
        input_artifact.uri = input_path

        component = test_utils._FakeComponent(
            name='FakeComponent',
            input_channel=channel_utils.as_channel([input_artifact]))

        pipeline_info = data_types.PipelineInfo(pipeline_name='Test',
                                                pipeline_root=pipeline_root,
                                                run_id='123')

        driver_args = data_types.DriverArgs(enable_cache=True)

        # We use InProcessComponentLauncher to test BaseComponentLauncher logics.
        launcher = in_process_component_launcher.InProcessComponentLauncher.create(
            component=component,
            pipeline_info=pipeline_info,
            driver_args=driver_args,
            metadata_connection=metadata_connection,
            beam_pipeline_args=[],
            additional_pipeline_args={})
        self.assertEqual(
            launcher._component_info.component_type, '.'.join([
                test_utils._FakeComponent.__module__,
                test_utils._FakeComponent.__name__
            ]))
        launcher.launch()

        output_path = component.outputs['output'].get()[0].uri
        self.assertTrue(fileio.exists(output_path))
        contents = file_io.read_file_to_string(output_path)
        self.assertEqual('test', contents)
示例#6
0
    def testRun(self):
        component_a = _FakeComponent(
            _FakeComponentSpecA(output=types.Channel(type=_ArtifactTypeA)),
            enable_cache=True)
        component_b = _FakeComponent(_FakeComponentSpecB(
            a=component_a.outputs['output'],
            output=types.Channel(type=_ArtifactTypeB)),
                                     enable_cache=False)
        component_c = _FakeComponent(
            _FakeComponentSpecC(a=component_a.outputs['output'],
                                output=types.Channel(type=_ArtifactTypeC)),
            True)
        component_d = _FakeComponent(
            _FakeComponentSpecD(b=component_b.outputs['output'],
                                c=component_c.outputs['output'],
                                output=types.Channel(type=_ArtifactTypeD)),
            False)
        component_e = _FakeComponent(
            _FakeComponentSpecE(a=component_a.outputs['output'],
                                b=component_b.outputs['output'],
                                d=component_d.outputs['output'],
                                output=types.Channel(type=_ArtifactTypeE)))

        test_pipeline = pipeline.Pipeline(
            pipeline_name='x',
            pipeline_root='y',
            metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
            components=[
                component_d, component_c, component_a, component_b, component_e
            ])

        beam_dag_runner.BeamDagRunner().run(test_pipeline)
        self.assertCountEqual(_executed_components, [
            '_FakeComponent.a', '_FakeComponent.b', '_FakeComponent.c',
            '_FakeComponent.d', '_FakeComponent.e'
        ])
        self.assertEqual(_executed_components[0], '_FakeComponent.a')
        self.assertEqual(_executed_components[3], '_FakeComponent.d')
        self.assertEqual(_executed_components[4], '_FakeComponent.e')

        self.assertDictEqual(
            {
                '_FakeComponent.a': True,
                '_FakeComponent.b': False,
                '_FakeComponent.c': True,
                '_FakeComponent.d': False,
                '_FakeComponent.e': False,
            }, _executed_components_cached)
示例#7
0
def sqlite_metadata_connection_config(
        metadata_db_uri: Text) -> metadata_store_pb2.ConnectionConfig:
    """Convenience function to create file based metadata connection config.

  Args:
    metadata_db_uri: uri to metadata db.

  Returns:
    A metadata_store_pb2.ConnectionConfig based on given metadata db uri.
  """
    tf.io.gfile.makedirs(os.path.dirname(metadata_db_uri))
    connection_config = metadata_store_pb2.ConnectionConfig()
    connection_config.sqlite.filename_uri = metadata_db_uri
    connection_config.sqlite.connection_mode = \
      metadata_store_pb2.SqliteMetadataSourceConfig.READWRITE_OPENCREATE
    return connection_config
示例#8
0
def get_default_kubernetes_metadata_config(
) -> metadata_store_pb2.ConnectionConfig:
    """Returns the default metadata connection config for a kubernetes cluster.

  Returns:
    A config proto that will be serialized as JSON and passed to the running
    container so the TFX component driver is able to communicate with MLMD in
    a kubernetes cluster.
  """
    connection_config = metadata_store_pb2.ConnectionConfig()
    connection_config.mysql.host = 'mysql'
    connection_config.mysql.port = 3306
    connection_config.mysql.database = 'mysql'
    connection_config.mysql.user = '******'
    connection_config.mysql.password = ''
    return connection_config
 def setUp(self):
     super(KubernetesRemoteRunnerTest, self).setUp()
     self.component_a = _FakeComponent(
         _FakeComponentSpecA(output=types.Channel(type=_ArtifactTypeA)))
     self.component_b = _FakeComponent(
         _FakeComponentSpecB(a=self.component_a.outputs['output'],
                             output=types.Channel(type=_ArtifactTypeB)))
     self.component_c = _FakeComponent(
         _FakeComponentSpecC(a=self.component_a.outputs['output'],
                             b=self.component_b.outputs['output'],
                             output=types.Channel(type=_ArtifactTypeC)))
     self.test_pipeline = tfx_pipeline.Pipeline(
         pipeline_name='x',
         pipeline_root='y',
         metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
         components=[self.component_c, self.component_a, self.component_b])
 def test_before_generating_trial_model(self):
     connection_config = metadata_store_pb2.ConnectionConfig()
     connection_config.sqlite.filename_uri = os.path.join(
         FLAGS.test_tmpdir, "1")
     connection_config.sqlite.connection_mode = 3
     handler = ml_metadata_db.MLMetaData(
         None, None, None, connection_config=connection_config)
     handler.before_generating_trial_model(trial_id=1, model_dir="/tmp/1")
     output = handler._store.get_executions_by_type("Trial")
     self.assertLen(output, 1)
     output = output[0]
     self.assertEqual(output.properties["id"].int_value, 1)
     self.assertEqual(output.properties["state"].string_value, "RUNNING")
     self.assertEqual(output.properties["serialized_data"].string_value, "")
     self.assertEqual(output.properties["model_dir"].string_value, "/tmp/1")
     self.assertEqual(output.properties["evaluation"].string_value, "")
 def test_report(self):
     connection_config = metadata_store_pb2.ConnectionConfig()
     connection_config.sqlite.filename_uri = os.path.join(
         FLAGS.test_tmpdir, "2")
     connection_config.sqlite.connection_mode = 3
     handler = ml_metadata_db.MLMetaData(
         None, None, None, connection_config=connection_config)
     handler.before_generating_trial_model(trial_id=1, model_dir="/tmp/1")
     handler.report(eval_dictionary={"loss": 0.5}, model_dir="/tmp/1")
     output = handler.get_completed_trials()
     self.assertLen(output, 1)
     output = output[0]
     self.assertEqual(output.id, 1)
     self.assertEqual(output.status, "COMPLETED")
     self.assertEqual(output.model_dir, "/tmp/1")
     self.assertEqual(output.final_measurement.objective_value, 0.5)
示例#12
0
 def setUp(self):
     super(MetadataTest, self).setUp()
     self._connection_config = metadata_store_pb2.ConnectionConfig()
     self._connection_config.sqlite.SetInParent()
     self._component_info = data_types.ComponentInfo(
         component_type='a.b.c', component_id='my_component')
     self._component_info2 = data_types.ComponentInfo(
         component_type='a.b.d', component_id='my_component_2')
     self._pipeline_info = data_types.PipelineInfo(
         pipeline_name='my_pipeline',
         pipeline_root='/tmp',
         run_id='my_run_id')
     self._pipeline_info2 = data_types.PipelineInfo(
         pipeline_name='my_pipeline',
         pipeline_root='/tmp',
         run_id='my_run_id2')
示例#13
0
    def test_run(self, mock_publisher):
        mock_publisher.return_value.publish_execution.return_value = {}

        test_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        connection_config = metadata_store_pb2.ConnectionConfig()
        connection_config.sqlite.SetInParent()

        pipeline_root = os.path.join(test_dir, 'Test')
        input_path = os.path.join(test_dir, 'input')
        tf.gfile.MakeDirs(os.path.dirname(input_path))
        file_io.write_string_to_file(input_path, 'test')

        input_artifact = types.TfxArtifact(type_name='InputPath')
        input_artifact.uri = input_path

        component = _FakeComponent(name='FakeComponent',
                                   input_channel=channel.as_channel(
                                       [input_artifact]))

        pipeline_info = data_types.PipelineInfo(pipeline_name='Test',
                                                pipeline_root=pipeline_root,
                                                run_id='123')

        driver_args = data_types.DriverArgs(worker_name=component.component_id,
                                            base_output_dir=os.path.join(
                                                pipeline_root,
                                                component.component_id),
                                            enable_cache=True)

        launcher = component_launcher.ComponentLauncher(
            component=component,
            pipeline_info=pipeline_info,
            driver_args=driver_args,
            metadata_connection_config=connection_config,
            additional_pipeline_args={})
        self.assertEqual(
            launcher._component_info.component_type,
            '.'.join([_FakeComponent.__module__, _FakeComponent.__name__]))
        launcher.launch()

        output_path = os.path.join(pipeline_root, 'output')
        self.assertTrue(tf.gfile.Exists(output_path))
        contents = file_io.read_file_to_string(output_path)
        self.assertEqual('test', contents)
示例#14
0
    def testResolverWithResolverPolicy(self):
        pipeline = pipeline_pb2.Pipeline()
        self.load_proto_from_text(
            os.path.join(self._testdata_dir,
                         'pipeline_for_input_resolver_test.pbtxt'), pipeline)
        my_example_gen = pipeline.nodes[0].pipeline_node
        my_transform = pipeline.nodes[2].pipeline_node

        connection_config = metadata_store_pb2.ConnectionConfig()
        connection_config.sqlite.SetInParent()
        with metadata.Metadata(connection_config=connection_config) as m:
            # Publishes first ExampleGen with two output channels. `output_examples`
            # will be consumed by downstream Transform.
            output_example_1 = types.Artifact(
                my_example_gen.outputs.outputs['output_examples'].
                artifact_spec.type)
            output_example_1.uri = 'my_examples_uri_1'

            output_example_2 = types.Artifact(
                my_example_gen.outputs.outputs['output_examples'].
                artifact_spec.type)
            output_example_2.uri = 'my_examples_uri_2'

            contexts = context_lib.register_contexts_if_not_exists(
                m, my_example_gen.contexts)
            execution = execution_publish_utils.register_execution(
                m, my_example_gen.node_info.type, contexts)
            execution_publish_utils.publish_succeeded_execution(
                m, execution.id, contexts, {
                    'output_examples': [output_example_1, output_example_2],
                })

            my_transform.inputs.resolver_config.resolver_policy = (
                pipeline_pb2.ResolverConfig.LATEST_ARTIFACT)

            # Gets inputs for transform. Should get back what the first ExampleGen
            # published in the `output_examples` channel.
            transform_inputs = inputs_utils.resolve_input_artifacts(
                m, my_transform.inputs)
            self.assertEqual(len(transform_inputs), 1)
            self.assertEqual(len(transform_inputs['examples']), 1)
            self.assertProtoPartiallyEquals(
                transform_inputs['examples'][0].mlmd_artifact,
                output_example_2.mlmd_artifact,
                ignored_fields=[
                    'create_time_since_epoch', 'last_update_time_since_epoch'
                ])
def _get_metadata_store(grpc_max_receive_message_length=None):
  if FLAGS.use_grpc_backend:
    grpc_connection_config = metadata_store_pb2.MetadataStoreClientConfig()
    if grpc_max_receive_message_length:
      (grpc_connection_config.channel_arguments.max_receive_message_length
      ) = grpc_max_receive_message_length
    if FLAGS.grpc_host is None:
      raise ValueError("grpc_host argument not set.")
    grpc_connection_config.host = FLAGS.grpc_host
    if not FLAGS.grpc_port:
      raise ValueError("grpc_port argument not set.")
    grpc_connection_config.port = FLAGS.grpc_port
    return metadata_store.MetadataStore(grpc_connection_config)

  connection_config = metadata_store_pb2.ConnectionConfig()
  connection_config.sqlite.SetInParent()
  return metadata_store.MetadataStore(connection_config)
示例#16
0
def _create_pipeline():
  pipeline_name = _PIPELINE_NAME
  test_output_dir = 'gs://{}/test_output'.format(test_utils.BUCKET_NAME)
  pipeline_root = os.path.join(test_output_dir, pipeline_name)
  components = test_utils.create_e2e_components(pipeline_root,
                                                test_utils.DATA_ROOT,
                                                test_utils.TAXI_MODULE_FILE)
  return tfx_pipeline.Pipeline(
      pipeline_name=pipeline_name,
      pipeline_root=pipeline_root,
      metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
      components=components[:2],
      log_root='/var/tmp/tfx/logs',
      additional_pipeline_args={
          'WORKFLOW_ID': pipeline_name,
      },
  )
示例#17
0
 def setUp(self):
   super().setUp()
   self._connection_config = metadata_store_pb2.ConnectionConfig()
   self._connection_config.sqlite.SetInParent()
   self._module_file_path = os.path.join(self.tmp_dir, 'module_file')
   self._input_artifacts = {'input_examples': [standard_artifacts.Examples()]}
   self._output_artifacts = {'output_models': [standard_artifacts.Model()]}
   self._parameters = {'module_file': self._module_file_path}
   self._module_file_content = 'module content'
   self._pipeline_node = text_format.Parse(
       """
       executor {
         python_class_executor_spec {class_path: 'a.b.c'}
       }
       """, pipeline_pb2.PipelineNode())
   self._executor_class_path = 'a.b.c'
   self._pipeline_info = pipeline_pb2.PipelineInfo(id='pipeline_id')
示例#18
0
    def setUp(self):
        super().setUp()

        example_gen_output_config = data_types.RuntimeParameter(
            name='example-gen-output-config', ptype=str)

        example_gen = csv_example_gen_component.CsvExampleGen(
            input_base='data_root', output_config=example_gen_output_config)
        statistics_gen = statistics_gen_component.StatisticsGen(
            examples=example_gen.outputs['examples']).with_id('foo')

        test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param')
        pipeline = tfx_pipeline.Pipeline(
            pipeline_name=self._test_pipeline_name,
            pipeline_root='test_pipeline_root',
            metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
            components=[example_gen, statistics_gen],
        )

        self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
        self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST'
        self._tfx_ir = pipeline_pb2.Pipeline()
        with dsl.Pipeline('test_pipeline'):
            self.example_gen = base_component.BaseComponent(
                component=example_gen,
                depends_on=set(),
                pipeline=pipeline,
                pipeline_root=test_pipeline_root,
                tfx_image='container_image',
                kubeflow_metadata_config=self._metadata_config,
                tfx_ir=self._tfx_ir,
                pod_labels_to_attach={},
                runtime_parameters=[example_gen_output_config])
            self.statistics_gen = base_component.BaseComponent(
                component=statistics_gen,
                depends_on=set(),
                pipeline=pipeline,
                pipeline_root=test_pipeline_root,
                tfx_image='container_image',
                kubeflow_metadata_config=self._metadata_config,
                tfx_ir=self._tfx_ir,
                pod_labels_to_attach={},
                runtime_parameters=[])

        self.tfx_example_gen = example_gen
        self.tfx_statistics_gen = statistics_gen
示例#19
0
 def setUp(self):
   super(ImporterDriverTest, self).setUp()
   self.connection_config = metadata_store_pb2.ConnectionConfig()
   self.connection_config.sqlite.SetInParent()
   self.artifact_type = 'Examples'
   self.output_dict = {
       importer_node.IMPORT_RESULT_KEY:
           types.Channel(type_name=self.artifact_type)
   }
   self.source_uri = 'm/y/u/r/i'
   self.existing_artifact = types.Artifact(type_name=self.artifact_type)
   self.existing_artifact.uri = self.source_uri
   self.component_info = data_types.ComponentInfo(
       component_type='c_type', component_id='c_id')
   self.pipeline_info = data_types.PipelineInfo(
       pipeline_name='p_name', pipeline_root='p_root', run_id='run_id')
   self.driver_args = data_types.DriverArgs(enable_cache=True)
示例#20
0
    def __init__(self,
                 components_to_always_add: List[BaseComponent],
                 benchmark_subpipelines: List[BenchmarkSubpipeline],
                 pipeline_name: Optional[str],
                 pipeline_root: Optional[str],
                 metadata_connection_config: Optional[
                     metadata_store_pb2.ConnectionConfig] = None,
                 beam_pipeline_args: Optional[List[str]] = None,
                 **kwargs):

        if not benchmark_subpipelines and not components_to_always_add:
            raise ValueError(
                "Requires at least one benchmark subpipeline or component to run. "
                "You may want to call `self.add(..., always=True) in order "
                "to run Components, Subpipelines, or Pipeline even without requiring "
                "a call to `self.evaluate(...)`.")

        # Set defaults.
        if not pipeline_name:
            pipeline_name = "nitroml"
        if not pipeline_root:
            tmp_root_dir = os.path.join("/tmp", pipeline_name)
            tf.io.gfile.makedirs(tmp_root_dir)
            pipeline_root = tempfile.mkdtemp(dir=tmp_root_dir)
            logging.info("Creating tmp pipeline_root at %s", pipeline_root)
        if not metadata_connection_config:
            metadata_connection_config = metadata_store_pb2.ConnectionConfig(
                sqlite=metadata_store_pb2.SqliteMetadataSourceConfig(
                    filename_uri=os.path.join(pipeline_root, "mlmd.sqlite")))

        # Ensure that pipeline dirs are created.
        _make_pipeline_dirs(pipeline_root, metadata_connection_config)

        components = set(components_to_always_add)
        for benchmark_subpipeline in benchmark_subpipelines:
            for component in benchmark_subpipeline.components:
                components.add(component)
        super().__init__(pipeline_name=pipeline_name,
                         pipeline_root=pipeline_root,
                         metadata_connection_config=metadata_connection_config,
                         components=list(components),
                         beam_pipeline_args=beam_pipeline_args,
                         **kwargs)

        self._subpipelines = benchmark_subpipelines
示例#21
0
    def testRun(self, mock_publisher):
        mock_publisher.return_value.publish_execution.return_value = {}

        example_gen = FileBasedExampleGen(
            custom_executor_spec=executor_spec.ExecutorClassSpec(
                avro_executor.Executor),
            input_base=self.avro_dir_path,
            input_config=self.input_config,
            output_config=self.output_config,
            instance_name='AvroExampleGen')

        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)
        pipeline_root = os.path.join(output_data_dir, 'Test')
        fileio.makedirs(pipeline_root)
        pipeline_info = data_types.PipelineInfo(pipeline_name='Test',
                                                pipeline_root=pipeline_root,
                                                run_id='123')

        driver_args = data_types.DriverArgs(enable_cache=True)

        connection_config = metadata_store_pb2.ConnectionConfig()
        connection_config.sqlite.SetInParent()
        metadata_connection = metadata.Metadata(connection_config)

        launcher = in_process_component_launcher.InProcessComponentLauncher.create(
            component=example_gen,
            pipeline_info=pipeline_info,
            driver_args=driver_args,
            metadata_connection=metadata_connection,
            beam_pipeline_args=[],
            additional_pipeline_args={})
        self.assertEqual(
            launcher._component_info.component_type, '.'.join(
                [FileBasedExampleGen.__module__,
                 FileBasedExampleGen.__name__]))

        launcher.launch()
        mock_publisher.return_value.publish_execution.assert_called_once()

        # Check output paths.
        self.assertTrue(
            fileio.exists(os.path.join(pipeline_root, example_gen.id)))
    def _input_artifacts(self, pipeline_name: Text,
                         input_artifacts: List[Artifact]) -> Channel:
        """Publish input artifacts for test to MLMD and return channel to them."""
        connection_config = metadata_store_pb2.ConnectionConfig(
            mysql=metadata_store_pb2.MySQLDatabaseConfig(
                host='127.0.0.1',
                port=3306,
                database=self._get_mlmd_db_name(pipeline_name),
                user='******',
                password=''))

        dummy_artifact = (input_artifacts[0].type_name, self._random_id())
        output_key = 'dummy_output_%s_%s' % dummy_artifact
        producer_component_id = 'dummy_producer_id_%s_%s' % dummy_artifact
        producer_component_type = 'dummy_producer_type_%s_%s' % dummy_artifact

        # Input artifacts must have a unique name and producer in MLMD.
        for artifact in input_artifacts:
            artifact.name = output_key
            artifact.pipeline_name = pipeline_name
            artifact.producer_component = producer_component_id

        with metadata.Metadata(connection_config=connection_config) as m:
            # Register a dummy execution to metadata store as producer execution.
            execution_id = m.register_execution(
                exec_properties={},
                pipeline_info=data_types.PipelineInfo(
                    pipeline_name=pipeline_name,
                    pipeline_root='/dummy_pipeline_root',
                    # test_utils uses pipeline_name as fixed WORKFLOW_ID.
                    run_id=pipeline_name,
                ),
                component_info=data_types.ComponentInfo(
                    component_type=producer_component_type,
                    component_id=producer_component_id))

            # Publish the test input artifact from the dummy execution.
            published_artifacts = m.publish_execution(
                execution_id=execution_id,
                input_dict={},
                output_dict={output_key: input_artifacts})

        return channel_utils.as_channel(published_artifacts[output_key])
示例#23
0
def _create_pipeline():
    pipeline_name = _PIPELINE_NAME
    pipeline_root = os.path.join(test_utils.get_test_output_dir(),
                                 pipeline_name)
    components = test_utils.create_e2e_components(
        pipeline_root,
        test_utils.get_csv_input_location(),
        test_utils.get_transform_module(),
        test_utils.get_trainer_module(),
    )
    return tfx_pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
        components=components[:4],
        log_root='/var/tmp/tfx/logs',
        additional_pipeline_args={
            'WORKFLOW_ID': pipeline_name,
        },
    )
 def test_get_completed_trials(self):
     connection_config = metadata_store_pb2.ConnectionConfig()
     connection_config.sqlite.filename_uri = os.path.join(
         FLAGS.test_tmpdir, "3")
     connection_config.sqlite.connection_mode = 3
     handler = ml_metadata_db.MLMetaData(
         None, None, None, connection_config=connection_config)
     handler.before_generating_trial_model(trial_id=1, model_dir="/tmp/1")
     handler.before_generating_trial_model(trial_id=2, model_dir="/tmp/2")
     handler.report(eval_dictionary={"loss": 0.1}, model_dir="/tmp/1")
     handler.before_generating_trial_model(trial_id=3, model_dir="/tmp/3")
     handler.report(eval_dictionary={"loss": 0.3}, model_dir="/tmp/3")
     handler.report(eval_dictionary={"loss": 0.2}, model_dir="/tmp/2")
     output = handler.get_completed_trials()
     self.assertLen(output, 3)
     for i in range(3):
         self.assertEqual(output[i].status, "COMPLETED")
         self.assertEqual(output[i].model_dir, "/tmp/" + str(output[i].id))
         self.assertEqual(output[i].final_measurement.objective_value,
                          float(output[i].id) / 10)
 def setUp(self):
     super(KubeflowMetadataAdapterTest, self).setUp()
     self._connection_config = metadata_store_pb2.ConnectionConfig()
     self._connection_config.sqlite.SetInParent()
     self._pipeline_info = data_types.PipelineInfo(
         pipeline_name='fake_pipeline_name',
         pipeline_root='/fake_pipeline_root',
         run_id='fake_run_id')
     self._pipeline_info2 = data_types.PipelineInfo(
         pipeline_name='fake_pipeline_name',
         pipeline_root='/fake_pipeline_root',
         run_id='fake_run_id2')
     self._component_info = data_types.ComponentInfo(
         component_type='fake.component.type',
         component_id='fake_component_id',
         pipeline_info=self._pipeline_info)
     self._component_info2 = data_types.ComponentInfo(
         component_type='fake.component.type',
         component_id='fake_component_id',
         pipeline_info=self._pipeline_info2)
示例#26
0
    def testRunWithSameSpec(self, mock_kube_utils):
        _initialize_executed_components()
        mock_kube_utils.is_inside_cluster.return_value = True

        component_a = _FakeComponent(spec=_FakeComponentSpecA(
            output=types.Channel(type=_ArtifactTypeA)))
        component_f1 = _FakeComponent(spec=_FakeComponentSpecF(
            a=component_a.outputs['output'])).with_id('f1')
        component_f2 = _FakeComponent(spec=_FakeComponentSpecF(
            a=component_a.outputs['output'])).with_id('f2')
        component_f2.add_upstream_node(component_f1)

        test_pipeline = pipeline.Pipeline(
            pipeline_name='x',
            pipeline_root='y',
            metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
            components=[component_f1, component_f2, component_a])
        kubernetes_dag_runner.KubernetesDagRunner().run(test_pipeline)
        self.assertEqual(_executed_components,
                         ['a.Wrapper', 'f1.Wrapper', 'f2.Wrapper'])
示例#27
0
def _two_step_pipeline() -> tfx_pipeline.Pipeline:
    default_input_config = json.dumps({
        'splits': [{
            'name': 'single_split',
            'pattern': 'SELECT * FROM default-table'
        }]
    })
    input_config = data_types.RuntimeParameter(name='input_config',
                                               ptype=str,
                                               default=default_input_config)
    example_gen = big_query_example_gen_component.BigQueryExampleGen(
        input_config=input_config, output_config=example_gen_pb2.Output())
    statistics_gen = statistics_gen_component.StatisticsGen(
        examples=example_gen.outputs['examples'])
    return tfx_pipeline.Pipeline(
        pipeline_name='two_step_pipeline',
        pipeline_root='pipeline_root',
        metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
        components=[example_gen, statistics_gen],
    )
示例#28
0
  def testRun(self, mock_kube_utils):
    _initialize_executed_components()
    mock_kube_utils.is_inside_cluster.return_value = True

    component_a = _FakeComponent(
        spec=_FakeComponentSpecA(output=types.Channel(type=_ArtifactTypeA)))
    component_b = _FakeComponent(
        spec=_FakeComponentSpecB(
            a=component_a.outputs['output'],
            output=types.Channel(type=_ArtifactTypeB)))
    component_c = _FakeComponent(
        spec=_FakeComponentSpecC(
            a=component_a.outputs['output'],
            output=types.Channel(type=_ArtifactTypeC)))
    component_c.add_upstream_node(component_b)
    component_d = _FakeComponent(
        spec=_FakeComponentSpecD(
            b=component_b.outputs['output'],
            c=component_c.outputs['output'],
            output=types.Channel(type=_ArtifactTypeD)))
    component_e = _FakeComponent(
        spec=_FakeComponentSpecE(
            a=component_a.outputs['output'],
            b=component_b.outputs['output'],
            d=component_d.outputs['output'],
            output=types.Channel(type=_ArtifactTypeE)))

    test_pipeline = pipeline.Pipeline(
        pipeline_name='x',
        pipeline_root='y',
        metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
        components=[
            component_d, component_c, component_a, component_b, component_e
        ])

    kubernetes_dag_runner.KubernetesDagRunner().run(test_pipeline)
    self.assertEqual(_executed_components, [
        '_FakeComponent.a.Wrapper', '_FakeComponent.b.Wrapper',
        '_FakeComponent.c.Wrapper', '_FakeComponent.d.Wrapper',
        '_FakeComponent.e.Wrapper'
    ])
示例#29
0
def mysql_metadata_connection_config(
        host: Text, port: int, database: Text, username: Text,
        password: Text) -> metadata_store_pb2.ConnectionConfig:
    """Convenience function to create mysql-based metadata connection config.

  Args:
    host: The name or network address of the instance of MySQL to connect to.
    port: The port MySQL is using to listen for connections.
    database: The name of the database to use.
    username: The MySQL login account being used.
    password: The password for the MySQL account being used.

  Returns:
    A metadata_store_pb2.ConnectionConfig based on given metadata db uri.
  """
    return metadata_store_pb2.ConnectionConfig(
        mysql=metadata_store_pb2.MySQLDatabaseConfig(host=host,
                                                     port=port,
                                                     database=database,
                                                     user=username,
                                                     password=password))
示例#30
0
  def testRun(self):
    component_a = _FakeComponent(
        _FakeComponentSpecA(output=types.Channel(type_name='a')))
    component_b = _FakeComponent(
        _FakeComponentSpecB(
            a=component_a.outputs['output'],
            output=types.Channel(type_name='b')))
    component_c = _FakeComponent(
        _FakeComponentSpecC(
            a=component_a.outputs['output'],
            output=types.Channel(type_name='c')))
    component_d = _FakeComponent(
        _FakeComponentSpecD(
            b=component_b.outputs['output'],
            c=component_c.outputs['output'],
            output=types.Channel(type_name='d')))
    component_e = _FakeComponent(
        _FakeComponentSpecE(
            a=component_a.outputs['output'],
            b=component_b.outputs['output'],
            d=component_d.outputs['output'],
            output=types.Channel(type_name='e')))

    test_pipeline = pipeline.Pipeline(
        pipeline_name='x',
        pipeline_root='y',
        metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
        components=[
            component_d, component_c, component_a, component_b, component_e
        ])

    beam_dag_runner.BeamDagRunner().run(test_pipeline)
    self.assertItemsEqual(_executed_components, [
        '_FakeComponent.a', '_FakeComponent.b', '_FakeComponent.c',
        '_FakeComponent.d', '_FakeComponent.e'
    ])
    self.assertEqual(_executed_components[0], '_FakeComponent.a')
    self.assertEqual(_executed_components[3], '_FakeComponent.d')
    self.assertEqual(_executed_components[4], '_FakeComponent.e')