Ejemplo n.º 1
0
  def __init__(self,
               statistics: types.Channel = None,
               schema: types.Channel = None,
               exclude_splits: Optional[List[Text]] = None,
               anomalies: Optional[Text] = None):
    """Construct an ExampleValidator component.

    Args:
      statistics: A Channel of type `standard_artifacts.ExampleStatistics`.
      schema: A Channel of type `standard_artifacts.Schema`. _required_
      exclude_splits: Names of splits that the example validator should not
        validate. Default behavior (when exclude_splits is set to None)
        is excluding no splits.
      anomalies: Output channel of type `standard_artifacts.ExampleAnomalies`.
    """
    if exclude_splits is None:
      exclude_splits = []
      logging.info('Excluding no splits because exclude_splits is not set.')
    if not anomalies:
      anomalies = types.Channel(type=standard_artifacts.ExampleAnomalies)
    spec = ExampleValidatorSpec(
        statistics=statistics,
        schema=schema,
        exclude_splits=json_utils.dumps(exclude_splits),
        anomalies=anomalies)
    super(ExampleValidator, self).__init__(spec=spec)
Ejemplo n.º 2
0
  def __init__(self,
               statistics: types.Channel = None,
               schema: types.Channel = None,
               exclude_splits: Optional[List[Text]] = None,
               anomalies: Optional[Text] = None,
               instance_name: Optional[Text] = None):
    """Construct an ExampleValidator component.

    Args:
      statistics: A Channel of type `standard_artifacts.ExampleStatistics`.
      schema: A Channel of type `standard_artifacts.Schema`. _required_
      exclude_splits: Names of splits that the example validator should not
        validate. Default behavior (when exclude_splits is set to None)
        is excluding no splits.
      anomalies: Output channel of type `standard_artifacts.ExampleAnomalies`.
      instance_name: Optional name assigned to this specific instance of
        ExampleValidator. Required only if multiple ExampleValidator components
        are declared in the same pipeline.  Either `stats` or `statistics` must
        be present in the arguments.
    """
    if exclude_splits is None:
      exclude_splits = []
      logging.info('Excluding no splits because exclude_splits is not set.')
    if not anomalies:
      anomalies = types.Channel(type=standard_artifacts.ExampleAnomalies)
    spec = ExampleValidatorSpec(
        statistics=statistics,
        schema=schema,
        exclude_splits=json_utils.dumps(exclude_splits),
        anomalies=anomalies)
    super(ExampleValidator, self).__init__(
        spec=spec, instance_name=instance_name)
Ejemplo n.º 3
0
  def __init__(self,
               examples: types.Channel,
               schema: Optional[types.Channel] = None,
               stats_options: Optional[tfdv.StatsOptions] = None,
               exclude_splits: Optional[List[Text]] = None):
    """Construct a StatisticsGen component.

    Args:
      examples: A Channel of `ExamplesPath` type, likely generated by the
        [ExampleGen component](https://www.tensorflow.org/tfx/guide/examplegen).
        This needs to contain two splits labeled `train` and `eval`. _required_
      schema: A `Schema` channel to use for automatically configuring the value
        of stats options passed to TFDV.
      stats_options: The StatsOptions instance to configure optional TFDV
        behavior. When stats_options.schema is set, it will be used instead of
        the `schema` channel input. Due to the requirement that stats_options be
        serialized, the slicer functions and custom stats generators are dropped
        and are therefore not usable.
      exclude_splits: Names of splits where statistics and sample should not
        be generated. Default behavior (when exclude_splits is set to None)
        is excluding no splits.
    """
    if exclude_splits is None:
      exclude_splits = []
      logging.info('Excluding no splits because exclude_splits is not set.')
    statistics = types.Channel(type=standard_artifacts.ExampleStatistics)
    # TODO(b/150802589): Move jsonable interface to tfx_bsl and use json_utils.
    stats_options_json = stats_options.to_json() if stats_options else None
    spec = StatisticsGenSpec(
        examples=examples,
        schema=schema,
        stats_options_json=stats_options_json,
        exclude_splits=json_utils.dumps(exclude_splits),
        statistics=statistics)
    super(StatisticsGen, self).__init__(spec=spec)
Ejemplo n.º 4
0
    def __init__(
            self,
            model: types.Channel = None,
            model_blessing: Optional[types.Channel] = None,
            infra_blessing: Optional[types.Channel] = None,
            push_destination: Optional[Union[pusher_pb2.PushDestination,
                                             Dict[Text, Any]]] = None,
            custom_config: Optional[Dict[Text, Any]] = None,
            custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None,
            output: Optional[types.Channel] = None,
            model_export: Optional[types.Channel] = None,
            instance_name: Optional[Text] = None):
        """Construct a Pusher component.

    Args:
      model: A Channel of type `standard_artifacts.Model`, usually produced by
        a Trainer component.
      model_blessing: An optional Channel of type
        `standard_artifacts.ModelBlessing`, usually produced from an Evaluator
        component.
      infra_blessing: An optional Channel of type
        `standard_artifacts.InfraBlessing`, usually produced from an
        InfraValidator component.
      push_destination: A pusher_pb2.PushDestination instance, providing info
        for tensorflow serving to load models. Optional if executor_class
        doesn't require push_destination. If any field is provided as a
        RuntimeParameter, push_destination should be constructed as a dict with
        the same field names as PushDestination proto message.
      custom_config: A dict which contains the deployment job parameters to be
        passed to cloud-based training platforms. The [Kubeflow example](
          https://github.com/tensorflow/tfx/blob/6ff57e36a7b65818d4598d41e584a42584d361e6/tfx/examples/chicago_taxi_pipeline/taxi_pipeline_kubeflow_gcp.py#L278-L285)
          contains an example how this can be used by custom executors.
      custom_executor_spec: Optional custom executor spec.
      output: Optional output `standard_artifacts.PushedModel` channel with
        result of push.
      model_export: Backwards compatibility alias for the 'model' argument.
      instance_name: Optional unique instance name. Necessary if multiple Pusher
        components are declared in the same pipeline.
    """
        if model_export:
            absl.logging.warning(
                'The "model_export" argument to the Pusher component has '
                'been renamed to "model" and is deprecated. Please update your '
                'usage as support for this argument will be removed soon.')
            model = model_export
        output = output or types.Channel(type=standard_artifacts.PushedModel)
        if push_destination is None and not custom_executor_spec:
            raise ValueError(
                'push_destination is required unless a '
                'custom_executor_spec is supplied that does not require '
                'it.')
        spec = PusherSpec(model=model,
                          model_blessing=model_blessing,
                          infra_blessing=infra_blessing,
                          push_destination=push_destination,
                          custom_config=json_utils.dumps(custom_config),
                          pushed_model=output)
        super(Pusher, self).__init__(spec=spec,
                                     custom_executor_spec=custom_executor_spec,
                                     instance_name=instance_name)
