def Computation(self, request_iterator, context): logger.info("server: computation") if not self.transformer: return hetr_pb2.ComputationReply( comp_id=-1, message="build transformer before computation") try: comp_id = self.new_comp_id() pb_ops, pb_edges = [], [] returns, placeholders = [], [] reconstructed_returns, reconstructed_placeholders = [], [] for request in request_iterator: pb_ops.extend(request.ops) pb_edges.extend(request.edges) returns.extend([protobuf_to_op(op) for op in request.returns]) placeholders.extend( [protobuf_to_op(op) for op in request.placeholders]) subgraph = _deserialize_graph_ops_edges(pb_ops, pb_edges) ops = Op.ordered_ops(subgraph) for r in returns: for op in ops: if op.uuid == r.uuid: reconstructed_returns.append(op) for p in placeholders: for op in ops: if op.uuid == p.uuid: reconstructed_placeholders.append(op) computation = self.transformer.computation( reconstructed_returns, *reconstructed_placeholders) self.computations[comp_id] = computation return hetr_pb2.ComputationReply(comp_id=comp_id) except Exception: return hetr_pb2.ComputationReply(comp_id=-1, message=traceback.format_exc())
def Computation(self, request, context): if not self.transformer: return hetr_pb2.ComputationReply(comp_id=-1) try: comp_id = self.new_comp_id() subgraph = _deserialize_graph(request.subgraph) returns = [] placeholders = [] for pb_op in request.returns: returns.append(protobuf_to_op(pb_op)) for pb_op in request.placeholders: placeholders.append(protobuf_to_op(pb_op)) return_list = [] placeholder_list = [] ops = Op.ordered_ops(subgraph) for op in ops: for r in returns: if op.uuid == r.uuid: return_list.append(op) for op in ops: for p in placeholders: if op.uuid == p.uuid: placeholder_list.append(op) computation = self.transformer.computation(return_list, *placeholder_list) self.computations[comp_id] = computation return hetr_pb2.ComputationReply(comp_id=comp_id) except: return hetr_pb2.ComputationReply(comp_id=-1)
def Computation(self, request_iterator, context): logger.debug("server: computation") if not self.transformer: return hetr_pb2.ComputationReply( comp_id=-1, message="build transformer before computation") try: comp_id = self.new_comp_id() pb_ops, pb_edges = [], [] returns, placeholders = [], [] reconstructed_returns, reconstructed_placeholders = [], [] for request in request_iterator: pb_ops.extend(request.ops) pb_edges.extend(request.edges) returns.extend([protobuf_to_op(op) for op in request.returns]) placeholders.extend( [protobuf_to_op(op) for op in request.placeholders]) subgraph = _deserialize_graph_ops_edges(pb_ops, pb_edges) # Add dependency on recv op to their send op in scenarios where the send buffer # is passed as an argument to the communication call (gather/scatter) # on the root device. # This ensures that by the send buffer does not get reused before # the recv_buf gets access to items root_idx = 0 for op in Op.all_op_references(subgraph): if isinstance(op, (GatherRecvOp)) and \ MPI.COMM_WORLD.Get_rank() == op.metadata['device_id']: args = list(op._args) args.extend(op.send_node().args) op._args = tuple(args) op.invalidate_property_cache('all_deps') elif (isinstance(op, (ScatterRecvOp)) and MPI.COMM_WORLD.Get_rank() == root_idx and MPI.COMM_WORLD.Get_rank() in op.metadata['device_id']): args = list(op._args) args.extend(op.send_node().args) op._args = tuple(args) op.invalidate_property_cache('all_deps') ops = Op.ordered_ops(subgraph) for r in returns: for op in ops: if op.uuid == r.uuid: reconstructed_returns.append(op) for p in placeholders: for op in ops: if op.uuid == p.uuid: reconstructed_placeholders.append(op) computation = self.transformer.computation( reconstructed_returns, *reconstructed_placeholders) self.computations[comp_id] = computation return hetr_pb2.ComputationReply(comp_id=comp_id) except Exception: return hetr_pb2.ComputationReply(comp_id=-1, message=traceback.format_exc())