Beispiel #1
0
    def Computation(self, request_iterator, context):
        logger.info("server: computation")
        if not self.transformer:
            return hetr_pb2.ComputationReply(
                comp_id=-1, message="build transformer before computation")
        try:
            comp_id = self.new_comp_id()
            pb_ops, pb_edges = [], []
            returns, placeholders = [], []
            reconstructed_returns, reconstructed_placeholders = [], []
            for request in request_iterator:
                pb_ops.extend(request.ops)
                pb_edges.extend(request.edges)
                returns.extend([protobuf_to_op(op) for op in request.returns])
                placeholders.extend(
                    [protobuf_to_op(op) for op in request.placeholders])

            subgraph = _deserialize_graph_ops_edges(pb_ops, pb_edges)
            ops = Op.ordered_ops(subgraph)
            for r in returns:
                for op in ops:
                    if op.uuid == r.uuid:
                        reconstructed_returns.append(op)
            for p in placeholders:
                for op in ops:
                    if op.uuid == p.uuid:
                        reconstructed_placeholders.append(op)

            computation = self.transformer.computation(
                reconstructed_returns, *reconstructed_placeholders)
            self.computations[comp_id] = computation
            return hetr_pb2.ComputationReply(comp_id=comp_id)
        except Exception:
            return hetr_pb2.ComputationReply(comp_id=-1,
                                             message=traceback.format_exc())
Beispiel #2
0
    def Computation(self, request, context):
        if not self.transformer:
            return hetr_pb2.ComputationReply(comp_id=-1)

        try:
            comp_id = self.new_comp_id()
            subgraph = _deserialize_graph(request.subgraph)
            returns = []
            placeholders = []
            for pb_op in request.returns:
                returns.append(protobuf_to_op(pb_op))
            for pb_op in request.placeholders:
                placeholders.append(protobuf_to_op(pb_op))
            return_list = []
            placeholder_list = []
            ops = Op.ordered_ops(subgraph)
            for op in ops:
                for r in returns:
                    if op.uuid == r.uuid:
                        return_list.append(op)
            for op in ops:
                for p in placeholders:
                    if op.uuid == p.uuid:
                        placeholder_list.append(op)
            computation = self.transformer.computation(return_list,
                                                       *placeholder_list)
            self.computations[comp_id] = computation
            return hetr_pb2.ComputationReply(comp_id=comp_id)
        except:
            return hetr_pb2.ComputationReply(comp_id=-1)
Beispiel #3
0
    def Computation(self, request_iterator, context):
        logger.debug("server: computation")
        if not self.transformer:
            return hetr_pb2.ComputationReply(
                comp_id=-1, message="build transformer before computation")
        try:
            comp_id = self.new_comp_id()
            pb_ops, pb_edges = [], []
            returns, placeholders = [], []
            reconstructed_returns, reconstructed_placeholders = [], []
            for request in request_iterator:
                pb_ops.extend(request.ops)
                pb_edges.extend(request.edges)
                returns.extend([protobuf_to_op(op) for op in request.returns])
                placeholders.extend(
                    [protobuf_to_op(op) for op in request.placeholders])

            subgraph = _deserialize_graph_ops_edges(pb_ops, pb_edges)

            # Add dependency on recv op to their send op in scenarios where the send buffer
            # is passed as an argument to the communication call (gather/scatter)
            # on the root device.
            # This ensures that by the send buffer does not get reused before
            # the recv_buf gets access to items
            root_idx = 0
            for op in Op.all_op_references(subgraph):
                if isinstance(op, (GatherRecvOp)) and \
                   MPI.COMM_WORLD.Get_rank() == op.metadata['device_id']:
                    args = list(op._args)
                    args.extend(op.send_node().args)
                    op._args = tuple(args)
                    op.invalidate_property_cache('all_deps')
                elif (isinstance(op, (ScatterRecvOp))
                      and MPI.COMM_WORLD.Get_rank() == root_idx and
                      MPI.COMM_WORLD.Get_rank() in op.metadata['device_id']):
                    args = list(op._args)
                    args.extend(op.send_node().args)
                    op._args = tuple(args)
                    op.invalidate_property_cache('all_deps')

            ops = Op.ordered_ops(subgraph)
            for r in returns:
                for op in ops:
                    if op.uuid == r.uuid:
                        reconstructed_returns.append(op)
            for p in placeholders:
                for op in ops:
                    if op.uuid == p.uuid:
                        reconstructed_placeholders.append(op)

            computation = self.transformer.computation(
                reconstructed_returns, *reconstructed_placeholders)
            self.computations[comp_id] = computation
            return hetr_pb2.ComputationReply(comp_id=comp_id)
        except Exception:
            return hetr_pb2.ComputationReply(comp_id=-1,
                                             message=traceback.format_exc())