Ejemplo n.º 5
0
  def _parse_parameter_from_component(
      self, component: base_component.BaseComponent) -> None:
    """Extract embedded RuntimeParameter placeholders from a component.

    Extract embedded RuntimeParameter placeholders from a component, then append
    the corresponding dsl.PipelineParam to KubeflowDagRunner.

    Args:
      component: a TFX component.
    """

    serialized_component = json_utils.dumps(component)
    placeholders = re.findall(data_types.RUNTIME_PARAMETER_PATTERN,
                              serialized_component)
    for placeholder in placeholders:
      placeholder = placeholder.replace('\\', '')  # Clean escapes.
      placeholder = utils.fix_brackets(placeholder)  # Fix brackets if needed.
      parameter = json_utils.loads(placeholder)
      # Escape pipeline root because it will be added later.
      if parameter.name == tfx_pipeline.ROOT_PARAMETER.name:
        continue
      if parameter.name not in self._deduped_parameter_names:
        self._deduped_parameter_names.add(parameter.name)
        # TODO(b/178436919): Create a test to cover default value rendering
        # and move the external code reference over there.
        # The default needs to be serialized then passed to dsl.PipelineParam.
        # See
        # https://github.com/kubeflow/pipelines/blob/f65391309650fdc967586529e79af178241b4c2c/sdk/python/kfp/dsl/_pipeline_param.py#L154
        dsl_parameter = dsl.PipelineParam(
            name=parameter.name, value=str(parameter.default))
        self._params.append(dsl_parameter)
Ejemplo n.º 6
0
    def testDoWithMajoritVoting(self):

        exec_properties = self._exec_properties.copy()
        exec_properties['tuner_fn'] = '%s.%s' % (
            tuner_module.tuner_fn.__module__, tuner_module.tuner_fn.__name__)
        exec_properties['metalearning_algorithm'] = 'majority_voting'

        input_dict = self._input_dict.copy()

        ps_type = ps_pb2.Type(
            binary_classification=ps_pb2.BinaryClassification(label='class'))
        ps = ps_pb2.ProblemStatement(
            owner=['nitroml'],
            tasks=[ps_pb2.Task(
                name='mockdata_1',
                type=ps_type,
            )])

        exec_properties['custom_config'] = json_utils.dumps({
            'problem_statement':
            text_format.MessageToString(message=ps, as_utf8=True),
        })
        hps_artifact = artifacts.KCandidateHyperParameters()
        hps_artifact.uri = os.path.join(self._testdata_dir,
                                        'MetaLearner.majority_voting',
                                        'hparams_out')
        input_dict['warmup_hyperparameters'] = [hps_artifact]

        tuner = executor.Executor(self._context)
        tuner.Do(input_dict=input_dict,
                 output_dict=self._output_dict,
                 exec_properties=exec_properties)
        self._verify_output()
Ejemplo n.º 7
0
 def get_value_and_set_type(
     value: types.ExecPropertyTypes,
     value_type: pipeline_pb2.Value.Schema.ValueType) -> types.Property:
   """Returns serialized value and sets value_type."""
   if isinstance(value, bool):
     if set_schema:
       value_type.boolean_type.SetInParent()
     return value
   elif isinstance(value, message.Message):
     # TODO(b/171794016): Investigate if file descripter set is needed for
     # tfx-owned proto already build in the launcher binary.
     if set_schema:
       proto_type = value_type.proto_type
       proto_type.message_type = type(value).DESCRIPTOR.full_name
       proto_utils.build_file_descriptor_set(value,
                                             proto_type.file_descriptors)
     return proto_utils.proto_to_json(value)
   elif isinstance(value, list) and len(value):
     if set_schema:
       value_type.list_type.SetInParent()
     value = [
         get_value_and_set_type(val, value_type.list_type) for val in value
     ]
     return json_utils.dumps(value)
   elif isinstance(value, (int, float, str)):
     return value
   else:
     raise ValueError('Unexpected type %s' % type(value))
Ejemplo n.º 8
0
    def testDo(self):
        source_data_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'testdata')

        eval_stats_artifact = standard_artifacts.ExampleStatistics()
        eval_stats_artifact.uri = os.path.join(source_data_dir,
                                               'statistics_gen')
        eval_stats_artifact.split_names = artifact_utils.encode_split_names(
            ['train', 'eval', 'test'])

        schema_artifact = standard_artifacts.Schema()
        schema_artifact.uri = os.path.join(source_data_dir, 'schema_gen')

        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        validation_output = standard_artifacts.ExampleAnomalies()
        validation_output.uri = os.path.join(output_data_dir, 'output')

        input_dict = {
            STATISTICS_KEY: [eval_stats_artifact],
            SCHEMA_KEY: [schema_artifact],
        }

        exec_properties = {
            # List needs to be serialized before being passed into Do function.
            EXCLUDE_SPLITS_KEY: json_utils.dumps(['test'])
        }

        output_dict = {
            ANOMALIES_KEY: [validation_output],
        }

        example_validator_executor = executor.Executor()
        example_validator_executor.Do(input_dict, output_dict, exec_properties)

        self.assertEqual(artifact_utils.encode_split_names(['train', 'eval']),
                         validation_output.split_names)

        # Check example_validator outputs.
        train_anomalies_path = os.path.join(validation_output.uri,
                                            'Split-train', 'SchemaDiff.pb')
        eval_anomalies_path = os.path.join(validation_output.uri, 'Split-eval',
                                           'SchemaDiff.pb')
        self.assertTrue(fileio.exists(train_anomalies_path))
        self.assertTrue(fileio.exists(eval_anomalies_path))
        train_anomalies_bytes = io_utils.read_bytes_file(train_anomalies_path)
        train_anomalies = anomalies_pb2.Anomalies()
        train_anomalies.ParseFromString(train_anomalies_bytes)
        eval_anomalies_bytes = io_utils.read_bytes_file(eval_anomalies_path)
        eval_anomalies = anomalies_pb2.Anomalies()
        eval_anomalies.ParseFromString(eval_anomalies_bytes)
        self.assertEqual(0, len(train_anomalies.anomaly_info))
        self.assertEqual(0, len(eval_anomalies.anomaly_info))

        # Assert 'test' split is excluded.
        train_file_path = os.path.join(validation_output.uri, 'Split-test',
                                       'SchemaDiff.pb')
        self.assertFalse(fileio.exists(train_file_path))
