예제 #1
0
def _OpsBetween(graph, to_ops, from_ops):
    """Build the list of operations between two lists of Operations.

  Args:
    graph: a Graph.
    to_ops: list of Operations.
    from_ops: list of Operations.

  Returns:
    The list of operations between "from_ops" and "to_ops", sorted by
    decreasing operation id. This list contains all elements of to_ops.

    TODO(touts): Think about returning an empty list if from_ops are not
    reachable from to_ops.  Presently it returns to_ops in that case.
  """
    # List of booleans, indexed by operation id, indicating if
    # an op is reached from the output of "input_ops".
    reached_ops = [False] * (graph._last_id + 1)
    # We only care to reach up to "output_ops" so we mark the
    # output ops as reached to avoid recursing past them.
    for op in to_ops:
        reached_ops[op._id] = True
    gradients_impl._MarkReachedOps(from_ops, reached_ops)
    between_ops = gradients_impl._GatherInputs(to_ops, reached_ops)
    between_ops.sort(key=lambda x: -x._id)
    return between_ops
예제 #2
0
def _OpsBetween(graph, to_ops, from_ops):
  """Build the list of operations between two lists of Operations.

  Args:
    graph: a Graph.
    to_ops: list of Operations.
    from_ops: list of Operations.

  Returns:
    The list of operations between "from_ops" and "to_ops", sorted by
    decreasing operation id. This list contains all elements of to_ops.

    TODO(touts): Think about returning an empty list if from_ops are not
    reachable from to_ops.  Presently it returns to_ops in that case.
  """
  # List of booleans, indexed by operation id, indicating if
  # an op is reached from the output of "input_ops".
  reached_ops = [False] * (graph._last_id + 1)
  # We only care to reach up to "output_ops" so we mark the
  # output ops as reached to avoid recursing past them.
  for op in to_ops:
    reached_ops[op._id] = True
  gradients_impl._MarkReachedOps(from_ops, reached_ops)
  between_ops = gradients_impl._GatherInputs(to_ops, reached_ops)
  between_ops.sort(key=lambda x: -x._id)
  return between_ops
예제 #3
0
def _OpsBetween(to_ops, from_ops):
  """Build the list of operations between two lists of Operations.

  Args:
    to_ops: list of Operations.
    from_ops: list of Operations.

  Returns:
    The list of operations between "from_ops" and "to_ops", sorted by
    decreasing operation id. This list contains all elements of to_ops.

    TODO(touts): Think about returning an empty list if from_ops are not
    reachable from to_ops.  Presently it returns to_ops in that case.
  """
  # Ops that are reachable from the output of "input_ops".
  reached_ops = set()
  # We only care to reach up to "output_ops" so we mark the
  # output ops as reached to avoid recursing past them.
  for op in to_ops:
    reached_ops.add(op)
  gradients_impl._MarkReachedOps(from_ops, reached_ops)
  between_ops = gradients_impl._GatherInputs(to_ops, reached_ops)
  between_ops.sort(key=lambda x: -x._id)
  return between_ops
예제 #4
0
def _OpsBetween(to_ops, from_ops):
  """Build the list of operations between two lists of Operations.

  Args:
    to_ops: list of Operations.
    from_ops: list of Operations.

  Returns:
    The list of operations between "from_ops" and "to_ops", sorted by
    decreasing operation id. This list contains all elements of to_ops.

    TODO(touts): Think about returning an empty list if from_ops are not
    reachable from to_ops.  Presently it returns to_ops in that case.
  """
  # Ops that are reachable from the output of "input_ops".
  reached_ops = set()
  # We only care to reach up to "output_ops" so we mark the
  # output ops as reached to avoid recursing past them.
  for op in to_ops:
    reached_ops.add(op)
  gradients_impl._MarkReachedOps(from_ops, reached_ops)
  between_ops = gradients_impl._GatherInputs(to_ops, reached_ops)
  between_ops.sort(key=lambda x: -x._id)
  return between_ops