Beispiel #1
0
    def _get_token(self, timeout=100):
        if self.client_id is None:
            raise AuthError("Could not find CLIENT_ID")

        if self.client_secret is None:
            raise AuthError("Could not find CLIENT_SECRET")

        s = requests.Session()
        retries = Retry(total=5,
                        backoff_factor=random.uniform(1, 10),
                        method_whitelist=frozenset(['GET', 'POST']),
                        status_forcelist=[429, 500, 502, 503, 504])

        s.mount('https://', HTTPAdapter(max_retries=retries))

        headers = {"content-type": "application/json"}
        params = {
            "scope": " ".join(self.scope),
            "client_id": self.client_id,
            "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
            "target": self.client_id,
            "api_type": "app",
            "refresh_token": self.client_secret
        }
        r = s.post(self.domain + "/auth/delegation", headers=headers, data=json.dumps(params), timeout=timeout)

        if r.status_code != 200:
            raise OauthError("%s: %s" % (r.status_code, r.text))

        data = r.json()
        self._token = data['id_token']

        token_info = {}

        if self.token_info_path:
            try:
                with open(self.token_info_path) as fp:
                    token_info = json.load(fp)
            except (IOError, ValueError):
                pass

        token_info['jwt_token'] = self._token

        if self.token_info_path:
            token_info_directory = os.path.dirname(self.token_info_path)
            makedirs_if_not_exists(token_info_directory)

            try:
                with open(self.token_info_path, 'w+') as fp:
                    json.dump(token_info, fp)

                os.chmod(self.token_info_path, stat.S_IRUSR | stat.S_IWUSR)
            except IOError as e:
                warnings.warn('failed to save token: {}'.format(e))
Beispiel #2
0
    def payload(self):
        """
        Gets the token payload.

        :rtype: dict
        :return: Dictionary containing the fields specified by scope, which may include:

            .. highlight:: none

            ::

                name:           The name of the user.
                groups:         Groups to which the user belongs.
                org:            The organization to which the user belongs.
                email:          The email address of the user.
                email_verified: True if the user's email has been verified.
                sub:            The user identifier.
                exp:            The expiration time of the token, in seconds since
                                the start of the unix epoch.

        :raises ~descarteslabs.client.exceptions.AuthError: Raised when
            incomplete information has been provided.
        :raises ~descarteslabs.client.exceptions.OauthError: Raised when
            a token cannot be obtained or refreshed.
        """
        if self._token is None:
            self._get_token()

        if isinstance(self._token, six.text_type):
            token = self._token.encode("utf-8")
        else:
            token = self._token

        try:
            claims = token.split(b".")[1]
            return json.loads(base64url_decode(claims).decode("utf-8"))
        except (
                IndexError,
                UnicodeDecodeError,
                binascii.Error,
                json.JSONDecodeError,
        ) as e:
            raise AuthError("Unable to read token: {}".format(e))
Beispiel #3
0
    def _get_token(self, timeout=100):
        if self.client_id is None:
            raise AuthError("Could not find client_id")

        if self.client_secret is None and self.refresh_token is None:
            raise AuthError("Could not find client_secret or refresh token")

        if self.client_id in ["ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c"
                              ]:  # TODO(justin) remove legacy handling
            # TODO (justin) insert deprecation warning
            if self.scope is None:
                scope = ["openid", "name", "groups", "org", "email"]
            else:
                scope = self.scope
            params = {
                "scope": " ".join(scope),
                "client_id": self.client_id,
                "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
                "target": self.client_id,
                "api_type": "app",
                "refresh_token": self.refresh_token,
            }
        else:
            params = {
                "client_id": self.client_id,
                "grant_type": "refresh_token",
                "refresh_token": self.refresh_token,
            }

            if self.scope is not None:
                params["scope"] = " ".join(self.scope)

        r = self.session.post(self.domain + "/token",
                              json=params,
                              timeout=timeout)

        if r.status_code != 200:
            raise OauthError("%s: %s" % (r.status_code, r.text))

        data = r.json()
        access_token = data.get("access_token")
        id_token = data.get(
            "id_token")  # TODO(justin) remove legacy id_token usage

        if access_token is not None:
            self._token = access_token
        elif id_token is not None:
            self._token = id_token
        else:
            raise OauthError("could not retrieve token")
        token_info = {}

        if self.token_info_path:
            try:
                with open(self.token_info_path) as fp:
                    token_info = json.load(fp)
            except (IOError, ValueError):
                pass

        token_info["jwt_token"] = self._token

        if self.token_info_path:
            token_info_directory = os.path.dirname(self.token_info_path)
            makedirs_if_not_exists(token_info_directory)

            try:
                with open(self.token_info_path, "w+") as fp:
                    json.dump(token_info, fp)

                os.chmod(self.token_info_path, stat.S_IRUSR | stat.S_IWUSR)
            except IOError as e:
                warnings.warn("failed to save token: {}".format(e))
