Exemple #1
0
    def Computation(self, request_iterator, context):
        logger.info("server: computation")
        if not self.transformer:
            return hetr_pb2.ComputationReply(
                comp_id=-1, message="build transformer before computation")
        try:
            comp_id = self.new_comp_id()
            pb_ops, pb_edges = [], []
            returns, placeholders = [], []
            reconstructed_returns, reconstructed_placeholders = [], []
            for request in request_iterator:
                pb_ops.extend(request.ops)
                pb_edges.extend(request.edges)
                returns.extend([protobuf_to_op(op) for op in request.returns])
                placeholders.extend(
                    [protobuf_to_op(op) for op in request.placeholders])

            subgraph = _deserialize_graph_ops_edges(pb_ops, pb_edges)
            ops = Op.ordered_ops(subgraph)
            for r in returns:
                for op in ops:
                    if op.uuid == r.uuid:
                        reconstructed_returns.append(op)
            for p in placeholders:
                for op in ops:
                    if op.uuid == p.uuid:
                        reconstructed_placeholders.append(op)

            computation = self.transformer.computation(
                reconstructed_returns, *reconstructed_placeholders)
            self.computations[comp_id] = computation
            return hetr_pb2.ComputationReply(comp_id=comp_id)
        except Exception:
            return hetr_pb2.ComputationReply(comp_id=-1,
                                             message=traceback.format_exc())
Exemple #2
0
    def Computation(self, request, context):
        if not self.transformer:
            return hetr_pb2.ComputationReply(comp_id=-1)

        try:
            comp_id = self.new_comp_id()
            subgraph = _deserialize_graph(request.subgraph)
            returns = []
            placeholders = []
            for pb_op in request.returns:
                returns.append(protobuf_to_op(pb_op))
            for pb_op in request.placeholders:
                placeholders.append(protobuf_to_op(pb_op))
            return_list = []
            placeholder_list = []
            ops = Op.ordered_ops(subgraph)
            for op in ops:
                for r in returns:
                    if op.uuid == r.uuid:
                        return_list.append(op)
            for op in ops:
                for p in placeholders:
                    if op.uuid == p.uuid:
                        placeholder_list.append(op)
            computation = self.transformer.computation(return_list,
                                                       *placeholder_list)
            self.computations[comp_id] = computation
            return hetr_pb2.ComputationReply(comp_id=comp_id)
        except:
            return hetr_pb2.ComputationReply(comp_id=-1)
Exemple #3
0
    def Computation(self, request_iterator, context):
        logger.debug("server: computation")
        if not self.transformer:
            return hetr_pb2.ComputationReply(
                comp_id=-1, message="build transformer before computation")
        try:
            comp_id = self.new_comp_id()
            pb_ops, pb_edges = [], []
            returns, placeholders = [], []
            reconstructed_returns, reconstructed_placeholders = [], []
            for request in request_iterator:
                pb_ops.extend(request.ops)
                pb_edges.extend(request.edges)
                returns.extend([protobuf_to_op(op) for op in request.returns])
                placeholders.extend(
                    [protobuf_to_op(op) for op in request.placeholders])

            subgraph = _deserialize_graph_ops_edges(pb_ops, pb_edges)

            # Add dependency on recv op to their send op in scenarios where the send buffer
            # is passed as an argument to the communication call (gather/scatter)
            # on the root device.
            # This ensures that by the send buffer does not get reused before
            # the recv_buf gets access to items
            root_idx = 0
            for op in Op.all_op_references(subgraph):
                if isinstance(op, (GatherRecvOp)) and \
                   MPI.COMM_WORLD.Get_rank() == op.metadata['device_id']:
                    args = list(op._args)
                    args.extend(op.send_node().args)
                    op._args = tuple(args)
                    op.invalidate_property_cache('all_deps')
                elif (isinstance(op, (ScatterRecvOp))
                      and MPI.COMM_WORLD.Get_rank() == root_idx and
                      MPI.COMM_WORLD.Get_rank() in op.metadata['device_id']):
                    args = list(op._args)
                    args.extend(op.send_node().args)
                    op._args = tuple(args)
                    op.invalidate_property_cache('all_deps')

            ops = Op.ordered_ops(subgraph)
            for r in returns:
                for op in ops:
                    if op.uuid == r.uuid:
                        reconstructed_returns.append(op)
            for p in placeholders:
                for op in ops:
                    if op.uuid == p.uuid:
                        reconstructed_placeholders.append(op)

            computation = self.transformer.computation(
                reconstructed_returns, *reconstructed_placeholders)
            self.computations[comp_id] = computation
            return hetr_pb2.ComputationReply(comp_id=comp_id)
        except Exception:
            return hetr_pb2.ComputationReply(comp_id=-1,
                                             message=traceback.format_exc())
Exemple #4
0
def test_op_to_protobuf():
    axis = ng.make_axis(name='C', length=2)
    axes = ng.make_axes([axis])
    orig_op = ng.placeholder(axes)

    # Test attributes
    orig_op.test0 = 'stringval_attr'
    orig_op.test1 = [-1.0, 4]
    orig_op.test2 = dict(foo=2, you='bar')
    orig_op.test3 = dict()
    orig_op.test4 = slice(1, 3, 5)
    orig_op.test5 = slice(1, 3)
    orig_op.test6 = slice(1, None, 3)
    orig_op.test7 = axis
    orig_op.test8 = axes

    # Test metadata
    orig_op.metadata['test0'] = 'stringval'
    orig_op.metadata['test1'] = [1, 4.0]
    orig_op.metadata['test2'] = dict(hey=1, you=4.0)
    orig_op.metadata['test4'] = dict()
    orig_op.metadata['test5'] = slice(1, 3, 5)
    orig_op.metadata['test6'] = slice(1, 3)
    orig_op.metadata['test7'] = slice(1, None, 5)
    orig_op.metadata['test8'] = axis
    orig_op.metadata['test9'] = axes

    pb_op = ser.op_to_protobuf(orig_op)
    py_op = ser.protobuf_to_op(pb_op)
    assert_object_equality(py_op, orig_op)