Пример #1
0
def test_articulation_points_resnet():
    """Make sure articulation points are found correctly in resnet."""
    tf.reset_default_graph()
    nodes = util.make_resnet(3)
    all_ops = ge.get_forward_walk_ops(seed_ops=nodes[0].op)
    graph = nx.Graph(util.tf_ops_to_graph(all_ops))
    assert util.set_equal(util.format_ops(nx.articulation_points(graph)),
                          ['a01_add'])

    tf.reset_default_graph()
    nodes = util.make_resnet(4)
    all_ops = ge.get_forward_walk_ops(seed_ops=nodes[0].op)
    graph = nx.Graph(util.tf_ops_to_graph(all_ops))
    assert util.set_equal(util.format_ops(nx.articulation_points(graph)),
                          ['a01_add', 'a02_add'])
def test_articulation_points_resnet():
  """Make sure articulation points are found correctly in resnet."""
  tf.reset_default_graph()
  nodes = util.make_resnet(3)
  all_ops = ge.get_forward_walk_ops(seed_ops=nodes[0].op)
  graph = nx.Graph(util.tf_ops_to_graph(all_ops))
  assert util.set_equal(util.format_ops(nx.articulation_points(graph)),
                        ['a01_add'])
  
  tf.reset_default_graph()
  nodes = util.make_resnet(4)
  all_ops = ge.get_forward_walk_ops(seed_ops=nodes[0].op)
  graph = nx.Graph(util.tf_ops_to_graph(all_ops))
  assert util.set_equal(util.format_ops(nx.articulation_points(graph)),
                        ['a01_add', 'a02_add'])
    def __init__(self,
                 ys,
                 xs,
                 grad_ys=None,
                 debug=False,
                 debug_level=1,
                 cpu_device="/cpu:0",
                 **kwargs):

        self._ys = ys
        self._xs = xs

        self._grad_ys = grad_ys
        self._kwargs = kwargs

        self._cpu_device = cpu_device
        self._debug = debug
        self._debug_level = debug_level

        self.bwd_ops = ge.get_backward_walk_ops([y.op for y in self._ys],
                                                inclusive=True)

        self._log_info("bwd_ops: {}".format(self.bwd_ops), 2)

        # forward ops are all ops that are candidates for recomputation
        self.fwd_ops = ge.get_forward_walk_ops([x.op for x in self._xs],
                                               inclusive=True,
                                               within_ops=self.bwd_ops)

        self._log_info("fwd_ops: {}".format(self.fwd_ops), 2)
Пример #4
0
    def _get_seed_ops(self):
        """Return a list of `tf.Operation` used as a starting point for LMS
        to traverse the graph.

        If a starting scope is given, the ops in this scope will be used.
        Otherwise, this method automatically searches for starting ops.
        """
        # seep ops for search
        seed_ops = set()
        ops = ge.make_list_of_op(self._graph)
        if self._starting_scope:
            scope_ops = set(
                ge.filter_ops_from_regex(ops,
                                         "^{}".format(self._starting_scope)))
            if not scope_ops:
                raise ValueError('No operations were found in starting '
                                 'scope {}.'.format(self._starting_scope))
            seed_ops |= scope_ops

        if self._starting_op_names:
            for name in self._starting_op_names:
                name_ops = set(
                    ge.filter_ops_from_regex(ops, "^{}$".format(name)))
                if not name_ops:
                    raise ValueError('No starting operation was found with '
                                     'name {}.'.format(name))
                seed_ops |= name_ops

        seed_ops = list(seed_ops)
        if not seed_ops:
            candidates = set()
            non_grad_ops = [
                op for op in self._graph.get_operations()
                if not (op in self._grad_ops)
            ]
            for op in non_grad_ops:
                for t in op.outputs:
                    frontier_ops = set(util.get_consuming_ops(t))
                    if (frontier_ops & self._grad_ops):
                        candidates.add(op)
                        break

            # ordering an operation by how much it covers the other ops
            tmp_dict = {}
            max_nelems = -1
            for op in candidates:
                nelems = len(
                    set(
                        ge.get_forward_walk_ops(
                            op, within_ops=non_grad_ops, inclusive=False))
                    & candidates)
                if nelems > 0:
                    tmp_dict[op] = nelems
                    max_nelems = nelems if (
                        nelems > max_nelems) else max_nelems

            # seed ops will cover most of the forward ops
            seed_ops = [k for k, v in tmp_dict.items() if v == max_nelems]
        return seed_ops
Пример #5
0
def test_resnet_structure():
    """sanity check on TF resnet structure."""
    tf.reset_default_graph()
    nodes = util.make_resnet(3)
    all_ops = ge.get_forward_walk_ops(seed_ops=nodes[0].op)
    desired_graph = {0: [1, 2], 1: [2], 2: [3, 4], 3: [4]}
    actual_graph = util.tf_ops_to_graph(all_ops)
    assert (util.graphs_isomorphic(actual_graph, desired_graph))
def test_resnet_structure():
  """sanity check on TF resnet structure."""
  tf.reset_default_graph()
  nodes = util.make_resnet(3)
  all_ops = ge.get_forward_walk_ops(seed_ops=nodes[0].op)
  desired_graph = {0: [1, 2], 1: [2], 2: [3, 4], 3: [4]}
  actual_graph = util.tf_ops_to_graph(all_ops)
  assert(util.graphs_isomorphic(actual_graph, desired_graph))
Пример #7
0
 def _clean_update_ops(self):
     """Remove ops that are in the update phase.
     """
     update_ops = set(
         ge.get_forward_walk_ops(list(self._grad_ops), inclusive=False))
     for i in range(0, len(self._topo_sort)):
         ops = self._topo_sort[i]
         # remove ops that are not bw or fw op
         # e.g ops in the update phase
         self._topo_sort[i] = ops - update_ops
 def _get_forward_walk_ops(self, op, inclusive=True):
     """ A wrapper of `tensorflow.contrib.graph_editor.get_forward_walk_ops`
     """
     if op in self._ops_dict:
         if inclusive:
             return self._ops_dict[op]
         else:
             return list(set(self._ops_dict[op]) - {op})
     else:
         ret = ge.get_forward_walk_ops(op)
         self._ops_dict[op] = ret
         if inclusive:
             return ret
         else:
             return list(set(ret) - {op})
def gradients(ys, xs, grad_ys=None, **kwargs):
    if not isinstance(ys, list):
        ys = [ys]
    if not isinstance(xs, list):
        xs = [xs]

    bwd_ops = ge.get_backward_walk_ops([y.op for y in ys], inclusive=True)

    for index, op in enumerate(bwd_ops):
        tf.logging.info("bwd_ops: [{}] :{}".format(index, op.name))

    # forward ops are all ops that are candidates for recomputation
    fwd_ops = ge.get_forward_walk_ops([x.op for x in xs],
                                      inclusive=True,
                                      within_ops=bwd_ops)
    for index, op in enumerate(fwd_ops):
        tf.logging.info("fwd_ops: [{}] : {}".format(index, op.name))
def tf_toposort(ts, within_ops=None):
    all_ops = ge.get_forward_walk_ops([x.op for x in ts], within_ops=within_ops)

    deps = {}
    for op in all_ops:
        for o in op.outputs:
            deps[o] = set(op.inputs)
    sorted_ts = toposort(deps)

    # only keep the tensors from our original list
    ts_sorted_lists = []
    for l in sorted_ts:
        keep = list(set(l).intersection(ts))
        if keep:
            ts_sorted_lists.append(keep)

    return ts_sorted_lists
def tf_toposort(ts_inp, within_ops=None):
    """ Tensorflow topological sort """
    all_ops = ge.get_forward_walk_ops([x.op for x in ts_inp], within_ops=within_ops)

    deps = {}
    for tf_op in all_ops:
        for outp in tf_op.outputs:
            deps[outp] = set(tf_op.inputs)
    sorted_ts = toposort(deps)

    # only keep the tensors from our original list
    ts_sorted_lists = []
    for lst in sorted_ts:
        keep = list(set(lst).intersection(ts_inp))
        if keep:
            ts_sorted_lists.append(keep)
    return ts_sorted_lists
def tf_toposort(ts, within_ops=None):
    all_ops = ge.get_forward_walk_ops([x.op for x in ts], within_ops=within_ops)

    deps = {}
    for op in all_ops:
        for o in op.outputs:
            deps[o] = set(op.inputs)
    sorted_ts = toposort(deps)

    # only keep the tensors from our original list
    ts_sorted_lists = []
    for l in sorted_ts:
        keep = list(set(l).intersection(ts))
        if keep:
            ts_sorted_lists.append(keep)

    return ts_sorted_lists
Пример #13
0
def tf_toposort(ts_inp, within_ops=None):
    """ Tensorflow topological sort """
    all_ops = ge.get_forward_walk_ops([x.op for x in ts_inp],
                                      within_ops=within_ops)

    deps = {}
    for tf_op in all_ops:
        for outp in tf_op.outputs:
            deps[outp] = set(tf_op.inputs)
    sorted_ts = toposort(deps)

    # only keep the tensors from our original list
    ts_sorted_lists = []
    for lst in sorted_ts:
        keep = list(set(lst).intersection(ts_inp))
        if keep:
            ts_sorted_lists.append(keep)
    return ts_sorted_lists
Пример #14
0
def topological_sort(layer_names):
    """
    Given `n` names, returns a boolean matrix $L \in {0, 1}^{n \times n}$,
    where `L[i, j] == true` iff layer i is lower than layer j (i flows to j).

    Args:
      layer_names:  List of `str` names of the layers in default graph.

    Returns:
      order_matrix: A boolean matrix L as described above.
    """
    n = len(layer_names)
    layer_ops = [
        op for op in tf.get_default_graph().as_graph_def().node
        if op.name in layer_names
    ]
    order_matrix = np.zeros((n, n), dtype=bool)
    for i in range(n):
        forward_ops = ge.get_forward_walk_ops(layer_ops[i], inclusive=False)
        for j in range(n):
            if layer_ops[j] in forward_ops:
                order_matrix[i, j] = True
    return order_matrix
Пример #15
0
def is_computable(result, known_values):
  """Returns true if given tensor is computable from known values."""

  computable_ops = ge.get_forward_walk_ops([val.op for val in known_values])
  return result.op in computable_ops