Ejemplo n.º 9
0
    def testDoWithTunerFn(self):

        self._exec_properties['tuner_fn'] = '%s.%s' % (
            tuner_module.tuner_fn.__module__, tuner_module.tuner_fn.__name__)

        ps_type = ps_pb2.Type(
            binary_classification=ps_pb2.BinaryClassification(label='class'))
        ps = ps_pb2.ProblemStatement(
            owner=['nitroml'],
            tasks=[ps_pb2.Task(
                name='mockdata_1',
                type=ps_type,
            )])

        self._exec_properties['custom_config'] = json_utils.dumps({
            'problem_statement':
            text_format.MessageToString(message=ps, as_utf8=True),
        })

        tuner = executor.Executor(self._context)
        tuner.Do(input_dict=self._input_dict,
                 output_dict=self._output_dict,
                 exec_properties=self._exec_properties)

        self._verify_output()
Ejemplo n.º 10
0
    def _parse_parameter_from_component(
            self, component: base_component.BaseComponent) -> None:
        """Extract embedded RuntimeParameter placeholders from a component.

    Extract embedded RuntimeParameter placeholders from a component, then append
    the corresponding dsl.PipelineParam to KubeflowDagRunner.

    Args:
      component: a TFX component.
    """

        serialized_component = json_utils.dumps(component)
        placeholders = re.findall(data_types.RUNTIME_PARAMETER_PATTERN,
                                  serialized_component)
        for placeholder in placeholders:
            placeholder = placeholder.replace('\\', '')  # Clean escapes.
            placeholder = utils.fix_brackets(
                placeholder)  # Fix brackets if needed.
            parameter = json_utils.loads(placeholder)
            # Escape pipeline root because it will be added later.
            if parameter.name == tfx_pipeline.ROOT_PARAMETER.name:
                continue
            if parameter.name not in self._deduped_parameter_names:
                self._deduped_parameter_names.add(parameter.name)
                dsl_parameter = dsl.PipelineParam(name=parameter.name,
                                                  value=parameter.default)
                self._params.append(dsl_parameter)
Ejemplo n.º 11
0
    def testDoValidation(self, exec_properties, blessed, has_baseline):
        source_data_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'testdata')
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        # Create input dict.
        examples = standard_artifacts.Examples()
        examples.uri = os.path.join(source_data_dir, 'csv_example_gen')
        examples.split_names = artifact_utils.encode_split_names(
            ['train', 'eval'])
        model = standard_artifacts.Model()
        baseline_model = standard_artifacts.Model()
        model.uri = os.path.join(source_data_dir, 'trainer/current')
        baseline_model.uri = os.path.join(source_data_dir, 'trainer/previous/')
        blessing_output = standard_artifacts.ModelBlessing()
        blessing_output.uri = os.path.join(output_data_dir, 'blessing_output')
        schema = standard_artifacts.Schema()
        schema.uri = os.path.join(source_data_dir, 'schema_gen')
        input_dict = {
            EXAMPLES_KEY: [examples],
            MODEL_KEY: [model],
            SCHEMA_KEY: [schema],
        }
        if has_baseline:
            input_dict[BASELINE_MODEL_KEY] = [baseline_model]

        # Create output dict.
        eval_output = standard_artifacts.ModelEvaluation()
        eval_output.uri = os.path.join(output_data_dir, 'eval_output')
        blessing_output = standard_artifacts.ModelBlessing()
        blessing_output.uri = os.path.join(output_data_dir, 'blessing_output')
        output_dict = {
            EVALUATION_KEY: [eval_output],
            BLESSING_KEY: [blessing_output],
        }

        # List needs to be serialized before being passed into Do function.
        exec_properties[EXAMPLE_SPLITS_KEY] = json_utils.dumps(None)

        # Run executor.
        evaluator = executor.Executor()
        evaluator.Do(input_dict, output_dict, exec_properties)

        # Check evaluator outputs.
        self.assertTrue(
            fileio.exists(os.path.join(eval_output.uri, 'eval_config.json')))
        self.assertTrue(fileio.exists(os.path.join(eval_output.uri,
                                                   'metrics')))
        self.assertTrue(fileio.exists(os.path.join(eval_output.uri, 'plots')))
        self.assertTrue(
            fileio.exists(os.path.join(eval_output.uri, 'validations')))
        if blessed:
            self.assertTrue(
                fileio.exists(os.path.join(blessing_output.uri, 'BLESSED')))
        else:
            self.assertTrue(
                fileio.exists(os.path.join(blessing_output.uri,
                                           'NOT_BLESSED')))
Ejemplo n.º 12
0
def value_converter(
        tfx_value: Any) -> Optional[pipeline_pb2.ValueOrRuntimeParameter]:
    """Converts TFX/MLMD values into Kubeflow pipeline ValueOrRuntimeParameter."""
    if tfx_value is None:
        return None

    result = pipeline_pb2.ValueOrRuntimeParameter()
    if isinstance(tfx_value, (int, float, str, Text)):
        result.constant_value.CopyFrom(get_kubeflow_value(tfx_value))
    elif isinstance(tfx_value, (Dict, List)):
        result.constant_value.CopyFrom(
            pipeline_pb2.Value(string_value=json.dumps(tfx_value)))
    elif isinstance(tfx_value, data_types.RuntimeParameter):
        # Attach the runtime parameter to the context.
        parameter_utils.attach_parameter(tfx_value)
        result.runtime_parameter = tfx_value.name
    elif isinstance(tfx_value, metadata_store_pb2.Value):
        if tfx_value.WhichOneof('value') == 'int_value':
            result.constant_value.CopyFrom(
                pipeline_pb2.Value(int_value=tfx_value.int_value))
        elif tfx_value.WhichOneof('value') == 'double_value':
            result.constant_value.CopyFrom(
                pipeline_pb2.Value(double_value=tfx_value.double_value))
        elif tfx_value.WhichOneof('value') == 'string_value':
            result.constant_value.CopyFrom(
                pipeline_pb2.Value(string_value=tfx_value.string_value))
    elif isinstance(tfx_value, message.Message):
        result.constant_value.CopyFrom(
            pipeline_pb2.Value(string_value=json_format.MessageToJson(
                message=tfx_value, sort_keys=True)))
    else:
        # By default will attempt to encode the object using json_utils.dumps.
        result.constant_value.CopyFrom(
            pipeline_pb2.Value(string_value=json_utils.dumps(tfx_value)))
    return result
