class A3CAgent(object):
    def __init__(self):
        self.num_state = OBS_SPACE  # observation size
        self.num_actions = NUM_ACTIONS  # number of actions
        self.lr = tf.Variable(3e-4)  # variable used for decaying learning rate
        self.starter_lr = 3e-4  # start value of learning rate

        # optimizer that trains the global network with the gradients of the locals
        # use locking because multiple threads
        self.optimizer = tf.train.AdamOptimizer(learning_rate=self.starter_lr,
                                                use_locking=True)

        # the global Actor-Critic network
        self.global_network = Actor_Critic(self.num_actions)
        # prepare the global network - used to construct the network on eager execution
        self.global_network(
            tf.convert_to_tensor(np.random.random((1, 84, 84, 4)),
                                 dtype=tf.float32))

        self.discount_rate = 0.99

    def start_threads(self):
        # max number of episodes
        max_eps = 1e6
        envs = []
        # create 1 local enviroment for each thread
        for _ in range(NUM_THREADS):
            _env = gym_super_mario_bros.make(env_name)
            _env = JoypadSpace(_env, SIMPLE_MOVEMENT)
            env = atari_wrapper.wrap_dqn(_env)
            envs.append(env)
        # create the threads and assign them their enviroment and exploration rate
        threads = []
        for i in range(NUM_THREADS):
            thread = threading.Thread(
                target=train_thread,
                daemon=True,
                args=(self, max_eps, envs[i],
                      agent.discount_rate, self.optimizer, stats,
                      AnnealingVariable(.7, 1e-20, 10000), i))
            threads.append(thread)

        # starts the threads
        for t in threads:
            print("STARTING")
            t.start()
            time.sleep(0.5)
        try:
            [t.join() for t in threads]  # wait for threads to finish
        except KeyboardInterrupt:
            print("Exiting threads!")

    def save_weights(self):
        print("Saving Weights")
        self.global_network.save_weights("A3CMarioWeights.h5")

    def restore_weights(self):
        print("Restoring Weights!")
        self.global_network.load_weights("A3CMarioWeights.h5")
示例#2
0
class A3CAgent(object):
    def __init__(self):
        self.num_actions = NUM_ACTIONS  # number of actions
        self.starter_lr = 1e-4  # start value of learning rate

        # optimizer that trains the global network with the gradients of the locals
        # use locking because multiple threads
        self.optimizer = tf.train.AdamOptimizer(learning_rate=self.starter_lr,
                                                use_locking=True)
        # the global Actor-Critic network
        self.global_network = Actor_Critic(self.num_actions)
        # prepare the global network - used to construct the network on eager execution
        self.global_network(
            tf.convert_to_tensor(np.random.random((1, 84, 84, 4)),
                                 dtype=tf.float32))
        self.restore_weights()

    def pick_action(self, state, exploration_rate=0.0):
        if np.random.random() < exploration_rate:
            return test_env.action_space.sample()  # pick randomly

        state = np.expand_dims(state, axis=0)
        logits, _ = self.global_network(state)
        probs = tf.nn.softmax(logits)
        action = np.random.choice(self.num_actions, 1, p=probs.numpy()[0])
        return action[0]

    def play(self, env, stats, episodes: int = 100, exploration_rate=0.0):
        rewards_arr = np.zeros(episodes)
        for episode in range(episodes):
            episode_reward = 0
            done = False
            state = env.reset()
            while not done:
                env.render()
                # time.sleep(0.05)
                action = self.pick_action(state, exploration_rate)
                next_state, reward, done, _ = env.step(action)
                episode_reward += reward
                state = next_state
            if callable(stats):
                stats(self, episode_reward)
            rewards_arr[episode] = episode_reward
            print(episode_reward)
        stats.save_stats()
        return rewards_arr

    def restore_weights(self):
        self.global_network.load_weights('A3CPong.h5')