コード例 #1
0
def connect_to_TDA():
    TDSession = TDClient(client_id=CONSUMER_KEY,
                         redirect_uri=REDIRECT_URI,
                         credentials_path='tda_key.json')

    TDSession.login()
    return TDSession
コード例 #2
0
ファイル: ALLTD.py プロジェクト: JChaiTea/Chia
class tdAction(object):
    def __init__(self):
        """Constructor"""
        self.conKey = CONSUMER_KEY
        self.redirect = REDIRECT_URL
        self.acct = ACCT_NUM
        self.tdClient = TDClient(client_id=self.conKey,
                                 redirect_uri=self.redirect)
        self.tdClient.login()

    def quote(self, tickerList):
        """Grab quotes"""
        retDict = {}
        quotes = self.tdClient.get_quotes(instruments=tickerList)
        for key in quotes:
            if type(quotes[key]['lastPrice']) is not None:
                retDict[key] = quotes[key]['lastPrice']
        return retDict

    def history(self, tickerList):
        """Create real time data for a given list of tickers"""
        retDict = {}
        for item in tickerList:
            tickerHistory = self.tdClient.get_price_history(symbol=item,
                                                            period_type='year')
            priceHistory = []
            for data in tickerHistory['candles']:
                priceHistory.append(data['close'])
            df = DataFrame({item: priceHistory})
            retDict[item] = df
        return retDict
コード例 #3
0
    def create_client(self) -> TDClient:
        creds_file = self.__fetch_credential_file()
        client = TDClient(client_id=self.td_client_id,
                          redirect_uri=self.td_redirect_uri,
                          credentials_path=creds_file)

        client.login()
        return client
コード例 #4
0
def initialize(client_id, redirect_uri, credentials_path):
    global client
    if client:
        raise RuntimeError("client已经初始化了")
    client = TDClient(client_id=client_id,
                      redirect_uri=redirect_uri,
                      credentials_path=credentials_path)
    client.login()
    TSTypeRegistry.register(TDDailyBar())
コード例 #5
0
    def _create_session(self) -> TDClient:
        td_client = TDClient(client_id=self.client_id,
                             redirect_uri=self.redirect_uri,
                             credentials_path=self.credentials_path)

        # login to the sessions.
        td_client.login()

        return td_client
コード例 #6
0
def initialize(client_id, redirect_uri, credentials_path):
    global client
    if client:
        raise RuntimeError("client已经初始化了")
    client = TDClient(
        client_id=client_id,
        redirect_uri=redirect_uri,
        credentials_path=credentials_path
    )
    client.login()
コード例 #7
0
  def create_client(force_refresh:bool=True, version:str='AWSCURRENT') -> TDClient:
    factory = ClientFactory()
    
    base_path = '/tmp/'
    if platform == 'win32':
      base_path = path.join(path.dirname(__file__),'..')

    #outfile = TemporaryDirectory(dir='FsiCollector')
    outpath = path.join(base_path,'creds.json')
    creds_file = factory.__fetch_credential_file(force_refresh=force_refresh, outpath=outpath, version=version)
    client = TDClient(
      client_id=factory.td_client_id,
      redirect_uri=factory.td_redirect_uri,
      credentials_path=creds_file)

    client.login()
    return client
コード例 #8
0
ファイル: robot.py プロジェクト: bdowe/python_trading_bot
    def _create_session(self) -> TDClient:
        """Start a new session.
        Creates a new session with the TD Ameritrade API and logs the user into
        the new session.
        Returns:
        ----
        TDClient -- A TDClient object with an authenticated sessions.
        """

        # Create a new instance of the client
        td_client = TDClient(client_id=self.client_id,
                             redirect_uri=self.redirect_uri,
                             credentials_path=self.credentials_path)

        # log the client into the new session
        td_client.login()

        return td_client
コード例 #9
0
def dt_signal_prices(candle_minutes, symbols):
    TDSession = TDClient(client_id=config.client_id,
                         redirect_uri='http://localhost/test',
                         credentials_path='td_state.json')

    TDSession.login()

    cur_day = datetime.datetime.now(tz=pytz.timezone('US/Eastern'))
    price_end_date = str(int(round(cur_day.timestamp() * 1000)))
    price_start_date = str(
        int(
            round(
                datetime.datetime(cur_day.year, cur_day.month,
                                  cur_day.day - 1).timestamp() * 1000)))

    candle_list = []

    for symbol in symbols:
        p_hist = TDSession.get_price_history(symbol,
                                             period_type='day',
                                             frequency_type='minute',
                                             frequency=str(candle_minutes),
                                             end_date=price_end_date,
                                             start_date=price_start_date)

        for candle in p_hist['candles']:
            candle_list.append([
                symbol,
                datetime.datetime.fromtimestamp(candle['datetime'] / 1000),
                candle['open'], candle['close'], candle['high'], candle['low']
            ])

    df_dt = pd.DataFrame(
        candle_list,
        columns=['Symbol', 'Date', 'Open', 'Close', 'High', 'Low'])

    # Calculate moving average
    df_dt['SMA_9'] = df_dt.groupby('Symbol')['Close'].rolling(
        9).mean().reset_index(0, drop=True)

    return df_dt
コード例 #10
0
def get_symbols(watchlist):
    TDSession = TDClient(client_id=config.client_id,
                         redirect_uri='http://localhost/test',
                         credentials_path='td_state.json')

    TDSession.login()

    response = TDSession.get_watchlist_accounts()

    i = 0

    for r in response:
        if r['name'] == str(watchlist):
            watch = response[i]
            symbols = [
                watch['watchlistItems'][x]['instrument']['symbol']
                for x in range(len(watch['watchlistItems']))
            ]
        else:
            i += 1

    return symbols
