def _record_operation_vjp(g, ans, vs, gvs, output_tensors, input_tensors, side_outputs, backward_function): """Gradient for _record_operation.""" del vs, gvs, input_tensors, output_tensors backward_args = tuple(g) + tuple(side_outputs) backward_args = container_types.make_sequence( EagerList, *(tuple(ans) + backward_args)) tensors = nest.flatten(backward_function(*backward_args)) return container_types.make_sequence(EagerList, *tensors)
def grad_fn(*orig_outputs): """Generated gradient function.""" tensors = inputs + list(orig_outputs) tensors = container_types.make_sequence(tape.EagerList, *tensors) result = _magic_gradient_function(op_name, attrs, len(inputs), num_outputs, *(tensors)) if _tracing: print("Gradient for", (name if name else op_name), "inputs", inputs, "output_grads", orig_outputs[results_size:], "gradients", result) return result
def record_operation(o, i, s, b): """Primitive to trigger autograd tracing on outputs from inputs.""" inputs = container_types.make_sequence(EagerList, *i) return _record_operation(o, inputs, s, b)
def record_operation(o, i, s, b): """Primitive to trigger autograd tracing on outputs from inputs.""" inputs = container_types.make_sequence(_EagerList, *i) return _record_operation(o, inputs, s, b)
def fwd_grad_make_sequence(argnum, g, ans, gvs, vs, args, kwargs): typ, elts = args[0], args[1:] zeros = list(elt.vspace.zeros() if isnode(elt) else vspace(elt).zeros() for elt in elts) zeros[argnum - 1] = g return ct.make_sequence(typ, *zeros)