Пример #1
0
  def setUp(self):
    super(ComponentTest, self).setUp()

    self.examples = channel_utils.as_channel([standard_artifacts.Examples()])
    self.transform_output = channel_utils.as_channel(
        [standard_artifacts.TransformGraph()])
    self.schema = channel_utils.as_channel([standard_artifacts.Schema()])
    self.hyperparameters = channel_utils.as_channel(
        [standard_artifacts.HyperParameters()])
    self.train_args = trainer_pb2.TrainArgs(splits=['train'], num_steps=100)
    self.eval_args = trainer_pb2.EvalArgs(splits=['eval'], num_steps=50)
    def setUp(self):
        super(KubeflowGCPIntegrationTest, self).setUp()

        # Raw Example artifacts for testing.
        raw_train_examples = standard_artifacts.Examples(split='train')
        raw_train_examples.uri = os.path.join(
            self._intermediate_data_root,
            'csv_example_gen/examples/test-pipeline/train/')
        raw_eval_examples = standard_artifacts.Examples(split='eval')
        raw_eval_examples.uri = os.path.join(
            self._intermediate_data_root,
            'csv_example_gen/examples/test-pipeline/eval/')
        self._test_raw_examples = [raw_train_examples, raw_eval_examples]

        # Transformed Example artifacts for testing.
        transformed_train_examples = standard_artifacts.Examples(split='train')
        transformed_train_examples.uri = os.path.join(
            self._intermediate_data_root,
            'transform/transformed_examples/test-pipeline/train/')
        transformed_eval_examples = standard_artifacts.Examples(split='eval')
        transformed_eval_examples.uri = os.path.join(
            self._intermediate_data_root,
            'transform/transformed_examples/test-pipeline/eval/')
        self._test_transformed_examples = [
            transformed_train_examples, transformed_eval_examples
        ]

        # Schema artifact for testing.
        schema = standard_artifacts.Schema()
        schema.uri = os.path.join(self._intermediate_data_root,
                                  'schema_gen/output/test-pipeline/')
        self._test_schema = [schema]

        # TransformGraph artifact for testing.
        transform_graph = standard_artifacts.TransformGraph()
        transform_graph.uri = os.path.join(
            self._intermediate_data_root,
            'transform/transform_output/test-pipeline/')
        self._test_transform_graph = [transform_graph]

        # Model artifact for testing.
        model = standard_artifacts.Model()
        model.uri = os.path.join(self._intermediate_data_root,
                                 'trainer/output/test-pipeline/')
        self._test_model = [model]

        # ModelBlessing artifact for testing.
        model_blessing = standard_artifacts.ModelBlessing()
        model_blessing.uri = os.path.join(
            self._intermediate_data_root,
            'model_validator/blessing/test-pipeline/')
        self._test_model_blessing = [model_blessing]
Пример #3
0
    def _make_base_do_params(self, source_data_dir, output_data_dir):
        # Create input dict.
        example1 = standard_artifacts.Examples()
        example1.uri = self._ARTIFACT1_URI
        example1.split_names = artifact_utils.encode_split_names(
            ['train', 'eval'])
        example2 = copy.deepcopy(example1)
        example2.uri = self._ARTIFACT2_URI

        self._example_artifacts = [example1, example2]

        schema_artifact = standard_artifacts.Schema()
        schema_artifact.uri = os.path.join(source_data_dir, 'schema_gen')

        self._input_dict = {
            standard_component_specs.EXAMPLES_KEY: self._example_artifacts[:1],
            standard_component_specs.SCHEMA_KEY: [schema_artifact],
        }

        # Create output dict.
        self._transformed_output = standard_artifacts.TransformGraph()
        self._transformed_output.uri = os.path.join(output_data_dir,
                                                    'transformed_graph')
        transformed1 = standard_artifacts.Examples()
        transformed1.uri = os.path.join(output_data_dir,
                                        'transformed_examples', '0')
        transformed2 = standard_artifacts.Examples()
        transformed2.uri = os.path.join(output_data_dir,
                                        'transformed_examples', '1')

        self._transformed_example_artifacts = [transformed1, transformed2]

        temp_path_output = _TempPath()
        temp_path_output.uri = tempfile.mkdtemp()
        self._updated_analyzer_cache_artifact = standard_artifacts.TransformCache(
        )
        self._updated_analyzer_cache_artifact.uri = os.path.join(
            self._output_data_dir, 'CACHE')

        self._output_dict = {
            standard_component_specs.TRANSFORM_GRAPH_KEY:
            [self._transformed_output],
            standard_component_specs.TRANSFORMED_EXAMPLES_KEY:
            self._transformed_example_artifacts[:1],
            executor.TEMP_PATH_KEY: [temp_path_output],
            standard_component_specs.UPDATED_ANALYZER_CACHE_KEY:
            [self._updated_analyzer_cache_artifact],
        }

        # Create exec properties skeleton.
        self._exec_properties = {}