Пример #16
0
def gradients(ys, xs, grad_ys=None, checkpoints='collection', **kwargs):
    print("-------------------------------")
    debug_print("Editing model for OME")
    incpu_count = 0
    #    print("Calling memsaving gradients with", checkpoints)
    if not isinstance(ys, list):
        ys = [ys]
    if not isinstance(xs, list):
        xs = [xs]

    bwd_ops = ge.get_backward_walk_ops([y.op for y in ys], inclusive=True)

    for index, op in enumerate(bwd_ops):
        debug_print("bwd_ops: [{}] :{}".format(index, op.name), 1)

    # forward ops are all ops that are candidates for recomputation
    fwd_ops = ge.get_forward_walk_ops([x.op for x in xs],
                                      inclusive=True,
                                      within_ops=bwd_ops)
    for index, op in enumerate(fwd_ops):
        debug_print("fwd_ops: [{}] : {}".format(index, op.name), 1)

    # exclude ops with no inputs
    fwd_ops = [op for op in fwd_ops if op.inputs]

    # don't recompute xs, remove variables
    xs_ops = _to_ops(xs)
    fwd_ops = [op for op in fwd_ops if not op in xs_ops]
    fwd_ops = [op for op in fwd_ops if not '/assign' in op.name]
    fwd_ops = [op for op in fwd_ops if not '/Assign' in op.name]
    fwd_ops = [op for op in fwd_ops if not '/read' in op.name]
    ts_all = ge.filter_ts(fwd_ops, True)  # get the tensors
    ts_all = [t for t in ts_all if '/read' not in t.name]
    ts_all = set(ts_all) - set(xs) - set(ys)

    # construct list of tensors to checkpoint during forward pass, if not
    # given as input
    if type(checkpoints) is not list:
        # remove very small tensors and some weird ops
        def fixdims(
            t
        ):  # tf.Dimension values are not compatible with int, convert manually
            try:
                return [int(e if e.value is not None else 64) for e in t]
            except:
                return [0]  # unknown shape

        ts_all = [
            t for t in ts_all
            if np.prod(fixdims(t.shape)) > MIN_CHECKPOINT_NODE_SIZE
        ]
        ts_all = [t for t in ts_all if 'L2Loss' not in t.name]
        ts_all = [t for t in ts_all if 'entropy' not in t.name]
        ts_all = [t for t in ts_all if 'FusedBatchNorm' not in t.name]
        ts_all = [t for t in ts_all if 'Switch' not in t.name]
        ts_all = [t for t in ts_all if 'dropout' not in t.name]
        # DV: FP16_FIX - need to add 'Cast' layer here to make it work for FP16
        ts_all = [t for t in ts_all if 'Cast' not in t.name]

        # filter out all tensors that are inputs of the backward graph
        with util.capture_ops() as bwd_ops:
            tf_gradients(ys, xs, grad_ys, **kwargs)

        bwd_inputs = [t for op in bwd_ops for t in op.inputs]
        # list of tensors in forward graph that is in input to bwd graph
        ts_filtered = list(set(bwd_inputs).intersection(ts_all))
        debug_print("Using tensors {}".format(ts_filtered), 1)

        # try two slightly different ways of getting bottlenecks tensors
        # to checkpoint
        for ts in [ts_filtered, ts_all]:

            # get all bottlenecks in the graph
            bottleneck_ts = []
            for t in ts:
                b = set(
                    ge.get_backward_walk_ops(t.op,
                                             inclusive=True,
                                             within_ops=fwd_ops))
                f = set(
                    ge.get_forward_walk_ops(t.op,
                                            inclusive=False,
                                            within_ops=fwd_ops))
                # check that there are not shortcuts
                b_inp = set([inp for op in b
                             for inp in op.inputs]).intersection(ts_all)
                f_inp = set([inp for op in f
                             for inp in op.inputs]).intersection(ts_all)
                if not set(b_inp).intersection(
                        f_inp) and len(b_inp) + len(f_inp) >= len(ts_all):
                    bottleneck_ts.append(t)  # we have a bottleneck!
                else:
                    debug_print(
                        "Rejected bottleneck candidate and ops {}".format(
                            [t] + list(set(ts_all) - set(b_inp) - set(f_inp))),
                        2)

            # success? or try again without filtering?
            if len(bottleneck_ts) >= np.sqrt(
                    len(ts_filtered)):  # yes, enough bottlenecks found!
                break

        if not bottleneck_ts:
            raise Exception(
                'unable to find bottleneck tensors! please provide checkpoint nodes manually, or use checkpoints="speed".'
            )

        # sort the bottlenecks
        bottlenecks_sorted_lists = tf_toposort(bottleneck_ts,
                                               within_ops=fwd_ops)
        sorted_bottlenecks = [t for ts in bottlenecks_sorted_lists for t in ts]

        # save an approximately optimal number ~ sqrt(N)
        N = len(ts_filtered)
        if len(bottleneck_ts) <= np.ceil(np.sqrt(N)):
            checkpoints = sorted_bottlenecks
        else:
            step = int(np.ceil(len(bottleneck_ts) / np.sqrt(N)))
            checkpoints = sorted_bottlenecks[step::step]

    checkpoints = list(set(checkpoints).intersection(ts_all))

    # at this point automatic selection happened and checkpoints is list of nodes
    assert isinstance(checkpoints, list)

    debug_print("Checkpoint nodes used: {}".format(checkpoints), 1)
    # better error handling of special cases
    # xs are already handled as checkpoint nodes, so no need to include them
    xs_intersect_checkpoints = set(xs).intersection(set(checkpoints))
    if xs_intersect_checkpoints:
        debug_print(
            "Warning, some input nodes are also checkpoint nodes: {}".format(
                xs_intersect_checkpoints), 2)
    ys_intersect_checkpoints = set(ys).intersection(set(checkpoints))
    debug_print(
        "ys: {}, checkpoints: {}, intersect: {}".format(
            ys, checkpoints, ys_intersect_checkpoints), 1)
    # saving an output node (ys) gives no benefit in memory while creating
    # new edge cases, exclude them
    if ys_intersect_checkpoints:
        debug_print(
            "Warning, some output nodes are also checkpoints nodes: {}".format(
                format_ops(ys_intersect_checkpoints)), 2)

    # remove initial and terminal nodes from checkpoints list if present
    checkpoints = list(set(checkpoints) - set(ys) - set(xs))

    # check that we have some nodes to checkpoint
    if not checkpoints:
        raise Exception('no checkpoints nodes found or given as input! ')

    debug_print(
        "Select {} nodes to checkpoint nodes.".format(len(checkpoints)), 0)

    # disconnect dependencies between checkpointed tensors
    checkpoints_disconnected = {}
    for x in checkpoints:
        frontier_ops = set(graph_util.get_consuming_ops(x.op.outputs))
        debug_print("my frontier ops: {}".format(frontier_ops), 1)

        bw_frontier_ops = frontier_ops & set(bwd_ops)
        debug_print("my bw frontier ops: {}".format(bw_frontier_ops), 1)

        if len(bw_frontier_ops) > 1:
            continue

        if x.op and x.op.name is not None:
            grad_node = tf.stop_gradient(x, name=x.op.name + "_sg")
        else:
            grad_node = tf.stop_gradient(x)

        swapout_op = _add_swapout(grad_node.op, grad_node.op.outputs)
        incpu_count = incpu_count + 1
        swapin_op = _add_swapin(swapout_op, bw_frontier_ops,
                                grad_node.op.outputs)
        checkpoints_disconnected[x] = swapin_op
        my_add_control_inputs(x, bw_frontier_ops, swapin_op)
        # control dependency -> swap_in
        # self._add_control_dependency(src_op, dest_op, swapin_op)

    # g = tf.get_default_graph()
    # print(g.get_operations())

    # partial derivatives to the checkpointed tensors and xs
    ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys],
                                    stop_at_ts=checkpoints,
                                    within_ops=fwd_ops)
    debug_print(
        "Found {} ops to copy within fwd_ops {}, seed {}, stop_at {}".format(
            len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints), 1)
    debug_print("ops_to_copy = {}".format(ops_to_copy), 1)
    debug_print("Processing list {}".format(ys), 1)
    copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {})
    for origin_op, op in info._transformed_ops.items():
        op._set_device(origin_op.node_def.device)
    copied_ops = info._transformed_ops.values()
    debug_print("Copied {} to {}".format(ops_to_copy, copied_ops), 1)
    ge.reroute_ts(checkpoints_disconnected.values(),
                  checkpoints_disconnected.keys(),
                  can_modify=copied_ops)
    debug_print(
        "Rewired {} in place of {} restricted to {}".format(
            checkpoints_disconnected.values(), checkpoints_disconnected.keys(),
            copied_ops), 1)

    # get gradients with respect to current boundary + original x's
    copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys]
    boundary = list(checkpoints_disconnected.values())
    dv = tf_gradients(ys=copied_ys,
                      xs=boundary + xs,
                      grad_ys=grad_ys,
                      **kwargs)
    debug_print("Got gradients {}".format(dv), 1)
    debug_print("for {}".format(copied_ys), 1)
    debug_print("with respect to {}".format(boundary + xs), 1)

    inputs_to_do_before = [y.op for y in ys]
    if grad_ys is not None:
        inputs_to_do_before += grad_ys
    wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
    my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)

    # partial derivatives to the checkpointed nodes
    # dictionary of "node: backprop" for nodes in the boundary
    d_checkpoints = {
        r: dr
        for r, dr in zip(checkpoints_disconnected.keys(),
                         dv[:len(checkpoints_disconnected)])
    }
    # partial derivatives to xs (usually the params of the neural net)
    d_xs = dv[len(checkpoints_disconnected):]

    # incorporate derivatives flowing through the checkpointed nodes
    checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops)
    for ts in checkpoints_sorted_lists[::-1]:
        debug_print("Processing list {}".format(ts), 1)
        checkpoints_other = [r for r in checkpoints if r not in ts]
        checkpoints_disconnected_other = [
            checkpoints_disconnected[r] for r in checkpoints_other
        ]

        # copy part of the graph below current checkpoint node, stopping at
        # other checkpoints nodes
        ops_to_copy = fast_backward_ops(within_ops=fwd_ops,
                                        seed_ops=[r.op for r in ts],
                                        stop_at_ts=checkpoints_other)
        debug_print(
            "Found {} ops to copy within {}, seed {}, stop_at {}".format(
                len(ops_to_copy), fwd_ops, [r.op for r in ts],
                checkpoints_other), 1)
        debug_print("ops_to_copy = {}".format(ops_to_copy), 1)
        if not ops_to_copy:  # we're done!
            break
        copied_sgv, info = ge.copy_with_input_replacements(
            ge.sgv(ops_to_copy), {})
        for origin_op, op in info._transformed_ops.items():
            op._set_device(origin_op.node_def.device)
        copied_ops = info._transformed_ops.values()
        debug_print("Copied {} to {}".format(ops_to_copy, copied_ops), 1)
        ge.reroute_ts(checkpoints_disconnected_other,
                      checkpoints_other,
                      can_modify=copied_ops)
        debug_print(
            "Rewired {} in place of {} restricted to {}".format(
                checkpoints_disconnected_other, checkpoints_other, copied_ops),
            1)

        # gradient flowing through the checkpointed node
        boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts]
        substitute_backprops = [d_checkpoints[r] for r in ts]
        dv = tf_gradients(boundary,
                          checkpoints_disconnected_other + xs,
                          grad_ys=substitute_backprops,
                          **kwargs)
        debug_print("Got gradients {}".format(dv), 1)
        debug_print("for {}".format(boundary), 1)
        debug_print(
            "with respect to {}".format(checkpoints_disconnected_other + xs),
            1)
        debug_print(
            "with boundary backprop substitutions {}".format(
                substitute_backprops), 1)

        inputs_to_do_before = [d_checkpoints[r].op for r in ts]
        wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
        my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)

        # partial derivatives to the checkpointed nodes
        for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]):
            if dr is not None:
                if d_checkpoints[r] is None:
                    d_checkpoints[r] = dr
                else:
                    d_checkpoints[r] += dr

        def _unsparsify(x):
            if not isinstance(x, tf.IndexedSlices):
                return x
            assert x.dense_shape is not None, "memory_saving_gradients encountered sparse gradients of unknown shape"
            indices = x.indices
            while indices.shape.ndims < x.values.shape.ndims:
                indices = tf.expand_dims(indices, -1)
            return tf.scatter_nd(indices, x.values, x.dense_shape)

        # partial derivatives to xs (usually the params of the neural net)
        d_xs_new = dv[len(checkpoints_other):]
        for j in range(len(xs)):
            if d_xs_new[j] is not None:
                if d_xs[j] is None:
                    d_xs[j] = _unsparsify(d_xs_new[j])
                else:
                    d_xs[j] += _unsparsify(d_xs_new[j])

    return d_xs
