Example #1
0
def _reroute_sgv_inputs(sgv0, sgv1, mode):
  """Re-route all the inputs of two subgraphs.

  Args:
    sgv0: the first subgraph to have its inputs swapped. This argument is
      converted to a subgraph using the same rules than the function
      subgraph.make_view.
    sgv1: the second subgraph to have its inputs swapped. This argument is
      converted to a subgraph using the same rules than the function
      subgraph.make_view.
    mode: reroute mode, see _reroute_ts(...).
  Returns:
    A tuple `(sgv0, sgv1)` of subgraph views with their inputs swapped.
      Note that the function argument sgv0 and sgv1 are also modified in place.
  Raises:
    StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using
      the same rules than the function subgraph.make_view.
  """
  sgv0 = _subgraph.make_view(sgv0)
  sgv1 = _subgraph.make_view(sgv1)
  _util.check_graphs(sgv0, sgv1)
  can_modify = sgv0.ops + sgv1.ops
  # also allow consumers of passthrough to be modified:
  can_modify += _util.get_consuming_ops(sgv0.passthroughs)
  can_modify += _util.get_consuming_ops(sgv1.passthroughs)
  _reroute_ts(sgv0.inputs, sgv1.inputs, mode, can_modify=can_modify)
  _reroute_sgv_remap(sgv0, sgv1, mode)
  return sgv0, sgv1
Example #2
0
def _reroute_sgv_inputs(sgv0, sgv1, mode):
    """Re-route all the inputs of two subgraphs.

  Args:
    sgv0: the first subgraph to have its inputs swapped. This argument is
      converted to a subgraph using the same rules than the function
      subgraph.make_view.
    sgv1: the second subgraph to have its inputs swapped. This argument is
      converted to a subgraph using the same rules than the function
      subgraph.make_view.
    mode: reroute mode, see _reroute_ts(...).
  Returns:
    A tuple `(sgv0, sgv1)` of subgraph views with their inputs swapped.
      Note that the function argument sgv0 and sgv1 are also modified in place.
  Raises:
    StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using
      the same rules than the function subgraph.make_view.
  """
    sgv0 = _subgraph.make_view(sgv0)
    sgv1 = _subgraph.make_view(sgv1)
    _util.check_graphs(sgv0, sgv1)
    can_modify = sgv0.ops + sgv1.ops
    # also allow consumers of passthrough to be modified:
    can_modify += _util.get_consuming_ops(sgv0.passthroughs)
    can_modify += _util.get_consuming_ops(sgv1.passthroughs)
    _reroute_ts(sgv0.inputs, sgv1.inputs, mode, can_modify=can_modify)
    _reroute_sgv_remap(sgv0, sgv1, mode)
    return sgv0, sgv1
    def _do_action(self, src_ops):
        """Add swapin and swapout ops for ops that are reachable from `src_ops`.

        Args:
          src_ops: a list of `tf.Operation`
        """
        open_set = Queue.Queue()
        closed_set = set()

        for op in src_ops:
            open_set.put(op)

        while not open_set.empty():
            src_op = open_set.get()

            # get next ops before the graph is changed
            next_ops = set()
            for t in src_op.outputs:
                frontier_ops = set(util.get_consuming_ops(t))
                next_ops |= frontier_ops - self._grad_ops

            # do action for src_op
            self._insert_swap_nodes(src_op)
            if self._swapped_max_tensors():
                return

            for op in next_ops:
                if op in closed_set:
                    continue
                if op not in open_set.queue:
                    open_set.put(op)

            closed_set.add(src_op)
    def _get_seed_ops(self):
        """Return a list of `tf.Operation` used as a starting point for LMS
        to traverse the graph.

        If a starting scope is given, the ops in this scope will be used.
        Otherwise, this method automatically searches for starting ops.
        """
        # seep ops for search
        seed_ops = set()
        ops = ge.make_list_of_op(self._graph)
        if self._starting_scope:
            scope_ops = set(
                ge.filter_ops_from_regex(ops,
                                         "^{}".format(self._starting_scope)))
            if not scope_ops:
                raise ValueError('No operations were found in starting '
                                 'scope {}.'.format(self._starting_scope))
            seed_ops |= scope_ops

        if self._starting_op_names:
            for name in self._starting_op_names:
                name_ops = set(
                    ge.filter_ops_from_regex(ops, "^{}$".format(name)))
                if not name_ops:
                    raise ValueError('No starting operation was found with '
                                     'name {}.'.format(name))
                seed_ops |= name_ops

        seed_ops = list(seed_ops)
        if not seed_ops:
            candidates = set()
            non_grad_ops = [
                op for op in self._graph.get_operations()
                if not (op in self._grad_ops)
            ]
            for op in non_grad_ops:
                for t in op.outputs:
                    frontier_ops = set(util.get_consuming_ops(t))
                    if (frontier_ops & self._grad_ops):
                        candidates.add(op)
                        break

            # ordering an operation by how much it covers the other ops
            tmp_dict = {}
            max_nelems = -1
            for op in candidates:
                nelems = len(
                    set(
                        ge.get_forward_walk_ops(
                            op, within_ops=non_grad_ops, inclusive=False))
                    & candidates)
                if nelems > 0:
                    tmp_dict[op] = nelems
                    max_nelems = nelems if (
                        nelems > max_nelems) else max_nelems

            # seed ops will cover most of the forward ops
            seed_ops = [k for k, v in tmp_dict.items() if v == max_nelems]
        return seed_ops
