Ejemplo n.º 1
0
 def setUp(self):
     super(TFServingRpcRequestBuilderTest, self).setUp()
     self._examples = standard_artifacts.Examples()
     self._examples.uri = _CSV_EXAMPLE_GEN_URI
     self._examples.split_names = artifact_utils.encode_split_names(
         ['train', 'eval'])
Ejemplo n.º 2
0
    def testDriverRunFn(self):
        # Create input dir.
        self._input_base_path = os.path.join(self._test_dir, 'input_base')
        fileio.makedirs(self._input_base_path)

        # Fake previous outputs
        span1_v1_split1 = os.path.join(self._input_base_path, 'span01',
                                       'split1', 'data')
        io_utils.write_string_file(span1_v1_split1, 'testing11')
        span1_v1_split2 = os.path.join(self._input_base_path, 'span01',
                                       'split2', 'data')
        io_utils.write_string_file(span1_v1_split2, 'testing12')

        ir_driver = driver.Driver(self._mock_metadata)
        example = standard_artifacts.Examples()

        # Prepare output_dic
        example.uri = 'my_uri'  # Will verify that this uri is not changed.
        output_dic = {utils.EXAMPLES_KEY: [example]}

        # Prepare output_dic exec_proterties.
        exec_properties = {
            utils.INPUT_BASE_KEY:
            self._input_base_path,
            utils.INPUT_CONFIG_KEY:
            json_format.MessageToJson(example_gen_pb2.Input(splits=[
                example_gen_pb2.Input.Split(name='s1',
                                            pattern='span{SPAN}/split1/*'),
                example_gen_pb2.Input.Split(name='s2',
                                            pattern='span{SPAN}/split2/*')
            ]),
                                      preserving_proto_field_name=True),
        }
        result = ir_driver.run(
            portable_data_types.ExecutionInfo(output_dict=output_dic,
                                              exec_properties=exec_properties))
        # Assert exec_properties' values
        exec_properties = result.exec_properties
        self.assertEqual(exec_properties[utils.SPAN_PROPERTY_NAME].int_value,
                         1)
        updated_input_config = example_gen_pb2.Input()
        json_format.Parse(exec_properties[utils.INPUT_CONFIG_KEY].string_value,
                          updated_input_config)
        self.assertProtoEquals(
            """
        splits {
          name: "s1"
          pattern: "span01/split1/*"
        }
        splits {
          name: "s2"
          pattern: "span01/split2/*"
        }""", updated_input_config)
        self.assertRegex(
            exec_properties[utils.FINGERPRINT_PROPERTY_NAME].string_value,
            r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*'
        )
        # Assert output_artifacts' values
        self.assertLen(result.output_artifacts[utils.EXAMPLES_KEY].artifacts,
                       1)
        output_example = result.output_artifacts[
            utils.EXAMPLES_KEY].artifacts[0]
        self.assertEqual(output_example.uri, example.uri)
        self.assertEqual(
            output_example.custom_properties[
                utils.SPAN_PROPERTY_NAME].string_value, '1')
        self.assertRegex(
            output_example.custom_properties[
                utils.FINGERPRINT_PROPERTY_NAME].string_value,
            r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*'
        )
Ejemplo n.º 3
0
    def testUpdateExecution(self):
        with metadata.Metadata(connection_config=self._connection_config) as m:
            contexts = m.register_pipeline_contexts_if_not_exists(
                self._pipeline_info)
            m.register_execution(input_artifacts={},
                                 exec_properties={'k': 'v1'},
                                 pipeline_info=self._pipeline_info,
                                 component_info=self._component_info,
                                 contexts=contexts)
            [execution] = m.store.get_executions_by_context(
                m.get_component_run_context(self._component_info).id)
            self.assertEqual(execution.properties['k'].string_value, 'v1')
            self.assertEqual(execution.properties['state'].string_value,
                             metadata.EXECUTION_STATE_NEW)
            self.assertEqual(execution.last_known_state,
                             metadata_store_pb2.Execution.RUNNING)

            m.update_execution(
                execution,
                self._component_info,
                input_artifacts={'input_a': [standard_artifacts.Examples()]},
                exec_properties={'k': 'v2'},
                contexts=contexts)

            [execution] = m.store.get_executions_by_context(
                m.get_component_run_context(self._component_info).id)
            self.assertEqual(execution.properties['k'].string_value, 'v2')
            self.assertEqual(execution.properties['state'].string_value,
                             metadata.EXECUTION_STATE_NEW)
            self.assertEqual(execution.last_known_state,
                             metadata_store_pb2.Execution.RUNNING)
            [event] = m.store.get_events_by_execution_ids([execution.id])
            self.assertEqual(event.artifact_id, 1)
            [artifact] = m.store.get_artifacts_by_context(
                m.get_component_run_context(self._component_info).id)
            self.assertEqual(artifact.id, 1)

            aa = standard_artifacts.Examples()
            aa.set_mlmd_artifact(artifact)
            m.update_execution(execution,
                               self._component_info,
                               input_artifacts={'input_a': [aa]})
            [event] = m.store.get_events_by_execution_ids([execution.id])
            self.assertEqual(event.type, metadata_store_pb2.Event.INPUT)

            m.publish_execution(
                self._component_info,
                output_artifacts={'output': [standard_artifacts.Model()]},
                exec_properties={'k': 'v3'})

            [execution] = m.store.get_executions_by_context(
                m.get_component_run_context(self._component_info).id)
            self.assertEqual(execution.properties['k'].string_value, 'v3')
            self.assertEqual(execution.properties['state'].string_value,
                             metadata.EXECUTION_STATE_COMPLETE)
            self.assertEqual(execution.last_known_state,
                             metadata_store_pb2.Execution.COMPLETE)
            [_, event_b] = m.store.get_events_by_execution_ids([execution.id])
            self.assertEqual(event_b.artifact_id, 2)
            self.assertEqual(event_b.type, metadata_store_pb2.Event.OUTPUT)
            [_, artifact_b] = m.store.get_artifacts_by_context(
                m.get_component_run_context(self._component_info).id)
            self.assertEqual(artifact_b.id, 2)
            self._check_artifact_state(m, artifact_b, ArtifactState.PUBLISHED)
