Beispiel #1
0
def ResolveSplitsConfig(
        splits_config_str: Optional[str],
        examples: List[types.Artifact]) -> transform_pb2.SplitsConfig:
    """Resolve SplitsConfig proto for the transfrom request."""
    result = transform_pb2.SplitsConfig()
    if splits_config_str:
        proto_utils.json_to_proto(splits_config_str, result)
        if not result.analyze:
            raise ValueError(
                'analyze cannot be empty when splits_config is set.')
        return result

    result.analyze.append('train')

    # All input artifacts should have the same set of split names.
    split_names = set(
        artifact_utils.decode_split_names(examples[0].split_names))

    for artifact in examples:
        artifact_split_names = set(
            artifact_utils.decode_split_names(artifact.split_names))
        if split_names != artifact_split_names:
            raise ValueError(
                'Not all input artifacts have the same split names: (%s, %s)' %
                (split_names, artifact_split_names))

    result.transform.extend(split_names)
    logging.info("Analyze the 'train' split and transform all splits when "
                 'splits_config is not set.')
    return result
Beispiel #2
0
    def testDoWithOutputExamplesTwoSplitsSampledOneSplitCopied(self):
        self._exec_properties[SPLITS_TO_TRANSFORM_KEY] = json.dumps(
            ['train', 'eval'])
        self._exec_properties[SPLITS_TO_COPY_KEY] = json.dumps(['unlabelled'])

        # Run executor.
        stratified_sampler = executor.Executor(self._context)
        stratified_sampler.Do(self._input_dict, self._output_dict_sr,
                              self._exec_properties)

        # Check outputs.
        self.assertTrue(fileio.exists(self._stratified_examples_dir))
        self.assertIn(
            'train',
            artifact_utils.decode_split_names(
                self._sampling_result.split_names))
        self.assertIn(
            'eval',
            artifact_utils.decode_split_names(
                self._sampling_result.split_names))
        self.assertIn(
            'unlabelled',
            artifact_utils.decode_split_names(
                self._sampling_result.split_names))
        self.assertLen(
            artifact_utils.decode_split_names(
                self._sampling_result.split_names), 3)
        self._verify_stratified_example_split('train')
        self._verify_copied_example_split('unlabelled')
        self._verify_stratified_example_split('eval')
Beispiel #3
0
    def _run_model_inference(
            self, data_spec: bulk_inferrer_pb2.DataSpec,
            examples: List[types.Artifact], output_uri: Text,
            inference_endpoint: model_spec_pb2.InferenceSpecType) -> bool:
        """Runs model inference on given example data.

    Args:
      data_spec: bulk_inferrer_pb2.DataSpec instance.
      examples: List of example artifacts.
      output_uri: Output artifact uri.
      inference_endpoint: Model inference endpoint.

    Returns:
      Whether the inference job succeed.
    """

        example_uris = {}
        if data_spec.example_splits:
            for example in examples:
                for split in artifact_utils.decode_split_names(
                        example.split_names):
                    if split in data_spec.example_splits:
                        example_uris[split] = os.path.join(example.uri, split)
        else:
            for example in examples:
                for split in artifact_utils.decode_split_names(
                        example.split_names):
                    example_uris[split] = os.path.join(example.uri, split)
        output_path = os.path.join(output_uri, _PREDICTION_LOGS_DIR_NAME)
        logging.info('BulkInferrer generates prediction log to %s',
                     output_path)

        with self._make_beam_pipeline() as pipeline:
            data_list = []
            for split, example_uri in example_uris.items():
                data = (
                    pipeline
                    | 'ReadData[{}]'.format(split) >> beam.io.ReadFromTFRecord(
                        file_pattern=io_utils.all_files_pattern(example_uri)))
                data_list.append(data)
            _ = (
                data_list
                | 'FlattenExamples' >> beam.Flatten(pipeline=pipeline)
                # TODO(b/131873699): Use the correct Example type here, which
                # is either Example or SequenceExample.
                | 'ParseExamples' >> beam.Map(tf.train.Example.FromString)
                | 'RunInference' >>
                run_inference.RunInference(inference_endpoint)
                | 'WritePredictionLogs' >> beam.io.WriteToTFRecord(
                    output_path,
                    file_name_suffix='.gz',
                    coder=beam.coders.ProtoCoder(
                        prediction_log_pb2.PredictionLog)))
        logging.info('Inference result written to %s.', output_path)
