def __init__(self, request_iterator: Sequence[message_pb2.RunStepRequest]): self._dag_queue = queue.Queue() req_head = None # a list of chunks req_bodies = [] for req in request_iterator: if req.HasField("head"): req_head = req else: req_bodies.append(req) if req_head is not None: # split dag dag = op_def_pb2.DagDef() dag_for = GSEngine.analytical_engine dag_bodies = [] for op in req_head.head.dag_def.op: if self.is_splited_op(op): if dag.op: self._dag_queue.put((dag_for, dag, dag_bodies)) # init empty dag dag = op_def_pb2.DagDef() dag_for = self.get_op_exec_engine(op) dag_bodies = [] # select op dag.op.extend([copy.deepcopy(op)]) for req_body in req_bodies: # select chunks belong to this op if req_body.body.op_key == op.key: dag_bodies.append(req_body) if dag.op: self._dag_queue.put((dag_for, dag, dag_bodies))
def extract_subdag_for(self, ops): """Extract all nodes included the path that can reach the target ops.""" out = op_def_pb2.DagDef() # leaf op handle # there are two kinds of leaf op: # 1) unload graph / app # 2) networkx releated op if len(ops) == 1 and ops[0].is_leaf_op(): out.op.extend([ops[0].as_op_def()]) return out op_keys = list() # assert op is not present in current dag for op in ops: assert op.key in self._ops_by_key, "%s is not in the dag" % op.key assert not self._ops_by_key[ op.key].evaluated, "%is is evaluated" % op.key op_keys.append(op.key) op_keys_to_keep = self._bfs_for_reachable_ops(op_keys) op_keys_to_keep = sorted(op_keys_to_keep, key=lambda n: self._ops_seq_by_key[n]) for key in op_keys_to_keep: op_def = self._ops_by_key[key].as_op_def() # mark op fetch or not if key in op_keys: op_def.fetch = True out.op.extend([op_def]) return out
def create_single_op_dag(op_type, config=None): op_def = op_def_pb2.OpDef(op=op_type, key=uuid.uuid4().hex) if config: for k, v in config.items(): op_def.attr[k].CopyFrom(v) dag = op_def_pb2.DagDef() dag.op.extend([op_def]) return dag
def run(self, fetch): """Run operations of `fetch`. Args: fetch: :class:`Operation` Raises: RuntimeError: Client disconnect to the service. Or run on a closed session. ValueError: If fetch is not a instance of :class:`Operation`. Or the fetch has been evaluated. InvalidArgumentError: Not recognized on output type. Returns: Different values for different output types of :class:`Operation` """ # prepare names to run and fetch if hasattr(fetch, "op"): fetch = fetch.op if not isinstance(fetch, Operation): raise ValueError("Expect a `Operation`") if fetch.output is not None: raise ValueError("The op <%s> are evaluated duplicated." % fetch.key) # convert to list to be compatible with rpc client method signature fetch_ops = [fetch] dag = op_def_pb2.DagDef() for op in fetch_ops: dag.op.extend([copy.deepcopy(op.as_op_def())]) if self._closed: raise RuntimeError("Attempted to use a closed Session.") if not self._grpc_client: raise RuntimeError("Session disconnected.") # execute the query try: response = self._grpc_client.run(dag) except FatalError: self.close() raise check_argument( len(fetch_ops) == 1, "Cannot execute multiple ops at the same time") return self._parse_value(fetch_ops[0], response)
def _get_engine_config(self): op_def = op_def_pb2.OpDef(op=types_pb2.GET_ENGINE_CONFIG) dag_def = op_def_pb2.DagDef() dag_def.op.extend([op_def]) fetch_request = message_pb2.RunStepRequest( session_id=self._session_id, dag_def=dag_def ) fetch_response = self._analytical_engine_stub.RunStep(fetch_request) config = json.loads(fetch_response.result.decode("utf-8")) if self._launcher_type == types_pb2.K8S: config["vineyard_service_name"] = self._launcher.get_vineyard_service_name() config["vineyard_rpc_endpoint"] = self._launcher.get_vineyard_rpc_endpoint() return config
def _maybe_register_graph(self, op, session_id): graph_sig = get_graph_sha256(op.attr) space = self._builtin_workspace graph_lib_path = get_lib_path(os.path.join(space, graph_sig), graph_sig) if not os.path.isfile(graph_lib_path): compiled_path = self._compile_lib_and_distribute( compile_graph_frame, graph_sig, op ) if graph_lib_path != compiled_path: raise RuntimeError("Computed path not equal to compiled path.") if graph_sig not in self._object_manager: # register graph op_def = op_def_pb2.OpDef(op=types_pb2.REGISTER_GRAPH_TYPE) op_def.attr[types_pb2.GRAPH_LIBRARY_PATH].CopyFrom( attr_value_pb2.AttrValue(s=graph_lib_path.encode("utf-8")) ) op_def.attr[types_pb2.TYPE_SIGNATURE].CopyFrom( attr_value_pb2.AttrValue(s=graph_sig.encode("utf-8")) ) op_def.attr[types_pb2.GRAPH_TYPE].CopyFrom( attr_value_pb2.AttrValue( graph_type=op.attr[types_pb2.GRAPH_TYPE].graph_type ) ) dag_def = op_def_pb2.DagDef() dag_def.op.extend([op_def]) register_request = message_pb2.RunStepRequest( session_id=session_id, dag_def=dag_def ) register_response = self._analytical_engine_stub.RunStep(register_request) if register_response.status.code == error_codes_pb2.OK: self._object_manager.put( graph_sig, LibMeta(register_response.result, "graph_frame", graph_lib_path), ) else: raise RuntimeError("Error occur when register graph") op.attr[types_pb2.TYPE_SIGNATURE].CopyFrom( attr_value_pb2.AttrValue(s=graph_sig.encode("utf-8")) ) return op
def _maybe_register_graph(self, op, session_id): graph_sig = self._generate_graph_sig(op.attr) if graph_sig in self._object_manager: lib_meta = self._object_manager.get(graph_sig) graph_lib_path = lib_meta.lib_path else: graph_lib_path = self._compile_lib_and_distribute( compile_graph_frame, graph_sig, op ) # register graph op_def = op_def_pb2.OpDef(op=types_pb2.REGISTER_GRAPH_TYPE) op_def.attr[types_pb2.GRAPH_LIBRARY_PATH].CopyFrom( attr_value_pb2.AttrValue(s=graph_lib_path.encode("utf-8")) ) op_def.attr[types_pb2.TYPE_SIGNATURE].CopyFrom( attr_value_pb2.AttrValue(s=graph_sig.encode("utf-8")) ) op_def.attr[types_pb2.GRAPH_TYPE].CopyFrom( attr_value_pb2.AttrValue( graph_type=op.attr[types_pb2.GRAPH_TYPE].graph_type ) ) dag_def = op_def_pb2.DagDef() dag_def.op.extend([op_def]) register_request = message_pb2.RunStepRequest( session_id=session_id, dag_def=dag_def ) register_response = self._analytical_engine_stub.RunStep(register_request) if register_response.status.code == error_codes_pb2.OK: self._object_manager.put( graph_sig, LibMeta(register_response.result, "graph_frame", graph_lib_path), ) else: raise RuntimeError("Error occur when register graph") op.attr[types_pb2.TYPE_SIGNATURE].CopyFrom( attr_value_pb2.AttrValue(s=graph_sig.encode("utf-8")) ) return op
def as_dag_def(self): """Return :class:`Dag` as a :class:`DagDef` proto buffer.""" dag_def = op_def_pb2.DagDef() for _, op in self._ops_by_key.items(): dag_def.op.extend([op.as_op_def()]) return dag_def