Exemplo n.º 1
0
  def setUp(self):
    super(DriverTest, self).setUp()
    # Create input splits.
    test_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self._testMethodName)
    self._input_base_path = os.path.join(test_dir, 'input_base')
    tf.io.gfile.makedirs(self._input_base_path)

    # Mock metadata.
    self._mock_metadata = tf.test.mock.Mock()
    self._example_gen_driver = driver.Driver(self._mock_metadata)

    # Create input dict.
    input_base = standard_artifacts.ExternalArtifact()
    input_base.uri = self._input_base_path
    self._input_channels = {
        'input_base': channel_utils.as_channel([input_base])
    }
    # Create exec proterties.
    self._exec_properties = {
        'input_config':
            json_format.MessageToJson(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='s1', pattern='span{SPAN}/split1/*'),
                    example_gen_pb2.Input.Split(
                        name='s2', pattern='span{SPAN}/split2/*')
                ])),
    }
Exemplo n.º 2
0
  def setUp(self):
    super(DriverTest, self).setUp()

    self._test_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self._testMethodName)

    # Mock metadata and create driver.
    self._mock_metadata = tf.compat.v1.test.mock.Mock()
    self._example_gen_driver = driver.Driver(self._mock_metadata)
Exemplo n.º 3
0
  def test_prepare_input_for_processing(self):
    output_data_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self._testMethodName)
    self._logger_config = logging_utils.LoggerConfig(
        log_root=os.path.join(output_data_dir, 'log_dir'))

    # Mock metadata.
    mock_metadata = tf.test.mock.Mock()
    example_gen_driver = driver.Driver(self._logger_config, mock_metadata)

    # Mock artifact.
    artifacts = []
    for i in [4, 3, 2, 1]:
      artifact = metadata_store_pb2.Artifact()
      artifact.id = i
      # Only odd ids will be matched to input_base.uri.
      artifact.uri = 'path-{}'.format(i % 2)
      artifacts.append(artifact)

    # Create input dict.
    input_base = types.TfxType(type_name='ExternalPath')
    input_base.uri = 'path-1'
    input_dict = {'input-base': [input_base]}

    # Cache not hit.
    mock_metadata.get_all_artifacts.return_value = []
    mock_metadata.publish_artifacts.return_value = [artifacts[3]]
    updated_input_dict = example_gen_driver._prepare_input_for_processing(
        copy.deepcopy(input_dict))
    self.assertEqual(1, len(updated_input_dict))
    self.assertEqual(1, len(updated_input_dict['input-base']))
    updated_input_base = updated_input_dict['input-base'][0]
    self.assertEqual(1, updated_input_base.id)
    self.assertEqual('path-1', updated_input_base.uri)

    # Cache hit.
    mock_metadata.get_all_artifacts.return_value = artifacts
    mock_metadata.publish_artifacts.return_value = []
    updated_input_dict = example_gen_driver._prepare_input_for_processing(
        copy.deepcopy(input_dict))
    self.assertEqual(1, len(updated_input_dict))
    self.assertEqual(1, len(updated_input_dict['input-base']))
    updated_input_base = updated_input_dict['input-base'][0]
    self.assertEqual(3, updated_input_base.id)
    self.assertEqual('path-1', updated_input_base.uri)
Exemplo n.º 4
0
    def test_prepare_input_for_processing(self):
        # Mock metadata.
        mock_metadata = tf.test.mock.Mock()
        example_gen_driver = driver.Driver(mock_metadata)

        # Mock artifact.
        artifacts = []
        for i in [4, 3, 2, 1]:
            artifact = metadata_store_pb2.Artifact()
            artifact.id = i
            # Only odd ids will be matched to input_base.uri.
            artifact.uri = 'path-{}'.format(i % 2)
            artifacts.append(artifact)

        # Create input dict.
        input_base = types.TfxArtifact(type_name='ExternalPath')
        input_base.uri = 'path-1'
        input_dict = {'input_base': [input_base]}

        # Cache not hit.
        mock_metadata.get_all_artifacts.return_value = []
        mock_metadata.publish_artifacts.return_value = [artifacts[3]]
        updated_input_dict = example_gen_driver._prepare_input_for_processing(
            copy.deepcopy(input_dict))
        self.assertEqual(1, len(updated_input_dict))
        self.assertEqual(1, len(updated_input_dict['input_base']))
        updated_input_base = updated_input_dict['input_base'][0]
        self.assertEqual(1, updated_input_base.id)
        self.assertEqual('path-1', updated_input_base.uri)

        # Cache hit.
        mock_metadata.get_all_artifacts.return_value = artifacts
        mock_metadata.publish_artifacts.return_value = []
        updated_input_dict = example_gen_driver._prepare_input_for_processing(
            copy.deepcopy(input_dict))
        self.assertEqual(1, len(updated_input_dict))
        self.assertEqual(1, len(updated_input_dict['input_base']))
        updated_input_base = updated_input_dict['input_base'][0]
        self.assertEqual(3, updated_input_base.id)
        self.assertEqual('path-1', updated_input_base.uri)