Beispiel #4
0
  def testDoWithOutputExamplesAllSplits(self):
    self._exec_properties[SPLITS_TO_TRANSFORM_KEY] = json.dumps(['eval', 'train'])

    # Run executor.
    stratified_sampler = executor.Executor(self._context)
    stratified_sampler.Do(self._input_dict, self._output_dict_sr,
                          self._exec_properties)

    # Check outputs.
    self.assertTrue(fileio.exists(self._filtered_examples_dir))
    self.assertIn('train', artifact_utils.decode_split_names(self._filtering_result.split_names))
    self.assertIn('eval', artifact_utils.decode_split_names(self._filtering_result.split_names))
    self.assertLen(artifact_utils.decode_split_names(self._filtering_result.split_names), 2)
    self._verify_filtered_example_split('train')
    self._verify_filtered_example_split('eval')
Beispiel #5
0
  def testDoWithOutputExamplesEvalSplit(self):
    self._exec_properties['splits_to_transform'] = json.dumps(['eval'])

    # Run executor.
    stratified_sampler = executor.Executor(self._context)
    stratified_sampler.Do(self._input_dict, self._output_dict_sr,
                          self._exec_properties)

    # Check outputs.
    self.assertTrue(fileio.exists(self._filtered_examples_dir))
    # self._verify_example_split('train')
    self.assertNotIn('train', artifact_utils.decode_split_names(self._filtering_result.split_names))
    self.assertIn('eval', artifact_utils.decode_split_names(self._filtering_result.split_names))
    self.assertLen(artifact_utils.decode_split_names(self._filtering_result.split_names), 1)
    self._verify_filtered_example_split('eval')
Beispiel #6
0
    def Do(self, input_dict: Dict[Text, List[types.Artifact]],
           output_dict: Dict[Text, List[types.Artifact]],
           exec_properties: Dict[Text, Any]) -> None:
        self._log_startup(input_dict, output_dict, exec_properties)
        logging.info('Validating schema against the computed statistics.')

        split_uris: List[Text] = []
        for artifact in input_dict[executor.STATISTICS_KEY]:
            for split in artifact_utils.decode_split_names(
                    artifact.split_names):
                split_uris.append(split)

        label_inputs = {
            labels.STATS:
            tfdv.load_statistics(
                io_utils.get_only_uri_in_dir(
                    artifact_utils.get_split_uri(
                        input_dict[executor.STATISTICS_KEY], split_uris[0]))),
            labels.SCHEMA:
            io_utils.SchemaReader().read(
                io_utils.get_only_uri_in_dir(
                    artifact_utils.get_single_uri(
                        input_dict[executor.SCHEMA_KEY])))
        }
        output_uri = artifact_utils.get_single_uri(
            output_dict[executor.ANOMALIES_KEY])
        label_outputs = {labels.SCHEMA_DIFF_PATH: output_uri}
        self._Validate(label_inputs, label_outputs)
        logging.info(
            'Validation complete. Anomalies written to {}.'.format(output_uri))
Beispiel #7
0
def SynthesizeGraph(identified_examples: InputArtifact[Examples],
                    synthesized_graph: OutputArtifact[SynthesizedGraph],
                    similarity_threshold: Parameter[float],
                    component_name: Parameter[str]) -> None:

    # Get a list of the splits in input_data
    splits_list = artifact_utils.decode_split_names(
        split_names=identified_examples.split_names)

    # We build a graph only based on the 'train' split which includes both
    # labeled and unlabeled examples.
    train_input_examples_uri = os.path.join(identified_examples.uri, 'train')
    print(train_input_examples_uri, 'These')
    output_graph_uri = os.path.join(synthesized_graph.uri, 'train')
    os.mkdir(output_graph_uri)

    print('Creating embeddings...')
    create_embeddings(train_input_examples_uri, output_graph_uri)

    print('Synthesizing graph...')
    build_graph(output_graph_uri, similarity_threshold)

    synthesized_graph.split_names = artifact_utils.encode_split_names(
        splits=['train'])

    return
