コード例 #1
0
ファイル: kubeflow_dag_runner.py プロジェクト: tvalentyn/tfx
def get_default_kubeflow_metadata_config(
) -> kubeflow_pb2.KubeflowMetadataConfig:
  """Returns the default metadata connection config for Kubeflow.

  Returns:
    A config proto that will be serialized as JSON and passed to the running
    container so the TFX component driver is able to communicate with MLMD in
    a Kubeflow cluster.
  """
  # The default metadata configuration for a Kubeflow Pipelines cluster is
  # codified in a pair of Kubernetes ConfigMap and Secret that can be found in
  # the following:
  # https://github.com/kubeflow/pipelines/blob/master/manifests/kustomize/base/metadata/metadata-configmap.yaml
  # https://github.com/kubeflow/pipelines/blob/master/manifests/kustomize/base/metadata/metadata-mysql-secret.yaml

  config = kubeflow_pb2.KubeflowMetadataConfig()
  # The environment variable to use to obtain the MySQL service host in the
  # cluster that is backing Kubeflow Metadata. Note that the key in the config
  # map and therefore environment variable used, are lower-cased.
  config.mysql_db_service_host.environment_variable = 'mysql_host'
  # The environment variable to use to obtain the MySQL service port in the
  # cluster that is backing Kubeflow Metadata.
  config.mysql_db_service_port.environment_variable = 'mysql_port'
  # The MySQL database name to use.
  config.mysql_db_name.environment_variable = 'mysql_database'
  # The MySQL database username.
  config.mysql_db_user.environment_variable = 'username'
  # The MySQL database password.
  config.mysql_db_password.environment_variable = 'password'

  return config
コード例 #2
0
  def setUp(self):
    super(BaseComponentWithPipelineParamTest, self).setUp()

    test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param')
    example_gen_buckets = data_types.RuntimeParameter(
        name='example-gen-buckets', ptype=int, default=10)

    examples = standard_artifacts.ExternalArtifact()
    example_gen = csv_example_gen_component.CsvExampleGen(
        input=channel_utils.as_channel([examples]),
        output_config={
            'split_config': {
                'splits': [{
                    'name': 'examples',
                    'hash_buckets': example_gen_buckets
                }]
            }
        })
    statistics_gen = statistics_gen_component.StatisticsGen(
        examples=example_gen.outputs['examples'], instance_name='foo')

    pipeline = tfx_pipeline.Pipeline(
        pipeline_name=self._test_pipeline_name,
        pipeline_root='test_pipeline_root',
        metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
        components=[example_gen, statistics_gen],
    )

    self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
    self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST'
    self._tfx_ir = pipeline_pb2.Pipeline()
    with dsl.Pipeline('test_pipeline'):
      self.example_gen = base_component.BaseComponent(
          component=example_gen,
          component_launcher_class=in_process_component_launcher
          .InProcessComponentLauncher,
          depends_on=set(),
          pipeline=pipeline,
          pipeline_name=self._test_pipeline_name,
          pipeline_root=test_pipeline_root,
          tfx_image='container_image',
          kubeflow_metadata_config=self._metadata_config,
          component_config=None,
          tfx_ir=self._tfx_ir)
      self.statistics_gen = base_component.BaseComponent(
          component=statistics_gen,
          component_launcher_class=in_process_component_launcher
          .InProcessComponentLauncher,
          depends_on=set(),
          pipeline=pipeline,
          pipeline_name=self._test_pipeline_name,
          pipeline_root=test_pipeline_root,
          tfx_image='container_image',
          kubeflow_metadata_config=self._metadata_config,
          component_config=None,
          tfx_ir=self._tfx_ir
      )

    self.tfx_example_gen = example_gen
    self.tfx_statistics_gen = statistics_gen
