Esempio n. 1
0
async def get_nodes(request: Request) -> ORJSONResponse:
    try:
        limit = int(request.query_params.get("limit", 50))
    except ValueError as e:
        raise HTTPException("Invalid limit", 422) from e
    if limit > 100:
        raise HTTPException("Limit too large", 422)

    if "all" in request.query_params:
        cursor = await connection().fetch(
            "select * from nodes where user_id = $1 order by id limit $2;",
            request.user.id,
            limit,
        )
        return ORJSONResponse(list(cursor))

    # if parent is not in params, it is none, so it gets the nodes at the root
    id_ = request.query_params.get("parent")
    cursor = await connection().fetch(
        "select * from nodes where user_id = $1 and parent = $2 order by id limit $3;",
        request.user.id,
        id_,
        limit,
    )
    return ORJSONResponse(list(cursor))
Esempio n. 2
0
async def validate_api_key(uuid: UUID) -> Tuple[User, List[str]]:
    result = await connection().fetchrow(
        """
        select (
            api_keys.expiry,
            api_keys.scope,
            users.id,
            users.name,
            users.email,
            users.password,
            users.two_factor_secret,
            users.two_factor_recovery
        )
        from api_keys inner join users on api_keys.user_id = users.id
        where api_keys.id = $1 limit 1;
        """,
        uuid,
    )
    if result is None:
        raise HTTPException("invalid API key", 401)
    expiry = result.pop("expiry")
    if expiry is not None and expiry <= datetime.now():
        # expired
        raise HTTPException("invalid API key", 401)
    scope = result.pop("scope")
    return (
        User(**result),
        orjson.loads(scope),
    )
Esempio n. 3
0
async def cancel_two_factor_setup(request: Request) -> ORJSONResponse:
    if not request.user.two_factor_secret:
        raise HTTPException("Not enabled", 422)
    if request.user.two_factor_recovery:
        raise HTTPException("Already confirmed", 422)
    await connection().execute(
        "update users set two_factor_secret = null where id = $1",
        request.user.id)
    return ORJSONResponse()
Esempio n. 4
0
def check_password_strength(password: str) -> Literal[True]:
    if password == "beef stew":
        # easter egg
        raise HTTPException("Password not stroganoff", 418)
    elif strength(password) < 8:
        raise HTTPException("Password not strong enough", 422)
    elif any(map(" ".__gt__, password)):
        # control character
        raise HTTPException("Invalid character in password", 422)
    return True
Esempio n. 5
0
async def delete_node(request: Request) -> ORJSONResponse:
    try:
        id_ = uuid.UUID(request.query_params.get("id"))
    except (ValueError, TypeError) as e:
        raise HTTPException("invalid ID", 422) from e

    result = await connection().execute("""delete from nodes where id = $1""",
                                        id_)
    if result != "DELETE 1":
        raise HTTPException("not found", 404)

    return ORJSONResponse()
Esempio n. 6
0
async def signup(request: Request) -> ORJSONResponse:
    json = await get_json(request)
    validate_types_raising(json, str)
    try:
        email = email_validator.validate_email(json)
    except email_validator.EmailNotValidError as e:
        raise HTTPException("invalid email address", 422) from e

    if await connection().fetchval("""select id from users where email = $1""", email):
        raise HTTPException("email address in use", 422)

    await send_message(email, "Confirm Tome account", "signup_confirm")
    return ORJSONResponse(None, 202)
Esempio n. 7
0
async def _create_account(name: str, email: str, password: str) -> UUID:
    if not name:
        raise HTTPException("empty name", 422)
    check_password_strength(password)
    try:
        return await connection().fetchval(
            """
            insert into users (email, name, password) values ($1, $2, $3) returning id;
            """,
            email,
            name,
            hash_password(password),
        )
    except asyncpg.UniqueViolationError as e:
        raise HTTPException("email address in use", 422) from e
Esempio n. 8
0
def validate_types_raising(data: Any, type_: Any) -> None:
    """validate types, raising a HTTPException if the types are invalid

    Takes the same parameters as validate_types
    """
    if not validate_types(data, type_):
        raise HTTPException("invalid types", 422)
Esempio n. 9
0
    def decorator(
        endpoint: Callable[[Request], Awaitable[Response]]
    ) -> Callable[[Request], Awaitable[Response]]:
        if redirect is None:
            exc = HTTPException(detail=detail,
                                status_code=status_code or 401,
                                headers=headers)

            async def inner(request: Request) -> Response:
                if scopes_set - set(request.auth):
                    raise exc
                return await endpoint(request)

        else:
            response = RedirectResponse(redirect,
                                        status_code=status_code or 307,
                                        headers=headers)

            async def inner(request: Request) -> Response:
                if scopes_set - set(request.auth):
                    return response
                else:
                    return await endpoint(request)

        return functools.wraps(endpoint)(inner)
