示例#1
0
 def testConstructSubclassQueryBasedWithRangeConfig(self):
     # @span_yyyymmdd_utc will replaced to '19700103' to query, span `2` will be
     # recorded in output Example artifact.
     range_config = range_config_pb2.RangeConfig(
         static_range=range_config_pb2.StaticRange(start_span_number=2,
                                                   end_span_number=2))
     example_gen = TestQueryBasedExampleGenComponent(
         input_config=example_gen_pb2.Input(splits=[
             example_gen_pb2.Input.Split(
                 name='single',
                 pattern='select * from table where date=@span_yyyymmdd_utc'
             ),
         ]),
         range_config=range_config)
     self.assertEqual({}, example_gen.inputs)
     self.assertEqual(driver.QueryBasedDriver, example_gen.driver_class)
     self.assertEqual(
         standard_artifacts.Examples.TYPE_NAME, example_gen.outputs[
             standard_component_specs.EXAMPLES_KEY].type_name)
     stored_range_config = range_config_pb2.RangeConfig()
     proto_utils.json_to_proto(
         example_gen.exec_properties[
             standard_component_specs.RANGE_CONFIG_KEY],
         stored_range_config)
     self.assertEqual(range_config, stored_range_config)
示例#2
0
    def _extract_conn_config(self, custom_config):
        unpacked_custom_config = example_gen_pb2.CustomConfig()
        proto_utils.json_to_proto(custom_config, unpacked_custom_config)

        conn_config = presto_config_pb2.PrestoConnConfig()
        unpacked_custom_config.custom_config.Unpack(conn_config)
        return conn_config
示例#3
0
def ResolveSplitsConfig(
        splits_config_str: Optional[str],
        examples: List[types.Artifact]) -> transform_pb2.SplitsConfig:
    """Resolve SplitsConfig proto for the transfrom request."""
    result = transform_pb2.SplitsConfig()
    if splits_config_str:
        proto_utils.json_to_proto(splits_config_str, result)
        if not result.analyze:
            raise ValueError(
                'analyze cannot be empty when splits_config is set.')
        return result

    result.analyze.append('train')

    # All input artifacts should have the same set of split names.
    split_names = set(
        artifact_utils.decode_split_names(examples[0].split_names))

    for artifact in examples:
        artifact_split_names = set(
            artifact_utils.decode_split_names(artifact.split_names))
        if split_names != artifact_split_names:
            raise ValueError(
                'Not all input artifacts have the same split names: (%s, %s)' %
                (split_names, artifact_split_names))

    result.transform.extend(split_names)
    logging.info("Analyze the 'train' split and transform all splits when "
                 'splits_config is not set.')
    return result
示例#4
0
    def testQueryBasedDriver(self):
        # Create exec proterties.
        exec_properties = {
            standard_component_specs.INPUT_CONFIG_KEY:
            proto_utils.proto_to_json(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='s1',
                        pattern=
                        "select * from table where span={SPAN} and split='s1'"
                    ),
                    example_gen_pb2.Input.Split(
                        name='s2',
                        pattern=
                        "select * from table where span={SPAN} and split='s2'")
                ])),
            standard_component_specs.RANGE_CONFIG_KEY:
            proto_utils.proto_to_json(
                range_config_pb2.RangeConfig(
                    static_range=range_config_pb2.StaticRange(
                        start_span_number=2, end_span_number=2))),
        }
        # Prepare output_dict
        example = standard_artifacts.Examples()
        example.uri = 'my_uri'
        output_dict = {standard_component_specs.EXAMPLES_KEY: [example]}

        query_based_driver = driver.QueryBasedDriver(self._mock_metadata)
        result = query_based_driver.run(
            portable_data_types.ExecutionInfo(output_dict=output_dict,
                                              exec_properties=exec_properties))

        self.assertEqual(exec_properties[utils.SPAN_PROPERTY_NAME], 2)
        self.assertIsNone(exec_properties[utils.VERSION_PROPERTY_NAME])
        self.assertIsNone(exec_properties[utils.FINGERPRINT_PROPERTY_NAME])
        updated_input_config = example_gen_pb2.Input()
        proto_utils.json_to_proto(
            exec_properties[standard_component_specs.INPUT_CONFIG_KEY],
            updated_input_config)
        self.assertProtoEquals(
            """
        splits {
          name: "s1"
          pattern: "select * from table where span=2 and split='s1'"
        }
        splits {
          name: "s2"
          pattern: "select * from table where span=2 and split='s2'"
        }""", updated_input_config)
        self.assertLen(
            result.output_artifacts[
                standard_component_specs.EXAMPLES_KEY].artifacts, 1)
        output_example = result.output_artifacts[
            standard_component_specs.EXAMPLES_KEY].artifacts[0]
        self.assertEqual(output_example.uri, example.uri)
        self.assertEqual(
            output_example.custom_properties[
                utils.SPAN_PROPERTY_NAME].string_value, '2')
示例#5
0
文件: executor.py 项目: jay90099/tfx
def get_tune_args(
        exec_properties: Dict[str, Any]) -> Optional[tuner_pb2.TuneArgs]:
    """Returns TuneArgs protos from execution properties, if present."""
    tune_args = exec_properties.get(standard_component_specs.TUNE_ARGS_KEY)
    if not tune_args:
        return None

    result = tuner_pb2.TuneArgs()
    proto_utils.json_to_proto(tune_args, result)

    return result
示例#6
0
 def _get_inference_spec(
     self, model_path: Text,
     exec_properties: Dict[Text, Any]) -> model_spec_pb2.InferenceSpecType:
   model_spec = bulk_inferrer_pb2.ModelSpec()
   proto_utils.json_to_proto(exec_properties['model_spec'], model_spec)
   saved_model_spec = model_spec_pb2.SavedModelSpec(
       model_path=model_path,
       tag=model_spec.tag,
       signature_name=model_spec.model_signature_name)
   result = model_spec_pb2.InferenceSpecType()
   result.saved_model_spec.CopyFrom(saved_model_spec)
   return result
示例#7
0
 def testConstructWithStaticRangeConfig(self):
     range_config = range_config_pb2.RangeConfig(
         static_range=range_config_pb2.StaticRange(start_span_number=1,
                                                   end_span_number=1))
     example_gen = component.FileBasedExampleGen(
         input_base='path',
         range_config=range_config,
         custom_executor_spec=executor_spec.ExecutorClassSpec(
             TestExampleGenExecutor))
     stored_range_config = range_config_pb2.RangeConfig()
     proto_utils.json_to_proto(example_gen.exec_properties['range_config'],
                               stored_range_config)
     self.assertEqual(range_config, stored_range_config)
示例#8
0
  def testConstructWithCustomConfig(self):
    custom_config = example_gen_pb2.CustomConfig(custom_config=any_pb2.Any())
    example_gen = component.FileBasedExampleGen(
        input_base='path',
        custom_config=custom_config,
        custom_executor_spec=executor_spec.BeamExecutorSpec(
            TestExampleGenExecutor))

    stored_custom_config = example_gen_pb2.CustomConfig()
    proto_utils.json_to_proto(
        example_gen.exec_properties[standard_component_specs.CUSTOM_CONFIG_KEY],
        stored_custom_config)
    self.assertEqual(custom_config, stored_custom_config)