Пример #4
0
  def setUp(self):
    super(ExecutorTest, self).setUp()
    self._source_data_dir = os.path.join(
        os.path.dirname(os.path.dirname(__file__)), 'testdata')
    self._output_data_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self._testMethodName)

    # Create input dict.
    train_examples = standard_artifacts.Examples(split='train')
    train_examples.uri = os.path.join(self._source_data_dir,
                                      'transform/transformed_examples/train/')
    eval_examples = standard_artifacts.Examples(split='eval')
    eval_examples.uri = os.path.join(self._source_data_dir,
                                     'transform/transformed_examples/eval/')
    transform_output = standard_artifacts.TransformGraph()
    transform_output.uri = os.path.join(self._source_data_dir,
                                        'transform/transform_output/')
    schema = standard_artifacts.Examples()
    schema.uri = os.path.join(self._source_data_dir, 'schema_gen/')

    self._input_dict = {
        'examples': [train_examples, eval_examples],
        'transform_output': [transform_output],
        'schema': [schema],
    }

    # Create output dict.
    self._model_exports = standard_artifacts.Model()
    self._model_exports.uri = os.path.join(self._output_data_dir,
                                           'model_export_path')
    self._output_dict = {'output': [self._model_exports]}

    # Create exec properties skeleton.
    self._exec_properties = {
        'train_args':
            json_format.MessageToJson(trainer_pb2.TrainArgs(num_steps=1000)),
        'eval_args':
            json_format.MessageToJson(trainer_pb2.EvalArgs(num_steps=500)),
        'warm_starting':
            False,
    }

    self._module_file = os.path.join(self._source_data_dir, 'module_file',
                                     'trainer_module.py')
    self._trainer_fn = '%s.%s' % (trainer_module.trainer_fn.__module__,
                                  trainer_module.trainer_fn.__name__)

    # Executor for test.
    self._trainer_executor = executor.Executor()
Пример #5
0
    def testGetCommonFnArgs(self):
        source_data_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'testdata')

        # Create input dict.
        examples = standard_artifacts.Examples()
        examples.uri = os.path.join(source_data_dir,
                                    'transform/transformed_examples')
        examples.split_names = artifact_utils.encode_split_names(
            ['train', 'eval'])

        transform_output = standard_artifacts.TransformGraph()
        transform_output.uri = os.path.join(source_data_dir,
                                            'transform/transform_graph')

        schema = standard_artifacts.Schema()
        schema.uri = os.path.join(source_data_dir, 'schema_gen')

        input_dict = {
            constants.EXAMPLES_KEY: [examples],
            constants.TRANSFORM_GRAPH_KEY: [transform_output],
            constants.SCHEMA_KEY: [schema],
        }

        # Create exec properties skeleton.
        exec_properties = {
            'train_args':
            json_format.MessageToJson(trainer_pb2.TrainArgs(num_steps=1000),
                                      preserving_proto_field_name=True),
            'eval_args':
            json_format.MessageToJson(trainer_pb2.EvalArgs(num_steps=500),
                                      preserving_proto_field_name=True),
        }

        fn_args = fn_args_utils.get_common_fn_args(input_dict, exec_properties,
                                                   'tempdir')
        self.assertEqual(fn_args.working_dir, 'tempdir')
        self.assertEqual(fn_args.train_steps, 1000)
        self.assertEqual(fn_args.eval_steps, 500)
        self.assertLen(fn_args.train_files, 1)
        self.assertEqual(fn_args.train_files[0],
                         os.path.join(examples.uri, 'train', '*'))
        self.assertLen(fn_args.eval_files, 1)
        self.assertEqual(fn_args.eval_files[0],
                         os.path.join(examples.uri, 'eval', '*'))
        self.assertEqual(fn_args.schema_path,
                         os.path.join(schema.uri, 'schema.pbtxt'))
        self.assertEqual(fn_args.transform_graph_path, transform_output.uri)
        self.assertIsInstance(fn_args.data_accessor,
                              fn_args_utils.DataAccessor)
