Ejemplo n.º 1
0
def cut_graph_def(graph_def, cut_nodes):
    """Cut groph_def to two parts by cut_nodes. All ancesters  of cut_nodes are put
    into back and the rest are put into head.

    Args:
        graph_def: input tf.GraphDef
        cut_nodes: a list of node names to cut
    """
    # back
    back = tf.graph_util.extract_sub_graph(graph_def, cut_nodes)

    # head
    head_node_names = [n.name for n in graph_def.node if n not in back.node]
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name="")
        all_ops = make_list_of_op(graph)
    head_ops = [o for o in all_ops if o.name in head_node_names]
    head_subgraph = SubGraphView(inside_ops=head_ops)

    head_graph = tf.Graph()
    replace_ts = {}
    for i in head_subgraph.inputs:
        k = i.name
        replace_ts[k] = make_placeholder_from_tensor(i)
    copy_with_input_replacements(
        head_subgraph,
        replace_ts,
        dst_graph=head_graph
    )

    # return
    return back, head_graph.as_graph_def()
Ejemplo n.º 2
0
 def test_copy_with_input_replacements(self):
     with self.graph.as_default():
         ten = tf.constant(10.0, shape=[10], name="Input")
         sgv, _ = ge.copy_with_input_replacements(self.o.op, {self.o.op.inputs[1]: ten})
         with tf.Session() as sess:
             val = sess.run(sgv.outputs[0])
         self.assertNear(np.linalg.norm(val - np.array([11])), 0.0, ERROR_TOLERANCE)
    def _clone_model(self, model, perturbations, dst_scope):
        ''' make a copy of model and connect the resulting sub-graph to
            input ops of the original graph and parameter assignments by
            perturbator.    
        '''
        def not_placeholder_or_trainvar_filter(op):
            # print(op.name)
            if op.type == 'Placeholder':              # evaluation sub-graphs will be fed from original placeholders
                return False
            for var_name in self.tvars:
                if op.name.startswith(var_name):      # remove Some/Var/(read,assign,...) -- will be replaced with perturbations
                    return False
            return True

        ops_without_inputs = ge.filter_ops(model.ops, not_placeholder_or_trainvar_filter)
        # print("ModelOPS=========================")
        # for o in ops_without_inputs:
        #     print(o.name, o.type)
        # remove init op from clone if already present
        try:
            ops_without_inputs.remove(self.work_graph.get_operation_by_name("init"))
        except:
            pass
        clone_sgv = ge.make_view(ops_without_inputs)
        clone_sgv = clone_sgv.remove_unused_ops(control_inputs=True)

        input_replacements = {}
        for t in clone_sgv.inputs:
            if t.name in perturbations.keys():                  # input from trainable var --> replace with perturbation
                input_replacements[t] = perturbations[t.name]
            else:                                               # otherwise take input from original graph
                input_replacements[t] = self.work_graph.get_tensor_by_name(t.name)
        return ge.copy_with_input_replacements(clone_sgv, input_replacements, dst_scope=dst_scope)
Ejemplo n.º 4
0
 def test_copy_with_input_replacements(self):
     with self.graph.as_default():
         ten = tf.constant(10.0, shape=[10], name="Input")
         sgv, _ = ge.copy_with_input_replacements(
             self.o.op, {self.o.op.inputs[1]: ten})
         with tf.Session() as sess:
             val = sess.run(sgv.outputs[0])
         self.assertNear(np.linalg.norm(val - np.array([11])), 0.0,
                         ERROR_TOLERANCE)
Ejemplo n.º 5
0
def _duplicate_layer(layer_name,
                     layer_sgv,
                     branch_name,
                     add_to_collections=True):
    """Duplicates a network layer, while preserving connections.

    Args:
      layer_name:         a layer is identified by its name scope
      layer_sgv:          SubgraphView (see tf.contrib.graph_editor)
      branch_name:        the duplicate is "layer_name + branch_name"
      add_to_collections: add duplicate vars to the same collections

    Returns:
      info:            see ret vals of `tf.contrib.graph_editor.copy`
      var_duplication: a list of tuples (var, dup_of_var)
    """

    if layer_name[-1] == '/':
        new_layer_name = layer_name[:-1] + branch_name + '/'
    else:
        new_layer_name = layer_name + branch_name

    replacement_ts = {}
    for op in layer_sgv.inputs:
        replacement_ts[op] = op

    duplicate_sgv, info = ge.copy_with_input_replacements(
        layer_sgv,
        replacement_ts=replacement_ts,
        src_scope=layer_name,
        dst_scope=new_layer_name)

    var_duplication = []
    for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
        if layer_name not in v.name:
            continue
        vproto = v.to_proto()
        new_vardef = variable_pb2.VariableDef()
        for field, val in vproto.ListFields():
            if isinstance(val, str):
                new_val = val.replace(layer_name, new_layer_name)
            else:
                new_val = val
            setattr(new_vardef, field.name, new_val)
        new_var = tf.Variable(variable_def=new_vardef)
        tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, new_var)
        var_duplication.append((v, new_var))

        if add_to_collections:
            for k in tf.get_default_graph().get_all_collection_keys():
                collection = tf.get_collection(k)
                if v in collection and new_var not in collection:
                    tf.add_to_collection(k, new_var)

    return info, var_duplication
Ejemplo n.º 6
0
def clone_subgraph(outputs, mappings, clone_scope=''):
    NON_REPLICABLE = {
        'Variable', 'VariableV2', 'AutoReloadVariable', 'MutableHashTable',
        'MutableHashTableV2', 'MutableHashTableOfTensors',
        'MutableHashTableOfTensorsV2', 'MutableDenseHashTable',
        'MutableDenseHashTableV2', 'VarHandleOp',
        'BoostedTreesEnsembleResourceHandleOp'
    }
    ops = ge.get_backward_walk_ops(outputs, stop_at_ts=mappings.keys())
    ops_replicate = [op for op in ops if op.type not in NON_REPLICABLE]
    sgv = ge.make_view(*ops_replicate)
    _, info = ge.copy_with_input_replacements(sgv,
                                              mappings,
                                              dst_scope=clone_scope)
    return info.transformed(outputs)
Ejemplo n.º 7
0
def recompute_tensor(target, known_values, preceding_op=None,
                     copy_known_values=False):
  """Computes target tensor from known_values. If preceding_op is not None,
  adds necessary control dependencies such that newly created computation takes
  place after preceding_op. 

  If copy_known_values is set, also copies known_values (for nicer graph
  visualization)
  """

  assert is_computable(target, known_values)
  
  # position of target in parent op
  target_pos = list(target.op.outputs).index(target)

  if copy_known_values:
    computation = ge.get_backward_walk_ops(target)
  else:
    computation = ge.get_backward_walk_ops(target, stop_at_ts=known_values)
    
  # create copy of computation
  copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(computation), {})

  # find our target tensor in the new computation
  new_target_op = info._transformed_ops[target.op]
  new_target = new_target_op.outputs[target_pos]
  new_computation = list(info._transformed_ops.values())

  # restrict computation to run after given op
  SAVE_ON_CONTROL_EDGES = True

  if SAVE_ON_CONTROL_EDGES:
    # only add "run_after" control dependencies to root of computation,
    # the rest automatically runs after because of data dependencies
    # TODO: more efficient implementation by walking back from new_target
    # instead of whole graph
    computation_graph = linearize_lib.get_graph(restrict_to=new_computation)

    # note, toposort order is reversed from networkx/mine convention
    computation_root = list(toposort.toposort(computation_graph))[-1]
    for op in computation_root:
      run_after(op, preceding_op)
  else:
    if preceding_op is not None:
      for op in info._transformed_ops.values():
        run_after(op, preceding_op)
  return new_target
Ejemplo n.º 8
0
    def _duplicate_graph(self, graph, vars_to_replace, name='Duplicated'):
        """
        Duplicates loss graph with swapped variables.
        :return: Swapped graph.
        """
        if graph in vars_to_replace:
            return vars_to_replace[graph]

        operations = []

        def get_ops(t):
            if t.op.type != 'VariableV2' and t.op.type != 'Placeholder':
                operations.append(t.op)
                for i in t.op.inputs:
                    if i not in vars_to_replace:
                        get_ops(i)

        get_ops(graph)

        sgv = graph_editor.make_view(operations)
        with ops.name_scope(name):
            new_view, _ = graph_editor.copy_with_input_replacements(
                sgv, vars_to_replace)
            return new_view.outputs[sgv.output_index(graph)]