示例#9
0
 def _get_inference_spec(
     self, model_path: str,
     exec_properties: Dict[str, Any]) -> model_spec_pb2.InferenceSpecType:
   model_spec = bulk_inferrer_pb2.ModelSpec()
   proto_utils.json_to_proto(
       exec_properties[standard_component_specs.MODEL_SPEC_KEY], model_spec)
   saved_model_spec = model_spec_pb2.SavedModelSpec(
       model_path=model_path,
       tag=model_spec.tag,
       signature_name=model_spec.model_signature_name)
   result = model_spec_pb2.InferenceSpecType()
   result.saved_model_spec.CopyFrom(saved_model_spec)
   return result
示例#10
0
    def testConstructWithInputConfig(self):
        input_config = example_gen_pb2.Input(splits=[
            example_gen_pb2.Input.Split(name='train', pattern='train/*'),
            example_gen_pb2.Input.Split(name='eval', pattern='eval/*'),
            example_gen_pb2.Input.Split(name='test', pattern='test/*')
        ])
        example_gen = TestFileBasedExampleGenComponent(
            input_base='path', input_config=input_config)
        self.assertEqual(standard_artifacts.Examples.TYPE_NAME,
                         example_gen.outputs['examples'].type_name)

        stored_input_config = example_gen_pb2.Input()
        proto_utils.json_to_proto(example_gen.exec_properties['input_config'],
                                  stored_input_config)
        self.assertEqual(input_config, stored_input_config)
示例#11
0
    def _dict_to_object(self, dict_data: Dict[Text, Any]) -> Any:
        """Converts a dictionary to an object."""
        if _TFX_OBJECT_TYPE_KEY not in dict_data:
            return dict_data

        object_type = dict_data.pop(_TFX_OBJECT_TYPE_KEY)

        def _extract_class(d):
            module_name = d.pop(_MODULE_KEY)
            class_name = d.pop(_CLASS_KEY)
            return getattr(importlib.import_module(module_name), class_name)

        if object_type == _ObjectType.JSONABLE:
            jsonable_class_type = _extract_class(dict_data)
            if not issubclass(jsonable_class_type, Jsonable):
                raise ValueError('Class %s must be a subclass of Jsonable' %
                                 jsonable_class_type)
            return jsonable_class_type.from_json_dict(dict_data)

        if object_type == _ObjectType.CLASS:
            return _extract_class(dict_data)

        if object_type == _ObjectType.PROTO:
            proto_class_type = _extract_class(dict_data)
            if not issubclass(proto_class_type, message.Message):
                raise ValueError(
                    'Class %s must be a subclass of proto.Message' %
                    proto_class_type)
            if _PROTO_VALUE_KEY not in dict_data:
                raise ValueError('Missing proto value in json dict')
            return proto_utils.json_to_proto(dict_data[_PROTO_VALUE_KEY],
                                             proto_class_type())
示例#12
0
    def resolve_exec_properties(
        self,
        exec_properties: Dict[Text, Any],
        pipeline_info: data_types.PipelineInfo,
        component_info: data_types.ComponentInfo,
    ) -> Dict[Text, Any]:
        """Overrides BaseDriver.resolve_exec_properties()."""
        del pipeline_info, component_info

        input_config = example_gen_pb2.Input()
        proto_utils.json_to_proto(
            exec_properties[standard_component_specs.INPUT_CONFIG_KEY],
            input_config)

        input_base = exec_properties[standard_component_specs.INPUT_BASE_KEY]
        logging.debug('Processing input %s.', input_base)

        range_config = None
        range_config_entry = exec_properties.get(
            standard_component_specs.RANGE_CONFIG_KEY)
        if range_config_entry:
            range_config = range_config_pb2.RangeConfig()
            proto_utils.json_to_proto(range_config_entry, range_config)

            if range_config.HasField('static_range'):
                # For ExampleGen, StaticRange must specify an exact span to look for,
                # since only one span is processed at a time.
                start_span_number = range_config.static_range.start_span_number
                end_span_number = range_config.static_range.end_span_number
                if start_span_number != end_span_number:
                    raise ValueError(
                        'Start and end span numbers for RangeConfig.static_range must '
                        'be equal: (%s, %s)' %
                        (start_span_number, end_span_number))

        # Note that this function updates the input_config.splits.pattern.
        fingerprint, span, version = utils.calculate_splits_fingerprint_span_and_version(
            input_base, input_config.splits, range_config)

        exec_properties[standard_component_specs.
                        INPUT_CONFIG_KEY] = proto_utils.proto_to_json(
                            input_config)
        exec_properties[utils.SPAN_PROPERTY_NAME] = span
        exec_properties[utils.VERSION_PROPERTY_NAME] = version
        exec_properties[utils.FINGERPRINT_PROPERTY_NAME] = fingerprint

        return exec_properties
示例#13
0
  def _parse_parameters(self, raw_args: Mapping[str, Any]):
    """Parse arguments to ComponentSpec."""
    unparsed_args = set(raw_args.keys())
    inputs = {}
    outputs = {}
    self.exec_properties = {}

    # First, check that the arguments are set.
    for arg_name, arg in itertools.chain(self.PARAMETERS.items(),
                                         self.INPUTS.items(),
                                         self.OUTPUTS.items()):
      if arg_name not in unparsed_args:
        if arg.optional:
          continue
        else:
          raise ValueError('Missing argument %r to %s.' %
                           (arg_name, self.__class__))
      unparsed_args.remove(arg_name)

      # Type check the argument.
      value = raw_args[arg_name]
      if arg.optional and value is None:
        continue
      arg.type_check(arg_name, value)

    # Populate the appropriate dictionary for each parameter type.
    for arg_name, arg in self.PARAMETERS.items():
      if arg.optional and arg_name not in raw_args:
        continue
      value = raw_args[arg_name]

      if (inspect.isclass(arg.type) and
          issubclass(arg.type, message.Message) and value and
          not _is_runtime_param(value)):
        if arg.use_proto:
          if isinstance(value, dict):
            value = proto_utils.dict_to_proto(value, arg.type())
          elif isinstance(value, str):
            value = proto_utils.json_to_proto(value, arg.type())
        else:
          # Create deterministic json string as it will be stored in metadata
          # for cache check.
          if isinstance(value, dict):
            value = json_utils.dumps(value)
          elif not isinstance(value, str):
            value = proto_utils.proto_to_json(value)

      self.exec_properties[arg_name] = value

    for arg_dict, param_dict in ((self.INPUTS, inputs), (self.OUTPUTS,
                                                         outputs)):
      for arg_name, arg in arg_dict.items():
        if arg.optional and not raw_args.get(arg_name):
          continue
        value = raw_args[arg_name]
        param_dict[arg_name] = value

    self.inputs = inputs
    self.outputs = outputs
示例#14
0
    def testConstructWithOutputConfig(self):
        output_config = example_gen_pb2.Output(
            split_config=example_gen_pb2.SplitConfig(splits=[
                example_gen_pb2.SplitConfig.Split(name='train',
                                                  hash_buckets=2),
                example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1),
                example_gen_pb2.SplitConfig.Split(name='test', hash_buckets=1)
            ]))
        example_gen = TestFileBasedExampleGenComponent(
            input_base='path', output_config=output_config)
        self.assertEqual(standard_artifacts.Examples.TYPE_NAME,
                         example_gen.outputs['examples'].type_name)

        stored_output_config = example_gen_pb2.Output()
        proto_utils.json_to_proto(example_gen.exec_properties['output_config'],
                                  stored_output_config)
        self.assertEqual(output_config, stored_output_config)
