def setUp(self): super(ExecutorTest, self).setUp() # Create input_dict. self._input_data_dir = os.path.join(os.path.dirname(__file__), 'testdata') examples = standard_artifacts.Examples() examples.uri = os.path.join(self._input_data_dir, 'example_gen') examples.split_names = artifact_utils.encode_split_names( ['train', 'eval']) schema_artifact = standard_artifacts.Schema() schema_artifact.uri = os.path.join(self._input_data_dir, 'schema_gen') self._input_dict = { standard_component_specs.EXAMPLES_KEY: [examples], standard_component_specs.SCHEMA_KEY: [schema_artifact], } # Create output_dict. output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', tempfile.mkdtemp(dir=flags.FLAGS.test_tmpdir)), self._testMethodName) self._transformed_output = standard_artifacts.TransformGraph() self._transformed_output.uri = os.path.join(output_data_dir, 'transformed_output') self._transformed_examples = standard_artifacts.Examples() self._transformed_examples.uri = output_data_dir self._transformed_examples.split_names = artifact_utils.encode_split_names( ['train', 'eval']) temp_path_output = _TempPath() temp_path_output.uri = tempfile.mkdtemp() self._output_dict = { standard_component_specs.TRANSFORM_GRAPH_KEY: [self._transformed_output], standard_component_specs.TRANSFORMED_EXAMPLES_KEY: [self._transformed_examples], tfx_executor.TEMP_PATH_KEY: [temp_path_output], } # Create exec properties. self._exec_properties = { 'custom_config': json.dumps({'problem_statement_path': '/some/fake/path'}) }
def testGetFromSplitsMultipleArtifacts(self): """Test split retrieval utility on a multiple list of split Artifacts.""" artifacts = [ standard_artifacts.Examples(), standard_artifacts.Examples() ] artifacts[0].uri = '/tmp1' artifacts[0].split_names = artifact_utils.encode_split_names( ['train', 'eval']) artifacts[1].uri = '/tmp2' artifacts[1].split_names = artifact_utils.encode_split_names( ['train', 'eval']) # When creating new splits, use 'Split-<split_name>' format. self.assertEqual(['/tmp1/Split-train', '/tmp2/Split-train'], artifact_utils.get_split_uris(artifacts, 'train')) self.assertEqual(['/tmp1/Split-eval', '/tmp2/Split-eval'], artifact_utils.get_split_uris(artifacts, 'eval')) # When reading artifacts without version. artifacts[0].mlmd_artifact.state = metadata_store_pb2.Artifact.LIVE artifacts[1].mlmd_artifact.state = metadata_store_pb2.Artifact.LIVE self.assertEqual(['/tmp1/train', '/tmp2/train'], artifact_utils.get_split_uris(artifacts, 'train')) self.assertEqual(['/tmp1/eval', '/tmp2/eval'], artifact_utils.get_split_uris(artifacts, 'eval')) # When reading artifacts with old version. artifacts[0].set_string_custom_property( artifact_utils.ARTIFACT_TFX_VERSION_CUSTOM_PROPERTY_KEY, '0.1') artifacts[1].set_string_custom_property( artifact_utils.ARTIFACT_TFX_VERSION_CUSTOM_PROPERTY_KEY, '0.1') self.assertEqual(['/tmp1/train', '/tmp2/train'], artifact_utils.get_split_uris(artifacts, 'train')) self.assertEqual(['/tmp1/eval', '/tmp2/eval'], artifact_utils.get_split_uris(artifacts, 'eval')) # When reading artifacts with new version. artifacts[0].set_string_custom_property( artifact_utils.ARTIFACT_TFX_VERSION_CUSTOM_PROPERTY_KEY, artifact_utils._ARTIFACT_VERSION_FOR_SPLIT_UPDATE) artifacts[1].set_string_custom_property( artifact_utils.ARTIFACT_TFX_VERSION_CUSTOM_PROPERTY_KEY, artifact_utils._ARTIFACT_VERSION_FOR_SPLIT_UPDATE) self.assertEqual(['/tmp1/Split-train', '/tmp2/Split-train'], artifact_utils.get_split_uris(artifacts, 'train')) self.assertEqual(['/tmp1/Split-eval', '/tmp2/Split-eval'], artifact_utils.get_split_uris(artifacts, 'eval'))
def testDoLegacySingleEvalSavedModelWFairness(self, exec_properties): source_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata') output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create input dict. examples = standard_artifacts.Examples() examples.uri = os.path.join(source_data_dir, 'csv_example_gen') examples.split_names = artifact_utils.encode_split_names(['train', 'eval']) model = standard_artifacts.Model() model.uri = os.path.join(source_data_dir, 'trainer/current') input_dict = { constants.EXAMPLES_KEY: [examples], constants.MODEL_KEY: [model], } # Create output dict. eval_output = standard_artifacts.ModelEvaluation() eval_output.uri = os.path.join(output_data_dir, 'eval_output') blessing_output = standard_artifacts.ModelBlessing() blessing_output.uri = os.path.join(output_data_dir, 'blessing_output') output_dict = { constants.EVALUATION_KEY: [eval_output], constants.BLESSING_KEY: [blessing_output], } try: # Need to import the following module so that the fairness indicator # post-export metric is registered. This may raise an ImportError if the # currently-installed version of TFMA does not support fairness # indicators. import tensorflow_model_analysis.addons.fairness.post_export_metrics.fairness_indicators # pylint: disable=g-import-not-at-top, unused-variable exec_properties['fairness_indicator_thresholds'] = [ 0.1, 0.3, 0.5, 0.7, 0.9 ] except ImportError: logging.warning( 'Not testing fairness indicators because a compatible TFMA version ' 'is not installed.') # List needs to be serialized before being passed into Do function. exec_properties[constants.EXAMPLE_SPLITS_KEY] = json_utils.dumps(None) # Run executor. evaluator = executor.Executor() evaluator.Do(input_dict, output_dict, exec_properties) # Check evaluator outputs. self.assertTrue( fileio.exists(os.path.join(eval_output.uri, 'eval_config.json'))) self.assertTrue(fileio.exists(os.path.join(eval_output.uri, 'metrics'))) self.assertTrue(fileio.exists(os.path.join(eval_output.uri, 'plots'))) self.assertFalse( fileio.exists(os.path.join(blessing_output.uri, 'BLESSED')))
def __init__( self, input: types.Channel = None, # pylint: disable=redefined-builtin input_config: Optional[Union[example_gen_pb2.Input, Dict[Text, Any]]] = None, output_config: Optional[Union[example_gen_pb2.Output, Dict[Text, Any]]] = None, custom_config: Optional[Union[example_gen_pb2.CustomConfig, Dict[Text, Any]]] = None, example_artifacts: Optional[types.Channel] = None, custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None, input_base: Optional[types.Channel] = None, instance_name: Optional[Text] = None): """Construct a FileBasedExampleGen component. Args: input: A Channel of type `standard_artifacts.ExternalArtifact`, which includes one artifact whose uri is an external directory containing the data files. _required_ input_config: An [`example_gen_pb2.Input`](https://github.com/tensorflow/tfx/blob/master/tfx/proto/example_gen.proto) instance, providing input configuration. If unset, the files under input_base will be treated as a single dataset. output_config: An example_gen_pb2.Output instance, providing the output configuration. If unset, default splits will be 'train' and 'eval' with size 2:1. custom_config: An optional example_gen_pb2.CustomConfig instance, providing custom configuration for executor. example_artifacts: Channel of 'ExamplesPath' for output train and eval examples. custom_executor_spec: Optional custom executor spec overriding the default executor spec specified in the component attribute. input_base: Backwards compatibility alias for the 'input' argument. instance_name: Optional unique instance name. Required only if multiple ExampleGen components are declared in the same pipeline. Either `input_base` or `input` must be present in the input arguments. """ input = input or input_base # Configure inputs and outputs. input_config = input_config or utils.make_default_input_config() output_config = output_config or utils.make_default_output_config( input_config) if not example_artifacts: artifact = standard_artifacts.Examples() artifact.split_names = artifact_utils.encode_split_names( utils.generate_output_split_names(input_config, output_config)) example_artifacts = channel_utils.as_channel([artifact]) spec = FileBasedExampleGenSpec(input_base=input, input_config=input_config, output_config=output_config, custom_config=custom_config, examples=example_artifacts) super(FileBasedExampleGen, self).__init__(spec=spec, custom_executor_spec=custom_executor_spec, instance_name=instance_name)
def testDoValidation(self, exec_properties, blessed, has_baseline): source_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata') output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create input dict. examples = standard_artifacts.Examples() examples.uri = os.path.join(source_data_dir, 'csv_example_gen') examples.split_names = artifact_utils.encode_split_names(['train', 'eval']) model = standard_artifacts.Model() baseline_model = standard_artifacts.Model() model.uri = os.path.join(source_data_dir, 'trainer/current') baseline_model.uri = os.path.join(source_data_dir, 'trainer/previous/') blessing_output = standard_artifacts.ModelBlessing() blessing_output.uri = os.path.join(output_data_dir, 'blessing_output') schema = standard_artifacts.Schema() schema.uri = os.path.join(source_data_dir, 'schema_gen') input_dict = { constants.EXAMPLES_KEY: [examples], constants.MODEL_KEY: [model], constants.SCHEMA_KEY: [schema], } if has_baseline: input_dict[constants.BASELINE_MODEL_KEY] = [baseline_model] # Create output dict. eval_output = standard_artifacts.ModelEvaluation() eval_output.uri = os.path.join(output_data_dir, 'eval_output') blessing_output = standard_artifacts.ModelBlessing() blessing_output.uri = os.path.join(output_data_dir, 'blessing_output') output_dict = { constants.EVALUATION_KEY: [eval_output], constants.BLESSING_KEY: [blessing_output], } # List needs to be serialized before being passed into Do function. exec_properties[constants.EXAMPLE_SPLITS_KEY] = json_utils.dumps(None) # Run executor. evaluator = executor.Executor() evaluator.Do(input_dict, output_dict, exec_properties) # Check evaluator outputs. self.assertTrue( fileio.exists(os.path.join(eval_output.uri, 'eval_config.json'))) self.assertTrue(fileio.exists(os.path.join(eval_output.uri, 'metrics'))) self.assertTrue(fileio.exists(os.path.join(eval_output.uri, 'plots'))) self.assertTrue(fileio.exists(os.path.join(eval_output.uri, 'validations'))) if blessed: self.assertTrue( fileio.exists(os.path.join(blessing_output.uri, 'BLESSED'))) else: self.assertTrue( fileio.exists(os.path.join(blessing_output.uri, 'NOT_BLESSED')))
def setUp(self): super(ExecutorTest, self).setUp() self._source_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata') self._output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create input dict. examples = standard_artifacts.Examples() examples.uri = os.path.join(self._source_data_dir, 'transform/transformed_examples') examples.split_names = artifact_utils.encode_split_names(['train', 'eval']) transform_output = standard_artifacts.TransformGraph() transform_output.uri = os.path.join(self._source_data_dir, 'transform/transform_output') schema = standard_artifacts.Schema() schema.uri = os.path.join(self._source_data_dir, 'schema_gen') previous_model = standard_artifacts.Model() previous_model.uri = os.path.join(self._source_data_dir, 'trainer/previous') self._input_dict = { executor.EXAMPLES_KEY: [examples], executor.TRANSFORM_GRAPH_KEY: [transform_output], executor.SCHEMA_KEY: [schema], executor.BASE_MODEL_KEY: [previous_model] } # Create output dict. self._model_exports = standard_artifacts.Model() self._model_exports.uri = os.path.join(self._output_data_dir, 'model_export_path') self._output_dict = {executor.OUTPUT_MODEL_KEY: [self._model_exports]} # Create exec properties skeleton. self._exec_properties = { 'train_args': json_format.MessageToJson( trainer_pb2.TrainArgs(num_steps=1000), preserving_proto_field_name=True), 'eval_args': json_format.MessageToJson( trainer_pb2.EvalArgs(num_steps=500), preserving_proto_field_name=True), 'warm_starting': False, } self._module_file = os.path.join(self._source_data_dir, 'module_file', 'trainer_module.py') self._trainer_fn = '%s.%s' % (trainer_module.trainer_fn.__module__, trainer_module.trainer_fn.__name__) # Executors for test. self._trainer_executor = executor.Executor() self._generic_trainer_executor = executor.GenericExecutor()
def testEnableCache(self): examples = standard_artifacts.Examples() examples.split_names = artifact_utils.encode_split_names( ['train', 'eval']) statistics_gen_1 = component.StatisticsGen( examples=channel_utils.as_channel([examples])) self.assertEqual(None, statistics_gen_1.enable_cache) statistics_gen_2 = component.StatisticsGen( examples=channel_utils.as_channel([examples]), enable_cache=True) self.assertEqual(True, statistics_gen_2.enable_cache)
def testDoWithTwoSchemas(self): source_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata') output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) tf.io.gfile.makedirs(output_data_dir) # Create input dict. examples = standard_artifacts.Examples() examples.uri = os.path.join(source_data_dir, 'csv_example_gen') examples.split_names = artifact_utils.encode_split_names( ['train', 'eval']) schema = standard_artifacts.Schema() schema.uri = os.path.join(source_data_dir, 'schema_gen') input_dict = { executor.EXAMPLES_KEY: [examples], executor.SCHEMA_KEY: [schema] } exec_properties = { executor.STATS_OPTIONS_JSON_KEY: tfdv.StatsOptions(label_feature='company', schema=schema_pb2.Schema()).to_json(), } # Create output dict. stats = standard_artifacts.ExampleStatistics() stats.uri = output_data_dir stats.split_names = artifact_utils.encode_split_names( ['train', 'eval']) output_dict = { executor.STATISTICS_KEY: [stats], } # Run executor. stats_gen_executor = executor.Executor() with self.assertRaises(ValueError): stats_gen_executor.Do(input_dict, output_dict, exec_properties=exec_properties)
def testConstruct(self): statistics_artifact = standard_artifacts.ExampleStatistics() statistics_artifact.split_names = artifact_utils.encode_split_names( ['eval']) example_validator = component.ExampleValidator( statistics=channel_utils.as_channel([statistics_artifact]), schema=channel_utils.as_channel([standard_artifacts.Schema()]), ) self.assertEqual(standard_artifacts.ExampleAnomalies.TYPE_NAME, example_validator.outputs['anomalies'].type_name)
def testConstruct(self): statistics_artifact = standard_artifacts.ExampleStatistics() statistics_artifact.split_names = artifact_utils.encode_split_names( ['train']) schema_gen = component.SchemaGen( statistics=channel_utils.as_channel([statistics_artifact])) self.assertEqual(standard_artifacts.Schema.TYPE_NAME, schema_gen.outputs['schema'].type_name) self.assertFalse( schema_gen.spec.exec_properties['infer_feature_shape'])
def testEvalution(self, exec_properties, model_agnostic=False): source_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata') output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create input dict. examples = standard_artifacts.Examples() examples.uri = os.path.join(source_data_dir, 'csv_example_gen') examples.split_names = artifact_utils.encode_split_names( ['train', 'eval']) baseline_model = standard_artifacts.Model() baseline_model.uri = os.path.join(source_data_dir, 'trainer/previous/') schema = standard_artifacts.Schema() schema.uri = os.path.join(source_data_dir, 'schema_gen') input_dict = { EXAMPLES_KEY: [examples], SCHEMA_KEY: [schema], } if not model_agnostic: model = standard_artifacts.Model() model.uri = os.path.join(source_data_dir, 'trainer/current') input_dict[MODEL_KEY] = [model] # Create output dict. eval_output = standard_artifacts.ModelEvaluation() eval_output.uri = os.path.join(output_data_dir, 'eval_output') blessing_output = standard_artifacts.ModelBlessing() blessing_output.uri = os.path.join(output_data_dir, 'blessing_output') output_dict = { EVALUATION_KEY: [eval_output], BLESSING_KEY: [blessing_output], } # Test multiple splits. exec_properties[EXAMPLE_SPLITS_KEY] = json_utils.dumps( ['train', 'eval']) if MODULE_FILE_KEY in exec_properties: exec_properties[MODULE_FILE_KEY] = os.path.join( source_data_dir, 'module_file', 'evaluator_module.py') # Run executor. evaluator = executor.Executor() evaluator.Do(input_dict, output_dict, exec_properties) # Check evaluator outputs. self.assertTrue( fileio.exists(os.path.join(eval_output.uri, 'eval_config.json'))) self.assertTrue(fileio.exists(os.path.join(eval_output.uri, 'metrics'))) self.assertTrue(fileio.exists(os.path.join(eval_output.uri, 'plots'))) self.assertFalse( fileio.exists(os.path.join(blessing_output.uri, 'BLESSED')))
def __init__(self, statistics: types.Channel = None, schema: types.Channel = None, exclude_splits: Optional[List[Text]] = None, output: Optional[types.Channel] = None, stats: Optional[types.Channel] = None, instance_name: Optional[Text] = None): """Construct an ExampleValidator component. Args: statistics: A Channel of type `standard_artifacts.ExampleStatistics`. This should contain at least 'eval' split. Other splits are currently ignored. schema: A Channel of type `standard_artifacts.Schema`. _required_ exclude_splits: Names of splits that the example validator should not validate. Default behavior (when exclude_splits is set to None) is excluding no splits. output: Output channel of type `standard_artifacts.ExampleAnomalies`. stats: Backwards compatibility alias for the 'statistics' argument. instance_name: Optional name assigned to this specific instance of ExampleValidator. Required only if multiple ExampleValidator components are declared in the same pipeline. Either `stats` or `statistics` must be present in the arguments. """ if stats: logging.warning( 'The "stats" argument to the StatisticsGen component has ' 'been renamed to "statistics" and is deprecated. Please update your ' 'usage as support for this argument will be removed soon.') statistics = stats if exclude_splits is None: exclude_splits = [] logging.info('Excluding no splits because exclude_splits is not set.') anomalies = output if not anomalies: anomalies_artifact = standard_artifacts.ExampleAnomalies() statistics_split_names = artifact_utils.decode_split_names( artifact_utils.get_single_instance(list( statistics.get())).split_names) split_names = [ split for split in statistics_split_names if split not in exclude_splits ] anomalies_artifact.split_names = artifact_utils.encode_split_names( split_names) anomalies = types.Channel( type=standard_artifacts.ExampleAnomalies, artifacts=[anomalies_artifact]) spec = ExampleValidatorSpec( statistics=statistics, schema=schema, exclude_splits=json_utils.dumps(exclude_splits), anomalies=anomalies) super(ExampleValidator, self).__init__( spec=spec, instance_name=instance_name)
def _import_artifacts(self, source_uri: List[Text], reimport: bool, destination_channel: types.Channel, split_names: List[Text]) -> List[types.Artifact]: """Imports external resource in MLMD.""" results = [] for uri, s in zip(source_uri, split_names): absl.logging.info('Processing source uri: %s, split: %s' % (uri, s or 'NO_SPLIT')) result = destination_channel.type() # TODO(ccy): refactor importer to treat split name just like any other # property. unfiltered_previous_artifacts = self._metadata_handler.get_artifacts_by_uri( uri) # Filter by split name. desired_split_names = artifact_utils.encode_split_names([s or '']) previous_artifacts = [] for previous_artifact in unfiltered_previous_artifacts: # TODO(ccy): refactor importer to treat split name just like any other # property. if result.PROPERTIES and SPLIT_KEY in result.PROPERTIES: # Consider the previous artifact only if the split_names match. split_names = previous_artifact.properties.get( 'split_names', None) if split_names and split_names.string_value == desired_split_names: previous_artifacts.append(previous_artifact) else: # Unconditionally add the previous artifact for consideration. previous_artifacts.append(previous_artifact) # TODO(ccy): refactor importer to treat split name just like any other # property. if SPLIT_KEY in result.artifact_type.properties: result.split_names = desired_split_names result.uri = uri # If any registered artifact with the same uri also has the same # fingerprint and user does not ask for re-import, just reuse the latest. # Otherwise, register the external resource into MLMD using the type info # in the destination channel. if bool(previous_artifacts) and not reimport: absl.logging.info('Reusing existing artifact') result.set_mlmd_artifact( max(previous_artifacts, key=lambda m: m.id)) else: [registered_artifact ] = self._metadata_handler.publish_artifacts([result]) absl.logging.info('Registered new artifact: %s' % registered_artifact) result.set_mlmd_artifact(registered_artifact) results.append(result) return results
def testDo(self, mock_client): # Mock query result schema for _BigQueryConverter. mock_client.return_value.query.return_value.result.return_value.schema = self._schema output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create output dict. examples = standard_artifacts.Examples() examples.uri = output_data_dir output_dict = {'examples': [examples]} # Create exe properties. exec_properties = { 'input_config': proto_utils.proto_to_json( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='bq', pattern='SELECT i, b, f, s FROM `fake`'), ])), 'output_config': proto_utils.proto_to_json( example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split( name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split( name='eval', hash_buckets=1) ]))) } # Run executor. big_query_example_gen = executor.Executor( base_beam_executor.BaseBeamExecutor.Context( beam_pipeline_args=['--project=test-project'])) big_query_example_gen.Do({}, output_dict, exec_properties) mock_client.assert_called_with(project='test-project') self.assertEqual( artifact_utils.encode_split_names(['train', 'eval']), examples.split_names) # Check BigQuery example gen outputs. train_output_file = os.path.join(examples.uri, 'Split-train', 'data_tfrecord-00000-of-00001.gz') eval_output_file = os.path.join(examples.uri, 'Split-eval', 'data_tfrecord-00000-of-00001.gz') self.assertTrue(fileio.exists(train_output_file)) self.assertTrue(fileio.exists(eval_output_file)) self.assertGreater( fileio.open(train_output_file).size(), fileio.open(eval_output_file).size())
def testEnableCache(self): statistics_artifact = standard_artifacts.ExampleStatistics() statistics_artifact.split_names = artifact_utils.encode_split_names( ['train']) schema_gen_1 = component.SchemaGen( statistics=channel_utils.as_channel([statistics_artifact])) schema_gen_2 = component.SchemaGen( statistics=channel_utils.as_channel([statistics_artifact]), enable_cache=True) self.assertEqual(None, schema_gen_1.enable_cache) self.assertEqual(True, schema_gen_2.enable_cache)
def testConstruct(self): examples = standard_artifacts.Examples() examples.split_names = artifact_utils.encode_split_names( ['train', 'eval']) exclude_splits = ['eval'] statistics_gen = component.StatisticsGen( examples=channel_utils.as_channel([examples]), exclude_splits=exclude_splits) self.assertEqual(standard_artifacts.ExampleStatistics.TYPE_NAME, statistics_gen.outputs['statistics'].type_name) self.assertEqual(statistics_gen.spec.exec_properties['exclude_splits'], '["eval"]')
def __init__(self, input_config: Union[example_gen_pb2.Input, Dict[Text, Any]], output_config: Optional[Union[example_gen_pb2.Output, Dict[Text, Any]]] = None, custom_config: Optional[Union[example_gen_pb2.CustomConfig, Dict[Text, Any]]] = None, example_artifacts: Optional[types.Channel] = None, instance_name: Optional[Text] = None, enable_cache: Optional[bool] = None): """Construct an QueryBasedExampleGen component. Args: input_config: An [example_gen_pb2.Input](https://github.com/tensorflow/tfx/blob/master/tfx/proto/example_gen.proto) instance, providing input configuration. If any field is provided as a RuntimeParameter, input_config should be constructed as a dict with the same field names as Input proto message. _required_ output_config: An [example_gen_pb2.Output](https://github.com/tensorflow/tfx/blob/master/tfx/proto/example_gen.proto) instance, providing output configuration. If unset, the default splits will be labeled as 'train' and 'eval' with a distribution ratio of 2:1. If any field is provided as a RuntimeParameter, output_config should be constructed as a dict with the same field names as Output proto message. custom_config: An [example_gen_pb2.CustomConfig](https://github.com/tensorflow/tfx/blob/master/tfx/proto/example_gen.proto) instance, providing custom configuration for ExampleGen. If any field is provided as a RuntimeParameter, output_config should be constructed as a dict. example_artifacts: Channel of `standard_artifacts.Examples` for output train and eval examples. instance_name: Optional unique instance name. Required only if multiple ExampleGen components are declared in the same pipeline. enable_cache: Optional boolean to indicate if cache is enabled for the QueryBasedExampleGen component. If not specified, defaults to the value specified for pipeline's enable_cache parameter. """ # Configure outputs. output_config = output_config or utils.make_default_output_config( input_config) if not example_artifacts: artifact = standard_artifacts.Examples() artifact.split_names = artifact_utils.encode_split_names( utils.generate_output_split_names(input_config, output_config)) example_artifacts = channel_utils.as_channel([artifact]) spec = QueryBasedExampleGenSpec(input_config=input_config, output_config=output_config, custom_config=custom_config, examples=example_artifacts) super(_QueryBasedExampleGen, self).__init__(spec=spec, instance_name=instance_name, enable_cache=enable_cache)
def testDo(self, mock_client): # Mock query result schema for _BigQueryConverter. mock_client.return_value.query.return_value.result.return_value.schema = self._schema output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create output dict. examples = standard_artifacts.Examples() examples.uri = output_data_dir output_dict = {'examples': [examples]} # Create exe properties. exec_properties = { 'input_config': json_format.MessageToJson( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='bq', pattern='SELECT i, b, f, s FROM `fake`'), ]), preserving_proto_field_name=True), 'output_config': json_format.MessageToJson( example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split( name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split( name='eval', hash_buckets=1) ])), preserving_proto_field_name=True) } # Run executor. big_query_example_gen = executor.Executor() big_query_example_gen.Do({}, output_dict, exec_properties) self.assertEqual( artifact_utils.encode_split_names(['train', 'eval']), examples.split_names) # Check BigQuery example gen outputs. train_output_file = os.path.join(examples.uri, 'train', 'data_tfrecord-00000-of-00001.gz') eval_output_file = os.path.join(examples.uri, 'eval', 'data_tfrecord-00000-of-00001.gz') self.assertTrue(tf.io.gfile.exists(train_output_file)) self.assertTrue(tf.io.gfile.exists(eval_output_file)) self.assertGreater( tf.io.gfile.GFile(train_output_file).size(), tf.io.gfile.GFile(eval_output_file).size())
def testConstruct(self): statistics_artifact = standard_artifacts.ExampleStatistics() statistics_artifact.split_names = artifact_utils.encode_split_names( ['train', 'eval']) exclude_splits = ['eval'] example_validator = component.ExampleValidator( statistics=channel_utils.as_channel([statistics_artifact]), schema=channel_utils.as_channel([standard_artifacts.Schema()]), exclude_splits=exclude_splits) self.assertEqual(standard_artifacts.ExampleAnomalies.TYPE_NAME, example_validator.outputs['anomalies'].type_name) self.assertEqual(example_validator.spec.exec_properties['exclude_splits'], '["eval"]')
def testGetSplitUriDeprecated(self): with mock.patch.object(tf_logging, 'warning'): warn_mock = mock.MagicMock() tf_logging.warning = warn_mock my_artifact = artifact.Artifact('TestType') my_artifact.uri = '123' my_artifact.split_names = artifact_utils.encode_split_names( ['train']) self.assertEqual('123/train', types.get_split_uri([my_artifact], 'train')) warn_mock.assert_called_once() self.assertIn('tfx.utils.types.get_split_uri has been renamed to', warn_mock.call_args[0][5])
def testConstructWithParameter(self): statistics_artifact = standard_artifacts.ExampleStatistics() statistics_artifact.split_names = artifact_utils.encode_split_names( ['train']) infer_shape = data_types.RuntimeParameter(name='infer-shape', ptype=bool) schema_gen = component.SchemaGen( statistics=channel_utils.as_channel([statistics_artifact]), infer_feature_shape=infer_shape) self.assertEqual(standard_artifacts.Schema.TYPE_NAME, schema_gen.outputs['schema'].type_name) self.assertJsonEqual( str(schema_gen.spec.exec_properties['infer_feature_shape']), str(infer_shape))
def _make_base_do_params(self, source_data_dir, output_data_dir): # Create input dict. example1 = standard_artifacts.Examples() example1.uri = self._ARTIFACT1_URI example1.split_names = artifact_utils.encode_split_names( ['train', 'eval']) example2 = copy.deepcopy(example1) example2.uri = self._ARTIFACT2_URI self._example_artifacts = [example1, example2] schema_artifact = standard_artifacts.Schema() schema_artifact.uri = os.path.join(source_data_dir, 'schema_gen') self._input_dict = { standard_component_specs.EXAMPLES_KEY: self._example_artifacts[:1], standard_component_specs.SCHEMA_KEY: [schema_artifact], } # Create output dict. self._transformed_output = standard_artifacts.TransformGraph() self._transformed_output.uri = os.path.join(output_data_dir, 'transformed_graph') transformed1 = standard_artifacts.Examples() transformed1.uri = os.path.join(output_data_dir, 'transformed_examples', '0') transformed2 = standard_artifacts.Examples() transformed2.uri = os.path.join(output_data_dir, 'transformed_examples', '1') self._transformed_example_artifacts = [transformed1, transformed2] temp_path_output = _TempPath() temp_path_output.uri = tempfile.mkdtemp() self._updated_analyzer_cache_artifact = standard_artifacts.TransformCache( ) self._updated_analyzer_cache_artifact.uri = os.path.join( self._output_data_dir, 'CACHE') self._output_dict = { standard_component_specs.TRANSFORM_GRAPH_KEY: [self._transformed_output], standard_component_specs.TRANSFORMED_EXAMPLES_KEY: self._transformed_example_artifacts[:1], executor.TEMP_PATH_KEY: [temp_path_output], standard_component_specs.UPDATED_ANALYZER_CACHE_KEY: [self._updated_analyzer_cache_artifact], } # Create exec properties skeleton. self._exec_properties = {}
def setUp(self): super(PlaceholderUtilsTest, self).setUp() examples = [standard_artifacts.Examples()] examples[0].uri = "/tmp" examples[0].split_names = artifact_utils.encode_split_names( ["train", "eval"]) self._serving_spec = infra_validator_pb2.ServingSpec() self._serving_spec.tensorflow_serving.tags.extend( ["latest", "1.15.0-gpu"]) self._resolution_context = placeholder_utils.ResolutionContext( exec_info=data_types.ExecutionInfo( input_dict={ "model": [standard_artifacts.Model()], "examples": examples, }, output_dict={"blessing": [standard_artifacts.ModelBlessing()]}, exec_properties={ "proto_property": json_format.MessageToJson(message=self._serving_spec, sort_keys=True, preserving_proto_field_name=True, indent=0) }, execution_output_uri="test_executor_output_uri", stateful_working_dir="test_stateful_working_dir", pipeline_node=pipeline_pb2.PipelineNode( node_info=pipeline_pb2.NodeInfo( type=metadata_store_pb2.ExecutionType( name="infra_validator"))), pipeline_info=pipeline_pb2.PipelineInfo( id="test_pipeline_id")), executor_spec=executable_spec_pb2.PythonClassExecutableSpec( class_path="test_class_path"), ) # Resolution context to simulate missing optional values. self._none_resolution_context = placeholder_utils.ResolutionContext( exec_info=data_types.ExecutionInfo( input_dict={ "model": [], "examples": [], }, output_dict={"blessing": []}, exec_properties={}, pipeline_node=pipeline_pb2.PipelineNode( node_info=pipeline_pb2.NodeInfo( type=metadata_store_pb2.ExecutionType( name="infra_validator"))), pipeline_info=pipeline_pb2.PipelineInfo( id="test_pipeline_id")), executor_spec=None, platform_config=None)
def Do(self, input_dict: Dict[Text, List[types.Artifact]], output_dict: Dict[Text, List[types.Artifact]], exec_properties: Dict[Text, Any]) -> None: source = exec_properties[StepKeys.SOURCE] args = exec_properties[StepKeys.ARGS] c = source_utils.load_source_path_class(source) tokenizer_step: BaseTokenizer = c(**args) tokenizer_location = artifact_utils.get_single_uri( output_dict["tokenizer"]) split_uris, split_names, all_files = [], [], [] for artifact in input_dict["examples"]: for split in artifact_utils.decode_split_names( artifact.split_names): split_names.append(split) uri = os.path.join(artifact.uri, split) split_uris.append((split, uri)) all_files += path_utils.list_dir(uri) # Get output split path output_examples = artifact_utils.get_single_instance( output_dict["output_examples"]) output_examples.split_names = artifact_utils.encode_split_names( split_names) if not tokenizer_step.skip_training: tokenizer_step.train(files=all_files) tokenizer_step.save(output_dir=tokenizer_location) with self._make_beam_pipeline() as p: for split, uri in split_uris: input_uri = io_utils.all_files_pattern(uri) _ = (p | 'ReadData.' + split >> beam.io.ReadFromTFRecord( file_pattern=input_uri) | "ParseTFExFromString." + split >> beam.Map( tf.train.Example.FromString) | "AddTokens." + split >> beam.Map( append_tf_example, tokenizer_step=tokenizer_step) | 'Serialize.' + split >> beam.Map( lambda x: x.SerializeToString()) | 'WriteSplit.' + split >> WriteSplit( get_split_uri( output_dict["output_examples"], split)))
def testDo(self): output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create output dict. examples = standard_artifacts.Examples() examples.uri = output_data_dir output_dict = {utils.EXAMPLES_KEY: [examples]} # Create exec proterties. exec_properties = { utils.INPUT_BASE_KEY: self._input_data_dir, utils.INPUT_CONFIG_KEY: json_format.MessageToJson( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='avro', pattern='avro/*.avro'), ]), preserving_proto_field_name=True), utils.OUTPUT_CONFIG_KEY: json_format.MessageToJson( example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split( name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split( name='eval', hash_buckets=1) ])), preserving_proto_field_name=True) } # Run executor. avro_example_gen = avro_executor.Executor() avro_example_gen.Do({}, output_dict, exec_properties) self.assertEqual( artifact_utils.encode_split_names(['train', 'eval']), examples.split_names) # Check Avro example gen outputs. train_output_file = os.path.join(examples.uri, 'train', 'data_tfrecord-00000-of-00001.gz') eval_output_file = os.path.join(examples.uri, 'eval', 'data_tfrecord-00000-of-00001.gz') self.assertTrue(fileio.exists(train_output_file)) self.assertTrue(fileio.exists(eval_output_file)) self.assertGreater( fileio.open(train_output_file).size(), fileio.open(eval_output_file).size())
def setUp(self): super(ExecutorTest, self).setUp() source_data_dir = os.path.dirname(os.path.dirname(__file__)) input_data_dir = os.path.join(source_data_dir, 'testdata') statistics = standard_artifacts.ExampleStatistics() statistics.uri = os.path.join(input_data_dir, 'StatisticsGen.train_mockdata_1', 'statistics', '5') statistics.split_names = artifact_utils.encode_split_names( ['train', 'eval']) transformed_examples = standard_artifacts.Examples() transformed_examples.uri = os.path.join(input_data_dir, 'Transform.train_mockdata_1', 'transformed_examples', '10') transformed_examples.split_names = artifact_utils.encode_split_names( ['train', 'eval']) self._input_dict = { executor.EXAMPLES_KEY: [transformed_examples], executor.STATISTICS_KEY: [statistics], } output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', tempfile.mkdtemp(dir=flags.FLAGS.test_tmpdir)), self._testMethodName) self._metafeatures = artifacts.MetaFeatures() self._metafeatures.uri = output_data_dir self._output_dict = { executor.METAFEATURES_KEY: [self._metafeatures], } self._exec_properties = { 'custom_config': { 'problem_statement_path': '/some/fake/path' } }
def testConstruct(self): examples = standard_artifacts.Examples() examples.split_names = artifact_utils.encode_split_names( ['train', 'eval']) examples.span = 1 binder = binder_component.DataViewBinder( input_examples=channel_utils.as_channel([examples]), data_view=channel_utils.as_channel([standard_artifacts.DataView() ])) output_examples = binder.outputs['output_examples'] self.assertIsNotNone(output_examples) output_examples = output_examples.get() self.assertLen(output_examples, 1) self._assert_example_artifact_equal(output_examples[0], examples)
def setUp(self): super().setUp() self._testdata_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata') self._output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) self._context = executor.Executor.Context( tmp_dir=self._output_data_dir, unique_id='1') # Create input dict. e1 = standard_artifacts.Examples() e1.uri = os.path.join(self._testdata_dir, 'penguin', 'data') e1.split_names = artifact_utils.encode_split_names(['train', 'eval']) e2 = copy.deepcopy(e1) self._single_artifact = [e1] self._multiple_artifacts = [e1, e2] schema = standard_artifacts.Schema() schema.uri = os.path.join(self._testdata_dir, 'penguin', 'schema') base_model = standard_artifacts.Model() base_model.uri = os.path.join(self._testdata_dir, 'trainer/previous') self._input_dict = { standard_component_specs.EXAMPLES_KEY: self._single_artifact, standard_component_specs.SCHEMA_KEY: [schema], standard_component_specs.BASE_MODEL_KEY: [base_model] } # Create output dict. self._best_hparams = standard_artifacts.Model() self._best_hparams.uri = os.path.join(self._output_data_dir, 'best_hparams') self._output_dict = { standard_component_specs.BEST_HYPERPARAMETERS_KEY: [self._best_hparams], } # Create exec properties. self._exec_properties = { standard_component_specs.TRAIN_ARGS_KEY: proto_utils.proto_to_json(trainer_pb2.TrainArgs(num_steps=100)), standard_component_specs.EVAL_ARGS_KEY: proto_utils.proto_to_json(trainer_pb2.EvalArgs(num_steps=50)), }
def testEnableCache(self): statistics_artifact = standard_artifacts.ExampleStatistics() statistics_artifact.split_names = artifact_utils.encode_split_names( ['eval']) example_validator_1 = component.ExampleValidator( statistics=channel_utils.as_channel([statistics_artifact]), schema=channel_utils.as_channel([standard_artifacts.Schema()]), ) self.assertEqual(None, example_validator_1.enable_cache) example_validator_2 = component.ExampleValidator( statistics=channel_utils.as_channel([statistics_artifact]), schema=channel_utils.as_channel([standard_artifacts.Schema()]), enable_cache=True) self.assertEqual(True, example_validator_2.enable_cache)
def testConstructWithSchemaAndStatsOptions(self): examples = standard_artifacts.Examples() examples.split_names = artifact_utils.encode_split_names( ['train', 'eval']) schema = standard_artifacts.Schema() stats_options = tfdv.StatsOptions(weight_feature='weight') statistics_gen = component.StatisticsGen( examples=channel_utils.as_channel([examples]), schema=channel_utils.as_channel([schema]), stats_options=stats_options) self.assertEqual( standard_artifacts.ExampleStatistics.TYPE_NAME, statistics_gen.outputs[ standard_component_specs.STATISTICS_KEY].type_name)