Пример #17
0
def gradients(ys, xs, grad_ys=None, checkpoints='collection', **kwargs):
    '''
    Authors: Tim Salimans & Yaroslav Bulatov

    memory efficient gradient implementation inspired by "Training Deep Nets with Sublinear Memory Cost"
    by Chen et al. 2016 (https://arxiv.org/abs/1604.06174)

    ys,xs,grad_ys,kwargs are the arguments to standard tensorflow tf.gradients
    (https://www.tensorflow.org/versions/r0.12/api_docs/python/train.html#gradients)

    'checkpoints' can either be
        - a list consisting of tensors from the forward pass of the neural net
          that we should re-use when calculating the gradients in the backward pass
          all other tensors that do not appear in this list will be re-computed
        - a string specifying how this list should be determined. currently we support
            - 'speed':  checkpoint all outputs of convolutions and matmuls. these ops are usually the most expensive,
                        so checkpointing them maximizes the running speed
                        (this is a good option if nonlinearities, concats, batchnorms, etc are taking up a lot of memory)
            - 'memory': try to minimize the memory usage
                        (currently using a very simple strategy that identifies a number of bottleneck tensors in the graph to checkpoint)
            - 'collection': look for a tensorflow collection named 'checkpoints', which holds the tensors to checkpoint
    '''

    #    print("Calling memsaving gradients with", checkpoints)
    if not isinstance(ys, list):
        ys = [ys]
    if not isinstance(xs, list):
        xs = [xs]

    bwd_ops = ge.get_backward_walk_ops([y.op for y in ys], inclusive=True)

    debug_print("bwd_ops: %s", bwd_ops)

    # forward ops are all ops that are candidates for recomputation
    fwd_ops = ge.get_forward_walk_ops([x.op for x in xs],
                                      inclusive=True,
                                      within_ops=bwd_ops)
    debug_print("fwd_ops: %s", fwd_ops)

    # exclude ops with no inputs
    fwd_ops = [op for op in fwd_ops if op.inputs]

    # don't recompute xs, remove variables
    xs_ops = _to_ops(xs)
    fwd_ops = [op for op in fwd_ops if not op in xs_ops]
    fwd_ops = [op for op in fwd_ops if not '/assign' in op.name]
    fwd_ops = [op for op in fwd_ops if not '/Assign' in op.name]
    fwd_ops = [op for op in fwd_ops if not '/read' in op.name]
    ts_all = ge.filter_ts(fwd_ops, True)  # get the tensors
    ts_all = [t for t in ts_all if '/read' not in t.name]
    ts_all = set(ts_all) - set(xs) - set(ys)

    checkpoints = 'collection'
    # construct list of tensors to checkpoint during forward pass, if not
    # given as input

    stereo_checkpoints = ge.filter_ts_from_regex(fwd_ops, "add")
    motion_checkpoints = ge.filter_ts_from_regex(fwd_ops, "Conv2D")

    my_ckps = []
    for x in motion_checkpoints:
        if ("motion" in x.name) and ("BatchNorm" not in x.name):
            my_ckps.append(x)
    for x in stereo_checkpoints:
        if ("stereo" in x.name) and ("BatchNorm" not in x.name):
            my_ckps.append(x)

    checkpoints = my_ckps
    checkpoints = list(set(checkpoints).intersection(ts_all))

    # at this point automatic selection happened and checkpoints is list of nodes
    assert isinstance(checkpoints, list)

    debug_print("Checkpoint nodes used: %s", checkpoints)
    # better error handling of special cases
    # xs are already handled as checkpoint nodes, so no need to include them
    xs_intersect_checkpoints = set(xs).intersection(set(checkpoints))
    if xs_intersect_checkpoints:
        debug_print("Warning, some input nodes are also checkpoint nodes: %s",
                    xs_intersect_checkpoints)
    ys_intersect_checkpoints = set(ys).intersection(set(checkpoints))
    debug_print("ys: %s, checkpoints: %s, intersect: %s", ys, checkpoints,
                ys_intersect_checkpoints)
    # saving an output node (ys) gives no benefit in memory while creating
    # new edge cases, exclude them
    if ys_intersect_checkpoints:
        debug_print(
            "Warning, some output nodes are also checkpoints nodes: %s",
            format_ops(ys_intersect_checkpoints))

    # remove initial and terminal nodes from checkpoints list if present
    checkpoints = list(set(checkpoints) - set(ys) - set(xs))

    # check that we have some nodes to checkpoint
    if not checkpoints:
        raise Exception('no checkpoints nodes found or given as input! ')

    # disconnect dependencies between checkpointed tensors
    checkpoints_disconnected = {}
    for x in checkpoints:
        if x.op and x.op.name is not None:
            grad_node = tf.stop_gradient(x, name=x.op.name + "_sg")
        else:
            grad_node = tf.stop_gradient(x)
        checkpoints_disconnected[x] = grad_node

    # partial derivatives to the checkpointed tensors and xs
    ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys],
                                    stop_at_ts=checkpoints,
                                    within_ops=fwd_ops)
    debug_print("Found %s ops to copy within fwd_ops %s, seed %s, stop_at %s",
                len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints)
    debug_print("ops_to_copy = %s", ops_to_copy)
    debug_print("Processing list %s", ys)
    copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {})
    copied_ops = info._transformed_ops.values()
    debug_print("Copied %s to %s", ops_to_copy, copied_ops)
    ge.reroute_ts(checkpoints_disconnected.values(),
                  checkpoints_disconnected.keys(),
                  can_modify=copied_ops)
    debug_print("Rewired %s in place of %s restricted to %s",
                checkpoints_disconnected.values(),
                checkpoints_disconnected.keys(), copied_ops)

    # get gradients with respect to current boundary + original x's
    copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys]
    boundary = list(checkpoints_disconnected.values())
    dv = tf_gradients(ys=copied_ys,
                      xs=boundary + xs,
                      grad_ys=grad_ys,
                      **kwargs)
    debug_print("Got gradients %s", dv)
    debug_print("for %s", copied_ys)
    debug_print("with respect to %s", boundary + xs)

    inputs_to_do_before = [y.op for y in ys]
    if grad_ys is not None:
        inputs_to_do_before += grad_ys
    wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
    my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)

    # partial derivatives to the checkpointed nodes
    # dictionary of "node: backprop" for nodes in the boundary
    d_checkpoints = {
        r: dr
        for r, dr in zip(checkpoints_disconnected.keys(),
                         dv[:len(checkpoints_disconnected)])
    }
    # partial derivatives to xs (usually the params of the neural net)
    d_xs = dv[len(checkpoints_disconnected):]

    # incorporate derivatives flowing through the checkpointed nodes
    checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops)
    for ts in checkpoints_sorted_lists[::-1]:
        debug_print("Processing list %s", ts)
        checkpoints_other = [r for r in checkpoints if r not in ts]
        checkpoints_disconnected_other = [
            checkpoints_disconnected[r] for r in checkpoints_other
        ]

        # copy part of the graph below current checkpoint node, stopping at
        # other checkpoints nodes
        ops_to_copy = fast_backward_ops(within_ops=fwd_ops,
                                        seed_ops=[r.op for r in ts],
                                        stop_at_ts=checkpoints_other)
        debug_print("Found %s ops to copy within %s, seed %s, stop_at %s",
                    len(ops_to_copy), fwd_ops, [r.op for r in ts],
                    checkpoints_other)
        debug_print("ops_to_copy = %s", ops_to_copy)
        if not ops_to_copy:  # we're done!
            break
        copied_sgv, info = ge.copy_with_input_replacements(
            ge.sgv(ops_to_copy), {})
        copied_ops = info._transformed_ops.values()
        debug_print("Copied %s to %s", ops_to_copy, copied_ops)
        ge.reroute_ts(checkpoints_disconnected_other,
                      checkpoints_other,
                      can_modify=copied_ops)
        debug_print("Rewired %s in place of %s restricted to %s",
                    checkpoints_disconnected_other, checkpoints_other,
                    copied_ops)

        # gradient flowing through the checkpointed node
        boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts]
        substitute_backprops = [d_checkpoints[r] for r in ts]
        dv = tf_gradients(boundary,
                          checkpoints_disconnected_other + xs,
                          grad_ys=substitute_backprops,
                          **kwargs)
        debug_print("Got gradients %s", dv)
        debug_print("for %s", boundary)
        debug_print("with respect to %s", checkpoints_disconnected_other + xs)
        debug_print("with boundary backprop substitutions %s",
                    substitute_backprops)

        inputs_to_do_before = [d_checkpoints[r].op for r in ts]
        wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
        my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)

        # partial derivatives to the checkpointed nodes
        for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]):
            if dr is not None:
                if d_checkpoints[r] is None:
                    d_checkpoints[r] = dr
                else:
                    d_checkpoints[r] += dr

        # partial derivatives to xs (usually the params of the neural net)
        d_xs_new = dv[len(checkpoints_other):]
        for j in range(len(xs)):
            if d_xs_new[j] is not None:
                if d_xs[j] is None:
                    d_xs[j] = d_xs_new[j]
                else:
                    d_xs[j] += d_xs_new[j]

    return d_xs