示例#15
0
    def resolve_exec_properties(
        self,
        exec_properties: Dict[Text, Any],
        pipeline_info: data_types.PipelineInfo,
        component_info: data_types.ComponentInfo,
    ) -> Dict[Text, Any]:
        """Overrides BaseDriver.resolve_exec_properties()."""
        del pipeline_info, component_info

        input_config = example_gen_pb2.Input()
        proto_utils.json_to_proto(
            exec_properties[standard_component_specs.INPUT_CONFIG_KEY],
            input_config)

        input_base = exec_properties.get(
            standard_component_specs.INPUT_BASE_KEY)
        logging.debug('Processing input %s.', input_base)

        range_config = None
        range_config_entry = exec_properties.get(
            standard_component_specs.RANGE_CONFIG_KEY)
        if range_config_entry:
            range_config = range_config_pb2.RangeConfig()
            proto_utils.json_to_proto(range_config_entry, range_config)

        processor = self.get_input_processor(splits=input_config.splits,
                                             range_config=range_config,
                                             input_base_uri=input_base)

        span, version = processor.resolve_span_and_version()
        fingerprint = processor.get_input_fingerprint(span, version)

        # Updates the input_config.splits.pattern.
        for split in input_config.splits:
            split.pattern = processor.get_pattern_for_span_version(
                split.pattern, span, version)

        exec_properties[standard_component_specs.
                        INPUT_CONFIG_KEY] = proto_utils.proto_to_json(
                            input_config)
        exec_properties[utils.SPAN_PROPERTY_NAME] = span
        exec_properties[utils.VERSION_PROPERTY_NAME] = version
        exec_properties[utils.FINGERPRINT_PROPERTY_NAME] = fingerprint

        return exec_properties
示例#16
0
 def testConstructWithRangeConfig(self):
     range_config = range_config_pb2.RangeConfig(
         static_range=range_config_pb2.StaticRange(start_span_number=2,
                                                   end_span_number=2))
     # @span_yyyymmdd_utc will replaced to '19700103' to query, span `2` will be
     # recorded in output Example artifact.
     big_query_example_gen = component.BigQueryExampleGen(
         query='select * from table where date=@span_yyyymmdd_utc',
         range_config=range_config)
     self.assertEqual(
         standard_artifacts.Examples.TYPE_NAME,
         big_query_example_gen.outputs[
             standard_component_specs.EXAMPLES_KEY].type_name)
     stored_range_config = range_config_pb2.RangeConfig()
     proto_utils.json_to_proto(
         big_query_example_gen.exec_properties[
             standard_component_specs.RANGE_CONFIG_KEY],
         stored_range_config)
     self.assertEqual(range_config, stored_range_config)
示例#17
0
文件: executor.py 项目: sycdesign/tfx
  def Do(self, input_dict: Dict[Text, List[types.Artifact]],
         output_dict: Dict[Text, List[types.Artifact]],
         exec_properties: Dict[Text, Any]) -> None:
    """Contract for running InfraValidator Executor.

    Args:
      input_dict:
        - `model`: Single `Model` artifact that we're validating.
        - `examples`: `Examples` artifacts to be used for test requests.
      output_dict:
        - `blessing`: Single `InfraBlessing` artifact containing the validated
          result. It is an empty file with the name either of INFRA_BLESSED or
          INFRA_NOT_BLESSED.
      exec_properties:
        - `serving_spec`: Serialized `ServingSpec` configuration.
        - `validation_spec`: Serialized `ValidationSpec` configuration.
        - `request_spec`: Serialized `RequestSpec` configuration.
    """
    self._log_startup(input_dict, output_dict, exec_properties)

    model = artifact_utils.get_single_instance(input_dict[MODEL_KEY])
    blessing = artifact_utils.get_single_instance(output_dict[BLESSING_KEY])

    if input_dict.get(EXAMPLES_KEY):
      examples = artifact_utils.get_single_instance(input_dict[EXAMPLES_KEY])
    else:
      examples = None

    serving_spec = infra_validator_pb2.ServingSpec()
    proto_utils.json_to_proto(exec_properties[SERVING_SPEC_KEY], serving_spec)
    if not serving_spec.model_name:
      serving_spec.model_name = _DEFAULT_MODEL_NAME

    validation_spec = infra_validator_pb2.ValidationSpec()
    if exec_properties.get(VALIDATION_SPEC_KEY):
      proto_utils.json_to_proto(exec_properties[VALIDATION_SPEC_KEY],
                                validation_spec)
    if not validation_spec.num_tries:
      validation_spec.num_tries = _DEFAULT_NUM_TRIES
    if not validation_spec.max_loading_time_seconds:
      validation_spec.max_loading_time_seconds = _DEFAULT_MAX_LOADING_TIME_SEC

    if exec_properties.get(REQUEST_SPEC_KEY):
      request_spec = infra_validator_pb2.RequestSpec()
      proto_utils.json_to_proto(exec_properties[REQUEST_SPEC_KEY],
                                request_spec)
    else:
      request_spec = None

    with self._InstallGracefulShutdownHandler():
      self._Do(
          model=model,
          examples=examples,
          blessing=blessing,
          serving_spec=serving_spec,
          validation_spec=validation_spec,
          request_spec=request_spec,
      )
示例#18
0
文件: executor.py 项目: jay90099/tfx
def _PrestoToExample(  # pylint: disable=invalid-name
    pipeline: beam.Pipeline, exec_properties: Dict[str, Any],
    split_pattern: str) -> beam.pvalue.PCollection:
  """Read from Presto and transform to TF examples.

  Args:
    pipeline: beam pipeline.
    exec_properties: A dict of execution properties.
    split_pattern: Split.pattern in Input config, a Presto sql string.

  Returns:
    PCollection of TF examples.
  """
  conn_config = example_gen_pb2.CustomConfig()
  proto_utils.json_to_proto(exec_properties['custom_config'], conn_config)
  presto_config = presto_config_pb2.PrestoConnConfig()
  conn_config.custom_config.Unpack(presto_config)

  client = _deserialize_conn_config(presto_config)
  return (pipeline
          | 'Query' >> beam.Create([split_pattern])
          | 'QueryTable' >> beam.ParDo(_ReadPrestoDoFn(client))
          | 'ToTFExample' >> beam.Map(_row_to_example))
示例#19
0
def extract_properties(
    execution: metadata_store_pb2.Execution
) -> Dict[str, types.ExecPropertyTypes]:
    """Extracts execution properties from mlmd Execution."""
    result = {}
    for key, prop in itertools.chain(execution.properties.items(),
                                     execution.custom_properties.items()):
        if execution_lib.is_schema_key(key):
            continue

        schema_key = execution_lib.get_schema_key(key)
        schema = None
        if schema_key in execution.custom_properties:
            schema = proto_utils.json_to_proto(
                data_types_utils.get_metadata_value(
                    execution.custom_properties[schema_key]),
                pipeline_pb2.Value.Schema())
        value = data_types_utils.get_parsed_value(prop, schema)

        if value is None:
            raise ValueError(
                f'Unexpected property with empty value; key: {key}')
        result[key] = value
    return result