Пример #6
0
 def testConstructTransformGraph(self):
     output_data_dir = os.path.join(
         os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
         self._testMethodName)
     artifact_channel = channel_utils.as_channel(
         [standard_artifacts.TransformGraph()])
     component_instance = component.TransformGraphPusher(
         artifact=artifact_channel,
         push_destination=pusher_pb2.PushDestination(
             filesystem=pusher_pb2.PushDestination.Filesystem(
                 base_directory=output_data_dir)))
     self.assertEqual('TransformGraph',
                      component_instance.inputs.artifact.type_name)
     self.assertEqual('TransformGraph',
                      component_instance.outputs.pushed_artifact.type_name)
Пример #7
0
    def setUp(self):
        super(ExecutorTest, self).setUp()

        # Create input_dict.
        self._input_data_dir = os.path.join(os.path.dirname(__file__),
                                            'testdata')
        examples = standard_artifacts.Examples()
        examples.uri = os.path.join(self._input_data_dir, 'example_gen')
        examples.split_names = artifact_utils.encode_split_names(
            ['train', 'eval'])
        schema_artifact = standard_artifacts.Schema()
        schema_artifact.uri = os.path.join(self._input_data_dir, 'schema_gen')
        self._input_dict = {
            standard_component_specs.EXAMPLES_KEY: [examples],
            standard_component_specs.SCHEMA_KEY: [schema_artifact],
        }

        # Create output_dict.
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR',
                           tempfile.mkdtemp(dir=flags.FLAGS.test_tmpdir)),
            self._testMethodName)
        self._transformed_output = standard_artifacts.TransformGraph()
        self._transformed_output.uri = os.path.join(output_data_dir,
                                                    'transformed_output')
        self._transformed_examples = standard_artifacts.Examples()
        self._transformed_examples.uri = output_data_dir
        self._transformed_examples.split_names = artifact_utils.encode_split_names(
            ['train', 'eval'])
        temp_path_output = _TempPath()
        temp_path_output.uri = tempfile.mkdtemp()
        self._output_dict = {
            standard_component_specs.TRANSFORM_GRAPH_KEY:
            [self._transformed_output],
            standard_component_specs.TRANSFORMED_EXAMPLES_KEY:
            [self._transformed_examples],
            tfx_executor.TEMP_PATH_KEY: [temp_path_output],
        }

        # Create exec properties.
        self._exec_properties = {
            'custom_config':
            json.dumps({'problem_statement_path': '/some/fake/path'})
        }
Пример #8
0
    def setUp(self):
        super(ComponentTest, self).setUp()

        examples_artifact = standard_artifacts.Examples()
        examples_artifact.split_names = artifact_utils.encode_split_names(
            ['train', 'eval'])
        transform_output = standard_artifacts.TransformGraph()

        self.examples = channel_utils.as_channel([examples_artifact])
        self.schema = channel_utils.as_channel([standard_artifacts.Schema()])
        self.transform_graph = channel_utils.as_channel([transform_output])
        self.custom_config = {'some': 'thing', 'some other': 1, 'thing': 2}
        self.train_args = trainer_pb2.TrainArgs(num_steps=100)
        self.eval_args = trainer_pb2.EvalArgs(num_steps=50)
        self.tune_args = tuner_pb2.TuneArgs(num_parallel_trials=3)
        self.warmup_hyperparams = channel_utils.as_channel(
            [artifacts.KCandidateHyperParameters()])
        self.meta_model = channel_utils.as_channel(
            [standard_artifacts.Model()])
