def cut_graph_def(graph_def, cut_nodes): """Cut groph_def to two parts by cut_nodes. All ancesters of cut_nodes are put into back and the rest are put into head. Args: graph_def: input tf.GraphDef cut_nodes: a list of node names to cut """ # back back = tf.graph_util.extract_sub_graph(graph_def, cut_nodes) # head head_node_names = [n.name for n in graph_def.node if n not in back.node] with tf.Graph().as_default() as graph: tf.import_graph_def(graph_def, name="") all_ops = make_list_of_op(graph) head_ops = [o for o in all_ops if o.name in head_node_names] head_subgraph = SubGraphView(inside_ops=head_ops) head_graph = tf.Graph() replace_ts = {} for i in head_subgraph.inputs: k = i.name replace_ts[k] = make_placeholder_from_tensor(i) copy_with_input_replacements( head_subgraph, replace_ts, dst_graph=head_graph ) # return return back, head_graph.as_graph_def()
def test_copy_with_input_replacements(self): with self.graph.as_default(): ten = tf.constant(10.0, shape=[10], name="Input") sgv, _ = ge.copy_with_input_replacements(self.o.op, {self.o.op.inputs[1]: ten}) with tf.Session() as sess: val = sess.run(sgv.outputs[0]) self.assertNear(np.linalg.norm(val - np.array([11])), 0.0, ERROR_TOLERANCE)
def _clone_model(self, model, perturbations, dst_scope): ''' make a copy of model and connect the resulting sub-graph to input ops of the original graph and parameter assignments by perturbator. ''' def not_placeholder_or_trainvar_filter(op): # print(op.name) if op.type == 'Placeholder': # evaluation sub-graphs will be fed from original placeholders return False for var_name in self.tvars: if op.name.startswith(var_name): # remove Some/Var/(read,assign,...) -- will be replaced with perturbations return False return True ops_without_inputs = ge.filter_ops(model.ops, not_placeholder_or_trainvar_filter) # print("ModelOPS=========================") # for o in ops_without_inputs: # print(o.name, o.type) # remove init op from clone if already present try: ops_without_inputs.remove(self.work_graph.get_operation_by_name("init")) except: pass clone_sgv = ge.make_view(ops_without_inputs) clone_sgv = clone_sgv.remove_unused_ops(control_inputs=True) input_replacements = {} for t in clone_sgv.inputs: if t.name in perturbations.keys(): # input from trainable var --> replace with perturbation input_replacements[t] = perturbations[t.name] else: # otherwise take input from original graph input_replacements[t] = self.work_graph.get_tensor_by_name(t.name) return ge.copy_with_input_replacements(clone_sgv, input_replacements, dst_scope=dst_scope)
def test_copy_with_input_replacements(self): with self.graph.as_default(): ten = tf.constant(10.0, shape=[10], name="Input") sgv, _ = ge.copy_with_input_replacements( self.o.op, {self.o.op.inputs[1]: ten}) with tf.Session() as sess: val = sess.run(sgv.outputs[0]) self.assertNear(np.linalg.norm(val - np.array([11])), 0.0, ERROR_TOLERANCE)
def _duplicate_layer(layer_name, layer_sgv, branch_name, add_to_collections=True): """Duplicates a network layer, while preserving connections. Args: layer_name: a layer is identified by its name scope layer_sgv: SubgraphView (see tf.contrib.graph_editor) branch_name: the duplicate is "layer_name + branch_name" add_to_collections: add duplicate vars to the same collections Returns: info: see ret vals of `tf.contrib.graph_editor.copy` var_duplication: a list of tuples (var, dup_of_var) """ if layer_name[-1] == '/': new_layer_name = layer_name[:-1] + branch_name + '/' else: new_layer_name = layer_name + branch_name replacement_ts = {} for op in layer_sgv.inputs: replacement_ts[op] = op duplicate_sgv, info = ge.copy_with_input_replacements( layer_sgv, replacement_ts=replacement_ts, src_scope=layer_name, dst_scope=new_layer_name) var_duplication = [] for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES): if layer_name not in v.name: continue vproto = v.to_proto() new_vardef = variable_pb2.VariableDef() for field, val in vproto.ListFields(): if isinstance(val, str): new_val = val.replace(layer_name, new_layer_name) else: new_val = val setattr(new_vardef, field.name, new_val) new_var = tf.Variable(variable_def=new_vardef) tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, new_var) var_duplication.append((v, new_var)) if add_to_collections: for k in tf.get_default_graph().get_all_collection_keys(): collection = tf.get_collection(k) if v in collection and new_var not in collection: tf.add_to_collection(k, new_var) return info, var_duplication
def clone_subgraph(outputs, mappings, clone_scope=''): NON_REPLICABLE = { 'Variable', 'VariableV2', 'AutoReloadVariable', 'MutableHashTable', 'MutableHashTableV2', 'MutableHashTableOfTensors', 'MutableHashTableOfTensorsV2', 'MutableDenseHashTable', 'MutableDenseHashTableV2', 'VarHandleOp', 'BoostedTreesEnsembleResourceHandleOp' } ops = ge.get_backward_walk_ops(outputs, stop_at_ts=mappings.keys()) ops_replicate = [op for op in ops if op.type not in NON_REPLICABLE] sgv = ge.make_view(*ops_replicate) _, info = ge.copy_with_input_replacements(sgv, mappings, dst_scope=clone_scope) return info.transformed(outputs)
def recompute_tensor(target, known_values, preceding_op=None, copy_known_values=False): """Computes target tensor from known_values. If preceding_op is not None, adds necessary control dependencies such that newly created computation takes place after preceding_op. If copy_known_values is set, also copies known_values (for nicer graph visualization) """ assert is_computable(target, known_values) # position of target in parent op target_pos = list(target.op.outputs).index(target) if copy_known_values: computation = ge.get_backward_walk_ops(target) else: computation = ge.get_backward_walk_ops(target, stop_at_ts=known_values) # create copy of computation copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(computation), {}) # find our target tensor in the new computation new_target_op = info._transformed_ops[target.op] new_target = new_target_op.outputs[target_pos] new_computation = list(info._transformed_ops.values()) # restrict computation to run after given op SAVE_ON_CONTROL_EDGES = True if SAVE_ON_CONTROL_EDGES: # only add "run_after" control dependencies to root of computation, # the rest automatically runs after because of data dependencies # TODO: more efficient implementation by walking back from new_target # instead of whole graph computation_graph = linearize_lib.get_graph(restrict_to=new_computation) # note, toposort order is reversed from networkx/mine convention computation_root = list(toposort.toposort(computation_graph))[-1] for op in computation_root: run_after(op, preceding_op) else: if preceding_op is not None: for op in info._transformed_ops.values(): run_after(op, preceding_op) return new_target
def _duplicate_graph(self, graph, vars_to_replace, name='Duplicated'): """ Duplicates loss graph with swapped variables. :return: Swapped graph. """ if graph in vars_to_replace: return vars_to_replace[graph] operations = [] def get_ops(t): if t.op.type != 'VariableV2' and t.op.type != 'Placeholder': operations.append(t.op) for i in t.op.inputs: if i not in vars_to_replace: get_ops(i) get_ops(graph) sgv = graph_editor.make_view(operations) with ops.name_scope(name): new_view, _ = graph_editor.copy_with_input_replacements( sgv, vars_to_replace) return new_view.outputs[sgv.output_index(graph)]
def gradients(ys, xs, grad_ys=None, checkpoints='collection', **kwargs): ''' Authors: Tim Salimans & Yaroslav Bulatov memory efficient gradient implementation inspired by "Training Deep Nets with Sublinear Memory Cost" by Chen et al. 2016 (https://arxiv.org/abs/1604.06174) ys,xs,grad_ys,kwargs are the arguments to standard tensorflow tf.gradients (https://www.tensorflow.org/versions/r0.12/api_docs/python/train.html#gradients) 'checkpoints' can either be - a list consisting of tensors from the forward pass of the neural net that we should re-use when calculating the gradients in the backward pass all other tensors that do not appear in this list will be re-computed - a string specifying how this list should be determined. currently we support - 'speed': checkpoint all outputs of convolutions and matmuls. these ops are usually the most expensive, so checkpointing them maximizes the running speed (this is a good option if nonlinearities, concats, batchnorms, etc are taking up a lot of memory) - 'memory': try to minimize the memory usage (currently using a very simple strategy that identifies a number of bottleneck tensors in the graph to checkpoint) - 'collection': look for a tensorflow collection named 'checkpoints', which holds the tensors to checkpoint ''' # print("Calling memsaving gradients with", checkpoints) if not isinstance(ys, list): ys = [ys] if not isinstance(xs, list): xs = [xs] bwd_ops = ge.get_backward_walk_ops([y.op for y in ys], inclusive=True) debug_print("bwd_ops: %s", bwd_ops) # forward ops are all ops that are candidates for recomputation fwd_ops = ge.get_forward_walk_ops([x.op for x in xs], inclusive=True, within_ops=bwd_ops) debug_print("fwd_ops: %s", fwd_ops) # exclude ops with no inputs fwd_ops = [op for op in fwd_ops if op.inputs] # don't recompute xs, remove variables xs_ops = _to_ops(xs) fwd_ops = [op for op in fwd_ops if not op in xs_ops] fwd_ops = [op for op in fwd_ops if not '/assign' in op.name] fwd_ops = [op for op in fwd_ops if not '/Assign' in op.name] fwd_ops = [op for op in fwd_ops if not '/read' in op.name] ts_all = ge.filter_ts(fwd_ops, True) # get the tensors ts_all = [t for t in ts_all if '/read' not in t.name] ts_all = set(ts_all) - set(xs) - set(ys) # construct list of tensors to checkpoint during forward pass, if not # given as input if type(checkpoints) is not list: if checkpoints == 'collection': checkpoints = tf.get_collection('checkpoints') elif checkpoints == 'speed': # checkpoint all expensive ops to maximize running speed checkpoints = ge.filter_ts_from_regex(fwd_ops, 'conv2d|Conv|MatMul') elif checkpoints == 'memory': # remove very small tensors and some weird ops def fixdims( t ): # tf.Dimension values are not compatible with int, convert manually try: return [int(e if e.value is not None else 64) for e in t] except: return [0] # unknown shape ts_all = [ t for t in ts_all if np.prod(fixdims(t.shape)) > MIN_CHECKPOINT_NODE_SIZE ] ts_all = [t for t in ts_all if 'L2Loss' not in t.name] ts_all = [t for t in ts_all if 'entropy' not in t.name] ts_all = [t for t in ts_all if 'FusedBatchNorm' not in t.name] ts_all = [t for t in ts_all if 'Switch' not in t.name] ts_all = [t for t in ts_all if 'dropout' not in t.name] # DV: FP16_FIX - need to add 'Cast' layer here to make it work for FP16 ts_all = [t for t in ts_all if 'Cast' not in t.name] # filter out all tensors that are inputs of the backward graph with util.capture_ops() as bwd_ops: tf_gradients(ys, xs, grad_ys, **kwargs) bwd_inputs = [t for op in bwd_ops for t in op.inputs] # list of tensors in forward graph that is in input to bwd graph ts_filtered = list(set(bwd_inputs).intersection(ts_all)) debug_print("Using tensors %s", ts_filtered) # try two slightly different ways of getting bottlenecks tensors # to checkpoint for ts in [ts_filtered, ts_all]: # get all bottlenecks in the graph bottleneck_ts = [] for t in ts: b = set( ge.get_backward_walk_ops(t.op, inclusive=True, within_ops=fwd_ops)) f = set( ge.get_forward_walk_ops(t.op, inclusive=False, within_ops=fwd_ops)) # check that there are not shortcuts b_inp = set([inp for op in b for inp in op.inputs]).intersection(ts_all) f_inp = set([inp for op in f for inp in op.inputs]).intersection(ts_all) if not set(b_inp).intersection( f_inp) and len(b_inp) + len(f_inp) >= len(ts_all): bottleneck_ts.append(t) # we have a bottleneck! else: debug_print( "Rejected bottleneck candidate and ops %s", [t] + list(set(ts_all) - set(b_inp) - set(f_inp))) # success? or try again without filtering? if len(bottleneck_ts) >= np.sqrt( len(ts_filtered)): # yes, enough bottlenecks found! break if not bottleneck_ts: raise Exception( 'unable to find bottleneck tensors! please provide checkpoint nodes manually, or use checkpoints="speed".' ) # sort the bottlenecks bottlenecks_sorted_lists = tf_toposort(bottleneck_ts, within_ops=fwd_ops) sorted_bottlenecks = [ t for ts in bottlenecks_sorted_lists for t in ts ] # save an approximately optimal number ~ sqrt(N) N = len(ts_filtered) if len(bottleneck_ts) <= np.ceil(np.sqrt(N)): checkpoints = sorted_bottlenecks else: step = int(np.ceil(len(bottleneck_ts) / np.sqrt(N))) checkpoints = sorted_bottlenecks[step::step] else: raise Exception('%s is unsupported input for "checkpoints"' % (checkpoints, )) checkpoints = list(set(checkpoints).intersection(ts_all)) # at this point automatic selection happened and checkpoints is list of nodes assert isinstance(checkpoints, list) debug_print("Checkpoint nodes used: %s", checkpoints) # better error handling of special cases # xs are already handled as checkpoint nodes, so no need to include them xs_intersect_checkpoints = set(xs).intersection(set(checkpoints)) if xs_intersect_checkpoints: debug_print("Warning, some input nodes are also checkpoint nodes: %s", xs_intersect_checkpoints) ys_intersect_checkpoints = set(ys).intersection(set(checkpoints)) debug_print("ys: %s, checkpoints: %s, intersect: %s", ys, checkpoints, ys_intersect_checkpoints) # saving an output node (ys) gives no benefit in memory while creating # new edge cases, exclude them if ys_intersect_checkpoints: debug_print( "Warning, some output nodes are also checkpoints nodes: %s", format_ops(ys_intersect_checkpoints)) # remove initial and terminal nodes from checkpoints list if present checkpoints = list(set(checkpoints) - set(ys) - set(xs)) # check that we have some nodes to checkpoint # if not checkpoints: # raise Exception('no checkpoints nodes found or given as input! ') # disconnect dependencies between checkpointed tensors checkpoints_disconnected = {} for x in checkpoints: if x.op and x.op.name is not None: grad_node = tf.stop_gradient(x, name=x.op.name + "_sg") else: grad_node = tf.stop_gradient(x) checkpoints_disconnected[x] = grad_node # partial derivatives to the checkpointed tensors and xs ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys], stop_at_ts=checkpoints, within_ops=fwd_ops) debug_print("Found %s ops to copy within fwd_ops %s, seed %s, stop_at %s", len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints) debug_print("ops_to_copy = %s", ops_to_copy) debug_print("Processing list %s", ys) copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {}) for origin_op, op in info._transformed_ops.items(): op._set_device(origin_op.node_def.device) copied_ops = info._transformed_ops.values() debug_print("Copied %s to %s", ops_to_copy, copied_ops) ge.reroute_ts(checkpoints_disconnected.values(), checkpoints_disconnected.keys(), can_modify=copied_ops) debug_print("Rewired %s in place of %s restricted to %s", checkpoints_disconnected.values(), checkpoints_disconnected.keys(), copied_ops) # get gradients with respect to current boundary + original x's copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys] boundary = list(checkpoints_disconnected.values()) dv = tf_gradients(ys=copied_ys, xs=boundary + xs, grad_ys=grad_ys, **kwargs) debug_print("Got gradients %s", dv) debug_print("for %s", copied_ys) debug_print("with respect to %s", boundary + xs) inputs_to_do_before = [y.op for y in ys] if grad_ys is not None: inputs_to_do_before += grad_ys wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) # partial derivatives to the checkpointed nodes # dictionary of "node: backprop" for nodes in the boundary d_checkpoints = { r: dr for r, dr in zip(checkpoints_disconnected.keys(), dv[:len(checkpoints_disconnected)]) } # partial derivatives to xs (usually the params of the neural net) d_xs = dv[len(checkpoints_disconnected):] # incorporate derivatives flowing through the checkpointed nodes checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops) for ts in checkpoints_sorted_lists[::-1]: debug_print("Processing list %s", ts) checkpoints_other = [r for r in checkpoints if r not in ts] checkpoints_disconnected_other = [ checkpoints_disconnected[r] for r in checkpoints_other ] # copy part of the graph below current checkpoint node, stopping at # other checkpoints nodes ops_to_copy = fast_backward_ops(within_ops=fwd_ops, seed_ops=[r.op for r in ts], stop_at_ts=checkpoints_other) debug_print("Found %s ops to copy within %s, seed %s, stop_at %s", len(ops_to_copy), fwd_ops, [r.op for r in ts], checkpoints_other) debug_print("ops_to_copy = %s", ops_to_copy) if not ops_to_copy: # we're done! break copied_sgv, info = ge.copy_with_input_replacements( ge.sgv(ops_to_copy), {}) for origin_op, op in info._transformed_ops.items(): op._set_device(origin_op.node_def.device) copied_ops = info._transformed_ops.values() debug_print("Copied %s to %s", ops_to_copy, copied_ops) ge.reroute_ts(checkpoints_disconnected_other, checkpoints_other, can_modify=copied_ops) debug_print("Rewired %s in place of %s restricted to %s", checkpoints_disconnected_other, checkpoints_other, copied_ops) # gradient flowing through the checkpointed node boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts] substitute_backprops = [d_checkpoints[r] for r in ts] dv = tf_gradients(boundary, checkpoints_disconnected_other + xs, grad_ys=substitute_backprops, **kwargs) debug_print("Got gradients %s", dv) debug_print("for %s", boundary) debug_print("with respect to %s", checkpoints_disconnected_other + xs) debug_print("with boundary backprop substitutions %s", substitute_backprops) inputs_to_do_before = [d_checkpoints[r].op for r in ts] wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) # partial derivatives to the checkpointed nodes for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]): if dr is not None: if d_checkpoints[r] is None: d_checkpoints[r] = dr else: d_checkpoints[r] += dr def _unsparsify(x): if not isinstance(x, tf.IndexedSlices): return x assert x.dense_shape is not None, "memory_saving_gradients encountered sparse gradients of unknown shape" indices = x.indices while indices.shape.ndims < x.values.shape.ndims: indices = tf.expand_dims(indices, -1) return tf.scatter_nd(indices, x.values, x.dense_shape) # partial derivatives to xs (usually the params of the neural net) d_xs_new = dv[len(checkpoints_other):] for j in range(len(xs)): if d_xs_new[j] is not None: if d_xs[j] is None: d_xs[j] = _unsparsify(d_xs_new[j]) else: d_xs[j] += _unsparsify(d_xs_new[j]) return d_xs
def apply(self, new_inputs, update_colocation_groups=True): assert len(new_inputs) == len(self.inputs) g = tf.get_default_graph() # todo: make that member variable new_inputs2 = [] # replace variable inputs with their read endpoint for input in new_inputs: if isinstance(input, tf.Variable): new_inputs2.append(input.read_value()) else: new_inputs2.append(input) new_inputs = new_inputs2 replacements = OrderedDict() for old_input_ts, new_input_ts in zip(self.inputs, new_inputs): # shape/dtype checks if isinstance(old_input_ts, (list, tuple)): reference_ts = old_input_ts[0] else: reference_ts = old_input_ts assert reference_ts.get_shape() == new_input_ts.get_shape() assert reference_ts.dtype == new_input_ts.dtype # Variable with multiple read endpoints, replace all of them with # new input tensor if isinstance(old_input_ts, (list, tuple)): for sub_input in old_input_ts: replacements[sub_input] = new_input_ts # regular Tensor else: replacements[old_input_ts] = new_input_ts # sanity checks # copying Variables is confusing because they don't get added # to GLOBAL_VARIABLES collection hence escape global initialization # therefore forbit it for op in self.ops: if op.type.startswith('Variable'): # 'VariableV2' or 'Variable' assert False, "Can't copy variables" # TODO: remove this def summarize_ts(ts): from collections import Counter type_counter = Counter() ops = set([tensor.op for tensor in ts]) print Counter([op.type for op in ops]).most_common(10) sgv = ge.sgv(self.ops) # import pdb; pdb.set_trace() copied_sgv, info = ge.copy_with_input_replacements(sgv, replacements) # converting between Python bytes and unicode def to_bytes(s): return s.encode('ascii') def from_bytes(s): return s.decode('ascii') # fix colocation constraints to point to copied ops # see https://github.com/tensorflow/tensorflow/issues/9925 if update_colocation_groups: new_ops = [info._transformed_ops[op] for op in self.ops] for new_op in new_ops: assert len(new_op.colocation_groups()) == 1 colocation_group = new_op.colocation_groups()[0] assert colocation_group.startswith(b'loc:@') colocated_with_name = from_bytes( colocation_group[len(b'loc:@'):]) # if there were no colocation constraints, the op gets colocated with # itself (default colocation group), ignore that constraint if colocated_with_name == new_op.name: continue colocation_op = g.get_operation_by_name(colocated_with_name) if colocation_op in info._transformed_ops: new_colocation_op = info._transformed_ops[colocation_op] else: assert colocation_op in self.input_ops colocation_op_idx = self.input_ops.index(colocation_op) new_colocation_op = new_inputs[colocation_op_idx].op # overwrite existing _class attribute with new colocation constraints new_colocation_groups = [ b'loc:@' + to_bytes(new_colocation_op.name) ] new_op.node_def.attr["_class"].CopyFrom( attr_value_pb2.AttrValue( list=attr_value_pb2.AttrValue.ListValue( s=new_colocation_groups))) # place new ops on device from current device context device = get_current_device() if device: for op in info._transformed_ops.values(): op._set_device(device) new_outputs = [] for old_output_ts in self.outputs: new_output_op = info._transformed_ops[old_output_ts.op] new_output_ts = new_output_op.outputs[0] new_outputs.append(new_output_ts) return new_outputs
def gradients(ys, xs, grad_ys=None, checkpoints='collection', **kwargs): print("-------------------------------") debug_print("Editing model for OME") incpu_count = 0 # print("Calling memsaving gradients with", checkpoints) if not isinstance(ys, list): ys = [ys] if not isinstance(xs, list): xs = [xs] bwd_ops = ge.get_backward_walk_ops([y.op for y in ys], inclusive=True) for index, op in enumerate(bwd_ops): debug_print("bwd_ops: [{}] :{}".format(index, op.name), 1) # forward ops are all ops that are candidates for recomputation fwd_ops = ge.get_forward_walk_ops([x.op for x in xs], inclusive=True, within_ops=bwd_ops) for index, op in enumerate(fwd_ops): debug_print("fwd_ops: [{}] : {}".format(index, op.name), 1) # exclude ops with no inputs fwd_ops = [op for op in fwd_ops if op.inputs] # don't recompute xs, remove variables xs_ops = _to_ops(xs) fwd_ops = [op for op in fwd_ops if not op in xs_ops] fwd_ops = [op for op in fwd_ops if not '/assign' in op.name] fwd_ops = [op for op in fwd_ops if not '/Assign' in op.name] fwd_ops = [op for op in fwd_ops if not '/read' in op.name] ts_all = ge.filter_ts(fwd_ops, True) # get the tensors ts_all = [t for t in ts_all if '/read' not in t.name] ts_all = set(ts_all) - set(xs) - set(ys) # construct list of tensors to checkpoint during forward pass, if not # given as input if type(checkpoints) is not list: # remove very small tensors and some weird ops def fixdims( t ): # tf.Dimension values are not compatible with int, convert manually try: return [int(e if e.value is not None else 64) for e in t] except: return [0] # unknown shape ts_all = [ t for t in ts_all if np.prod(fixdims(t.shape)) > MIN_CHECKPOINT_NODE_SIZE ] ts_all = [t for t in ts_all if 'L2Loss' not in t.name] ts_all = [t for t in ts_all if 'entropy' not in t.name] ts_all = [t for t in ts_all if 'FusedBatchNorm' not in t.name] ts_all = [t for t in ts_all if 'Switch' not in t.name] ts_all = [t for t in ts_all if 'dropout' not in t.name] # DV: FP16_FIX - need to add 'Cast' layer here to make it work for FP16 ts_all = [t for t in ts_all if 'Cast' not in t.name] # filter out all tensors that are inputs of the backward graph with util.capture_ops() as bwd_ops: tf_gradients(ys, xs, grad_ys, **kwargs) bwd_inputs = [t for op in bwd_ops for t in op.inputs] # list of tensors in forward graph that is in input to bwd graph ts_filtered = list(set(bwd_inputs).intersection(ts_all)) debug_print("Using tensors {}".format(ts_filtered), 1) # try two slightly different ways of getting bottlenecks tensors # to checkpoint for ts in [ts_filtered, ts_all]: # get all bottlenecks in the graph bottleneck_ts = [] for t in ts: b = set( ge.get_backward_walk_ops(t.op, inclusive=True, within_ops=fwd_ops)) f = set( ge.get_forward_walk_ops(t.op, inclusive=False, within_ops=fwd_ops)) # check that there are not shortcuts b_inp = set([inp for op in b for inp in op.inputs]).intersection(ts_all) f_inp = set([inp for op in f for inp in op.inputs]).intersection(ts_all) if not set(b_inp).intersection( f_inp) and len(b_inp) + len(f_inp) >= len(ts_all): bottleneck_ts.append(t) # we have a bottleneck! else: debug_print( "Rejected bottleneck candidate and ops {}".format( [t] + list(set(ts_all) - set(b_inp) - set(f_inp))), 2) # success? or try again without filtering? if len(bottleneck_ts) >= np.sqrt( len(ts_filtered)): # yes, enough bottlenecks found! break if not bottleneck_ts: raise Exception( 'unable to find bottleneck tensors! please provide checkpoint nodes manually, or use checkpoints="speed".' ) # sort the bottlenecks bottlenecks_sorted_lists = tf_toposort(bottleneck_ts, within_ops=fwd_ops) sorted_bottlenecks = [t for ts in bottlenecks_sorted_lists for t in ts] # save an approximately optimal number ~ sqrt(N) N = len(ts_filtered) if len(bottleneck_ts) <= np.ceil(np.sqrt(N)): checkpoints = sorted_bottlenecks else: step = int(np.ceil(len(bottleneck_ts) / np.sqrt(N))) checkpoints = sorted_bottlenecks[step::step] checkpoints = list(set(checkpoints).intersection(ts_all)) # at this point automatic selection happened and checkpoints is list of nodes assert isinstance(checkpoints, list) debug_print("Checkpoint nodes used: {}".format(checkpoints), 1) # better error handling of special cases # xs are already handled as checkpoint nodes, so no need to include them xs_intersect_checkpoints = set(xs).intersection(set(checkpoints)) if xs_intersect_checkpoints: debug_print( "Warning, some input nodes are also checkpoint nodes: {}".format( xs_intersect_checkpoints), 2) ys_intersect_checkpoints = set(ys).intersection(set(checkpoints)) debug_print( "ys: {}, checkpoints: {}, intersect: {}".format( ys, checkpoints, ys_intersect_checkpoints), 1) # saving an output node (ys) gives no benefit in memory while creating # new edge cases, exclude them if ys_intersect_checkpoints: debug_print( "Warning, some output nodes are also checkpoints nodes: {}".format( format_ops(ys_intersect_checkpoints)), 2) # remove initial and terminal nodes from checkpoints list if present checkpoints = list(set(checkpoints) - set(ys) - set(xs)) # check that we have some nodes to checkpoint if not checkpoints: raise Exception('no checkpoints nodes found or given as input! ') debug_print( "Select {} nodes to checkpoint nodes.".format(len(checkpoints)), 0) # disconnect dependencies between checkpointed tensors checkpoints_disconnected = {} for x in checkpoints: frontier_ops = set(graph_util.get_consuming_ops(x.op.outputs)) debug_print("my frontier ops: {}".format(frontier_ops), 1) bw_frontier_ops = frontier_ops & set(bwd_ops) debug_print("my bw frontier ops: {}".format(bw_frontier_ops), 1) if len(bw_frontier_ops) > 1: continue if x.op and x.op.name is not None: grad_node = tf.stop_gradient(x, name=x.op.name + "_sg") else: grad_node = tf.stop_gradient(x) swapout_op = _add_swapout(grad_node.op, grad_node.op.outputs) incpu_count = incpu_count + 1 swapin_op = _add_swapin(swapout_op, bw_frontier_ops, grad_node.op.outputs) checkpoints_disconnected[x] = swapin_op my_add_control_inputs(x, bw_frontier_ops, swapin_op) # control dependency -> swap_in # self._add_control_dependency(src_op, dest_op, swapin_op) # g = tf.get_default_graph() # print(g.get_operations()) # partial derivatives to the checkpointed tensors and xs ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys], stop_at_ts=checkpoints, within_ops=fwd_ops) debug_print( "Found {} ops to copy within fwd_ops {}, seed {}, stop_at {}".format( len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints), 1) debug_print("ops_to_copy = {}".format(ops_to_copy), 1) debug_print("Processing list {}".format(ys), 1) copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {}) for origin_op, op in info._transformed_ops.items(): op._set_device(origin_op.node_def.device) copied_ops = info._transformed_ops.values() debug_print("Copied {} to {}".format(ops_to_copy, copied_ops), 1) ge.reroute_ts(checkpoints_disconnected.values(), checkpoints_disconnected.keys(), can_modify=copied_ops) debug_print( "Rewired {} in place of {} restricted to {}".format( checkpoints_disconnected.values(), checkpoints_disconnected.keys(), copied_ops), 1) # get gradients with respect to current boundary + original x's copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys] boundary = list(checkpoints_disconnected.values()) dv = tf_gradients(ys=copied_ys, xs=boundary + xs, grad_ys=grad_ys, **kwargs) debug_print("Got gradients {}".format(dv), 1) debug_print("for {}".format(copied_ys), 1) debug_print("with respect to {}".format(boundary + xs), 1) inputs_to_do_before = [y.op for y in ys] if grad_ys is not None: inputs_to_do_before += grad_ys wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) # partial derivatives to the checkpointed nodes # dictionary of "node: backprop" for nodes in the boundary d_checkpoints = { r: dr for r, dr in zip(checkpoints_disconnected.keys(), dv[:len(checkpoints_disconnected)]) } # partial derivatives to xs (usually the params of the neural net) d_xs = dv[len(checkpoints_disconnected):] # incorporate derivatives flowing through the checkpointed nodes checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops) for ts in checkpoints_sorted_lists[::-1]: debug_print("Processing list {}".format(ts), 1) checkpoints_other = [r for r in checkpoints if r not in ts] checkpoints_disconnected_other = [ checkpoints_disconnected[r] for r in checkpoints_other ] # copy part of the graph below current checkpoint node, stopping at # other checkpoints nodes ops_to_copy = fast_backward_ops(within_ops=fwd_ops, seed_ops=[r.op for r in ts], stop_at_ts=checkpoints_other) debug_print( "Found {} ops to copy within {}, seed {}, stop_at {}".format( len(ops_to_copy), fwd_ops, [r.op for r in ts], checkpoints_other), 1) debug_print("ops_to_copy = {}".format(ops_to_copy), 1) if not ops_to_copy: # we're done! break copied_sgv, info = ge.copy_with_input_replacements( ge.sgv(ops_to_copy), {}) for origin_op, op in info._transformed_ops.items(): op._set_device(origin_op.node_def.device) copied_ops = info._transformed_ops.values() debug_print("Copied {} to {}".format(ops_to_copy, copied_ops), 1) ge.reroute_ts(checkpoints_disconnected_other, checkpoints_other, can_modify=copied_ops) debug_print( "Rewired {} in place of {} restricted to {}".format( checkpoints_disconnected_other, checkpoints_other, copied_ops), 1) # gradient flowing through the checkpointed node boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts] substitute_backprops = [d_checkpoints[r] for r in ts] dv = tf_gradients(boundary, checkpoints_disconnected_other + xs, grad_ys=substitute_backprops, **kwargs) debug_print("Got gradients {}".format(dv), 1) debug_print("for {}".format(boundary), 1) debug_print( "with respect to {}".format(checkpoints_disconnected_other + xs), 1) debug_print( "with boundary backprop substitutions {}".format( substitute_backprops), 1) inputs_to_do_before = [d_checkpoints[r].op for r in ts] wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) # partial derivatives to the checkpointed nodes for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]): if dr is not None: if d_checkpoints[r] is None: d_checkpoints[r] = dr else: d_checkpoints[r] += dr def _unsparsify(x): if not isinstance(x, tf.IndexedSlices): return x assert x.dense_shape is not None, "memory_saving_gradients encountered sparse gradients of unknown shape" indices = x.indices while indices.shape.ndims < x.values.shape.ndims: indices = tf.expand_dims(indices, -1) return tf.scatter_nd(indices, x.values, x.dense_shape) # partial derivatives to xs (usually the params of the neural net) d_xs_new = dv[len(checkpoints_other):] for j in range(len(xs)): if d_xs_new[j] is not None: if d_xs[j] is None: d_xs[j] = _unsparsify(d_xs_new[j]) else: d_xs[j] += _unsparsify(d_xs_new[j]) return d_xs
def gradients(ys, xs, grad_ys=None, checkpoints='collection', **kwargs): ''' Authors: Tim Salimans & Yaroslav Bulatov memory efficient gradient implementation inspired by "Training Deep Nets with Sublinear Memory Cost" by Chen et al. 2016 (https://arxiv.org/abs/1604.06174) ys,xs,grad_ys,kwargs are the arguments to standard tensorflow tf.gradients (https://www.tensorflow.org/versions/r0.12/api_docs/python/train.html#gradients) 'checkpoints' can either be - a list consisting of tensors from the forward pass of the neural net that we should re-use when calculating the gradients in the backward pass all other tensors that do not appear in this list will be re-computed - a string specifying how this list should be determined. currently we support - 'speed': checkpoint all outputs of convolutions and matmuls. these ops are usually the most expensive, so checkpointing them maximizes the running speed (this is a good option if nonlinearities, concats, batchnorms, etc are taking up a lot of memory) - 'memory': try to minimize the memory usage (currently using a very simple strategy that identifies a number of bottleneck tensors in the graph to checkpoint) - 'collection': look for a tensorflow collection named 'checkpoints', which holds the tensors to checkpoint ''' # print("Calling memsaving gradients with", checkpoints) if not isinstance(ys, list): ys = [ys] if not isinstance(xs, list): xs = [xs] bwd_ops = ge.get_backward_walk_ops([y.op for y in ys], inclusive=True) debug_print("bwd_ops: %s", bwd_ops) # forward ops are all ops that are candidates for recomputation fwd_ops = ge.get_forward_walk_ops([x.op for x in xs], inclusive=True, within_ops=bwd_ops) debug_print("fwd_ops: %s", fwd_ops) # exclude ops with no inputs fwd_ops = [op for op in fwd_ops if op.inputs] # don't recompute xs, remove variables xs_ops = _to_ops(xs) fwd_ops = [op for op in fwd_ops if not op in xs_ops] fwd_ops = [op for op in fwd_ops if not '/assign' in op.name] fwd_ops = [op for op in fwd_ops if not '/Assign' in op.name] fwd_ops = [op for op in fwd_ops if not '/read' in op.name] ts_all = ge.filter_ts(fwd_ops, True) # get the tensors ts_all = [t for t in ts_all if '/read' not in t.name] ts_all = set(ts_all) - set(xs) - set(ys) checkpoints = 'collection' # construct list of tensors to checkpoint during forward pass, if not # given as input stereo_checkpoints = ge.filter_ts_from_regex(fwd_ops, "add") motion_checkpoints = ge.filter_ts_from_regex(fwd_ops, "Conv2D") my_ckps = [] for x in motion_checkpoints: if ("motion" in x.name) and ("BatchNorm" not in x.name): my_ckps.append(x) for x in stereo_checkpoints: if ("stereo" in x.name) and ("BatchNorm" not in x.name): my_ckps.append(x) checkpoints = my_ckps checkpoints = list(set(checkpoints).intersection(ts_all)) # at this point automatic selection happened and checkpoints is list of nodes assert isinstance(checkpoints, list) debug_print("Checkpoint nodes used: %s", checkpoints) # better error handling of special cases # xs are already handled as checkpoint nodes, so no need to include them xs_intersect_checkpoints = set(xs).intersection(set(checkpoints)) if xs_intersect_checkpoints: debug_print("Warning, some input nodes are also checkpoint nodes: %s", xs_intersect_checkpoints) ys_intersect_checkpoints = set(ys).intersection(set(checkpoints)) debug_print("ys: %s, checkpoints: %s, intersect: %s", ys, checkpoints, ys_intersect_checkpoints) # saving an output node (ys) gives no benefit in memory while creating # new edge cases, exclude them if ys_intersect_checkpoints: debug_print( "Warning, some output nodes are also checkpoints nodes: %s", format_ops(ys_intersect_checkpoints)) # remove initial and terminal nodes from checkpoints list if present checkpoints = list(set(checkpoints) - set(ys) - set(xs)) # check that we have some nodes to checkpoint if not checkpoints: raise Exception('no checkpoints nodes found or given as input! ') # disconnect dependencies between checkpointed tensors checkpoints_disconnected = {} for x in checkpoints: if x.op and x.op.name is not None: grad_node = tf.stop_gradient(x, name=x.op.name + "_sg") else: grad_node = tf.stop_gradient(x) checkpoints_disconnected[x] = grad_node # partial derivatives to the checkpointed tensors and xs ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys], stop_at_ts=checkpoints, within_ops=fwd_ops) debug_print("Found %s ops to copy within fwd_ops %s, seed %s, stop_at %s", len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints) debug_print("ops_to_copy = %s", ops_to_copy) debug_print("Processing list %s", ys) copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {}) copied_ops = info._transformed_ops.values() debug_print("Copied %s to %s", ops_to_copy, copied_ops) ge.reroute_ts(checkpoints_disconnected.values(), checkpoints_disconnected.keys(), can_modify=copied_ops) debug_print("Rewired %s in place of %s restricted to %s", checkpoints_disconnected.values(), checkpoints_disconnected.keys(), copied_ops) # get gradients with respect to current boundary + original x's copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys] boundary = list(checkpoints_disconnected.values()) dv = tf_gradients(ys=copied_ys, xs=boundary + xs, grad_ys=grad_ys, **kwargs) debug_print("Got gradients %s", dv) debug_print("for %s", copied_ys) debug_print("with respect to %s", boundary + xs) inputs_to_do_before = [y.op for y in ys] if grad_ys is not None: inputs_to_do_before += grad_ys wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) # partial derivatives to the checkpointed nodes # dictionary of "node: backprop" for nodes in the boundary d_checkpoints = { r: dr for r, dr in zip(checkpoints_disconnected.keys(), dv[:len(checkpoints_disconnected)]) } # partial derivatives to xs (usually the params of the neural net) d_xs = dv[len(checkpoints_disconnected):] # incorporate derivatives flowing through the checkpointed nodes checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops) for ts in checkpoints_sorted_lists[::-1]: debug_print("Processing list %s", ts) checkpoints_other = [r for r in checkpoints if r not in ts] checkpoints_disconnected_other = [ checkpoints_disconnected[r] for r in checkpoints_other ] # copy part of the graph below current checkpoint node, stopping at # other checkpoints nodes ops_to_copy = fast_backward_ops(within_ops=fwd_ops, seed_ops=[r.op for r in ts], stop_at_ts=checkpoints_other) debug_print("Found %s ops to copy within %s, seed %s, stop_at %s", len(ops_to_copy), fwd_ops, [r.op for r in ts], checkpoints_other) debug_print("ops_to_copy = %s", ops_to_copy) if not ops_to_copy: # we're done! break copied_sgv, info = ge.copy_with_input_replacements( ge.sgv(ops_to_copy), {}) copied_ops = info._transformed_ops.values() debug_print("Copied %s to %s", ops_to_copy, copied_ops) ge.reroute_ts(checkpoints_disconnected_other, checkpoints_other, can_modify=copied_ops) debug_print("Rewired %s in place of %s restricted to %s", checkpoints_disconnected_other, checkpoints_other, copied_ops) # gradient flowing through the checkpointed node boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts] substitute_backprops = [d_checkpoints[r] for r in ts] dv = tf_gradients(boundary, checkpoints_disconnected_other + xs, grad_ys=substitute_backprops, **kwargs) debug_print("Got gradients %s", dv) debug_print("for %s", boundary) debug_print("with respect to %s", checkpoints_disconnected_other + xs) debug_print("with boundary backprop substitutions %s", substitute_backprops) inputs_to_do_before = [d_checkpoints[r].op for r in ts] wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) # partial derivatives to the checkpointed nodes for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]): if dr is not None: if d_checkpoints[r] is None: d_checkpoints[r] = dr else: d_checkpoints[r] += dr # partial derivatives to xs (usually the params of the neural net) d_xs_new = dv[len(checkpoints_other):] for j in range(len(xs)): if d_xs_new[j] is not None: if d_xs[j] is None: d_xs[j] = d_xs_new[j] else: d_xs[j] += d_xs_new[j] return d_xs
def apply(self, new_inputs, update_colocation_groups=True): assert len(new_inputs) == len(self.inputs) g = tf.get_default_graph() # todo: make that member variable new_inputs2 = [] # replace variable inputs with their read endpoint for input in new_inputs: if isinstance(input, tf.Variable): new_inputs2.append(input.read_value()) else: new_inputs2.append(input) new_inputs = new_inputs2 replacements = OrderedDict() for old_input_ts, new_input_ts in zip(self.inputs, new_inputs): # shape/dtype checks if isinstance(old_input_ts, (list, tuple)): reference_ts = old_input_ts[0] else: reference_ts = old_input_ts assert reference_ts.get_shape() == new_input_ts.get_shape() assert reference_ts.dtype == new_input_ts.dtype # Variable with multiple read endpoints, replace all of them with # new input tensor if isinstance(old_input_ts, (list, tuple)): for sub_input in old_input_ts: replacements[sub_input] = new_input_ts # regular Tensor else: replacements[old_input_ts] = new_input_ts # sanity checks # copying Variables is confusing because they don't get added # to GLOBAL_VARIABLES collection hence escape global initialization # therefore forbit it for op in self.ops: if op.type.startswith('Variable'): # 'VariableV2' or 'Variable' assert False, "Can't copy variables" # TODO: remove this def summarize_ts(ts): from collections import Counter type_counter = Counter() ops = set([tensor.op for tensor in ts]) print Counter([op.type for op in ops]).most_common(10) sgv = ge.sgv(self.ops) # import pdb; pdb.set_trace() copied_sgv, info = ge.copy_with_input_replacements(sgv, replacements) # converting between Python bytes and unicode def to_bytes(s): return s.encode('ascii') def from_bytes(s): return s.decode('ascii') # fix colocation constraints to point to copied ops # see https://github.com/tensorflow/tensorflow/issues/9925 if update_colocation_groups: new_ops = [info._transformed_ops[op] for op in self.ops] for new_op in new_ops: assert len(new_op.colocation_groups()) == 1 colocation_group = new_op.colocation_groups()[0] assert colocation_group.startswith(b'loc:@') colocated_with_name = from_bytes(colocation_group[len(b'loc:@'):]) # if there were no colocation constraints, the op gets colocated with # itself (default colocation group), ignore that constraint if colocated_with_name == new_op.name: continue colocation_op = g.get_operation_by_name(colocated_with_name) if colocation_op in info._transformed_ops: new_colocation_op = info._transformed_ops[colocation_op] else: assert colocation_op in self.input_ops colocation_op_idx = self.input_ops.index(colocation_op) new_colocation_op = new_inputs[colocation_op_idx].op # overwrite existing _class attribute with new colocation constraints new_colocation_groups = [b'loc:@'+to_bytes(new_colocation_op.name)] new_op.node_def.attr["_class"].CopyFrom(attr_value_pb2.AttrValue( list=attr_value_pb2.AttrValue.ListValue(s=new_colocation_groups))) # place new ops on device from current device context device = get_current_device() if device: for op in info._transformed_ops.values(): op._set_device(device) new_outputs = [] for old_output_ts in self.outputs: new_output_op = info._transformed_ops[old_output_ts.op] new_output_ts = new_output_op.outputs[0] new_outputs.append(new_output_ts) return new_outputs
def gradients(ys, xs, # pylint: disable: too-many-statements, too-many-branches grad_ys=None, checkpoints='collection', **kwargs): ''' Authors: Tim Salimans & Yaroslav Bulatov memory efficient gradient implementation inspired by "Training Deep Nets with Sublinear Memory Cost" by Chen et al. 2016 (https://arxiv.org/abs/1604.06174) ys,xs,grad_ys,kwargs are the arguments to standard tensorflow tf.gradients (https://www.tensorflow.org/versions/r0.12/api_docs/python/train.html#gradients) 'checkpoints' can either be - a list consisting of tensors from the forward pass of the neural net that we should re-use when calculating the gradients in the backward pass all other tensors that do not appear in this list will be re-computed - a string specifying how this list should be determined. currently we support - 'speed': checkpoint all outputs of convolutions and matmuls. these ops are usually the most expensive, so checkpointing them maximizes the running speed (this is a good option if nonlinearities, concats, batchnorms, etc are taking up a lot of memory) - 'memory': try to minimize the memory usage (currently using a very simple strategy that identifies a number of bottleneck tensors in the graph to checkpoint) - 'collection': look for a tensorflow collection named 'checkpoints', which holds the tensors to checkpoint ''' # print("Calling memsaving gradients with", checkpoints) if not isinstance(ys, list): ys = [ys] if not isinstance(xs, list): xs = [xs] bwd_ops = ge.get_backward_walk_ops([y.op for y in ys], inclusive=True) debug_print("bwd_ops: {}".format(bwd_ops)) # forward ops are all ops that are candidates for recomputation fwd_ops = ge.get_forward_walk_ops([x.op for x in xs], inclusive=True, within_ops=bwd_ops) debug_print("fwd_ops: {}".format(fwd_ops)) # exclude ops with no inputs fwd_ops = [op for op in fwd_ops if op.inputs] # don't recompute xs, remove variables xs_ops = _to_ops(xs) fwd_ops = [op for op in fwd_ops if op not in xs_ops] fwd_ops = [op for op in fwd_ops if '/assign' not in op.name] fwd_ops = [op for op in fwd_ops if '/Assign' not in op.name] fwd_ops = [op for op in fwd_ops if '/read' not in op.name] ts_all = ge.filter_ts(fwd_ops, True) # get the tensors ts_all = [t for t in ts_all if '/read' not in t.name] ts_all = set(ts_all) - set(xs) - set(ys) # construct list of tensors to checkpoint during forward pass, if not # given as input if type(checkpoints) is not list: if checkpoints == 'collection': checkpoints = tf.get_collection('checkpoints') elif checkpoints == 'speed': # checkpoint all expensive ops to maximize running speed checkpoints = ge.filter_ts_from_regex(fwd_ops, 'conv2d|Conv|MatMul') elif checkpoints == 'memory': # remove very small tensors and some weird ops def fixdims(t): # tf.Dimension values are not compatible with int, convert manually try: return [int(e if e.value is not None else 64) for e in t] except: return [0] # unknown shape ts_all = [t for t in ts_all if np.prod(fixdims(t.shape)) > MIN_CHECKPOINT_NODE_SIZE] ts_all = [t for t in ts_all if 'L2Loss' not in t.name] ts_all = [t for t in ts_all if 'entropy' not in t.name] ts_all = [t for t in ts_all if 'FusedBatchNorm' not in t.name] ts_all = [t for t in ts_all if 'Switch' not in t.name] ts_all = [t for t in ts_all if 'dropout' not in t.name] # DV: FP16_FIX - need to add 'Cast' layer here to make it work for FP16 ts_all = [t for t in ts_all if 'Cast' not in t.name] # filter out all tensors that are inputs of the backward graph with util.capture_ops() as bwd_ops: tf_gradients(ys, xs, grad_ys, **kwargs) bwd_inputs = [t for op in bwd_ops for t in op.inputs] # list of tensors in forward graph that is in input to bwd graph ts_filtered = list(set(bwd_inputs).intersection(ts_all)) debug_print("Using tensors {}".format(ts_filtered)) # try two slightly different ways of getting bottlenecks tensors # to checkpoint for ts in [ts_filtered, ts_all]: # get all bottlenecks in the graph bottleneck_ts = [] for t in ts: b = set(ge.get_backward_walk_ops(t.op, inclusive=True, within_ops=fwd_ops)) f = set(ge.get_forward_walk_ops(t.op, inclusive=False, within_ops=fwd_ops)) # check that there are not shortcuts b_inp = set([inp for op in b for inp in op.inputs]).intersection(ts_all) f_inp = set([inp for op in f for inp in op.inputs]).intersection(ts_all) if not set(b_inp).intersection(f_inp) and len(b_inp)+len(f_inp) >= len(ts_all): bottleneck_ts.append(t) # we have a bottleneck! else: debug_print("Rejected bottleneck candidate and ops {}".format( [t] + list(set(ts_all) - set(b_inp) - set(f_inp)))) # success? or try again without filtering? if len(bottleneck_ts) >= np.sqrt(len(ts_filtered)): # enough bottlenecks found! break if not bottleneck_ts: raise Exception('unable to find bottleneck tensors! please provide checkpoint ' 'nodes manually, or use checkpoints="speed".') # sort the bottlenecks bottlenecks_sorted_lists = tf_toposort(bottleneck_ts, within_ops=fwd_ops) sorted_bottlenecks = [t for ts in bottlenecks_sorted_lists for t in ts] # save an approximately optimal number ~ sqrt(N) N = len(ts_filtered) if len(bottleneck_ts) <= np.ceil(np.sqrt(N)): checkpoints = sorted_bottlenecks else: step = int(np.ceil(len(bottleneck_ts) / np.sqrt(N))) checkpoints = sorted_bottlenecks[step::step] else: raise Exception('%s is unsupported input for "checkpoints"' % (checkpoints,)) checkpoints = list(set(checkpoints).intersection(ts_all)) # at this point automatic selection happened and checkpoints is list of nodes assert isinstance(checkpoints, list) debug_print("Checkpoint nodes used: {}".format(checkpoints)) # better error handling of special cases # xs are already handled as checkpoint nodes, so no need to include them xs_intersect_checkpoints = set(xs).intersection(set(checkpoints)) if xs_intersect_checkpoints: debug_print("Warning, some input nodes are also checkpoint nodes: {}".format( xs_intersect_checkpoints)) ys_intersect_checkpoints = set(ys).intersection(set(checkpoints)) debug_print("ys: {}, checkpoints:{}, intersect: {}".format( ys, checkpoints, ys_intersect_checkpoints)) # saving an output node (ys) gives no benefit in memory while creating # new edge cases, exclude them if ys_intersect_checkpoints: debug_print("Warning, some output nodes are also checkpoints nodes: {}".format( format_ops(ys_intersect_checkpoints))) # remove initial and terminal nodes from checkpoints list if present checkpoints = list(set(checkpoints) - set(ys) - set(xs)) # check that we have some nodes to checkpoint if not checkpoints: raise Exception('no checkpoints nodes found or given as input! ') # disconnect dependencies between checkpointed tensors checkpoints_disconnected = {} for x in checkpoints: if x.op and x.op.name is not None: grad_node = tf.stop_gradient(x, name=x.op.name+"_sg") else: grad_node = tf.stop_gradient(x) checkpoints_disconnected[x] = grad_node # partial derivatives to the checkpointed tensors and xs ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys], stop_at_ts=checkpoints, within_ops=fwd_ops) debug_print("Found {} ops to copy within fwd_ops {}, seed {}, stop_at {}".format( len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints)) debug_print("ops_to_copy = {}".format(ops_to_copy)) debug_print("Processing list {}".format(ys)) _, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {}) for origin_op, op in info._transformed_ops.items(): op._set_device(origin_op.node_def.device) copied_ops = info._transformed_ops.values() debug_print("Copied {} to {}".format(ops_to_copy, copied_ops)) ge.reroute_ts(checkpoints_disconnected.values(), checkpoints_disconnected.keys(), can_modify=copied_ops) debug_print("Rewired {} in place of {} restricted to {}".format( checkpoints_disconnected.values(), checkpoints_disconnected.keys(), copied_ops)) # get gradients with respect to current boundary + original x's copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys] boundary = list(checkpoints_disconnected.values()) dv = tf_gradients(ys=copied_ys, xs=boundary+xs, grad_ys=grad_ys, **kwargs) debug_print("Got gradients {}".format(dv)) debug_print("for %s", copied_ys) debug_print("with respect to {}".format(boundary+xs)) inputs_to_do_before = [y.op for y in ys] if grad_ys is not None: inputs_to_do_before += grad_ys wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) # partial derivatives to the checkpointed nodes # dictionary of "node: backprop" for nodes in the boundary d_checkpoints = {r: dr for r, dr in zip(checkpoints_disconnected.keys(), dv[:len(checkpoints_disconnected)])} # partial derivatives to xs (usually the params of the neural net) d_xs = dv[len(checkpoints_disconnected):] # incorporate derivatives flowing through the checkpointed nodes checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops) for ts in checkpoints_sorted_lists[::-1]: debug_print("Processing list {}".format(ts)) checkpoints_other = [r for r in checkpoints if r not in ts] checkpoints_disconnected_other = [checkpoints_disconnected[r] for r in checkpoints_other] # copy part of the graph below current checkpoint node, stopping at # other checkpoints nodes ops_to_copy = fast_backward_ops(within_ops=fwd_ops, seed_ops=[r.op for r in ts], stop_at_ts=checkpoints_other) debug_print("Found {} ops to copy within {}, seed {}, stop_at {}".format( len(ops_to_copy), fwd_ops, [r.op for r in ts], checkpoints_other)) debug_print("ops_to_copy = {}".format(ops_to_copy)) if not ops_to_copy: # we're done! break _, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {}) for origin_op, op in info._transformed_ops.items(): op._set_device(origin_op.node_def.device) copied_ops = info._transformed_ops.values() debug_print("Copied {} to {}".format(ops_to_copy, copied_ops)) ge.reroute_ts(checkpoints_disconnected_other, checkpoints_other, can_modify=copied_ops) debug_print("Rewired %s in place of %s restricted to %s", checkpoints_disconnected_other, checkpoints_other, copied_ops) # gradient flowing through the checkpointed node boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts] substitute_backprops = [d_checkpoints[r] for r in ts] dv = tf_gradients(boundary, checkpoints_disconnected_other+xs, grad_ys=substitute_backprops, **kwargs) debug_print("Got gradients {}".format(dv)) debug_print("for {}".format(boundary)) debug_print("with respect to {}".format(checkpoints_disconnected_other+xs)) debug_print("with boundary backprop substitutions {}".format(substitute_backprops)) inputs_to_do_before = [d_checkpoints[r].op for r in ts] wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) # partial derivatives to the checkpointed nodes for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]): if dr is not None: if d_checkpoints[r] is None: d_checkpoints[r] = dr else: d_checkpoints[r] += dr def _unsparsify(var_x): if not isinstance(var_x, tf.IndexedSlices): return var_x assert var_x.dense_shape is not None, \ "memory_saving_gradients encountered sparse gradients of unknown shape" indices = var_x.indices while indices.shape.ndims < var_x.values.shape.ndims: indices = tf.expand_dims(indices, -1) return tf.scatter_nd(indices, var_x.values, var_x.dense_shape) # partial derivatives to xs (usually the params of the neural net) d_xs_new = dv[len(checkpoints_other):] for j in range(len(xs)): if d_xs_new[j] is not None: if d_xs[j] is None: d_xs[j] = _unsparsify(d_xs_new[j]) else: d_xs[j] += _unsparsify(d_xs_new[j]) return d_xs
def run(self): checkpoints = self._get_checkpoint() # at this point automatic selection happened and checkpoints is list of nodes assert isinstance(checkpoints, list) self._log_info("Checkpoint nodes used: {}".format(checkpoints), 1) # better error handling of special cases # xs are already handled as checkpoint nodes, so no need to include them xs_intersect_checkpoints = set(self._xs).intersection(set(checkpoints)) if xs_intersect_checkpoints: self._log_info( "Warning, some input nodes are also checkpoint nodes: %s". format(xs_intersect_checkpoints)) ys_intersect_checkpoints = set(self._ys).intersection(set(checkpoints)) self._log_info( "ys: %s, checkpoints: {}, intersect: {}".format( self._ys, checkpoints, ys_intersect_checkpoints), 1) # saving an output node (ys) gives no benefit in memory while creating # new edge cases, exclude them if ys_intersect_checkpoints: self._log_info( "Warning, some output nodes are also checkpoints nodes: {}". format(self.format_ops(ys_intersect_checkpoints))) # remove initial and terminal nodes from checkpoints list if present checkpoints = list(set(checkpoints) - set(self._ys) - set(self._xs)) # check that we have some nodes to checkpoint if not checkpoints: raise Exception('no checkpoints nodes found or given as input! ') # disconnect dependencies between checkpointed tensors checkpoints_disconnected = {} for x in checkpoints: if x.op and x.op.name is not None: grad_node = tf.stop_gradient(x, name=x.op.name + "_sg") else: grad_node = tf.stop_gradient(x) checkpoints_disconnected[x] = grad_node # partial derivatives to the checkpointed tensors and xs ops_to_copy = self.fast_backward_ops(seed_ops=[y.op for y in self._ys], stop_at_ts=checkpoints, within_ops=self.fwd_ops) self._log_info( "Found %s ops to copy within fwd_ops {}, seed {}, stop_at {}". format(len(ops_to_copy), self.fwd_ops, [r.op for r in self._ys], checkpoints), 1) self._log_info("ops_to_copy = {}".format(ops_to_copy, 1)) self._log_info("Processing list {}".format(self._ys), 1) copied_sgv, info = ge.copy_with_input_replacements( ge.sgv(ops_to_copy), {}) for origin_op, op in info._transformed_ops.items(): op._set_device(origin_op.node_def.device) copied_ops = info._transformed_ops.values() self._log_info("Copied {} to {}".format(ops_to_copy, copied_ops, 1)) ge.reroute_ts(checkpoints_disconnected.values(), checkpoints_disconnected.keys(), can_modify=copied_ops) self._log_info( "Rewired %s in place of {} restricted to {}".format( checkpoints_disconnected.values(), checkpoints_disconnected.keys(), copied_ops), 1) # get gradients with respect to current boundary + original x's copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in self._ys] boundary = list(checkpoints_disconnected.values()) dv = tf_gradients(ys=copied_ys, xs=boundary + self._xs, grad_ys=self._grad_ys, **self._kwargs) self._log_info("Got gradients {}".format(dv), 1) self._log_info("for {}".format(copied_ys), 1) self._log_info("with respect to {}".format(boundary + self._xs), 1) inputs_to_do_before = [y.op for y in self._ys] if self._grad_ys is not None: inputs_to_do_before += self._grad_ys wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] self.my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) # partial derivatives to the checkpointed nodes # dictionary of "node: backprop" for nodes in the boundary d_checkpoints = { r: dr for r, dr in zip(checkpoints_disconnected.keys(), dv[:len(checkpoints_disconnected)]) } # partial derivatives to xs (usually the params of the neural net) d_xs = dv[len(checkpoints_disconnected):] # incorporate derivatives flowing through the checkpointed nodes checkpoints_sorted_lists = self.tf_toposort(checkpoints, within_ops=self.fwd_ops) for ts in checkpoints_sorted_lists[::-1]: self._log_info("Processing list {}".format(ts), 1) checkpoints_other = [r for r in checkpoints if r not in ts] checkpoints_disconnected_other = [ checkpoints_disconnected[r] for r in checkpoints_other ] # copy part of the graph below current checkpoint node, stopping at # other checkpoints nodes ops_to_copy = self.fast_backward_ops(within_ops=self.fwd_ops, seed_ops=[r.op for r in ts], stop_at_ts=checkpoints_other) self._log_info( "Found {} ops to copy within {}, seed {}, stop_at {}".format( len(ops_to_copy), self.fwd_ops, [r.op for r in ts], checkpoints_other), 1) self._log_info("ops_to_copy = {}".format(ops_to_copy), 1) if not ops_to_copy: # we're done! break copied_sgv, info = ge.copy_with_input_replacements( ge.sgv(ops_to_copy), {}) for origin_op, op in info._transformed_ops.items(): op._set_device(origin_op.node_def.device) copied_ops = info._transformed_ops.values() self._log_info("Copied {} to {}".format(ops_to_copy, copied_ops), 1) ge.reroute_ts(checkpoints_disconnected_other, checkpoints_other, can_modify=copied_ops) self._log_info( "Rewired {} in place of {} restricted to {}".format( checkpoints_disconnected_other, checkpoints_other, copied_ops), 1) # gradient flowing through the checkpointed node boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts] substitute_backprops = [d_checkpoints[r] for r in ts] dv = tf_gradients(boundary, checkpoints_disconnected_other + self._xs, grad_ys=substitute_backprops, **self._kwargs) self._log_info("Got gradients {}".format(dv), 1) self._log_info("for {}".format(boundary), 1) self._log_info( "with respect to {}".format(checkpoints_disconnected_other + self._xs), 1) self._log_info( "with boundary backprop substitutions {}".format( substitute_backprops), 1) inputs_to_do_before = [d_checkpoints[r].op for r in ts] wait_to_do_ops = list(copied_ops) + [ g.op for g in dv if g is not None ] self.my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) # partial derivatives to the checkpointed nodes for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]): if dr is not None: if d_checkpoints[r] is None: d_checkpoints[r] = dr else: d_checkpoints[r] += dr def _unsparsify(x): if not isinstance(x, tf.IndexedSlices): return x assert x.dense_shape is not None, "memory_saving_gradients encountered sparse gradients of unknown shape" indices = x.indices while indices.shape.ndims < x.values.shape.ndims: indices = tf.expand_dims(indices, -1) return tf.scatter_nd(indices, x.values, x.dense_shape) # partial derivatives to xs (usually the params of the neural net) d_xs_new = dv[len(checkpoints_other):] for j in range(len(self._xs)): if d_xs_new[j] is not None: if d_xs[j] is None: d_xs[j] = _unsparsify(d_xs_new[j]) else: d_xs[j] += _unsparsify(d_xs_new[j]) return d_xs