def testBatchFunctionOpWithInputError(self):
        """Tests that batch_function op works with error in the inputs."""
        if context.executing_eagerly():
            return
        with self.cached_session() as sess:
            inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])

            @function.Defun(dtypes.int32, dtypes.int32)
            def computation(in0, in1):
                return in0 + in1

            result = gen_batch_ops.batch_function(
                [inp],  # computation actually expects 2 inputs.
                num_batch_threads=1,
                max_batch_size=10,
                batch_timeout_micros=100000,  # 100ms
                batching_queue="",
                f=computation,
                captured_tensors=computation.captured_inputs,
                Tout=[
                    o.type for o in computation.definition.signature.output_arg
                ])

            with self.assertRaisesRegex(
                    InvalidArgumentError,
                    r"Function takes 2 argument\(s\) but 1 argument\(s\) were passed"
            ):
                sess.run([result], feed_dict={inp: [2]})
  def testBatchFunctionOp(self):
    """Tests that the batch_function op works."""
    with self.test_session() as sess:

      @function.Defun(dtypes.int32)
      def computation(in_t):
        return in_t + 1

      inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
      result = gen_batch_ops.batch_function(
          [inp],
          num_batch_threads=1,
          max_batch_size=10,
          batch_timeout_micros=100000,
          Tout=[dtypes.int32],
          f=computation,
          captured_tensors=computation.captured_inputs)
      thread_results = []

      def worker():
        thread_results.extend(sess.run([result], feed_dict={inp: [1]}))

      worker_thread = threading.Thread(target=worker)
      worker_thread.start()
      main_results = sess.run([result], feed_dict={inp: [2]})
      worker_thread.join()
      self.assertEqual(thread_results[0], [2])
      self.assertEqual(main_results[0], [3])
    def testBatchFunctionOp(self):
        """Tests that the batch_function op works."""
        if context.executing_eagerly():
            return
        with self.cached_session() as sess:

            @function.Defun(dtypes.int32)
            def computation(in_t):
                return in_t + 1

            inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
            result = gen_batch_ops.batch_function(
                [inp],
                num_batch_threads=1,
                max_batch_size=10,
                batch_timeout_micros=100000,
                Tout=[dtypes.int32],
                f=computation,
                captured_tensors=computation.captured_inputs)
            thread_results = []

            def worker():
                thread_results.extend(sess.run([result], feed_dict={inp: [1]}))

            worker_thread = threading.Thread(target=worker)
            worker_thread.start()
            main_results = sess.run([result], feed_dict={inp: [2]})
            worker_thread.join()
            self.assertEqual(thread_results[0], [2])
            self.assertEqual(main_results[0], [3])
  def testBatchFunctionOpWithCapturedInput(self):
    """Tests that batch_function op works with captured input."""
    with self.test_session() as sess:
      captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
      captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
      inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])

      @function.Defun(dtypes.int32)
      def computation(inp):
        return inp + captured_inp0 - captured_inp1

      result = gen_batch_ops.batch_function(
          num_batch_threads=1,
          max_batch_size=10,
          batch_timeout_micros=100000,  # 100ms
          allowed_batch_sizes=[3, 10],
          batching_queue="",
          f=computation,
          in_tensors=[inp],
          captured_tensors=computation.captured_inputs,
          Tout=[o.type for o in computation.definition.signature.output_arg])

      thread_results = []

      def worker():
        thread_results.extend(sess.run([result], feed_dict={inp: [1]}))

      worker_thread = threading.Thread(target=worker)
      worker_thread.start()
      main_results = sess.run([result], feed_dict={inp: [2]})
      worker_thread.join()
      self.assertEqual(thread_results[0], [2])
      self.assertEqual(main_results[0], [3])
