def create_computation(self, pb_graph, returns, placeholders): logger.debug("client: create_computation") def make_computation_request(pb_ops, pb_edges, pb_returns=None, pb_placeholders=None): if pb_returns or pb_placeholders: return hetr_pb2.ComputationRequest( ops=pb_ops, edges=pb_edges, returns=pb_returns, placeholders=pb_placeholders) else: return hetr_pb2.ComputationRequest(ops=pb_ops, edges=pb_edges) def generate_messages(): pb_returns = [op_to_protobuf(o) for o in returns] pb_placeholders = [op_to_protobuf(o) for o in placeholders] for pb_ops, pb_edges in pb_graph: msg = make_computation_request(pb_ops, pb_edges, pb_returns, pb_placeholders) yield msg pb_returns, pb_placeholders = [], [] if not self.is_trans_built: raise RuntimeError( "call build_transformer before create_computation") update_comm_deps(returns) self.computation_response_future = self.RPC.Computation.future( generate_messages(), _TIMEOUT_SECONDS)
def computation(self, returns, placeholders): # # don't actually create a computation, that has to be done inside process # # instead, return a lightweight computation wrapper that can be used later. class AsyncComputation(object): def __init__(self, async_transformer): self.async_transformer = async_transformer self.comp_id = self.async_transformer.new_comp_id() def feed_input(self, values): if not self.async_transformer.started: self.async_transformer.start() self.async_transformer.started = True # Does this need to be thread safe? only one caller thread right? # no- the caller is actually the mapper self.async_transformer.work_q.put((self.comp_id, values)) def get_results(self): while True: try: q = self.async_transformer.results_qs[self.comp_id] return_list = q.get(timeout=AsyncTransformer.SLEEP_S) # TODO set self.returns somewhere cleaner return_dict = { op: return_list[mypos] for (op, mypos) in iteritems(self.returns) } return return_dict except Exception as e: if isinstance(e, Empty): if not self.async_transformer.is_alive(): ecode = self.async_transformer.exitcode if sys.platform == 'darwin' and ecode == -signal.SIGSEGV: import pytest pytest.xfail( "Hetr: OSX blas fork-safety issue (#961)" ) elif ecode == PYCUDA_LOGIC_ERROR_CODE: import pytest pytest.xfail( "Hetr: CUDA driver init in child issue (#1059)" ) raise RuntimeError( "Child process unexpectedly exited with code ", ecode) else: raise update_comm_deps(returns) c = AsyncComputation(self) self.results_qs[c.comp_id] = self.manager.Queue() self.computation_builds[c.comp_id] = (returns, placeholders) self.computation_q.put(c.comp_id) return c
def create_computation(self, returns, placeholders): logger.info("client: create_computation") def make_computation_request(pb_ops, pb_edges, pb_returns=None, pb_placeholders=None): if pb_returns or pb_placeholders: return hetr_pb2.ComputationRequest( ops=pb_ops, edges=pb_edges, returns=pb_returns, placeholders=pb_placeholders) else: return hetr_pb2.ComputationRequest(ops=pb_ops, edges=pb_edges) def generate_returns_placeholders(): pb_returns = [] pb_placeholders = [] for op in returns: pb_returns.append(op_to_protobuf(op)) for op in placeholders: pb_placeholders.append(op_to_protobuf(op)) return pb_returns, pb_placeholders def generate_messages(): pb_ops, pb_edges = [], [] pb_returns, pb_placeholders = generate_returns_placeholders() ops = Op.all_op_references(returns + list(placeholders)) for i, op in enumerate(ops): pb_ops.append(op_to_protobuf(op)) add_edges(pb_edges, pb_ops, op) if (i != 0 and i % _OPS_PER_MSG == 0) or (i == len(ops) - 1): msg = make_computation_request(pb_ops, pb_edges, pb_returns, pb_placeholders) yield msg pb_ops, pb_edges = [], [] pb_returns, pb_placeholders = [], [] if not self.is_trans_built: raise RuntimeError( "call build_transformer before create_computation") update_comm_deps(returns) self.computation_response_future = self.RPC.Computation.future( generate_messages(), _TIMEOUT_SECONDS)
def test_update_comm_deps_scatter_gather(): ax_a = ng.make_axis(length=10, name='A') ax_b = ng.make_axis(length=15, name='B') axes = ng.make_axes([ax_a, ax_b]) parallel_metadata = dict(parallel=ax_a, device_id=(0, 1), transformer=None, host_transformer=None, device=None) with ng.metadata(transformer='cpu0'): with ng.metadata(**parallel_metadata): from_node_a = ng.placeholder(axes) to_node_a = ng.placeholder(axes) scatter_send_x = ScatterSendOp(from_node=from_node_a, to_node=to_node_a) scatter_recv_a = ScatterRecvOp(to_node=to_node_a, send_node=scatter_send_x) with ng.metadata(**parallel_metadata): x_plus_one_a = scatter_recv_a + 1 gather_send_x_plus_one_a = GatherSendOp(from_node=x_plus_one_a) with ng.metadata(transformer='cpu1'): with ng.metadata(**parallel_metadata): to_node_b = ng.placeholder(axes) scatter_recv_b = ScatterRecvOp(to_node=to_node_b, send_node=scatter_send_x) with ng.metadata(**parallel_metadata): x_plus_one_b = scatter_recv_b + 1 gather_send_x_plus_one_b = GatherSendOp(from_node=x_plus_one_b) with ng.metadata(transformer='cpu0'): with ng.metadata(**parallel_metadata): gather_recv_x_plus_one_a = GatherRecvOp( from_node=from_node_a, to_node=to_node_a, send_node=gather_send_x_plus_one_a) z_a = gather_recv_x_plus_one_a + 1 update_comm_deps((scatter_send_x, gather_send_x_plus_one_a, z_a)) update_comm_deps((gather_send_x_plus_one_b, )) assert set([scatter_send_x]) == set(scatter_recv_a.control_deps) assert set([scatter_send_x, gather_send_x_plus_one_a]) == \ set(gather_recv_x_plus_one_a.control_deps)
def computation(self, returns, placeholders): if not self.initialized: raise RuntimeError("RPC build_transformer request failed!") update_comm_deps(returns) pb_subgraph = _serialize_graph(returns + list(placeholders)) pb_returns = [] pb_placeholders = [] for op in returns: pb_returns.append(op_to_protobuf(op)) for op in placeholders: pb_placeholders.append(op_to_protobuf(op)) response = self.RPC.Computation( hetr_pb2.ComputationRequest( subgraph=pb_subgraph, returns=pb_returns, placeholders=pb_placeholders), _TIMEOUT_SECONDS) if response.comp_id >= 0: rpcComputationClient = RPCComputationClient(response.comp_id, self.RPC) return rpcComputationClient else: raise RuntimeError("RPC computation request failed!")
def test_update_comm_deps(): with ng.metadata(transformer='cpu0'): z, recv_x, recv_x_plus_one, send_x, x_plus_one, from_node, send_x_plus_one = \ create_send_recv_graph() update_comm_deps((z, send_x)) assert recv_x_plus_one in z.all_deps