def test_enqueue_many(self): [batch_2_lt] = ops.shuffle_batch([self.pack_lt], batch_size=2, enqueue_many=True, min_after_dequeue=8, seed=0) self.assertEqual(len(batch_2_lt.axes['batch']), 2) [batch_10_lt] = ops.batch([batch_2_lt], batch_size=10, enqueue_many=True) self.assertEqual(batch_10_lt.axes, self.pack_lt.axes) [batch_10, pack] = self.eval([batch_10_lt.tensor, self.pack_lt.tensor]) self.assertFalse((batch_10 == pack).all())
def test_allow_smaller_final_batch(self): [batch_2_op] = ops.shuffle_batch([self.original_lt], batch_size=2, allow_smaller_final_batch=True) self.assertEqual(batch_2_op.axes['batch'].size, None)
def test_name(self): batch_lts = ops.shuffle_batch([self.pack_lt, self.pack_lt], batch_size=2, enqueue_many=True) for blt in batch_lts: self.assertIn('lt_shuffle_batch', blt.name)
def test_allow_smaller_final_batch(self): [batch_2_op] = ops.shuffle_batch( [self.original_lt], batch_size=2, allow_smaller_final_batch=True) self.assertEqual(batch_2_op.axes['batch'].size, None)
def test_name(self): batch_lts = ops.shuffle_batch( [self.pack_lt, self.pack_lt], batch_size=2, enqueue_many=True) for blt in batch_lts: self.assertIn('lt_shuffle_batch', blt.name)