Ejemplo n.º 4
0
    def setUp(self):
        super().setUp()
        self._source_data_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'testdata')
        self._output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)
        self.component_id = 'test_component'

        # Create input dict.
        self._examples = standard_artifacts.Examples()
        unlabelled_path = os.path.join(self._source_data_dir,
                                       'csv_example_gen', 'unlabelled')
        self._examples.uri = os.path.join(self._output_data_dir,
                                          'csv_example_gen')
        io_utils.copy_dir(unlabelled_path,
                          os.path.join(self._examples.uri, 'Split-unlabelled'))
        io_utils.copy_dir(
            unlabelled_path,
            os.path.join(self._examples.uri, 'Split-unlabelled2'))
        self._examples.split_names = artifact_utils.encode_split_names(
            ['unlabelled', 'unlabelled2'])
        self._model = standard_artifacts.Model()
        self._model.uri = os.path.join(self._source_data_dir,
                                       'trainer/current')

        self._model_blessing = standard_artifacts.ModelBlessing()
        self._model_blessing.uri = os.path.join(self._source_data_dir,
                                                'model_validator/blessed')
        self._model_blessing.set_int_custom_property('blessed', 1)

        self._input_dict = {
            standard_component_specs.EXAMPLES_KEY: [self._examples],
            standard_component_specs.MODEL_KEY: [self._model],
            standard_component_specs.MODEL_BLESSING_KEY:
            [self._model_blessing],
        }

        # Create output dict.
        self._inference_result = standard_artifacts.InferenceResult()
        self._prediction_log_dir = os.path.join(self._output_data_dir,
                                                'prediction_logs')
        self._inference_result.uri = self._prediction_log_dir

        self._output_examples = standard_artifacts.Examples()
        self._output_examples_dir = os.path.join(self._output_data_dir,
                                                 'output_examples')
        self._output_examples.uri = self._output_examples_dir

        self._output_dict_ir = {
            standard_component_specs.INFERENCE_RESULT_KEY:
            [self._inference_result],
        }
        self._output_dict_oe = {
            standard_component_specs.OUTPUT_EXAMPLES_KEY:
            [self._output_examples],
        }

        # Create exe properties.
        self._exec_properties = {
            standard_component_specs.DATA_SPEC_KEY:
            proto_utils.proto_to_json(bulk_inferrer_pb2.DataSpec()),
            standard_component_specs.MODEL_SPEC_KEY:
            proto_utils.proto_to_json(bulk_inferrer_pb2.ModelSpec()),
            'component_id':
            self.component_id,
        }

        # Create context
        self._tmp_dir = os.path.join(self._output_data_dir, '.temp')
        self._context = executor.Executor.Context(tmp_dir=self._tmp_dir,
                                                  unique_id='2')
    def setUp(self):
        super().setUp()
        self._project_id = 'my-project'
        self._job_id = 'my-job-123'
        self._labels = ['label1', 'label2']
        self._mock_api_client = mock.Mock()
        examples_artifact = standard_artifacts.Examples()
        examples_artifact.split_names = artifact_utils.encode_split_names(
            ['train', 'eval'])
        examples_artifact.uri = _EXAMPLE_LOCATION
        self._inputs = {'examples': [examples_artifact]}
        model_artifact = standard_artifacts.Model()
        model_artifact.uri = _MODEL_LOCATION
        self._outputs = {'model': [model_artifact]}

        training_input = {
            'scaleTier':
            'CUSTOM',
            'region':
            'us-central1',
            'masterType':
            'n1-standard-8',
            'masterConfig': {
                'imageUri': 'gcr.io/my-project/caip-training-test:latest'
            },
            'workerType':
            'n1-standard-8',
            'workerCount':
            8,
            'workerConfig': {
                'imageUri': 'gcr.io/my-project/caip-training-test:latest'
            },
            'args': [
                '--examples',
                placeholders.InputUriPlaceholder('examples'), '--n-steps',
                placeholders.InputValuePlaceholder('n_step'), '--model-dir',
                placeholders.OutputUriPlaceholder('model')
            ]
        }

        aip_training_config = {
            ai_platform_training_executor.PROJECT_CONFIG_KEY: self._project_id,
            ai_platform_training_executor.TRAINING_INPUT_CONFIG_KEY:
            training_input,
            ai_platform_training_executor.JOB_ID_CONFIG_KEY: self._job_id,
            ai_platform_training_executor.LABELS_CONFIG_KEY: self._labels,
        }

        self._exec_properties = {
            ai_platform_training_executor.CONFIG_KEY:
            json_utils.dumps(aip_training_config),
            'n_step':
            100
        }

        resolved_training_input = copy.deepcopy(training_input)
        resolved_training_input['args'] = [
            '--examples', _EXAMPLE_LOCATION, '--n-steps', '100', '--model-dir',
            _MODEL_LOCATION
        ]

        self._expected_job_spec = {
            'jobId': self._job_id,
            'trainingInput': resolved_training_input,
            'labels': self._labels,
        }
Ejemplo n.º 6
0
    def testPublishSuccessExecutionExecutorEditedOutputDict(self):
        # There is one artifact in the system provided output_dict, while there are
        # two artifacts in executor output. We expect that two artifacts are
        # published.
        with metadata.Metadata(connection_config=self._connection_config) as m:
            contexts = self._generate_contexts(m)
            execution_id = execution_publish_utils.register_execution(
                m, self._execution_type, contexts).id

            output_example = standard_artifacts.Examples()
            output_example.uri = '/original_path'

            executor_output = execution_result_pb2.ExecutorOutput()
            output_key = 'examples'
            text_format.Parse(
                """
          uri: '/original_path/subdir_1'
          custom_properties {
            key: 'prop'
            value {int_value: 1}
          }
          """, executor_output.output_artifacts[output_key].artifacts.add())
            text_format.Parse(
                """
          uri: '/original_path/subdir_2'
          custom_properties {
            key: 'prop'
            value {int_value: 2}
          }
          """, executor_output.output_artifacts[output_key].artifacts.add())

            output_dict = execution_publish_utils.publish_succeeded_execution(
                m, execution_id, contexts, {output_key: [output_example]},
                executor_output)
            [execution] = m.store.get_executions()
            self.assertProtoPartiallyEquals("""
          id: 1
          type_id: 3
          last_known_state: COMPLETE
          """,
                                            execution,
                                            ignored_fields=[
                                                'create_time_since_epoch',
                                                'last_update_time_since_epoch'
                                            ])
            artifacts = m.store.get_artifacts()
            self.assertLen(artifacts, 2)
            self.assertProtoPartiallyEquals("""
          id: 1
          type_id: 4
          state: LIVE
          uri: '/original_path/subdir_1'
          custom_properties {
            key: 'prop'
            value {int_value: 1}
          }""",
                                            artifacts[0],
                                            ignored_fields=[
                                                'create_time_since_epoch',
                                                'last_update_time_since_epoch'
                                            ])
            self.assertProtoPartiallyEquals("""
          id: 2
          type_id: 4
          state: LIVE
          uri: '/original_path/subdir_2'
          custom_properties {
            key: 'prop'
            value {int_value: 2}
          }""",
                                            artifacts[1],
                                            ignored_fields=[
                                                'create_time_since_epoch',
                                                'last_update_time_since_epoch'
                                            ])
            events = m.store.get_events_by_execution_ids([execution.id])
            self.assertLen(events, 2)
            self.assertProtoPartiallyEquals(
                """
          artifact_id: 1
          execution_id: 1
          path {
            steps {
              key: 'examples'
            }
            steps {
              index: 0
            }
          }
          type: OUTPUT
          """,
                events[0],
                ignored_fields=['milliseconds_since_epoch'])
            self.assertProtoPartiallyEquals(
                """
          artifact_id: 2
          execution_id: 1
          path {
            steps {
              key: 'examples'
            }
            steps {
              index: 1
            }
          }
          type: OUTPUT
          """,
                events[1],
                ignored_fields=['milliseconds_since_epoch'])
            # Verifies the context-execution edges are set up.
            self.assertCountEqual([c.id for c in contexts], [
                c.id for c in m.store.get_contexts_by_execution(execution.id)
            ])
            for artifact_list in output_dict.values():
                for output_example in artifact_list:
                    self.assertCountEqual([c.id for c in contexts], [
                        c.id for c in m.store.get_contexts_by_artifact(
                            output_example.id)
                    ])
