Example #1
0
    def setUp(self):
        super(DriverTest, self).setUp()

        self._test_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        # Mock metadata and create driver.
        self._mock_metadata = tf.compat.v1.test.mock.Mock()
        self._file_based_driver = driver.FileBasedDriver(self._mock_metadata)
Example #2
0
    def testDriverRunFn(self):
        # Create input dir.
        self._input_base_path = os.path.join(self._test_dir, 'input_base')
        fileio.makedirs(self._input_base_path)

        # Fake previous outputs
        span1_v1_split1 = os.path.join(self._input_base_path, 'span01',
                                       'split1', 'data')
        io_utils.write_string_file(span1_v1_split1, 'testing11')
        span1_v1_split2 = os.path.join(self._input_base_path, 'span01',
                                       'split2', 'data')
        io_utils.write_string_file(span1_v1_split2, 'testing12')

        ir_driver = driver.FileBasedDriver(self._mock_metadata)
        example = standard_artifacts.Examples()

        # Prepare output_dic
        example.uri = 'my_uri'  # Will verify that this uri is not changed.
        output_dic = {standard_component_specs.EXAMPLES_KEY: [example]}

        # Prepare output_dic exec_proterties.
        exec_properties = {
            standard_component_specs.INPUT_BASE_KEY:
            self._input_base_path,
            standard_component_specs.INPUT_CONFIG_KEY:
            proto_utils.proto_to_json(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='s1', pattern='span{SPAN:2}/split1/*'),
                    example_gen_pb2.Input.Split(
                        name='s2', pattern='span{SPAN:2}/split2/*')
                ])),
        }
        result = ir_driver.run(
            portable_data_types.ExecutionInfo(output_dict=output_dic,
                                              exec_properties=exec_properties))
        # Assert exec_properties' values
        exec_properties = result.exec_properties
        self.assertEqual(exec_properties[utils.SPAN_PROPERTY_NAME].int_value,
                         1)
        updated_input_config = example_gen_pb2.Input()
        proto_utils.json_to_proto(
            exec_properties[
                standard_component_specs.INPUT_CONFIG_KEY].string_value,
            updated_input_config)
        self.assertProtoEquals(
            """
        splits {
          name: "s1"
          pattern: "span01/split1/*"
        }
        splits {
          name: "s2"
          pattern: "span01/split2/*"
        }""", updated_input_config)
        self.assertRegex(
            exec_properties[utils.FINGERPRINT_PROPERTY_NAME].string_value,
            r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*'
        )
        # Assert output_artifacts' values
        self.assertLen(
            result.output_artifacts[
                standard_component_specs.EXAMPLES_KEY].artifacts, 1)
        output_example = result.output_artifacts[
            standard_component_specs.EXAMPLES_KEY].artifacts[0]
        self.assertEqual(output_example.uri, example.uri)
        self.assertEqual(
            output_example.custom_properties[
                utils.SPAN_PROPERTY_NAME].string_value, '1')
        self.assertRegex(
            output_example.custom_properties[
                utils.FINGERPRINT_PROPERTY_NAME].string_value,
            r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*'
        )