コード例 #3
0
ファイル: base_component_test.py プロジェクト: htahir1/tfx
    def setUp(self):
        super(BaseComponentTest, self).setUp()
        example_gen = csv_example_gen_component.CsvExampleGen(
            input_base='data_input')
        statistics_gen = statistics_gen_component.StatisticsGen(
            examples=example_gen.outputs['examples']).with_id('foo')

        pipeline = tfx_pipeline.Pipeline(
            pipeline_name=self._test_pipeline_name,
            pipeline_root='test_pipeline_root',
            metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
            components=[example_gen, statistics_gen],
        )

        test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param')

        self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
        self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST'
        self._tfx_ir = pipeline_pb2.Pipeline()
        with dsl.Pipeline('test_pipeline'):
            self.component = base_component.BaseComponent(
                component=statistics_gen,
                depends_on=set(),
                pipeline=pipeline,
                pipeline_root=test_pipeline_root,
                tfx_image='container_image',
                kubeflow_metadata_config=self._metadata_config,
                tfx_ir=self._tfx_ir,
            )
        self.tfx_component = statistics_gen
コード例 #4
0
ファイル: base_component_test.py プロジェクト: sycdesign/tfx
  def setUp(self):
    super(BaseComponentTest, self).setUp()
    examples = standard_artifacts.ExternalArtifact()
    example_gen = csv_example_gen_component.CsvExampleGen(
        input=channel_utils.as_channel([examples]))
    statistics_gen = statistics_gen_component.StatisticsGen(
        examples=example_gen.outputs['examples'], instance_name='foo')

    pipeline = tfx_pipeline.Pipeline(
        pipeline_name=self._test_pipeline_name,
        pipeline_root='test_pipeline_root',
        metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
        components=[example_gen, statistics_gen],
    )

    test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param')

    self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
    self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST'
    with dsl.Pipeline('test_pipeline'):
      self.component = base_component.BaseComponent(
          component=statistics_gen,
          component_launcher_class=in_process_component_launcher
          .InProcessComponentLauncher,
          depends_on=set(),
          pipeline=pipeline,
          pipeline_name=self._test_pipeline_name,
          pipeline_root=test_pipeline_root,
          tfx_image='container_image',
          kubeflow_metadata_config=self._metadata_config,
          component_config=None,
      )
    self.tfx_component = statistics_gen
コード例 #5
0
    def testDeprecatedMysqlMetadataConnectionConfig(self):
        self._set_required_env_vars({
            'mysql_host': 'mysql',
            'mysql_port': '3306',
            'mysql_database': 'metadb',
            'mysql_user_name': 'root',
            'mysql_user_password': '******'
        })

        metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
        metadata_config.mysql_db_service_host.environment_variable = 'mysql_host'
        metadata_config.mysql_db_service_port.environment_variable = 'mysql_port'
        metadata_config.mysql_db_name.environment_variable = 'mysql_database'
        metadata_config.mysql_db_user.environment_variable = 'mysql_user_name'
        metadata_config.mysql_db_password.environment_variable = 'mysql_user_password'

        ml_metadata_config = container_entrypoint._get_metadata_connection_config(
            metadata_config)
        self.assertIsInstance(ml_metadata_config,
                              metadata_store_pb2.ConnectionConfig)
        self.assertEqual(ml_metadata_config.mysql.host, 'mysql')
        self.assertEqual(ml_metadata_config.mysql.port, 3306)
        self.assertEqual(ml_metadata_config.mysql.database, 'metadb')
        self.assertEqual(ml_metadata_config.mysql.user, 'root')
        self.assertEqual(ml_metadata_config.mysql.password, 'test')
コード例 #6
0
  def setUp(self):
    super(BaseComponentTest, self).setUp()
    examples = standard_artifacts.ExternalArtifact()
    example_gen = csv_example_gen_component.CsvExampleGen(
        input_base=channel_utils.as_channel([examples]))
    statistics_gen = statistics_gen_component.StatisticsGen(
        input_data=example_gen.outputs.examples, instance_name='foo')

    pipeline = tfx_pipeline.Pipeline(
        pipeline_name='test_pipeline',
        pipeline_root='test_pipeline_root',
        metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
        components=[example_gen, statistics_gen],
    )

    self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
    self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST'
    with dsl.Pipeline('test_pipeline'):
      self.component = base_component.BaseComponent(
          component=statistics_gen,
          depends_on=set(),
          pipeline=pipeline,
          tfx_image='container_image',
          kubeflow_metadata_config=self._metadata_config,
      )
    self.tfx_component = statistics_gen
