Example #1
0
def client_from_token_file(token_path, api_key):
    '''
    Returns a session from an existing token file. The session will perform
    an auth refresh as needed. It will also update the token on disk whenever
    appropriate.

    :param token_path: Path to an existing token. Updated tokens will be written
                       to this path. If you do not yet have a token, use
                       :func:`~tda.auth.client_from_login_flow` or
                       :func:`~tda.auth.easy_client` to create one.
    :param api_key: Your TD Ameritrade application's API key, also known as the
                    client ID.
    '''
    # Load old token from secrets directory
    with open(token_path, 'rb') as f:
        token = pickle.load(f)

    # Don't emit token details in debug logs
    __register_token_redactions(token)

    # Return a new session configured to refresh credentials
    return Client(
        api_key,
        OAuth2Session(
            api_key,
            token=token,
            auto_refresh_url='https://api.tdameritrade.com/v1/oauth2/token',
            auto_refresh_kwargs={'client_id': api_key},
            token_updater=__token_updater(token_path)))
Example #2
0
def client_from_login_flow(webdriver,
                           api_key,
                           redirect_url,
                           token_path,
                           redirect_wait_time_seconds=0.1):
    '''Uses the webdriver to perform an OAuth webapp login flow and creates a
    client wrapped around the resulting token. The client will be configured to
    refresh the token as necessary, writing each updated version to
    ``token_path``.

    :param webdriver: `selenium <https://selenium-python.readthedocs.io>`__
                      webdriver which will be used to perform the login flow.
    :param api_key: Your TD Ameritrade application's API key, also known as the
                    client ID.
    :param redirect_url: Your TD Ameritrade application's redirect URL. Note
                         this must *exactly* match the value you've entered in
                         your application configuration, otherwise login will
                         fail with a security error.
    :param token_path: Path to which the new token will be written. Updated
                       tokens will be written to this path as well.
    '''
    oauth = OAuth2Session(api_key, redirect_uri=redirect_url)
    authorization_url, state = oauth.authorization_url(
        'https://auth.tdameritrade.com/auth')

    # Open the login page and wait for the redirect
    print('Opening the login page in a webdriver. Please use this window to',
          'log in. Successful login will be detected automatically.')
    print('If you encounter any issues, see here for troubleshooting: ' +
          'https://tda-api.readthedocs.io/en/latest/auth.html' +
          '#troubleshooting')

    webdriver.get(authorization_url)
    callback_url = ''
    while not callback_url.startswith(redirect_url):
        callback_url = webdriver.current_url
        time.sleep(redirect_wait_time_seconds)

    token = oauth.fetch_token('https://api.tdameritrade.com/v1/oauth2/token',
                              authorization_response=callback_url,
                              access_type='offline',
                              client_id=api_key,
                              include_client_id=True)

    # Record the token
    update_token = __token_updater(token_path)
    update_token(token)

    # Return a new session configured to refresh credentials
    return Client(
        api_key,
        OAuth2Session(
            api_key,
            token=token,
            auto_refresh_url='https://api.tdameritrade.com/v1/oauth2/token',
            auto_refresh_kwargs={'client_id': api_key},
            token_updater=update_token))
Example #3
0
def client_from_access_functions(api_key,
                                 token_read_func,
                                 token_write_func=None):
    '''
    Returns a session from an existing token file, using the accessor methods to 
    read and write the token. This is an advanced method for users who do not 
    have access to a standard writable filesystem, such as users of AWS Lambda 
    and other serverless products who must persist token updates on 
    non-filesystem places, such as S3. 99.9% of users should not use this 
    function.

    Users are free to customize how they represent the token file. In theory, 
    since they have direct access to the token, they can get creative about how 
    they store it and fetch it. In practice, it is *highly* recommended to 
    simply accept the token object and use ``pickle`` to serialize and 
    deserialize it, without inspecting it in any way.

    :param api_key: Your TD Ameritrade application's API key, also known as the
                    client ID.
    :param token_read_func: Function that takes no arguments and returns a token 
                            object.
    :param token_write_func: Function that a token object and writes it. Will be 
                             called whenever the token is updated, such as when 
                             it is refreshed. Optional, but *highly* 
                             recommended. Note old tokens become unusable on 
                             refresh, so not setting this parameter risks 
                             permanently losing refreshed tokens.
    '''
    token = token_read_func()

    # Don't emit token details in debug logs
    __register_token_redactions(token)

    # Return a new session configured to refresh credentials
    api_key = __normalize_api_key(api_key)

    session_kwargs = {
        'token': token,
        'auto_refresh_url': 'https://api.tdameritrade.com/v1/oauth2/token',
        'auto_refresh_kwargs': {
            'client_id': api_key
        },
    }

    if token_write_func is not None:
        session_kwargs['token_updater'] = token_write_func

    return Client(api_key, OAuth2Session(api_key, **session_kwargs))
