예제 #1
0
def test_fetch_token_post():
    url = 'https://example.com/token'

    def assert_func(request):
        content = request.form
        assert content.get('code') == 'v'
        assert content.get('client_id') == 'foo'
        assert content.get('grant_type') == 'authorization_code'

    mock_response = MockDispatch(default_token, assert_func=assert_func)
    with OAuth2Client('foo', app=mock_response) as client:
        token = client.fetch_token(url, authorization_response='https://i.b/?code=v')
        assert token == default_token

    with OAuth2Client(
            'foo',
            token_endpoint_auth_method='none',
            app=mock_response
    ) as client:
        token = client.fetch_token(url, code='v')
        assert token == default_token

    mock_response = MockDispatch({'error': 'invalid_request'})
    with OAuth2Client('foo', app=mock_response) as client:
        with pytest.raises(OAuthError):
            client.fetch_token(url)
예제 #2
0
def test_auto_refresh_token2():

    def _update_token(token, refresh_token=None, access_token=None):
        assert access_token == 'a'
        assert token == default_token

    update_token = mock.Mock(side_effect=_update_token)

    old_token = dict(
        access_token='a',
        token_type='bearer',
        expires_at=100
    )

    app = MockDispatch(default_token)

    with OAuth2Client(
            'foo', token=old_token,
            token_endpoint='https://i.b/token',
            grant_type='client_credentials',
            app=app,
    ) as client:
        client.get('https://i.b/user')
        assert update_token.called is False

    with OAuth2Client(
            'foo', token=old_token, token_endpoint='https://i.b/token',
            update_token=update_token, grant_type='client_credentials',
            app=app,
    ) as client:
        client.get('https://i.b/user')
        assert update_token.called is True
예제 #3
0
def test_auto_refresh_token():

    def _update_token(token, refresh_token=None, access_token=None):
        assert refresh_token == 'b'
        assert token == default_token

    update_token = mock.Mock(side_effect=_update_token)

    old_token = dict(
        access_token='a', refresh_token='b',
        token_type='bearer', expires_at=100
    )

    app = MockDispatch(default_token)
    with OAuth2Client(
            'foo', token=old_token, token_endpoint='https://i.b/token',
            update_token=update_token, app=app
    ) as sess:
        sess.get('https://i.b/user')
        assert update_token.called is True

    old_token = dict(
        access_token='a',
        token_type='bearer',
        expires_at=100
    )
    with OAuth2Client(
            'foo', token=old_token, token_endpoint='https://i.b/token',
            update_token=update_token, app=app
    ) as sess:
        with pytest.raises(OAuthError):
            sess.get('https://i.b/user')
예제 #4
0
    def test_fetch_token_post(self):
        url = 'https://example.com/token'

        def assert_func(request):
            body = request.content.decode()
            self.assertIn('code=v', body)
            self.assertIn('client_id=', body)
            self.assertIn('grant_type=authorization_code', body)

        mock_response = MockDispatch(self.token, assert_func=assert_func)
        with OAuth2Client(self.client_id, dispatch=mock_response) as client:
            token = client.fetch_token(url, authorization_response='https://i.b/?code=v')
            self.assertEqual(token, self.token)

        with OAuth2Client(
                self.client_id,
                token_endpoint_auth_method='none',
                dispatch=mock_response
        ) as client:
            token = client.fetch_token(url, code='v')
            self.assertEqual(token, self.token)

        mock_response = MockDispatch({'error': 'invalid_request'})
        with OAuth2Client(self.client_id, dispatch=mock_response) as client:
            self.assertRaises(OAuthError, client.fetch_token, url)
