Exemplo n.º 1
0
class OAuthClient:
    """
    Helper class to handle the OAuth authentication flow
    the logic is divided in 2 steps:
    - open the browser on GitGuardian login screen and run a local server to wait for callback
    - handle the oauth callback to exchange an authorization code against a valid access token
    """
    def __init__(self, config: Config, instance: str) -> None:
        self.config = config
        self.instance = instance
        self._oauth_client = WebApplicationClient(CLIENT_ID)
        self._state = ""  # use the `state` property instead

        self._handler_wrapper = RequestHandlerWrapper(oauth_client=self)
        self._access_token: Optional[str] = None
        self._port = USABLE_PORT_RANGE[0]
        self.server: Optional[HTTPServer] = None

        self._generate_pkce_pair()

    def oauth_process(self,
                      token_name: Optional[str] = None,
                      lifetime: Optional[int] = None) -> None:
        """
        Handle the whole oauth process which includes
        - opening the user's webbrowser to GitGuardian login page
        - open a server and wait for the callback processing
        """
        # enable redirection to http://localhost
        os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = str(True)

        if token_name is None:
            token_name = "ggshield token " + datetime.today().strftime(
                "%Y-%m-%d")
        self._token_name = token_name

        if lifetime is None:
            lifetime = self.default_token_lifetime
        self._lifetime = lifetime

        self._prepare_server()
        self._redirect_to_login()
        self._wait_for_callback()

        message = f"Created Personal Access Token {self._token_name} "
        expire_at = self.instance_config.account.expire_at
        if expire_at is not None:
            message += "expiring on " + get_pretty_date(expire_at)
        else:
            message += "with no expiry"
        click.echo(message)

    def process_callback(self, callback_url: str) -> None:
        """
        This function runs within the request handler do_GET method.
        It takes the url of the callback request as argument and does
        - Extract the authorization code
        - Exchange the code against an access token with GitGuardian's api
        - Validate the new token against GitGuardian's api
        - Save the token in configuration
        Any error during this process will raise a OAuthError
        """
        authorization_code = self._get_code(callback_url)
        self._claim_token(authorization_code)
        token_data = self._validate_access_token()
        self._save_token(token_data)

    def _generate_pkce_pair(self) -> None:
        """
        Generate a code verifier (random string) and its sha encoded version to be used
        for the pkce checking process
        """
        self.code_verifier = self._oauth_client.create_code_verifier(
            128)  # type: ignore
        self.code_challenge = (urlsafe_b64encode(
            sha256(self.code_verifier.encode()).digest()).decode().rstrip("="))

    def _redirect_to_login(self) -> None:
        """
        Open the user's browser to the GitGuardian ggshield authentication page
        """
        static_params = {
            "auth_mode": "ggshield_login",
            "utm_source": "cli",
            "utm_medium": "login",
            "utm_campaign": "ggshield",
        }
        request_uri = self._oauth_client.prepare_request_uri(
            uri=urlparse.urljoin(self.dashboard_url, "auth/login"),
            redirect_uri=self.redirect_uri,
            scope=SCOPE,
            code_challenge=self.code_challenge,
            code_challenge_method="S256",
            state=self.state,
            **static_params,
        )
        click.echo(
            f"To complete the login process, follow the instructions from {request_uri}.\n"
            "Opening your web browser now...")
        webbrowser.open_new_tab(request_uri)

    def _prepare_server(self) -> None:
        for port in range(*USABLE_PORT_RANGE):
            try:
                self.server = HTTPServer(
                    # only consider requests from localhost on the predetermined port
                    ("127.0.0.1", port),
                    # attach the wrapped request handler
                    self._handler_wrapper.request_handler,
                )
                self._port = port
                break
            except OSError:
                continue
        else:
            raise click.ClickException("Could not find unoccupied port.")

    def _wait_for_callback(self) -> None:
        """
        Wait to receive and process the authorization callback on the local server.
        This actually catches HTTP requests made on the previously opened server.
        The callback processing logic is implementend in the request handler class
        and the `process_callback` method
        """
        try:
            while not self._handler_wrapper.complete:
                # Wait for callback on localserver including an authorization code
                # any matchin request will get processed by the request handler and
                # the `process_callback` function
                self.server.handle_request()  # type: ignore
        except KeyboardInterrupt:
            raise click.ClickException("Aborting")

        if self._handler_wrapper.error_message is not None:
            # if no error message is attached, the process is considered successful
            raise click.ClickException(self._handler_wrapper.error_message)

    def _get_code(self, uri: str) -> str:
        """
        Extract the authorization from the incoming request uri and verify that the state from
        the uri match the one stored internally.
        if no code can be extracted or the state is invalid, raise an OAuthError
        else return the extracted code
        """
        try:
            authorization_code = self._oauth_client.parse_request_uri_response(
                uri, self.state).get("code")
        except OAuth2Error:
            authorization_code = None
        if authorization_code is None:
            raise OAuthError(
                "Invalid code or state received from the callback.")
        return authorization_code  # type: ignore

    def _claim_token(self, authorization_code: str) -> None:
        """
        Exchange the authorization code with a valid access token using GitGuardian public api.
        If no valid token could be retrieved, exit the authentication process with an error message
        """

        request_params = {"name": self._token_name}
        if self._lifetime is not None:
            request_params["lifetime"] = str(self._lifetime)

        request_body = self._oauth_client.prepare_request_body(
            code=authorization_code,
            redirect_uri=self.redirect_uri,
            code_verifier=self.code_verifier,
            body=urlparse.urlencode(request_params),
        )

        response = requests.post(
            urlparse.urljoin(self.api_url, "oauth/token"),
            request_body,
            headers={"Content-Type": "application/x-www-form-urlencoded"},
        )

        if not response.ok:
            raise OAuthError("Cannot create a token.")

        self._access_token = response.json()["key"]
        self.config.auth_config.current_token = self._access_token

    def _validate_access_token(self) -> Dict[str, Any]:
        """
        Validate the token using GitGuardian public api.
        If the token is not valid, exit the authentication process with an error message.
        """
        response = retrieve_client(self.config).get(endpoint="token")
        if not response.ok:
            raise OAuthError("The created token is invalid.")
        return response.json()  # type: ignore

    def _save_token(self, api_token_data: Dict[str, Any]) -> None:
        """
        Save the new token in the configuration.
        """
        account_config = AccountConfig(
            account_id=api_token_data["account_id"],
            token=self._access_token,  # type: ignore
            expire_at=api_token_data.get("expire_at"),
            token_name=api_token_data.get("name", ""),
            type=api_token_data.get("type", ""),
        )
        self.instance_config.account = account_config
        self.config.save()

    @property
    def instance_config(self) -> InstanceConfig:
        return self.config.auth_config.instances[self.instance]

    @property
    def default_token_lifetime(self) -> Optional[int]:
        """
        return the default token lifetime saved in the instance config.
        if None, this will be interpreted as no expiry.
        """
        default_lifetime = self.instance_config.default_token_lifetime
        if default_lifetime is not None:
            return default_lifetime.days

        return None

    @property
    def redirect_uri(self) -> str:
        return f"http://localhost:{self._port}"

    @property
    def state(self) -> str:
        """
        Return the state used to verify the auth process.
        The state is included in the redirect_uri and is expected in the callback url.
        Then, if both states don't match, the process fails.
        The state is an url-encoded string dict containing the token name and lifetime
        It is cached to prevent from altering its value during the process
        """
        if not self._state:
            self._state = urlparse.quote(
                json.dumps({
                    "token_name": self._token_name,
                    "lifetime": self._lifetime
                }))
        return self._state

    @property
    def dashboard_url(self) -> str:
        return self.config.dashboard_url

    @property
    def api_url(self) -> str:
        return self.config.api_url