Example #4
0
def client_from_token_file(token_path, api_key):
    '''Returns a session from the specified token path. The session will
    perform an auth refresh as needed. It will also update the token on disk
    whenever appropriate.

    :param token_path: Path to the token. Updated tokens will be written to this
                       path.
    :param api_key: Your TD Ameritrade application's API key, also known as the
                    client ID.
    '''

    # Load old token from secrets directory
    with open(token_path, 'rb') as f:
        token = pickle.load(f)

    # Return a new session configured to refresh credentials
    return Client(
        api_key,
        OAuth2Session(api_key, token=token,
                      auto_refresh_url='https://api.tdameritrade.com/v1/oauth2/token',
                      auto_refresh_kwargs={'client_id': api_key},
                      token_updater=__token_updater(token_path)))
Example #5
0
 def setUp(self):
     self.mock_session = MagicMock()
     self.client = Client(API_KEY, self.mock_session)
Example #6
0
class TestClient(unittest.TestCase):
    def setUp(self):
        self.mock_session = MagicMock()
        self.client = Client(API_KEY, self.mock_session)

    def make_url(self, path):
        path = path.format(accountId=ACCOUNT_ID,
                           orderId=ORDER_ID,
                           savedOrderId=SAVED_ORDER_ID,
                           cusip=CUSIP,
                           market=MARKET,
                           index=INDEX,
                           symbol=SYMBOL,
                           transactionId=TRANSACTION_ID,
                           watchlistId=WATCHLIST_ID)
        return 'https://api.tdameritrade.com' + path

    # get_order

    def test_get_order(self):
        self.client.get_order(ORDER_ID, ACCOUNT_ID)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/accounts/{accountId}/orders/{orderId}'),
            params={})

    # cancel_order

    def test_cancel_order(self):
        self.client.cancel_order(ORDER_ID, ACCOUNT_ID)
        self.mock_session.delete.assert_called_once_with(
            self.make_url('/v1/accounts/{accountId}/orders/{orderId}'))

    # get_orders_by_path

    @patch('datetime.datetime', mockdatetime)
    def test_get_orders_by_path_vanilla(self):
        self.client.get_orders_by_path(ACCOUNT_ID)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/accounts/{accountId}/orders'),
            params={
                'fromEnteredTime': MIN_ISO,
                'toEnteredTime': NOW_ISO
            })

    @patch('datetime.datetime', mockdatetime)
    def test_get_orders_by_path_max_results(self):
        self.client.get_orders_by_path(ACCOUNT_ID, max_results=100)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/accounts/{accountId}/orders'),
            params={
                'fromEnteredTime': MIN_ISO,
                'toEnteredTime': NOW_ISO,
                'maxResults': 100,
            })

    @patch('datetime.datetime', mockdatetime)
    def test_get_orders_by_path_from_entered_datetime(self):
        self.client.get_orders_by_path(ACCOUNT_ID,
                                       from_entered_datetime=EARLIER_DATETIME)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/accounts/{accountId}/orders'),
            params={
                'fromEnteredTime': EARLIER_ISO,
                'toEnteredTime': NOW_ISO,
            })

    @patch('datetime.datetime', mockdatetime)
    def test_get_orders_by_path_to_entered_datetime(self):
        self.client.get_orders_by_path(ACCOUNT_ID,
                                       to_entered_datetime=EARLIER_DATETIME)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/accounts/{accountId}/orders'),
            params={
                'fromEnteredTime': MIN_ISO,
                'toEnteredTime': EARLIER_ISO,
            })

    @patch('datetime.datetime', mockdatetime)
    def test_get_orders_by_path_status_and_statuses(self):
        self.assertRaises(
            ValueError, lambda: self.client.get_orders_by_path(
                ACCOUNT_ID,
                to_entered_datetime=EARLIER_DATETIME,
                status='EXPIRED',
                statuses=[Client.Order.Status.FILLED]))

    @patch('datetime.datetime', mockdatetime)
    def test_get_orders_by_path_status(self):
        self.client.get_orders_by_path(ACCOUNT_ID,
                                       status=Client.Order.Status.FILLED)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/accounts/{accountId}/orders'),
            params={
                'fromEnteredTime': MIN_ISO,
                'toEnteredTime': NOW_ISO,
                'status': 'FILLED'
            })

    @patch('datetime.datetime', mockdatetime)
    def test_get_orders_by_path_status_unchecked(self):
        self.client.set_enforce_enums(False)
        self.client.get_orders_by_path(ACCOUNT_ID, status='FILLED')
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/accounts/{accountId}/orders'),
            params={
                'fromEnteredTime': MIN_ISO,
                'toEnteredTime': NOW_ISO,
                'status': 'FILLED'
            })

    @patch('datetime.datetime', mockdatetime)
    def test_get_orders_by_path_statuses(self):
        self.client.get_orders_by_path(
            ACCOUNT_ID,
            statuses=[Client.Order.Status.FILLED, Client.Order.Status.EXPIRED])
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/accounts/{accountId}/orders'),
            params={
                'fromEnteredTime': MIN_ISO,
                'toEnteredTime': NOW_ISO,
                'status': 'FILLED,EXPIRED'
            })

    @patch('datetime.datetime', mockdatetime)
    def test_get_orders_by_path_statuses_unchecked(self):
        self.client.set_enforce_enums(False)
        self.client.get_orders_by_path(ACCOUNT_ID,
                                       statuses=['FILLED', 'EXPIRED'])
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/accounts/{accountId}/orders'),
            params={
                'fromEnteredTime': MIN_ISO,
                'toEnteredTime': NOW_ISO,
                'status': 'FILLED,EXPIRED'
            })

    # get_orders_by_query

    @patch('datetime.datetime', mockdatetime)
    def test_get_orders_by_query_vanilla(self):
        self.client.get_orders_by_query()
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/orders'),
            params={
                'fromEnteredTime': MIN_ISO,
                'toEnteredTime': NOW_ISO
            })

    @patch('datetime.datetime', mockdatetime)
    def test_get_orders_by_query_max_results(self):
        self.client.get_orders_by_query(max_results=100)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/orders'),
            params={
                'fromEnteredTime': MIN_ISO,
                'toEnteredTime': NOW_ISO,
                'maxResults': 100,
            })

    @patch('datetime.datetime', mockdatetime)
    def test_get_orders_by_query_from_entered_datetime(self):
        self.client.get_orders_by_query(from_entered_datetime=EARLIER_DATETIME)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/orders'),
            params={
                'fromEnteredTime': EARLIER_ISO,
                'toEnteredTime': NOW_ISO,
            })

    @patch('datetime.datetime', mockdatetime)
    def test_get_orders_by_query_to_entered_datetime(self):
        self.client.get_orders_by_query(to_entered_datetime=EARLIER_DATETIME)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/orders'),
            params={
                'fromEnteredTime': MIN_ISO,
                'toEnteredTime': EARLIER_ISO,
            })

    @patch('datetime.datetime', mockdatetime)
    def test_get_orders_by_query_status_and_statuses(self):
        with self.assertRaises(
                ValueError,
                msg='at most one of status or statuses may be set'):
            self.client.get_orders_by_query(
                to_entered_datetime=EARLIER_DATETIME,
                status='EXPIRED',
                statuses=[
                    Client.Order.Status.FILLED, Client.Order.Status.EXPIRED
                ])

    @patch('datetime.datetime', mockdatetime)
    def test_get_orders_by_query_status(self):
        self.client.get_orders_by_query(status=Client.Order.Status.FILLED)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/orders'),
            params={
                'fromEnteredTime': MIN_ISO,
                'toEnteredTime': NOW_ISO,
                'status': 'FILLED'
            })

    @patch('datetime.datetime', mockdatetime)
    def test_get_orders_by_query_status_unchecked(self):
        self.client.set_enforce_enums(False)
        self.client.get_orders_by_query(status='FILLED')
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/orders'),
            params={
                'fromEnteredTime': MIN_ISO,
                'toEnteredTime': NOW_ISO,
                'status': 'FILLED'
            })

    @patch('datetime.datetime', mockdatetime)
    def test_get_orders_by_query_statuses(self):
        self.client.get_orders_by_query(
            statuses=[Client.Order.Status.FILLED, Client.Order.Status.EXPIRED])
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/orders'),
            params={
                'fromEnteredTime': MIN_ISO,
                'toEnteredTime': NOW_ISO,
                'status': 'FILLED,EXPIRED'
            })

    @patch('datetime.datetime', mockdatetime)
    def test_get_orders_by_query_statuses_unchecked(self):
        self.client.set_enforce_enums(False)
        self.client.get_orders_by_query(statuses=['FILLED', 'EXPIRED'])
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/orders'),
            params={
                'fromEnteredTime': MIN_ISO,
                'toEnteredTime': NOW_ISO,
                'status': 'FILLED,EXPIRED'
            })

    # place_order

    def test_place_order(self):
        order_spec = {'order': 'spec'}
        self.client.place_order(ACCOUNT_ID, order_spec)
        self.mock_session.post.assert_called_once_with(
            self.make_url('/v1/accounts/{accountId}/orders'), json=order_spec)

    # replace_order

    def test_replace_order(self):
        order_spec = {'order': 'spec'}
        self.client.replace_order(ACCOUNT_ID, ORDER_ID, order_spec)
        self.mock_session.put.assert_called_once_with(
            self.make_url('/v1/accounts/{accountId}/orders/{orderId}'),
            json=order_spec)

    # create_saved_order

    def test_create_saved_order(self):
        order_spec = {'order': 'spec'}
        self.client.create_saved_order(ACCOUNT_ID, order_spec)
        self.mock_session.post.assert_called_once_with(
            self.make_url('/v1/accounts/{accountId}/savedorders'),
            json=order_spec)

    # delete_saved_order

    def test_delete_saved_order(self):
        self.client.delete_saved_order(ACCOUNT_ID, SAVED_ORDER_ID)
        self.mock_session.delete.assert_called_once_with(
            self.make_url(
                '/v1/accounts/{accountId}/savedorders/{savedOrderId}'))

    # delete_saved_order

    def test_get_saved_order(self):
        self.client.get_saved_order(ACCOUNT_ID, SAVED_ORDER_ID)
        self.mock_session.get.assert_called_once_with(self.make_url(
            '/v1/accounts/{accountId}/savedorders/{savedOrderId}'),
                                                      params={})

    # get_saved_orders_by_path

    def test_get_saved_orders_by_path(self):
        self.client.get_saved_orders_by_path(ACCOUNT_ID)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/accounts/{accountId}/savedorders'), params={})

    # replace_saved_order

    def test_replace_saved_order(self):
        order_spec = {'order': 'spec'}
        self.client.replace_saved_order(ACCOUNT_ID, SAVED_ORDER_ID, order_spec)
        self.mock_session.put.assert_called_once_with(self.make_url(
            '/v1/accounts/{accountId}/savedorders/{savedOrderId}'),
                                                      json=order_spec)

    # get_account

    def test_get_account(self):
        self.client.get_account(ACCOUNT_ID)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/accounts/{accountId}'), params={})

    def test_get_account_fields(self):
        self.client.get_account(ACCOUNT_ID,
                                fields=[
                                    Client.Account.Fields.POSITIONS,
                                    Client.Account.Fields.ORDERS
                                ])
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/accounts/{accountId}'),
            params={'fields': 'positions,orders'})

    def test_get_account_fields_unchecked(self):
        self.client.set_enforce_enums(False)
        self.client.get_account(ACCOUNT_ID, fields=['positions', 'orders'])
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/accounts/{accountId}'),
            params={'fields': 'positions,orders'})

    # get_accounts

    def test_get_accounts(self):
        self.client.get_accounts()
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/accounts'), params={})

    def test_get_accounts_fields(self):
        self.client.get_accounts(fields=[
            Client.Account.Fields.POSITIONS, Client.Account.Fields.ORDERS
        ])
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/accounts'),
            params={'fields': 'positions,orders'})

    def test_get_accounts_fields_unchecked(self):
        self.client.set_enforce_enums(False)
        self.client.get_accounts(fields=['positions', 'orders'])
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/accounts'),
            params={'fields': 'positions,orders'})

    # search_instruments

    def test_search_instruments(self):
        self.client.search_instruments(
            ['AAPL', 'MSFT'], Client.Instrument.Projection.FUNDAMENTAL)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/instruments'),
            params={
                'apikey': API_KEY,
                'symbol': 'AAPL,MSFT',
                'projection': 'fundamental'
            })

    def test_search_instruments_unchecked(self):
        self.client.set_enforce_enums(False)
        self.client.search_instruments(['AAPL', 'MSFT'], 'fundamental')
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/instruments'),
            params={
                'apikey': API_KEY,
                'symbol': 'AAPL,MSFT',
                'projection': 'fundamental'
            })

    # get_instrument

    def test_get_instrument(self):
        self.client.get_instrument(CUSIP)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/instruments/{cusip}'),
            params={'apikey': API_KEY})

    def test_get_instrument_cusip_must_be_string(self):
        msg = 'CUSIPs must be passed as strings to preserve leading zeroes'
        with self.assertRaises(ValueError, msg=msg):
            self.client.get_instrument(123456)

    # get_hours_for_multiple_markets

    def test_get_hours_for_multiple_markets(self):
        self.client.get_hours_for_multiple_markets(
            [Client.Markets.EQUITY, Client.Markets.BOND], NOW_DATETIME)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/hours'),
            params={
                'apikey': API_KEY,
                'markets': 'EQUITY,BOND',
                'date': NOW_ISO
            })

    def test_get_hours_for_multiple_markets_unchecked(self):
        self.client.set_enforce_enums(False)
        self.client.get_hours_for_multiple_markets(['EQUITY', 'BOND'],
                                                   NOW_DATETIME)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/hours'),
            params={
                'apikey': API_KEY,
                'markets': 'EQUITY,BOND',
                'date': NOW_ISO
            })

    # get_hours_for_single_market

    def test_get_hours_for_single_market(self):
        self.client.get_hours_for_single_market(Client.Markets.EQUITY,
                                                NOW_DATETIME)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/{market}/hours'),
            params={
                'apikey': API_KEY,
                'date': NOW_ISO
            })

    def test_get_hours_for_single_market_unchecked(self):
        self.client.set_enforce_enums(False)
        self.client.get_hours_for_single_market('EQUITY', NOW_DATETIME)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/{market}/hours'),
            params={
                'apikey': API_KEY,
                'date': NOW_ISO
            })

    # get_movers

    def test_get_movers(self):
        self.client.get_movers(INDEX, Client.Movers.Direction.UP,
                               Client.Movers.Change.PERCENT)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/{index}/movers'),
            params={
                'apikey': API_KEY,
                'direction': 'up',
                'change': 'percent'
            })

    def test_get_movers_unchecked(self):
        self.client.set_enforce_enums(False)
        self.client.get_movers(INDEX, 'up', 'percent')
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/{index}/movers'),
            params={
                'apikey': API_KEY,
                'direction': 'up',
                'change': 'percent'
            })

    # get_option_chain

    def test_get_option_chain_vanilla(self):
        self.client.get_option_chain('AAPL')
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/chains'),
            params={
                'apikey': API_KEY,
                'symbol': 'AAPL'
            })

    def test_get_option_chain_contract_type(self):
        self.client.get_option_chain(
            'AAPL', contract_type=Client.Options.ContractType.PUT)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/chains'),
            params={
                'apikey': API_KEY,
                'symbol': 'AAPL',
                'contractType': 'PUT'
            })

    def test_get_option_chain_contract_type_unchecked(self):
        self.client.set_enforce_enums(False)
        self.client.get_option_chain('AAPL', contract_type='PUT')
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/chains'),
            params={
                'apikey': API_KEY,
                'symbol': 'AAPL',
                'contractType': 'PUT'
            })

    def test_get_option_chain_strike_count(self):
        self.client.get_option_chain('AAPL', strike_count=100)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/chains'),
            params={
                'apikey': API_KEY,
                'symbol': 'AAPL',
                'strikeCount': 100
            })

    def test_get_option_chain_include_quotes(self):
        self.client.get_option_chain('AAPL', include_quotes=True)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/chains'),
            params={
                'apikey': API_KEY,
                'symbol': 'AAPL',
                'includeQuotes': True
            })

    def test_get_option_chain_strategy(self):
        self.client.get_option_chain('AAPL',
                                     strategy=Client.Options.Strategy.STRANGLE)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/chains'),
            params={
                'apikey': API_KEY,
                'symbol': 'AAPL',
                'strategy': 'STRANGLE'
            })

    def test_get_option_chain_strategy_unchecked(self):
        self.client.set_enforce_enums(False)
        self.client.get_option_chain('AAPL', strategy='STRANGLE')
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/chains'),
            params={
                'apikey': API_KEY,
                'symbol': 'AAPL',
                'strategy': 'STRANGLE'
            })

    def test_get_option_chain_interval(self):
        self.client.get_option_chain('AAPL', interval=10.0)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/chains'),
            params={
                'apikey': API_KEY,
                'symbol': 'AAPL',
                'interval': 10.0
            })

    def test_get_option_chain_strike(self):
        self.client.get_option_chain('AAPL', strike=123)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/chains'),
            params={
                'apikey': API_KEY,
                'symbol': 'AAPL',
                'strike': 123
            })

    def test_get_option_chain_strike_range(self):
        self.client.get_option_chain(
            'AAPL', strike_range=Client.Options.StrikeRange.IN_THE_MONEY)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/chains'),
            params={
                'apikey': API_KEY,
                'symbol': 'AAPL',
                'range': 'ITM'
            })

    def test_get_option_chain_strike_range_unchecked(self):
        self.client.set_enforce_enums(False)
        self.client.get_option_chain('AAPL', strike_range='ITM')
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/chains'),
            params={
                'apikey': API_KEY,
                'symbol': 'AAPL',
                'range': 'ITM'
            })

    def test_get_option_chain_from_date(self):
        self.client.get_option_chain('AAPL', strike_from_date=EARLIER_DATETIME)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/chains'),
            params={
                'apikey': API_KEY,
                'symbol': 'AAPL',
                'fromDate': EARLIER_ISO
            })

    def test_get_option_chain_to_date(self):
        self.client.get_option_chain('AAPL', strike_to_date=NOW_DATETIME)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/chains'),
            params={
                'apikey': API_KEY,
                'symbol': 'AAPL',
                'toDate': NOW_ISO
            })

    def test_get_option_chain_volatility(self):
        self.client.get_option_chain('AAPL', volatility=40.0)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/chains'),
            params={
                'apikey': API_KEY,
                'symbol': 'AAPL',
                'volatility': 40.0
            })

    def test_get_option_chain_underlying_price(self):
        self.client.get_option_chain('AAPL', underlying_price=234.0)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/chains'),
            params={
                'apikey': API_KEY,
                'symbol': 'AAPL',
                'underlyingPrice': 234.0
            })

    def test_get_option_chain_interest_rate(self):
        self.client.get_option_chain('AAPL', interest_rate=0.07)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/chains'),
            params={
                'apikey': API_KEY,
                'symbol': 'AAPL',
                'interestRate': 0.07
            })

    def test_get_option_chain_days_to_expiration(self):
        self.client.get_option_chain('AAPL', days_to_expiration=12)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/chains'),
            params={
                'apikey': API_KEY,
                'symbol': 'AAPL',
                'daysToExpiration': 12
            })

    def test_get_option_chain_exp_month(self):
        self.client.get_option_chain('AAPL', exp_month='JAN')
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/chains'),
            params={
                'apikey': API_KEY,
                'symbol': 'AAPL',
                'expMonth': 'JAN'
            })

    def test_get_option_chain_option_type(self):
        self.client.get_option_chain('AAPL',
                                     option_type=Client.Options.Type.STANDARD)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/chains'),
            params={
                'apikey': API_KEY,
                'symbol': 'AAPL',
                'optionType': 'S'
            })

    def test_get_option_chain_option_type_unchecked(self):
        self.client.set_enforce_enums(False)
        self.client.get_option_chain('AAPL', option_type='S')
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/chains'),
            params={
                'apikey': API_KEY,
                'symbol': 'AAPL',
                'optionType': 'S'
            })

    # get_price_history

    def test_get_price_history_vanilla(self):
        self.client.get_price_history(SYMBOL)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/{symbol}/pricehistory'),
            params={'apikey': API_KEY})

    def test_get_price_history_period_type(self):
        self.client.get_price_history(
            SYMBOL, period_type=Client.PriceHistory.PeriodType.MONTH)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/{symbol}/pricehistory'),
            params={
                'apikey': API_KEY,
                'periodType': 'month'
            })

    def test_get_price_history_period_type_unchecked(self):
        self.client.set_enforce_enums(False)
        self.client.get_price_history(SYMBOL, period_type='month')
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/{symbol}/pricehistory'),
            params={
                'apikey': API_KEY,
                'periodType': 'month'
            })

    def test_get_price_history_num_periods(self):
        self.client.get_price_history(
            SYMBOL, period=Client.PriceHistory.Period.TEN_DAYS)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/{symbol}/pricehistory'),
            params={
                'apikey': API_KEY,
                'period': 10
            })

    def test_get_price_history_num_periods_unchecked(self):
        self.client.set_enforce_enums(False)
        self.client.get_price_history(SYMBOL, period=10)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/{symbol}/pricehistory'),
            params={
                'apikey': API_KEY,
                'period': 10
            })

    def test_get_price_history_frequency_type(self):
        self.client.get_price_history(
            SYMBOL, frequency_type=Client.PriceHistory.FrequencyType.DAILY)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/{symbol}/pricehistory'),
            params={
                'apikey': API_KEY,
                'frequencyType': 'daily'
            })

    def test_get_price_history_frequency_type_unchecked(self):
        self.client.set_enforce_enums(False)
        self.client.get_price_history(SYMBOL, frequency_type='daily')
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/{symbol}/pricehistory'),
            params={
                'apikey': API_KEY,
                'frequencyType': 'daily'
            })

    def test_get_price_history_frequency(self):
        self.client.get_price_history(
            SYMBOL, frequency=Client.PriceHistory.Frequency.EVERY_FIVE_MINUTES)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/{symbol}/pricehistory'),
            params={
                'apikey': API_KEY,
                'frequency': 5
            })

    def test_get_price_history_frequency_unchecked(self):
        self.client.set_enforce_enums(False)
        self.client.get_price_history(SYMBOL, frequency=5)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/{symbol}/pricehistory'),
            params={
                'apikey': API_KEY,
                'frequency': 5
            })

    def test_get_price_history_start_datetime(self):
        self.client.get_price_history(SYMBOL, start_datetime=EARLIER_DATETIME)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/{symbol}/pricehistory'),
            params={
                'apikey': API_KEY,
                'startDate': EARLIER_MILLIS
            })

    def test_get_price_history_end_datetime(self):
        self.client.get_price_history(SYMBOL, end_datetime=EARLIER_DATETIME)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/{symbol}/pricehistory'),
            params={
                'apikey': API_KEY,
                'endDate': EARLIER_MILLIS
            })

    def test_get_price_history_need_extended_hours_data(self):
        self.client.get_price_history(SYMBOL, need_extended_hours_data=True)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/{symbol}/pricehistory'),
            params={
                'apikey': API_KEY,
                'needExtendedHoursData': True
            })

    # get_quote

    def test_get_quote(self):
        self.client.get_quote(SYMBOL)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/{symbol}/quotes'),
            params={'apikey': API_KEY})

    # get_quotes

    def test_get_quotes(self):
        self.client.get_quotes(['AAPL', 'MSFT'])
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/marketdata/quotes'),
            params={
                'apikey': API_KEY,
                'symbol': 'AAPL,MSFT'
            })

    # get_transaction

    def test_get_transaction(self):
        self.client.get_transaction(ACCOUNT_ID, TRANSACTION_ID)
        self.mock_session.get.assert_called_once_with(
            self.make_url(
                '/v1/accounts/{accountId}/transactions/{transactionId}'),
            params={'apikey': API_KEY})

    # get_transactions

    def test_get_transactions(self):
        self.client.get_transactions(ACCOUNT_ID)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/accounts/{accountId}/transactions'),
            params={'apikey': API_KEY})

    def test_get_transactions_type(self):
        self.client.get_transactions(
            ACCOUNT_ID,
            transaction_type=Client.Transactions.TransactionType.DIVIDEND)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/accounts/{accountId}/transactions'),
            params={
                'apikey': API_KEY,
                'type': 'DIVIDEND'
            })

    def test_get_transactions_type_unchecked(self):
        self.client.set_enforce_enums(False)
        self.client.get_transactions(ACCOUNT_ID, transaction_type='DIVIDEND')
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/accounts/{accountId}/transactions'),
            params={
                'apikey': API_KEY,
                'type': 'DIVIDEND'
            })

    def test_get_transactions_symbol(self):
        self.client.get_transactions(ACCOUNT_ID, symbol='AAPL')
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/accounts/{accountId}/transactions'),
            params={
                'apikey': API_KEY,
                'symbol': 'AAPL'
            })

    def test_get_transactions_start_datetime(self):
        self.client.get_transactions(ACCOUNT_ID,
                                     start_datetime=EARLIER_DATETIME)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/accounts/{accountId}/transactions'),
            params={
                'apikey': API_KEY,
                'startDate': EARLIER_DATE_STR
            })

    def test_get_transactions_end_datetime(self):
        self.client.get_transactions(ACCOUNT_ID, end_datetime=EARLIER_DATETIME)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/accounts/{accountId}/transactions'),
            params={
                'apikey': API_KEY,
                'endDate': EARLIER_DATE_STR
            })

    # get_preferences

    def test_get_preferences(self):
        self.client.get_preferences(ACCOUNT_ID)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/accounts/{accountId}/preferences'),
            params={'apikey': API_KEY})

    # get_streamer_subscription_keys

    def test_get_streamer_subscription_keys(self):
        self.client.get_streamer_subscription_keys([1000, 2000, 3000])
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/userprincipals/streamersubscriptionkeys'),
            params={
                'apikey': API_KEY,
                'accountIds': '1000,2000,3000'
            })

    # get_user_principals

    def test_get_user_principals_vanilla(self):
        self.client.get_user_principals()
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/userprincipals'), params={'apikey': API_KEY})

    def test_get_user_principals_fields(self):
        self.client.get_user_principals(fields=[
            Client.UserPrincipals.Fields.STREAMER_SUBSCRIPTION_KEYS,
            Client.UserPrincipals.Fields.PREFERENCES
        ])
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/userprincipals'),
            params={
                'apikey': API_KEY,
                'fields': 'streamerSubscriptionKeys,preferences'
            })

    def test_get_user_principals_fields_unchecked(self):
        self.client.set_enforce_enums(False)
        self.client.get_user_principals(
            fields=['streamerSubscriptionKeys', 'preferences'])
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/userprincipals'),
            params={
                'apikey': API_KEY,
                'fields': 'streamerSubscriptionKeys,preferences'
            })

    # update_preferences

    def test_update_preferences(self):
        preferences = {'wantMoney': True}
        self.client.update_preferences(ACCOUNT_ID, preferences)
        self.mock_session.put.assert_called_once_with(
            self.make_url('/v1/accounts/{accountId}/preferences'),
            json=preferences)

    # create_watchlist

    def test_create_watchlist(self):
        watchlist = {'AAPL': True}
        self.client.create_watchlist(ACCOUNT_ID, watchlist)
        self.mock_session.post.assert_called_once_with(
            self.make_url('/v1/accounts/{accountId}/watchlists'),
            json=watchlist)

    # delete_watchlist

    def test_delete_watchlist(self):
        watchlist = {'AAPL': True}
        self.client.delete_watchlist(ACCOUNT_ID, WATCHLIST_ID)
        self.mock_session.delete.assert_called_once_with(
            self.make_url('/v1/accounts/{accountId}/watchlists/{watchlistId}'))

    # get_watchlist

    def test_get_watchlist(self):
        self.client.get_watchlist(ACCOUNT_ID, WATCHLIST_ID)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/accounts/{accountId}/watchlists/{watchlistId}'),
            params={})

    # get_watchlists_for_multiple_accounts

    def test_get_watchlists_for_multiple_accounts(self):
        self.client.get_watchlists_for_multiple_accounts()
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/accounts/watchlists'), params={})

    # get_watchlists_for_single_account

    def test_get_watchlists_for_single_account(self):
        self.client.get_watchlists_for_single_account(ACCOUNT_ID)
        self.mock_session.get.assert_called_once_with(
            self.make_url('/v1/accounts/{accountId}/watchlists'), params={})

    # replace_watchlist

    def test_replace_watchlist(self):
        watchlist = {'AAPL': True}
        self.client.replace_watchlist(ACCOUNT_ID, WATCHLIST_ID, watchlist)
        self.mock_session.put.assert_called_once_with(
            self.make_url('/v1/accounts/{accountId}/watchlists/{watchlistId}'),
            json=watchlist)

    # update_watchlist

    def test_update_watchlist(self):
        watchlist = {'AAPL': True}
        self.client.update_watchlist(ACCOUNT_ID, WATCHLIST_ID, watchlist)
        self.mock_session.patch.assert_called_once_with(
            self.make_url('/v1/accounts/{accountId}/watchlists/{watchlistId}'),
            json=watchlist)
