def test_stack(self):
     seqs = self.sampler.sample([3, 256])
     split_seqs = [tf.reshape(s, [-1]) for s in tf.split(seqs, len(seqs))]
     stack_fn = transforms.Stack(axis=0)
     output = stack_fn.call(*split_seqs)
     self.assertAllEqual(output, seqs)
Exemple #2
0
 def stack_and_pop(on):
     stack = transforms.Stack(on=paired_keys(on), out=on)
     pop = transforms.Pop(on=paired_keys(on))
     return [stack, pop]