def ResolveSplitsConfig( splits_config_str: Optional[str], examples: List[types.Artifact]) -> transform_pb2.SplitsConfig: """Resolve SplitsConfig proto for the transfrom request.""" result = transform_pb2.SplitsConfig() if splits_config_str: proto_utils.json_to_proto(splits_config_str, result) if not result.analyze: raise ValueError( 'analyze cannot be empty when splits_config is set.') return result result.analyze.append('train') # All input artifacts should have the same set of split names. split_names = set( artifact_utils.decode_split_names(examples[0].split_names)) for artifact in examples: artifact_split_names = set( artifact_utils.decode_split_names(artifact.split_names)) if split_names != artifact_split_names: raise ValueError( 'Not all input artifacts have the same split names: (%s, %s)' % (split_names, artifact_split_names)) result.transform.extend(split_names) logging.info("Analyze the 'train' split and transform all splits when " 'splits_config is not set.') return result
def test_do_with_empty_analyze_splits(self): self._exec_properties['splits_config'] = proto_utils.proto_to_json( transform_pb2.SplitsConfig(analyze=[], transform=['train', 'eval'])) self._exec_properties['module_file'] = self._module_file with self.assertRaises(ValueError): self._transform_executor.Do(self._input_dict, self._output_dict, self._exec_properties)
def test_do_with_custom_splits(self): self._exec_properties['splits_config'] = proto_utils.proto_to_json( transform_pb2.SplitsConfig(analyze=['train'], transform=['train', 'eval'])) self._exec_properties['module_file'] = self._module_file self._transform_executor.Do(self._input_dict, self._output_dict, self._exec_properties) self._verify_transform_outputs()
def test_do_with_empty_analyze_splits(self): self._exec_properties['splits_config'] = json_format.MessageToJson( transform_pb2.SplitsConfig(analyze=[], transform=['train', 'eval']), preserving_proto_field_name=True) self._exec_properties['module_file'] = self._module_file with self.assertRaises(ValueError): self._transform_executor.Do(self._input_dict, self._output_dict, self._exec_properties)
def test_do_with_custom_splits(self): self._exec_properties['splits_config'] = json_format.MessageToJson( transform_pb2.SplitsConfig(analyze=['train'], transform=['train', 'eval']), preserving_proto_field_name=True) self._exec_properties['module_file'] = self._module_file self._transform_executor.Do(self._input_dict, self._output_dict, self._exec_properties) self._verify_transform_outputs()
def test_do_with_empty_analyze_splits(self): self._exec_properties[standard_component_specs. SPLITS_CONFIG_KEY] = proto_utils.proto_to_json( transform_pb2.SplitsConfig( analyze=[], transform=['train', 'eval'])) self._exec_properties[ standard_component_specs.MODULE_FILE_KEY] = self._module_file with self.assertRaises(ValueError): self._transform_executor.Do(self._input_dict, self._output_dict, self._exec_properties)
def test_do_with_custom_splits(self): self._exec_properties[standard_component_specs. SPLITS_CONFIG_KEY] = proto_utils.proto_to_json( transform_pb2.SplitsConfig( analyze=['train'], transform=['train', 'eval'])) self._exec_properties[ standard_component_specs.MODULE_FILE_KEY] = self._module_file self._transform_executor.Do(self._input_dict, self._output_dict, self._exec_properties) self._verify_transform_outputs()
def test_construct_with_splits_config(self): splits_config = transform_pb2.SplitsConfig(analyze=['train'], transform=['eval']) module_file = '/path/to/preprocessing.py' transform = component.Transform( examples=self.examples, schema=self.schema, module_file=module_file, splits_config=splits_config, ) self._verify_outputs(transform) self.assertEqual( proto_utils.proto_to_json(splits_config), transform.exec_properties[ standard_component_specs.SPLITS_CONFIG_KEY])
def test_construct_with_splits_config(self): splits_config = transform_pb2.SplitsConfig( analyze=['train'], transform=['eval']) module_file = '/path/to/preprocessing.py' transform = component.Transform( examples=self.examples, schema=self.schema, module_file=module_file, splits_config=splits_config, ) self._verify_outputs(transform) self.assertEqual( json_format.MessageToJson( splits_config, sort_keys=True, preserving_proto_field_name=True), transform.exec_properties['splits_config'])
def test_do_with_empty_transform_splits(self): self._exec_properties['splits_config'] = proto_utils.proto_to_json( transform_pb2.SplitsConfig(analyze=['train'], transform=[])) self._exec_properties['module_file'] = self._module_file self._output_dict[executor.TRANSFORMED_EXAMPLES_KEY] = ( self._transformed_example_artifacts[:1]) self._transform_executor.Do(self._input_dict, self._output_dict, self._exec_properties) self.assertFalse( fileio.exists( os.path.join(self._transformed_example_artifacts[0].uri, 'train'))) self.assertFalse( fileio.exists( os.path.join(self._transformed_example_artifacts[0].uri, 'eval'))) path_to_saved_model = os.path.join( self._transformed_output.uri, tft.TFTransformOutput.TRANSFORM_FN_DIR, tf.saved_model.SAVED_MODEL_FILENAME_PB) self.assertTrue(fileio.exists(path_to_saved_model))
def test_do_with_empty_transform_splits(self): self._exec_properties['splits_config'] = json_format.MessageToJson( transform_pb2.SplitsConfig(analyze=['train'], transform=[]), preserving_proto_field_name=True) self._exec_properties['module_file'] = self._module_file self._transformed_examples.split_names = artifact_utils.encode_split_names( []) self._output_dict[executor.TRANSFORMED_EXAMPLES_KEY] = [ self._transformed_examples ] self._transform_executor.Do(self._input_dict, self._output_dict, self._exec_properties) self.assertFalse( tf.io.gfile.exists( os.path.join(self._transformed_examples.uri, 'train'))) self.assertFalse( tf.io.gfile.exists( os.path.join(self._transformed_examples.uri, 'eval'))) path_to_saved_model = os.path.join( self._transformed_output.uri, tft.TFTransformOutput.TRANSFORM_FN_DIR, tf.saved_model.SAVED_MODEL_FILENAME_PB) self.assertTrue(tf.io.gfile.exists(path_to_saved_model))
def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text, module_file: Text, serving_model_dir: Text, metadata_path: Text, beam_pipeline_args: List[Text]): """Creates pipeline.""" pipeline_root = os.path.join(pipeline_root, 'pipelines', pipeline_name) example_gen = ImportExampleGen( input_base=data_root, # IMPORTANT: must set FORMAT_PROTO payload_format=example_gen_pb2.FORMAT_PROTO) data_view_provider = provider_component.TfGraphDataViewProvider( module_file=module_file, create_decoder_func='make_decoder') data_view_binder = binder_component.DataViewBinder( example_gen.outputs['examples'], data_view_provider.outputs['data_view']) statistics_gen = StatisticsGen( examples=data_view_binder.outputs['output_examples']) schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics']) transform = Transform( examples=data_view_binder.outputs['output_examples'], schema=schema_gen.outputs['schema'], module_file=module_file, # important: must disable Transform materialization and ensure the # transform field of the splits config is empty. splits_config=transform_pb2.SplitsConfig(analyze=['train']), materialize=False) trainer = Trainer(examples=data_view_binder.outputs['output_examples'], transform_graph=transform.outputs['transform_graph'], module_file=module_file, train_args=trainer_pb2.TrainArgs(num_steps=1000), schema=schema_gen.outputs['schema'], eval_args=trainer_pb2.EvalArgs(num_steps=10)) eval_config = tfma.EvalConfig( model_specs=[ tfma.ModelSpec(signature_name='', label_key='relevance', padding_options=tfma.config.PaddingOptions( label_float_padding=-1.0, prediction_float_padding=-1.0)) ], slicing_specs=[ tfma.SlicingSpec(), tfma.SlicingSpec(feature_keys=['query_tokens']), ], metrics_specs=[ tfma.MetricsSpec( per_slice_thresholds={ 'metric/ndcg_10': tfma.config.PerSliceMetricThresholds(thresholds=[ tfma.PerSliceMetricThreshold( # The overall slice. slicing_specs=[tfma.SlicingSpec()], threshold=tfma.MetricThreshold( value_threshold=tfma.GenericValueThreshold( lower_bound={'value': 0.6}))) ]) }) ]) evaluator = Evaluator(examples=data_view_binder.outputs['output_examples'], model=trainer.outputs['model'], eval_config=eval_config, schema=schema_gen.outputs['schema']) # Checks whether the model passed the validation steps and pushes the model # to a file destination if check passed. pusher = Pusher(model=trainer.outputs['model'], model_blessing=evaluator.outputs['blessing'], push_destination=pusher_pb2.PushDestination( filesystem=pusher_pb2.PushDestination.Filesystem( base_directory=serving_model_dir))) return pipeline.Pipeline( pipeline_name=pipeline_name, pipeline_root=pipeline_root, components=[ example_gen, data_view_provider, data_view_binder, statistics_gen, schema_gen, transform, trainer, evaluator, pusher, ], enable_cache=True, metadata_connection_config=metadata.sqlite_metadata_connection_config( metadata_path), beam_pipeline_args=beam_pipeline_args)
def testResolveSplitsConfigOk(self): config = transform_pb2.SplitsConfig( analyze=['train'], transform=['train', 'eval']) config_str = proto_utils.proto_to_json(config) resolved = executor_utils.ResolveSplitsConfig(config_str, []) self.assertProtoEquals(config, resolved)
def testResolveSplitsConfigEmptyAnalyze(self): wrong_config = transform_pb2.SplitsConfig(transform=['train']) with self.assertRaisesRegex(ValueError, 'analyze cannot be empty'): config_str = proto_utils.proto_to_json(wrong_config) executor_utils.ResolveSplitsConfig(config_str, [])
def create_pipeline(pipeline_name: Text, pipeline_root: Text): """ Args: pipeline_name: pipeline_root: num_epochs: batch_size: learning_rate: hidden_units: Returns: pipeline: """ # Get train split BigQuery query. train_sql_query = bq_datasource_utils.get_training_source_query( config.GOOGLE_CLOUD_PROJECT_ID, config.GOOGLE_CLOUD_REGION, config.DATASET_DISPLAY_NAME, ml_use="UNASSIGNED", limit=int(config.TRAIN_LIMIT), ) # Configure train and eval splits for model training and evaluation. train_output_config = example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig( splits=[ example_gen_pb2.SplitConfig.Split( name="train", hash_buckets=int(config.NUM_TRAIN_SPLITS) ), example_gen_pb2.SplitConfig.Split( name="eval", hash_buckets=int(config.NUM_EVAL_SPLITS) ), ] ) ) # Generate train split examples. train_example_gen = BigQueryExampleGen( query=train_sql_query, output_config=train_output_config, ).with_id("TrainDataGen") # Get test source query. test_sql_query = bq_datasource_utils.get_training_source_query( config.GOOGLE_CLOUD_PROJECT_ID, config.GOOGLE_CLOUD_REGION, config.DATASET_DISPLAY_NAME, ml_use="TEST", limit=int(config.TEST_LIMIT), ) # Configure test split for model evaluation. test_output_config = example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig( splits=[ example_gen_pb2.SplitConfig.Split(name="test", hash_buckets=1), ] ) ) # Test example generation. test_example_gen = BigQueryExampleGen( query=test_sql_query, output_config=test_output_config, ).with_id("TestDataGen") # Schema importer. schema_importer = Importer( source_uri=SCHEMA_DIR, artifact_type=Schema, ).with_id("SchemaImporter") # schema_importer = ImportSchemaGen(schema_file=SCHEMA_FILE).with_id("SchemaImporter") # Generate dataset statistics. statistics_gen = StatisticsGen( examples=train_example_gen.outputs["examples"] ).with_id("StatisticsGen") # Generate data schema file. # schema_gen = SchemaGen( # statistics=statistics_gen.outputs["statistics"], infer_feature_shape=True # ) # Example validation. example_validator = ExampleValidator( statistics=statistics_gen.outputs["statistics"], schema=schema_importer.outputs["result"], ).with_id("ExampleValidator") # Data transformation. transform = Transform( examples=train_example_gen.outputs["examples"], schema=schema_importer.outputs["result"], module_file=TRANSFORM_MODULE_FILE, # This is a temporary workaround to run on Dataflow. force_tf_compat_v1=config.BEAM_RUNNER == "DataflowRunner", splits_config=transform_pb2.SplitsConfig( analyze=["train"], transform=["train", "eval"] ), ).with_id("Tranform") # Add dependency from example_validator to transform. transform.add_upstream_node(example_validator) # Train model on Vertex AI. trainer = VertexTrainer( module_file=TRAIN_MODULE_FILE, examples=transform.outputs["transformed_examples"], transform_graph=transform.outputs["transform_graph"], custom_config=config.VERTEX_TRAINING_CONFIG, ).with_id("ModelTrainer") # Get the latest blessed model (baseline) for model validation. baseline_model_resolver = Resolver( strategy_class=LatestBlessedModelStrategy, model=Channel(type=Model), model_blessing=Channel(type=ModelBlessing), ).with_id("BaselineModelResolver") # Prepare evaluation config. eval_config = tfma.EvalConfig( model_specs=[ tfma.ModelSpec( signature_name="serving_tf_example", label_key=features.TARGET_FEATURE_NAME, prediction_key="probabilities", ) ], slicing_specs=[ tfma.SlicingSpec(), ], metrics_specs=[ tfma.MetricsSpec( metrics=[ tfma.MetricConfig(class_name="ExampleCount"), tfma.MetricConfig( class_name="BinaryAccuracy", threshold=tfma.MetricThreshold( value_threshold=tfma.GenericValueThreshold( lower_bound={"value": float(config.ACCURACY_THRESHOLD)} ), # Change threshold will be ignored if there is no # baseline model resolved from MLMD (first run). change_threshold=tfma.GenericChangeThreshold( direction=tfma.MetricDirection.HIGHER_IS_BETTER, absolute={"value": -1e-10}, ), ), ), ] ) ], ) # Model evaluation. evaluator = Evaluator( examples=test_example_gen.outputs["examples"], example_splits=["test"], model=trainer.outputs["model"], baseline_model=baseline_model_resolver.outputs["model"], eval_config=eval_config, schema=schema_importer.outputs["result"], ).with_id("ModelEvaluator") exported_model_location = os.path.join( config.MODEL_REGISTRY_URI, config.MODEL_DISPLAY_NAME ) push_destination = pusher_pb2.PushDestination( filesystem=pusher_pb2.PushDestination.Filesystem( base_directory=exported_model_location ) ) # Push custom model to model registry. pusher = Pusher( model=trainer.outputs["model"], model_blessing=evaluator.outputs["blessing"], push_destination=push_destination, ).with_id("ModelPusher") pipeline_components = [ train_example_gen, test_example_gen, schema_importer, statistics_gen, # schema_gen, example_validator, transform, trainer, baseline_model_resolver, evaluator, pusher, ] logging.info( "Pipeline components: %s", ", ".join([component.id for component in pipeline_components]), ) beam_pipeline_args = config.BEAM_DIRECT_PIPELINE_ARGS if config.BEAM_RUNNER == "DataflowRunner": beam_pipeline_args = config.BEAM_DATAFLOW_PIPELINE_ARGS logging.info("Beam pipeline args: %s", beam_pipeline_args) return Pipeline( pipeline_name=pipeline_name, pipeline_root=pipeline_root, components=pipeline_components, beam_pipeline_args=beam_pipeline_args, enable_cache=int(config.ENABLE_CACHE), )