Beispiel #8
0
def IdentifyExamples(orig_examples: InputArtifact[Examples],
                     identified_examples: OutputArtifact[Examples],
                     id_feature_name: Parameter[str],
                     component_name: Parameter[str]) -> None:

    # Get a list of the splits in input_data
    splits_list = artifact_utils.decode_split_names(
        split_names=orig_examples.split_names)

    for split in splits_list:
        input_dir = os.path.join(orig_examples.uri, split)
        output_dir = os.path.join(identified_examples.uri, split)
        os.mkdir(output_dir)
        with beam.Pipeline() as pipeline:
            (pipeline
             | 'ReadExamples' >> beam.io.ReadFromTFRecord(
                 os.path.join(input_dir, '*'),
                 coder=beam.coders.coders.ProtoCoder(tf.train.Example))
             | 'AddUniqueId' >> beam.Map(make_example_with_unique_id,
                                         id_feature_name)
             | 'WriteIdentifiedExamples' >> beam.io.WriteToTFRecord(
                 file_path_prefix=os.path.join(output_dir, 'data_tfrecord'),
                 coder=beam.coders.coders.ProtoCoder(tf.train.Example),
                 file_name_suffix='.gz'))
    identified_examples.split_names = artifact_utils.encode_split_names(
        splits=splits_list)

    return
Beispiel #9
0
def _prepare_output_paths(artifact: types.Artifact):
    """Create output directories for output artifact."""
    if tf.io.gfile.exists(artifact.uri):
        msg = 'Output artifact uri %s already exists' % artifact.uri
        absl.logging.warning(msg)
        # TODO(b/158689199): We currently simply return as a short-term workaround
        # to unblock execution retires. A comprehensive solution to guarantee
        # idempotent executions is needed.
        return

    # TODO(b/147242148): Introduce principled artifact structure (directory
    # or file) definition.
    if isinstance(artifact, types.ValueArtifact):
        artifact_dir = os.path.dirname(artifact.uri)
    else:
        artifact_dir = artifact.uri

    # TODO(zhitaoli): Consider refactoring this out into something
    # which can handle permission bits.
    absl.logging.debug('Creating output artifact uri %s as directory',
                       artifact_dir)
    tf.io.gfile.makedirs(artifact_dir)
    # TODO(b/147242148): Avoid special-casing the "split_names" property.
    if artifact.type.PROPERTIES and 'split_names' in artifact.type.PROPERTIES:
        split_names = artifact_utils.decode_split_names(artifact.split_names)
        for split in split_names:
            split_dir = os.path.join(artifact.uri, split)
            absl.logging.debug('Creating output split %s as directory',
                               split_dir)
            tf.io.gfile.makedirs(split_dir)
Beispiel #10
0
    def _callImporterDriver(self, reimport: bool):
        with metadata.Metadata(connection_config=self.connection_config) as m:
            m.publish_artifacts(self.existing_artifacts)
            driver = importer_node.ImporterDriver(metadata_handler=m)
            execution_result = driver.pre_execution(
                component_info=self.component_info,
                pipeline_info=self.pipeline_info,
                driver_args=self.driver_args,
                input_dict={},
                output_dict=self.output_dict,
                exec_properties={
                    importer_node.SOURCE_URI_KEY: self.source_uri,
                    importer_node.REIMPORT_OPTION_KEY: reimport,
                    importer_node.SPLIT_KEY: self.split,
                })
            self.assertFalse(execution_result.use_cached_results)
            self.assertEmpty(execution_result.input_dict)
            self.assertEqual(
                execution_result.output_dict[importer_node.IMPORT_RESULT_KEY]
                [0].uri, self.source_uri[0])
            self.assertEqual(
                execution_result.output_dict[importer_node.IMPORT_RESULT_KEY]
                [0].id, 3 if reimport else 1)

            self.assertNotEmpty(
                self.output_dict[importer_node.IMPORT_RESULT_KEY].get())

            results = self.output_dict[importer_node.IMPORT_RESULT_KEY].get()
            for res, uri, split in zip(results, self.source_uri, self.split):
                self.assertEqual(res.uri, uri)
                self.assertEqual(
                    artifact_utils.decode_split_names(res.split_names)[0],
                    split)