Exemplo n.º 5
0
  def testRun(self):
    # Create input dir.
    self._input_base_path = os.path.join(self._test_dir, 'input_base')
    tf.io.gfile.makedirs(self._input_base_path)

    # Create PipelineInfo and PipelineNode
    pipeline_info = pipeline_pb2.PipelineInfo()
    pipeline_node = pipeline_pb2.PipelineNode()

    # Fake previous outputs
    span1_v1_split1 = os.path.join(self._input_base_path, 'span01', 'version01',
                                   'split1', 'data')
    io_utils.write_string_file(span1_v1_split1, 'testing11')
    span1_v1_split2 = os.path.join(self._input_base_path, 'span01', 'version01',
                                   'split2', 'data')
    io_utils.write_string_file(span1_v1_split2, 'testing12')

    ir_driver = driver.Driver(self._mock_metadata, pipeline_info, pipeline_node)
    example = standard_artifacts.Examples()

    # Prepare output_dic
    example.uri = 'my_uri'  # Will verify that this uri is not changed.
    output_dic = {utils.EXAMPLES_KEY: [example]}

    # Prepare output_dic exec_proterties.
    exec_properties = {
        utils.INPUT_BASE_KEY:
            self._input_base_path,
        utils.INPUT_CONFIG_KEY:
            json_format.MessageToJson(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='s1',
                        pattern='span{SPAN}/version{VERSION}/split1/*'),
                    example_gen_pb2.Input.Split(
                        name='s2',
                        pattern='span{SPAN}/version{VERSION}/split2/*')
                ]),
                preserving_proto_field_name=True),
    }
    result = ir_driver.run(None, output_dic, exec_properties)
    print(result)
    # Assert exec_properties' values
    exec_properties = result.exec_properties
    self.assertEqual(exec_properties[utils.SPAN_PROPERTY_NAME].int_value, 1)
    self.assertEqual(exec_properties[utils.VERSION_PROPERTY_NAME].int_value, 1)
    updated_input_config = example_gen_pb2.Input()
    json_format.Parse(exec_properties[utils.INPUT_CONFIG_KEY].string_value,
                      updated_input_config)
    self.assertProtoEquals(
        """
        splits {
          name: "s1"
          pattern: "span01/version01/split1/*"
        }
        splits {
          name: "s2"
          pattern: "span01/version01/split2/*"
        }""", updated_input_config)
    self.assertRegex(
        exec_properties[utils.FINGERPRINT_PROPERTY_NAME].string_value,
        r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*'
    )
    # Assert output_artifacts' values
    self.assertLen(result.output_artifacts[utils.EXAMPLES_KEY].artifacts, 1)
    output_example = result.output_artifacts[utils.EXAMPLES_KEY].artifacts[0]
    self.assertEqual(output_example.uri, example.uri)
    self.assertEqual(
        output_example.custom_properties[utils.SPAN_PROPERTY_NAME].string_value,
        '1')
    self.assertEqual(
        output_example.custom_properties[
            utils.VERSION_PROPERTY_NAME].string_value, '1')
    self.assertRegex(
        output_example.custom_properties[
            utils.FINGERPRINT_PROPERTY_NAME].string_value,
        r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*'
    )
