Beispiel #1
0
 def forward_mode_pass(v):
     ac.assert_vspace_match(v, start_node.vspace, None)
     start_node.progenitors[start_node] = v
     active_forward_progenitors[start_node] = True
     end_node = fun(*args, **kwargs)
     active_forward_progenitors.pop(start_node)
     if not ac.isnode(end_node) or start_node not in end_node.progenitors:
         warnings.warn("Output seems independent of input.")
         return (end_node, end_node.vspace.zeros() if ac.isnode(end_node) else
                 ac.vspace(end_node).zeros())
     return end_node, end_node.progenitors[start_node]
Beispiel #2
0
def _watch_with_tape(tape, resource_variable):
    """Wraps a watched Tensor and keeps track of it in the implicit tape."""
    tensor = resource_variable.handle
    w = _watch_with_tape_internal(tape, tensor)
    if ag_core.isnode(tape):
        tape.value.variables[ops.tensor_id(tensor)] = resource_variable
        tape.value.tensors[ops.tensor_id(tensor)] = w
Beispiel #3
0
def _watch_with_tape(tape, resource_variable):
  """Wraps a watched Tensor and keeps track of it in the implicit tape."""
  tensor = resource_variable.handle
  w = _watch_with_tape_internal(tape, tensor)
  if ag_core.isnode(tape):
    tape.value.variables[ops.tensor_id(tensor)] = resource_variable
    tape.value.tensors[ops.tensor_id(tensor)] = w
Beispiel #4
0
def new_progenitor(x, fwd=False):
    if ac.isnode(x):
        node = ac.new_node(x.value, (ac.identity, (x,), {}, [(0, x)]), x.progenitors)
    else:
        node = ac.new_node(x,       (ac.identity, (x,), {}, []      ), dict()       )
    if not fwd:
        node.progenitors[node] = None
    return node
def array_from_args_fwd_gradmaker(argnum, g, ans, gvs, vs, args, kwargs):
    result = list()
    for i, arg in enumerate(args):
        if i == argnum:
            result.append(g)
        else:
            result.append(
                arg.vspace.zeros() if isnode(arg) else vspace(arg).zeros())
    return nw.array_from_args(*result)
Beispiel #6
0
def _record_operation_vjp(g, ans, vs, gvs, output_tensors, input_tensors,
                          side_outputs, backward_function):
  """Gradient for _record_operation."""
  del ans, vs, gvs, output_tensors, input_tensors
  backward_args = tuple(g) + tuple(side_outputs)
  if ag_core.isnode(backward_args):
    backward_args = list(backward_args)
  tensors = nest.flatten(backward_function(*backward_args))
  return _EagerList([ag_core.getval(t) for t in tensors])
 def _run_op(a, *args):
   # pylint: disable=protected-access
   value = a._AsTensor()
   if ag_core.isnode(value):
     # This avoids autograd trying to wrap a ResourceVariable.
     value = ops.convert_to_tensor(value)
     args = [ops.convert_to_tensor(x) for x in args]
     return getattr(tensor_node.TensorNode, operator)(value, *args)
   else:
     return getattr(ops.Tensor, operator)(value, *args)
 def _run_op(a, *args):
     # pylint: disable=protected-access
     value = a._AsTensor()
     if ag_core.isnode(value):
         # This avoids autograd trying to wrap a ResourceVariable.
         value = ops.convert_to_tensor(value)
         args = [ops.convert_to_tensor(x) for x in args]
         return getattr(tensor_node.TensorNode, operator)(value, *args)
     else:
         return getattr(ops.Tensor, operator)(value, *args)
