def testRun(self, mock_publisher): mock_publisher.return_value.publish_execution.return_value = {} example_gen = FileBasedExampleGen( custom_executor_spec=executor_spec.ExecutorClassSpec( avro_executor.Executor), input=external_input(self.avro_dir_path), input_config=self.input_config, output_config=self.output_config, instance_name='AvroExampleGen') output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) pipeline_root = os.path.join(output_data_dir, 'Test') tf.io.gfile.makedirs(pipeline_root) pipeline_info = data_types.PipelineInfo(pipeline_name='Test', pipeline_root=pipeline_root, run_id='123') driver_args = data_types.DriverArgs(enable_cache=True) connection_config = metadata_store_pb2.ConnectionConfig() connection_config.sqlite.SetInParent() metadata_connection = metadata.Metadata(connection_config) launcher = in_process_component_launcher.InProcessComponentLauncher.create( component=example_gen, pipeline_info=pipeline_info, driver_args=driver_args, metadata_connection=metadata_connection, beam_pipeline_args=[], additional_pipeline_args={}) self.assertEqual( launcher._component_info.component_type, '.'.join( [FileBasedExampleGen.__module__, FileBasedExampleGen.__name__])) launcher.launch() mock_publisher.return_value.publish_execution.assert_called_once() # Get output paths. component_id = example_gen.id output_path = os.path.join(pipeline_root, component_id, 'examples/1') examples = standard_artifacts.Examples() examples.uri = output_path examples.split_names = artifact_utils.encode_split_names( ['train', 'eval']) # Check Avro example gen outputs. train_output_file = os.path.join(examples.uri, 'train', 'data_tfrecord-00000-of-00001.gz') eval_output_file = os.path.join(examples.uri, 'eval', 'data_tfrecord-00000-of-00001.gz') self.assertTrue(tf.io.gfile.exists(train_output_file)) self.assertTrue(tf.io.gfile.exists(eval_output_file)) self.assertGreater( tf.io.gfile.GFile(train_output_file).size(), tf.io.gfile.GFile(eval_output_file).size())
def test_run(self, mock_publisher): mock_publisher.return_value.publish_execution.return_value = {} example_gen = FileBasedExampleGen( executor_class=avro_executor.Executor, input_base=external_input(self.avro_dir_path), input_config=self.input_config, output_config=self.output_config, name='AvroExampleGenComponent') output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) pipeline_root = os.path.join(output_data_dir, 'Test') tf.gfile.MakeDirs(pipeline_root) pipeline_info = data_types.PipelineInfo(pipeline_name='Test', pipeline_root=pipeline_root, run_id='123') driver_args = data_types.DriverArgs(enable_cache=True) connection_config = metadata_store_pb2.ConnectionConfig() connection_config.sqlite.SetInParent() launcher = component_launcher.ComponentLauncher( component=example_gen, pipeline_info=pipeline_info, driver_args=driver_args, metadata_connection_config=connection_config, additional_pipeline_args={}) self.assertEqual( launcher._component_info.component_type, '.'.join( [FileBasedExampleGen.__module__, FileBasedExampleGen.__name__])) launcher.launch() mock_publisher.return_value.publish_execution.assert_called_once() # Get output paths. component_id = '.'.join([example_gen.component_name, example_gen.name]) output_path = os.path.join(pipeline_root, component_id, 'examples/1') train_examples = types.TfxArtifact(type_name='ExamplesPath', split='train') train_examples.uri = os.path.join(output_path, 'train') eval_examples = types.TfxArtifact(type_name='ExamplesPath', split='eval') eval_examples.uri = os.path.join(output_path, 'eval') # Check Avro example gen outputs. train_output_file = os.path.join(train_examples.uri, 'data_tfrecord-00000-of-00001.gz') eval_output_file = os.path.join(eval_examples.uri, 'data_tfrecord-00000-of-00001.gz') self.assertTrue(tf.gfile.Exists(train_output_file)) self.assertTrue(tf.gfile.Exists(eval_output_file)) self.assertGreater( tf.gfile.GFile(train_output_file).size(), tf.gfile.GFile(eval_output_file).size())
def testRun(self, mock_publisher): mock_publisher.return_value.publish_execution.return_value = {} example_gen = FileBasedExampleGen( custom_executor_spec=executor_spec.ExecutorClassSpec( avro_executor.Executor), input_base=self.avro_dir_path, input_config=self.input_config, output_config=self.output_config, instance_name='AvroExampleGen') output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) pipeline_root = os.path.join(output_data_dir, 'Test') fileio.makedirs(pipeline_root) pipeline_info = data_types.PipelineInfo(pipeline_name='Test', pipeline_root=pipeline_root, run_id='123') driver_args = data_types.DriverArgs(enable_cache=True) connection_config = metadata_store_pb2.ConnectionConfig() connection_config.sqlite.SetInParent() metadata_connection = metadata.Metadata(connection_config) launcher = in_process_component_launcher.InProcessComponentLauncher.create( component=example_gen, pipeline_info=pipeline_info, driver_args=driver_args, metadata_connection=metadata_connection, beam_pipeline_args=[], additional_pipeline_args={}) self.assertEqual( launcher._component_info.component_type, '.'.join( [FileBasedExampleGen.__module__, FileBasedExampleGen.__name__])) launcher.launch() mock_publisher.return_value.publish_execution.assert_called_once() # Check output paths. self.assertTrue( fileio.exists(os.path.join(pipeline_root, example_gen.id)))