Ejemplo n.º 13
0
    def __init__(self,
                 statistics: Optional[types.Channel] = None,
                 infer_feature_shape: Optional[Union[
                     bool, data_types.RuntimeParameter]] = True,
                 exclude_splits: Optional[List[Text]] = None):
        """Constructs a SchemaGen component.

    Args:
      statistics: A Channel of `ExampleStatistics` type (required if spec is not
        passed). This should contain at least a `train` split. Other splits are
        currently ignored. _required_
      infer_feature_shape: Boolean (or RuntimeParameter) value indicating
        whether or not to infer the shape of features. If the feature shape is
        not inferred, downstream Tensorflow Transform component using the schema
        will parse input as tf.SparseTensor. Default to True if not set.
      exclude_splits: Names of splits that will not be taken into consideration
        when auto-generating a schema. Default behavior (when exclude_splits is
        set to None) is excluding no splits.
    """
        if exclude_splits is None:
            exclude_splits = []
            logging.info(
                'Excluding no splits because exclude_splits is not set.')
        schema = types.Channel(type=standard_artifacts.Schema)
        if isinstance(infer_feature_shape, bool):
            infer_feature_shape = int(infer_feature_shape)
        spec = SchemaGenSpec(statistics=statistics,
                             infer_feature_shape=infer_feature_shape,
                             exclude_splits=json_utils.dumps(exclude_splits),
                             schema=schema)
        super(SchemaGen, self).__init__(spec=spec)
Ejemplo n.º 14
0
    def __init__(
            self,
            examples: types.Channel = None,
            transformed_examples: Optional[types.Channel] = None,
            transform_graph: Optional[types.Channel] = None,
            schema: Optional[types.Channel] = None,
            base_model: Optional[types.Channel] = None,
            hyperparameters: Optional[types.Channel] = None,
            module_file: Optional[Union[Text,
                                        data_types.RuntimeParameter]] = None,
            run_fn: Optional[Union[Text, data_types.RuntimeParameter]] = None,
            trainer_fn: Optional[Union[Text,
                                       data_types.RuntimeParameter]] = None,
            train_args: Union[trainer_pb2.TrainArgs, Dict[Text, Any]] = None,
            eval_args: Union[trainer_pb2.EvalArgs, Dict[Text, Any]] = None,
            custom_config: Optional[Dict[Text, Any]] = None,
            custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None,
            output: Optional[types.Channel] = None,
            model_run: Optional[types.Channel] = None,
            test_results: Optional[types.Channel] = None,
            instance_name: Optional[Text] = None):

        if [bool(module_file),
                bool(run_fn),
                bool(trainer_fn)].count(True) != 1:
            raise ValueError(
                "Exactly one of 'module_file', 'trainer_fn', or 'run_fn' must be "
                "supplied.")

        if bool(examples) == bool(transformed_examples):
            raise ValueError(
                "Exactly one of 'example' or 'transformed_example' must be supplied."
            )

        if transformed_examples and not transform_graph:
            raise ValueError("If 'transformed_examples' is supplied, "
                             "'transform_graph' must be supplied too.")

        examples = examples or transformed_examples
        output = output or types.Channel(type=Model)
        model_run = model_run or types.Channel(type=ModelRun)
        test_results = test_results or types.Channel(type=Examples)
        spec = ZenMLTrainerSpec(examples=examples,
                                transform_graph=transform_graph,
                                schema=schema,
                                base_model=base_model,
                                hyperparameters=hyperparameters,
                                train_args=train_args,
                                eval_args=eval_args,
                                module_file=module_file,
                                run_fn=run_fn,
                                trainer_fn=trainer_fn,
                                custom_config=json_utils.dumps(custom_config),
                                model=output,
                                model_run=model_run,
                                test_results=test_results)
        super(Trainer,
              self).__init__(spec=spec,
                             custom_executor_spec=custom_executor_spec,
                             instance_name=instance_name)
Ejemplo n.º 15
0
 def testDoWithModuleFileWithTFXIO(self):
     self._exec_properties['custom_config'] = json_utils.dumps(
         {'use_tfxio_input_fn': True})
     self._exec_properties['module_file'] = self._module_file
     self._do(self._trainer_executor)
     self._verify_model_exports()
     self._verify_model_run_exports()
Ejemplo n.º 16
0
  def testDo(self):
    source_data_dir = os.path.join(
        os.path.dirname(os.path.dirname(__file__)), 'testdata')

    statistics_artifact = standard_artifacts.ExampleStatistics()
    statistics_artifact.uri = os.path.join(source_data_dir, 'statistics_gen')
    statistics_artifact.split_names = artifact_utils.encode_split_names(
        ['train', 'eval', 'test'])

    output_data_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self._testMethodName)

    schema_output = standard_artifacts.Schema()
    schema_output.uri = os.path.join(output_data_dir, 'schema_output')

    input_dict = {
        standard_component_specs.STATISTICS_KEY: [statistics_artifact],
    }

    exec_properties = {
        # List needs to be serialized before being passed into Do function.
        standard_component_specs.EXCLUDE_SPLITS_KEY:
            json_utils.dumps(['test'])
    }

    output_dict = {
        standard_component_specs.SCHEMA_KEY: [schema_output],
    }

    schema_gen_executor = executor.Executor()
    schema_gen_executor.Do(input_dict, output_dict, exec_properties)
    self.assertNotEqual(0, len(fileio.listdir(schema_output.uri)))
