def test_initialize_class_missing_authenticate(): app = Sanic() with pytest.raises(exceptions.AuthenticateNotImplemented): Initialize(app)
import app.routes import app.api my_views = ( ('/register', app.routes.routes.Register), ('/login', app.routes.routes.Login), ('/uirefresh', app.routes.routes.UIRefresh), ) Initialize(apfell, authenticate=app.routes.authentication.authenticate, retrieve_user=app.routes.authentication.retrieve_user, cookie_set=True, cookie_strict=False, cookie_access_token_name='access_token', cookie_refresh_token_name='refresh_token', cookie_httponly=True, scopes_enabled=True, add_scopes_to_payload=app.routes.authentication.add_scopes_to_payload, scopes_name='scope', secret='apfell_secret jwt for signing here', url_prefix='/', class_views=my_views, path_to_authenticate='/auth', path_to_retrieve_user='******', path_to_verify='/verify', path_to_refresh='/refresh', refresh_token_enabled=True, expiration_delta=14400, # initial token expiration time store_refresh_token=app.routes.authentication.store_refresh_token, retrieve_refresh_token=app.routes.authentication.retrieve_refresh_token)
@blueprint.get("/", strict_slashes=True) @protected() def protected_hello_world(request): return json({"message": "hello world"}) async def authenticate(request, *args, **kwargs): return {"user_id": 1} app = Sanic() app.blueprint(blueprint) sanicjwt = Initialize(app, authenticate=authenticate) def test_protected_blueprint(): _, response = app.test_client.get("/test/") assert response.status == 401 assert response.json.get("exception") == "Unauthorized" assert "Authorization header not present." in response.json.get("reasons") _, response = app.test_client.post("/auth", json={ "username": "******", "password": "******" })
def create_app( agent=None, cors_origins: Union[Text, List[Text]] = "*", auth_token: Optional[Text] = None, jwt_secret: Optional[Text] = None, jwt_method: Text = "HS256", ): """Class representing a Rasa Core HTTP server.""" app = Sanic(__name__) app.config.RESPONSE_TIMEOUT = 60 * 60 CORS( app, resources={r"/*": {"origins": cors_origins or ""}}, automatic_options=True ) # Setup the Sanic-JWT extension if jwt_secret and jwt_method: # since we only want to check signatures, we don't actually care # about the JWT method and set the passed secret as either symmetric # or asymmetric key. jwt lib will choose the right one based on method app.config["USE_JWT"] = True Initialize( app, secret=jwt_secret, authenticate=authenticate, algorithm=jwt_method, user_id="username", ) app.agent = agent @app.listener("after_server_start") async def warn_if_agent_is_unavailable(app, loop): if not app.agent or not app.agent.is_ready(): logger.warning( "The loaded agent is not ready to be used yet " "(e.g. only the NLU interpreter is configured, " "but no Core model is loaded). This is NOT AN ISSUE " "some endpoints are not available until the agent " "is ready though." ) @app.exception(NotFound) @app.exception(ErrorResponse) async def ignore_404s(request: Request, exception: ErrorResponse): return response.json(exception.error_info, status=exception.status) @app.get("/") async def hello(request: Request): """Check if the server is running and responds with the version.""" return response.text("hello from Rasa: " + rasa.__version__) @app.get("/version") async def version(request: Request): """respond with the version number of the installed rasa core.""" return response.json( { "version": rasa.__version__, "minimum_compatible_version": MINIMUM_COMPATIBLE_VERSION, } ) # <sender_id> can be be 'default' if there's only 1 client @app.post("/conversations/<sender_id>/execute") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def execute_action(request: Request, sender_id: Text): request_params = request.json # we'll accept both parameters to specify the actions name action_to_execute = request_params.get("name") or request_params.get("action") policy = request_params.get("policy", None) confidence = request_params.get("confidence", None) verbosity = event_verbosity_parameter(request, EventVerbosity.AFTER_RESTART) try: out = CollectingOutputChannel() await app.agent.execute_action( sender_id, action_to_execute, out, policy, confidence ) # retrieve tracker and set to requested state tracker = app.agent.tracker_store.get_or_create_tracker(sender_id) state = tracker.current_state(verbosity) return response.json({"tracker": state, "messages": out.messages}) except ValueError as e: raise ErrorResponse(400, "ValueError", e) except Exception as e: logger.error( "Encountered an exception while running action '{}'. " "Bot will continue, but the actions events are lost. " "Make sure to fix the exception in your custom " "code.".format(action_to_execute) ) logger.debug(e, exc_info=True) raise ErrorResponse( 500, "ValueError", "Server failure. Error: {}".format(e) ) @app.post("/conversations/<sender_id>/tracker/events") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def append_event(request: Request, sender_id: Text): """Append a list of events to the state of a conversation""" request_params = request.json evt = Event.from_parameters(request_params) tracker = app.agent.tracker_store.get_or_create_tracker(sender_id) verbosity = event_verbosity_parameter(request, EventVerbosity.AFTER_RESTART) if evt: tracker.update(evt) app.agent.tracker_store.save(tracker) return response.json(tracker.current_state(verbosity)) else: logger.warning( "Append event called, but could not extract a " "valid event. Request JSON: {}".format(request_params) ) raise ErrorResponse( 400, "InvalidParameter", "Couldn't extract a proper event from the request body.", {"parameter": "", "in": "body"}, ) @app.put("/conversations/<sender_id>/tracker/events") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def replace_events(request: Request, sender_id: Text): """Use a list of events to set a conversations tracker to a state.""" request_params = request.json verbosity = event_verbosity_parameter(request, EventVerbosity.AFTER_RESTART) tracker = DialogueStateTracker.from_dict( sender_id, request_params, app.agent.domain.slots ) # will override an existing tracker with the same id! app.agent.tracker_store.save(tracker) return response.json(tracker.current_state(verbosity)) @app.get("/conversations") @requires_auth(app, auth_token) async def list_trackers(request: Request): if app.agent.tracker_store: keys = list(app.agent.tracker_store.keys()) else: keys = [] return response.json(keys) @app.get("/conversations/<sender_id>/tracker") @requires_auth(app, auth_token) async def retrieve_tracker(request: Request, sender_id: Text): """Get a dump of a conversation's tracker including its events.""" if not app.agent.tracker_store: raise ErrorResponse( 503, "NoTrackerStore", "No tracker store available. Make sure to " "configure a tracker store when starting " "the server.", ) # parameters default_verbosity = EventVerbosity.AFTER_RESTART # this is for backwards compatibility if "ignore_restarts" in request.raw_args: ignore_restarts = utils.bool_arg(request, "ignore_restarts", default=False) if ignore_restarts: default_verbosity = EventVerbosity.ALL if "events" in request.raw_args: include_events = utils.bool_arg(request, "events", default=True) if not include_events: default_verbosity = EventVerbosity.NONE verbosity = event_verbosity_parameter(request, default_verbosity) # retrieve tracker and set to requested state tracker = app.agent.tracker_store.get_or_create_tracker(sender_id) if not tracker: raise ErrorResponse( 503, "NoDomain", "Could not retrieve tracker. Most likely " "because there is no domain set on the agent.", ) until_time = utils.float_arg(request, "until") if until_time is not None: tracker = tracker.travel_back_in_time(until_time) # dump and return tracker state = tracker.current_state(verbosity) return response.json(state) @app.get("/conversations/<sender_id>/story") @requires_auth(app, auth_token) async def retrieve_story(request: Request, sender_id: Text): """Get an end-to-end story corresponding to this conversation.""" if not app.agent.tracker_store: raise ErrorResponse( 503, "NoTrackerStore", "No tracker store available. Make sure to " "configure " "a tracker store when starting the server.", ) # retrieve tracker and set to requested state tracker = app.agent.tracker_store.get_or_create_tracker(sender_id) if not tracker: raise ErrorResponse( 503, "NoDomain", "Could not retrieve tracker. Most likely " "because there is no domain set on the agent.", ) until_time = utils.float_arg(request, "until") if until_time is not None: tracker = tracker.travel_back_in_time(until_time) # dump and return tracker state = tracker.export_stories(e2e=True) return response.text(state) @app.route("/conversations/<sender_id>/respond", methods=["GET", "POST"]) @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def respond(request: Request, sender_id: Text): request_params = request_parameters(request) if "query" in request_params: message = request_params["query"] elif "q" in request_params: message = request_params["q"] else: raise ErrorResponse( 400, "InvalidParameter", "Missing the message parameter.", {"parameter": "query", "in": "query"}, ) try: # Set the output channel out = CollectingOutputChannel() # Fetches the appropriate bot response in a json format responses = await app.agent.handle_text( message, output_channel=out, sender_id=sender_id ) return response.json(responses) except Exception as e: logger.exception("Caught an exception during respond.") raise ErrorResponse( 500, "ActionException", "Server failure. Error: {}".format(e) ) @app.post("/conversations/<sender_id>/predict") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def predict(request: Request, sender_id: Text): try: # Fetches the appropriate bot response in a json format responses = app.agent.predict_next(sender_id) responses["scores"] = sorted( responses["scores"], key=lambda k: (-k["score"], k["action"]) ) return response.json(responses) except Exception as e: logger.exception("Caught an exception during prediction.") raise ErrorResponse( 500, "PredictionException", "Server failure. Error: {}".format(e) ) @app.post("/conversations/<sender_id>/messages") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def log_message(request: Request, sender_id: Text): request_params = request.json try: message = request_params["message"] except KeyError: message = request_params.get("text") sender = request_params.get("sender") parse_data = request_params.get("parse_data") verbosity = event_verbosity_parameter(request, EventVerbosity.AFTER_RESTART) # TODO: implement properly for agent / bot if sender != "user": raise ErrorResponse( 500, "NotSupported", "Currently, only user messages can be passed " "to this endpoint. Messages of sender '{}' " "cannot be handled.".format(sender), {"parameter": "sender", "in": "body"}, ) try: usermsg = UserMessage(message, None, sender_id, parse_data) tracker = await app.agent.log_message(usermsg) return response.json(tracker.current_state(verbosity)) except Exception as e: logger.exception("Caught an exception while logging message.") raise ErrorResponse( 500, "MessageException", "Server failure. Error: {}".format(e) ) @app.post("/model") @requires_auth(app, auth_token) async def load_model(request: Request): """Loads a zipped model, replacing the existing one.""" if "model" not in request.files: # model file is missing raise ErrorResponse( 400, "InvalidParameter", "You did not supply a model as part of your request.", {"parameter": "model", "in": "body"}, ) model_file = request.files["model"] logger.info("Received new model through REST interface.") zipped_path = tempfile.NamedTemporaryFile(delete=False, suffix=".zip") zipped_path.close() model_directory = tempfile.mkdtemp() model_file.save(zipped_path.name) logger.debug("Downloaded model to {}".format(zipped_path.name)) zip_ref = zipfile.ZipFile(zipped_path.name, "r") zip_ref.extractall(model_directory) zip_ref.close() logger.debug("Unzipped model to {}".format(os.path.abspath(model_directory))) domain_path = os.path.join(os.path.abspath(model_directory), "domain.yml") domain = Domain.load(domain_path) ensemble = PolicyEnsemble.load(model_directory) app.agent.update_model(domain, ensemble, None) logger.debug("Finished loading new agent.") return response.text("", 204) @app.post("/evaluate") @requires_auth(app, auth_token) async def evaluate_stories(request: Request): """Evaluate stories against the currently loaded model.""" import rasa.nlu.utils tmp_file = rasa.nlu.utils.create_temporary_file(request.body, mode="w+b") use_e2e = utils.bool_arg(request, "e2e", default=False) try: evaluation = await test(tmp_file, app.agent, use_e2e=use_e2e) return response.json(evaluation) except ValueError as e: raise ErrorResponse( 400, "FailedEvaluation", "Evaluation could not be created. Error: {}".format(e), ) @app.post("/intentEvaluation") @requires_auth(app, auth_token) async def evaluate_intents(request: Request): """Evaluate intents against a Rasa NLU model.""" # create `tmpdir` and cast as str for py3.5 compatibility tmpdir = str(tempfile.mkdtemp()) zipped_model_path = os.path.join(tmpdir, "model.tar.gz") write_request_body_to_file(request, zipped_model_path) model_path, nlu_files = await nlu_model_and_evaluation_files_from_archive( zipped_model_path, tmpdir ) if len(nlu_files) == 1: data_path = os.path.abspath(nlu_files[0]) try: evaluation = run_evaluation(data_path, model_path) return response.json(evaluation) except ValueError as e: return ErrorResponse( 400, "FailedIntentEvaluation", "Evaluation could not be created. Error: {}".format(e), ) else: return ErrorResponse( 400, "FailedIntentEvaluation", "NLU evaluation file could not be found. " "This endpoint requires a single file ending " "on `.md` or `.json`.", ) @app.post("/jobs") @requires_auth(app, auth_token) async def train_stack(request: Request): """Train a Rasa Stack model.""" from rasa.train import train_async rjs = request.json # create a temporary directory to store config, domain and # training data temp_dir = tempfile.mkdtemp() try: config_path = os.path.join(temp_dir, "config.yml") dump_obj_as_str_to_file(config_path, rjs["config"]) domain_path = os.path.join(temp_dir, "domain.yml") dump_obj_as_str_to_file(domain_path, rjs["domain"]) nlu_path = os.path.join(temp_dir, "nlu.md") dump_obj_as_str_to_file(nlu_path, rjs["nlu"]) stories_path = os.path.join(temp_dir, "stories.md") dump_obj_as_str_to_file(stories_path, rjs["stories"]) except KeyError: raise ErrorResponse( 400, "TrainingError", "The Rasa Stack training request is " "missing a key. The required keys are " "`config`, `domain`, `nlu` and `stories`.", ) # the model will be saved to the same temporary dir # unless `out` was specified in the request try: model_path = await train_async( domain=domain_path, config=config_path, training_files=[nlu_path, stories_path], output=rjs.get("out", temp_dir), force_training=rjs.get("force", False), ) return await response.file(model_path) except Exception as e: raise ErrorResponse( 400, "TrainingError", "Rasa Stack model could not be trained. Error: {}".format(e), ) @app.get("/domain") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def get_domain(request: Request): """Get current domain in yaml or json format.""" accepts = request.headers.get("Accept", default="application/json") if accepts.endswith("json"): domain = app.agent.domain.as_dict() return response.json(domain) elif accepts.endswith("yml") or accepts.endswith("yaml"): domain_yaml = app.agent.domain.as_yaml() return response.text( domain_yaml, status=200, content_type="application/x-yml" ) else: raise ErrorResponse( 406, "InvalidHeader", "Invalid Accept header. Domain can be " "provided as " 'json ("Accept: application/json") or' 'yml ("Accept: application/x-yml"). ' "Make sure you've set the appropriate Accept " "header.", ) @app.post("/finetune") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def continue_training(request: Request): epochs = request.raw_args.get("epochs", 30) batch_size = request.raw_args.get("batch_size", 5) request_params = request.json sender_id = UserMessage.DEFAULT_SENDER_ID try: tracker = DialogueStateTracker.from_dict( sender_id, request_params, app.agent.domain.slots ) except Exception as e: raise ErrorResponse( 400, "InvalidParameter", "Supplied events are not valid. {}".format(e), {"parameter": "", "in": "body"}, ) try: # Fetches the appropriate bot response in a json format app.agent.continue_training([tracker], epochs=epochs, batch_size=batch_size) return response.text("", 204) except Exception as e: logger.exception("Caught an exception during prediction.") raise ErrorResponse( 500, "TrainingException", "Server failure. Error: {}".format(e) ) @app.get("/status") @requires_auth(app, auth_token) async def status(request: Request): return response.json( { "model_fingerprint": app.agent.fingerprint if app.agent else None, "is_ready": app.agent.is_ready() if app.agent else False, } ) @app.post("/predict") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def tracker_predict(request: Request): """ Given a list of events, predicts the next action""" sender_id = UserMessage.DEFAULT_SENDER_ID request_params = request.json verbosity = event_verbosity_parameter(request, EventVerbosity.AFTER_RESTART) try: tracker = DialogueStateTracker.from_dict( sender_id, request_params, app.agent.domain.slots ) except Exception as e: raise ErrorResponse( 400, "InvalidParameter", "Supplied events are not valid. {}".format(e), {"parameter": "", "in": "body"}, ) policy_ensemble = app.agent.policy_ensemble probabilities, policy = policy_ensemble.probabilities_using_best_policy( tracker, app.agent.domain ) scores = [ {"action": a, "score": p} for a, p in zip(app.agent.domain.action_names, probabilities) ] return response.json( { "scores": scores, "policy": policy, "tracker": tracker.current_state(verbosity), } ) @app.post("/parse") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def parse(request: Request): request_params = request.json parse_data = await app.agent.interpreter.parse(request_params.get("q")) return response.json(parse_data) return app
if not username or not password: raise exceptions.AuthenticationFailed("Missing username or password.") user = username_table.get(username, None) if user is None: raise exceptions.AuthenticationFailed("User not found.") if password != user.password: raise exceptions.AuthenticationFailed("Password is incorrect.") return user app = Sanic(__name__) sanic_jwt = Initialize(app, authenticate=authenticate, retrieve_user=retrieve_user) @app.route("/hello") async def test(request): return json({"hello": "world"}) @app.route("/protected") @protected() async def protected(request): return json({"protected": True}) @app.route("/protected_user")
def test_decorators_override_configuration_defaults(): blueprint = Blueprint("Test") app = Sanic("sanic-jwt-test") sanicjwt = Initialize( blueprint, app=app, authenticate=authenticate, scopes_enabled=True, retrieve_user=authenticate, add_scopes_to_payload=my_scope_extender, ) @blueprint.get("/protected") @protected(blueprint, verify_exp=False) def protected_hello_world(request): return json({"message": "hello world"}) @blueprint.route("/scoped") @sanicjwt.scoped("user", authorization_header="foobar") async def scoped_endpoint(request): return json({"scoped": True}) app.blueprint(blueprint, url_prefix="/test") _, response = app.test_client.post("/test/auth", json={ "username": "******", "password": "******" }) access_token = response.json.get(sanicjwt.config.access_token_name(), None) payload = jwt.decode( access_token, sanicjwt.config.secret(), algorithms=sanicjwt.config.algorithm(), ) exp = payload.get("exp", None) assert "exp" in payload exp = datetime.utcfromtimestamp(exp) with freeze_time(datetime.utcnow() + timedelta(seconds=(60 * 35))): assert isinstance(exp, datetime) assert datetime.utcnow() > exp _, response = app.test_client.get( "/test/protected", headers={"Authorization": "Bearer {}".format(access_token)}, ) assert response.status == 200 _, response = app.test_client.get( "/test/scoped", headers={"FudgeBar": "Bearer {}".format(access_token)}) assert response.status == 401 assert "Authorization header not present." in response.json.get("reasons") assert response.json.get("exception") == "Unauthorized" _, response = app.test_client.get( "/test/scoped", headers={"Foobar": "Bear {}".format(access_token)}) assert response.status == 401 assert "Authorization header is invalid." in response.json.get("reasons") assert response.json.get("exception") == "Unauthorized" _, response = app.test_client.get( "/test/scoped", headers={"Foobar": "Bearer {}".format(access_token)}) assert response.status == 200
def test_deprecated_handler_payload_extend(): app = Sanic("sanic-jwt-test") app.config.SANIC_JWT_HANDLER_PAYLOAD_EXTEND = lambda *a, **kw: {} with pytest.raises(exceptions.InvalidConfiguration): Initialize(app, authenticate=lambda: True)
def app_with_async_methods(): cache = {} async def authenticate(request, *args, **kwargs): username = request.json.get("username", None) password = request.json.get("password", None) if not username or not password: raise exceptions.AuthenticationFailed( "Missing username or password.") user = None for u in users: if u.username == username: user = u break if user is None: raise exceptions.AuthenticationFailed("User not found.") if password != user.password: raise exceptions.AuthenticationFailed("Password is incorrect.") return user async def store_refresh_token(user_id, refresh_token, *args, **kwargs): assert user_id == "0x1" key = "refresh_token_{user_id}".format(user_id=user_id) cache[key] = refresh_token print("key", key) print("refresh_token (stored):", refresh_token) async def retrieve_refresh_token(user_id, *args, **kwargs): assert user_id == "0x1" key = "refresh_token_{user_id}".format(user_id=user_id) print("key", key) return cache.get(key, None) async def retrieve_user(request, payload, *args, **kwargs): if payload: user_id = payload.get("user_id", None) assert user_id == "0x1" if user_id is not None: for u in users: if u.id == int(user_id, base=16): return u else: return None secret = str(binascii.hexlify(os.urandom(32)), "utf-8") sanic_app = Sanic() sanicjwt = Initialize( sanic_app, authenticate=authenticate, store_refresh_token=store_refresh_token, retrieve_refresh_token=retrieve_refresh_token, retrieve_user=retrieve_user, refresh_token_enabled=True, secret=secret, ) @sanic_app.route("/") async def helloworld(request): return json({"hello": "world"}) @sanic_app.route("/protected") @protected() async def protected_request(request): return json({"protected": True}) yield (sanic_app, sanicjwt)
if not username or not password: raise exceptions.AuthenticationFailed("Missing username or password.") user = username_table.get(username, None) if user is None: raise exceptions.AuthenticationFailed("User not found.") if password != user.password: raise exceptions.AuthenticationFailed("Password is incorrect.") return user sanic_app = Sanic() sanic_jwt = Initialize(sanic_app, authenticate=authenticate) class PublicView(HTTPMethodView): def get(self, request): return json({"hello": "world"}) class ProtectedView(HTTPMethodView): decorators = [protected()] async def get(self, request): return json({"protected": True}) class PartiallyProtectedView(HTTPMethodView):
app = Sanic(__name__) app.config.from_object(Settings()) # Install extentions app.blueprint(ext_exceptions) app.blueprint(ext_middlewares) # Versions version = 'v1/' # Name of the API name = 'my_api' # Install apps app.blueprint(app_name, url_prefix=version+name) app.blueprint(swagger_blueprint) # swagger # JWT Initialize(app, authenticate=authenticate) # Running sanic, we need to make sure directly run by interpreter # ref: http://sanic.readthedocs.io/en/latest/sanic/deploying.html#running-via-command if __name__ == '__main__': app.run( host=args.host, port=args.port, workers=args.workers, debug=args.debug, access_log=args.accesslog )
cache[key] = refresh_token async def retrieve_refresh_token(self, user_id, *args, **kwargs): key = "refresh_token_{user_id}".format(user_id=user_id) token = cache.get(key, None) return token async def retrieve_user(self, request, payload, *args, **kwargs): return {"user_id": 1} app = Sanic("sanic-jwt-test") sanicjwt = Initialize( blueprint, app=app, authentication_class=MyAuthentication, refresh_token_enabled=True, ) app.blueprint(blueprint, url_prefix="/test") def test_protected_blueprint(): _, response = app.test_client.get("/test/") assert response.status == 401 assert response.json.get("exception") == "Unauthorized" assert "Authorization header not present." in response.json.get("reasons") _, response = app.test_client.post("/test/auth", json={
app.config.JWT_SECRET = "NextYearInDublin" app.config.CSRF_SECRET = Fernet.generate_key() def authenticate(request: Request) -> t.Dict[str, int]: return {"user_id": 1} def scope_extender(*args) -> str: return "top_secret:read:write" sanicjwt = Initialize( app, authenticate=authenticate, add_scopes_to_payload=scope_extender, cookie_set=True, cookie_split=True, ) @app.exception(Forbidden, Unauthorized) async def handle_exceptions(request: Request, exception: Exception) -> HTTPResponse: """We want to override the default exception handler to return a JSON message as opposed to some HTML""" return json( {"error": exception.__class__.__name__, "message": str(exception)}, status=exception.status_code, ) @app.middleware("request")
async def retrieve_refresh_token(request, user_id): refresh_token = request.json["refresh_token"] if await app.ctx.auth.verify_token(refresh_token): return refresh_token return None def store_refresh_token(*args, **kwargs): ... def retrieve_user(request, payload): return userid_table.get(payload["user_id"]) app = Sanic("myapp") Initialize( app, authenticate=authenticate, generate_refresh_token=generate_refresh_token, refresh_token_enabled=True, retrieve_user=retrieve_user, retrieve_refresh_token=retrieve_refresh_token, store_refresh_token=store_refresh_token, ) if __name__ == "__main__": app.run(host="127.0.0.1", port=8888, debug=True)
def test_initialize_class(): app = Sanic() Initialize(app, authenticate=lambda: True) assert True
return user class User2Claim(Claim): key = "custom_user_id" def setup(self, payload, user): return user.user_id def verify(self, value): return value == 2 custom_claims = [User2Claim] app = Sanic(__name__) sanicjwt = Initialize(app, authenticate=authenticate, custom_claims=custom_claims, debug=True) @app.route("/protected") @protected() async def protected(request): return json({"protected": True}) if __name__ == "__main__": app.run(host="127.0.0.1", port=8888, auto_reload=True)
if user_id: async with db.cursor() as cur: await cur.execute("SELECT is_admin FROM user WHERE id=%s", [user_id]) res = await cur.fetchone() if res is not None: is_admin, = res return User(user_id=user_id, scope=('admin' if is_admin else 'user')) sanic_app = Sanic() CORS(sanic_app) Initialize(sanic_app, authenticate=authenticate, secret=SECRET_KEY, add_scopes_to_payload=lambda user: user.scope, retrieve_user=retrieve_user) @sanic_app.get("/") async def index(request): return await file('index.html') class UserView(HTTPMethodView): @staticmethod @inject_user() @scoped(['user', 'admin'], require_all=False) async def get(request, user): id = request.raw_args.get('id', None)
return json({"user": id}) @blueprint.route("/scoped_empty") @scoped("something", initialized_on=blueprint) async def scoped(request): return json({"scoped": True}) async def authenticate(request, *args, **kwargs): return {"user_id": 1} app = Sanic() sanicjwt = Initialize(blueprint, app=app, authenticate=authenticate) app.blueprint(blueprint, url_prefix="/test") def test_protected_blueprint(): _, response = app.test_client.get("/test/") assert response.status == 401 _, response = app.test_client.post("/test/auth", json={ "username": "******", "password": "******" })
def app_with_sync_methods(users): cache = {} def authenticate(request, *args, **kwargs): username = request.json.get('username', None) password = request.json.get('password', None) if not username or not password: raise exceptions.AuthenticationFailed( "Missing username or password.") user = None for u in users: if u.username == username: user = u break if user is None: raise exceptions.AuthenticationFailed("User not found.") if password != user.password: raise exceptions.AuthenticationFailed("Password is incorrect.") return user def store_refresh_token(user_id, refresh_token, *args, **kwargs): key = 'refresh_token_{user_id}'.format(user_id=user_id) cache[key] = refresh_token def retrieve_refresh_token(user_id, *args, **kwargs): key = 'refresh_token_{user_id}'.format(user_id=user_id) return cache.get(key, None) def retrieve_user(request, payload, *args, **kwargs): if payload: user_id = payload.get('user_id', None) if user_id is not None: for u in users: if u.user_id == user_id: return u else: return None sanic_app = Sanic() sanicjwt = Initialize(sanic_app, authenticate=authenticate, store_refresh_token=store_refresh_token, retrieve_refresh_token=retrieve_refresh_token, retrieve_user=retrieve_user) sanic_app.config.SANIC_JWT_REFRESH_TOKEN_ENABLED = True sanic_app.config.SANIC_JWT_SECRET = str(binascii.hexlify(os.urandom(32)), 'utf-8') @sanic_app.route("/") async def helloworld(request): return json({"hello": "world"}) @sanic_app.route("/protected") @protected() async def protected_request(request): return json({"protected": True}) yield (sanic_app, sanicjwt)
def test_configuration_initialize_class_default(): try: app = Sanic("sanic-jwt-test") Initialize(app, authenticate=lambda: True) except Exception as e: pytest.fail("Raised exception: {}".format(e))
user = username_table.get(username, None) if user is None: raise exceptions.AuthenticationFailed("User not found.") if password != user.password: raise exceptions.AuthenticationFailed("Password is incorrect.") return user def user2(payload): return payload.get("user_id") == 2 extra_verifications = [user2] app = Sanic() Initialize(app, authenticate=authenticate, extra_verifications=extra_verifications) @app.route("/protected") @protected() async def protected(request): return json({"protected": True}) if __name__ == "__main__": app.run(host="127.0.0.1", port=8888, auto_reload=True)
import asyncio import math from sanic import Sanic from sanic_jwt import Claim, exceptions, Initialize def is_prime(n): if n % 2 == 0 and n > 2: return False return all(n % i for i in range(3, int(math.sqrt(n)) + 1, 2)) app = Sanic() sanicjwt = Initialize(app, auth_mode=False) user7 = {"user_id": 7} user8 = {"user_id": 8} class UserIsPrime(Claim): key = "user_id_checker" def setup(self, payload, user): return user.get("user_id") def verify(self, value): return is_prime(value)
user = username_table.get(username, None) if user is None: raise exceptions.AuthenticationFailed("User not found.") if password != user.password: raise exceptions.AuthenticationFailed("Password is incorrect.") return user secret = str(binascii.hexlify(os.urandom(32)), "utf-8") app = Sanic(__name__) sanic_jwt = Initialize(app, authenticate=authenticate, expiration_delta=(60 * 60), secret=secret, retrieve_user=retrieve_user) @app.route("/hello") async def test(request): return json({"hello": "world"}) @app.route("/protected") @protected() async def protected(request): return json({"protected": True})
def create_app( agent: Optional["Agent"] = None, cors_origins: Union[Text, List[Text], None] = "*", auth_token: Optional[Text] = None, response_timeout: int = DEFAULT_RESPONSE_TIMEOUT, jwt_secret: Optional[Text] = None, jwt_method: Text = "HS256", endpoints: Optional[AvailableEndpoints] = None, ): """Class representing a Rasa HTTP server.""" app = Sanic(__name__) app.config.RESPONSE_TIMEOUT = response_timeout configure_cors(app, cors_origins) # Setup the Sanic-JWT extension if jwt_secret and jwt_method: # since we only want to check signatures, we don't actually care # about the JWT method and set the passed secret as either symmetric # or asymmetric key. jwt lib will choose the right one based on method app.config["USE_JWT"] = True Initialize( app, secret=jwt_secret, authenticate=authenticate, algorithm=jwt_method, user_id="username", ) app.agent = agent # Initialize shared object of type unsigned int for tracking # the number of active training processes app.active_training_processes = multiprocessing.Value("I", 0) @app.exception(ErrorResponse) async def handle_error_response(request: Request, exception: ErrorResponse): return response.json(exception.error_info, status=exception.status) add_root_route(app) @app.get("/version") async def version(request: Request): """Respond with the version number of the installed Rasa.""" return response.json({ "version": rasa.__version__, "minimum_compatible_version": MINIMUM_COMPATIBLE_VERSION, }) @app.get("/status") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def status(request: Request): """Respond with the model name and the fingerprint of that model.""" return response.json({ "model_file": app.agent.path_to_model_archive or app.agent.model_directory, "fingerprint": model.fingerprint_from_path(app.agent.model_directory), "num_active_training_jobs": app.active_training_processes.value, }) @app.get("/conversations/<conversation_id:path>/tracker") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def retrieve_tracker(request: Request, conversation_id: Text): """Get a dump of a conversation's tracker including its events.""" verbosity = event_verbosity_parameter(request, EventVerbosity.AFTER_RESTART) until_time = rasa.utils.endpoints.float_arg(request, "until") tracker = await get_tracker(app.agent.create_processor(), conversation_id) try: if until_time is not None: tracker = tracker.travel_back_in_time(until_time) state = tracker.current_state(verbosity) return response.json(state) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse(500, "ConversationError", f"An unexpected error occurred. Error: {e}") @app.post("/conversations/<conversation_id:path>/tracker/events") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def append_events(request: Request, conversation_id: Text): """Append a list of events to the state of a conversation""" validate_request_body( request, "You must provide events in the request body in order to append them" "to the state of a conversation.", ) verbosity = event_verbosity_parameter(request, EventVerbosity.AFTER_RESTART) try: async with app.agent.lock_store.lock(conversation_id): processor = app.agent.create_processor() tracker = processor.get_tracker(conversation_id) _validate_tracker(tracker, conversation_id) events = _get_events_from_request_body(request) for event in events: tracker.update(event, app.agent.domain) app.agent.tracker_store.save(tracker) return response.json(tracker.current_state(verbosity)) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse(500, "ConversationError", f"An unexpected error occurred. Error: {e}") def _get_events_from_request_body(request: Request) -> List[Event]: events = request.json if not isinstance(events, list): events = [events] events = [Event.from_parameters(event) for event in events] events = [event for event in events if event] if not events: common_utils.raise_warning( f"Append event called, but could not extract a valid event. " f"Request JSON: {request.json}") raise ErrorResponse( 400, "BadRequest", "Couldn't extract a proper event from the request body.", { "parameter": "", "in": "body" }, ) return events @app.put("/conversations/<conversation_id:path>/tracker/events") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def replace_events(request: Request, conversation_id: Text): """Use a list of events to set a conversations tracker to a state.""" validate_request_body( request, "You must provide events in the request body to set the sate of the " "conversation tracker.", ) verbosity = event_verbosity_parameter(request, EventVerbosity.AFTER_RESTART) try: async with app.agent.lock_store.lock(conversation_id): tracker = DialogueStateTracker.from_dict( conversation_id, request.json, app.agent.domain.slots) # will override an existing tracker with the same id! app.agent.tracker_store.save(tracker) return response.json(tracker.current_state(verbosity)) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse(500, "ConversationError", f"An unexpected error occurred. Error: {e}") @app.get("/conversations/<conversation_id:path>/story") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def retrieve_story(request: Request, conversation_id: Text): """Get an end-to-end story corresponding to this conversation.""" # retrieve tracker and set to requested state tracker = await get_tracker(app.agent.create_processor(), conversation_id) until_time = rasa.utils.endpoints.float_arg(request, "until") try: if until_time is not None: tracker = tracker.travel_back_in_time(until_time) # dump and return tracker state = tracker.export_stories(e2e=True) return response.text(state) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse(500, "ConversationError", f"An unexpected error occurred. Error: {e}") @app.post("/conversations/<conversation_id:path>/execute") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def execute_action(request: Request, conversation_id: Text): request_params = request.json action_to_execute = request_params.get("name", None) if not action_to_execute: raise ErrorResponse( 400, "BadRequest", "Name of the action not provided in request body.", { "parameter": "name", "in": "body" }, ) policy = request_params.get("policy", None) confidence = request_params.get("confidence", None) verbosity = event_verbosity_parameter(request, EventVerbosity.AFTER_RESTART) try: async with app.agent.lock_store.lock(conversation_id): tracker = await get_tracker(app.agent.create_processor(), conversation_id) output_channel = _get_output_channel(request, tracker) await app.agent.execute_action( conversation_id, action_to_execute, output_channel, policy, confidence, ) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse(500, "ConversationError", f"An unexpected error occurred. Error: {e}") tracker = await get_tracker(app.agent.create_processor(), conversation_id) state = tracker.current_state(verbosity) response_body = {"tracker": state} if isinstance(output_channel, CollectingOutputChannel): response_body["messages"] = output_channel.messages return response.json(response_body) @app.post("/conversations/<conversation_id:path>/trigger_intent") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def trigger_intent(request: Request, conversation_id: Text) -> HTTPResponse: request_params = request.json intent_to_trigger = request_params.get("name") entities = request_params.get("entities", []) if not intent_to_trigger: raise ErrorResponse( 400, "BadRequest", "Name of the intent not provided in request body.", { "parameter": "name", "in": "body" }, ) verbosity = event_verbosity_parameter(request, EventVerbosity.AFTER_RESTART) try: async with app.agent.lock_store.lock(conversation_id): tracker = await get_tracker(app.agent.create_processor(), conversation_id) output_channel = _get_output_channel(request, tracker) if intent_to_trigger not in app.agent.domain.intents: raise ErrorResponse( 404, "NotFound", f"The intent {trigger_intent} does not exist in the domain.", ) await app.agent.trigger_intent( intent_name=intent_to_trigger, entities=entities, output_channel=output_channel, tracker=tracker, ) except ErrorResponse: raise except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse(500, "ConversationError", f"An unexpected error occurred. Error: {e}") state = tracker.current_state(verbosity) response_body = {"tracker": state} if isinstance(output_channel, CollectingOutputChannel): response_body["messages"] = output_channel.messages return response.json(response_body) @app.post("/conversations/<conversation_id:path>/predict") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def predict(request: Request, conversation_id: Text) -> HTTPResponse: try: # Fetches the appropriate bot response in a json format responses = await app.agent.predict_next(conversation_id) responses["scores"] = sorted(responses["scores"], key=lambda k: (-k["score"], k["action"])) return response.json(responses) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse(500, "ConversationError", f"An unexpected error occurred. Error: {e}") @app.post("/conversations/<conversation_id:path>/messages") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def add_message(request: Request, conversation_id: Text): validate_request_body( request, "No message defined in request body. Add a message to the request body in " "order to add it to the tracker.", ) request_params = request.json message = request_params.get("text") sender = request_params.get("sender") parse_data = request_params.get("parse_data") verbosity = event_verbosity_parameter(request, EventVerbosity.AFTER_RESTART) # TODO: implement for agent / bot if sender != "user": raise ErrorResponse( 400, "BadRequest", "Currently, only user messages can be passed to this endpoint. " "Messages of sender '{}' cannot be handled.".format(sender), { "parameter": "sender", "in": "body" }, ) user_message = UserMessage(message, None, conversation_id, parse_data) try: async with app.agent.lock_store.lock(conversation_id): tracker = await app.agent.log_message(user_message) return response.json(tracker.current_state(verbosity)) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse(500, "ConversationError", f"An unexpected error occurred. Error: {e}") @app.post("/model/train") @requires_auth(app, auth_token) async def train(request: Request) -> HTTPResponse: """Train a Rasa Model.""" validate_request_body( request, "You must provide training data in the request body in order to " "train your model.", ) if request.headers.get("Content-type") == YAML_CONTENT_TYPE: training_payload = _training_payload_from_yaml(request) else: training_payload = _training_payload_from_json(request) try: with app.active_training_processes.get_lock(): app.active_training_processes.value += 1 loop = asyncio.get_event_loop() from rasa import train as train_model # Declare `model_path` upfront to avoid pytype `name-error` model_path: Optional[Text] = None # pass `None` to run in default executor model_path = await loop.run_in_executor( None, functools.partial(train_model, **training_payload)) if model_path: filename = os.path.basename(model_path) return await response.file(model_path, filename=filename, headers={"filename": filename}) else: raise ErrorResponse( 500, "TrainingError", "Ran training, but it finished without a trained model.", ) except ErrorResponse as e: raise e except InvalidDomain as e: raise ErrorResponse( 400, "InvalidDomainError", f"Provided domain file is invalid. Error: {e}", ) except Exception as e: logger.error(traceback.format_exc()) raise ErrorResponse( 500, "TrainingError", f"An unexpected error occurred during training. Error: {e}", ) finally: with app.active_training_processes.get_lock(): app.active_training_processes.value -= 1 @app.post("/model/test/stories") @requires_auth(app, auth_token) @ensure_loaded_agent(app, require_core_is_ready=True) async def evaluate_stories(request: Request) -> HTTPResponse: """Evaluate stories against the currently loaded model.""" validate_request_body( request, "You must provide some stories in the request body in order to " "evaluate your model.", ) test_data = _test_data_file_from_payload(request) use_e2e = rasa.utils.endpoints.bool_arg(request, "e2e", default=False) try: evaluation = await test(test_data, app.agent, e2e=use_e2e) return response.json(evaluation) except Exception as e: logger.error(traceback.format_exc()) raise ErrorResponse( 500, "TestingError", f"An unexpected error occurred during evaluation. Error: {e}", ) @app.post("/model/test/intents") @requires_auth(app, auth_token) async def evaluate_intents(request: Request) -> HTTPResponse: """Evaluate intents against a Rasa model.""" validate_request_body( request, "You must provide some nlu data in the request body in order to " "evaluate your model.", ) test_data = _test_data_file_from_payload(request) eval_agent = app.agent model_path = request.args.get("model", None) if model_path: model_server = app.agent.model_server if model_server is not None: model_server.url = model_path eval_agent = await _load_agent(model_path, model_server, app.agent.remote_storage) data_path = os.path.abspath(test_data) if not os.path.exists(eval_agent.model_directory): raise ErrorResponse(409, "Conflict", "Loaded model file not found.") model_directory = eval_agent.model_directory _, nlu_model = model.get_model_subdirectories(model_directory) try: evaluation = run_evaluation(data_path, nlu_model) return response.json(evaluation) except Exception as e: logger.error(traceback.format_exc()) raise ErrorResponse( 500, "TestingError", f"An unexpected error occurred during evaluation. Error: {e}", ) @app.post("/model/predict") @requires_auth(app, auth_token) @ensure_loaded_agent(app, require_core_is_ready=True) async def tracker_predict(request: Request) -> HTTPResponse: """ Given a list of events, predicts the next action""" validate_request_body( request, "No events defined in request_body. Add events to request body in order to " "predict the next action.", ) sender_id = UserMessage.DEFAULT_SENDER_ID verbosity = event_verbosity_parameter(request, EventVerbosity.AFTER_RESTART) request_params = request.json try: tracker = DialogueStateTracker.from_dict(sender_id, request_params, app.agent.domain.slots) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse( 400, "BadRequest", f"Supplied events are not valid. {e}", { "parameter": "", "in": "body" }, ) try: policy_ensemble = app.agent.policy_ensemble probabilities, policy = policy_ensemble.probabilities_using_best_policy( tracker, app.agent.domain, app.agent.interpreter) scores = [{ "action": a, "score": p } for a, p in zip(app.agent.domain.action_names, probabilities)] return response.json({ "scores": scores, "policy": policy, "tracker": tracker.current_state(verbosity), }) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse(500, "PredictionError", f"An unexpected error occurred. Error: {e}") @app.post("/model/parse") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def parse(request: Request) -> HTTPResponse: validate_request_body( request, "No text message defined in request_body. Add text message to request body " "in order to obtain the intent and extracted entities.", ) emulation_mode = request.args.get("emulation_mode") emulator = _create_emulator(emulation_mode) try: data = emulator.normalise_request_json(request.json) try: parsed_data = await app.agent.parse_message_using_nlu_interpreter( data.get("text")) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse( 400, "ParsingError", f"An unexpected error occurred. Error: {e}") response_data = emulator.normalise_response_json(parsed_data) return response.json(response_data) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse(500, "ParsingError", f"An unexpected error occurred. Error: {e}") @app.put("/model") @requires_auth(app, auth_token) async def load_model(request: Request) -> HTTPResponse: validate_request_body( request, "No path to model file defined in request_body.") model_path = request.json.get("model_file", None) model_server = request.json.get("model_server", None) remote_storage = request.json.get("remote_storage", None) if model_server: try: model_server = EndpointConfig.from_dict(model_server) except TypeError as e: logger.debug(traceback.format_exc()) raise ErrorResponse( 400, "BadRequest", f"Supplied 'model_server' is not valid. Error: {e}", { "parameter": "model_server", "in": "body" }, ) app.agent = await _load_agent(model_path, model_server, remote_storage, endpoints, app.agent.lock_store) logger.debug(f"Successfully loaded model '{model_path}'.") return response.json(None, status=204) @app.delete("/model") @requires_auth(app, auth_token) async def unload_model(request: Request) -> HTTPResponse: model_file = app.agent.model_directory app.agent = Agent(lock_store=app.agent.lock_store) logger.debug(f"Successfully unloaded model '{model_file}'.") return response.json(None, status=204) @app.get("/domain") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def get_domain(request: Request) -> HTTPResponse: """Get current domain in yaml or json format.""" accepts = request.headers.get("Accept", default=JSON_CONTENT_TYPE) if accepts.endswith("json"): domain = app.agent.domain.as_dict() return response.json(domain) elif accepts.endswith("yml") or accepts.endswith("yaml"): domain_yaml = app.agent.domain.as_yaml() return response.text(domain_yaml, status=200, content_type=YAML_CONTENT_TYPE) else: raise ErrorResponse( 406, "NotAcceptable", f"Invalid Accept header. Domain can be " f"provided as " f'json ("Accept: {JSON_CONTENT_TYPE}") or' f'yml ("Accept: {YAML_CONTENT_TYPE}"). ' f"Make sure you've set the appropriate Accept " f"header.", ) return app
return json({"version": 1}) @blueprint2.get("/", strict_slashes=True) @protected(blueprint2) def protected_hello_world_2(request): return json({"version": 2}) async def authenticate(request, *args, **kwargs): return {"user_id": 1} app = Sanic() sanicjwt1 = Initialize(blueprint1, app=app, authenticate=authenticate) sanicjwt2 = Initialize( blueprint2, app=app, authenticate=authenticate, url_prefix="/a", access_token_name="token", cookie_access_token_name="token", cookie_set=True, secret="somethingdifferent", ) app.blueprint(blueprint1, url_prefix="/test1") app.blueprint(blueprint2, url_prefix="/test2")
def create_app( agent: Optional["Agent"] = None, cors_origins: Union[Text, List[Text], None] = "*", auth_token: Optional[Text] = None, response_timeout: int = DEFAULT_RESPONSE_TIMEOUT, jwt_secret: Optional[Text] = None, jwt_method: Text = "HS256", endpoints: Optional[AvailableEndpoints] = None, ): """Class representing a Rasa HTTP server.""" app = Sanic(__name__) app.config.RESPONSE_TIMEOUT = response_timeout configure_cors(app, cors_origins) # Setup the Sanic-JWT extension if jwt_secret and jwt_method: # since we only want to check signatures, we don't actually care # about the JWT method and set the passed secret as either symmetric # or asymmetric key. jwt lib will choose the right one based on method app.config["USE_JWT"] = True Initialize( app, secret=jwt_secret, authenticate=authenticate, algorithm=jwt_method, user_id="username", ) app.agent = agent # Initialize shared object of type unsigned int for tracking # the number of active training processes app.active_training_processes = multiprocessing.Value("I", 0) @app.exception(ErrorResponse) async def handle_error_response(request: Request, exception: ErrorResponse): return response.json(exception.error_info, status=exception.status) add_root_route(app) @app.get("/version") async def version(request: Request): """Respond with the version number of the installed Rasa.""" return response.json({ "version": rasa.__version_bf__, # bf "minimum_compatible_version": MINIMUM_COMPATIBLE_VERSION, }) @app.get("/status") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def status(request: Request): """Respond with the model name and the fingerprint of that model.""" return response.json({ "model_file": app.agent.path_to_model_archive or app.agent.model_directory, "fingerprint": model.fingerprint_from_path(app.agent.model_directory), "num_active_training_jobs": app.active_training_processes.value, }) @app.get("/conversations/<conversation_id>/tracker") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def retrieve_tracker(request: Request, conversation_id: Text): """Get a dump of a conversation's tracker including its events.""" verbosity = event_verbosity_parameter(request, EventVerbosity.AFTER_RESTART) until_time = rasa.utils.endpoints.float_arg(request, "until") tracker = await get_tracker(app.agent.create_processor(), conversation_id) try: if until_time is not None: tracker = tracker.travel_back_in_time(until_time) state = tracker.current_state(verbosity) return response.json(state) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse(500, "ConversationError", f"An unexpected error occurred. Error: {e}") @app.post("/conversations/<conversation_id>/tracker/events") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def append_events(request: Request, conversation_id: Text): """Append a list of events to the state of a conversation""" validate_request_body( request, "You must provide events in the request body in order to append them" "to the state of a conversation.", ) verbosity = event_verbosity_parameter(request, EventVerbosity.AFTER_RESTART) try: async with app.agent.lock_store.lock(conversation_id): processor = app.agent.create_processor() tracker = processor.get_tracker(conversation_id) _validate_tracker(tracker, conversation_id) events = _get_events_from_request_body(request) for event in events: tracker.update(event, app.agent.domain) app.agent.tracker_store.save(tracker) return response.json(tracker.current_state(verbosity)) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse(500, "ConversationError", f"An unexpected error occurred. Error: {e}") def _get_events_from_request_body(request: Request) -> List[Event]: events = request.json if not isinstance(events, list): events = [events] events = [Event.from_parameters(event) for event in events] events = [event for event in events if event] if not events: raise_warning( f"Append event called, but could not extract a valid event. " f"Request JSON: {request.json}") raise ErrorResponse( 400, "BadRequest", "Couldn't extract a proper event from the request body.", { "parameter": "", "in": "body" }, ) return events @app.put("/conversations/<conversation_id>/tracker/events") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def replace_events(request: Request, conversation_id: Text): """Use a list of events to set a conversations tracker to a state.""" validate_request_body( request, "You must provide events in the request body to set the sate of the " "conversation tracker.", ) verbosity = event_verbosity_parameter(request, EventVerbosity.AFTER_RESTART) try: async with app.agent.lock_store.lock(conversation_id): tracker = DialogueStateTracker.from_dict( conversation_id, request.json, app.agent.domain.slots) # will override an existing tracker with the same id! app.agent.tracker_store.save(tracker) return response.json(tracker.current_state(verbosity)) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse(500, "ConversationError", f"An unexpected error occurred. Error: {e}") @app.get("/conversations/<conversation_id>/story") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def retrieve_story(request: Request, conversation_id: Text): """Get an end-to-end story corresponding to this conversation.""" # retrieve tracker and set to requested state tracker = await get_tracker(app.agent.create_processor(), conversation_id) until_time = rasa.utils.endpoints.float_arg(request, "until") try: if until_time is not None: tracker = tracker.travel_back_in_time(until_time) # dump and return tracker state = tracker.export_stories(e2e=True) return response.text(state) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse(500, "ConversationError", f"An unexpected error occurred. Error: {e}") @app.post("/conversations/<conversation_id>/execute") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def execute_action(request: Request, conversation_id: Text): request_params = request.json action_to_execute = request_params.get("name", None) if not action_to_execute: raise ErrorResponse( 400, "BadRequest", "Name of the action not provided in request body.", { "parameter": "name", "in": "body" }, ) policy = request_params.get("policy", None) confidence = request_params.get("confidence", None) verbosity = event_verbosity_parameter(request, EventVerbosity.AFTER_RESTART) try: async with app.agent.lock_store.lock(conversation_id): tracker = await get_tracker(app.agent.create_processor(), conversation_id) output_channel = _get_output_channel(request, tracker) await app.agent.execute_action( conversation_id, action_to_execute, output_channel, policy, confidence, ) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse(500, "ConversationError", f"An unexpected error occurred. Error: {e}") tracker = await get_tracker(app.agent.create_processor(), conversation_id) state = tracker.current_state(verbosity) response_body = {"tracker": state} if isinstance(output_channel, CollectingOutputChannel): response_body["messages"] = output_channel.messages return response.json(response_body) @app.post("/conversations/<conversation_id>/trigger_intent") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def trigger_intent(request: Request, conversation_id: Text) -> HTTPResponse: request_params = request.json intent_to_trigger = request_params.get("name") entities = request_params.get("entities", []) if not intent_to_trigger: raise ErrorResponse( 400, "BadRequest", "Name of the intent not provided in request body.", { "parameter": "name", "in": "body" }, ) verbosity = event_verbosity_parameter(request, EventVerbosity.AFTER_RESTART) try: async with app.agent.lock_store.lock(conversation_id): tracker = await get_tracker(app.agent.create_processor(), conversation_id) output_channel = _get_output_channel(request, tracker) if intent_to_trigger not in app.agent.domain.intents: raise ErrorResponse( 404, "NotFound", f"The intent {trigger_intent} does not exist in the domain.", ) await app.agent.trigger_intent( intent_name=intent_to_trigger, entities=entities, output_channel=output_channel, tracker=tracker, ) except ErrorResponse: raise except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse(500, "ConversationError", f"An unexpected error occurred. Error: {e}") state = tracker.current_state(verbosity) response_body = {"tracker": state} if isinstance(output_channel, CollectingOutputChannel): response_body["messages"] = output_channel.messages return response.json(response_body) @app.post("/conversations/<conversation_id>/predict") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def predict(request: Request, conversation_id: Text): try: # Fetches the appropriate bot response in a json format responses = await app.agent.predict_next(conversation_id) responses["scores"] = sorted(responses["scores"], key=lambda k: (-k["score"], k["action"])) return response.json(responses) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse(500, "ConversationError", f"An unexpected error occurred. Error: {e}") @app.post("/conversations/<conversation_id>/messages") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def add_message(request: Request, conversation_id: Text): validate_request_body( request, "No message defined in request body. Add a message to the request body in " "order to add it to the tracker.", ) request_params = request.json message = request_params.get("text") sender = request_params.get("sender") parse_data = request_params.get("parse_data") verbosity = event_verbosity_parameter(request, EventVerbosity.AFTER_RESTART) # TODO: implement for agent / bot if sender != "user": raise ErrorResponse( 400, "BadRequest", "Currently, only user messages can be passed to this endpoint. " "Messages of sender '{}' cannot be handled.".format(sender), { "parameter": "sender", "in": "body" }, ) user_message = UserMessage(message, None, conversation_id, parse_data) try: async with app.agent.lock_store.lock(conversation_id): tracker = await app.agent.log_message(user_message) return response.json(tracker.current_state(verbosity)) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse(500, "ConversationError", f"An unexpected error occurred. Error: {e}") @app.post("/model/train") @requires_auth(app, auth_token) async def train(request: Request) -> HTTPResponse: """Train a Rasa Model.""" validate_request_body( request, "You must provide training data in the request body in order to " "train your model.", ) rjs = request.json validate_request(rjs) # create a temporary directory to store config, domain and # training data temp_dir = tempfile.mkdtemp() # bf >> config_paths = [] for key in rjs["config"].keys(): config_path = os.path.join(temp_dir, "config-{}.yml".format(key)) rasa.utils.io.write_text_file(rjs["config"][key], config_path) config_paths += [config_path] if "nlu" in rjs: nlu_dir = os.path.join(temp_dir, "nlu") os.mkdir(nlu_dir) for key in rjs["nlu"].keys(): nlu_path = os.path.join(nlu_dir, "{}.md".format(key)) rasa.utils.io.write_text_file(rjs["nlu"][key]["data"], nlu_path) # << bf if "stories" in rjs: stories_path = os.path.join(temp_dir, "stories.md") rasa.utils.io.write_text_file(rjs["stories"], stories_path) domain_path = DEFAULT_DOMAIN_PATH if "domain" in rjs: domain_path = os.path.join(temp_dir, "domain.yml") rasa.utils.io.write_text_file(rjs["domain"], domain_path) if rjs.get("save_to_default_model_directory", True) is True: model_output_directory = DEFAULT_MODELS_PATH else: model_output_directory = tempfile.gettempdir() try: with app.active_training_processes.get_lock(): app.active_training_processes.value += 1 info = dict( domain=domain_path, config=config_paths, # bf training_files=temp_dir, output=model_output_directory, force_training=rjs.get("force", False), fixed_model_name=rjs.get("fixed_model_name"), # bf persist_nlu_training_data=True, # bf ) loop = asyncio.get_event_loop() from rasa import train as train_model # Declare `model_path` upfront to avoid pytype `name-error` model_path: Optional[Text] = None # pass `None` to run in default executor model_path = await loop.run_in_executor( None, functools.partial(train_model, **info)) filename = os.path.basename(model_path) if model_path else None return await response.file(model_path, filename=filename, headers={"filename": filename}) except InvalidDomain as e: raise ErrorResponse( 400, "InvalidDomainError", f"Provided domain file is invalid. Error: {e}", ) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse( 500, "TrainingError", f"An unexpected error occurred during training. Error: {e}", ) finally: with app.active_training_processes.get_lock(): app.active_training_processes.value -= 1 def validate_request(rjs): if "config" not in rjs: raise ErrorResponse( 400, "BadRequest", "The training request is missing the required key `config`.", { "parameter": "config", "in": "body" }, ) if "nlu" not in rjs and "stories" not in rjs: raise ErrorResponse( 400, "BadRequest", "To train a Rasa model you need to specify at least one type of " "training data. Add `nlu` and/or `stories` to the request.", { "parameters": ["nlu", "stories"], "in": "body" }, ) if "stories" in rjs and "domain" not in rjs: raise ErrorResponse( 400, "BadRequest", "To train a Rasa model with story training data, you also need to " "specify the `domain`.", { "parameter": "domain", "in": "body" }, ) @app.post("/model/test/stories") @requires_auth(app, auth_token) @ensure_loaded_agent(app, require_core_is_ready=True) async def evaluate_stories(request: Request): """Evaluate stories against the currently loaded model.""" validate_request_body( request, "You must provide some stories in the request body in order to " "evaluate your model.", ) stories = rasa.utils.io.create_temporary_file(request.body, mode="w+b") use_e2e = rasa.utils.endpoints.bool_arg(request, "e2e", default=False) try: evaluation = await test(stories, app.agent, e2e=use_e2e) return response.json(evaluation) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse( 500, "TestingError", f"An unexpected error occurred during evaluation. Error: {e}", ) @app.post("/model/test/intents") @requires_auth(app, auth_token) async def evaluate_intents(request: Request): """Evaluate intents against a Rasa model.""" validate_request_body( request, "You must provide some nlu data in the request body in order to " "evaluate your model.", ) eval_agent = app.agent model_path = request.args.get("model", None) if model_path: model_server = app.agent.model_server if model_server is not None: model_server.url = model_path eval_agent = await _load_agent(model_path, model_server, app.agent.remote_storage) nlu_data = rasa.utils.io.create_temporary_file(request.body, mode="w+b") data_path = os.path.abspath(nlu_data) if not os.path.exists(eval_agent.model_directory): raise ErrorResponse(409, "Conflict", "Loaded model file not found.") model_directory = eval_agent.model_directory model_directory = os.path.abspath( os.path.join(model_directory, os.pardir)) # bf _, nlu_model = model.get_model_subdirectories(model_directory) try: language = request.args.get("language", None) # bf evaluation = run_evaluation(data_path, nlu_model.get(language)) #bf return response.json(evaluation) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse( 500, "TestingError", f"An unexpected error occurred during evaluation. Error: {e}", ) @app.post("/model/predict") @requires_auth(app, auth_token) @ensure_loaded_agent(app, require_core_is_ready=True) async def tracker_predict(request: Request): """ Given a list of events, predicts the next action""" validate_request_body( request, "No events defined in request_body. Add events to request body in order to " "predict the next action.", ) sender_id = UserMessage.DEFAULT_SENDER_ID verbosity = event_verbosity_parameter(request, EventVerbosity.AFTER_RESTART) request_params = request.json try: tracker = DialogueStateTracker.from_dict(sender_id, request_params, app.agent.domain.slots) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse( 400, "BadRequest", f"Supplied events are not valid. {e}", { "parameter": "", "in": "body" }, ) try: policy_ensemble = app.agent.policy_ensemble probabilities, policy = policy_ensemble.probabilities_using_best_policy( tracker, app.agent.domain) scores = [{ "action": a, "score": p } for a, p in zip(app.agent.domain.action_names, probabilities)] return response.json({ "scores": scores, "policy": policy, "tracker": tracker.current_state(verbosity), }) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse(500, "PredictionError", f"An unexpected error occurred. Error: {e}") @app.post("/model/parse") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def parse(request: Request): validate_request_body( request, "No text message defined in request_body. Add text message to request body " "in order to obtain the intent and extracted entities.", ) if not request.json.get("lang"): raise ErrorResponse(400, "Bad Request", "'lang' property is required'") emulation_mode = request.args.get("emulation_mode") emulator = _create_emulator(emulation_mode) try: data = emulator.normalise_request_json(request.json) try: # bf: get query args from rasa.core.interpreter import NaturalLanguageInterpreter if isinstance(app.agent.interpreter, dict): parsed_data = await app.agent.interpreter.get( request.json.get("lang")).parse( data.get("text"), data.get("message_id"), ) elif isinstance(app.agent.interpreter, NaturalLanguageInterpreter): parsed_data = await app.agent.interpreter.parse( data.get("text"), data.get("message_id"), ) # bf: end except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse( 400, "ParsingError", f"An unexpected error occurred. Error: {e}") response_data = emulator.normalise_response_json(parsed_data) return response.json(response_data) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse(500, "ParsingError", f"An unexpected error occurred. Error: {e}") @app.put("/model") @requires_auth(app, auth_token) async def load_model(request: Request): validate_request_body( request, "No path to model file defined in request_body.") model_path = request.json.get("model_file", None) model_server = request.json.get("model_server", None) remote_storage = request.json.get("remote_storage", None) if model_server: try: model_server = EndpointConfig.from_dict(model_server) except TypeError as e: logger.debug(traceback.format_exc()) raise ErrorResponse( 400, "BadRequest", f"Supplied 'model_server' is not valid. Error: {e}", { "parameter": "model_server", "in": "body" }, ) app.agent = await _load_agent(model_path, model_server, remote_storage, endpoints, app.agent.lock_store) logger.debug(f"Successfully loaded model '{model_path}'.") return response.json(None, status=204) @app.delete("/model") @requires_auth(app, auth_token) async def unload_model(request: Request): model_file = app.agent.model_directory app.agent = Agent(lock_store=app.agent.lock_store) logger.debug(f"Successfully unloaded model '{model_file}'.") return response.json(None, status=204) @app.get("/domain") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def get_domain(request: Request): """Get current domain in yaml or json format.""" accepts = request.headers.get("Accept", default="application/json") if accepts.endswith("json"): domain = app.agent.domain.as_dict() return response.json(domain) elif accepts.endswith("yml") or accepts.endswith("yaml"): domain_yaml = app.agent.domain.as_yaml() return response.text(domain_yaml, status=200, content_type="application/x-yml") else: raise ErrorResponse( 406, "NotAcceptable", "Invalid Accept header. Domain can be " "provided as " 'json ("Accept: application/json") or' 'yml ("Accept: application/x-yml"). ' "Make sure you've set the appropriate Accept " "header.", ) @app.post("/data/convert") @requires_auth(app, auth_token) async def post_data_convert(request: Request): """Converts current domain in yaml or json format.""" validate_request_body( request, "You must provide training data in the request body in order to " "train your model.", ) rjs = request.json if 'data' not in rjs: raise ErrorResponse( 400, "BadRequest", "Must provide training data in 'data' property") if 'output_format' not in rjs or rjs["output_format"] not in [ "json", "md" ]: raise ErrorResponse( 400, "BadRequest", "'output_format' is required and must be either 'md' or 'json") if 'language' not in rjs: raise ErrorResponse(400, "BadRequest", "'language' is required") temp_dir = tempfile.mkdtemp() out_dir = tempfile.mkdtemp() nlu_data_path = os.path.join(temp_dir, "nlu_data") output_path = os.path.join(out_dir, "output") if type(rjs["data"]) is dict: rasa.utils.io.dump_obj_as_json_to_file(nlu_data_path, rjs["data"]) else: rasa.utils.io.write_text_file(rjs["data"], nlu_data_path) from rasa.nlu.convert import convert_training_data convert_training_data(nlu_data_path, output_path, rjs["output_format"], rjs["language"]) with open(output_path, encoding='utf-8') as f: data = f.read() if rjs["output_format"] == 'json': import json data = json.loads(data, encoding='utf-8') return response.json({"data": data}) return app
@mythic.middleware("request") async def add_session(request): request.ctx.session = session Initialize( mythic, authentication_class=app.routes.authentication.MyAuthentication, configuration_class=app.routes.authentication.MyConfig, cookie_set=True, cookie_strict=False, cookie_access_token_name="access_token", cookie_refresh_token_name="refresh_token", cookie_httponly=True, scopes_enabled=True, add_scopes_to_payload=app.routes.authentication.add_scopes_to_payload, scopes_name="scope", secret=str(uuid.uuid4()) + str(uuid.uuid4()), url_prefix="/", class_views=my_views, path_to_authenticate="/auth", path_to_retrieve_user="******", path_to_verify="/verify", path_to_refresh="/refresh", refresh_token_enabled=True, expiration_delta=28800, # initial token expiration time, 8hrs store_refresh_token=app.routes.authentication.store_refresh_token, retrieve_refresh_token=app.routes.authentication.retrieve_refresh_token, login_redirect_url="/login", )
@blueprint2.get("/", strict_slashes=True) @protected(blueprint2) def protected_hello_world_2(request): return json({'version': 2}) async def authenticate(request, *args, **kwargs): return {'user_id': 1} app = Sanic() sanicjwt1 = Initialize( blueprint1, app=app, authenticate=authenticate, ) sanicjwt2 = Initialize(blueprint2, app=app, authenticate=authenticate, url_prefix='/a', access_token_name='token', cookie_access_token_name='token', cookie_set=True, secret='somethingdifferent') app.blueprint(blueprint1, url_prefix='/test1') app.blueprint(blueprint2, url_prefix='/test2')
def test_deprecated_payload_handler(): app = Sanic() app.config.SANIC_JWT_PAYLOAD_HANDLER = lambda *a, **kw: {} with pytest.raises(exceptions.InvalidConfiguration): Initialize(app, authenticate=lambda: True)
def create_app( agent: Optional["Agent"] = None, cors_origins: Union[Text, List[Text]] = "*", auth_token: Optional[Text] = None, jwt_secret: Optional[Text] = None, jwt_method: Text = "HS256", endpoints: Optional[AvailableEndpoints] = None, ): """Class representing a Rasa HTTP server.""" app = Sanic(__name__) app.config.RESPONSE_TIMEOUT = 60 * 60 CORS(app, resources={r"/*": { "origins": cors_origins or "" }}, automatic_options=True) # Setup the Sanic-JWT extension if jwt_secret and jwt_method: # since we only want to check signatures, we don't actually care # about the JWT method and set the passed secret as either symmetric # or asymmetric key. jwt lib will choose the right one based on method app.config["USE_JWT"] = True Initialize( app, secret=jwt_secret, authenticate=authenticate, algorithm=jwt_method, user_id="username", ) app.agent = agent @app.exception(ErrorResponse) async def handle_error_response(request: Request, exception: ErrorResponse): return response.json(exception.error_info, status=exception.status) @app.get("/") async def hello(request: Request): """Check if the server is running and responds with the version.""" return response.text("Hello from Rasa: " + rasa.__version__) @app.get("/version") async def version(request: Request): """Respond with the version number of the installed Rasa.""" return response.json({ "version": rasa.__version__, "minimum_compatible_version": MINIMUM_COMPATIBLE_VERSION, }) @app.get("/status") @requires_auth(app, auth_token) async def status(request: Request): """Respond with the model name and the fingerprint of that model.""" return response.json({ "model_file": app.agent.model_directory, "fingerprint": fingerprint_from_path(app.agent.model_directory), }) @app.get("/conversations/<conversation_id>/tracker") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def retrieve_tracker(request: Request, conversation_id: Text): """Get a dump of a conversation's tracker including its events.""" if not app.agent.tracker_store: raise ErrorResponse( 409, "Conflict", "No tracker store available. Make sure to " "configure a tracker store when starting " "the server.", ) verbosity = event_verbosity_parameter(request, EventVerbosity.AFTER_RESTART) until_time = rasa.utils.endpoints.float_arg(request, "until") tracker = obtain_tracker_store(app.agent, conversation_id) try: if until_time is not None: tracker = tracker.travel_back_in_time(until_time) state = tracker.current_state(verbosity) return response.json(state) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse( 500, "ConversationError", "An unexpected error occurred. Error: {}".format(e), ) @app.post("/conversations/<conversation_id>/tracker/events") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def append_events(request: Request, conversation_id: Text): """Append a list of events to the state of a conversation""" validate_request_body( request, "You must provide events in the request body in order to append them" "to the state of a conversation.", ) events = request.json if not isinstance(events, list): events = [events] events = [Event.from_parameters(event) for event in events] events = [event for event in events if event] if not events: logger.warning( "Append event called, but could not extract a valid event. " "Request JSON: {}".format(request.json)) raise ErrorResponse( 400, "BadRequest", "Couldn't extract a proper event from the request body.", { "parameter": "", "in": "body" }, ) verbosity = event_verbosity_parameter(request, EventVerbosity.AFTER_RESTART) tracker = obtain_tracker_store(app.agent, conversation_id) try: for event in events: tracker.update(event, app.agent.domain) app.agent.tracker_store.save(tracker) return response.json(tracker.current_state(verbosity)) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse( 500, "ConversationError", "An unexpected error occurred. Error: {}".format(e), ) @app.put("/conversations/<conversation_id>/tracker/events") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def replace_events(request: Request, conversation_id: Text): """Use a list of events to set a conversations tracker to a state.""" validate_request_body( request, "You must provide events in the request body to set the sate of the " "conversation tracker.", ) verbosity = event_verbosity_parameter(request, EventVerbosity.AFTER_RESTART) try: tracker = DialogueStateTracker.from_dict(conversation_id, request.json, app.agent.domain.slots) # will override an existing tracker with the same id! app.agent.tracker_store.save(tracker) return response.json(tracker.current_state(verbosity)) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse( 500, "ConversationError", "An unexpected error occurred. Error: {}".format(e), ) @app.get("/conversations/<conversation_id>/story") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def retrieve_story(request: Request, conversation_id: Text): """Get an end-to-end story corresponding to this conversation.""" if not app.agent.tracker_store: raise ErrorResponse( 409, "Conflict", "No tracker store available. Make sure to " "configure a tracker store when starting " "the server.", ) # retrieve tracker and set to requested state tracker = obtain_tracker_store(app.agent, conversation_id) until_time = rasa.utils.endpoints.float_arg(request, "until") try: if until_time is not None: tracker = tracker.travel_back_in_time(until_time) # dump and return tracker state = tracker.export_stories(e2e=True) return response.text(state) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse( 500, "ConversationError", "An unexpected error occurred. Error: {}".format(e), ) @app.post("/conversations/<conversation_id>/execute") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def execute_action(request: Request, conversation_id: Text): request_params = request.json action_to_execute = request_params.get("name", None) if not action_to_execute: raise ErrorResponse( 400, "BadRequest", "Name of the action not provided in request body.", { "parameter": "name", "in": "body" }, ) policy = request_params.get("policy", None) confidence = request_params.get("confidence", None) verbosity = event_verbosity_parameter(request, EventVerbosity.AFTER_RESTART) try: out = CollectingOutputChannel() await app.agent.execute_action(conversation_id, action_to_execute, out, policy, confidence) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse( 500, "ConversationError", "An unexpected error occurred. Error: {}".format(e), ) tracker = obtain_tracker_store(app.agent, conversation_id) state = tracker.current_state(verbosity) return response.json({"tracker": state, "messages": out.messages}) @app.post("/conversations/<conversation_id>/predict") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def predict(request: Request, conversation_id: Text): try: # Fetches the appropriate bot response in a json format responses = app.agent.predict_next(conversation_id) responses["scores"] = sorted(responses["scores"], key=lambda k: (-k["score"], k["action"])) return response.json(responses) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse( 500, "ConversationError", "An unexpected error occurred. Error: {}".format(e), ) @app.post("/conversations/<conversation_id>/messages") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def add_message(request: Request, conversation_id: Text): validate_request_body( request, "No message defined in request body. Add a message to the request body in " "order to add it to the tracker.", ) request_params = request.json message = request_params.get("text") sender = request_params.get("sender") parse_data = request_params.get("parse_data") verbosity = event_verbosity_parameter(request, EventVerbosity.AFTER_RESTART) # TODO: implement for agent / bot if sender != "user": raise ErrorResponse( 400, "BadRequest", "Currently, only user messages can be passed to this endpoint. " "Messages of sender '{}' cannot be handled.".format(sender), { "parameter": "sender", "in": "body" }, ) try: user_message = UserMessage(message, None, conversation_id, parse_data) tracker = await app.agent.log_message(user_message) return response.json(tracker.current_state(verbosity)) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse( 500, "ConversationError", "An unexpected error occurred. Error: {}".format(e), ) @app.post("/model/train") @requires_auth(app, auth_token) async def train(request: Request): """Train a Rasa Model.""" from rasa.train import train_async validate_request_body( request, "You must provide training data in the request body in order to " "train your model.", ) rjs = request.json validate_request(rjs) # create a temporary directory to store config, domain and # training data temp_dir = tempfile.mkdtemp() config_path = os.path.join(temp_dir, "config.yml") dump_obj_as_str_to_file(config_path, rjs["config"]) if "nlu" in rjs: nlu_path = os.path.join(temp_dir, "nlu.md") dump_obj_as_str_to_file(nlu_path, rjs["nlu"]) if "stories" in rjs: stories_path = os.path.join(temp_dir, "stories.md") dump_obj_as_str_to_file(stories_path, rjs["stories"]) domain_path = DEFAULT_DOMAIN_PATH if "domain" in rjs: domain_path = os.path.join(temp_dir, "domain.yml") dump_obj_as_str_to_file(domain_path, rjs["domain"]) try: model_path = await train_async( domain=domain_path, config=config_path, training_files=temp_dir, output_path=rjs.get("out", DEFAULT_MODELS_PATH), force_training=rjs.get("force", False), ) return await response.file(model_path) except InvalidDomain as e: raise ErrorResponse( 400, "InvalidDomainError", "Provided domain file is invalid. Error: {}".format(e), ) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse( 500, "TrainingError", "An unexpected error occurred during training. Error: {}". format(e), ) def validate_request(rjs): if "config" not in rjs: raise ErrorResponse( 400, "BadRequest", "The training request is missing the required key `config`.", { "parameter": "config", "in": "body" }, ) if "nlu" not in rjs and "stories" not in rjs: raise ErrorResponse( 400, "BadRequest", "To train a Rasa model you need to specify at least one type of " "training data. Add `nlu` and/or `stories` to the request.", { "parameters": ["nlu", "stories"], "in": "body" }, ) if "stories" in rjs and "domain" not in rjs: raise ErrorResponse( 400, "BadRequest", "To train a Rasa model with story training data, you also need to " "specify the `domain`.", { "parameter": "domain", "in": "body" }, ) @app.post("/model/test/stories") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def evaluate_stories(request: Request): """Evaluate stories against the currently loaded model.""" validate_request_body( request, "You must provide some stories in the request body in order to " "evaluate your model.", ) stories = rasa.utils.io.create_temporary_file(request.body, mode="w+b") use_e2e = rasa.utils.endpoints.bool_arg(request, "e2e", default=False) try: evaluation = await test(stories, app.agent, e2e=use_e2e) return response.json(evaluation) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse( 500, "TestingError", "An unexpected error occurred during evaluation. Error: {}". format(e), ) @app.post("/model/test/intents") @requires_auth(app, auth_token) async def evaluate_intents(request: Request): """Evaluate intents against a Rasa model.""" validate_request_body( request, "You must provide some nlu data in the request body in order to " "evaluate your model.", ) eval_agent = app.agent model_path = request.args.get("model", None) if model_path: model_server = app.agent.model_server if model_server is not None: model_server.url = model_path eval_agent = await _load_agent(model_path, model_server, app.agent.remote_storage) nlu_data = rasa.utils.io.create_temporary_file(request.body, mode="w+b") data_path = os.path.abspath(nlu_data) if not os.path.exists(eval_agent.model_directory): raise ErrorResponse(409, "Conflict", "Loaded model file not found.") model_directory = eval_agent.model_directory _, nlu_model = get_model_subdirectories(model_directory) try: evaluation = run_evaluation(data_path, nlu_model) return response.json(evaluation) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse( 500, "TestingError", "An unexpected error occurred during evaluation. Error: {}". format(e), ) @app.post("/model/predict") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def tracker_predict(request: Request): """ Given a list of events, predicts the next action""" validate_request_body( request, "No events defined in request_body. Add events to request body in order to " "predict the next action.", ) sender_id = UserMessage.DEFAULT_SENDER_ID verbosity = event_verbosity_parameter(request, EventVerbosity.AFTER_RESTART) request_params = request.json try: tracker = DialogueStateTracker.from_dict(sender_id, request_params, app.agent.domain.slots) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse( 400, "BadRequest", "Supplied events are not valid. {}".format(e), { "parameter": "", "in": "body" }, ) try: policy_ensemble = app.agent.policy_ensemble probabilities, policy = policy_ensemble.probabilities_using_best_policy( tracker, app.agent.domain) scores = [{ "action": a, "score": p } for a, p in zip(app.agent.domain.action_names, probabilities)] return response.json({ "scores": scores, "policy": policy, "tracker": tracker.current_state(verbosity), }) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse( 500, "PredictionError", "An unexpected error occurred. Error: {}".format(e), ) @app.post("/model/parse") @requires_auth(app, auth_token) async def parse(request: Request): validate_request_body( request, "No text message defined in request_body. Add text message to request body " "in order to obtain the intent and extracted entities.", ) emulation_mode = request.args.get("emulation_mode") emulator = _create_emulator(emulation_mode) try: data = emulator.normalise_request_json(request.json) parse_data = await app.agent.interpreter.parse(data.get("text")) response_data = emulator.normalise_response_json(parse_data) return response.json(response_data) except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse( 500, "ParsingError", "An unexpected error occurred. Error: {}".format(e)) @app.put("/model") @requires_auth(app, auth_token) async def load_model(request: Request): validate_request_body( request, "No path to model file defined in request_body.") model_path = request.json.get("model_file", None) model_server = request.json.get("model_server", None) remote_storage = request.json.get("remote_storage", None) app.agent = await _load_agent(model_path, model_server, remote_storage, endpoints) logger.debug("Successfully loaded model '{}'.".format(model_path)) return response.json(None, status=204) @app.delete("/model") @requires_auth(app, auth_token) async def unload_model(request: Request): model_file = app.agent.model_directory app.agent = Agent() logger.debug("Successfully unload model '{}'.".format(model_file)) return response.json(None, status=204) @app.get("/domain") @requires_auth(app, auth_token) @ensure_loaded_agent(app) async def get_domain(request: Request): """Get current domain in yaml or json format.""" accepts = request.headers.get("Accept", default="application/json") if accepts.endswith("json"): domain = app.agent.domain.as_dict() return response.json(domain) elif accepts.endswith("yml") or accepts.endswith("yaml"): domain_yaml = app.agent.domain.as_yaml() return response.text(domain_yaml, status=200, content_type="application/x-yml") else: raise ErrorResponse( 406, "NotAcceptable", "Invalid Accept header. Domain can be " "provided as " 'json ("Accept: application/json") or' 'yml ("Accept: application/x-yml"). ' "Make sure you've set the appropriate Accept " "header.", ) return app
def test_invalid_initialization_object(): app = Sanic() with pytest.raises(exceptions.InitializationFailure): Initialize(object, app=app, authenticate=lambda: True)