def setUp(self): super(TFServingRpcRequestBuilderTest, self).setUp() self._examples = standard_artifacts.Examples() self._examples.uri = _CSV_EXAMPLE_GEN_URI self._examples.split_names = artifact_utils.encode_split_names( ['train', 'eval'])
def testDriverRunFn(self): # Create input dir. self._input_base_path = os.path.join(self._test_dir, 'input_base') fileio.makedirs(self._input_base_path) # Fake previous outputs span1_v1_split1 = os.path.join(self._input_base_path, 'span01', 'split1', 'data') io_utils.write_string_file(span1_v1_split1, 'testing11') span1_v1_split2 = os.path.join(self._input_base_path, 'span01', 'split2', 'data') io_utils.write_string_file(span1_v1_split2, 'testing12') ir_driver = driver.Driver(self._mock_metadata) example = standard_artifacts.Examples() # Prepare output_dic example.uri = 'my_uri' # Will verify that this uri is not changed. output_dic = {utils.EXAMPLES_KEY: [example]} # Prepare output_dic exec_proterties. exec_properties = { utils.INPUT_BASE_KEY: self._input_base_path, utils.INPUT_CONFIG_KEY: json_format.MessageToJson(example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='s1', pattern='span{SPAN}/split1/*'), example_gen_pb2.Input.Split(name='s2', pattern='span{SPAN}/split2/*') ]), preserving_proto_field_name=True), } result = ir_driver.run( portable_data_types.ExecutionInfo(output_dict=output_dic, exec_properties=exec_properties)) # Assert exec_properties' values exec_properties = result.exec_properties self.assertEqual(exec_properties[utils.SPAN_PROPERTY_NAME].int_value, 1) updated_input_config = example_gen_pb2.Input() json_format.Parse(exec_properties[utils.INPUT_CONFIG_KEY].string_value, updated_input_config) self.assertProtoEquals( """ splits { name: "s1" pattern: "span01/split1/*" } splits { name: "s2" pattern: "span01/split2/*" }""", updated_input_config) self.assertRegex( exec_properties[utils.FINGERPRINT_PROPERTY_NAME].string_value, r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*' ) # Assert output_artifacts' values self.assertLen(result.output_artifacts[utils.EXAMPLES_KEY].artifacts, 1) output_example = result.output_artifacts[ utils.EXAMPLES_KEY].artifacts[0] self.assertEqual(output_example.uri, example.uri) self.assertEqual( output_example.custom_properties[ utils.SPAN_PROPERTY_NAME].string_value, '1') self.assertRegex( output_example.custom_properties[ utils.FINGERPRINT_PROPERTY_NAME].string_value, r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*' )
def testUpdateExecution(self): with metadata.Metadata(connection_config=self._connection_config) as m: contexts = m.register_pipeline_contexts_if_not_exists( self._pipeline_info) m.register_execution(input_artifacts={}, exec_properties={'k': 'v1'}, pipeline_info=self._pipeline_info, component_info=self._component_info, contexts=contexts) [execution] = m.store.get_executions_by_context( m.get_component_run_context(self._component_info).id) self.assertEqual(execution.properties['k'].string_value, 'v1') self.assertEqual(execution.properties['state'].string_value, metadata.EXECUTION_STATE_NEW) self.assertEqual(execution.last_known_state, metadata_store_pb2.Execution.RUNNING) m.update_execution( execution, self._component_info, input_artifacts={'input_a': [standard_artifacts.Examples()]}, exec_properties={'k': 'v2'}, contexts=contexts) [execution] = m.store.get_executions_by_context( m.get_component_run_context(self._component_info).id) self.assertEqual(execution.properties['k'].string_value, 'v2') self.assertEqual(execution.properties['state'].string_value, metadata.EXECUTION_STATE_NEW) self.assertEqual(execution.last_known_state, metadata_store_pb2.Execution.RUNNING) [event] = m.store.get_events_by_execution_ids([execution.id]) self.assertEqual(event.artifact_id, 1) [artifact] = m.store.get_artifacts_by_context( m.get_component_run_context(self._component_info).id) self.assertEqual(artifact.id, 1) aa = standard_artifacts.Examples() aa.set_mlmd_artifact(artifact) m.update_execution(execution, self._component_info, input_artifacts={'input_a': [aa]}) [event] = m.store.get_events_by_execution_ids([execution.id]) self.assertEqual(event.type, metadata_store_pb2.Event.INPUT) m.publish_execution( self._component_info, output_artifacts={'output': [standard_artifacts.Model()]}, exec_properties={'k': 'v3'}) [execution] = m.store.get_executions_by_context( m.get_component_run_context(self._component_info).id) self.assertEqual(execution.properties['k'].string_value, 'v3') self.assertEqual(execution.properties['state'].string_value, metadata.EXECUTION_STATE_COMPLETE) self.assertEqual(execution.last_known_state, metadata_store_pb2.Execution.COMPLETE) [_, event_b] = m.store.get_events_by_execution_ids([execution.id]) self.assertEqual(event_b.artifact_id, 2) self.assertEqual(event_b.type, metadata_store_pb2.Event.OUTPUT) [_, artifact_b] = m.store.get_artifacts_by_context( m.get_component_run_context(self._component_info).id) self.assertEqual(artifact_b.id, 2) self._check_artifact_state(m, artifact_b, ArtifactState.PUBLISHED)
def setUp(self): super().setUp() self._source_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata') self._output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) self.component_id = 'test_component' # Create input dict. self._examples = standard_artifacts.Examples() unlabelled_path = os.path.join(self._source_data_dir, 'csv_example_gen', 'unlabelled') self._examples.uri = os.path.join(self._output_data_dir, 'csv_example_gen') io_utils.copy_dir(unlabelled_path, os.path.join(self._examples.uri, 'Split-unlabelled')) io_utils.copy_dir( unlabelled_path, os.path.join(self._examples.uri, 'Split-unlabelled2')) self._examples.split_names = artifact_utils.encode_split_names( ['unlabelled', 'unlabelled2']) self._model = standard_artifacts.Model() self._model.uri = os.path.join(self._source_data_dir, 'trainer/current') self._model_blessing = standard_artifacts.ModelBlessing() self._model_blessing.uri = os.path.join(self._source_data_dir, 'model_validator/blessed') self._model_blessing.set_int_custom_property('blessed', 1) self._input_dict = { standard_component_specs.EXAMPLES_KEY: [self._examples], standard_component_specs.MODEL_KEY: [self._model], standard_component_specs.MODEL_BLESSING_KEY: [self._model_blessing], } # Create output dict. self._inference_result = standard_artifacts.InferenceResult() self._prediction_log_dir = os.path.join(self._output_data_dir, 'prediction_logs') self._inference_result.uri = self._prediction_log_dir self._output_examples = standard_artifacts.Examples() self._output_examples_dir = os.path.join(self._output_data_dir, 'output_examples') self._output_examples.uri = self._output_examples_dir self._output_dict_ir = { standard_component_specs.INFERENCE_RESULT_KEY: [self._inference_result], } self._output_dict_oe = { standard_component_specs.OUTPUT_EXAMPLES_KEY: [self._output_examples], } # Create exe properties. self._exec_properties = { standard_component_specs.DATA_SPEC_KEY: proto_utils.proto_to_json(bulk_inferrer_pb2.DataSpec()), standard_component_specs.MODEL_SPEC_KEY: proto_utils.proto_to_json(bulk_inferrer_pb2.ModelSpec()), 'component_id': self.component_id, } # Create context self._tmp_dir = os.path.join(self._output_data_dir, '.temp') self._context = executor.Executor.Context(tmp_dir=self._tmp_dir, unique_id='2')
def setUp(self): super().setUp() self._project_id = 'my-project' self._job_id = 'my-job-123' self._labels = ['label1', 'label2'] self._mock_api_client = mock.Mock() examples_artifact = standard_artifacts.Examples() examples_artifact.split_names = artifact_utils.encode_split_names( ['train', 'eval']) examples_artifact.uri = _EXAMPLE_LOCATION self._inputs = {'examples': [examples_artifact]} model_artifact = standard_artifacts.Model() model_artifact.uri = _MODEL_LOCATION self._outputs = {'model': [model_artifact]} training_input = { 'scaleTier': 'CUSTOM', 'region': 'us-central1', 'masterType': 'n1-standard-8', 'masterConfig': { 'imageUri': 'gcr.io/my-project/caip-training-test:latest' }, 'workerType': 'n1-standard-8', 'workerCount': 8, 'workerConfig': { 'imageUri': 'gcr.io/my-project/caip-training-test:latest' }, 'args': [ '--examples', placeholders.InputUriPlaceholder('examples'), '--n-steps', placeholders.InputValuePlaceholder('n_step'), '--model-dir', placeholders.OutputUriPlaceholder('model') ] } aip_training_config = { ai_platform_training_executor.PROJECT_CONFIG_KEY: self._project_id, ai_platform_training_executor.TRAINING_INPUT_CONFIG_KEY: training_input, ai_platform_training_executor.JOB_ID_CONFIG_KEY: self._job_id, ai_platform_training_executor.LABELS_CONFIG_KEY: self._labels, } self._exec_properties = { ai_platform_training_executor.CONFIG_KEY: json_utils.dumps(aip_training_config), 'n_step': 100 } resolved_training_input = copy.deepcopy(training_input) resolved_training_input['args'] = [ '--examples', _EXAMPLE_LOCATION, '--n-steps', '100', '--model-dir', _MODEL_LOCATION ] self._expected_job_spec = { 'jobId': self._job_id, 'trainingInput': resolved_training_input, 'labels': self._labels, }
def testPublishSuccessExecutionExecutorEditedOutputDict(self): # There is one artifact in the system provided output_dict, while there are # two artifacts in executor output. We expect that two artifacts are # published. with metadata.Metadata(connection_config=self._connection_config) as m: contexts = self._generate_contexts(m) execution_id = execution_publish_utils.register_execution( m, self._execution_type, contexts).id output_example = standard_artifacts.Examples() output_example.uri = '/original_path' executor_output = execution_result_pb2.ExecutorOutput() output_key = 'examples' text_format.Parse( """ uri: '/original_path/subdir_1' custom_properties { key: 'prop' value {int_value: 1} } """, executor_output.output_artifacts[output_key].artifacts.add()) text_format.Parse( """ uri: '/original_path/subdir_2' custom_properties { key: 'prop' value {int_value: 2} } """, executor_output.output_artifacts[output_key].artifacts.add()) output_dict = execution_publish_utils.publish_succeeded_execution( m, execution_id, contexts, {output_key: [output_example]}, executor_output) [execution] = m.store.get_executions() self.assertProtoPartiallyEquals(""" id: 1 type_id: 3 last_known_state: COMPLETE """, execution, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ]) artifacts = m.store.get_artifacts() self.assertLen(artifacts, 2) self.assertProtoPartiallyEquals(""" id: 1 type_id: 4 state: LIVE uri: '/original_path/subdir_1' custom_properties { key: 'prop' value {int_value: 1} }""", artifacts[0], ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ]) self.assertProtoPartiallyEquals(""" id: 2 type_id: 4 state: LIVE uri: '/original_path/subdir_2' custom_properties { key: 'prop' value {int_value: 2} }""", artifacts[1], ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ]) events = m.store.get_events_by_execution_ids([execution.id]) self.assertLen(events, 2) self.assertProtoPartiallyEquals( """ artifact_id: 1 execution_id: 1 path { steps { key: 'examples' } steps { index: 0 } } type: OUTPUT """, events[0], ignored_fields=['milliseconds_since_epoch']) self.assertProtoPartiallyEquals( """ artifact_id: 2 execution_id: 1 path { steps { key: 'examples' } steps { index: 1 } } type: OUTPUT """, events[1], ignored_fields=['milliseconds_since_epoch']) # Verifies the context-execution edges are set up. self.assertCountEqual([c.id for c in contexts], [ c.id for c in m.store.get_contexts_by_execution(execution.id) ]) for artifact_list in output_dict.values(): for output_example in artifact_list: self.assertCountEqual([c.id for c in contexts], [ c.id for c in m.store.get_contexts_by_artifact( output_example.id) ])
def __init__( self, examples: types.Channel = None, schema: types.Channel = None, module_file: Optional[Union[Text, data_types.RuntimeParameter]] = None, preprocessing_fn: Optional[Union[ Text, data_types.RuntimeParameter]] = None, transform_graph: Optional[types.Channel] = None, transformed_examples: Optional[types.Channel] = None, input_data: Optional[types.Channel] = None, instance_name: Optional[Text] = None, enable_cache: Optional[bool] = None): """Construct a Transform component. Args: examples: A Channel of type `standard_artifacts.Examples` (required). This should contain the two splits 'train' and 'eval'. schema: A Channel of type `standard_artifacts.Schema`. This should contain a single schema artifact. module_file: The file path to a python module file, from which the 'preprocessing_fn' function will be loaded. The function must have the following signature. def preprocessing_fn(inputs: Dict[Text, Any]) -> Dict[Text, Any]: ... where the values of input and returned Dict are either tf.Tensor or tf.SparseTensor. Exactly one of 'module_file' or 'preprocessing_fn' must be supplied. preprocessing_fn: The path to python function that implements a 'preprocessing_fn'. See 'module_file' for expected signature of the function. Exactly one of 'module_file' or 'preprocessing_fn' must be supplied. transform_graph: Optional output 'TransformPath' channel for output of 'tf.Transform', which includes an exported Tensorflow graph suitable for both training and serving; transformed_examples: Optional output 'ExamplesPath' channel for materialized transformed examples, which includes both 'train' and 'eval' splits. input_data: Backwards compatibility alias for the 'examples' argument. instance_name: Optional unique instance name. Necessary iff multiple transform components are declared in the same pipeline. enable_cache: Optional boolean to indicate if cache is enabled for the Transform component. If not specified, defaults to the value specified for pipeline's enable_cache parameter. Raises: ValueError: When both or neither of 'module_file' and 'preprocessing_fn' is supplied. """ if input_data: absl.logging.warning( 'The "input_data" argument to the Transform component has ' 'been renamed to "examples" and is deprecated. Please update your ' 'usage as support for this argument will be removed soon.') examples = input_data if bool(module_file) == bool(preprocessing_fn): raise ValueError( "Exactly one of 'module_file' or 'preprocessing_fn' must be supplied." ) transform_graph = transform_graph or types.Channel( type=standard_artifacts.TransformGraph, artifacts=[standard_artifacts.TransformGraph()]) if not transformed_examples: example_artifact = standard_artifacts.Examples() example_artifact.split_names = artifact_utils.encode_split_names( artifact.DEFAULT_EXAMPLE_SPLITS) transformed_examples = types.Channel( type=standard_artifacts.Examples, artifacts=[example_artifact]) spec = TransformSpec(examples=examples, schema=schema, module_file=module_file, preprocessing_fn=preprocessing_fn, transform_graph=transform_graph, transformed_examples=transformed_examples) super(Transform, self).__init__(spec=spec, instance_name=instance_name, enable_cache=enable_cache)
def testRunExecutor_with_InplaceUpdateExecutor(self): executor_sepc = text_format.Parse( """ class_path: "tfx.orchestration.portable.python_executor_operator_test.InplaceUpdateExecutor" """, local_deployment_config_pb2.ExecutableSpec.PythonClassExecutableSpec( )) operator = python_executor_operator.PythonExecutorOperator( executor_sepc) input_dict = {'input_key': [standard_artifacts.Examples()]} output_dict = {'output_key': [standard_artifacts.Model()]} exec_properties = { 'string': 'value', 'int': 1, 'float': 0.0, # This should not happen on production and will be # dropped. 'proto': execution_result_pb2.ExecutorOutput() } stateful_working_dir = os.path.join(self.tmp_dir, 'stateful_working_dir') executor_output_uri = os.path.join(self.tmp_dir, 'executor_output') executor_output = operator.run_executor( base_executor_operator.ExecutionInfo( input_dict=input_dict, output_dict=output_dict, exec_properties=exec_properties, stateful_working_dir=stateful_working_dir, executor_output_uri=executor_output_uri)) self.assertProtoPartiallyEquals( """ execution_properties { key: "float" value { double_value: 0.0 } } execution_properties { key: "int" value { int_value: 1 } } execution_properties { key: "string" value { string_value: "value" } } output_artifacts { key: "output_key" value { artifacts { custom_properties { key: "name" value { string_value: "my_model" } } } } }""", executor_output)
def _createExamples(self, span: int) -> standard_artifacts.Examples: artifact = standard_artifacts.Examples() artifact.uri = f'uri{span}' artifact.set_int_custom_property(utils.SPAN_PROPERTY_NAME, span) return artifact
def __init__( self, examples: types.Channel = None, schema: types.Channel = None, module_file: Optional[Union[Text, data_types.RuntimeParameter]] = None, preprocessing_fn: Optional[Union[Text, data_types.RuntimeParameter]] = None, transform_graph: Optional[types.Channel] = None, transformed_examples: Optional[types.Channel] = None, input_data: Optional[types.Channel] = None, analyzer_cache: Optional[types.Channel] = None, instance_name: Optional[Text] = None, materialize: bool = True, disable_analyzer_cache: bool = False, custom_config: Optional[Dict[Text, Any]] = None): """Construct a Transform component. Args: examples: A Channel of type `standard_artifacts.Examples` (required). This should contain the two splits 'train' and 'eval'. schema: A Channel of type `standard_artifacts.Schema`. This should contain a single schema artifact. module_file: The file path to a python module file, from which the 'preprocessing_fn' function will be loaded. Exactly one of 'module_file' or 'preprocessing_fn' must be supplied. The function needs to have the following signature: ``` def preprocessing_fn(inputs: Dict[Text, Any]) -> Dict[Text, Any]: ... ``` where the values of input and returned Dict are either tf.Tensor or tf.SparseTensor. If additional inputs are needed for preprocessing_fn, they can be passed in custom_config: ``` def preprocessing_fn(inputs: Dict[Text, Any], custom_config: Dict[Text, Any]) -> Dict[Text, Any]: ... ``` preprocessing_fn: The path to python function that implements a 'preprocessing_fn'. See 'module_file' for expected signature of the function. Exactly one of 'module_file' or 'preprocessing_fn' must be supplied. transform_graph: Optional output 'TransformPath' channel for output of 'tf.Transform', which includes an exported Tensorflow graph suitable for both training and serving; transformed_examples: Optional output 'ExamplesPath' channel for materialized transformed examples, which includes both 'train' and 'eval' splits. input_data: Backwards compatibility alias for the 'examples' argument. analyzer_cache: Optional input 'TransformCache' channel containing cached information from previous Transform runs. When provided, Transform will try use the cached calculation if possible. instance_name: Optional unique instance name. Necessary iff multiple transform components are declared in the same pipeline. materialize: If True, write transformed examples as an output. If False, `transformed_examples` must not be provided. disable_analyzer_cache: If False, Transform will use input cache if provided and write cache output. If True, `analyzer_cache` must not be provided. custom_config: A dict which contains additional parameters that will be passed to preprocessing_fn. Raises: ValueError: When both or neither of 'module_file' and 'preprocessing_fn' is supplied. """ if input_data: absl.logging.warning( 'The "input_data" argument to the Transform component has ' 'been renamed to "examples" and is deprecated. Please update your ' 'usage as support for this argument will be removed soon.') examples = input_data if bool(module_file) == bool(preprocessing_fn): raise ValueError( "Exactly one of 'module_file' or 'preprocessing_fn' must be supplied." ) transform_graph = transform_graph or types.Channel( type=standard_artifacts.TransformGraph, artifacts=[standard_artifacts.TransformGraph()]) if materialize and transformed_examples is None: transformed_examples = types.Channel( type=standard_artifacts.Examples, # TODO(b/161548528): remove the hardcode artifact. artifacts=[standard_artifacts.Examples()], matching_channel_name='examples') elif not materialize and transformed_examples is not None: raise ValueError( 'Must not specify transformed_examples when materialize is False.') if disable_analyzer_cache: updated_analyzer_cache = None if analyzer_cache: raise ValueError( '`analyzer_cache` is set when disable_analyzer_cache is True.') else: updated_analyzer_cache = types.Channel( type=standard_artifacts.TransformCache, artifacts=[standard_artifacts.TransformCache()]) spec = TransformSpec( examples=examples, schema=schema, module_file=module_file, preprocessing_fn=preprocessing_fn, transform_graph=transform_graph, transformed_examples=transformed_examples, analyzer_cache=analyzer_cache, updated_analyzer_cache=updated_analyzer_cache, custom_config=json.dumps(custom_config)) super(Transform, self).__init__(spec=spec, instance_name=instance_name)
def setUp(self): super(ExecutorTest, self).setUp() # Setup Mocks runner_patcher = mock.patch('tfx.components.infra_validator' '.model_server_runners.local_docker_runner' '.LocalDockerModelServerRunner') self.model_server = runner_patcher.start().return_value self.addCleanup(runner_patcher.stop) build_request_patcher = mock.patch( 'tfx.components.infra_validator.request_builder' '.build_requests') self.build_requests_mock = build_request_patcher.start() self.addCleanup(build_request_patcher.stop) self.model_server.client = mock.MagicMock() # Setup directories source_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata') base_output_dir = os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()) output_data_dir = os.path.join(base_output_dir, self._testMethodName) # Setup input_dict. model = standard_artifacts.Model() model.uri = os.path.join(source_data_dir, 'trainer', 'current') examples = standard_artifacts.Examples() examples.uri = os.path.join(source_data_dir, 'transform', 'transformed_examples', 'eval') examples.split_names = artifact_utils.encode_split_names(['eval']) self.input_dict = { 'model': [model], 'examples': [examples], } # Setup output_dict. self.blessing = standard_artifacts.InfraBlessing() self.blessing.uri = os.path.join(output_data_dir, 'blessing') self.output_dict = {'blessing': [self.blessing]} # Setup Context temp_dir = os.path.join(output_data_dir, '.temp') self.context = executor.Executor.Context(tmp_dir=temp_dir, unique_id='1') # Setup exec_properties self.exec_properties = { 'serving_spec': json.dumps({ 'tensorflow_serving': { 'tags': ['1.15.0'] }, 'local_docker': {} }), 'validation_spec': json.dumps({'max_loading_time_seconds': 10}), 'request_spec': json.dumps({ 'tensorflow_serving': { 'rpc_kind': 'CLASSIFY' }, 'max_examples': 10 }) }
def setUp(self): super().setUp() self._examples = standard_artifacts.Examples() self._examples.uri = _CSV_EXAMPLE_GEN_URI self._examples.split_names = artifact_utils.encode_split_names( ['train', 'eval'])
def examples(self): examples = standard_artifacts.Examples() return channel_utils.as_channel([examples])
def setUp(self): super(ExecutorTest, self).setUp() # Setup Mocks patcher = mock.patch.object(request_builder, 'build_requests') self.build_requests_mock = patcher.start() self.addCleanup(patcher.stop) # Setup directories source_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata') base_output_dir = os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()) output_data_dir = os.path.join(base_output_dir, self._testMethodName) # Setup input_dict. self._model = standard_artifacts.Model() self._model.uri = os.path.join(source_data_dir, 'trainer', 'current') self._model_path = path_utils.serving_model_path(self._model.uri) examples = standard_artifacts.Examples() examples.uri = os.path.join(source_data_dir, 'transform', 'transformed_examples', 'eval') examples.split_names = artifact_utils.encode_split_names(['eval']) self._input_dict = { 'model': [self._model], 'examples': [examples], } self._blessing = standard_artifacts.InfraBlessing() self._blessing.uri = os.path.join(output_data_dir, 'blessing') self._output_dict = {'blessing': [self._blessing]} temp_dir = os.path.join(output_data_dir, '.temp') self._context = executor.Executor.Context(tmp_dir=temp_dir, unique_id='1') self._serving_spec = _make_serving_spec({ 'tensorflow_serving': { 'tags': ['1.15.0'] }, 'local_docker': {}, 'model_name': 'chicago-taxi', }) self._serving_binary = serving_bins.parse_serving_binaries( self._serving_spec)[0] self._validation_spec = _make_validation_spec({ 'max_loading_time_seconds': 10, 'num_tries': 3 }) self._request_spec = _make_request_spec({ 'tensorflow_serving': { 'signature_names': ['serving_default'], }, 'num_examples': 1 }) self._exec_properties = { 'serving_spec': json_format.MessageToJson(self._serving_spec), 'validation_spec': json_format.MessageToJson(self._validation_spec), 'request_spec': json_format.MessageToJson(self._request_spec), }
def __init__(self, input_data: types.Channel = None, schema: types.Channel = None, module_file: Optional[Text] = None, preprocessing_fn: Optional[Text] = None, transform_output: Optional[types.Channel] = None, transformed_examples: Optional[types.Channel] = None, examples: Optional[types.Channel] = None, name: Optional[Text] = None): """Construct a Transform component. Args: input_data: A Channel of 'ExamplesPath' type (required). This should contain the two splits 'train' and 'eval'. schema: A Channel of 'SchemaPath' type. This should contain a single schema artifact. module_file: The file path to a python module file, from which the 'preprocessing_fn' function will be loaded. The function must have the following signature. def preprocessing_fn(inputs: Dict[Text, Any]) -> Dict[Text, Any]: ... where the values of input and returned Dict are either tf.Tensor or tf.SparseTensor. Exactly one of 'module_file' or 'preprocessing_fn' must be supplied. preprocessing_fn: The path to python function that implements a 'preprocessing_fn'. See 'module_file' for expected signature of the function. Exactly one of 'module_file' or 'preprocessing_fn' must be supplied. transform_output: Optional output 'TransformPath' channel for output of 'tf.Transform', which includes an exported Tensorflow graph suitable for both training and serving; transformed_examples: Optional output 'ExamplesPath' channel for materialized transformed examples, which includes both 'train' and 'eval' splits. examples: Forwards compatibility alias for the 'input_data' argument. name: Optional unique name. Necessary iff multiple transform components are declared in the same pipeline. Raises: ValueError: When both or neither of 'module_file' and 'preprocessing_fn' is supplied. """ input_data = input_data or examples if bool(module_file) == bool(preprocessing_fn): raise ValueError( "Exactly one of 'module_file' or 'preprocessing_fn' must be supplied." ) transform_output = transform_output or types.Channel( type=standard_artifacts.TransformGraph, artifacts=[standard_artifacts.TransformGraph()]) transformed_examples = transformed_examples or types.Channel( type=standard_artifacts.Examples, artifacts=[ standard_artifacts.Examples(split=split) for split in artifact.DEFAULT_EXAMPLE_SPLITS ]) spec = TransformSpec(input_data=input_data, schema=schema, module_file=module_file, preprocessing_fn=preprocessing_fn, transform_output=transform_output, transformed_examples=transformed_examples) super(Transform, self).__init__(spec=spec, name=name)
def testDo(self): source_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata') output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create input dict. train_examples = standard_artifacts.Examples(split='train') eval_examples = standard_artifacts.Examples(split='eval') eval_examples.uri = os.path.join(source_data_dir, 'csv_example_gen/eval/') model_exports = standard_artifacts.Model() model_exports.uri = os.path.join(source_data_dir, 'trainer/current/') input_dict = { 'examples': [train_examples, eval_examples], 'model_exports': [model_exports], } # Create output dict. eval_output = standard_artifacts.ModelEvaluation() eval_output.uri = os.path.join(output_data_dir, 'eval_output') output_dict = {'output': [eval_output]} # Create exec proterties. exec_properties = { 'feature_slicing_spec': json_format.MessageToJson(evaluator_pb2.FeatureSlicingSpec(specs=[ evaluator_pb2.SingleSlicingSpec( column_for_slicing=['trip_start_hour']), evaluator_pb2.SingleSlicingSpec( column_for_slicing=['trip_start_day', 'trip_miles']), ]), preserving_proto_field_name=True) } try: # Need to import the following module so that the fairness indicator # post-export metric is registered. This may raise an ImportError if the # currently-installed version of TFMA does not support fairness # indicators. import tensorflow_model_analysis.addons.fairness.post_export_metrics.fairness_indicators # pylint: disable=g-import-not-at-top, unused-variable exec_properties['fairness_indicator_thresholds'] = [ 0.1, 0.3, 0.5, 0.7, 0.9 ] except ImportError: absl.logging.warning( 'Not testing fairness indicators because a compatible TFMA version ' 'is not installed.') # Run executor. evaluator = executor.Executor() evaluator.Do(input_dict, output_dict, exec_properties) # Check evaluator outputs. self.assertTrue( # TODO(b/141490237): Update to only check eval_config.json after TFMA # released with corresponding change. tf.io.gfile.exists(os.path.join(eval_output.uri, 'eval_config')) or tf.io.gfile.exists( os.path.join(eval_output.uri, 'eval_config.json'))) self.assertTrue( tf.io.gfile.exists(os.path.join(eval_output.uri, 'metrics'))) self.assertTrue( tf.io.gfile.exists(os.path.join(eval_output.uri, 'plots')))
def setUp(self): super(ExecutorTest, self).setUp() self._source_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata') self._output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create input dict. e1 = standard_artifacts.Examples() e1.uri = os.path.join(self._source_data_dir, 'transform/transformed_examples') e1.split_names = artifact_utils.encode_split_names(['train', 'eval']) e2 = copy.deepcopy(e1) self._single_artifact = [e1] self._multiple_artifacts = [e1, e2] transform_output = standard_artifacts.TransformGraph() transform_output.uri = os.path.join(self._source_data_dir, 'transform/transform_graph') schema = standard_artifacts.Schema() schema.uri = os.path.join(self._source_data_dir, 'schema_gen') previous_model = standard_artifacts.Model() previous_model.uri = os.path.join(self._source_data_dir, 'trainer/previous') self._input_dict = { constants.EXAMPLES_KEY: self._single_artifact, constants.TRANSFORM_GRAPH_KEY: [transform_output], constants.SCHEMA_KEY: [schema], constants.BASE_MODEL_KEY: [previous_model] } # Create output dict. self._model_exports = standard_artifacts.Model() self._model_exports.uri = os.path.join(self._output_data_dir, 'model_export_path') self._model_run_exports = standard_artifacts.ModelRun() self._model_run_exports.uri = os.path.join(self._output_data_dir, 'model_run_path') self._output_dict = { constants.MODEL_KEY: [self._model_exports], constants.MODEL_RUN_KEY: [self._model_run_exports] } # Create exec properties skeleton. self._exec_properties = { 'train_args': json_format.MessageToJson(trainer_pb2.TrainArgs(num_steps=1000), preserving_proto_field_name=True), 'eval_args': json_format.MessageToJson(trainer_pb2.EvalArgs(num_steps=500), preserving_proto_field_name=True), 'warm_starting': False, } self._module_file = os.path.join(self._source_data_dir, 'module_file', 'trainer_module.py') self._trainer_fn = '%s.%s' % (trainer_module.trainer_fn.__module__, trainer_module.trainer_fn.__name__) # Executors for test. self._trainer_executor = executor.Executor() self._generic_trainer_executor = executor.GenericExecutor()
def __init__(self, query: Optional[Text] = None, beam_transform: beam.PTransform = None, bucket_name: Optional[Text] = None, output_schema: Optional[Text] = None, table_name: Optional[Text] = None, use_bigquery_source: Optional[Any] = False, input_config: Optional[example_gen_pb2.Input] = None, output_config: Optional[example_gen_pb2.Output] = None, example_artifacts: Optional[types.Channel] = None, instance_name: Optional[Text] = None): """Constructs a BigQueryExampleGen component. Args: query: BigQuery sql string, query result will be treated as a single split, can be overwritten by input_config. input_config: An example_gen_pb2.Input instance with Split.pattern as BigQuery sql string. If set, it overwrites the 'query' arg, and allows different queries per split. If any field is provided as a RuntimeParameter, input_config should be constructed as a dict with the same field names as Input proto message. beam_transform: beam.PTransform pipeline. Will be used to processed data ingested by the BigQuery query. bucket_name: string containing a GCS bucket name. Will be used as a temporary storage space to read query and pickle file. table_name: string containing the BigQuery output table name. use_bigquery_source: Whether to use BigQuerySource instead of experimental `ReadFromBigQuery` PTransform (required by the BigQueryExampleGen executor) input_config: An example_gen_pb2.Input instance with Split.pattern as BigQuery sql string. If set, it overwrites the 'query' arg, and allows different queries per split. If any field is provided as a RuntimeParameter, input_config should be constructed as a dict with the same field names as Input proto message. output_config: An example_gen_pb2.Output instance, providing output configuration. If unset, default splits will be 'train' and 'eval' with size 2:1. If any field is provided as a RuntimeParameter, input_config should be constructed as a dict with the same field names as Output proto message. example_artifacts: Optional channel of 'ExamplesPath' for output train and eval examples. instance_name: Optional unique instance name. Necessary if multiple BigQueryExampleGen components are declared in the same pipeline. Raises: RuntimeError: Only one of query and input_config should be set. """ # Configure inputs and outputs input_config = input_config or utils.make_default_input_config() output_config = output_config or utils.make_default_output_config( input_config) if not example_artifacts: example_artifacts = channel_utils.as_channel( [standard_artifacts.Examples()]) # Upload Beam Transform to a GCS Bucket beam_transform_uri = upload_beam_to_gcs(beam_transform, bucket_name) spec = TCGAPreprocessingSpec( # custom parameters query=query, output_schema=output_schema, table_name=table_name, use_bigquery_source=use_bigquery_source, # default parameters input_config=input_config, output_config=output_config, input_base=beam_transform_uri, # outputs examples=example_artifacts) super(TCGAPreprocessing, self).__init__(spec=spec, instance_name=instance_name)
def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text, module_file: Text, serving_model_dir: Text, metadata_path: Text, direct_num_workers: int) -> pipeline.Pipeline: input_data = external_input(examples_path) input_config = example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='train', pattern='train.tfrecord'), example_gen_pb2.Input.Split(name='eval', pattern='eval.tfrecord') ]) example_gen = ImportExampleGen(input=input_data, input_config=input_config) identify_examples = IdentifyExamples( orig_examples=example_gen.outputs['examples'], component_name=u'IdentifyExamples', id_feature_name=u'id') # Computes statistics over data for visualization and example validation. statistics_gen = StatisticsGen( examples=identify_examples.outputs["identified_examples"]) schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics']) # Performs anomaly detection based on statistics and data schema. validate_stats = ExampleValidator( statistics=statistics_gen.outputs['statistics'], schema=schema_gen.outputs['schema']) synthesize_graph = SynthesizeGraph( identified_examples=identify_examples.outputs['identified_examples'], component_name=u'SynthesizeGraph', similarity_threshold=0.99) transform = Transform( examples=identify_examples.outputs['identified_examples'], schema=schema_gen.outputs['schema'], # TODO(b/169218106): Remove transformed_examples kwargs after bugfix is released. transformed_examples=channel.Channel( type=standard_artifacts.Examples, artifacts=[standard_artifacts.Examples()]), module_file=_transform_module_file) # Augments training data with graph neighbors. graph_augmentation = GraphAugmentation( identified_examples=transform.outputs['transformed_examples'], synthesized_graph=synthesize_graph.outputs['synthesized_graph'], component_name=u'GraphAugmentation', num_neighbors=3) trainer = Trainer( module_file=_trainer_module_file, transformed_examples=graph_augmentation.outputs['augmented_examples'], schema=schema_gen.outputs['schema'], transform_graph=transform.outputs['transform_graph'], train_args=trainer_pb2.TrainArgs(num_steps=10000), eval_args=trainer_pb2.EvalArgs(num_steps=5000)) model_validator = ModelValidator(examples=example_gen.outputs['examples'], model=trainer.outputs['model']) pusher = Pusher(model=trainer.outputs['model'], model_blessing=model_validator.outputs['blessing'], push_destination=pusher_pb2.PushDestination( filesystem=pusher_pb2.PushDestination.Filesystem( base_directory=serving_model_dir))) return pipeline.Pipeline( pipeline_name=pipeline_name, pipeline_root=pipeline_root, components=[ example_gen, identify_examples, statistics_gen, schema_gen, validate_stats, synthesize_graph, transform, graph_augmentation, trainer, model_validator, pusher ], enable_cache=True, metadata_connection_config=metadata.sqlite_metadata_connection_config( metadata_path), beam_pipeline_args=['--direct_num_workers=%d' % direct_num_workers])
def testPublishSuccessfulExecution(self): with metadata.Metadata(connection_config=self._connection_config) as m: contexts = self._generate_contexts(m) execution_id = execution_publish_utils.register_execution( m, self._execution_type, contexts).id output_key = 'examples' output_example = standard_artifacts.Examples() output_example.uri = '/examples_uri' executor_output = execution_result_pb2.ExecutorOutput() text_format.Parse( """ uri: '/examples_uri' custom_properties { key: 'prop' value {int_value: 1} } """, executor_output.output_artifacts[output_key].artifacts.add()) output_dict = execution_publish_utils.publish_succeeded_execution( m, execution_id, contexts, {output_key: [output_example]}, executor_output) [execution] = m.store.get_executions() self.assertProtoPartiallyEquals(""" id: 1 type_id: 3 last_known_state: COMPLETE """, execution, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ]) [artifact] = m.store.get_artifacts() self.assertProtoPartiallyEquals(""" id: 1 type_id: 4 state: LIVE uri: '/examples_uri' custom_properties { key: 'prop' value {int_value: 1} }""", artifact, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ]) [event] = m.store.get_events_by_execution_ids([execution.id]) self.assertProtoPartiallyEquals( """ artifact_id: 1 execution_id: 1 path { steps { key: 'examples' } steps { index: 0 } } type: OUTPUT """, event, ignored_fields=['milliseconds_since_epoch']) # Verifies the context-execution edges are set up. self.assertCountEqual([c.id for c in contexts], [ c.id for c in m.store.get_contexts_by_execution(execution.id) ]) for artifact_list in output_dict.values(): for output_example in artifact_list: self.assertCountEqual([c.id for c in contexts], [ c.id for c in m.store.get_contexts_by_artifact( output_example.id) ])
def __init__( self, # TODO(b/159467778): deprecate this, use input_base instead. input: Optional[types.Channel] = None, # pylint: disable=redefined-builtin input_base: Optional[Text] = None, input_config: Optional[Union[example_gen_pb2.Input, Dict[Text, Any]]] = None, output_config: Optional[Union[example_gen_pb2.Output, Dict[Text, Any]]] = None, custom_config: Optional[Union[example_gen_pb2.CustomConfig, Dict[Text, Any]]] = None, output_data_format: Optional[int] = example_gen_pb2. FORMAT_TF_EXAMPLE, example_artifacts: Optional[types.Channel] = None, custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None, instance_name: Optional[Text] = None): """Construct a FileBasedExampleGen component. Args: input: A Channel of type `standard_artifacts.ExternalArtifact`, which includes one artifact whose uri is an external directory containing the data files. (Deprecated by input_base) input_base: an external directory containing the data files. input_config: An [`example_gen_pb2.Input`](https://github.com/tensorflow/tfx/blob/master/tfx/proto/example_gen.proto) instance, providing input configuration. If unset, input files will be treated as a single split. output_config: An example_gen_pb2.Output instance, providing the output configuration. If unset, default splits will be 'train' and 'eval' with size 2:1. custom_config: An optional example_gen_pb2.CustomConfig instance, providing custom configuration for executor. output_data_format: Payload format of generated data in output artifact, one of example_gen_pb2.PayloadFormat enum. example_artifacts: Channel of 'ExamplesPath' for output train and eval examples. custom_executor_spec: Optional custom executor spec overriding the default executor spec specified in the component attribute. instance_name: Optional unique instance name. Required only if multiple ExampleGen components are declared in the same pipeline. """ if input: logging.warning( 'The "input" argument to the ExampleGen component has been ' 'deprecated by "input_base". Please update your usage as support for ' 'this argument will be removed soon.') input_base = artifact_utils.get_single_uri(list(input.get())) # Configure inputs and outputs. input_config = input_config or utils.make_default_input_config() output_config = output_config or utils.make_default_output_config( input_config) if not example_artifacts: artifact = standard_artifacts.Examples() artifact.split_names = artifact_utils.encode_split_names( utils.generate_output_split_names(input_config, output_config)) example_artifacts = channel_utils.as_channel([artifact]) spec = FileBasedExampleGenSpec(input_base=input_base, input_config=input_config, output_config=output_config, custom_config=custom_config, output_data_format=output_data_format, examples=example_artifacts) super(FileBasedExampleGen, self).__init__(spec=spec, custom_executor_spec=custom_executor_spec, instance_name=instance_name)
def testTrainerFn(self): temp_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) schema_file = os.path.join(self._testdata_path, 'schema_gen/schema.pbtxt') data_accessor = DataAccessor(tf_dataset_factory=tfxio_utils. get_tf_dataset_factory_from_artifact( [standard_artifacts.Examples()], []), record_batch_factory=None) trainer_fn_args = trainer_executor.TrainerFnArgs( train_files=os.path.join( self._testdata_path, 'transform/transformed_examples/train/*.gz'), transform_output=os.path.join(self._testdata_path, 'transform/transform_graph'), serving_model_dir=os.path.join(temp_dir, 'serving_model_dir'), eval_files=os.path.join( self._testdata_path, 'transform/transformed_examples/eval/*.gz'), schema_file=schema_file, train_steps=1, eval_steps=1, base_model=None, data_accessor=data_accessor) schema = io_utils.parse_pbtxt_file(schema_file, schema_pb2.Schema()) training_spec = taxi_utils.trainer_fn(trainer_fn_args, schema) estimator = training_spec['estimator'] train_spec = training_spec['train_spec'] eval_spec = training_spec['eval_spec'] eval_input_receiver_fn = training_spec['eval_input_receiver_fn'] self.assertIsInstance(estimator, tf.estimator.DNNLinearCombinedClassifier) self.assertIsInstance(train_spec, tf.estimator.TrainSpec) self.assertIsInstance(eval_spec, tf.estimator.EvalSpec) self.assertIsInstance(eval_input_receiver_fn, types.FunctionType) # Test keep_max_checkpoint in RunConfig self.assertGreater(estimator._config.keep_checkpoint_max, 1) # Train for one step, then eval for one step. eval_result, exports = tf.estimator.train_and_evaluate( estimator, train_spec, eval_spec) self.assertGreater(eval_result['loss'], 0.0) self.assertEqual(len(exports), 1) self.assertGreaterEqual(len(fileio.listdir(exports[0])), 1) # Export the eval saved model. eval_savedmodel_path = tfma.export.export_eval_savedmodel( estimator=estimator, export_dir_base=path_utils.eval_model_dir(temp_dir), eval_input_receiver_fn=eval_input_receiver_fn) self.assertGreaterEqual(len(fileio.listdir(eval_savedmodel_path)), 1) # Test exported serving graph. with tf.compat.v1.Session() as sess: metagraph_def = tf.compat.v1.saved_model.loader.load( sess, [tf.saved_model.SERVING], exports[0]) self.assertIsInstance(metagraph_def, tf.compat.v1.MetaGraphDef)
def setUp(self): super(TunerTest, self).setUp() self.examples = channel_utils.as_channel([standard_artifacts.Examples()]) self.schema = channel_utils.as_channel([standard_artifacts.Schema()])
def testExecution(self): with metadata.Metadata(connection_config=self._connection_config) as m: contexts = m.register_pipeline_contexts_if_not_exists( self._pipeline_info) # Test prepare_execution. exec_properties = {'arg_one': 1} input_artifact = standard_artifacts.Examples() output_artifact = standard_artifacts.Examples() input_artifacts = {'input': [input_artifact]} output_artifacts = {'output': [output_artifact]} m.register_execution(input_artifacts=input_artifacts, exec_properties=exec_properties, pipeline_info=self._pipeline_info, component_info=self._component_info, contexts=contexts) [execution] = m.store.get_executions_by_context(contexts[0].id) # Skip verifying time sensitive fields. execution.ClearField('create_time_since_epoch') execution.ClearField('last_update_time_since_epoch') self.assertProtoEquals( """ id: 1 type_id: 3 properties { key: "state" value { string_value: "new" } } properties { key: "pipeline_name" value { string_value: "my_pipeline" } } properties { key: "pipeline_root" value { string_value: "/tmp" } } properties { key: "run_id" value { string_value: "my_run_id" } } properties { key: "component_id" value { string_value: "my_component" } } properties { key: "arg_one" value { string_value: "1" } }""", execution) # Test publish_execution. m.publish_execution(component_info=self._component_info, output_artifacts=output_artifacts) # Make sure artifacts in output_dict are published. self.assertEqual(ArtifactState.PUBLISHED, output_artifact.state) # Make sure execution state are changed. [execution] = m.store.get_executions_by_id([execution.id]) self.assertEqual(metadata.EXECUTION_STATE_COMPLETE, execution.properties['state'].string_value) # Make sure events are published. events = m.store.get_events_by_execution_ids([execution.id]) self.assertEqual(2, len(events)) self.assertEqual(input_artifact.id, events[0].artifact_id) self.assertEqual(metadata_store_pb2.Event.INPUT, events[0].type) self.assertProtoEquals( """ steps { key: "input" } steps { index: 0 }""", events[0].path) self.assertEqual(output_artifact.id, events[1].artifact_id) self.assertEqual(metadata_store_pb2.Event.OUTPUT, events[1].type) self.assertProtoEquals( """ steps { key: "output" } steps { index: 0 }""", events[1].path)
def testGetArtifactsDict(self): with metadata.Metadata(connection_config=self._connection_config) as m: # Create and shuffle a few artifacts. The shuffled order should be # retained in the output of `execution_lib.get_artifacts_dict`. input_artifact_keys = ('input1', 'input2', 'input3') input_artifacts_dict = collections.OrderedDict() for input_key in input_artifact_keys: input_examples = [] for i in range(10): input_example = standard_artifacts.Examples() input_example.uri = f'{input_key}/example{i}' input_example.type_id = common_utils.register_type_if_not_exist( m, input_example.artifact_type).id input_examples.append(input_example) random.shuffle(input_examples) input_artifacts_dict[input_key] = input_examples output_models = [] for i in range(8): output_model = standard_artifacts.Model() output_model.uri = f'model{i}' output_model.type_id = common_utils.register_type_if_not_exist( m, output_model.artifact_type).id output_models.append(output_model) random.shuffle(output_models) output_artifacts_dict = {'model': output_models} # Store input artifacts only. Outputs will be saved in put_execution(). input_mlmd_artifacts = [ a.mlmd_artifact for a in itertools.chain(*input_artifacts_dict.values()) ] artifact_ids = m.store.put_artifacts(input_mlmd_artifacts) for artifact_id, mlmd_artifact in zip(artifact_ids, input_mlmd_artifacts): mlmd_artifact.id = artifact_id execution = execution_lib.prepare_execution( m, metadata_store_pb2.ExecutionType(name='my_execution_type'), state=metadata_store_pb2.Execution.RUNNING) contexts = self._generate_contexts(m) # Change the order of the OrderedDict to shuffle the order of input keys. input_artifacts_dict.move_to_end('input1') execution = execution_lib.put_execution( m, execution, contexts, input_artifacts=input_artifacts_dict, output_artifacts=output_artifacts_dict) # Verify that the same artifacts are returned in the correct order. artifacts_dict = execution_lib.get_artifacts_dict( m, execution.id, [metadata_store_pb2.Event.INPUT]) self.assertEqual(set(input_artifact_keys), set(artifacts_dict.keys())) for key in artifacts_dict: self.assertEqual([ex.uri for ex in input_artifacts_dict[key]], [a.uri for a in artifacts_dict[key]], f'for key={key}') artifacts_dict = execution_lib.get_artifacts_dict( m, execution.id, [metadata_store_pb2.Event.OUTPUT]) self.assertEqual({'model'}, set(artifacts_dict.keys())) self.assertEqual([model.uri for model in output_models], [a.uri for a in artifacts_dict['model']])
def testGetCachedOutputArtifacts(self): # Output artifacts that will be used by the first execution with the same # cache key. output_model_one = standard_artifacts.Model() output_model_one.uri = 'model_one' output_model_two = standard_artifacts.Model() output_model_two.uri = 'model_two' output_example_one = standard_artifacts.Examples() output_example_one.uri = 'example_one' # Output artifacts that will be used by the second execution with the same # cache key. output_model_three = standard_artifacts.Model() output_model_three.uri = 'model_three' output_model_four = standard_artifacts.Model() output_model_four.uri = 'model_four' output_example_two = standard_artifacts.Examples() output_example_two.uri = 'example_two' output_models_key = 'output_models' output_examples_key = 'output_examples' with metadata.Metadata(connection_config=self._connection_config) as m: cache_context = context_lib.register_context_if_not_exists( m, context_lib.CONTEXT_TYPE_EXECUTION_CACHE, 'cache_key') execution_one = execution_publish_utils.register_execution( m, metadata_store_pb2.ExecutionType(name='my_type'), [cache_context]) execution_publish_utils.publish_succeeded_execution( m, execution_one.id, [cache_context], output_artifacts={ output_models_key: [output_model_one, output_model_two], output_examples_key: [output_example_one] }) execution_two = execution_publish_utils.register_execution( m, metadata_store_pb2.ExecutionType(name='my_type'), [cache_context]) execution_publish_utils.publish_succeeded_execution( m, execution_two.id, [cache_context], output_artifacts={ output_models_key: [output_model_three, output_model_four], output_examples_key: [output_example_two] }) # The cached output got should be the artifacts produced by the most # recent execution under the given cache context. cached_output = cache_utils.get_cached_outputs(m, cache_context) self.assertLen(cached_output, 2) self.assertLen(cached_output[output_models_key], 2) self.assertLen(cached_output[output_examples_key], 1) self.assertProtoPartiallyEquals( cached_output[output_models_key][0].mlmd_artifact, output_model_three.mlmd_artifact, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ]) self.assertProtoPartiallyEquals( cached_output[output_models_key][1].mlmd_artifact, output_model_four.mlmd_artifact, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ]) self.assertProtoPartiallyEquals( cached_output[output_examples_key][0].mlmd_artifact, output_example_two.mlmd_artifact, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ])
def __init__( self, input: types.Channel = None, # pylint: disable=redefined-builtin input_config: Optional[Union[example_gen_pb2.Input, Dict[Text, Any]]] = None, output_config: Optional[Union[example_gen_pb2.Output, Dict[Text, Any]]] = None, custom_config: Optional[Union[example_gen_pb2.CustomConfig, Dict[Text, Any]]] = None, example_artifacts: Optional[types.Channel] = None, custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None, input_base: Optional[types.Channel] = None, instance_name: Optional[Text] = None): """Construct a FileBasedExampleGen component. Args: input: A Channel of type `standard_artifacts.ExternalArtifact`, which includes one artifact whose uri is an external directory containing the data files. _required_ input_config: An [`example_gen_pb2.Input`](https://github.com/tensorflow/tfx/blob/master/tfx/proto/example_gen.proto) instance, providing input configuration. If unset, the files under input_base will be treated as a single dataset. output_config: An example_gen_pb2.Output instance, providing the output configuration. If unset, default splits will be 'train' and 'eval' with size 2:1. custom_config: An optional example_gen_pb2.CustomConfig instance, providing custom configuration for executor. example_artifacts: Channel of 'ExamplesPath' for output train and eval examples. custom_executor_spec: Optional custom executor spec overriding the default executor spec specified in the component attribute. input_base: Backwards compatibility alias for the 'input' argument. instance_name: Optional unique instance name. Required only if multiple ExampleGen components are declared in the same pipeline. Either `input_base` or `input` must be present in the input arguments. """ if input_base: absl.logging.warning( 'The "input_base" argument to the ExampleGen component has ' 'been renamed to "input" and is deprecated. Please update your ' 'usage as support for this argument will be removed soon.') input = input_base # Configure inputs and outputs. input_config = input_config or utils.make_default_input_config() output_config = output_config or utils.make_default_output_config( input_config) if not example_artifacts: artifact = standard_artifacts.Examples() artifact.split_names = artifact_utils.encode_split_names( utils.generate_output_split_names(input_config, output_config)) example_artifacts = channel_utils.as_channel([artifact]) spec = FileBasedExampleGenSpec(input=input, input_config=input_config, output_config=output_config, custom_config=custom_config, examples=example_artifacts) super(FileBasedExampleGen, self).__init__(spec=spec, custom_executor_spec=custom_executor_spec, instance_name=instance_name)