Ejemplo n.º 17
0
  def testDoSkippedModelCreation(self, mock_runner, mock_run_model_inference,
                                 _):
    input_dict = {
        'examples': [self._examples],
        'model': [self._model],
        'model_blessing': [self._model_blessing],
    }
    output_dict = {
        'inference_result': [self._inference_result],
    }
    ai_platform_serving_args = {
        'model_name': 'model_name',
        'project_id': 'project_id'
    }
    # Create exe properties.
    exec_properties = {
        'data_spec':
            proto_utils.proto_to_json(bulk_inferrer_pb2.DataSpec()),
        'custom_config':
            json_utils.dumps(
                {executor.SERVING_ARGS_KEY: ai_platform_serving_args}),
    }
    mock_runner.get_service_name_and_api_version.return_value = ('ml', 'v1')
    mock_runner.create_model_for_aip_prediction_if_not_exist.return_value = False

    # Run executor.
    bulk_inferrer = executor.Executor(self._context)
    bulk_inferrer.Do(input_dict, output_dict, exec_properties)

    ai_platform_prediction_model_spec = (
        model_spec_pb2.AIPlatformPredictionModelSpec(
            project_id='project_id',
            model_name='model_name',
            version_name=self._model_version))
    ai_platform_prediction_model_spec.use_serialization_config = True
    inference_endpoint = model_spec_pb2.InferenceSpecType()
    inference_endpoint.ai_platform_prediction_model_spec.CopyFrom(
        ai_platform_prediction_model_spec)
    mock_run_model_inference.assert_called_once_with(mock.ANY, mock.ANY,
                                                     mock.ANY, mock.ANY,
                                                     mock.ANY,
                                                     inference_endpoint)
    executor_class_path = '%s.%s' % (bulk_inferrer.__class__.__module__,
                                     bulk_inferrer.__class__.__name__)
    with telemetry_utils.scoped_labels(
        {telemetry_utils.LABEL_TFX_EXECUTOR: executor_class_path}):
      job_labels = telemetry_utils.make_labels_dict()
    mock_runner.deploy_model_for_aip_prediction.assert_called_once_with(
        serving_path=path_utils.serving_model_path(self._model.uri),
        model_version_name=mock.ANY,
        ai_platform_serving_args=ai_platform_serving_args,
        labels=job_labels,
        api=mock.ANY,
        skip_model_endpoint_creation=True,
        set_default=False)
    mock_runner.delete_model_from_aip_if_exists.assert_called_once_with(
        model_version_name=mock.ANY,
        ai_platform_serving_args=ai_platform_serving_args,
        api=mock.ANY,
        delete_model_endpoint=False)
Ejemplo n.º 18
0
    def __init__(self,
                 examples: types.Channel = None,
                 schema: Optional[types.Channel] = None,
                 transform_graph: Optional[types.Channel] = None,
                 module_file: Optional[Text] = None,
                 tuner_fn: Optional[Text] = None,
                 train_args: trainer_pb2.TrainArgs = None,
                 eval_args: trainer_pb2.EvalArgs = None,
                 tune_args: Optional[tuner_pb2.TuneArgs] = None,
                 custom_config: Optional[Dict[Text, Any]] = None):
        """Construct a Tuner component.

    Args:
      examples: A Channel of type `standard_artifacts.Examples`, serving as the
        source of examples that are used in tuning (required).
      schema:  An optional Channel of type `standard_artifacts.Schema`, serving
        as the schema of training and eval data. This is used when raw examples
        are provided.
      transform_graph: An optional Channel of type
        `standard_artifacts.TransformGraph`, serving as the input transform
        graph if present. This is used when transformed examples are provided.
      module_file: A path to python module file containing UDF tuner definition.
        The module_file must implement a function named `tuner_fn` at its top
        level. The function must have the following signature.
            def tuner_fn(fn_args: FnArgs) -> TunerFnResult:
        Exactly one of 'module_file' or 'tuner_fn' must be supplied.
      tuner_fn:  A python path to UDF model definition function. See
        'module_file' for the required signature of the UDF. Exactly one of
        'module_file' or 'tuner_fn' must be supplied.
      train_args: A trainer_pb2.TrainArgs instance, containing args used for
        training. Currently only splits and num_steps are available. Default
        behavior (when splits is empty) is train on `train` split.
      eval_args: A trainer_pb2.EvalArgs instance, containing args used for eval.
        Currently only splits and num_steps are available. Default behavior
        (when splits is empty) is evaluate on `eval` split.
      tune_args: A tuner_pb2.TuneArgs instance, containing args used for tuning.
        Currently only num_parallel_trials is available.
      custom_config: A dict which contains addtional training job parameters
        that will be passed into user module.
    """
        if bool(module_file) == bool(tuner_fn):
            raise ValueError(
                "Exactly one of 'module_file' or 'tuner_fn' must be supplied")

        best_hyperparameters = types.Channel(
            type=standard_artifacts.HyperParameters)
        spec = TunerSpec(
            examples=examples,
            schema=schema,
            transform_graph=transform_graph,
            module_file=module_file,
            tuner_fn=tuner_fn,
            train_args=train_args,
            eval_args=eval_args,
            tune_args=tune_args,
            best_hyperparameters=best_hyperparameters,
            custom_config=json_utils.dumps(custom_config),
        )
        super(Tuner, self).__init__(spec=spec)
Ejemplo n.º 19
0
  def _parse_parameters(self, raw_args: Mapping[str, Any]):
    """Parse arguments to ComponentSpec."""
    unparsed_args = set(raw_args.keys())
    inputs = {}
    outputs = {}
    self.exec_properties = {}

    # First, check that the arguments are set.
    for arg_name, arg in itertools.chain(self.PARAMETERS.items(),
                                         self.INPUTS.items(),
                                         self.OUTPUTS.items()):
      if arg_name not in unparsed_args:
        if arg.optional:
          continue
        else:
          raise ValueError('Missing argument %r to %s.' %
                           (arg_name, self.__class__))
      unparsed_args.remove(arg_name)

      # Type check the argument.
      value = raw_args[arg_name]
      if arg.optional and value is None:
        continue
      arg.type_check(arg_name, value)

    # Populate the appropriate dictionary for each parameter type.
    for arg_name, arg in self.PARAMETERS.items():
      if arg.optional and arg_name not in raw_args:
        continue
      value = raw_args[arg_name]

      if (inspect.isclass(arg.type) and
          issubclass(arg.type, message.Message) and value and
          not _is_runtime_param(value)):
        if arg.use_proto:
          if isinstance(value, dict):
            value = proto_utils.dict_to_proto(value, arg.type())
          elif isinstance(value, str):
            value = proto_utils.json_to_proto(value, arg.type())
        else:
          # Create deterministic json string as it will be stored in metadata
          # for cache check.
          if isinstance(value, dict):
            value = json_utils.dumps(value)
          elif not isinstance(value, str):
            value = proto_utils.proto_to_json(value)

      self.exec_properties[arg_name] = value

    for arg_dict, param_dict in ((self.INPUTS, inputs), (self.OUTPUTS,
                                                         outputs)):
      for arg_name, arg in arg_dict.items():
        if arg.optional and not raw_args.get(arg_name):
          continue
        value = raw_args[arg_name]
        param_dict[arg_name] = value

    self.inputs = inputs
    self.outputs = outputs