def gradients(ys, xs, grad_ys=None, checkpoints='collection', **kwargs):
    '''
    Authors: Tim Salimans & Yaroslav Bulatov

    memory efficient gradient implementation inspired by "Training Deep Nets with Sublinear Memory Cost"
    by Chen et al. 2016 (https://arxiv.org/abs/1604.06174)

    ys,xs,grad_ys,kwargs are the arguments to standard tensorflow tf.gradients
    (https://www.tensorflow.org/versions/r0.12/api_docs/python/train.html#gradients)

    'checkpoints' can either be
        - a list consisting of tensors from the forward pass of the neural net
          that we should re-use when calculating the gradients in the backward pass
          all other tensors that do not appear in this list will be re-computed
        - a string specifying how this list should be determined. currently we support
            - 'speed':  checkpoint all outputs of convolutions and matmuls. these ops are usually the most expensive,
                        so checkpointing them maximizes the running speed
                        (this is a good option if nonlinearities, concats, batchnorms, etc are taking up a lot of memory)
            - 'memory': try to minimize the memory usage
                        (currently using a very simple strategy that identifies a number of bottleneck tensors in the graph to checkpoint)
            - 'collection': look for a tensorflow collection named 'checkpoints', which holds the tensors to checkpoint
    '''

    #    print("Calling memsaving gradients with", checkpoints)
    if not isinstance(ys, list):
        ys = [ys]
    if not isinstance(xs, list):
        xs = [xs]

    bwd_ops = ge.get_backward_walk_ops([y.op for y in ys], inclusive=True)

    debug_print("bwd_ops: %s", bwd_ops)

    # forward ops are all ops that are candidates for recomputation
    fwd_ops = ge.get_forward_walk_ops([x.op for x in xs],
                                      inclusive=True,
                                      within_ops=bwd_ops)
    debug_print("fwd_ops: %s", fwd_ops)

    # exclude ops with no inputs
    fwd_ops = [op for op in fwd_ops if op.inputs]

    # don't recompute xs, remove variables
    xs_ops = _to_ops(xs)
    fwd_ops = [op for op in fwd_ops if not op in xs_ops]
    fwd_ops = [op for op in fwd_ops if not '/assign' in op.name]
    fwd_ops = [op for op in fwd_ops if not '/Assign' in op.name]
    fwd_ops = [op for op in fwd_ops if not '/read' in op.name]
    ts_all = ge.filter_ts(fwd_ops, True)  # get the tensors
    ts_all = [t for t in ts_all if '/read' not in t.name]
    ts_all = set(ts_all) - set(xs) - set(ys)

    # construct list of tensors to checkpoint during forward pass, if not
    # given as input
    if type(checkpoints) is not list:
        if checkpoints == 'collection':
            checkpoints = tf.get_collection('checkpoints')

        elif checkpoints == 'speed':
            # checkpoint all expensive ops to maximize running speed
            checkpoints = ge.filter_ts_from_regex(fwd_ops,
                                                  'conv2d|Conv|MatMul')

        elif checkpoints == 'memory':

            # remove very small tensors and some weird ops
            def fixdims(
                t
            ):  # tf.Dimension values are not compatible with int, convert manually
                try:
                    return [int(e if e.value is not None else 64) for e in t]
                except:
                    return [0]  # unknown shape

            ts_all = [
                t for t in ts_all
                if np.prod(fixdims(t.shape)) > MIN_CHECKPOINT_NODE_SIZE
            ]
            ts_all = [t for t in ts_all if 'L2Loss' not in t.name]
            ts_all = [t for t in ts_all if 'entropy' not in t.name]
            ts_all = [t for t in ts_all if 'FusedBatchNorm' not in t.name]
            ts_all = [t for t in ts_all if 'Switch' not in t.name]
            ts_all = [t for t in ts_all if 'dropout' not in t.name]
            # DV: FP16_FIX - need to add 'Cast' layer here to make it work for FP16
            ts_all = [t for t in ts_all if 'Cast' not in t.name]

            # filter out all tensors that are inputs of the backward graph
            with util.capture_ops() as bwd_ops:
                tf_gradients(ys, xs, grad_ys, **kwargs)

            bwd_inputs = [t for op in bwd_ops for t in op.inputs]
            # list of tensors in forward graph that is in input to bwd graph
            ts_filtered = list(set(bwd_inputs).intersection(ts_all))
            debug_print("Using tensors %s", ts_filtered)

            # try two slightly different ways of getting bottlenecks tensors
            # to checkpoint
            for ts in [ts_filtered, ts_all]:

                # get all bottlenecks in the graph
                bottleneck_ts = []
                for t in ts:
                    b = set(
                        ge.get_backward_walk_ops(t.op,
                                                 inclusive=True,
                                                 within_ops=fwd_ops))
                    f = set(
                        ge.get_forward_walk_ops(t.op,
                                                inclusive=False,
                                                within_ops=fwd_ops))
                    # check that there are not shortcuts
                    b_inp = set([inp for op in b
                                 for inp in op.inputs]).intersection(ts_all)
                    f_inp = set([inp for op in f
                                 for inp in op.inputs]).intersection(ts_all)
                    if not set(b_inp).intersection(
                            f_inp) and len(b_inp) + len(f_inp) >= len(ts_all):
                        bottleneck_ts.append(t)  # we have a bottleneck!
                    else:
                        debug_print(
                            "Rejected bottleneck candidate and ops %s",
                            [t] + list(set(ts_all) - set(b_inp) - set(f_inp)))

                # success? or try again without filtering?
                if len(bottleneck_ts) >= np.sqrt(
                        len(ts_filtered)):  # yes, enough bottlenecks found!
                    break

            if not bottleneck_ts:
                raise Exception(
                    'unable to find bottleneck tensors! please provide checkpoint nodes manually, or use checkpoints="speed".'
                )

            # sort the bottlenecks
            bottlenecks_sorted_lists = tf_toposort(bottleneck_ts,
                                                   within_ops=fwd_ops)
            sorted_bottlenecks = [
                t for ts in bottlenecks_sorted_lists for t in ts
            ]

            # save an approximately optimal number ~ sqrt(N)
            N = len(ts_filtered)
            if len(bottleneck_ts) <= np.ceil(np.sqrt(N)):
                checkpoints = sorted_bottlenecks
            else:
                step = int(np.ceil(len(bottleneck_ts) / np.sqrt(N)))
                checkpoints = sorted_bottlenecks[step::step]

        else:
            raise Exception('%s is unsupported input for "checkpoints"' %
                            (checkpoints, ))

    checkpoints = list(set(checkpoints).intersection(ts_all))

    # at this point automatic selection happened and checkpoints is list of nodes
    assert isinstance(checkpoints, list)

    debug_print("Checkpoint nodes used: %s", checkpoints)
    # better error handling of special cases
    # xs are already handled as checkpoint nodes, so no need to include them
    xs_intersect_checkpoints = set(xs).intersection(set(checkpoints))
    if xs_intersect_checkpoints:
        debug_print("Warning, some input nodes are also checkpoint nodes: %s",
                    xs_intersect_checkpoints)
    ys_intersect_checkpoints = set(ys).intersection(set(checkpoints))
    debug_print("ys: %s, checkpoints: %s, intersect: %s", ys, checkpoints,
                ys_intersect_checkpoints)
    # saving an output node (ys) gives no benefit in memory while creating
    # new edge cases, exclude them
    if ys_intersect_checkpoints:
        debug_print(
            "Warning, some output nodes are also checkpoints nodes: %s",
            format_ops(ys_intersect_checkpoints))

    # remove initial and terminal nodes from checkpoints list if present
    checkpoints = list(set(checkpoints) - set(ys) - set(xs))

    # check that we have some nodes to checkpoint
    # if not checkpoints:
    #     raise Exception('no checkpoints nodes found or given as input! ')

    # disconnect dependencies between checkpointed tensors
    checkpoints_disconnected = {}
    for x in checkpoints:
        if x.op and x.op.name is not None:
            grad_node = tf.stop_gradient(x, name=x.op.name + "_sg")
        else:
            grad_node = tf.stop_gradient(x)
        checkpoints_disconnected[x] = grad_node

    # partial derivatives to the checkpointed tensors and xs
    ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys],
                                    stop_at_ts=checkpoints,
                                    within_ops=fwd_ops)
    debug_print("Found %s ops to copy within fwd_ops %s, seed %s, stop_at %s",
                len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints)
    debug_print("ops_to_copy = %s", ops_to_copy)
    debug_print("Processing list %s", ys)
    copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {})
    for origin_op, op in info._transformed_ops.items():
        op._set_device(origin_op.node_def.device)
    copied_ops = info._transformed_ops.values()
    debug_print("Copied %s to %s", ops_to_copy, copied_ops)
    ge.reroute_ts(checkpoints_disconnected.values(),
                  checkpoints_disconnected.keys(),
                  can_modify=copied_ops)
    debug_print("Rewired %s in place of %s restricted to %s",
                checkpoints_disconnected.values(),
                checkpoints_disconnected.keys(), copied_ops)

    # get gradients with respect to current boundary + original x's
    copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys]
    boundary = list(checkpoints_disconnected.values())
    dv = tf_gradients(ys=copied_ys,
                      xs=boundary + xs,
                      grad_ys=grad_ys,
                      **kwargs)
    debug_print("Got gradients %s", dv)
    debug_print("for %s", copied_ys)
    debug_print("with respect to %s", boundary + xs)

    inputs_to_do_before = [y.op for y in ys]
    if grad_ys is not None:
        inputs_to_do_before += grad_ys
    wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
    my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)

    # partial derivatives to the checkpointed nodes
    # dictionary of "node: backprop" for nodes in the boundary
    d_checkpoints = {
        r: dr
        for r, dr in zip(checkpoints_disconnected.keys(),
                         dv[:len(checkpoints_disconnected)])
    }
    # partial derivatives to xs (usually the params of the neural net)
    d_xs = dv[len(checkpoints_disconnected):]

    # incorporate derivatives flowing through the checkpointed nodes
    checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops)
    for ts in checkpoints_sorted_lists[::-1]:
        debug_print("Processing list %s", ts)
        checkpoints_other = [r for r in checkpoints if r not in ts]
        checkpoints_disconnected_other = [
            checkpoints_disconnected[r] for r in checkpoints_other
        ]

        # copy part of the graph below current checkpoint node, stopping at
        # other checkpoints nodes
        ops_to_copy = fast_backward_ops(within_ops=fwd_ops,
                                        seed_ops=[r.op for r in ts],
                                        stop_at_ts=checkpoints_other)
        debug_print("Found %s ops to copy within %s, seed %s, stop_at %s",
                    len(ops_to_copy), fwd_ops, [r.op for r in ts],
                    checkpoints_other)
        debug_print("ops_to_copy = %s", ops_to_copy)
        if not ops_to_copy:  # we're done!
            break
        copied_sgv, info = ge.copy_with_input_replacements(
            ge.sgv(ops_to_copy), {})
        for origin_op, op in info._transformed_ops.items():
            op._set_device(origin_op.node_def.device)
        copied_ops = info._transformed_ops.values()
        debug_print("Copied %s to %s", ops_to_copy, copied_ops)
        ge.reroute_ts(checkpoints_disconnected_other,
                      checkpoints_other,
                      can_modify=copied_ops)
        debug_print("Rewired %s in place of %s restricted to %s",
                    checkpoints_disconnected_other, checkpoints_other,
                    copied_ops)

        # gradient flowing through the checkpointed node
        boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts]
        substitute_backprops = [d_checkpoints[r] for r in ts]
        dv = tf_gradients(boundary,
                          checkpoints_disconnected_other + xs,
                          grad_ys=substitute_backprops,
                          **kwargs)
        debug_print("Got gradients %s", dv)
        debug_print("for %s", boundary)
        debug_print("with respect to %s", checkpoints_disconnected_other + xs)
        debug_print("with boundary backprop substitutions %s",
                    substitute_backprops)

        inputs_to_do_before = [d_checkpoints[r].op for r in ts]
        wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
        my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)

        # partial derivatives to the checkpointed nodes
        for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]):
            if dr is not None:
                if d_checkpoints[r] is None:
                    d_checkpoints[r] = dr
                else:
                    d_checkpoints[r] += dr

        def _unsparsify(x):
            if not isinstance(x, tf.IndexedSlices):
                return x
            assert x.dense_shape is not None, "memory_saving_gradients encountered sparse gradients of unknown shape"
            indices = x.indices
            while indices.shape.ndims < x.values.shape.ndims:
                indices = tf.expand_dims(indices, -1)
            return tf.scatter_nd(indices, x.values, x.dense_shape)

        # partial derivatives to xs (usually the params of the neural net)
        d_xs_new = dv[len(checkpoints_other):]
        for j in range(len(xs)):
            if d_xs_new[j] is not None:
                if d_xs[j] is None:
                    d_xs[j] = _unsparsify(d_xs_new[j])
                else:
                    d_xs[j] += _unsparsify(d_xs_new[j])

    return d_xs