コード例 #11
0
ファイル: tda.py プロジェクト: dr-natetorious/aws-homenet
    def grab_refresh_token(self, tda_creds: dict) -> TDClient:

        # Write the creds int a known location...
        output = path.join(gettempdir(), 'creds.json')
        with open(output, 'w') as f:
            f.write(dumps(tda_creds))

        # Fetch the offline token...
        client = TDClient(credentials_path=output,
                          client_id=self.client_id,
                          redirect_uri=self.redirect_uri)

        if not client.login():
            raise ValueError('Unable to login')
        if not client.grab_refresh_token():
            raise ValueError('Unable to grab_refresh_token')

        # Read the cached offline token...
        with open(output, 'r') as f:
            token = f.read()
            return loads(token)
def grab_candle_data(pull_from_td: bool) -> list[dict]:
    """A function that grabs candle data from TD Ameritrade,
    cleans up the data, and saves it to a JSON file, so we can
    use it later.

    ### Parameters
    ----------
    pull_from_td : bool
        If `True`, pull fresh candles from the TD
        Ameritrade API. If `False`, load the data
        from the JSON file.

    ### Returns
    -------
    list[dict]
        A list of candle dictionaries with cleaned
        dates, and additional values.
    """

    if pull_from_td:

        # Grab configuration values.
        config = ConfigParser()
        config.read('config/config.ini')

        # Read the Config File.
        CLIENT_ID = config.get('main', 'CLIENT_ID')
        REDIRECT_URI = config.get('main', 'REDIRECT_URI')
        JSON_PATH = config.get('main', 'JSON_PATH')
        ACCOUNT_NUMBER = config.get('main', 'ACCOUNT_NUMBER')

        # Create a new session
        TDSession = TDClient(client_id=CLIENT_ID,
                             redirect_uri=REDIRECT_URI,
                             credentials_path=JSON_PATH,
                             account_number=ACCOUNT_NUMBER)

        # Login to the session
        TDSession.login()

        # Initialize the list to store candles.
        all_candles = []

        # Loop through each Ticker.
        for ticker in ['AAPL', 'NIO', 'FIT', 'TSLA', 'MSFT', 'AMZN', 'IBM']:

            # Grab the Quotes.
            quotes = TDSession.get_price_history(symbol=ticker,
                                                 period_type='day',
                                                 period='10',
                                                 frequency_type='minute',
                                                 frequency=1,
                                                 extended_hours=False)

            # Grab the Candles.
            candles = quotes['candles']

            # Loop through each candle.
            for candle in candles:

                # Calculate the Range.
                candle['range'] = round(candle['high'] - candle['low'], 5)

                # Add the Symbol.
                candle['symbol'] = quotes['symbol']

                # Convert to ISO String.
                candle['datetime_iso'] = datetime.fromtimestamp(
                    candle['datetime'] / 1000).isoformat()

                # Conver to a Timestamp non-milliseconds.
                candle['datetime_non_milli'] = int(candle['datetime'] / 1000)

                all_candles.append((candle))

        # Save it to a JSON File.
        with open(file='data/candles.json', mode='w+') as candle_file:
            json.dump(obj=all_candles, fp=candle_file, indent=4)

    elif pull_from_td is False and pathlib.Path('data/candles.json').exists():

        # Save it to a JSON File.
        with open(file='data/candles.json', mode='r') as candle_file:
            all_candles = json.load(fp=candle_file)

    return all_candles