Ejemplo n.º 20
0
 def testBigQueryServingArgs(self):
     temp_exec_properties = {
         'custom_config': json_utils.dumps({}),
         'push_destination': None,
     }
     with self.assertRaises(ValueError):
         self._executor.Do(self._input_dict, self._output_dict,
                           temp_exec_properties)
Ejemplo n.º 21
0
 def testMultipleArtifactsWithTFXIO(self):
     self._exec_properties['custom_config'] = json_utils.dumps(
         {'use_tfxio_input_fn': True})
     self._input_dict[constants.EXAMPLES_KEY] = self._multiple_artifacts
     self._exec_properties['module_file'] = self._module_file
     self._do(self._generic_trainer_executor)
     self._verify_model_exports()
     self._verify_model_run_exports()
Ejemplo n.º 22
0
    def __init__(
            self,
            model: Optional[types.Channel] = None,
            model_blessing: Optional[types.Channel] = None,
            infra_blessing: Optional[types.Channel] = None,
            push_destination: Optional[Union[pusher_pb2.PushDestination,
                                             Dict[Text, Any]]] = None,
            custom_config: Optional[Dict[Text, Any]] = None,
            custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None,
            pushed_model: Optional[types.Channel] = None):
        """Construct a Pusher component.

    Args:
      model: An optional Channel of type `standard_artifacts.Model`, usually
        produced by a Trainer component.
      model_blessing: An optional Channel of type
        `standard_artifacts.ModelBlessing`, usually produced from an Evaluator
        component.
      infra_blessing: An optional Channel of type
        `standard_artifacts.InfraBlessing`, usually produced from an
        InfraValidator component.
      push_destination: A pusher_pb2.PushDestination instance, providing info
        for tensorflow serving to load models. Optional if executor_class
        doesn't require push_destination. If any field is provided as a
        RuntimeParameter, push_destination should be constructed as a dict with
        the same field names as PushDestination proto message.
      custom_config: A dict which contains the deployment job parameters to be
        passed to cloud-based training platforms. The [Kubeflow example](
          https://github.com/tensorflow/tfx/blob/6ff57e36a7b65818d4598d41e584a42584d361e6/tfx/examples/chicago_taxi_pipeline/taxi_pipeline_kubeflow_gcp.py#L278-L285)
          contains an example how this can be used by custom executors.
      custom_executor_spec: Optional custom executor spec. This is experimental
        and is subject to change in the future.
      pushed_model: Optional output `standard_artifacts.PushedModel` channel
        with result of push.
    """
        pushed_model = pushed_model or types.Channel(
            type=standard_artifacts.PushedModel)
        if push_destination is None and not custom_executor_spec:
            raise ValueError(
                'push_destination is required unless a '
                'custom_executor_spec is supplied that does not require '
                'it.')
        if model is None and infra_blessing is None:
            raise ValueError(
                'Either one of model or infra_blessing channel should be given. '
                'If infra_blessing is used in place of model, it must have been '
                'created with InfraValidator with RequestSpec.make_warmup = True. '
                'This cannot be checked during pipeline construction time but will '
                'raise runtime error if infra_blessing does not contain a model.'
            )
        spec = PusherSpec(model=model,
                          model_blessing=model_blessing,
                          infra_blessing=infra_blessing,
                          push_destination=push_destination,
                          custom_config=json_utils.dumps(custom_config),
                          pushed_model=pushed_model)
        super(Pusher, self).__init__(spec=spec,
                                     custom_executor_spec=custom_executor_spec)
Ejemplo n.º 23
0
    def testDoLegacySingleEvalSavedModelWFairness(self, exec_properties):
        source_data_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'testdata')
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        # Create input dict.
        examples = standard_artifacts.Examples()
        examples.uri = os.path.join(source_data_dir, 'csv_example_gen')
        examples.split_names = artifact_utils.encode_split_names(
            ['train', 'eval'])
        model = standard_artifacts.Model()
        model.uri = os.path.join(source_data_dir, 'trainer/current')
        input_dict = {
            EXAMPLES_KEY: [examples],
            MODEL_KEY: [model],
        }

        # Create output dict.
        eval_output = standard_artifacts.ModelEvaluation()
        eval_output.uri = os.path.join(output_data_dir, 'eval_output')
        blessing_output = standard_artifacts.ModelBlessing()
        blessing_output.uri = os.path.join(output_data_dir, 'blessing_output')
        output_dict = {
            EVALUATION_KEY: [eval_output],
            BLESSING_KEY: [blessing_output],
        }

        try:
            # Need to import the following module so that the fairness indicator
            # post-export metric is registered.  This may raise an ImportError if the
            # currently-installed version of TFMA does not support fairness
            # indicators.
            import tensorflow_model_analysis.addons.fairness.post_export_metrics.fairness_indicators  # pylint: disable=g-import-not-at-top, unused-variable
            exec_properties['fairness_indicator_thresholds'] = [
                0.1, 0.3, 0.5, 0.7, 0.9
            ]
        except ImportError:
            logging.warning(
                'Not testing fairness indicators because a compatible TFMA version '
                'is not installed.')

        # List needs to be serialized before being passed into Do function.
        exec_properties[EXAMPLE_SPLITS_KEY] = json_utils.dumps(None)

        # Run executor.
        evaluator = executor.Executor()
        evaluator.Do(input_dict, output_dict, exec_properties)

        # Check evaluator outputs.
        self.assertTrue(
            fileio.exists(os.path.join(eval_output.uri, 'eval_config.json')))
        self.assertTrue(fileio.exists(os.path.join(eval_output.uri,
                                                   'metrics')))
        self.assertTrue(fileio.exists(os.path.join(eval_output.uri, 'plots')))
        self.assertFalse(
            fileio.exists(os.path.join(blessing_output.uri, 'BLESSED')))
