N_WORKERS = multiprocessing.cpu_count() MAX_GLOBAL_EP = 50 GLOBAL_NET_SCOPE = 'Global_Net' UPDATE_GLOBAL_ITER = 50 # GAMMA = 0.9 GAMMA = 0.8 # ENTROPY_BETA = 0.001 ENTROPY_BETA = 0.1 LR_A = 0.001 # learning rate for actor # LR_C = 0.001 # learning rate for critic LR_C = 0.01 # learning rate for critic GLOBAL_RUNNING_R = [] GLOBAL_EP = 0 env = StockEnv() N_S = env.get_state().shape[0] N_A = 4 logger = Logger('A3C') class ACNet(object): def __init__(self, scope, N_S, A_S, globalAC=None): if scope == GLOBAL_NET_SCOPE: # get global network with tf.variable_scope(scope): self.s = tf.placeholder(tf.float32, [None, N_S], 'S') self.a_params, self.c_params = self._build_net(scope)[-2:] else: # local net, calculate losses with tf.variable_scope(scope): self.s = tf.placeholder(tf.float32, [None, N_S], 'S') self.a_his = tf.placeholder(tf.int32, [
class Worker(object): def __init__(self, name, globalAC): self.env = StockEnv() self.name = name self.AC = ACNet(name,self.env.get_state().shape[0], 4, globalAC) def _update_global_reward(self, ep_r): global GLOBAL_RUNNING_R, GLOBAL_EP if len(GLOBAL_RUNNING_R) == 0: # record running episode reward GLOBAL_RUNNING_R.append(ep_r) else: GLOBAL_RUNNING_R.append(0.99 * GLOBAL_RUNNING_R[-1] + 0.01 * ep_r) logger.debug( [self.name, "Ep:", GLOBAL_EP, "| Ep_r: %i" % GLOBAL_RUNNING_R[-1]] ) GLOBAL_EP += 1 def _update_globa_acnet(self, done, s_, buffer_s, buffer_a, buffer_r): if done: v_s_ = 0 # terminal else: v_s_ = SESS.run(self.AC.v, {self.AC.s: s_[np.newaxis, :]})[0, 0] buffer_v_target = [] for r in buffer_r[::-1]: # reverse buffer r v_s_ = r + GAMMA * v_s_ buffer_v_target.append(v_s_) buffer_v_target.reverse() buffer_s, buffer_a, buffer_v_target = np.vstack( buffer_s), np.array(buffer_a), np.vstack(buffer_v_target) feed_dict = { self.AC.s: buffer_s, self.AC.a_his: buffer_a, self.AC.v_target: buffer_v_target, } self.AC.update_global(feed_dict) def work(self): total_step = 1 buffer_s, buffer_a, buffer_r = [], [], [] self.env.reset() if self.name == 'W_0': self.env.render() while not COORD.should_stop(): ep_r = 0 while True: s = self.env._get_state() a, p = self.AC.choose_action(s) s_, r, done = self.env.step(a) if done: r = -0.5 ep_r += r buffer_s.append(s) buffer_a.append(a) buffer_r.append(r) if total_step % UPDATE_GLOBAL_ITER == 0 or done: self._update_globa_acnet(done, s_, buffer_s, buffer_a, buffer_r) buffer_s, buffer_a, buffer_r = [], [], [] self.AC.pull_global() # s = s_ total_step += 1 if done: self._update_global_reward(ep_r) break if self.name == 'W_0': logger.debug(["s", s, " a:", a, " p:", p, " r:", r, " total_step:", total_step, 'total', self.env.total]) time.sleep(0.5) def train(self): global GLOBAL_RUNNING_R, GLOBAL_EP total_step = 1 buffer_s, buffer_a, buffer_r = [], [], [] while not COORD.should_stop() and GLOBAL_EP < MAX_GLOBAL_EP: s = self.env.reset() ep_r = 0 while True: # if self.name == 'W_0': # self.env.render() a, p = self.AC.choose_action(s) s_, r, done = self.env.step(a) if done: r = -0.5 ep_r += r buffer_s.append(s) buffer_a.append(a) buffer_r.append(r) if total_step % UPDATE_GLOBAL_ITER == 0 or done: # update global and assign to local net self._update_globa_acnet(done, s_, buffer_s, buffer_a, buffer_r) buffer_s, buffer_a, buffer_r = [], [], [] self.AC.pull_global() if done: self._update_global_reward(ep_r) logger.debug(["s", s, " a:", a, " p:", p, " r:", r, " total_step:", total_step, 'total', self.env.total]) break s = s_ total_step += 1
class Worker(object): GAMMA = 0.9 GLOBAL_RUNNING_R = [] GLOBAL_EP = 0 def __init__(self, sess, name, N_S, N_A, globalAC): self.SESS = sess self.N_S = N_S self.N_A = N_A self.env = StockEnv() self.name = name self.AC = A3CNet(self.SESS, self.name, self.N_S, self.N_A, globalAC) # self.saver = tf.train.Saver() def _record_global_reward_and_print(self, global_runing_rs, ep_r, global_ep, total_step): global_runing_rs.append(ep_r) try: print(self.name, "Ep:", global_ep, "| Ep_r: %i" % global_runing_rs[-1], "| total step:", total_step) except Exception as e: print(e) def train(self): buffer_s, buffer_a, buffer_r = [], [], [] s = self.env.reset() ep_r = 0 total_step = 1 def reset(): nonlocal ep_r, total_step self.env.reset() ep_r = 0 total_step = 1 while not COORD.should_stop() and self.GLOBAL_EP < MAX_GLOBAL_EP: # s = self.env.reset() # ep_r = 0 # total_step = 1 reset() while total_step < MAX_TOTAL_STEP: try: s = self.env.get_state() a, p = self.AC.choose_action(s) s_, r, done = self.env.step(a) if done: r = -2 ep_r += r buffer_s.append(s) buffer_a.append(a) buffer_r.append(r) if total_step % UPDATE_GLOBAL_ITER == 0 or done: # update global and assign to local net self.AC.update(done, s_, buffer_r, buffer_s, buffer_a) buffer_s, buffer_a, buffer_r = [], [], [] if done: self._record_global_reward_and_print( self.GLOBAL_RUNNING_R, ep_r, self.GLOBAL_EP, total_step) self.GLOBAL_EP += 1 reset() # s = s_ total_step += 1 if self.name == 'W_0': self.env.render() time.sleep(0.05) logger.debug([ "s ", s, " v ", self.AC.get_v(s), " a ", a, " p ", p, " ep_r ", ep_r, " total ", self.env.total, " acct ", self.env.acct ]) except Exception as e: print(e) try: print(self.name, " not done,may be donkey!", " total_step:", total_step) except Exception as e: print(e)