Ejemplo n.º 7
0
    def __init__(
            self,
            examples: types.Channel = None,
            schema: types.Channel = None,
            module_file: Optional[Union[Text,
                                        data_types.RuntimeParameter]] = None,
            preprocessing_fn: Optional[Union[
                Text, data_types.RuntimeParameter]] = None,
            transform_graph: Optional[types.Channel] = None,
            transformed_examples: Optional[types.Channel] = None,
            input_data: Optional[types.Channel] = None,
            instance_name: Optional[Text] = None,
            enable_cache: Optional[bool] = None):
        """Construct a Transform component.

    Args:
      examples: A Channel of type `standard_artifacts.Examples` (required).
        This should contain the two splits 'train' and 'eval'.
      schema: A Channel of type `standard_artifacts.Schema`. This should
        contain a single schema artifact.
      module_file: The file path to a python module file, from which the
        'preprocessing_fn' function will be loaded. The function must have the
        following signature.

        def preprocessing_fn(inputs: Dict[Text, Any]) -> Dict[Text, Any]:
          ...

        where the values of input and returned Dict are either tf.Tensor or
        tf.SparseTensor.  Exactly one of 'module_file' or 'preprocessing_fn'
        must be supplied.
      preprocessing_fn: The path to python function that implements a
        'preprocessing_fn'. See 'module_file' for expected signature of the
        function. Exactly one of 'module_file' or 'preprocessing_fn' must be
        supplied.
      transform_graph: Optional output 'TransformPath' channel for output of
        'tf.Transform', which includes an exported Tensorflow graph suitable for
        both training and serving;
      transformed_examples: Optional output 'ExamplesPath' channel for
        materialized transformed examples, which includes both 'train' and
        'eval' splits.
      input_data: Backwards compatibility alias for the 'examples' argument.
      instance_name: Optional unique instance name. Necessary iff multiple
        transform components are declared in the same pipeline.
      enable_cache: Optional boolean to indicate if cache is enabled for the
        Transform component. If not specified, defaults to the value
        specified for pipeline's enable_cache parameter.
    Raises:
      ValueError: When both or neither of 'module_file' and 'preprocessing_fn'
        is supplied.
    """
        if input_data:
            absl.logging.warning(
                'The "input_data" argument to the Transform component has '
                'been renamed to "examples" and is deprecated. Please update your '
                'usage as support for this argument will be removed soon.')
            examples = input_data
        if bool(module_file) == bool(preprocessing_fn):
            raise ValueError(
                "Exactly one of 'module_file' or 'preprocessing_fn' must be supplied."
            )

        transform_graph = transform_graph or types.Channel(
            type=standard_artifacts.TransformGraph,
            artifacts=[standard_artifacts.TransformGraph()])
        if not transformed_examples:
            example_artifact = standard_artifacts.Examples()
            example_artifact.split_names = artifact_utils.encode_split_names(
                artifact.DEFAULT_EXAMPLE_SPLITS)
            transformed_examples = types.Channel(
                type=standard_artifacts.Examples, artifacts=[example_artifact])
        spec = TransformSpec(examples=examples,
                             schema=schema,
                             module_file=module_file,
                             preprocessing_fn=preprocessing_fn,
                             transform_graph=transform_graph,
                             transformed_examples=transformed_examples)
        super(Transform, self).__init__(spec=spec,
                                        instance_name=instance_name,
                                        enable_cache=enable_cache)
Ejemplo n.º 8
0
 def testRunExecutor_with_InplaceUpdateExecutor(self):
     executor_sepc = text_format.Parse(
         """
   class_path: "tfx.orchestration.portable.python_executor_operator_test.InplaceUpdateExecutor"
 """, local_deployment_config_pb2.ExecutableSpec.PythonClassExecutableSpec(
         ))
     operator = python_executor_operator.PythonExecutorOperator(
         executor_sepc)
     input_dict = {'input_key': [standard_artifacts.Examples()]}
     output_dict = {'output_key': [standard_artifacts.Model()]}
     exec_properties = {
         'string': 'value',
         'int': 1,
         'float': 0.0,
         # This should not happen on production and will be
         # dropped.
         'proto': execution_result_pb2.ExecutorOutput()
     }
     stateful_working_dir = os.path.join(self.tmp_dir,
                                         'stateful_working_dir')
     executor_output_uri = os.path.join(self.tmp_dir, 'executor_output')
     executor_output = operator.run_executor(
         base_executor_operator.ExecutionInfo(
             input_dict=input_dict,
             output_dict=output_dict,
             exec_properties=exec_properties,
             stateful_working_dir=stateful_working_dir,
             executor_output_uri=executor_output_uri))
     self.assertProtoPartiallyEquals(
         """
       execution_properties {
         key: "float"
         value {
           double_value: 0.0
         }
       }
       execution_properties {
         key: "int"
         value {
           int_value: 1
         }
       }
       execution_properties {
         key: "string"
         value {
           string_value: "value"
         }
       }
       output_artifacts {
         key: "output_key"
         value {
           artifacts {
             custom_properties {
               key: "name"
               value {
                 string_value: "my_model"
               }
             }
           }
         }
       }""", executor_output)
Ejemplo n.º 9
0
 def _createExamples(self, span: int) -> standard_artifacts.Examples:
     artifact = standard_artifacts.Examples()
     artifact.uri = f'uri{span}'
     artifact.set_int_custom_property(utils.SPAN_PROPERTY_NAME, span)
     return artifact
Ejemplo n.º 10
0
  def __init__(
      self,
      examples: types.Channel = None,
      schema: types.Channel = None,
      module_file: Optional[Union[Text, data_types.RuntimeParameter]] = None,
      preprocessing_fn: Optional[Union[Text,
                                       data_types.RuntimeParameter]] = None,
      transform_graph: Optional[types.Channel] = None,
      transformed_examples: Optional[types.Channel] = None,
      input_data: Optional[types.Channel] = None,
      analyzer_cache: Optional[types.Channel] = None,
      instance_name: Optional[Text] = None,
      materialize: bool = True,
      disable_analyzer_cache: bool = False,
      custom_config: Optional[Dict[Text, Any]] = None):
    """Construct a Transform component.

    Args:
      examples: A Channel of type `standard_artifacts.Examples` (required).
        This should contain the two splits 'train' and 'eval'.
      schema: A Channel of type `standard_artifacts.Schema`. This should
        contain a single schema artifact.
      module_file: The file path to a python module file, from which the
        'preprocessing_fn' function will be loaded.
        Exactly one of 'module_file' or 'preprocessing_fn' must be supplied.

        The function needs to have the following signature:
        ```
        def preprocessing_fn(inputs: Dict[Text, Any]) -> Dict[Text, Any]:
          ...
        ```
        where the values of input and returned Dict are either tf.Tensor or
        tf.SparseTensor.

        If additional inputs are needed for preprocessing_fn, they can be passed
        in custom_config:

        ```
        def preprocessing_fn(inputs: Dict[Text, Any], custom_config:
                             Dict[Text, Any]) -> Dict[Text, Any]:
          ...
        ```
      preprocessing_fn: The path to python function that implements a
        'preprocessing_fn'. See 'module_file' for expected signature of the
        function. Exactly one of 'module_file' or 'preprocessing_fn' must be
        supplied.
      transform_graph: Optional output 'TransformPath' channel for output of
        'tf.Transform', which includes an exported Tensorflow graph suitable for
        both training and serving;
      transformed_examples: Optional output 'ExamplesPath' channel for
        materialized transformed examples, which includes both 'train' and
        'eval' splits.
      input_data: Backwards compatibility alias for the 'examples' argument.
      analyzer_cache: Optional input 'TransformCache' channel containing
        cached information from previous Transform runs. When provided,
        Transform will try use the cached calculation if possible.
      instance_name: Optional unique instance name. Necessary iff multiple
        transform components are declared in the same pipeline.
      materialize: If True, write transformed examples as an output. If False,
        `transformed_examples` must not be provided.
      disable_analyzer_cache: If False, Transform will use input cache if
        provided and write cache output. If True, `analyzer_cache` must not be
        provided.
      custom_config: A dict which contains additional parameters that will be
        passed to preprocessing_fn.

    Raises:
      ValueError: When both or neither of 'module_file' and 'preprocessing_fn'
        is supplied.
    """
    if input_data:
      absl.logging.warning(
          'The "input_data" argument to the Transform component has '
          'been renamed to "examples" and is deprecated. Please update your '
          'usage as support for this argument will be removed soon.')
      examples = input_data
    if bool(module_file) == bool(preprocessing_fn):
      raise ValueError(
          "Exactly one of 'module_file' or 'preprocessing_fn' must be supplied."
      )

    transform_graph = transform_graph or types.Channel(
        type=standard_artifacts.TransformGraph,
        artifacts=[standard_artifacts.TransformGraph()])

    if materialize and transformed_examples is None:
      transformed_examples = types.Channel(
          type=standard_artifacts.Examples,
          # TODO(b/161548528): remove the hardcode artifact.
          artifacts=[standard_artifacts.Examples()],
          matching_channel_name='examples')
    elif not materialize and transformed_examples is not None:
      raise ValueError(
          'Must not specify transformed_examples when materialize is False.')

    if disable_analyzer_cache:
      updated_analyzer_cache = None
      if analyzer_cache:
        raise ValueError(
            '`analyzer_cache` is set when disable_analyzer_cache is True.')
    else:
      updated_analyzer_cache = types.Channel(
          type=standard_artifacts.TransformCache,
          artifacts=[standard_artifacts.TransformCache()])

    spec = TransformSpec(
        examples=examples,
        schema=schema,
        module_file=module_file,
        preprocessing_fn=preprocessing_fn,
        transform_graph=transform_graph,
        transformed_examples=transformed_examples,
        analyzer_cache=analyzer_cache,
        updated_analyzer_cache=updated_analyzer_cache,
        custom_config=json.dumps(custom_config))
    super(Transform, self).__init__(spec=spec, instance_name=instance_name)