コード例 #7
0
def _get_kubeflow_metadata_config() -> kubeflow_pb2.KubeflowMetadataConfig:
  config = kubeflow_pb2.KubeflowMetadataConfig()
  config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST'
  config.mysql_db_service_port.environment_variable = 'MYSQL_SERVICE_PORT'
  config.mysql_db_name.value = 'metadb'
  config.mysql_db_user.value = 'root'
  config.mysql_db_password.value = ''
  return config
コード例 #8
0
ファイル: test_utils.py プロジェクト: SImtiaz/tfx
 def _get_kubeflow_metadata_config(
         self, pipeline_name: Text) -> kubeflow_pb2.KubeflowMetadataConfig:
     config = kubeflow_pb2.KubeflowMetadataConfig()
     config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST'
     config.mysql_db_service_port.environment_variable = 'MYSQL_SERVICE_PORT'
     config.mysql_db_name.value = self._get_mlmd_db_name(pipeline_name)
     config.mysql_db_user.value = 'root'
     config.mysql_db_password.value = ''
     return config
コード例 #9
0
ファイル: container_entrypoint.py プロジェクト: fsx950223/tfx
def main():
    # Log to the container's stdout so Kubeflow Pipelines UI can display logs to
    # the user.
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    logging.getLogger().setLevel(logging.INFO)

    parser = argparse.ArgumentParser()
    parser.add_argument('--pipeline_name', type=str, required=True)
    parser.add_argument('--pipeline_root', type=str, required=True)
    parser.add_argument('--kubeflow_metadata_config', type=str, required=True)
    parser.add_argument('--beam_pipeline_args', type=str, required=True)
    parser.add_argument('--additional_pipeline_args', type=str, required=True)
    parser.add_argument('--component_launcher_class_path',
                        type=str,
                        required=True)
    parser.add_argument('--enable_cache', action='store_true')
    parser.add_argument('--serialized_component', type=str, required=True)
    parser.add_argument('--component_config', type=str, required=True)

    args = parser.parse_args()

    component = json_utils.loads(args.serialized_component)
    component_config = json_utils.loads(args.component_config)
    component_launcher_class = import_utils.import_class_by_path(
        args.component_launcher_class_path)
    if not issubclass(component_launcher_class,
                      base_component_launcher.BaseComponentLauncher):
        raise TypeError(
            'component_launcher_class "%s" is not subclass of base_component_launcher.BaseComponentLauncher'
            % component_launcher_class)

    kubeflow_metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
    json_format.Parse(args.kubeflow_metadata_config, kubeflow_metadata_config)
    metadata_connection = kubeflow_metadata_adapter.KubeflowMetadataAdapter(
        _get_metadata_connection_config(kubeflow_metadata_config))
    driver_args = data_types.DriverArgs(enable_cache=args.enable_cache)

    beam_pipeline_args = _make_beam_pipeline_args(args.beam_pipeline_args)

    additional_pipeline_args = json.loads(args.additional_pipeline_args)

    launcher = component_launcher_class.create(
        component=component,
        pipeline_info=data_types.PipelineInfo(
            pipeline_name=args.pipeline_name,
            pipeline_root=args.pipeline_root,
            run_id=os.environ['WORKFLOW_ID']),
        driver_args=driver_args,
        metadata_connection=metadata_connection,
        beam_pipeline_args=beam_pipeline_args,
        additional_pipeline_args=additional_pipeline_args,
        component_config=component_config)

    execution_info = launcher.launch()

    # Dump the UI metadata.
    _dump_ui_metadata(component, execution_info)
コード例 #10
0
def get_kubeflow_metadata_config(
        pipeline_name: Text) -> kubeflow_pb2.KubeflowMetadataConfig:
    config = kubeflow_pb2.KubeflowMetadataConfig()
    config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST'
    config.mysql_db_service_port.environment_variable = 'MYSQL_SERVICE_PORT'
    # MySQL database name cannot exceed 64 characters.
    config.mysql_db_name.value = 'mlmd_{}'.format(pipeline_name[-59:])
    config.mysql_db_user.value = 'root'
    config.mysql_db_password.value = ''
    return config
