def testBatchFunctionOpWithInputError(self): """Tests that batch_function op works with error in the inputs.""" if context.executing_eagerly(): return with self.cached_session() as sess: inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) @function.Defun(dtypes.int32, dtypes.int32) def computation(in0, in1): return in0 + in1 result = gen_batch_ops.batch_function( [inp], # computation actually expects 2 inputs. num_batch_threads=1, max_batch_size=10, batch_timeout_micros=100000, # 100ms batching_queue="", f=computation, captured_tensors=computation.captured_inputs, Tout=[ o.type for o in computation.definition.signature.output_arg ]) with self.assertRaisesRegex( InvalidArgumentError, r"Function takes 2 argument\(s\) but 1 argument\(s\) were passed" ): sess.run([result], feed_dict={inp: [2]})
def testBatchFunctionOp(self): """Tests that the batch_function op works.""" with self.test_session() as sess: @function.Defun(dtypes.int32) def computation(in_t): return in_t + 1 inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) result = gen_batch_ops.batch_function( [inp], num_batch_threads=1, max_batch_size=10, batch_timeout_micros=100000, Tout=[dtypes.int32], f=computation, captured_tensors=computation.captured_inputs) thread_results = [] def worker(): thread_results.extend(sess.run([result], feed_dict={inp: [1]})) worker_thread = threading.Thread(target=worker) worker_thread.start() main_results = sess.run([result], feed_dict={inp: [2]}) worker_thread.join() self.assertEqual(thread_results[0], [2]) self.assertEqual(main_results[0], [3])
def testBatchFunctionOp(self): """Tests that the batch_function op works.""" if context.executing_eagerly(): return with self.cached_session() as sess: @function.Defun(dtypes.int32) def computation(in_t): return in_t + 1 inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) result = gen_batch_ops.batch_function( [inp], num_batch_threads=1, max_batch_size=10, batch_timeout_micros=100000, Tout=[dtypes.int32], f=computation, captured_tensors=computation.captured_inputs) thread_results = [] def worker(): thread_results.extend(sess.run([result], feed_dict={inp: [1]})) worker_thread = threading.Thread(target=worker) worker_thread.start() main_results = sess.run([result], feed_dict={inp: [2]}) worker_thread.join() self.assertEqual(thread_results[0], [2]) self.assertEqual(main_results[0], [3])
def testBatchFunctionOpWithCapturedInput(self): """Tests that batch_function op works with captured input.""" with self.test_session() as sess: captured_inp0 = array_ops.placeholder_with_default(2, shape=[]) captured_inp1 = array_ops.placeholder_with_default(1, shape=[]) inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) @function.Defun(dtypes.int32) def computation(inp): return inp + captured_inp0 - captured_inp1 result = gen_batch_ops.batch_function( num_batch_threads=1, max_batch_size=10, batch_timeout_micros=100000, # 100ms allowed_batch_sizes=[3, 10], batching_queue="", f=computation, in_tensors=[inp], captured_tensors=computation.captured_inputs, Tout=[o.type for o in computation.definition.signature.output_arg]) thread_results = [] def worker(): thread_results.extend(sess.run([result], feed_dict={inp: [1]})) worker_thread = threading.Thread(target=worker) worker_thread.start() main_results = sess.run([result], feed_dict={inp: [2]}) worker_thread.join() self.assertEqual(thread_results[0], [2]) self.assertEqual(main_results[0], [3])
def testBatchFunctionOpWithCapturedInput(self): """Tests that batch_function op works with captured input.""" with self.cached_session() as sess: captured_inp0 = array_ops.placeholder_with_default(2, shape=[]) captured_inp1 = array_ops.placeholder_with_default(1, shape=[]) inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) @function.Defun(dtypes.int32) def computation(inp): return inp + captured_inp0 - captured_inp1 result = gen_batch_ops.batch_function( num_batch_threads=1, max_batch_size=10, batch_timeout_micros=100000, # 100ms allowed_batch_sizes=[3, 10], batching_queue="", f=computation, in_tensors=[inp], captured_tensors=computation.captured_inputs, Tout=[ o.type for o in computation.definition.signature.output_arg ]) thread_results = [] def worker(): thread_results.extend(sess.run([result], feed_dict={inp: [1]})) worker_thread = threading.Thread(target=worker) worker_thread.start() main_results = sess.run([result], feed_dict={inp: [2]}) worker_thread.join() self.assertEqual(thread_results[0], [2]) self.assertEqual(main_results[0], [3])
def decorated(*args): # pylint: disable=missing-docstring @function.defun(autograph=autograph) def computation(*computation_args): return fn(*computation_args) computation = computation.get_concrete_function( *[tensor_spec.TensorSpec(dtype=x.dtype, shape=x.shape, name=str(i)) for i, x in enumerate(args)]) with ops.name_scope("batch") as name: for a in args: if not isinstance(a, ops.Tensor): raise ValueError("All arguments to functions decorated with " "`batch_function` are supposed to be Tensors; " "found %s" % repr(a)) return gen_batch_ops.batch_function( num_batch_threads=num_batch_threads, max_batch_size=max_batch_size, batch_timeout_micros=batch_timeout_micros, allowed_batch_sizes=allowed_batch_sizes, max_enqueued_batches=max_enqueued_batches, shared_name=name, f=computation, in_tensors=list(args), captured_tensors=computation.captured_inputs, Tout=[o.dtype for o in computation.outputs])
def decorated(*args): # pylint: disable=missing-docstring @function.defun() def computation(*computation_args): return fn(*computation_args) computation = computation.get_concrete_function( *[tensor_spec.TensorSpec(dtype=x.dtype, shape=x.shape, name=str(i)) for i, x in enumerate(args)]) with ops.name_scope("batch") as name: for a in args: if not isinstance(a, ops.Tensor): raise ValueError("All arguments to functions decorated with " "`batch_function` are supposed to be Tensors; " "found %s" % repr(a)) return gen_batch_ops.batch_function( num_batch_threads=num_batch_threads, max_batch_size=max_batch_size, batch_timeout_micros=batch_timeout_micros, allowed_batch_sizes=allowed_batch_sizes, max_enqueued_batches=max_enqueued_batches, shared_name=name, f=computation, in_tensors=list(args), captured_tensors=computation.captured_inputs, Tout=[o.dtype for o in computation.outputs])
def decorated(*args): # pylint: disable=missing-docstring types = [arg.dtype for arg in args] @function.Defun(*types) def computation(*computation_args): return fn(*computation_args) with ops.name_scope("batch") as name: for a in args: if not isinstance(a, ops.Tensor): raise ValueError( "All arguments to functions decorated with " "`batch_function` are supposed to be Tensors; " "found %s" % repr(a)) return gen_batch_ops.batch_function( num_batch_threads=num_batch_threads, max_batch_size=max_batch_size, batch_timeout_micros=batch_timeout_micros, allowed_batch_sizes=allowed_batch_sizes, max_enqueued_batches=max_enqueued_batches, shared_name=name, f=computation, in_tensors=list(args), captured_tensors=computation.captured_inputs, Tout=[ o.type for o in computation.definition.signature.output_arg ])
def decorated(*args): # pylint: disable=missing-docstring types = [arg.dtype for arg in args] @function.Defun(*types) def computation(*computation_args): return fn(*computation_args) with ops.name_scope("batch") as name: for a in args: if not isinstance(a, ops.Tensor): raise ValueError("All arguments to functions decorated with " "`batch_function` are supposed to be Tensors; " "found %s" % repr(a)) for inp in computation.captured_inputs: print("inp: %s" % inp) for op in inp.consumers(): print("op: %s" % op) return gen_batch_ops.batch_function( num_batch_threads=num_batch_threads, max_batch_size=max_batch_size, batch_timeout_micros=batch_timeout_micros, allowed_batch_sizes=allowed_batch_sizes, max_enqueued_batches=max_enqueued_batches, shared_name=name, f=computation, in_tensors=list(args), captured_tensors=computation.captured_inputs, Tout=[o.type for o in computation.definition.signature.output_arg])
def testBatchFunctionOpWithLargeBatchSplitted(self): """Tests that the batch_function op works with large batch splitted.""" if context.executing_eagerly(): return with self.cached_session() as sess: @function.Defun(dtypes.int32) def computation(in_t): return in_t + 3 inp = array_ops.placeholder(dtype=dtypes.int32) result = gen_batch_ops.batch_function( [inp], num_batch_threads=2, # enable_large_batch_splitting is True, so it's valid as long as # max('allowed_batch_sizes') <= 'max_batch_size'. allowed_batch_sizes=[1, 2], max_batch_size=5, batch_timeout_micros=100000, # 100ms Tout=[dtypes.int32], enable_large_batch_splitting=True, f=computation, captured_tensors=computation.captured_inputs) thread1_results = [] thread2_results = [] # Input sizes of worker1 and main thread are larger than # max(allowed_batch_sizes), while input size of worker2 is smaller. def worker1(): thread1_results.extend( sess.run([result], feed_dict={inp: [5, 6, 7, 8, 9]})) worker_thread1 = threading.Thread(target=worker1) worker_thread1.start() def worker2(): thread2_results.extend( sess.run([result], feed_dict={inp: [10]})) worker_thread2 = threading.Thread(target=worker2) worker_thread2.start() main_results = sess.run([result], feed_dict={inp: [2, 3, 4]}) worker_thread1.join() worker_thread2.join() self.assertTrue( np.all( np.equal(thread2_results[0], np.array([13], dtype=np.int32)))) self.assertTrue( np.all( np.equal(thread1_results[0], np.array([8, 9, 10, 11, 12], dtype=np.int32)))) self.assertTrue( np.all( np.equal(main_results[0], np.array([5, 6, 7], dtype=np.int32))))
def testBatchFunctionOpWithInputError(self): """Tests that batch_function op works with error in the inputs.""" with self.test_session() as sess: inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) @function.Defun(dtypes.int32, dtypes.int32) def computation(in0, in1): return in0 + in1 result = gen_batch_ops.batch_function( [inp], # computation actually expects 2 inputs. num_batch_threads=1, max_batch_size=10, batch_timeout_micros=100000, # 100ms batching_queue="", f=computation, captured_tensors=computation.captured_inputs, Tout=[o.type for o in computation.definition.signature.output_arg]) with self.assertRaisesRegexp(InvalidArgumentError, ".*2 arguments.*but 1.*"): sess.run([result], feed_dict={inp: [2]})