Пример #9
0
    def _make_base_do_params(self, source_data_dir, output_data_dir):
        # Create input dict.
        train_artifact = standard_artifacts.Examples(split='train')
        train_artifact.uri = os.path.join(source_data_dir,
                                          'csv_example_gen/train/')
        eval_artifact = standard_artifacts.Examples(split='eval')
        eval_artifact.uri = os.path.join(source_data_dir,
                                         'csv_example_gen/eval/')
        schema_artifact = standard_artifacts.Schema()
        schema_artifact.uri = os.path.join(source_data_dir, 'schema_gen/')

        self._input_dict = {
            'input_data': [train_artifact, eval_artifact],
            'schema': [schema_artifact],
        }

        # Create output dict.
        self._transformed_output = standard_artifacts.TransformGraph()
        self._transformed_output.uri = os.path.join(output_data_dir,
                                                    'transformed_output')
        self._transformed_train_examples = standard_artifacts.Examples(
            split='train')
        self._transformed_train_examples.uri = os.path.join(
            output_data_dir, 'train')
        self._transformed_eval_examples = standard_artifacts.Examples(
            split='eval')
        self._transformed_eval_examples.uri = os.path.join(
            output_data_dir, 'eval')
        temp_path_output = types.Artifact('TempPath')
        temp_path_output.uri = tempfile.mkdtemp()

        self._output_dict = {
            'transform_output': [self._transformed_output],
            'transformed_examples': [
                self._transformed_train_examples,
                self._transformed_eval_examples
            ],
            'temp_path': [temp_path_output],
        }

        # Create exec properties skeleton.
        self._exec_properties = {}
Пример #10
0
    def _make_base_do_params(self, source_data_dir, output_data_dir):
        # Create input dict.
        examples = standard_artifacts.Examples()
        examples.uri = os.path.join(source_data_dir, 'csv_example_gen')
        examples.split_names = artifact_utils.encode_split_names(
            ['train', 'eval'])
        schema_artifact = standard_artifacts.Schema()
        schema_artifact.uri = os.path.join(source_data_dir, 'schema_gen')

        self._input_dict = {
            executor.EXAMPLES_KEY: [examples],
            executor.SCHEMA_KEY: [schema_artifact],
        }

        # Create output dict.
        self._transformed_output = standard_artifacts.TransformGraph()
        self._transformed_output.uri = os.path.join(output_data_dir,
                                                    'transformed_graph')
        self._transformed_examples = standard_artifacts.Examples()
        self._transformed_examples.uri = os.path.join(output_data_dir,
                                                      'transformed_examples')
        temp_path_output = _TempPath()
        temp_path_output.uri = tempfile.mkdtemp()
        self._updated_analyzer_cache_artifact = standard_artifacts.TransformCache(
        )
        self._updated_analyzer_cache_artifact.uri = os.path.join(
            self._output_data_dir, 'CACHE')

        self._output_dict = {
            executor.TRANSFORM_GRAPH_KEY: [self._transformed_output],
            executor.TRANSFORMED_EXAMPLES_KEY: [self._transformed_examples],
            executor.TEMP_PATH_KEY: [temp_path_output],
            executor.UPDATED_ANALYZER_CACHE_KEY:
            [self._updated_analyzer_cache_artifact],
        }

        # Create exec properties skeleton.
        self._exec_properties = {}
Пример #11
0
    def setUp(self):
        super(ExecutorTest, self).setUp()
        self._source_data_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'testdata')
        self._output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        # Create input dict.
        e1 = standard_artifacts.Examples()
        e1.uri = os.path.join(self._source_data_dir,
                              'transform/transformed_examples')
        e1.split_names = artifact_utils.encode_split_names(['train', 'eval'])

        e2 = copy.deepcopy(e1)

        self._single_artifact = [e1]
        self._multiple_artifacts = [e1, e2]

        transform_graph = standard_artifacts.TransformGraph()
        transform_graph.uri = os.path.join(self._source_data_dir,
                                           'transform/transform_graph')

        schema = standard_artifacts.Schema()
        schema.uri = os.path.join(self._source_data_dir, 'schema_gen')
        previous_model = standard_artifacts.Model()
        previous_model.uri = os.path.join(self._source_data_dir,
                                          'trainer/previous')

        self._input_dict = {
            standard_component_specs.EXAMPLES_KEY: self._single_artifact,
            standard_component_specs.TRANSFORM_GRAPH_KEY: [transform_graph],
            standard_component_specs.SCHEMA_KEY: [schema],
            standard_component_specs.BASE_MODEL_KEY: [previous_model]
        }

        # Create output dict.
        self._model_exports = standard_artifacts.Model()
        self._model_exports.uri = os.path.join(self._output_data_dir,
                                               'model_export_path')
        self._model_run_exports = standard_artifacts.ModelRun()
        self._model_run_exports.uri = os.path.join(self._output_data_dir,
                                                   'model_run_path')
        self._output_dict = {
            standard_component_specs.MODEL_KEY: [self._model_exports],
            standard_component_specs.MODEL_RUN_KEY: [self._model_run_exports]
        }

        # Create exec properties skeleton.
        self._exec_properties = {
            standard_component_specs.TRAIN_ARGS_KEY:
            proto_utils.proto_to_json(trainer_pb2.TrainArgs(num_steps=1000)),
            standard_component_specs.EVAL_ARGS_KEY:
            proto_utils.proto_to_json(trainer_pb2.EvalArgs(num_steps=500)),
            'warm_starting':
            False,
        }

        self._module_file = os.path.join(
            self._source_data_dir, standard_component_specs.MODULE_FILE_KEY,
            'trainer_module.py')
        self._trainer_fn = '%s.%s' % (trainer_module.trainer_fn.__module__,
                                      trainer_module.trainer_fn.__name__)

        # Executors for test.
        self._trainer_executor = executor.Executor()
        self._generic_trainer_executor = executor.GenericExecutor()