コード例 #11
0
ファイル: base_component_test.py プロジェクト: tvalentyn/tfx
    def setUp(self):
        super(BaseComponentWithPipelineParamTest, self).setUp()

        test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param')
        example_gen_output_name = runtime_string_parameter.RuntimeStringParameter(
            name='example-gen-output-name', default='default-to-be-discarded')

        examples = standard_artifacts.ExternalArtifact()
        example_gen = csv_example_gen_component.CsvExampleGen(
            input=channel_utils.as_channel([examples]),
            output_config=example_gen_pb2.Output(
                split_config=example_gen_pb2.SplitConfig(splits=[
                    example_gen_pb2.SplitConfig.Split(
                        name=example_gen_output_name, hash_buckets=10)
                ])))
        statistics_gen = statistics_gen_component.StatisticsGen(
            examples=example_gen.outputs['examples'], instance_name='foo')

        pipeline = tfx_pipeline.Pipeline(
            pipeline_name=self._test_pipeline_name,
            pipeline_root='test_pipeline_root',
            metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
            components=[example_gen, statistics_gen],
        )

        self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
        self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST'
        with dsl.Pipeline('test_pipeline'):
            self.example_gen = base_component.BaseComponent(
                component=example_gen,
                component_launcher_class=in_process_component_launcher.
                InProcessComponentLauncher,
                depends_on=set(),
                pipeline=pipeline,
                pipeline_name=self._test_pipeline_name,
                pipeline_root=test_pipeline_root,
                tfx_image='container_image',
                kubeflow_metadata_config=self._metadata_config,
                component_config=None)
            self.statistics_gen = base_component.BaseComponent(
                component=statistics_gen,
                component_launcher_class=in_process_component_launcher.
                InProcessComponentLauncher,
                depends_on=set(),
                pipeline=pipeline,
                pipeline_name=self._test_pipeline_name,
                pipeline_root=test_pipeline_root,
                tfx_image='container_image',
                kubeflow_metadata_config=self._metadata_config,
                component_config=None,
            )

        self.tfx_example_gen = example_gen
        self.tfx_statistics_gen = statistics_gen
コード例 #12
0
ファイル: base_component_test.py プロジェクト: jay90099/tfx
    def setUp(self):
        super().setUp()

        example_gen_output_config = data_types.RuntimeParameter(
            name='example-gen-output-config', ptype=str)

        example_gen = csv_example_gen_component.CsvExampleGen(
            input_base='data_root', output_config=example_gen_output_config)
        statistics_gen = statistics_gen_component.StatisticsGen(
            examples=example_gen.outputs['examples']).with_id('foo')

        test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param')
        pipeline = tfx_pipeline.Pipeline(
            pipeline_name=self._test_pipeline_name,
            pipeline_root='test_pipeline_root',
            metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
            components=[example_gen, statistics_gen],
        )

        self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
        self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST'
        self._tfx_ir = pipeline_pb2.Pipeline()
        with dsl.Pipeline('test_pipeline'):
            self.example_gen = base_component.BaseComponent(
                component=example_gen,
                depends_on=set(),
                pipeline=pipeline,
                pipeline_root=test_pipeline_root,
                tfx_image='container_image',
                kubeflow_metadata_config=self._metadata_config,
                tfx_ir=self._tfx_ir,
                pod_labels_to_attach={},
                runtime_parameters=[example_gen_output_config])
            self.statistics_gen = base_component.BaseComponent(
                component=statistics_gen,
                depends_on=set(),
                pipeline=pipeline,
                pipeline_root=test_pipeline_root,
                tfx_image='container_image',
                kubeflow_metadata_config=self._metadata_config,
                tfx_ir=self._tfx_ir,
                pod_labels_to_attach={},
                runtime_parameters=[])

        self.tfx_example_gen = example_gen
        self.tfx_statistics_gen = statistics_gen