예제 #5
0
def test_fetch_token_get():
    url = 'https://example.com/token'

    def assert_func(request):
        url = str(request.url)
        assert 'code=v' in url
        assert 'client_id=' in url
        assert 'grant_type=authorization_code' in url

    mock_response = MockDispatch(default_token, assert_func=assert_func)
    with OAuth2Client('foo', app=mock_response) as client:
        authorization_response = 'https://i.b/?code=v'
        token = client.fetch_token(
            url, authorization_response=authorization_response, method='GET')
        assert token == default_token

    with OAuth2Client(
            'foo',
            token_endpoint_auth_method='none',
            app=mock_response
    ) as client:
        token = client.fetch_token(url, code='v', method='GET')
        assert token == default_token

        token = client.fetch_token(url + '?q=a', code='v', method='GET')
        assert token == default_token
예제 #6
0
    def test_fetch_token_get(self):
        url = 'https://example.com/token'

        def assert_func(request):
            url = str(request.url)
            self.assertIn('code=v', url)
            self.assertIn('client_id=', url)
            self.assertIn('grant_type=authorization_code', url)

        mock_response = MockDispatch(self.token, assert_func=assert_func)
        with OAuth2Client(self.client_id, dispatch=mock_response) as client:
            authorization_response = 'https://i.b/?code=v'
            token = client.fetch_token(
                url, authorization_response=authorization_response, method='GET')
            self.assertEqual(token, self.token)

        with OAuth2Client(
                self.client_id,
                token_endpoint_auth_method='none',
                dispatch=mock_response
        ) as client:
            token = client.fetch_token(url, code='v', method='GET')
            self.assertEqual(token, self.token)

            token = client.fetch_token(url + '?q=a', code='v', method='GET')
            self.assertEqual(token, self.token)
예제 #7
0
    def test_auto_refresh_token2(self):

        def _update_token(token, refresh_token=None, access_token=None):
            self.assertEqual(access_token, 'a')
            self.assertEqual(token, self.token)

        update_token = mock.Mock(side_effect=_update_token)

        old_token = dict(
            access_token='a',
            token_type='bearer',
            expires_at=100
        )

        dispatch = MockDispatch(self.token)

        with OAuth2Client(
                'foo', token=old_token,
                token_endpoint='https://i.b/token',
                grant_type='client_credentials',
                dispatch=dispatch
        ) as sess:
            sess.get('https://i.b/user')
            self.assertFalse(update_token.called)

        with OAuth2Client(
                'foo', token=old_token, token_endpoint='https://i.b/token',
                update_token=update_token, grant_type='client_credentials',
                dispatch=dispatch
        ) as sess:
            sess.get('https://i.b/user')
            self.assertTrue(update_token.called)
예제 #8
0
    def test_revoke_token(self):
        answer = {'status': 'ok'}
        dispatch = MockDispatch(answer)

        def _revoke_token_request(url, headers, data):
            self.assertEqual(url, 'https://i.b/token')
            return url, headers, data

        revoke_token_request = mock.Mock(side_effect=_revoke_token_request)
        with OAuth2Client('a', dispatch=dispatch) as sess:
            resp = sess.revoke_token('https://i.b/token', 'hi')
            self.assertEqual(resp.json(), answer)

            resp = sess.revoke_token(
                'https://i.b/token', 'hi',
                token_type_hint='access_token'
            )
            self.assertEqual(resp.json(), answer)

            sess.register_compliance_hook(
                'revoke_token_request',
                revoke_token_request,
            )
            sess.revoke_token(
                'https://i.b/token', 'hi',
                body='',
                token_type_hint='access_token'
            )
            self.assertTrue(revoke_token_request.called)
예제 #9
0
def test_create_authorization_url():
    url = 'https://example.com/authorize?foo=bar'

    sess = OAuth2Client(client_id='foo')
    auth_url, state = sess.create_authorization_url(url)
    assert state in auth_url
    assert 'client_id=foo' in auth_url
    assert 'response_type=code' in auth_url

    sess = OAuth2Client(client_id='foo', prompt='none')
    auth_url, state = sess.create_authorization_url(
        url, state='foo', redirect_uri='https://i.b', scope='profile')
    assert state == 'foo'
    assert 'i.b' in auth_url
    assert 'profile' in auth_url
    assert 'prompt=none' in auth_url
