Example #1
0
class AgentTrainer(object):
    def __init__(self, config):
        # Create session to store trained parameters
        self.session = tf.Session()

        self.action_count = config["action_count"]

        # Create agent for training
        self.agent = DQNAgent(self.action_count)

        # Create memory to store observations
        self.memory = ExperienceMemory(config["replay_memory_size"])

        # Tools for saving and loading networks
        self.saver = tf.train.Saver()

        # Last action that agent performed
        self.last_action_index = None

        # Deque to keep track of average reward and play time
        self.game_history = GameHistory(config["match_memory_size"])

        # Deque to store losses
        self.episode_history = EpisodeHistory(config["replay_memory_size"])

        self.INITIAL_EPSILON = config["initial_epsilon"]
        self.FINAL_EPSILON = config["final_epsilon"]
        self.OBSERVE = config["observe_step_count"]
        self.EXPLORE = config["explore_step_count"]
        self.FRAME_PER_ACTION = config["frame_per_action"]
        self.GAMMA = config["gamma"]
        self.LOG_PERIOD = config["log_period"]
        self.BATCH_SIZE = config["batch_size"]

    def init_training(self):
        # Initialize training parameters
        self.session.run(tf.global_variables_initializer())
        self.epsilon = self.INITIAL_EPSILON
        self.t = 0
        self.last_action_index = None

    def load_model(self, path):
        checkpoint = tf.train.get_checkpoint_state(path)
        if checkpoint and checkpoint.model_checkpoint_path:
            self.saver.restore(self.session, checkpoint.model_checkpoint_path)
            print("Successfully loaded: {}".format(checkpoint.model_checkpoint_path))
        else:
            print("Could not find old network weights")

    def save_model(self, path):
        # Replace with os.path.join
        self.saver.save(self.session, path + "/dqn", global_step=self.t)

    def reset_state(self, initial_state):
        # Get the first state by doing nothing and preprocess the image to 80x80x4
        x_t = initial_state
        x_t = transformImage(x_t)
        self.s_t = np.concatenate((x_t, x_t, x_t, x_t), axis=2)
        self.match_reward = 0
        self.match_playtime = 0
        self.gamma_pow = 1

    def act(self):
        # Choose an action epsilon greedily
        action_index = 0
        if self.t % self.FRAME_PER_ACTION == 0:
            if np.random.random() <= self.epsilon:
                action_index = np.random.randint(0, self.action_count)
            else:
                action_index = self.agent.act(self.session, self.s_t)
        else:
            action_index = self.last_action_index  # do the same thing as before
        self.last_action_index = action_index
        return action_index

    def process_frame(self, screen, reward, terminal):
        if self.last_action_index is None:
            self.reset_state(screen)
            return

        a_t = np.zeros([self.action_count])
        a_t[self.last_action_index] = 1

        # scale down epsilon
        if self.epsilon > self.FINAL_EPSILON and self.t > self.OBSERVE:
            self.epsilon -= (self.INITIAL_EPSILON - self.FINAL_EPSILON) / self.EXPLORE

        # run the selected action and observe next state and reward
        x_t1, r_t = screen, reward
        x_t1 = transformImage(x_t1)
        s_t1 = np.append(x_t1, self.s_t[:, :, :3], axis=2)

        # store the transition in memory
        self.memory.add_experience((self.s_t, a_t, r_t, s_t1, terminal))

        # only train if done observing
        if self.t > self.OBSERVE:
            loss = self.make_train_step()
            self.episode_history.add_episode(Episode(loss))

        # update the old values
        self.s_t = s_t1
        self.t += 1

        # print info
        if self.t % self.LOG_PERIOD == 0:
            print("TIMESTEP {}, EPSILON {}, EPISODE_STATS {}, MATCH_STATS {}".format(
                self.t,
                self.epsilon,
                self.episode_history.get_average_stats(),
                self.game_history.get_average_stats()))
            sys.stdout.flush()

        self.match_reward += r_t * self.gamma_pow
        self.match_playtime += 1
        self.gamma_pow *= self.GAMMA

        if terminal:
            self.game_history.add_match(MatchResults(
                self.match_reward,
                self.match_playtime,
                reward))
            self.reset_state(screen)

    def make_train_step(self):
        # sample a minibatch to train on
        minibatch = self.memory.sample(self.BATCH_SIZE)

        # get the batch variables
        s_j_batch = [d[0] for d in minibatch]
        a_batch = [d[1] for d in minibatch]
        r_batch = [d[2] for d in minibatch]
        s_j1_batch = [d[3] for d in minibatch]

        # get the batch variables
        # s_j_batch, a_batch, r_batch, s_j1_batch, terminal_batch = zip(*minibatch)
        action_scores_batch = np.array(self.agent.score_actions(self.session, s_j1_batch))
        # r_future = GAMMA * (1 - np.array(terminal_batch)) * np.max(action_scores_batch, axis=1)
        # y_batch = r_batch + r_future

        y_batch = []
        for i in range(0, len(minibatch)):
            # if terminal only equals reward
            if minibatch[i][4]:
                y_batch.append(r_batch[i])
            else:
                y_batch.append(r_batch[i] + self.GAMMA * np.max(action_scores_batch[i]))

        return self.agent.train(self.session, y_batch, a_batch, s_j_batch)