def test_all_op_references(): base_op, simple_graph = get_simple_graph() leaf_all_ops = Op.all_op_references([simple_graph]) assert base_op in leaf_all_ops assert simple_graph in leaf_all_ops base_all_ops = Op.all_op_references([base_op]) assert base_op in base_all_ops assert simple_graph not in base_all_ops
def Computation(self, request_iterator, context): logger.debug("server: computation") if not self.transformer: return hetr_pb2.ComputationReply( comp_id=-1, message="build transformer before computation") try: comp_id = self.new_comp_id() pb_ops, pb_edges = [], [] returns, placeholders = [], [] reconstructed_returns, reconstructed_placeholders = [], [] for request in request_iterator: pb_ops.extend(request.ops) pb_edges.extend(request.edges) returns.extend([protobuf_to_op(op) for op in request.returns]) placeholders.extend( [protobuf_to_op(op) for op in request.placeholders]) subgraph = _deserialize_graph_ops_edges(pb_ops, pb_edges) # Add dependency on recv op to their send op in scenarios where the send buffer # is passed as an argument to the communication call (gather/scatter) # on the root device. # This ensures that by the send buffer does not get reused before # the recv_buf gets access to items root_idx = 0 for op in Op.all_op_references(subgraph): if isinstance(op, (GatherRecvOp)) and \ MPI.COMM_WORLD.Get_rank() == op.metadata['device_id']: args = list(op._args) args.extend(op.send_node().args) op._args = tuple(args) op.invalidate_property_cache('all_deps') elif (isinstance(op, (ScatterRecvOp)) and MPI.COMM_WORLD.Get_rank() == root_idx and MPI.COMM_WORLD.Get_rank() in op.metadata['device_id']): args = list(op._args) args.extend(op.send_node().args) op._args = tuple(args) op.invalidate_property_cache('all_deps') ops = Op.ordered_ops(subgraph) for r in returns: for op in ops: if op.uuid == r.uuid: reconstructed_returns.append(op) for p in placeholders: for op in ops: if op.uuid == p.uuid: reconstructed_placeholders.append(op) computation = self.transformer.computation( reconstructed_returns, *reconstructed_placeholders) self.computations[comp_id] = computation return hetr_pb2.ComputationReply(comp_id=comp_id) except Exception: return hetr_pb2.ComputationReply(comp_id=-1, message=traceback.format_exc())
def do_pass(self, ops, **kwargs): try: import graphviz except ImportError: raise ImportError( "You tried to use the ShowGraph transformer pass but did " "not have the python graphviz library installed") # Get all ops and edges from this set all_ops = Op.all_op_references(ops) all_edges = ser._serialize_graph(ops).edges vg = graphviz.Digraph(node_attr={'shape': 'box'}, graph_attr={ 'nodesep': '.5', 'ranksep': '.5' }) if self.subgraph_attr is not None: subgraphs = {} for subgraph_name in self.get_subgraphs(all_ops): if subgraph_name not in subgraphs and subgraph_name is not None: sg = graphviz.Digraph( name='cluster_{}'.format(subgraph_name)) sg.body.append('color="{}"'.format(self.random_color())) sg.body.append('style=filled') sg.body.append('label="{}"'.format(subgraph_name)) subgraphs[subgraph_name] = sg for op in all_ops: subgraph_name = op.metadata.get(self.subgraph_attr, '') # hack to show hetr graph, a tuple becomes a list after clone_graph if isinstance(subgraph_name, list): subgraph_name = tuple(subgraph_name) if subgraph_name in subgraphs: graph = subgraphs[subgraph_name] else: graph = vg self.add_op_to_graph(op, graph) for sg in subgraphs.values(): vg.subgraph(sg) else: for op in all_ops: self.add_op_to_graph(op, vg) for edge in all_edges: self.add_edge_to_graph(edge, vg) tmp_dir = tempfile.mkdtemp() vg.render(directory=tmp_dir, view=self.view, cleanup=True) if not self.view: logging.info("VizPass graph rendered to {}", tmp_dir) # Cleanup self.uuid_lookup_table.clear() return ops
def test_hetr_send_recv_graph_serialization(): """ test serializing send/recv ops defined in comm_nodes for hetr communication """ z, recv_x, recv_x_plus_one, send_x, x_plus_one, from_node, send_x_plus_one = \ create_send_recv_graph() ser_string = ser.serialize_graph([z]) py_graph = ser.deserialize_graph(ser_string) orig_graph = Op.all_op_references([z]) for o1, o2 in zip(sorted(py_graph, key=lambda x: x.uuid), sorted(orig_graph, key=lambda x: x.uuid)): assert_object_equality(o1, o2)
def test_full_graph_serialization_endtoend(): base_op, simple_graph = get_simple_graph() ser_string = ser.serialize_graph([simple_graph]) py_graph = ser.deserialize_graph(ser_string) orig_graph = Op.all_op_references([simple_graph]) # This is actually overkill since the checks of the leaf nodes will recursively # check equality up the graph, but we also want to make sure the full set of nodes # returned is equal for o1, o2 in zip(sorted(py_graph, key=lambda x: x.uuid), sorted(orig_graph, key=lambda x: x.uuid)): assert_object_equality(o1, o2)
def generate_messages(): pb_ops, pb_edges = [], [] pb_returns, pb_placeholders = generate_returns_placeholders() ops = Op.all_op_references(returns + list(placeholders)) for i, op in enumerate(ops): pb_ops.append(op_to_protobuf(op)) add_edges(pb_edges, pb_ops, op) if (i != 0 and i % _OPS_PER_MSG == 0) or (i == len(ops) - 1): msg = make_computation_request(pb_ops, pb_edges, pb_returns, pb_placeholders) yield msg pb_ops, pb_edges = [], [] pb_returns, pb_placeholders = [], []
def _serialize_graph(ops): """ Serializes a graph and returns the actual protobuf python object (rather than serialized byte string as done by `serialize_graph`). """ assert isinstance( ops, Iterable), "Ops passed into `serialize_graph` must be an iterable" ops = Op.all_op_references(ops) pb_ops = [] pb_edges = [] for op in ops: pb_ops.append(op_to_protobuf(op)) add_edges(pb_edges, pb_ops, op) graph_def = ops_pb.GraphDef() for edge in pb_edges: temp = graph_def.edges.add() temp.CopyFrom(edge) for op in pb_ops: temp = graph_def.ops.add() temp.CopyFrom(op) return graph_def
def clone_graph(root, clone_id, parallel_axis): """ clone graph with serde (serialization) input: output: new_root of the cloned graph """ # clone nodes with GatherSendOp as root using serde ser_cloned_nodes = deserialize_graph(serialize_graph([root])) new_root = next((o for o in ser_cloned_nodes if o.uuid == root.uuid), None) orig_ops = {op.uuid: op for op in Op.ordered_ops([root])} cloned_graph = Op.ordered_ops([new_root]) new_send_nodes = OrderedSet() replaced_send_nodes = OrderedSet() # update newly cloned op metadata, generate new UUIDs for op in cloned_graph: cloned_ops = orig_ops[op.uuid].metadata.get('clones') if cloned_ops is None or cloned_ops.get(str(clone_id)) is None: op.metadata['transformer'] = op.metadata['device'] + str(clone_id) op.metadata['device_id'] = str(clone_id) if isinstance( op, (ScatterRecvOp, GatherSendOp, AllReduceOp, BroadcastRecvOp)): # for gpu communication op buffer op.idx = int(clone_id) if isinstance(op, (ScatterRecvOp, BroadcastRecvOp)): op._send_node = orig_ops[op.uuid].send_node() if hasattr( op, 'reduction_axes') and parallel_axis in op.reduction_axes: op.reduction_axes = set_parallel_axes(op.reduction_axes, parallel_axis) if getattr(op, 'axes', None) is not None \ and parallel_axis in Axes.as_flattened_list(op.axes): # if parallel_axis in Axes.as_flattened_list(op.axes): op._axes = set_parallel_axes(op.axes, parallel_axis) if isinstance(op, DotOp): if parallel_axis in op.x_out_axes: op.x_out_axes = set_parallel_axes( op.x_out_axes, parallel_axis) elif parallel_axis in op.y_out_axes: op.y_out_axes = set_parallel_axes( op.y_out_axes, parallel_axis) else: raise ValueError("Missing parallel_axis in Op's " "x_out_axes or y_out_axes") if isinstance(op, TensorValueOp) and parallel_axis in op.tensor.axes: op.tensor._axes = set_parallel_axes(op.tensor.axes, parallel_axis) args_list = list(op.args) for arg_idx, arg_op in enumerate(args_list): if arg_op.uuid in orig_ops.keys(): if orig_ops[arg_op.uuid].metadata.get('clones') and \ orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id)): args_list[arg_idx] = \ orig_ops[arg_op.uuid].metadata['clones'].get(str(clone_id)) op.invalidate_property_cache('all_deps') op._args = tuple(args_list) if op != new_root: if orig_ops[op.uuid].metadata.get('clones') is None: orig_ops[op.uuid].metadata['clones'] = dict() orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op else: orig_ops[op.uuid].metadata['clones'][str(clone_id)] = op op.uuid = uuid.uuid4() # create new uuids for all the ops that have references to the new root for _op in Op.all_op_references([new_root]): _op.uuid = uuid.uuid4() return new_root, new_send_nodes, replaced_send_nodes
def __init__(self, hetr, computation_op): self.child_computations = dict() self.transformer = hetr # clear send_nodes for multiple computations if hetr.send_nodes: hetr.send_nodes.clear() self.send_nodes = hetr.send_nodes self.computation_op = computation_op # self.returns could be replaced by comp_op.returns if it were expressed as a set self.returns = OrderedSet() if isinstance(computation_op.returns, collections.Container): self.returns.update(list(computation_op.returns)) elif isinstance(computation_op.returns, Op): self.returns.update(list([computation_op.returns])) # if one of the requested results is marked as distributed across devices, # wrap it in a ResultOp to facilitate DistributedPass inserting a gather operation new_returns = OrderedSet() for op in self.returns: if 'device_id' in op.metadata and \ isinstance(op.metadata['device_id'], (list, tuple)): op.metadata['is_split_op'] = True new_result = ResultOp(device_id=0, args=tuple([op])) op.metadata['hetr_replaced_by'] = new_result new_result.metadata['replaces_op'] = op new_returns.add(new_result) else: new_returns.add(op) # Do Hetr passes logger.info('Running graph passes'), pass_ops = new_returns | OrderedSet(self.computation_op.parameters) for graph_pass in self.transformer.graph_passes: pass_ops = pass_ops | OrderedSet(hetr.send_nodes) graph_pass.do_pass(ops=pass_ops) # hack around new TensorValueOp that wraps AssignableTensorOp # autogenerated by creating a ComputationOp: for p in self.computation_op.parameters: if isinstance(p, TensorValueOp): p.metadata.update(p.states_read[0].metadata) logger.info('Launching child processes'), # assume all children are the same type # and all GPUs are in one chassis num_process = len(self.transformer.child_transformers) ppn = 1 if self.transformer.default_device == 'cpu' else num_process self.transformer.mpilauncher.launch(num_process, ppn) self.transformer.setup_child_transformers(num_process) def is_my_op(op, name): op_trans = op.metadata['transformer'] return name == op_trans or name in op_trans logger.info('Serializaing computation graph'), # build whole_graph once to avoid slow serialization once per worker # split whole pb message into list of smaller chunks # gRPC prefers sending smaller messages placeholders = [p for p in self.computation_op.parameters] all_returns = [o for o in self.send_nodes | new_returns] transform_returns = [ o.args[0] if isinstance(o, ResultOp) else o for o in all_returns ] whole_graph = Op.all_op_references(transform_returns + placeholders) pb_whole_graph = [] pb_ops, pb_edges = [], [] for i, o in enumerate(whole_graph): pb_ops.append(op_to_protobuf(o)) add_edges(pb_edges, pb_ops, o) if (i != 0 and i % _OPS_PER_MSG == 0) or (i == len(whole_graph) - 1): pb_whole_graph.append((pb_ops, pb_edges)) pb_ops, pb_edges = [], [] t_placeholders, t_returns = {}, {} for t_name in self.transformer.child_transformers.keys(): t_placeholders[t_name] = [ p for p in placeholders if is_my_op(p, t_name) ] t_returns[t_name] = [r for r in all_returns if is_my_op(r, t_name)] # create_computation is an async call using gPRC future # allowing child transformers to create computation simultaneously # get_computation waits the corresponding request to finish logger.info('Creating remote computations'), for t_name, trans in iteritems(self.transformer.child_transformers): logger.debug('child transformer: {}'.format(t_name)) trans.build_transformer() transform_ops = [ r.args[0] if isinstance(r, ResultOp) else r for r in t_returns[t_name] ] trans.create_computation(pb_whole_graph, transform_ops, t_placeholders[t_name]) for t_name, trans in iteritems(self.transformer.child_transformers): comp = trans.get_computation() comp.param_idx = [ g_pos for g_pos, p in enumerate(self.computation_op.parameters) if is_my_op(p, t_name) ] # when there is a ResultOp, hack around it comp.returns = dict() for i, op in enumerate(t_returns[t_name]): if op in self.returns and 'hetr_replaced_by' not in op.metadata: comp.returns[op] = i elif 'replaces_op' in op.metadata and op.metadata[ 'replaces_op'] in self.returns: comp.returns[op.metadata['replaces_op']] = i self.child_computations[t_name] = comp