コード例 #13
0
    def testGrpcMetadataConnectionConfig(self):
        self._set_required_env_vars({
            'METADATA_GRPC_SERVICE_HOST': 'metadata-grpc',
            'METADATA_GRPC_SERVICE_PORT': '8080',
        })

        grpc_config = kubeflow_pb2.KubeflowGrpcMetadataConfig()
        grpc_config.grpc_service_host.environment_variable = 'METADATA_GRPC_SERVICE_HOST'
        grpc_config.grpc_service_port.environment_variable = 'METADATA_GRPC_SERVICE_PORT'
        metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
        metadata_config.grpc_config.CopyFrom(grpc_config)

        ml_metadata_config = container_entrypoint._get_metadata_connection_config(
            metadata_config)
        self.assertIsInstance(ml_metadata_config,
                              metadata_store_pb2.MetadataStoreClientConfig)
        self.assertEqual(ml_metadata_config.host, 'metadata-grpc')
        self.assertEqual(ml_metadata_config.port, 8080)
コード例 #14
0
ファイル: kubeflow_dag_runner.py プロジェクト: suryaavala/tfx
def get_default_kubeflow_metadata_config(
) -> kubeflow_pb2.KubeflowMetadataConfig:
  """Returns the default metadata connection config for Kubeflow.

  Returns:
    A config proto that will be serialized as JSON and passed to the running
    container so the TFX component driver is able to communicate with MLMD in
    a Kubeflow cluster.
  """
  # The default metadata configuration for a Kubeflow Pipelines cluster is
  # codified as a Kubernetes ConfigMap
  # https://github.com/kubeflow/pipelines/blob/master/manifests/kustomize/base/metadata/metadata-grpc-configmap.yaml

  config = kubeflow_pb2.KubeflowMetadataConfig()
  # The environment variable to use to obtain the Metadata gRPC service host in
  # the cluster that is backing Kubeflow Metadata. Note that the key in the
  # config map and therefore environment variable used, are lower-cased.
  config.grpc_config.grpc_service_host.environment_variable = 'METADATA_GRPC_SERVICE_HOST'
  # The environment variable to use to obtain the Metadata grpc service port in
  # the cluster that is backing Kubeflow Metadata.
  config.grpc_config.grpc_service_port.environment_variable = 'METADATA_GRPC_SERVICE_PORT'

  return config
コード例 #15
0
ファイル: kubeflow_dag_runner.py プロジェクト: galgoogle/tfx
def get_default_kubeflow_metadata_config(
) -> kubeflow_pb2.KubeflowMetadataConfig:
  """Returns the default metadata connection config for Kubeflow.

  Returns:
    A config proto that will be serialized as JSON and passed to the running
    container so the TFX component driver is able to communicate with MLMD in
    a Kubeflow cluster.
  """
  # The default metadata configuration for a Kubeflow cluster can be found
  # here:
  # https://github.com/kubeflow/manifests/blob/master/metadata/base/metadata-db-deployment.yaml

  # If deploying Kubeflow Pipelines outside of Kubeflow, that configuration
  # lives here:
  # https://github.com/kubeflow/pipelines/blob/master/manifests/kustomize/base/mysql/mysql-deployment.yaml

  config = kubeflow_pb2.KubeflowMetadataConfig()
  # The environment variable to use to obtain the MySQL service host in the
  # cluster that is backing Kubeflow Metadata.
  config.mysql_db_service_host.environment_variable = 'METADATA_DB_SERVICE_HOST'
  # The environment variable to use to obtain the MySQL service port in the
  # cluster that is backing Kubeflow Metadata.
  config.mysql_db_service_port.environment_variable = 'METADATA_DB_SERVICE_PORT'
  # The MySQL database name to use.
  config.mysql_db_name.value = 'metadb'
  # The MySQL database username.
  config.mysql_db_user.value = 'root'
  # The MySQL database password. It is currently set to `test` for the
  # default install of Kubeflow Metadata:
  # https://github.com/kubeflow/manifests/blob/master/metadata/base/metadata-db-secret.yaml
  # Note that you should ideally use k8s secrets for username/passwords. If you
  # do so, you can change this setting so the container obtains the value at
  # runtime from the secred mounted as an environment variable.
  config.mysql_db_password.value = 'test'

  return config
