示例#1
0
 def __init__(self, request_iterator: Sequence[message_pb2.RunStepRequest]):
     self._dag_queue = queue.Queue()
     req_head = None
     # a list of chunks
     req_bodies = []
     for req in request_iterator:
         if req.HasField("head"):
             req_head = req
         else:
             req_bodies.append(req)
     if req_head is not None:
         # split dag
         dag = op_def_pb2.DagDef()
         dag_for = GSEngine.analytical_engine
         dag_bodies = []
         for op in req_head.head.dag_def.op:
             if self.is_splited_op(op):
                 if dag.op:
                     self._dag_queue.put((dag_for, dag, dag_bodies))
                 # init empty dag
                 dag = op_def_pb2.DagDef()
                 dag_for = self.get_op_exec_engine(op)
                 dag_bodies = []
             # select op
             dag.op.extend([copy.deepcopy(op)])
             for req_body in req_bodies:
                 # select chunks belong to this op
                 if req_body.body.op_key == op.key:
                     dag_bodies.append(req_body)
         if dag.op:
             self._dag_queue.put((dag_for, dag, dag_bodies))
示例#2
0
 def extract_subdag_for(self, ops):
     """Extract all nodes included the path that can reach the target ops."""
     out = op_def_pb2.DagDef()
     # leaf op handle
     # there are two kinds of leaf op:
     #   1) unload graph / app
     #   2) networkx releated op
     if len(ops) == 1 and ops[0].is_leaf_op():
         out.op.extend([ops[0].as_op_def()])
         return out
     op_keys = list()
     # assert op is not present in current dag
     for op in ops:
         assert op.key in self._ops_by_key, "%s is not in the dag" % op.key
         assert not self._ops_by_key[
             op.key].evaluated, "%is is evaluated" % op.key
         op_keys.append(op.key)
     op_keys_to_keep = self._bfs_for_reachable_ops(op_keys)
     op_keys_to_keep = sorted(op_keys_to_keep,
                              key=lambda n: self._ops_seq_by_key[n])
     for key in op_keys_to_keep:
         op_def = self._ops_by_key[key].as_op_def()
         # mark op fetch or not
         if key in op_keys:
             op_def.fetch = True
         out.op.extend([op_def])
     return out
示例#3
0
def create_single_op_dag(op_type, config=None):
    op_def = op_def_pb2.OpDef(op=op_type, key=uuid.uuid4().hex)
    if config:
        for k, v in config.items():
            op_def.attr[k].CopyFrom(v)

    dag = op_def_pb2.DagDef()
    dag.op.extend([op_def])
    return dag
示例#4
0
    def run(self, fetch):
        """Run operations of `fetch`.

        Args:
            fetch: :class:`Operation`

        Raises:
            RuntimeError:
                Client disconnect to the service. Or run on a closed session.

            ValueError:
                If fetch is not a instance of :class:`Operation`. Or
                the fetch has been evaluated.

            InvalidArgumentError:
                Not recognized on output type.

        Returns:
            Different values for different output types of :class:`Operation`
        """

        # prepare names to run and fetch
        if hasattr(fetch, "op"):
            fetch = fetch.op
        if not isinstance(fetch, Operation):
            raise ValueError("Expect a `Operation`")
        if fetch.output is not None:
            raise ValueError("The op <%s> are evaluated duplicated." %
                             fetch.key)

        # convert to list to be compatible with rpc client method signature
        fetch_ops = [fetch]

        dag = op_def_pb2.DagDef()
        for op in fetch_ops:
            dag.op.extend([copy.deepcopy(op.as_op_def())])

        if self._closed:
            raise RuntimeError("Attempted to use a closed Session.")

        if not self._grpc_client:
            raise RuntimeError("Session disconnected.")

        # execute the query
        try:
            response = self._grpc_client.run(dag)
        except FatalError:
            self.close()
            raise
        check_argument(
            len(fetch_ops) == 1,
            "Cannot execute multiple ops at the same time")
        return self._parse_value(fetch_ops[0], response)
示例#5
0
 def _get_engine_config(self):
     op_def = op_def_pb2.OpDef(op=types_pb2.GET_ENGINE_CONFIG)
     dag_def = op_def_pb2.DagDef()
     dag_def.op.extend([op_def])
     fetch_request = message_pb2.RunStepRequest(
         session_id=self._session_id, dag_def=dag_def
     )
     fetch_response = self._analytical_engine_stub.RunStep(fetch_request)
     config = json.loads(fetch_response.result.decode("utf-8"))
     if self._launcher_type == types_pb2.K8S:
         config["vineyard_service_name"] = self._launcher.get_vineyard_service_name()
         config["vineyard_rpc_endpoint"] = self._launcher.get_vineyard_rpc_endpoint()
     return config