示例#20
0
def _run_driver(exec_properties: Dict[str, Any],
                outputs_dict: Dict[str, List[artifact.Artifact]],
                output_metadata_uri: str,
                name_from_id: Optional[Dict[int, str]] = None) -> None:
    """Runs the driver, writing its output as a ExecutorOutput proto.

  The main goal of this driver is to calculate the span and fingerprint of input
  data, allowing for the executor invocation to be skipped if the ExampleGen
  component has been previously run on the same data with the same
  configuration. This span and fingerprint are added as new custom execution
  properties to an ExecutorOutput proto and written to a GCS path. The CAIP
  pipelines system reads this file and updates MLMD with the new execution
  properties.


  Args:
    exec_properties:
      These are required to contain the following properties:
      'input_base_uri': A path from which files will be read and their
        span/fingerprint calculated.
      'input_config': A json-serialized tfx.proto.example_gen_pb2.InputConfig
        proto message.
        See https://www.tensorflow.org/tfx/guide/examplegen for more details.
      'output_config': A json-serialized tfx.proto.example_gen_pb2.OutputConfig
        proto message.
        See https://www.tensorflow.org/tfx/guide/examplegen for more details.
    outputs_dict: The mapping of the output artifacts.
    output_metadata_uri: A path at which an ExecutorOutput message will be
      written with updated execution properties and output artifacts. The CAIP
      Pipelines service will update the task's properties and artifacts prior to
      running the executor.
    name_from_id: Optional. Mapping from the converted int-typed id to str-typed
      runtime artifact name, which should be unique.
  """
    if name_from_id is None:
        name_from_id = {}

    logging.set_verbosity(logging.INFO)
    logging.info('exec_properties = %s\noutput_metadata_uri = %s',
                 exec_properties, output_metadata_uri)

    input_base_uri = exec_properties[utils.INPUT_BASE_KEY]

    input_config = example_gen_pb2.Input()
    proto_utils.json_to_proto(exec_properties[utils.INPUT_CONFIG_KEY],
                              input_config)

    # TODO(b/161734559): Support range config.
    fingerprint, select_span, version = utils.calculate_splits_fingerprint_span_and_version(
        input_base_uri, input_config.splits)
    logging.info('Calculated span: %s', select_span)
    logging.info('Calculated fingerprint: %s', fingerprint)

    exec_properties[utils.SPAN_PROPERTY_NAME] = select_span
    exec_properties[utils.FINGERPRINT_PROPERTY_NAME] = fingerprint
    exec_properties[utils.VERSION_PROPERTY_NAME] = version

    if utils.EXAMPLES_KEY not in outputs_dict:
        raise ValueError(
            'Example artifact was missing in the ExampleGen outputs.')
    example_artifact = artifact_utils.get_single_instance(
        outputs_dict[utils.EXAMPLES_KEY])

    driver.update_output_artifact(
        exec_properties=exec_properties,
        output_artifact=example_artifact.mlmd_artifact)

    # Log the output metadata file
    output_metadata = pipeline_pb2.ExecutorOutput()
    output_metadata.parameters[
        utils.FINGERPRINT_PROPERTY_NAME].string_value = fingerprint
    output_metadata.parameters[utils.SPAN_PROPERTY_NAME].string_value = str(
        select_span)
    output_metadata.parameters[
        utils.INPUT_CONFIG_KEY].string_value = json_format.MessageToJson(
            input_config)
    output_metadata.artifacts[utils.EXAMPLES_KEY].artifacts.add().CopyFrom(
        kubeflow_v2_entrypoint_utils.to_runtime_artifact(
            example_artifact, name_from_id))

    fileio.makedirs(os.path.dirname(output_metadata_uri))
    with fileio.open(output_metadata_uri, 'wb') as f:
        f.write(json_format.MessageToJson(output_metadata, sort_keys=True))
示例#21
0
    def Do(self, input_dict: Dict[Text, List[types.Artifact]],
           output_dict: Dict[Text, List[types.Artifact]],
           exec_properties: Dict[Text, Any]) -> None:
        """Runs batch inference on a given model with given input examples.

    Args:
      input_dict: Input dict from input key to a list of Artifacts.
        - examples: examples for inference.
        - model: exported model.
        - model_blessing: model blessing result, optional.
      output_dict: Output dict from output key to a list of Artifacts.
        - output: bulk inference results.
      exec_properties: A dict of execution properties.
        - model_spec: JSON string of bulk_inferrer_pb2.ModelSpec instance.
        - data_spec: JSON string of bulk_inferrer_pb2.DataSpec instance.

    Returns:
      None
    """
        self._log_startup(input_dict, output_dict, exec_properties)

        if output_dict.get(standard_component_specs.INFERENCE_RESULT_KEY):
            inference_result = artifact_utils.get_single_instance(
                output_dict[standard_component_specs.INFERENCE_RESULT_KEY])
        else:
            inference_result = None
        if output_dict.get(standard_component_specs.OUTPUT_EXAMPLES_KEY):
            output_examples = artifact_utils.get_single_instance(
                output_dict[standard_component_specs.OUTPUT_EXAMPLES_KEY])
        else:
            output_examples = None

        if 'examples' not in input_dict:
            raise ValueError('\'examples\' is missing in input dict.')
        if 'model' not in input_dict:
            raise ValueError('Input models are not valid, model '
                             'need to be specified.')
        if standard_component_specs.MODEL_BLESSING_KEY in input_dict:
            model_blessing = artifact_utils.get_single_instance(
                input_dict[standard_component_specs.MODEL_BLESSING_KEY])
            if not model_utils.is_model_blessed(model_blessing):
                logging.info('Model on %s was not blessed', model_blessing.uri)
                return
        else:
            logging.info(
                'Model blessing is not provided, exported model will be '
                'used.')

        model = artifact_utils.get_single_instance(
            input_dict[standard_component_specs.MODEL_KEY])
        model_path = path_utils.serving_model_path(
            model.uri, path_utils.is_old_model_artifact(model))
        logging.info('Use exported model from %s.', model_path)

        data_spec = bulk_inferrer_pb2.DataSpec()
        proto_utils.json_to_proto(
            exec_properties[standard_component_specs.DATA_SPEC_KEY], data_spec)

        output_example_spec = bulk_inferrer_pb2.OutputExampleSpec()
        if exec_properties.get(
                standard_component_specs.OUTPUT_EXAMPLE_SPEC_KEY):
            proto_utils.json_to_proto(
                exec_properties[
                    standard_component_specs.OUTPUT_EXAMPLE_SPEC_KEY],
                output_example_spec)

        self._run_model_inference(
            data_spec, output_example_spec,
            input_dict[standard_component_specs.EXAMPLES_KEY], output_examples,
            inference_result,
            self._get_inference_spec(model_path, exec_properties))
示例#22
0
  def Do(
      self,
      input_dict: Dict[Text, List[types.Artifact]],
      output_dict: Dict[Text, List[types.Artifact]],
      exec_properties: Dict[Text, Any],
  ) -> None:
    """Take input data source and generates serialized data splits.

    The output is intended to be serialized tf.train.Examples or
    tf.train.SequenceExamples protocol buffer in gzipped TFRecord format,
    but subclasses can choose to override to write to any serialized records
    payload into gzipped TFRecord as specified, so long as downstream
    component can consume it. The format of payload is added to
    `payload_format` custom property of the output Example artifact.

    Args:
      input_dict: Input dict from input key to a list of Artifacts. Depends on
        detailed example gen implementation.
      output_dict: Output dict from output key to a list of Artifacts.
        - examples: splits of serialized records.
      exec_properties: A dict of execution properties. Depends on detailed
        example gen implementation.
        - input_base: an external directory containing the data files.
        - input_config: JSON string of example_gen_pb2.Input instance,
          providing input configuration.
        - output_config: JSON string of example_gen_pb2.Output instance,
          providing output configuration.
        - output_data_format: Payload format of generated data in output
          artifact, one of example_gen_pb2.PayloadFormat enum.

    Returns:
      None
    """
    self._log_startup(input_dict, output_dict, exec_properties)

    input_config = example_gen_pb2.Input()
    proto_utils.json_to_proto(exec_properties[utils.INPUT_CONFIG_KEY],
                              input_config)
    output_config = example_gen_pb2.Output()
    proto_utils.json_to_proto(exec_properties[utils.OUTPUT_CONFIG_KEY],
                              output_config)

    examples_artifact = artifact_utils.get_single_instance(
        output_dict[utils.EXAMPLES_KEY])
    examples_artifact.split_names = artifact_utils.encode_split_names(
        utils.generate_output_split_names(input_config, output_config))

    logging.info('Generating examples.')
    with self._make_beam_pipeline() as pipeline:
      example_splits = self.GenerateExamplesByBeam(pipeline, exec_properties)

      # pylint: disable=expression-not-assigned, no-value-for-parameter
      for split_name, example_split in example_splits.items():
        (example_split
         | 'WriteSplit[{}]'.format(split_name) >> _WriteSplit(
             artifact_utils.get_split_uri(output_dict[utils.EXAMPLES_KEY],
                                          split_name)))
      # pylint: enable=expression-not-assigned, no-value-for-parameter

    output_payload_format = exec_properties.get(utils.OUTPUT_DATA_FORMAT_KEY)
    if output_payload_format:
      for output_examples_artifact in output_dict[utils.EXAMPLES_KEY]:
        examples_utils.set_payload_format(
            output_examples_artifact, output_payload_format)
    logging.info('Examples generated.')