Пример #12
0
    def __init__(
            self,
            examples: types.Channel = None,
            schema: types.Channel = None,
            module_file: Optional[Union[Text,
                                        data_types.RuntimeParameter]] = None,
            preprocessing_fn: Optional[Union[
                Text, data_types.RuntimeParameter]] = None,
            transform_graph: Optional[types.Channel] = None,
            transformed_examples: Optional[types.Channel] = None,
            input_data: Optional[types.Channel] = None,
            instance_name: Optional[Text] = None,
            materialize: bool = True):
        """Construct a Transform component.

    Args:
      examples: A Channel of type `standard_artifacts.Examples` (required).
        This should contain the two splits 'train' and 'eval'.
      schema: A Channel of type `standard_artifacts.Schema`. This should
        contain a single schema artifact.
      module_file: The file path to a python module file, from which the
        'preprocessing_fn' function will be loaded. The function must have the
        following signature.

        def preprocessing_fn(inputs: Dict[Text, Any]) -> Dict[Text, Any]:
          ...

        where the values of input and returned Dict are either tf.Tensor or
        tf.SparseTensor.  Exactly one of 'module_file' or 'preprocessing_fn'
        must be supplied.
      preprocessing_fn: The path to python function that implements a
        'preprocessing_fn'. See 'module_file' for expected signature of the
        function. Exactly one of 'module_file' or 'preprocessing_fn' must be
        supplied.
      transform_graph: Optional output 'TransformPath' channel for output of
        'tf.Transform', which includes an exported Tensorflow graph suitable for
        both training and serving;
      transformed_examples: Optional output 'ExamplesPath' channel for
        materialized transformed examples, which includes both 'train' and
        'eval' splits.
      input_data: Backwards compatibility alias for the 'examples' argument.
      instance_name: Optional unique instance name. Necessary iff multiple
        transform components are declared in the same pipeline.
      materialize: If True, write transformed examples as an output. If False,
        `transformed_examples` must not be provided.

    Raises:
      ValueError: When both or neither of 'module_file' and 'preprocessing_fn'
        is supplied.
    """
        if input_data:
            absl.logging.warning(
                'The "input_data" argument to the Transform component has '
                'been renamed to "examples" and is deprecated. Please update your '
                'usage as support for this argument will be removed soon.')
            examples = input_data
        if bool(module_file) == bool(preprocessing_fn):
            raise ValueError(
                "Exactly one of 'module_file' or 'preprocessing_fn' must be supplied."
            )

        transform_graph = transform_graph or types.Channel(
            type=standard_artifacts.TransformGraph,
            artifacts=[standard_artifacts.TransformGraph()])
        if materialize and transformed_examples is None:
            transformed_examples = types.Channel(
                type=standard_artifacts.Examples,
                # TODO(b/161548528): remove the hardcode artifact.
                artifacts=[standard_artifacts.Examples()],
                matching_channel_name='examples')
        elif not materialize and transformed_examples is not None:
            raise ValueError(
                'must not specify transformed_examples when materialize==False'
            )
        spec = TransformSpec(examples=examples,
                             schema=schema,
                             module_file=module_file,
                             preprocessing_fn=preprocessing_fn,
                             transform_graph=transform_graph,
                             transformed_examples=transformed_examples)
        super(Transform, self).__init__(spec=spec, instance_name=instance_name)