class TestAuth(unittest.TestCase):
    def tearDown(self):
        warnings.resetwarnings()

    def test_auth_client_refresh_match(self):
        with warnings.catch_warnings(record=True) as w:
            auth = Auth(
                client_id="client_id",
                client_secret="secret",
                refresh_token="mismatched_refresh_token",
            )
            assert 1 == len(w)
            assert "mismatched_refresh_token" == auth.refresh_token
            assert "mismatched_refresh_token" == auth.client_secret

    @responses.activate
    def test_get_token(self):
        responses.add(
            responses.POST,
            "https://accounts.descarteslabs.com/token",
            json=dict(access_token="access_token"),
            status=200,
        )
        auth = Auth(token_info_path=None,
                    client_secret="client_secret",
                    client_id="client_id")
        auth._get_token()

        assert "access_token" == auth._token

    @responses.activate
    def test_get_token_legacy(self):
        responses.add(
            responses.POST,
            "https://accounts.descarteslabs.com/token",
            json=dict(id_token="id_token"),
            status=200,
        )
        auth = Auth(token_info_path=None,
                    client_secret="client_secret",
                    client_id="client_id")
        auth._get_token()

        assert "id_token" == auth._token

    @patch("descarteslabs.client.auth.Auth.payload", new=dict(sub="asdf"))
    def test_get_namespace(self):
        auth = Auth(token_info_path=None,
                    client_secret="client_secret",
                    client_id="client_id")
        assert auth.namespace == "3da541559918a808c2402bba5012f6c60b27661c"

    def test_init_token_no_path(self):
        auth = Auth(jwt_token="token", token_info_path=None, client_id="foo")
        assert "token" == auth._token

    @responses.activate
    def test_get_token_schema_internal_only(self):
        responses.add_callback(
            responses.POST,
            "https://accounts.descarteslabs.com/token",
            callback=token_response_callback,
        )
        auth = Auth(token_info_path=None,
                    refresh_token="refresh_token",
                    client_id="client_id")
        auth._get_token()

        assert "access_token" == auth._token

        auth = Auth(token_info_path=None,
                    client_secret="refresh_token",
                    client_id="client_id")
        auth._get_token()

        assert "access_token" == auth._token

    @responses.activate
    def test_get_token_schema_legacy_internal_only(self):
        responses.add_callback(
            responses.POST,
            "https://accounts.descarteslabs.com/token",
            callback=token_response_callback,
        )
        auth = Auth(
            token_info_path=None,
            client_secret="client_secret",
            client_id="ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c",
        )
        auth._get_token()
        assert "id_token" == auth._token

    @patch("descarteslabs.client.auth.Auth._get_token")
    def test_token(self, _get_token):
        auth = Auth(
            token_info_path=None,
            client_secret="client_secret",
            client_id="ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c",
        )
        token = b".".join(
            (base64.b64encode(to_bytes(p))
             for p in ["header",
                       json.dumps(dict(exp=9999999999)), "sig"]))
        auth._token = token

        assert auth.token == token
        _get_token.assert_not_called()

    @patch("descarteslabs.client.auth.Auth._get_token")
    def test_token_expired(self, _get_token):
        auth = Auth(
            token_info_path=None,
            client_secret="client_secret",
            client_id="ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c",
        )
        token = b".".join(
            (base64.b64encode(to_bytes(p))
             for p in ["header", json.dumps(dict(exp=0)), "sig"]))
        auth._token = token

        assert auth.token == token
        _get_token.assert_called_once()

    @patch("descarteslabs.client.auth.Auth._get_token",
           side_effect=AuthError("error"))
    def test_token_expired_autherror(self, _get_token):
        auth = Auth(
            token_info_path=None,
            client_secret="client_secret",
            client_id="ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c",
        )
        token = b".".join(
            (base64.b64encode(to_bytes(p))
             for p in ["header", json.dumps(dict(exp=0)), "sig"]))
        auth._token = token

        with pytest.raises(AuthError):
            auth.token
        _get_token.assert_called_once()

    @patch("descarteslabs.client.auth.Auth._get_token",
           side_effect=AuthError("error"))
    def test_token_in_leeway_autherror(self, _get_token):
        auth = Auth(
            token_info_path=None,
            client_secret="client_secret",
            client_id="ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c",
        )
        exp = (datetime.datetime.utcnow() -
               datetime.datetime(1970, 1, 1)).total_seconds() + auth.leeway / 2
        token = b".".join(
            (base64.b64encode(to_bytes(p))
             for p in ["header", json.dumps(dict(exp=exp)), "sig"]))
        auth._token = token

        assert auth.token == token
        _get_token.assert_called_once()

    def test_auth_init_env_vars(self):
        warnings.simplefilter("ignore")

        environ = dict(
            CLIENT_SECRET="secret_bar",
            CLIENT_ID="id_bar",
            DESCARTESLABS_CLIENT_SECRET="secret_foo",
            DESCARTESLABS_CLIENT_ID="id_foo",
            DESCARTESLABS_REFRESH_TOKEN="refresh_foo",
        )

        # should work with direct var
        with patch.dict("descarteslabs.client.auth.auth.os.environ", environ):
            auth = Auth(
                client_id="client_id",
                client_secret="client_secret",
                refresh_token="client_secret",
                jwt_token="jwt_token",
            )
            assert auth.client_secret == "client_secret"
            assert auth.client_id == "client_id"

        # should work with namespaced env vars
        with patch.dict("descarteslabs.client.auth.auth.os.environ", environ):
            auth = Auth()
            # when refresh_token and client_secret do not match,
            # the Auth implementation sets both to the value of
            # refresh_token
            assert auth.client_secret == environ.get(
                "DESCARTESLABS_REFRESH_TOKEN")
            assert auth.client_id == environ.get("DESCARTESLABS_CLIENT_ID")

        # remove the namespaced ones, except the refresh token because
        # Auth does not recognize a REFRESH_TOKEN environment variable
        # and removing it from the dictionary would result in non-deterministic
        # results based on the token_info.json file on the test runner disk
        environ.pop("DESCARTESLABS_CLIENT_SECRET")
        environ.pop("DESCARTESLABS_CLIENT_ID")

        # should fallback to legacy env vars
        with patch.dict("descarteslabs.client.auth.auth.os.environ", environ):
            auth = Auth()
            assert auth.client_secret == environ.get(
                "DESCARTESLABS_REFRESH_TOKEN")
            assert auth.client_id == environ.get("CLIENT_ID")
