Пример #1
0
    def get_schema(self):
        pipeline_name = self.flags_dict[labels.PIPELINE_NAME]

        # Check if pipeline exists.
        self._check_pipeline_existence(pipeline_name)

        # Path to pipeline args.
        pipeline_args_path = os.path.join(
            self._handler_home_dir, self.flags_dict[labels.PIPELINE_NAME],
            'pipeline_args.json')

        # Get pipeline_root.
        with open(pipeline_args_path, 'r') as f:
            pipeline_args = json.load(f)

        # Check if pipeline root created. If not, it means that the user has not
        # created a run yet or the pipeline is still running for the first time.
        pipeline_root = pipeline_args[labels.PIPELINE_ROOT]
        if not tf.io.gfile.exists(pipeline_root):
            sys.exit(
                'Create a run before inferring schema. If pipeline is already running, then wait for it to successfully finish.'
            )

        # If pipeline_root exists, then check if SchemaGen output exists.
        components = tf.io.gfile.listdir(pipeline_root)
        if 'SchemaGen' not in components:
            sys.exit(
                'Either SchemaGen component does not exist or pipeline is still running. If pipeline is running, then wait for it to successfully finish.'
            )

        # Get the latest SchemaGen output.
        component_output_dir = os.path.join(pipeline_root, 'SchemaGen')
        schema1_uri = base_driver._generate_output_uri(  # pylint: disable=protected-access
            component_output_dir, 'schema', 1)
        schema_dir = os.path.join(os.path.dirname(schema1_uri), '')
        schemagen_outputs = tf.io.gfile.listdir(schema_dir)
        latest_schema_folder = max(schemagen_outputs, key=int)

        # Copy schema to current dir.
        latest_schema_uri = base_driver._generate_output_uri(  # pylint: disable=protected-access
            component_output_dir, 'schema', latest_schema_folder)
        latest_schema_path = os.path.join(latest_schema_uri, 'schema.pbtxt')
        curr_dir_path = os.path.join(os.getcwd(), 'schema.pbtxt')
        io_utils.copy_file(latest_schema_path, curr_dir_path, overwrite=True)

        # Print schema and path to schema
        click.echo('Path to schema: {}'.format(curr_dir_path))
        click.echo('*********SCHEMA FOR {}**********'.format(
            pipeline_name.upper()))
        with open(curr_dir_path, 'r') as f:
            click.echo(f.read())
Пример #2
0
  def _prepare_output_artifacts(
      self,
      input_artifacts: Dict[Text, List[types.Artifact]],
      output_dict: Dict[Text, types.Channel],
      exec_properties: Dict[Text, Any],
      execution_id: int,
      pipeline_info: data_types.PipelineInfo,
      component_info: data_types.ComponentInfo,
  ) -> Dict[Text, List[types.Artifact]]:
    """Overrides BaseDriver._prepare_output_artifacts()."""
    del input_artifacts

    result = channel_utils.unwrap_channel_dict(output_dict)
    if len(result) != 1:
      raise RuntimeError('Multiple output artifacts are not supported.')

    base_output_dir = os.path.join(pipeline_info.pipeline_root,
                                   component_info.component_id)

    example_artifact = artifact_utils.get_single_instance(
        result[utils.EXAMPLES_KEY])
    example_artifact.uri = base_driver._generate_output_uri(  # pylint: disable=protected-access
        base_output_dir, utils.EXAMPLES_KEY, execution_id)
    example_artifact.set_string_custom_property(
        utils.FINGERPRINT_PROPERTY_NAME,
        exec_properties[utils.FINGERPRINT_PROPERTY_NAME])
    example_artifact.set_string_custom_property(
        utils.SPAN_PROPERTY_NAME, exec_properties[utils.SPAN_PROPERTY_NAME])

    base_driver._prepare_output_paths(example_artifact)  # pylint: disable=protected-access

    return result
Пример #3
0
  def testPrepareOutputArtifacts(self):
    examples = standard_artifacts.Examples()
    output_dict = {utils.EXAMPLES_KEY: channel_utils.as_channel([examples])}
    exec_properties = {
        utils.SPAN_PROPERTY_NAME: 2,
        utils.VERSION_PROPERTY_NAME: 1,
        utils.FINGERPRINT_PROPERTY_NAME: 'fp'
    }

    pipeline_info = data_types.PipelineInfo(
        pipeline_name='name', pipeline_root=self._test_dir, run_id='rid')
    component_info = data_types.ComponentInfo(
        component_type='type', component_id='cid', pipeline_info=pipeline_info)

    input_artifacts = {}
    output_artifacts = self._example_gen_driver._prepare_output_artifacts(
        input_artifacts, output_dict, exec_properties, 1, pipeline_info,
        component_info)
    examples = artifact_utils.get_single_instance(
        output_artifacts[utils.EXAMPLES_KEY])
    base_output_dir = os.path.join(self._test_dir, component_info.component_id)
    expected_uri = base_driver._generate_output_uri(  # pylint: disable=protected-access
        base_output_dir, 'examples', 1)

    self.assertEqual(examples.uri, expected_uri)
    self.assertEqual(
        examples.get_string_custom_property(utils.FINGERPRINT_PROPERTY_NAME),
        'fp')
    self.assertEqual(
        examples.get_string_custom_property(utils.SPAN_PROPERTY_NAME), '2')
    self.assertEqual(
        examples.get_string_custom_property(utils.VERSION_PROPERTY_NAME), '1')
