def add_preprocessing(g, preproc_g): # type: (gde.Graph, gde.Graph) -> None """ Add preprocessing ops to a graph. Replaces one or more input `Placeholders` in the target graph with subgraphs that preprocess the input values prior to feeding them into the original graph. After performing this rewrite, the inputs of the resulting graph may have a different shape and dtype than before, but they will have the same names. Args: g: `gde.Graph` to which preprocessing should be added. *Modified in place.* preproc_g: `gde.Graph` containing the preprocessing ops to add. For each placeholder in `g` that needs preprocessing, `preproc_g` should contain a placeholder with the same name and a second op named "<name of placeholder>_preprocessed", where `<name of placeholder>` is the name of the Placeholder op. """ placeholders = gde.filter_ops_by_optype(preproc_g, "Placeholder") def preproc_name(placeholder_name): return placeholder_name + "_preprocessed" def orig_name(placeholder_name): return "__original__" + placeholder_name # Validate before modifying the graph for p in placeholders: if not g.contains_node(p.name): raise ValueError("Preprocessing graph contains a Placeholder called " "'{}', but target graph does not have an input " "Placeholder by that name." "".format(p.name)) if not preproc_g.contains_node(preproc_name(p.name)): raise ValueError("Preprocessing graph contains a Placeholder called " "'{}', but it does not have an output node called '{}' " "to produce the preprocessed version of that input." "".format(p.name, preproc_name(p.name))) # Rename all the target placeholders so we can bulk-copy the preprocessing # graph. for p in placeholders: g.rename_node(p.name, orig_name(p.name)) # Now it should be safe to copy the preprocessing graph into the original # graph. gde.copy(preproc_g, g) for p in placeholders: preproc_p = g.get_node_by_name(preproc_name(p.name)) orig_p = g.get_node_by_name(orig_name(p.name)) # Reroute all connections from original placeholder to go to the # corresponding output of the preprocessing graph. gde.reroute_ts(preproc_p.output(0), orig_p.output(0)) # Get rid of the original placeholder g.remove_node_by_name(orig_p.name)
def test_reroute(self): gde.reroute_ts([self.a0, self.b0], [self.a1, self.b1]) self.assertTrue(gde.OpMatcher("c0").input_ops("a0", "b0")(self.c0.op)) self.assertTrue(gde.OpMatcher("c1").input_ops("a0", "b0")(self.c1.op)) gde.reroute_ts([self.a1, self.b1], [self.a0, self.b0]) self.assertTrue(gde.OpMatcher("c0").input_ops("a1", "b1")(self.c0.op)) self.assertTrue(gde.OpMatcher("c1").input_ops("a1", "b1")(self.c1.op))
def _graft_pre_and_post_processing_to_main_graph(g): # type: (gde.Graph) -> None """ Attach pre- and post-processing subgraphs to the main graph. Args: g: GDE representation of the core graph. Modified in place. """ # Build the pre- and post-processing subgraphs and import into GDE pre_g = gde.Graph(_build_preprocessing_graph_def()) post_g = gde.Graph(_build_postprocessing_graph_def()) # Replace the graph's input placeholder with the contents of our # pre-processing graph. name_of_input_node = _INPUT_NODE_NAMES[0] gde.copy(pre_g, g) gde.reroute_ts( g.get_node_by_name("preprocessed_image").output(0), g.get_node_by_name(name_of_input_node).output(0)) g.remove_node_by_name(name_of_input_node) g.rename_node("raw_image", name_of_input_node) # Tack on the postprocessing graph at the original output and rename # the postprocessed output to the original output's name # The original graph produces an output called "detection_classes". # The postprocessing graph goes from "detection_classes" to # "decoded_detection_classes". # The graph after modification produces decoded classes under the original # "detection_classes" name. The original output is renamed to # "raw_detection_classes". g.rename_node("detection_classes", "raw_detection_classes") gde.copy(post_g, g) gde.reroute_ts( g.get_node_by_name("raw_detection_classes").output(0), g.get_node_by_name("detection_classes").output(0)) g.remove_node_by_name("detection_classes") g.rename_node("decoded_detection_classes", "detection_classes")
def gradients(ys, xs, grad_ys=None, checkpoints='collection', **kwargs): ''' Authors: Tim Salimans & Yaroslav Bulatov memory efficient gradient implementation inspired by "Training Deep Nets with Sublinear Memory Cost" by Chen et al. 2016 (https://arxiv.org/abs/1604.06174) ys,xs,grad_ys,kwargs are the arguments to standard tensorflow tf.gradients (https://www.tensorflow.org/versions/r0.12/api_docs/python/train.html#gradients) 'checkpoints' can either be - a list consisting of tensors from the forward pass of the neural net that we should re-use when calculating the gradients in the backward pass all other tensors that do not appear in this list will be re-computed - a string specifying how this list should be determined. currently we support - 'speed': checkpoint all outputs of convolutions and matmuls. these ops are usually the most expensive, so checkpointing them maximizes the running speed (this is a good option if nonlinearities, concats, batchnorms, etc are taking up a lot of memory) - 'memory': try to minimize the memory usage (currently using a very simple strategy that identifies a number of bottleneck tensors in the graph to checkpoint) - 'collection': look for a tensorflow collection named 'checkpoints', which holds the tensors to checkpoint ''' # print("Calling memsaving gradients with", checkpoints) if not isinstance(ys,list): ys = [ys] if not isinstance(xs,list): xs = [xs] bwd_ops = ge.get_backward_walk_ops([y.op for y in ys], inclusive=True) debug_print("bwd_ops: %s", bwd_ops) # forward ops are all ops that are candidates for recomputation fwd_ops = ge.get_forward_walk_ops([x.op for x in xs], inclusive=True, within_ops=bwd_ops) debug_print("fwd_ops: %s", fwd_ops) # exclude ops with no inputs fwd_ops = [op for op in fwd_ops if op.inputs] # don't recompute xs, remove variables xs_ops = _to_ops(xs) fwd_ops = [op for op in fwd_ops if not op in xs_ops] fwd_ops = [op for op in fwd_ops if not '/assign' in op.name] fwd_ops = [op for op in fwd_ops if not '/Assign' in op.name] fwd_ops = [op for op in fwd_ops if not '/read' in op.name] ts_all = ge.filter_ts(fwd_ops, True) # get the tensors ts_all = [t for t in ts_all if '/read' not in t.name] ts_all = set(ts_all) - set(xs) - set(ys) # construct list of tensors to checkpoint during forward pass, if not # given as input if type(checkpoints) is not list: if checkpoints == 'collection': checkpoints = tf.compat.v1.get_collection('checkpoints') elif checkpoints == 'speed': # checkpoint all expensive ops to maximize running speed checkpoints = ge.filter_ts_from_regex(fwd_ops, 'conv2d|Conv|MatMul') elif checkpoints == 'memory': # remove very small tensors and some weird ops def fixdims(t): # tf.Dimension values are not compatible with int, convert manually try: return [int(e if e.value is not None else 64) for e in t] except: return [0] # unknown shape ts_all = [t for t in ts_all if np.prod(fixdims(t.shape)) > MIN_CHECKPOINT_NODE_SIZE] ts_all = [t for t in ts_all if 'L2Loss' not in t.name] ts_all = [t for t in ts_all if 'entropy' not in t.name] ts_all = [t for t in ts_all if 'FusedBatchNorm' not in t.name] ts_all = [t for t in ts_all if 'Switch' not in t.name] ts_all = [t for t in ts_all if 'dropout' not in t.name] # DV: FP16_FIX - need to add 'Cast' layer here to make it work for FP16 ts_all = [t for t in ts_all if 'Cast' not in t.name] # filter out all tensors that are inputs of the backward graph with util.capture_ops() as bwd_ops: tf_gradients(ys, xs, grad_ys, **kwargs) bwd_inputs = [t for op in bwd_ops for t in op.inputs] # list of tensors in forward graph that is in input to bwd graph ts_filtered = list(set(bwd_inputs).intersection(ts_all)) debug_print("Using tensors %s", ts_filtered) # try two slightly different ways of getting bottlenecks tensors # to checkpoint for ts in [ts_filtered, ts_all]: # get all bottlenecks in the graph bottleneck_ts = [] for t in ts: b = set(ge.get_backward_walk_ops(t.op, inclusive=True, within_ops=fwd_ops)) f = set(ge.get_forward_walk_ops(t.op, inclusive=False, within_ops=fwd_ops)) # check that there are not shortcuts b_inp = set([inp for op in b for inp in op.inputs]).intersection(ts_all) f_inp = set([inp for op in f for inp in op.inputs]).intersection(ts_all) if not set(b_inp).intersection(f_inp) and len(b_inp)+len(f_inp) >= len(ts_all): bottleneck_ts.append(t) # we have a bottleneck! else: debug_print("Rejected bottleneck candidate and ops %s", [t] + list(set(ts_all) - set(b_inp) - set(f_inp))) # success? or try again without filtering? if len(bottleneck_ts) >= np.sqrt(len(ts_filtered)): # yes, enough bottlenecks found! break if not bottleneck_ts: raise Exception('unable to find bottleneck tensors! please provide checkpoint nodes manually, or use checkpoints="speed".') # sort the bottlenecks bottlenecks_sorted_lists = tf_toposort(bottleneck_ts, within_ops=fwd_ops) sorted_bottlenecks = [t for ts in bottlenecks_sorted_lists for t in ts] # save an approximately optimal number ~ sqrt(N) N = len(ts_filtered) if len(bottleneck_ts) <= np.ceil(np.sqrt(N)): checkpoints = sorted_bottlenecks else: step = int(np.ceil(len(bottleneck_ts) / np.sqrt(N))) checkpoints = sorted_bottlenecks[step::step] else: raise Exception('%s is unsupported input for "checkpoints"' % (checkpoints,)) checkpoints = list(set(checkpoints).intersection(ts_all)) # at this point automatic selection happened and checkpoints is list of nodes assert isinstance(checkpoints, list) debug_print("Checkpoint nodes used: %s", checkpoints) # better error handling of special cases # xs are already handled as checkpoint nodes, so no need to include them xs_intersect_checkpoints = set(xs).intersection(set(checkpoints)) if xs_intersect_checkpoints: debug_print("Warning, some input nodes are also checkpoint nodes: %s", xs_intersect_checkpoints) ys_intersect_checkpoints = set(ys).intersection(set(checkpoints)) debug_print("ys: %s, checkpoints: %s, intersect: %s", ys, checkpoints, ys_intersect_checkpoints) # saving an output node (ys) gives no benefit in memory while creating # new edge cases, exclude them if ys_intersect_checkpoints: debug_print("Warning, some output nodes are also checkpoints nodes: %s", format_ops(ys_intersect_checkpoints)) # remove initial and terminal nodes from checkpoints list if present checkpoints = list(set(checkpoints) - set(ys) - set(xs)) # check that we have some nodes to checkpoint # if not checkpoints: # raise Exception('no checkpoints nodes found or given as input! ') # disconnect dependencies between checkpointed tensors checkpoints_disconnected = {} for x in checkpoints: if x.op and x.op.name is not None: grad_node = tf.stop_gradient(x, name=x.op.name+"_sg") else: grad_node = tf.stop_gradient(x) checkpoints_disconnected[x] = grad_node # partial derivatives to the checkpointed tensors and xs ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys], stop_at_ts=checkpoints, within_ops=fwd_ops) debug_print("Found %s ops to copy within fwd_ops %s, seed %s, stop_at %s", len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints) debug_print("ops_to_copy = %s", ops_to_copy) debug_print("Processing list %s", ys) copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {}) for origin_op, op in info._transformed_ops.items(): op._set_device(origin_op.node_def.device) copied_ops = info._transformed_ops.values() debug_print("Copied %s to %s", ops_to_copy, copied_ops) ge.reroute_ts(checkpoints_disconnected.values(), checkpoints_disconnected.keys(), can_modify=copied_ops) debug_print("Rewired %s in place of %s restricted to %s", checkpoints_disconnected.values(), checkpoints_disconnected.keys(), copied_ops) # get gradients with respect to current boundary + original x's copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys] boundary = list(checkpoints_disconnected.values()) dv = tf_gradients(ys=copied_ys, xs=boundary+xs, grad_ys=grad_ys, **kwargs) debug_print("Got gradients %s", dv) debug_print("for %s", copied_ys) debug_print("with respect to %s", boundary+xs) inputs_to_do_before = [y.op for y in ys] if grad_ys is not None: inputs_to_do_before += grad_ys wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) # partial derivatives to the checkpointed nodes # dictionary of "node: backprop" for nodes in the boundary d_checkpoints = {r: dr for r,dr in zip(checkpoints_disconnected.keys(), dv[:len(checkpoints_disconnected)])} # partial derivatives to xs (usually the params of the neural net) d_xs = dv[len(checkpoints_disconnected):] # incorporate derivatives flowing through the checkpointed nodes checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops) for ts in checkpoints_sorted_lists[::-1]: debug_print("Processing list %s", ts) checkpoints_other = [r for r in checkpoints if r not in ts] checkpoints_disconnected_other = [checkpoints_disconnected[r] for r in checkpoints_other] # copy part of the graph below current checkpoint node, stopping at # other checkpoints nodes ops_to_copy = fast_backward_ops(within_ops=fwd_ops, seed_ops=[r.op for r in ts], stop_at_ts=checkpoints_other) debug_print("Found %s ops to copy within %s, seed %s, stop_at %s", len(ops_to_copy), fwd_ops, [r.op for r in ts], checkpoints_other) debug_print("ops_to_copy = %s", ops_to_copy) if not ops_to_copy: # we're done! break copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {}) for origin_op, op in info._transformed_ops.items(): op._set_device(origin_op.node_def.device) copied_ops = info._transformed_ops.values() debug_print("Copied %s to %s", ops_to_copy, copied_ops) ge.reroute_ts(checkpoints_disconnected_other, checkpoints_other, can_modify=copied_ops) debug_print("Rewired %s in place of %s restricted to %s", checkpoints_disconnected_other, checkpoints_other, copied_ops) # gradient flowing through the checkpointed node boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts] substitute_backprops = [d_checkpoints[r] for r in ts] dv = tf_gradients(boundary, checkpoints_disconnected_other+xs, grad_ys=substitute_backprops, **kwargs) debug_print("Got gradients %s", dv) debug_print("for %s", boundary) debug_print("with respect to %s", checkpoints_disconnected_other+xs) debug_print("with boundary backprop substitutions %s", substitute_backprops) inputs_to_do_before = [d_checkpoints[r].op for r in ts] wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) # partial derivatives to the checkpointed nodes for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]): if dr is not None: if d_checkpoints[r] is None: d_checkpoints[r] = dr else: d_checkpoints[r] += dr def _unsparsify(x): if not isinstance(x, tf.IndexedSlices): return x assert x.dense_shape is not None, "memory_saving_gradients encountered sparse gradients of unknown shape" indices = x.indices while indices.shape.ndims < x.values.shape.ndims: indices = tf.expand_dims(indices, -1) return tf.scatter_nd(indices, x.values, x.dense_shape) # partial derivatives to xs (usually the params of the neural net) d_xs_new = dv[len(checkpoints_other):] for j in range(len(xs)): if d_xs_new[j] is not None: if d_xs[j] is None: d_xs[j] = _unsparsify(d_xs_new[j]) else: d_xs[j] += _unsparsify(d_xs_new[j]) return d_xs
def test_compatibility(self): with self.assertRaises(ValueError): gde.reroute_ts([self.a0, self.b0], [self.a2, self.b2])
def add_postprocessing(g, postproc_g): # type: (gde.Graph, gde.Graph) -> None """ Add postprocessing ops to a graph. The postprocessing ops can replace one or more output operations of the original graph with a series of operations that apply additional transformations to the output and return the result of the transformations. After performing this rewrite, the outputs of the resulting graph may have a different shape and dtype than before, but they will have the same names. Args: g: `gde.Graph` to which postprocessing should be added. *Modified in place.* postproc_g: `gde.Graph` containing the postprocessing ops to add. For each op in `g` that needs postprocessing, `postproc_g` should contain a placeholder with the same name and a second op named "<name of output>_postprocessed", where `<name of output>` is the name of the original op. """ placeholders = gde.filter_ops_by_optype(postproc_g, "Placeholder") def postproc_name(placeholder_name): return placeholder_name + "_postprocessed" def orig_name(placeholder_name): return "__original__" + placeholder_name # Validate before modifying the graph for p in placeholders: if not g.contains_node(p.name): raise ValueError("Postprocessing graph contains a Placeholder called " "'{}', but target graph does not have an op by that " "name".format(p.name)) if 1 != len(g.get_node_by_name(p.name).outputs): raise ValueError("Output node '{}' of target graph has {} output " "tensors. Only one output is supported." "".format(p.name, len(g.get_node_by_name(p.name).outputs))) if not postproc_g.contains_node(postproc_name(p.name)): raise ValueError("Postprocessing graph contains a Placeholder called " "'{}', but it does not have a node called '{}' " "to produce the postprocessed version of that output." "".format(p.name, postproc_name(p.name))) # Rename all the original output ops so we can bulk-copy the preprocessing # graph. for p in placeholders: g.rename_node(p.name, orig_name(p.name)) # Now it should be safe to copy the preprocessing graph into the original # graph. gde.copy(postproc_g, g) for p in placeholders: postproc_input_p = g.get_node_by_name(p.name) orig_output_node = g.get_node_by_name(orig_name(p.name)) # Reroute all connections from original placeholder to go to the # corresponding output of the original graph. gde.reroute_ts(orig_output_node.output(0), postproc_input_p.output(0)) # Get rid of the placeholder g.remove_node_by_name(postproc_input_p.name) # Rename the postprocessed output to the name of the original output g.rename_node(postproc_name(p.name), p.name)