Example #5
0
def get_forward_walk_ops(seed_ops, inclusive=True, within_ops=None,
                         stop_at_ts=(), control_outputs=None):
  """Do a forward graph walk and return all the visited ops.

  Args:
    seed_ops: an iterable of operations from which the forward graph
      walk starts. If a list of tensors is given instead, the seed_ops are set
      to be the consumers of those tensors.
    inclusive: if True the given seed_ops are also part of the resulting set.
    within_ops: an iterable of tf.Operation whithin which the search is
      restricted. If within_ops is None, the search is performed within
      the whole graph.
    stop_at_ts: an iterable of tensors at which the graph walk stops.
    control_outputs: a util.ControlOutputs instance or None.
      If not None, it will be used while walking the graph forward.
  Returns:
    A Python set of all the tf.Operation ahead of seed_ops.
  Raises:
    TypeError: if seed_ops or within_ops cannot be converted to a list of
      tf.Operation.
  """
  _, control_outputs = check_cios(False, control_outputs)
  if not util.is_iterable(seed_ops): seed_ops = [seed_ops]
  if not seed_ops: return []
  if isinstance(seed_ops[0], tf_ops.Tensor):
    ts = util.make_list_of_t(seed_ops, allow_graph=False)
    seed_ops = util.get_consuming_ops(ts)
  else:
    seed_ops = util.make_list_of_op(seed_ops, allow_graph=False)

  seed_ops = frozenset(seed_ops)
  stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts))
  if within_ops:
    within_ops = util.make_list_of_op(within_ops, allow_graph=False)
    within_ops = frozenset(within_ops)
    seed_ops &= within_ops
  def is_within(op):
    return within_ops is None or op in within_ops
  result = list(seed_ops)
  wave = set(seed_ops)
  while wave:
    new_wave = set()
    for op in wave:
      for new_t in op.outputs:
        if new_t in stop_at_ts:
          continue
        for new_op in new_t.consumers():
          if new_op not in result and is_within(new_op):
            new_wave.add(new_op)
      if control_outputs is not None:
        for new_op in control_outputs.get(op):
          if new_op not in result and is_within(new_op):
            new_wave.add(new_op)
    util.concatenate_unique(result, new_wave)
    wave = new_wave
  if not inclusive:
    result = [op for op in result if op not in seed_ops]
  return result
    def _find_new_src_op(self, original_op):
        """Find a set of new operations to swap out their output tensors.

        This method is used when `original_op` produces a tensor that is consumed by
        a backward ops whose order is negative. In this case, the tensor might be consumed
        immediately by the backward ops, depending on TensorFlow runtime. Hence, there is
        no need to swap out the tensor.

        This method starts from `original_op` and returns operations whose output tensors
        are consumed by backward operations with positive order.

        Args:
          `original_op`: a `tf.Operation`.

        Return:
          A set of `tf.Operation`.
        """
        src_ops = set()
        open_set = Queue.Queue()
        closed_set = set()

        open_set.put(original_op)

        while not open_set.empty():
            src_op = open_set.get()

            # do action for src_op
            next_ops = set()

            frontier_ops = set()
            for t in src_op.outputs:
                frontier_ops |= set(util.get_consuming_ops(t))
            has_order_ops = {
                op
                for op in frontier_ops
                if (self._topo_sort.get_order(op) >
                    self._topo_sort.bw_starting_order)
            }
            if has_order_ops:
                src_ops.add(src_op)

            next_ops = frontier_ops - has_order_ops
            for op in next_ops:
                if op in closed_set:
                    continue
                if op not in open_set.queue:
                    open_set.put(op)

            closed_set.add(src_op)
        return src_ops