Ejemplo n.º 10
0
    def apply(self, new_inputs, update_colocation_groups=True):
        assert len(new_inputs) == len(self.inputs)
        g = tf.get_default_graph()  # todo: make that member variable

        new_inputs2 = []
        # replace variable inputs with their read endpoint
        for input in new_inputs:
            if isinstance(input, tf.Variable):
                new_inputs2.append(input.read_value())
            else:
                new_inputs2.append(input)
        new_inputs = new_inputs2

        replacements = OrderedDict()
        for old_input_ts, new_input_ts in zip(self.inputs, new_inputs):
            # shape/dtype checks
            if isinstance(old_input_ts, (list, tuple)):
                reference_ts = old_input_ts[0]
            else:
                reference_ts = old_input_ts
            assert reference_ts.get_shape() == new_input_ts.get_shape()
            assert reference_ts.dtype == new_input_ts.dtype

            # Variable with multiple read endpoints, replace all of them with
            # new input tensor
            if isinstance(old_input_ts, (list, tuple)):
                for sub_input in old_input_ts:
                    replacements[sub_input] = new_input_ts
            # regular Tensor
            else:
                replacements[old_input_ts] = new_input_ts

        # sanity checks
        # copying Variables is confusing because they don't get added
        # to GLOBAL_VARIABLES collection hence escape global initialization
        # therefore forbit it
        for op in self.ops:
            if op.type.startswith('Variable'):  # 'VariableV2' or 'Variable'
                assert False, "Can't copy variables"

        # TODO: remove this
        def summarize_ts(ts):
            from collections import Counter
            type_counter = Counter()
            ops = set([tensor.op for tensor in ts])
            print Counter([op.type for op in ops]).most_common(10)

        sgv = ge.sgv(self.ops)
        #    import pdb; pdb.set_trace()
        copied_sgv, info = ge.copy_with_input_replacements(sgv, replacements)

        # converting between Python bytes and unicode
        def to_bytes(s):
            return s.encode('ascii')

        def from_bytes(s):
            return s.decode('ascii')

        # fix colocation constraints to point to copied ops
        # see https://github.com/tensorflow/tensorflow/issues/9925
        if update_colocation_groups:
            new_ops = [info._transformed_ops[op] for op in self.ops]
            for new_op in new_ops:
                assert len(new_op.colocation_groups()) == 1
                colocation_group = new_op.colocation_groups()[0]
                assert colocation_group.startswith(b'loc:@')
                colocated_with_name = from_bytes(
                    colocation_group[len(b'loc:@'):])

                # if there were no colocation constraints, the op gets colocated with
                # itself (default colocation group), ignore that constraint
                if colocated_with_name == new_op.name:
                    continue

                colocation_op = g.get_operation_by_name(colocated_with_name)
                if colocation_op in info._transformed_ops:
                    new_colocation_op = info._transformed_ops[colocation_op]
                else:
                    assert colocation_op in self.input_ops
                    colocation_op_idx = self.input_ops.index(colocation_op)
                    new_colocation_op = new_inputs[colocation_op_idx].op

                # overwrite existing _class attribute with new colocation constraints
                new_colocation_groups = [
                    b'loc:@' + to_bytes(new_colocation_op.name)
                ]
                new_op.node_def.attr["_class"].CopyFrom(
                    attr_value_pb2.AttrValue(
                        list=attr_value_pb2.AttrValue.ListValue(
                            s=new_colocation_groups)))

        # place new ops on device from current device context
        device = get_current_device()
        if device:
            for op in info._transformed_ops.values():
                op._set_device(device)

        new_outputs = []
        for old_output_ts in self.outputs:
            new_output_op = info._transformed_ops[old_output_ts.op]
            new_output_ts = new_output_op.outputs[0]
            new_outputs.append(new_output_ts)

        return new_outputs