class OME(object):

    def __init__(self, graph=None,
                 grad_ys=None,
                 debug=False,
                 debug_level=1,
                 cpu_device="/cpu:0",
                 **kwargs):

        self._graph = graph
        self._grad_ys = grad_ys
        self._topo_sort = None
        self._cpu_device = cpu_device
        self._debug = debug
        self._debug_level = debug_level

        # keep log of tensors on host
        self._incpu_count = 0

        # store a dictionary of visited ops to avoid multiple visits
        self._ops_dict = {}
        self.kwargs = kwargs

    def run(self, graph=None):

        if graph:
            self._graph = graph

        if not self._graph:
            raise ValueError('The dataflow graph is required but has not been'
                             ' provided.')

        self._log_info("Editing model for LMS")
        start_time = time.time()

        loss_ops = tf.get_default_graph().get_operations()[-1:]
        xs_ops = tf.trainable_variables()

        # forward ops are all ops that are candidates for recomputation
        #    print("Calling memsaving gradients with", checkpoints)
        if not isinstance(loss_ops, list):
            ys = [loss_ops]
        if not isinstance(xs_ops, list):
            xs = [xs_ops]

        bwd_ops = ge.get_backward_walk_ops([y.op for y in loss_ops],
                                           inclusive=True)

        self._log_info("bwd_ops {}".format(bwd_ops), 2)

        # forward ops are all ops that are candidates for recomputation
        fwd_ops = ge.get_forward_walk_ops([x.op for x in xs],
                                          inclusive=True,
                                          within_ops=bwd_ops)
        self._log_info("fwd_ops: %s".format(fwd_ops, 2))

        # exclude ops with no inputs
        fwd_ops = [op for op in fwd_ops if op.inputs]

        # don't recompute xs, remove variables
        xs_ops = self._to_ops(xs)
        fwd_ops = [op for op in fwd_ops if not op in xs_ops]
        fwd_ops = [op for op in fwd_ops if not '/assign' in op.name]
        fwd_ops = [op for op in fwd_ops if not '/Assign' in op.name]
        fwd_ops = [op for op in fwd_ops if not '/read' in op.name]
        ts_all = ge.filter_ts(fwd_ops, True)  # get the tensors
        ts_all = [t for t in ts_all if '/read' not in t.name]
        ts_all = set(ts_all) - set(xs) - set(ys)

        # remove very small tensors and some weird ops
        def fixdims(t):  # tf.Dimension values are not compatible with int, convert manually
            try:
                return [int(e if e.value is not None else 64) for e in t]
            except:
                return [0]  # unknown shape

        ts_all = [t for t in ts_all if np.prod(fixdims(t.shape)) > MIN_CHECKPOINT_NODE_SIZE]
        ts_all = [t for t in ts_all if 'L2Loss' not in t.name]
        ts_all = [t for t in ts_all if 'entropy' not in t.name]
        ts_all = [t for t in ts_all if 'FusedBatchNorm' not in t.name]
        ts_all = [t for t in ts_all if 'Switch' not in t.name]
        ts_all = [t for t in ts_all if 'dropout' not in t.name]
        # DV: FP16_FIX - need to add 'Cast' layer here to make it work for FP16
        ts_all = [t for t in ts_all if 'Cast' not in t.name]

        # filter out all tensors that are inputs of the backward graph
        with util.capture_ops() as bwd_ops:
            tf_gradients(ys, xs, self._grad_ys, **self.kwargs)

        bwd_inputs = [t for op in bwd_ops for t in op.inputs]
        # list of tensors in forward graph that is in input to bwd graph
        ts_filtered = list(set(bwd_inputs).intersection(ts_all))
        self._log_info("Using tensors {}".format(ts_filtered), 1)

        # try two slightly different ways of getting bottlenecks tensors
        # to checkpoint
        for ts in [ts_filtered, ts_all]:

            # get all bottlenecks in the graph
            bottleneck_ts = []
            for t in ts:
                b = set(ge.get_backward_walk_ops(t.op, inclusive=True, within_ops=fwd_ops))
                f = set(ge.get_forward_walk_ops(t.op, inclusive=False, within_ops=fwd_ops))
                # check that there are not shortcuts
                b_inp = set([inp for op in b for inp in op.inputs]).intersection(ts_all)
                f_inp = set([inp for op in f for inp in op.inputs]).intersection(ts_all)
                if not set(b_inp).intersection(f_inp) and len(b_inp) + len(f_inp) >= len(ts_all):
                    bottleneck_ts.append(t)  # we have a bottleneck!
                else:
                    self._log_info("Rejected bottleneck candidate and ops {}".format(
                        [t] + list(set(ts_all) - set(b_inp) - set(f_inp))), 2)

            # success? or try again without filtering?
            if len(bottleneck_ts) >= np.sqrt(len(ts_filtered)):  # yes, enough bottlenecks found!
                break

        if not bottleneck_ts:
            raise Exception(
                'unable to find bottleneck tensors! please provide checkpoint nodes manually, or use checkpoints="speed".')

        # sort the bottlenecks
        bottlenecks_sorted_lists = self.tf_toposort(bottleneck_ts, within_ops=fwd_ops)
        sorted_bottlenecks = [t for ts in bottlenecks_sorted_lists for t in ts]

        # save an approximately optimal number ~ sqrt(N)
        N = len(ts_filtered)
        if len(bottleneck_ts) <= np.ceil(np.sqrt(N)):
            checkpoints = sorted_bottlenecks
        else:
            step = int(np.ceil(len(bottleneck_ts) / np.sqrt(N)))
            checkpoints = sorted_bottlenecks[step::step]

        checkpoints = list(set(checkpoints).intersection(ts_all))

        # at this point automatic selection happened and checkpoints is list of nodes
        assert isinstance(checkpoints, list)

        self._log_info("Checkpoint nodes used: %s", checkpoints, 0)
        # better error handling of special cases
        # xs are already handled as checkpoint nodes, so no need to include them
        xs_intersect_checkpoints = set(xs).intersection(set(checkpoints))
        if xs_intersect_checkpoints:
            self._log_info("Warning, some input nodes are also checkpoint nodes: {}".format(
                xs_intersect_checkpoints), 2)
        ys_intersect_checkpoints = set(ys).intersection(set(checkpoints))
        self._log_info("ys: {}, checkpoints: {}, intersect: {}".format(ys, checkpoints,
                                                                       ys_intersect_checkpoints), 1)
        # saving an output node (ys) gives no benefit in memory while creating
        # new edge cases, exclude them
        if ys_intersect_checkpoints:
            self._log_info("Warning, some output nodes are also checkpoints nodes: {}".format(
                self.format_ops(ys_intersect_checkpoints)), 2)

        # remove initial and terminal nodes from checkpoints list if present
        checkpoints = list(set(checkpoints) - set(ys) - set(xs))

        # check that we have some nodes to checkpoint
        if not checkpoints:
            raise Exception('no checkpoints nodes found or given as input! ')

        # disconnect dependencies between checkpointed tensors
        checkpoints_disconnected = {}
        for x in checkpoints:
            if x.op and x.op.name is not None:
                grad_node = tf.stop_gradient(x, name=x.op.name + "_sg")
            else:
                grad_node = tf.stop_gradient(x)
            checkpoints_disconnected[x] = grad_node

        # partial derivatives to the checkpointed tensors and xs
        ops_to_copy = self.fast_backward_ops(seed_ops=[y.op for y in ys],
                                             stop_at_ts=checkpoints, within_ops=fwd_ops)
        self._log_info("Found {} ops to copy within fwd_ops {}, seed {}, stop_at {}".format(
            len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints), 1)
        self._log_info("ops_to_copy = {}".format(ops_to_copy), 2)
        self._log_info("Processing list {}".format(ys), 2)
        copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {})
        for origin_op, op in info._transformed_ops.items():
            op._set_device(origin_op.node_def.device)
        copied_ops = info._transformed_ops.values()
        self._log_info("Copied {} to {}".format(ops_to_copy, copied_ops), 2)
        ge.reroute_ts(checkpoints_disconnected.values(), checkpoints_disconnected.keys(), can_modify=copied_ops)
        self._log_info("Rewired {} in place of {} restricted to {}".format(
            checkpoints_disconnected.values(), checkpoints_disconnected.keys(), copied_ops), 2)

        # get gradients with respect to current boundary + original x's
        copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys]
        boundary = list(checkpoints_disconnected.values())
        dv = tf_gradients(ys=copied_ys, xs=boundary + xs)
        self._log_info("Got gradients {}".format(dv), 2)
        self._log_info("for {}".format(copied_ys), 2)
        self._log_info("with respect to {}".format(boundary + xs), 2)

        inputs_to_do_before = [y.op for y in ys]
        if self._grad_ys is not None:
            inputs_to_do_before += self._grad_ys
        wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
        my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)

        # partial derivatives to the checkpointed nodes
        # dictionary of "node: backprop" for nodes in the boundary
        d_checkpoints = {r: dr for r, dr in zip(checkpoints_disconnected.keys(),
                                                dv[:len(checkpoints_disconnected)])}
        # partial derivatives to xs (usually the params of the neural net)
        d_xs = dv[len(checkpoints_disconnected):]

        # incorporate derivatives flowing through the checkpointed nodes
        checkpoints_sorted_lists = self.tf_toposort(checkpoints, within_ops=fwd_ops)
        for ts in checkpoints_sorted_lists[::-1]:
            self._log_info("Processing list {}".format(ts), 2)
            checkpoints_other = [r for r in checkpoints if r not in ts]
            checkpoints_disconnected_other = [checkpoints_disconnected[r] for r in checkpoints_other]

            # copy part of the graph below current checkpoint node, stopping at
            # other checkpoints nodes
            ops_to_copy = self.fast_backward_ops(within_ops=fwd_ops, seed_ops=[r.op for r in ts],
                                            stop_at_ts=checkpoints_other)
            self._log_info("Found {} ops to copy within {}, seed {}, stop_at {}".format(
                        len(ops_to_copy), fwd_ops, [r.op for r in ts],
                        checkpoints_other), 2)
            self._log_info("ops_to_copy = {}".format(ops_to_copy), 2)
            if not ops_to_copy:  # we're done!
                break
            copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {})
            for origin_op, op in info._transformed_ops.items():
                op._set_device(origin_op.node_def.device)
            copied_ops = info._transformed_ops.values()
            self._log_info("Copied {} to {}".format(ops_to_copy, copied_ops), 2)
            ge.reroute_ts(checkpoints_disconnected_other, checkpoints_other, can_modify=copied_ops)
            self._log_info("Rewired {} in place of {} restricted to {}".format(
                        checkpoints_disconnected_other, checkpoints_other, copied_ops), 2)

            # gradient flowing through the checkpointed node
            boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts]
            substitute_backprops = [d_checkpoints[r] for r in ts]
            dv = tf_gradients(boundary,
                              checkpoints_disconnected_other + xs,
                              grad_ys=substitute_backprops, **self.kwargs)
            self._log_info("Got gradients {}".format(dv), 2)
            self._log_info("for {}".format(boundary), 2)
            self._log_info("with respect to {}".format(checkpoints_disconnected_other + xs), 2)
            self._log_info("with boundary backprop substitutions {}".format(substitute_backprops), 2)

            inputs_to_do_before = [d_checkpoints[r].op for r in ts]
            wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
            my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)

            # partial derivatives to the checkpointed nodes
            for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]):
                if dr is not None:
                    if d_checkpoints[r] is None:
                        d_checkpoints[r] = dr
                    else:
                        d_checkpoints[r] += dr

            def _unsparsify(x):
                if not isinstance(x, tf.IndexedSlices):
                    return x
                assert x.dense_shape is not None, "memory_saving_gradients encountered sparse gradients of unknown shape"
                indices = x.indices
                while indices.shape.ndims < x.values.shape.ndims:
                    indices = tf.expand_dims(indices, -1)
                return tf.scatter_nd(indices, x.values, x.dense_shape)

            # partial derivatives to xs (usually the params of the neural net)
            d_xs_new = dv[len(checkpoints_other):]
            for j in range(len(xs)):
                if d_xs_new[j] is not None:
                    if d_xs[j] is None:
                        d_xs[j] = _unsparsify(d_xs_new[j])
                    else:
                        d_xs[j] += _unsparsify(d_xs_new[j])

        return d_xs

    def tf_toposort(self, ts, within_ops=None):
        all_ops = ge.get_forward_walk_ops([x.op for x in ts], within_ops=within_ops)

        deps = {}
        for op in all_ops:
            for o in op.outputs:
                deps[o] = set(op.inputs)
        sorted_ts = toposort(deps)

        # only keep the tensors from our original list
        ts_sorted_lists = []
        for l in sorted_ts:
            keep = list(set(l).intersection(ts))
            if keep:
                ts_sorted_lists.append(keep)

        return ts_sorted_lists

    def fast_backward_ops(self, within_ops, seed_ops, stop_at_ts):
        bwd_ops = set(ge.get_backward_walk_ops(seed_ops, stop_at_ts=stop_at_ts))
        ops = bwd_ops.intersection(within_ops).difference([t.op for t in stop_at_ts])
        return list(ops)

    @contextlib.contextmanager
    def capture_ops(self):
        """Decorator to capture ops created in the block.
        with capture_ops() as ops:
          # create some ops
        print(ops) # => prints ops created.
        """

        micros = int(time.time() * 10 ** 6)
        scope_name = str(micros)
        op_list = []
        with tf.name_scope(scope_name):
            yield op_list

        g = tf.get_default_graph()
        op_list.extend(ge.select_ops(scope_name + "/.*", graph=g))

    def _to_op(self, tensor_or_op):
        if hasattr(tensor_or_op, "op"):
            return tensor_or_op.op
        return tensor_or_op

    def _to_ops(self, iterable):
        if not self._is_iterable(iterable):
            return iterable
        return [self._to_op(i) for i in iterable]

    def _is_iterable(self, o):
        try:
            _ = iter(o)
        except Exception:
            return False
        return True

    def _log_info(self, message, level=0):
        """Log debug information.
        Args:
          message: a formatted string.
          level: an `integer`.
        """
        if level == 0 or (self._debug and self._debug_level >= level):
            # Use tf.logging.info instead of print, since print
            # is not thread safe, which can break tests.
            tf.logging.info("[OME][{}] {}".format(level, message))

    def format_ops(self, ops, sort_outputs=True):
        """Helper method for printing ops. Converts Tensor/Operation op to op.name,
        rest to str(op)."""

        if hasattr(ops, '__iter__') and not isinstance(ops, str):
            l = [(op.name if hasattr(op, "name") else str(op)) for op in ops]
            if sort_outputs:
                return sorted(l)
            return l
        else:
            return ops.name if hasattr(ops, "name") else str(ops)

    def my_add_control_inputs(wait_to_do_ops, inputs_to_do_before):
        for op in wait_to_do_ops:
            ci = [i for i in inputs_to_do_before if op.control_inputs is None or i not in op.control_inputs]
            ge.add_control_inputs(op, ci)

    -----------------------------------------------------

    reachable_ops = set()
    for seed_op in seed_ops:
        reachable_ops |= set(self._get_forward_walk_ops(seed_op))

    for op in reachable_ops:
        if 'lms/swap' in op.name:
            self._log_info('This model has already been updated with LMS '
                           'swap operations. LMS will not re-process it.')
            return
    # exclusive ops
    self._excl_ops = self._filter_scopes_and_types(reachable_ops,
                                                   self._excl_scopes,
                                                   self._excl_types)
    # inclusive ops
    self._incl_ops = self._filter_scopes_and_types(reachable_ops,
                                                   self._incl_scopes,
                                                   self._incl_types)

    reachable_ops -= self._grad_ops

    # build a topological sort
    self._topo_sort = topos.TOPOS(seed_ops, self._grad_ops)
    self._topo_sort.build()
    for i in range(0, self._topo_sort.size):
        self._log_info("[{}]: {}".format(
            i, [op.name for op in self._topo_sort.get_ops(i)]), 1)

    self._do_action(seed_ops)

    # check the validation of the new model
    new_reachable_ops = set()
    for seed_op in seed_ops:
        new_reachable_ops |= set(ge.get_forward_walk_ops(seed_op))
    new_reachable_ops -= self._grad_ops
    if (new_reachable_ops >= reachable_ops):
        self._log_info("Edited model is valid and logically equivalent to the original one")
        self._log_info("Added {} ops into the model".format(len(new_reachable_ops - reachable_ops)))
    else:
        self._log_info("Edited model is invalid. Running this may produce unexpected result")

    self._log_info("Editing model for LMS, took: {} ms".format(
        (time.time() - start_time) * 1000))
    self._log_info(
        "{} tensors will be swapped out(in) to(from) the host".format(
            self._incpu_count))
    return (new_reachable_ops - reachable_ops)
    def run(self, graph=None):
        """Edit the graph by adding swapin and swapout ops.

        Swapin and swapout ops are in the host.

        The graph is modified in-place.

        Return:
          a set of added ops.
        """
        if graph:
            self._graph = graph

        if self._n_tensors == 0:
            self._log_info("LMS is disabled and will not modify the model.")
            return  # turn off LMS
        elif self._n_tensors < 0:
            self._n_tensors = 0  # swap all tensors (default)

        if not self._graph:
            raise ValueError('The dataflow graph is required but has not been'
                             ' provided.')

        self._log_info("Editing model for LMS")
        self._print_configuration()
        start_time = time.time()

        self._build_gradient_ops()
        seed_ops = self._get_seed_ops()

        self._log_info(
            "Starting ops: {}".format(
                [(op.name, op.type) for op in seed_ops]), 1)

        reachable_ops = set()
        for seed_op in seed_ops:
            reachable_ops |= set(self._get_forward_walk_ops(seed_op))

        for op in reachable_ops:
            if 'lms/swap' in op.name:
                self._log_info('This model has already been updated with LMS '
                               'swap operations. LMS will not re-process it.')
                return
        # exclusive ops
        self._excl_ops = self._filter_scopes_and_types(reachable_ops,
                                                       self._excl_scopes,
                                                       self._excl_types)
        # inclusive ops
        self._incl_ops = self._filter_scopes_and_types(reachable_ops,
                                                       self._incl_scopes,
                                                       self._incl_types)

        reachable_ops -= self._grad_ops

        # build a topological sort
        self._topo_sort = topos.TOPOS(seed_ops, self._grad_ops)
        self._topo_sort.build()
        for i in range(0, self._topo_sort.size):
            self._log_info("[{}]: {}".format(
                i, [op.name for op in self._topo_sort.get_ops(i)]), 1)

        self._do_action(seed_ops)

        # check the validation of the new model
        new_reachable_ops = set()
        for seed_op in seed_ops:
            new_reachable_ops |= set(ge.get_forward_walk_ops(seed_op))
        new_reachable_ops -= self._grad_ops
        if (new_reachable_ops >= reachable_ops):
            self._log_info("Edited model is valid and logically equivalent to the original one")
            self._log_info("Added {} ops into the model".format(len(new_reachable_ops - reachable_ops)))
        else:
            self._log_info("Edited model is invalid. Running this may produce unexpected result")

        self._log_info("Editing model for LMS, took: {} ms".format(
            (time.time()-start_time)*1000))
        self._log_info(
            "{} tensors will be swapped out(in) to(from) the host".format(
                self._incpu_count))
        return (new_reachable_ops - reachable_ops)