Пример #13
0
    def __init__(self,
                 input_data: types.Channel = None,
                 schema: types.Channel = None,
                 module_file: Optional[Text] = None,
                 preprocessing_fn: Optional[Text] = None,
                 transform_output: Optional[types.Channel] = None,
                 transformed_examples: Optional[types.Channel] = None,
                 examples: Optional[types.Channel] = None,
                 instance_name: Optional[Text] = None):
        """Construct a Transform component.

    Args:
      input_data: A Channel of 'ExamplesPath' type (required). This should
        contain the two splits 'train' and 'eval'.
      schema: A Channel of 'SchemaPath' type. This should contain a single
        schema artifact.
      module_file: The file path to a python module file, from which the
        'preprocessing_fn' function will be loaded. The function must have the
        following signature.

        def preprocessing_fn(inputs: Dict[Text, Any]) -> Dict[Text, Any]:
          ...

        where the values of input and returned Dict are either tf.Tensor or
        tf.SparseTensor.  Exactly one of 'module_file' or 'preprocessing_fn'
        must be supplied.
      preprocessing_fn: The path to python function that implements a
         'preprocessing_fn'. See 'module_file' for expected signature of the
         function. Exactly one of 'module_file' or 'preprocessing_fn' must
         be supplied.
      transform_output: Optional output 'TransformPath' channel for output of
        'tf.Transform', which includes an exported Tensorflow graph suitable for
        both training and serving;
      transformed_examples: Optional output 'ExamplesPath' channel for
        materialized transformed examples, which includes both 'train' and
        'eval' splits.
      examples: Forwards compatibility alias for the 'input_data' argument.
      instance_name: Optional unique instance name. Necessary iff multiple
        transform components are declared in the same pipeline.

    Raises:
      ValueError: When both or neither of 'module_file' and 'preprocessing_fn'
        is supplied.
    """
        input_data = input_data or examples
        if bool(module_file) == bool(preprocessing_fn):
            raise ValueError(
                "Exactly one of 'module_file' or 'preprocessing_fn' must be supplied."
            )

        transform_output = transform_output or types.Channel(
            type=standard_artifacts.TransformGraph,
            artifacts=[standard_artifacts.TransformGraph()])
        transformed_examples = transformed_examples or types.Channel(
            type=standard_artifacts.Examples,
            artifacts=[
                standard_artifacts.Examples(split=split)
                for split in artifact.DEFAULT_EXAMPLE_SPLITS
            ])
        spec = TransformSpec(input_data=input_data,
                             schema=schema,
                             module_file=module_file,
                             preprocessing_fn=preprocessing_fn,
                             transform_output=transform_output,
                             transformed_examples=transformed_examples)
        super(Transform, self).__init__(spec=spec, instance_name=instance_name)