Example #7
0
def client_from_login_flow(webdriver,
                           api_key,
                           redirect_url,
                           token_path,
                           redirect_wait_time_seconds=0.1,
                           max_waits=3000):
    '''
    Uses the webdriver to perform an OAuth webapp login flow and creates a
    client wrapped around the resulting token. The client will be configured to
    refresh the token as necessary, writing each updated version to
    ``token_path``.

    :param webdriver: `selenium <https://selenium-python.readthedocs.io>`__
                      webdriver which will be used to perform the login flow.
    :param api_key: Your TD Ameritrade application's API key, also known as the
                    client ID.
    :param redirect_url: Your TD Ameritrade application's redirect URL. Note
                         this must *exactly* match the value you've entered in
                         your application configuration, otherwise login will
                         fail with a security error.
    :param token_path: Path to which the new token will be written. If the token
                       file already exists, it will be overwritten with a new
                       one. Updated tokens will be written to this path as well.
    '''
    get_logger().info(('Creating new token with redirect URL \'{}\' ' +
                       'and token path \'{}\'').format(redirect_url,
                                                       token_path))

    api_key = __normalize_api_key(api_key)

    oauth = OAuth2Session(api_key, redirect_uri=redirect_url)
    authorization_url, state = oauth.authorization_url(
        'https://auth.tdameritrade.com/auth')

    # Open the login page and wait for the redirect
    print('\n**************************************************************\n')
    print('Opening the login page in a webdriver. Please use this window to',
          'log in. Successful login will be detected automatically.')
    print()
    print('If you encounter any issues, see here for troubleshooting: ' +
          'https://tda-api.readthedocs.io/en/stable/auth.html' +
          '#troubleshooting')
    print('\n**************************************************************\n')

    webdriver.get(authorization_url)

    # Tolerate redirects to HTTPS on the callback URL
    if redirect_url.startswith('http://'):
        print(
            ('WARNING: Your redirect URL ({}) will transmit data over HTTP, ' +
             'which is a potentially severe security vulnerability. ' +
             'Please go to your app\'s configuration with TDAmeritrade ' +
             'and update your redirect URL to begin with \'https\' ' +
             'to stop seeing this message.').format(redirect_url))

        redirect_urls = (redirect_url, 'https' + redirect_url[4:])
    else:
        redirect_urls = (redirect_url, )

    # Wait until the current URL starts with the callback URL
    current_url = ''
    num_waits = 0
    while not any(current_url.startswith(r_url) for r_url in redirect_urls):
        current_url = webdriver.current_url

        if num_waits > max_waits:
            raise RedirectTimeoutError('timed out waiting for redirect')
        time.sleep(redirect_wait_time_seconds)
        num_waits += 1

    token = oauth.fetch_token('https://api.tdameritrade.com/v1/oauth2/token',
                              authorization_response=current_url,
                              access_type='offline',
                              client_id=api_key,
                              include_client_id=True)

    # Don't emit token details in debug logs
    __register_token_redactions(token)

    # Record the token
    update_token = __token_updater(token_path)
    update_token(token)

    # Return a new session configured to refresh credentials
    return Client(
        api_key,
        OAuth2Session(
            api_key,
            token=token,
            auto_refresh_url='https://api.tdameritrade.com/v1/oauth2/token',
            auto_refresh_kwargs={'client_id': api_key},
            token_updater=update_token))