Esempio n. 10
0
async def validate_auth_token(
        token: Union[str, bytes]) -> Tuple[User, List[str]]:
    payload = decode_jwt(token)
    user = await connection().fetchrow("SELECT * FROM users WHERE id = $1",
                                       payload["sub"])
    if user is None:
        raise HTTPException("account not available", 409)
    return User(**user), payload["scope"]
Esempio n. 11
0
async def patch_account_name(request: Request) -> ORJSONResponse:
    name = await get_json(request)
    validate_types_raising(name, str)
    if not name:
        raise HTTPException("invalid name", 400)
    await connection().execute("update users set name = $1 where id = $2",
                               name, request.user.id)
    return ORJSONResponse()
Esempio n. 12
0
async def begin_two_factor_setup(request: Request) -> ORJSONResponse:
    if request.user.two_factor_recovery:
        raise HTTPException("Already enabled", 422)
    elif request.user.two_factor_secret:
        raise HTTPException("Setup already started", 422)

    secret = pyotp.random_base32(32)

    await connection().execute(
        "update users set two_factor_secret = $1 where id = $2", secret,
        request.user.id)

    return ORJSONResponse({
        "secret":
        secret,
        "qr_code_url":
        make_totp_qr_code(secret, request.user.email)
    })
Esempio n. 13
0
async def confirm_two_factor_setup(request: Request) -> ORJSONResponse:
    json = await get_json(request)
    validate_types_raising(json, str)

    if request.user.two_factor_recovery:
        raise HTTPException("Already confirmed", 422)
    if not request.user.two_factor_secret:
        raise HTTPException("Not enabled", 422)
    if not pyotp.TOTP(request.user.two_factor_secret).verify(json):
        raise HTTPException("Incorrect code", 422)

    recovery = pyotp.random_base32(32)
    await connection().execute(
        "update users set two_factor_recovery = $1 where id = $2",
        recovery,
        request.user.id,
    )
    return ORJSONResponse(recovery)
Esempio n. 14
0
async def test_http_exception_handler():
    import tome.exception_handlers
    from tome.exceptions import HTTPException

    exception = HTTPException({"some": ["json", 45]}, 418)
    response = await tome.exception_handlers.handle_http_exception(
        None, exception)
    assert orjson.loads(response.body) == {"error": {"some": ["json", 45]}}
    assert response.headers["content-type"] == "application/json"
    assert response.status_code == 418
Esempio n. 15
0
async def change_password(request: Request) -> ORJSONResponse:
    json = await get_json(request)
    validate_types_raising(json, {"new": str, "current": str})

    # check new password validity
    if json["new"] == json["current"]:
        # same as current (even if current is incorrect, we needn't bother checking)
        raise HTTPException("Password not changed", 422)
    check_password_strength(json["new"])

    # check current password is correct
    if not verify_password(request.user.password, json["current"]):
        raise HTTPException("Incorrect password", 401)

    # update password
    hashed_new = hash_password(json["new"])
    await connection().execute("update users set password = $1 where id = $2",
                               hashed_new, request.user.id)
    return ORJSONResponse()
Esempio n. 16
0
async def patch_account_email(request: Request) -> ORJSONResponse:
    json = await get_json(request)
    validate_types_raising(json, str)
    try:
        email = email_validator.validate_email(json).email
    except email_validator.EmailNotValidError as e:
        raise HTTPException("invalid email address", 400) from e
    await connection().execute("update users set email = $1 where id = $2",
                               email, request.user.id)
    return ORJSONResponse()
Esempio n. 17
0
async def signup_no_confirm(request: Request) -> ORJSONResponse:
    json = await get_json(request)
    validate_types_raising(json, {"name": str, "password": str, "email": str})
    try:
        email_validator.validate_email(json["email"])
    except email_validator.EmailNotValidError as e:
        raise HTTPException("invalid email address", 422) from e

    user_id = await _create_account(**json)

    return ORJSONResponse(user_id, 201)
Esempio n. 18
0
async def get_json(request: starlette.requests.Request) -> Any:
    """
    return json data from a request, raising appropriate http exceptions if invalid
    :param request: starlette request object
    :return: any valid JSON data
    """
    body = await request.body()
    try:
        return orjson.loads(body)
    except orjson.JSONDecodeError as e:
        raise HTTPException(f"invalid json: {e}", 400)
Esempio n. 19
0
async def create_api_key(expiry: Optional[datetime], user_id: UUID,
                         scope: Sequence[str]) -> UUID:
    if set(scope) - ALLOWED_API_KEY_SCOPES:
        raise HTTPException("cannot issue an API token with this scope", 422)
    key = await connection().fetchval(
        "insert into api_keys (scope, user_id, expiry) values ($1, $2, $3) returning id",
        orjson.dumps(scope).decode(),
        user_id,
        expiry,
    )
    return cast(UUID, key)