Exemple #5
0
    def testBatchFunctionOpWithCapturedInput(self):
        """Tests that batch_function op works with captured input."""
        with self.cached_session() as sess:
            captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
            captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
            inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])

            @function.Defun(dtypes.int32)
            def computation(inp):
                return inp + captured_inp0 - captured_inp1

            result = gen_batch_ops.batch_function(
                num_batch_threads=1,
                max_batch_size=10,
                batch_timeout_micros=100000,  # 100ms
                allowed_batch_sizes=[3, 10],
                batching_queue="",
                f=computation,
                in_tensors=[inp],
                captured_tensors=computation.captured_inputs,
                Tout=[
                    o.type for o in computation.definition.signature.output_arg
                ])

            thread_results = []

            def worker():
                thread_results.extend(sess.run([result], feed_dict={inp: [1]}))

            worker_thread = threading.Thread(target=worker)
            worker_thread.start()
            main_results = sess.run([result], feed_dict={inp: [2]})
            worker_thread.join()
            self.assertEqual(thread_results[0], [2])
            self.assertEqual(main_results[0], [3])
Exemple #6
0
    def decorated(*args):  # pylint: disable=missing-docstring

      @function.defun(autograph=autograph)
      def computation(*computation_args):
        return fn(*computation_args)

      computation = computation.get_concrete_function(
          *[tensor_spec.TensorSpec(dtype=x.dtype, shape=x.shape, name=str(i))
            for i, x in enumerate(args)])

      with ops.name_scope("batch") as name:
        for a in args:
          if not isinstance(a, ops.Tensor):
            raise ValueError("All arguments to functions decorated with "
                             "`batch_function`  are supposed to be Tensors; "
                             "found %s" % repr(a))
        return gen_batch_ops.batch_function(
            num_batch_threads=num_batch_threads,
            max_batch_size=max_batch_size,
            batch_timeout_micros=batch_timeout_micros,
            allowed_batch_sizes=allowed_batch_sizes,
            max_enqueued_batches=max_enqueued_batches,
            shared_name=name,
            f=computation,
            in_tensors=list(args),
            captured_tensors=computation.captured_inputs,
            Tout=[o.dtype for o in computation.outputs])
Exemple #7
0
    def decorated(*args):  # pylint: disable=missing-docstring

      @function.defun()
      def computation(*computation_args):
        return fn(*computation_args)

      computation = computation.get_concrete_function(
          *[tensor_spec.TensorSpec(dtype=x.dtype, shape=x.shape, name=str(i))
            for i, x in enumerate(args)])

      with ops.name_scope("batch") as name:
        for a in args:
          if not isinstance(a, ops.Tensor):
            raise ValueError("All arguments to functions decorated with "
                             "`batch_function`  are supposed to be Tensors; "
                             "found %s" % repr(a))
        return gen_batch_ops.batch_function(
            num_batch_threads=num_batch_threads,
            max_batch_size=max_batch_size,
            batch_timeout_micros=batch_timeout_micros,
            allowed_batch_sizes=allowed_batch_sizes,
            max_enqueued_batches=max_enqueued_batches,
            shared_name=name,
            f=computation,
            in_tensors=list(args),
            captured_tensors=computation.captured_inputs,
            Tout=[o.dtype for o in computation.outputs])
Exemple #8
0
        def decorated(*args):  # pylint: disable=missing-docstring
            types = [arg.dtype for arg in args]

            @function.Defun(*types)
            def computation(*computation_args):
                return fn(*computation_args)

            with ops.name_scope("batch") as name:
                for a in args:
                    if not isinstance(a, ops.Tensor):
                        raise ValueError(
                            "All arguments to functions decorated with "
                            "`batch_function`  are supposed to be Tensors; "
                            "found %s" % repr(a))
                return gen_batch_ops.batch_function(
                    num_batch_threads=num_batch_threads,
                    max_batch_size=max_batch_size,
                    batch_timeout_micros=batch_timeout_micros,
                    allowed_batch_sizes=allowed_batch_sizes,
                    max_enqueued_batches=max_enqueued_batches,
                    shared_name=name,
                    f=computation,
                    in_tensors=list(args),
                    captured_tensors=computation.captured_inputs,
                    Tout=[
                        o.type
                        for o in computation.definition.signature.output_arg
                    ])
