Exemple #1
0
 def test_construct_with_materialization_disabled(self):
   transform = component.Transform(
       examples=self.examples,
       schema=self.schema,
       preprocessing_fn='my_preprocessing_fn',
       materialize=False)
   self._verify_outputs(transform, materialize=False)
Exemple #2
0
    def testEnableCache(self):
        module_file = '/path/to/preprocessing.py'
        transform_1 = component.Transform(
            examples=self.examples,
            schema=self.schema,
            module_file=module_file,
        )
        self.assertEqual(None, transform_1.enable_cache)

        transform_2 = component.Transform(
            examples=self.examples,
            schema=self.schema,
            module_file=module_file,
            enable_cache=True,
        )
        self.assertEqual(True, transform_2.enable_cache)
class TransformRunner(TfxComponentRunner):

  def __init__(self, args):
    component = transform_component.Transform(
        input_data=channel.Channel('ExamplesPath'),
        schema=channel.Channel('SchemaPath'),
        module_file=args.module_file)
Exemple #4
0
 def test_construct_with_cache_disabled(self):
   transform = component.Transform(
       examples=self.examples,
       schema=self.schema,
       preprocessing_fn='my_preprocessing_fn',
       disable_analyzer_cache=True)
   self._verify_outputs(transform, disable_analyzer_cache=True)
Exemple #5
0
 def test_construct_duplicate_user_module(self):
     with self.assertRaises(ValueError):
         _ = component.Transform(
             input_data=self.input_data,
             schema=self.schema,
             module_file='/path/to/preprocessing.py',
             preprocessing_fn='path.to.my_preprocessing_fn',
         )
Exemple #6
0
 def testConstructDuplicateUserModule(self):
     with self.assertRaises(ValueError):
         _ = component.Transform(
             examples=self.examples,
             schema=self.schema,
             module_file='/path/to/preprocessing.py',
             preprocessing_fn='path.to.my_preprocessing_fn',
         )
Exemple #7
0
 def test_construct_with_cache_disabled_but_input_cache(self):
     with self.assertRaises(ValueError):
         _ = component.Transform(examples=self.examples,
                                 schema=self.schema,
                                 preprocessing_fn='my_preprocessing_fn',
                                 disable_analyzer_cache=True,
                                 analyzer_cache=channel_utils.as_channel(
                                     [standard_artifacts.TransformCache()]))
Exemple #8
0
 def test_construct_from_module_file(self):
     module_file = '/path/to/preprocessing.py'
     transform = component.Transform(
         examples=self.examples,
         schema=self.schema,
         module_file=module_file,
     )
     self._verify_outputs(transform)
     self.assertEqual(module_file, transform.exec_properties['module_file'])
Exemple #9
0
 def test_construct_with_materialization_disabled_but_output_examples(self):
     with self.assertRaises(ValueError):
         _ = component.Transform(
             examples=self.examples,
             schema=self.schema,
             preprocessing_fn='my_preprocessing_fn',
             materialize=False,
             transformed_examples=channel_utils.as_channel(
                 [standard_artifacts.Examples()]))
Exemple #10
0
 def testConstructFromModuleFile(self):
   module_file = '/path/to/preprocessing.py'
   transform = component.Transform(
       input_data=self.input_data,
       schema=self.schema,
       module_file=module_file,
   )
   self._verify_outputs(transform)
   self.assertEqual(module_file, transform.spec.exec_properties['module_file'])
Exemple #11
0
 def test_construct_with_parameter(self):
   module_file = data_types.RuntimeParameter(name='module-file', ptype=Text)
   transform = component.Transform(
       examples=self.examples,
       schema=self.schema,
       module_file=module_file,
   )
   self._verify_outputs(transform)
   self.assertJsonEqual(
       str(module_file), str(transform.exec_properties['module_file']))
Exemple #12
0
 def testConstructFromPreprocessingFn(self):
     preprocessing_fn = 'path.to.my_preprocessing_fn'
     transform = component.Transform(
         examples=self.examples,
         schema=self.schema,
         preprocessing_fn=preprocessing_fn,
     )
     self._verify_outputs(transform)
     self.assertEqual(preprocessing_fn,
                      transform.spec.exec_properties['preprocessing_fn'])
Exemple #13
0
 def test_construct_from_preprocessing_fn(self):
     preprocessing_fn = 'path.to.my_preprocessing_fn'
     transform = component.Transform(
         input_data=self.input_data,
         schema=self.schema,
         preprocessing_fn=preprocessing_fn,
     )
     self._verify_outputs(transform)
     self.assertEqual(preprocessing_fn,
                      transform.spec.exec_properties['preprocessing_fn'])
    def __init__(self, input_data: str, schema: str, module_file: str):
        component = transform_component.Transform(
            input_data=channel.Channel('ExamplesPath'),
            schema=channel.Channel('SchemaPath'),
            module_file=module_file)

        super().__init__(component, {
            "input_data": input_data,
            "schema": schema
        })
