class DQNAgent(Agent): def __init__(self, config, save_name=None): super().__init__(config, save_name) def load_model(self, model_path): self.model = load_model(model_path) def init_env(self): self.env = gym.make(self.env_name, csv_list=self.csv_list, trading_cost=self.trading_cost, time_cost=self.time_cost, market=self.market) self.env = MyWrapper(self.env) np.random.seed(self.seed) self.env.seed(self.seed) self.env.action_space.seed(self.seed) random.seed(self.seed) tf.random.set_random_seed(self.seed) self.steps_per_episode = int(self.env.processed_array.shape[0] * self.split_ratio) self.nb_actions = self.env.action_space.n def create_agent(self): session_conf = tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1) sess = tf.Session(graph=tf.get_default_graph(), config=session_conf) K.set_session(sess) self.model = Model.simple(self.env) memory = SequentialMemory(limit=self.memory_len, window_length=1) policy = EpsGreedyQPolicy(eps=self.eps) self.agent = DQNAgent(model=self.model, nb_actions=self.nb_actions, memory=memory, nb_steps_warmup=self.steps_per_episode, target_model_update=self.target_model_update, policy=policy) self.agent.compile(Adam(lr=self.lr)) def predict(self, df): data = self.env.fe.transform(df.iloc[:]).values val = np.argmax(self.agent.select_action(data[-1:], do_train=False)) self.signal_transform(val) def signal_transform(self, val): if val == 0: print('Buy') elif val == 1: print('Sell') else: print('Close/Hold')