def main():
    # Log to the container's stdout so it can be streamed by the client.
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    logging.getLogger().setLevel(logging.INFO)

    parser = argparse.ArgumentParser()

    # Pipeline is serialized via a json format.
    # See kubernetes_remote_runner._serialize_pipeline for details.
    parser.add_argument('--serialized_pipeline', type=str, required=True)
    parser.add_argument('--tfx_image', type=str, required=True)
    args = parser.parse_args()

    kubernetes_dag_runner.KubernetesDagRunner(
        config=kubernetes_dag_runner.KubernetesDagRunnerConfig(
            tfx_image=args.tfx_image)).run(
                kubernetes_remote_runner.deserialize_pipeline(
                    args.serialized_pipeline))
Exemple #2
0
    def testRunWithSameSpec(self, mock_kube_utils):
        _initialize_executed_components()
        mock_kube_utils.is_inside_cluster.return_value = True

        component_a = _FakeComponent(spec=_FakeComponentSpecA(
            output=types.Channel(type=_ArtifactTypeA)))
        component_f1 = _FakeComponent(spec=_FakeComponentSpecF(
            a=component_a.outputs['output'])).with_id('f1')
        component_f2 = _FakeComponent(spec=_FakeComponentSpecF(
            a=component_a.outputs['output'])).with_id('f2')
        component_f2.add_upstream_node(component_f1)

        test_pipeline = pipeline.Pipeline(
            pipeline_name='x',
            pipeline_root='y',
            metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
            components=[component_f1, component_f2, component_a])
        kubernetes_dag_runner.KubernetesDagRunner().run(test_pipeline)
        self.assertEqual(_executed_components,
                         ['a.Wrapper', 'f1.Wrapper', 'f2.Wrapper'])
Exemple #3
0
  def testRun(self, mock_kube_utils):
    _initialize_executed_components()
    mock_kube_utils.is_inside_cluster.return_value = True

    component_a = _FakeComponent(
        spec=_FakeComponentSpecA(output=types.Channel(type=_ArtifactTypeA)))
    component_b = _FakeComponent(
        spec=_FakeComponentSpecB(
            a=component_a.outputs['output'],
            output=types.Channel(type=_ArtifactTypeB)))
    component_c = _FakeComponent(
        spec=_FakeComponentSpecC(
            a=component_a.outputs['output'],
            output=types.Channel(type=_ArtifactTypeC)))
    component_c.add_upstream_node(component_b)
    component_d = _FakeComponent(
        spec=_FakeComponentSpecD(
            b=component_b.outputs['output'],
            c=component_c.outputs['output'],
            output=types.Channel(type=_ArtifactTypeD)))
    component_e = _FakeComponent(
        spec=_FakeComponentSpecE(
            a=component_a.outputs['output'],
            b=component_b.outputs['output'],
            d=component_d.outputs['output'],
            output=types.Channel(type=_ArtifactTypeE)))

    test_pipeline = pipeline.Pipeline(
        pipeline_name='x',
        pipeline_root='y',
        metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
        components=[
            component_d, component_c, component_a, component_b, component_e
        ])

    kubernetes_dag_runner.KubernetesDagRunner().run(test_pipeline)
    self.assertEqual(_executed_components, [
        '_FakeComponent.a.Wrapper', '_FakeComponent.b.Wrapper',
        '_FakeComponent.c.Wrapper', '_FakeComponent.d.Wrapper',
        '_FakeComponent.e.Wrapper'
    ])
Exemple #4
0
  return pipeline.Pipeline(
      pipeline_name=pipeline_name,
      pipeline_root=pipeline_root,
      components=[
          example_gen,
          statistics_gen,
          schema_gen,
          example_validator,
          transform,
          trainer,
          model_resolver,
          evaluator,
          pusher,
      ],
      enable_cache=False,
      metadata_connection_config=config,
      beam_pipeline_args=beam_pipeline_args)


if __name__ == '__main__':
  absl.logging.set_verbosity(absl.logging.INFO)

  kubernetes_dag_runner.KubernetesDagRunner().run(
      create_pipeline(
          pipeline_name=_pipeline_name,
          pipeline_root=_pipeline_root,
          data_root=_data_root,
          module_file=_module_file,
          serving_model_dir=_serving_model_dir,
          beam_pipeline_args=_beam_pipeline_args))
      components=[
          example_gen,
          statistics_gen,
          schema_gen,
          example_validator,
          transform,
          trainer,
          model_resolver,
          evaluator,
          pusher,
      ],
      enable_cache=False,
      metadata_connection_config=config,
      # TODO(b/142684737): The multi-processing API might change.
      beam_pipeline_args=['--direct_num_workers=%d' % direct_num_workers])


if __name__ == '__main__':
  absl.logging.set_verbosity(absl.logging.INFO)

  kubernetes_dag_runner.KubernetesDagRunner().run(
      create_pipeline(
          pipeline_name=_pipeline_name,
          pipeline_root=_pipeline_root,
          data_root=_data_root,
          module_file=_module_file,
          serving_model_dir=_serving_model_dir,
          # 0 means auto-detect based on the number of CPUs available during
          # execution time.
          direct_num_workers=0))
Exemple #6
0
def main():
    # First, create the tfx pipeline instance.
    pipeline = _create_pipeline()
    # Use kubernetes dag runner to run the pipeline.
    kubernetes_dag_runner.KubernetesDagRunner().run(pipeline=pipeline)