예제 #10
0
    def test_create_authorization_url(self):
        url = 'https://example.com/authorize?foo=bar'

        sess = OAuth2Client(client_id=self.client_id)
        auth_url, state = sess.create_authorization_url(url)
        self.assertIn(state, auth_url)
        self.assertIn(self.client_id, auth_url)
        self.assertIn('response_type=code', auth_url)

        sess = OAuth2Client(client_id=self.client_id, prompt='none')
        auth_url, state = sess.create_authorization_url(
            url, state='foo', redirect_uri='https://i.b', scope='profile')
        self.assertEqual(state, 'foo')
        self.assertIn('i.b', auth_url)
        self.assertIn('profile', auth_url)
        self.assertIn('prompt=none', auth_url)
예제 #11
0
def test_code_challenge():
    sess = OAuth2Client('foo', code_challenge_method='S256')

    url = 'https://example.com/authorize'
    auth_url, _ = sess.create_authorization_url(
        url, code_verifier=generate_token(48))
    assert 'code_challenge=' in auth_url
    assert 'code_challenge_method=S256' in auth_url
예제 #12
0
    def test_code_challenge(self):
        sess = OAuth2Client(client_id=self.client_id, code_challenge_method='S256')

        url = 'https://example.com/authorize'
        auth_url, _ = sess.create_authorization_url(
            url, code_verifier=generate_token(48))
        self.assertIn('code_challenge', auth_url)
        self.assertIn('code_challenge_method=S256', auth_url)
예제 #13
0
파일: auth.py 프로젝트: xpertdev/tda-api
    def ensure_refresh_token_update(
            self, api_key, session, update_interval_seconds=None):
        '''
        If the refresh token is older than update_interval_seconds, update it by
        issuing a call to the token refresh endpoint and return a new session
        wrapped around the resulting token. Returns None if the refresh token
        was not updated.
        '''
        logger = get_logger()

        if update_interval_seconds is None:
            # 85 days is less than the documented 90 day expiration window of
            # the token, but hopefully long enough to not trigger TDA's
            # thresholds for excessive refresh token updates.
            update_interval_seconds = 60 * 60 * 24 * 85

        now = int(time.time())

        logger.info((
            'Updating refresh token:\n'+
            ' - Current timestamp is {}\n'+
            ' - Token creation timestamp is {}\n'+
            ' - Update interval is {} seconds').format(
                now, self.creation_timestamp, update_interval_seconds))

        if not (self.creation_timestamp is None
                or now - self.creation_timestamp >
                update_interval_seconds):
            logger.info('Skipping refresh token update')
            return None

        old_token = session.token
        oauth = OAuth2Client(api_key)

        new_token = oauth.fetch_token(
            TOKEN_ENDPOINT,
            grant_type='refresh_token',
            refresh_token=old_token['refresh_token'],
            access_type='offline')

        logger.info('Updated refresh token')

        self.creation_timestamp = now

        # Don't emit token details in debug logs
        _register_token_redactions(new_token)

        token_write_func = self.wrapped_token_write_func()
        token_write_func(new_token)

        session_class = session.__class__
        return session_class(
            api_key,
            token=new_token,
            token_endpoint=TOKEN_ENDPOINT,
            update_token=token_write_func)
예제 #14
0
 def test_invalid_token_type(self):
     token = {
         'token_type': 'invalid',
         'access_token': 'a',
         'refresh_token': 'b',
         'expires_in': '3600',
         'expires_at': int(time.time()) + 3600,
     }
     with OAuth2Client(self.client_id, token=token) as client:
         self.assertRaises(OAuthError, client.get, 'https://i.b')
예제 #15
0
def test_add_token_to_streaming_request(assert_func, token_placement):
    mock_response = MockDispatch({'a': 'a'}, assert_func=assert_func)
    with OAuth2Client(
            'foo',
            token=default_token,
            token_placement=token_placement,
            app=mock_response
    ) as client:
        with client.stream("GET", 'https://i.b') as stream:
            stream.read()
            data = stream.json()
    assert data['a'] == 'a'