Ejemplo n.º 11
0
    def setUp(self):
        super(ExecutorTest, self).setUp()

        # Setup Mocks

        runner_patcher = mock.patch('tfx.components.infra_validator'
                                    '.model_server_runners.local_docker_runner'
                                    '.LocalDockerModelServerRunner')
        self.model_server = runner_patcher.start().return_value
        self.addCleanup(runner_patcher.stop)

        build_request_patcher = mock.patch(
            'tfx.components.infra_validator.request_builder'
            '.build_requests')
        self.build_requests_mock = build_request_patcher.start()
        self.addCleanup(build_request_patcher.stop)

        self.model_server.client = mock.MagicMock()

        # Setup directories

        source_data_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'testdata')
        base_output_dir = os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR',
                                         self.get_temp_dir())
        output_data_dir = os.path.join(base_output_dir, self._testMethodName)

        # Setup input_dict.

        model = standard_artifacts.Model()
        model.uri = os.path.join(source_data_dir, 'trainer', 'current')
        examples = standard_artifacts.Examples()
        examples.uri = os.path.join(source_data_dir, 'transform',
                                    'transformed_examples', 'eval')
        examples.split_names = artifact_utils.encode_split_names(['eval'])

        self.input_dict = {
            'model': [model],
            'examples': [examples],
        }

        # Setup output_dict.

        self.blessing = standard_artifacts.InfraBlessing()
        self.blessing.uri = os.path.join(output_data_dir, 'blessing')
        self.output_dict = {'blessing': [self.blessing]}

        # Setup Context

        temp_dir = os.path.join(output_data_dir, '.temp')
        self.context = executor.Executor.Context(tmp_dir=temp_dir,
                                                 unique_id='1')

        # Setup exec_properties

        self.exec_properties = {
            'serving_spec':
            json.dumps({
                'tensorflow_serving': {
                    'tags': ['1.15.0']
                },
                'local_docker': {}
            }),
            'validation_spec':
            json.dumps({'max_loading_time_seconds': 10}),
            'request_spec':
            json.dumps({
                'tensorflow_serving': {
                    'rpc_kind': 'CLASSIFY'
                },
                'max_examples': 10
            })
        }
Ejemplo n.º 12
0
 def setUp(self):
     super().setUp()
     self._examples = standard_artifacts.Examples()
     self._examples.uri = _CSV_EXAMPLE_GEN_URI
     self._examples.split_names = artifact_utils.encode_split_names(
         ['train', 'eval'])
Ejemplo n.º 13
0
 def examples(self):
     examples = standard_artifacts.Examples()
     return channel_utils.as_channel([examples])
Ejemplo n.º 14
0
    def setUp(self):
        super(ExecutorTest, self).setUp()

        # Setup Mocks

        patcher = mock.patch.object(request_builder, 'build_requests')
        self.build_requests_mock = patcher.start()
        self.addCleanup(patcher.stop)

        # Setup directories

        source_data_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'testdata')
        base_output_dir = os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR',
                                         self.get_temp_dir())
        output_data_dir = os.path.join(base_output_dir, self._testMethodName)

        # Setup input_dict.

        self._model = standard_artifacts.Model()
        self._model.uri = os.path.join(source_data_dir, 'trainer', 'current')
        self._model_path = path_utils.serving_model_path(self._model.uri)
        examples = standard_artifacts.Examples()
        examples.uri = os.path.join(source_data_dir, 'transform',
                                    'transformed_examples', 'eval')
        examples.split_names = artifact_utils.encode_split_names(['eval'])

        self._input_dict = {
            'model': [self._model],
            'examples': [examples],
        }
        self._blessing = standard_artifacts.InfraBlessing()
        self._blessing.uri = os.path.join(output_data_dir, 'blessing')
        self._output_dict = {'blessing': [self._blessing]}
        temp_dir = os.path.join(output_data_dir, '.temp')
        self._context = executor.Executor.Context(tmp_dir=temp_dir,
                                                  unique_id='1')
        self._serving_spec = _make_serving_spec({
            'tensorflow_serving': {
                'tags': ['1.15.0']
            },
            'local_docker': {},
            'model_name': 'chicago-taxi',
        })
        self._serving_binary = serving_bins.parse_serving_binaries(
            self._serving_spec)[0]
        self._validation_spec = _make_validation_spec({
            'max_loading_time_seconds':
            10,
            'num_tries':
            3
        })
        self._request_spec = _make_request_spec({
            'tensorflow_serving': {
                'signature_names': ['serving_default'],
            },
            'num_examples': 1
        })
        self._exec_properties = {
            'serving_spec': json_format.MessageToJson(self._serving_spec),
            'validation_spec':
            json_format.MessageToJson(self._validation_spec),
            'request_spec': json_format.MessageToJson(self._request_spec),
        }
Ejemplo n.º 15
0
    def __init__(self,
                 input_data: types.Channel = None,
                 schema: types.Channel = None,
                 module_file: Optional[Text] = None,
                 preprocessing_fn: Optional[Text] = None,
                 transform_output: Optional[types.Channel] = None,
                 transformed_examples: Optional[types.Channel] = None,
                 examples: Optional[types.Channel] = None,
                 name: Optional[Text] = None):
        """Construct a Transform component.

    Args:
      input_data: A Channel of 'ExamplesPath' type (required). This should
        contain the two splits 'train' and 'eval'.
      schema: A Channel of 'SchemaPath' type. This should contain a single
        schema artifact.
      module_file: The file path to a python module file, from which the
        'preprocessing_fn' function will be loaded. The function must have the
        following signature.

        def preprocessing_fn(inputs: Dict[Text, Any]) -> Dict[Text, Any]:
          ...

        where the values of input and returned Dict are either tf.Tensor or
        tf.SparseTensor.  Exactly one of 'module_file' or 'preprocessing_fn'
        must be supplied.
      preprocessing_fn: The path to python function that implements a
         'preprocessing_fn'. See 'module_file' for expected signature of the
         function. Exactly one of 'module_file' or 'preprocessing_fn' must
         be supplied.
      transform_output: Optional output 'TransformPath' channel for output of
        'tf.Transform', which includes an exported Tensorflow graph suitable for
        both training and serving;
      transformed_examples: Optional output 'ExamplesPath' channel for
        materialized transformed examples, which includes both 'train' and
        'eval' splits.
      examples: Forwards compatibility alias for the 'input_data' argument.
      name: Optional unique name. Necessary iff multiple transform components
        are declared in the same pipeline.

    Raises:
      ValueError: When both or neither of 'module_file' and 'preprocessing_fn'
        is supplied.
    """
        input_data = input_data or examples
        if bool(module_file) == bool(preprocessing_fn):
            raise ValueError(
                "Exactly one of 'module_file' or 'preprocessing_fn' must be supplied."
            )

        transform_output = transform_output or types.Channel(
            type=standard_artifacts.TransformGraph,
            artifacts=[standard_artifacts.TransformGraph()])
        transformed_examples = transformed_examples or types.Channel(
            type=standard_artifacts.Examples,
            artifacts=[
                standard_artifacts.Examples(split=split)
                for split in artifact.DEFAULT_EXAMPLE_SPLITS
            ])
        spec = TransformSpec(input_data=input_data,
                             schema=schema,
                             module_file=module_file,
                             preprocessing_fn=preprocessing_fn,
                             transform_output=transform_output,
                             transformed_examples=transformed_examples)
        super(Transform, self).__init__(spec=spec, name=name)
