class A3CActorThread(object):
    def __init__(self, thread_index, global_network, initial_learning_rate,
                 learning_rate_input, optimizer, max_global_time_step, device):

        self.thread_index = thread_index
        self.learning_rate_input = learning_rate_input
        self.max_global_time_step = max_global_time_step

        if USE_LSTM:
            self.local_network = A3CLSTMNetwork(STATE_DIM, STATE_CHN,
                                                ACTION_DIM, device,
                                                thread_index)
        else:
            self.local_network = A3CFFNetwork(STATE_DIM, STATE_CHN, ACTION_DIM,
                                              device, thread_index)
        self.local_network.create_loss(ENTROPY_BETA)
        self.gradients = tf.gradients(self.local_network.total_loss,
                                      self.local_network.get_vars())

        clip_accum_grads = [
            tf.clip_by_norm(accum_grad, 10.0) for accum_grad in self.gradients
        ]
        self.apply_gradients = optimizer.apply_gradients(
            zip(clip_accum_grads, global_network.get_vars()))
        # self.apply_gradients = optimizer.apply_gradients(zip(self.gradients, global_network.get_vars()))

        self.sync = self.local_network.sync_from(global_network)

        self.game_state = GameState(thread_index)

        self.local_t = 0
        self.initial_learning_rate = initial_learning_rate

        # for log
        self.episode_reward = 0.0
        self.episode_start_time = 0.0
        self.prev_local_t = 0
        return

    def _anneal_learning_rate(self, global_time_step):
        learning_rate = self.initial_learning_rate * \
            (self.max_global_time_step - global_time_step) / self.max_global_time_step
        if learning_rate < 0.0:
            learning_rate = 0.0
        return learning_rate

    def choose_action(self, policy_output):
        return np.random.choice(range(len(policy_output)), p=policy_output)

    def _record_log(self, sess, global_t, summary_writer, summary_op,
                    reward_input, reward, time_input, living_time):
        summary_str = sess.run(summary_op,
                               feed_dict={
                                   reward_input: reward,
                                   time_input: living_time
                               })
        summary_writer.add_summary(summary_str, global_t)
        summary_writer.flush()
        return

    def process(self, sess, global_t, summary_writer, summary_op, reward_input,
                time_input):
        states = []
        actions = []
        rewards = []
        values = []

        terminal_end = False
        # reduce the influence of socket connecting time
        if self.episode_start_time == 0.0:
            self.episode_start_time = timestamp()

        # copy weight from global network
        sess.run(self.sync)

        start_local_t = self.local_t
        if USE_LSTM:
            start_lstm_state = self.local_network.lstm_state_out

        for i in range(LOCAL_T_MAX):
            policy_, value_ = self.local_network.run_policy_and_value(
                sess, self.game_state.s_t)
            if self.thread_index == 0 and self.local_t % 1000 == 0:
                print 'policy=', policy_
                print 'value=', value_

            action_id = self.choose_action(policy_)

            states.append(self.game_state.s_t)
            actions.append(action_id)
            values.append(value_)

            self.game_state.process(action_id)
            reward = self.game_state.reward
            terminal = self.game_state.terminal

            self.episode_reward += reward
            rewards.append(np.clip(reward, -1.0, 1.0))

            self.local_t += 1

            # s_t1 -> s_t
            self.game_state.update()

            if terminal:
                terminal_end = True
                episode_end_time = timestamp()
                living_time = episode_end_time - self.episode_start_time

                self._record_log(sess, global_t, summary_writer, summary_op,
                                 reward_input, self.episode_reward, time_input,
                                 living_time)

                print("global_t=%d / reward=%.2f / living_time=%.4f") % (
                    global_t, self.episode_reward, living_time)

                # reset variables
                self.episode_reward = 0.0
                self.episode_start_time = episode_end_time
                self.game_state.reset()
                if USE_LSTM:
                    self.local_network.reset_lstm_state()
                break
            # log
            if self.local_t % 40 == 0:
                living_time = timestamp() - self.episode_start_time
                self._record_log(sess, global_t, summary_writer, summary_op,
                                 reward_input, self.episode_reward, time_input,
                                 living_time)
        # -----------end of batch (LOCAL_T_MAX)--------------------

        R = 0.0
        if not terminal_end:
            R = self.local_network.run_value(sess, self.game_state.s_t)
        # print ('global_t: %d, R: %f') % (global_t, R)

        states.reverse()
        actions.reverse()
        rewards.reverse()
        values.reverse()

        batch_state = []
        batch_action = []
        batch_td = []
        batch_R = []

        for (ai, ri, si, Vi) in zip(actions, rewards, states, values):
            R = ri + GAMMA * R
            td = R - Vi
            action = np.zeros([ACTION_DIM])
            action[ai] = 1

            batch_state.append(si)
            batch_action.append(action)
            batch_td.append(td)
            batch_R.append(R)

        cur_learning_rate = self._anneal_learning_rate(global_t)
        if USE_LSTM:
            batch_state.reverse()
            batch_action.reverse()
            batch_td.reverse()
            batch_R.reverse()
            sess.run(self.apply_gradients,
                     feed_dict={
                         self.local_network.state_input: batch_state,
                         self.local_network.action_input: batch_action,
                         self.local_network.td: batch_td,
                         self.local_network.R: batch_R,
                         self.local_network.step_size: [len(batch_state)],
                         self.local_network.initial_lstm_state:
                         start_lstm_state,
                         self.learning_rate_input: cur_learning_rate
                     })
        else:
            sess.run(self.apply_gradients,
                     feed_dict={
                         self.local_network.state_input: batch_state,
                         self.local_network.action_input: batch_action,
                         self.local_network.td: batch_td,
                         self.local_network.R: batch_R,
                         self.learning_rate_input: cur_learning_rate
                     })

        diff_local_t = self.local_t - start_local_t
        return diff_local_t