Beispiel #11
0
    def ReadExamplesArtifact(self,
                             examples: types.Artifact,
                             num_examples: int,
                             split_name: Optional[Text] = None):
        """Read records from Examples artifact.

    Currently it assumes Examples artifact contains serialized tf.Example in
    gzipped TFRecord files.

    Args:
      examples: `Examples` artifact.
      num_examples: Number of examples to read. If the specified value is larger
          than the actual number of examples, all examples would be read.
      split_name: Name of the split to read from the Examples artifact.

    Raises:
      RuntimeError: If read twice.
    """
        if self._records:
            raise RuntimeError('Cannot read records twice.')

        if num_examples < 1:
            raise ValueError('num_examples < 1 (got {})'.format(num_examples))

        available_splits = artifact_utils.decode_split_names(
            examples.split_names)
        if not available_splits:
            raise ValueError(
                'No split_name is available in given Examples artifact.')
        if split_name is None:
            split_name = available_splits[0]
        if split_name not in available_splits:
            raise ValueError(
                'No split_name {}; available split names: {}'.format(
                    split_name, ', '.join(available_splits)))

        # ExampleGen generates artifacts under each split_name directory.
        glob_pattern = os.path.join(examples.uri, split_name, '*')
        tfxio_factory = tfxio_utils.get_tfxio_factory_from_artifact(
            examples=[examples],
            telemetry_descriptors=_TELEMETRY_DESCRIPTORS,
            schema=None,
            read_as_raw_records=True,
            raw_record_column_name=_RAW_RECORDS_COLUMN)
        try:
            filenames = fileio.glob(glob_pattern)
        except tf.errors.NotFoundError:
            filenames = []
        if not filenames:
            raise ValueError(
                'Unable to find examples matching {}.'.format(glob_pattern))

        self._payload_format = examples_utils.get_payload_format(examples)
        tfxio = tfxio_factory(filenames)

        self._ReadFromDataset(
            tfxio.TensorFlowDataset(
                dataset_options.TensorFlowDatasetOptions(
                    batch_size=num_examples)))
Beispiel #12
0
 def display(self, artifact: types.Artifact):
     from IPython.core.display import display  # pylint: disable=g-import-not-at-top
     from IPython.core.display import HTML  # pylint: disable=g-import-not-at-top
     for split in artifact_utils.decode_split_names(artifact.split_names):
         display(HTML('<div><b>%r split:</b></div><br/>' % split))
         stats_path = os.path.join(artifact.uri, split, 'stats_tfrecord')
         stats = tfdv.load_statistics(stats_path)
         tfdv.visualize_statistics(stats)
Beispiel #13
0
  def Do(self, input_dict: Dict[Text, List[types.Artifact]],
         output_dict: Dict[Text, List[types.Artifact]],
         exec_properties: Dict[Text, Any]) -> None:
    """TensorFlow SchemaGen executor entrypoint.

    This infers the schema using tensorflow_data_validation on the precomputed
    stats of 'train' split.

    Args:
      input_dict: Input dict from input key to a list of artifacts, including:
        - 'stats': A list of 'ExampleStatistics' type which must contain
          split 'train'. Stats on other splits are ignored.
        - 'statistics': Synonym for 'stats'.
      output_dict: Output dict from key to a list of artifacts, including:
        - output: A list of 'Schema' artifact of size one.
      exec_properties: A dict of execution properties, includes:
        - infer_feature_shape: Whether or not to infer the shape of the feature.
        - exclude_splits: Names of splits that will not be taken into
          consideration when auto-generating a schema.

    Returns:
      None
    """
    # TODO(zhitaoli): Move constants between this file and component.py to a
    # constants.py.
    infer_feature_shape = exec_properties.get(INFER_FEATURE_SHAPE_KEY)

    # Load and deserialize exclude splits from execution properties.
    exclude_splits = json_utils.loads(
        exec_properties.get(EXCLUDE_SPLITS_KEY, 'null')) or []
    if not isinstance(exclude_splits, list):
      raise ValueError('exclude_splits in execution properties needs to be a '
                       'list. Got %s instead.' % type(exclude_splits))

    # Only one schema is generated for all splits.
    schema = None
    stats_artifact = artifact_utils.get_single_instance(
        input_dict[STATISTICS_KEY])
    for split in artifact_utils.decode_split_names(stats_artifact.split_names):
      if split in exclude_splits:
        continue

      logging.info('Processing schema from statistics for split %s.', split)
      stats_uri = io_utils.get_only_uri_in_dir(
          os.path.join(stats_artifact.uri, split))
      if not schema:
        schema = tfdv.infer_schema(
            tfdv.load_statistics(stats_uri), infer_feature_shape)
      else:
        schema = tfdv.update_schema(schema, tfdv.load_statistics(stats_uri),
                                    infer_feature_shape)

    output_uri = os.path.join(
        artifact_utils.get_single_uri(output_dict[SCHEMA_KEY]),
        _DEFAULT_FILE_NAME)
    io_utils.write_pbtxt_file(output_uri, schema)
    logging.info('Schema written to %s.', output_uri)
