예제 #1
0
    def create_computation(self, pb_graph, returns, placeholders):
        logger.debug("client: create_computation")

        def make_computation_request(pb_ops,
                                     pb_edges,
                                     pb_returns=None,
                                     pb_placeholders=None):
            if pb_returns or pb_placeholders:
                return hetr_pb2.ComputationRequest(
                    ops=pb_ops,
                    edges=pb_edges,
                    returns=pb_returns,
                    placeholders=pb_placeholders)
            else:
                return hetr_pb2.ComputationRequest(ops=pb_ops, edges=pb_edges)

        def generate_messages():
            pb_returns = [op_to_protobuf(o) for o in returns]
            pb_placeholders = [op_to_protobuf(o) for o in placeholders]

            for pb_ops, pb_edges in pb_graph:
                msg = make_computation_request(pb_ops, pb_edges, pb_returns,
                                               pb_placeholders)
                yield msg
                pb_returns, pb_placeholders = [], []

        if not self.is_trans_built:
            raise RuntimeError(
                "call build_transformer before create_computation")

        update_comm_deps(returns)

        self.computation_response_future = self.RPC.Computation.future(
            generate_messages(), _TIMEOUT_SECONDS)
예제 #2
0
    def computation(self, returns, placeholders):
        #
        # don't actually create a computation, that has to be done inside process
        #
        # instead, return a lightweight computation wrapper that can be used later.
        class AsyncComputation(object):
            def __init__(self, async_transformer):
                self.async_transformer = async_transformer
                self.comp_id = self.async_transformer.new_comp_id()

            def feed_input(self, values):
                if not self.async_transformer.started:
                    self.async_transformer.start()
                    self.async_transformer.started = True

                # Does this need to be thread safe? only one caller thread right?
                # no- the caller is actually the mapper
                self.async_transformer.work_q.put((self.comp_id, values))

            def get_results(self):
                while True:
                    try:
                        q = self.async_transformer.results_qs[self.comp_id]
                        return_list = q.get(timeout=AsyncTransformer.SLEEP_S)
                        # TODO set self.returns somewhere cleaner
                        return_dict = {
                            op: return_list[mypos]
                            for (op, mypos) in iteritems(self.returns)
                        }
                        return return_dict
                    except Exception as e:
                        if isinstance(e, Empty):
                            if not self.async_transformer.is_alive():
                                ecode = self.async_transformer.exitcode
                                if sys.platform == 'darwin' and ecode == -signal.SIGSEGV:
                                    import pytest
                                    pytest.xfail(
                                        "Hetr: OSX blas fork-safety issue (#961)"
                                    )
                                elif ecode == PYCUDA_LOGIC_ERROR_CODE:
                                    import pytest
                                    pytest.xfail(
                                        "Hetr: CUDA driver init in child issue (#1059)"
                                    )
                                raise RuntimeError(
                                    "Child process unexpectedly exited with code ",
                                    ecode)
                        else:
                            raise

        update_comm_deps(returns)
        c = AsyncComputation(self)
        self.results_qs[c.comp_id] = self.manager.Queue()
        self.computation_builds[c.comp_id] = (returns, placeholders)
        self.computation_q.put(c.comp_id)
        return c
예제 #3
0
    def create_computation(self, returns, placeholders):
        logger.info("client: create_computation")

        def make_computation_request(pb_ops,
                                     pb_edges,
                                     pb_returns=None,
                                     pb_placeholders=None):
            if pb_returns or pb_placeholders:
                return hetr_pb2.ComputationRequest(
                    ops=pb_ops,
                    edges=pb_edges,
                    returns=pb_returns,
                    placeholders=pb_placeholders)
            else:
                return hetr_pb2.ComputationRequest(ops=pb_ops, edges=pb_edges)

        def generate_returns_placeholders():
            pb_returns = []
            pb_placeholders = []
            for op in returns:
                pb_returns.append(op_to_protobuf(op))
            for op in placeholders:
                pb_placeholders.append(op_to_protobuf(op))
            return pb_returns, pb_placeholders

        def generate_messages():
            pb_ops, pb_edges = [], []
            pb_returns, pb_placeholders = generate_returns_placeholders()
            ops = Op.all_op_references(returns + list(placeholders))
            for i, op in enumerate(ops):
                pb_ops.append(op_to_protobuf(op))
                add_edges(pb_edges, pb_ops, op)
                if (i != 0 and i % _OPS_PER_MSG == 0) or (i == len(ops) - 1):
                    msg = make_computation_request(pb_ops, pb_edges,
                                                   pb_returns, pb_placeholders)
                    yield msg

                    pb_ops, pb_edges = [], []
                    pb_returns, pb_placeholders = [], []

        if not self.is_trans_built:
            raise RuntimeError(
                "call build_transformer before create_computation")
        update_comm_deps(returns)

        self.computation_response_future = self.RPC.Computation.future(
            generate_messages(), _TIMEOUT_SECONDS)