Ejemplo n.º 24
0
    def testDumpsNestedClass(self):
        obj = _DefaultJsonableObject(_DefaultJsonableObject, None, None)

        json_text = json_utils.dumps(obj)

        actual_obj = json_utils.loads(json_text)
        self.assertEqual(_DefaultJsonableObject, actual_obj.a)
        self.assertIsNone(actual_obj.b)
        self.assertIsNone(actual_obj.c)
Ejemplo n.º 25
0
    def testDumpsJsonableObjectRoundtrip(self):
        obj = _DefaultJsonableObject(1, {'a': 'b'}, [True])

        json_text = json_utils.dumps(obj)

        actual_obj = json_utils.loads(json_text)
        self.assertEqual(1, actual_obj.a)
        self.assertDictEqual({'a': 'b'}, actual_obj.b)
        self.assertItemsEqual([True], actual_obj.c)
Ejemplo n.º 26
0
    def testArtifact(self):
        instance = _MyArtifact()

        # Test property getters.
        self.assertEqual('', instance.uri)
        self.assertEqual(0, instance.id)
        self.assertEqual(0, instance.type_id)
        self.assertEqual('MyTypeName', instance.type_name)
        self.assertEqual('', instance.state)

        # Default property does not have span or split_names.
        with self.assertRaisesRegexp(AttributeError, "has no property 'span'"):
            instance.span  # pylint: disable=pointless-statement
        with self.assertRaisesRegexp(AttributeError,
                                     "has no property 'split_names'"):
            instance.split_names  # pylint: disable=pointless-statement

        # Test property setters.
        instance.uri = '/tmp/uri2'
        self.assertEqual('/tmp/uri2', instance.uri)

        instance.id = 1
        self.assertEqual(1, instance.id)

        instance.type_id = 2
        self.assertEqual(2, instance.type_id)

        instance.state = artifact.ArtifactState.DELETED
        self.assertEqual(artifact.ArtifactState.DELETED, instance.state)

        # Default artifact does not have span.
        with self.assertRaisesRegexp(AttributeError,
                                     "unknown property 'span'"):
            instance.span = 20190101
        # Default artifact does not have span.
        with self.assertRaisesRegexp(AttributeError,
                                     "unknown property 'split_names'"):
            instance.split_names = ''

        instance.set_int_custom_property('int_key', 20)
        self.assertEqual(
            20, instance.mlmd_artifact.custom_properties['int_key'].int_value)

        instance.set_string_custom_property('string_key', 'string_value')
        self.assertEqual(
            'string_value', instance.mlmd_artifact.
            custom_properties['string_key'].string_value)

        self.assertEqual(
            'Artifact(type_name: MyTypeName, uri: /tmp/uri2, id: 1)',
            str(instance))

        # Test json serialization.
        json_dict = json_utils.dumps(instance)
        other_instance = json_utils.loads(json_dict)
        self.assertEqual(instance.mlmd_artifact, other_instance.mlmd_artifact)
        self.assertEqual(instance.artifact_type, other_instance.artifact_type)
Ejemplo n.º 27
0
    def testConstructSingleVmJob(self):
        training = ai_platform_training_component.create_ai_platform_training(
            name='my_training_step',
            project_id='my-project',
            region='us-central1',
            image_uri='gcr.io/my-project/caip-training-test:latest',
            args=[
                '--examples',
                placeholders.InputUriPlaceholder('examples'), '--n-steps',
                placeholders.InputValuePlaceholder('n_step'), '--model-dir',
                placeholders.OutputUriPlaceholder('model')
            ],
            scale_tier='BASIC_GPU',
            inputs={
                'examples': self.examples,
            },
            outputs={'model': standard_artifacts.Model},
            parameters={'n_step': 100})

        expected_aip_config = {
            ai_platform_training_executor.PROJECT_CONFIG_KEY: 'my-project',
            ai_platform_training_executor.TRAINING_JOB_CONFIG_KEY: {
                'training_input': {
                    'scaleTier':
                    'BASIC_GPU',
                    'region':
                    'us-central1',
                    'masterConfig': {
                        'imageUri':
                        'gcr.io/my-project/caip-training-test:latest'
                    },
                    'args': [
                        '--examples',
                        placeholders.InputUriPlaceholder('examples'),
                        '--n-steps',
                        placeholders.InputValuePlaceholder('n_step'),
                        '--model-dir',
                        placeholders.OutputUriPlaceholder('model')
                    ]
                },
                ai_platform_training_executor.LABELS_CONFIG_KEY: None,
            },
            ai_platform_training_executor.JOB_ID_CONFIG_KEY: None,
            ai_platform_training_executor.LABELS_CONFIG_KEY: None,
        }

        # exec_properties has two entries: one is the user-defined 'n_step', another
        # is the aip_training_config.
        self.assertLen(training.exec_properties, 2)
        self.assertEqual(training.outputs['model'].type_name,
                         standard_artifacts.Model.TYPE_NAME)
        self.assertEqual(training.inputs['examples'].type_name,
                         standard_artifacts.Examples.TYPE_NAME)
        self.assertEqual(training.exec_properties['n_step'], 100)
        self.assertEqual(
            training.exec_properties[ai_platform_training_executor.CONFIG_KEY],
            json_utils.dumps(expected_aip_config))