Exemplo n.º 6
0
  def test_prepare_input_for_processing(self):
    # Create input splits.
    test_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self._testMethodName)
    input_base_path = os.path.join(test_dir, 'input_base')
    split1 = os.path.join(input_base_path, 'split1', 'data')
    io_utils.write_string_file(split1, 'testing')
    os.utime(split1, (0, 1))
    split2 = os.path.join(input_base_path, 'split2', 'data')
    io_utils.write_string_file(split2, 'testing2')
    os.utime(split2, (0, 3))

    # Mock metadata.
    mock_metadata = tf.test.mock.Mock()
    example_gen_driver = driver.Driver(mock_metadata)

    # Mock artifact.
    artifacts = []
    for i in [4, 3, 2, 1]:
      artifact = metadata_store_pb2.Artifact()
      artifact.id = i
      artifact.uri = input_base_path
      # Only odd ids will be matched
      if i % 2 == 1:
        artifact.custom_properties[
            'input_fingerprint'].string_value = 'split:s1,num_files:1,total_bytes:7,xor_checksum:1,sum_checksum:1\nsplit:s2,num_files:1,total_bytes:8,xor_checksum:3,sum_checksum:3'
      else:
        artifact.custom_properties[
            'input_fingerprint'].string_value = 'not_match'
      artifacts.append(artifact)

    # Create input dict.
    input_base = types.TfxArtifact(type_name='ExternalPath')
    input_base.uri = input_base_path
    input_dict = {'input_base': [input_base]}
    # Create exec proterties.
    exec_properties = {
        'input_config':
            json_format.MessageToJson(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(name='s1', pattern='split1/*'),
                    example_gen_pb2.Input.Split(name='s2', pattern='split2/*')
                ])),
    }

    # Cache not hit.
    mock_metadata.get_artifacts_by_uri.return_value = [artifacts[0]]
    mock_metadata.publish_artifacts.return_value = [artifacts[3]]
    updated_input_dict = example_gen_driver._prepare_input_for_processing(
        copy.deepcopy(input_dict), exec_properties)
    self.assertEqual(1, len(updated_input_dict))
    self.assertEqual(1, len(updated_input_dict['input_base']))
    updated_input_base = updated_input_dict['input_base'][0]
    self.assertEqual(1, updated_input_base.id)
    self.assertEqual(input_base_path, updated_input_base.uri)

    # Cache hit.
    mock_metadata.get_artifacts_by_uri.return_value = artifacts
    mock_metadata.publish_artifacts.return_value = []
    updated_input_dict = example_gen_driver._prepare_input_for_processing(
        copy.deepcopy(input_dict), exec_properties)
    self.assertEqual(1, len(updated_input_dict))
    self.assertEqual(1, len(updated_input_dict['input_base']))
    updated_input_base = updated_input_dict['input_base'][0]
    self.assertEqual(3, updated_input_base.id)
    self.assertEqual(input_base_path, updated_input_base.uri)
Exemplo n.º 7
0
    def testDriverRunFn(self):
        # Create input dir.
        self._input_base_path = os.path.join(self._test_dir, 'input_base')
        fileio.makedirs(self._input_base_path)

        # Fake previous outputs
        span1_v1_split1 = os.path.join(self._input_base_path, 'span01',
                                       'split1', 'data')
        io_utils.write_string_file(span1_v1_split1, 'testing11')
        span1_v1_split2 = os.path.join(self._input_base_path, 'span01',
                                       'split2', 'data')
        io_utils.write_string_file(span1_v1_split2, 'testing12')

        ir_driver = driver.Driver(self._mock_metadata)
        example = standard_artifacts.Examples()

        # Prepare output_dic
        example.uri = 'my_uri'  # Will verify that this uri is not changed.
        output_dic = {standard_component_specs.EXAMPLES_KEY: [example]}

        # Prepare output_dic exec_proterties.
        exec_properties = {
            standard_component_specs.INPUT_BASE_KEY:
            self._input_base_path,
            standard_component_specs.INPUT_CONFIG_KEY:
            proto_utils.proto_to_json(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(name='s1',
                                                pattern='span{SPAN}/split1/*'),
                    example_gen_pb2.Input.Split(name='s2',
                                                pattern='span{SPAN}/split2/*')
                ])),
        }
        result = ir_driver.run(
            portable_data_types.ExecutionInfo(output_dict=output_dic,
                                              exec_properties=exec_properties))
        # Assert exec_properties' values
        exec_properties = result.exec_properties
        self.assertEqual(exec_properties[utils.SPAN_PROPERTY_NAME].int_value,
                         1)
        updated_input_config = example_gen_pb2.Input()
        proto_utils.json_to_proto(
            exec_properties[
                standard_component_specs.INPUT_CONFIG_KEY].string_value,
            updated_input_config)
        self.assertProtoEquals(
            """
        splits {
          name: "s1"
          pattern: "span01/split1/*"
        }
        splits {
          name: "s2"
          pattern: "span01/split2/*"
        }""", updated_input_config)
        self.assertRegex(
            exec_properties[utils.FINGERPRINT_PROPERTY_NAME].string_value,
            r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*'
        )
        # Assert output_artifacts' values
        self.assertLen(
            result.output_artifacts[
                standard_component_specs.EXAMPLES_KEY].artifacts, 1)
        output_example = result.output_artifacts[
            standard_component_specs.EXAMPLES_KEY].artifacts[0]
        self.assertEqual(output_example.uri, example.uri)
        self.assertEqual(
            output_example.custom_properties[
                utils.SPAN_PROPERTY_NAME].string_value, '1')
        self.assertRegex(
            output_example.custom_properties[
                utils.FINGERPRINT_PROPERTY_NAME].string_value,
            r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*'
        )