trainable_variables = tf.trainable_variables()  # TODO - test only adding placeholders to variables in specific scope(s)
gamma_update_op = None
for var in trainable_variables:
    # Add summary
    tf.summary.histogram(var.name, var)

    # Get shape
    var_shape = var.shape
    placeholder_shapes.append(var_shape)

    # Create placeholder
    delta_placeholder = tf.placeholder('float', shape=var_shape)
    delta_placeholders.append(delta_placeholder)

    # Get ops from forward walk from var
    fw_ops = ge.get_forward_walk_ops(var.op.outputs)

    not_types = ['Assign', 'Identity', 'HistogramSummary', 'Enter']

    next_op = None

    # Select the correct op from the forward walk to connect to
    for fw_op in fw_ops:
        if fw_op.type not in not_types:
            next_op = fw_op
            break

    if next_op is None:
        raise ValueError('No suitable next op found to connect to. Try looking at the graph or full list of forward ops')

    # Add placeholder and variable
def gradients(ys, xs,   # pylint: disable: too-many-statements, too-many-branches
              grad_ys=None, checkpoints='collection', **kwargs):
    '''
    Authors: Tim Salimans & Yaroslav Bulatov

    memory efficient gradient implementation inspired by "Training Deep Nets with Sublinear Memory
    Cost" by Chen et al. 2016 (https://arxiv.org/abs/1604.06174)

    ys,xs,grad_ys,kwargs are the arguments to standard tensorflow tf.gradients
    (https://www.tensorflow.org/versions/r0.12/api_docs/python/train.html#gradients)

    'checkpoints' can either be
        - a list consisting of tensors from the forward pass of the neural net
          that we should re-use when calculating the gradients in the backward pass
          all other tensors that do not appear in this list will be re-computed
        - a string specifying how this list should be determined. currently we support
            - 'speed':  checkpoint all outputs of convolutions and matmuls. these ops are usually
                        the most expensive, so checkpointing them maximizes the running speed
                        (this is a good option if nonlinearities, concats, batchnorms, etc are
                        taking up a lot of memory)
            - 'memory': try to minimize the memory usage
                        (currently using a very simple strategy that identifies a number of
                        bottleneck tensors in the graph to checkpoint)
            - 'collection': look for a tensorflow collection named 'checkpoints', which holds the
                            tensors to checkpoint
    '''

    #    print("Calling memsaving gradients with", checkpoints)
    if not isinstance(ys, list):
        ys = [ys]
    if not isinstance(xs, list):
        xs = [xs]

    bwd_ops = ge.get_backward_walk_ops([y.op for y in ys],
                                       inclusive=True)

    debug_print("bwd_ops: {}".format(bwd_ops))

    # forward ops are all ops that are candidates for recomputation
    fwd_ops = ge.get_forward_walk_ops([x.op for x in xs],
                                      inclusive=True,
                                      within_ops=bwd_ops)
    debug_print("fwd_ops: {}".format(fwd_ops))

    # exclude ops with no inputs
    fwd_ops = [op for op in fwd_ops if op.inputs]

    # don't recompute xs, remove variables
    xs_ops = _to_ops(xs)
    fwd_ops = [op for op in fwd_ops if op not in xs_ops]
    fwd_ops = [op for op in fwd_ops if '/assign' not in op.name]
    fwd_ops = [op for op in fwd_ops if '/Assign' not in op.name]
    fwd_ops = [op for op in fwd_ops if '/read' not in op.name]
    ts_all = ge.filter_ts(fwd_ops, True)  # get the tensors
    ts_all = [t for t in ts_all if '/read' not in t.name]
    ts_all = set(ts_all) - set(xs) - set(ys)

    # construct list of tensors to checkpoint during forward pass, if not
    # given as input
    if type(checkpoints) is not list:
        if checkpoints == 'collection':
            checkpoints = tf.get_collection('checkpoints')

        elif checkpoints == 'speed':
            # checkpoint all expensive ops to maximize running speed
            checkpoints = ge.filter_ts_from_regex(fwd_ops, 'conv2d|Conv|MatMul')

        elif checkpoints == 'memory':

            # remove very small tensors and some weird ops
            def fixdims(t):  # tf.Dimension values are not compatible with int, convert manually
                try:
                    return [int(e if e.value is not None else 64) for e in t]
                except:
                    return [0]  # unknown shape
            ts_all = [t for t in ts_all if np.prod(fixdims(t.shape)) > MIN_CHECKPOINT_NODE_SIZE]
            ts_all = [t for t in ts_all if 'L2Loss' not in t.name]
            ts_all = [t for t in ts_all if 'entropy' not in t.name]
            ts_all = [t for t in ts_all if 'FusedBatchNorm' not in t.name]
            ts_all = [t for t in ts_all if 'Switch' not in t.name]
            ts_all = [t for t in ts_all if 'dropout' not in t.name]
            # DV: FP16_FIX - need to add 'Cast' layer here to make it work for FP16
            ts_all = [t for t in ts_all if 'Cast' not in t.name]

            # filter out all tensors that are inputs of the backward graph
            with util.capture_ops() as bwd_ops:
                tf_gradients(ys, xs, grad_ys, **kwargs)

            bwd_inputs = [t for op in bwd_ops for t in op.inputs]
            # list of tensors in forward graph that is in input to bwd graph
            ts_filtered = list(set(bwd_inputs).intersection(ts_all))
            debug_print("Using tensors {}".format(ts_filtered))

            # try two slightly different ways of getting bottlenecks tensors
            # to checkpoint
            for ts in [ts_filtered, ts_all]:

                # get all bottlenecks in the graph
                bottleneck_ts = []
                for t in ts:
                    b = set(ge.get_backward_walk_ops(t.op, inclusive=True, within_ops=fwd_ops))
                    f = set(ge.get_forward_walk_ops(t.op, inclusive=False, within_ops=fwd_ops))
                    # check that there are not shortcuts
                    b_inp = set([inp for op in b for inp in op.inputs]).intersection(ts_all)
                    f_inp = set([inp for op in f for inp in op.inputs]).intersection(ts_all)
                    if not set(b_inp).intersection(f_inp) and len(b_inp)+len(f_inp) >= len(ts_all):
                        bottleneck_ts.append(t)  # we have a bottleneck!
                    else:
                        debug_print("Rejected bottleneck candidate and ops {}".format(
                            [t] + list(set(ts_all) - set(b_inp) - set(f_inp))))

                # success? or try again without filtering?
                if len(bottleneck_ts) >= np.sqrt(len(ts_filtered)):  # enough bottlenecks found!
                    break

            if not bottleneck_ts:
                raise Exception('unable to find bottleneck tensors! please provide checkpoint '
                                'nodes manually, or use checkpoints="speed".')

            # sort the bottlenecks
            bottlenecks_sorted_lists = tf_toposort(bottleneck_ts, within_ops=fwd_ops)
            sorted_bottlenecks = [t for ts in bottlenecks_sorted_lists for t in ts]

            # save an approximately optimal number ~ sqrt(N)
            N = len(ts_filtered)
            if len(bottleneck_ts) <= np.ceil(np.sqrt(N)):
                checkpoints = sorted_bottlenecks
            else:
                step = int(np.ceil(len(bottleneck_ts) / np.sqrt(N)))
                checkpoints = sorted_bottlenecks[step::step]

        else:
            raise Exception('%s is unsupported input for "checkpoints"' % (checkpoints,))

    checkpoints = list(set(checkpoints).intersection(ts_all))

    # at this point automatic selection happened and checkpoints is list of nodes
    assert isinstance(checkpoints, list)

    debug_print("Checkpoint nodes used: {}".format(checkpoints))
    # better error handling of special cases
    # xs are already handled as checkpoint nodes, so no need to include them
    xs_intersect_checkpoints = set(xs).intersection(set(checkpoints))
    if xs_intersect_checkpoints:
        debug_print("Warning, some input nodes are also checkpoint nodes: {}".format(
            xs_intersect_checkpoints))
    ys_intersect_checkpoints = set(ys).intersection(set(checkpoints))
    debug_print("ys: {}, checkpoints:{}, intersect: {}".format(
        ys, checkpoints, ys_intersect_checkpoints))
    # saving an output node (ys) gives no benefit in memory while creating
    # new edge cases, exclude them
    if ys_intersect_checkpoints:
        debug_print("Warning, some output nodes are also checkpoints nodes: {}".format(
            format_ops(ys_intersect_checkpoints)))

    # remove initial and terminal nodes from checkpoints list if present
    checkpoints = list(set(checkpoints) - set(ys) - set(xs))

    # check that we have some nodes to checkpoint
    if not checkpoints:
        raise Exception('no checkpoints nodes found or given as input! ')

    # disconnect dependencies between checkpointed tensors
    checkpoints_disconnected = {}
    for x in checkpoints:
        if x.op and x.op.name is not None:
            grad_node = tf.stop_gradient(x, name=x.op.name+"_sg")
        else:
            grad_node = tf.stop_gradient(x)
        checkpoints_disconnected[x] = grad_node

    # partial derivatives to the checkpointed tensors and xs
    ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys],
                                    stop_at_ts=checkpoints, within_ops=fwd_ops)
    debug_print("Found {} ops to copy within fwd_ops {}, seed {}, stop_at {}".format(
        len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints))
    debug_print("ops_to_copy = {}".format(ops_to_copy))
    debug_print("Processing list {}".format(ys))
    _, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {})
    for origin_op, op in info._transformed_ops.items():
        op._set_device(origin_op.node_def.device)
    copied_ops = info._transformed_ops.values()
    debug_print("Copied {} to {}".format(ops_to_copy, copied_ops))
    ge.reroute_ts(checkpoints_disconnected.values(),
                  checkpoints_disconnected.keys(),
                  can_modify=copied_ops)
    debug_print("Rewired {} in place of {} restricted to {}".format(
        checkpoints_disconnected.values(), checkpoints_disconnected.keys(), copied_ops))

    # get gradients with respect to current boundary + original x's
    copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys]
    boundary = list(checkpoints_disconnected.values())
    dv = tf_gradients(ys=copied_ys, xs=boundary+xs, grad_ys=grad_ys, **kwargs)
    debug_print("Got gradients {}".format(dv))
    debug_print("for %s", copied_ys)
    debug_print("with respect to {}".format(boundary+xs))

    inputs_to_do_before = [y.op for y in ys]
    if grad_ys is not None:
        inputs_to_do_before += grad_ys
    wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
    my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)

    # partial derivatives to the checkpointed nodes
    # dictionary of "node: backprop" for nodes in the boundary
    d_checkpoints = {r: dr for r, dr in zip(checkpoints_disconnected.keys(),
                                            dv[:len(checkpoints_disconnected)])}
    # partial derivatives to xs (usually the params of the neural net)
    d_xs = dv[len(checkpoints_disconnected):]

    # incorporate derivatives flowing through the checkpointed nodes
    checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops)
    for ts in checkpoints_sorted_lists[::-1]:
        debug_print("Processing list {}".format(ts))
        checkpoints_other = [r for r in checkpoints if r not in ts]
        checkpoints_disconnected_other = [checkpoints_disconnected[r] for r in checkpoints_other]

        # copy part of the graph below current checkpoint node, stopping at
        # other checkpoints nodes
        ops_to_copy = fast_backward_ops(within_ops=fwd_ops,
                                        seed_ops=[r.op for r in ts],
                                        stop_at_ts=checkpoints_other)
        debug_print("Found {} ops to copy within {}, seed {}, stop_at {}".format(
            len(ops_to_copy), fwd_ops, [r.op for r in ts], checkpoints_other))
        debug_print("ops_to_copy = {}".format(ops_to_copy))
        if not ops_to_copy:  # we're done!
            break
        _, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {})
        for origin_op, op in info._transformed_ops.items():
            op._set_device(origin_op.node_def.device)
        copied_ops = info._transformed_ops.values()
        debug_print("Copied {} to {}".format(ops_to_copy, copied_ops))
        ge.reroute_ts(checkpoints_disconnected_other, checkpoints_other, can_modify=copied_ops)
        debug_print("Rewired %s in place of %s restricted to %s",
                    checkpoints_disconnected_other, checkpoints_other, copied_ops)

        # gradient flowing through the checkpointed node
        boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts]
        substitute_backprops = [d_checkpoints[r] for r in ts]
        dv = tf_gradients(boundary,
                          checkpoints_disconnected_other+xs,
                          grad_ys=substitute_backprops, **kwargs)
        debug_print("Got gradients {}".format(dv))
        debug_print("for {}".format(boundary))
        debug_print("with respect to {}".format(checkpoints_disconnected_other+xs))
        debug_print("with boundary backprop substitutions {}".format(substitute_backprops))

        inputs_to_do_before = [d_checkpoints[r].op for r in ts]
        wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
        my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)

        # partial derivatives to the checkpointed nodes
        for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]):
            if dr is not None:
                if d_checkpoints[r] is None:
                    d_checkpoints[r] = dr
                else:
                    d_checkpoints[r] += dr

        def _unsparsify(var_x):
            if not isinstance(var_x, tf.IndexedSlices):
                return var_x
            assert var_x.dense_shape is not None, \
                "memory_saving_gradients encountered sparse gradients of unknown shape"
            indices = var_x.indices
            while indices.shape.ndims < var_x.values.shape.ndims:
                indices = tf.expand_dims(indices, -1)
            return tf.scatter_nd(indices, var_x.values, var_x.dense_shape)

        # partial derivatives to xs (usually the params of the neural net)
        d_xs_new = dv[len(checkpoints_other):]
        for j in range(len(xs)):
            if d_xs_new[j] is not None:
                if d_xs[j] is None:
                    d_xs[j] = _unsparsify(d_xs_new[j])
                else:
                    d_xs[j] += _unsparsify(d_xs_new[j])

    return d_xs
