def generate_models(self, args): # Modified version of Chicago Taxi Example pipeline # tfx/examples/chicago_taxi_pipeline/taxi_pipeline_beam.py root = tempfile.mkdtemp() pipeline_root = os.path.join(root, "pipeline") metadata_path = os.path.join(root, "metadata/metadata.db") module_file = os.path.join( os.path.dirname(__file__), "../../../examples/chicago_taxi_pipeline/taxi_utils.py") examples = external_input(os.path.dirname(self.dataset_path())) example_gen = components.ImportExampleGen(input=examples) statistics_gen = components.StatisticsGen( examples=example_gen.outputs["examples"]) schema_gen = components.SchemaGen( statistics=statistics_gen.outputs["statistics"], infer_feature_shape=False) transform = components.Transform( examples=example_gen.outputs["examples"], schema=schema_gen.outputs["schema"], module_file=module_file) trainer = components.Trainer( module_file=module_file, transformed_examples=transform.outputs["transformed_examples"], schema=schema_gen.outputs["schema"], transform_graph=transform.outputs["transform_graph"], train_args=trainer_pb2.TrainArgs(num_steps=100), eval_args=trainer_pb2.EvalArgs(num_steps=50)) p = pipeline.Pipeline(pipeline_name="chicago_taxi_beam", pipeline_root=pipeline_root, components=[ example_gen, statistics_gen, schema_gen, transform, trainer ], enable_cache=True, metadata_connection_config=metadata. sqlite_metadata_connection_config(metadata_path)) BeamDagRunner().run(p) def join_unique_subdir(path): dirs = os.listdir(path) if len(dirs) != 1: raise ValueError( "expecting there to be only one subdirectory in %s, but " "subdirectories were: %s" % (path, dirs)) return os.path.join(path, dirs[0]) trainer_output_dir = join_unique_subdir( os.path.join(pipeline_root, "Trainer/output")) eval_model_dir = join_unique_subdir( os.path.join(trainer_output_dir, "eval_model_dir")) serving_model_dir = join_unique_subdir( os.path.join(trainer_output_dir, "serving_model_dir/export/chicago-taxi")) shutil.rmtree(self.trained_saved_model_path(), ignore_errors=True) shutil.rmtree(self.tfma_saved_model_path(), ignore_errors=True) shutil.copytree(serving_model_dir, self.trained_saved_model_path()) shutil.copytree(eval_model_dir, self.tfma_saved_model_path())
def _build_schema_gen(self, statistics: types.Channel, instance_name: Optional[str]) -> tfx.SchemaGen: """Returns the SchemaGen component.""" return tfx.SchemaGen(statistics=statistics, infer_feature_shape=self._preprocessor. requires_inferred_feature_shapes, instance_name=instance_name)
def create_e2e_components(csv_input_location: str, ) -> List[BaseComponent]: """Creates components for a simple Chicago Taxi TFX pipeline for testing. Because we don't need to run whole pipeline, we will make a very short toy pipeline. Args: csv_input_location: The location of the input data directory. Returns: A list of TFX components that constitutes an end-to-end test pipeline. """ example_gen = components.CsvExampleGen(input_base=csv_input_location) statistics_gen = components.StatisticsGen( examples=example_gen.outputs['examples']) schema_gen = components.SchemaGen( statistics=statistics_gen.outputs['statistics'], infer_feature_shape=False) return [example_gen, statistics_gen, schema_gen]
def create_pipeline_components( pipeline_root: Text, transform_module: Text, trainer_module: Text, bigquery_query: Text = '', csv_input_location: Text = '', ) -> List[base_node.BaseNode]: """Creates components for a simple Chicago Taxi TFX pipeline for testing. Args: pipeline_root: The root of the pipeline output. transform_module: The location of the transform module file. trainer_module: The location of the trainer module file. bigquery_query: The query to get input data from BigQuery. If not empty, BigQueryExampleGen will be used. csv_input_location: The location of the input data directory. Returns: A list of TFX components that constitutes an end-to-end test pipeline. """ if bool(bigquery_query) == bool(csv_input_location): raise ValueError( 'Exactly one example gen is expected. ', 'Please provide either bigquery_query or csv_input_location.') if bigquery_query: example_gen = big_query_example_gen_component.BigQueryExampleGen( query=bigquery_query) else: example_gen = components.CsvExampleGen(input_base=csv_input_location) statistics_gen = components.StatisticsGen( examples=example_gen.outputs['examples']) schema_gen = components.SchemaGen( statistics=statistics_gen.outputs['statistics'], infer_feature_shape=False) example_validator = components.ExampleValidator( statistics=statistics_gen.outputs['statistics'], schema=schema_gen.outputs['schema']) transform = components.Transform(examples=example_gen.outputs['examples'], schema=schema_gen.outputs['schema'], module_file=transform_module) latest_model_resolver = resolver.Resolver( strategy_class=latest_artifacts_resolver.LatestArtifactsResolver, model=channel.Channel(type=standard_artifacts.Model)).with_id( 'Resolver.latest_model_resolver') trainer = components.Trainer( custom_executor_spec=executor_spec.ExecutorClassSpec(Executor), transformed_examples=transform.outputs['transformed_examples'], schema=schema_gen.outputs['schema'], base_model=latest_model_resolver.outputs['model'], transform_graph=transform.outputs['transform_graph'], train_args=trainer_pb2.TrainArgs(num_steps=10), eval_args=trainer_pb2.EvalArgs(num_steps=5), module_file=trainer_module, ) # Get the latest blessed model for model validation. model_resolver = resolver.Resolver( strategy_class=latest_blessed_model_resolver. LatestBlessedModelResolver, model=channel.Channel(type=standard_artifacts.Model), model_blessing=channel.Channel( type=standard_artifacts.ModelBlessing)).with_id( 'Resolver.latest_blessed_model_resolver') # Set the TFMA config for Model Evaluation and Validation. eval_config = tfma.EvalConfig( model_specs=[tfma.ModelSpec(signature_name='eval')], metrics_specs=[ tfma.MetricsSpec( metrics=[tfma.MetricConfig(class_name='ExampleCount')], thresholds={ 'binary_accuracy': tfma.MetricThreshold( value_threshold=tfma.GenericValueThreshold( lower_bound={'value': 0.5}), change_threshold=tfma.GenericChangeThreshold( direction=tfma.MetricDirection.HIGHER_IS_BETTER, absolute={'value': -1e-10})) }) ], slicing_specs=[ tfma.SlicingSpec(), tfma.SlicingSpec(feature_keys=['trip_start_hour']) ]) evaluator = components.Evaluator( examples=example_gen.outputs['examples'], model=trainer.outputs['model'], baseline_model=model_resolver.outputs['model'], eval_config=eval_config) pusher = components.Pusher( model=trainer.outputs['model'], model_blessing=evaluator.outputs['blessing'], push_destination=pusher_pb2.PushDestination( filesystem=pusher_pb2.PushDestination.Filesystem( base_directory=os.path.join(pipeline_root, 'model_serving')))) return [ example_gen, statistics_gen, schema_gen, example_validator, transform, latest_model_resolver, trainer, model_resolver, evaluator, pusher ]