Esempio n. 20
0
def decode_jwt(token: Union[str, bytes]) -> Any:
    try:
        return jwt.decode(
            token,
            algorithms=[ALGORITHM],
            key=SECRET_KEY,
            issuer=ISSUER,
            audience=AUDIENCE,
        )
    except jwt.InvalidTokenError as e:
        raise HTTPException("invalid token", 401) from e
Esempio n. 21
0
async def delete_api_key(request: Request) -> ORJSONResponse:
    query = dict(request.query_params)
    validate_types_raising(query, {"id": str})
    result = await connection().execute(
        "delete from api_keys where id = $1 and user_id = $2",
        query["id"],
        request.user.id,
    )
    if result.split(" ")[1] == "0":
        raise HTTPException("not found", 404)
    return ORJSONResponse()
Esempio n. 22
0
async def signup_confirm(request: Request) -> ORJSONResponse:
    json = await get_json(request)
    validate_types_raising(json, {"name": str, "password": str, "token": str})

    email = decode_jwt(json["token"])["sub"]
    json["email"] = email

    try:
        user_id = await _create_account(**json)
    except HTTPException as e:
        raise HTTPException("account has already been created", 409) from e

    return ORJSONResponse(user_id, 201)
Esempio n. 23
0
async def modify_node(request: Request) -> ORJSONResponse:
    try:
        id_ = uuid.UUID(request.query_params.get("id"))
    except (ValueError, TypeError) as e:
        raise HTTPException("invalid ID", 422) from e

    json = await get_json(request)
    validate_types_raising(json, {
        "content": Optional[str],
        "parent": Optional[str]
    })

    query = _modify_node_query(json, id_)

    try:
        result = await connection().fetchrow(*query)
    except (TypeError, ValueError) as e:
        raise HTTPException("Invalid ID", 422) from e
    except asyncpg.ForeignKeyViolationError as e:
        raise HTTPException("Parent does not exist", 404) from e

    return ORJSONResponse(result)
Esempio n. 24
0
async def create_node(request: Request) -> ORJSONResponse:
    json = await get_json(request)
    validate_types_raising(json, {
        "parent": Optional[str],
        "content": Optional[str]
    })

    parent = None
    if "parent" in json:
        try:
            parent = uuid.UUID(json["parent"])
        except ValueError as e:
            raise HTTPException("invalid ID", 422) from e

    try:
        result = await connection().fetchval(
            "insert into nodes (user_id, parent) values ($1, $2) returning id",
            request.user.id,
            parent,
        )
    except asyncpg.ForeignKeyViolationError as e:
        raise HTTPException("Parent node does not exist", 404) from e

    return ORJSONResponse(result)
Esempio n. 25
0
 async def authenticate(
         self, request: HTTPConnection) -> Tuple[Optional[User], List[str]]:
     header: str = request.headers.get("Authorization", "")
     prefix, _, token = header.partition(" ")
     if prefix != "Bearer":
         return None, ["anonymous"]
     if token.startswith("api-key-"):
         try:
             uuid = UUID(token[8:])
         except ValueError as e:
             raise HTTPException("invalid API key", 401) from e
         return await auth.validate_api_key(uuid)
     else:
         # token is JWT
         return await auth.validate_auth_token(token.encode())
Esempio n. 26
0
    def inner(data: dict, *extra: Any) -> Iterable[Any]:
        nonlocal placeholders

        # if there are any provided keys that don't exist as columns
        if data.keys() - columns:
            raise HTTPException("invalid types", 422)

        format_args = []
        values = []
        for column in columns:
            if column in data:
                # format with a postgresql placeholder
                format_args.append("$" + str(placeholders))
                values.append(data[column])
                placeholders += 1
            else:
                # format with the column name, to keep the existing stored value
                format_args.append(column)

        return [template.format(*format_args), *extra, *values]
Esempio n. 27
0
async def get_api_key(request: Request) -> ORJSONResponse:
    if "id" in request.query_params:
        result = await connection().fetchrow(
            "select id, scope, expiry from api_keys where id = $1 and user_id = $2",
            request.query_params["id"],
            request.user.id,
        )
        if not result:
            raise HTTPException("not found", 404)
        else:
            result = dict(result)
    else:
        result = list(
            map(
                dict,
                await connection().fetch(
                    "select id, scope, expiry from api_keys where user_id = $1",
                    request.user.id,
                ),
            ))
    return ORJSONResponse(result)
Esempio n. 28
0
 async def fake_authenticate_incorrect(_self, _request):
     raise HTTPException("foo bar", 418)