Ejemplo n.º 16
0
    def testDo(self):
        source_data_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'testdata')
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        # Create input dict.
        train_examples = standard_artifacts.Examples(split='train')
        eval_examples = standard_artifacts.Examples(split='eval')
        eval_examples.uri = os.path.join(source_data_dir,
                                         'csv_example_gen/eval/')
        model_exports = standard_artifacts.Model()
        model_exports.uri = os.path.join(source_data_dir, 'trainer/current/')
        input_dict = {
            'examples': [train_examples, eval_examples],
            'model_exports': [model_exports],
        }

        # Create output dict.
        eval_output = standard_artifacts.ModelEvaluation()
        eval_output.uri = os.path.join(output_data_dir, 'eval_output')
        output_dict = {'output': [eval_output]}

        # Create exec proterties.
        exec_properties = {
            'feature_slicing_spec':
            json_format.MessageToJson(evaluator_pb2.FeatureSlicingSpec(specs=[
                evaluator_pb2.SingleSlicingSpec(
                    column_for_slicing=['trip_start_hour']),
                evaluator_pb2.SingleSlicingSpec(
                    column_for_slicing=['trip_start_day', 'trip_miles']),
            ]),
                                      preserving_proto_field_name=True)
        }

        try:
            # Need to import the following module so that the fairness indicator
            # post-export metric is registered.  This may raise an ImportError if the
            # currently-installed version of TFMA does not support fairness
            # indicators.
            import tensorflow_model_analysis.addons.fairness.post_export_metrics.fairness_indicators  # pylint: disable=g-import-not-at-top, unused-variable
            exec_properties['fairness_indicator_thresholds'] = [
                0.1, 0.3, 0.5, 0.7, 0.9
            ]
        except ImportError:
            absl.logging.warning(
                'Not testing fairness indicators because a compatible TFMA version '
                'is not installed.')

        # Run executor.
        evaluator = executor.Executor()
        evaluator.Do(input_dict, output_dict, exec_properties)

        # Check evaluator outputs.
        self.assertTrue(
            # TODO(b/141490237): Update to only check eval_config.json after TFMA
            # released with corresponding change.
            tf.io.gfile.exists(os.path.join(eval_output.uri, 'eval_config'))
            or tf.io.gfile.exists(
                os.path.join(eval_output.uri, 'eval_config.json')))
        self.assertTrue(
            tf.io.gfile.exists(os.path.join(eval_output.uri, 'metrics')))
        self.assertTrue(
            tf.io.gfile.exists(os.path.join(eval_output.uri, 'plots')))
Ejemplo n.º 17
0
    def setUp(self):
        super(ExecutorTest, self).setUp()
        self._source_data_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'testdata')
        self._output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        # Create input dict.
        e1 = standard_artifacts.Examples()
        e1.uri = os.path.join(self._source_data_dir,
                              'transform/transformed_examples')
        e1.split_names = artifact_utils.encode_split_names(['train', 'eval'])

        e2 = copy.deepcopy(e1)

        self._single_artifact = [e1]
        self._multiple_artifacts = [e1, e2]

        transform_output = standard_artifacts.TransformGraph()
        transform_output.uri = os.path.join(self._source_data_dir,
                                            'transform/transform_graph')

        schema = standard_artifacts.Schema()
        schema.uri = os.path.join(self._source_data_dir, 'schema_gen')
        previous_model = standard_artifacts.Model()
        previous_model.uri = os.path.join(self._source_data_dir,
                                          'trainer/previous')

        self._input_dict = {
            constants.EXAMPLES_KEY: self._single_artifact,
            constants.TRANSFORM_GRAPH_KEY: [transform_output],
            constants.SCHEMA_KEY: [schema],
            constants.BASE_MODEL_KEY: [previous_model]
        }

        # Create output dict.
        self._model_exports = standard_artifacts.Model()
        self._model_exports.uri = os.path.join(self._output_data_dir,
                                               'model_export_path')
        self._model_run_exports = standard_artifacts.ModelRun()
        self._model_run_exports.uri = os.path.join(self._output_data_dir,
                                                   'model_run_path')
        self._output_dict = {
            constants.MODEL_KEY: [self._model_exports],
            constants.MODEL_RUN_KEY: [self._model_run_exports]
        }

        # Create exec properties skeleton.
        self._exec_properties = {
            'train_args':
            json_format.MessageToJson(trainer_pb2.TrainArgs(num_steps=1000),
                                      preserving_proto_field_name=True),
            'eval_args':
            json_format.MessageToJson(trainer_pb2.EvalArgs(num_steps=500),
                                      preserving_proto_field_name=True),
            'warm_starting':
            False,
        }

        self._module_file = os.path.join(self._source_data_dir, 'module_file',
                                         'trainer_module.py')
        self._trainer_fn = '%s.%s' % (trainer_module.trainer_fn.__module__,
                                      trainer_module.trainer_fn.__name__)

        # Executors for test.
        self._trainer_executor = executor.Executor()
        self._generic_trainer_executor = executor.GenericExecutor()
Ejemplo n.º 18
0
    def __init__(self,
                 query: Optional[Text] = None,
                 beam_transform: beam.PTransform = None,
                 bucket_name: Optional[Text] = None,
                 output_schema: Optional[Text] = None,
                 table_name: Optional[Text] = None,
                 use_bigquery_source: Optional[Any] = False,
                 input_config: Optional[example_gen_pb2.Input] = None,
                 output_config: Optional[example_gen_pb2.Output] = None,
                 example_artifacts: Optional[types.Channel] = None,
                 instance_name: Optional[Text] = None):
        """Constructs a BigQueryExampleGen component.

        Args:
            query: BigQuery sql string, query result will be treated as a single
                split, can be overwritten by input_config.
                input_config: An example_gen_pb2.Input instance with Split.pattern as
                BigQuery sql string. If set, it overwrites the 'query' arg, and allows
                different queries per split. If any field is provided as a
                RuntimeParameter, input_config should be constructed as a dict with the
                same field names as Input proto message.
            beam_transform: beam.PTransform pipeline. Will be used to processed data ingested
                by the BigQuery query.
            bucket_name: string containing a GCS bucket name. Will be used as a temporary storage
                space to read query and pickle file.
            table_name: string containing the BigQuery output table name.
            use_bigquery_source: Whether to use BigQuerySource instead of experimental
                `ReadFromBigQuery` PTransform (required by the BigQueryExampleGen executor)
            input_config: An example_gen_pb2.Input instance with Split.pattern as
                BigQuery sql string. If set, it overwrites the 'query' arg, and allows
                different queries per split. If any field is provided as a
                RuntimeParameter, input_config should be constructed as a dict with the
                same field names as Input proto message.
            output_config: An example_gen_pb2.Output instance, providing output
                    configuration. If unset, default splits will be 'train' and 'eval' with
                    size 2:1. If any field is provided as a RuntimeParameter,
                    input_config should be constructed as a dict with the same field names
                    as Output proto message.
            example_artifacts: Optional channel of 'ExamplesPath' for output train and
                    eval examples.
            instance_name: Optional unique instance name. Necessary if multiple
                    BigQueryExampleGen components are declared in the same pipeline.

        Raises:
            RuntimeError: Only one of query and input_config should be set.
        """

        # Configure inputs and outputs
        input_config = input_config or utils.make_default_input_config()
        output_config = output_config or utils.make_default_output_config(
            input_config)

        if not example_artifacts:
            example_artifacts = channel_utils.as_channel(
                [standard_artifacts.Examples()])

        # Upload Beam Transform to a GCS Bucket
        beam_transform_uri = upload_beam_to_gcs(beam_transform, bucket_name)

        spec = TCGAPreprocessingSpec(
            # custom parameters
            query=query,
            output_schema=output_schema,
            table_name=table_name,
            use_bigquery_source=use_bigquery_source,
            # default parameters
            input_config=input_config,
            output_config=output_config,
            input_base=beam_transform_uri,
            # outputs
            examples=example_artifacts)
        super(TCGAPreprocessing, self).__init__(spec=spec,
                                                instance_name=instance_name)
