Exemplo n.º 1
0
 def decorated(*args):
     with ops.name_scope("batch") as name:
         for a in args:
             if not isinstance(a, ops.Tensor):
                 raise ValueError(
                     "All arguments to functions decorated with "
                     "`batch_function`  are supposed to be Tensors; "
                     "found %s" % repr(a))
         batched_tensors, batch_index, id_t = gen_batch_ops.batch(
             args,
             num_batch_threads=num_batch_threads,
             max_batch_size=max_batch_size,
             batch_timeout_micros=batch_timeout_micros,
             max_enqueued_batches=max_enqueued_batches,
             allowed_batch_sizes=allowed_batch_sizes,
             grad_timeout_micros=grad_timeout_micros,
             shared_name=name)
         outputs = f(*batched_tensors)
         if isinstance(outputs, ops.Tensor):
             outputs_list = [outputs]
         else:
             outputs_list = outputs
         with ops.name_scope("unbatch") as unbatch_name:
             unbatched = [
                 gen_batch_ops.unbatch(
                     t,
                     batch_index,
                     id_t,
                     timeout_micros=unbatch_timeout_micros,
                     shared_name=unbatch_name + "/" + t.name)
                 for t in outputs_list
             ]
         if isinstance(outputs, ops.Tensor):
             return unbatched[0]
         return unbatched
Exemplo n.º 2
0
 def decorated(*args):
   with ops.name_scope("batch") as name:
     for a in args:
       if not isinstance(a, ops.Tensor):
         raise ValueError("All arguments to functions decorated with "
                          "`batch_function`  are supposed to be Tensors; "
                          "found %s" % repr(a))
     batched_tensors, batch_index, id_t = gen_batch_ops.batch(
         args,
         num_batch_threads=num_batch_threads,
         max_batch_size=max_batch_size,
         batch_timeout_micros=batch_timeout_micros,
         max_enqueued_batches=max_enqueued_batches,
         allowed_batch_sizes=allowed_batch_sizes,
         grad_timeout_micros=grad_timeout_micros,
         shared_name=name)
     outputs = f(*batched_tensors)
     if isinstance(outputs, ops.Tensor):
       outputs_list = [outputs]
     else:
       outputs_list = outputs
     with ops.name_scope("unbatch") as unbatch_name:
       unbatched = [
           gen_batch_ops.unbatch(t, batch_index, id_t,
                                 timeout_micros=unbatch_timeout_micros,
                                 shared_name=unbatch_name + "/" + t.name)
           for t in outputs_list]
     if isinstance(outputs, ops.Tensor):
       return unbatched[0]
     return unbatched
Exemplo n.º 3
0
def _BatchGrad(op, *out_grads):  # pylint: disable=invalid-name
    """Gradient for batch op."""
    gradients = []
    for i in range(len(op.inputs)):
        gradients.append(
            gen_batch_ops.unbatch(
                out_grads[i],
                op.outputs[-2],
                op.outputs[-1],
                timeout_micros=op.get_attr("grad_timeout_micros"),
                shared_name="batch_gradient_{}_{}".format(op.name, i)))
    return gradients
Exemplo n.º 4
0
def _BatchGrad(op, *out_grads):  # pylint: disable=invalid-name
  """Gradient for batch op."""
  gradients = []
  for i in range(len(op.inputs)):
    gradients.append(
        gen_batch_ops.unbatch(
            out_grads[i],
            op.outputs[-2],
            op.outputs[-1],
            timeout_micros=op.get_attr("grad_timeout_micros"),
            shared_name="batch_gradient_{}_{}".format(op.name, i)))
  return gradients