def test_construct(self): examples = types.TfxType(type_name='ExamplesPath') model_exports = types.TfxType(type_name='ModelExportPath') evaluator = component.Evaluator( examples=channel.as_channel([examples]), model_exports=channel.as_channel([model_exports])) self.assertEqual('ModelEvalPath', evaluator.outputs.output.type_name)
def test_do(self): input_data_dir = os.path.join( os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'testdata') output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create input dict. input_base = types.TfxType(type_name='ExternalPath') input_base.uri = os.path.join(input_data_dir, 'external/csv/') input_dict = {'input-base': [input_base]} # Create output dict. train_examples = types.TfxType(type_name='ExamplesPath', split='train') train_examples.uri = os.path.join(output_data_dir, 'train') eval_examples = types.TfxType(type_name='ExamplesPath', split='eval') eval_examples.uri = os.path.join(output_data_dir, 'eval') output_dict = {'examples': [train_examples, eval_examples]} # Run executor. csv_example_gen = executor.Executor() csv_example_gen.Do(input_dict, output_dict, {}) # Check CSV example gen outputs. train_output_file = os.path.join(train_examples.uri, 'data_tfrecord-00000-of-00001.gz') eval_output_file = os.path.join(eval_examples.uri, 'data_tfrecord-00000-of-00001.gz') self.assertTrue(tf.gfile.Exists(train_output_file)) self.assertTrue(tf.gfile.Exists(eval_output_file)) self.assertGreater( tf.gfile.GFile(train_output_file).size(), tf.gfile.GFile(eval_output_file).size())
def test_channel_as_channel_success(self): instance_a = types.TfxType('MyTypeName') instance_b = types.TfxType('MyTypeName') chnl_original = channel.Channel( 'MyTypeName', static_artifact_collection=[instance_a, instance_b]) chnl_result = channel.as_channel(chnl_original) self.assertEqual(chnl_original, chnl_result)
def testDo(self, mock_client): # Mock query result schema for _BigQueryConverter. mock_client.return_value.query.return_value.result.return_value.schema = self._schema output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create output dict. train_examples = types.TfxType(type_name='ExamplesPath', split='train') train_examples.uri = os.path.join(output_data_dir, 'train') eval_examples = types.TfxType(type_name='ExamplesPath', split='eval') eval_examples.uri = os.path.join(output_data_dir, 'eval') output_dict = {'examples': [train_examples, eval_examples]} # Run executor. big_query_example_gen = executor.Executor() big_query_example_gen.Do({}, output_dict, self._exec_properties) # Check BigQuery example gen outputs. train_output_file = os.path.join(train_examples.uri, 'data_tfrecord-00000-of-00001.gz') eval_output_file = os.path.join(eval_examples.uri, 'data_tfrecord-00000-of-00001.gz') self.assertTrue(tf.gfile.Exists(train_output_file)) self.assertTrue(tf.gfile.Exists(eval_output_file)) self.assertGreater( tf.gfile.GFile(train_output_file).size(), tf.gfile.GFile(eval_output_file).size())
def test_construct(self): train_examples = types.TfxType(type_name='ExamplesPath', split='train') eval_examples = types.TfxType(type_name='ExamplesPath', split='eval') statistics_gen = component.StatisticsGen( input_data=channel.as_channel([train_examples, eval_examples])) self.assertEqual('ExampleStatisticsPath', statistics_gen.outputs.output.type_name)
def setUp(self): self._source_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata') self._output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) tf.gfile.MakeDirs(self._output_data_dir) self._model_export = types.TfxType(type_name='ModelExportPath') self._model_export.uri = os.path.join(self._source_data_dir, 'trainer/current/') self._model_blessing = types.TfxType(type_name='ModelBlessingPath') self._input_dict = { 'model_export': [self._model_export], 'model_blessing': [self._model_blessing], } self._model_push = types.TfxType(type_name='ModelPushPath') self._model_push.uri = os.path.join(self._output_data_dir, 'model_push') tf.gfile.MakeDirs(self._model_push.uri) self._output_dict = { 'model_push': [self._model_push], } self._serving_model_dir = os.path.join(self._output_data_dir, 'serving_model_dir') tf.gfile.MakeDirs(self._serving_model_dir) self._exec_properties = { 'push_destination': json_format.MessageToJson( pusher_pb2.PushDestination( filesystem=pusher_pb2.PushDestination.Filesystem( base_directory=self._serving_model_dir))), } self._executor = executor.Executor()
def test_valid_channel(self): instance_a = types.TfxType('MyTypeName') instance_b = types.TfxType('MyTypeName') chnl = channel.Channel( 'MyTypeName', static_artifact_collection=[instance_a, instance_b]) self.assertEqual(chnl.type_name, 'MyTypeName') self.assertItemsEqual(chnl.get(), [instance_a, instance_b])
def test_fetch_previous_result(self): with Metadata( connection_config=self._connection_config, logger=self._logger) as m: # Create an 'previous' execution. exec_properties = {'log_root': 'path'} eid = m.prepare_execution('Test', exec_properties) input_artifact = types.TfxType(type_name='ExamplesPath') m.publish_artifacts([input_artifact]) output_artifact = types.TfxType(type_name='ExamplesPath') input_dict = {'input': [input_artifact]} output_dict = {'output': [output_artifact]} m.publish_execution(eid, input_dict, output_dict) # Test previous_run. self.assertEqual(None, m.previous_run('Test', input_dict, {})) self.assertEqual(None, m.previous_run('Test', {}, exec_properties)) self.assertEqual(None, m.previous_run('Test2', input_dict, exec_properties)) self.assertEqual(eid, m.previous_run('Test', input_dict, exec_properties)) # Test fetch_previous_result_artifacts. new_output_artifact = types.TfxType(type_name='ExamplesPath') self.assertNotEqual(types.ARTIFACT_STATE_PUBLISHED, new_output_artifact.state) new_output_dict = {'output': [new_output_artifact]} updated_output_dict = m.fetch_previous_result_artifacts( new_output_dict, eid) previous_artifact = output_dict['output'][-1].artifact current_artifact = updated_output_dict['output'][-1].artifact self.assertEqual(types.ARTIFACT_STATE_PUBLISHED, current_artifact.properties['state'].string_value) self.assertEqual(previous_artifact.id, current_artifact.id) self.assertEqual(previous_artifact.type_id, current_artifact.type_id)
def test_invalid_channel_type(self): instance_a = types.TfxType('MyTypeName') instance_b = types.TfxType('MyTypeName') with self.assertRaises(ValueError): channel.Channel( 'AnotherTypeName', static_artifact_collection=[instance_a, instance_b])
def setUp(self): self._source_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata') output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create input dict. eval_examples = types.TfxType(type_name='ExamplesPath', split='eval') eval_examples.uri = os.path.join(self._source_data_dir, 'csv_example_gen/eval/') model = types.TfxType(type_name='ModelExportPath') model.uri = os.path.join(self._source_data_dir, 'trainer/current/') self._input_dict = { 'examples': [eval_examples], 'model': [model], } # Create output dict. self._blessing = types.TfxType('ModelBlessingPath') self._blessing.uri = os.path.join(output_data_dir, 'blessing') self._output_dict = { 'blessing': [self._blessing] } # Create context self._tmp_dir = os.path.join(output_data_dir, '.temp') self._context = executor.Executor.Context(tmp_dir=self._tmp_dir, unique_id='2')
def test_do(self): source_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata') train_stats_artifact = types.TfxType('ExampleStatsPath', split='train') train_stats_artifact.uri = os.path.join(source_data_dir, 'statistics_gen/train/') output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) schema_output = types.TfxType('SchemaPath') schema_output.uri = os.path.join(output_data_dir, 'schema_output') input_dict = { 'stats': [train_stats_artifact], } output_dict = { 'output': [schema_output], } exec_properties = {} schema_gen_executor = executor.Executor() schema_gen_executor.Do(input_dict, output_dict, exec_properties) self.assertNotEqual(0, len(tf.gfile.ListDirectory(schema_output.uri)))
def setUp(self): self._mock_metadata = tf.test.mock.Mock() self._input_dict = { 'input_data': [types.TfxType(type_name='InputType')], } input_dir = os.path.join( os.environ.get('TEST_TMP_DIR', self.get_temp_dir()), self._testMethodName, 'input_dir') # valid input artifacts must have a uri pointing to an existing directory. for key, input_list in self._input_dict.items(): for index, artifact in enumerate(input_list): artifact.id = index + 1 uri = os.path.join(input_dir, key, str(artifact.id), '') artifact.uri = uri tf.gfile.MakeDirs(uri) self._output_dict = { 'output_data': [types.TfxType(type_name='OutputType')], } self._exec_properties = { 'key': 'value', } self._base_output_dir = os.path.join( os.environ.get('TEST_TMP_DIR', self.get_temp_dir()), self._testMethodName, 'base_output_dir') self._driver_options = base_driver.DriverOptions( worker_name='worker_name', base_output_dir=self._base_output_dir, enable_cache=True) self._execution_id = 100 log_root = os.path.join(self._base_output_dir, 'log_dir') logger_config = logging_utils.LoggerConfig(log_root=log_root) self._logger = logging_utils.get_logger(logger_config)
def setUp(self): self._source_data_dir = os.path.join( os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), 'components', 'testdata') self._output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) tf.gfile.MakeDirs(self._output_data_dir) self._model_export = types.TfxType(type_name='ModelExportPath') self._model_export.uri = os.path.join(self._source_data_dir, 'trainer/current/') self._model_blessing = types.TfxType(type_name='ModelBlessingPath') self._input_dict = { 'model_export': [self._model_export], 'model_blessing': [self._model_blessing], } self._model_push = types.TfxType(type_name='ModelPushPath') self._model_push.uri = os.path.join(self._output_data_dir, 'model_push') tf.gfile.MakeDirs(self._model_push.uri) self._output_dict = { 'model_push': [self._model_push], } self._exec_properties = { 'custom_config': { 'ai_platform_serving_args': { 'model_name': 'model_name', 'project_id': 'project_id' }, }, } self._executor = Executor()
def test_do(self): output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create output dict. train_examples = types.TfxType(type_name='ExamplesPath', split='train') train_examples.uri = os.path.join(output_data_dir, 'train') eval_examples = types.TfxType(type_name='ExamplesPath', split='eval') eval_examples.uri = os.path.join(output_data_dir, 'eval') output_dict = {'examples': [train_examples, eval_examples]} # Run executor. example_gen = TestExampleGenExecutor() example_gen.Do({}, output_dict, {}) # Check example gen outputs. train_output_file = os.path.join(train_examples.uri, 'data_tfrecord-00000-of-00001.gz') eval_output_file = os.path.join(eval_examples.uri, 'data_tfrecord-00000-of-00001.gz') self.assertTrue(tf.gfile.Exists(train_output_file)) self.assertTrue(tf.gfile.Exists(eval_output_file)) self.assertGreater( tf.gfile.GFile(train_output_file).size(), tf.gfile.GFile(eval_output_file).size())
def test_do(self): source_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata') output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create input dict. train_examples = types.TfxType(type_name='ExamplesPath', split='train') train_examples.uri = os.path.join( source_data_dir, 'transform/transformed_examples/train/') eval_examples = types.TfxType(type_name='ExamplesPath', split='eval') eval_examples.uri = os.path.join( source_data_dir, 'transform/transformed_examples/eval/') transform_output = types.TfxType(type_name='TransformPath') transform_output.uri = os.path.join(source_data_dir, 'transform/transform_output/') schema = types.TfxType(type_name='ExamplesPath') schema.uri = os.path.join(source_data_dir, 'schema_gen/') input_dict = { 'transformed_examples': [train_examples, eval_examples], 'transform_output': [transform_output], 'schema': [schema], } # Create output dict. model_exports = types.TfxType(type_name='ModelExportPath') model_exports.uri = os.path.join(output_data_dir, 'model_export_path') output_dict = {'output': [model_exports]} # Create exec properties. module_file_path = os.path.join(source_data_dir, 'module_file', 'trainer_module.py') exec_properties = { 'train_args': json_format.MessageToJson(trainer_pb2.TrainArgs(num_steps=1000)), 'eval_args': json_format.MessageToJson(trainer_pb2.EvalArgs(num_steps=500)), 'module_file': module_file_path, 'warm_starting': False, } # Run executor. pipeline = beam.Pipeline() evaluator = executor.Executor(pipeline) evaluator.Do(input_dict=input_dict, output_dict=output_dict, exec_properties=exec_properties) # Check outputs. self.assertTrue( tf.gfile.Exists(os.path.join(model_exports.uri, 'eval_model_dir'))) self.assertTrue( tf.gfile.Exists( os.path.join(model_exports.uri, 'serving_model_dir')))
def test_construct(self): examples = types.TfxType(type_name='ExamplesPath') model = types.TfxType(type_name='ModelExportPath') model_validator = component.ModelValidator( examples=channel.as_channel([examples]), model=channel.as_channel([model])) self.assertEqual('ModelBlessingPath', model_validator.outputs.blessing.type_name)
def test_construct(self): example_vadalitor = component.ExampleValidator( stats=channel.as_channel([ types.TfxType(type_name='ExampleStatisticsPath', split='eval') ]), schema=channel.as_channel([types.TfxType(type_name='SchemaPath')]), ) self.assertEqual('ExampleValidationPath', example_vadalitor.outputs.output.type_name)
def setUp(self): self.input_one = types.TfxType('INPUT_ONE') self.input_one.source = airflow_component._OrchestrationSource( 'input_one_key', 'input_one_component_id') self.output_one = types.TfxType('OUTPUT_ONE') self.output_one.source = airflow_component._OrchestrationSource( 'output_one_key', 'output_one_component_id') self.input_one_json = json.dumps([self.input_one.json_dict()]) self.output_one_json = json.dumps([self.output_one.json_dict()]) self._logger_config = logging_utils.LoggerConfig()
def test_execution(self): with Metadata( connection_config=self._connection_config, logger=self._logger) as m: # Test prepare_execution. exec_properties = {} eid = m.prepare_execution('Test', exec_properties) [execution] = m.store.get_executions() self.assertProtoEquals( """ id: 1 type_id: 1 properties { key: "state" value { string_value: "new" } }""", execution) # Test publish_execution. input_artifact = types.TfxType(type_name='ExamplesPath') m.publish_artifacts([input_artifact]) output_artifact = types.TfxType(type_name='ExamplesPath') input_dict = {'input': [input_artifact]} output_dict = {'output': [output_artifact]} m.publish_execution(eid, input_dict, output_dict) # Make sure artifacts in output_dict are published. self.assertEqual(types.ARTIFACT_STATE_PUBLISHED, output_artifact.state) # Make sure execution state are changed. [execution] = m.store.get_executions_by_id([eid]) self.assertEqual('complete', execution.properties['state'].string_value) # Make sure events are published. events = m.store.get_events_by_execution_ids([eid]) self.assertEqual(2, len(events)) self.assertEqual(input_artifact.id, events[0].artifact_id) self.assertEqual(metadata_store_pb2.Event.DECLARED_INPUT, events[0].type) self.assertProtoEquals( """ steps { key: "input" } steps { index: 0 }""", events[0].path) self.assertEqual(output_artifact.id, events[1].artifact_id) self.assertEqual(metadata_store_pb2.Event.DECLARED_OUTPUT, events[1].type) self.assertProtoEquals( """ steps { key: "output" } steps { index: 0 }""", events[1].path)
def test_construct_with_slice_spec(self): examples = types.TfxType(type_name='ExamplesPath') model_exports = types.TfxType(type_name='ModelExportPath') evaluator = component.Evaluator( examples=channel.as_channel([examples]), model_exports=channel.as_channel([model_exports]), feature_slicing_spec=evaluator_pb2.FeatureSlicingSpec(specs=[ evaluator_pb2.SingleSlicingSpec( column_for_slicing=['trip_start_hour']) ])) self.assertEqual('ModelEvalPath', evaluator.outputs.output.type_name)
def test_construct(self): transformed_examples = types.TfxType(type_name='ExamplesPath') transform_output = types.TfxType(type_name='TransformPath') schema = types.TfxType(type_name='SchemaPath') trainer = component.Trainer( module_file='/path/to/module/file', transformed_examples=channel.as_channel([transformed_examples]), transform_output=channel.as_channel([transform_output]), schema=channel.as_channel([schema]), train_args=trainer_pb2.TrainArgs(num_steps=100), eval_args=trainer_pb2.EvalArgs(num_steps=50)) self.assertEqual('ModelExportPath', trainer.outputs.output.type_name)
def testDo(self, mock_client): # Mock query result schema for _BigQueryConverter. mock_client.return_value.query.return_value.result.return_value.schema = self._schema output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create output dict. train_examples = types.TfxType(type_name='ExamplesPath', split='train') train_examples.uri = os.path.join(output_data_dir, 'train') eval_examples = types.TfxType(type_name='ExamplesPath', split='eval') eval_examples.uri = os.path.join(output_data_dir, 'eval') output_dict = {'examples': [train_examples, eval_examples]} # Create exe properties. exec_properties = { 'input': json_format.MessageToJson( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='bq', pattern='SELECT i, f, s FROM `fake`'), ])), 'output': json_format.MessageToJson( example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1) ]))) } # Run executor. big_query_example_gen = executor.Executor() big_query_example_gen.Do({}, output_dict, exec_properties) # Check BigQuery example gen outputs. train_output_file = os.path.join(train_examples.uri, 'data_tfrecord-00000-of-00001.gz') eval_output_file = os.path.join(eval_examples.uri, 'data_tfrecord-00000-of-00001.gz') self.assertTrue(tf.gfile.Exists(train_output_file)) self.assertTrue(tf.gfile.Exists(eval_output_file)) self.assertGreater( tf.gfile.GFile(train_output_file).size(), tf.gfile.GFile(eval_output_file).size())
def test_fetch_last_blessed_model(self): output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) log_root = os.path.join(output_data_dir, 'log_root') # Mock metadata. mock_metadata = tf.test.mock.Mock() model_validator_driver = driver.Driver(log_root, mock_metadata) # No blessed model. mock_metadata.get_all_artifacts.return_value = [] self.assertEqual((None, None), model_validator_driver._fetch_last_blessed_model()) # Mock blessing artifacts. artifacts = [] for span in [4, 3, 2, 1]: model_blessing = types.TfxType(type_name='ModelBlessingPath') model_blessing.span = span model_blessing.set_string_custom_property('current_model', 'uri-%d' % span) model_blessing.set_int_custom_property('current_model_id', span) # Only odd spans are "blessed" model_blessing.set_int_custom_property('blessed', span % 2) artifacts.append(model_blessing.artifact) mock_metadata.get_all_artifacts.return_value = artifacts self.assertEqual(('uri-3', 3), model_validator_driver._fetch_last_blessed_model())
def setUp(self): output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create output dict. train_examples = types.TfxType(type_name='ExamplesPath', split='train') train_examples.uri = os.path.join(output_data_dir, 'train') eval_examples = types.TfxType(type_name='ExamplesPath', split='eval') eval_examples.uri = os.path.join(output_data_dir, 'eval') self._output_dict = {'examples': [train_examples, eval_examples]} self._train_output_file = os.path.join( train_examples.uri, 'data_tfrecord-00000-of-00001.gz') self._eval_output_file = os.path.join( eval_examples.uri, 'data_tfrecord-00000-of-00001.gz')
def test_do(self): source_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata') output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create input dict. train_examples = types.TfxType(type_name='ExamplesPath', split='train') eval_examples = types.TfxType(type_name='ExamplesPath', split='eval') eval_examples.uri = os.path.join(source_data_dir, 'csv_example_gen/eval/') model_exports = types.TfxType(type_name='ModelExportPath') model_exports.uri = os.path.join(source_data_dir, 'trainer/current/') input_dict = { 'examples': [train_examples, eval_examples], 'model_exports': [model_exports], } # Create output dict. eval_output = types.TfxType('ModelEvalPath') eval_output.uri = os.path.join(output_data_dir, 'eval_output') output_dict = {'output': [eval_output]} # Create exec proterties. exec_properties = { 'feature_slicing_spec': json_format.MessageToJson( evaluator_pb2.FeatureSlicingSpec(specs=[ evaluator_pb2.SingleSlicingSpec( column_for_slicing=['trip_start_hour']), evaluator_pb2.SingleSlicingSpec( column_for_slicing=['trip_start_day', 'trip_miles']), ])) } # Run executor. evaluator = executor.Executor() evaluator.Do(input_dict, output_dict, exec_properties) # Check evaluator outputs. self.assertTrue( tf.gfile.Exists(os.path.join(eval_output.uri, 'eval_config'))) self.assertTrue( tf.gfile.Exists(os.path.join(eval_output.uri, 'metrics'))) self.assertTrue(tf.gfile.Exists(os.path.join(eval_output.uri, 'plots')))
def setUp(self): input_data_dir = os.path.join( os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'testdata') # Create input dict. input_base = types.TfxType(type_name='ExternalPath') input_base.uri = os.path.join(input_data_dir, 'external/csv/') self._input_dict = {'input-base': [input_base]}
def test_tfxtype_deprecated(self): with mock.patch.object(tf_logging, 'warning'): warn_mock = mock.MagicMock() tf_logging.warning = warn_mock types.TfxType('FakeType') warn_mock.assert_called_once() self.assertIn('TfxType has been renamed to TfxArtifact', warn_mock.call_args[0][5])
def test_construct(self): source_data_dir = os.path.join(os.path.dirname(__file__), 'testdata', 'taxi') preprocessing_fn_file = os.path.join(source_data_dir, 'module', 'preprocess.py') transform = component.Transform( input_data=channel.as_channel([ types.TfxType(type_name='ExamplesPath', split='train'), types.TfxType(type_name='ExamplesPath', split='eval'), ]), schema=channel.as_channel([types.TfxType(type_name='SchemaPath')]), module_file=preprocessing_fn_file, ) self.assertEqual('TransformPath', transform.outputs.transform_output.type_name) self.assertEqual('ExamplesPath', transform.outputs.transformed_examples.type_name)
def test_tfx_type(self): instance = types.TfxType('MyTypeName', split='eval') # Test property getters. self.assertEqual('', instance.uri) self.assertEqual(0, instance.id) self.assertEqual(0, instance.type_id) self.assertEqual('MyTypeName', instance.type_name) self.assertIsNone(instance.state) self.assertEqual('eval', instance.split) self.assertIsNone(instance.span) # Test property setters. instance.uri = '/tmp/uri2' self.assertEqual('/tmp/uri2', instance.uri) instance.id = 1 self.assertEqual(1, instance.id) instance.type_id = 2 self.assertEqual(2, instance.type_id) instance.state = types.ARTIFACT_STATE_DELETED self.assertEqual(types.ARTIFACT_STATE_DELETED, instance.state) instance.split = '' self.assertEqual('', instance.split) instance.span = 20190101 self.assertEqual(20190101, instance.span) instance.set_int_custom_property('int_key', 20) self.assertEqual(20, instance.artifact.custom_properties['int_key'].int_value) instance.set_string_custom_property('string_key', 'string_value') self.assertEqual( 'string_value', instance.artifact.custom_properties['string_key'].string_value) self.assertEqual('MyTypeName:/tmp/uri2.1', str(instance)) # Test json serialization. json_dict = instance.json_dict() s = json.dumps(json_dict) other_instance = types.TfxType.parse_from_json_dict(json.loads(s)) self.assertEqual(instance.artifact, other_instance.artifact) self.assertEqual(instance.artifact_type, other_instance.artifact_type) # Test pickling dumped_instance = pickle.dumps(instance) loaded_instance = pickle.loads(dumped_instance) self.assertEqual(instance.artifact, loaded_instance.artifact) self.assertEqual(instance.artifact_type, loaded_instance.artifact_type) self.assertIsNone(instance.source) instance.source = 'hello_world' self.assertEqual('hello_world', instance.source)
def testDo(self): output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create output dict. train_examples = types.TfxType(type_name='ExamplesPath', split='train') train_examples.uri = os.path.join(output_data_dir, 'train') eval_examples = types.TfxType(type_name='ExamplesPath', split='eval') eval_examples.uri = os.path.join(output_data_dir, 'eval') output_dict = {'examples': [train_examples, eval_examples]} # Create exec proterties. exec_properties = { 'input': json_format.MessageToJson( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='tfrecord', pattern='tfrecord/*'), ])), 'output': json_format.MessageToJson( example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1) ]))) } # Run executor. import_example_gen = executor.Executor() import_example_gen.Do(self._input_dict, output_dict, exec_properties) # Check import_example_gen outputs. train_output_file = os.path.join(train_examples.uri, 'data_tfrecord-00000-of-00001.gz') eval_output_file = os.path.join(eval_examples.uri, 'data_tfrecord-00000-of-00001.gz') self.assertTrue(tf.gfile.Exists(train_output_file)) self.assertTrue(tf.gfile.Exists(eval_output_file)) self.assertGreater( tf.gfile.GFile(train_output_file).size(), tf.gfile.GFile(eval_output_file).size())