示例#23
0
  def GenerateExamplesByBeam(
      self,
      pipeline: beam.Pipeline,
      exec_properties: Dict[Text, Any],
  ) -> Dict[Text, beam.pvalue.PCollection]:
    """Converts input source to serialized record splits based on configs.

    Custom ExampleGen executor should provide GetInputSourceToExamplePTransform
    for converting input split to serialized records. Overriding this
    'GenerateExamplesByBeam' method instead if complex logic is need, e.g.,
    custom spliting logic.

    Args:
      pipeline: Beam pipeline.
      exec_properties: A dict of execution properties. Depends on detailed
        example gen implementation.
        - input_base: an external directory containing the data files.
        - input_config: JSON string of example_gen_pb2.Input instance, providing
          input configuration.
        - output_config: JSON string of example_gen_pb2.Output instance,
          providing output configuration.
        - output_data_format: Payload format of generated data in output
          artifact, one of example_gen_pb2.PayloadFormat enum.

    Returns:
      Dict of beam PCollection with split name as key, each PCollection is a
      single output split that contains serialized records.
    """
    # Get input split information.
    input_config = example_gen_pb2.Input()
    proto_utils.json_to_proto(exec_properties[utils.INPUT_CONFIG_KEY],
                              input_config)
    # Get output split information.
    output_config = example_gen_pb2.Output()
    proto_utils.json_to_proto(exec_properties[utils.OUTPUT_CONFIG_KEY],
                              output_config)
    # Get output split names.
    split_names = utils.generate_output_split_names(input_config, output_config)
    # Make beam_pipeline_args available in exec_properties since certain
    # example_gen executors need this information.
    exec_properties['_beam_pipeline_args'] = self._beam_pipeline_args or []

    example_splits = []
    input_to_record = self.GetInputSourceToExamplePTransform()
    if output_config.split_config.splits:
      # Use output splits, input must have only one split.
      assert len(
          input_config.splits
      ) == 1, 'input must have only one split when output split is specified.'
      # Calculate split buckets.
      buckets = []
      total_buckets = 0
      for split in output_config.split_config.splits:
        total_buckets += split.hash_buckets
        buckets.append(total_buckets)
      example_splits = (
          pipeline
          | 'InputToRecord' >>
          # pylint: disable=no-value-for-parameter
          input_to_record(exec_properties, input_config.splits[0].pattern)
          | 'SplitData' >> beam.Partition(_PartitionFn, len(buckets), buckets,
                                          output_config.split_config))
    else:
      # Use input splits.
      for split in input_config.splits:
        examples = (
            pipeline
            | 'InputToRecord[{}]'.format(split.name) >>
            # pylint: disable=no-value-for-parameter
            input_to_record(exec_properties, split.pattern))
        example_splits.append(examples)

    result = {}
    for index, example_split in enumerate(example_splits):
      result[split_names[index]] = example_split
    return result
示例#24
0
    def testResolveExecProperties(self):
        # Create input dir.
        self._input_base_path = os.path.join(self._test_dir, 'input_base')
        fileio.makedirs(self._input_base_path)

        # Create exec proterties.
        self._exec_properties = {
            standard_component_specs.INPUT_BASE_KEY:
            self._input_base_path,
            standard_component_specs.INPUT_CONFIG_KEY:
            proto_utils.proto_to_json(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='s1',
                        pattern='span{SPAN:2}/version{VERSION:2}/split1/*'),
                    example_gen_pb2.Input.Split(
                        name='s2',
                        pattern='span{SPAN:2}/version{VERSION:2}/split2/*')
                ])),
            standard_component_specs.RANGE_CONFIG_KEY:
            None,
        }

        # Test align of span number.
        span1_v1_split1 = os.path.join(self._input_base_path, 'span01',
                                       'version01', 'split1', 'data')
        io_utils.write_string_file(span1_v1_split1, 'testing11')
        span1_v1_split2 = os.path.join(self._input_base_path, 'span01',
                                       'version01', 'split2', 'data')
        io_utils.write_string_file(span1_v1_split2, 'testing12')
        span2_v1_split1 = os.path.join(self._input_base_path, 'span02',
                                       'version01', 'split1', 'data')
        io_utils.write_string_file(span2_v1_split1, 'testing21')

        # Check that error raised when span does not match.
        with self.assertRaisesRegexp(
                ValueError, 'Latest span should be the same for each split'):
            self._file_based_driver.resolve_exec_properties(
                self._exec_properties, None, None)

        span2_v1_split2 = os.path.join(self._input_base_path, 'span02',
                                       'version01', 'split2', 'data')
        io_utils.write_string_file(span2_v1_split2, 'testing22')
        span2_v2_split1 = os.path.join(self._input_base_path, 'span02',
                                       'version02', 'split1', 'data')
        io_utils.write_string_file(span2_v2_split1, 'testing21')

        # Check that error raised when span matches, but version does not match.
        with self.assertRaisesRegexp(
                ValueError,
                'Latest version should be the same for each split'):
            self._file_based_driver.resolve_exec_properties(
                self._exec_properties, None, None)

        span2_v2_split2 = os.path.join(self._input_base_path, 'span02',
                                       'version02', 'split2', 'data')
        io_utils.write_string_file(span2_v2_split2, 'testing22')

        # Test if latest span and version selected when span and version aligns
        # for each split.
        self._file_based_driver.resolve_exec_properties(
            self._exec_properties, None, None)
        self.assertEqual(self._exec_properties[utils.SPAN_PROPERTY_NAME], 2)
        self.assertEqual(self._exec_properties[utils.VERSION_PROPERTY_NAME], 2)
        self.assertRegex(
            self._exec_properties[utils.FINGERPRINT_PROPERTY_NAME],
            r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*'
        )
        updated_input_config = example_gen_pb2.Input()
        proto_utils.json_to_proto(
            self._exec_properties[standard_component_specs.INPUT_CONFIG_KEY],
            updated_input_config)

        # Check if latest span is selected.
        self.assertProtoEquals(
            """
        splits {
          name: "s1"
          pattern: "span02/version02/split1/*"
        }
        splits {
          name: "s2"
          pattern: "span02/version02/split2/*"
        }""", updated_input_config)

        # Test driver behavior using RangeConfig with static range.
        self._exec_properties[
            standard_component_specs.
            INPUT_CONFIG_KEY] = proto_utils.proto_to_json(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='s1',
                        pattern='span{SPAN:2}/version{VERSION:2}/split1/*'),
                    example_gen_pb2.Input.Split(
                        name='s2',
                        pattern='span{SPAN:2}/version{VERSION:2}/split2/*'),
                ]))

        self._exec_properties[
            standard_component_specs.
            RANGE_CONFIG_KEY] = proto_utils.proto_to_json(
                range_config_pb2.RangeConfig(
                    static_range=range_config_pb2.StaticRange(
                        start_span_number=1, end_span_number=2)))
        with self.assertRaisesRegexp(
                ValueError, 'For ExampleGen, start and end span numbers'):
            self._file_based_driver.resolve_exec_properties(
                self._exec_properties, None, None)

        self._exec_properties[
            standard_component_specs.
            RANGE_CONFIG_KEY] = proto_utils.proto_to_json(
                range_config_pb2.RangeConfig(
                    static_range=range_config_pb2.StaticRange(
                        start_span_number=1, end_span_number=1)))
        self._file_based_driver.resolve_exec_properties(
            self._exec_properties, None, None)
        self.assertEqual(self._exec_properties[utils.SPAN_PROPERTY_NAME], 1)
        self.assertEqual(self._exec_properties[utils.VERSION_PROPERTY_NAME], 1)
        self.assertRegex(
            self._exec_properties[utils.FINGERPRINT_PROPERTY_NAME],
            r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*'
        )
        updated_input_config = example_gen_pb2.Input()
        proto_utils.json_to_proto(
            self._exec_properties[standard_component_specs.INPUT_CONFIG_KEY],
            updated_input_config)
        # Check if correct span inside static range is selected.
        self.assertProtoEquals(
            """
        splits {
          name: "s1"
          pattern: "span01/version01/split1/*"
        }
        splits {
          name: "s2"
          pattern: "span01/version01/split2/*"
        }""", updated_input_config)