class TDSession(TestCase):
    """Will perform a unit test for the TD session."""
    def setUp(self) -> None:
        """Set up the `TDClient`."""

        # Grab configuration values.
        config = ConfigParser()
        config.read('config/config.ini')

        # Load the values.
        CLIENT_ID = config.get('main', 'CLIENT_ID')
        REDIRECT_URI = config.get('main', 'REDIRECT_URI')
        JSON_PATH = config.get('main', 'JSON_PATH')
        ACCOUNT_NUMBER = config.get('main', 'ACCOUNT_NUMBER')

        # Initalize the session.
        self.td_session = TDClient(client_id=CLIENT_ID,
                                   redirect_uri=REDIRECT_URI,
                                   credentials_path=JSON_PATH,
                                   account_number=ACCOUNT_NUMBER)

    def test_creates_instance_of_session(self):
        """Create an instance and make sure it's a `TDClient` object."""

        self.assertIsInstance(self.td_session, TDClient)

    def test_login(self):
        """Test whether the session is authenticated or not."""

        self.assertTrue(self.td_session.login())
        self.assertTrue(self.td_session.authstate)

    def test_state(self):
        """Make sure the state is updated."""

        self.assertIsNotNone(self.td_session.state['refresh_token'])
        self.assertIsNotNone(self.td_session.state['access_token'])

        self.assertNotEqual(self.td_session.state['refresh_token_expires_at'],
                            0)
        self.assertNotEqual(self.td_session.state['access_token_expires_at'],
                            0)

    def test_single_get_quotes(self):
        """Test Getting a Single quote."""

        # Grab a single quote.
        quotes = self.td_session.get_quotes(instruments=['MSFT'])

        # See if the Symbol is in the Quotes.
        self.assertIn('MSFT', quotes)

        # Save the data.
        if SAVE_FLAG:
            with open('samples/responses/sample_single_quotes.jsonc',
                      'w+') as data_file:
                json.dump(obj=quotes, fp=data_file, indent=3)

    def test_get_quotes(self):
        """Test Getting Multiple Quotes."""

        # Grab multiple Quotes.
        quotes = self.td_session.get_quotes(instruments=['MSFT', 'AAPL'])

        # See if the Symbols are in the Quotes.
        self.assertTrue(set(['MSFT', 'AAPL']).issuperset(set(quotes.keys())))

        # Save the data.
        if SAVE_FLAG:
            with open('samples/responses/sample_multiple_quotes.jsonc',
                      'w+') as data_file:
                json.dump(obj=quotes, fp=data_file, indent=3)

    def test_get_accounts(self):
        """Test Get Accounts."""

        accounts = self.td_session.get_accounts(account='all',
                                                fields=['orders', 'positions'])

        self.assertIn('positions', accounts[0]['securitiesAccount'])
        self.assertIn('currentBalances', accounts[0]['securitiesAccount'])
        # self.assertIn('orderStrategies', accounts[0]['securitiesAccount'])

        # Save the data.
        if SAVE_FLAG:
            with open('samples/responses/sample_accounts.jsonc',
                      'w+') as data_file:
                json.dump(obj=accounts, fp=data_file, indent=3)

    def test_create_stream_session(self):
        """Test Creating a new streaming session."""

        stream_session = self.td_session.create_streaming_session()
        self.assertIsInstance(stream_session, TDStreamerClient)

    def test_get_transactions(self):
        """Test getting transactions."""

        # `get_transactions` Endpoint. Should not return an error
        transaction_data_multi = self.td_session.get_transactions(
            account=self.td_session.account_number, transaction_type='ALL')

        # Make sure it's a list.
        self.assertIsInstance(transaction_data_multi, list)
        self.assertIn('type', transaction_data_multi[0])

        # Save the data.
        if SAVE_FLAG:
            with open('samples/responses/sample_transaction_data.jsonc',
                      'w+') as data_file:
                json.dump(obj=transaction_data_multi, fp=data_file, indent=3)

    def test_get_market_hours(self):
        """Test get market hours."""

        # `get_market_hours` Endpoint with multiple values
        market_hours_multi = self.td_session.get_market_hours(
            markets=['EQUITY', 'FOREX'], date=datetime.today().isoformat())

        # If it's a weekend nothing is returned, so raise an error.
        if datetime.today().weekday() in (5, 6):

            # Make sure it's a list.
            self.assertIsInstance(market_hours_multi, dict)
            self.assertIn('isOpen', market_hours_multi['equity']['equity'])

        else:
            # Make sure it's a list.
            self.assertIsInstance(market_hours_multi, dict)
            self.assertIn('isOpen', market_hours_multi['equity']['EQ'])

        # Save the data.
        if SAVE_FLAG:
            with open('samples/responses/sample_market_hours.jsonc',
                      'w+') as data_file:
                json.dump(obj=market_hours_multi, fp=data_file, indent=3)

    def test_get_instrument(self):
        """Test getting an instrument."""

        # `get_instruments` Endpoint.
        get_instrument = self.td_session.get_instruments(cusip='594918104')

        # Make sure it's a list.
        self.assertIsInstance(get_instrument, list)
        self.assertIn('cusip', get_instrument[0])

        # Save the data.
        if SAVE_FLAG:
            with open('samples/responses/sample_instrument.jsonc',
                      'w+') as data_file:
                json.dump(obj=get_instrument, fp=data_file, indent=3)

    def test_chart_history(self):
        """Test getting historical prices."""

        # Define a list of all valid periods
        valid_values = {
            'minute': {
                'day': [1, 2, 3, 4, 5, 10]
            },
            'daily': {
                'month': [1, 2, 3, 6],
                'year': [1, 2, 3, 5, 10, 15, 20],
                'ytd': [1]
            },
            'weekly': {
                'month': [1, 2, 3, 6],
                'year': [1, 2, 3, 5, 10, 15, 20],
                'ytd': [1]
            },
            'monthly': {
                'year': [1, 2, 3, 5, 10, 15, 20]
            }
        }

        # Define the static arguments.
        hist_symbol = 'MSFT'
        hist_needExtendedHoursData = False

        for frequency_type in valid_values.keys():
            frequency_periods = valid_values[frequency_type]

            for frequency_period in frequency_periods.keys():
                possible_values = frequency_periods[frequency_period]

                for value in possible_values:

                    # Define the dynamic arguments - I want 5 DAYS of historical 1-minute bars.
                    hist_periodType = frequency_period
                    hist_period = value
                    hist_frequencyType = frequency_type
                    hist_frequency = 1

                    # make the request
                    historical_prices = self.td_session.get_price_history(
                        symbol=hist_symbol,
                        period_type=hist_periodType,
                        period=hist_period,
                        frequency_type=hist_frequencyType,
                        frequency=hist_frequency,
                        extended_hours=hist_needExtendedHoursData)

                    self.assertIsInstance(historical_prices, dict)
                    self.assertFalse(historical_prices['empty'])

        # Save the data.
        if SAVE_FLAG:
            with open('samples/responses/sample_historical_prices.jsonc',
                      'w+') as data_file:
                json.dump(obj=historical_prices, fp=data_file, indent=3)

    def test_custom_historical_prices(self):
        """Test getting historical prices for a custom date range."""

        # The max look back period for minute data is 31 Days.
        lookback_period = 10

        # Define today.
        today_00 = datetime.now()

        # Define 300 days ago.
        today_ago = datetime.now() - timedelta(days=lookback_period)

        # The TD API expects a timestamp in milliseconds. However, the timestamp() method only returns to seconds so multiply it by 1000.
        today_00 = str(int(round(today_00.timestamp() * 1000)))
        today_ago = str(int(round(today_ago.timestamp() * 1000)))

        # These values will now be our startDate and endDate parameters.
        hist_startDate = today_ago
        hist_endDate = today_00

        # Define the dynamic arguments.
        hist_periodType = 'day'
        hist_frequencyType = 'minute'
        hist_frequency = 1

        # Make the request
        historical_custom = self.td_session.get_price_history(
            symbol='MSFT',
            period_type=hist_periodType,
            frequency_type=hist_frequencyType,
            start_date=hist_startDate,
            end_date=hist_endDate,
            frequency=hist_frequency,
            extended_hours=True)

        self.assertIsInstance(historical_custom, dict)
        self.assertFalse(historical_custom['empty'])

        # Save the data.
        if SAVE_FLAG:
            with open('samples/responses/sample_historical_prices.jsonc',
                      'w+') as data_file:
                json.dump(obj=historical_custom, fp=data_file, indent=3)

    def test_search_instruments(self):
        """Test Searching for Instruments."""

        # `search_instruments` Endpoint
        instrument_search_data = self.td_session.search_instruments(
            symbol='MSFT', projection='symbol-search')

        self.assertIsInstance(instrument_search_data, dict)
        self.assertIn('MSFT', instrument_search_data)

        # Save the data.
        if SAVE_FLAG:
            with open('samples/responses/sample_search_instrument.jsonc',
                      'w+') as data_file:
                json.dump(obj=instrument_search_data, fp=data_file, indent=3)

    def test_get_movers(self):
        """Test getting Market movers."""

        # `get_movers` Endpoint
        movers_data = self.td_session.get_movers(market='$DJI',
                                                 direction='up',
                                                 change='value')

        if datetime.today().weekday() in (5, 6):
            self.assertIsInstance(movers_data, list)
            self.assertFalse(movers_data)
        else:
            self.assertIsInstance(movers_data, list)
            self.assertIn('symbol', movers_data[0])

        # Save the data.
        if SAVE_FLAG:
            with open('samples/responses/sample_movers.jsonc',
                      'w+') as data_file:
                json.dump(obj=movers_data, fp=data_file, indent=3)

    def test_get_user_preferences(self):
        """Test getting user preferences."""

        # `get_preferences` endpoint. Should not return an error
        preference_data = self.td_session.get_preferences(
            account=self.td_session.account_number)

        self.assertIsInstance(preference_data, dict)
        self.assertIn('expressTrading', preference_data)

        # Save the data.
        if SAVE_FLAG:
            with open('samples/responses/sample_account_preferences.jsonc',
                      'w+') as data_file:
                json.dump(obj=preference_data, fp=data_file, indent=3)

    def test_get_user_principals(self):
        """Test getting user principals."""

        # `get_preferences` endpoint. Should not return an error
        user_principals = self.td_session.get_user_principals(
            fields=['preferences', 'surrogateIds'])

        self.assertIsInstance(user_principals, dict)
        self.assertIn('userId', user_principals)

        # Save the data.
        if SAVE_FLAG:
            with open('samples/responses/sample_user_principals.jsonc',
                      'w+') as data_file:
                json.dump(obj=user_principals, fp=data_file, indent=3)

    def test_get_streamer_keys(self):
        """Test getting user preferences."""

        # `get_subscription_keys` endpoint. Should not return an error
        streamer_keys = self.td_session.get_streamer_subscription_keys(
            accounts=[self.td_session.account_number])

        self.assertIsInstance(streamer_keys, dict)
        self.assertIn('keys', streamer_keys)

        # Save the data.
        if SAVE_FLAG:
            with open('samples/responses/sample_streamer_keys.jsonc',
                      'w+') as data_file:
                json.dump(obj=streamer_keys, fp=data_file, indent=3)

    def tearDown(self) -> None:
        """Teardown the Robot."""

        self.td_session = None
