Ejemplo n.º 1
0
 def test_resolve_payload_format_and_data_view_uri(
         self,
         payload_formats,
         data_view_uris=None,
         data_view_ids=None,
         expected_payload_format=None,
         expected_data_view_uri=None,
         expected_error_type=None,
         expected_error_msg_regex=None):
     examples = []
     if data_view_uris is None:
         data_view_uris = [None] * len(payload_formats)
     if data_view_ids is None:
         data_view_ids = [None] * len(payload_formats)
     for payload_format, data_view_uri, data_view_id in zip(
             payload_formats, data_view_uris, data_view_ids):
         artifact = standard_artifacts.Examples()
         examples_utils.set_payload_format(artifact, payload_format)
         if data_view_uri is not None:
             artifact.set_string_custom_property(
                 constants.DATA_VIEW_URI_PROPERTY_KEY, data_view_uri)
         if data_view_id is not None:
             artifact.set_int_custom_property(
                 constants.DATA_VIEW_ID_PROPERTY_KEY, data_view_id)
         examples.append(artifact)
     if expected_error_type is None:
         payload_format, data_view_uri = (
             tfxio_utils.resolve_payload_format_and_data_view_uri(examples))
         self.assertEqual(payload_format, expected_payload_format)
         self.assertEqual(data_view_uri, expected_data_view_uri)
     else:
         with self.assertRaisesRegex(expected_error_type,
                                     expected_error_msg_regex):
             _ = tfxio_utils.resolve_payload_format_and_data_view_uri(
                 examples)
Ejemplo n.º 2
0
    def _run_model_inference(
        self,
        data_spec: bulk_inferrer_pb2.DataSpec,
        output_example_spec: bulk_inferrer_pb2.OutputExampleSpec,
        examples: List[types.Artifact],
        output_examples: Optional[types.Artifact],
        inference_result: Optional[types.Artifact],
        inference_endpoint: model_spec_pb2.InferenceSpecType,
    ) -> None:
        """Runs model inference on given examples data.

    Args:
      data_spec: bulk_inferrer_pb2.DataSpec instance.
      output_example_spec: bulk_inferrer_pb2.OutputExampleSpec instance.
      examples: List of `standard_artifacts.Examples` artifacts.
      output_examples: Optional output `standard_artifacts.Examples` artifact.
      inference_result: Optional output `standard_artifacts.InferenceResult`
        artifact.
      inference_endpoint: Model inference endpoint.
    """

        example_uris = {}
        for example_artifact in examples:
            for split in artifact_utils.decode_split_names(
                    example_artifact.split_names):
                if data_spec.example_splits:
                    if split in data_spec.example_splits:
                        example_uris[split] = artifact_utils.get_split_uri(
                            [example_artifact], split)
                else:
                    example_uris[split] = artifact_utils.get_split_uri(
                        [example_artifact], split)

        payload_format, _ = tfxio_utils.resolve_payload_format_and_data_view_uri(
            examples)

        tfxio_factory = tfxio_utils.get_tfxio_factory_from_artifact(
            examples,
            _TELEMETRY_DESCRIPTORS,
            schema=None,
            read_as_raw_records=True,
            # We have to specify this parameter in order to create a RawRecord TFXIO
            # but we won't use the RecordBatches so the column name of the raw
            # records does not matter.
            raw_record_column_name='unused')

        if output_examples:
            output_examples.split_names = artifact_utils.encode_split_names(
                sorted(example_uris.keys()))

        with self._make_beam_pipeline() as pipeline:
            data_list = []
            for split, example_uri in example_uris.items():
                tfxio = tfxio_factory(
                    [io_utils.all_files_pattern(example_uri)])
                assert isinstance(
                    tfxio, record_based_tfxio.RecordBasedTFXIO
                ), ('Unable to use TFXIO {} as it does not support reading raw records.'
                    .format(type(tfxio)))
                # pylint: disable=no-value-for-parameter
                data = (pipeline
                        | 'ReadData[{}]'.format(split) >>
                        tfxio.RawRecordBeamSource()
                        | 'RunInference[{}]'.format(split) >> _RunInference(
                            payload_format, inference_endpoint))
                if output_examples:
                    output_examples_split_uri = artifact_utils.get_split_uri(
                        [output_examples], split)
                    logging.info('Path of output examples split `%s` is %s.',
                                 split, output_examples_split_uri)
                    _ = (data
                         | 'WriteExamples[{}]'.format(split) >> _WriteExamples(
                             output_example_spec, output_examples_split_uri))
                    # pylint: enable=no-value-for-parameter

                data_list.append(data)

            if inference_result:
                _ = (
                    data_list
                    |
                    'FlattenInferenceResult' >> beam.Flatten(pipeline=pipeline)
                    | 'WritePredictionLogs' >> beam.io.WriteToTFRecord(
                        os.path.join(inference_result.uri,
                                     _PREDICTION_LOGS_FILE_NAME),
                        file_name_suffix='.gz',
                        coder=beam.coders.ProtoCoder(
                            prediction_log_pb2.PredictionLog)))

        if output_examples:
            logging.info('Output examples written to %s.', output_examples.uri)
        if inference_result:
            logging.info('Inference result written to %s.',
                         inference_result.uri)