Example #7
0
    def _build_dependency_dict(self):
        """Build a dictionary of dependencies among nodes.
        """
        open_set = Queue.Queue()
        closed_set = set()

        dep_dict = {}
        for op in self._seed_ops:
            open_set.put(op)

        reachable_ops = set(
            ge.get_walks_intersection_ops(list(self._seed_ops),
                                          list(self._grad_ops)))

        # traversal in the fw phase
        while not open_set.empty():
            src_op = open_set.get()

            # do action for src_op
            dep_ops = set(src_op.control_inputs)
            for t in src_op.inputs:
                dep_ops |= set(util.get_generating_ops(t))
                dep_ops &= reachable_ops
            dep_dict[src_op] = dep_ops

            next_ops = set()
            for t in src_op.outputs:
                next_ops |= set(util.get_consuming_ops(t))
            for op in next_ops:
                if op in closed_set:
                    continue
                if op not in open_set.queue:
                    open_set.put(op)

            closed_set.add(src_op)

        return dep_dict
Example #8
0
def get_forward_walk_ops(seed_ops,
                         inclusive=True,
                         within_ops=None,
                         stop_at_ts=(),
                         control_outputs=None):
  """Do a forward graph walk and return all the visited ops.

  Args:
    seed_ops: an iterable of operations from which the forward graph
      walk starts. If a list of tensors is given instead, the seed_ops are set
      to be the consumers of those tensors.
    inclusive: if True the given seed_ops are also part of the resulting set.
    within_ops: an iterable of tf.Operation whithin which the search is
      restricted. If within_ops is None, the search is performed within
      the whole graph.
    stop_at_ts: an iterable of tensors at which the graph walk stops.
    control_outputs: a util.ControlOutputs instance or None.
      If not None, it will be used while walking the graph forward.
  Returns:
    A Python set of all the tf.Operation ahead of seed_ops.
  Raises:
    TypeError: if seed_ops or within_ops cannot be converted to a list of
      tf.Operation.
  """
  _, control_outputs = check_cios(False, control_outputs)
  if not util.is_iterable(seed_ops):
    seed_ops = [seed_ops]
  if not seed_ops:
    return []
  if isinstance(seed_ops[0], tf_ops.Tensor):
    ts = util.make_list_of_t(seed_ops, allow_graph=False)
    seed_ops = util.get_consuming_ops(ts)
  else:
    seed_ops = util.make_list_of_op(seed_ops, allow_graph=False)

  seed_ops = frozenset(seed_ops)
  stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts))
  if within_ops:
    within_ops = util.make_list_of_op(within_ops, allow_graph=False)
    within_ops = frozenset(within_ops)
    seed_ops &= within_ops

  def is_within(op):
    return within_ops is None or op in within_ops

  result = list(seed_ops)
  wave = set(seed_ops)
  while wave:
    new_wave = set()
    for op in wave:
      for new_t in op.outputs:
        if new_t in stop_at_ts:
          continue
        for new_op in new_t.consumers():
          if new_op not in result and is_within(new_op):
            new_wave.add(new_op)
      if control_outputs is not None:
        for new_op in control_outputs.get(op):
          if new_op not in result and is_within(new_op):
            new_wave.add(new_op)
    util.concatenate_unique(result, new_wave)
    wave = new_wave
  if not inclusive:
    result = [op for op in result if op not in seed_ops]
  return result