Beispiel #9
0
def _record_gradient(op_name, inputs, attrs, results, name):
    """Records gradients for a TensorFlow operation.

  Args:
    op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to
      execute.
    inputs: A flat list of Tensor object inputs to the operation.
    attrs: A tuple with alternating string attr names and attr values for this
      operation.
    results: The results of the operation (as a flat list).
    name: Customized name for the operation.

  Returns:
    A list of maybe-wrapped results. Either Tensors or TensorNodes.

  Raises:
    An exception on error.
  """
    if not any(ag_core.isnode(x) for x in inputs):
        return results
    num_outputs = len(results)
    if num_outputs == 0:
        return results
    if attrs is not None:
        attrs = tuple(tuple(x) if isinstance(x, list) else x for x in attrs)

    # It is imperative we make a copy of results here as otherwise we create a
    # dependency cycle in the captured function and this can delay garbage
    # collecting of the tensors arbitrarily.
    results_size = len(results) if isinstance(results, (list, tuple)) else 1

    def grad_fn(*orig_outputs):
        """Generated gradient function."""
        tensors = inputs + list(orig_outputs)
        tensors = container_types.make_sequence(tape.EagerList, *tensors)
        result = _magic_gradient_function(op_name, attrs, len(inputs),
                                          num_outputs, *(tensors))
        if _tracing:
            print("Gradient for", (name if name else op_name), "inputs",
                  inputs, "output_grads", orig_outputs[results_size:],
                  "gradients", result)
        return result

    results = tape.record_operation(results, inputs, [], grad_fn)
    if _tracing:
        print("Computed op", (name if name else op_name), "inputs", inputs,
              "outputs", results)
    return results
Beispiel #10
0
    def decorated(*args, **kwargs):
        """Decorated function with custom gradient."""
        input_tensors = [
            _watch_value_from_tape(x) for x in args
            if isinstance(x, (_tensor.Tensor,
                              tf_ops.Tensor)) or ag_core.isnode(x)
        ]
        result, grad_fn = f(*args, **kwargs)

        flat_result = nest.flatten(result)
        flat_result = [ag_core.getval(x) for x in flat_result]
        flat_result = tape.record_operation(flat_result, input_tensors, [],
                                            grad_fn)
        flat_result = list(flat_result)
        return nest.pack_sequence_as(structure=result,
                                     flat_sequence=flat_result)
Beispiel #11
0
  def decorated(*args, **kwargs):
    """Decorated function with custom gradient."""
    input_tensors = [_watch_value_from_tape(x) for x in args
                     if isinstance(x, (_tensor.Tensor, tf_ops.Tensor))
                     or ag_core.isnode(x)]
    result, grad_fn = f(*args, **kwargs)

    flat_result = nest.flatten(result)
    flat_result = [ag_core.getval(x) for x in flat_result]
    flat_result = tape.record_operation(
        flat_result,
        input_tensors,
        [],
        grad_fn)
    flat_result = list(flat_result)
    return nest.pack_sequence_as(structure=result, flat_sequence=flat_result)
Beispiel #12
0
def _record_gradient(op_name, inputs, attrs, results, name):
  """Records gradients for a TensorFlow operation.

  Args:
    op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to
      execute.
    inputs: A flat list of Tensor object inputs to the operation.
    attrs: A tuple with alternating string attr names and attr values for this
      operation.
    results: The results of the operation (as a flat list).
    name: Customized name for the operation.

  Returns:
    A list of maybe-wrapped results. Either Tensors or TensorNodes.

  Raises:
    An exception on error.
  """
  if not any(ag_core.isnode(x) for x in inputs):
    return results
  num_outputs = len(results)
  if num_outputs == 0:
    return results
  if attrs is not None:
    attrs = tuple(tuple(x) if isinstance(x, list) else x for x in attrs)

  # It is imperative we make a copy of results here as otherwise we create a
  # dependency cycle in the captured function and this can delay garbage
  # collecting of the tensors arbitrarily.
  results_size = len(results) if isinstance(results, (list, tuple)) else 1

  def grad_fn(*orig_outputs):
    """Generated gradient function."""
    tensors = inputs + list(orig_outputs)
    tensors = container_types.make_sequence(tape.EagerList, *tensors)
    result = _magic_gradient_function(op_name, attrs, len(inputs),
                                      num_outputs, *(tensors))
    if _tracing:
      print("Gradient for", (name if name else op_name), "inputs", inputs,
            "output_grads", orig_outputs[results_size:], "gradients", result)
    return result

  results = tape.record_operation(results, inputs, [], grad_fn)
  if _tracing:
    print("Computed op", (name if name else op_name), "inputs", inputs,
          "outputs", results)
  return results
