def test_enqueue_many(self): [batch_2_op] = ops.batch([self.pack_lt], batch_size=2, enqueue_many=True) self.assertEqual(len(batch_2_op.axes['batch']), 2) [batch_10_op] = ops.batch([batch_2_op], batch_size=10, enqueue_many=True) self.assertLabeledTensorsEqual(self.pack_lt, batch_10_op)
def test_no_enqueue_many(self): [batch_2_op] = ops.batch([self.original_lt], batch_size=2) self.assertEqual(len(batch_2_op.axes['batch']), 2) [batch_10_op] = ops.batch([batch_2_op], batch_size=10, enqueue_many=True) self.assertLabeledTensorsEqual( ops.pack(10 * [self.original_lt], 'batch'), batch_10_op)
def test_enqueue_many(self): [batch_2_lt] = ops.shuffle_batch([self.pack_lt], batch_size=2, enqueue_many=True, min_after_dequeue=8, seed=0) self.assertEqual(len(batch_2_lt.axes['batch']), 2) [batch_10_lt] = ops.batch([batch_2_lt], batch_size=10, enqueue_many=True) self.assertEqual(batch_10_lt.axes, self.pack_lt.axes) [batch_10, pack] = self.eval([batch_10_lt.tensor, self.pack_lt.tensor]) self.assertFalse((batch_10 == pack).all())
def test_allow_smaller_final_batch(self): [batch_2_op] = ops.batch([self.original_lt], batch_size=2, allow_smaller_final_batch=True) self.assertEqual(batch_2_op.axes['batch'].size, None)
def test_invalid_input(self): with self.assertRaises(ValueError): ops.batch([self.original_lt], 3, enqueue_many=True)
def test_name(self): batch_ops = ops.batch([self.pack_lt, self.pack_lt], batch_size=2, enqueue_many=True) for bo in batch_ops: self.assertIn('lt_batch', bo.name)
def test_allow_smaller_final_batch(self): [batch_2_op] = ops.batch( [self.original_lt], batch_size=2, allow_smaller_final_batch=True) self.assertEqual(batch_2_op.axes['batch'].size, None)
def test_name(self): batch_ops = ops.batch( [self.pack_lt, self.pack_lt], batch_size=2, enqueue_many=True) for bo in batch_ops: self.assertIn('lt_batch', bo.name)