def gradients(ys, xs, grad_ys=None, checkpoints='collection', **kwargs):
    '''
    Authors: Tim Salimans & Yaroslav Bulatov

    memory efficient gradient implementation inspired by "Training Deep Nets with Sublinear Memory Cost"
    by Chen et al. 2016 (https://arxiv.org/abs/1604.06174)

    ys,xs,grad_ys,kwargs are the arguments to standard tensorflow tf.gradients
    (https://www.tensorflow.org/versions/r0.12/api_docs/python/train.html#gradients)

    'checkpoints' can either be
        - a list consisting of tensors from the forward pass of the neural net
          that we should re-use when calculating the gradients in the backward pass
          all other tensors that do not appear in this list will be re-computed
        - a string specifying how this list should be determined. currently we support
            - 'speed':  checkpoint all outputs of convolutions and matmuls. these ops are usually the most expensive,
                        so checkpointing them maximizes the running speed
                        (this is a good option if nonlinearities, concats, batchnorms, etc are taking up a lot of memory)
            - 'memory': try to minimize the memory usage
                        (currently using a very simple strategy that identifies a number of bottleneck tensors in the graph to checkpoint)
            - 'collection': look for a tensorflow collection named 'checkpoints', which holds the tensors to checkpoint
    '''

    #    print("Calling memsaving gradients with", checkpoints)
    if not isinstance(ys, list):
        ys = [ys]
    if not isinstance(xs, list):
        xs = [xs]

    bwd_ops = ge.get_backward_walk_ops([y.op for y in ys], inclusive=True)

    debug_print("bwd_ops: %s", bwd_ops)

    # forward ops are all ops that are candidates for recomputation
    fwd_ops = ge.get_forward_walk_ops([x.op for x in xs],
                                      inclusive=True,
                                      within_ops=bwd_ops)
    debug_print("fwd_ops: %s", fwd_ops)

    # exclude ops with no inputs
    fwd_ops = [op for op in fwd_ops if op.inputs]

    # don't recompute xs, remove variables
    xs_ops = _to_ops(xs)
    fwd_ops = [op for op in fwd_ops if not op in xs_ops]
    fwd_ops = [op for op in fwd_ops if not '/assign' in op.name]
    fwd_ops = [op for op in fwd_ops if not '/Assign' in op.name]
    fwd_ops = [op for op in fwd_ops if not '/read' in op.name]
    ts_all = ge.filter_ts(fwd_ops, True)  # get the tensors
    ts_all = [t for t in ts_all if '/read' not in t.name]
    ts_all = set(ts_all) - set(xs) - set(ys)

    # construct list of tensors to checkpoint during forward pass, if not
    # given as input
    if type(checkpoints) is not list:
        if checkpoints == 'collection':
            checkpoints = tf.get_collection('checkpoints')

        elif checkpoints == 'speed':
            # checkpoint all expensive ops to maximize running speed
            checkpoints = ge.filter_ts_from_regex(fwd_ops,
                                                  'conv2d|Conv|MatMul')

        elif checkpoints == 'memory':

            # remove very small tensors and some weird ops
            def fixdims(
                t
            ):  # tf.Dimension values are not compatible with int, convert manually
                try:
                    return [int(e if e.value is not None else 64) for e in t]
                except:
                    return [0]  # unknown shape

            ts_all = [
                t for t in ts_all
                if np.prod(fixdims(t.shape)) > MIN_CHECKPOINT_NODE_SIZE
            ]
            ts_all = [t for t in ts_all if 'L2Loss' not in t.name]
            ts_all = [t for t in ts_all if 'entropy' not in t.name]
            ts_all = [t for t in ts_all if 'FusedBatchNorm' not in t.name]
            ts_all = [t for t in ts_all if 'Switch' not in t.name]
            ts_all = [t for t in ts_all if 'dropout' not in t.name]
            # DV: FP16_FIX - need to add 'Cast' layer here to make it work for FP16
            ts_all = [t for t in ts_all if 'Cast' not in t.name]

            # filter out all tensors that are inputs of the backward graph
            with util.capture_ops() as bwd_ops:
                tf_gradients(ys, xs, grad_ys, **kwargs)

            bwd_inputs = [t for op in bwd_ops for t in op.inputs]
            # list of tensors in forward graph that is in input to bwd graph
            ts_filtered = list(set(bwd_inputs).intersection(ts_all))
            debug_print("Using tensors %s", ts_filtered)

            # try two slightly different ways of getting bottlenecks tensors
            # to checkpoint
            for ts in [ts_filtered, ts_all]:

                # get all bottlenecks in the graph
                bottleneck_ts = []
                for t in ts:
                    b = set(
                        ge.get_backward_walk_ops(t.op,
                                                 inclusive=True,
                                                 within_ops=fwd_ops))
                    f = set(
                        ge.get_forward_walk_ops(t.op,
                                                inclusive=False,
                                                within_ops=fwd_ops))
                    # check that there are not shortcuts
                    b_inp = set([inp for op in b
                                 for inp in op.inputs]).intersection(ts_all)
                    f_inp = set([inp for op in f
                                 for inp in op.inputs]).intersection(ts_all)
                    if not set(b_inp).intersection(
                            f_inp) and len(b_inp) + len(f_inp) >= len(ts_all):
                        bottleneck_ts.append(t)  # we have a bottleneck!
                    else:
                        debug_print(
                            "Rejected bottleneck candidate and ops %s",
                            [t] + list(set(ts_all) - set(b_inp) - set(f_inp)))

                # success? or try again without filtering?
                if len(bottleneck_ts) >= np.sqrt(
                        len(ts_filtered)):  # yes, enough bottlenecks found!
                    break

            if not bottleneck_ts:
                raise Exception(
                    'unable to find bottleneck tensors! please provide checkpoint nodes manually, or use checkpoints="speed".'
                )

            # sort the bottlenecks
            bottlenecks_sorted_lists = tf_toposort(bottleneck_ts,
                                                   within_ops=fwd_ops)
            sorted_bottlenecks = [
                t for ts in bottlenecks_sorted_lists for t in ts
            ]

            # save an approximately optimal number ~ sqrt(N)
            N = len(ts_filtered)
            if len(bottleneck_ts) <= np.ceil(np.sqrt(N)):
                checkpoints = sorted_bottlenecks
            else:
                step = int(np.ceil(len(bottleneck_ts) / np.sqrt(N)))
                checkpoints = sorted_bottlenecks[step::step]

        else:
            raise Exception('%s is unsupported input for "checkpoints"' %
                            (checkpoints, ))

    checkpoints = list(set(checkpoints).intersection(ts_all))

    # at this point automatic selection happened and checkpoints is list of nodes
    assert isinstance(checkpoints, list)

    debug_print("Checkpoint nodes used: %s", checkpoints)
    # better error handling of special cases
    # xs are already handled as checkpoint nodes, so no need to include them
    xs_intersect_checkpoints = set(xs).intersection(set(checkpoints))
    if xs_intersect_checkpoints:
        debug_print("Warning, some input nodes are also checkpoint nodes: %s",
                    xs_intersect_checkpoints)
    ys_intersect_checkpoints = set(ys).intersection(set(checkpoints))
    debug_print("ys: %s, checkpoints: %s, intersect: %s", ys, checkpoints,
                ys_intersect_checkpoints)
    # saving an output node (ys) gives no benefit in memory while creating
    # new edge cases, exclude them
    if ys_intersect_checkpoints:
        debug_print(
            "Warning, some output nodes are also checkpoints nodes: %s",
            format_ops(ys_intersect_checkpoints))

    # remove initial and terminal nodes from checkpoints list if present
    checkpoints = list(set(checkpoints) - set(ys) - set(xs))

    # check that we have some nodes to checkpoint
    # if not checkpoints:
    #     raise Exception('no checkpoints nodes found or given as input! ')

    # disconnect dependencies between checkpointed tensors
    checkpoints_disconnected = {}
    for x in checkpoints:
        if x.op and x.op.name is not None:
            grad_node = tf.stop_gradient(x, name=x.op.name + "_sg")
        else:
            grad_node = tf.stop_gradient(x)
        checkpoints_disconnected[x] = grad_node

    # partial derivatives to the checkpointed tensors and xs
    ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys],
                                    stop_at_ts=checkpoints,
                                    within_ops=fwd_ops)
    debug_print("Found %s ops to copy within fwd_ops %s, seed %s, stop_at %s",
                len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints)
    debug_print("ops_to_copy = %s", ops_to_copy)
    debug_print("Processing list %s", ys)
    copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {})
    for origin_op, op in info._transformed_ops.items():
        op._set_device(origin_op.node_def.device)
    copied_ops = info._transformed_ops.values()
    debug_print("Copied %s to %s", ops_to_copy, copied_ops)
    ge.reroute_ts(checkpoints_disconnected.values(),
                  checkpoints_disconnected.keys(),
                  can_modify=copied_ops)
    debug_print("Rewired %s in place of %s restricted to %s",
                checkpoints_disconnected.values(),
                checkpoints_disconnected.keys(), copied_ops)

    # get gradients with respect to current boundary + original x's
    copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys]
    boundary = list(checkpoints_disconnected.values())
    dv = tf_gradients(ys=copied_ys,
                      xs=boundary + xs,
                      grad_ys=grad_ys,
                      **kwargs)
    debug_print("Got gradients %s", dv)
    debug_print("for %s", copied_ys)
    debug_print("with respect to %s", boundary + xs)

    inputs_to_do_before = [y.op for y in ys]
    if grad_ys is not None:
        inputs_to_do_before += grad_ys
    wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
    my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)

    # partial derivatives to the checkpointed nodes
    # dictionary of "node: backprop" for nodes in the boundary
    d_checkpoints = {
        r: dr
        for r, dr in zip(checkpoints_disconnected.keys(),
                         dv[:len(checkpoints_disconnected)])
    }
    # partial derivatives to xs (usually the params of the neural net)
    d_xs = dv[len(checkpoints_disconnected):]

    # incorporate derivatives flowing through the checkpointed nodes
    checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops)
    for ts in checkpoints_sorted_lists[::-1]:
        debug_print("Processing list %s", ts)
        checkpoints_other = [r for r in checkpoints if r not in ts]
        checkpoints_disconnected_other = [
            checkpoints_disconnected[r] for r in checkpoints_other
        ]

        # copy part of the graph below current checkpoint node, stopping at
        # other checkpoints nodes
        ops_to_copy = fast_backward_ops(within_ops=fwd_ops,
                                        seed_ops=[r.op for r in ts],
                                        stop_at_ts=checkpoints_other)
        debug_print("Found %s ops to copy within %s, seed %s, stop_at %s",
                    len(ops_to_copy), fwd_ops, [r.op for r in ts],
                    checkpoints_other)
        debug_print("ops_to_copy = %s", ops_to_copy)
        if not ops_to_copy:  # we're done!
            break
        copied_sgv, info = ge.copy_with_input_replacements(
            ge.sgv(ops_to_copy), {})
        for origin_op, op in info._transformed_ops.items():
            op._set_device(origin_op.node_def.device)
        copied_ops = info._transformed_ops.values()
        debug_print("Copied %s to %s", ops_to_copy, copied_ops)
        ge.reroute_ts(checkpoints_disconnected_other,
                      checkpoints_other,
                      can_modify=copied_ops)
        debug_print("Rewired %s in place of %s restricted to %s",
                    checkpoints_disconnected_other, checkpoints_other,
                    copied_ops)

        # gradient flowing through the checkpointed node
        boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts]
        substitute_backprops = [d_checkpoints[r] for r in ts]
        dv = tf_gradients(boundary,
                          checkpoints_disconnected_other + xs,
                          grad_ys=substitute_backprops,
                          **kwargs)
        debug_print("Got gradients %s", dv)
        debug_print("for %s", boundary)
        debug_print("with respect to %s", checkpoints_disconnected_other + xs)
        debug_print("with boundary backprop substitutions %s",
                    substitute_backprops)

        inputs_to_do_before = [d_checkpoints[r].op for r in ts]
        wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
        my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)

        # partial derivatives to the checkpointed nodes
        for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]):
            if dr is not None:
                if d_checkpoints[r] is None:
                    d_checkpoints[r] = dr
                else:
                    d_checkpoints[r] += dr

        def _unsparsify(x):
            if not isinstance(x, tf.IndexedSlices):
                return x
            assert x.dense_shape is not None, "memory_saving_gradients encountered sparse gradients of unknown shape"
            indices = x.indices
            while indices.shape.ndims < x.values.shape.ndims:
                indices = tf.expand_dims(indices, -1)
            return tf.scatter_nd(indices, x.values, x.dense_shape)

        # partial derivatives to xs (usually the params of the neural net)
        d_xs_new = dv[len(checkpoints_other):]
        for j in range(len(xs)):
            if d_xs_new[j] is not None:
                if d_xs[j] is None:
                    d_xs[j] = _unsparsify(d_xs_new[j])
                else:
                    d_xs[j] += _unsparsify(d_xs_new[j])

    return d_xs
    def _get_checkpoint(self):
        # exclude ops with no inputs
        self.fwd_ops = [op for op in self.fwd_ops if op.inputs]

        # don't recompute xs, remove variables
        xs_ops = self._to_ops(self._xs)
        self.fwd_ops = [op for op in self.fwd_ops if not op in xs_ops]
        self.fwd_ops = [op for op in self.fwd_ops if not '/assign' in op.name]
        self.fwd_ops = [op for op in self.fwd_ops if not '/Assign' in op.name]
        self.fwd_ops = [op for op in self.fwd_ops if not '/read' in op.name]
        ts_all = ge.filter_ts(self.fwd_ops, True)  # get the tensors
        ts_all = [t for t in ts_all if '/read' not in t.name]
        ts_all = set(ts_all) - set(self._xs) - set(self._ys)

        # construct list of tensors to checkpoint during forward pass, if not
        # given as input
        # remove very small tensors and some weird ops
        def fixdims(
            t
        ):  # tf.Dimension values are not compatible with int, convert manually
            try:
                return [int(e if e.value is not None else 64) for e in t]
            except:
                return [0]  # unknown shape

        ts_all = [
            t for t in ts_all
            if np.prod(fixdims(t.shape)) > MIN_CHECKPOINT_NODE_SIZE
        ]
        ts_all = [t for t in ts_all if 'L2Loss' not in t.name]
        ts_all = [t for t in ts_all if 'entropy' not in t.name]
        ts_all = [t for t in ts_all if 'FusedBatchNorm' not in t.name]
        ts_all = [t for t in ts_all if 'Switch' not in t.name]
        ts_all = [t for t in ts_all if 'dropout' not in t.name]
        # DV: FP16_FIX - need to add 'Cast' layer here to make it work for FP16
        ts_all = [t for t in ts_all if 'Cast' not in t.name]

        self._log_info("ts_all, {}".format(ts_all))
        # filter out all tensors that are inputs of the backward graph
        with util.capture_ops() as bwd_ops:
            tf_gradients(self._ys, self._xs, self._grad_ys, **self._kwargs)

        bwd_inputs = [t for op in self.bwd_ops for t in op.inputs]

        self._log_info("bwd_inputs, {}".format(bwd_inputs))
        # list of tensors in forward graph that is in input to bwd graph
        ts_filtered = list(set(bwd_inputs).intersection(ts_all))
        self._log_info("Using tensors %{}".format(ts_filtered), 2)

        # try two slightly different ways of getting bottlenecks tensors
        # to checkpoint
        for ts in [ts_filtered, ts_all]:

            # get all bottlenecks in the graph
            bottleneck_ts = []
            for t in ts:
                b = set(
                    ge.get_backward_walk_ops(t.op,
                                             inclusive=True,
                                             within_ops=self.fwd_ops))
                f = set(
                    ge.get_forward_walk_ops(t.op,
                                            inclusive=False,
                                            within_ops=self.fwd_ops))
                # check that there are not shortcuts
                b_inp = set([inp for op in b
                             for inp in op.inputs]).intersection(ts_all)
                f_inp = set([inp for op in f
                             for inp in op.inputs]).intersection(ts_all)
                if not set(b_inp).intersection(
                        f_inp) and len(b_inp) + len(f_inp) >= len(ts_all):
                    bottleneck_ts.append(t)  # we have a bottleneck!
                else:
                    self._log_info(
                        "Rejected bottleneck candidate and ops {}".format(
                            [t] + list(set(ts_all) - set(b_inp) - set(f_inp))),
                        3)

            # success? or try again without filtering?
            if len(bottleneck_ts) >= np.sqrt(
                    len(ts_filtered)):  # yes, enough bottlenecks found!
                break

        if not bottleneck_ts:
            raise Exception(
                'unable to find bottleneck tensors! please provide checkpoint nodes manually, or use checkpoints="speed".'
            )

        # sort the bottlenecks
        bottlenecks_sorted_lists = self.tf_toposort(bottleneck_ts,
                                                    within_ops=self.fwd_ops)
        sorted_bottlenecks = [t for ts in bottlenecks_sorted_lists for t in ts]

        # save an approximately optimal number ~ sqrt(N)
        N = len(ts_filtered)
        if len(bottleneck_ts) <= np.ceil(np.sqrt(N)):
            checkpoints = sorted_bottlenecks
        else:
            step = int(np.ceil(len(bottleneck_ts) / np.sqrt(N)))
            checkpoints = sorted_bottlenecks[step::step]

        checkpoints = list(set(checkpoints).intersection(ts_all))

        return checkpoints