def testBuildLatestBlessedModelResolverSucceed(self): latest_blessed_resolver = components.ResolverNode( instance_name='my_resolver2', resolver_class=latest_blessed_model_resolver.LatestBlessedModelResolver, model=channel.Channel(type=standard_artifacts.Model), model_blessing=channel.Channel(type=standard_artifacts.ModelBlessing)) test_pipeline_info = data_types.PipelineInfo( pipeline_name='test-pipeline', pipeline_root='gs://path/to/my/root') deployment_config = pipeline_pb2.PipelineDeploymentConfig() my_builder = step_builder.StepBuilder( node=latest_blessed_resolver, deployment_config=deployment_config, pipeline_info=test_pipeline_info) actual_step_specs = my_builder.build() self.assertProtoEquals( text_format.Parse(_EXPECTED_LATEST_BLESSED_MODEL_RESOLVER_1, pipeline_pb2.PipelineTaskSpec()), actual_step_specs[0]) self.assertProtoEquals( text_format.Parse(_EXPECTED_LATEST_BLESSED_MODEL_RESOLVER_2, pipeline_pb2.PipelineTaskSpec()), actual_step_specs[1]) self.assertProtoEquals( text_format.Parse(_EXPECTED_LATEST_BLESSED_MODEL_EXECUTOR, pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
def testBuildLatestArtifactResolverSucceed(self): latest_model_resolver = components.ResolverNode( instance_name='my_resolver', resolver_class=latest_artifacts_resolver.LatestArtifactsResolver, model=channel.Channel(type=standard_artifacts.Model), examples=channel.Channel(type=standard_artifacts.Examples)) deployment_config = pipeline_pb2.PipelineDeploymentConfig() test_pipeline_info = data_types.PipelineInfo( pipeline_name='test-pipeline', pipeline_root='gs://path/to/my/root') my_builder = step_builder.StepBuilder( node=latest_model_resolver, deployment_config=deployment_config, pipeline_info=test_pipeline_info) actual_step_spec = self._sole(my_builder.build()) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_latest_artifact_resolver.pbtxt', pipeline_pb2.PipelineTaskSpec()), actual_step_spec) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_latest_artifact_resolver_executor.pbtxt', pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
def create_pipeline_components( pipeline_root: Text, transform_module: Text, trainer_module: Text, bigquery_query: Text = '', csv_input_location: Text = '', ) -> List[base_node.BaseNode]: """Creates components for a simple Chicago Taxi TFX pipeline for testing. Args: pipeline_root: The root of the pipeline output. transform_module: The location of the transform module file. trainer_module: The location of the trainer module file. bigquery_query: The query to get input data from BigQuery. If not empty, BigQueryExampleGen will be used. csv_input_location: The location of the input data directory. Returns: A list of TFX components that constitutes an end-to-end test pipeline. """ if bool(bigquery_query) == bool(csv_input_location): raise ValueError( 'Exactly one example gen is expected. ', 'Please provide either bigquery_query or csv_input_location.') if bigquery_query: example_gen = big_query_example_gen_component.BigQueryExampleGen( query=bigquery_query) else: examples = dsl_utils.external_input(csv_input_location) example_gen = components.CsvExampleGen(input=examples) statistics_gen = components.StatisticsGen( examples=example_gen.outputs['examples']) schema_gen = components.SchemaGen( statistics=statistics_gen.outputs['statistics'], infer_feature_shape=False) example_validator = components.ExampleValidator( statistics=statistics_gen.outputs['statistics'], schema=schema_gen.outputs['schema']) transform = components.Transform( examples=example_gen.outputs['examples'], schema=schema_gen.outputs['schema'], module_file=transform_module) latest_model_resolver = components.ResolverNode( instance_name='latest_model_resolver', resolver_class=latest_artifacts_resolver.LatestArtifactsResolver, model=channel.Channel(type=standard_artifacts.Model)) trainer = components.Trainer( transformed_examples=transform.outputs['transformed_examples'], schema=schema_gen.outputs['schema'], base_model=latest_model_resolver.outputs['model'], transform_graph=transform.outputs['transform_graph'], train_args=trainer_pb2.TrainArgs(num_steps=10), eval_args=trainer_pb2.EvalArgs(num_steps=5), module_file=trainer_module, ) # Get the latest blessed model for model validation. model_resolver = components.ResolverNode( instance_name='latest_blessed_model_resolver', resolver_class=latest_blessed_model_resolver.LatestBlessedModelResolver, model=channel.Channel(type=standard_artifacts.Model), model_blessing=channel.Channel(type=standard_artifacts.ModelBlessing)) # Set the TFMA config for Model Evaluation and Validation. eval_config = tfma.EvalConfig( model_specs=[tfma.ModelSpec(signature_name='eval')], metrics_specs=[ tfma.MetricsSpec( metrics=[tfma.MetricConfig(class_name='ExampleCount')], thresholds={ 'binary_accuracy': tfma.MetricThreshold( value_threshold=tfma.GenericValueThreshold( lower_bound={'value': 0.5}), change_threshold=tfma.GenericChangeThreshold( direction=tfma.MetricDirection.HIGHER_IS_BETTER, absolute={'value': -1e-10})) }) ], slicing_specs=[ tfma.SlicingSpec(), tfma.SlicingSpec(feature_keys=['trip_start_hour']) ]) evaluator = components.Evaluator( examples=example_gen.outputs['examples'], model=trainer.outputs['model'], baseline_model=model_resolver.outputs['model'], eval_config=eval_config) pusher = components.Pusher( model=trainer.outputs['model'], model_blessing=evaluator.outputs['blessing'], push_destination=pusher_pb2.PushDestination( filesystem=pusher_pb2.PushDestination.Filesystem( base_directory=os.path.join(pipeline_root, 'model_serving')))) return [ example_gen, statistics_gen, schema_gen, example_validator, transform, latest_model_resolver, trainer, model_resolver, evaluator, pusher ]