Ejemplo n.º 11
0
def gradients(ys, xs, grad_ys=None, checkpoints='collection', **kwargs):
    print("-------------------------------")
    debug_print("Editing model for OME")
    incpu_count = 0
    #    print("Calling memsaving gradients with", checkpoints)
    if not isinstance(ys, list):
        ys = [ys]
    if not isinstance(xs, list):
        xs = [xs]

    bwd_ops = ge.get_backward_walk_ops([y.op for y in ys], inclusive=True)

    for index, op in enumerate(bwd_ops):
        debug_print("bwd_ops: [{}] :{}".format(index, op.name), 1)

    # forward ops are all ops that are candidates for recomputation
    fwd_ops = ge.get_forward_walk_ops([x.op for x in xs],
                                      inclusive=True,
                                      within_ops=bwd_ops)
    for index, op in enumerate(fwd_ops):
        debug_print("fwd_ops: [{}] : {}".format(index, op.name), 1)

    # exclude ops with no inputs
    fwd_ops = [op for op in fwd_ops if op.inputs]

    # don't recompute xs, remove variables
    xs_ops = _to_ops(xs)
    fwd_ops = [op for op in fwd_ops if not op in xs_ops]
    fwd_ops = [op for op in fwd_ops if not '/assign' in op.name]
    fwd_ops = [op for op in fwd_ops if not '/Assign' in op.name]
    fwd_ops = [op for op in fwd_ops if not '/read' in op.name]
    ts_all = ge.filter_ts(fwd_ops, True)  # get the tensors
    ts_all = [t for t in ts_all if '/read' not in t.name]
    ts_all = set(ts_all) - set(xs) - set(ys)

    # construct list of tensors to checkpoint during forward pass, if not
    # given as input
    if type(checkpoints) is not list:
        # remove very small tensors and some weird ops
        def fixdims(
            t
        ):  # tf.Dimension values are not compatible with int, convert manually
            try:
                return [int(e if e.value is not None else 64) for e in t]
            except:
                return [0]  # unknown shape

        ts_all = [
            t for t in ts_all
            if np.prod(fixdims(t.shape)) > MIN_CHECKPOINT_NODE_SIZE
        ]
        ts_all = [t for t in ts_all if 'L2Loss' not in t.name]
        ts_all = [t for t in ts_all if 'entropy' not in t.name]
        ts_all = [t for t in ts_all if 'FusedBatchNorm' not in t.name]
        ts_all = [t for t in ts_all if 'Switch' not in t.name]
        ts_all = [t for t in ts_all if 'dropout' not in t.name]
        # DV: FP16_FIX - need to add 'Cast' layer here to make it work for FP16
        ts_all = [t for t in ts_all if 'Cast' not in t.name]

        # filter out all tensors that are inputs of the backward graph
        with util.capture_ops() as bwd_ops:
            tf_gradients(ys, xs, grad_ys, **kwargs)

        bwd_inputs = [t for op in bwd_ops for t in op.inputs]
        # list of tensors in forward graph that is in input to bwd graph
        ts_filtered = list(set(bwd_inputs).intersection(ts_all))
        debug_print("Using tensors {}".format(ts_filtered), 1)

        # try two slightly different ways of getting bottlenecks tensors
        # to checkpoint
        for ts in [ts_filtered, ts_all]:

            # get all bottlenecks in the graph
            bottleneck_ts = []
            for t in ts:
                b = set(
                    ge.get_backward_walk_ops(t.op,
                                             inclusive=True,
                                             within_ops=fwd_ops))
                f = set(
                    ge.get_forward_walk_ops(t.op,
                                            inclusive=False,
                                            within_ops=fwd_ops))
                # check that there are not shortcuts
                b_inp = set([inp for op in b
                             for inp in op.inputs]).intersection(ts_all)
                f_inp = set([inp for op in f
                             for inp in op.inputs]).intersection(ts_all)
                if not set(b_inp).intersection(
                        f_inp) and len(b_inp) + len(f_inp) >= len(ts_all):
                    bottleneck_ts.append(t)  # we have a bottleneck!
                else:
                    debug_print(
                        "Rejected bottleneck candidate and ops {}".format(
                            [t] + list(set(ts_all) - set(b_inp) - set(f_inp))),
                        2)

            # success? or try again without filtering?
            if len(bottleneck_ts) >= np.sqrt(
                    len(ts_filtered)):  # yes, enough bottlenecks found!
                break

        if not bottleneck_ts:
            raise Exception(
                'unable to find bottleneck tensors! please provide checkpoint nodes manually, or use checkpoints="speed".'
            )

        # sort the bottlenecks
        bottlenecks_sorted_lists = tf_toposort(bottleneck_ts,
                                               within_ops=fwd_ops)
        sorted_bottlenecks = [t for ts in bottlenecks_sorted_lists for t in ts]

        # save an approximately optimal number ~ sqrt(N)
        N = len(ts_filtered)
        if len(bottleneck_ts) <= np.ceil(np.sqrt(N)):
            checkpoints = sorted_bottlenecks
        else:
            step = int(np.ceil(len(bottleneck_ts) / np.sqrt(N)))
            checkpoints = sorted_bottlenecks[step::step]

    checkpoints = list(set(checkpoints).intersection(ts_all))

    # at this point automatic selection happened and checkpoints is list of nodes
    assert isinstance(checkpoints, list)

    debug_print("Checkpoint nodes used: {}".format(checkpoints), 1)
    # better error handling of special cases
    # xs are already handled as checkpoint nodes, so no need to include them
    xs_intersect_checkpoints = set(xs).intersection(set(checkpoints))
    if xs_intersect_checkpoints:
        debug_print(
            "Warning, some input nodes are also checkpoint nodes: {}".format(
                xs_intersect_checkpoints), 2)
    ys_intersect_checkpoints = set(ys).intersection(set(checkpoints))
    debug_print(
        "ys: {}, checkpoints: {}, intersect: {}".format(
            ys, checkpoints, ys_intersect_checkpoints), 1)
    # saving an output node (ys) gives no benefit in memory while creating
    # new edge cases, exclude them
    if ys_intersect_checkpoints:
        debug_print(
            "Warning, some output nodes are also checkpoints nodes: {}".format(
                format_ops(ys_intersect_checkpoints)), 2)

    # remove initial and terminal nodes from checkpoints list if present
    checkpoints = list(set(checkpoints) - set(ys) - set(xs))

    # check that we have some nodes to checkpoint
    if not checkpoints:
        raise Exception('no checkpoints nodes found or given as input! ')

    debug_print(
        "Select {} nodes to checkpoint nodes.".format(len(checkpoints)), 0)

    # disconnect dependencies between checkpointed tensors
    checkpoints_disconnected = {}
    for x in checkpoints:
        frontier_ops = set(graph_util.get_consuming_ops(x.op.outputs))
        debug_print("my frontier ops: {}".format(frontier_ops), 1)

        bw_frontier_ops = frontier_ops & set(bwd_ops)
        debug_print("my bw frontier ops: {}".format(bw_frontier_ops), 1)

        if len(bw_frontier_ops) > 1:
            continue

        if x.op and x.op.name is not None:
            grad_node = tf.stop_gradient(x, name=x.op.name + "_sg")
        else:
            grad_node = tf.stop_gradient(x)

        swapout_op = _add_swapout(grad_node.op, grad_node.op.outputs)
        incpu_count = incpu_count + 1
        swapin_op = _add_swapin(swapout_op, bw_frontier_ops,
                                grad_node.op.outputs)
        checkpoints_disconnected[x] = swapin_op
        my_add_control_inputs(x, bw_frontier_ops, swapin_op)
        # control dependency -> swap_in
        # self._add_control_dependency(src_op, dest_op, swapin_op)

    # g = tf.get_default_graph()
    # print(g.get_operations())

    # partial derivatives to the checkpointed tensors and xs
    ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys],
                                    stop_at_ts=checkpoints,
                                    within_ops=fwd_ops)
    debug_print(
        "Found {} ops to copy within fwd_ops {}, seed {}, stop_at {}".format(
            len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints), 1)
    debug_print("ops_to_copy = {}".format(ops_to_copy), 1)
    debug_print("Processing list {}".format(ys), 1)
    copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {})
    for origin_op, op in info._transformed_ops.items():
        op._set_device(origin_op.node_def.device)
    copied_ops = info._transformed_ops.values()
    debug_print("Copied {} to {}".format(ops_to_copy, copied_ops), 1)
    ge.reroute_ts(checkpoints_disconnected.values(),
                  checkpoints_disconnected.keys(),
                  can_modify=copied_ops)
    debug_print(
        "Rewired {} in place of {} restricted to {}".format(
            checkpoints_disconnected.values(), checkpoints_disconnected.keys(),
            copied_ops), 1)

    # get gradients with respect to current boundary + original x's
    copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys]
    boundary = list(checkpoints_disconnected.values())
    dv = tf_gradients(ys=copied_ys,
                      xs=boundary + xs,
                      grad_ys=grad_ys,
                      **kwargs)
    debug_print("Got gradients {}".format(dv), 1)
    debug_print("for {}".format(copied_ys), 1)
    debug_print("with respect to {}".format(boundary + xs), 1)

    inputs_to_do_before = [y.op for y in ys]
    if grad_ys is not None:
        inputs_to_do_before += grad_ys
    wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
    my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)

    # partial derivatives to the checkpointed nodes
    # dictionary of "node: backprop" for nodes in the boundary
    d_checkpoints = {
        r: dr
        for r, dr in zip(checkpoints_disconnected.keys(),
                         dv[:len(checkpoints_disconnected)])
    }
    # partial derivatives to xs (usually the params of the neural net)
    d_xs = dv[len(checkpoints_disconnected):]

    # incorporate derivatives flowing through the checkpointed nodes
    checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops)
    for ts in checkpoints_sorted_lists[::-1]:
        debug_print("Processing list {}".format(ts), 1)
        checkpoints_other = [r for r in checkpoints if r not in ts]
        checkpoints_disconnected_other = [
            checkpoints_disconnected[r] for r in checkpoints_other
        ]

        # copy part of the graph below current checkpoint node, stopping at
        # other checkpoints nodes
        ops_to_copy = fast_backward_ops(within_ops=fwd_ops,
                                        seed_ops=[r.op for r in ts],
                                        stop_at_ts=checkpoints_other)
        debug_print(
            "Found {} ops to copy within {}, seed {}, stop_at {}".format(
                len(ops_to_copy), fwd_ops, [r.op for r in ts],
                checkpoints_other), 1)
        debug_print("ops_to_copy = {}".format(ops_to_copy), 1)
        if not ops_to_copy:  # we're done!
            break
        copied_sgv, info = ge.copy_with_input_replacements(
            ge.sgv(ops_to_copy), {})
        for origin_op, op in info._transformed_ops.items():
            op._set_device(origin_op.node_def.device)
        copied_ops = info._transformed_ops.values()
        debug_print("Copied {} to {}".format(ops_to_copy, copied_ops), 1)
        ge.reroute_ts(checkpoints_disconnected_other,
                      checkpoints_other,
                      can_modify=copied_ops)
        debug_print(
            "Rewired {} in place of {} restricted to {}".format(
                checkpoints_disconnected_other, checkpoints_other, copied_ops),
            1)

        # gradient flowing through the checkpointed node
        boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts]
        substitute_backprops = [d_checkpoints[r] for r in ts]
        dv = tf_gradients(boundary,
                          checkpoints_disconnected_other + xs,
                          grad_ys=substitute_backprops,
                          **kwargs)
        debug_print("Got gradients {}".format(dv), 1)
        debug_print("for {}".format(boundary), 1)
        debug_print(
            "with respect to {}".format(checkpoints_disconnected_other + xs),
            1)
        debug_print(
            "with boundary backprop substitutions {}".format(
                substitute_backprops), 1)

        inputs_to_do_before = [d_checkpoints[r].op for r in ts]
        wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
        my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)

        # partial derivatives to the checkpointed nodes
        for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]):
            if dr is not None:
                if d_checkpoints[r] is None:
                    d_checkpoints[r] = dr
                else:
                    d_checkpoints[r] += dr

        def _unsparsify(x):
            if not isinstance(x, tf.IndexedSlices):
                return x
            assert x.dense_shape is not None, "memory_saving_gradients encountered sparse gradients of unknown shape"
            indices = x.indices
            while indices.shape.ndims < x.values.shape.ndims:
                indices = tf.expand_dims(indices, -1)
            return tf.scatter_nd(indices, x.values, x.dense_shape)

        # partial derivatives to xs (usually the params of the neural net)
        d_xs_new = dv[len(checkpoints_other):]
        for j in range(len(xs)):
            if d_xs_new[j] is not None:
                if d_xs[j] is None:
                    d_xs[j] = _unsparsify(d_xs_new[j])
                else:
                    d_xs[j] += _unsparsify(d_xs_new[j])

    return d_xs