Beispiel #14
0
 def testConstruct(self):
   big_query_example_gen = component.BigQueryExampleGen(query='query')
   self.assertEqual(standard_artifacts.Examples.TYPE_NAME,
                    big_query_example_gen.outputs['examples'].type_name)
   artifact_collection = big_query_example_gen.outputs['examples'].get()
   self.assertEqual(1, len(artifact_collection))
   self.assertEqual(['train', 'eval'],
                    artifact_utils.decode_split_names(
                        artifact_collection[0].split_names))
Beispiel #15
0
 def display(self, artifact: types.Artifact):
     from IPython.core.display import display  # pylint: disable=g-import-not-at-top
     from IPython.core.display import HTML  # pylint: disable=g-import-not-at-top
     for split in artifact_utils.decode_split_names(artifact.split_names):
         display(HTML('<div><b>%r split:</b></div><br/>' % split))
         anomalies_path = os.path.join(artifact.uri, split,
                                       'anomalies.pbtxt')
         anomalies = tfdv.load_anomalies_text(anomalies_path)
         tfdv.display_anomalies(anomalies)
Beispiel #16
0
 def testConstruct(self):
     import_example_gen = component.ImportExampleGen(input_base='path')
     self.assertEqual(standard_artifacts.Examples.TYPE_NAME,
                      import_example_gen.outputs['examples'].type_name)
     artifact_collection = import_example_gen.outputs['examples'].get()
     self.assertEqual(1, len(artifact_collection))
     self.assertEqual(['train', 'eval'],
                      artifact_utils.decode_split_names(
                          artifact_collection[0].split_names))
Beispiel #17
0
  def __init__(self,
               statistics: types.Channel = None,
               schema: types.Channel = None,
               exclude_splits: Optional[List[Text]] = None,
               output: Optional[types.Channel] = None,
               stats: Optional[types.Channel] = None,
               instance_name: Optional[Text] = None):
    """Construct an ExampleValidator component.

    Args:
      statistics: A Channel of type `standard_artifacts.ExampleStatistics`. This
        should contain at least 'eval' split. Other splits are currently
        ignored.
      schema: A Channel of type `standard_artifacts.Schema`. _required_
      exclude_splits: Names of splits that the example validator should not
        validate. Default behavior (when exclude_splits is set to None)
        is excluding no splits.
      output: Output channel of type `standard_artifacts.ExampleAnomalies`.
      stats: Backwards compatibility alias for the 'statistics' argument.
      instance_name: Optional name assigned to this specific instance of
        ExampleValidator. Required only if multiple ExampleValidator components
        are declared in the same pipeline.  Either `stats` or `statistics` must
        be present in the arguments.
    """
    if stats:
      logging.warning(
          'The "stats" argument to the StatisticsGen component has '
          'been renamed to "statistics" and is deprecated. Please update your '
          'usage as support for this argument will be removed soon.')
      statistics = stats
    if exclude_splits is None:
      exclude_splits = []
      logging.info('Excluding no splits because exclude_splits is not set.')
    anomalies = output
    if not anomalies:
      anomalies_artifact = standard_artifacts.ExampleAnomalies()
      statistics_split_names = artifact_utils.decode_split_names(
          artifact_utils.get_single_instance(list(
              statistics.get())).split_names)
      split_names = [
          split for split in statistics_split_names
          if split not in exclude_splits
      ]
      anomalies_artifact.split_names = artifact_utils.encode_split_names(
          split_names)
      anomalies = types.Channel(
          type=standard_artifacts.ExampleAnomalies,
          artifacts=[anomalies_artifact])
    spec = ExampleValidatorSpec(
        statistics=statistics,
        schema=schema,
        exclude_splits=json_utils.dumps(exclude_splits),
        anomalies=anomalies)
    super(ExampleValidator, self).__init__(
        spec=spec, instance_name=instance_name)
Beispiel #18
0
 def testConstruct(self):
   input_base = standard_artifacts.ExternalArtifact()
   csv_example_gen = component.CsvExampleGen(
       input=channel_utils.as_channel([input_base]))
   self.assertEqual(standard_artifacts.Examples.TYPE_NAME,
                    csv_example_gen.outputs['examples'].type_name)
   artifact_collection = csv_example_gen.outputs['examples'].get()
   self.assertEqual(1, len(artifact_collection))
   self.assertEqual(['train', 'eval'],
                    artifact_utils.decode_split_names(
                        artifact_collection[0].split_names))