示例#25
0
def run_component(full_component_class_name: str,
                  temp_directory_path: Optional[str] = None,
                  beam_pipeline_args: Optional[List[str]] = None,
                  **arguments):
    r"""Loads a component, instantiates it with arguments and runs its executor.

  The component class is instantiated, so the component code is executed,
  not just the executor code.

  To pass artifact URI, use <input_name>_uri argument name.
  To pass artifact property, use <input_name>_<property> argument name.
  Protobuf property values can be passed as JSON-serialized protobufs.

  # pylint: disable=line-too-long

  Example::

    # When run as a script:
    python3 scripts/run_component.py \
      --full-component-class-name tfx.components.StatisticsGen \
      --examples-uri gs://my_bucket/chicago_taxi_simple/CsvExamplesGen/examples/1/ \
      --examples-split-names '["train", "eval"]' \
      --output-uri gs://my_bucket/chicago_taxi_simple/StatisticsGen/output/1/

    # When run as a function:
    run_component(
      full_component_class_name='tfx.components.StatisticsGen',
      examples_uri='gs://my_bucket/chicago_taxi_simple/CsvExamplesGen/sxamples/1/',
      examples_split_names='["train", "eval"]',
      output_uri='gs://my_bucket/chicago_taxi_simple/StatisticsGen/output/1/',
    )

  Args:
    full_component_class_name: The component class name including module name.
    temp_directory_path: Optional. Temporary directory path for the executor.
    beam_pipeline_args: Optional. Arguments to pass to the Beam pipeline.
    **arguments: Key-value pairs with component arguments.
  """
    component_class = import_utils.import_class_by_path(
        full_component_class_name)

    component_arguments = {}

    for name, execution_param in component_class.SPEC_CLASS.PARAMETERS.items():
        argument_value = arguments.get(name, None)
        if argument_value is None:
            continue
        param_type = execution_param.type
        if (isinstance(param_type, type)
                and issubclass(param_type, message.Message)):
            argument_value_obj = param_type()
            proto_utils.json_to_proto(argument_value, argument_value_obj)
        elif param_type is int:
            argument_value_obj = int(argument_value)
        elif param_type is float:
            argument_value_obj = float(argument_value)
        else:
            argument_value_obj = argument_value
        component_arguments[name] = argument_value_obj

    for input_name, channel_param in component_class.SPEC_CLASS.INPUTS.items():
        uri = (arguments.get(input_name + '_uri')
               or arguments.get(input_name + '_path'))
        if uri:
            artifact = channel_param.type()
            artifact.uri = uri
            # Setting the artifact properties
            for property_name, property_spec in (channel_param.type.PROPERTIES
                                                 or {}).items():
                property_arg_name = input_name + '_' + property_name
                if property_arg_name in arguments:
                    property_value = arguments[property_arg_name]
                    if property_spec.type == PropertyType.INT:
                        property_value = int(property_value)
                    if property_spec.type == PropertyType.FLOAT:
                        property_value = float(property_value)
                    setattr(artifact, property_name, property_value)
            component_arguments[input_name] = channel_utils.as_channel(
                [artifact])

    component_instance = component_class(**component_arguments)

    input_dict = channel_utils.unwrap_channel_dict(component_instance.inputs)
    output_dict = channel_utils.unwrap_channel_dict(component_instance.outputs)
    exec_properties = component_instance.exec_properties

    # Generating paths for output artifacts
    for output_name, channel_param in component_class.SPEC_CLASS.OUTPUTS.items(
    ):
        uri = (arguments.get('output_' + output_name + '_uri')
               or arguments.get(output_name + '_uri')
               or arguments.get(output_name + '_path'))
        if uri:
            artifacts = output_dict[output_name]
            if not artifacts:
                artifacts.append(channel_param.type())
            for artifact in artifacts:
                artifact.uri = uri

    if issubclass(component_instance.executor_spec.executor_class,
                  base_beam_executor.BaseBeamExecutor):
        executor_context = base_beam_executor.BaseBeamExecutor.Context(
            beam_pipeline_args=beam_pipeline_args,
            tmp_dir=temp_directory_path,
            unique_id='',
        )
    else:
        executor_context = base_executor.BaseExecutor.Context(
            extra_flags=beam_pipeline_args,
            tmp_dir=temp_directory_path,
            unique_id='',
        )
    executor = component_instance.executor_spec.executor_class(
        executor_context)
    executor.Do(
        input_dict=input_dict,
        output_dict=output_dict,
        exec_properties=exec_properties,
    )

    # Writing out the output artifact properties
    for output_name, channel_param in component_class.SPEC_CLASS.OUTPUTS.items(
    ):
        for property_name in channel_param.type.PROPERTIES or []:
            property_path_arg_name = output_name + '_' + property_name + '_path'
            property_path = arguments.get(property_path_arg_name)
            if property_path:
                artifacts = output_dict[output_name]
                for artifact in artifacts:
                    property_value = getattr(artifact, property_name)
                    os.makedirs(os.path.dirname(property_path), exist_ok=True)
                    with open(property_path, 'w') as f:
                        f.write(str(property_value))