Ejemplo n.º 12
0
def gradients(ys, xs, grad_ys=None, checkpoints='collection', **kwargs):
    '''
    Authors: Tim Salimans & Yaroslav Bulatov

    memory efficient gradient implementation inspired by "Training Deep Nets with Sublinear Memory Cost"
    by Chen et al. 2016 (https://arxiv.org/abs/1604.06174)

    ys,xs,grad_ys,kwargs are the arguments to standard tensorflow tf.gradients
    (https://www.tensorflow.org/versions/r0.12/api_docs/python/train.html#gradients)

    'checkpoints' can either be
        - a list consisting of tensors from the forward pass of the neural net
          that we should re-use when calculating the gradients in the backward pass
          all other tensors that do not appear in this list will be re-computed
        - a string specifying how this list should be determined. currently we support
            - 'speed':  checkpoint all outputs of convolutions and matmuls. these ops are usually the most expensive,
                        so checkpointing them maximizes the running speed
                        (this is a good option if nonlinearities, concats, batchnorms, etc are taking up a lot of memory)
            - 'memory': try to minimize the memory usage
                        (currently using a very simple strategy that identifies a number of bottleneck tensors in the graph to checkpoint)
            - 'collection': look for a tensorflow collection named 'checkpoints', which holds the tensors to checkpoint
    '''

    #    print("Calling memsaving gradients with", checkpoints)
    if not isinstance(ys, list):
        ys = [ys]
    if not isinstance(xs, list):
        xs = [xs]

    bwd_ops = ge.get_backward_walk_ops([y.op for y in ys], inclusive=True)

    debug_print("bwd_ops: %s", bwd_ops)

    # forward ops are all ops that are candidates for recomputation
    fwd_ops = ge.get_forward_walk_ops([x.op for x in xs],
                                      inclusive=True,
                                      within_ops=bwd_ops)
    debug_print("fwd_ops: %s", fwd_ops)

    # exclude ops with no inputs
    fwd_ops = [op for op in fwd_ops if op.inputs]

    # don't recompute xs, remove variables
    xs_ops = _to_ops(xs)
    fwd_ops = [op for op in fwd_ops if not op in xs_ops]
    fwd_ops = [op for op in fwd_ops if not '/assign' in op.name]
    fwd_ops = [op for op in fwd_ops if not '/Assign' in op.name]
    fwd_ops = [op for op in fwd_ops if not '/read' in op.name]
    ts_all = ge.filter_ts(fwd_ops, True)  # get the tensors
    ts_all = [t for t in ts_all if '/read' not in t.name]
    ts_all = set(ts_all) - set(xs) - set(ys)

    checkpoints = 'collection'
    # construct list of tensors to checkpoint during forward pass, if not
    # given as input

    stereo_checkpoints = ge.filter_ts_from_regex(fwd_ops, "add")
    motion_checkpoints = ge.filter_ts_from_regex(fwd_ops, "Conv2D")

    my_ckps = []
    for x in motion_checkpoints:
        if ("motion" in x.name) and ("BatchNorm" not in x.name):
            my_ckps.append(x)
    for x in stereo_checkpoints:
        if ("stereo" in x.name) and ("BatchNorm" not in x.name):
            my_ckps.append(x)

    checkpoints = my_ckps
    checkpoints = list(set(checkpoints).intersection(ts_all))

    # at this point automatic selection happened and checkpoints is list of nodes
    assert isinstance(checkpoints, list)

    debug_print("Checkpoint nodes used: %s", checkpoints)
    # better error handling of special cases
    # xs are already handled as checkpoint nodes, so no need to include them
    xs_intersect_checkpoints = set(xs).intersection(set(checkpoints))
    if xs_intersect_checkpoints:
        debug_print("Warning, some input nodes are also checkpoint nodes: %s",
                    xs_intersect_checkpoints)
    ys_intersect_checkpoints = set(ys).intersection(set(checkpoints))
    debug_print("ys: %s, checkpoints: %s, intersect: %s", ys, checkpoints,
                ys_intersect_checkpoints)
    # saving an output node (ys) gives no benefit in memory while creating
    # new edge cases, exclude them
    if ys_intersect_checkpoints:
        debug_print(
            "Warning, some output nodes are also checkpoints nodes: %s",
            format_ops(ys_intersect_checkpoints))

    # remove initial and terminal nodes from checkpoints list if present
    checkpoints = list(set(checkpoints) - set(ys) - set(xs))

    # check that we have some nodes to checkpoint
    if not checkpoints:
        raise Exception('no checkpoints nodes found or given as input! ')

    # disconnect dependencies between checkpointed tensors
    checkpoints_disconnected = {}
    for x in checkpoints:
        if x.op and x.op.name is not None:
            grad_node = tf.stop_gradient(x, name=x.op.name + "_sg")
        else:
            grad_node = tf.stop_gradient(x)
        checkpoints_disconnected[x] = grad_node

    # partial derivatives to the checkpointed tensors and xs
    ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys],
                                    stop_at_ts=checkpoints,
                                    within_ops=fwd_ops)
    debug_print("Found %s ops to copy within fwd_ops %s, seed %s, stop_at %s",
                len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints)
    debug_print("ops_to_copy = %s", ops_to_copy)
    debug_print("Processing list %s", ys)
    copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {})
    copied_ops = info._transformed_ops.values()
    debug_print("Copied %s to %s", ops_to_copy, copied_ops)
    ge.reroute_ts(checkpoints_disconnected.values(),
                  checkpoints_disconnected.keys(),
                  can_modify=copied_ops)
    debug_print("Rewired %s in place of %s restricted to %s",
                checkpoints_disconnected.values(),
                checkpoints_disconnected.keys(), copied_ops)

    # get gradients with respect to current boundary + original x's
    copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys]
    boundary = list(checkpoints_disconnected.values())
    dv = tf_gradients(ys=copied_ys,
                      xs=boundary + xs,
                      grad_ys=grad_ys,
                      **kwargs)
    debug_print("Got gradients %s", dv)
    debug_print("for %s", copied_ys)
    debug_print("with respect to %s", boundary + xs)

    inputs_to_do_before = [y.op for y in ys]
    if grad_ys is not None:
        inputs_to_do_before += grad_ys
    wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
    my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)

    # partial derivatives to the checkpointed nodes
    # dictionary of "node: backprop" for nodes in the boundary
    d_checkpoints = {
        r: dr
        for r, dr in zip(checkpoints_disconnected.keys(),
                         dv[:len(checkpoints_disconnected)])
    }
    # partial derivatives to xs (usually the params of the neural net)
    d_xs = dv[len(checkpoints_disconnected):]

    # incorporate derivatives flowing through the checkpointed nodes
    checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops)
    for ts in checkpoints_sorted_lists[::-1]:
        debug_print("Processing list %s", ts)
        checkpoints_other = [r for r in checkpoints if r not in ts]
        checkpoints_disconnected_other = [
            checkpoints_disconnected[r] for r in checkpoints_other
        ]

        # copy part of the graph below current checkpoint node, stopping at
        # other checkpoints nodes
        ops_to_copy = fast_backward_ops(within_ops=fwd_ops,
                                        seed_ops=[r.op for r in ts],
                                        stop_at_ts=checkpoints_other)
        debug_print("Found %s ops to copy within %s, seed %s, stop_at %s",
                    len(ops_to_copy), fwd_ops, [r.op for r in ts],
                    checkpoints_other)
        debug_print("ops_to_copy = %s", ops_to_copy)
        if not ops_to_copy:  # we're done!
            break
        copied_sgv, info = ge.copy_with_input_replacements(
            ge.sgv(ops_to_copy), {})
        copied_ops = info._transformed_ops.values()
        debug_print("Copied %s to %s", ops_to_copy, copied_ops)
        ge.reroute_ts(checkpoints_disconnected_other,
                      checkpoints_other,
                      can_modify=copied_ops)
        debug_print("Rewired %s in place of %s restricted to %s",
                    checkpoints_disconnected_other, checkpoints_other,
                    copied_ops)

        # gradient flowing through the checkpointed node
        boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts]
        substitute_backprops = [d_checkpoints[r] for r in ts]
        dv = tf_gradients(boundary,
                          checkpoints_disconnected_other + xs,
                          grad_ys=substitute_backprops,
                          **kwargs)
        debug_print("Got gradients %s", dv)
        debug_print("for %s", boundary)
        debug_print("with respect to %s", checkpoints_disconnected_other + xs)
        debug_print("with boundary backprop substitutions %s",
                    substitute_backprops)

        inputs_to_do_before = [d_checkpoints[r].op for r in ts]
        wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
        my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)

        # partial derivatives to the checkpointed nodes
        for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]):
            if dr is not None:
                if d_checkpoints[r] is None:
                    d_checkpoints[r] = dr
                else:
                    d_checkpoints[r] += dr

        # partial derivatives to xs (usually the params of the neural net)
        d_xs_new = dv[len(checkpoints_other):]
        for j in range(len(xs)):
            if d_xs_new[j] is not None:
                if d_xs[j] is None:
                    d_xs[j] = d_xs_new[j]
                else:
                    d_xs[j] += d_xs_new[j]

    return d_xs
