예제 #1
0
    def _maybe_compile_app(self, op):
        app_sig = get_app_sha256(op.attr)
        space = self._builtin_workspace
        if types_pb2.GAR in op.attr:
            space = self._udf_app_workspace
        app_lib_path = get_lib_path(os.path.join(space, app_sig), app_sig)
        if not os.path.isfile(app_lib_path):
            compiled_path = self._compile_lib_and_distribute(compile_app, app_sig, op)
            if app_lib_path != compiled_path:
                raise RuntimeError("Computed path not equal to compiled path.")

        op.attr[types_pb2.APP_LIBRARY_PATH].CopyFrom(
            attr_value_pb2.AttrValue(s=app_lib_path.encode("utf-8"))
        )
        return op, app_sig, app_lib_path
예제 #2
0
    def _maybe_register_graph(self, op, session_id):
        graph_sig = get_graph_sha256(op.attr)
        space = self._builtin_workspace
        graph_lib_path = get_lib_path(os.path.join(space, graph_sig), graph_sig)
        if not os.path.isfile(graph_lib_path):
            compiled_path = self._compile_lib_and_distribute(
                compile_graph_frame, graph_sig, op
            )
            if graph_lib_path != compiled_path:
                raise RuntimeError("Computed path not equal to compiled path.")
        if graph_sig not in self._object_manager:
            # register graph
            op_def = op_def_pb2.OpDef(op=types_pb2.REGISTER_GRAPH_TYPE)
            op_def.attr[types_pb2.GRAPH_LIBRARY_PATH].CopyFrom(
                attr_value_pb2.AttrValue(s=graph_lib_path.encode("utf-8"))
            )
            op_def.attr[types_pb2.TYPE_SIGNATURE].CopyFrom(
                attr_value_pb2.AttrValue(s=graph_sig.encode("utf-8"))
            )
            op_def.attr[types_pb2.GRAPH_TYPE].CopyFrom(
                attr_value_pb2.AttrValue(
                    graph_type=op.attr[types_pb2.GRAPH_TYPE].graph_type
                )
            )
            dag_def = op_def_pb2.DagDef()
            dag_def.op.extend([op_def])
            register_request = message_pb2.RunStepRequest(
                session_id=session_id, dag_def=dag_def
            )
            register_response = self._analytical_engine_stub.RunStep(register_request)

            if register_response.status.code == error_codes_pb2.OK:
                self._object_manager.put(
                    graph_sig,
                    LibMeta(register_response.result, "graph_frame", graph_lib_path),
                )
            else:
                raise RuntimeError("Error occur when register graph")
        op.attr[types_pb2.TYPE_SIGNATURE].CopyFrom(
            attr_value_pb2.AttrValue(s=graph_sig.encode("utf-8"))
        )
        return op