コード例 #14
0
import dateutil.relativedelta

from td.client import TDClient
from td.enums import ORDER_SESSION, ORDER_TYPE
from util.options import processOptionTrade, orderedKeys
from config import ACCOUNT_ID, ACCOUNT_NUMBER, ACCOUNT_PASSWORD, CONSUMER_ID, REDIRECT_URI, SECRET_QUESTIONS

# create a new session
sess = TDClient(
    account_number = ACCOUNT_NUMBER,
    account_password = ACCOUNT_PASSWORD,
    consumer_id = CONSUMER_ID,
    redirect_uri = REDIRECT_URI,
    secret_questions = SECRET_QUESTIONS
)
sess.login()

# run job every end of date before midnight
# start date exclusive, end date inclusive
# for daily job, endData should be today
today = date.today()
delta = dateutil.relativedelta.relativedelta(days=1)
startDt = (today - delta).strftime("%Y-%m-%d")
endDt = today.strftime("%Y-%m-%d")

data = sess.get_transactions(
    account = ACCOUNT_ID, transaction_type = 'TRADE',
    start_date=startDt, end_date=endDt
)
trades = list(filter(lambda entry: entry['type'] == 'TRADE', data))
optionTrades = filter(lambda entry: entry['transactionItem']['instrument']['assetType'] == 'OPTION', trades)
コード例 #15
0
class PyRobotPortfolioTest(TestCase):

    """Will perform a unit test for the Portfolio object."""

    def setUp(self) -> None:
        """Set up the Portfolio."""

        self.portfolio = Portfolio()
        self.maxDiff = None

                # Grab configuration values.
        config = ConfigParser()
        config.read('configs/config.ini')       

        CLIENT_ID = config.get('main', 'CLIENT_ID')
        REDIRECT_URI = config.get('main', 'REDIRECT_URI')
        CREDENTIALS_PATH = config.get('main', 'JSON_PATH')
        self.ACCOUNT_NUMBER = config.get('main', 'ACCOUNT_NUMBER')

        self.td_client = TDClient(
            client_id=CLIENT_ID,
            redirect_uri=REDIRECT_URI,
            credentials_path=CREDENTIALS_PATH
        )

        self.td_client.login()

    def test_create_portofolio(self):
        """Make sure it's a Portfolio."""

        self.assertIsInstance(self.portfolio, Portfolio)

    def test_td_client_property(self):
        """Test the TD Client property."""

        # Should be None if wasn't initalized from the PyRobot.
        self.assertIsNone(self.portfolio.td_client)
    
    def test_stock_frame_property(self):
        """Test the Stock Frame property."""

        # Should be None if wasn't initalized from the PyRobot.
        self.assertIsNone(self.portfolio.stock_frame)        

    def test_historical_prices_property(self):
        """Test the Historical Prices property."""

        # Should be False if wasn't initalized from the PyRobot.
        self.assertFalse(self.portfolio.historical_prices)  

    def test_add_position(self):
        """Test adding a single position to the portfolio."""

        new_position = self.portfolio.add_position(
            symbol='MSFT',
            asset_type='equity',
            quantity=10,
            purchase_price=3.00,
            purchase_date='2020-01-31'
        )

        correct_position = {
            'symbol': 'MSFT',
            'asset_type': 'equity',
            'ownership_status': True, 
            'quantity': 10, 
            'purchase_price': 3.00,
            'purchase_date': '2020-01-31'
        }

        self.assertDictEqual(new_position, correct_position)

    def test_add_position_default_arguments(self):
        """Test adding a single position to the portfolio, no date."""

        new_position = self.portfolio.add_position(
            symbol='MSFT',
            asset_type='equity'
        )

        correct_position = {
            'symbol': 'MSFT',
            'asset_type': 'equity',
            'ownership_status': False, 
            'quantity': 0, 
            'purchase_price': 0.00,
            'purchase_date': None
        }

        self.assertDictEqual(new_position, correct_position)

    def test_delete_existing_position(self):
        """Test deleting an exisiting position."""

        self.portfolio.add_position(
            symbol='MSFT',
            asset_type='equity',
            quantity=10,
            purchase_price=3.00,
            purchase_date='2020-01-31'
        )

        delete_status = self.portfolio.remove_position(symbol='MSFT')
        correct_status = (True, 'MSFT was successfully removed.')

        self.assertTupleEqual(delete_status, correct_status)

    def test_delete_non_existing_position(self):
        """Test deleting a non-exisiting position."""

        delete_status = self.portfolio.remove_position(symbol='AAPL')
        correct_status = (False, 'AAPL did not exist in the porfolio.')

        self.assertTupleEqual(delete_status, correct_status)
    
    def test_in_portfolio_exisitng(self):
        """Checks to see if an exisiting position exists."""

        self.portfolio.add_position(
            symbol='MSFT',
            asset_type='equity',
            quantity=10,
            purchase_price=3.00,
            purchase_date='2020-01-31'
        )

        in_portfolio_flag = self.portfolio.in_portfolio(symbol='MSFT')
        self.assertTrue(in_portfolio_flag)
    
    def test_in_portfolio_non_exisitng(self):
        """Checks to see if a non exisiting position exists."""

        in_portfolio_flag = self.portfolio.in_portfolio(symbol='AAPL')
        self.assertFalse(in_portfolio_flag)

    def test_is_profitable(self):
        """Checks to see if a position is profitable."""

        # Add a position.
        self.portfolio.add_position(
            symbol='MSFT',
            asset_type='equity',
            quantity=10,
            purchase_price=3.00,
            purchase_date='2020-01-31'
        )

        # Test for being Profitable.
        is_profitable = self.portfolio.is_profitable(
            symbol='MSFT',
            current_price=5.00
        
        )

        # Test for not being profitable.
        is_not_profitable = self.portfolio.is_profitable(
            symbol='MSFT',
            current_price=1.00
        )
        
        self.assertTrue(is_profitable)
        self.assertFalse(is_not_profitable)

    def test_projected_market_value(self):
        """Tests the generation of a market value summary, for all of the positions."""

        # Add a position.
        self.portfolio.add_position(
            symbol='MSFT',
            asset_type='equity',
            quantity=10,
            purchase_price=3.00,
            purchase_date='2020-01-31'
        )

        correct_dict = {
            'MSFT': {
                'current_price': 5.0,
                'is_profitable': True,
                'purchase_price': 3.0,
                'quantity': 10,
                'total_invested_capital': 30.0,
                'total_loss_or_gain_$': 20.0,
                'total_loss_or_gain_%': 0.6667,
                'total_market_value': 50.0
            },
            'total': {
                'number_of_breakeven_positions': 0,
                'number_of_non_profitable_positions': 0,
                'number_of_profitable_positions': 1,
                'total_invested_capital': 30.0,
                'total_market_value': 50.0,
                'total_positions': 1,
                'total_profit_or_loss': 20.0
            }
        }

        portfolio_summary = self.portfolio.projected_market_value(current_prices={'MSFT':{'lastPrice':5.0}})
        self.assertDictEqual(correct_dict, portfolio_summary)

    def test_grab_historical_prices(self):
        pass

    def test_portfolio_summary(self):
        """Tests the generation of a portfolio summary, for all of the positions."""

        # Add a position.
        self.portfolio.add_position(
            symbol='MSFT',
            asset_type='equity',
            quantity=10,
            purchase_price=3.00,
            purchase_date='2020-01-31'
        )

        self.portfolio.td_client = self.td_client

        correct_dict = [
            'projected_market_value',
            'portfolio_weights',
            'portfolio_risk'
        ]

        correct_dict = set(correct_dict)

        summary_dict = self.portfolio.portfolio_summary()

        self.assertTrue(correct_dict.issubset(summary_dict))    

    def test_ownership_status(self):
        """Tests getting and setting the ownership status."""

        # Add a position.
        self.portfolio.add_position(
            symbol='MSFT',
            asset_type='equity',
            quantity=10,
            purchase_price=3.00,
            purchase_date='2020-01-31'
        )

        # Should be True, since `purchase_date` was set.
        self.assertTrue(self.portfolio.get_ownership_status(symbol='MSFT'))

        # Reassign it.
        self.portfolio.set_ownership_status(symbol='MSFT', ownership=False)

        # Should be False.
        self.assertFalse(self.portfolio.get_ownership_status(symbol='MSFT'))


    def tearDown(self) -> None:
        """Teardown the Portfolio object."""

        self.portfolio = None
