Esempio n. 1
0
def _record_operation_vjp(g, ans, vs, gvs, output_tensors, input_tensors,
                          side_outputs, backward_function):
  """Gradient for _record_operation."""
  del vs, gvs, input_tensors, output_tensors
  backward_args = tuple(g) + tuple(side_outputs)
  backward_args = container_types.make_sequence(
      EagerList, *(tuple(ans) + backward_args))
  tensors = nest.flatten(backward_function(*backward_args))
  return container_types.make_sequence(EagerList, *tensors)
Esempio n. 2
0
def _record_operation_vjp(g, ans, vs, gvs, output_tensors, input_tensors,
                          side_outputs, backward_function):
    """Gradient for _record_operation."""
    del vs, gvs, input_tensors, output_tensors
    backward_args = tuple(g) + tuple(side_outputs)
    backward_args = container_types.make_sequence(
        EagerList, *(tuple(ans) + backward_args))
    tensors = nest.flatten(backward_function(*backward_args))
    return container_types.make_sequence(EagerList, *tensors)
Esempio n. 3
0
 def grad_fn(*orig_outputs):
   """Generated gradient function."""
   tensors = inputs + list(orig_outputs)
   tensors = container_types.make_sequence(tape.EagerList, *tensors)
   result = _magic_gradient_function(op_name, attrs, len(inputs),
                                     num_outputs, *(tensors))
   if _tracing:
     print("Gradient for", (name if name else op_name), "inputs", inputs,
           "output_grads", orig_outputs[results_size:], "gradients", result)
   return result
Esempio n. 4
0
 def grad_fn(*orig_outputs):
     """Generated gradient function."""
     tensors = inputs + list(orig_outputs)
     tensors = container_types.make_sequence(tape.EagerList, *tensors)
     result = _magic_gradient_function(op_name, attrs, len(inputs),
                                       num_outputs, *(tensors))
     if _tracing:
         print("Gradient for", (name if name else op_name), "inputs",
               inputs, "output_grads", orig_outputs[results_size:],
               "gradients", result)
     return result
Esempio n. 5
0
def record_operation(o, i, s, b):
  """Primitive to trigger autograd tracing on outputs from inputs."""
  inputs = container_types.make_sequence(EagerList, *i)
  return _record_operation(o, inputs, s, b)
Esempio n. 6
0
File: tape.py Progetto: lengjia/RRL
def record_operation(o, i, s, b):
  """Primitive to trigger autograd tracing on outputs from inputs."""
  inputs = container_types.make_sequence(_EagerList, *i)
  return _record_operation(o, inputs, s, b)
Esempio n. 7
0
def fwd_grad_make_sequence(argnum, g, ans, gvs, vs, args, kwargs):
    typ, elts = args[0], args[1:]
    zeros = list(elt.vspace.zeros() if isnode(elt) else vspace(elt).zeros()
                 for elt in elts)
    zeros[argnum - 1] = g
    return ct.make_sequence(typ, *zeros)