Exemplo n.º 2
0
class A3CActorThread(object):
    def __init__(self, thread_index, global_network, initial_learning_rate,
                 learning_rate_input, optimizer, max_global_time_step, device):

        self.thread_index = thread_index
        self.learning_rate_input = learning_rate_input
        self.max_global_time_step = max_global_time_step

        if USE_LSTM:
            self.local_network = A3CLSTMNetwork(STATE_DIM, STATE_CHN,
                                                ACTION_DIM, device,
                                                thread_index)
        else:
            self.local_network = A3CFFNetwork(STATE_DIM, STATE_CHN, ACTION_DIM,
                                              device, thread_index)
        self.local_network.create_loss(ENTROPY_BETA)
        self.gradients = tf.gradients(self.local_network.total_loss,
                                      self.local_network.get_vars())

        clip_accum_grads = [
            tf.clip_by_norm(accum_grad, 10.0) for accum_grad in self.gradients
        ]
        self.apply_gradients = optimizer.apply_gradients(
            zip(clip_accum_grads, global_network.get_vars()))
        # self.apply_gradients = optimizer.apply_gradients(zip(self.gradients, global_network.get_vars()))

        self.sync = self.local_network.sync_from(global_network)

        self.game_state = GameState(thread_index)

        self.local_t = 0
        self.initial_learning_rate = initial_learning_rate

        # for log
        self.episode_reward = 0.0
        self.episode_start_time = 0.0
        self.prev_local_t = 0
        return

    def _anneal_learning_rate(self, global_time_step):
        learning_rate = self.initial_learning_rate * \
            (self.max_global_time_step - global_time_step) / self.max_global_time_step
        if learning_rate < 0.0:
            learning_rate = 0.0
        return learning_rate

    def choose_action(self, policy_output):
        return np.random.choice(range(len(policy_output)), p=policy_output)

    def _record_log(self, sess, global_t, summary_writer, summary_op,
                    reward_input, reward, time_input, living_time):
        summary_str = sess.run(summary_op,
                               feed_dict={
                                   reward_input: reward,
                                   time_input: living_time
                               })
        summary_writer.add_summary(summary_str, global_t)
        summary_writer.flush()
        return

    def _discount_accum_reward(self, rewards, running_add=0.0, gamma=0.99):
        """ discounted the reward using gamma
        """
        discounted_r = np.zeros_like(rewards, dtype=np.float32)
        for t in reversed(range(len(rewards))):
            running_add = rewards[t] + running_add * gamma
            discounted_r[t] = running_add

        return list(discounted_r)

    def process(self, sess, global_t, summary_writer, summary_op, reward_input,
                time_input):
        batch_state = []
        batch_action = []
        batch_reward = []

        terminal_end = False
        # reduce the influence of socket connecting time
        if self.episode_start_time == 0.0:
            self.episode_start_time = timestamp()

        # copy weight from global network
        sess.run(self.sync)

        start_local_t = self.local_t
        if USE_LSTM:
            start_lstm_state = self.local_network.lstm_state_out

        for i in range(LOCAL_T_MAX):
            policy_ = self.local_network.run_policy(sess, self.game_state.s_t)
            if self.thread_index == 0 and self.local_t % 1000 == 0:
                print 'policy=', policy_

            action_id = self.choose_action(policy_)

            action_onehot = np.zeros([ACTION_DIM])
            action_onehot[action_id] = 1
            batch_state.append(self.game_state.s_t)
            batch_action.append(action_onehot)

            self.game_state.process(action_id)
            reward = self.game_state.reward
            terminal = self.game_state.terminal

            self.episode_reward += reward
            batch_reward.append(np.clip(reward, -1.0, 1.0))

            self.local_t += 1

            # s_t1 -> s_t
            self.game_state.update()

            if terminal:
                terminal_end = True
                episode_end_time = timestamp()
                living_time = episode_end_time - self.episode_start_time

                self._record_log(sess, global_t, summary_writer, summary_op,
                                 reward_input, self.episode_reward, time_input,
                                 living_time)

                print("global_t=%d / reward=%.2f / living_time=%.4f") % (
                    global_t, self.episode_reward, living_time)

                # reset variables
                self.episode_reward = 0.0
                self.episode_start_time = episode_end_time
                self.game_state.reset()
                if USE_LSTM:
                    self.local_network.reset_lstm_state()
                break
            # log
            if self.local_t % 40 == 0:
                living_time = timestamp() - self.episode_start_time
                self._record_log(sess, global_t, summary_writer, summary_op,
                                 reward_input, self.episode_reward, time_input,
                                 living_time)
        # -----------end of batch (LOCAL_T_MAX)--------------------

        R = 0.0
        if not terminal_end:
            R = self.local_network.run_value(sess, self.game_state.s_t)
        # print ('global_t: %d, R: %f') % (global_t, R)

        batch_value = self.local_network.run_batch_value(
            sess, batch_state, start_lstm_state)
        batch_R = self._discount_accum_reward(batch_reward, R, GAMMA)
        batch_td = np.array(batch_R) - np.array(batch_value)
        cur_learning_rate = self._anneal_learning_rate(global_t)

        # print("=" * 60)
        # print(batch_value)
        # print(self.local_network.run_batch_value(sess, batch_state, start_lstm_state))
        # print("=" * 60)
        # import sys
        # sys.exit()

        if USE_LSTM:
            sess.run(self.apply_gradients,
                     feed_dict={
                         self.local_network.state_input: batch_state,
                         self.local_network.action_input: batch_action,
                         self.local_network.td: batch_td,
                         self.local_network.R: batch_R,
                         self.local_network.step_size: [len(batch_state)],
                         self.local_network.initial_lstm_state:
                         start_lstm_state,
                         self.learning_rate_input: cur_learning_rate
                     })
        else:
            sess.run(self.apply_gradients,
                     feed_dict={
                         self.local_network.state_input: batch_state,
                         self.local_network.action_input: batch_action,
                         self.local_network.td: batch_td,
                         self.local_network.R: batch_R,
                         self.learning_rate_input: cur_learning_rate
                     })

        diff_local_t = self.local_t - start_local_t
        return diff_local_t