def replace_read_ops(loss_or_losses, var_list):
    """
    Replaces read ops of each variable in `vars` with new read ops obtained
    from `read_value()`, thus forcing to read the most up-to-date values of
    the variables (which might incur copies across devices).
    The graph is seeded from the tensor(s) `loss_or_losses`.
    """
    # ops between var ops and the loss
    ops = set(ge.get_walks_intersection_ops([var.op for var in var_list], loss_or_losses))
    if not ops:  # loss_or_losses doesn't depend on any var in var_list, so there is nothiing to replace
        return

    # filter out variables that are not involved in computing the loss
    var_list = [var for var in var_list if var.op in ops]

    # assume that for each variable, the only op required to compute the loss
    # is a read op, and there is exactly one per variable
    read_ops = []
    for var in var_list:
        output, = var.op.outputs
        read_op, = set(output.consumers()) & ops
        read_ops.append(read_op)

    for var, read_op in zip(var_list, read_ops):
        with tf.name_scope('/'.join(read_op.name.split('/')[:-1])):
            with tf.device(read_op.device):
                read_t, = read_op.outputs
                consumer_ops = set(read_t.consumers()) & ops
                # consumer_sgv might have multiple inputs, but we only care
                # about replacing the input that is read_t
                consumer_sgv = ge.sgv(consumer_ops)
                consumer_sgv = consumer_sgv.remap_inputs([list(consumer_sgv.inputs).index(read_t)])
                ge.connect(ge.sgv(var.read_value().op), consumer_sgv)
Example #2
0
def _reroute_network(outcoming_dict, endpoints, dup_info):
    """
    Called after _duplicate_layer. Re-route the paths from layers' outputs
    to the network's endpoints to the duplicate layer.

    Args:
      outcoming_dict: a dict {op: [outputs]} of original layers' outcoming
                      nodes; only ops from the layer and outputs outside
                      the layer are considered.
      endpoints:      network's endpoints (outputs to task-specific heads)
      dup_info:       the `info` ret val of _duplicate_layer.
    """
    branch_ops = ge.get_walks_intersection_ops(
        forward_seed_ops=list(outcoming_dict),
        backward_seed_ops=endpoints,
        forward_inclusive=False,
        backward_inclusive=True)

    outputs_to_swap = []
    for op, outputs in outcoming_dict.items():
        outputs_to_swap += [o for o in outputs if o in branch_ops]

    for node in outputs_to_swap:
        orig_inputs = list(node.inputs)
        new_inputs = []
        for ts in orig_inputs:
            new_op = dup_info.transformed(ts.op)
            if new_op is not None:
                new_inputs.extend(new_op.outputs)
            else:
                new_inputs.append(ts)
        ge.reroute_inputs(new_inputs, node)
Example #3
0
  def test_get_walks_intersection(self):
    """Test for ge.get_walks_intersection_ops."""
    ops = ge.get_walks_intersection_ops([self.c.op], [self.g.op])
    self.assertEqual(len(ops), 2)

    ops = ge.get_walks_intersection_ops([self.a.op], [self.f.op])
    self.assertEqual(len(ops), 3)
    self.assertTrue(self.a.op in ops)
    self.assertTrue(self.c.op in ops)
    self.assertTrue(self.f.op in ops)

    within_ops = [self.a.op, self.f.op]
    ops = ge.get_walks_intersection_ops(
        [self.a.op], [self.f.op], within_ops=within_ops)
    self.assertEqual(len(ops), 0)

    within_ops_fn = lambda op: op in [self.a.op, self.f.op]
    ops = ge.get_walks_intersection_ops(
        [self.a.op], [self.f.op], within_ops_fn=within_ops_fn)
    self.assertEqual(len(ops), 0)
Example #4
0
    def _build_dependency_dict(self):
        """Build a dictionary of dependencies among nodes.
        """
        open_set = Queue.Queue()
        closed_set = set()

        dep_dict = {}
        for op in self._seed_ops:
            open_set.put(op)

        reachable_ops = set(
            ge.get_walks_intersection_ops(list(self._seed_ops),
                                          list(self._grad_ops)))

        # traversal in the fw phase
        while not open_set.empty():
            src_op = open_set.get()

            # do action for src_op
            dep_ops = set(src_op.control_inputs)
            for t in src_op.inputs:
                dep_ops |= set(util.get_generating_ops(t))
                dep_ops &= reachable_ops
            dep_dict[src_op] = dep_ops

            next_ops = set()
            for t in src_op.outputs:
                next_ops |= set(util.get_consuming_ops(t))
            for op in next_ops:
                if op in closed_set:
                    continue
                if op not in open_set.queue:
                    open_set.put(op)

            closed_set.add(src_op)

        return dep_dict
Example #5
0
 def test_get_walks_intersection(self):
     """Test for ge.get_walks_intersection_ops."""
     ops = ge.get_walks_intersection_ops([self.c.op], [self.g.op])
     self.assertEqual(len(ops), 2)
Example #6
0
 def test_get_walks_intersection(self):
   """Test for ge.get_walks_intersection_ops."""
   ops = ge.get_walks_intersection_ops([self.c.op], [self.g.op])
   self.assertEqual(len(ops), 2)