def __init__(self, inputs, outputs, within_ops=None): assert isinstance(inputs, (list, tuple)) assert isinstance(outputs, (list, tuple)) g = tf.get_default_graph() self.g = g # special handling for Variable objects because graph_editor only # supports Tensor objects. Instead of replacing Variable with new inputs # replace all of its read endpoints. # # Find all Identity ops connected to variable op, replace them all in # tandem on apply. # [tensor, tensor, variable] becomes # [tensor, tensor, (var_read1, var_read2)] new_inputs = [] for input in inputs: if isinstance(input, tf.Variable): # find all the read endpoints of the variable read_tensors = [] for consumer in input.op.outputs[0].consumers(): if consumer.type == 'Identity': read_tensors.append(consumer.outputs[0]) new_inputs.append(read_tensors) else: new_inputs.append(input) inputs = new_inputs self.inputs = inputs self.input_ops = [tensor.op for tensor in flatten1(inputs)] self.outputs = outputs # only support 1 output for now, need extra logic in apply for output_ts in outputs: num_siblings = len(output_ts.op.outputs)-1 assert num_siblings == 0 # obtain part of graph that's needed output_ops = [ts.op for ts in self.outputs] self.ops = ge.get_backward_walk_ops(output_ops, inclusive=True, stop_at_ts=flatten1(self.inputs), within_ops=within_ops) # workaround for https://github.com/tensorflow/tensorflow/issues/9978 clear_original_ops(self.ops)
def dependency_of_targets(targets, op): """ Check that op is in the subgraph induced by the dependencies of targets. The result is memoized. This is useful if some SessionRunHooks should be run only together with certain ops. Args: targets: a tuple of ops or tensors. The targets to find dependencies of. op (tf.Operation or tf.Tensor): Returns: bool """ # TODO tensorarray? sparsetensor? if isinstance(op, tf.Tensor): op = op.op assert isinstance(op, tf.Operation), op # alternative implementation can use graph_util.extract_sub_graph dependent_ops = get_backward_walk_ops(targets, control_inputs=True) return op in dependent_ops
def gradients(ys, xs, # pylint: disable: too-many-statements, too-many-branches grad_ys=None, checkpoints='collection', **kwargs): ''' Authors: Tim Salimans & Yaroslav Bulatov memory efficient gradient implementation inspired by "Training Deep Nets with Sublinear Memory Cost" by Chen et al. 2016 (https://arxiv.org/abs/1604.06174) ys,xs,grad_ys,kwargs are the arguments to standard tensorflow tf.gradients (https://www.tensorflow.org/versions/r0.12/api_docs/python/train.html#gradients) 'checkpoints' can either be - a list consisting of tensors from the forward pass of the neural net that we should re-use when calculating the gradients in the backward pass all other tensors that do not appear in this list will be re-computed - a string specifying how this list should be determined. currently we support - 'speed': checkpoint all outputs of convolutions and matmuls. these ops are usually the most expensive, so checkpointing them maximizes the running speed (this is a good option if nonlinearities, concats, batchnorms, etc are taking up a lot of memory) - 'memory': try to minimize the memory usage (currently using a very simple strategy that identifies a number of bottleneck tensors in the graph to checkpoint) - 'collection': look for a tensorflow collection named 'checkpoints', which holds the tensors to checkpoint ''' # print("Calling memsaving gradients with", checkpoints) if not isinstance(ys, list): ys = [ys] if not isinstance(xs, list): xs = [xs] bwd_ops = ge.get_backward_walk_ops([y.op for y in ys], inclusive=True) debug_print("bwd_ops: {}".format(bwd_ops)) # forward ops are all ops that are candidates for recomputation fwd_ops = ge.get_forward_walk_ops([x.op for x in xs], inclusive=True, within_ops=bwd_ops) debug_print("fwd_ops: {}".format(fwd_ops)) # exclude ops with no inputs fwd_ops = [op for op in fwd_ops if op.inputs] # don't recompute xs, remove variables xs_ops = _to_ops(xs) fwd_ops = [op for op in fwd_ops if op not in xs_ops] fwd_ops = [op for op in fwd_ops if '/assign' not in op.name] fwd_ops = [op for op in fwd_ops if '/Assign' not in op.name] fwd_ops = [op for op in fwd_ops if '/read' not in op.name] ts_all = ge.filter_ts(fwd_ops, True) # get the tensors ts_all = [t for t in ts_all if '/read' not in t.name] ts_all = set(ts_all) - set(xs) - set(ys) # construct list of tensors to checkpoint during forward pass, if not # given as input if type(checkpoints) is not list: if checkpoints == 'collection': checkpoints = tf.get_collection('checkpoints') elif checkpoints == 'speed': # checkpoint all expensive ops to maximize running speed checkpoints = ge.filter_ts_from_regex(fwd_ops, 'conv2d|Conv|MatMul') elif checkpoints == 'memory': # remove very small tensors and some weird ops def fixdims(t): # tf.Dimension values are not compatible with int, convert manually try: return [int(e if e.value is not None else 64) for e in t] except: return [0] # unknown shape ts_all = [t for t in ts_all if np.prod(fixdims(t.shape)) > MIN_CHECKPOINT_NODE_SIZE] ts_all = [t for t in ts_all if 'L2Loss' not in t.name] ts_all = [t for t in ts_all if 'entropy' not in t.name] ts_all = [t for t in ts_all if 'FusedBatchNorm' not in t.name] ts_all = [t for t in ts_all if 'Switch' not in t.name] ts_all = [t for t in ts_all if 'dropout' not in t.name] # DV: FP16_FIX - need to add 'Cast' layer here to make it work for FP16 ts_all = [t for t in ts_all if 'Cast' not in t.name] # filter out all tensors that are inputs of the backward graph with util.capture_ops() as bwd_ops: tf_gradients(ys, xs, grad_ys, **kwargs) bwd_inputs = [t for op in bwd_ops for t in op.inputs] # list of tensors in forward graph that is in input to bwd graph ts_filtered = list(set(bwd_inputs).intersection(ts_all)) debug_print("Using tensors {}".format(ts_filtered)) # try two slightly different ways of getting bottlenecks tensors # to checkpoint for ts in [ts_filtered, ts_all]: # get all bottlenecks in the graph bottleneck_ts = [] for t in ts: b = set(ge.get_backward_walk_ops(t.op, inclusive=True, within_ops=fwd_ops)) f = set(ge.get_forward_walk_ops(t.op, inclusive=False, within_ops=fwd_ops)) # check that there are not shortcuts b_inp = set([inp for op in b for inp in op.inputs]).intersection(ts_all) f_inp = set([inp for op in f for inp in op.inputs]).intersection(ts_all) if not set(b_inp).intersection(f_inp) and len(b_inp)+len(f_inp) >= len(ts_all): bottleneck_ts.append(t) # we have a bottleneck! else: debug_print("Rejected bottleneck candidate and ops {}".format( [t] + list(set(ts_all) - set(b_inp) - set(f_inp)))) # success? or try again without filtering? if len(bottleneck_ts) >= np.sqrt(len(ts_filtered)): # enough bottlenecks found! break if not bottleneck_ts: raise Exception('unable to find bottleneck tensors! please provide checkpoint ' 'nodes manually, or use checkpoints="speed".') # sort the bottlenecks bottlenecks_sorted_lists = tf_toposort(bottleneck_ts, within_ops=fwd_ops) sorted_bottlenecks = [t for ts in bottlenecks_sorted_lists for t in ts] # save an approximately optimal number ~ sqrt(N) N = len(ts_filtered) if len(bottleneck_ts) <= np.ceil(np.sqrt(N)): checkpoints = sorted_bottlenecks else: step = int(np.ceil(len(bottleneck_ts) / np.sqrt(N))) checkpoints = sorted_bottlenecks[step::step] else: raise Exception('%s is unsupported input for "checkpoints"' % (checkpoints,)) checkpoints = list(set(checkpoints).intersection(ts_all)) # at this point automatic selection happened and checkpoints is list of nodes assert isinstance(checkpoints, list) debug_print("Checkpoint nodes used: {}".format(checkpoints)) # better error handling of special cases # xs are already handled as checkpoint nodes, so no need to include them xs_intersect_checkpoints = set(xs).intersection(set(checkpoints)) if xs_intersect_checkpoints: debug_print("Warning, some input nodes are also checkpoint nodes: {}".format( xs_intersect_checkpoints)) ys_intersect_checkpoints = set(ys).intersection(set(checkpoints)) debug_print("ys: {}, checkpoints:{}, intersect: {}".format( ys, checkpoints, ys_intersect_checkpoints)) # saving an output node (ys) gives no benefit in memory while creating # new edge cases, exclude them if ys_intersect_checkpoints: debug_print("Warning, some output nodes are also checkpoints nodes: {}".format( format_ops(ys_intersect_checkpoints))) # remove initial and terminal nodes from checkpoints list if present checkpoints = list(set(checkpoints) - set(ys) - set(xs)) # check that we have some nodes to checkpoint if not checkpoints: raise Exception('no checkpoints nodes found or given as input! ') # disconnect dependencies between checkpointed tensors checkpoints_disconnected = {} for x in checkpoints: if x.op and x.op.name is not None: grad_node = tf.stop_gradient(x, name=x.op.name+"_sg") else: grad_node = tf.stop_gradient(x) checkpoints_disconnected[x] = grad_node # partial derivatives to the checkpointed tensors and xs ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys], stop_at_ts=checkpoints, within_ops=fwd_ops) debug_print("Found {} ops to copy within fwd_ops {}, seed {}, stop_at {}".format( len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints)) debug_print("ops_to_copy = {}".format(ops_to_copy)) debug_print("Processing list {}".format(ys)) _, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {}) for origin_op, op in info._transformed_ops.items(): op._set_device(origin_op.node_def.device) copied_ops = info._transformed_ops.values() debug_print("Copied {} to {}".format(ops_to_copy, copied_ops)) ge.reroute_ts(checkpoints_disconnected.values(), checkpoints_disconnected.keys(), can_modify=copied_ops) debug_print("Rewired {} in place of {} restricted to {}".format( checkpoints_disconnected.values(), checkpoints_disconnected.keys(), copied_ops)) # get gradients with respect to current boundary + original x's copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys] boundary = list(checkpoints_disconnected.values()) dv = tf_gradients(ys=copied_ys, xs=boundary+xs, grad_ys=grad_ys, **kwargs) debug_print("Got gradients {}".format(dv)) debug_print("for %s", copied_ys) debug_print("with respect to {}".format(boundary+xs)) inputs_to_do_before = [y.op for y in ys] if grad_ys is not None: inputs_to_do_before += grad_ys wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) # partial derivatives to the checkpointed nodes # dictionary of "node: backprop" for nodes in the boundary d_checkpoints = {r: dr for r, dr in zip(checkpoints_disconnected.keys(), dv[:len(checkpoints_disconnected)])} # partial derivatives to xs (usually the params of the neural net) d_xs = dv[len(checkpoints_disconnected):] # incorporate derivatives flowing through the checkpointed nodes checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops) for ts in checkpoints_sorted_lists[::-1]: debug_print("Processing list {}".format(ts)) checkpoints_other = [r for r in checkpoints if r not in ts] checkpoints_disconnected_other = [checkpoints_disconnected[r] for r in checkpoints_other] # copy part of the graph below current checkpoint node, stopping at # other checkpoints nodes ops_to_copy = fast_backward_ops(within_ops=fwd_ops, seed_ops=[r.op for r in ts], stop_at_ts=checkpoints_other) debug_print("Found {} ops to copy within {}, seed {}, stop_at {}".format( len(ops_to_copy), fwd_ops, [r.op for r in ts], checkpoints_other)) debug_print("ops_to_copy = {}".format(ops_to_copy)) if not ops_to_copy: # we're done! break _, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {}) for origin_op, op in info._transformed_ops.items(): op._set_device(origin_op.node_def.device) copied_ops = info._transformed_ops.values() debug_print("Copied {} to {}".format(ops_to_copy, copied_ops)) ge.reroute_ts(checkpoints_disconnected_other, checkpoints_other, can_modify=copied_ops) debug_print("Rewired %s in place of %s restricted to %s", checkpoints_disconnected_other, checkpoints_other, copied_ops) # gradient flowing through the checkpointed node boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts] substitute_backprops = [d_checkpoints[r] for r in ts] dv = tf_gradients(boundary, checkpoints_disconnected_other+xs, grad_ys=substitute_backprops, **kwargs) debug_print("Got gradients {}".format(dv)) debug_print("for {}".format(boundary)) debug_print("with respect to {}".format(checkpoints_disconnected_other+xs)) debug_print("with boundary backprop substitutions {}".format(substitute_backprops)) inputs_to_do_before = [d_checkpoints[r].op for r in ts] wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) # partial derivatives to the checkpointed nodes for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]): if dr is not None: if d_checkpoints[r] is None: d_checkpoints[r] = dr else: d_checkpoints[r] += dr def _unsparsify(var_x): if not isinstance(var_x, tf.IndexedSlices): return var_x assert var_x.dense_shape is not None, \ "memory_saving_gradients encountered sparse gradients of unknown shape" indices = var_x.indices while indices.shape.ndims < var_x.values.shape.ndims: indices = tf.expand_dims(indices, -1) return tf.scatter_nd(indices, var_x.values, var_x.dense_shape) # partial derivatives to xs (usually the params of the neural net) d_xs_new = dv[len(checkpoints_other):] for j in range(len(xs)): if d_xs_new[j] is not None: if d_xs[j] is None: d_xs[j] = _unsparsify(d_xs_new[j]) else: d_xs[j] += _unsparsify(d_xs_new[j]) return d_xs
def fast_backward_ops(within_ops, seed_ops, stop_at_ts): """ Fast backward ops """ bwd_ops = set(ge.get_backward_walk_ops(seed_ops, stop_at_ts=stop_at_ts)) ops = bwd_ops.intersection(within_ops).difference([t.op for t in stop_at_ts]) return list(ops)
def fast_backward_ops(within_ops, seed_ops, stop_at_ts): bwd_ops = set(ge.get_backward_walk_ops(seed_ops, stop_at_ts=stop_at_ts)) ops = bwd_ops.intersection(within_ops).difference( [t.op for t in stop_at_ts]) return list(ops)
def gradients(ys, xs, grad_ys=None, checkpoints='collection', **kwargs): ''' Authors: Tim Salimans & Yaroslav Bulatov memory efficient gradient implementation inspired by "Training Deep Nets with Sublinear Memory Cost" by Chen et al. 2016 (https://arxiv.org/abs/1604.06174) ys,xs,grad_ys,kwargs are the arguments to standard tensorflow tf.gradients (https://www.tensorflow.org/versions/r0.12/api_docs/python/train.html#gradients) 'checkpoints' can either be - a list consisting of tensors from the forward pass of the neural net that we should re-use when calculating the gradients in the backward pass all other tensors that do not appear in this list will be re-computed - a string specifying how this list should be determined. currently we support - 'speed': checkpoint all outputs of convolutions and matmuls. these ops are usually the most expensive, so checkpointing them maximizes the running speed (this is a good option if nonlinearities, concats, batchnorms, etc are taking up a lot of memory) - 'memory': try to minimize the memory usage (currently using a very simple strategy that identifies a number of bottleneck tensors in the graph to checkpoint) - 'collection': look for a tensorflow collection named 'checkpoints', which holds the tensors to checkpoint ''' # print("Calling memsaving gradients with", checkpoints) if not isinstance(ys, list): ys = [ys] if not isinstance(xs, list): xs = [xs] bwd_ops = ge.get_backward_walk_ops([y.op for y in ys], inclusive=True) debug_print("bwd_ops: %s", bwd_ops) # forward ops are all ops that are candidates for recomputation fwd_ops = ge.get_forward_walk_ops([x.op for x in xs], inclusive=True, within_ops=bwd_ops) debug_print("fwd_ops: %s", fwd_ops) # exclude ops with no inputs fwd_ops = [op for op in fwd_ops if op._inputs] # don't recompute xs, remove variables xs_ops = _to_ops(xs) fwd_ops = [op for op in fwd_ops if not op in xs_ops] fwd_ops = [op for op in fwd_ops if not '/assign' in op.name] fwd_ops = [op for op in fwd_ops if not '/Assign' in op.name] fwd_ops = [op for op in fwd_ops if not '/read' in op.name] ts_all = ge.filter_ts(fwd_ops, True) # get the tensors ts_all = [t for t in ts_all if '/read' not in t.name] ts_all = set(ts_all) - set(xs) - set(ys) # construct list of tensors to checkpoint during forward pass, if not # given as input if type(checkpoints) is not list: if checkpoints == 'collection': checkpoints = tf.get_collection('checkpoints') elif checkpoints == 'speed': # checkpoint all expensive ops to maximize running speed checkpoints = ge.filter_ts_from_regex(fwd_ops, 'conv2d|Conv|MatMul') elif checkpoints == 'memory': # remove very small tensors and some weird ops def fixdims( t ): # tf.Dimension values are not compatible with int, convert manually try: return [int(e if e.value is not None else 64) for e in t] except: return [0] # unknown shape ts_all = [ t for t in ts_all if np.prod(fixdims(t.shape)) > MIN_CHECKPOINT_NODE_SIZE ] ts_all = [t for t in ts_all if 'L2Loss' not in t.name] ts_all = [t for t in ts_all if 'entropy' not in t.name] ts_all = [t for t in ts_all if 'FusedBatchNorm' not in t.name] ts_all = [t for t in ts_all if 'Switch' not in t.name] ts_all = [t for t in ts_all if 'dropout' not in t.name] # filter out all tensors that are inputs of the backward graph with util.capture_ops() as bwd_ops: tf_gradients(ys, xs, grad_ys, **kwargs) bwd_inputs = [t for op in bwd_ops for t in op.inputs] # list of tensors in forward graph that is in input to bwd graph ts_filtered = list(set(bwd_inputs).intersection(ts_all)) debug_print("Using tensors %s", ts_filtered) # try two slightly different ways of getting bottlenecks tensors # to checkpoint for ts in [ts_filtered, ts_all]: # get all bottlenecks in the graph bottleneck_ts = [] for t in ts: b = set( ge.get_backward_walk_ops(t.op, inclusive=True, within_ops=fwd_ops)) f = set( ge.get_forward_walk_ops(t.op, inclusive=False, within_ops=fwd_ops)) # check that there are not shortcuts b_inp = set([inp for op in b for inp in op.inputs]).intersection(ts_all) f_inp = set([inp for op in f for inp in op.inputs]).intersection(ts_all) if not set(b_inp).intersection( f_inp) and len(b_inp) + len(f_inp) >= len(ts_all): bottleneck_ts.append(t) # we have a bottleneck! else: debug_print( "Rejected bottleneck candidate and ops %s", [t] + list(set(ts_all) - set(b_inp) - set(f_inp))) # success? or try again without filtering? if len(bottleneck_ts) >= np.sqrt( len(ts_filtered)): # yes, enough bottlenecks found! break if not bottleneck_ts: raise Exception( 'unable to find bottleneck tensors! please provide checkpoint nodes manually, or use checkpoints="speed".' ) # sort the bottlenecks bottlenecks_sorted_lists = tf_toposort(bottleneck_ts, within_ops=fwd_ops) sorted_bottlenecks = [ t for ts in bottlenecks_sorted_lists for t in ts ] # save an approximately optimal number ~ sqrt(N) N = len(ts_filtered) if len(bottleneck_ts) <= np.ceil(np.sqrt(N)): checkpoints = sorted_bottlenecks else: step = int(np.ceil(len(bottleneck_ts) / np.sqrt(N))) checkpoints = sorted_bottlenecks[step::step] else: raise Exception('%s is unsupported input for "checkpoints"' % (checkpoints, )) checkpoints = list(set(checkpoints).intersection(ts_all)) # at this point automatic selection happened and checkpoints is list of nodes assert isinstance(checkpoints, list) debug_print("Checkpoint nodes used: %s", checkpoints) # better error handling of special cases # xs are already handled as checkpoint nodes, so no need to include them xs_intersect_checkpoints = set(xs).intersection(set(checkpoints)) if xs_intersect_checkpoints: debug_print("Warning, some input nodes are also checkpoint nodes: %s", xs_intersect_checkpoints) ys_intersect_checkpoints = set(ys).intersection(set(checkpoints)) debug_print("ys: %s, checkpoints: %s, intersect: %s", ys, checkpoints, ys_intersect_checkpoints) # saving an output node (ys) gives no benefit in memory while creating # new edge cases, exclude them if ys_intersect_checkpoints: debug_print( "Warning, some output nodes are also checkpoints nodes: %s", format_ops(ys_intersect_checkpoints)) # remove initial and terminal nodes from checkpoints list if present checkpoints = list(set(checkpoints) - set(ys) - set(xs)) # check that we have some nodes to checkpoint if not checkpoints: raise Exception('no checkpoints nodes found or given as input! ') # disconnect dependencies between checkpointed tensors checkpoints_disconnected = {} for x in checkpoints: if x.op and x.op.name is not None: grad_node = tf.stop_gradient(x, name=x.op.name + "_sg") else: grad_node = tf.stop_gradient(x) checkpoints_disconnected[x] = grad_node # partial derivatives to the checkpointed tensors and xs ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys], stop_at_ts=checkpoints, within_ops=fwd_ops) debug_print("Found %s ops to copy within fwd_ops %s, seed %s, stop_at %s", len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints) debug_print("ops_to_copy = %s", ops_to_copy) debug_print("Processing list %s", ys) copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {}) for origin_op, op in info._transformed_ops.items(): op._set_device(origin_op.node_def.device) copied_ops = info._transformed_ops.values() debug_print("Copied %s to %s", ops_to_copy, copied_ops) ge.reroute_ts(checkpoints_disconnected.values(), checkpoints_disconnected.keys(), can_modify=copied_ops) debug_print("Rewired %s in place of %s restricted to %s", checkpoints_disconnected.values(), checkpoints_disconnected.keys(), copied_ops) # get gradients with respect to current boundary + original x's copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys] boundary = list(checkpoints_disconnected.values()) dv = tf_gradients(ys=copied_ys, xs=boundary + xs, grad_ys=grad_ys, **kwargs) debug_print("Got gradients %s", dv) debug_print("for %s", copied_ys) debug_print("with respect to %s", boundary + xs) inputs_to_do_before = [y.op for y in ys] if grad_ys is not None: inputs_to_do_before += grad_ys wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) # partial derivatives to the checkpointed nodes # dictionary of "node: backprop" for nodes in the boundary d_checkpoints = { r: dr for r, dr in zip(checkpoints_disconnected.keys(), dv[:len(checkpoints_disconnected)]) } # partial derivatives to xs (usually the params of the neural net) d_xs = dv[len(checkpoints_disconnected):] # incorporate derivatives flowing through the checkpointed nodes checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops) for ts in checkpoints_sorted_lists[::-1]: debug_print("Processing list %s", ts) checkpoints_other = [r for r in checkpoints if r not in ts] checkpoints_disconnected_other = [ checkpoints_disconnected[r] for r in checkpoints_other ] # copy part of the graph below current checkpoint node, stopping at # other checkpoints nodes ops_to_copy = fast_backward_ops(within_ops=fwd_ops, seed_ops=[r.op for r in ts], stop_at_ts=checkpoints_other) debug_print("Found %s ops to copy within %s, seed %s, stop_at %s", len(ops_to_copy), fwd_ops, [r.op for r in ts], checkpoints_other) debug_print("ops_to_copy = %s", ops_to_copy) if not ops_to_copy: # we're done! break copied_sgv, info = ge.copy_with_input_replacements( ge.sgv(ops_to_copy), {}) for origin_op, op in info._transformed_ops.items(): op._set_device(origin_op.node_def.device) copied_ops = info._transformed_ops.values() debug_print("Copied %s to %s", ops_to_copy, copied_ops) ge.reroute_ts(checkpoints_disconnected_other, checkpoints_other, can_modify=copied_ops) debug_print("Rewired %s in place of %s restricted to %s", checkpoints_disconnected_other, checkpoints_other, copied_ops) # gradient flowing through the checkpointed node boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts] substitute_backprops = [d_checkpoints[r] for r in ts] dv = tf_gradients(boundary, checkpoints_disconnected_other + xs, grad_ys=substitute_backprops, **kwargs) debug_print("Got gradients %s", dv) debug_print("for %s", boundary) debug_print("with respect to %s", checkpoints_disconnected_other + xs) debug_print("with boundary backprop substitutions %s", substitute_backprops) inputs_to_do_before = [d_checkpoints[r].op for r in ts] wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) # partial derivatives to the checkpointed nodes for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]): if dr is not None: if d_checkpoints[r] is None: d_checkpoints[r] = dr else: d_checkpoints[r] += dr # partial derivatives to xs (usually the params of the neural net) d_xs_new = dv[len(checkpoints_other):] for j in range(len(xs)): if d_xs_new[j] is not None: if d_xs[j] is None: d_xs[j] = d_xs_new[j] else: d_xs[j] += d_xs_new[j] return d_xs
def _get_checkpoint(self): # exclude ops with no inputs self.fwd_ops = [op for op in self.fwd_ops if op.inputs] # don't recompute xs, remove variables xs_ops = self._to_ops(self._xs) self.fwd_ops = [op for op in self.fwd_ops if not op in xs_ops] self.fwd_ops = [op for op in self.fwd_ops if not '/assign' in op.name] self.fwd_ops = [op for op in self.fwd_ops if not '/Assign' in op.name] self.fwd_ops = [op for op in self.fwd_ops if not '/read' in op.name] ts_all = ge.filter_ts(self.fwd_ops, True) # get the tensors ts_all = [t for t in ts_all if '/read' not in t.name] ts_all = set(ts_all) - set(self._xs) - set(self._ys) # construct list of tensors to checkpoint during forward pass, if not # given as input # remove very small tensors and some weird ops def fixdims( t ): # tf.Dimension values are not compatible with int, convert manually try: return [int(e if e.value is not None else 64) for e in t] except: return [0] # unknown shape ts_all = [ t for t in ts_all if np.prod(fixdims(t.shape)) > MIN_CHECKPOINT_NODE_SIZE ] ts_all = [t for t in ts_all if 'L2Loss' not in t.name] ts_all = [t for t in ts_all if 'entropy' not in t.name] ts_all = [t for t in ts_all if 'FusedBatchNorm' not in t.name] ts_all = [t for t in ts_all if 'Switch' not in t.name] ts_all = [t for t in ts_all if 'dropout' not in t.name] # DV: FP16_FIX - need to add 'Cast' layer here to make it work for FP16 ts_all = [t for t in ts_all if 'Cast' not in t.name] self._log_info("ts_all, {}".format(ts_all)) # filter out all tensors that are inputs of the backward graph with util.capture_ops() as bwd_ops: tf_gradients(self._ys, self._xs, self._grad_ys, **self._kwargs) bwd_inputs = [t for op in self.bwd_ops for t in op.inputs] self._log_info("bwd_inputs, {}".format(bwd_inputs)) # list of tensors in forward graph that is in input to bwd graph ts_filtered = list(set(bwd_inputs).intersection(ts_all)) self._log_info("Using tensors %{}".format(ts_filtered), 2) # try two slightly different ways of getting bottlenecks tensors # to checkpoint for ts in [ts_filtered, ts_all]: # get all bottlenecks in the graph bottleneck_ts = [] for t in ts: b = set( ge.get_backward_walk_ops(t.op, inclusive=True, within_ops=self.fwd_ops)) f = set( ge.get_forward_walk_ops(t.op, inclusive=False, within_ops=self.fwd_ops)) # check that there are not shortcuts b_inp = set([inp for op in b for inp in op.inputs]).intersection(ts_all) f_inp = set([inp for op in f for inp in op.inputs]).intersection(ts_all) if not set(b_inp).intersection( f_inp) and len(b_inp) + len(f_inp) >= len(ts_all): bottleneck_ts.append(t) # we have a bottleneck! else: self._log_info( "Rejected bottleneck candidate and ops {}".format( [t] + list(set(ts_all) - set(b_inp) - set(f_inp))), 3) # success? or try again without filtering? if len(bottleneck_ts) >= np.sqrt( len(ts_filtered)): # yes, enough bottlenecks found! break if not bottleneck_ts: raise Exception( 'unable to find bottleneck tensors! please provide checkpoint nodes manually, or use checkpoints="speed".' ) # sort the bottlenecks bottlenecks_sorted_lists = self.tf_toposort(bottleneck_ts, within_ops=self.fwd_ops) sorted_bottlenecks = [t for ts in bottlenecks_sorted_lists for t in ts] # save an approximately optimal number ~ sqrt(N) N = len(ts_filtered) if len(bottleneck_ts) <= np.ceil(np.sqrt(N)): checkpoints = sorted_bottlenecks else: step = int(np.ceil(len(bottleneck_ts) / np.sqrt(N))) checkpoints = sorted_bottlenecks[step::step] checkpoints = list(set(checkpoints).intersection(ts_all)) return checkpoints