示例#26
0
文件: executor.py 项目: jay90099/tfx
    def Do(self, input_dict: Dict[str, List[types.Artifact]],
           output_dict: Dict[str, List[types.Artifact]],
           exec_properties: Dict[str, Any]) -> None:
        """Runs batch inference on a given model with given input examples.

    This function creates a new model (if necessary) and a new model version
    before inference, and cleans up resources after inference. It provides
    re-executability as it cleans up (only) the model resources that are created
    during the process even inference job failed.

    Args:
      input_dict: Input dict from input key to a list of Artifacts.
        - examples: examples for inference.
        - model: exported model.
        - model_blessing: model blessing result
      output_dict: Output dict from output key to a list of Artifacts.
        - output: bulk inference results.
      exec_properties: A dict of execution properties.
        - data_spec: JSON string of bulk_inferrer_pb2.DataSpec instance.
        - custom_config: custom_config.ai_platform_serving_args need to contain
          the serving job parameters sent to Google Cloud AI Platform. For the
          full set of parameters, refer to
          https://cloud.google.com/ml-engine/reference/rest/v1/projects.models

    Returns:
      None
    """
        self._log_startup(input_dict, output_dict, exec_properties)

        if output_dict.get('inference_result'):
            inference_result = artifact_utils.get_single_instance(
                output_dict['inference_result'])
        else:
            inference_result = None
        if output_dict.get('output_examples'):
            output_examples = artifact_utils.get_single_instance(
                output_dict['output_examples'])
        else:
            output_examples = None

        if 'examples' not in input_dict:
            raise ValueError('`examples` is missing in input dict.')
        if 'model' not in input_dict:
            raise ValueError('Input models are not valid, model '
                             'need to be specified.')
        if 'model_blessing' in input_dict:
            model_blessing = artifact_utils.get_single_instance(
                input_dict['model_blessing'])
            if not model_utils.is_model_blessed(model_blessing):
                logging.info('Model on %s was not blessed', model_blessing.uri)
                return
        else:
            logging.info(
                'Model blessing is not provided, exported model will be '
                'used.')
        if _CUSTOM_CONFIG_KEY not in exec_properties:
            raise ValueError(
                'Input exec properties are not valid, {} '
                'need to be specified.'.format(_CUSTOM_CONFIG_KEY))

        custom_config = json_utils.loads(
            exec_properties.get(_CUSTOM_CONFIG_KEY, 'null'))
        if custom_config is not None and not isinstance(custom_config, Dict):
            raise ValueError(
                'custom_config in execution properties needs to be a '
                'dict.')
        ai_platform_serving_args = custom_config.get(SERVING_ARGS_KEY)
        if not ai_platform_serving_args:
            raise ValueError(
                '`ai_platform_serving_args` is missing in `custom_config`')
        service_name, api_version = runner.get_service_name_and_api_version(
            ai_platform_serving_args)
        executor_class_path = '%s.%s' % (self.__class__.__module__,
                                         self.__class__.__name__)
        with telemetry_utils.scoped_labels(
            {telemetry_utils.LABEL_TFX_EXECUTOR: executor_class_path}):
            job_labels = telemetry_utils.make_labels_dict()
        model = artifact_utils.get_single_instance(input_dict['model'])
        model_path = path_utils.serving_model_path(
            model.uri, path_utils.is_old_model_artifact(model))
        logging.info('Use exported model from %s.', model_path)
        # Use model artifact uri to generate model version to guarantee the
        # 1:1 mapping from model version to model.
        model_version = 'version_' + hashlib.sha256(
            model.uri.encode()).hexdigest()
        inference_spec = self._get_inference_spec(model_path, model_version,
                                                  ai_platform_serving_args)
        data_spec = bulk_inferrer_pb2.DataSpec()
        proto_utils.json_to_proto(exec_properties['data_spec'], data_spec)
        output_example_spec = bulk_inferrer_pb2.OutputExampleSpec()
        if exec_properties.get('output_example_spec'):
            proto_utils.json_to_proto(exec_properties['output_example_spec'],
                                      output_example_spec)
        endpoint = custom_config.get(constants.ENDPOINT_ARGS_KEY)
        if endpoint and 'regions' in ai_platform_serving_args:
            raise ValueError(
                '`endpoint` and `ai_platform_serving_args.regions` cannot be set simultaneously'
            )
        api = discovery.build(
            service_name,
            api_version,
            requestBuilder=telemetry_utils.TFXHttpRequest,
            client_options=client_options.ClientOptions(api_endpoint=endpoint),
        )
        new_model_endpoint_created = False
        try:
            new_model_endpoint_created = runner.create_model_for_aip_prediction_if_not_exist(
                job_labels, ai_platform_serving_args, api)
            runner.deploy_model_for_aip_prediction(
                serving_path=model_path,
                model_version_name=model_version,
                ai_platform_serving_args=ai_platform_serving_args,
                api=api,
                labels=job_labels,
                skip_model_endpoint_creation=True,
                set_default=False,
            )
            self._run_model_inference(data_spec, output_example_spec,
                                      input_dict['examples'], output_examples,
                                      inference_result, inference_spec)
        except Exception as e:
            logging.error(
                'Error in executing CloudAIBulkInferrerComponent: %s', str(e))
            raise
        finally:
            # Guarantee newly created resources are cleaned up even if the inference
            # job failed.

            # Clean up the newly deployed model.
            runner.delete_model_from_aip_if_exists(
                model_version_name=model_version,
                ai_platform_serving_args=ai_platform_serving_args,
                api=api,
                delete_model_endpoint=new_model_endpoint_created)
示例#27
0
    def Do(self, input_dict: Dict[Text, List[types.Artifact]],
           output_dict: Dict[Text, List[types.Artifact]],
           exec_properties: Dict[Text, Any]) -> None:
        """Push model to target directory if blessed.

    Args:
      input_dict: Input dict from input key to a list of artifacts, including:
        - model: exported model from trainer.
        - model_blessing: model blessing path from model_validator.  A push
          action delivers the model exports produced by Trainer to the
          destination defined in component config.
      output_dict: Output dict from key to a list of artifacts, including:
        - pushed_model: A list of 'ModelPushPath' artifact of size one. It will
          include the model in this push execution if the model was pushed.
      exec_properties: A dict of execution properties, including:
        - push_destination: JSON string of pusher_pb2.PushDestination instance,
          providing instruction of destination to push model.

    Returns:
      None
    """
        self._log_startup(input_dict, output_dict, exec_properties)
        model_push = artifact_utils.get_single_instance(
            output_dict[standard_component_specs.PUSHED_MODEL_KEY])
        if not self.CheckBlessing(input_dict):
            self._MarkNotPushed(model_push)
            return
        model_export = artifact_utils.get_single_instance(
            input_dict[standard_component_specs.MODEL_KEY])
        model_path = path_utils.serving_model_path(
            model_export.uri, path_utils.is_old_model_artifact(model_export))

        # Push model to the destination, which can be listened by a model server.
        #
        # If model is already successfully copied to outside before, stop copying.
        # This is because model validator might blessed same model twice (check
        # mv driver) with different blessing output, we still want Pusher to
        # handle the mv output again to keep metadata tracking, but no need to
        # copy to outside path again..
        # TODO(jyzhao): support rpc push and verification.
        push_destination = pusher_pb2.PushDestination()
        proto_utils.json_to_proto(
            exec_properties[standard_component_specs.PUSH_DESTINATION_KEY],
            push_destination)

        destination_kind = push_destination.WhichOneof('destination')
        if destination_kind == 'filesystem':
            fs_config = push_destination.filesystem
            if fs_config.versioning == _Versioning.AUTO:
                fs_config.versioning = _Versioning.UNIX_TIMESTAMP
            if fs_config.versioning == _Versioning.UNIX_TIMESTAMP:
                model_version = str(int(time.time()))
            else:
                raise NotImplementedError('Invalid Versioning {}'.format(
                    fs_config.versioning))
            logging.info('Model version: %s', model_version)
            serving_path = os.path.join(fs_config.base_directory,
                                        model_version)

            if fileio.exists(serving_path):
                logging.info(
                    'Destination directory %s already exists, skipping current push.',
                    serving_path)
            else:
                # tf.serving won't load partial model, it will retry until fully copied.
                io_utils.copy_dir(model_path, serving_path)
                logging.info('Model written to serving path %s.', serving_path)
        else:
            raise NotImplementedError(
                'Invalid push destination {}'.format(destination_kind))

        # Copy the model to pushing uri for archiving.
        io_utils.copy_dir(model_path, model_push.uri)
        self._MarkPushed(model_push,
                         pushed_destination=serving_path,
                         pushed_version=model_version)
        logging.info('Model pushed to %s.', model_push.uri)