Beispiel #13
0
 def grad_fn(*args, **kwds):
   """Computes the gradient of the wrapped function."""
   tape.push_new_tape()
   end_node = f(*args)
   start_node = tape.pop_tape()
   ag_core.active_progenitors.remove(start_node)
   if not ag_core.isnode(end_node):
     raise ValueError(
         "Target not part of a computation being traced. %s" % end_node)
   if start_node not in end_node.progenitors:
     raise ValueError("Target not derived from source. %s %s" %
                      (end_node.progenitors, repr(start_node)))
   output_gradients = kwds.get("output_gradients", None)
   if output_gradients is None:
     output_gradients = _ones(end_node.shape, end_node.dtype)
   grad = ag_core.backward_pass(output_gradients, end_node, start_node)
   return end_node.value, _aggregate_grads(grad.gradients)
Beispiel #14
0
 def grad_fn(*args, **kwds):
     """Computes the gradient of the wrapped function."""
     tape.push_new_tape()
     end_node = f(*args)
     start_node = tape.pop_tape()
     ag_core.active_progenitors.remove(start_node)
     if not ag_core.isnode(end_node):
         raise ValueError(
             "Target not part of a computation being traced. %s" % end_node)
     if start_node not in end_node.progenitors:
         raise ValueError("Target not derived from source. %s %s" %
                          (end_node.progenitors, repr(start_node)))
     output_gradients = kwds.get("output_gradients", None)
     if output_gradients is None:
         output_gradients = _ones(end_node.shape, end_node.dtype)
     grad = ag_core.backward_pass(output_gradients, end_node, start_node)
     return end_node.value, _aggregate_grads(grad.gradients)
Beispiel #15
0
 def __init__(self, input_placeholders, extra_inputs, fdef, graph, operations,
              func_outputs, func_outputs_to_fdef_outputs, output_shapes):
   assert len(input_placeholders) == len(fdef.signature.input_arg), "%s %s" % (
       len(input_placeholders), len(fdef.signature.input_arg))
   self._input_placeholders = input_placeholders
   self._extra_inputs = list(extra_inputs)
   self._graph = graph
   self._has_backprop = False
   self._func_name = fdef.signature.name
   self._fdef = _DefinedFunction(fdef)
   self._num_outputs = len(fdef.signature.output_arg)
   self._ops = operations
   self._func_outputs = func_outputs
   if (isinstance(func_outputs, (ops.Tensor, type(None))) or
       ag_core.isnode(func_outputs)):
     self._returns = [func_outputs]
   else:
     self._returns = list(func_outputs)
   self._returns_to_fedf_outputs = func_outputs_to_fdef_outputs
   self._output_shapes = output_shapes
Beispiel #16
0
 def __init__(self, input_placeholders, extra_inputs, fdef, graph, operations,
              func_outputs, func_outputs_to_fdef_outputs, output_shapes):
   assert len(input_placeholders) == len(fdef.signature.input_arg), "%s %s" % (
       len(input_placeholders), len(fdef.signature.input_arg))
   self._input_placeholders = input_placeholders
   self._extra_inputs = list(extra_inputs)
   self._graph = graph
   self._has_backprop = False
   self._func_name = fdef.signature.name
   self._fdef = _DefinedFunction(fdef)
   self._num_outputs = len(fdef.signature.output_arg)
   self._ops = operations
   self._func_outputs = func_outputs
   if (isinstance(func_outputs, (ops.Tensor, type(None))) or
       ag_core.isnode(func_outputs)):
     self._returns = [func_outputs]
   else:
     self._returns = list(func_outputs)
   self._returns_to_fedf_outputs = func_outputs_to_fdef_outputs
   self._output_shapes = output_shapes