Beispiel #19
0
 def testConstructSubclassFileBased(self):
     example_gen = TestFileBasedExampleGenComponent(input_base='path')
     self.assertIn('input_base', example_gen.exec_properties)
     self.assertEqual(driver.Driver, example_gen.driver_class)
     self.assertEqual(standard_artifacts.Examples.TYPE_NAME,
                      example_gen.outputs['examples'].type_name)
     self.assertIsNone(example_gen.exec_properties.get('custom_config'))
     artifact_collection = example_gen.outputs['examples'].get()
     self.assertEqual(1, len(artifact_collection))
     self.assertEqual(['train', 'eval'],
                      artifact_utils.decode_split_names(
                          artifact_collection[0].split_names))
Beispiel #20
0
 def testConstruct(self):
   big_query_elwc_example_gen = component.BigQueryElwcExampleGen(
       elwc_config=example_gen_pb2.ElwcConfig(
           context_feature_fields=['query_id', 'query_content']),
       query='query', )
   self.assertEqual(standard_artifacts.Examples.TYPE_NAME,
                    big_query_elwc_example_gen.outputs['examples'].type_name)
   artifact_collection = big_query_elwc_example_gen.outputs['examples'].get()
   self.assertEqual(1, len(artifact_collection))
   self.assertEqual(['train', 'eval'],
                    artifact_utils.decode_split_names(
                        artifact_collection[0].split_names))
Beispiel #21
0
    def ReadExamplesArtifact(self,
                             examples: types.Artifact,
                             num_examples: int,
                             split_name: Optional[Text] = None):
        """Read records from Examples artifact.

    Currently it assumes Examples artifact contains serialized tf.Example in
    gzipped TFRecord files.

    Args:
      examples: `Examples` artifact.
      num_examples: Number of examples to read. If the specified value is larger
          than the actual number of examples, all examples would be read.
      split_name: Name of the split to read from the Examples artifact.

    Raises:
      RuntimeError: If read twice.
    """
        if self._records:
            raise RuntimeError('Cannot read records twice.')

        if num_examples < 1:
            raise ValueError('num_examples < 1 (got {})'.format(num_examples))

        available_splits = artifact_utils.decode_split_names(
            examples.split_names)
        if not available_splits:
            raise ValueError(
                'No split_name is available in given Examples artifact.')
        if split_name is None:
            split_name = available_splits[0]
        if split_name not in available_splits:
            raise ValueError(
                'No split_name {}; available split names: {}'.format(
                    split_name, ', '.join(available_splits)))

        # ExampleGen generates artifacts under each split_name directory.
        glob_pattern = os.path.join(examples.uri, split_name, '*.gz')
        try:
            filenames = tf.io.gfile.glob(glob_pattern)
        except tf.errors.NotFoundError:
            filenames = []
        if not filenames:
            raise ValueError(
                'Unable to find examples matching {}.'.format(glob_pattern))

        # Assume we have a tf.Example logical format.
        self._record_format = _LogicalFormat.TF_EXAMPLE

        self._ReadFromDataset(tf.data.TFRecordDataset(filenames,
                                                      compression_type='GZIP'),
                              num_examples=num_examples)
Beispiel #22
0
 def testConstructCustomExecutor(self):
     example_gen = component.FileBasedExampleGen(
         input_base='path',
         custom_executor_spec=executor_spec.ExecutorClassSpec(
             TestExampleGenExecutor))
     self.assertEqual(driver.Driver, example_gen.driver_class)
     self.assertEqual(standard_artifacts.Examples.TYPE_NAME,
                      example_gen.outputs['examples'].type_name)
     artifact_collection = example_gen.outputs['examples'].get()
     self.assertEqual(1, len(artifact_collection))
     self.assertEqual(['train', 'eval'],
                      artifact_utils.decode_split_names(
                          artifact_collection[0].split_names))
