def testChannelParameterType(self): arg_name = 'foo' class _FooArtifact(artifact.Artifact): TYPE_NAME = 'FooArtifact' class _BarArtifact(artifact.Artifact): TYPE_NAME = 'BarArtifact' channel_parameter = ChannelParameter(type=_FooArtifact) # Following should pass. channel_parameter.type_check(arg_name, channel.Channel(type=_FooArtifact)) with self.assertRaisesRegex(TypeError, arg_name): channel_parameter.type_check(arg_name, 42) # Wrong value. with self.assertRaisesRegex(TypeError, arg_name): channel_parameter.type_check(arg_name, channel.Channel(type=_BarArtifact)) setattr(_FooArtifact, component_spec.COMPATIBLE_TYPES_KEY, {_BarArtifact}) channel_parameter.type_check(arg_name, channel.Channel(type=_BarArtifact))
def testBuildLatestBlessedModelResolverSucceed(self): latest_blessed_resolver = components.ResolverNode( instance_name='my_resolver2', resolver_class=latest_blessed_model_resolver.LatestBlessedModelResolver, model=channel.Channel(type=standard_artifacts.Model), model_blessing=channel.Channel(type=standard_artifacts.ModelBlessing)) test_pipeline_info = data_types.PipelineInfo( pipeline_name='test-pipeline', pipeline_root='gs://path/to/my/root') deployment_config = pipeline_pb2.PipelineDeploymentConfig() my_builder = step_builder.StepBuilder( node=latest_blessed_resolver, deployment_config=deployment_config, pipeline_info=test_pipeline_info) actual_step_specs = my_builder.build() self.assertProtoEquals( text_format.Parse(_EXPECTED_LATEST_BLESSED_MODEL_RESOLVER_1, pipeline_pb2.PipelineTaskSpec()), actual_step_specs[0]) self.assertProtoEquals( text_format.Parse(_EXPECTED_LATEST_BLESSED_MODEL_RESOLVER_2, pipeline_pb2.PipelineTaskSpec()), actual_step_specs[1]) self.assertProtoEquals( text_format.Parse(_EXPECTED_LATEST_BLESSED_MODEL_EXECUTOR, pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
def testBuildLatestArtifactResolverSucceed(self): latest_model_resolver = resolver.Resolver( instance_name='my_resolver', strategy_class=latest_artifacts_resolver.LatestArtifactsResolver, model=channel.Channel(type=standard_artifacts.Model), examples=channel.Channel(type=standard_artifacts.Examples)) deployment_config = pipeline_pb2.PipelineDeploymentConfig() component_defs = {} test_pipeline_info = data_types.PipelineInfo( pipeline_name='test-pipeline', pipeline_root='gs://path/to/my/root') my_builder = step_builder.StepBuilder( node=latest_model_resolver, deployment_config=deployment_config, pipeline_info=test_pipeline_info, component_defs=component_defs) actual_step_spec = self._sole(my_builder.build()) actual_component_def = self._sole(component_defs) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_latest_artifact_resolver_component.pbtxt', pipeline_pb2.ComponentSpec()), actual_component_def) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_latest_artifact_resolver_task.pbtxt', pipeline_pb2.PipelineTaskSpec()), actual_step_spec) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_latest_artifact_resolver_executor.pbtxt', pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
def testComponentSpec_JsonProto(self): proto_str = '{"splits": [{"name": "name1", "pattern": "pattern1"}]}' spec = _BasicComponentSpec( folds=10, proto=proto_str, input=channel.Channel(type=_InputArtifact), output=channel.Channel(type=_OutputArtifact)) self.assertIsInstance(spec.exec_properties['proto'], str) self.assertEqual(spec.exec_properties['proto'], proto_str)
def testValidUnionChannel(self): channel1 = channel.Channel(type=_MyType) channel2 = channel.Channel(type=_MyType) union_channel = channel.union([channel1, channel2]) self.assertIs(union_channel.type_name, 'MyTypeName') self.assertEqual(union_channel.channels, [channel1, channel2]) union_channel = channel.union([channel1, channel.union([channel2])]) self.assertIs(union_channel.type_name, 'MyTypeName') self.assertEqual(union_channel.channels, [channel1, channel2])
def testBuildLatestBlessedModelResolverSucceed(self): latest_blessed_resolver = resolver.Resolver( instance_name='my_resolver2', strategy_class=latest_blessed_model_resolver. LatestBlessedModelResolver, model=channel.Channel(type=standard_artifacts.Model), model_blessing=channel.Channel( type=standard_artifacts.ModelBlessing)) test_pipeline_info = data_types.PipelineInfo( pipeline_name='test-pipeline', pipeline_root='gs://path/to/my/root') deployment_config = pipeline_pb2.PipelineDeploymentConfig() component_defs = {} my_builder = step_builder.StepBuilder( node=latest_blessed_resolver, deployment_config=deployment_config, pipeline_info=test_pipeline_info, component_defs=component_defs) actual_step_specs = my_builder.build() model_blessing_resolver_id = 'Resolver.my_resolver2-model-blessing-resolver' model_resolver_id = 'Resolver.my_resolver2-model-resolver' self.assertSameElements( actual_step_specs.keys(), [model_blessing_resolver_id, model_resolver_id]) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_latest_blessed_model_resolver_component_1.pbtxt', pipeline_pb2.ComponentSpec()), component_defs[model_blessing_resolver_id]) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_latest_blessed_model_resolver_task_1.pbtxt', pipeline_pb2.PipelineTaskSpec()), actual_step_specs[model_blessing_resolver_id]) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_latest_blessed_model_resolver_component_2.pbtxt', pipeline_pb2.ComponentSpec()), component_defs[model_resolver_id]) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_latest_blessed_model_resolver_task_2.pbtxt', pipeline_pb2.PipelineTaskSpec()), actual_step_specs[model_resolver_id]) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_latest_blessed_model_resolver_executor.pbtxt', pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
def testComponentSpec_Basic(self): proto = example_gen_pb2.Input() proto.splits.extend([ example_gen_pb2.Input.Split(name='name1', pattern='pattern1'), example_gen_pb2.Input.Split(name='name2', pattern='pattern2'), example_gen_pb2.Input.Split(name='name3', pattern='pattern3'), ]) input_channel = channel.Channel(type=_InputArtifact) output_channel = channel.Channel(type=_OutputArtifact) spec = _BasicComponentSpec(folds=10, proto=proto, input=input_channel, output=output_channel) # Verify proto property. self.assertIsInstance(spec.exec_properties['proto'], str) decoded_proto = json.loads(spec.exec_properties['proto']) self.assertCountEqual(['splits'], decoded_proto.keys()) self.assertLen(decoded_proto['splits'], 3) self.assertCountEqual(['name1', 'name2', 'name3'], list(s['name'] for s in decoded_proto['splits'])) self.assertCountEqual(['pattern1', 'pattern2', 'pattern3'], list(s['pattern'] for s in decoded_proto['splits'])) # Verify other properties. self.assertEqual(10, spec.exec_properties['folds']) self.assertIs(spec.inputs['input'], input_channel) self.assertIs(spec.outputs['output'], output_channel) with self.assertRaisesRegex( TypeError, "Expected type <(class|type) 'int'> for parameter u?'folds' but got " 'string.'): spec = _BasicComponentSpec(folds='string', input=input_channel, output=output_channel) with self.assertRaisesRegex( TypeError, '.*should be a Channel of .*InputArtifact.*got (.|\\s)*Examples.*' ): spec = _BasicComponentSpec(folds=10, input=channel.Channel(type=Examples), output=output_channel) with self.assertRaisesRegex( TypeError, '.*should be a Channel of .*OutputArtifact.*got (.|\\s)*Examples.*' ): spec = _BasicComponentSpec(folds=10, input=input_channel, output=channel.Channel(type=Examples))
def testGetInidividualChannels(self): instance_a = _MyArtifact() instance_b = _MyArtifact() one_channel = channel.Channel(_MyArtifact).set_artifacts([instance_a]) another_channel = channel.Channel(_MyArtifact).set_artifacts( [instance_b]) result = channel_utils.get_individual_channels(one_channel) self.assertEqual(result, [one_channel]) result = channel_utils.get_individual_channels( channel.union([one_channel, another_channel])) self.assertEqual(result, [one_channel, another_channel])
def testExecutionParameterUseProto(self): class SpecWithNonPrimitiveTypes(ComponentSpec): PARAMETERS = { 'config_proto': ExecutionParameter(type=example_gen_pb2.Input, use_proto=True), 'boolean': ExecutionParameter(type=bool, use_proto=True), 'list_config_proto': ExecutionParameter(type=List[example_gen_pb2.Input], use_proto=True), 'list_boolean': ExecutionParameter(type=List[bool], use_proto=True), } INPUTS = { 'input': ChannelParameter(type=_InputArtifact), } OUTPUTS = { 'output': ChannelParameter(type=_OutputArtifact), } spec = SpecWithNonPrimitiveTypes( config_proto='{"splits": [{"name": "name", "pattern": "pattern"}]}', boolean=True, list_config_proto=[ example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='trainer', pattern='train.data') ]), example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='eval', pattern='*eval.data') ]) ], list_boolean=[False, True], input=channel.Channel(type=_InputArtifact), output=channel.Channel(type=_OutputArtifact)) # Verify exec_properties store parsed value when use_proto set to True. expected_proto = text_format.Parse( """ splits { name: "name" pattern: "pattern" } """, example_gen_pb2.Input()) self.assertProtoEquals(expected_proto, spec.exec_properties['config_proto']) self.assertEqual(True, spec.exec_properties['boolean']) self.assertIsInstance(spec.exec_properties['list_config_proto'], list) self.assertEqual(spec.exec_properties['list_boolean'], [False, True])
def testComponentSpec_WithUnionChannel(self): input_channel_1 = channel.Channel(type=_InputArtifact) input_channel_2 = channel.Channel(type=_InputArtifact) output_channel = channel.Channel(type=_OutputArtifact) spec = _BasicComponentSpec(folds=10, input=channel.union( [input_channel_1, input_channel_2]), output=output_channel) # Verify properties. self.assertEqual(10, spec.exec_properties['folds']) self.assertEqual(spec.inputs['input'].type, _InputArtifact) self.assertEqual(spec.inputs['input'].channels, [input_channel_1, input_channel_2]) self.assertIs(spec.outputs['output'], output_channel)
def testFutureProducesPlaceholder(self): chnl = channel.Channel(type=_MyType) future = chnl.future() self.assertIsInstance(future, placeholder.ChannelWrappedPlaceholder) self.assertIs(future.channel, chnl) self.assertIsInstance(future[0], placeholder.ChannelWrappedPlaceholder) self.assertIsInstance(future.value, placeholder.ChannelWrappedPlaceholder)
def testComponentSpec_MissingArguments(self): class SimpleComponentSpec(ComponentSpec): PARAMETERS = { 'x': ExecutionParameter(type=int), 'y': ExecutionParameter(type=int, optional=True), } INPUTS = {'z': ChannelParameter(type=_Z)} OUTPUTS = {} with self.assertRaisesRegex(ValueError, 'Missing argument'): SimpleComponentSpec(x=10) with self.assertRaisesRegex(ValueError, 'Missing argument'): SimpleComponentSpec(z=channel.Channel(type=_Z)) # Okay since y is optional. SimpleComponentSpec(x=10, z=channel.Channel(type=_Z))
def testBuildOutputArtifactSpec(self): examples = standard_artifacts.Examples() examples.span = 1 examples.set_int_custom_property(key='int_param', value=42) examples.set_string_custom_property(key='str_param', value='42') example_channel = channel.Channel( type=standard_artifacts.Examples).set_artifacts([examples]) spec = compiler_utils.build_output_artifact_spec(example_channel) expected_spec = text_format.Parse( """ artifact_type { instance_schema: "title: tfx.Examples\\ntype: object\\nproperties:\\n span:\\n type: integer\\n description: Span for an artifact.\\n version:\\n type: integer\\n description: Version for an artifact.\\n split_names:\\n type: string\\n description: JSON-encoded list of splits for an artifact. Empty string means artifact has no split.\\n" } metadata { fields { key: "int_param" value { number_value: 42.0 } } fields { key: "span" value { number_value: 1.0 } } fields { key: "str_param" value { string_value: "42" } } } """, pipeline_pb2.ComponentOutputsSpec.ArtifactSpec()) self.assertProtoEquals(spec, expected_spec) # Empty output channel with only type info. model_channel = channel.Channel(type=standard_artifacts.Model) spec = compiler_utils.build_output_artifact_spec(model_channel) expected_spec = text_format.Parse( """ artifact_type { instance_schema: "title: tfx.Model\\ntype: object\\n" } """, pipeline_pb2.ComponentOutputsSpec.ArtifactSpec()) self.assertProtoEquals(spec, expected_spec)
def testBuildInputArtifactSpec(self): spec = compiler_utils.build_input_artifact_spec( channel.Channel(type=standard_artifacts.Model)) expected_spec = text_format.Parse( r'artifact_type { instance_schema: "title: tfx.Model\ntype: object\n" }', pipeline_pb2.ComponentInputsSpec.ArtifactSpec()) self.assertProtoEquals(spec, expected_spec) # Test artifact type with properties spec = compiler_utils.build_input_artifact_spec( channel.Channel(type=standard_artifacts.Examples)) expected_spec = text_format.Parse( """ artifact_type { instance_schema: "title: tfx.Examples\\ntype: object\\nproperties:\\n span:\\n type: integer\\n description: Span for an artifact.\\n version:\\n type: integer\\n description: Version for an artifact.\\n split_names:\\n type: string\\n description: JSON-encoded list of splits for an artifact. Empty string means artifact has no split.\\n" } """, pipeline_pb2.ComponentInputsSpec.ArtifactSpec()) self.assertProtoEquals(spec, expected_spec)
def testUnwrapChannelDict(self): instance_a = _MyArtifact() instance_b = _MyArtifact() channel_dict = { 'id': channel.Channel(_MyArtifact).set_artifacts( [instance_a, instance_b]) } result = channel_utils.unwrap_channel_dict(channel_dict) self.assertDictEqual(result, {'id': [instance_a, instance_b]})
def testJsonRoundTripUnknownArtifactClass(self): chnl = channel.Channel(type=_MyType) serialized = chnl.to_json_dict() serialized['type']['name'] = 'UnknownTypeName' rehydrated = channel.Channel.from_json_dict(serialized) self.assertEqual('UnknownTypeName', rehydrated.type_name) self.assertEqual(chnl.type._get_artifact_type().properties, rehydrated.type._get_artifact_type().properties) self.assertTrue(rehydrated.type._AUTOGENERATED)
def testOptionalOutputs(self): class SpecWithOptionalOutput(ComponentSpec): PARAMETERS = {} INPUTS = {} OUTPUTS = {'x': ChannelParameter(type=_Z, optional=True)} optional_not_specified = SpecWithOptionalOutput() self.assertNotIn('x', optional_not_specified.outputs.keys()) self.assertTrue(optional_not_specified.is_optional_output('x')) optional_specified = SpecWithOptionalOutput(x=channel.Channel(type=_Z)) self.assertIn('x', optional_specified.outputs.keys())
def testCompileDynamicExecPropTypeError(self): dsl_compiler = compiler.Compiler() test_pipeline = self._get_test_pipeline_definition( dynamic_exec_properties_pipeline) downstream_component = next( c for c in test_pipeline.components if isinstance(c, dynamic_exec_properties_pipeline.DownstreamComponent)) instance_a = _MyType() instance_b = _MyType() test_wrong_type_channel = channel.Channel(_MyType).set_artifacts( [instance_a, instance_b]).future() downstream_component.exec_properties["input_num"] = test_wrong_type_channel with self.assertRaisesRegex( ValueError, "output channel to dynamic exec properties is not ValueArtifact"): dsl_compiler.compile(test_pipeline)
def testJsonRoundTrip(self): chnl = channel.Channel( type=_MyType, additional_properties={ 'string_value': metadata_store_pb2.Value(string_value='forty-two') }, additional_custom_properties={ 'int_value': metadata_store_pb2.Value(int_value=42) }) serialized = chnl.to_json_dict() rehydrated = channel.Channel.from_json_dict(serialized) self.assertIs(chnl.type, rehydrated.type) self.assertEqual(chnl.type_name, rehydrated.type_name) self.assertEqual(chnl.additional_properties, rehydrated.additional_properties) self.assertEqual(chnl.additional_custom_properties, rehydrated.additional_custom_properties)
def as_channel(artifacts: Iterable[artifact.Artifact]) -> channel.Channel: """Converts artifact collection of the same artifact type into a Channel. Args: artifacts: An iterable of Artifact. Returns: A static Channel containing the source artifact collection. Raises: ValueError when source is not a non-empty iterable of Artifact. """ try: first_element = next(iter(artifacts)) if isinstance(first_element, artifact.Artifact): return channel.Channel( type=first_element.type).set_artifacts(artifacts) else: raise ValueError('Invalid artifact iterable: {}'.format(artifacts)) except StopIteration: raise ValueError('Cannot convert empty artifact iterable into Channel')
def _make_channel_dict( artifact_dict: Dict[Text, Text]) -> Dict[Text, channel.Channel]: """Makes a dictionary of artifact channels from a dictionary of artifacts. Args: artifact_dict: Dictionary of artifacts. Returns: Dictionary of artifact channels. Raises: RuntimeError: If list of artifacts is malformed. """ channel_dict = {} for name, artifact_list in artifact_dict.items(): if not artifact_list: raise RuntimeError( 'Found empty list of artifacts for input/output named {}: {}'. format(name, artifact_list)) type_name = artifact_list[0].type_name channel_dict[name] = channel.Channel(type_name=type_name, artifacts=artifact_list) return channel_dict
def testInvalidChannelType(self): instance_a = _MyType() instance_b = _MyType() with self.assertRaises(ValueError): channel.Channel(_AnotherType).set_artifacts([instance_a, instance_b])
def setUp(self): super().setUp() self._test_channel = channel.Channel(type=_MyArtifactWithProperty)
def testExecutionParameterTypeCheck(self): int_parameter = ExecutionParameter(type=int) int_parameter.type_check('int_parameter', 8) with self.assertRaisesRegex( TypeError, "Expected type <(class|type) 'int'>" " for parameter u?'int_parameter'"): int_parameter.type_check('int_parameter', 'string') list_parameter = ExecutionParameter(type=List[int]) list_parameter.type_check('list_parameter', []) list_parameter.type_check('list_parameter', [42]) with self.assertRaisesRegex(TypeError, 'Expecting a list for parameter'): list_parameter.type_check('list_parameter', 42) with self.assertRaisesRegex( TypeError, "Expecting item type <(class|type) " "'int'> for parameter u?'list_parameter'"): list_parameter.type_check('list_parameter', [42, 'wrong item']) dict_parameter = ExecutionParameter(type=Dict[str, int]) dict_parameter.type_check('dict_parameter', {}) dict_parameter.type_check('dict_parameter', {'key1': 1, 'key2': 2}) with self.assertRaisesRegex(TypeError, 'Expecting a dict for parameter'): dict_parameter.type_check('dict_parameter', 'simple string') with self.assertRaisesRegex( TypeError, "Expecting value type " "<(class|type) 'int'>"): dict_parameter.type_check('dict_parameter', {'key1': '1'}) proto_parameter = ExecutionParameter(type=example_gen_pb2.Input) proto_parameter.type_check('proto_parameter', example_gen_pb2.Input()) proto_parameter.type_check( 'proto_parameter', proto_utils.proto_to_json(example_gen_pb2.Input())) proto_parameter.type_check('proto_parameter', {'splits': [{ 'name': 'hello' }]}) proto_parameter.type_check('proto_parameter', {'wrong_field': 42}) with self.assertRaisesRegex( TypeError, "Expected type <class 'tfx.proto.example_gen_pb2.Input'>"): proto_parameter.type_check('proto_parameter', 42) with self.assertRaises(json_format.ParseError): proto_parameter.type_check('proto_parameter', {'splits': 42}) output_channel = channel.Channel(type=_OutputArtifact) placeholder_parameter = ExecutionParameter(type=str) placeholder_parameter.type_check( 'wrapped_channel_placeholder_parameter', output_channel.future()[0].value) placeholder_parameter.type_check( 'placeholder_parameter', placeholder.runtime_info('platform_config').base_dir) with self.assertRaisesRegex( TypeError, 'Only simple RuntimeInfoPlaceholders are supported'): placeholder_parameter.type_check( 'placeholder_parameter', placeholder.runtime_info('platform_config').base_dir + placeholder.exec_property('version'))
def create_pipeline_components( pipeline_root: Text, transform_module: Text, trainer_module: Text, bigquery_query: Text = '', csv_input_location: Text = '', ) -> List[base_node.BaseNode]: """Creates components for a simple Chicago Taxi TFX pipeline for testing. Args: pipeline_root: The root of the pipeline output. transform_module: The location of the transform module file. trainer_module: The location of the trainer module file. bigquery_query: The query to get input data from BigQuery. If not empty, BigQueryExampleGen will be used. csv_input_location: The location of the input data directory. Returns: A list of TFX components that constitutes an end-to-end test pipeline. """ if bool(bigquery_query) == bool(csv_input_location): raise ValueError( 'Exactly one example gen is expected. ', 'Please provide either bigquery_query or csv_input_location.') if bigquery_query: example_gen = big_query_example_gen_component.BigQueryExampleGen( query=bigquery_query) else: example_gen = components.CsvExampleGen(input_base=csv_input_location) statistics_gen = components.StatisticsGen( examples=example_gen.outputs['examples']) schema_gen = components.SchemaGen( statistics=statistics_gen.outputs['statistics'], infer_feature_shape=False) example_validator = components.ExampleValidator( statistics=statistics_gen.outputs['statistics'], schema=schema_gen.outputs['schema']) transform = components.Transform(examples=example_gen.outputs['examples'], schema=schema_gen.outputs['schema'], module_file=transform_module) latest_model_resolver = resolver.Resolver( strategy_class=latest_artifacts_resolver.LatestArtifactsResolver, model=channel.Channel(type=standard_artifacts.Model)).with_id( 'Resolver.latest_model_resolver') trainer = components.Trainer( custom_executor_spec=executor_spec.ExecutorClassSpec(Executor), transformed_examples=transform.outputs['transformed_examples'], schema=schema_gen.outputs['schema'], base_model=latest_model_resolver.outputs['model'], transform_graph=transform.outputs['transform_graph'], train_args=trainer_pb2.TrainArgs(num_steps=10), eval_args=trainer_pb2.EvalArgs(num_steps=5), module_file=trainer_module, ) # Get the latest blessed model for model validation. model_resolver = resolver.Resolver( strategy_class=latest_blessed_model_resolver. LatestBlessedModelResolver, model=channel.Channel(type=standard_artifacts.Model), model_blessing=channel.Channel( type=standard_artifacts.ModelBlessing)).with_id( 'Resolver.latest_blessed_model_resolver') # Set the TFMA config for Model Evaluation and Validation. eval_config = tfma.EvalConfig( model_specs=[tfma.ModelSpec(signature_name='eval')], metrics_specs=[ tfma.MetricsSpec( metrics=[tfma.MetricConfig(class_name='ExampleCount')], thresholds={ 'binary_accuracy': tfma.MetricThreshold( value_threshold=tfma.GenericValueThreshold( lower_bound={'value': 0.5}), change_threshold=tfma.GenericChangeThreshold( direction=tfma.MetricDirection.HIGHER_IS_BETTER, absolute={'value': -1e-10})) }) ], slicing_specs=[ tfma.SlicingSpec(), tfma.SlicingSpec(feature_keys=['trip_start_hour']) ]) evaluator = components.Evaluator( examples=example_gen.outputs['examples'], model=trainer.outputs['model'], baseline_model=model_resolver.outputs['model'], eval_config=eval_config) pusher = components.Pusher( model=trainer.outputs['model'], model_blessing=evaluator.outputs['blessing'], push_destination=pusher_pb2.PushDestination( filesystem=pusher_pb2.PushDestination.Filesystem( base_directory=os.path.join(pipeline_root, 'model_serving')))) return [ example_gen, statistics_gen, schema_gen, example_validator, transform, latest_model_resolver, trainer, model_resolver, evaluator, pusher ]
def testEncodeTemplatedExecutorContainerSpec(self): specs = executor_specs.TemplatedExecutorContainerSpec( image='image', command=[ self._text, self._input_value_placeholder, self._another_input_value_placeholder, self._input_uri_placeholder, self._output_uri_placeholder, self._concat_placeholder ]) encode_result = specs.encode( component_spec=TestComponentSpec( input_artifact=channel.Channel(type=standard_artifacts.Examples), output_artifact=channel.Channel(type=standard_artifacts.Model), input_parameter=42)) self.assertProtoEquals( """ image: "image" commands { value { string_value: "text" } } commands { operator { artifact_value_op { expression { operator { index_op { expression { placeholder { key: "input_artifact" } } } } } } } } commands { placeholder { type: EXEC_PROPERTY key: "input_parameter" } } commands { operator { artifact_uri_op { expression { operator { index_op { expression { placeholder { key: "input_artifact" } } index: 0 } } } } } } commands { operator { artifact_uri_op { expression { operator { index_op { expression { placeholder { type: OUTPUT_ARTIFACT key: "output_artifact" } } index: 0 } } } } } } commands { operator { concat_op { expressions { value { string_value: "text" } } expressions { operator { artifact_value_op { expression { operator { index_op { expression { placeholder { key: "input_artifact" } } index: 0 } } } } } } expressions { operator { artifact_uri_op { expression { operator { index_op { expression { placeholder { key: "input_artifact" } } index: 0 } } } } } } expressions { operator { artifact_uri_op { expression { operator { index_op { expression { placeholder { type: OUTPUT_ARTIFACT key: "output_artifact" } } index: 0 } } } } } } } } }""", encode_result)
def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text, module_file: Text, serving_model_dir: Text, metadata_path: Text, direct_num_workers: int) -> pipeline.Pipeline: input_data = external_input(examples_path) input_config = example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='train', pattern='train.tfrecord'), example_gen_pb2.Input.Split(name='eval', pattern='eval.tfrecord') ]) example_gen = ImportExampleGen(input=input_data, input_config=input_config) identify_examples = IdentifyExamples( orig_examples=example_gen.outputs['examples'], component_name=u'IdentifyExamples', id_feature_name=u'id') # Computes statistics over data for visualization and example validation. statistics_gen = StatisticsGen( examples=identify_examples.outputs["identified_examples"]) schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics']) # Performs anomaly detection based on statistics and data schema. validate_stats = ExampleValidator( statistics=statistics_gen.outputs['statistics'], schema=schema_gen.outputs['schema']) synthesize_graph = SynthesizeGraph( identified_examples=identify_examples.outputs['identified_examples'], component_name=u'SynthesizeGraph', similarity_threshold=0.99) transform = Transform( examples=identify_examples.outputs['identified_examples'], schema=schema_gen.outputs['schema'], # TODO(b/169218106): Remove transformed_examples kwargs after bugfix is released. transformed_examples=channel.Channel( type=standard_artifacts.Examples, artifacts=[standard_artifacts.Examples()]), module_file=_transform_module_file) # Augments training data with graph neighbors. graph_augmentation = GraphAugmentation( identified_examples=transform.outputs['transformed_examples'], synthesized_graph=synthesize_graph.outputs['synthesized_graph'], component_name=u'GraphAugmentation', num_neighbors=3) trainer = Trainer( module_file=_trainer_module_file, transformed_examples=graph_augmentation.outputs['augmented_examples'], schema=schema_gen.outputs['schema'], transform_graph=transform.outputs['transform_graph'], train_args=trainer_pb2.TrainArgs(num_steps=10000), eval_args=trainer_pb2.EvalArgs(num_steps=5000)) model_validator = ModelValidator(examples=example_gen.outputs['examples'], model=trainer.outputs['model']) pusher = Pusher(model=trainer.outputs['model'], model_blessing=model_validator.outputs['blessing'], push_destination=pusher_pb2.PushDestination( filesystem=pusher_pb2.PushDestination.Filesystem( base_directory=serving_model_dir))) return pipeline.Pipeline( pipeline_name=pipeline_name, pipeline_root=pipeline_root, components=[ example_gen, identify_examples, statistics_gen, schema_gen, validate_stats, synthesize_graph, transform, graph_augmentation, trainer, model_validator, pusher ], enable_cache=True, metadata_connection_config=metadata.sqlite_metadata_connection_config( metadata_path), beam_pipeline_args=['--direct_num_workers=%d' % direct_num_workers])
def testMismatchedUnionChannelType(self): chnl = channel.Channel(type=_MyType) another_channel = channel.Channel(type=_AnotherType) with self.assertRaises(TypeError): channel.union([chnl, another_channel])
def testStringTypeNameNotAllowed(self): with self.assertRaises(ValueError): channel.Channel('StringTypeName')
def testValidChannel(self): instance_a = _MyType() instance_b = _MyType() chnl = channel.Channel(_MyType).set_artifacts([instance_a, instance_b]) self.assertEqual(chnl.type_name, 'MyTypeName') self.assertCountEqual(chnl.get(), [instance_a, instance_b])