コード例 #16
0
class TDAccount(AbstractAccount):
    def valid_scope(self, scope: Scope):
        for code in scope.codes:
            cc = code.split("_")
            symbol = cc[0]
            symbol_type = cc[1]
            if symbol_type != 'STK':
                raise NotImplementedError
            res = self.client.search_instruments(symbol, "symbol-search")
            if not res or len(res) <= 0 or symbol not in res:
                raise RuntimeError("没有查询到资产,code:" + code)
            if res[symbol]['assetType'] != 'EQUITY':
                raise RuntimeError("资产不是股票类型,暂不支持")

    def __init__(self, name: str, initial_cash: float):
        super().__init__(name, initial_cash)

        self.account_id = None
        self.client: TDClient = None
        self.td_orders: Mapping[str, TDOrder] = {}
        self.start_sync_order_thread()

    def with_client(self, client_id, redirect_uri, credentials_path,
                    account_id):
        self.account_id = account_id
        self.client = TDClient(client_id=client_id,
                               redirect_uri=redirect_uri,
                               credentials_path=credentials_path)
        self.client.login()

    @do_log(target_name='下单', escape_params=[EscapeParam(index=0, key='self')])
    @alarm(target='下单', escape_params=[EscapeParam(index=0, key='self')])
    @retry(limit=3)
    def place_order(self, order: Order):
        td_order = TDOrder(order, self.order_callback, self)
        try:
            resp = self.client.place_order(self.account_id, td_order.to_dict())
        except Exception as e:
            order.status = OrderStatus.FAILED
            raise e
        td_order_id = resp.get('order_id')
        order.td_order_id = td_order_id
        self.sync_order(td_order)
        if order.status == OrderStatus.CREATED or order.status == OrderStatus.FAILED:
            if order.status == OrderStatus.CREATED:
                self.cancel_open_order(order)
            raise RuntimeError("place order error")
        self.td_orders[td_order_id] = td_order

    def start_sync_order_thread(self):
        def do_sync():
            while True:
                for td_order_id in self.td_orders.keys():
                    td_order = self.td_orders.get(td_order_id)
                    if td_order.framework_order.status in [
                            OrderStatus.FAILED, OrderStatus.FILLED,
                            OrderStatus.CANCELED
                    ]:
                        # 如果订单到终态,就不需要同步
                        continue
                    try:
                        self.sync_order(td_order)
                    except:
                        import traceback
                        logging.error("{}".format(traceback.format_exc()))
                import time
                time.sleep(0.5)

        threading.Thread(target=do_sync, name="sync td orders").start()

    @alarm(level=AlarmLevel.ERROR,
           target="同步订单",
           escape_params=[EscapeParam(index=0, key='self')])
    def sync_order(self, td_order: TDOrder):
        """
            同步订单状态以及订单执行情况
            :return:
        """
        if not td_order.td_order_id:
            raise RuntimeError("非法的td订单")
        if td_order.framework_order.status in [
                OrderStatus.FAILED, OrderStatus.CANCELED, OrderStatus.FILLED
        ]:
            logging.info("订单已经到终态,不需要同步")
            return
        o: dict = self.client.get_orders(self.account_id, td_order.td_order_id)
        # 更新框架订单的状态
        td_order_status: str = o.get('status')
        if td_order_status in ['ACCEPTED', 'WORKING', 'QUEUED'] and \
                td_order.framework_order.status == OrderStatus.CREATED:
            td_order.framework_order.status = OrderStatus.SUBMITTED
        elif 'PENDING' in td_order_status:
            # pending状态下,不改变状态
            pass
        elif td_order_status == "CANCELED" and td_order.framework_order.status != OrderStatus.CANCELED:
            td_order.framework_order.status = OrderStatus.CANCELED
            self.order_callback.order_status_change(td_order.framework_order,
                                                    None)
        elif td_order_status == 'FILLED':
            # filled的状态在account.order_filled中变更
            pass
        else:
            raise NotImplementedError("无法处理的订单状态:" + td_order_status)

        # 同步执行详情,由于td的执行详情没有惟一标识,所以每次同步到执行详情时,会将旧的执行详情回滚掉,再应用新的执行详情
        executions: List[Dict] = o.get("orderActivityCollection")
        if executions and len(executions) > 0:
            filled_quantity = 0
            total_cost = 0
            # 汇总所有执行详情
            for execution in executions:
                if not execution.get("executionType") == 'FILL':
                    raise NotImplementedError
                execution_legs: List[Dict] = execution.get("executionLegs")
                if len(execution_legs) > 1:
                    raise NotImplementedError
                if execution_legs and len(execution_legs) == 1:
                    execution_leg = execution_legs[0]
                    single_filled_quantity = execution_leg.get("quantity")
                    single_filled_price = execution_leg.get('price')
                    total_cost += single_filled_price * single_filled_quantity
                    filled_quantity += single_filled_quantity
            filled_avg_price = total_cost / filled_quantity
            if filled_quantity > td_order.framework_order.filled_quantity:
                if len(td_order.framework_order.execution_map) > 0:
                    old_version = td_order.framework_order.execution_map.get(
                        "default").version
                    oe = OrderExecution("default", old_version + 1, 0,
                                        filled_quantity, filled_avg_price,
                                        None, None,
                                        td_order.framework_order.direction,
                                        None)
                else:
                    oe = OrderExecution("default", 1, 0, filled_quantity,
                                        filled_avg_price, None, None,
                                        td_order.framework_order.direction,
                                        None)
                self.order_filled(td_order.framework_order, oe)

    def match(self, data):
        raise NotImplementedError

    @do_log(target_name='取消订单',
            escape_params=[EscapeParam(index=0, key='self')])
    @alarm(target='取消订单', escape_params=[EscapeParam(index=0, key='self')])
    @retry(limit=3)
    def cancel_open_order(self, open_order: Order):
        if not open_order.td_order_id or open_order.td_order_id not in self.td_orders:
            raise RuntimeError("没有td订单号")
        if open_order.status == OrderStatus.CANCELED:
            return
        td_order = self.td_orders[open_order.td_order_id]
        self.client.cancel_order(self.account_id, open_order.td_order_id)
        self.sync_order(td_order)
        open_order.status = OrderStatus.CANCELED
        self.order_callback.order_status_change(open_order, self)

    def update_order(self, order: Order, reason):
        if not order.td_order_id or (order.td_order_id not in self.td_orders):
            raise RuntimeError("没有td订单号")
        if not isinstance(order, LimitOrder):
            raise NotImplementedError
        self.cancel_open_order(order)
        new_order = LimitOrder(order.code, order.direction,
                               order.quantity - order.filled_quantity,
                               Timestamp.now(tz='Asia/Shanghai'),
                               order.limit_price, None)
        self.place_order(new_order)

    def start_save_thread(self):
        # 启动账户保存线程,每隔半小时会保存当前账户的操作数据
        def save():
            while True:
                try:
                    logging.info("开始保存账户数据")
                    self.save()
                except:
                    import traceback
                    err_msg = "保存账户失败:{}".format(traceback.format_exc())
                    logging.error(err_msg)
                import time
                time.sleep(30 * 60)

        threading.Thread(name="account_save", target=save).start()