Beispiel #23
0
 def display(self, artifact: types.Artifact):
   from IPython.core.display import display  # pylint: disable=g-import-not-at-top
   from IPython.core.display import HTML  # pylint: disable=g-import-not-at-top
   for split in artifact_utils.decode_split_names(artifact.split_names):
     display(HTML('<div><b>%r split:</b></div><br/>' % split))
     stats_path = io_utils.get_only_uri_in_dir(
         artifact_utils.get_split_uri([artifact], split))
     if artifact_utils.is_artifact_version_older_than(
         artifact, artifact_utils._ARTIFACT_VERSION_FOR_STATS_UPDATE):  # pylint: disable=protected-access
       stats = tfdv.load_statistics(stats_path)
     else:
       stats = tfdv.load_stats_binary(stats_path)
     tfdv.visualize_statistics(stats)
Beispiel #24
0
    def Do(self, input_dict: Dict[Text, List[types.Artifact]],
           output_dict: Dict[Text, List[types.Artifact]],
           exec_properties: Dict[Text, Any]) -> None:

        source = exec_properties[StepKeys.SOURCE]
        args = exec_properties[StepKeys.ARGS]

        c = source_utils.load_source_path_class(source)
        tokenizer_step: BaseTokenizer = c(**args)

        tokenizer_location = artifact_utils.get_single_uri(
            output_dict["tokenizer"])

        split_uris, split_names, all_files = [], [], []
        for artifact in input_dict["examples"]:
            for split in artifact_utils.decode_split_names(
                    artifact.split_names):
                split_names.append(split)
                uri = os.path.join(artifact.uri, split)
                split_uris.append((split, uri))
                all_files += path_utils.list_dir(uri)

        # Get output split path
        output_examples = artifact_utils.get_single_instance(
            output_dict["output_examples"])
        output_examples.split_names = artifact_utils.encode_split_names(
            split_names)

        if not tokenizer_step.skip_training:
            tokenizer_step.train(files=all_files)

            tokenizer_step.save(output_dir=tokenizer_location)

        with self._make_beam_pipeline() as p:
            for split, uri in split_uris:
                input_uri = io_utils.all_files_pattern(uri)

                _ = (p
                     | 'ReadData.' + split >> beam.io.ReadFromTFRecord(
                            file_pattern=input_uri)
                     | "ParseTFExFromString." + split >> beam.Map(
                            tf.train.Example.FromString)
                     | "AddTokens." + split >> beam.Map(
                            append_tf_example,
                            tokenizer_step=tokenizer_step)
                     | 'Serialize.' + split >> beam.Map(
                            lambda x: x.SerializeToString())
                     | 'WriteSplit.' + split >> WriteSplit(
                            get_split_uri(
                                output_dict["output_examples"],
                                split)))
Beispiel #25
0
 def testConstruct(self):
   presto_example_gen = component.PrestoExampleGen(
       self.conn_config, query='query')
   self.assertEqual(
       self.conn_config,
       self._extract_conn_config(
           presto_example_gen.exec_properties['custom_config']))
   self.assertEqual(standard_artifacts.Examples.TYPE_NAME,
                    presto_example_gen.outputs['examples'].type_name)
   artifact_collection = presto_example_gen.outputs['examples'].get()
   self.assertEqual(1, len(artifact_collection))
   self.assertEqual(['train', 'eval'],
                    artifact_utils.decode_split_names(
                        artifact_collection[0].split_names))
Beispiel #26
0
 def testConstructWithInputConfig(self):
   big_query_example_gen = component.BigQueryExampleGen(
       input_config=example_gen_pb2.Input(splits=[
           example_gen_pb2.Input.Split(name='train', pattern='query1'),
           example_gen_pb2.Input.Split(name='eval', pattern='query2'),
           example_gen_pb2.Input.Split(name='test', pattern='query3')
       ]))
   self.assertEqual(standard_artifacts.Examples.TYPE_NAME,
                    big_query_example_gen.outputs['examples'].type_name)
   artifact_collection = big_query_example_gen.outputs['examples'].get()
   self.assertEqual(1, len(artifact_collection))
   self.assertEqual(['train', 'eval', 'test'],
                    artifact_utils.decode_split_names(
                        artifact_collection[0].split_names))