Ejemplo n.º 19
0
def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text,
                     module_file: Text, serving_model_dir: Text,
                     metadata_path: Text,
                     direct_num_workers: int) -> pipeline.Pipeline:

    input_data = external_input(examples_path)

    input_config = example_gen_pb2.Input(splits=[
        example_gen_pb2.Input.Split(name='train', pattern='train.tfrecord'),
        example_gen_pb2.Input.Split(name='eval', pattern='eval.tfrecord')
    ])

    example_gen = ImportExampleGen(input=input_data, input_config=input_config)

    identify_examples = IdentifyExamples(
        orig_examples=example_gen.outputs['examples'],
        component_name=u'IdentifyExamples',
        id_feature_name=u'id')

    # Computes statistics over data for visualization and example validation.
    statistics_gen = StatisticsGen(
        examples=identify_examples.outputs["identified_examples"])

    schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics'])

    # Performs anomaly detection based on statistics and data schema.
    validate_stats = ExampleValidator(
        statistics=statistics_gen.outputs['statistics'],
        schema=schema_gen.outputs['schema'])

    synthesize_graph = SynthesizeGraph(
        identified_examples=identify_examples.outputs['identified_examples'],
        component_name=u'SynthesizeGraph',
        similarity_threshold=0.99)

    transform = Transform(
        examples=identify_examples.outputs['identified_examples'],
        schema=schema_gen.outputs['schema'],
        # TODO(b/169218106): Remove transformed_examples kwargs after bugfix is released.
        transformed_examples=channel.Channel(
            type=standard_artifacts.Examples,
            artifacts=[standard_artifacts.Examples()]),
        module_file=_transform_module_file)

    # Augments training data with graph neighbors.
    graph_augmentation = GraphAugmentation(
        identified_examples=transform.outputs['transformed_examples'],
        synthesized_graph=synthesize_graph.outputs['synthesized_graph'],
        component_name=u'GraphAugmentation',
        num_neighbors=3)

    trainer = Trainer(
        module_file=_trainer_module_file,
        transformed_examples=graph_augmentation.outputs['augmented_examples'],
        schema=schema_gen.outputs['schema'],
        transform_graph=transform.outputs['transform_graph'],
        train_args=trainer_pb2.TrainArgs(num_steps=10000),
        eval_args=trainer_pb2.EvalArgs(num_steps=5000))

    model_validator = ModelValidator(examples=example_gen.outputs['examples'],
                                     model=trainer.outputs['model'])

    pusher = Pusher(model=trainer.outputs['model'],
                    model_blessing=model_validator.outputs['blessing'],
                    push_destination=pusher_pb2.PushDestination(
                        filesystem=pusher_pb2.PushDestination.Filesystem(
                            base_directory=serving_model_dir)))

    return pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=[
            example_gen, identify_examples, statistics_gen, schema_gen,
            validate_stats, synthesize_graph, transform, graph_augmentation,
            trainer, model_validator, pusher
        ],
        enable_cache=True,
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            metadata_path),
        beam_pipeline_args=['--direct_num_workers=%d' % direct_num_workers])
Ejemplo n.º 20
0
 def testPublishSuccessfulExecution(self):
     with metadata.Metadata(connection_config=self._connection_config) as m:
         contexts = self._generate_contexts(m)
         execution_id = execution_publish_utils.register_execution(
             m, self._execution_type, contexts).id
         output_key = 'examples'
         output_example = standard_artifacts.Examples()
         output_example.uri = '/examples_uri'
         executor_output = execution_result_pb2.ExecutorOutput()
         text_format.Parse(
             """
       uri: '/examples_uri'
       custom_properties {
         key: 'prop'
         value {int_value: 1}
       }
       """, executor_output.output_artifacts[output_key].artifacts.add())
         output_dict = execution_publish_utils.publish_succeeded_execution(
             m, execution_id, contexts, {output_key: [output_example]},
             executor_output)
         [execution] = m.store.get_executions()
         self.assertProtoPartiallyEquals("""
       id: 1
       type_id: 3
       last_known_state: COMPLETE
       """,
                                         execution,
                                         ignored_fields=[
                                             'create_time_since_epoch',
                                             'last_update_time_since_epoch'
                                         ])
         [artifact] = m.store.get_artifacts()
         self.assertProtoPartiallyEquals("""
       id: 1
       type_id: 4
       state: LIVE
       uri: '/examples_uri'
       custom_properties {
         key: 'prop'
         value {int_value: 1}
       }""",
                                         artifact,
                                         ignored_fields=[
                                             'create_time_since_epoch',
                                             'last_update_time_since_epoch'
                                         ])
         [event] = m.store.get_events_by_execution_ids([execution.id])
         self.assertProtoPartiallyEquals(
             """
       artifact_id: 1
       execution_id: 1
       path {
         steps {
           key: 'examples'
         }
         steps {
           index: 0
         }
       }
       type: OUTPUT
       """,
             event,
             ignored_fields=['milliseconds_since_epoch'])
         # Verifies the context-execution edges are set up.
         self.assertCountEqual([c.id for c in contexts], [
             c.id for c in m.store.get_contexts_by_execution(execution.id)
         ])
         for artifact_list in output_dict.values():
             for output_example in artifact_list:
                 self.assertCountEqual([c.id for c in contexts], [
                     c.id for c in m.store.get_contexts_by_artifact(
                         output_example.id)
                 ])
Ejemplo n.º 21
0
    def __init__(
            self,
            # TODO(b/159467778): deprecate this, use input_base instead.
            input: Optional[types.Channel] = None,  # pylint: disable=redefined-builtin
            input_base: Optional[Text] = None,
            input_config: Optional[Union[example_gen_pb2.Input,
                                         Dict[Text, Any]]] = None,
            output_config: Optional[Union[example_gen_pb2.Output,
                                          Dict[Text, Any]]] = None,
            custom_config: Optional[Union[example_gen_pb2.CustomConfig,
                                          Dict[Text, Any]]] = None,
            output_data_format: Optional[int] = example_gen_pb2.
        FORMAT_TF_EXAMPLE,
            example_artifacts: Optional[types.Channel] = None,
            custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None,
            instance_name: Optional[Text] = None):
        """Construct a FileBasedExampleGen component.

    Args:
      input: A Channel of type `standard_artifacts.ExternalArtifact`, which
        includes one artifact whose uri is an external directory containing the
        data files. (Deprecated by input_base)
      input_base: an external directory containing the data files.
      input_config: An
        [`example_gen_pb2.Input`](https://github.com/tensorflow/tfx/blob/master/tfx/proto/example_gen.proto)
          instance, providing input configuration. If unset, input files will be
          treated as a single split.
      output_config: An example_gen_pb2.Output instance, providing the output
        configuration. If unset, default splits will be 'train' and
        'eval' with size 2:1.
      custom_config: An optional example_gen_pb2.CustomConfig instance,
        providing custom configuration for executor.
      output_data_format: Payload format of generated data in output artifact,
        one of example_gen_pb2.PayloadFormat enum.
      example_artifacts: Channel of 'ExamplesPath' for output train and eval
        examples.
      custom_executor_spec: Optional custom executor spec overriding the default
        executor spec specified in the component attribute.
      instance_name: Optional unique instance name. Required only if multiple
        ExampleGen components are declared in the same pipeline.
    """
        if input:
            logging.warning(
                'The "input" argument to the ExampleGen component has been '
                'deprecated by "input_base". Please update your usage as support for '
                'this argument will be removed soon.')
            input_base = artifact_utils.get_single_uri(list(input.get()))
        # Configure inputs and outputs.
        input_config = input_config or utils.make_default_input_config()
        output_config = output_config or utils.make_default_output_config(
            input_config)
        if not example_artifacts:
            artifact = standard_artifacts.Examples()
            artifact.split_names = artifact_utils.encode_split_names(
                utils.generate_output_split_names(input_config, output_config))
            example_artifacts = channel_utils.as_channel([artifact])
        spec = FileBasedExampleGenSpec(input_base=input_base,
                                       input_config=input_config,
                                       output_config=output_config,
                                       custom_config=custom_config,
                                       output_data_format=output_data_format,
                                       examples=example_artifacts)
        super(FileBasedExampleGen,
              self).__init__(spec=spec,
                             custom_executor_spec=custom_executor_spec,
                             instance_name=instance_name)
