class A2C(Model): def __init__(self, scope_name, env, states_shape, n_actions, channel_min, channel_rate=2): super(A2C, self).__init__(scope_name) self.scope_name = scope_name self.channel_min = channel_min self.channel_rate = channel_rate self.states_shape = states_shape self.n_actions = n_actions self.env = env def buind(self, lr=1e-5): self.policy = Policy('Policy', env=self.env, state_shape=self.states_shape, n_actions=self.n_actions) self.policy.build(lr=lr, show_img=True) def predict(self, z): session = tf.get_default_session() feed_dict = {self.input_z: z} output = session.run(self.fake_img, feed_dict=feed_dict) return output def train(self, datas): session = tf.get_default_session() self.policy.train(datas.states, datas.rewards, datas.conditions, datas.actions, datas.last_actions, datas.values)
class Worker(object): def __init__(self, name, globalP): self.env = Toy() self.name = name self.policy = Policy(name + '/Policy', env=self.env, state_shape=self.env.observation_shape, n_actions=16) self.policy.build() self.pull_global_op = get_pull_global(globalP, self.policy) self.update_global_op = get_update_global(globalP, self.policy) def work(self): train_times = 0 while not COORD.should_stop() and train_times < MAX_EPISODE: datas = get_data(self.policy, 3, fixed=True, chooce_max=False) loss = self.train(datas.states, datas.rewards, datas.conditions, datas.actions, datas.last_actions, datas.values) self.pull_global() #if train_times % MAX_EP_STEPS == 0: # datas = get_data(self.globalP, self.env, 3) # print(datas.reward_total) # print(self.name + ' time:', train_times, # ' reward_total:', datas.reward_total, # ' mean_reward:', datas.reward_total / BATCH_SIZE) train_times += 1 def train(self, state, reward, condition, action, last_action, value): policy = self.policy lstm_state_c, lstm_state_h = policy.get_initial_features(len(state)) value_ = value[:, 1:] session = SESS feed_dict = { policy.last_action: last_action, policy.state: state, policy.action: action, policy.value_: value_, policy.reward: reward, policy.condition: condition, policy.init_state[0]: lstm_state_c, policy.init_state[1]: lstm_state_h } loss, _ = session.run([self.policy.loss, self.update_global_op], feed_dict=feed_dict) return loss def pull_global(self): session = SESS session.run(self.pull_global_op)
}) SUMMARY_WRITER.add_summary(summary_str, train_times) train_times += 1 if __name__ == "__main__": with tf.device("/cpu:0"): GLOBAL_P = Policy(GLOBAL_POLICY_NET_SCOPE, env=env, state_shape=N_F, n_actions=N_A) # we only need its params workers = [] # Create worker workers.append(SummerWorker(GLOBAL_P)) GLOBAL_P.build(show_img=True) for i in range(N_WORKERS): i_name = 'W_%i' % i # worker name workers.append(Worker(i_name, GLOBAL_P)) COORD = tf.train.Coordinator() SUMMARY_OP = tf.summary.merge_all() GLOBAL_P.complete() SESS = tf.get_default_session() SUMMARY_WRITER = tf.summary.FileWriter("log/", graph=SESS.graph) worker_threads = [] for worker in workers: job = lambda: worker.work() thread = threading.Thread(target=job)