예제 #16
0
def test_add_token_get_request(assert_func, token_placement):
    mock_response = MockDispatch({'a': 'a'}, assert_func=assert_func)
    with OAuth2Client(
            'foo',
            token=default_token,
            token_placement=token_placement,
            app=mock_response
    ) as client:
        resp = client.get('https://i.b')

    data = resp.json()
    assert data['a'] == 'a'
예제 #17
0
def test_cleans_previous_token_before_fetching_new_one():
    now = int(time.time())
    new_token = deepcopy(default_token)
    past = now - 7200
    default_token['expires_at'] = past
    new_token['expires_at'] = now + 3600
    url = 'https://example.com/token'

    app = MockDispatch(new_token)
    with mock.patch('time.time', lambda: now):
        with OAuth2Client('foo', token=default_token, app=app) as sess:
            assert sess.fetch_token(url) == new_token
예제 #18
0
    def test_add_token_to_header(self):
        def assert_func(request):
            token = 'Bearer ' + self.token['access_token']
            auth_header = request.headers.get('authorization')
            self.assertEqual(auth_header, token)

        mock_response = MockDispatch({'a': 'a'}, assert_func=assert_func)
        with OAuth2Client(self.client_id, token=self.token, dispatch=mock_response) as client:
            resp = client.get('https://i.b')

        data = resp.json()
        self.assertEqual(data['a'], 'a')
예제 #19
0
    def test_add_token_to_uri(self):
        def assert_func(request):
            self.assertIn(self.token['access_token'], str(request.url))

        mock_response = MockDispatch({'a': 'a'}, assert_func=assert_func)
        with OAuth2Client(self.client_id,
                          token=self.token,
                          token_placement='uri',
                          dispatch=mock_response) as client:
            resp = client.get('https://i.b')

        data = resp.json()
        self.assertEqual(data['a'], 'a')
예제 #20
0
def test_revoke_token():
    answer = {'status': 'ok'}
    app = MockDispatch(answer)

    with OAuth2Client('a', app=app) as sess:
        resp = sess.revoke_token('https://i.b/token', 'hi')
        assert resp.json() == answer

        resp = sess.revoke_token(
            'https://i.b/token', 'hi',
            token_type_hint='access_token'
        )
        assert resp.json() == answer
예제 #21
0
def test_client_credentials_type():
    url = 'https://example.com/token'

    def assert_func(request):
        content = request.form
        assert content.get('scope') == 'profile'
        assert content.get('grant_type') == 'client_credentials'

    app = MockDispatch(default_token, assert_func=assert_func)
    with OAuth2Client('foo', scope='profile', app=app) as sess:
        token = sess.fetch_token(url)
        assert token == default_token

        token = sess.fetch_token(url, grant_type='client_credentials')
        assert token == default_token
예제 #22
0
    def test_client_credentials_type(self):
        url = 'https://example.com/token'

        def assert_func(request):
            body = request.content.decode()
            self.assertIn('scope=profile', body)
            self.assertIn('grant_type=client_credentials', body)

        dispatch = MockDispatch(self.token, assert_func=assert_func)
        with OAuth2Client(self.client_id, scope='profile', dispatch=dispatch) as sess:
            token = sess.fetch_token(url)
            self.assertEqual(token, self.token)

            token = sess.fetch_token(url, grant_type='client_credentials')
            self.assertEqual(token, self.token)
예제 #23
0
def test_add_token_to_uri():
    def assert_func(request):
        assert default_token['access_token'] in str(request.url)

    mock_response = MockDispatch({'a': 'a'}, assert_func=assert_func)
    with OAuth2Client(
            'foo',
            token=default_token,
            token_placement='uri',
            app=mock_response
    ) as client:
        resp = client.get('https://i.b')

    data = resp.json()
    assert data['a'] == 'a'