Example #9
0
def gradients(ys, xs, grad_ys=None, checkpoints='collection', **kwargs):
    print("-------------------------------")
    debug_print("Editing model for OME")
    incpu_count = 0
    #    print("Calling memsaving gradients with", checkpoints)
    if not isinstance(ys, list):
        ys = [ys]
    if not isinstance(xs, list):
        xs = [xs]

    bwd_ops = ge.get_backward_walk_ops([y.op for y in ys], inclusive=True)

    for index, op in enumerate(bwd_ops):
        debug_print("bwd_ops: [{}] :{}".format(index, op.name), 1)

    # forward ops are all ops that are candidates for recomputation
    fwd_ops = ge.get_forward_walk_ops([x.op for x in xs],
                                      inclusive=True,
                                      within_ops=bwd_ops)
    for index, op in enumerate(fwd_ops):
        debug_print("fwd_ops: [{}] : {}".format(index, op.name), 1)

    # exclude ops with no inputs
    fwd_ops = [op for op in fwd_ops if op.inputs]

    # don't recompute xs, remove variables
    xs_ops = _to_ops(xs)
    fwd_ops = [op for op in fwd_ops if not op in xs_ops]
    fwd_ops = [op for op in fwd_ops if not '/assign' in op.name]
    fwd_ops = [op for op in fwd_ops if not '/Assign' in op.name]
    fwd_ops = [op for op in fwd_ops if not '/read' in op.name]
    ts_all = ge.filter_ts(fwd_ops, True)  # get the tensors
    ts_all = [t for t in ts_all if '/read' not in t.name]
    ts_all = set(ts_all) - set(xs) - set(ys)

    # construct list of tensors to checkpoint during forward pass, if not
    # given as input
    if type(checkpoints) is not list:
        # remove very small tensors and some weird ops
        def fixdims(
            t
        ):  # tf.Dimension values are not compatible with int, convert manually
            try:
                return [int(e if e.value is not None else 64) for e in t]
            except:
                return [0]  # unknown shape

        ts_all = [
            t for t in ts_all
            if np.prod(fixdims(t.shape)) > MIN_CHECKPOINT_NODE_SIZE
        ]
        ts_all = [t for t in ts_all if 'L2Loss' not in t.name]
        ts_all = [t for t in ts_all if 'entropy' not in t.name]
        ts_all = [t for t in ts_all if 'FusedBatchNorm' not in t.name]
        ts_all = [t for t in ts_all if 'Switch' not in t.name]
        ts_all = [t for t in ts_all if 'dropout' not in t.name]
        # DV: FP16_FIX - need to add 'Cast' layer here to make it work for FP16
        ts_all = [t for t in ts_all if 'Cast' not in t.name]

        # filter out all tensors that are inputs of the backward graph
        with util.capture_ops() as bwd_ops:
            tf_gradients(ys, xs, grad_ys, **kwargs)

        bwd_inputs = [t for op in bwd_ops for t in op.inputs]
        # list of tensors in forward graph that is in input to bwd graph
        ts_filtered = list(set(bwd_inputs).intersection(ts_all))
        debug_print("Using tensors {}".format(ts_filtered), 1)

        # try two slightly different ways of getting bottlenecks tensors
        # to checkpoint
        for ts in [ts_filtered, ts_all]:

            # get all bottlenecks in the graph
            bottleneck_ts = []
            for t in ts:
                b = set(
                    ge.get_backward_walk_ops(t.op,
                                             inclusive=True,
                                             within_ops=fwd_ops))
                f = set(
                    ge.get_forward_walk_ops(t.op,
                                            inclusive=False,
                                            within_ops=fwd_ops))
                # check that there are not shortcuts
                b_inp = set([inp for op in b
                             for inp in op.inputs]).intersection(ts_all)
                f_inp = set([inp for op in f
                             for inp in op.inputs]).intersection(ts_all)
                if not set(b_inp).intersection(
                        f_inp) and len(b_inp) + len(f_inp) >= len(ts_all):
                    bottleneck_ts.append(t)  # we have a bottleneck!
                else:
                    debug_print(
                        "Rejected bottleneck candidate and ops {}".format(
                            [t] + list(set(ts_all) - set(b_inp) - set(f_inp))),
                        2)

            # success? or try again without filtering?
            if len(bottleneck_ts) >= np.sqrt(
                    len(ts_filtered)):  # yes, enough bottlenecks found!
                break

        if not bottleneck_ts:
            raise Exception(
                'unable to find bottleneck tensors! please provide checkpoint nodes manually, or use checkpoints="speed".'
            )

        # sort the bottlenecks
        bottlenecks_sorted_lists = tf_toposort(bottleneck_ts,
                                               within_ops=fwd_ops)
        sorted_bottlenecks = [t for ts in bottlenecks_sorted_lists for t in ts]

        # save an approximately optimal number ~ sqrt(N)
        N = len(ts_filtered)
        if len(bottleneck_ts) <= np.ceil(np.sqrt(N)):
            checkpoints = sorted_bottlenecks
        else:
            step = int(np.ceil(len(bottleneck_ts) / np.sqrt(N)))
            checkpoints = sorted_bottlenecks[step::step]

    checkpoints = list(set(checkpoints).intersection(ts_all))

    # at this point automatic selection happened and checkpoints is list of nodes
    assert isinstance(checkpoints, list)

    debug_print("Checkpoint nodes used: {}".format(checkpoints), 1)
    # better error handling of special cases
    # xs are already handled as checkpoint nodes, so no need to include them
    xs_intersect_checkpoints = set(xs).intersection(set(checkpoints))
    if xs_intersect_checkpoints:
        debug_print(
            "Warning, some input nodes are also checkpoint nodes: {}".format(
                xs_intersect_checkpoints), 2)
    ys_intersect_checkpoints = set(ys).intersection(set(checkpoints))
    debug_print(
        "ys: {}, checkpoints: {}, intersect: {}".format(
            ys, checkpoints, ys_intersect_checkpoints), 1)
    # saving an output node (ys) gives no benefit in memory while creating
    # new edge cases, exclude them
    if ys_intersect_checkpoints:
        debug_print(
            "Warning, some output nodes are also checkpoints nodes: {}".format(
                format_ops(ys_intersect_checkpoints)), 2)

    # remove initial and terminal nodes from checkpoints list if present
    checkpoints = list(set(checkpoints) - set(ys) - set(xs))

    # check that we have some nodes to checkpoint
    if not checkpoints:
        raise Exception('no checkpoints nodes found or given as input! ')

    debug_print(
        "Select {} nodes to checkpoint nodes.".format(len(checkpoints)), 0)

    # disconnect dependencies between checkpointed tensors
    checkpoints_disconnected = {}
    for x in checkpoints:
        frontier_ops = set(graph_util.get_consuming_ops(x.op.outputs))
        debug_print("my frontier ops: {}".format(frontier_ops), 1)

        bw_frontier_ops = frontier_ops & set(bwd_ops)
        debug_print("my bw frontier ops: {}".format(bw_frontier_ops), 1)

        if len(bw_frontier_ops) > 1:
            continue

        if x.op and x.op.name is not None:
            grad_node = tf.stop_gradient(x, name=x.op.name + "_sg")
        else:
            grad_node = tf.stop_gradient(x)

        swapout_op = _add_swapout(grad_node.op, grad_node.op.outputs)
        incpu_count = incpu_count + 1
        swapin_op = _add_swapin(swapout_op, bw_frontier_ops,
                                grad_node.op.outputs)
        checkpoints_disconnected[x] = swapin_op
        my_add_control_inputs(x, bw_frontier_ops, swapin_op)
        # control dependency -> swap_in
        # self._add_control_dependency(src_op, dest_op, swapin_op)

    # g = tf.get_default_graph()
    # print(g.get_operations())

    # partial derivatives to the checkpointed tensors and xs
    ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys],
                                    stop_at_ts=checkpoints,
                                    within_ops=fwd_ops)
    debug_print(
        "Found {} ops to copy within fwd_ops {}, seed {}, stop_at {}".format(
            len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints), 1)
    debug_print("ops_to_copy = {}".format(ops_to_copy), 1)
    debug_print("Processing list {}".format(ys), 1)
    copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {})
    for origin_op, op in info._transformed_ops.items():
        op._set_device(origin_op.node_def.device)
    copied_ops = info._transformed_ops.values()
    debug_print("Copied {} to {}".format(ops_to_copy, copied_ops), 1)
    ge.reroute_ts(checkpoints_disconnected.values(),
                  checkpoints_disconnected.keys(),
                  can_modify=copied_ops)
    debug_print(
        "Rewired {} in place of {} restricted to {}".format(
            checkpoints_disconnected.values(), checkpoints_disconnected.keys(),
            copied_ops), 1)

    # get gradients with respect to current boundary + original x's
    copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys]
    boundary = list(checkpoints_disconnected.values())
    dv = tf_gradients(ys=copied_ys,
                      xs=boundary + xs,
                      grad_ys=grad_ys,
                      **kwargs)
    debug_print("Got gradients {}".format(dv), 1)
    debug_print("for {}".format(copied_ys), 1)
    debug_print("with respect to {}".format(boundary + xs), 1)

    inputs_to_do_before = [y.op for y in ys]
    if grad_ys is not None:
        inputs_to_do_before += grad_ys
    wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
    my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)

    # partial derivatives to the checkpointed nodes
    # dictionary of "node: backprop" for nodes in the boundary
    d_checkpoints = {
        r: dr
        for r, dr in zip(checkpoints_disconnected.keys(),
                         dv[:len(checkpoints_disconnected)])
    }
    # partial derivatives to xs (usually the params of the neural net)
    d_xs = dv[len(checkpoints_disconnected):]

    # incorporate derivatives flowing through the checkpointed nodes
    checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops)
    for ts in checkpoints_sorted_lists[::-1]:
        debug_print("Processing list {}".format(ts), 1)
        checkpoints_other = [r for r in checkpoints if r not in ts]
        checkpoints_disconnected_other = [
            checkpoints_disconnected[r] for r in checkpoints_other
        ]

        # copy part of the graph below current checkpoint node, stopping at
        # other checkpoints nodes
        ops_to_copy = fast_backward_ops(within_ops=fwd_ops,
                                        seed_ops=[r.op for r in ts],
                                        stop_at_ts=checkpoints_other)
        debug_print(
            "Found {} ops to copy within {}, seed {}, stop_at {}".format(
                len(ops_to_copy), fwd_ops, [r.op for r in ts],
                checkpoints_other), 1)
        debug_print("ops_to_copy = {}".format(ops_to_copy), 1)
        if not ops_to_copy:  # we're done!
            break
        copied_sgv, info = ge.copy_with_input_replacements(
            ge.sgv(ops_to_copy), {})
        for origin_op, op in info._transformed_ops.items():
            op._set_device(origin_op.node_def.device)
        copied_ops = info._transformed_ops.values()
        debug_print("Copied {} to {}".format(ops_to_copy, copied_ops), 1)
        ge.reroute_ts(checkpoints_disconnected_other,
                      checkpoints_other,
                      can_modify=copied_ops)
        debug_print(
            "Rewired {} in place of {} restricted to {}".format(
                checkpoints_disconnected_other, checkpoints_other, copied_ops),
            1)

        # gradient flowing through the checkpointed node
        boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts]
        substitute_backprops = [d_checkpoints[r] for r in ts]
        dv = tf_gradients(boundary,
                          checkpoints_disconnected_other + xs,
                          grad_ys=substitute_backprops,
                          **kwargs)
        debug_print("Got gradients {}".format(dv), 1)
        debug_print("for {}".format(boundary), 1)
        debug_print(
            "with respect to {}".format(checkpoints_disconnected_other + xs),
            1)
        debug_print(
            "with boundary backprop substitutions {}".format(
                substitute_backprops), 1)

        inputs_to_do_before = [d_checkpoints[r].op for r in ts]
        wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
        my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)

        # partial derivatives to the checkpointed nodes
        for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]):
            if dr is not None:
                if d_checkpoints[r] is None:
                    d_checkpoints[r] = dr
                else:
                    d_checkpoints[r] += dr

        def _unsparsify(x):
            if not isinstance(x, tf.IndexedSlices):
                return x
            assert x.dense_shape is not None, "memory_saving_gradients encountered sparse gradients of unknown shape"
            indices = x.indices
            while indices.shape.ndims < x.values.shape.ndims:
                indices = tf.expand_dims(indices, -1)
            return tf.scatter_nd(indices, x.values, x.dense_shape)

        # partial derivatives to xs (usually the params of the neural net)
        d_xs_new = dv[len(checkpoints_other):]
        for j in range(len(xs)):
            if d_xs_new[j] is not None:
                if d_xs[j] is None:
                    d_xs[j] = _unsparsify(d_xs_new[j])
                else:
                    d_xs[j] += _unsparsify(d_xs_new[j])

    return d_xs
    def _do_chain_rule(self, fw_op, bw_op, lower_b, upper_b):
        """Find a control dependency operation using chain rules.
        Go down along the forward phase to find corresponding backward ops
        as candidates for control dependency ops.

        Args:
          fw_op: a `tf.Operation` that has a tensor swapped out.
          bw_op: a `tf.Operation` that consumes a tensor swapped in.
          lower_b: an `integer`. The distance in the graph between
            `fw_op` and a forward operation that has corresponding backward
            ops as candidates for control dependency ops must be greater than
            `lower_b`.
          upper_b: an `integer`. The distance in the graph between
            `fw_op` and a forward operation that has corresponding backward
             ops as candidates for control dependency ops must be smaller than
            `upper_b`

        Return:
          A tuple of (`tf.Operation`, an `integer`). The first item is
          the control dependency operation that triggers swapping in the input
          tensor of `bw_op`. The second item is the order of the control
          dependency operation in the topological order.
        """
        fw_order = self._topo_sort.get_order(fw_op)
        bw_order = self._topo_sort.get_order(bw_op)

        # check if the bw op is near the boundary between fw and bw phases
        if (bw_order - lower_b) < self._topo_sort.bw_starting_order:
            return self._do_direct_order(fw_op, bw_op, lower_b, upper_b)

        open_set1 = Queue.Queue()
        open_set2 = Queue.Queue()
        closed_set = set()

        open_set1.put(fw_op)

        result_ops = set()
        while not open_set1.empty():
            # stop if reaching the upperbound
            if upper_b == 0 or (lower_b > upper_b):
                break

            src_op = open_set1.get()

            # do action for src_op
            total_consumming_ops = set()
            for t in src_op.outputs:
                consumming_ops = set(util.get_consuming_ops(t))
                total_consumming_ops |= consumming_ops

            if lower_b <= 0:
                # inside the range
                consumming_ops_bw = total_consumming_ops & self._grad_ops
                # check validation
                consumming_ops_bw = {
                    op
                    for op in consumming_ops_bw
                    if self._topo_sort.get_order(op) > fw_order}
                consumming_ops_bw = {
                    op
                    for op in consumming_ops_bw
                    if self._topo_sort.get_order(op) < bw_order}
                consumming_ops_bw = {
                    op
                    for op in consumming_ops_bw
                    if "/cond/" not in op.name}
                result_ops |= consumming_ops_bw
            # go to the next level
            next_ops = total_consumming_ops - self._grad_ops
            for op in next_ops:
                if op in closed_set:
                    continue
                if op not in open_set2.queue:
                    open_set2.put(op)

            closed_set.add(src_op)
            if open_set1.empty():
                if result_ops:
                    break
                lower_b = lower_b - 1
                upper_b = upper_b - 1
                while not open_set2.empty():
                    open_set1.put(open_set2.get())
        if result_ops:
            ctrld_op = next(iter(result_ops))
            return (ctrld_op, self._topo_sort.get_order(ctrld_op))
        else:
            return (None, -1)
    def _insert_swap_nodes(self, src_op):
        """Insert swapin and swapout ops for the given operation into the graph.

        This method does an in-place modification to the graph.

        Args:
          src_op: a `tf.Operation`
        """
        self._log_info("Operation: {}".format(src_op), 2)

        # bypass excluded ops
        if src_op in self._excl_ops:
            return

        # if inclusive mode is enabled, only proceed if this op is included
        if self._incl_ops:
            if src_op not in self._incl_ops:
                return

        for t in src_op.outputs:
            if self._swapped_max_tensors():
                return

            frontier_ops = set(util.get_consuming_ops(t))
            self._log_info("my frontier ops: {}".format(frontier_ops), 2)

            bw_frontier_ops = frontier_ops & self._grad_ops
            self._log_info("my bw frontier ops: {}".format(bw_frontier_ops), 2)

            # swap branch ops if they are far enough (depending on threshold)
            if self._swap_branches:
                fw_branch_ops = self._get_branch_ops(
                    frontier_ops - self._grad_ops,
                    self._branch_threshold)
                bw_frontier_ops = bw_frontier_ops | fw_branch_ops

            # Do not swap tensors used by bw ops without outgoing ops.
            # These bw ops can be removed by Tensorflow compiler
            bw_frontier_ops = {op
                               for op in bw_frontier_ops
                               if set(self._get_forward_walk_ops(op, inclusive=False))}

            if not bw_frontier_ops:
                continue

            self._log_info("Operation: {}, order {}, type {}".format(
                src_op.name, self._topo_sort.get_order(src_op),
                src_op.type), 1)

            # create swap_out node only if there exists a real dest. operation
            swapout_op = None
            for op in bw_frontier_ops:
                if self._topo_sort.get_order(op) >= 0:
                    swapout_op = self._add_swapout(src_op, t)
                    self._incpu_count = self._incpu_count + 1
                    break

            # create swap_in nodes
            if self._fuse_swapins and swapout_op:
                bw_frontier_ops = self._fuse_swapin_ops(
                    src_op, swapout_op, bw_frontier_ops, t)
            for dest_op in bw_frontier_ops:
                if self._topo_sort.get_order(dest_op) < 0:
                    if src_op in self._grad_ops:
                        continue
                    else:
                        new_src_ops = self._find_new_src_op(dest_op)
                        for op in new_src_ops:
                            self._insert_swap_nodes(op)
                else:
                    # swap_in op
                    swapin_op = self._add_swapin(swapout_op, dest_op, t)
                    # control dependency -> swap_in
                    self._add_control_dependency(src_op, dest_op, swapin_op)