コード例 #16
0
ファイル: container_entrypoint.py プロジェクト: htahir1/tfx
def main():
    # Log to the container's stdout so Kubeflow Pipelines UI can display logs to
    # the user.
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    logging.getLogger().setLevel(logging.INFO)

    parser = argparse.ArgumentParser()
    parser.add_argument('--pipeline_root', type=str, required=True)
    parser.add_argument('--kubeflow_metadata_config', type=str, required=True)
    parser.add_argument('--serialized_component', type=str, required=True)
    parser.add_argument('--tfx_ir', type=str, required=True)
    parser.add_argument('--node_id', type=str, required=True)
    launcher._register_execution = _register_execution  # pylint: disable=protected-access

    args = parser.parse_args()

    tfx_ir = pipeline_pb2.Pipeline()
    json_format.Parse(args.tfx_ir, tfx_ir)
    # Substitute the runtime parameter to be a concrete run_id
    runtime_parameter_utils.substitute_runtime_parameter(
        tfx_ir, {
            constants.PIPELINE_RUN_ID_PARAMETER_NAME:
            os.environ['WORKFLOW_ID'],
        })

    deployment_config = runner_utils.extract_local_deployment_config(tfx_ir)

    kubeflow_metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
    json_format.Parse(args.kubeflow_metadata_config, kubeflow_metadata_config)
    metadata_connection = kubeflow_metadata_adapter.KubeflowMetadataAdapter(
        _get_metadata_connection_config(kubeflow_metadata_config))

    node_id = args.node_id
    # Attach necessary labels to distinguish different runner and DSL.
    # TODO(zhitaoli): Pass this from KFP runner side when the same container
    # entrypoint can be used by a different runner.
    with telemetry_utils.scoped_labels({
            telemetry_utils.LABEL_TFX_RUNNER: 'kfp',
    }):
        custom_executor_operators = {
            executable_spec_pb2.ContainerExecutableSpec:
            kubernetes_executor_operator.KubernetesExecutorOperator
        }

        executor_spec = runner_utils.extract_executor_spec(
            deployment_config, node_id)
        custom_driver_spec = runner_utils.extract_custom_driver_spec(
            deployment_config, node_id)

        pipeline_node = _get_pipeline_node(tfx_ir, node_id)
        component_launcher = launcher.Launcher(
            pipeline_node=pipeline_node,
            mlmd_connection=metadata_connection,
            pipeline_info=tfx_ir.pipeline_info,
            pipeline_runtime_spec=tfx_ir.runtime_spec,
            executor_spec=executor_spec,
            custom_driver_spec=custom_driver_spec,
            custom_executor_operators=custom_executor_operators)
        logging.info('Component %s is running.', node_id)
        execution_info = component_launcher.launch()
        logging.info('Component %s is finished.', node_id)

    # Dump the UI metadata.
    _dump_ui_metadata(pipeline_node, execution_info)
コード例 #17
0
def main(argv):
    # Log to the container's stdout so Kubeflow Pipelines UI can display logs to
    # the user.
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    logging.getLogger().setLevel(logging.INFO)

    parser = argparse.ArgumentParser()
    parser.add_argument('--pipeline_root', type=str, required=True)
    parser.add_argument('--metadata_ui_path',
                        type=str,
                        required=False,
                        default='/mlpipeline-ui-metadata.json')
    parser.add_argument('--kubeflow_metadata_config', type=str, required=True)
    parser.add_argument('--tfx_ir', type=str, required=True)
    parser.add_argument('--node_id', type=str, required=True)
    # There might be multiple runtime parameters.
    # `args.runtime_parameter` should become List[str] by using "append".
    parser.add_argument('--runtime_parameter', type=str, action='append')

    # TODO(b/196892362): Replace hooking with a more straightforward mechanism.
    launcher._register_execution = _register_execution  # pylint: disable=protected-access

    args = parser.parse_args(argv)

    tfx_ir = pipeline_pb2.Pipeline()
    json_format.Parse(args.tfx_ir, tfx_ir)

    _resolve_runtime_parameters(tfx_ir, args.runtime_parameter)

    deployment_config = runner_utils.extract_local_deployment_config(tfx_ir)

    kubeflow_metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
    json_format.Parse(args.kubeflow_metadata_config, kubeflow_metadata_config)
    metadata_connection = metadata.Metadata(
        _get_metadata_connection_config(kubeflow_metadata_config))

    node_id = args.node_id
    # Attach necessary labels to distinguish different runner and DSL.
    # TODO(zhitaoli): Pass this from KFP runner side when the same container
    # entrypoint can be used by a different runner.
    with telemetry_utils.scoped_labels({
            telemetry_utils.LABEL_TFX_RUNNER: 'kfp',
    }):
        custom_executor_operators = {
            executable_spec_pb2.ContainerExecutableSpec:
            kubernetes_executor_operator.KubernetesExecutorOperator
        }

        executor_spec = runner_utils.extract_executor_spec(
            deployment_config, node_id)
        custom_driver_spec = runner_utils.extract_custom_driver_spec(
            deployment_config, node_id)

        pipeline_node = _get_pipeline_node(tfx_ir, node_id)
        component_launcher = launcher.Launcher(
            pipeline_node=pipeline_node,
            mlmd_connection=metadata_connection,
            pipeline_info=tfx_ir.pipeline_info,
            pipeline_runtime_spec=tfx_ir.runtime_spec,
            executor_spec=executor_spec,
            custom_driver_spec=custom_driver_spec,
            custom_executor_operators=custom_executor_operators)
        logging.info('Component %s is running.', node_id)
        execution_info = component_launcher.launch()
        logging.info('Component %s is finished.', node_id)

    # Dump the UI metadata.
    _dump_ui_metadata(pipeline_node, execution_info, args.metadata_ui_path)
