def decorated(*args): with ops.name_scope("batch") as name: for a in args: if not isinstance(a, ops.Tensor): raise ValueError( "All arguments to functions decorated with " "`batch_function` are supposed to be Tensors; " "found %s" % repr(a)) batched_tensors, batch_index, id_t = gen_batch_ops.batch( args, num_batch_threads=num_batch_threads, max_batch_size=max_batch_size, batch_timeout_micros=batch_timeout_micros, max_enqueued_batches=max_enqueued_batches, allowed_batch_sizes=allowed_batch_sizes, grad_timeout_micros=grad_timeout_micros, shared_name=name) outputs = f(*batched_tensors) if isinstance(outputs, ops.Tensor): outputs_list = [outputs] else: outputs_list = outputs with ops.name_scope("unbatch") as unbatch_name: unbatched = [ gen_batch_ops.unbatch( t, batch_index, id_t, timeout_micros=unbatch_timeout_micros, shared_name=unbatch_name + "/" + t.name) for t in outputs_list ] if isinstance(outputs, ops.Tensor): return unbatched[0] return unbatched
def decorated(*args): with ops.name_scope("batch") as name: for a in args: if not isinstance(a, ops.Tensor): raise ValueError("All arguments to functions decorated with " "`batch_function` are supposed to be Tensors; " "found %s" % repr(a)) batched_tensors, batch_index, id_t = gen_batch_ops.batch( args, num_batch_threads=num_batch_threads, max_batch_size=max_batch_size, batch_timeout_micros=batch_timeout_micros, max_enqueued_batches=max_enqueued_batches, allowed_batch_sizes=allowed_batch_sizes, grad_timeout_micros=grad_timeout_micros, shared_name=name) outputs = f(*batched_tensors) if isinstance(outputs, ops.Tensor): outputs_list = [outputs] else: outputs_list = outputs with ops.name_scope("unbatch") as unbatch_name: unbatched = [ gen_batch_ops.unbatch(t, batch_index, id_t, timeout_micros=unbatch_timeout_micros, shared_name=unbatch_name + "/" + t.name) for t in outputs_list] if isinstance(outputs, ops.Tensor): return unbatched[0] return unbatched
def _BatchGrad(op, *out_grads): # pylint: disable=invalid-name """Gradient for batch op.""" gradients = [] for i in range(len(op.inputs)): gradients.append( gen_batch_ops.unbatch( out_grads[i], op.outputs[-2], op.outputs[-1], timeout_micros=op.get_attr("grad_timeout_micros"), shared_name="batch_gradient_{}_{}".format(op.name, i))) return gradients