def connect_to_TDA(): TDSession = TDClient(client_id=CONSUMER_KEY, redirect_uri=REDIRECT_URI, credentials_path='tda_key.json') TDSession.login() return TDSession
class tdAction(object): def __init__(self): """Constructor""" self.conKey = CONSUMER_KEY self.redirect = REDIRECT_URL self.acct = ACCT_NUM self.tdClient = TDClient(client_id=self.conKey, redirect_uri=self.redirect) self.tdClient.login() def quote(self, tickerList): """Grab quotes""" retDict = {} quotes = self.tdClient.get_quotes(instruments=tickerList) for key in quotes: if type(quotes[key]['lastPrice']) is not None: retDict[key] = quotes[key]['lastPrice'] return retDict def history(self, tickerList): """Create real time data for a given list of tickers""" retDict = {} for item in tickerList: tickerHistory = self.tdClient.get_price_history(symbol=item, period_type='year') priceHistory = [] for data in tickerHistory['candles']: priceHistory.append(data['close']) df = DataFrame({item: priceHistory}) retDict[item] = df return retDict
def create_client(self) -> TDClient: creds_file = self.__fetch_credential_file() client = TDClient(client_id=self.td_client_id, redirect_uri=self.td_redirect_uri, credentials_path=creds_file) client.login() return client
def initialize(client_id, redirect_uri, credentials_path): global client if client: raise RuntimeError("client已经初始化了") client = TDClient(client_id=client_id, redirect_uri=redirect_uri, credentials_path=credentials_path) client.login() TSTypeRegistry.register(TDDailyBar())
def _create_session(self) -> TDClient: td_client = TDClient(client_id=self.client_id, redirect_uri=self.redirect_uri, credentials_path=self.credentials_path) # login to the sessions. td_client.login() return td_client
def initialize(client_id, redirect_uri, credentials_path): global client if client: raise RuntimeError("client已经初始化了") client = TDClient( client_id=client_id, redirect_uri=redirect_uri, credentials_path=credentials_path ) client.login()
def create_client(force_refresh:bool=True, version:str='AWSCURRENT') -> TDClient: factory = ClientFactory() base_path = '/tmp/' if platform == 'win32': base_path = path.join(path.dirname(__file__),'..') #outfile = TemporaryDirectory(dir='FsiCollector') outpath = path.join(base_path,'creds.json') creds_file = factory.__fetch_credential_file(force_refresh=force_refresh, outpath=outpath, version=version) client = TDClient( client_id=factory.td_client_id, redirect_uri=factory.td_redirect_uri, credentials_path=creds_file) client.login() return client
def _create_session(self) -> TDClient: """Start a new session. Creates a new session with the TD Ameritrade API and logs the user into the new session. Returns: ---- TDClient -- A TDClient object with an authenticated sessions. """ # Create a new instance of the client td_client = TDClient(client_id=self.client_id, redirect_uri=self.redirect_uri, credentials_path=self.credentials_path) # log the client into the new session td_client.login() return td_client
def dt_signal_prices(candle_minutes, symbols): TDSession = TDClient(client_id=config.client_id, redirect_uri='http://localhost/test', credentials_path='td_state.json') TDSession.login() cur_day = datetime.datetime.now(tz=pytz.timezone('US/Eastern')) price_end_date = str(int(round(cur_day.timestamp() * 1000))) price_start_date = str( int( round( datetime.datetime(cur_day.year, cur_day.month, cur_day.day - 1).timestamp() * 1000))) candle_list = [] for symbol in symbols: p_hist = TDSession.get_price_history(symbol, period_type='day', frequency_type='minute', frequency=str(candle_minutes), end_date=price_end_date, start_date=price_start_date) for candle in p_hist['candles']: candle_list.append([ symbol, datetime.datetime.fromtimestamp(candle['datetime'] / 1000), candle['open'], candle['close'], candle['high'], candle['low'] ]) df_dt = pd.DataFrame( candle_list, columns=['Symbol', 'Date', 'Open', 'Close', 'High', 'Low']) # Calculate moving average df_dt['SMA_9'] = df_dt.groupby('Symbol')['Close'].rolling( 9).mean().reset_index(0, drop=True) return df_dt
def get_symbols(watchlist): TDSession = TDClient(client_id=config.client_id, redirect_uri='http://localhost/test', credentials_path='td_state.json') TDSession.login() response = TDSession.get_watchlist_accounts() i = 0 for r in response: if r['name'] == str(watchlist): watch = response[i] symbols = [ watch['watchlistItems'][x]['instrument']['symbol'] for x in range(len(watch['watchlistItems'])) ] else: i += 1 return symbols
def grab_refresh_token(self, tda_creds: dict) -> TDClient: # Write the creds int a known location... output = path.join(gettempdir(), 'creds.json') with open(output, 'w') as f: f.write(dumps(tda_creds)) # Fetch the offline token... client = TDClient(credentials_path=output, client_id=self.client_id, redirect_uri=self.redirect_uri) if not client.login(): raise ValueError('Unable to login') if not client.grab_refresh_token(): raise ValueError('Unable to grab_refresh_token') # Read the cached offline token... with open(output, 'r') as f: token = f.read() return loads(token)
def grab_candle_data(pull_from_td: bool) -> list[dict]: """A function that grabs candle data from TD Ameritrade, cleans up the data, and saves it to a JSON file, so we can use it later. ### Parameters ---------- pull_from_td : bool If `True`, pull fresh candles from the TD Ameritrade API. If `False`, load the data from the JSON file. ### Returns ------- list[dict] A list of candle dictionaries with cleaned dates, and additional values. """ if pull_from_td: # Grab configuration values. config = ConfigParser() config.read('config/config.ini') # Read the Config File. CLIENT_ID = config.get('main', 'CLIENT_ID') REDIRECT_URI = config.get('main', 'REDIRECT_URI') JSON_PATH = config.get('main', 'JSON_PATH') ACCOUNT_NUMBER = config.get('main', 'ACCOUNT_NUMBER') # Create a new session TDSession = TDClient(client_id=CLIENT_ID, redirect_uri=REDIRECT_URI, credentials_path=JSON_PATH, account_number=ACCOUNT_NUMBER) # Login to the session TDSession.login() # Initialize the list to store candles. all_candles = [] # Loop through each Ticker. for ticker in ['AAPL', 'NIO', 'FIT', 'TSLA', 'MSFT', 'AMZN', 'IBM']: # Grab the Quotes. quotes = TDSession.get_price_history(symbol=ticker, period_type='day', period='10', frequency_type='minute', frequency=1, extended_hours=False) # Grab the Candles. candles = quotes['candles'] # Loop through each candle. for candle in candles: # Calculate the Range. candle['range'] = round(candle['high'] - candle['low'], 5) # Add the Symbol. candle['symbol'] = quotes['symbol'] # Convert to ISO String. candle['datetime_iso'] = datetime.fromtimestamp( candle['datetime'] / 1000).isoformat() # Conver to a Timestamp non-milliseconds. candle['datetime_non_milli'] = int(candle['datetime'] / 1000) all_candles.append((candle)) # Save it to a JSON File. with open(file='data/candles.json', mode='w+') as candle_file: json.dump(obj=all_candles, fp=candle_file, indent=4) elif pull_from_td is False and pathlib.Path('data/candles.json').exists(): # Save it to a JSON File. with open(file='data/candles.json', mode='r') as candle_file: all_candles = json.load(fp=candle_file) return all_candles
class TDSession(TestCase): """Will perform a unit test for the TD session.""" def setUp(self) -> None: """Set up the `TDClient`.""" # Grab configuration values. config = ConfigParser() config.read('config/config.ini') # Load the values. CLIENT_ID = config.get('main', 'CLIENT_ID') REDIRECT_URI = config.get('main', 'REDIRECT_URI') JSON_PATH = config.get('main', 'JSON_PATH') ACCOUNT_NUMBER = config.get('main', 'ACCOUNT_NUMBER') # Initalize the session. self.td_session = TDClient(client_id=CLIENT_ID, redirect_uri=REDIRECT_URI, credentials_path=JSON_PATH, account_number=ACCOUNT_NUMBER) def test_creates_instance_of_session(self): """Create an instance and make sure it's a `TDClient` object.""" self.assertIsInstance(self.td_session, TDClient) def test_login(self): """Test whether the session is authenticated or not.""" self.assertTrue(self.td_session.login()) self.assertTrue(self.td_session.authstate) def test_state(self): """Make sure the state is updated.""" self.assertIsNotNone(self.td_session.state['refresh_token']) self.assertIsNotNone(self.td_session.state['access_token']) self.assertNotEqual(self.td_session.state['refresh_token_expires_at'], 0) self.assertNotEqual(self.td_session.state['access_token_expires_at'], 0) def test_single_get_quotes(self): """Test Getting a Single quote.""" # Grab a single quote. quotes = self.td_session.get_quotes(instruments=['MSFT']) # See if the Symbol is in the Quotes. self.assertIn('MSFT', quotes) # Save the data. if SAVE_FLAG: with open('samples/responses/sample_single_quotes.jsonc', 'w+') as data_file: json.dump(obj=quotes, fp=data_file, indent=3) def test_get_quotes(self): """Test Getting Multiple Quotes.""" # Grab multiple Quotes. quotes = self.td_session.get_quotes(instruments=['MSFT', 'AAPL']) # See if the Symbols are in the Quotes. self.assertTrue(set(['MSFT', 'AAPL']).issuperset(set(quotes.keys()))) # Save the data. if SAVE_FLAG: with open('samples/responses/sample_multiple_quotes.jsonc', 'w+') as data_file: json.dump(obj=quotes, fp=data_file, indent=3) def test_get_accounts(self): """Test Get Accounts.""" accounts = self.td_session.get_accounts(account='all', fields=['orders', 'positions']) self.assertIn('positions', accounts[0]['securitiesAccount']) self.assertIn('currentBalances', accounts[0]['securitiesAccount']) # self.assertIn('orderStrategies', accounts[0]['securitiesAccount']) # Save the data. if SAVE_FLAG: with open('samples/responses/sample_accounts.jsonc', 'w+') as data_file: json.dump(obj=accounts, fp=data_file, indent=3) def test_create_stream_session(self): """Test Creating a new streaming session.""" stream_session = self.td_session.create_streaming_session() self.assertIsInstance(stream_session, TDStreamerClient) def test_get_transactions(self): """Test getting transactions.""" # `get_transactions` Endpoint. Should not return an error transaction_data_multi = self.td_session.get_transactions( account=self.td_session.account_number, transaction_type='ALL') # Make sure it's a list. self.assertIsInstance(transaction_data_multi, list) self.assertIn('type', transaction_data_multi[0]) # Save the data. if SAVE_FLAG: with open('samples/responses/sample_transaction_data.jsonc', 'w+') as data_file: json.dump(obj=transaction_data_multi, fp=data_file, indent=3) def test_get_market_hours(self): """Test get market hours.""" # `get_market_hours` Endpoint with multiple values market_hours_multi = self.td_session.get_market_hours( markets=['EQUITY', 'FOREX'], date=datetime.today().isoformat()) # If it's a weekend nothing is returned, so raise an error. if datetime.today().weekday() in (5, 6): # Make sure it's a list. self.assertIsInstance(market_hours_multi, dict) self.assertIn('isOpen', market_hours_multi['equity']['equity']) else: # Make sure it's a list. self.assertIsInstance(market_hours_multi, dict) self.assertIn('isOpen', market_hours_multi['equity']['EQ']) # Save the data. if SAVE_FLAG: with open('samples/responses/sample_market_hours.jsonc', 'w+') as data_file: json.dump(obj=market_hours_multi, fp=data_file, indent=3) def test_get_instrument(self): """Test getting an instrument.""" # `get_instruments` Endpoint. get_instrument = self.td_session.get_instruments(cusip='594918104') # Make sure it's a list. self.assertIsInstance(get_instrument, list) self.assertIn('cusip', get_instrument[0]) # Save the data. if SAVE_FLAG: with open('samples/responses/sample_instrument.jsonc', 'w+') as data_file: json.dump(obj=get_instrument, fp=data_file, indent=3) def test_chart_history(self): """Test getting historical prices.""" # Define a list of all valid periods valid_values = { 'minute': { 'day': [1, 2, 3, 4, 5, 10] }, 'daily': { 'month': [1, 2, 3, 6], 'year': [1, 2, 3, 5, 10, 15, 20], 'ytd': [1] }, 'weekly': { 'month': [1, 2, 3, 6], 'year': [1, 2, 3, 5, 10, 15, 20], 'ytd': [1] }, 'monthly': { 'year': [1, 2, 3, 5, 10, 15, 20] } } # Define the static arguments. hist_symbol = 'MSFT' hist_needExtendedHoursData = False for frequency_type in valid_values.keys(): frequency_periods = valid_values[frequency_type] for frequency_period in frequency_periods.keys(): possible_values = frequency_periods[frequency_period] for value in possible_values: # Define the dynamic arguments - I want 5 DAYS of historical 1-minute bars. hist_periodType = frequency_period hist_period = value hist_frequencyType = frequency_type hist_frequency = 1 # make the request historical_prices = self.td_session.get_price_history( symbol=hist_symbol, period_type=hist_periodType, period=hist_period, frequency_type=hist_frequencyType, frequency=hist_frequency, extended_hours=hist_needExtendedHoursData) self.assertIsInstance(historical_prices, dict) self.assertFalse(historical_prices['empty']) # Save the data. if SAVE_FLAG: with open('samples/responses/sample_historical_prices.jsonc', 'w+') as data_file: json.dump(obj=historical_prices, fp=data_file, indent=3) def test_custom_historical_prices(self): """Test getting historical prices for a custom date range.""" # The max look back period for minute data is 31 Days. lookback_period = 10 # Define today. today_00 = datetime.now() # Define 300 days ago. today_ago = datetime.now() - timedelta(days=lookback_period) # The TD API expects a timestamp in milliseconds. However, the timestamp() method only returns to seconds so multiply it by 1000. today_00 = str(int(round(today_00.timestamp() * 1000))) today_ago = str(int(round(today_ago.timestamp() * 1000))) # These values will now be our startDate and endDate parameters. hist_startDate = today_ago hist_endDate = today_00 # Define the dynamic arguments. hist_periodType = 'day' hist_frequencyType = 'minute' hist_frequency = 1 # Make the request historical_custom = self.td_session.get_price_history( symbol='MSFT', period_type=hist_periodType, frequency_type=hist_frequencyType, start_date=hist_startDate, end_date=hist_endDate, frequency=hist_frequency, extended_hours=True) self.assertIsInstance(historical_custom, dict) self.assertFalse(historical_custom['empty']) # Save the data. if SAVE_FLAG: with open('samples/responses/sample_historical_prices.jsonc', 'w+') as data_file: json.dump(obj=historical_custom, fp=data_file, indent=3) def test_search_instruments(self): """Test Searching for Instruments.""" # `search_instruments` Endpoint instrument_search_data = self.td_session.search_instruments( symbol='MSFT', projection='symbol-search') self.assertIsInstance(instrument_search_data, dict) self.assertIn('MSFT', instrument_search_data) # Save the data. if SAVE_FLAG: with open('samples/responses/sample_search_instrument.jsonc', 'w+') as data_file: json.dump(obj=instrument_search_data, fp=data_file, indent=3) def test_get_movers(self): """Test getting Market movers.""" # `get_movers` Endpoint movers_data = self.td_session.get_movers(market='$DJI', direction='up', change='value') if datetime.today().weekday() in (5, 6): self.assertIsInstance(movers_data, list) self.assertFalse(movers_data) else: self.assertIsInstance(movers_data, list) self.assertIn('symbol', movers_data[0]) # Save the data. if SAVE_FLAG: with open('samples/responses/sample_movers.jsonc', 'w+') as data_file: json.dump(obj=movers_data, fp=data_file, indent=3) def test_get_user_preferences(self): """Test getting user preferences.""" # `get_preferences` endpoint. Should not return an error preference_data = self.td_session.get_preferences( account=self.td_session.account_number) self.assertIsInstance(preference_data, dict) self.assertIn('expressTrading', preference_data) # Save the data. if SAVE_FLAG: with open('samples/responses/sample_account_preferences.jsonc', 'w+') as data_file: json.dump(obj=preference_data, fp=data_file, indent=3) def test_get_user_principals(self): """Test getting user principals.""" # `get_preferences` endpoint. Should not return an error user_principals = self.td_session.get_user_principals( fields=['preferences', 'surrogateIds']) self.assertIsInstance(user_principals, dict) self.assertIn('userId', user_principals) # Save the data. if SAVE_FLAG: with open('samples/responses/sample_user_principals.jsonc', 'w+') as data_file: json.dump(obj=user_principals, fp=data_file, indent=3) def test_get_streamer_keys(self): """Test getting user preferences.""" # `get_subscription_keys` endpoint. Should not return an error streamer_keys = self.td_session.get_streamer_subscription_keys( accounts=[self.td_session.account_number]) self.assertIsInstance(streamer_keys, dict) self.assertIn('keys', streamer_keys) # Save the data. if SAVE_FLAG: with open('samples/responses/sample_streamer_keys.jsonc', 'w+') as data_file: json.dump(obj=streamer_keys, fp=data_file, indent=3) def tearDown(self) -> None: """Teardown the Robot.""" self.td_session = None
import dateutil.relativedelta from td.client import TDClient from td.enums import ORDER_SESSION, ORDER_TYPE from util.options import processOptionTrade, orderedKeys from config import ACCOUNT_ID, ACCOUNT_NUMBER, ACCOUNT_PASSWORD, CONSUMER_ID, REDIRECT_URI, SECRET_QUESTIONS # create a new session sess = TDClient( account_number = ACCOUNT_NUMBER, account_password = ACCOUNT_PASSWORD, consumer_id = CONSUMER_ID, redirect_uri = REDIRECT_URI, secret_questions = SECRET_QUESTIONS ) sess.login() # run job every end of date before midnight # start date exclusive, end date inclusive # for daily job, endData should be today today = date.today() delta = dateutil.relativedelta.relativedelta(days=1) startDt = (today - delta).strftime("%Y-%m-%d") endDt = today.strftime("%Y-%m-%d") data = sess.get_transactions( account = ACCOUNT_ID, transaction_type = 'TRADE', start_date=startDt, end_date=endDt ) trades = list(filter(lambda entry: entry['type'] == 'TRADE', data)) optionTrades = filter(lambda entry: entry['transactionItem']['instrument']['assetType'] == 'OPTION', trades)
class PyRobotPortfolioTest(TestCase): """Will perform a unit test for the Portfolio object.""" def setUp(self) -> None: """Set up the Portfolio.""" self.portfolio = Portfolio() self.maxDiff = None # Grab configuration values. config = ConfigParser() config.read('configs/config.ini') CLIENT_ID = config.get('main', 'CLIENT_ID') REDIRECT_URI = config.get('main', 'REDIRECT_URI') CREDENTIALS_PATH = config.get('main', 'JSON_PATH') self.ACCOUNT_NUMBER = config.get('main', 'ACCOUNT_NUMBER') self.td_client = TDClient( client_id=CLIENT_ID, redirect_uri=REDIRECT_URI, credentials_path=CREDENTIALS_PATH ) self.td_client.login() def test_create_portofolio(self): """Make sure it's a Portfolio.""" self.assertIsInstance(self.portfolio, Portfolio) def test_td_client_property(self): """Test the TD Client property.""" # Should be None if wasn't initalized from the PyRobot. self.assertIsNone(self.portfolio.td_client) def test_stock_frame_property(self): """Test the Stock Frame property.""" # Should be None if wasn't initalized from the PyRobot. self.assertIsNone(self.portfolio.stock_frame) def test_historical_prices_property(self): """Test the Historical Prices property.""" # Should be False if wasn't initalized from the PyRobot. self.assertFalse(self.portfolio.historical_prices) def test_add_position(self): """Test adding a single position to the portfolio.""" new_position = self.portfolio.add_position( symbol='MSFT', asset_type='equity', quantity=10, purchase_price=3.00, purchase_date='2020-01-31' ) correct_position = { 'symbol': 'MSFT', 'asset_type': 'equity', 'ownership_status': True, 'quantity': 10, 'purchase_price': 3.00, 'purchase_date': '2020-01-31' } self.assertDictEqual(new_position, correct_position) def test_add_position_default_arguments(self): """Test adding a single position to the portfolio, no date.""" new_position = self.portfolio.add_position( symbol='MSFT', asset_type='equity' ) correct_position = { 'symbol': 'MSFT', 'asset_type': 'equity', 'ownership_status': False, 'quantity': 0, 'purchase_price': 0.00, 'purchase_date': None } self.assertDictEqual(new_position, correct_position) def test_delete_existing_position(self): """Test deleting an exisiting position.""" self.portfolio.add_position( symbol='MSFT', asset_type='equity', quantity=10, purchase_price=3.00, purchase_date='2020-01-31' ) delete_status = self.portfolio.remove_position(symbol='MSFT') correct_status = (True, 'MSFT was successfully removed.') self.assertTupleEqual(delete_status, correct_status) def test_delete_non_existing_position(self): """Test deleting a non-exisiting position.""" delete_status = self.portfolio.remove_position(symbol='AAPL') correct_status = (False, 'AAPL did not exist in the porfolio.') self.assertTupleEqual(delete_status, correct_status) def test_in_portfolio_exisitng(self): """Checks to see if an exisiting position exists.""" self.portfolio.add_position( symbol='MSFT', asset_type='equity', quantity=10, purchase_price=3.00, purchase_date='2020-01-31' ) in_portfolio_flag = self.portfolio.in_portfolio(symbol='MSFT') self.assertTrue(in_portfolio_flag) def test_in_portfolio_non_exisitng(self): """Checks to see if a non exisiting position exists.""" in_portfolio_flag = self.portfolio.in_portfolio(symbol='AAPL') self.assertFalse(in_portfolio_flag) def test_is_profitable(self): """Checks to see if a position is profitable.""" # Add a position. self.portfolio.add_position( symbol='MSFT', asset_type='equity', quantity=10, purchase_price=3.00, purchase_date='2020-01-31' ) # Test for being Profitable. is_profitable = self.portfolio.is_profitable( symbol='MSFT', current_price=5.00 ) # Test for not being profitable. is_not_profitable = self.portfolio.is_profitable( symbol='MSFT', current_price=1.00 ) self.assertTrue(is_profitable) self.assertFalse(is_not_profitable) def test_projected_market_value(self): """Tests the generation of a market value summary, for all of the positions.""" # Add a position. self.portfolio.add_position( symbol='MSFT', asset_type='equity', quantity=10, purchase_price=3.00, purchase_date='2020-01-31' ) correct_dict = { 'MSFT': { 'current_price': 5.0, 'is_profitable': True, 'purchase_price': 3.0, 'quantity': 10, 'total_invested_capital': 30.0, 'total_loss_or_gain_$': 20.0, 'total_loss_or_gain_%': 0.6667, 'total_market_value': 50.0 }, 'total': { 'number_of_breakeven_positions': 0, 'number_of_non_profitable_positions': 0, 'number_of_profitable_positions': 1, 'total_invested_capital': 30.0, 'total_market_value': 50.0, 'total_positions': 1, 'total_profit_or_loss': 20.0 } } portfolio_summary = self.portfolio.projected_market_value(current_prices={'MSFT':{'lastPrice':5.0}}) self.assertDictEqual(correct_dict, portfolio_summary) def test_grab_historical_prices(self): pass def test_portfolio_summary(self): """Tests the generation of a portfolio summary, for all of the positions.""" # Add a position. self.portfolio.add_position( symbol='MSFT', asset_type='equity', quantity=10, purchase_price=3.00, purchase_date='2020-01-31' ) self.portfolio.td_client = self.td_client correct_dict = [ 'projected_market_value', 'portfolio_weights', 'portfolio_risk' ] correct_dict = set(correct_dict) summary_dict = self.portfolio.portfolio_summary() self.assertTrue(correct_dict.issubset(summary_dict)) def test_ownership_status(self): """Tests getting and setting the ownership status.""" # Add a position. self.portfolio.add_position( symbol='MSFT', asset_type='equity', quantity=10, purchase_price=3.00, purchase_date='2020-01-31' ) # Should be True, since `purchase_date` was set. self.assertTrue(self.portfolio.get_ownership_status(symbol='MSFT')) # Reassign it. self.portfolio.set_ownership_status(symbol='MSFT', ownership=False) # Should be False. self.assertFalse(self.portfolio.get_ownership_status(symbol='MSFT')) def tearDown(self) -> None: """Teardown the Portfolio object.""" self.portfolio = None
class TDAccount(AbstractAccount): def valid_scope(self, scope: Scope): for code in scope.codes: cc = code.split("_") symbol = cc[0] symbol_type = cc[1] if symbol_type != 'STK': raise NotImplementedError res = self.client.search_instruments(symbol, "symbol-search") if not res or len(res) <= 0 or symbol not in res: raise RuntimeError("没有查询到资产,code:" + code) if res[symbol]['assetType'] != 'EQUITY': raise RuntimeError("资产不是股票类型,暂不支持") def __init__(self, name: str, initial_cash: float): super().__init__(name, initial_cash) self.account_id = None self.client: TDClient = None self.td_orders: Mapping[str, TDOrder] = {} self.start_sync_order_thread() def with_client(self, client_id, redirect_uri, credentials_path, account_id): self.account_id = account_id self.client = TDClient(client_id=client_id, redirect_uri=redirect_uri, credentials_path=credentials_path) self.client.login() @do_log(target_name='下单', escape_params=[EscapeParam(index=0, key='self')]) @alarm(target='下单', escape_params=[EscapeParam(index=0, key='self')]) @retry(limit=3) def place_order(self, order: Order): td_order = TDOrder(order, self.order_callback, self) try: resp = self.client.place_order(self.account_id, td_order.to_dict()) except Exception as e: order.status = OrderStatus.FAILED raise e td_order_id = resp.get('order_id') order.td_order_id = td_order_id self.sync_order(td_order) if order.status == OrderStatus.CREATED or order.status == OrderStatus.FAILED: if order.status == OrderStatus.CREATED: self.cancel_open_order(order) raise RuntimeError("place order error") self.td_orders[td_order_id] = td_order def start_sync_order_thread(self): def do_sync(): while True: for td_order_id in self.td_orders.keys(): td_order = self.td_orders.get(td_order_id) if td_order.framework_order.status in [ OrderStatus.FAILED, OrderStatus.FILLED, OrderStatus.CANCELED ]: # 如果订单到终态,就不需要同步 continue try: self.sync_order(td_order) except: import traceback logging.error("{}".format(traceback.format_exc())) import time time.sleep(0.5) threading.Thread(target=do_sync, name="sync td orders").start() @alarm(level=AlarmLevel.ERROR, target="同步订单", escape_params=[EscapeParam(index=0, key='self')]) def sync_order(self, td_order: TDOrder): """ 同步订单状态以及订单执行情况 :return: """ if not td_order.td_order_id: raise RuntimeError("非法的td订单") if td_order.framework_order.status in [ OrderStatus.FAILED, OrderStatus.CANCELED, OrderStatus.FILLED ]: logging.info("订单已经到终态,不需要同步") return o: dict = self.client.get_orders(self.account_id, td_order.td_order_id) # 更新框架订单的状态 td_order_status: str = o.get('status') if td_order_status in ['ACCEPTED', 'WORKING', 'QUEUED'] and \ td_order.framework_order.status == OrderStatus.CREATED: td_order.framework_order.status = OrderStatus.SUBMITTED elif 'PENDING' in td_order_status: # pending状态下,不改变状态 pass elif td_order_status == "CANCELED" and td_order.framework_order.status != OrderStatus.CANCELED: td_order.framework_order.status = OrderStatus.CANCELED self.order_callback.order_status_change(td_order.framework_order, None) elif td_order_status == 'FILLED': # filled的状态在account.order_filled中变更 pass else: raise NotImplementedError("无法处理的订单状态:" + td_order_status) # 同步执行详情,由于td的执行详情没有惟一标识,所以每次同步到执行详情时,会将旧的执行详情回滚掉,再应用新的执行详情 executions: List[Dict] = o.get("orderActivityCollection") if executions and len(executions) > 0: filled_quantity = 0 total_cost = 0 # 汇总所有执行详情 for execution in executions: if not execution.get("executionType") == 'FILL': raise NotImplementedError execution_legs: List[Dict] = execution.get("executionLegs") if len(execution_legs) > 1: raise NotImplementedError if execution_legs and len(execution_legs) == 1: execution_leg = execution_legs[0] single_filled_quantity = execution_leg.get("quantity") single_filled_price = execution_leg.get('price') total_cost += single_filled_price * single_filled_quantity filled_quantity += single_filled_quantity filled_avg_price = total_cost / filled_quantity if filled_quantity > td_order.framework_order.filled_quantity: if len(td_order.framework_order.execution_map) > 0: old_version = td_order.framework_order.execution_map.get( "default").version oe = OrderExecution("default", old_version + 1, 0, filled_quantity, filled_avg_price, None, None, td_order.framework_order.direction, None) else: oe = OrderExecution("default", 1, 0, filled_quantity, filled_avg_price, None, None, td_order.framework_order.direction, None) self.order_filled(td_order.framework_order, oe) def match(self, data): raise NotImplementedError @do_log(target_name='取消订单', escape_params=[EscapeParam(index=0, key='self')]) @alarm(target='取消订单', escape_params=[EscapeParam(index=0, key='self')]) @retry(limit=3) def cancel_open_order(self, open_order: Order): if not open_order.td_order_id or open_order.td_order_id not in self.td_orders: raise RuntimeError("没有td订单号") if open_order.status == OrderStatus.CANCELED: return td_order = self.td_orders[open_order.td_order_id] self.client.cancel_order(self.account_id, open_order.td_order_id) self.sync_order(td_order) open_order.status = OrderStatus.CANCELED self.order_callback.order_status_change(open_order, self) def update_order(self, order: Order, reason): if not order.td_order_id or (order.td_order_id not in self.td_orders): raise RuntimeError("没有td订单号") if not isinstance(order, LimitOrder): raise NotImplementedError self.cancel_open_order(order) new_order = LimitOrder(order.code, order.direction, order.quantity - order.filled_quantity, Timestamp.now(tz='Asia/Shanghai'), order.limit_price, None) self.place_order(new_order) def start_save_thread(self): # 启动账户保存线程,每隔半小时会保存当前账户的操作数据 def save(): while True: try: logging.info("开始保存账户数据") self.save() except: import traceback err_msg = "保存账户失败:{}".format(traceback.format_exc()) logging.error(err_msg) import time time.sleep(30 * 60) threading.Thread(name="account_save", target=save).start()
def request_data(): # Create a new session td_client = TDClient( client_id='AZ2BZPRDVDNFBHUB5ADYDAPMD2CLG9RG', redirect_uri='https://localhost/first', credentials_path='C:/Users/R/Desktop/code/2nd-tda-api/td_state.json') # Login to a new session td_client.login() print('td client logged in') # ... Backfill data ... ''' This is for filling in data from overnight or whenever the program hasn't been running during market hours ''' for symbol in symbols: print(symbol) ext_hours = True # extended hours period_type = 'month' period = 3 frequency_type = 'daily' frequency = 1 # Make request backfill = td_client.get_price_history(symbol=symbol, period_type=period_type, period=period, frequency_type=frequency_type, frequency=frequency, extended_hours=ext_hours) # returns a dictionary with 3 items: candles (a list), empty, and symbol # found a typo: http://prntscr.com/sw8dtq # Remove symbol prefix if present if symbol[0] == '$' or symbol[0] == '/': symbol = symbol[1:] # in case I get bad data that causes an error # (log something too) print('backfill requested') try: # Save the OHLC data s = backfill['candles'] new_data = pd.DataFrame(s) new_data['datetime'] = pd.to_datetime(new_data['datetime'], unit='ms', origin='unix') new_data.rename(columns={"datetime": "dt"}, inplace=True) # new_data.set_index('dt', inplace=True) new_data = new_data[[ 'dt', 'open', 'high', 'low', 'close', 'volume' ]] f = open( 'C:/Users/R/Desktop/code/2nd-tda-api/data/{}_1d.csv'.format( symbol), 'w+') # append to existing raw data and delete duplicates old_data = pd.DataFrame(f) df = old_data.append(new_data) df.drop_duplicates(inplace=True) df.fillna(method='ffill', inplace=True) df.to_csv( 'C:/Users/R/Desktop/code/2nd-tda-api/data/{}_1d.csv'.format( symbol), index=True) f.close() print('backfill saved!!!!') except Exception as e: print(e) #log someting maybe pass
# Define the ASSET to be traded - ENUM EXAMPLE -- SYMBOL MUST ALWAYS BE A STRING. new_order_leg.order_leg_asset(asset_type=ORDER_ASSET_TYPE.EQUITY, symbol='MSFT') # Once we have built our order leg, we can add it to our OrderObject. new_order.add_order_leg(order_leg=new_order_leg) # Create a new session td_session = TDClient( account_number=ACCOUNT_NUMBER, account_password=ACCOUNT_PASSWORD, consumer_id=CONSUMER_ID, redirect_uri=REDIRECT_URI, json_path=r"C:\Users\Alex\OneDrive\Desktop\TDAmeritradeState.json") td_session.login() # # Place the Order # td_session.place_order(account = '11111', order= new_order) # Create the Order. option_order = Order() option_order.order_session(session='NORMAL') option_order.order_duration(duration='GOOD_TILL_CANCEL') option_order.order_type(order_type='LIMIT') option_order.order_strategy_type(order_strategy_type='SINGLE') option_order.order_price(price=12.00) # Create the Order Leg option_order_leg = OrderLeg() option_order_leg.order_leg_instruction(instruction='BUY_TO_OPEN')
class AmeritradeRebalanceUtils: def __init__(self): self.session = None self.account = None self.account_id = None def auth(self, credentials_path='./td_state.json', client_path='./td_client_auth.json'): with open(client_path) as f: data = json.load(f) self.session = TDClient( client_id=data['client_id'], redirect_uri=data['callback_url'], credentials_path=credentials_path ) self.session.login() # assuming only 1 account under management self.account = self.session.get_accounts(fields=['positions'])[0] self.account_id = self.account['securitiesAccount']['accountId'] return self.session def get_portfolio(self): positions = self.account['securitiesAccount']['positions'] portfolio = {} for position in positions: portfolio[position['instrument']['symbol']] = position['longQuantity'] return portfolio def place_orders_dry_run(self, portfolio_diff: dict): result = portfolio_diff.copy() prices = self._get_last_prices(result) for ticker, qty in portfolio_diff.items(): round_qty = round(qty) abs_rounded_qty = abs(round_qty) result[ticker] = { 'instruction': ('BUY' if qty > 0 else 'SELL'), 'qty': abs_rounded_qty, 'money_movement': round_qty*prices[ticker]*-1 } return result def place_orders(self, place_orders_dry_run: dict): result = [] for ticker, order in place_orders_dry_run.items(): res = self.session.place_order(account=self.account_id, order=self._get_market_order_payload(ticker, order['qty'], order['instruction'])) result.append(res) return result def _get_market_order_payload(self, ticker, quantity, instruction='BUY'): return { "orderType": "MARKET", "session": "NORMAL", "duration": "DAY", "orderStrategyType": "SINGLE", "orderLegCollection": [ { "instruction": instruction, "quantity": quantity, "instrument": { "symbol": ticker, "assetType": "EQUITY" } } ] } def _get_last_prices(self, portfolio: dict): quotes = self.session.get_quotes(instruments=portfolio.keys()) portfolio_prices = portfolio.copy() for ticker, _ in portfolio_prices.items(): portfolio_prices[ticker] = quotes[ticker]['lastPrice'] return portfolio_prices
# from new_tos import td_consumer_key, redirect_url, json_path from td.client import TDClient td_consumer_key = "HS7K2SZXYBG2HMOYU6JOMXWAWA2QRASG" redirect_url = "https://localhost/test" json_paths = 'C:/Users/Admin/github/Random_Projects/stock trading/td_state.json' td_account = '490558627' td_client = TDClient(consumer_id=td_consumer_key, redirect_uri=redirect_url, json_path=json_paths) td_client.login()
import pprint from datetime import datetime from datetime import timedelta from td.client import TDClient # Create a new session TDSession = TDClient(client_id='<CLIENT_ID>', redirect_uri='<REDIRECT_URI>', credentials_path='<CREDENTIALS_PATH>') # Login to the session TDSession.login() # Define a list of all valid periods valid_values = { 'minute': { 'day': [1, 2, 3, 4, 5, 10] }, 'daily': { 'month': [1, 2, 3, 6], 'year': [1, 2, 3, 5, 10, 15, 20], 'ytd': [1] }, 'weekly': { 'month': [1, 2, 3, 6], 'year': [1, 2, 3, 5, 10, 15, 20], 'ytd': [1] }, 'monthly': { 'year': [1, 2, 3, 5, 10, 15, 20] }
def login(self): td_client = TDClient(client_id=CONSUMER_KEY, redirect_uri=REDIRECT_URI, credentials_path=JSON_PATH) td_client.login() return td_client