Ejemplo n.º 13
0
  def apply(self, new_inputs, update_colocation_groups=True):
    assert len(new_inputs) == len(self.inputs)
    g = tf.get_default_graph()  # todo: make that member variable

    new_inputs2 = []
    # replace variable inputs with their read endpoint
    for input in new_inputs:
      if isinstance(input, tf.Variable):
        new_inputs2.append(input.read_value())
      else:
        new_inputs2.append(input)
    new_inputs = new_inputs2
    
    replacements = OrderedDict()
    for old_input_ts, new_input_ts in zip(self.inputs, new_inputs):
      # shape/dtype checks
      if isinstance(old_input_ts, (list, tuple)):
        reference_ts = old_input_ts[0]
      else:
        reference_ts = old_input_ts
      assert reference_ts.get_shape() == new_input_ts.get_shape()
      assert reference_ts.dtype == new_input_ts.dtype

      # Variable with multiple read endpoints, replace all of them with
      # new input tensor
      if isinstance(old_input_ts, (list, tuple)):
        for sub_input in old_input_ts:
          replacements[sub_input] = new_input_ts
      # regular Tensor
      else:
        replacements[old_input_ts] = new_input_ts


    # sanity checks
    # copying Variables is confusing because they don't get added
    # to GLOBAL_VARIABLES collection hence escape global initialization
    # therefore forbit it
    for op in self.ops:
      if op.type.startswith('Variable'): # 'VariableV2' or 'Variable'
        assert False, "Can't copy variables"


    # TODO: remove this
    def summarize_ts(ts):
      from collections import Counter
      type_counter = Counter()
      ops = set([tensor.op for tensor in ts])
      print Counter([op.type for op in ops]).most_common(10)



        
    sgv = ge.sgv(self.ops)
    #    import pdb; pdb.set_trace()
    copied_sgv, info = ge.copy_with_input_replacements(sgv,
                                                       replacements)


    # converting between Python bytes and unicode
    def to_bytes(s): return s.encode('ascii')
    def from_bytes(s): return s.decode('ascii')

    # fix colocation constraints to point to copied ops
    # see https://github.com/tensorflow/tensorflow/issues/9925
    if update_colocation_groups:
      new_ops = [info._transformed_ops[op] for op in self.ops]
      for new_op in new_ops:
        assert len(new_op.colocation_groups()) == 1
        colocation_group = new_op.colocation_groups()[0]
        assert colocation_group.startswith(b'loc:@')
        colocated_with_name = from_bytes(colocation_group[len(b'loc:@'):])

        # if there were no colocation constraints, the op gets colocated with
        # itself (default colocation group), ignore that constraint
        if colocated_with_name == new_op.name:
          continue

        colocation_op = g.get_operation_by_name(colocated_with_name)
        if colocation_op in info._transformed_ops:
          new_colocation_op = info._transformed_ops[colocation_op]
        else:
          assert colocation_op in self.input_ops
          colocation_op_idx = self.input_ops.index(colocation_op)
          new_colocation_op = new_inputs[colocation_op_idx].op

        # overwrite existing _class attribute with new colocation constraints
        new_colocation_groups = [b'loc:@'+to_bytes(new_colocation_op.name)]
        new_op.node_def.attr["_class"].CopyFrom(attr_value_pb2.AttrValue(
          list=attr_value_pb2.AttrValue.ListValue(s=new_colocation_groups)))
    
    # place new ops on device from current device context
    device = get_current_device()
    if device:
      for op in info._transformed_ops.values():
        op._set_device(device)
      
    new_outputs = []
    for old_output_ts in self.outputs:
      new_output_op = info._transformed_ops[old_output_ts.op]
      new_output_ts = new_output_op.outputs[0]
      new_outputs.append(new_output_ts)
      
    return new_outputs
