def test_out_of_order_execution2(self):
        with self.test_session() as session:
            batcher = dynamic_batching._Batcher(minimum_batch_size=1,
                                                maximum_batch_size=1,
                                                timeout_ms=None)

            tp = pool.ThreadPool(10)
            r0 = tp.apply_async(session.run, batcher.compute([[1]],
                                                             [tf.int32]))
            (input0, ), computation_id0 = session.run(
                batcher.get_inputs([tf.int32]))
            r1 = tp.apply_async(session.run, batcher.compute([[2]],
                                                             [tf.int32]))
            (input1, ), computation_id1 = session.run(
                batcher.get_inputs([tf.int32]))

            self.assertAllEqual([1], input0)
            self.assertAllEqual([2], input1)

            # These two runs are switched from testOutOfOrderExecution1.
            session.run(batcher.set_outputs([input1 + 42], computation_id1))
            session.run(batcher.set_outputs([input0 + 42], computation_id0))

            self.assertAllEqual([43], r0.get())
            self.assertAllEqual([44], r1.get())
    def test_op_shape(self):
        with self.test_session():
            batcher = dynamic_batching._Batcher(minimum_batch_size=1,
                                                maximum_batch_size=1,
                                                timeout_ms=None)

            _, computation_id = batcher.get_inputs([tf.int32])

            self.assertEqual([], computation_id.shape)
  def test_op_shape(self):
    with self.test_session():
      batcher = dynamic_batching._Batcher(minimum_batch_size=1,
                                          maximum_batch_size=1,
                                          timeout_ms=None)

      _, computation_id = batcher.get_inputs([tf.int32])

      self.assertEqual([], computation_id.shape)
    def test_invalid_computation_id(self):
        with self.test_session() as session:
            batcher = dynamic_batching._Batcher(minimum_batch_size=1,
                                                maximum_batch_size=1,
                                                timeout_ms=None)

            tp = pool.ThreadPool(10)
            tp.apply_async(session.run, batcher.compute([[1]], [tf.int32]))
            (input0, ), _ = session.run(batcher.get_inputs([tf.int32]))

            self.assertAllEqual([1], input0)

            with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
                                         'Invalid computation id'):
                session.run(batcher.set_outputs([input0], 42))
  def test_invalid_computation_id(self):
    with self.test_session() as session:
      batcher = dynamic_batching._Batcher(minimum_batch_size=1,
                                          maximum_batch_size=1,
                                          timeout_ms=None)

      tp = pool.ThreadPool(10)
      tp.apply_async(session.run, batcher.compute([[1]], [tf.int32]))
      (input0,), _ = session.run(batcher.get_inputs([tf.int32]))

      self.assertAllEqual([1], input0)

      with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
                                   'Invalid computation id'):
        session.run(batcher.set_outputs([input0], 42))
  def test_out_of_order_execution1(self):
    with self.test_session() as session:
      batcher = dynamic_batching._Batcher(minimum_batch_size=1,
                                          maximum_batch_size=1,
                                          timeout_ms=None)

      tp = pool.ThreadPool(10)
      r0 = tp.apply_async(session.run, batcher.compute([[1]], [tf.int32]))
      (input0,), computation_id0 = session.run(batcher.get_inputs([tf.int32]))
      r1 = tp.apply_async(session.run, batcher.compute([[2]], [tf.int32]))
      (input1,), computation_id1 = session.run(batcher.get_inputs([tf.int32]))

      self.assertAllEqual([1], input0)
      self.assertAllEqual([2], input1)

      session.run(batcher.set_outputs([input0 + 42], computation_id0))
      session.run(batcher.set_outputs([input1 + 42], computation_id1))

      self.assertAllEqual([43], r0.get())
      self.assertAllEqual([44], r1.get())