def testDefaultFilling(self):
        context_features = [
            struct2tensor_parsing_utils.Feature('ctx.bytes',
                                                tf.string,
                                                default_value=b'g',
                                                length=1),
        ]
        example_features = [
            struct2tensor_parsing_utils.Feature('example_float',
                                                tf.float32,
                                                default_value=-1.0,
                                                length=2),
        ]
        decoder = struct2tensor_parsing_utils.ELWCDecoder(
            'test_decoder',
            context_features,
            example_features,
            size_feature_name=None,
            label_feature=None)

        result = decoder.decode_record(tf.convert_to_tensor(_ELWCS))
        self.assertLen(result, len(context_features) + len(example_features))
        for f in itertools.chain(context_features, example_features):
            self.assertIn(f.name, result)
            self.assertIsInstance(result[f.name], tf.RaggedTensor)

        expected = {
            'ctx.bytes': [[b'g'], [b'c']],
            'example_float': [[[11.0, 12.0], [-1.0, -1.0]], [[14.0, 15.0]]],
        }
        self.assertEqual({k: v.to_list() for k, v in result.items()}, expected)
    def testAllDTypes(self):
        context_features = [
            struct2tensor_parsing_utils.Feature('ctx.int', tf.int64),
            struct2tensor_parsing_utils.Feature('ctx.float', tf.float32),
            struct2tensor_parsing_utils.Feature('ctx.bytes', tf.string),
        ]
        example_features = [
            struct2tensor_parsing_utils.Feature('example_int', tf.int64),
            struct2tensor_parsing_utils.Feature('example_float', tf.float32),
            struct2tensor_parsing_utils.Feature('example_bytes', tf.string),
        ]
        decoder = struct2tensor_parsing_utils.ELWCDecoder(
            'test_decoder',
            context_features,
            example_features,
            size_feature_name=None,
            label_feature=None)

        result = decoder.decode_record(tf.convert_to_tensor(_ELWCS))
        self.assertLen(result, len(context_features) + len(example_features))
        for f in itertools.chain(context_features, example_features):
            self.assertIn(f.name, result)
            self.assertIsInstance(result[f.name], tf.RaggedTensor)

        expected = {
            'ctx.int': [[1, 2], [3]],
            'ctx.float': [[1.0, 2.0], [3.0]],
            'ctx.bytes': [[], [b'c']],
            'example_int': [[[11], [22]], [[33]]],
            'example_float': [[[11.0, 12.0], []], [[14.0, 15.0]]],
            'example_bytes': [[[b'u', b'v'], [b'w']], [[b'x', b'y', b'z']]],
        }
        self.assertEqual({k: v.to_list() for k, v in result.items()}, expected)
    def testLabelFeature(self):
        decoder = struct2tensor_parsing_utils.ELWCDecoder(
            'test_decoder', [], [],
            size_feature_name=None,
            label_feature=struct2tensor_parsing_utils.Feature(
                'example_int', tf.int64))
        result = decoder.decode_record(tf.convert_to_tensor(_ELWCS))

        self.assertLen(result, 1)
        self.assertEqual(result['example_int'].to_list(),
                         [[11.0, 22.0], [33.0]])
Example #4
0
def make_decoder():
    """Creates a data decoder that that decodes ELWC records to tensors.

  A DataView (see "TfGraphDataViewProvider" component in the pipeline)
  will refer to this decoder. And any components that consumes the data
  with the DataView applied will use this decoder.

  Returns:
    A ELWC decoder.
  """
    context_features, example_features, label_feature = features.get_features()

    return struct2tensor_parsing_utils.ELWCDecoder(
        name='ELWCDecoder',
        context_features=context_features,
        example_features=example_features,
        size_feature_name=features.LIST_SIZE_FEATURE_NAME,
        label_feature=label_feature)
 def testSizeFeature(self):
     decoder = struct2tensor_parsing_utils.ELWCDecoder(
         'test_decoder', [], [], size_feature_name='example_list_size')
     result = decoder.decode_record(tf.convert_to_tensor(_ELWCS))
     self.assertLen(result, 1)
     self.assertEqual(result['example_list_size'].to_list(), [[2], [1]])