Ejemplo n.º 3
0
    def Do(self, input_dict: Dict[Text, List[types.Artifact]],
           output_dict: Dict[Text, List[types.Artifact]],
           exec_properties: Dict[Text, Any]) -> None:
        """TensorFlow Transform executor entrypoint.

    This implements BaseExecutor.Do() and is invoked by orchestration systems.
    This is not inteded for manual usage or further customization. Please use
    the Transform() function which takes an input format with no artifact
    dependency.

    Args:
      input_dict: Input dict from input key to a list of artifacts, including:
        - input_data: A list of type `standard_artifacts.Examples` which
          should contain two splits 'train' and 'eval'.
        - schema: A list of type `standard_artifacts.Schema` which should
          contain a single schema artifact.
      output_dict: Output dict from key to a list of artifacts, including:
        - transform_output: Output of 'tf.Transform', which includes an exported
          Tensorflow graph suitable for both training and serving;
        - transformed_examples: Materialized transformed examples, which
          includes both 'train' and 'eval' splits.
      exec_properties: A dict of execution properties, including either one of:
        - module_file: The file path to a python module file, from which the
          'preprocessing_fn' function will be loaded.
        - preprocessing_fn: The module path to a python function that
          implements 'preprocessing_fn'.

    Returns:
      None
    """
        self._log_startup(input_dict, output_dict, exec_properties)

        train_data_uri = artifact_utils.get_split_uri(input_dict[EXAMPLES_KEY],
                                                      'train')
        eval_data_uri = artifact_utils.get_split_uri(input_dict[EXAMPLES_KEY],
                                                     'eval')
        payload_format, data_view_uri = (
            tfxio_utils.resolve_payload_format_and_data_view_uri(
                input_dict[EXAMPLES_KEY]))
        schema_file = io_utils.get_only_uri_in_dir(
            artifact_utils.get_single_uri(input_dict[SCHEMA_KEY]))

        transform_graph_uri = artifact_utils.get_single_uri(
            input_dict[TRANSFORM_GRAPH_KEY])
        transform_output = artifact_utils.get_single_uri(
            output_dict[TRANSFORM_OUTPUT_KEY])

        temp_path = os.path.join(transform_output,
                                 _TEMP_DIR_IN_TRANSFORM_OUTPUT)
        absl.logging.debug('Using temp path %s for tft.beam', temp_path)

        materialize_output_paths = []
        if output_dict.get(TRANSFORMED_EXAMPLES_KEY) is not None:
            transformed_example_artifact = artifact_utils.get_single_instance(
                output_dict[TRANSFORMED_EXAMPLES_KEY])
            # TODO(b/161490287): move the split_names setting to executor for all
            # components.
            transformed_example_artifact.split_names = (
                artifact_utils.encode_split_names(
                    artifact.DEFAULT_EXAMPLE_SPLITS))
            transformed_train_output = artifact_utils.get_split_uri(
                output_dict[TRANSFORMED_EXAMPLES_KEY], 'train')
            transformed_eval_output = artifact_utils.get_split_uri(
                output_dict[TRANSFORMED_EXAMPLES_KEY], 'eval')
            materialize_output_paths = [
                os.path.join(transformed_train_output,
                             _DEFAULT_TRANSFORMED_EXAMPLES_PREFIX),
                os.path.join(transformed_eval_output,
                             _DEFAULT_TRANSFORMED_EXAMPLES_PREFIX)
            ]

        def _GetCachePath(label, params_dict):
            if label not in params_dict:
                return None
            else:
                return artifact_utils.get_single_uri(params_dict[label])

        label_inputs = {
            'transform_graph_uri':
            transform_graph_uri,
            labels.COMPUTE_STATISTICS_LABEL:
            False,
            labels.SCHEMA_PATH_LABEL:
            schema_file,
            labels.EXAMPLES_DATA_FORMAT_LABEL:
            payload_format,
            labels.DATA_VIEW_LABEL:
            data_view_uri,
            labels.ANALYZE_DATA_PATHS_LABEL:
            io_utils.all_files_pattern(train_data_uri),
            labels.ANALYZE_PATHS_FILE_FORMATS_LABEL:
            labels.FORMAT_TFRECORD,
            labels.TRANSFORM_DATA_PATHS_LABEL: [
                io_utils.all_files_pattern(train_data_uri),
                io_utils.all_files_pattern(eval_data_uri)
            ],
            labels.TRANSFORM_PATHS_FILE_FORMATS_LABEL:
            [labels.FORMAT_TFRECORD, labels.FORMAT_TFRECORD],
            labels.CUSTOM_CONFIG:
            exec_properties.get('custom_config', None),
        }
        cache_input = _GetCachePath('cache_input_path', input_dict)
        if cache_input is not None:
            label_inputs[labels.CACHE_INPUT_PATH_LABEL] = cache_input

        label_outputs = {
            labels.TRANSFORM_METADATA_OUTPUT_PATH_LABEL: transform_output,
            labels.TRANSFORM_MATERIALIZE_OUTPUT_PATHS_LABEL:
            materialize_output_paths,
            labels.TEMP_OUTPUT_LABEL: str(temp_path),
        }
        cache_output = _GetCachePath('cache_output_path', output_dict)
        if cache_output is not None:
            label_outputs[labels.CACHE_OUTPUT_PATH_LABEL] = cache_output
        status_file = 'status_file'  # Unused
        self.Transform(label_inputs, label_outputs, status_file)
        absl.logging.debug('Cleaning up temp path %s on executor success',
                           temp_path)
        io_utils.delete_dir(temp_path)