def test_sequence_coder(self): coder = serialization.SequenceCoder( specs={ 'seq': tf.float32, 'name': tf.string, 'multi': tf.float32, 'ints': tf.int64, }, sequence_keys=['seq', 'multi', 'ints']) encoded = coder.encode(self.example) decoded = coder.decode(encoded) reencoded = coder.encode(decoded) redecoded = coder.decode(reencoded) self.assertAllClose(redecoded['seq'][:, 0], self.example['seq'], atol=1e-7) self.assertAllClose(redecoded['seq'], decoded['seq'], atol=1e-7) self.assertAllClose(redecoded['multi'], self.example['multi'], atol=1e-7) self.assertAllClose(redecoded['ints'][:, 0], self.example['ints'], atol=1e-7) self.assertEqual(redecoded['name'], self.example['name'])
def get_sequence_coder(specs): """Returns a SequenceCoder with the proper specs.""" def not_sequence(v): return v == tf.string or isinstance(v, tf.io.FixedLenFeature) sequence_keys = [k for k, v in specs.items() if not not_sequence(v)] return serialization.SequenceCoder(specs=specs, sequence_keys=sequence_keys)