Exemple #15
0
 def test_construct_from_preprocessing_fn(self):
     preprocessing_fn = 'path.to.my_preprocessing_fn'
     transform = component.Transform(
         examples=self.examples,
         schema=self.schema,
         preprocessing_fn=preprocessing_fn,
     )
     self._verify_outputs(transform)
     self.assertEqual(
         preprocessing_fn, transform.exec_properties[
             standard_component_specs.PREPROCESSING_FN_KEY])
Exemple #16
0
 def test_construct_with_force_tf_compat_v1_override(self):
     transform = component.Transform(
         examples=self.examples,
         schema=self.schema,
         preprocessing_fn='my_preprocessing_fn',
         force_tf_compat_v1=True,
     )
     self._verify_outputs(transform)
     self.assertEqual(
         True,
         bool(transform.spec.exec_properties[
             standard_component_specs.FORCE_TF_COMPAT_V1_KEY]))
Exemple #17
0
 def test_construct_with_stats_disabled(self):
     transform = component.Transform(
         examples=self.examples,
         schema=self.schema,
         preprocessing_fn='my_preprocessing_fn',
         disable_statistics=True,
     )
     self._verify_outputs(transform, disable_statistics=True)
     self.assertEqual(
         True,
         bool(transform.spec.exec_properties[
             standard_component_specs.DISABLE_STATISTICS_KEY]))
Exemple #18
0
    def __init__(self, input_data: dsl.PipelineParam,
                 schema: dsl.PipelineParam, module_file: dsl.PipelineParam):
        component = transform_component.Transform(
            input_data=channel.Channel('ExamplesPath'),
            schema=channel.Channel('SchemaPath'),
            module_file='')

        super().__init__(
            component, {
                "input_data": input_data,
                "schema": schema,
                "module_file": module_file,
            })
Exemple #19
0
 def test_construct_from_preprocessing_fn_with_custom_config(self):
     preprocessing_fn = 'path.to.my_preprocessing_fn'
     custom_config = {'param': 1}
     transform = component.Transform(
         examples=self.examples,
         schema=self.schema,
         preprocessing_fn=preprocessing_fn,
         custom_config=custom_config,
     )
     self._verify_outputs(transform)
     self.assertEqual(preprocessing_fn,
                      transform.spec.exec_properties['preprocessing_fn'])
     self.assertEqual(json.dumps(custom_config),
                      transform.spec.exec_properties['custom_config'])
Exemple #20
0
 def test_construct_with_splits_config(self):
   splits_config = transform_pb2.SplitsConfig(
       analyze=['train'], transform=['eval'])
   module_file = '/path/to/preprocessing.py'
   transform = component.Transform(
       examples=self.examples,
       schema=self.schema,
       module_file=module_file,
       splits_config=splits_config,
   )
   self._verify_outputs(transform)
   self.assertEqual(
       json_format.MessageToJson(
           splits_config, sort_keys=True, preserving_proto_field_name=True),
       transform.exec_properties['splits_config'])
Exemple #21
0
 def test_construct_with_splits_config(self):
     splits_config = transform_pb2.SplitsConfig(analyze=['train'],
                                                transform=['eval'])
     module_file = '/path/to/preprocessing.py'
     transform = component.Transform(
         examples=self.examples,
         schema=self.schema,
         module_file=module_file,
         splits_config=splits_config,
     )
     self._verify_outputs(transform)
     self.assertEqual(
         proto_utils.proto_to_json(splits_config),
         transform.exec_properties[
             standard_component_specs.SPLITS_CONFIG_KEY])
Exemple #22
0
 def test_construct_from_preprocessing_fn_with_custom_config(self):
     preprocessing_fn = 'path.to.my_preprocessing_fn'
     custom_config = {'param': 1}
     transform = component.Transform(
         examples=self.examples,
         schema=self.schema,
         preprocessing_fn=preprocessing_fn,
         custom_config=custom_config,
     )
     self._verify_outputs(transform)
     self.assertEqual(
         preprocessing_fn, transform.spec.exec_properties[
             standard_component_specs.PREPROCESSING_FN_KEY])
     self.assertEqual(
         json.dumps(custom_config), transform.spec.exec_properties[
             standard_component_specs.CUSTOM_CONFIG_KEY])
Exemple #23
0
 def test_construct_from_preprocessing_fn_with_stats_options_updater_fn(
         self):
     preprocessing_fn = 'path.to.my_preprocessing_fn'
     stats_options_updater_fn = 'path.to.my.stats_options_updater_fn'
     transform = component.Transform(
         examples=self.examples,
         schema=self.schema,
         preprocessing_fn=preprocessing_fn,
         stats_options_updater_fn=stats_options_updater_fn)
     self._verify_outputs(transform)
     self.assertEqual(
         preprocessing_fn, transform.exec_properties[
             standard_component_specs.PREPROCESSING_FN_KEY])
     self.assertEqual(
         stats_options_updater_fn, transform.exec_properties[
             standard_component_specs.STATS_OPTIONS_UPDATER_FN_KEY])