コード例 #17
0
ファイル: request.py プロジェクト: rrrudolph/old-and-lame
def request_data():
    # Create a new session
    td_client = TDClient(
        client_id='AZ2BZPRDVDNFBHUB5ADYDAPMD2CLG9RG',
        redirect_uri='https://localhost/first',
        credentials_path='C:/Users/R/Desktop/code/2nd-tda-api/td_state.json')

    # Login to a new session
    td_client.login()
    print('td client logged in')

    # ... Backfill data ...
    ''' This is for filling in data from overnight or
        whenever the program hasn't been running during
        market hours '''
    for symbol in symbols:
        print(symbol)
        ext_hours = True  # extended hours
        period_type = 'month'
        period = 3
        frequency_type = 'daily'
        frequency = 1

        # Make request
        backfill = td_client.get_price_history(symbol=symbol,
                                               period_type=period_type,
                                               period=period,
                                               frequency_type=frequency_type,
                                               frequency=frequency,
                                               extended_hours=ext_hours)

        # returns a dictionary with 3 items: candles (a list), empty, and symbol
        # found a typo:  http://prntscr.com/sw8dtq
        # Remove symbol prefix if present
        if symbol[0] == '$' or symbol[0] == '/':
            symbol = symbol[1:]

        # in case I get bad data that causes an error
        # (log something too)
        print('backfill requested')
        try:
            # Save the OHLC data
            s = backfill['candles']
            new_data = pd.DataFrame(s)
            new_data['datetime'] = pd.to_datetime(new_data['datetime'],
                                                  unit='ms',
                                                  origin='unix')
            new_data.rename(columns={"datetime": "dt"}, inplace=True)
            # new_data.set_index('dt', inplace=True)
            new_data = new_data[[
                'dt', 'open', 'high', 'low', 'close', 'volume'
            ]]
            f = open(
                'C:/Users/R/Desktop/code/2nd-tda-api/data/{}_1d.csv'.format(
                    symbol), 'w+')
            # append to existing raw data and delete duplicates
            old_data = pd.DataFrame(f)
            df = old_data.append(new_data)
            df.drop_duplicates(inplace=True)
            df.fillna(method='ffill', inplace=True)
            df.to_csv(
                'C:/Users/R/Desktop/code/2nd-tda-api/data/{}_1d.csv'.format(
                    symbol),
                index=True)
            f.close()
            print('backfill saved!!!!')
        except Exception as e:
            print(e)
            #log someting maybe
            pass
