Ejemplo n.º 1
0
    def testRun(self, mock_publisher):
        mock_publisher.return_value.publish_execution.return_value = {}

        example_gen = FileBasedExampleGen(
            custom_executor_spec=executor_spec.ExecutorClassSpec(
                avro_executor.Executor),
            input=external_input(self.avro_dir_path),
            input_config=self.input_config,
            output_config=self.output_config,
            instance_name='AvroExampleGen')

        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)
        pipeline_root = os.path.join(output_data_dir, 'Test')
        tf.io.gfile.makedirs(pipeline_root)
        pipeline_info = data_types.PipelineInfo(pipeline_name='Test',
                                                pipeline_root=pipeline_root,
                                                run_id='123')

        driver_args = data_types.DriverArgs(enable_cache=True)

        connection_config = metadata_store_pb2.ConnectionConfig()
        connection_config.sqlite.SetInParent()
        metadata_connection = metadata.Metadata(connection_config)

        launcher = in_process_component_launcher.InProcessComponentLauncher.create(
            component=example_gen,
            pipeline_info=pipeline_info,
            driver_args=driver_args,
            metadata_connection=metadata_connection,
            beam_pipeline_args=[],
            additional_pipeline_args={})
        self.assertEqual(
            launcher._component_info.component_type, '.'.join(
                [FileBasedExampleGen.__module__,
                 FileBasedExampleGen.__name__]))

        launcher.launch()
        mock_publisher.return_value.publish_execution.assert_called_once()

        # Get output paths.
        component_id = example_gen.id
        output_path = os.path.join(pipeline_root, component_id, 'examples/1')
        examples = standard_artifacts.Examples()
        examples.uri = output_path
        examples.split_names = artifact_utils.encode_split_names(
            ['train', 'eval'])

        # Check Avro example gen outputs.
        train_output_file = os.path.join(examples.uri, 'train',
                                         'data_tfrecord-00000-of-00001.gz')
        eval_output_file = os.path.join(examples.uri, 'eval',
                                        'data_tfrecord-00000-of-00001.gz')
        self.assertTrue(tf.io.gfile.exists(train_output_file))
        self.assertTrue(tf.io.gfile.exists(eval_output_file))
        self.assertGreater(
            tf.io.gfile.GFile(train_output_file).size(),
            tf.io.gfile.GFile(eval_output_file).size())
Ejemplo n.º 2
0
    def test_run(self, mock_publisher):
        mock_publisher.return_value.publish_execution.return_value = {}

        example_gen = FileBasedExampleGen(
            executor_class=avro_executor.Executor,
            input_base=external_input(self.avro_dir_path),
            input_config=self.input_config,
            output_config=self.output_config,
            name='AvroExampleGenComponent')

        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)
        pipeline_root = os.path.join(output_data_dir, 'Test')
        tf.gfile.MakeDirs(pipeline_root)
        pipeline_info = data_types.PipelineInfo(pipeline_name='Test',
                                                pipeline_root=pipeline_root,
                                                run_id='123')

        driver_args = data_types.DriverArgs(enable_cache=True)

        connection_config = metadata_store_pb2.ConnectionConfig()
        connection_config.sqlite.SetInParent()

        launcher = component_launcher.ComponentLauncher(
            component=example_gen,
            pipeline_info=pipeline_info,
            driver_args=driver_args,
            metadata_connection_config=connection_config,
            additional_pipeline_args={})
        self.assertEqual(
            launcher._component_info.component_type, '.'.join(
                [FileBasedExampleGen.__module__,
                 FileBasedExampleGen.__name__]))

        launcher.launch()
        mock_publisher.return_value.publish_execution.assert_called_once()

        # Get output paths.
        component_id = '.'.join([example_gen.component_name, example_gen.name])
        output_path = os.path.join(pipeline_root, component_id, 'examples/1')
        train_examples = types.TfxArtifact(type_name='ExamplesPath',
                                           split='train')
        train_examples.uri = os.path.join(output_path, 'train')
        eval_examples = types.TfxArtifact(type_name='ExamplesPath',
                                          split='eval')
        eval_examples.uri = os.path.join(output_path, 'eval')

        # Check Avro example gen outputs.
        train_output_file = os.path.join(train_examples.uri,
                                         'data_tfrecord-00000-of-00001.gz')
        eval_output_file = os.path.join(eval_examples.uri,
                                        'data_tfrecord-00000-of-00001.gz')
        self.assertTrue(tf.gfile.Exists(train_output_file))
        self.assertTrue(tf.gfile.Exists(eval_output_file))
        self.assertGreater(
            tf.gfile.GFile(train_output_file).size(),
            tf.gfile.GFile(eval_output_file).size())
Ejemplo n.º 3
0
    def testRun(self, mock_publisher):
        mock_publisher.return_value.publish_execution.return_value = {}

        example_gen = FileBasedExampleGen(
            custom_executor_spec=executor_spec.ExecutorClassSpec(
                avro_executor.Executor),
            input_base=self.avro_dir_path,
            input_config=self.input_config,
            output_config=self.output_config,
            instance_name='AvroExampleGen')

        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)
        pipeline_root = os.path.join(output_data_dir, 'Test')
        fileio.makedirs(pipeline_root)
        pipeline_info = data_types.PipelineInfo(pipeline_name='Test',
                                                pipeline_root=pipeline_root,
                                                run_id='123')

        driver_args = data_types.DriverArgs(enable_cache=True)

        connection_config = metadata_store_pb2.ConnectionConfig()
        connection_config.sqlite.SetInParent()
        metadata_connection = metadata.Metadata(connection_config)

        launcher = in_process_component_launcher.InProcessComponentLauncher.create(
            component=example_gen,
            pipeline_info=pipeline_info,
            driver_args=driver_args,
            metadata_connection=metadata_connection,
            beam_pipeline_args=[],
            additional_pipeline_args={})
        self.assertEqual(
            launcher._component_info.component_type, '.'.join(
                [FileBasedExampleGen.__module__,
                 FileBasedExampleGen.__name__]))

        launcher.launch()
        mock_publisher.return_value.publish_execution.assert_called_once()

        # Check output paths.
        self.assertTrue(
            fileio.exists(os.path.join(pipeline_root, example_gen.id)))