Beispiel #17
0
def find_progenitors(self, args):
    argvals = list(args)
    parents = []
    rev_progenitors = set()
    fwd_progenitors = defaultdict(list)
    for argnum, arg in enumerate(args):
        if ac.isnode(arg):
            argvals[argnum] = arg.value
            if argnum in self.zero_vjps: continue

            arg_rev_progenitors, arg_fwd_progenitors = split_progenitors(arg.progenitors)
            for progenitor in arg_fwd_progenitors:
                if active_forward_progenitors.get(progenitor, False):
                    fwd_progenitors[progenitor].append((argnum, arg))

            reverse = arg_rev_progenitors & ac.active_progenitors
            if reverse:
                parents.append((argnum, arg))
                rev_progenitors.update(reverse)

    return argvals, parents, rev_progenitors, fwd_progenitors
Beispiel #18
0
  def __call__(self, *args):
    """Executes the passed function in eager mode."""
    tensor_inputs = [
        x for x in nest.flatten(args)
        if isinstance(x, (tensor.Tensor, ops.Tensor,
                          tensor.LazyZero)) or ag_core.isnode(x)
    ]
    if tape.should_record(tensor_inputs) or any(
        tape.any_tape_has(t) for t in self._extra_inputs):
      if not self._has_backprop:
        self._compute_backprop()
      return self._backprop_call(tensor_inputs)

    if context.in_graph_mode():
      g = ops.get_default_graph()
      g._add_function(self._fdef)  # pylint: disable=protected-access
      signature = self._fdef.definition.signature
      args = list(tensor_inputs) + self._extra_inputs
      op = g.create_op(
          signature.name, [ops.convert_to_tensor(x) for x in args],
          [dtypes.DType(x.type) for x in signature.output_arg],
          op_def=signature,
          name="FunctionCall",
          compute_shapes=False)
      result = op.outputs
      for i, s in enumerate(self._output_shapes):
        result[i].set_shape(s)
    else:
      tensor_inputs = [
          x.tensor() if isinstance(x, tensor.LazyZero) else x
          for x in tensor_inputs
      ]
      result = execute.execute(
          self._func_name,
          num_outputs=self._num_outputs,
          inputs=tensor_inputs + self._extra_inputs)

    return self._build_call_outputs(self._returns, result)
Beispiel #19
0
  def __call__(self, *args):
    """Executes the passed function in eager mode."""
    tensor_inputs = [
        x for x in nest.flatten(args)
        if isinstance(x, (tensor.Tensor, ops.Tensor,
                          tensor.LazyZero)) or ag_core.isnode(x)
    ]
    if tape.should_record(tensor_inputs) or any(
        tape.any_tape_has(t) for t in self._extra_inputs):
      if not self._has_backprop:
        self._compute_backprop()
      return self._backprop_call(tensor_inputs)

    if context.in_graph_mode():
      g = ops.get_default_graph()
      g._add_function(self._fdef)  # pylint: disable=protected-access
      signature = self._fdef.definition.signature
      args = list(tensor_inputs) + self._extra_inputs
      op = g.create_op(
          signature.name, [ops.convert_to_tensor(x) for x in args],
          [dtypes.DType(x.type) for x in signature.output_arg],
          op_def=signature,
          name="FunctionCall",
          compute_shapes=False)
      result = op.outputs
      for i, s in enumerate(self._output_shapes):
        result[i].set_shape(s)
    else:
      tensor_inputs = [
          x.tensor() if isinstance(x, tensor.LazyZero) else x
          for x in tensor_inputs
      ]
      result = execute.execute(
          self._func_name,
          num_outputs=self._num_outputs,
          inputs=tensor_inputs + self._extra_inputs)

    return self._build_call_outputs(self._returns, result)