コード例 #18
0
# Define the ASSET to be traded - ENUM EXAMPLE -- SYMBOL MUST ALWAYS BE A STRING.
new_order_leg.order_leg_asset(asset_type=ORDER_ASSET_TYPE.EQUITY,
                              symbol='MSFT')

# Once we have built our order leg, we can add it to our OrderObject.
new_order.add_order_leg(order_leg=new_order_leg)

# Create a new session
td_session = TDClient(
    account_number=ACCOUNT_NUMBER,
    account_password=ACCOUNT_PASSWORD,
    consumer_id=CONSUMER_ID,
    redirect_uri=REDIRECT_URI,
    json_path=r"C:\Users\Alex\OneDrive\Desktop\TDAmeritradeState.json")

td_session.login()

# # Place the Order
# td_session.place_order(account = '11111', order= new_order)

# Create the Order.
option_order = Order()
option_order.order_session(session='NORMAL')
option_order.order_duration(duration='GOOD_TILL_CANCEL')
option_order.order_type(order_type='LIMIT')
option_order.order_strategy_type(order_strategy_type='SINGLE')
option_order.order_price(price=12.00)

# Create the Order Leg
option_order_leg = OrderLeg()
option_order_leg.order_leg_instruction(instruction='BUY_TO_OPEN')
コード例 #19
0
class AmeritradeRebalanceUtils:

    def __init__(self):
        self.session = None
        self.account = None
        self.account_id = None

    def auth(self, credentials_path='./td_state.json', client_path='./td_client_auth.json'):
        with open(client_path) as f:
            data = json.load(f)
        self.session = TDClient(
            client_id=data['client_id'],
            redirect_uri=data['callback_url'],
            credentials_path=credentials_path
        )
        self.session.login()
        
        # assuming only 1 account under management
        self.account = self.session.get_accounts(fields=['positions'])[0]
        self.account_id = self.account['securitiesAccount']['accountId']
        return self.session
    
    def get_portfolio(self):
        positions = self.account['securitiesAccount']['positions']
        portfolio = {}
        for position in positions:
            portfolio[position['instrument']['symbol']] = position['longQuantity']
        return portfolio

    def place_orders_dry_run(self, portfolio_diff: dict):
        result = portfolio_diff.copy()
        prices = self._get_last_prices(result)
        for ticker, qty in portfolio_diff.items():
            round_qty = round(qty)
            abs_rounded_qty = abs(round_qty)
            result[ticker] = {
            'instruction': ('BUY' if qty > 0 else 'SELL'),
            'qty': abs_rounded_qty,
            'money_movement': round_qty*prices[ticker]*-1
            }
        return result

    def place_orders(self, place_orders_dry_run: dict):
        result = []
        for ticker, order in place_orders_dry_run.items():
            res = self.session.place_order(account=self.account_id, order=self._get_market_order_payload(ticker, order['qty'], order['instruction']))
            result.append(res)
        return result

    def _get_market_order_payload(self, ticker, quantity, instruction='BUY'):
        return {
            "orderType": "MARKET",
            "session": "NORMAL",
            "duration": "DAY",
            "orderStrategyType": "SINGLE",
            "orderLegCollection": [
                {
                    "instruction": instruction,
                    "quantity": quantity,
                    "instrument": {
                    "symbol": ticker,
                    "assetType": "EQUITY"
                    }
                }
            ]
        }

    def _get_last_prices(self, portfolio: dict):
        quotes = self.session.get_quotes(instruments=portfolio.keys())
        portfolio_prices = portfolio.copy()
        for ticker, _ in portfolio_prices.items():
            portfolio_prices[ticker] = quotes[ticker]['lastPrice']
        return portfolio_prices