Ejemplo n.º 22
0
    def testTrainerFn(self):
        temp_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        schema_file = os.path.join(self._testdata_path,
                                   'schema_gen/schema.pbtxt')
        data_accessor = DataAccessor(tf_dataset_factory=tfxio_utils.
                                     get_tf_dataset_factory_from_artifact(
                                         [standard_artifacts.Examples()], []),
                                     record_batch_factory=None)
        trainer_fn_args = trainer_executor.TrainerFnArgs(
            train_files=os.path.join(
                self._testdata_path,
                'transform/transformed_examples/train/*.gz'),
            transform_output=os.path.join(self._testdata_path,
                                          'transform/transform_graph'),
            serving_model_dir=os.path.join(temp_dir, 'serving_model_dir'),
            eval_files=os.path.join(
                self._testdata_path,
                'transform/transformed_examples/eval/*.gz'),
            schema_file=schema_file,
            train_steps=1,
            eval_steps=1,
            base_model=None,
            data_accessor=data_accessor)
        schema = io_utils.parse_pbtxt_file(schema_file, schema_pb2.Schema())
        training_spec = taxi_utils.trainer_fn(trainer_fn_args, schema)

        estimator = training_spec['estimator']
        train_spec = training_spec['train_spec']
        eval_spec = training_spec['eval_spec']
        eval_input_receiver_fn = training_spec['eval_input_receiver_fn']

        self.assertIsInstance(estimator,
                              tf.estimator.DNNLinearCombinedClassifier)
        self.assertIsInstance(train_spec, tf.estimator.TrainSpec)
        self.assertIsInstance(eval_spec, tf.estimator.EvalSpec)
        self.assertIsInstance(eval_input_receiver_fn, types.FunctionType)

        # Test keep_max_checkpoint in RunConfig
        self.assertGreater(estimator._config.keep_checkpoint_max, 1)

        # Train for one step, then eval for one step.
        eval_result, exports = tf.estimator.train_and_evaluate(
            estimator, train_spec, eval_spec)
        self.assertGreater(eval_result['loss'], 0.0)
        self.assertEqual(len(exports), 1)
        self.assertGreaterEqual(len(fileio.listdir(exports[0])), 1)

        # Export the eval saved model.
        eval_savedmodel_path = tfma.export.export_eval_savedmodel(
            estimator=estimator,
            export_dir_base=path_utils.eval_model_dir(temp_dir),
            eval_input_receiver_fn=eval_input_receiver_fn)
        self.assertGreaterEqual(len(fileio.listdir(eval_savedmodel_path)), 1)

        # Test exported serving graph.
        with tf.compat.v1.Session() as sess:
            metagraph_def = tf.compat.v1.saved_model.loader.load(
                sess, [tf.saved_model.SERVING], exports[0])
            self.assertIsInstance(metagraph_def, tf.compat.v1.MetaGraphDef)
Ejemplo n.º 23
0
  def setUp(self):
    super(TunerTest, self).setUp()

    self.examples = channel_utils.as_channel([standard_artifacts.Examples()])
    self.schema = channel_utils.as_channel([standard_artifacts.Schema()])
Ejemplo n.º 24
0
    def testExecution(self):
        with metadata.Metadata(connection_config=self._connection_config) as m:
            contexts = m.register_pipeline_contexts_if_not_exists(
                self._pipeline_info)
            # Test prepare_execution.
            exec_properties = {'arg_one': 1}
            input_artifact = standard_artifacts.Examples()
            output_artifact = standard_artifacts.Examples()
            input_artifacts = {'input': [input_artifact]}
            output_artifacts = {'output': [output_artifact]}
            m.register_execution(input_artifacts=input_artifacts,
                                 exec_properties=exec_properties,
                                 pipeline_info=self._pipeline_info,
                                 component_info=self._component_info,
                                 contexts=contexts)
            [execution] = m.store.get_executions_by_context(contexts[0].id)
            # Skip verifying time sensitive fields.
            execution.ClearField('create_time_since_epoch')
            execution.ClearField('last_update_time_since_epoch')
            self.assertProtoEquals(
                """
        id: 1
        type_id: 3
        properties {
          key: "state"
          value {
            string_value: "new"
          }
        }
        properties {
          key: "pipeline_name"
          value {
            string_value: "my_pipeline"
          }
        }
        properties {
          key: "pipeline_root"
          value {
            string_value: "/tmp"
          }
        }
        properties {
          key: "run_id"
          value {
            string_value: "my_run_id"
          }
        }
        properties {
          key: "component_id"
          value {
            string_value: "my_component"
          }
        }
        properties {
          key: "arg_one"
          value {
            string_value: "1"
          }
        }""", execution)

            # Test publish_execution.
            m.publish_execution(component_info=self._component_info,
                                output_artifacts=output_artifacts)
            # Make sure artifacts in output_dict are published.
            self.assertEqual(ArtifactState.PUBLISHED, output_artifact.state)
            # Make sure execution state are changed.
            [execution] = m.store.get_executions_by_id([execution.id])
            self.assertEqual(metadata.EXECUTION_STATE_COMPLETE,
                             execution.properties['state'].string_value)
            # Make sure events are published.
            events = m.store.get_events_by_execution_ids([execution.id])
            self.assertEqual(2, len(events))
            self.assertEqual(input_artifact.id, events[0].artifact_id)
            self.assertEqual(metadata_store_pb2.Event.INPUT, events[0].type)
            self.assertProtoEquals(
                """
          steps {
            key: "input"
          }
          steps {
            index: 0
          }""", events[0].path)
            self.assertEqual(output_artifact.id, events[1].artifact_id)
            self.assertEqual(metadata_store_pb2.Event.OUTPUT, events[1].type)
            self.assertProtoEquals(
                """
          steps {
            key: "output"
          }
          steps {
            index: 0
          }""", events[1].path)
