def _reroute_sgv_inputs(sgv0, sgv1, mode): """Re-route all the inputs of two subgraphs. Args: sgv0: the first subgraph to have its inputs swapped. This argument is converted to a subgraph using the same rules than the function subgraph.make_view. sgv1: the second subgraph to have its inputs swapped. This argument is converted to a subgraph using the same rules than the function subgraph.make_view. mode: reroute mode, see _reroute_ts(...). Returns: A tuple `(sgv0, sgv1)` of subgraph views with their inputs swapped. Note that the function argument sgv0 and sgv1 are also modified in place. Raises: StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using the same rules than the function subgraph.make_view. """ sgv0 = _subgraph.make_view(sgv0) sgv1 = _subgraph.make_view(sgv1) _util.check_graphs(sgv0, sgv1) can_modify = sgv0.ops + sgv1.ops # also allow consumers of passthrough to be modified: can_modify += _util.get_consuming_ops(sgv0.passthroughs) can_modify += _util.get_consuming_ops(sgv1.passthroughs) _reroute_ts(sgv0.inputs, sgv1.inputs, mode, can_modify=can_modify) _reroute_sgv_remap(sgv0, sgv1, mode) return sgv0, sgv1
def _do_action(self, src_ops): """Add swapin and swapout ops for ops that are reachable from `src_ops`. Args: src_ops: a list of `tf.Operation` """ open_set = Queue.Queue() closed_set = set() for op in src_ops: open_set.put(op) while not open_set.empty(): src_op = open_set.get() # get next ops before the graph is changed next_ops = set() for t in src_op.outputs: frontier_ops = set(util.get_consuming_ops(t)) next_ops |= frontier_ops - self._grad_ops # do action for src_op self._insert_swap_nodes(src_op) if self._swapped_max_tensors(): return for op in next_ops: if op in closed_set: continue if op not in open_set.queue: open_set.put(op) closed_set.add(src_op)
def _get_seed_ops(self): """Return a list of `tf.Operation` used as a starting point for LMS to traverse the graph. If a starting scope is given, the ops in this scope will be used. Otherwise, this method automatically searches for starting ops. """ # seep ops for search seed_ops = set() ops = ge.make_list_of_op(self._graph) if self._starting_scope: scope_ops = set( ge.filter_ops_from_regex(ops, "^{}".format(self._starting_scope))) if not scope_ops: raise ValueError('No operations were found in starting ' 'scope {}.'.format(self._starting_scope)) seed_ops |= scope_ops if self._starting_op_names: for name in self._starting_op_names: name_ops = set( ge.filter_ops_from_regex(ops, "^{}$".format(name))) if not name_ops: raise ValueError('No starting operation was found with ' 'name {}.'.format(name)) seed_ops |= name_ops seed_ops = list(seed_ops) if not seed_ops: candidates = set() non_grad_ops = [ op for op in self._graph.get_operations() if not (op in self._grad_ops) ] for op in non_grad_ops: for t in op.outputs: frontier_ops = set(util.get_consuming_ops(t)) if (frontier_ops & self._grad_ops): candidates.add(op) break # ordering an operation by how much it covers the other ops tmp_dict = {} max_nelems = -1 for op in candidates: nelems = len( set( ge.get_forward_walk_ops( op, within_ops=non_grad_ops, inclusive=False)) & candidates) if nelems > 0: tmp_dict[op] = nelems max_nelems = nelems if ( nelems > max_nelems) else max_nelems # seed ops will cover most of the forward ops seed_ops = [k for k, v in tmp_dict.items() if v == max_nelems] return seed_ops
def get_forward_walk_ops(seed_ops, inclusive=True, within_ops=None, stop_at_ts=(), control_outputs=None): """Do a forward graph walk and return all the visited ops. Args: seed_ops: an iterable of operations from which the forward graph walk starts. If a list of tensors is given instead, the seed_ops are set to be the consumers of those tensors. inclusive: if True the given seed_ops are also part of the resulting set. within_ops: an iterable of tf.Operation whithin which the search is restricted. If within_ops is None, the search is performed within the whole graph. stop_at_ts: an iterable of tensors at which the graph walk stops. control_outputs: a util.ControlOutputs instance or None. If not None, it will be used while walking the graph forward. Returns: A Python set of all the tf.Operation ahead of seed_ops. Raises: TypeError: if seed_ops or within_ops cannot be converted to a list of tf.Operation. """ _, control_outputs = check_cios(False, control_outputs) if not util.is_iterable(seed_ops): seed_ops = [seed_ops] if not seed_ops: return [] if isinstance(seed_ops[0], tf_ops.Tensor): ts = util.make_list_of_t(seed_ops, allow_graph=False) seed_ops = util.get_consuming_ops(ts) else: seed_ops = util.make_list_of_op(seed_ops, allow_graph=False) seed_ops = frozenset(seed_ops) stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts)) if within_ops: within_ops = util.make_list_of_op(within_ops, allow_graph=False) within_ops = frozenset(within_ops) seed_ops &= within_ops def is_within(op): return within_ops is None or op in within_ops result = list(seed_ops) wave = set(seed_ops) while wave: new_wave = set() for op in wave: for new_t in op.outputs: if new_t in stop_at_ts: continue for new_op in new_t.consumers(): if new_op not in result and is_within(new_op): new_wave.add(new_op) if control_outputs is not None: for new_op in control_outputs.get(op): if new_op not in result and is_within(new_op): new_wave.add(new_op) util.concatenate_unique(result, new_wave) wave = new_wave if not inclusive: result = [op for op in result if op not in seed_ops] return result
def _find_new_src_op(self, original_op): """Find a set of new operations to swap out their output tensors. This method is used when `original_op` produces a tensor that is consumed by a backward ops whose order is negative. In this case, the tensor might be consumed immediately by the backward ops, depending on TensorFlow runtime. Hence, there is no need to swap out the tensor. This method starts from `original_op` and returns operations whose output tensors are consumed by backward operations with positive order. Args: `original_op`: a `tf.Operation`. Return: A set of `tf.Operation`. """ src_ops = set() open_set = Queue.Queue() closed_set = set() open_set.put(original_op) while not open_set.empty(): src_op = open_set.get() # do action for src_op next_ops = set() frontier_ops = set() for t in src_op.outputs: frontier_ops |= set(util.get_consuming_ops(t)) has_order_ops = { op for op in frontier_ops if (self._topo_sort.get_order(op) > self._topo_sort.bw_starting_order) } if has_order_ops: src_ops.add(src_op) next_ops = frontier_ops - has_order_ops for op in next_ops: if op in closed_set: continue if op not in open_set.queue: open_set.put(op) closed_set.add(src_op) return src_ops
def _build_dependency_dict(self): """Build a dictionary of dependencies among nodes. """ open_set = Queue.Queue() closed_set = set() dep_dict = {} for op in self._seed_ops: open_set.put(op) reachable_ops = set( ge.get_walks_intersection_ops(list(self._seed_ops), list(self._grad_ops))) # traversal in the fw phase while not open_set.empty(): src_op = open_set.get() # do action for src_op dep_ops = set(src_op.control_inputs) for t in src_op.inputs: dep_ops |= set(util.get_generating_ops(t)) dep_ops &= reachable_ops dep_dict[src_op] = dep_ops next_ops = set() for t in src_op.outputs: next_ops |= set(util.get_consuming_ops(t)) for op in next_ops: if op in closed_set: continue if op not in open_set.queue: open_set.put(op) closed_set.add(src_op) return dep_dict
def gradients(ys, xs, grad_ys=None, checkpoints='collection', **kwargs): print("-------------------------------") debug_print("Editing model for OME") incpu_count = 0 # print("Calling memsaving gradients with", checkpoints) if not isinstance(ys, list): ys = [ys] if not isinstance(xs, list): xs = [xs] bwd_ops = ge.get_backward_walk_ops([y.op for y in ys], inclusive=True) for index, op in enumerate(bwd_ops): debug_print("bwd_ops: [{}] :{}".format(index, op.name), 1) # forward ops are all ops that are candidates for recomputation fwd_ops = ge.get_forward_walk_ops([x.op for x in xs], inclusive=True, within_ops=bwd_ops) for index, op in enumerate(fwd_ops): debug_print("fwd_ops: [{}] : {}".format(index, op.name), 1) # exclude ops with no inputs fwd_ops = [op for op in fwd_ops if op.inputs] # don't recompute xs, remove variables xs_ops = _to_ops(xs) fwd_ops = [op for op in fwd_ops if not op in xs_ops] fwd_ops = [op for op in fwd_ops if not '/assign' in op.name] fwd_ops = [op for op in fwd_ops if not '/Assign' in op.name] fwd_ops = [op for op in fwd_ops if not '/read' in op.name] ts_all = ge.filter_ts(fwd_ops, True) # get the tensors ts_all = [t for t in ts_all if '/read' not in t.name] ts_all = set(ts_all) - set(xs) - set(ys) # construct list of tensors to checkpoint during forward pass, if not # given as input if type(checkpoints) is not list: # remove very small tensors and some weird ops def fixdims( t ): # tf.Dimension values are not compatible with int, convert manually try: return [int(e if e.value is not None else 64) for e in t] except: return [0] # unknown shape ts_all = [ t for t in ts_all if np.prod(fixdims(t.shape)) > MIN_CHECKPOINT_NODE_SIZE ] ts_all = [t for t in ts_all if 'L2Loss' not in t.name] ts_all = [t for t in ts_all if 'entropy' not in t.name] ts_all = [t for t in ts_all if 'FusedBatchNorm' not in t.name] ts_all = [t for t in ts_all if 'Switch' not in t.name] ts_all = [t for t in ts_all if 'dropout' not in t.name] # DV: FP16_FIX - need to add 'Cast' layer here to make it work for FP16 ts_all = [t for t in ts_all if 'Cast' not in t.name] # filter out all tensors that are inputs of the backward graph with util.capture_ops() as bwd_ops: tf_gradients(ys, xs, grad_ys, **kwargs) bwd_inputs = [t for op in bwd_ops for t in op.inputs] # list of tensors in forward graph that is in input to bwd graph ts_filtered = list(set(bwd_inputs).intersection(ts_all)) debug_print("Using tensors {}".format(ts_filtered), 1) # try two slightly different ways of getting bottlenecks tensors # to checkpoint for ts in [ts_filtered, ts_all]: # get all bottlenecks in the graph bottleneck_ts = [] for t in ts: b = set( ge.get_backward_walk_ops(t.op, inclusive=True, within_ops=fwd_ops)) f = set( ge.get_forward_walk_ops(t.op, inclusive=False, within_ops=fwd_ops)) # check that there are not shortcuts b_inp = set([inp for op in b for inp in op.inputs]).intersection(ts_all) f_inp = set([inp for op in f for inp in op.inputs]).intersection(ts_all) if not set(b_inp).intersection( f_inp) and len(b_inp) + len(f_inp) >= len(ts_all): bottleneck_ts.append(t) # we have a bottleneck! else: debug_print( "Rejected bottleneck candidate and ops {}".format( [t] + list(set(ts_all) - set(b_inp) - set(f_inp))), 2) # success? or try again without filtering? if len(bottleneck_ts) >= np.sqrt( len(ts_filtered)): # yes, enough bottlenecks found! break if not bottleneck_ts: raise Exception( 'unable to find bottleneck tensors! please provide checkpoint nodes manually, or use checkpoints="speed".' ) # sort the bottlenecks bottlenecks_sorted_lists = tf_toposort(bottleneck_ts, within_ops=fwd_ops) sorted_bottlenecks = [t for ts in bottlenecks_sorted_lists for t in ts] # save an approximately optimal number ~ sqrt(N) N = len(ts_filtered) if len(bottleneck_ts) <= np.ceil(np.sqrt(N)): checkpoints = sorted_bottlenecks else: step = int(np.ceil(len(bottleneck_ts) / np.sqrt(N))) checkpoints = sorted_bottlenecks[step::step] checkpoints = list(set(checkpoints).intersection(ts_all)) # at this point automatic selection happened and checkpoints is list of nodes assert isinstance(checkpoints, list) debug_print("Checkpoint nodes used: {}".format(checkpoints), 1) # better error handling of special cases # xs are already handled as checkpoint nodes, so no need to include them xs_intersect_checkpoints = set(xs).intersection(set(checkpoints)) if xs_intersect_checkpoints: debug_print( "Warning, some input nodes are also checkpoint nodes: {}".format( xs_intersect_checkpoints), 2) ys_intersect_checkpoints = set(ys).intersection(set(checkpoints)) debug_print( "ys: {}, checkpoints: {}, intersect: {}".format( ys, checkpoints, ys_intersect_checkpoints), 1) # saving an output node (ys) gives no benefit in memory while creating # new edge cases, exclude them if ys_intersect_checkpoints: debug_print( "Warning, some output nodes are also checkpoints nodes: {}".format( format_ops(ys_intersect_checkpoints)), 2) # remove initial and terminal nodes from checkpoints list if present checkpoints = list(set(checkpoints) - set(ys) - set(xs)) # check that we have some nodes to checkpoint if not checkpoints: raise Exception('no checkpoints nodes found or given as input! ') debug_print( "Select {} nodes to checkpoint nodes.".format(len(checkpoints)), 0) # disconnect dependencies between checkpointed tensors checkpoints_disconnected = {} for x in checkpoints: frontier_ops = set(graph_util.get_consuming_ops(x.op.outputs)) debug_print("my frontier ops: {}".format(frontier_ops), 1) bw_frontier_ops = frontier_ops & set(bwd_ops) debug_print("my bw frontier ops: {}".format(bw_frontier_ops), 1) if len(bw_frontier_ops) > 1: continue if x.op and x.op.name is not None: grad_node = tf.stop_gradient(x, name=x.op.name + "_sg") else: grad_node = tf.stop_gradient(x) swapout_op = _add_swapout(grad_node.op, grad_node.op.outputs) incpu_count = incpu_count + 1 swapin_op = _add_swapin(swapout_op, bw_frontier_ops, grad_node.op.outputs) checkpoints_disconnected[x] = swapin_op my_add_control_inputs(x, bw_frontier_ops, swapin_op) # control dependency -> swap_in # self._add_control_dependency(src_op, dest_op, swapin_op) # g = tf.get_default_graph() # print(g.get_operations()) # partial derivatives to the checkpointed tensors and xs ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys], stop_at_ts=checkpoints, within_ops=fwd_ops) debug_print( "Found {} ops to copy within fwd_ops {}, seed {}, stop_at {}".format( len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints), 1) debug_print("ops_to_copy = {}".format(ops_to_copy), 1) debug_print("Processing list {}".format(ys), 1) copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {}) for origin_op, op in info._transformed_ops.items(): op._set_device(origin_op.node_def.device) copied_ops = info._transformed_ops.values() debug_print("Copied {} to {}".format(ops_to_copy, copied_ops), 1) ge.reroute_ts(checkpoints_disconnected.values(), checkpoints_disconnected.keys(), can_modify=copied_ops) debug_print( "Rewired {} in place of {} restricted to {}".format( checkpoints_disconnected.values(), checkpoints_disconnected.keys(), copied_ops), 1) # get gradients with respect to current boundary + original x's copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys] boundary = list(checkpoints_disconnected.values()) dv = tf_gradients(ys=copied_ys, xs=boundary + xs, grad_ys=grad_ys, **kwargs) debug_print("Got gradients {}".format(dv), 1) debug_print("for {}".format(copied_ys), 1) debug_print("with respect to {}".format(boundary + xs), 1) inputs_to_do_before = [y.op for y in ys] if grad_ys is not None: inputs_to_do_before += grad_ys wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) # partial derivatives to the checkpointed nodes # dictionary of "node: backprop" for nodes in the boundary d_checkpoints = { r: dr for r, dr in zip(checkpoints_disconnected.keys(), dv[:len(checkpoints_disconnected)]) } # partial derivatives to xs (usually the params of the neural net) d_xs = dv[len(checkpoints_disconnected):] # incorporate derivatives flowing through the checkpointed nodes checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops) for ts in checkpoints_sorted_lists[::-1]: debug_print("Processing list {}".format(ts), 1) checkpoints_other = [r for r in checkpoints if r not in ts] checkpoints_disconnected_other = [ checkpoints_disconnected[r] for r in checkpoints_other ] # copy part of the graph below current checkpoint node, stopping at # other checkpoints nodes ops_to_copy = fast_backward_ops(within_ops=fwd_ops, seed_ops=[r.op for r in ts], stop_at_ts=checkpoints_other) debug_print( "Found {} ops to copy within {}, seed {}, stop_at {}".format( len(ops_to_copy), fwd_ops, [r.op for r in ts], checkpoints_other), 1) debug_print("ops_to_copy = {}".format(ops_to_copy), 1) if not ops_to_copy: # we're done! break copied_sgv, info = ge.copy_with_input_replacements( ge.sgv(ops_to_copy), {}) for origin_op, op in info._transformed_ops.items(): op._set_device(origin_op.node_def.device) copied_ops = info._transformed_ops.values() debug_print("Copied {} to {}".format(ops_to_copy, copied_ops), 1) ge.reroute_ts(checkpoints_disconnected_other, checkpoints_other, can_modify=copied_ops) debug_print( "Rewired {} in place of {} restricted to {}".format( checkpoints_disconnected_other, checkpoints_other, copied_ops), 1) # gradient flowing through the checkpointed node boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts] substitute_backprops = [d_checkpoints[r] for r in ts] dv = tf_gradients(boundary, checkpoints_disconnected_other + xs, grad_ys=substitute_backprops, **kwargs) debug_print("Got gradients {}".format(dv), 1) debug_print("for {}".format(boundary), 1) debug_print( "with respect to {}".format(checkpoints_disconnected_other + xs), 1) debug_print( "with boundary backprop substitutions {}".format( substitute_backprops), 1) inputs_to_do_before = [d_checkpoints[r].op for r in ts] wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) # partial derivatives to the checkpointed nodes for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]): if dr is not None: if d_checkpoints[r] is None: d_checkpoints[r] = dr else: d_checkpoints[r] += dr def _unsparsify(x): if not isinstance(x, tf.IndexedSlices): return x assert x.dense_shape is not None, "memory_saving_gradients encountered sparse gradients of unknown shape" indices = x.indices while indices.shape.ndims < x.values.shape.ndims: indices = tf.expand_dims(indices, -1) return tf.scatter_nd(indices, x.values, x.dense_shape) # partial derivatives to xs (usually the params of the neural net) d_xs_new = dv[len(checkpoints_other):] for j in range(len(xs)): if d_xs_new[j] is not None: if d_xs[j] is None: d_xs[j] = _unsparsify(d_xs_new[j]) else: d_xs[j] += _unsparsify(d_xs_new[j]) return d_xs
def _do_chain_rule(self, fw_op, bw_op, lower_b, upper_b): """Find a control dependency operation using chain rules. Go down along the forward phase to find corresponding backward ops as candidates for control dependency ops. Args: fw_op: a `tf.Operation` that has a tensor swapped out. bw_op: a `tf.Operation` that consumes a tensor swapped in. lower_b: an `integer`. The distance in the graph between `fw_op` and a forward operation that has corresponding backward ops as candidates for control dependency ops must be greater than `lower_b`. upper_b: an `integer`. The distance in the graph between `fw_op` and a forward operation that has corresponding backward ops as candidates for control dependency ops must be smaller than `upper_b` Return: A tuple of (`tf.Operation`, an `integer`). The first item is the control dependency operation that triggers swapping in the input tensor of `bw_op`. The second item is the order of the control dependency operation in the topological order. """ fw_order = self._topo_sort.get_order(fw_op) bw_order = self._topo_sort.get_order(bw_op) # check if the bw op is near the boundary between fw and bw phases if (bw_order - lower_b) < self._topo_sort.bw_starting_order: return self._do_direct_order(fw_op, bw_op, lower_b, upper_b) open_set1 = Queue.Queue() open_set2 = Queue.Queue() closed_set = set() open_set1.put(fw_op) result_ops = set() while not open_set1.empty(): # stop if reaching the upperbound if upper_b == 0 or (lower_b > upper_b): break src_op = open_set1.get() # do action for src_op total_consumming_ops = set() for t in src_op.outputs: consumming_ops = set(util.get_consuming_ops(t)) total_consumming_ops |= consumming_ops if lower_b <= 0: # inside the range consumming_ops_bw = total_consumming_ops & self._grad_ops # check validation consumming_ops_bw = { op for op in consumming_ops_bw if self._topo_sort.get_order(op) > fw_order} consumming_ops_bw = { op for op in consumming_ops_bw if self._topo_sort.get_order(op) < bw_order} consumming_ops_bw = { op for op in consumming_ops_bw if "/cond/" not in op.name} result_ops |= consumming_ops_bw # go to the next level next_ops = total_consumming_ops - self._grad_ops for op in next_ops: if op in closed_set: continue if op not in open_set2.queue: open_set2.put(op) closed_set.add(src_op) if open_set1.empty(): if result_ops: break lower_b = lower_b - 1 upper_b = upper_b - 1 while not open_set2.empty(): open_set1.put(open_set2.get()) if result_ops: ctrld_op = next(iter(result_ops)) return (ctrld_op, self._topo_sort.get_order(ctrld_op)) else: return (None, -1)
def _insert_swap_nodes(self, src_op): """Insert swapin and swapout ops for the given operation into the graph. This method does an in-place modification to the graph. Args: src_op: a `tf.Operation` """ self._log_info("Operation: {}".format(src_op), 2) # bypass excluded ops if src_op in self._excl_ops: return # if inclusive mode is enabled, only proceed if this op is included if self._incl_ops: if src_op not in self._incl_ops: return for t in src_op.outputs: if self._swapped_max_tensors(): return frontier_ops = set(util.get_consuming_ops(t)) self._log_info("my frontier ops: {}".format(frontier_ops), 2) bw_frontier_ops = frontier_ops & self._grad_ops self._log_info("my bw frontier ops: {}".format(bw_frontier_ops), 2) # swap branch ops if they are far enough (depending on threshold) if self._swap_branches: fw_branch_ops = self._get_branch_ops( frontier_ops - self._grad_ops, self._branch_threshold) bw_frontier_ops = bw_frontier_ops | fw_branch_ops # Do not swap tensors used by bw ops without outgoing ops. # These bw ops can be removed by Tensorflow compiler bw_frontier_ops = {op for op in bw_frontier_ops if set(self._get_forward_walk_ops(op, inclusive=False))} if not bw_frontier_ops: continue self._log_info("Operation: {}, order {}, type {}".format( src_op.name, self._topo_sort.get_order(src_op), src_op.type), 1) # create swap_out node only if there exists a real dest. operation swapout_op = None for op in bw_frontier_ops: if self._topo_sort.get_order(op) >= 0: swapout_op = self._add_swapout(src_op, t) self._incpu_count = self._incpu_count + 1 break # create swap_in nodes if self._fuse_swapins and swapout_op: bw_frontier_ops = self._fuse_swapin_ops( src_op, swapout_op, bw_frontier_ops, t) for dest_op in bw_frontier_ops: if self._topo_sort.get_order(dest_op) < 0: if src_op in self._grad_ops: continue else: new_src_ops = self._find_new_src_op(dest_op) for op in new_src_ops: self._insert_swap_nodes(op) else: # swap_in op swapin_op = self._add_swapin(swapout_op, dest_op, t) # control dependency -> swap_in self._add_control_dependency(src_op, dest_op, swapin_op)