def do_traversal(self, root): # Note: This is almost identical to Op's visit_input_closure. available = OrderedSet() counts = dict() parents = collections.defaultdict(list) ready = OrderedSet() nodes = list() available.add(root) while available: node = available.pop() node.update_forwards() if node in counts: continue children = [child.forwarded for child in node.control_deps] if children: counts[node] = len(children) for child in children: parents[child].append(node) available.update(children) else: ready.add(node) while ready: node = ready.pop() nodes.append(node) for p in parents.get(node, []): count = counts[p] - 1 if count == 0: ready.add(p) del counts[p] else: counts[p] = count return nodes
def find_recvs(fro): # Find all the Receivers fro depends on visit = OrderedSet() recvs = OrderedSet() visit.add(fro) while visit: v = visit.pop() if isinstance(v, Receiver): recvs.add(v) visit.add(v.send_node()) else: if hasattr(v, 'args'): visit.update(v.args) return recvs
def comm_path_exists(fro, to): """ Find a path from fro to to, including paths non-explicit edges from a Receiver to its Sender. Note- this is a non-standard traversal, as most traversals stop at a Receiver. """ # TODO: does this correctly handle traversing multiple send-recv junctions # from fro to to? visit = OrderedSet(fro.args) while visit: v = visit.pop() if v == to: return True if isinstance(v, Receiver): visit.add(v.send_node()) else: visit.update(v.args) return False
def __init__(self, transformer, returns, *args, **kwargs): super(Computation, self).__init__(**kwargs) self.transformer = transformer self.computation_name = None def wrap_op(op): if isinstance(op, TensorOp): return ResultHandle(op) else: return op def wrap_ops(ops): return [wrap_op(op) for op in ops] self.ops = OrderedSet() if isinstance(returns, collections.Set): returns = set(wrap_ops(returns)) self.ops.update(returns) elif isinstance(returns, collections.Sequence): returns = wrap_ops(returns) self.ops.update(returns) elif isinstance(returns, Op): returns = wrap_op(returns) self.ops.add(returns) elif returns is not None: raise ValueError() self.returns = returns self.parameters = [] for arg in args: if arg.input: self.parameters.append(arg) else: raise ValueError(( 'The arguments to a computation must all have property ' 'input=True, but the op passed had input=False. In most ' 'cases you want to pass placeholder ops in as arguments. ' '{op} was passed in, of type {op_type}.' ).format( op=arg, op_type=arg.__class__.__name__, )) if isinstance(arg, Op): self.ops.add(arg) else: raise ValueError() control_ops = OrderedSet() for op in self.ops: control_ops.update(op.user_deps) processed_ops = set() pending_ops = OrderedSet(self.ops) while pending_ops: op = pending_ops.pop() if op in processed_ops: continue control_ops.update(op.other_deps) pending_ops.update(op.other_deps) pending_ops.update(op.args) processed_ops.add(op) self.ops.update(control_ops) self.transformer.all_results.update(self.ops) self.executor = None