Ejemplo n.º 25
0
    def testGetArtifactsDict(self):
        with metadata.Metadata(connection_config=self._connection_config) as m:
            # Create and shuffle a few artifacts. The shuffled order should be
            # retained in the output of `execution_lib.get_artifacts_dict`.
            input_artifact_keys = ('input1', 'input2', 'input3')
            input_artifacts_dict = collections.OrderedDict()
            for input_key in input_artifact_keys:
                input_examples = []
                for i in range(10):
                    input_example = standard_artifacts.Examples()
                    input_example.uri = f'{input_key}/example{i}'
                    input_example.type_id = common_utils.register_type_if_not_exist(
                        m, input_example.artifact_type).id
                    input_examples.append(input_example)
                random.shuffle(input_examples)
                input_artifacts_dict[input_key] = input_examples

            output_models = []
            for i in range(8):
                output_model = standard_artifacts.Model()
                output_model.uri = f'model{i}'
                output_model.type_id = common_utils.register_type_if_not_exist(
                    m, output_model.artifact_type).id
                output_models.append(output_model)
            random.shuffle(output_models)
            output_artifacts_dict = {'model': output_models}

            # Store input artifacts only. Outputs will be saved in put_execution().
            input_mlmd_artifacts = [
                a.mlmd_artifact
                for a in itertools.chain(*input_artifacts_dict.values())
            ]
            artifact_ids = m.store.put_artifacts(input_mlmd_artifacts)
            for artifact_id, mlmd_artifact in zip(artifact_ids,
                                                  input_mlmd_artifacts):
                mlmd_artifact.id = artifact_id

            execution = execution_lib.prepare_execution(
                m,
                metadata_store_pb2.ExecutionType(name='my_execution_type'),
                state=metadata_store_pb2.Execution.RUNNING)
            contexts = self._generate_contexts(m)

            # Change the order of the OrderedDict to shuffle the order of input keys.
            input_artifacts_dict.move_to_end('input1')
            execution = execution_lib.put_execution(
                m,
                execution,
                contexts,
                input_artifacts=input_artifacts_dict,
                output_artifacts=output_artifacts_dict)

            # Verify that the same artifacts are returned in the correct order.
            artifacts_dict = execution_lib.get_artifacts_dict(
                m, execution.id, [metadata_store_pb2.Event.INPUT])
            self.assertEqual(set(input_artifact_keys),
                             set(artifacts_dict.keys()))
            for key in artifacts_dict:
                self.assertEqual([ex.uri for ex in input_artifacts_dict[key]],
                                 [a.uri for a in artifacts_dict[key]],
                                 f'for key={key}')
            artifacts_dict = execution_lib.get_artifacts_dict(
                m, execution.id, [metadata_store_pb2.Event.OUTPUT])
            self.assertEqual({'model'}, set(artifacts_dict.keys()))
            self.assertEqual([model.uri for model in output_models],
                             [a.uri for a in artifacts_dict['model']])
Ejemplo n.º 26
0
 def testGetCachedOutputArtifacts(self):
   # Output artifacts that will be used by the first execution with the same
   # cache key.
   output_model_one = standard_artifacts.Model()
   output_model_one.uri = 'model_one'
   output_model_two = standard_artifacts.Model()
   output_model_two.uri = 'model_two'
   output_example_one = standard_artifacts.Examples()
   output_example_one.uri = 'example_one'
   # Output artifacts that will be used by the second execution with the same
   # cache key.
   output_model_three = standard_artifacts.Model()
   output_model_three.uri = 'model_three'
   output_model_four = standard_artifacts.Model()
   output_model_four.uri = 'model_four'
   output_example_two = standard_artifacts.Examples()
   output_example_two.uri = 'example_two'
   output_models_key = 'output_models'
   output_examples_key = 'output_examples'
   with metadata.Metadata(connection_config=self._connection_config) as m:
     cache_context = context_lib.register_context_if_not_exists(
         m, context_lib.CONTEXT_TYPE_EXECUTION_CACHE, 'cache_key')
     execution_one = execution_publish_utils.register_execution(
         m, metadata_store_pb2.ExecutionType(name='my_type'), [cache_context])
     execution_publish_utils.publish_succeeded_execution(
         m,
         execution_one.id, [cache_context],
         output_artifacts={
             output_models_key: [output_model_one, output_model_two],
             output_examples_key: [output_example_one]
         })
     execution_two = execution_publish_utils.register_execution(
         m, metadata_store_pb2.ExecutionType(name='my_type'), [cache_context])
     execution_publish_utils.publish_succeeded_execution(
         m,
         execution_two.id, [cache_context],
         output_artifacts={
             output_models_key: [output_model_three, output_model_four],
             output_examples_key: [output_example_two]
         })
     # The cached output got should be the artifacts produced by the most
     # recent execution under the given cache context.
     cached_output = cache_utils.get_cached_outputs(m, cache_context)
     self.assertLen(cached_output, 2)
     self.assertLen(cached_output[output_models_key], 2)
     self.assertLen(cached_output[output_examples_key], 1)
     self.assertProtoPartiallyEquals(
         cached_output[output_models_key][0].mlmd_artifact,
         output_model_three.mlmd_artifact,
         ignored_fields=[
             'create_time_since_epoch', 'last_update_time_since_epoch'
         ])
     self.assertProtoPartiallyEquals(
         cached_output[output_models_key][1].mlmd_artifact,
         output_model_four.mlmd_artifact,
         ignored_fields=[
             'create_time_since_epoch', 'last_update_time_since_epoch'
         ])
     self.assertProtoPartiallyEquals(
         cached_output[output_examples_key][0].mlmd_artifact,
         output_example_two.mlmd_artifact,
         ignored_fields=[
             'create_time_since_epoch', 'last_update_time_since_epoch'
         ])
Ejemplo n.º 27
0
    def __init__(
            self,
            input: types.Channel = None,  # pylint: disable=redefined-builtin
            input_config: Optional[Union[example_gen_pb2.Input,
                                         Dict[Text, Any]]] = None,
            output_config: Optional[Union[example_gen_pb2.Output,
                                          Dict[Text, Any]]] = None,
            custom_config: Optional[Union[example_gen_pb2.CustomConfig,
                                          Dict[Text, Any]]] = None,
            example_artifacts: Optional[types.Channel] = None,
            custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None,
            input_base: Optional[types.Channel] = None,
            instance_name: Optional[Text] = None):
        """Construct a FileBasedExampleGen component.

    Args:
      input: A Channel of type `standard_artifacts.ExternalArtifact`, which
        includes one artifact whose uri is an external directory containing the
        data files. _required_
      input_config: An
        [`example_gen_pb2.Input`](https://github.com/tensorflow/tfx/blob/master/tfx/proto/example_gen.proto)
          instance, providing input configuration. If unset, the files under
          input_base will be treated as a single dataset.
      output_config: An example_gen_pb2.Output instance, providing the output
        configuration. If unset, default splits will be 'train' and
        'eval' with size 2:1.
      custom_config: An optional example_gen_pb2.CustomConfig instance,
        providing custom configuration for executor.
      example_artifacts: Channel of 'ExamplesPath' for output train and eval
        examples.
      custom_executor_spec: Optional custom executor spec overriding the default
        executor spec specified in the component attribute.
      input_base: Backwards compatibility alias for the 'input' argument.
      instance_name: Optional unique instance name. Required only if multiple
        ExampleGen components are declared in the same pipeline.  Either
        `input_base` or `input` must be present in the input arguments.
    """
        if input_base:
            absl.logging.warning(
                'The "input_base" argument to the ExampleGen component has '
                'been renamed to "input" and is deprecated. Please update your '
                'usage as support for this argument will be removed soon.')
            input = input_base
        # Configure inputs and outputs.
        input_config = input_config or utils.make_default_input_config()
        output_config = output_config or utils.make_default_output_config(
            input_config)
        if not example_artifacts:
            artifact = standard_artifacts.Examples()
            artifact.split_names = artifact_utils.encode_split_names(
                utils.generate_output_split_names(input_config, output_config))
            example_artifacts = channel_utils.as_channel([artifact])
        spec = FileBasedExampleGenSpec(input=input,
                                       input_config=input_config,
                                       output_config=output_config,
                                       custom_config=custom_config,
                                       examples=example_artifacts)
        super(FileBasedExampleGen,
              self).__init__(spec=spec,
                             custom_executor_spec=custom_executor_spec,
                             instance_name=instance_name)