def replace_read_ops(loss_or_losses, var_list): """ Replaces read ops of each variable in `vars` with new read ops obtained from `read_value()`, thus forcing to read the most up-to-date values of the variables (which might incur copies across devices). The graph is seeded from the tensor(s) `loss_or_losses`. """ # ops between var ops and the loss ops = set(ge.get_walks_intersection_ops([var.op for var in var_list], loss_or_losses)) if not ops: # loss_or_losses doesn't depend on any var in var_list, so there is nothiing to replace return # filter out variables that are not involved in computing the loss var_list = [var for var in var_list if var.op in ops] # assume that for each variable, the only op required to compute the loss # is a read op, and there is exactly one per variable read_ops = [] for var in var_list: output, = var.op.outputs read_op, = set(output.consumers()) & ops read_ops.append(read_op) for var, read_op in zip(var_list, read_ops): with tf.name_scope('/'.join(read_op.name.split('/')[:-1])): with tf.device(read_op.device): read_t, = read_op.outputs consumer_ops = set(read_t.consumers()) & ops # consumer_sgv might have multiple inputs, but we only care # about replacing the input that is read_t consumer_sgv = ge.sgv(consumer_ops) consumer_sgv = consumer_sgv.remap_inputs([list(consumer_sgv.inputs).index(read_t)]) ge.connect(ge.sgv(var.read_value().op), consumer_sgv)
def _reroute_network(outcoming_dict, endpoints, dup_info): """ Called after _duplicate_layer. Re-route the paths from layers' outputs to the network's endpoints to the duplicate layer. Args: outcoming_dict: a dict {op: [outputs]} of original layers' outcoming nodes; only ops from the layer and outputs outside the layer are considered. endpoints: network's endpoints (outputs to task-specific heads) dup_info: the `info` ret val of _duplicate_layer. """ branch_ops = ge.get_walks_intersection_ops( forward_seed_ops=list(outcoming_dict), backward_seed_ops=endpoints, forward_inclusive=False, backward_inclusive=True) outputs_to_swap = [] for op, outputs in outcoming_dict.items(): outputs_to_swap += [o for o in outputs if o in branch_ops] for node in outputs_to_swap: orig_inputs = list(node.inputs) new_inputs = [] for ts in orig_inputs: new_op = dup_info.transformed(ts.op) if new_op is not None: new_inputs.extend(new_op.outputs) else: new_inputs.append(ts) ge.reroute_inputs(new_inputs, node)
def test_get_walks_intersection(self): """Test for ge.get_walks_intersection_ops.""" ops = ge.get_walks_intersection_ops([self.c.op], [self.g.op]) self.assertEqual(len(ops), 2) ops = ge.get_walks_intersection_ops([self.a.op], [self.f.op]) self.assertEqual(len(ops), 3) self.assertTrue(self.a.op in ops) self.assertTrue(self.c.op in ops) self.assertTrue(self.f.op in ops) within_ops = [self.a.op, self.f.op] ops = ge.get_walks_intersection_ops( [self.a.op], [self.f.op], within_ops=within_ops) self.assertEqual(len(ops), 0) within_ops_fn = lambda op: op in [self.a.op, self.f.op] ops = ge.get_walks_intersection_ops( [self.a.op], [self.f.op], within_ops_fn=within_ops_fn) self.assertEqual(len(ops), 0)
def _build_dependency_dict(self): """Build a dictionary of dependencies among nodes. """ open_set = Queue.Queue() closed_set = set() dep_dict = {} for op in self._seed_ops: open_set.put(op) reachable_ops = set( ge.get_walks_intersection_ops(list(self._seed_ops), list(self._grad_ops))) # traversal in the fw phase while not open_set.empty(): src_op = open_set.get() # do action for src_op dep_ops = set(src_op.control_inputs) for t in src_op.inputs: dep_ops |= set(util.get_generating_ops(t)) dep_ops &= reachable_ops dep_dict[src_op] = dep_ops next_ops = set() for t in src_op.outputs: next_ops |= set(util.get_consuming_ops(t)) for op in next_ops: if op in closed_set: continue if op not in open_set.queue: open_set.put(op) closed_set.add(src_op) return dep_dict
def test_get_walks_intersection(self): """Test for ge.get_walks_intersection_ops.""" ops = ge.get_walks_intersection_ops([self.c.op], [self.g.op]) self.assertEqual(len(ops), 2)