Ejemplo n.º 28
0
    def __init__(self,
                 examples: types.Channel,
                 to_key_fn: Optional[Text] = None,
                 to_key_fn_key: Optional[Text] = 'to_key_fn',
                 pipeline_configuration: Optional[types.Channel] = None,
                 stratified_examples: Optional[types.Channel] = None,
                 splits_to_transform: Optional[List[Text]] = None,
                 splits_to_copy: Optional[List[Text]] = None,
                 samples_per_key: Optional[int] = None):
        """Construct an StratifiedSampler component.
    Args:
      examples: A Channel of 'Examples' type, usually produced by ExampleGen
        component. _required_
      pipeline_configuration: A Channel of 'PipelineConfiguration' type, usually produced by FromCustomConfig
        component.
      stratified_examples: Channel of `Examples` to store the inference
        results.
      splits_to_transform: Optional list of split names to transform.
      splits_to_copy: Optional list of split names to copy.
      samples_per_key: Number of samples per key.
      to_key_fn_key: the name of the key that contains the to_key_fn - default is 'to_key_fn'.
      to_key_fn: To key function, the function that will extract the key - must be 'to_key: Example -> key
                 For example something like:
                 >>> def to_key(m):
                 >>>   return m.features.feature['trip_miles'].float_list.value[0] > 42.
    """
        stratified_examples = stratified_examples or types.Channel(
            type=standard_artifacts.Examples)

        if stratified_examples is None:
            stratified_examples = types.Channel(
                type=standard_artifacts.Examples,
                matching_channel_name='examples')

        spec = StratifiedSamplerSpec(
            examples=examples,
            pipeline_configuration=pipeline_configuration,
            stratified_examples=stratified_examples,
            splits_to_transform=json_utils.dumps(splits_to_transform),
            splits_to_copy=json_utils.dumps(splits_to_copy),
            to_key_fn=to_key_fn,
            to_key_fn_key=to_key_fn_key,
            samples_per_key=samples_per_key)
        super(StratifiedSampler, self).__init__(spec=spec)
Ejemplo n.º 29
0
    def __init__(
            self,
            model: Optional[types.Channel] = None,
            model_blessing: Optional[types.Channel] = None,
            infra_blessing: Optional[types.Channel] = None,
            push_destination: Optional[Union[pusher_pb2.PushDestination,
                                             Dict[Text, Any]]] = None,
            custom_config: Optional[Dict[Text, Any]] = None,
            custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None):
        """Construct a Pusher component.

    Args:
      model: An optional Channel of type `standard_artifacts.Model`, usually
        produced by a Trainer component.
      model_blessing: An optional Channel of type
        `standard_artifacts.ModelBlessing`, usually produced from an Evaluator
        component.
      infra_blessing: An optional Channel of type
        `standard_artifacts.InfraBlessing`, usually produced from an
        InfraValidator component.
      push_destination: A pusher_pb2.PushDestination instance, providing info
        for tensorflow serving to load models. Optional if executor_class
        doesn't require push_destination. If any field is provided as a
        RuntimeParameter, push_destination should be constructed as a dict with
        the same field names as PushDestination proto message.
      custom_config: A dict which contains the deployment job parameters to be
        passed to Cloud platforms.
      custom_executor_spec: Optional custom executor spec. This is experimental
        and is subject to change in the future.
    """
        pushed_model = types.Channel(type=standard_artifacts.PushedModel)
        if (push_destination is None and not custom_executor_spec
                and self.EXECUTOR_SPEC.executor_class == executor.Executor):
            raise ValueError(
                'push_destination is required unless a '
                'custom_executor_spec is supplied that does not require '
                'it.')
        if custom_executor_spec:
            logging.warning(
                '`custom_executor_spec` is going to be deprecated.')
        if model is None and infra_blessing is None:
            raise ValueError(
                'Either one of model or infra_blessing channel should be given. '
                'If infra_blessing is used in place of model, it must have been '
                'created with InfraValidator with RequestSpec.make_warmup = True. '
                'This cannot be checked during pipeline construction time but will '
                'raise runtime error if infra_blessing does not contain a model.'
            )
        spec = PusherSpec(model=model,
                          model_blessing=model_blessing,
                          infra_blessing=infra_blessing,
                          push_destination=push_destination,
                          custom_config=json_utils.dumps(custom_config),
                          pushed_model=pushed_model)
        super(Pusher, self).__init__(spec=spec,
                                     custom_executor_spec=custom_executor_spec)
Ejemplo n.º 30
0
    def __init__(self,
                 examples: types.Channel,
                 model: Optional[types.Channel] = None,
                 model_blessing: Optional[types.Channel] = None,
                 data_spec: Optional[Union[bulk_inferrer_pb2.DataSpec,
                                           Dict[Text, Any]]] = None,
                 output_example_spec: Optional[Union[
                     bulk_inferrer_pb2.OutputExampleSpec, Dict[Text,
                                                               Any]]] = None,
                 custom_config: Optional[Dict[Text, Any]] = None):
        """Construct an BulkInferrer component.

    Args:
      examples: A Channel of type `standard_artifacts.Examples`, usually
        produced by an ExampleGen component. _required_
      model: A Channel of type `standard_artifacts.Model`, usually produced by
        a Trainer component.
      model_blessing: A Channel of type `standard_artifacts.ModelBlessing`,
        usually produced by a ModelValidator component.
      data_spec: bulk_inferrer_pb2.DataSpec instance that describes data
        selection. If any field is provided as a RuntimeParameter, data_spec
        should be constructed as a dict with the same field names as DataSpec
        proto message.
      output_example_spec: bulk_inferrer_pb2.OutputExampleSpec instance, specify
        if you want BulkInferrer to output examples instead of inference result.
        If any field is provided as a RuntimeParameter, output_example_spec
        should be constructed as a dict with the same field names as
        OutputExampleSpec proto message.
      custom_config: A dict which contains the deployment job parameters to be
        passed to Google Cloud AI Platform.
        custom_config.ai_platform_serving_args need to contain the serving job
        parameters. For the full set of parameters, refer to
        https://cloud.google.com/ml-engine/reference/rest/v1/projects.models

    Raises:
      ValueError: Must not specify inference_result or output_examples depends
        on whether output_example_spec is set or not.
    """
        if output_example_spec:
            output_examples = types.Channel(type=standard_artifacts.Examples)
            inference_result = None
        else:
            inference_result = types.Channel(
                type=standard_artifacts.InferenceResult)
            output_examples = None

        spec = CloudAIBulkInferrerComponentSpec(
            examples=examples,
            model=model,
            model_blessing=model_blessing,
            data_spec=data_spec or bulk_inferrer_pb2.DataSpec(),
            output_example_spec=output_example_spec,
            custom_config=json_utils.dumps(custom_config),
            inference_result=inference_result,
            output_examples=output_examples)
        super(CloudAIBulkInferrerComponent, self).__init__(spec=spec)