Esempio n. 1
0
    def test_sequence_coder(self):
        coder = serialization.SequenceCoder(
            specs={
                'seq': tf.float32,
                'name': tf.string,
                'multi': tf.float32,
                'ints': tf.int64,
            },
            sequence_keys=['seq', 'multi', 'ints'])

        encoded = coder.encode(self.example)
        decoded = coder.decode(encoded)
        reencoded = coder.encode(decoded)
        redecoded = coder.decode(reencoded)
        self.assertAllClose(redecoded['seq'][:, 0],
                            self.example['seq'],
                            atol=1e-7)
        self.assertAllClose(redecoded['seq'], decoded['seq'], atol=1e-7)
        self.assertAllClose(redecoded['multi'],
                            self.example['multi'],
                            atol=1e-7)
        self.assertAllClose(redecoded['ints'][:, 0],
                            self.example['ints'],
                            atol=1e-7)
        self.assertEqual(redecoded['name'], self.example['name'])
Esempio n. 2
0
def get_sequence_coder(specs):
    """Returns a SequenceCoder with the proper specs."""
    def not_sequence(v):
        return v == tf.string or isinstance(v, tf.io.FixedLenFeature)

    sequence_keys = [k for k, v in specs.items() if not not_sequence(v)]
    return serialization.SequenceCoder(specs=specs,
                                       sequence_keys=sequence_keys)