def test_construct_with_materialization_disabled(self): transform = component.Transform( examples=self.examples, schema=self.schema, preprocessing_fn='my_preprocessing_fn', materialize=False) self._verify_outputs(transform, materialize=False)
def testEnableCache(self): module_file = '/path/to/preprocessing.py' transform_1 = component.Transform( examples=self.examples, schema=self.schema, module_file=module_file, ) self.assertEqual(None, transform_1.enable_cache) transform_2 = component.Transform( examples=self.examples, schema=self.schema, module_file=module_file, enable_cache=True, ) self.assertEqual(True, transform_2.enable_cache)
class TransformRunner(TfxComponentRunner): def __init__(self, args): component = transform_component.Transform( input_data=channel.Channel('ExamplesPath'), schema=channel.Channel('SchemaPath'), module_file=args.module_file)
def test_construct_with_cache_disabled(self): transform = component.Transform( examples=self.examples, schema=self.schema, preprocessing_fn='my_preprocessing_fn', disable_analyzer_cache=True) self._verify_outputs(transform, disable_analyzer_cache=True)
def test_construct_duplicate_user_module(self): with self.assertRaises(ValueError): _ = component.Transform( input_data=self.input_data, schema=self.schema, module_file='/path/to/preprocessing.py', preprocessing_fn='path.to.my_preprocessing_fn', )
def testConstructDuplicateUserModule(self): with self.assertRaises(ValueError): _ = component.Transform( examples=self.examples, schema=self.schema, module_file='/path/to/preprocessing.py', preprocessing_fn='path.to.my_preprocessing_fn', )
def test_construct_with_cache_disabled_but_input_cache(self): with self.assertRaises(ValueError): _ = component.Transform(examples=self.examples, schema=self.schema, preprocessing_fn='my_preprocessing_fn', disable_analyzer_cache=True, analyzer_cache=channel_utils.as_channel( [standard_artifacts.TransformCache()]))
def test_construct_from_module_file(self): module_file = '/path/to/preprocessing.py' transform = component.Transform( examples=self.examples, schema=self.schema, module_file=module_file, ) self._verify_outputs(transform) self.assertEqual(module_file, transform.exec_properties['module_file'])
def test_construct_with_materialization_disabled_but_output_examples(self): with self.assertRaises(ValueError): _ = component.Transform( examples=self.examples, schema=self.schema, preprocessing_fn='my_preprocessing_fn', materialize=False, transformed_examples=channel_utils.as_channel( [standard_artifacts.Examples()]))
def testConstructFromModuleFile(self): module_file = '/path/to/preprocessing.py' transform = component.Transform( input_data=self.input_data, schema=self.schema, module_file=module_file, ) self._verify_outputs(transform) self.assertEqual(module_file, transform.spec.exec_properties['module_file'])
def test_construct_with_parameter(self): module_file = data_types.RuntimeParameter(name='module-file', ptype=Text) transform = component.Transform( examples=self.examples, schema=self.schema, module_file=module_file, ) self._verify_outputs(transform) self.assertJsonEqual( str(module_file), str(transform.exec_properties['module_file']))
def testConstructFromPreprocessingFn(self): preprocessing_fn = 'path.to.my_preprocessing_fn' transform = component.Transform( examples=self.examples, schema=self.schema, preprocessing_fn=preprocessing_fn, ) self._verify_outputs(transform) self.assertEqual(preprocessing_fn, transform.spec.exec_properties['preprocessing_fn'])
def test_construct_from_preprocessing_fn(self): preprocessing_fn = 'path.to.my_preprocessing_fn' transform = component.Transform( input_data=self.input_data, schema=self.schema, preprocessing_fn=preprocessing_fn, ) self._verify_outputs(transform) self.assertEqual(preprocessing_fn, transform.spec.exec_properties['preprocessing_fn'])
def __init__(self, input_data: str, schema: str, module_file: str): component = transform_component.Transform( input_data=channel.Channel('ExamplesPath'), schema=channel.Channel('SchemaPath'), module_file=module_file) super().__init__(component, { "input_data": input_data, "schema": schema })
def test_construct_from_preprocessing_fn(self): preprocessing_fn = 'path.to.my_preprocessing_fn' transform = component.Transform( examples=self.examples, schema=self.schema, preprocessing_fn=preprocessing_fn, ) self._verify_outputs(transform) self.assertEqual( preprocessing_fn, transform.exec_properties[ standard_component_specs.PREPROCESSING_FN_KEY])
def test_construct_with_force_tf_compat_v1_override(self): transform = component.Transform( examples=self.examples, schema=self.schema, preprocessing_fn='my_preprocessing_fn', force_tf_compat_v1=True, ) self._verify_outputs(transform) self.assertEqual( True, bool(transform.spec.exec_properties[ standard_component_specs.FORCE_TF_COMPAT_V1_KEY]))
def test_construct_with_stats_disabled(self): transform = component.Transform( examples=self.examples, schema=self.schema, preprocessing_fn='my_preprocessing_fn', disable_statistics=True, ) self._verify_outputs(transform, disable_statistics=True) self.assertEqual( True, bool(transform.spec.exec_properties[ standard_component_specs.DISABLE_STATISTICS_KEY]))
def __init__(self, input_data: dsl.PipelineParam, schema: dsl.PipelineParam, module_file: dsl.PipelineParam): component = transform_component.Transform( input_data=channel.Channel('ExamplesPath'), schema=channel.Channel('SchemaPath'), module_file='') super().__init__( component, { "input_data": input_data, "schema": schema, "module_file": module_file, })
def test_construct_from_preprocessing_fn_with_custom_config(self): preprocessing_fn = 'path.to.my_preprocessing_fn' custom_config = {'param': 1} transform = component.Transform( examples=self.examples, schema=self.schema, preprocessing_fn=preprocessing_fn, custom_config=custom_config, ) self._verify_outputs(transform) self.assertEqual(preprocessing_fn, transform.spec.exec_properties['preprocessing_fn']) self.assertEqual(json.dumps(custom_config), transform.spec.exec_properties['custom_config'])
def test_construct_with_splits_config(self): splits_config = transform_pb2.SplitsConfig( analyze=['train'], transform=['eval']) module_file = '/path/to/preprocessing.py' transform = component.Transform( examples=self.examples, schema=self.schema, module_file=module_file, splits_config=splits_config, ) self._verify_outputs(transform) self.assertEqual( json_format.MessageToJson( splits_config, sort_keys=True, preserving_proto_field_name=True), transform.exec_properties['splits_config'])
def test_construct_with_splits_config(self): splits_config = transform_pb2.SplitsConfig(analyze=['train'], transform=['eval']) module_file = '/path/to/preprocessing.py' transform = component.Transform( examples=self.examples, schema=self.schema, module_file=module_file, splits_config=splits_config, ) self._verify_outputs(transform) self.assertEqual( proto_utils.proto_to_json(splits_config), transform.exec_properties[ standard_component_specs.SPLITS_CONFIG_KEY])
def test_construct_from_preprocessing_fn_with_custom_config(self): preprocessing_fn = 'path.to.my_preprocessing_fn' custom_config = {'param': 1} transform = component.Transform( examples=self.examples, schema=self.schema, preprocessing_fn=preprocessing_fn, custom_config=custom_config, ) self._verify_outputs(transform) self.assertEqual( preprocessing_fn, transform.spec.exec_properties[ standard_component_specs.PREPROCESSING_FN_KEY]) self.assertEqual( json.dumps(custom_config), transform.spec.exec_properties[ standard_component_specs.CUSTOM_CONFIG_KEY])
def test_construct_from_preprocessing_fn_with_stats_options_updater_fn( self): preprocessing_fn = 'path.to.my_preprocessing_fn' stats_options_updater_fn = 'path.to.my.stats_options_updater_fn' transform = component.Transform( examples=self.examples, schema=self.schema, preprocessing_fn=preprocessing_fn, stats_options_updater_fn=stats_options_updater_fn) self._verify_outputs(transform) self.assertEqual( preprocessing_fn, transform.exec_properties[ standard_component_specs.PREPROCESSING_FN_KEY]) self.assertEqual( stats_options_updater_fn, transform.exec_properties[ standard_component_specs.STATS_OPTIONS_UPDATER_FN_KEY])
def test_construct(self): source_data_dir = os.path.join(os.path.dirname(__file__), 'testdata', 'taxi') preprocessing_fn_file = os.path.join(source_data_dir, 'module', 'preprocess.py') transform = component.Transform( input_data=channel.as_channel([ types.TfxType(type_name='ExamplesPath', split='train'), types.TfxType(type_name='ExamplesPath', split='eval'), ]), schema=channel.as_channel([types.TfxType(type_name='SchemaPath')]), module_file=preprocessing_fn_file, ) self.assertEqual('TransformPath', transform.outputs.transform_output.type_name) self.assertEqual('ExamplesPath', transform.outputs.transformed_examples.type_name)
def test_construct_missing_user_module(self): with self.assertRaises(ValueError): _ = component.Transform( input_data=self.input_data, schema=self.schema, )
def testConstructMissingUserModule(self): with self.assertRaises(ValueError): _ = component.Transform( examples=self.examples, schema=self.schema, )
def testConstructMissingUserModule(self): with self.assertRaises(ValueError): _ = component.Transform( input_data=self.input_data, schema=self.schema, )
def _create_pipeline(): """Implements the chicago east pipeline with TFX.""" print_info("Creating pipeline") examples = tfrecord_input(_data_root) example_gen = ImportExampleGen(input_base=examples) # Computes statistics over data for visualization and example validation. # pylint: disable=line-too-long statistics_gen = StatisticsGen(input_data=example_gen.outputs.examples) # Step 3 # pylint: enable=line-too-long # Generates schema based on statistics files. infer_schema = SchemaGen(stats=statistics_gen.outputs.output) # Step 3 # Performs anomaly detection based on statistics and data schema. validate_stats = ExampleValidator( # Step 3 stats=statistics_gen.outputs.output, # Step 3 schema=infer_schema.outputs.output) # Step 3 # Performs transformations and feature engineering in training and serving. transform = component.Transform( # Step 4 input_data=example_gen.outputs.examples, # Step 4 schema=infer_schema.outputs.output, # Step 4 module_file=_east_module_file) # Step 4 # Uses user-provided Python function that implements a model using TF-Learn. trainer = Trainer( # Step 5 module_file=_east_module_file, # Step 5 transformed_examples=transform.outputs.transformed_examples, # Step 5 schema=infer_schema.outputs.output, # Step 5 transform_output=transform.outputs.transform_output, # Step 5 train_args=trainer_pb2.TrainArgs(num_steps=50), # Step 5 eval_args=trainer_pb2.EvalArgs(num_steps=10)) # Step 5 # Uses TFMA to compute a evaluation statistics over features of a model. #model_analyzer = Evaluator( # Step 6 # examples=example_gen.outputs.examples, # Step 6 # model_exports=trainer.outputs.output, # Step 6 # feature_slicing_spec=evaluator_pb2.FeatureSlicingSpec(specs=[ # Step 6 # evaluator_pb2.SingleSlicingSpec( # Step 6 # column_for_slicing=['trip_start_hour']) # Step 6 # ])) # Step 6 # Performs quality validation of a candidate model (compared to a baseline). # model_validator = ModelValidator( # Step 7 # examples=example_gen.outputs.examples, # Step 7 # model=trainer.outputs.output) # Step 7 # Checks whether the model passed the validation steps and pushes the model # to a file destination if check passed. # pusher = Pusher( # Step 7 # model_export=trainer.outputs.output, # Step 7 # model_blessing=model_validator.outputs.blessing, # Step 7 # push_destination=pusher_pb2.PushDestination( # Step 7 # filesystem=pusher_pb2.PushDestination.Filesystem( # Step 7 # base_directory=_serving_model_dir))) # Step 7 return [ example_gen, statistics_gen, infer_schema, #validate_stats, # Step 3 transform, # Step 4 trainer, # Step 5 #model_analyzer, # Step 6 # model_validator, pusher # Step 7 ]
def test_construct_missing_user_module(self): with self.assertRaises(ValueError): _ = component.Transform( examples=self.examples, schema=self.schema, )