class TestAuth(unittest.TestCase):
    @responses.activate
    def test_get_token(self):
        responses.add(
            responses.POST,
            "https://accounts.descarteslabs.com/token",
            json=dict(access_token="access_token"),
            status=200,
        )
        auth = Auth(token_info_path=None,
                    client_secret="client_secret",
                    client_id="client_id")
        auth._get_token()

        self.assertEqual("access_token", auth._token)

    @responses.activate
    def test_get_token_legacy(self):
        responses.add(
            responses.POST,
            "https://accounts.descarteslabs.com/token",
            json=dict(id_token="id_token"),
            status=200,
        )
        auth = Auth(token_info_path=None,
                    client_secret="client_secret",
                    client_id="client_id")
        auth._get_token()

        self.assertEqual("id_token", auth._token)

    @patch("descarteslabs.client.auth.Auth.payload", new=dict(sub="asdf"))
    def test_get_namespace(self):
        auth = Auth(token_info_path=None,
                    client_secret="client_secret",
                    client_id="client_id")
        self.assertEqual(auth.namespace,
                         "3da541559918a808c2402bba5012f6c60b27661c")

    def test_init_token_no_path(self):
        auth = Auth(jwt_token="token", token_info_path=None, client_id="foo")
        self.assertEqual("token", auth._token)

    @responses.activate
    def test_get_token_schema_internal_only(self):
        responses.add_callback(
            responses.POST,
            "https://accounts.descarteslabs.com/token",
            callback=token_response_callback,
        )
        auth = Auth(token_info_path=None,
                    refresh_token="refresh_token",
                    client_id="client_id")
        auth._get_token()

        self.assertEqual("access_token", auth._token)

        auth = Auth(token_info_path=None,
                    client_secret="refresh_token",
                    client_id="client_id")
        auth._get_token()

        self.assertEqual("access_token", auth._token)

    @responses.activate
    def test_get_token_schema_legacy_internal_only(self):
        responses.add_callback(
            responses.POST,
            "https://accounts.descarteslabs.com/token",
            callback=token_response_callback,
        )
        auth = Auth(
            token_info_path=None,
            client_secret="client_secret",
            client_id="ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c",
        )
        auth._get_token()
        self.assertEqual("id_token", auth._token)

    @patch("descarteslabs.client.auth.Auth._get_token")
    def test_token(self, _get_token):
        auth = Auth(
            token_info_path=None,
            client_secret="client_secret",
            client_id="ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c",
        )
        token = b".".join(
            (base64.b64encode(to_bytes(p))
             for p in ["header",
                       json.dumps(dict(exp=9999999999)), "sig"]))
        auth._token = token

        self.assertEqual(auth.token, token)
        _get_token.assert_not_called()

    @patch("descarteslabs.client.auth.Auth._get_token")
    def test_token_expired(self, _get_token):
        auth = Auth(
            token_info_path=None,
            client_secret="client_secret",
            client_id="ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c",
        )
        token = b".".join(
            (base64.b64encode(to_bytes(p))
             for p in ["header", json.dumps(dict(exp=0)), "sig"]))
        auth._token = token

        self.assertEqual(auth.token, token)
        _get_token.assert_called_once()

    @patch("descarteslabs.client.auth.Auth._get_token",
           side_effect=AuthError("error"))
    def test_token_expired_autherror(self, _get_token):
        auth = Auth(
            token_info_path=None,
            client_secret="client_secret",
            client_id="ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c",
        )
        token = b".".join(
            (base64.b64encode(to_bytes(p))
             for p in ["header", json.dumps(dict(exp=0)), "sig"]))
        auth._token = token

        with self.assertRaises(AuthError):
            auth.token
        _get_token.assert_called_once()

    @patch("descarteslabs.client.auth.Auth._get_token",
           side_effect=AuthError("error"))
    def test_token_in_leeway_autherror(self, _get_token):
        auth = Auth(
            token_info_path=None,
            client_secret="client_secret",
            client_id="ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c",
        )
        exp = (datetime.datetime.utcnow() -
               datetime.datetime(1970, 1, 1)).total_seconds() + auth.leeway / 2
        token = b".".join(
            (base64.b64encode(to_bytes(p))
             for p in ["header", json.dumps(dict(exp=exp)), "sig"]))
        auth._token = token

        self.assertEqual(auth.token, token)
        _get_token.assert_called_once()

    def test_auth_init_env_vars(self):
        environ = dict(
            CLIENT_SECRET="secret_bar",
            CLIENT_ID="id_bar",
            DESCARTESLABS_CLIENT_SECRET="secret_foo",
            DESCARTESLABS_CLIENT_ID="id_foo",
        )

        # should work with direct var
        with patch.dict("descarteslabs.client.auth.auth.os.environ", environ):
            auth = Auth(
                client_id="client_id",
                client_secret="client_secret",
                jwt_token="jwt_token",
            )
            self.assertEqual(auth.client_secret, "client_secret")
            self.assertEqual(auth.client_id, "client_id")

        # should work with namespaced env vars
        with patch.dict("descarteslabs.client.auth.auth.os.environ", environ):
            auth = Auth()
            self.assertEqual(auth.client_secret,
                             environ.get("DESCARTESLABS_CLIENT_SECRET"))
            self.assertEqual(auth.client_id,
                             environ.get("DESCARTESLABS_CLIENT_ID"))

        # remove the namespaced ones
        environ.pop("DESCARTESLABS_CLIENT_SECRET")
        environ.pop("DESCARTESLABS_CLIENT_ID")

        # should fallback to legacy env vars
        with patch.dict("descarteslabs.client.auth.auth.os.environ", environ):
            auth = Auth()
            self.assertEqual(auth.client_secret, environ.get("CLIENT_SECRET"))
            self.assertEqual(auth.client_id, environ.get("CLIENT_ID"))