Beispiel #27
0
    def Do(self, input_dict: Dict[Text, List[types.Artifact]],
           output_dict: Dict[Text, List[types.Artifact]],
           exec_properties: Dict[Text, Any]) -> None:
        """Computes stats for each split of input using tensorflow_data_validation.

    Args:
      input_dict: Input dict from input key to a list of Artifacts.
        - input_data: A list of type `standard_artifacts.Examples`. This should
          contain both 'train' and 'eval' split.
      output_dict: Output dict from output key to a list of Artifacts.
        - output: A list of type `standard_artifacts.ExampleStatistics`. This
          should contain both the 'train' and 'eval' splits.
      exec_properties: A dict of execution properties. Not used yet.

    Returns:
      None
    """
        self._log_startup(input_dict, output_dict, exec_properties)

        split_uris = []
        for artifact in input_dict['input_data']:
            for split in artifact_utils.decode_split_names(
                    artifact.split_names):
                uri = os.path.join(artifact.uri, split)
                split_uris.append((split, uri))
        with self._make_beam_pipeline() as p:
            # TODO(b/126263006): Support more stats_options through config.
            stats_options = options.StatsOptions()
            for split, uri in split_uris:
                absl.logging.info(
                    'Generating statistics for split {}'.format(split))
                input_uri = io_utils.all_files_pattern(uri)
                output_uri = artifact_utils.get_split_uri(
                    output_dict['output'], split)
                output_path = os.path.join(output_uri, _DEFAULT_FILE_NAME)
                _ = (p
                     | 'ReadData.' + split >>
                     beam.io.ReadFromTFRecord(file_pattern=input_uri)
                     | 'DecodeData.' + split >>
                     tf_example_decoder.DecodeTFExample()
                     | 'GenerateStatistics.' + split >>
                     stats_api.GenerateStatistics(stats_options)
                     | 'WriteStatsOutput.' + split >> beam.io.WriteToTFRecord(
                         output_path,
                         shard_name_template='',
                         coder=beam.coders.ProtoCoder(
                             statistics_pb2.DatasetFeatureStatisticsList)))
                absl.logging.info(
                    'Statistics for split {} written to {}.'.format(
                        split, output_uri))
Beispiel #28
0
 def testConstructSubclassQueryBased(self):
     example_gen = TestQueryBasedExampleGenComponent(
         input_config=example_gen_pb2.Input(splits=[
             example_gen_pb2.Input.Split(name='single', pattern='query'),
         ]))
     self.assertEqual({}, example_gen.inputs.get_all())
     self.assertEqual(base_driver.BaseDriver, example_gen.driver_class)
     self.assertEqual(standard_artifacts.Examples.TYPE_NAME,
                      example_gen.outputs['examples'].type_name)
     self.assertIsNone(example_gen.exec_properties.get('custom_config'))
     artifact_collection = example_gen.outputs['examples'].get()
     self.assertEqual(1, len(artifact_collection))
     self.assertEqual(['train', 'eval'],
                      artifact_utils.decode_split_names(
                          artifact_collection[0].split_names))
Beispiel #29
0
 def display(self, artifact: types.Artifact):
   from IPython.core.display import display  # pylint: disable=g-import-not-at-top
   from IPython.core.display import HTML  # pylint: disable=g-import-not-at-top
   for split in artifact_utils.decode_split_names(artifact.split_names):
     display(HTML('<div><b>%r split:</b></div><br/>' % split))
     anomalies_path = io_utils.get_only_uri_in_dir(
         artifact_utils.get_split_uri([artifact], split))
     if artifact_utils.is_artifact_version_older_than(
         artifact, artifact_utils._ARTIFACT_VERSION_FOR_ANOMALIES_UPDATE):  # pylint: disable=protected-access
       anomalies = tfdv.load_anomalies_text(anomalies_path)
     else:
       anomalies = anomalies_pb2.Anomalies()
       anomalies_bytes = io_utils.read_bytes_file(anomalies_path)
       anomalies.ParseFromString(anomalies_bytes)
     tfdv.display_anomalies(anomalies)
Beispiel #30
0
 def testConstructWithInputConfig(self):
     example_gen = TestFileBasedExampleGenComponent(
         input_base='path',
         input_config=example_gen_pb2.Input(splits=[
             example_gen_pb2.Input.Split(name='train', pattern='train/*'),
             example_gen_pb2.Input.Split(name='eval', pattern='eval/*'),
             example_gen_pb2.Input.Split(name='test', pattern='test/*')
         ]))
     self.assertEqual(standard_artifacts.Examples.TYPE_NAME,
                      example_gen.outputs['examples'].type_name)
     artifact_collection = example_gen.outputs['examples'].get()
     self.assertEqual(1, len(artifact_collection))
     self.assertEqual(['train', 'eval', 'test'],
                      artifact_utils.decode_split_names(
                          artifact_collection[0].split_names))