Ejemplo n.º 14
0
def gradients(ys, xs,   # pylint: disable: too-many-statements, too-many-branches
              grad_ys=None, checkpoints='collection', **kwargs):
    '''
    Authors: Tim Salimans & Yaroslav Bulatov

    memory efficient gradient implementation inspired by "Training Deep Nets with Sublinear Memory
    Cost" by Chen et al. 2016 (https://arxiv.org/abs/1604.06174)

    ys,xs,grad_ys,kwargs are the arguments to standard tensorflow tf.gradients
    (https://www.tensorflow.org/versions/r0.12/api_docs/python/train.html#gradients)

    'checkpoints' can either be
        - a list consisting of tensors from the forward pass of the neural net
          that we should re-use when calculating the gradients in the backward pass
          all other tensors that do not appear in this list will be re-computed
        - a string specifying how this list should be determined. currently we support
            - 'speed':  checkpoint all outputs of convolutions and matmuls. these ops are usually
                        the most expensive, so checkpointing them maximizes the running speed
                        (this is a good option if nonlinearities, concats, batchnorms, etc are
                        taking up a lot of memory)
            - 'memory': try to minimize the memory usage
                        (currently using a very simple strategy that identifies a number of
                        bottleneck tensors in the graph to checkpoint)
            - 'collection': look for a tensorflow collection named 'checkpoints', which holds the
                            tensors to checkpoint
    '''

    #    print("Calling memsaving gradients with", checkpoints)
    if not isinstance(ys, list):
        ys = [ys]
    if not isinstance(xs, list):
        xs = [xs]

    bwd_ops = ge.get_backward_walk_ops([y.op for y in ys],
                                       inclusive=True)

    debug_print("bwd_ops: {}".format(bwd_ops))

    # forward ops are all ops that are candidates for recomputation
    fwd_ops = ge.get_forward_walk_ops([x.op for x in xs],
                                      inclusive=True,
                                      within_ops=bwd_ops)
    debug_print("fwd_ops: {}".format(fwd_ops))

    # exclude ops with no inputs
    fwd_ops = [op for op in fwd_ops if op.inputs]

    # don't recompute xs, remove variables
    xs_ops = _to_ops(xs)
    fwd_ops = [op for op in fwd_ops if op not in xs_ops]
    fwd_ops = [op for op in fwd_ops if '/assign' not in op.name]
    fwd_ops = [op for op in fwd_ops if '/Assign' not in op.name]
    fwd_ops = [op for op in fwd_ops if '/read' not in op.name]
    ts_all = ge.filter_ts(fwd_ops, True)  # get the tensors
    ts_all = [t for t in ts_all if '/read' not in t.name]
    ts_all = set(ts_all) - set(xs) - set(ys)

    # construct list of tensors to checkpoint during forward pass, if not
    # given as input
    if type(checkpoints) is not list:
        if checkpoints == 'collection':
            checkpoints = tf.get_collection('checkpoints')

        elif checkpoints == 'speed':
            # checkpoint all expensive ops to maximize running speed
            checkpoints = ge.filter_ts_from_regex(fwd_ops, 'conv2d|Conv|MatMul')

        elif checkpoints == 'memory':

            # remove very small tensors and some weird ops
            def fixdims(t):  # tf.Dimension values are not compatible with int, convert manually
                try:
                    return [int(e if e.value is not None else 64) for e in t]
                except:
                    return [0]  # unknown shape
            ts_all = [t for t in ts_all if np.prod(fixdims(t.shape)) > MIN_CHECKPOINT_NODE_SIZE]
            ts_all = [t for t in ts_all if 'L2Loss' not in t.name]
            ts_all = [t for t in ts_all if 'entropy' not in t.name]
            ts_all = [t for t in ts_all if 'FusedBatchNorm' not in t.name]
            ts_all = [t for t in ts_all if 'Switch' not in t.name]
            ts_all = [t for t in ts_all if 'dropout' not in t.name]
            # DV: FP16_FIX - need to add 'Cast' layer here to make it work for FP16
            ts_all = [t for t in ts_all if 'Cast' not in t.name]

            # filter out all tensors that are inputs of the backward graph
            with util.capture_ops() as bwd_ops:
                tf_gradients(ys, xs, grad_ys, **kwargs)

            bwd_inputs = [t for op in bwd_ops for t in op.inputs]
            # list of tensors in forward graph that is in input to bwd graph
            ts_filtered = list(set(bwd_inputs).intersection(ts_all))
            debug_print("Using tensors {}".format(ts_filtered))

            # try two slightly different ways of getting bottlenecks tensors
            # to checkpoint
            for ts in [ts_filtered, ts_all]:

                # get all bottlenecks in the graph
                bottleneck_ts = []
                for t in ts:
                    b = set(ge.get_backward_walk_ops(t.op, inclusive=True, within_ops=fwd_ops))
                    f = set(ge.get_forward_walk_ops(t.op, inclusive=False, within_ops=fwd_ops))
                    # check that there are not shortcuts
                    b_inp = set([inp for op in b for inp in op.inputs]).intersection(ts_all)
                    f_inp = set([inp for op in f for inp in op.inputs]).intersection(ts_all)
                    if not set(b_inp).intersection(f_inp) and len(b_inp)+len(f_inp) >= len(ts_all):
                        bottleneck_ts.append(t)  # we have a bottleneck!
                    else:
                        debug_print("Rejected bottleneck candidate and ops {}".format(
                            [t] + list(set(ts_all) - set(b_inp) - set(f_inp))))

                # success? or try again without filtering?
                if len(bottleneck_ts) >= np.sqrt(len(ts_filtered)):  # enough bottlenecks found!
                    break

            if not bottleneck_ts:
                raise Exception('unable to find bottleneck tensors! please provide checkpoint '
                                'nodes manually, or use checkpoints="speed".')

            # sort the bottlenecks
            bottlenecks_sorted_lists = tf_toposort(bottleneck_ts, within_ops=fwd_ops)
            sorted_bottlenecks = [t for ts in bottlenecks_sorted_lists for t in ts]

            # save an approximately optimal number ~ sqrt(N)
            N = len(ts_filtered)
            if len(bottleneck_ts) <= np.ceil(np.sqrt(N)):
                checkpoints = sorted_bottlenecks
            else:
                step = int(np.ceil(len(bottleneck_ts) / np.sqrt(N)))
                checkpoints = sorted_bottlenecks[step::step]

        else:
            raise Exception('%s is unsupported input for "checkpoints"' % (checkpoints,))

    checkpoints = list(set(checkpoints).intersection(ts_all))

    # at this point automatic selection happened and checkpoints is list of nodes
    assert isinstance(checkpoints, list)

    debug_print("Checkpoint nodes used: {}".format(checkpoints))
    # better error handling of special cases
    # xs are already handled as checkpoint nodes, so no need to include them
    xs_intersect_checkpoints = set(xs).intersection(set(checkpoints))
    if xs_intersect_checkpoints:
        debug_print("Warning, some input nodes are also checkpoint nodes: {}".format(
            xs_intersect_checkpoints))
    ys_intersect_checkpoints = set(ys).intersection(set(checkpoints))
    debug_print("ys: {}, checkpoints:{}, intersect: {}".format(
        ys, checkpoints, ys_intersect_checkpoints))
    # saving an output node (ys) gives no benefit in memory while creating
    # new edge cases, exclude them
    if ys_intersect_checkpoints:
        debug_print("Warning, some output nodes are also checkpoints nodes: {}".format(
            format_ops(ys_intersect_checkpoints)))

    # remove initial and terminal nodes from checkpoints list if present
    checkpoints = list(set(checkpoints) - set(ys) - set(xs))

    # check that we have some nodes to checkpoint
    if not checkpoints:
        raise Exception('no checkpoints nodes found or given as input! ')

    # disconnect dependencies between checkpointed tensors
    checkpoints_disconnected = {}
    for x in checkpoints:
        if x.op and x.op.name is not None:
            grad_node = tf.stop_gradient(x, name=x.op.name+"_sg")
        else:
            grad_node = tf.stop_gradient(x)
        checkpoints_disconnected[x] = grad_node

    # partial derivatives to the checkpointed tensors and xs
    ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys],
                                    stop_at_ts=checkpoints, within_ops=fwd_ops)
    debug_print("Found {} ops to copy within fwd_ops {}, seed {}, stop_at {}".format(
        len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints))
    debug_print("ops_to_copy = {}".format(ops_to_copy))
    debug_print("Processing list {}".format(ys))
    _, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {})
    for origin_op, op in info._transformed_ops.items():
        op._set_device(origin_op.node_def.device)
    copied_ops = info._transformed_ops.values()
    debug_print("Copied {} to {}".format(ops_to_copy, copied_ops))
    ge.reroute_ts(checkpoints_disconnected.values(),
                  checkpoints_disconnected.keys(),
                  can_modify=copied_ops)
    debug_print("Rewired {} in place of {} restricted to {}".format(
        checkpoints_disconnected.values(), checkpoints_disconnected.keys(), copied_ops))

    # get gradients with respect to current boundary + original x's
    copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys]
    boundary = list(checkpoints_disconnected.values())
    dv = tf_gradients(ys=copied_ys, xs=boundary+xs, grad_ys=grad_ys, **kwargs)
    debug_print("Got gradients {}".format(dv))
    debug_print("for %s", copied_ys)
    debug_print("with respect to {}".format(boundary+xs))

    inputs_to_do_before = [y.op for y in ys]
    if grad_ys is not None:
        inputs_to_do_before += grad_ys
    wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
    my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)

    # partial derivatives to the checkpointed nodes
    # dictionary of "node: backprop" for nodes in the boundary
    d_checkpoints = {r: dr for r, dr in zip(checkpoints_disconnected.keys(),
                                            dv[:len(checkpoints_disconnected)])}
    # partial derivatives to xs (usually the params of the neural net)
    d_xs = dv[len(checkpoints_disconnected):]

    # incorporate derivatives flowing through the checkpointed nodes
    checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops)
    for ts in checkpoints_sorted_lists[::-1]:
        debug_print("Processing list {}".format(ts))
        checkpoints_other = [r for r in checkpoints if r not in ts]
        checkpoints_disconnected_other = [checkpoints_disconnected[r] for r in checkpoints_other]

        # copy part of the graph below current checkpoint node, stopping at
        # other checkpoints nodes
        ops_to_copy = fast_backward_ops(within_ops=fwd_ops,
                                        seed_ops=[r.op for r in ts],
                                        stop_at_ts=checkpoints_other)
        debug_print("Found {} ops to copy within {}, seed {}, stop_at {}".format(
            len(ops_to_copy), fwd_ops, [r.op for r in ts], checkpoints_other))
        debug_print("ops_to_copy = {}".format(ops_to_copy))
        if not ops_to_copy:  # we're done!
            break
        _, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {})
        for origin_op, op in info._transformed_ops.items():
            op._set_device(origin_op.node_def.device)
        copied_ops = info._transformed_ops.values()
        debug_print("Copied {} to {}".format(ops_to_copy, copied_ops))
        ge.reroute_ts(checkpoints_disconnected_other, checkpoints_other, can_modify=copied_ops)
        debug_print("Rewired %s in place of %s restricted to %s",
                    checkpoints_disconnected_other, checkpoints_other, copied_ops)

        # gradient flowing through the checkpointed node
        boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts]
        substitute_backprops = [d_checkpoints[r] for r in ts]
        dv = tf_gradients(boundary,
                          checkpoints_disconnected_other+xs,
                          grad_ys=substitute_backprops, **kwargs)
        debug_print("Got gradients {}".format(dv))
        debug_print("for {}".format(boundary))
        debug_print("with respect to {}".format(checkpoints_disconnected_other+xs))
        debug_print("with boundary backprop substitutions {}".format(substitute_backprops))

        inputs_to_do_before = [d_checkpoints[r].op for r in ts]
        wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
        my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)

        # partial derivatives to the checkpointed nodes
        for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]):
            if dr is not None:
                if d_checkpoints[r] is None:
                    d_checkpoints[r] = dr
                else:
                    d_checkpoints[r] += dr

        def _unsparsify(var_x):
            if not isinstance(var_x, tf.IndexedSlices):
                return var_x
            assert var_x.dense_shape is not None, \
                "memory_saving_gradients encountered sparse gradients of unknown shape"
            indices = var_x.indices
            while indices.shape.ndims < var_x.values.shape.ndims:
                indices = tf.expand_dims(indices, -1)
            return tf.scatter_nd(indices, var_x.values, var_x.dense_shape)

        # partial derivatives to xs (usually the params of the neural net)
        d_xs_new = dv[len(checkpoints_other):]
        for j in range(len(xs)):
            if d_xs_new[j] is not None:
                if d_xs[j] is None:
                    d_xs[j] = _unsparsify(d_xs_new[j])
                else:
                    d_xs[j] += _unsparsify(d_xs_new[j])

    return d_xs
    def run(self):

        checkpoints = self._get_checkpoint()

        # at this point automatic selection happened and checkpoints is list of nodes
        assert isinstance(checkpoints, list)

        self._log_info("Checkpoint nodes used: {}".format(checkpoints), 1)
        # better error handling of special cases
        # xs are already handled as checkpoint nodes, so no need to include them
        xs_intersect_checkpoints = set(self._xs).intersection(set(checkpoints))
        if xs_intersect_checkpoints:
            self._log_info(
                "Warning, some input nodes are also checkpoint nodes: %s".
                format(xs_intersect_checkpoints))
        ys_intersect_checkpoints = set(self._ys).intersection(set(checkpoints))
        self._log_info(
            "ys: %s, checkpoints: {}, intersect: {}".format(
                self._ys, checkpoints, ys_intersect_checkpoints), 1)
        # saving an output node (ys) gives no benefit in memory while creating
        # new edge cases, exclude them
        if ys_intersect_checkpoints:
            self._log_info(
                "Warning, some output nodes are also checkpoints nodes: {}".
                format(self.format_ops(ys_intersect_checkpoints)))

        # remove initial and terminal nodes from checkpoints list if present
        checkpoints = list(set(checkpoints) - set(self._ys) - set(self._xs))

        # check that we have some nodes to checkpoint
        if not checkpoints:
            raise Exception('no checkpoints nodes found or given as input! ')

        # disconnect dependencies between checkpointed tensors
        checkpoints_disconnected = {}
        for x in checkpoints:
            if x.op and x.op.name is not None:
                grad_node = tf.stop_gradient(x, name=x.op.name + "_sg")
            else:
                grad_node = tf.stop_gradient(x)
            checkpoints_disconnected[x] = grad_node

        # partial derivatives to the checkpointed tensors and xs
        ops_to_copy = self.fast_backward_ops(seed_ops=[y.op for y in self._ys],
                                             stop_at_ts=checkpoints,
                                             within_ops=self.fwd_ops)
        self._log_info(
            "Found %s ops to copy within fwd_ops {}, seed {}, stop_at {}".
            format(len(ops_to_copy), self.fwd_ops, [r.op for r in self._ys],
                   checkpoints), 1)
        self._log_info("ops_to_copy = {}".format(ops_to_copy, 1))
        self._log_info("Processing list {}".format(self._ys), 1)
        copied_sgv, info = ge.copy_with_input_replacements(
            ge.sgv(ops_to_copy), {})
        for origin_op, op in info._transformed_ops.items():
            op._set_device(origin_op.node_def.device)
        copied_ops = info._transformed_ops.values()
        self._log_info("Copied {} to {}".format(ops_to_copy, copied_ops, 1))
        ge.reroute_ts(checkpoints_disconnected.values(),
                      checkpoints_disconnected.keys(),
                      can_modify=copied_ops)
        self._log_info(
            "Rewired %s in place of {} restricted to {}".format(
                checkpoints_disconnected.values(),
                checkpoints_disconnected.keys(), copied_ops), 1)

        # get gradients with respect to current boundary + original x's
        copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in self._ys]
        boundary = list(checkpoints_disconnected.values())
        dv = tf_gradients(ys=copied_ys,
                          xs=boundary + self._xs,
                          grad_ys=self._grad_ys,
                          **self._kwargs)
        self._log_info("Got gradients {}".format(dv), 1)
        self._log_info("for {}".format(copied_ys), 1)
        self._log_info("with respect to {}".format(boundary + self._xs), 1)

        inputs_to_do_before = [y.op for y in self._ys]
        if self._grad_ys is not None:
            inputs_to_do_before += self._grad_ys
        wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
        self.my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)

        # partial derivatives to the checkpointed nodes
        # dictionary of "node: backprop" for nodes in the boundary
        d_checkpoints = {
            r: dr
            for r, dr in zip(checkpoints_disconnected.keys(),
                             dv[:len(checkpoints_disconnected)])
        }
        # partial derivatives to xs (usually the params of the neural net)
        d_xs = dv[len(checkpoints_disconnected):]

        # incorporate derivatives flowing through the checkpointed nodes
        checkpoints_sorted_lists = self.tf_toposort(checkpoints,
                                                    within_ops=self.fwd_ops)
        for ts in checkpoints_sorted_lists[::-1]:
            self._log_info("Processing list {}".format(ts), 1)
            checkpoints_other = [r for r in checkpoints if r not in ts]
            checkpoints_disconnected_other = [
                checkpoints_disconnected[r] for r in checkpoints_other
            ]

            # copy part of the graph below current checkpoint node, stopping at
            # other checkpoints nodes
            ops_to_copy = self.fast_backward_ops(within_ops=self.fwd_ops,
                                                 seed_ops=[r.op for r in ts],
                                                 stop_at_ts=checkpoints_other)
            self._log_info(
                "Found {} ops to copy within {}, seed {}, stop_at {}".format(
                    len(ops_to_copy), self.fwd_ops, [r.op for r in ts],
                    checkpoints_other), 1)
            self._log_info("ops_to_copy = {}".format(ops_to_copy), 1)
            if not ops_to_copy:  # we're done!
                break
            copied_sgv, info = ge.copy_with_input_replacements(
                ge.sgv(ops_to_copy), {})
            for origin_op, op in info._transformed_ops.items():
                op._set_device(origin_op.node_def.device)
            copied_ops = info._transformed_ops.values()
            self._log_info("Copied {} to {}".format(ops_to_copy, copied_ops),
                           1)
            ge.reroute_ts(checkpoints_disconnected_other,
                          checkpoints_other,
                          can_modify=copied_ops)
            self._log_info(
                "Rewired {} in place of {} restricted to {}".format(
                    checkpoints_disconnected_other, checkpoints_other,
                    copied_ops), 1)

            # gradient flowing through the checkpointed node
            boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts]
            substitute_backprops = [d_checkpoints[r] for r in ts]
            dv = tf_gradients(boundary,
                              checkpoints_disconnected_other + self._xs,
                              grad_ys=substitute_backprops,
                              **self._kwargs)
            self._log_info("Got gradients {}".format(dv), 1)
            self._log_info("for {}".format(boundary), 1)
            self._log_info(
                "with respect to {}".format(checkpoints_disconnected_other +
                                            self._xs), 1)
            self._log_info(
                "with boundary backprop substitutions {}".format(
                    substitute_backprops), 1)

            inputs_to_do_before = [d_checkpoints[r].op for r in ts]
            wait_to_do_ops = list(copied_ops) + [
                g.op for g in dv if g is not None
            ]
            self.my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)

            # partial derivatives to the checkpointed nodes
            for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]):
                if dr is not None:
                    if d_checkpoints[r] is None:
                        d_checkpoints[r] = dr
                    else:
                        d_checkpoints[r] += dr

            def _unsparsify(x):
                if not isinstance(x, tf.IndexedSlices):
                    return x
                assert x.dense_shape is not None, "memory_saving_gradients encountered sparse gradients of unknown shape"
                indices = x.indices
                while indices.shape.ndims < x.values.shape.ndims:
                    indices = tf.expand_dims(indices, -1)
                return tf.scatter_nd(indices, x.values, x.dense_shape)

            # partial derivatives to xs (usually the params of the neural net)
            d_xs_new = dv[len(checkpoints_other):]
            for j in range(len(self._xs)):
                if d_xs_new[j] is not None:
                    if d_xs[j] is None:
                        d_xs[j] = _unsparsify(d_xs_new[j])
                    else:
                        d_xs[j] += _unsparsify(d_xs_new[j])

        return d_xs