def _init_options(self, **options): try: self.m_type = options['market'] except KeyError: self.m_type = 'stock' try: self.init_cash = options['cash'] except KeyError: self.init_cash = 100000 try: self.logger = options['logger'] except KeyError: self.logger = None try: self.use_sequence = options['use_sequence'] except KeyError: self.use_sequence = False try: self.use_one_hot = options['use_one_hot'] except KeyError: self.use_one_hot = True try: self.use_normalized = options['use_normalized'] except KeyError: self.use_normalized = True try: self.use_state_mix_cash = options['state_mix_cash'] except KeyError: self.use_state_mix_cash = True try: self.seq_length = options['seq_length'] except KeyError: self.seq_length = 5 finally: self.seq_length = self.seq_length if self.seq_length > 1 else 2 try: self.training_data_ratio = options['training_data_ratio'] except KeyError: self.training_data_ratio = 0.7 try: scaler = options['scaler'] except KeyError: scaler = StandardScaler self.scaler = [scaler() for _ in self.codes] self.trader = Trader(self, cash=self.init_cash) self.doc_class = Stock if self.m_type == 'stock' else Future
def _init_options(self, **options): try: self.m_type = options['market_type'] except KeyError: self.m_type = 'stock' try: self.init_cash = options['cash'] except KeyError: self.init_cash = 100000 try: self.use_sequence = options['use_sequence'] except KeyError: self.use_sequence = False try: self.use_one_hot = options['use_one_hot'] except KeyError: self.use_one_hot = True try: self.use_normalized = options['use_normalized'] except KeyError: self.use_normalized = True try: self.use_state_mix_cash = options['state_mix_cash'] except KeyError: self.use_state_mix_cash = True try: self.seq_length = options['seq_length'] except KeyError: self.seq_length = 5 finally: self.seq_length = self.seq_length if self.seq_length > 1 else 2 try: self.training_data_ratio = options['training_data_ratio'] except KeyError: self.training_data_ratio = 0.7 self.trader = Trader(self, cash=self.init_cash)
class Market(object): Running = 0 Done = -1 def __init__(self, codes, start_date="2008-01-01", end_date="2018-01-01", **options): # Initialize codes. self.codes = codes self.index_codes = [] self.state_codes = [] # Initialize dates. self.dates = [] self.t_dates = [] self.e_dates = [] # Initialize data frames. self.origin_frames = dict() self.scaled_frames = dict() # Initialize scaled data x, y. self.data_x = None self.data_y = None # Initialize scaled seq data x, y. self.seq_data_x = None self.seq_data_y = None # Initialize flag date. self.next_date = None self.iter_dates = None self.current_date = None # Initialize parameters. self._init_options(**options) # Initialize stock data. self._init_data(start_date, end_date) def _init_options(self, **options): try: self.m_type = options['market'] except KeyError: self.m_type = 'stock' try: self.init_cash = options['cash'] except KeyError: self.init_cash = 100000 try: self.logger = options['logger'] except KeyError: self.logger = None try: self.use_sequence = options['use_sequence'] except KeyError: self.use_sequence = False try: self.use_normalized = options['use_normalized'] except KeyError: self.use_normalized = True try: self.mix_trader_state = options['mix_trader_state'] except KeyError: self.mix_trader_state = True try: self.mix_index_state = options['mix_index_state'] except KeyError: self.mix_index_state = False finally: if self.mix_index_state: self.index_codes.append('sh') try: self.seq_length = options['seq_length'] except KeyError: self.seq_length = 5 finally: self.seq_length = self.seq_length if self.seq_length > 1 else 2 try: self.training_data_ratio = options['training_data_ratio'] except KeyError: self.training_data_ratio = 0.7 try: scaler = options['scaler'] except KeyError: scaler = StandardScaler self.state_codes = self.codes + self.index_codes self.scaler = [scaler() for _ in self.state_codes] self.trader = Trader(self, cash=self.init_cash) self.doc_class = Stock if self.m_type == 'stock' else Future def _init_data(self, start_date, end_date): self._init_data_frames(start_date, end_date) self._init_env_data() self._init_data_indices() def _validate_codes(self): if not self.state_code_count: raise ValueError("Codes cannot be empty.") for code in self.state_codes: if not self.doc_class.exist_in_db(code): raise ValueError("Code: {} not exists in database.".format(code)) def _init_data_frames(self, start_date, end_date): # Remove invalid codes first. self._validate_codes() # Init columns and data set. columns, dates_set = ['open', 'high', 'low', 'close', 'volume'], set() # Load data. for index, code in enumerate(self.state_codes): # Load instrument docs by code. instrument_docs = self.doc_class.get_k_data(code, start_date, end_date) # Init instrument dicts. instrument_dicts = [instrument.to_dic() for instrument in instrument_docs] # Split dates. dates = [instrument[1] for instrument in instrument_dicts] # Split instruments. instruments = [instrument[2:] for instrument in instrument_dicts] # Update dates set. dates_set = dates_set.union(dates) # Build origin and scaled frames. scaler = self.scaler[index] scaler.fit(instruments) instruments_scaled = scaler.transform(instruments) origin_frame = pd.DataFrame(data=instruments, index=dates, columns=columns) scaled_frame = pd.DataFrame(data=instruments_scaled, index=dates, columns=columns) # Build code - frame map. self.origin_frames[code] = origin_frame self.scaled_frames[code] = scaled_frame # Init date iter. self.dates = sorted(list(dates_set)) # Rebuild index. for code in self.state_codes: origin_frame = self.origin_frames[code] scaled_frame = self.scaled_frames[code] self.origin_frames[code] = origin_frame.reindex(self.dates, method='bfill') self.scaled_frames[code] = scaled_frame.reindex(self.dates, method='bfill') def _init_env_data(self): if not self.use_sequence: self._init_series_data() else: self._init_sequence_data() def _init_series_data(self): # Calculate data count. self.data_count = len(self.dates[: -1]) # Calculate bound index. self.bound_index = int(self.data_count * self.training_data_ratio) # Init scaled_x, scaled_y. scaled_data_x, scaled_data_y = [], [] for index, date in enumerate(self.dates[: -1]): # Get current x, y. x = [self.scaled_frames[code].iloc[index] for code in self.state_codes] y = [self.scaled_frames[code].iloc[index + 1] for code in self.state_codes] # Convert x, y to array. x = np.array(x).reshape((1, -1)) y = np.array(y) # Append x, y scaled_data_x.append(x) scaled_data_y.append(y) # Convert list to array. self.data_x = np.array(scaled_data_x) self.data_y = np.array(scaled_data_y) def _init_sequence_data(self): # Calculate data count. self.data_count = len(self.dates[: -1 - self.seq_length]) # Calculate bound index. self.bound_index = int(self.data_count * self.training_data_ratio) # Init seqs_x, seqs_y. scaled_seqs_x, scaled_seqs_y = [], [] # Scale to valid dates. for date_index, date in enumerate(self.dates[: -1]): # Continue until valid date index. if date_index < self.seq_length: continue data_x, data_y = [], [] for index, code in enumerate(self.state_codes): # Get scaled frame by code. scaled_frame = self.scaled_frames[code] # Get instrument data x. instruments_x = scaled_frame.iloc[date_index - self.seq_length: date_index] data_x.append(np.array(instruments_x)) # Get instrument data y. if index < self.state_code_count - 1: if date_index < self.bound_index: # Get y, y is not at date index, but plus 1. (Training Set) instruments_y = scaled_frame.iloc[date_index + 1]['close'] else: # Get y, y is at date index. (Test Set) instruments_y = scaled_frame.iloc[date_index + 1]['close'] data_y.append(np.array(instruments_y)) # Convert list to array. data_x = np.array(data_x) data_y = np.array(data_y) seq_x = [] seq_y = data_y # Build seq x, y. for seq_index in range(self.seq_length): seq_x.append(data_x[:, seq_index, :].reshape((-1))) # Convert list to array. seq_x = np.array(seq_x) scaled_seqs_x.append(seq_x) scaled_seqs_y.append(seq_y) # Convert seq from list to array. self.seq_data_x = np.array(scaled_seqs_x) self.seq_data_y = np.array(scaled_seqs_y) def _init_data_indices(self): # Calculate indices range. self.data_indices = np.arange(0, self.data_count) # Calculate train and eval indices. self.t_data_indices = self.data_indices[:self.bound_index] self.e_data_indices = self.data_indices[self.bound_index:] # Generate train and eval dates. self.t_dates = self.dates[:self.bound_index] self.e_dates = self.dates[self.bound_index:] def _origin_data(self, code, date): date_index = self.dates.index(date) return self.origin_frames[code].iloc[date_index] def _scaled_data_as_state(self, date): if not self.use_sequence: data = self.data_x[self.dates.index(date)] if self.mix_trader_state: trader_state = self.trader.scaled_data_as_state() data = np.insert(data, -1, trader_state).reshape((1, -1)) return data else: return self.seq_data_x[self.dates.index(date)] def reset(self, mode='train'): # Reset trader. self.trader.reset() # Reset iter dates by mode. self.iter_dates = iter(self.t_dates) if mode == 'train' else iter(self.e_dates) try: self.current_date = next(self.iter_dates) self.next_date = next(self.iter_dates) except StopIteration: raise ValueError("Reset error, dates are empty.") # Reset baseline. self._reset_baseline() return self._scaled_data_as_state(self.current_date) def get_batch_data(self, batch_size=32): batch_indices = np.random.choice(self.t_data_indices, batch_size) if not self.use_sequence: batch_x = self.data_x[batch_indices] batch_y = self.data_y[batch_indices] else: batch_x = self.seq_data_x[batch_indices] batch_y = self.seq_data_y[batch_indices] return batch_x, batch_y def get_test_data(self): if not self.use_sequence: test_x = self.data_x[self.e_data_indices] test_y = self.data_y[self.e_data_indices] else: test_x = self.seq_data_x[self.e_data_indices] test_y = self.seq_data_y[self.e_data_indices] return test_x, test_y def forward(self, stock_code, action_code): # Check Trader. self.trader.remove_invalid_positions() self.trader.reset_reward() # Get stock data. stock = self._origin_data(stock_code, self.current_date) stock_next = self._origin_data(stock_code, self.next_date) # Execute transaction. action = self.trader.action_by_code(action_code) action(stock_code, stock, 100, stock_next) # Init episode status. episode_done = self.Running # Add action times. self.trader.action_times += 1 # Update date if need. if self.trader.action_times == self.code_count: self._update_profits_and_baseline() try: self.current_date, self.next_date = self.next_date, next(self.iter_dates) except StopIteration: episode_done = self.Done finally: self.trader.action_times = 0 # Get next state. state_next = self._scaled_data_as_state(self.current_date) # Return s_n, r, d, info. return state_next, self.trader.reward, episode_done, self.trader.cur_action_status def _update_profits_and_baseline(self): prices = [self._origin_data(code, self.current_date).close for code in self.codes] baseline_profits = np.dot(self.stocks_holding_baseline, np.transpose(prices)) - self.trader.initial_cash policy_profits = self.trader.profits self.trader.history_baselines.append(baseline_profits) self.trader.history_profits.append(policy_profits) def _reset_baseline(self): # Calculate cash piece. cash_piece = self.init_cash / self.code_count # Get stocks data. stocks = [self._origin_data(code, self.current_date) for code in self.codes] # Init stocks baseline. self.stocks_holding_baseline = [int(math.floor(cash_piece / stock.close)) for stock in stocks] @property def code_count(self): return len(self.codes) @property def index_code_count(self): return len(self.index_codes) @property def state_code_count(self): return len(self.state_codes) @property def data_dim(self): data_dim = self.state_code_count * self.scaled_frames[self.codes[0]].shape[1] if not self.use_sequence: if self.mix_trader_state: data_dim += (2 + self.code_count) return data_dim
def _init_options(self, **options): try: self.pre_process_strategy = options['pre_process_strategy'] except KeyError: self.pre_process_strategy = active_stragery try: self.m_type = options['market'] except KeyError: self.m_type = 'stock' try: self.init_cash = options['cash'] except KeyError: self.init_cash = 100000 try: self.logger = options['logger'] except KeyError: self.logger = None try: self.use_sequence = options['use_sequence'] except KeyError: self.use_sequence = False try: self.use_normalized = options['use_normalized'] except KeyError: self.use_normalized = True try: self.mix_trader_state = options['mix_trader_state'] except KeyError: self.mix_trader_state = True try: self.mix_index_state = options['mix_index_state'] except KeyError: self.mix_index_state = False finally: if self.mix_index_state: self.index_codes.append('sh') try: self.seq_length = options['seq_length'] except KeyError: self.seq_length = 5 finally: self.seq_length = self.seq_length if self.seq_length > 1 else 2 try: self.training_data_ratio = options['training_data_ratio'] except KeyError: self.training_data_ratio = 0.7 try: scaler = options['scaler'] except KeyError: scaler = StandardScaler() self.state_codes = self.codes # self.state_codes = self.codes + self.index_codes self.scaler = [scaler for _ in self.state_codes] self.trader = Trader(self, cash=self.init_cash) self.doc_class = Stock if self.m_type == 'stock' else Future
class Market(object): Running = 0 Done = -1 class Source(Enum): CSV = 'CSV' MONGODB = 'MongoDB' def __init__(self, codes, start_date="2008-01-01", end_date="2018-01-01", col_order=None, **options): # Initialize codes. self.codes = codes self.index_codes = [] self.state_codes = [] # Initialize dates. self.dates = [] self.t_dates = [] self.e_dates = [] # Initialize data frames. self.origin_frames = dict() self.scaled_frames = dict() # added by steven, origin_frames plus indicators calculatd in fly self.post_frames = dict() # Initialize scaled data x, y. self.data_x = None self.data_y = None # Initialize scaled seq data x, y. self.seq_data_x = None self.seq_data_y = None # Initialize flag date. self.next_date = None self.iter_dates = None self.current_date = None self.col_order = col_order # Initialize parameters. self._init_options(**options) # Initialize stock data. self._init_data(start_date, end_date) def _init_options(self, **options): try: self.pre_process_strategy = options['pre_process_strategy'] except KeyError: self.pre_process_strategy = active_stragery try: self.m_type = options['market'] except KeyError: self.m_type = 'stock' try: self.init_cash = options['cash'] except KeyError: self.init_cash = 100000 try: self.logger = options['logger'] except KeyError: self.logger = None try: self.use_sequence = options['use_sequence'] except KeyError: self.use_sequence = False try: self.use_normalized = options['use_normalized'] except KeyError: self.use_normalized = True try: self.mix_trader_state = options['mix_trader_state'] except KeyError: self.mix_trader_state = True try: self.mix_index_state = options['mix_index_state'] except KeyError: self.mix_index_state = False finally: if self.mix_index_state: self.index_codes.append('sh') try: self.seq_length = options['seq_length'] except KeyError: self.seq_length = 5 finally: self.seq_length = self.seq_length if self.seq_length > 1 else 2 try: self.training_data_ratio = options['training_data_ratio'] except KeyError: self.training_data_ratio = 0.7 try: scaler = options['scaler'] except KeyError: scaler = StandardScaler() self.state_codes = self.codes # self.state_codes = self.codes + self.index_codes self.scaler = [scaler for _ in self.state_codes] self.trader = Trader(self, cash=self.init_cash) self.doc_class = Stock if self.m_type == 'stock' else Future def _init_data(self, start_date, end_date): self._init_data_frames(start_date, end_date) self._init_env_data() self._init_data_indices() def _validate_codes(self): if not self.state_code_count: raise ValueError("Codes cannot be empty.") for code in self.state_codes: if not self.doc_class.exist_in_db(code): raise ValueError( "Code: {} not exists in database.".format(code)) def _init_data_frames(self, start_date, end_date, source=Source.CSV.value): # action_fetch, action_pre_analyze, action_analyze, action_post_analyze = pre_process.get_active_strategy() self.dates, self.scaled_frames, self.origin_frames, self.post_frames = \ pre_process.ProcessStrategy(#action_fetch, action_pre_analyze, action_analyze, action_post_analyze, self.state_codes, start_date, end_date, self.scaler[0], self.pre_process_strategy, self.col_order).process() def _init_env_data(self): if not self.use_sequence: self._init_series_data() else: self._init_sequence_data() def _init_series_data(self): # Calculate data count. self.data_count = len(self.dates[:-1]) # Calculate bound index. self.bound_index = int(self.data_count * self.training_data_ratio) # Init scaled_x, scaled_y. scaled_data_x, scaled_data_y = [], [] for index, date in enumerate(self.dates[:-1]): # Get current x, y. x = [ self.scaled_frames[code].iloc[index] for code in self.state_codes ] y = [ self.scaled_frames[code].iloc[index + 1] for code in self.state_codes ] # Convert x, y to array. x = np.array(x).reshape((1, -1)) y = np.array(y) # Append x, y scaled_data_x.append(x) scaled_data_y.append(y) # Convert list to array. self.data_x = np.array(scaled_data_x) self.data_y = np.array(scaled_data_y) def _init_sequence_data(self): # Calculate data count. self.data_count = len(self.dates[:-1 - self.seq_length]) # Calculate bound index. self.bound_index = int(self.data_count * self.training_data_ratio) # Init seqs_x, seqs_y. scaled_seqs_x, scaled_seqs_y = [], [] # Scale to valid dates. for date_index, date in enumerate(self.dates[:-1]): # Continue until valid date index. if date_index < self.seq_length: continue data_x, data_y = [], [] for index, code in enumerate(self.state_codes): # Get scaled frame by code. scaled_frame = self.scaled_frames[code] # Get instrument data x. label = self.pre_process_strategy['label'] # added by wilson instruments_x = scaled_frame.drop( [label], axis=1).iloc[date_index - self.seq_length:date_index] # try: # instruments_x = instruments_x.drop(["close"], axis=1) # added by steven, trend patch # except Exception: # pass data_x.append(np.array(instruments_x)) # Get instrument data y. if index < date_index - 1: if date_index < self.bound_index: # Get y, y is not at date index, but plus 1. (Training Set) # instruments_y = scaled_frame.iloc[date_index + 1]['close'] instruments_y = scaled_frame.iloc[date_index][ label] #TODO confirm is this a bug? else: # Get y, y is at date index. (Test Set) # instruments_y = scaled_frame.iloc[date_index + 1]['close'] instruments_y = scaled_frame.iloc[date_index][ label] #TODO confirm is this a bug? data_y.append(np.array(instruments_y)) # Convert list to array. data_x = np.array(data_x) data_y = np.array(data_y) seq_x = [] seq_y = data_y # Build seq x, y. for seq_index in range(self.seq_length): seq_x.append(data_x[:, seq_index, :].reshape((-1))) # Convert list to array. seq_x = np.array(seq_x) scaled_seqs_x.append(seq_x) scaled_seqs_y.append(seq_y) # Convert seq from list to array. self.seq_data_x = np.array(scaled_seqs_x) self.seq_data_y = np.array(scaled_seqs_y) def _init_data_indices(self): # Calculate indices range. self.data_indices = np.arange(0, self.data_count) # Calculate train and eval indices. self.t_data_indices = self.data_indices[:self.bound_index] self.e_data_indices = self.data_indices[self.bound_index:] # Generate train and eval dates. self.t_dates = self.dates[:self.bound_index] self.e_dates = self.dates[self.bound_index:] def _origin_data(self, code, date): date_index = self.dates.index(date) return self.origin_frames[code].iloc[date_index] def _scaled_data_as_state(self, date): if not self.use_sequence: data = self.data_x[self.dates.index(date)] if self.mix_trader_state: trader_state = self.trader.scaled_data_as_state() data = np.insert(data, -1, trader_state).reshape((1, -1)) return data else: return self.seq_data_x[self.dates.index(date)] def reset(self, mode='train'): # Reset trader. self.trader.reset() # Reset iter dates by mode. self.iter_dates = iter(self.t_dates) if mode == 'train' else iter( self.e_dates) try: self.current_date = next(self.iter_dates) self.next_date = next(self.iter_dates) except StopIteration: raise ValueError("Reset error, dates are empty.") # Reset baseline. self._reset_baseline() return self._scaled_data_as_state(self.current_date) def get_batch_data(self, batch_size=32): batch_indices = np.random.choice(self.t_data_indices, batch_size) if not self.use_sequence: batch_x = self.data_x[batch_indices] batch_y = self.data_y[batch_indices] else: batch_x = self.seq_data_x[batch_indices] batch_y = self.seq_data_y[batch_indices] return batch_x, batch_y def get_test_data(self): if not self.use_sequence: test_x = self.data_x[self.e_data_indices] test_y = self.data_y[self.e_data_indices] else: test_x = self.seq_data_x[self.e_data_indices] test_y = self.seq_data_y[self.e_data_indices] return test_x, test_y def forward(self, stock_code, action_code): # Check Trader. self.trader.remove_invalid_positions() self.trader.reset_reward() # Get stock data. stock = self._origin_data(stock_code, self.current_date) stock_next = self._origin_data(stock_code, self.next_date) # Execute transaction. action = self.trader.action_by_code(action_code) action(stock_code, stock, 100, stock_next) # Init episode status. episode_done = self.Running # Add action times. self.trader.action_times += 1 # Update date if need. if self.trader.action_times == self.code_count: self._update_profits_and_baseline() try: self.current_date, self.next_date = self.next_date, next( self.iter_dates) except StopIteration: episode_done = self.Done finally: self.trader.action_times = 0 # Get next state. state_next = self._scaled_data_as_state(self.current_date) # Return s_n, r, d, info. return state_next, self.trader.reward, episode_done, self.trader.cur_action_status def _update_profits_and_baseline(self): prices = [ self._origin_data(code, self.current_date).close for code in self.codes ] baseline_profits = np.dot( self.stocks_holding_baseline, np.transpose(prices)) - self.trader.initial_cash policy_profits = self.trader.profits self.trader.history_baselines.append(baseline_profits) self.trader.history_profits.append(policy_profits) def _reset_baseline(self): # Calculate cash piece. cash_piece = self.init_cash / self.code_count # Get stocks data. stocks = [ self._origin_data(code, self.current_date) for code in self.codes ] # Init stocks baseline. self.stocks_holding_baseline = [ int(math.floor(cash_piece / stock.close)) for stock in stocks ] @property def code_count(self): return len(self.codes) @property def index_code_count(self): return len(self.index_codes) @property def state_code_count(self): return len(self.state_codes) @property def data_dim(self): data_dim = self.state_code_count * ( self.scaled_frames[self.codes[0]].shape[1] - 1 ) # replaced by steven, trend patch # data_dim = self.state_code_count * self.scaled_frames[self.codes[0]].shape[1] if not self.use_sequence: if self.mix_trader_state: data_dim += (2 + self.code_count) return data_dim
def _init_options(self, **options): try: self.m_type = options['market'] except KeyError: self.m_type = 'stock' try: self.init_cash = options['cash'] except KeyError: self.init_cash = 100000 # 初始化记录器 try: self.logger = options['logger'] except KeyError: self.logger = None # 初始化? try: self.use_sequence = options['use_sequence'] except KeyError: self.use_sequence = False # 初始化是否进行归一化处理,如果输错,则默认归一化处理 try: self.use_normalized = options['use_normalized'] except KeyError: self.use_normalized = True #初始化是否混合交易信息 try: self.mix_trader_state = options['mix_trader_state'] except KeyError: self.mix_trader_state = True try: self.mix_index_state = options['mix_index_state'] except KeyError: self.mix_index_state = False # 初始化如果mix_index_state存在,则再index_codes后面增加一个‘sh’ finally: if self.mix_index_state: self.index_codes.append('sh') # # 初始化每个片段序列的长度,如果输出默认是5(天),如果输入的长度大于1, 则接受, 否则片段长度强制设置为2 try: self.seq_length = options['seq_length'] except KeyError: self.seq_length = 5 finally: self.seq_length = self.seq_length if self.seq_length > 1 else 2 #初始化训练数据集的比例,如果输入错误,则默认0.7,默认传入main函数中的0.98 try: self.training_data_ratio = options['training_data_ratio'] except KeyError: self.training_data_ratio = 0.7 # 初始化特征缩放器,默认标准缩放器 try: scaler = options['scaler'] except KeyError: scaler = StandardScaler # 展示代码为codes + index_codes,某次state_codes返回为<class 'list'>: ['600036', '601998', 'sh'] self.state_codes = self.codes + self.index_codes self.scaler = [ scaler() for _ in self.state_codes ] # 返回一个scaler()对象列表, state_codes里面有几个股票代码就初始化几个scaler对象 self.trader = Trader(self, cash=self.init_cash) self.doc_class = Stock if self.m_type == 'stock' else Future
class Market(object): Running = 0 Done = -1 def __init__(self, codes, start_date="2008-01-01", end_date="2019-07-19", **options): # Initialize codes. # 股票代码例如“[600318,600515]” self.codes = codes # self.index_codes = [] # self.state_codes = self.codes + self.index_codes, 某次state_codes返回为<class 'list'>: ['600036', '601998', 'sh'] self.state_codes = [] # Initialize dates. self.dates = [] # 全局变量dates self.t_dates = [] self.e_dates = [] # Initialize data frames. self.origin_frames = dict() self.scaled_frames = dict() # Initialize scaled data x, y. self.data_x = None #由_init_series_data返回 self.data_y = None #由_init_series_data返回 # Initialize scaled seq data x, y. self.seq_data_x = None #由_init_sequence_data返回 将列表转换成numpy数组 返回形状:(len(dates),5,3*特征dim),三维数组 self.seq_data_y = None #由_init_sequence_data返回 将列表转换成numpy数组 返回形状:(len(dates),3),二维数组 # Initialize flag date. # 下个日期 self.next_date = None self.iter_dates = None self.current_date = None # 初始化可选参数 self._init_options(**options) # 初始化股票数据 self._init_data(start_date, end_date) # 初始化可选参数 def _init_options(self, **options): try: self.m_type = options['market'] except KeyError: self.m_type = 'stock' try: self.init_cash = options['cash'] except KeyError: self.init_cash = 100000 # 初始化记录器 try: self.logger = options['logger'] except KeyError: self.logger = None # 初始化? try: self.use_sequence = options['use_sequence'] except KeyError: self.use_sequence = False # 初始化是否进行归一化处理,如果输错,则默认归一化处理 try: self.use_normalized = options['use_normalized'] except KeyError: self.use_normalized = True #初始化是否混合交易信息 try: self.mix_trader_state = options['mix_trader_state'] except KeyError: self.mix_trader_state = True try: self.mix_index_state = options['mix_index_state'] except KeyError: self.mix_index_state = False # 初始化如果mix_index_state存在,则再index_codes后面增加一个‘sh’ finally: if self.mix_index_state: self.index_codes.append('sh') # # 初始化每个片段序列的长度,如果输出默认是5(天),如果输入的长度大于1, 则接受, 否则片段长度强制设置为2 try: self.seq_length = options['seq_length'] except KeyError: self.seq_length = 5 finally: self.seq_length = self.seq_length if self.seq_length > 1 else 2 #初始化训练数据集的比例,如果输入错误,则默认0.7,默认传入main函数中的0.98 try: self.training_data_ratio = options['training_data_ratio'] except KeyError: self.training_data_ratio = 0.7 # 初始化特征缩放器,默认标准缩放器 try: scaler = options['scaler'] except KeyError: scaler = StandardScaler # 展示代码为codes + index_codes,某次state_codes返回为<class 'list'>: ['600036', '601998', 'sh'] self.state_codes = self.codes + self.index_codes self.scaler = [ scaler() for _ in self.state_codes ] # 返回一个scaler()对象列表, state_codes里面有几个股票代码就初始化几个scaler对象 self.trader = Trader(self, cash=self.init_cash) self.doc_class = Stock if self.m_type == 'stock' else Future def _init_data(self, start_date, end_date): self._init_data_frames(start_date, end_date) # 初始化原始数据和被scaled的原始数据 self._init_env_data() # 初始化seq_data和series_data self._init_data_indices() # 初始化训练集和测试集的索引集合和日期集合 def _validate_codes(self): if not self.state_code_count: raise ValueError("Codes cannot be empty.") for code in self.state_codes: if not self.doc_class.exist_in_db(code): raise ValueError( "Code: {} not exists in database.".format(code)) def _init_data_frames(self, start_date, end_date): # Remove invalid codes first. self._validate_codes() # Init columns and data set. columns, dates_set = ['open', 'high', 'low', 'close', 'volume'], set() # Load data. for index, code in enumerate(self.state_codes): # Load instrument docs by code. doc_class 此时为Stock,调用document里面Stock类 instrument_docs = self.doc_class.get_k_data( code, start_date, end_date) # 返回某个股票代码从start_date到end_date的数据,按date升序排列 # Init instrument dicts. instrument_dicts = [ instrument.to_dic() for instrument in instrument_docs ] # 调用document文档里的to_dict() 方法,返回这只股票所有行除了_id列以外的数据的to_dict() 值 # Split dates.分离出date列 dates = [instrument[1] for instrument in instrument_dicts ] #第0索引是tushare下载下来数据的第一列,日期在第二列 # Split instruments.# 分离出除了_id列和date列以外的所有数据 instruments = [instrument[2:] for instrument in instrument_dicts] # Update dates set. dates_set = dates_set.union( dates ) #去重返回两个集合的并集 ,经历循环结束后, 会把其他股票代码的日期也放到这个地方, 所以用union返回去重并集合 # Build origin and scaled frames. scaler = self.scaler[ index] # 把scaler()对象中第一个scaler()对象赋值到scaler, 用来为第一条数据进行缩放操作 scaler.fit(instruments) # 对分离的除了_id列和date列以外的所有数据进行缩放 instruments_scaled = scaler.transform(instruments) origin_frame = pd.DataFrame(data=instruments, index=dates, columns=columns) scaled_frame = pd.DataFrame(data=instruments_scaled, index=dates, columns=columns) # Build code - frame map.建立股票代码-dataframe的对应关系 self.origin_frames[ code] = origin_frame # origin_frames是以dates为索引, 包含5列分别为['open', 'high', 'low', 'close', 'volume']的Dataframe self.scaled_frames[ code] = scaled_frame # 用分别用Scaler()缩放器缩放过后的origin_frame, 也包含['open', 'high', 'low', 'close', 'volume']等列 # Init date iter. self.dates = sorted( list(dates_set)) #sorted, 对可迭代对象进行升序排序, 最终传递到market的self.dates变量中 # Rebuild index. for code in self.state_codes: origin_frame = self.origin_frames[code] scaled_frame = self.scaled_frames[code] # 对日期做并集、去重之后,将日期做为新索引, 对origin_frames,scaled_frames进行重新索引 self.origin_frames[code] = origin_frame.reindex( self.dates, method='bfill') # method = 'bfill' 指的是后向填充(或搬运)值 self.scaled_frames[code] = scaled_frame.reindex( self.dates, method='bfill') # # method = 'bfill' 指的是后向填充(或搬运)值 def _init_env_data(self): if not self.use_sequence: self._init_series_data() else: self._init_sequence_data() def _init_series_data(self): #初始化序列数据, # Calculate data count.data_count= 总数据行数-1 self.data_count = len( self.dates[:-1]) ##注意_init_sequence_data的data_count和 _seq_data不一样! # Calculate bound index. #切割训练数据和测试数据的索引 self.bound_index = int( self.data_count * self.training_data_ratio ) # ##注意bound_index在_init_sequence_data的data_count和 _seq_data中也不一样! # Init scaled_x, scaled_y. scaled_data_x, scaled_data_y = [], [] for index, date in enumerate(self.dates[:-1]): # Get current x, y. scaled_frames为一个字典 x = [ self.scaled_frames[code].iloc[index] for code in self.state_codes ] # 返回"600318","600324","sh"三只股票的iloc的第[0]行数据的键值对 # y 为第二天的数据,同时重建索引为0,1,2,3... y = [ self.scaled_frames[code].iloc[index + 1] for code in self.state_codes ] # 返回 # Convert x, y to array. # 转换当天数据为张量,并reshape为一维向量 x = np.array(x).reshape( (1, -1) ) #reshape成一维向量(1,len(dates)*len(state_codes)*5),相当于将拼接特征维度, 将sh上证指数和某一只股票的的特征维度拼接 # 转换第二天数据为向量 y = np.array(y) # Append x, y scaled_data_x.append(x) # scaled_data_y.append(y) # Convert list to array. self.data_x = np.array(scaled_data_x) # data_x形状为(len(dates),5*3) 2维 self.data_y = np.array(scaled_data_y) # def _init_sequence_data(self): # Calculate data count.data_count= 总数据行数-1, 之所以要减去1,是因为要减去列名那行 self.data_count = len( self.dates[:-1 - self.seq_length] ) #注意_init_sequence_data的data_count和 _init_series_data的不一样! 切片的优先级高于四则运算 , 所以这里返回的是总日期长度-1再-5 # Calculate bound index. self.bound_index = int(self.data_count * self.training_data_ratio) # 分割训练集和测试集的索引位置 # Init seqs_x, seqs_y. scaled_seqs_x, scaled_seqs_y = [], [] # Scale to valid dates. for date_index, date in enumerate(self.dates[:-1]): # 总共要循环dates列表长度次 # Continue until valid date index. if date_index < self.seq_length: # 保证在索引位置大于5, 否则中断后面代码 ,后面的date_index 大于或等于seq_length, 这里为5 continue data_x, data_y = [], [] #每次dates循环, 产生一个空列表data_x, data_y for index, code in enumerate( self.state_codes ): # 每次循环一个dates, 在这次循环下再循环state_code长度次, 这里state_codes长度为3只股票, 其中包括1只为上证指数"sh" # Get scaled frame by code. scaled_frame = self.scaled_frames[ code] #在某次循环内某个股票代码赋值给scaled_frame # Get instrument data x. instruments_x = scaled_frame.iloc[ date_index - self.seq_length: date_index] # 从索引位置-5到索引位置5天的数据,每只股票5天的数据,形状为(5,特征dim) data_x.append( np.array(instruments_x) ) # 往data_x空列表append某只股票代码5天的数据,一共循环state_codes长度==3 次, data_x里存在3个股票代码5天的数据 # Get instrument data y. if index < date_index: #date_index取值为[0,1,2,3,4,5,6,7,...1342], index取值为[0,1,2],这里做的限制,初步判断应该是seq_length有可能会小于等于2 if date_index < self.bound_index: # Get y, y is not at date index, but plus 1. (Training Set) instruments_y = scaled_frame.iloc[date_index + 1][ 'close'] # 标签为索引位置+1后的一天的close dim维度, 形状为() else: # Get y, y is at date index. (Test Set) instruments_y = scaled_frame.iloc[date_index + 1][ 'close'] # 标签为索引位置+1后的一天的close值, 每次取单个数,所以形状为() data_y.append( np.array(instruments_y)) # 循环完3次之,列表中有三个,形状为(3,) # Convert list to array. data_x = np.array(data_x) #将列表转换成numpy数组,形状变为(3,5,特征dim) data_y = np.array(data_y) #将列表转换成numpy数组,形状为(3,) seq_x = [] seq_y = data_y #index+1天的close价格 # Build seq x, y. for seq_index in range(self.seq_length): # 循环5次 seq_x.append(data_x[:, seq_index, :].reshape( (-1))) #将三只股票代码同一天的数据取出来,列表对象seq_x # Convert list to array. seq_x = np.array( seq_x ) # 将列表对象组装到numpy数组, 形状变为(5,3*特征dim),其中有一维度被展平,组合成的numpy数组为每天三个股票代码的特征dim的拼接, 一共产生5天的 scaled_seqs_x.append(seq_x) # 列表组装 scaled_seqs_y.append(seq_y) # 列表组装 # Convert seq from list to array. self.seq_data_x = np.array( scaled_seqs_x ) # 将列表转换成numpy数组 返回形状:(len(dates),5,3*特征dim),三维数组 ; 注意这是整个数据集 self.seq_data_y = np.array( scaled_seqs_y ) # 将列表转换成numpy数组 返回形状:(len(dates),3),二维数组; 注意这是整个数据集的标签 def _init_data_indices(self): # 此方法中seq_data和series_data也都不一样 # Calculate indices range. self.data_indices = np.arange(0, self.data_count) # Calculate train and eval indices. self.t_data_indices = self.data_indices[:self. bound_index] #返回训练数据集自然数集合 self.e_data_indices = self.data_indices[ self.bound_index:] #返回测试数据集自然数集合 # Generate train and eval dates. self.t_dates = self.dates[:self.bound_index] # 返回训练集日期集合 self.e_dates = self.dates[self.bound_index:] # 返回测试集日期集合 def _origin_data(self, code, date): date_index = self.dates.index( date) # list().index() python列表提供的方法, 返回索引位置,这里返回日期的索引位置 return self.origin_frames[code].iloc[date_index] def _scaled_data_as_state(self, date): if not self.use_sequence: data = self.data_x[self.dates.index( date)] # data_x形状为(len(dates),5*3) 2维,所以这里返回某一个15特征维度的向量 # 如果是混合交易,则在 if self.mix_trader_state: # mix_trader_state指的是混合有指数的方式 trader_state = self.trader.scaled_data_as_state() data = np.insert(data, -1, trader_state).reshape((1, -1)) return data else: return self.seq_data_x[self.dates.index(date)] def reset(self, mode='train'): # Reset trader. self.trader.reset() # Reset iter dates by mode. self.iter_dates = iter(self.t_dates) if mode == 'train' else iter( self.e_dates) try: self.current_date = next(self.iter_dates) self.next_date = next(self.iter_dates) except StopIteration: raise ValueError("Reset error, dates are empty.") # Reset baseline. self._reset_baseline() return self._scaled_data_as_state(self.current_date) def get_batch_data(self, batch_size=32): batch_indices = np.random.choice(self.t_data_indices, batch_size) #如果不是运用序列 if not self.use_sequence: batch_x = self.data_x[batch_indices] batch_y = self.data_y[batch_indices] #如果是用序列 else: batch_x = self.seq_data_x[batch_indices] # batch_x是5天窗口序列的数据 batch_y = self.seq_data_y[batch_indices] # batch_y是5天后1天的label数据 return batch_x, batch_y def get_test_data(self): if not self.use_sequence: test_x = self.data_x[self.e_data_indices] test_y = self.data_y[self.e_data_indices] else: test_x = self.seq_data_x[self.e_data_indices] # 这是整个测试集的所有输入序列 test_y = self.seq_data_y[ self.e_data_indices] # 这是整个测试集的输出序列,对应着每个输入序列后一天的close return test_x, test_y # 返回的是测试集 def forward(self, stock_code, action_code): # 向前预测, 每隔一个时间步生成一个新数据 # Check Trader. self.trader.remove_invalid_positions() self.trader.reset_reward() # Get stock data. stock = self._origin_data(stock_code, self.current_date) stock_next = self._origin_data(stock_code, self.next_date) # Execute transaction. action = self.trader.action_by_code(action_code) action(stock_code, stock, 100, stock_next) # Init episode status. episode_done = self.Running # Add action times. self.trader.action_times += 1 # Update date if need. if self.trader.action_times == self.code_count: self._update_profits_and_baseline() try: self.current_date, self.next_date = self.next_date, next( self.iter_dates) except StopIteration: episode_done = self.Done finally: self.trader.action_times = 0 # Get next state. state_next = self._scaled_data_as_state(self.current_date) # Return s_n, r, d, info. return state_next, self.trader.reward, episode_done, self.trader.cur_action_status def _update_profits_and_baseline(self): prices = [ self._origin_data(code, self.current_date).close for code in self.codes ] baseline_profits = np.dot( self.stocks_holding_baseline, np.transpose(prices)) - self.trader.initial_cash policy_profits = self.trader.profits self.trader.history_baselines.append(baseline_profits) self.trader.history_profits.append(policy_profits) def _reset_baseline(self): # Calculate cash piece. cash_piece = self.init_cash / self.code_count # Get stocks data. stocks = [ self._origin_data(code, self.current_date) for code in self.codes ] # Init stocks baseline. self.stocks_holding_baseline = [ int(math.floor(cash_piece / stock.close)) for stock in stocks ] @property # y_space def code_count( self ): #刚好对应于#由_init_sequence_data返回 将列表转换成numpy数组 返回形状:(len(dates),4),二维数组 return len(self.codes) @property def index_code_count(self): return len(self.index_codes) @property def state_code_count(self): return len(self.state_codes) @property def data_dim(self): # x_space 数据特征dim data_dim = self.state_code_count * self.scaled_frames[ self.codes[0]].shape[ 1] # 这里的state_code_count包含了'sh'指数, 本行data_dim = 5*5, ..(len(dates),5,5*特征dim), 三维数组 if not self.use_sequence: if self.mix_trader_state: data_dim += (2 + self.code_count) return data_dim