Exemple #24
0
 def test_construct(self):
     source_data_dir = os.path.join(os.path.dirname(__file__), 'testdata',
                                    'taxi')
     preprocessing_fn_file = os.path.join(source_data_dir, 'module',
                                          'preprocess.py')
     transform = component.Transform(
         input_data=channel.as_channel([
             types.TfxType(type_name='ExamplesPath', split='train'),
             types.TfxType(type_name='ExamplesPath', split='eval'),
         ]),
         schema=channel.as_channel([types.TfxType(type_name='SchemaPath')]),
         module_file=preprocessing_fn_file,
     )
     self.assertEqual('TransformPath',
                      transform.outputs.transform_output.type_name)
     self.assertEqual('ExamplesPath',
                      transform.outputs.transformed_examples.type_name)
Exemple #25
0
 def test_construct_missing_user_module(self):
     with self.assertRaises(ValueError):
         _ = component.Transform(
             input_data=self.input_data,
             schema=self.schema,
         )
Exemple #26
0
 def testConstructMissingUserModule(self):
     with self.assertRaises(ValueError):
         _ = component.Transform(
             examples=self.examples,
             schema=self.schema,
         )
Exemple #27
0
 def testConstructMissingUserModule(self):
   with self.assertRaises(ValueError):
     _ = component.Transform(
         input_data=self.input_data,
         schema=self.schema,
     )
Exemple #28
0
def _create_pipeline():
    """Implements the chicago east pipeline with TFX."""
    print_info("Creating pipeline")
    examples = tfrecord_input(_data_root)
    example_gen = ImportExampleGen(input_base=examples)

    # Computes statistics over data for visualization and example validation.
    # pylint: disable=line-too-long
    statistics_gen = StatisticsGen(input_data=example_gen.outputs.examples) # Step 3
    # pylint: enable=line-too-long

    # Generates schema based on statistics files.
    infer_schema = SchemaGen(stats=statistics_gen.outputs.output) # Step 3

    # Performs anomaly detection based on statistics and data schema.
    validate_stats = ExampleValidator( # Step 3
        stats=statistics_gen.outputs.output, # Step 3
        schema=infer_schema.outputs.output) # Step 3

    # Performs transformations and feature engineering in training and serving.
    transform = component.Transform( # Step 4
        input_data=example_gen.outputs.examples, # Step 4
        schema=infer_schema.outputs.output, # Step 4
        module_file=_east_module_file) # Step 4

    # Uses user-provided Python function that implements a model using TF-Learn.
    trainer = Trainer( # Step 5
        module_file=_east_module_file, # Step 5
        transformed_examples=transform.outputs.transformed_examples, # Step 5
        schema=infer_schema.outputs.output, # Step 5
        transform_output=transform.outputs.transform_output, # Step 5
        train_args=trainer_pb2.TrainArgs(num_steps=50), # Step 5
        eval_args=trainer_pb2.EvalArgs(num_steps=10)) # Step 5

    # Uses TFMA to compute a evaluation statistics over features of a model.
    #model_analyzer = Evaluator( # Step 6
    #     examples=example_gen.outputs.examples, # Step 6
    #     model_exports=trainer.outputs.output, # Step 6
    #     feature_slicing_spec=evaluator_pb2.FeatureSlicingSpec(specs=[ # Step 6
    #         evaluator_pb2.SingleSlicingSpec( # Step 6
    #             column_for_slicing=['trip_start_hour']) # Step 6
    #     ])) # Step 6

    # Performs quality validation of a candidate model (compared to a baseline).
    # model_validator = ModelValidator( # Step 7
    #     examples=example_gen.outputs.examples, # Step 7
    #              model=trainer.outputs.output) # Step 7

    # Checks whether the model passed the validation steps and pushes the model
    # to a file destination if check passed.
    # pusher = Pusher( # Step 7
    #     model_export=trainer.outputs.output, # Step 7
    #     model_blessing=model_validator.outputs.blessing, # Step 7
    #     push_destination=pusher_pb2.PushDestination( # Step 7
    #         filesystem=pusher_pb2.PushDestination.Filesystem( # Step 7
    #             base_directory=_serving_model_dir))) # Step 7

    return [
        example_gen,
        statistics_gen, infer_schema, #validate_stats, # Step 3
        transform, # Step 4
        trainer, # Step 5
        #model_analyzer, # Step 6
        # model_validator, pusher # Step 7
    ]
Exemple #29
0
 def test_construct_missing_user_module(self):
     with self.assertRaises(ValueError):
         _ = component.Transform(
             examples=self.examples,
             schema=self.schema,
         )