コード例 #18
0
        'runtimeVersion': _runtime_version,
        'pythonVersion': _python_version
    }

    # Dataflow settings.
    _beam_tmp_folder = '{}/beam/tmp'.format(_artifact_store_uri)
    _beam_pipeline_args = [
        '--runner=DataflowRunner',
        '--experiments=shuffle_mode=auto',
        '--project=' + _project_id,
        '--temp_location=' + _beam_tmp_folder,
        '--region=' + _gcp_region,
    ]

    # ML Metadata settings
    _metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
    _metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST'
    _metadata_config.mysql_db_service_port.environment_variable = 'MYSQL_SERVICE_PORT'
    _metadata_config.mysql_db_name.value = 'metadb'
    _metadata_config.mysql_db_user.environment_variable = 'MYSQL_USERNAME'
    _metadata_config.mysql_db_password.environment_variable = 'MYSQL_PASSWORD'

    operator_funcs = [
        gcp.use_gcp_secret('user-gcp-sa'),
        use_mysql_secret('mysql-credential')
    ]

    # Compile the pipeline
    runner_config = kubeflow_dag_runner.KubeflowDagRunnerConfig(
        kubeflow_metadata_config=_metadata_config,
        pipeline_operator_funcs=operator_funcs,
コード例 #19
0
ファイル: container_entrypoint.py プロジェクト: yifanmai/tfx
def main():
  # Log to the container's stdout so Kubeflow Pipelines UI can display logs to
  # the user.
  logging.basicConfig(stream=sys.stdout, level=logging.INFO)
  logging.getLogger().setLevel(logging.INFO)

  parser = argparse.ArgumentParser()
  parser.add_argument('--pipeline_name', type=str, required=True)
  parser.add_argument('--pipeline_root', type=str, required=True)
  parser.add_argument('--kubeflow_metadata_config', type=str, required=True)
  parser.add_argument('--beam_pipeline_args', type=str, required=True)
  parser.add_argument('--additional_pipeline_args', type=str, required=True)
  parser.add_argument(
      '--component_launcher_class_path', type=str, required=True)
  parser.add_argument('--enable_cache', action='store_true')
  parser.add_argument('--serialized_component', type=str, required=True)
  parser.add_argument('--component_config', type=str, required=True)

  args = parser.parse_args()

  component = json_utils.loads(args.serialized_component)
  component_config = json_utils.loads(args.component_config)
  component_launcher_class = import_utils.import_class_by_path(
      args.component_launcher_class_path)
  if not issubclass(component_launcher_class,
                    base_component_launcher.BaseComponentLauncher):
    raise TypeError(
        'component_launcher_class "%s" is not subclass of base_component_launcher.BaseComponentLauncher'
        % component_launcher_class)

  kubeflow_metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
  json_format.Parse(args.kubeflow_metadata_config, kubeflow_metadata_config)
  metadata_connection = kubeflow_metadata_adapter.KubeflowMetadataAdapter(
      _get_metadata_connection_config(kubeflow_metadata_config))
  driver_args = data_types.DriverArgs(enable_cache=args.enable_cache)

  beam_pipeline_args = json.loads(args.beam_pipeline_args)

  additional_pipeline_args = json.loads(args.additional_pipeline_args)

  launcher = component_launcher_class.create(
      component=component,
      pipeline_info=data_types.PipelineInfo(
          pipeline_name=args.pipeline_name,
          pipeline_root=args.pipeline_root,
          run_id=os.environ['WORKFLOW_ID']),
      driver_args=driver_args,
      metadata_connection=metadata_connection,
      beam_pipeline_args=beam_pipeline_args,
      additional_pipeline_args=additional_pipeline_args,
      component_config=component_config)

  # Attach necessary labels to distinguish different runner and DSL.
  # TODO(zhitaoli): Pass this from KFP runner side when the same container
  # entrypoint can be used by a different runner.
  with telemetry_utils.scoped_labels({
      telemetry_utils.LABEL_TFX_RUNNER: 'kfp',
  }):
    execution_info = launcher.launch()

  # Dump the UI metadata.
  _dump_ui_metadata(component, execution_info)
コード例 #20
0
def main():
    # Log to the container's stdout so Kubeflow Pipelines UI can display logs to
    # the user.
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    logging.getLogger().setLevel(logging.INFO)

    parser = argparse.ArgumentParser()
    parser.add_argument('--pipeline_name', type=str, required=True)
    parser.add_argument('--pipeline_root', type=str, required=True)
    parser.add_argument('--kubeflow_metadata_config', type=str, required=True)
    parser.add_argument('--additional_pipeline_args', type=str, required=True)
    parser.add_argument('--component_id', type=str, required=True)
    parser.add_argument('--component_type', type=str, required=True)
    parser.add_argument('--driver_class_path', type=str, required=True)
    parser.add_argument('--executor_spec', type=str, required=True)
    parser.add_argument('--component_launcher_class_path',
                        type=str,
                        required=True)
    parser.add_argument('--inputs', type=str, required=True)
    parser.add_argument('--outputs', type=str, required=True)
    parser.add_argument('--exec_properties', type=str, required=True)
    parser.add_argument('--enable_cache', action='store_true')

    args = parser.parse_args()

    inputs = artifact_utils.parse_artifact_dict(args.inputs)
    input_dict = _make_channel_dict(inputs)

    outputs = artifact_utils.parse_artifact_dict(args.outputs)
    output_dict = _make_channel_dict(outputs)

    exec_properties = json.loads(args.exec_properties)

    driver_class = import_utils.import_class_by_path(args.driver_class_path)
    executor_spec = json_utils.loads(args.executor_spec)

    component_launcher_class = import_utils.import_class_by_path(
        args.component_launcher_class_path)
    if not issubclass(component_launcher_class,
                      base_component_launcher.BaseComponentLauncher):
        raise TypeError(
            'component_launcher_class "%s" is not subclass of base_component_launcher.BaseComponentLauncher'
            % component_launcher_class)

    kubeflow_metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
    json_format.Parse(args.kubeflow_metadata_config, kubeflow_metadata_config)
    connection_config = _get_metadata_connection_config(
        kubeflow_metadata_config)

    component_info = data_types.ComponentInfo(
        component_type=args.component_type, component_id=args.component_id)

    driver_args = data_types.DriverArgs(enable_cache=args.enable_cache)

    additional_pipeline_args = _make_additional_pipeline_args(
        args.additional_pipeline_args)

    # TODO(hongyes): create a classmethod to create launcher from a deserialized
    # component.
    launcher = component_launcher_class(
        component_info=component_info,
        driver_class=driver_class,
        component_executor_spec=executor_spec,
        input_dict=input_dict,
        output_dict=output_dict,
        exec_properties=exec_properties,
        pipeline_info=data_types.PipelineInfo(
            pipeline_name=args.pipeline_name,
            pipeline_root=args.pipeline_root,
            run_id=os.environ['WORKFLOW_ID']),
        driver_args=driver_args,
        metadata_connection_config=connection_config,
        additional_pipeline_args=additional_pipeline_args)

    launcher.launch()