示例#28
0
    def testDriverRunFn(self):
        # Create input dir.
        self._input_base_path = os.path.join(self._test_dir, 'input_base')
        fileio.makedirs(self._input_base_path)

        # Fake previous outputs
        span1_v1_split1 = os.path.join(self._input_base_path, 'span01',
                                       'split1', 'data')
        io_utils.write_string_file(span1_v1_split1, 'testing11')
        span1_v1_split2 = os.path.join(self._input_base_path, 'span01',
                                       'split2', 'data')
        io_utils.write_string_file(span1_v1_split2, 'testing12')

        ir_driver = driver.FileBasedDriver(self._mock_metadata)
        example = standard_artifacts.Examples()

        # Prepare output_dic
        example.uri = 'my_uri'  # Will verify that this uri is not changed.
        output_dic = {standard_component_specs.EXAMPLES_KEY: [example]}

        # Prepare output_dic exec_proterties.
        exec_properties = {
            standard_component_specs.INPUT_BASE_KEY:
            self._input_base_path,
            standard_component_specs.INPUT_CONFIG_KEY:
            proto_utils.proto_to_json(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='s1', pattern='span{SPAN:2}/split1/*'),
                    example_gen_pb2.Input.Split(
                        name='s2', pattern='span{SPAN:2}/split2/*')
                ])),
        }
        result = ir_driver.run(
            portable_data_types.ExecutionInfo(output_dict=output_dic,
                                              exec_properties=exec_properties))
        # Assert exec_properties' values
        exec_properties = result.exec_properties
        self.assertEqual(exec_properties[utils.SPAN_PROPERTY_NAME].int_value,
                         1)
        updated_input_config = example_gen_pb2.Input()
        proto_utils.json_to_proto(
            exec_properties[
                standard_component_specs.INPUT_CONFIG_KEY].string_value,
            updated_input_config)
        self.assertProtoEquals(
            """
        splits {
          name: "s1"
          pattern: "span01/split1/*"
        }
        splits {
          name: "s2"
          pattern: "span01/split2/*"
        }""", updated_input_config)
        self.assertRegex(
            exec_properties[utils.FINGERPRINT_PROPERTY_NAME].string_value,
            r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*'
        )
        # Assert output_artifacts' values
        self.assertLen(
            result.output_artifacts[
                standard_component_specs.EXAMPLES_KEY].artifacts, 1)
        output_example = result.output_artifacts[
            standard_component_specs.EXAMPLES_KEY].artifacts[0]
        self.assertEqual(output_example.uri, example.uri)
        self.assertEqual(
            output_example.custom_properties[
                utils.SPAN_PROPERTY_NAME].string_value, '1')
        self.assertRegex(
            output_example.custom_properties[
                utils.FINGERPRINT_PROPERTY_NAME].string_value,
            r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*'
        )
示例#29
0
def get_common_fn_args(input_dict: Dict[Text, List[types.Artifact]],
                       exec_properties: Dict[Text, Any],
                       working_dir: Text = None) -> FnArgs:
  """Get common args of training and tuning."""
  if input_dict.get(standard_component_specs.TRANSFORM_GRAPH_KEY):
    transform_graph_path = artifact_utils.get_single_uri(
        input_dict[standard_component_specs.TRANSFORM_GRAPH_KEY])
  else:
    transform_graph_path = None

  if input_dict.get(standard_component_specs.SCHEMA_KEY):
    schema_path = io_utils.get_only_uri_in_dir(
        artifact_utils.get_single_uri(
            input_dict[standard_component_specs.SCHEMA_KEY]))
  else:
    schema_path = None

  train_args = trainer_pb2.TrainArgs()
  eval_args = trainer_pb2.EvalArgs()
  proto_utils.json_to_proto(
      exec_properties[standard_component_specs.TRAIN_ARGS_KEY], train_args)
  proto_utils.json_to_proto(
      exec_properties[standard_component_specs.EVAL_ARGS_KEY], eval_args)

  # Default behavior is train on `train` split (when splits is empty in train
  # args) and evaluate on `eval` split (when splits is empty in eval args).
  if not train_args.splits:
    train_args.splits.append('train')
    absl.logging.info("Train on the 'train' split when train_args.splits is "
                      'not set.')
  if not eval_args.splits:
    eval_args.splits.append('eval')
    absl.logging.info("Evaluate on the 'eval' split when eval_args.splits is "
                      'not set.')

  train_files = []
  for train_split in train_args.splits:
    train_files.extend([
        io_utils.all_files_pattern(uri)
        for uri in artifact_utils.get_split_uris(
            input_dict[standard_component_specs.EXAMPLES_KEY], train_split)
    ])

  eval_files = []
  for eval_split in eval_args.splits:
    eval_files.extend([
        io_utils.all_files_pattern(uri)
        for uri in artifact_utils.get_split_uris(
            input_dict[standard_component_specs.EXAMPLES_KEY], eval_split)
    ])

  data_accessor = DataAccessor(
      tf_dataset_factory=tfxio_utils.get_tf_dataset_factory_from_artifact(
          input_dict[standard_component_specs.EXAMPLES_KEY],
          _TELEMETRY_DESCRIPTORS),
      record_batch_factory=tfxio_utils.get_record_batch_factory_from_artifact(
          input_dict[standard_component_specs.EXAMPLES_KEY],
          _TELEMETRY_DESCRIPTORS),
      data_view_decode_fn=tfxio_utils.get_data_view_decode_fn_from_artifact(
          input_dict[standard_component_specs.EXAMPLES_KEY],
          _TELEMETRY_DESCRIPTORS)
      )

  # https://github.com/tensorflow/tfx/issues/45: Replace num_steps=0 with
  # num_steps=None.  Conversion of the proto to python will set the default
  # value of an int as 0 so modify the value here.  Tensorflow will raise an
  # error if num_steps <= 0.
  train_steps = train_args.num_steps or None
  eval_steps = eval_args.num_steps or None

  # Load and deserialize custom config from execution properties.
  # Note that in the component interface the default serialization of custom
  # config is 'null' instead of '{}'. Therefore we need to default the
  # json_utils.loads to 'null' then populate it with an empty dict when
  # needed.
  custom_config = json_utils.loads(
      exec_properties.get(standard_component_specs.CUSTOM_CONFIG_KEY, 'null'))

  return FnArgs(
      working_dir=working_dir,
      train_files=train_files,
      eval_files=eval_files,
      train_steps=train_steps,
      eval_steps=eval_steps,
      schema_path=schema_path,
      transform_graph_path=transform_graph_path,
      data_accessor=data_accessor,
      custom_config=custom_config,
  )
示例#30
0
 def test_json_to_proto(self):
     json_str = '{"obsolete_field":2,"string_value":"x"}'
     result = proto_utils.json_to_proto(json_str, foo_pb2.TestProto())
     self.assertEqual(result, foo_pb2.TestProto(string_value='x'))
     # Make sure that returned type is not message.Message
     self.assertEqual(result.string_value, 'x')