Пример #14
0
    def __init__(
            self,
            examples: types.Channel = None,
            schema: types.Channel = None,
            module_file: Optional[Union[Text,
                                        data_types.RuntimeParameter]] = None,
            preprocessing_fn: Optional[Union[
                Text, data_types.RuntimeParameter]] = None,
            custom_config: Optional[Dict[Text, Any]] = None,
            transform_graph: Optional[types.Channel] = None,
            transformed_examples: Optional[types.Channel] = None,
            instance_name: Optional[Text] = None):
        # pyformat: disable
        # pylint: disable=g-doc-args
        """Construct a Transform component.

    Args:
      examples: A Channel of type `standard_artifacts.Examples` (required).
        This should contain the two splits 'train' and 'eval'.
      schema: A Channel of type `standard_artifacts.Schema`. This should
        contain a single schema artifact.
      module_file: The file path to a python module file, from which the
        'preprocessing_fn' function will be loaded. The function must have the
        following signature.

        def preprocessing_fn(inputs: Dict[Text, Any],
                             schema: schema_pb2.Schema,
                             custom_config: Dict[Text, Any]) -> Dict[Text, Any]:
          ...

        where the values of input and returned Dict are either tf.Tensor or
        tf.SparseTensor. The 'schema' and 'custom_config' arguments are not
        necessary and can be omitted. Exactly one of 'module_file' or
        'preprocessing_fn' must be supplied.
      preprocessing_fn: The path to python function that implements a
        'preprocessing_fn'. See 'module_file' for expected signature of the
        function. Exactly one of 'module_file' or 'preprocessing_fn' must be
        supplied.
      custom_config: A dict which contains additional transform parameters that
        will be passed into the preprocessing_fn.
      transform_graph: Optional output 'TransformPath' channel for output of
        'tf.Transform', which includes an exported Tensorflow graph suitable for
        both training and serving;
      transformed_examples: Optional output 'ExamplesPath' channel for
        materialized transformed examples, which includes both 'train' and
        'eval' splits.
      instance_name: Optional unique instance name. Necessary iff multiple
        transform components are declared in the same pipeline.

    Raises:
      ValueError: When both or neither of 'module_file' and 'preprocessing_fn'
        is supplied.
    """
        # pyformat: enable
        # pylint: enable=g-doc-args
        if bool(module_file) == bool(preprocessing_fn):
            raise ValueError(
                "Exactly one of 'module_file' or 'preprocessing_fn' must be supplied."
            )
        transform_graph = transform_graph or types.Channel(
            type=standard_artifacts.TransformGraph,
            artifacts=[standard_artifacts.TransformGraph()])
        if not transformed_examples:
            example_artifact = standard_artifacts.Examples()
            example_artifact.split_names = artifact_utils.encode_split_names(
                artifact.DEFAULT_EXAMPLE_SPLITS)
            transformed_examples = types.Channel(
                type=standard_artifacts.Examples, artifacts=[example_artifact])
        spec = standard_component_specs.TransformSpec(
            examples=examples,
            schema=schema,
            module_file=module_file,
            preprocessing_fn=preprocessing_fn,
            custom_config=json.dumps(custom_config),
            transform_graph=transform_graph,
            transformed_examples=transformed_examples)
        super(Transform, self).__init__(spec=spec, instance_name=instance_name)
Пример #15
0
    def setUp(self):
        super(ExecutorTest, self).setUp()
        self._source_data_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'testdata')
        self._output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        # Create input dict.
        examples = standard_artifacts.Examples()
        examples.uri = os.path.join(self._source_data_dir,
                                    'transform/transformed_examples')
        examples.split_names = artifact_utils.encode_split_names(
            ['train', 'eval'])
        transform_output = standard_artifacts.TransformGraph()
        transform_output.uri = os.path.join(self._source_data_dir,
                                            'transform/transform_graph')
        schema = standard_artifacts.Schema()
        schema.uri = os.path.join(self._source_data_dir, 'schema_gen')
        previous_model = standard_artifacts.Model()
        previous_model.uri = os.path.join(self._source_data_dir,
                                          'trainer/previous')

        self._input_dict = {
            constants.EXAMPLES_KEY: [examples],
            constants.TRANSFORM_GRAPH_KEY: [transform_output],
            constants.SCHEMA_KEY: [schema],
            constants.BASE_MODEL_KEY: [previous_model]
        }

        # Create output dict.
        self._model_exports = standard_artifacts.Model()
        self._model_exports.uri = os.path.join(self._output_data_dir,
                                               'model_export_path')
        self._model_run_exports = standard_artifacts.ModelRun()
        self._model_run_exports.uri = os.path.join(self._output_data_dir,
                                                   'model_run_path')
        self._output_dict = {
            constants.MODEL_KEY: [self._model_exports],
            constants.MODEL_RUN_KEY: [self._model_run_exports]
        }

        # Create exec properties skeleton.
        self._exec_properties = {
            'train_args':
            json_format.MessageToJson(trainer_pb2.TrainArgs(num_steps=1000),
                                      preserving_proto_field_name=True),
            'eval_args':
            json_format.MessageToJson(trainer_pb2.EvalArgs(num_steps=500),
                                      preserving_proto_field_name=True),
            'warm_starting':
            False,
        }

        self._module_file = os.path.join(self._source_data_dir, 'module_file',
                                         'trainer_module.py')
        self._trainer_fn = '%s.%s' % (trainer_module.trainer_fn.__module__,
                                      trainer_module.trainer_fn.__name__)

        # Executors for test.
        self._trainer_executor = executor.Executor()
        self._generic_trainer_executor = executor.GenericExecutor()