예제 #24
0
    def test_access_token_response_hook(self):
        url = 'https://example.com/token'

        def _access_token_response_hook(resp):
            self.assertEqual(resp.json(), self.token)
            return resp

        access_token_response_hook = mock.Mock(side_effect=_access_token_response_hook)
        dispatch = MockDispatch(self.token)
        with OAuth2Client(self.client_id, token=self.token, dispatch=dispatch) as sess:
            sess.register_compliance_hook(
                'access_token_response',
                access_token_response_hook
            )
            self.assertEqual(sess.fetch_token(url), self.token)
            self.assertTrue(access_token_response_hook.called)
예제 #25
0
def test_add_token_to_header():
    def assert_func(request):
        token = 'Bearer ' + default_token['access_token']
        auth_header = request.headers.get('authorization')
        assert auth_header == token

    mock_response = MockDispatch({'a': 'a'}, assert_func=assert_func)
    with OAuth2Client(
            'foo',
            token=default_token,
            app=mock_response
    ) as client:
        resp = client.get('https://i.b')

    data = resp.json()
    assert data['a'] == 'a'
예제 #26
0
def test_access_token_response_hook():
    url = 'https://example.com/token'

    def _access_token_response_hook(resp):
        assert resp.json() == default_token
        return resp

    access_token_response_hook = mock.Mock(side_effect=_access_token_response_hook)
    app = MockDispatch(default_token)
    with OAuth2Client('foo', token=default_token, app=app) as sess:
        sess.register_compliance_hook(
            'access_token_response',
            access_token_response_hook
        )
        assert sess.fetch_token(url) == default_token
        assert access_token_response_hook.called is True
예제 #27
0
def test_add_token_to_body():
    def assert_func(request):
        content = request.data
        content = content.decode()
        assert content == 'access_token=%s' % default_token['access_token']

    mock_response = MockDispatch({'a': 'a'}, assert_func=assert_func)
    with OAuth2Client(
            'foo',
            token=default_token,
            token_placement='body',
            app=mock_response
    ) as client:
        resp = client.get('https://i.b')

    data = resp.json()
    assert data['a'] == 'a'
예제 #28
0
    def test_password_grant_type(self):
        url = 'https://example.com/token'

        def assert_func(request):
            body = request.content.decode()
            self.assertIn('username=v', body)
            self.assertIn('scope=profile', body)
            self.assertIn('grant_type=password', body)

        dispatch = MockDispatch(self.token, assert_func=assert_func)
        with OAuth2Client(self.client_id, scope='profile', dispatch=dispatch) as sess:
            token = sess.fetch_token(url, username='******', password='******')
            self.assertEqual(token, self.token)

            token = sess.fetch_token(
                url, username='******', password='******', grant_type='password')
            self.assertEqual(token, self.token)
예제 #29
0
    def test_cleans_previous_token_before_fetching_new_one(self):
        """Makes sure the previous token is cleaned before fetching a new one.
        The reason behind it is that, if the previous token is expired, this
        method shouldn't fail with a TokenExpiredError, since it's attempting
        to get a new one (which shouldn't be expired).
        """
        now = int(time.time())
        new_token = deepcopy(self.token)
        past = now - 7200
        self.token['expires_at'] = past
        new_token['expires_at'] = now + 3600
        url = 'https://example.com/token'

        dispatch = MockDispatch(new_token)
        with mock.patch('time.time', lambda: now):
            with OAuth2Client(self.client_id, token=self.token, dispatch=dispatch) as sess:
                self.assertEqual(sess.fetch_token(url), new_token)
예제 #30
0
def test_password_grant_type():
    url = 'https://example.com/token'

    def assert_func(request):
        content = request.form
        assert content.get('username') == 'v'
        assert content.get('scope') == 'profile'
        assert content.get('grant_type') == 'password'

    app = MockDispatch(default_token, assert_func=assert_func)
    with OAuth2Client('foo', scope='profile', app=app) as sess:
        token = sess.fetch_token(url, username='******', password='******')
        assert token == default_token

        token = sess.fetch_token(
            url, username='******', password='******', grant_type='password')
        assert token == default_token