Пример #4
0
    def _prepare_output_artifacts(
        self,
        input_artifacts: Dict[Text, List[types.Artifact]],
        output_dict: Dict[Text, types.Channel],
        exec_properties: Dict[Text, Any],
        execution_id: int,
        pipeline_info: data_types.PipelineInfo,
        component_info: data_types.ComponentInfo,
    ) -> Dict[Text, List[types.Artifact]]:
        """Overrides BaseDriver._prepare_output_artifacts()."""
        del input_artifacts

        example_artifact = output_dict[utils.EXAMPLES_KEY].type()
        base_output_dir = os.path.join(pipeline_info.pipeline_root,
                                       component_info.component_id)

        example_artifact.uri = base_driver._generate_output_uri(  # pylint: disable=protected-access
            base_output_dir, utils.EXAMPLES_KEY, execution_id)
        example_artifact.set_string_custom_property(
            utils.FINGERPRINT_PROPERTY_NAME,
            exec_properties[utils.FINGERPRINT_PROPERTY_NAME])
        example_artifact.set_string_custom_property(
            utils.SPAN_PROPERTY_NAME,
            str(exec_properties[utils.SPAN_PROPERTY_NAME]))
        # TODO(b/162622803): add default behavior for when version spec not present.
        if exec_properties[utils.VERSION_PROPERTY_NAME]:
            example_artifact.set_string_custom_property(
                utils.VERSION_PROPERTY_NAME,
                str(exec_properties[utils.VERSION_PROPERTY_NAME]))

        base_driver._prepare_output_paths(example_artifact)  # pylint: disable=protected-access

        return {utils.EXAMPLES_KEY: [example_artifact]}
Пример #5
0
    def artifact_dir(self,
                     artifact_root: str,
                     artifact_subdir: str = "") -> str:
        """Returns the full path to the artifact subdir under the pipeline root.

    For example to get the transformed examples for the train split:

      `self.artifact_dir('Transform.AutoData/transformed_examples', 'train/*')`

    would return the path

      '<pipeline_root>/Transform.AutoData/transformed_examples/4/train/*'.

    Assumes there is a single execution per component.

    Args:
      artifact_root: Root subdirectory to specify the component's artifacts.
      artifact_subdir: Optional subdirectory to append after the execution
        ID.

    Returns:
      The full path to the artifact subdir under the pipeline root.
    """

        root = os.path.join(self.pipeline_root, artifact_root)
        artifact_subdirs = tf.io.gfile.listdir(root)
        if len(artifact_subdirs) != 1:
            raise ValueError(
                f"Expected a single artifact dir, got: {artifact_subdirs}")
        base_uri = base_driver._generate_output_uri(  # pylint: disable=protected-access
            self.pipeline_root, artifact_root, artifact_subdirs[0])
        return os.path.join(base_uri, artifact_subdir)
Пример #6
0
    def testPipelineSchemaSuccessfulRun(self):
        # First create a pipeline.
        flags_dict = {
            labels.ENGINE_FLAG: self.engine,
            labels.PIPELINE_DSL_PATH: self.pipeline_path
        }
        handler = beam_handler.BeamHandler(flags_dict)
        handler.create_pipeline()

        flags_dict = {
            labels.ENGINE_FLAG: self.engine,
            labels.PIPELINE_NAME: self.pipeline_name,
        }
        handler = beam_handler.BeamHandler(flags_dict)
        # Create fake schema in pipeline root.
        component_output_dir = os.path.join(self.pipeline_root, 'SchemaGen')
        schema_path = base_driver._generate_output_uri(  # pylint: disable=protected-access
            component_output_dir, 'schema', 3)

        tf.io.gfile.makedirs(schema_path)
        with open(os.path.join(schema_path, 'schema.pbtxt'), 'w') as f:
            f.write('SCHEMA')
        with self.captureWritesToStream(sys.stdout) as captured:
            handler.get_schema()
            curr_dir_path = os.path.join(os.getcwd(), 'schema.pbtxt')
            self.assertIn('Path to schema: {}'.format(curr_dir_path),
                          captured.contents())
            self.assertIn(
                '*********SCHEMA FOR {}**********'.format(
                    self.pipeline_name.upper()), captured.contents())
            self.assertTrue(tf.io.gfile.exists(curr_dir_path))
 def assertInfraValidatorPassed(self) -> None:
   infra_validator_path = os.path.join(self._pipeline_root, 'InfraValidator')
   blessing_path = os.path.join(self._pipeline_root, 'InfraValidator',
                                'blessing')
   executions = tf.io.gfile.listdir(blessing_path)
   self.assertGreaterEqual(len(executions), 1)
   for exec_id in executions:
     blessing_uri = base_driver._generate_output_uri(  # pylint: disable=protected-access
         infra_validator_path, 'blessing', exec_id)
     blessed = os.path.join(blessing_uri, 'INFRA_BLESSED')
     self.assertTrue(tf.io.gfile.exists(blessed))
Пример #8
0
    def _prepare_output_artifacts(
        self,
        input_artifacts: Dict[Text, List[types.Artifact]],
        output_dict: Dict[Text, types.Channel],
        exec_properties: Dict[Text, Any],
        execution_id: int,
        pipeline_info: data_types.PipelineInfo,
        component_info: data_types.ComponentInfo,
    ) -> Dict[Text, List[types.Artifact]]:
        """Overrides BaseDriver._prepare_output_artifacts()."""
        del input_artifacts

        example_artifact = output_dict[utils.EXAMPLES_KEY].type()
        base_output_dir = os.path.join(pipeline_info.pipeline_root,
                                       component_info.component_id)

        example_artifact.uri = base_driver._generate_output_uri(  # pylint: disable=protected-access
            base_output_dir, utils.EXAMPLES_KEY, execution_id)
        _update_output_artifact(exec_properties,
                                example_artifact.mlmd_artifact)
        base_driver._prepare_output_paths(example_artifact)  # pylint: disable=protected-access

        return {utils.EXAMPLES_KEY: [example_artifact]}