Beispiel #20
0
    def decorated(*args, **kwargs):
        """Decorated function with custom gradient."""
        input_tensors = [
            _watch_value_from_tape(x) for x in args
            if isinstance(x, (_tensor.Tensor,
                              tf_ops.Tensor)) or ag_core.isnode(x)
        ]
        result, grad_fn = f(*args, **kwargs)
        result_size = len(result) if isinstance(result, (list, tuple)) else 1

        # TODO(apassos): naive uses of custom_gradient will not get the correct
        # second derivative this way if they capture any output tensors. Change the
        # signature of custom_gradient.
        def actual_grad_fn(*outputs):
            outputs = outputs[result_size:]
            return grad_fn(*outputs)

        flat_result = nest.flatten(result)
        flat_result = [ag_core.getval(x) for x in flat_result]
        flat_result = tape.record_operation(flat_result, input_tensors, [],
                                            actual_grad_fn)
        flat_result = list(flat_result)
        return nest.pack_sequence_as(structure=result,
                                     flat_sequence=flat_result)
Beispiel #21
0
  def decorated(*args, **kwargs):
    """Decorated function with custom gradient."""
    input_tensors = [_watch_value_from_tape(x) for x in args
                     if isinstance(x, (_tensor.Tensor, tf_ops.Tensor))
                     or ag_core.isnode(x)]
    result, grad_fn = f(*args, **kwargs)
    result_size = len(result) if isinstance(result, (list, tuple)) else 1

    # TODO(apassos): naive uses of custom_gradient will not get the correct
    # second derivative this way if they capture any output tensors. Change the
    # signature of custom_gradient.
    def actual_grad_fn(*outputs):
      outputs = outputs[result_size:]
      return grad_fn(*outputs)

    flat_result = nest.flatten(result)
    flat_result = [ag_core.getval(x) for x in flat_result]
    flat_result = tape.record_operation(
        flat_result,
        input_tensors,
        [],
        actual_grad_fn)
    flat_result = list(flat_result)
    return nest.pack_sequence_as(structure=result, flat_sequence=flat_result)
Beispiel #22
0
def should_record(tensors):
  """Returns true if any tape in the stach watches any of these tensors."""
  return any(ag_core.isnode(x) for x in tensors)
def fwd_grad_make_sequence(argnum, g, ans, gvs, vs, args, kwargs):
    typ, elts = args[0], args[1:]
    zeros = list(elt.vspace.zeros() if isnode(elt) else vspace(elt).zeros()
                 for elt in elts)
    zeros[argnum - 1] = g
    return ct.make_sequence(typ, *zeros)
Beispiel #24
0
def _watch_with_tape(tape, tensor):
  """Wraps a watched Tensor and keeps track of it in the implicit tape."""
  w = _watch_with_tape_internal(tape, tensor)
  if ag_core.isnode(tape):
    tape.value.tensors[ops.tensor_id(tensor)] = w
  return w
Beispiel #25
0
 def __init__(self, value):
   super(EagerList, self).__init__(value)
   for v in value:
     assert not ag_core.isnode(v)
Beispiel #26
0
def should_record(tensors):
  """Returns true if any tape in the stach watches any of these tensors."""
  return any(ag_core.isnode(x) for x in tensors)
Beispiel #27
0
def _watch_with_tape(tape, tensor):
  """Wraps a watched Tensor and keeps track of it in the implicit tape."""
  w = _watch_with_tape_internal(tape, tensor)
  if ag_core.isnode(tape):
    tape.value.tensors[tensor_id(tensor)] = w
  return w
def fwd_grad_sequence_extend_right(argnum, g, ans, gvs, vs, args, kwargs):
    zeros = list(arg.vspace.zeros() if isnode(arg) else vspace(arg).zeros()
                 for arg in args)
    zeros[argnum] = g
    return ct.sequence_extend_right(*zeros)
Beispiel #29
0
 def __init__(self, value):
   super(_EagerList, self).__init__(value)
   for v in value:
     assert not ag_core.isnode(v)