コード例 #20
0
# from new_tos import td_consumer_key, redirect_url, json_path
from td.client import TDClient
td_consumer_key = "HS7K2SZXYBG2HMOYU6JOMXWAWA2QRASG"
redirect_url = "https://localhost/test"
json_paths = 'C:/Users/Admin/github/Random_Projects/stock trading/td_state.json'

td_account = '490558627'

td_client = TDClient(consumer_id=td_consumer_key,
                     redirect_uri=redirect_url,
                     json_path=json_paths)
td_client.login()
コード例 #21
0
import pprint
from datetime import datetime
from datetime import timedelta
from td.client import TDClient

# Create a new session
TDSession = TDClient(client_id='<CLIENT_ID>',
                     redirect_uri='<REDIRECT_URI>',
                     credentials_path='<CREDENTIALS_PATH>')

# Login to the session
TDSession.login()

# Define a list of all valid periods
valid_values = {
    'minute': {
        'day': [1, 2, 3, 4, 5, 10]
    },
    'daily': {
        'month': [1, 2, 3, 6],
        'year': [1, 2, 3, 5, 10, 15, 20],
        'ytd': [1]
    },
    'weekly': {
        'month': [1, 2, 3, 6],
        'year': [1, 2, 3, 5, 10, 15, 20],
        'ytd': [1]
    },
    'monthly': {
        'year': [1, 2, 3, 5, 10, 15, 20]
    }
コード例 #22
0
 def login(self):
     td_client = TDClient(client_id=CONSUMER_KEY,
                          redirect_uri=REDIRECT_URI,
                          credentials_path=JSON_PATH)
     td_client.login()
     return td_client