示例#6
0
    def _maybe_register_graph(self, op, session_id):
        graph_sig = get_graph_sha256(op.attr)
        space = self._builtin_workspace
        graph_lib_path = get_lib_path(os.path.join(space, graph_sig), graph_sig)
        if not os.path.isfile(graph_lib_path):
            compiled_path = self._compile_lib_and_distribute(
                compile_graph_frame, graph_sig, op
            )
            if graph_lib_path != compiled_path:
                raise RuntimeError("Computed path not equal to compiled path.")
        if graph_sig not in self._object_manager:
            # register graph
            op_def = op_def_pb2.OpDef(op=types_pb2.REGISTER_GRAPH_TYPE)
            op_def.attr[types_pb2.GRAPH_LIBRARY_PATH].CopyFrom(
                attr_value_pb2.AttrValue(s=graph_lib_path.encode("utf-8"))
            )
            op_def.attr[types_pb2.TYPE_SIGNATURE].CopyFrom(
                attr_value_pb2.AttrValue(s=graph_sig.encode("utf-8"))
            )
            op_def.attr[types_pb2.GRAPH_TYPE].CopyFrom(
                attr_value_pb2.AttrValue(
                    graph_type=op.attr[types_pb2.GRAPH_TYPE].graph_type
                )
            )
            dag_def = op_def_pb2.DagDef()
            dag_def.op.extend([op_def])
            register_request = message_pb2.RunStepRequest(
                session_id=session_id, dag_def=dag_def
            )
            register_response = self._analytical_engine_stub.RunStep(register_request)

            if register_response.status.code == error_codes_pb2.OK:
                self._object_manager.put(
                    graph_sig,
                    LibMeta(register_response.result, "graph_frame", graph_lib_path),
                )
            else:
                raise RuntimeError("Error occur when register graph")
        op.attr[types_pb2.TYPE_SIGNATURE].CopyFrom(
            attr_value_pb2.AttrValue(s=graph_sig.encode("utf-8"))
        )
        return op
示例#7
0
    def _maybe_register_graph(self, op, session_id):
        graph_sig = self._generate_graph_sig(op.attr)
        if graph_sig in self._object_manager:
            lib_meta = self._object_manager.get(graph_sig)
            graph_lib_path = lib_meta.lib_path
        else:
            graph_lib_path = self._compile_lib_and_distribute(
                compile_graph_frame, graph_sig, op
            )

            # register graph
            op_def = op_def_pb2.OpDef(op=types_pb2.REGISTER_GRAPH_TYPE)
            op_def.attr[types_pb2.GRAPH_LIBRARY_PATH].CopyFrom(
                attr_value_pb2.AttrValue(s=graph_lib_path.encode("utf-8"))
            )
            op_def.attr[types_pb2.TYPE_SIGNATURE].CopyFrom(
                attr_value_pb2.AttrValue(s=graph_sig.encode("utf-8"))
            )
            op_def.attr[types_pb2.GRAPH_TYPE].CopyFrom(
                attr_value_pb2.AttrValue(
                    graph_type=op.attr[types_pb2.GRAPH_TYPE].graph_type
                )
            )
            dag_def = op_def_pb2.DagDef()
            dag_def.op.extend([op_def])
            register_request = message_pb2.RunStepRequest(
                session_id=session_id, dag_def=dag_def
            )
            register_response = self._analytical_engine_stub.RunStep(register_request)

            if register_response.status.code == error_codes_pb2.OK:
                self._object_manager.put(
                    graph_sig,
                    LibMeta(register_response.result, "graph_frame", graph_lib_path),
                )
            else:
                raise RuntimeError("Error occur when register graph")
        op.attr[types_pb2.TYPE_SIGNATURE].CopyFrom(
            attr_value_pb2.AttrValue(s=graph_sig.encode("utf-8"))
        )
        return op
示例#8
0
 def as_dag_def(self):
     """Return :class:`Dag` as a :class:`DagDef` proto buffer."""
     dag_def = op_def_pb2.DagDef()
     for _, op in self._ops_by_key.items():
         dag_def.op.extend([op.as_op_def()])
     return dag_def