def test_out_of_order_execution2(self): with self.test_session() as session: batcher = dynamic_batching._Batcher(minimum_batch_size=1, maximum_batch_size=1, timeout_ms=None) tp = pool.ThreadPool(10) r0 = tp.apply_async(session.run, batcher.compute([[1]], [tf.int32])) (input0, ), computation_id0 = session.run( batcher.get_inputs([tf.int32])) r1 = tp.apply_async(session.run, batcher.compute([[2]], [tf.int32])) (input1, ), computation_id1 = session.run( batcher.get_inputs([tf.int32])) self.assertAllEqual([1], input0) self.assertAllEqual([2], input1) # These two runs are switched from testOutOfOrderExecution1. session.run(batcher.set_outputs([input1 + 42], computation_id1)) session.run(batcher.set_outputs([input0 + 42], computation_id0)) self.assertAllEqual([43], r0.get()) self.assertAllEqual([44], r1.get())
def test_op_shape(self): with self.test_session(): batcher = dynamic_batching._Batcher(minimum_batch_size=1, maximum_batch_size=1, timeout_ms=None) _, computation_id = batcher.get_inputs([tf.int32]) self.assertEqual([], computation_id.shape)
def test_invalid_computation_id(self): with self.test_session() as session: batcher = dynamic_batching._Batcher(minimum_batch_size=1, maximum_batch_size=1, timeout_ms=None) tp = pool.ThreadPool(10) tp.apply_async(session.run, batcher.compute([[1]], [tf.int32])) (input0, ), _ = session.run(batcher.get_inputs([tf.int32])) self.assertAllEqual([1], input0) with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, 'Invalid computation id'): session.run(batcher.set_outputs([input0], 42))
def test_invalid_computation_id(self): with self.test_session() as session: batcher = dynamic_batching._Batcher(minimum_batch_size=1, maximum_batch_size=1, timeout_ms=None) tp = pool.ThreadPool(10) tp.apply_async(session.run, batcher.compute([[1]], [tf.int32])) (input0,), _ = session.run(batcher.get_inputs([tf.int32])) self.assertAllEqual([1], input0) with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, 'Invalid computation id'): session.run(batcher.set_outputs([input0], 42))
def test_out_of_order_execution1(self): with self.test_session() as session: batcher = dynamic_batching._Batcher(minimum_batch_size=1, maximum_batch_size=1, timeout_ms=None) tp = pool.ThreadPool(10) r0 = tp.apply_async(session.run, batcher.compute([[1]], [tf.int32])) (input0,), computation_id0 = session.run(batcher.get_inputs([tf.int32])) r1 = tp.apply_async(session.run, batcher.compute([[2]], [tf.int32])) (input1,), computation_id1 = session.run(batcher.get_inputs([tf.int32])) self.assertAllEqual([1], input0) self.assertAllEqual([2], input1) session.run(batcher.set_outputs([input0 + 42], computation_id0)) session.run(batcher.set_outputs([input1 + 42], computation_id1)) self.assertAllEqual([43], r0.get()) self.assertAllEqual([44], r1.get())