def do_pass(self, min_ops, transformer): """ Visit the ops until nothing changes. Args: min_ops: The set of ops that must be computed. transformer: An InitGraph object. """ assert isinstance(min_ops, Iterable), "Ops passed into do_pass must be an iterable" has_work = True while True: ops = Op.ordered_ops(min_ops) # Check for ops that added state that needs to be initialized, so they can # be added to the initialization function. has_new_inits = transformer.add_initialization_ops(ops) if not has_work and not has_new_inits: return self.replacement_list = [] # pass through the ops in an execution order collecting things to do ops = Op.ordered_ops(op.forwarded for op in transformer.state_initialization_ops + min_ops) for op in ops: op.update_forwards() self.visit(op) # Perform the gathered replacements for old, rep in self.replacement_list: old.forwarded.replace_self(rep.forwarded) has_work = len(self.replacement_list) > 0 min_ops = list(_.forwarded for _ in min_ops)
def test_all_op_references(): base_op, simple_graph = get_simple_graph() leaf_all_ops = Op.all_op_references([simple_graph]) assert base_op in leaf_all_ops assert simple_graph in leaf_all_ops base_all_ops = Op.all_op_references([base_op]) assert base_op in base_all_ops assert simple_graph not in base_all_ops
def Computation(self, request_iterator, context): logger.debug("server: computation") if not self.transformer: return hetr_pb2.ComputationReply( comp_id=-1, message="build transformer before computation") try: comp_id = self.new_comp_id() pb_ops, pb_edges = [], [] returns, placeholders = [], [] reconstructed_returns, reconstructed_placeholders = [], [] for request in request_iterator: pb_ops.extend(request.ops) pb_edges.extend(request.edges) returns.extend([protobuf_to_op(op) for op in request.returns]) placeholders.extend( [protobuf_to_op(op) for op in request.placeholders]) subgraph = _deserialize_graph_ops_edges(pb_ops, pb_edges) # Add dependency on recv op to their send op in scenarios where the send buffer # is passed as an argument to the communication call (gather/scatter) # on the root device. # This ensures that by the send buffer does not get reused before # the recv_buf gets access to items root_idx = 0 for op in Op.all_op_references(subgraph): if isinstance(op, (GatherRecvOp)) and \ MPI.COMM_WORLD.Get_rank() == op.metadata['device_id']: args = list(op._args) args.extend(op.send_node().args) op._args = tuple(args) op.invalidate_property_cache('all_deps') elif (isinstance(op, (ScatterRecvOp)) and MPI.COMM_WORLD.Get_rank() == root_idx and MPI.COMM_WORLD.Get_rank() in op.metadata['device_id']): args = list(op._args) args.extend(op.send_node().args) op._args = tuple(args) op.invalidate_property_cache('all_deps') ops = Op.ordered_ops(subgraph) for r in returns: for op in ops: if op.uuid == r.uuid: reconstructed_returns.append(op) for p in placeholders: for op in ops: if op.uuid == p.uuid: reconstructed_placeholders.append(op) computation = self.transformer.computation( reconstructed_returns, *reconstructed_placeholders) self.computations[comp_id] = computation return hetr_pb2.ComputationReply(comp_id=comp_id) except Exception: return hetr_pb2.ComputationReply(comp_id=-1, message=traceback.format_exc())
def clone_graph(root, clone_id, shared_queues_idx, parallel_axis, num_clones): """ clone graph with serde (serialization) input: output: new_root of the cloned graph """ # clone nodes with GatherSendOp as root using serde ser_cloned_nodes = deserialize_graph(serialize_graph([root])) new_root = next((o for o in ser_cloned_nodes if o.uuid == root.uuid), None) orig_ops = {op.uuid: op for op in Op.ordered_ops([root])} # Prune ops that are not control_deps of new_gather_send_op # deserialize includes extra referenced nodes cloned_graph = Op.ordered_ops([new_root]) new_send_nodes = OrderedSet() replaced_send_nodes = OrderedSet() # update newly cloned op metadata, generate new UUIDs for op in cloned_graph: op.metadata['transformer'] = op.metadata['device'] + str(clone_id) op.metadata['device_id'] = str(clone_id) if isinstance(op, (ScatterRecvOp, GatherSendOp)): op._shared_queues = orig_ops[op.uuid]._shared_queues op.idx = shared_queues_idx if isinstance(op, ScatterRecvOp): op._send_node = orig_ops[op.uuid].send_node() elif isinstance(op, (CPUQueueRecvOp, GPUQueueRecvOp)): # Cloning a recv node means we need a broadcast, so simulate one by adding an # additional sender with the same input data as the original sender. # TODO replace with real broadcast #1398 #1399 send_op = CPUQueueSendOp(orig_ops[op.uuid].send_node().args[0]) op._queue = send_op.queue op._send_node = send_op new_send_nodes.add(send_op) replaced_send_nodes.add(orig_ops[op.uuid].send_node()) if hasattr(op, '_axes') and parallel_axis in op._axes: op._axes = calculate_scatter_axes(op.axes, parallel_axis, num_clones) # TODO: Revisit to handle axes updation better. Github Ticket #1355 if isinstance(op, DotOp): if parallel_axis in op.x_out_axes: op.x_out_axes = calculate_scatter_axes( op.x_out_axes, parallel_axis, num_clones) elif parallel_axis in op.y_out_axes: op.y_out_axes = calculate_scatter_axes( op.y_out_axes, parallel_axis, num_clones) else: raise ValueError( "Missing parallel_axis in Op's x_out_axes or y_out_axes" ) op.uuid = uuid.uuid4() return new_root, new_send_nodes, replaced_send_nodes
def _transform_computations(self): """ Transform computation graphs to a form that can be run. """ # Run passes on the computation graphs all_results = [] for comp in self.computations: all_results.append(comp.computation_op) all_ops = self.run_registered_graph_passes(ops=all_results) # Collect up all ops from the graph and obtain the init graph all_ops = OrderedSet(Op.ordered_ops(all_ops)) def ensure_tensor(op): op = op.forwarded tensor_description = op.tensor_description() base = tensor_description.base tensor = self.op_tensors.get(base, None) if tensor is None: tensor = self.device_buffer_storage( base.tensor_size, base.dtype, base.name ) self.op_tensors[base] = tensor self.device_buffers.add(tensor) tensor_view = tensor.device_tensor(tensor_description) self.op_tensor_views[tensor_description] = tensor_view self.ops = Op.ordered_ops(all_ops) for op in self.ops: if op.is_tensor_op: ensure_tensor(op) self.start_transform_allocate() for device_buffer in self.device_buffers: device_buffer.transform_allocate() self.finish_transform_allocate() # Compile the computations now that we know their storage for comp in self.computations: comp.computation_name = \ self.transform_ordered_ops(comp, Op.ordered_ops([comp.computation_op]), name=comp.name) self.finish_transform() self.finalized = True
def _transform_computations(self): """ Transform computation graphs to a form that can be run. """ # Run passes on the computation graphs all_results = [] for comp in self.computations: all_results.append(comp.computation) all_ops = self.run_registered_graph_passes(all_results) self.init_computation = \ self.add_computation(computation(doall(self.state_initialization_ops)).named('init')) all_ops.append(self.init_computation.computation) # Collect up all ops from the graph and obtain the init graph all_ops = OrderedSet(Op.ordered_ops(all_ops)) def init_tensor_description(tensor_description): if tensor_description.buffer is None: tensor_description.buffer = self.device_buffer_storage( tensor_description.base.tensor_size, tensor_description.dtype, tensor_description.name ) self.device_buffers.add(tensor_description.buffer) tensor_description.value = \ tensor_description.buffer.device_tensor(tensor_description) for state in self.init_states: init_tensor_description(state.tensor_description()) self.ops = Op.ordered_ops(all_ops) for op in self.ops: if op.is_tensor_op: init_tensor_description(op.tensor_description()) self.start_transform_allocate() for device_buffer in self.device_buffers: device_buffer.transform_allocate() self.finish_transform_allocate() # Compile the computations now that we know their storage for comp in self.computations: comp.computation_name = \ self.transform_ordered_ops(Op.ordered_ops([comp.computation]), name=comp.name) self.finish_transform() self.finalized = True
def Computation(self, request_iterator, context): logger.info("server: computation") if not self.transformer: return hetr_pb2.ComputationReply( comp_id=-1, message="build transformer before computation") try: comp_id = self.new_comp_id() pb_ops, pb_edges = [], [] returns, placeholders = [], [] reconstructed_returns, reconstructed_placeholders = [], [] for request in request_iterator: pb_ops.extend(request.ops) pb_edges.extend(request.edges) returns.extend([protobuf_to_op(op) for op in request.returns]) placeholders.extend( [protobuf_to_op(op) for op in request.placeholders]) subgraph = _deserialize_graph_ops_edges(pb_ops, pb_edges) ops = Op.ordered_ops(subgraph) for r in returns: for op in ops: if op.uuid == r.uuid: reconstructed_returns.append(op) for p in placeholders: for op in ops: if op.uuid == p.uuid: reconstructed_placeholders.append(op) computation = self.transformer.computation( reconstructed_returns, *reconstructed_placeholders) self.computations[comp_id] = computation return hetr_pb2.ComputationReply(comp_id=comp_id) except Exception: return hetr_pb2.ComputationReply(comp_id=-1, message=traceback.format_exc())
def do_pass(self, min_ops, transformer): """ Visit the ops until nothing changes. Args: min_ops: The set of ops that must be computed. transformer: An InitGraph object. """ assert isinstance( min_ops, Iterable), "Ops passed into do_pass must be an iterable" has_work = True while True: if not has_work: return self.replacement_list = [] # pass through the ops in an execution order collecting things to do ops = Op.ordered_ops(op.forwarded for op in min_ops) for op in ops: op.update_forwards() self.visit(op) # Perform the gathered replacements for old, rep in self.replacement_list: old.forwarded.replace_self(rep.forwarded) has_work = len(self.replacement_list) > 0 min_ops = list(op.forwarded for op in min_ops)
def do_pass(self, ops, transformer): ops = OrderedSet(op.forwarded for op in ops) for op in reversed(Op.ordered_ops(ops)): if op.metadata.get('marker') == 'gather': # op is GatherRecvOp self.parallel_axes = op.metadata['parallel'] gather_send_op = op.send_nodes[0] # clone nodes for each device_id replaced_send_ops = OrderedSet() new_gather_send_nodes = OrderedSet() for i, id in enumerate(op.from_id): new_gather_send_op, new_sends, replaced_sends = clone_graph( root=gather_send_op, clone_id=id, shared_queues_idx=i, parallel_axis=self.parallel_axes, num_clones=len(op.from_id)) new_gather_send_nodes.add(new_gather_send_op) new_sends.add(new_gather_send_op) for o in new_sends: self.send_nodes.add(o) replaced_send_ops |= replaced_sends op.send_nodes = new_gather_send_nodes replaced_send_ops.add(gather_send_op) for o in replaced_send_ops: self.send_nodes.remove(o)
def update_parallel_axis(root, parallel_axis): for op in Op.ordered_ops([root]): if hasattr(op, 'reduction_axes') and parallel_axis in op.reduction_axes: op.reduction_axes = set_parallel_axes(op.reduction_axes, parallel_axis) if getattr(op, 'axes', None) is not None \ and parallel_axis in Axes.as_flattened_list(op.axes): # if parallel_axis in Axes.as_flattened_list(op.axes): op._axes = set_parallel_axes(op.axes, parallel_axis) if isinstance(op, DotOp): if parallel_axis in op.x_out_axes: op.x_out_axes = set_parallel_axes(op.x_out_axes, parallel_axis) elif parallel_axis in op.y_out_axes: op.y_out_axes = set_parallel_axes(op.y_out_axes, parallel_axis) else: raise ValueError("Missing parallel_axis in Op's " "x_out_axes or y_out_axes") if isinstance(op, TensorValueOp) and parallel_axis in op.tensor.axes: op.tensor._axes = set_parallel_axes(op.tensor.axes, parallel_axis)
def Computation(self, request, context): if not self.transformer: return hetr_pb2.ComputationReply(comp_id=-1) try: comp_id = self.new_comp_id() subgraph = _deserialize_graph(request.subgraph) returns = [] placeholders = [] for pb_op in request.returns: returns.append(protobuf_to_op(pb_op)) for pb_op in request.placeholders: placeholders.append(protobuf_to_op(pb_op)) return_list = [] placeholder_list = [] ops = Op.ordered_ops(subgraph) for op in ops: for r in returns: if op.uuid == r.uuid: return_list.append(op) for op in ops: for p in placeholders: if op.uuid == p.uuid: placeholder_list.append(op) computation = self.transformer.computation(return_list, *placeholder_list) self.computations[comp_id] = computation return hetr_pb2.ComputationReply(comp_id=comp_id) except: return hetr_pb2.ComputationReply(comp_id=-1)
def add_initialization_ops(self, ops): """ Ensure initializations have been captured for state in ops. Args: ops: Collection of ops. Returns: True if new initializations were added. """ did_work = False for op in ops: if op in self.init_checked_ops: continue self.init_checked_ops.add(op) new_inits = self.state_initializations(op.states_read) new_inits.update(self.state_initializations(op.states_written)) if len(new_inits) > 0: did_work = True self.state_initialization_ops.update(new_inits) self.add_initialization_ops(Op.ordered_ops(new_inits)) self.state_initialization_ops = \ OrderedSet(op.forwarded for op in self.state_initialization_ops) return did_work
def _transform_computations(self): """ Transform computation graphs to a form that can be run. """ # with Op.saved_user_deps(): # Run passes on the computation graphs self.run_registered_graph_passes(self.all_results) # Collect up all ops from the graph and obtain the init graph all_ops = OrderedSet(Op.ordered_ops(self.all_results)) init_op = doall(self.ordered_initializers(all_ops)) # Run passes on the initialization graphs self.run_registered_graph_passes([init_op]) # Union the init and computation graphs self.inits = Op.ordered_ops([init_op]) all_ops.update(self.inits) # create computation which initializes values (called once per session) init_op.update_forwards() self.init_computation = self.computation(init_op, name="init") # Give ids for op in all_ops: if op not in self.opids: self.opids[op] = len(self.opids) self.dataflow, self.memory = assign_buffers(self, all_ops, self.fusion) # Initialize tensor descriptions for op in all_ops: self.initialize_tensor_descriptions(op) self.ops = self.dataflow.instructions self.start_transform_allocate() for device_buffer in self.device_buffers: device_buffer.transform_allocate() self.finish_transform_allocate() # Compile the computations now that we know their storage for computation in self.computations: computation.transform() self.finish_transform() self.finalized = True
def do_pass(self, min_ops, transformer): """ Visit the ops until nothing changes. Args: min_ops: The set of ops that must be computed. transformer: An InitGraph object. """ assert isinstance(min_ops, Iterable), "Ops passed into do_pass must be an iterable" has_work = True while True: ops = Op.ordered_ops(min_ops) # Check for ops that added state that needs to be initialized, so they can # be added to the initialization function. has_new_inits = transformer.add_initialization_ops(ops) if not has_work and not has_new_inits: return self.replacement_list = [] # Make control dependency adjustments for any added control blocks. ops = Op.ordered_ops(op.forwarded for op in transformer.state_initialization_ops + min_ops) for op in ops: for cop in op.control_deps: if isinstance(cop, ParallelOp): op.remove_control_dep(cop) for dep in cop.control_deps: op.add_control_dep(dep) if isinstance(op, SequentialOp) and not op.control_dependencies_computed: op.compute_control_dependencies() # pass through the ops in an execution order collecting things to do ops = Op.ordered_ops(op.forwarded for op in transformer.state_initialization_ops + min_ops) for op in ops: op.update_forwards() self.visit(op) for old, rep in self.replacement_list: old.forwarded.replace_self(rep.forwarded) has_work = len(self.replacement_list) > 0 min_ops = list(_.forwarded for _ in min_ops)
def do_pass(self, ops, **kwargs): try: import graphviz except ImportError: raise ImportError( "You tried to use the ShowGraph transformer pass but did " "not have the python graphviz library installed") # Get all ops and edges from this set all_ops = Op.all_op_references(ops) all_edges = ser._serialize_graph(ops).edges vg = graphviz.Digraph(node_attr={'shape': 'box'}, graph_attr={ 'nodesep': '.5', 'ranksep': '.5' }) if self.subgraph_attr is not None: subgraphs = {} for subgraph_name in self.get_subgraphs(all_ops): if subgraph_name not in subgraphs and subgraph_name is not None: sg = graphviz.Digraph( name='cluster_{}'.format(subgraph_name)) sg.body.append('color="{}"'.format(self.random_color())) sg.body.append('style=filled') sg.body.append('label="{}"'.format(subgraph_name)) subgraphs[subgraph_name] = sg for op in all_ops: subgraph_name = op.metadata.get(self.subgraph_attr, '') # hack to show hetr graph, a tuple becomes a list after clone_graph if isinstance(subgraph_name, list): subgraph_name = tuple(subgraph_name) if subgraph_name in subgraphs: graph = subgraphs[subgraph_name] else: graph = vg self.add_op_to_graph(op, graph) for sg in subgraphs.values(): vg.subgraph(sg) else: for op in all_ops: self.add_op_to_graph(op, vg) for edge in all_edges: self.add_edge_to_graph(edge, vg) tmp_dir = tempfile.mkdtemp() vg.render(directory=tmp_dir, view=self.view, cleanup=True) if not self.view: logging.info("VizPass graph rendered to {}", tmp_dir) # Cleanup self.uuid_lookup_table.clear() return ops
def test_hetr_send_recv_graph_serialization(): """ test serializing send/recv ops defined in comm_nodes for hetr communication """ z, recv_x, recv_x_plus_one, send_x, x_plus_one, from_node, send_x_plus_one = \ create_send_recv_graph() ser_string = ser.serialize_graph([z]) py_graph = ser.deserialize_graph(ser_string) orig_graph = Op.all_op_references([z]) for o1, o2 in zip(sorted(py_graph, key=lambda x: x.uuid), sorted(orig_graph, key=lambda x: x.uuid)): assert_object_equality(o1, o2)
def test_full_graph_serialization_endtoend(): base_op, simple_graph = get_simple_graph() ser_string = ser.serialize_graph([simple_graph]) py_graph = ser.deserialize_graph(ser_string) orig_graph = Op.all_op_references([simple_graph]) # This is actually overkill since the checks of the leaf nodes will recursively # check equality up the graph, but we also want to make sure the full set of nodes # returned is equal for o1, o2 in zip(sorted(py_graph, key=lambda x: x.uuid), sorted(orig_graph, key=lambda x: x.uuid)): assert_object_equality(o1, o2)
def run_pass(self, process_op, ops, **kwargs): assert isinstance(ops, Iterable), "Ops passed into do_pass must be an iterable" has_work = True while has_work: self.begin_batch() # pass through the ops in an execution order collecting things to do ops = Op.ordered_ops(op.forwarded for op in ops) for op in ops: op.update_forwards() process_op(op) has_work = self.end_batch() ops = list(op.forwarded for op in ops)
def generate_messages(): pb_ops, pb_edges = [], [] pb_returns, pb_placeholders = generate_returns_placeholders() ops = Op.all_op_references(returns + list(placeholders)) for i, op in enumerate(ops): pb_ops.append(op_to_protobuf(op)) add_edges(pb_edges, pb_ops, op) if (i != 0 and i % _OPS_PER_MSG == 0) or (i == len(ops) - 1): msg = make_computation_request(pb_ops, pb_edges, pb_returns, pb_placeholders) yield msg pb_ops, pb_edges = [], [] pb_returns, pb_placeholders = [], []
def do_pass(self, ops): assert isinstance( ops, Iterable), "Ops passed into do_pass must be an iterable" has_work = True while has_work: self.replacement_list = [] ops = set(op.forwarded for op in ops) for op in Op.ordered_ops(ops): op.update_forwards() self.visit(op) for old, rep in self.replacement_list: old.forwarded.replace_self(rep.forwarded) has_work = len(self.replacement_list) > 0 return ops
def do_pass(self, ops, **kwargs): ops = OrderedSet(op.forwarded for op in ops) for op in reversed(Op.ordered_ops(ops)): if op.metadata.get('marker') == 'gather': # op is GatherRecvOp if self.parallel_axis is None: a = op.metadata['parallel'] assert a.length % len(op.from_id) == 0, '{} can not be equally divided by {}'\ .format(a, len(op.from_id)) self.parallel_axis = make_axis( name=a.name, length=a.length // len(op.from_id), docstring='HeTr parallel axis') gather_send_op = op.send_node() update_parallel_axis(gather_send_op, self.parallel_axis)
def do_pass(self, ops, **kwargs): ops = OrderedSet(op.forwarded for op in ops) for op in reversed(Op.ordered_ops(ops)): if op.metadata.get('marker') == 'gather': # op is GatherRecvOp if self.parallel_axes is None: a = op.metadata['parallel'] assert a.length % len(op.from_id) == 0, '{} can not be equally divided by {}'\ .format(a, len(op.from_id)) self.parallel_axes = make_axis( name=a.name, length=a.length // len(op.from_id), docstring='HeTr parallel axis') gather_send_op = op.send_nodes[0] # clone nodes for each device_id replaced_send_ops = OrderedSet() new_gather_send_nodes = OrderedSet() for i, id in enumerate(op.from_id): new_gather_send_op, new_sends, replaced_sends = clone_graph( root=gather_send_op, clone_id=id, shared_queues_idx=i, parallel_axis=self.parallel_axes, num_clones=len(op.from_id)) new_gather_send_nodes.add(new_gather_send_op) new_sends.add(new_gather_send_op) for o in new_sends: self.send_nodes.add(o) replaced_send_ops |= replaced_sends op.send_nodes = new_gather_send_nodes replaced_send_ops.add(gather_send_op) for o in replaced_send_ops: self.send_nodes.remove(o)
def allocate(self): """ Allocate storage and then initializes constants. Will finalize if not already done. """ if self.allocated: return with Op.saved_user_deps(): # Disable user_deps during transformations if not self.finalized: self._transform_computations() self.allocate_storage() for op in OrderedSet(self.inits + self.ops): self.initialize_constant(op) self.allocated = True
def do_pass(self, ops, transformer): ops = OrderedSet(op.forwarded for op in ops) def set_new_axes(root, num_devices): visit = self.do_traversal(root) self.new_axes = calculate_new_axes(root.axes, self.parallel_axis, num_devices, False) while visit: node = visit.pop() if hasattr(node, 'axes'): node._TensorOp__axes = self.new_axes # Start traversal from the top to the bottom for op in reversed(Op.ordered_ops(ops)): args = list() for arg in op.args: if 'marker' in arg.metadata: if 'gather' is arg.metadata['marker']: self.parallel_axis = arg.metadata['parallel'] set_new_axes(arg.send_node(), len(arg.from_id)) for d in range(1, len(arg.from_id)): if d == (len(arg.from_id) - 1): self.new_axes = calculate_new_axes( arg.axes, self.parallel_axis, len(arg.from_id), True) nodes = self.do_traversal(arg.send_node()) self.clone_nodes(nodes, arg.from_id[d], self.new_axes, self.scatter_shared_queues[d], self.gather_shared_queues[d]) args.append(arg) if isinstance(op.args, tuple): op._Op__args = tuple(args) else: op.args(args)
def _serialize_graph(ops): """ Serializes a graph and returns the actual protobuf python object (rather than serialized byte string as done by `serialize_graph`). """ assert isinstance( ops, Iterable), "Ops passed into `serialize_graph` must be an iterable" ops = Op.all_op_references(ops) pb_ops = [] pb_edges = [] for op in ops: pb_ops.append(op_to_protobuf(op)) add_edges(pb_edges, pb_ops, op) graph_def = ops_pb.GraphDef() for edge in pb_edges: temp = graph_def.edges.add() temp.CopyFrom(edge) for op in pb_ops: temp = graph_def.ops.add() temp.CopyFrom(op) return graph_def
def clone_graph(root, clone_id, shared_queues_idx, parallel_axis, num_clones): """ clone graph with serde (serialization) input: output: new_root of the cloned graph """ # clone nodes with GatherSendOp as root using serde ser_cloned_nodes = deserialize_graph(serialize_graph([root])) new_root = next((o for o in ser_cloned_nodes if o.uuid == root.uuid), None) orig_ops = {op.uuid: op for op in Op.ordered_ops([root])} # Prune ops that are not control_deps of new_gather_send_op # deserialize includes extra referenced nodes cloned_graph = Op.ordered_ops([new_root]) new_send_nodes = OrderedSet() replaced_send_nodes = OrderedSet() # update newly cloned op metadata, generate new UUIDs for op in cloned_graph: cloned_ops = orig_ops[op.uuid].metadata.get('clones') if cloned_ops is None or cloned_ops.get(str(clone_id)) is None: op.metadata['transformer'] = op.metadata['device'] + str(clone_id) op.metadata['device_id'] = str(clone_id) if isinstance( op, (ScatterRecvOp, GatherSendOp, AllReduceOp, BroadcastRecvOp)): op._shared_queues = orig_ops[op.uuid]._shared_queues op.idx = shared_queues_idx if isinstance(op, (ScatterRecvOp, BroadcastRecvOp)): op._send_node = orig_ops[op.uuid].send_node() elif isinstance(op, (CPUQueueRecvOp, GPUQueueRecvOp)): # Cloning a recv node means we need a broadcast, so simulate one by adding an # additional sender with the same input data as the original sender. send_op = CPUQueueSendOp(orig_ops[op.uuid].send_node().args[0]) op._queue = send_op.queue op._send_node = send_op new_send_nodes.add(send_op) replaced_send_nodes.add(orig_ops[op.uuid].send_node()) if hasattr( op, 'reduction_axes') and parallel_axis in op.reduction_axes: op.reduction_axes = set_parallel_axes(op.reduction_axes, parallel_axis) if getattr(op, 'axes', None) is not None \ and parallel_axis in Axes.as_flattened_list(op.axes): # if parallel_axis in Axes.as_flattened_list(op.axes): op._axes = set_parallel_axes(op.axes, parallel_axis) if isinstance(op, DotOp): if parallel_axis in op.x_out_axes: op.x_out_axes = set_parallel_axes( op.x_out_axes, parallel_axis) elif parallel_axis in op.y_out_axes: op.y_out_axes = set_parallel_axes( op.y_out_axes, parallel_axis) else: raise ValueError("Missing parallel_axis in Op's " "x_out_axes or y_out_axes") if isinstance(op, TensorValueOp) and parallel_axis in op.tensor.axes: op.tensor._axes = set_parallel_axes(op.tensor.axes, parallel_axis) args_list = list(op.args) for arg_idx, arg_op in enumerate(args_list): if arg_op.uuid in orig_ops.keys(): if orig_ops[arg_op.uuid].metadata.get('clones') and \ orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id)): args_list[arg_idx] = \ orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id)) op.invalidate_property_cache('all_deps') op._args = tuple(args_list) if op != new_root: if orig_ops[op.uuid].metadata.get('clones') is None: orig_ops[op.uuid].metadata['clones'] = dict() orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op else: orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op op.uuid = uuid.uuid4() return new_root, new_send_nodes, replaced_send_nodes
def clone_graph(root, clone_id, parallel_axis): """ clone graph with serde (serialization) input: output: new_root of the cloned graph """ # clone nodes with GatherSendOp as root using serde ser_cloned_nodes = deserialize_graph(serialize_graph([root])) new_root = next((o for o in ser_cloned_nodes if o.uuid == root.uuid), None) orig_ops = {op.uuid: op for op in Op.ordered_ops([root])} cloned_graph = Op.ordered_ops([new_root]) new_send_nodes = OrderedSet() replaced_send_nodes = OrderedSet() # update newly cloned op metadata, generate new UUIDs for op in cloned_graph: cloned_ops = orig_ops[op.uuid].metadata.get('clones') if cloned_ops is None or cloned_ops.get(str(clone_id)) is None: op.metadata['transformer'] = op.metadata['device'] + str(clone_id) op.metadata['device_id'] = str(clone_id) if isinstance( op, (ScatterRecvOp, GatherSendOp, AllReduceOp, BroadcastRecvOp)): # for gpu communication op buffer op.idx = int(clone_id) if isinstance(op, (ScatterRecvOp, BroadcastRecvOp)): op._send_node = orig_ops[op.uuid].send_node() if hasattr( op, 'reduction_axes') and parallel_axis in op.reduction_axes: op.reduction_axes = set_parallel_axes(op.reduction_axes, parallel_axis) if getattr(op, 'axes', None) is not None \ and parallel_axis in Axes.as_flattened_list(op.axes): # if parallel_axis in Axes.as_flattened_list(op.axes): op._axes = set_parallel_axes(op.axes, parallel_axis) if isinstance(op, DotOp): if parallel_axis in op.x_out_axes: op.x_out_axes = set_parallel_axes( op.x_out_axes, parallel_axis) elif parallel_axis in op.y_out_axes: op.y_out_axes = set_parallel_axes( op.y_out_axes, parallel_axis) else: raise ValueError("Missing parallel_axis in Op's " "x_out_axes or y_out_axes") if isinstance(op, TensorValueOp) and parallel_axis in op.tensor.axes: op.tensor._axes = set_parallel_axes(op.tensor.axes, parallel_axis) args_list = list(op.args) for arg_idx, arg_op in enumerate(args_list): if arg_op.uuid in orig_ops.keys(): if orig_ops[arg_op.uuid].metadata.get('clones') and \ orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id)): args_list[arg_idx] = \ orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id)) op.invalidate_property_cache('all_deps') op._args = tuple(args_list) if op != new_root: if orig_ops[op.uuid].metadata.get('clones') is None: orig_ops[op.uuid].metadata['clones'] = dict() orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op else: orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op op.uuid = uuid.uuid4() # create new uuids for all the ops that have references to the new root for _op in Op.all_op_references([new_root]): _op.uuid = uuid.uuid4() return new_root, new_send_nodes, replaced_send_nodes
def do_pass(self, ops): return len(Op.ordered_ops(ops))
def __init__(self, hetr, computation_op): self.child_computations = dict() self.transformer = hetr # clear send_nodes for multiple computations if hetr.send_nodes: hetr.send_nodes.clear() self.send_nodes = hetr.send_nodes self.computation_op = computation_op # self.returns could be replaced by comp_op.returns if it were expressed as a set self.returns = OrderedSet() if isinstance(computation_op.returns, collections.Container): self.returns.update(list(computation_op.returns)) elif isinstance(computation_op.returns, Op): self.returns.update(list([computation_op.returns])) # if one of the requested results is marked as distributed across devices, # wrap it in a ResultOp to facilitate DistributedPass inserting a gather operation new_returns = OrderedSet() for op in self.returns: if 'device_id' in op.metadata and \ isinstance(op.metadata['device_id'], (list, tuple)): op.metadata['is_split_op'] = True new_result = ResultOp(device_id=0, args=tuple([op])) op.metadata['hetr_replaced_by'] = new_result new_result.metadata['replaces_op'] = op new_returns.add(new_result) else: new_returns.add(op) # Do Hetr passes logger.info('Running graph passes'), pass_ops = new_returns | OrderedSet(self.computation_op.parameters) for graph_pass in self.transformer.graph_passes: pass_ops = pass_ops | OrderedSet(hetr.send_nodes) graph_pass.do_pass(ops=pass_ops) # hack around new TensorValueOp that wraps AssignableTensorOp # autogenerated by creating a ComputationOp: for p in self.computation_op.parameters: if isinstance(p, TensorValueOp): p.metadata.update(p.states_read[0].metadata) logger.info('Launching child processes'), # assume all children are the same type # and all GPUs are in one chassis num_process = len(self.transformer.child_transformers) ppn = 1 if self.transformer.default_device == 'cpu' else num_process self.transformer.mpilauncher.launch(num_process, ppn) self.transformer.setup_child_transformers(num_process) def is_my_op(op, name): op_trans = op.metadata['transformer'] return name == op_trans or name in op_trans logger.info('Serializaing computation graph'), # build whole_graph once to avoid slow serialization once per worker # split whole pb message into list of smaller chunks # gRPC prefers sending smaller messages placeholders = [p for p in self.computation_op.parameters] all_returns = [o for o in self.send_nodes | new_returns] transform_returns = [ o.args[0] if isinstance(o, ResultOp) else o for o in all_returns ] whole_graph = Op.all_op_references(transform_returns + placeholders) pb_whole_graph = [] pb_ops, pb_edges = [], [] for i, o in enumerate(whole_graph): pb_ops.append(op_to_protobuf(o)) add_edges(pb_edges, pb_ops, o) if (i != 0 and i % _OPS_PER_MSG == 0) or (i == len(whole_graph) - 1): pb_whole_graph.append((pb_ops, pb_edges)) pb_ops, pb_edges = [], [] t_placeholders, t_returns = {}, {} for t_name in self.transformer.child_transformers.keys(): t_placeholders[t_name] = [ p for p in placeholders if is_my_op(p, t_name) ] t_returns[t_name] = [r for r in all_returns if is_my_op(r, t_name)] # create_computation is an async call using gPRC future # allowing child transformers to create computation simultaneously # get_computation waits the corresponding request to finish logger.info('Creating remote computations'), for t_name, trans in iteritems(self.transformer.child_transformers): logger.debug('child transformer: {}'.format(t_name)) trans.build_transformer() transform_ops = [ r.args[0] if isinstance(r, ResultOp) else r for r in t_returns[t_name] ] trans.create_computation(pb_whole_graph, transform_ops, t_placeholders[t_name]) for t_name, trans in iteritems(self.transformer.child_transformers): comp = trans.get_computation() comp.param_idx = [ g_pos for g_pos, p in enumerate(self.computation_op.parameters) if is_my_op(p, t_name) ] # when there is a ResultOp, hack around it comp.returns = dict() for i, op in enumerate(t_returns[t_name]): if op in self.returns and 'hetr_replaced_by' not in op.metadata: comp.returns[op] = i elif 'replaces_op' in op.metadata and op.metadata[ 'replaces_op'] in self.returns: comp.returns[op.metadata['replaces_op']] = i self.child_computations[t_name] = comp