Exemple #9
0
    def decorated(*args):  # pylint: disable=missing-docstring
      types = [arg.dtype for arg in args]

      @function.Defun(*types)
      def computation(*computation_args):
        return fn(*computation_args)

      with ops.name_scope("batch") as name:
        for a in args:
          if not isinstance(a, ops.Tensor):
            raise ValueError("All arguments to functions decorated with "
                             "`batch_function`  are supposed to be Tensors; "
                             "found %s" % repr(a))
        for inp in computation.captured_inputs:
          print("inp: %s" % inp)
          for op in inp.consumers():
            print("op: %s" % op)
        return gen_batch_ops.batch_function(
            num_batch_threads=num_batch_threads,
            max_batch_size=max_batch_size,
            batch_timeout_micros=batch_timeout_micros,
            allowed_batch_sizes=allowed_batch_sizes,
            max_enqueued_batches=max_enqueued_batches,
            shared_name=name,
            f=computation,
            in_tensors=list(args),
            captured_tensors=computation.captured_inputs,
            Tout=[o.type for o in computation.definition.signature.output_arg])
    def testBatchFunctionOpWithLargeBatchSplitted(self):
        """Tests that the batch_function op works with large batch splitted."""
        if context.executing_eagerly():
            return

        with self.cached_session() as sess:

            @function.Defun(dtypes.int32)
            def computation(in_t):
                return in_t + 3

            inp = array_ops.placeholder(dtype=dtypes.int32)
            result = gen_batch_ops.batch_function(
                [inp],
                num_batch_threads=2,
                # enable_large_batch_splitting is True, so it's valid as long as
                # max('allowed_batch_sizes') <= 'max_batch_size'.
                allowed_batch_sizes=[1, 2],
                max_batch_size=5,
                batch_timeout_micros=100000,  # 100ms
                Tout=[dtypes.int32],
                enable_large_batch_splitting=True,
                f=computation,
                captured_tensors=computation.captured_inputs)
            thread1_results = []
            thread2_results = []

            # Input sizes of worker1 and main thread are larger than
            # max(allowed_batch_sizes), while input size of worker2 is smaller.
            def worker1():
                thread1_results.extend(
                    sess.run([result], feed_dict={inp: [5, 6, 7, 8, 9]}))

            worker_thread1 = threading.Thread(target=worker1)
            worker_thread1.start()

            def worker2():
                thread2_results.extend(
                    sess.run([result], feed_dict={inp: [10]}))

            worker_thread2 = threading.Thread(target=worker2)
            worker_thread2.start()

            main_results = sess.run([result], feed_dict={inp: [2, 3, 4]})
            worker_thread1.join()
            worker_thread2.join()
            self.assertTrue(
                np.all(
                    np.equal(thread2_results[0], np.array([13],
                                                          dtype=np.int32))))
            self.assertTrue(
                np.all(
                    np.equal(thread1_results[0],
                             np.array([8, 9, 10, 11, 12], dtype=np.int32))))
            self.assertTrue(
                np.all(
                    np.equal(main_results[0],
                             np.array([5, 6, 7], dtype=np.int32))))
  def testBatchFunctionOpWithInputError(self):
    """Tests that batch_function op works with error in the inputs."""
    with self.test_session() as sess:
      inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])

      @function.Defun(dtypes.int32, dtypes.int32)
      def computation(in0, in1):
        return in0 + in1

      result = gen_batch_ops.batch_function(
          [inp],  # computation actually expects 2 inputs.
          num_batch_threads=1,
          max_batch_size=10,
          batch_timeout_micros=100000,  # 100ms
          batching_queue="",
          f=computation,
          captured_tensors=computation.captured_inputs,
          Tout=[o.type for o in computation.definition.signature.output_arg])

      with self.assertRaisesRegexp(InvalidArgumentError,
                                   ".*2 arguments.*but 1.*"):
        sess.run([result], feed_dict={inp: [2]})