예제 #4
0
def test_update_comm_deps_scatter_gather():
    ax_a = ng.make_axis(length=10, name='A')
    ax_b = ng.make_axis(length=15, name='B')
    axes = ng.make_axes([ax_a, ax_b])

    parallel_metadata = dict(parallel=ax_a,
                             device_id=(0, 1),
                             transformer=None,
                             host_transformer=None,
                             device=None)
    with ng.metadata(transformer='cpu0'):
        with ng.metadata(**parallel_metadata):
            from_node_a = ng.placeholder(axes)
            to_node_a = ng.placeholder(axes)
        scatter_send_x = ScatterSendOp(from_node=from_node_a,
                                       to_node=to_node_a)
        scatter_recv_a = ScatterRecvOp(to_node=to_node_a,
                                       send_node=scatter_send_x)
        with ng.metadata(**parallel_metadata):
            x_plus_one_a = scatter_recv_a + 1
        gather_send_x_plus_one_a = GatherSendOp(from_node=x_plus_one_a)

    with ng.metadata(transformer='cpu1'):
        with ng.metadata(**parallel_metadata):
            to_node_b = ng.placeholder(axes)
        scatter_recv_b = ScatterRecvOp(to_node=to_node_b,
                                       send_node=scatter_send_x)
        with ng.metadata(**parallel_metadata):
            x_plus_one_b = scatter_recv_b + 1
        gather_send_x_plus_one_b = GatherSendOp(from_node=x_plus_one_b)

    with ng.metadata(transformer='cpu0'):
        with ng.metadata(**parallel_metadata):
            gather_recv_x_plus_one_a = GatherRecvOp(
                from_node=from_node_a,
                to_node=to_node_a,
                send_node=gather_send_x_plus_one_a)
            z_a = gather_recv_x_plus_one_a + 1

    update_comm_deps((scatter_send_x, gather_send_x_plus_one_a, z_a))
    update_comm_deps((gather_send_x_plus_one_b, ))

    assert set([scatter_send_x]) == set(scatter_recv_a.control_deps)
    assert set([scatter_send_x, gather_send_x_plus_one_a]) == \
        set(gather_recv_x_plus_one_a.control_deps)
예제 #5
0
 def computation(self, returns, placeholders):
     if not self.initialized:
         raise RuntimeError("RPC build_transformer request failed!")
     update_comm_deps(returns)
     pb_subgraph = _serialize_graph(returns + list(placeholders))
     pb_returns = []
     pb_placeholders = []
     for op in returns:
         pb_returns.append(op_to_protobuf(op))
     for op in placeholders:
         pb_placeholders.append(op_to_protobuf(op))
     response = self.RPC.Computation(
         hetr_pb2.ComputationRequest(
             subgraph=pb_subgraph,
             returns=pb_returns,
             placeholders=pb_placeholders),
         _TIMEOUT_SECONDS)
     if response.comp_id >= 0:
         rpcComputationClient = RPCComputationClient(response.comp_id, self.RPC)
         return rpcComputationClient
     else:
         raise RuntimeError("RPC computation request failed!")
예제 #6
0
def test_update_comm_deps():
    with ng.metadata(transformer='cpu0'):
        z, recv_x, recv_x_plus_one, send_x, x_plus_one, from_node, send_x_plus_one = \
            create_send_recv_graph()
    update_comm_deps((z, send_x))
    assert recv_x_plus_one in z.all_deps