Beispiel #1
0
def measure(name, iters=5000, **settings):
    print(name)
    for k, v in settings.items():
        print("\t{}: {}".format(k, v))

    # Vizdoom wrapper
    doom_wrapper = VizdoomWrapper(**settings)
    start = time()
    for _ in trange(iters, leave=False):
        current_img, current_misc = doom_wrapper.get_current_state()
        action_index = randint(0, doom_wrapper.actions_num - 1)
        doom_wrapper.make_action(action_index)

        if doom_wrapper.is_terminal():
            doom_wrapper.reset()
    end = time()
    wrapper_t = (end - start)

    # Vanilla vizdoom:
    doom = vzd.DoomGame()
    if "scenarios_path" not in settings:
        scenarios_path = vzd.__path__[0] + "/scenarios"
    else:
        scenarios_path = settings["scenarios_path"]
    config_file = scenarios_path + "/" + settings["config_file"]
    doom.load_config(config_file)
    doom.set_window_visible(False)
    doom.set_screen_format(vzd.ScreenFormat.GRAY8)
    doom.set_screen_resolution(vzd.ScreenResolution.RES_160X120)
    doom.init()
    actions = [
        list(a)
        for a in it.product([0, 1],
                            repeat=len(doom.get_available_game_variables()))
    ]
    start = time()
    frame_skip = settings["frame_skip"]
    for _ in trange(iters, leave=False):
        if doom.is_episode_finished():
            doom.new_episode()
        doom.make_action(choice(actions), frame_skip)

    end = time()
    vanilla_t = end - start
    print(green("\twrapper: {:0.2f} steps/s".format(iters / wrapper_t)))
    print(
        green("\twrapper: {:0.2f} s/1000 steps".format(wrapper_t / iters *
                                                       1000)))
    print(blue("\tvanilla: {:0.2f} steps/s".format(iters / vanilla_t)))
    print(
        blue("\tvanilla: {:0.2f} s/1000 steps\n".format(vanilla_t / iters *
                                                        1000)))
class A3CLearner(Thread):
    def __init__(self,
                 thread_index=0,
                 game=None,
                 model_savefile=None,
                 network_class="ACLstmNet",
                 global_steps_counter=None,
                 scenario_tag=None,
                 run_id_string=None,
                 session=None,
                 tf_logdir=None,
                 global_network=None,
                 optimizer=None,
                 learning_rate=None,
                 test_only=False,
                 test_interval=1,
                 write_summaries=True,
                 enable_progress_bar=True,
                 deterministic_testing=True,
                 save_interval=1,
                 writer_max_queue=10,
                 writer_flush_secs=120,
                 gamma_compensation=False,
                 figar_gamma=False,
                 gamma=0.99,
                 show_heatmaps=True,
                 **settings):
        super(A3CLearner, self).__init__()

        log("Creating actor-learner #{}.".format(thread_index))
        self.thread_index = thread_index

        self._global_steps_counter = global_steps_counter
        self.write_summaries = write_summaries
        self.save_interval = save_interval
        self.enable_progress_bar = enable_progress_bar
        self._model_savefile = None
        self._train_writer = None
        self._test_writer = None
        self._summaries = None
        self._session = session
        self.deterministic_testing = deterministic_testing
        self.local_steps = 0
        # TODO epoch as tf variable?
        self._epoch = 1
        self.train_scores = []
        self.train_actions = []
        self.train_frameskips = []
        self.show_heatmaps = show_heatmaps
        self.test_interval = test_interval

        self.local_steps_per_epoch = settings["local_steps_per_epoch"]
        self._run_tests = settings["test_episodes_per_epoch"] > 0 and settings["run_tests"]
        self.test_episodes_per_epoch = settings["test_episodes_per_epoch"]
        self._epochs = np.float32(settings["epochs"])
        self.max_remembered_steps = settings["max_remembered_steps"]

        assert not (gamma_compensation and figar_gamma)

        gamma = np.float32(gamma)

        if gamma_compensation:
            self.scale_gamma = lambda fskip: ((1 - gamma ** fskip) / (1 - gamma), gamma ** fskip)
        elif figar_gamma:
            self.scale_gamma = lambda fskip: (1.0, gamma ** fskip)
        else:
            self.scale_gamma = lambda _: (1.0, gamma)

        if self.write_summaries and thread_index == 0 and not test_only:
            assert tf_logdir is not None
            self.run_id_string = run_id_string
            self.tf_models_path = settings["models_path"]
            create_directory(tf_logdir)

            if self.tf_models_path is not None:
                create_directory(self.tf_models_path)

        if game is None:
            self.doom_wrapper = VizdoomWrapper(**settings)
        else:
            self.doom_wrapper = game
        misc_len = self.doom_wrapper.misc_len
        img_shape = self.doom_wrapper.img_shape
        self.use_misc = self.doom_wrapper.use_misc

        self.actions_num = self.doom_wrapper.actions_num
        self.local_network = getattr(networks, network_class)(actions_num=self.actions_num, img_shape=img_shape,
                                                              misc_len=misc_len,
                                                              thread=thread_index, **settings)

        if not test_only:
            self.learning_rate = learning_rate
            # TODO check gate_gradients != Optimizer.GATE_OP
            grads_and_vars = optimizer.compute_gradients(self.local_network.ops.loss,
                                                         var_list=self.local_network.get_params())
            grads, local_vars = zip(*grads_and_vars)

            grads_and_global_vars = zip(grads, global_network.get_params())
            self.train_op = optimizer.apply_gradients(grads_and_global_vars, global_step=tf.train.get_global_step())

            self.global_network = global_network
            self.local_network.prepare_sync_op(global_network)

        if self.thread_index == 0 and not test_only:
            self._model_savefile = model_savefile
            if self.write_summaries:
                self.actions_placeholder = tf.placeholder(tf.int32, None)
                self.frameskips_placeholder = tf.placeholder(tf.int32, None)
                self.scores_placeholder, summaries = setup_vector_summaries(scenario_tag + "/scores")

                # TODO remove scenario_tag from histograms
                a_histogram = tf.summary.histogram(scenario_tag + "/actions", self.actions_placeholder)
                fs_histogram = tf.summary.histogram(scenario_tag + "/frameskips", self.frameskips_placeholder)
                score_histogram = tf.summary.histogram(scenario_tag + "/scores", self.scores_placeholder)
                lr_summary = tf.summary.scalar(scenario_tag + "/learning_rate", self.learning_rate)
                summaries.append(lr_summary)
                summaries.append(a_histogram)
                summaries.append(fs_histogram)
                summaries.append(score_histogram)
                self._summaries = tf.summary.merge(summaries)
                self._train_writer = tf.summary.FileWriter("{}/{}/{}".format(tf_logdir, self.run_id_string, "train"),
                                                           flush_secs=writer_flush_secs, max_queue=writer_max_queue)
                self._test_writer = tf.summary.FileWriter("{}/{}/{}".format(tf_logdir, self.run_id_string, "test"),
                                                          flush_secs=writer_flush_secs, max_queue=writer_max_queue)

    def heatmap(self, actions, frameskips):
        min_frameskip = np.min(frameskips)
        max_frameskip = np.max(frameskips)
        fs_values = range(min_frameskip, max_frameskip + 1)

        a_labels = [str(a) for a in self.doom_wrapper.actions]

        mat = np.zeros((self.actions_num, (len(fs_values))))

        for f, a in zip(frameskips, actions):
            mat[a, f - min_frameskip] += 1

        return string_heatmap(mat, fs_values, a_labels)

    @staticmethod
    def choose_best_index(policy, deterministic=True):
        if deterministic:
            return np.argmax(policy)

        r = random.random()
        cummulative_sum = 0.0
        for i, p in enumerate(policy):
            cummulative_sum += p
            if r <= cummulative_sum:
                return i

        return len(policy) - 1

    def make_training_step(self):
        states_img = []
        states_misc = []
        actions = []
        rewards_reversed = []
        Rs = []

        self._session.run(self.local_network.ops.sync)

        initial_network_state = None
        if self.local_network.has_state():
            initial_network_state = self.local_network.get_current_network_state()

        terminal = None
        steps_performed = 0
        for _ in range(self.max_remembered_steps):
            steps_performed += 1
            current_state = self.doom_wrapper.get_current_state()
            policy = self.local_network.get_policy(self._session, current_state)
            action_index = A3CLearner.choose_best_index(policy, deterministic=False)
            states_img.append(current_state[0])
            states_misc.append(current_state[1])
            actions.append(action_index)
            reward = self.doom_wrapper.make_action(action_index)
            terminal = self.doom_wrapper.is_terminal()
            rewards_reversed.insert(0, reward)
            self.local_steps += 1
            if terminal:
                if self.thread_index == 0:
                    self.train_scores.append(self.doom_wrapper.get_total_reward())
                self.doom_wrapper.reset()
                if self.local_network.has_state():
                    self.local_network.reset_state()
                break

        self.train_actions += actions
        self.train_frameskips += [self.doom_wrapper.frameskip] * len(actions)

        if terminal:
            R = 0.0
        else:
            R = self.local_network.get_value(self._session, self.doom_wrapper.get_current_state())

        # #TODO this could be handles smarter ....
        for ri in rewards_reversed:
            scale, gamma = self.scale_gamma(self.doom_wrapper.frameskip)
            R = scale * ri + gamma * R
            Rs.insert(0, R)

        train_op_feed_dict = {
            self.local_network.vars.state_img: states_img,
            self.local_network.vars.a: actions,
            self.local_network.vars.R: Rs
        }
        if self.use_misc:
            train_op_feed_dict[self.local_network.vars.state_misc] = states_misc

        if self.local_network.has_state():
            train_op_feed_dict[self.local_network.vars.initial_network_state] = initial_network_state
            train_op_feed_dict[self.local_network.vars.sequence_length] = [len(actions)]

        self._session.run(self.train_op, feed_dict=train_op_feed_dict)

        return steps_performed

    def run_episode(self, deterministic=True, return_stats=False):
        self.doom_wrapper.reset()
        if self.local_network.has_state():
            self.local_network.reset_state()
        actions = []
        frameskips = []
        rewards = []
        while not self.doom_wrapper.is_terminal():
            current_state = self.doom_wrapper.get_current_state()
            action_index, frameskip = self._get_best_action(self._session, current_state, deterministic=deterministic)
            reward = self.doom_wrapper.make_action(action_index, frameskip)
            if return_stats:
                actions.append(action_index)
                if frameskip is None:
                    frameskip = self.doom_wrapper.frameskip
                frameskips.append(frameskip)
                rewards.append(reward)

        total_reward = self.doom_wrapper.get_total_reward()
        if return_stats:
            return total_reward, actions, frameskips, rewards
        else:
            return total_reward

    def test(self, episodes_num=None, deterministic=True):
        if episodes_num is None:
            episodes_num = self.test_episodes_per_epoch

        test_start_time = time.time()
        test_rewards = []
        test_actions = []
        test_frameskips = []
        for _ in trange(episodes_num, desc="Testing", file=sys.stdout,
                        leave=False, disable=not self.enable_progress_bar):
            total_reward, actions, frameskips, _ = self.run_episode(deterministic=deterministic, return_stats=True)
            test_rewards.append(total_reward)
            test_actions += actions
            test_frameskips += frameskips

        self.doom_wrapper.reset()
        if self.local_network.has_state():
            self.local_network.reset_state()

        test_end_time = time.time()
        test_duration = test_end_time - test_start_time
        min_score = np.min(test_rewards)
        max_score = np.max(test_rewards)
        mean_score = np.mean(test_rewards)
        score_std = np.std(test_rewards)
        log(
            "TEST: mean: {}, min: {}, max: {}, test time: {}".format(
                green("{:0.3f}±{:0.2f}".format(mean_score, score_std)),
                red("{:0.3f}".format(min_score)),
                blue("{:0.3f}".format(max_score)),
                sec_to_str(test_duration)))
        return test_rewards, test_actions, test_frameskips

    def _print_train_log(self, scores, overall_start_time, last_log_time, steps):
        current_time = time.time()
        mean_score = np.mean(scores)
        score_std = np.std(scores)
        min_score = np.min(scores)
        max_score = np.max(scores)

        elapsed_time = time.time() - overall_start_time
        global_steps = self._global_steps_counter.get()
        local_steps_per_sec = steps / (current_time - last_log_time)
        global_steps_per_sec = global_steps / elapsed_time
        global_mil_steps_per_hour = global_steps_per_sec * 3600 / 1000000.0
        log(
            "TRAIN: {}(GlobalSteps), {} episodes, mean: {}, min: {}, max: {}, "
            "\nLocalSpd: {:.0f} STEPS/s GlobalSpd: "
            "{} STEPS/s, {:.2f}M STEPS/hour, total elapsed time: {}".format(
                global_steps,
                len(scores),
                green("{:0.3f}±{:0.2f}".format(mean_score, score_std)),
                red("{:0.3f}".format(min_score)),
                blue("{:0.3f}".format(max_score)),
                local_steps_per_sec,
                blue("{:.0f}".format(
                    global_steps_per_sec)),
                global_mil_steps_per_hour,
                sec_to_str(elapsed_time)
            ))

    def run(self):
        # TODO this method is ugly, make it nicer
        try:
            overall_start_time = time.time()
            last_log_time = overall_start_time
            local_steps_for_log = 0
            while self._epoch <= self._epochs:
                steps = self.make_training_step()
                local_steps_for_log += steps
                global_steps = self._global_steps_counter.inc(steps)
                # Logs & tests
                if self.local_steps_per_epoch * self._epoch <= self.local_steps:
                    self._epoch += 1

                    if self.thread_index == 0:
                        log("EPOCH {}".format(self._epoch - 1))
                        self._print_train_log(
                            self.train_scores, overall_start_time, last_log_time, local_steps_for_log)
                        run_test_this_epoch = ((self._epoch - 1) % self.test_interval) == 0
                        if self._run_tests and run_test_this_epoch:
                            test_scores, test_actions, test_frameskips = self.test(
                                deterministic=self.deterministic_testing)

                        if self.write_summaries:
                            train_summary = self._session.run(self._summaries,
                                                              {self.scores_placeholder: self.train_scores,
                                                               self.actions_placeholder: self.train_actions,
                                                               self.frameskips_placeholder: self.train_frameskips})
                            self._train_writer.add_summary(train_summary, global_steps)
                            if self._run_tests and run_test_this_epoch:
                                test_summary = self._session.run(self._summaries,
                                                                 {self.scores_placeholder: test_scores,
                                                                  self.actions_placeholder: test_actions,
                                                                  self.frameskips_placeholder: test_frameskips})
                                self._test_writer.add_summary(test_summary, global_steps)

                        last_log_time = time.time()
                        local_steps_for_log = 0
                        log("Learning rate: {}".format(self._session.run(self.learning_rate)))

                        # Saves model
                        if self._epoch % self.save_interval == 0:
                            self.save_model()
                        now = datetime.datetime.now()
                        log("Time: {:2d}:{:02d}".format(now.hour, now.minute))

                        if self.show_heatmaps:
                            log("Train heatmaps:")
                            log(self.heatmap(self.train_actions, self.train_frameskips))
                            log("")
                            if run_test_this_epoch:
                                log("Test heatmaps:")
                                log(self.heatmap(test_actions, test_frameskips))
                        log("")
                    self.train_scores = []
                    self.train_actions = []
                    self.train_frameskips = []

            threadsafe_print("Thread {} finished.".format(self.thread_index))
        except (SignalException, ViZDoomUnexpectedExitException):
            threadsafe_print(red("Thread #{} aborting(ViZDoom killed).".format(self.thread_index)))

    def run_training(self, session):
        self._session = session
        self.start()

    def save_model(self):
        ensure_parent_directories(self._model_savefile)
        log("Saving model to: {}".format(self._model_savefile))
        saver = tf.train.Saver(self.local_network.get_params())
        saver.save(self._session, self._model_savefile)

    def load_model(self, session, savefile):
        saver = tf.train.Saver(self.local_network.get_params())
        log("Loading model from: {}".format(savefile))
        saver.restore(session, savefile)
        log("Loaded model.")

    def _get_best_action(self, sess, state, deterministic=True):
        policy = self.local_network.get_policy(sess, state)
        action_index = self.choose_best_index(policy, deterministic=deterministic)
        frameskip = None
        return action_index, frameskip
Beispiel #3
0
                              learning_rate,
                              gamma,
                              save_path=model_path)

if load_pretrained_network:
    doomguy.load_model()

if train_network:
    for episode in range(episodes):
        print('Episode', episode)
        doom.new_game()
        done = False
        step = 0

        while not done:
            state = doom.get_current_state()
            action_index = doomguy.act(state)
            next_state, reward, done = doom.step(action_index)
            doomguy.remember(state, action_index, reward, next_state, done)

            step += 1

        loss = doomguy.train()
        doomguy.reset_memory()
        print('Total steps: {}, loss was: {}'.format(step, loss))

if show_results:
    doom = VizdoomWrapper(config_path=config_path,
                          reward_table=reward_table,
                          frame_resolution=resolution,
                          show_mode=True,
Beispiel #4
0
class DQN(object):
    def __init__(self,
                 scenario_tag=None,
                 run_id_string=None,
                 network_type="networks.DQNNet",
                 write_summaries=True,
                 tf_logdir="tensorboard_logs",
                 epochs=100,
                 train_steps_per_epoch=1000000,
                 test_episodes_per_epoch=100,
                 run_tests=True,
                 initial_epsilon=1.0,
                 final_epsilon=0.0000,
                 epsilon_decay_steps=10e07,
                 epsilon_decay_start_step=2e05,
                 frozen_steps=5000,
                 batchsize=32,
                 memory_capacity=10000,
                 update_pattern=(4, 4),
                 prioritized_memory=False,
                 enable_progress_bar=True,
                 save_interval=1,
                 writer_max_queue=10,
                 writer_flush_secs=120,
                 dynamic_frameskips=None,
                 **settings):

        if prioritized_memory:
            raise NotImplementedError("Prioritized memory not implemented. Maybe some day.")
            # TODO maybe some day ...
            pass

        if dynamic_frameskips:
            if isinstance(dynamic_frameskips, (list, tuple)):
                self.frameskips = list(dynamic_frameskips)
            elif isinstance(dynamic_frameskips, int):
                self.frameskips = list(range(1, dynamic_frameskips + 1))
        else:
            self.frameskips = [None]

        self.update_pattern = update_pattern
        self.write_summaries = write_summaries
        self._settings = settings
        self.run_id_string = run_id_string
        self.train_steps_per_epoch = train_steps_per_epoch
        self._run_tests = test_episodes_per_epoch > 0 and run_tests
        self.test_episodes_per_epoch = test_episodes_per_epoch
        self._epochs = np.float32(epochs)

        self.doom_wrapper = VizdoomWrapper(**settings)
        misc_len = self.doom_wrapper.misc_len
        img_shape = self.doom_wrapper.img_shape
        self.use_misc = self.doom_wrapper.use_misc
        self.actions_num = self.doom_wrapper.actions_num
        self.replay_memory = ReplayMemory(img_shape, misc_len, batch_size=batchsize, capacity=memory_capacity)
        self.network = eval(network_type)(actions_num=self.actions_num * len(self.frameskips), img_shape=img_shape,
                                          misc_len=misc_len,
                                          **settings)

        self.batchsize = batchsize
        self.frozen_steps = frozen_steps

        self.save_interval = save_interval

        self._model_savefile = settings["models_path"] + "/" + self.run_id_string
        ## TODO move summaries somewhere so they are consistent between dqn and asyncs
        if self.write_summaries:
            assert tf_logdir is not None
            if not os.path.isdir(tf_logdir):
                os.makedirs(tf_logdir)

            self.scores_placeholder, summaries = setup_vector_summaries(scenario_tag + "/scores")
            self._summaries = tf.summary.merge(summaries)
            self._train_writer = tf.summary.FileWriter("{}/{}/{}".format(tf_logdir, self.run_id_string, "train"),
                                                       flush_secs=writer_flush_secs, max_queue=writer_max_queue)
            self._test_writer = tf.summary.FileWriter("{}/{}/{}".format(tf_logdir, self.run_id_string, "test"),
                                                      flush_secs=writer_flush_secs, max_queue=writer_max_queue)
        else:
            self._train_writer = None
            self._test_writer = None
            self._summaries = None
        self.steps = 0
        # TODO epoch as tf variable?
        self._epoch = 1

        # Epsilon
        self.epsilon_decay_rate = (initial_epsilon - final_epsilon) / epsilon_decay_steps
        self.epsilon_decay_start_step = epsilon_decay_start_step
        self.initial_epsilon = initial_epsilon
        self.final_epsilon = final_epsilon

        self.enable_progress_bar = enable_progress_bar

    def get_current_epsilon(self):
        eps = self.initial_epsilon - (self.steps - self.epsilon_decay_start_step) * self.epsilon_decay_rate
        return np.clip(eps, self.final_epsilon, 1.0)

    def get_action_and_frameskip(self, ai):
        action = ai % self.actions_num
        frameskip = self.frameskips[ai // self.actions_num]
        return action, frameskip

    @staticmethod
    def print_epoch_log(prefix, scores, steps, epoch_time):
        mean_score = np.mean(scores)
        score_std = np.std(scores)
        min_score = np.min(scores)
        max_score = np.max(scores)
        episodes = len(scores)

        steps_per_sec = steps / epoch_time
        mil_steps_per_hour = steps_per_sec * 3600 / 1000000.0
        log(
            "{}: Episodes: {}, mean: {}, min: {}, max: {}, "
            " Speed: {:.0f} STEPS/s, {:.2f}M STEPS/hour, time: {}".format(
                prefix,
                episodes,
                green("{:0.3f}±{:0.2f}".format(mean_score, score_std)),
                red("{:0.3f}".format(min_score)),
                blue("{:0.3f}".format(max_score)),
                steps_per_sec,
                mil_steps_per_hour,
                sec_to_str(epoch_time)
            ))

    def save_model(self, session, savefile=None):
        if savefile is None:
            savefile = self._model_savefile
        savedir = os.path.dirname(savefile)
        if not os.path.exists(savedir):
            log("Creating directory: {}".format(savedir))
            os.makedirs(savedir)
        log("Saving model to: {}".format(savefile))
        saver = tf.train.Saver()
        saver.save(session, savefile)

    def load_model(self, session, savefile):
        saver = tf.train.Saver()
        log("Loading model from: {}".format(savefile))
        saver.restore(session, savefile)
        log("Loaded model.")

    def train(self, session):

        # Prefill replay memory:
        for _ in trange(self.replay_memory.capacity, desc="Filling replay memory",
                        leave=False, disable=not self.enable_progress_bar, file=sys.stdout):
            if self.doom_wrapper.is_terminal():
                self.doom_wrapper.reset()
            s1 = self.doom_wrapper.get_current_state()
            action_frameskip_index = randint(0, self.actions_num * len(self.frameskips) - 1)
            action_index, frameskip = self.get_action_and_frameskip(action_frameskip_index)
            reward = self.doom_wrapper.make_action(action_index, frameskip)
            terminal = self.doom_wrapper.is_terminal()
            s2 = self.doom_wrapper.get_current_state()
            self.replay_memory.add_transition(s1, action_frameskip_index, s2, reward, terminal)

        overall_start_time = time()
        self.network.update_target_network(session)

        log(green("Starting training.\n"))
        while self._epoch <= self._epochs:
            self.doom_wrapper.reset()
            train_scores = []
            test_scores = []
            train_start_time = time()

            for _ in trange(self.train_steps_per_epoch, desc="Training, epoch {}".format(self._epoch),
                            leave=False, disable=not self.enable_progress_bar, file=sys.stdout):
                self.steps += 1
                s1 = self.doom_wrapper.get_current_state()

                if random() <= self.get_current_epsilon():
                    action_frameskip_index = randint(0, self.actions_num*len(self.frameskips) - 1)
                    action_index, frameskip = self.get_action_and_frameskip(action_frameskip_index)
                else:
                    action_frameskip_index = self.network.get_action(session, s1)
                    action_index, frameskip = self.get_action_and_frameskip(action_frameskip_index)

                reward = self.doom_wrapper.make_action(action_index, frameskip)
                terminal = self.doom_wrapper.is_terminal()
                s2 = self.doom_wrapper.get_current_state()
                self.replay_memory.add_transition(s1, action_frameskip_index, s2, reward, terminal)

                if self.steps % self.update_pattern[0] == 0:
                    for _ in range(self.update_pattern[1]):
                        self.network.train_batch(session, self.replay_memory.get_sample())

                if terminal:
                    train_scores.append(self.doom_wrapper.get_total_reward())
                    self.doom_wrapper.reset()
                if self.steps % self.frozen_steps == 0:
                    self.network.update_target_network(session)

            train_time = time() - train_start_time

            log("Epoch {}".format(self._epoch))
            log("Training steps: {}, epsilon: {}".format(self.steps, self.get_current_epsilon()))
            self.print_epoch_log("TRAIN", train_scores, self.train_steps_per_epoch, train_time)
            test_start_time = time()
            test_steps = 0
            # TESTING
            for _ in trange(self.test_episodes_per_epoch, desc="Testing, epoch {}".format(self._epoch),
                            leave=False, disable=not self.enable_progress_bar, file=sys.stdout):
                self.doom_wrapper.reset()
                while not self.doom_wrapper.is_terminal():
                    test_steps += 1
                    state = self.doom_wrapper.get_current_state()
                    action_frameskip_index = self.network.get_action(session, state)
                    action_index, frameskip = self.get_action_and_frameskip(action_frameskip_index)
                    self.doom_wrapper.make_action(action_index, frameskip)

                test_scores.append(self.doom_wrapper.get_total_reward())

            test_time = time() - test_start_time

            self.print_epoch_log("TEST", test_scores, test_steps, test_time)

            if self.write_summaries:
                log("Writing summaries.")
                train_summary = session.run(self._summaries, {self.scores_placeholder: train_scores})
                self._train_writer.add_summary(train_summary, self.steps)
                if self._run_tests:
                    test_summary = session.run(self._summaries, {self.scores_placeholder: test_scores})
                    self._test_writer.add_summary(test_summary, self.steps)

            # Save model
            if self._epoch % self.save_interval == 0:
                savedir = os.path.dirname(self._model_savefile)
                if not os.path.exists(savedir):
                    log("Creating directory: {}".format(savedir))
                    os.makedirs(savedir)
                log("Saving model to: {}".format(self._model_savefile))
                saver = tf.train.Saver()
                saver.save(session, self._model_savefile)

            overall_time = time() - overall_start_time
            log("Total elapsed time: {}\n".format(sec_to_str(overall_time)))
            self._epoch += 1

    def run_test_episode(self, session):
        self.doom_wrapper.reset()
        while not self.doom_wrapper.is_terminal():
            state = self.doom_wrapper.get_current_state()
            action_frameskip_index = self.network.get_action(session, state)
            action_index, frameskip = self.get_action_and_frameskip(action_frameskip_index)
            self.doom_wrapper.make_action(action_index, frameskip)
        reward = self.doom_wrapper.get_total_reward()
        return reward
class A3CLearner(Thread):
    def __init__(self,
                 thread_index,
                 network_type,
                 global_steps_counter,
                 scenario_tag=None,
                 run_id_string=None,
                 session=None,
                 tf_logdir=None,
                 global_network=None,
                 optimizer=None,
                 learning_rate=None,
                 test_only=False,
                 write_summaries=True,
                 enable_progress_bar=True,
                 deterministic_testing=True,
                 save_interval=1,
                 writer_max_queue=10,
                 writer_flush_secs=120,
                 **settings):
        super(A3CLearner, self).__init__()

        log("Creating actor-learner #{}.".format(thread_index))
        self.thread_index = thread_index

        self._global_steps_counter = global_steps_counter
        self.write_summaries = write_summaries
        self.save_interval = save_interval
        self.enable_progress_bar = enable_progress_bar
        self._model_savefile = None
        self._train_writer = None
        self._test_writer = None
        self._summaries = None
        self._session = session
        self.deterministic_testing = deterministic_testing
        self.local_steps = 0
        # TODO epoch as tf variable?
        self._epoch = 1
        self.train_scores = []

        self.local_steps_per_epoch = settings["local_steps_per_epoch"]
        self._run_tests = settings["test_episodes_per_epoch"] > 0 and settings[
            "run_tests"]
        self.test_episodes_per_epoch = settings["test_episodes_per_epoch"]
        self._epochs = np.float32(settings["epochs"])
        self.max_remembered_steps = settings["max_remembered_steps"]
        self.gamma = np.float32(settings["gamma"])

        if self.write_summaries and thread_index == 0 and not test_only:
            assert tf_logdir is not None
            self.run_id_string = run_id_string
            self.tf_models_path = settings["models_path"]
            if not os.path.isdir(tf_logdir):
                os.makedirs(tf_logdir)

            if self.tf_models_path is not None:
                if not os.path.isdir(settings["models_path"]):
                    os.makedirs(settings["models_path"])

        self.doom_wrapper = VizdoomWrapper(**settings)
        misc_len = self.doom_wrapper.misc_len
        img_shape = self.doom_wrapper.img_shape
        self.use_misc = self.doom_wrapper.use_misc

        self.actions_num = self.doom_wrapper.actions_num
        # TODO add debug log
        self.local_network = eval(network_type)(actions_num=self.actions_num,
                                                img_shape=img_shape,
                                                misc_len=misc_len,
                                                thread=thread_index,
                                                **settings)

        if not test_only:
            self.learning_rate = learning_rate
            # TODO check gate_gradients != Optimizer.GATE_OP
            grads_and_vars = optimizer.compute_gradients(
                self.local_network.ops.loss,
                var_list=self.local_network.get_params())
            grads, local_vars = zip(*grads_and_vars)

            grads_and_global_vars = zip(grads, global_network.get_params())
            self.train_op = optimizer.apply_gradients(
                grads_and_global_vars, global_step=tf.train.get_global_step())

            self.global_network = global_network
            self.local_network.prepare_sync_op(global_network)

        if self.thread_index == 0 and not test_only:
            self._model_savefile = settings[
                "models_path"] + "/" + self.run_id_string

            if self.write_summaries:
                self.scores_placeholder, summaries = setup_vector_summaries(
                    scenario_tag + "/scores")
                lr_summary = tf.summary.scalar(scenario_tag + "/learning_rate",
                                               self.learning_rate)
                summaries.append(lr_summary)
                self._summaries = tf.summary.merge(summaries)
                self._train_writer = tf.summary.FileWriter(
                    "{}/{}/{}".format(tf_logdir, self.run_id_string, "train"),
                    flush_secs=writer_flush_secs,
                    max_queue=writer_max_queue)
                self._test_writer = tf.summary.FileWriter(
                    "{}/{}/{}".format(tf_logdir, self.run_id_string, "test"),
                    flush_secs=writer_flush_secs,
                    max_queue=writer_max_queue)

    @staticmethod
    def choose_action_index(policy, deterministic=False):
        if deterministic:
            return np.argmax(policy)

        r = random.random()
        cummulative_sum = 0.0
        for i, p in enumerate(policy):
            cummulative_sum += p
            if r <= cummulative_sum:
                return i

        return len(policy) - 1

    def make_training_step(self):
        states_img = []
        states_misc = []
        actions = []
        rewards_reversed = []
        values_reversed = []
        advantages = []
        Rs = []

        # TODO check how default session works
        self._session.run(self.local_network.ops.sync)

        initial_network_state = None
        if self.local_network.has_state():
            initial_network_state = self.local_network.get_current_network_state(
            )

        terminal = None
        steps_performed = 0
        for _ in range(self.max_remembered_steps):
            steps_performed += 1
            current_img, current_misc = self.doom_wrapper.get_current_state()
            policy, state_value = self.local_network.get_policy_and_value(
                self._session, [current_img, current_misc])
            action_index = A3CLearner.choose_action_index(policy)
            values_reversed.insert(0, state_value)
            states_img.append(current_img)
            states_misc.append(current_misc)
            actions.append(action_index)
            reward = self.doom_wrapper.make_action(action_index)
            terminal = self.doom_wrapper.is_terminal()

            rewards_reversed.insert(0, reward)
            self.local_steps += 1
            if terminal:
                if self.thread_index == 0:
                    self.train_scores.append(
                        self.doom_wrapper.get_total_reward())
                self.doom_wrapper.reset()
                if self.local_network.has_state():
                    self.local_network.reset_state()
                break

        if terminal:
            R = 0.0
        else:
            R = self.local_network.get_value(
                self._session, self.doom_wrapper.get_current_state())

        for ri, Vi in zip(rewards_reversed, values_reversed):
            R = ri + self.gamma * R
            advantages.insert(0, R - Vi)
            Rs.insert(0, R)

        train_op_feed_dict = {
            self.local_network.vars.state_img: states_img,
            self.local_network.vars.a: actions,
            self.local_network.vars.advantage: advantages,
            self.local_network.vars.R: Rs
        }
        if self.use_misc:
            train_op_feed_dict[
                self.local_network.vars.state_misc] = states_misc

        if self.local_network.has_state():
            train_op_feed_dict[self.local_network.vars.
                               initial_network_state] = initial_network_state
            train_op_feed_dict[self.local_network.vars.sequence_length] = [
                len(actions)
            ]

        self._session.run(self.train_op, feed_dict=train_op_feed_dict)

        return steps_performed

    def run_episode(self, deterministic=True):
        self.doom_wrapper.reset()
        if self.local_network.has_state():
            self.local_network.reset_state()
        while not self.doom_wrapper.is_terminal():
            current_state = self.doom_wrapper.get_current_state()
            action_index = self._get_best_action(self._session,
                                                 current_state,
                                                 deterministic=deterministic)
            self.doom_wrapper.make_action(action_index)

        total_reward = self.doom_wrapper.get_total_reward()
        return total_reward

    def test(self, episodes_num=None, deterministic=True):
        if episodes_num is None:
            episodes_num = self.test_episodes_per_epoch

        test_start_time = time.time()
        test_rewards = []
        for _ in trange(episodes_num,
                        desc="Testing",
                        file=sys.stdout,
                        leave=False,
                        disable=not self.enable_progress_bar):
            total_reward = self.run_episode(
                deterministic=self.deterministic_testing)
            test_rewards.append(total_reward)

        self.doom_wrapper.reset()
        if self.local_network.has_state():
            self.local_network.reset_state()

        test_end_time = time.time()
        test_duration = test_end_time - test_start_time
        min_score = np.min(test_rewards)
        max_score = np.max(test_rewards)
        mean_score = np.mean(test_rewards)
        score_std = np.std(test_rewards)
        log("TEST: mean: {}, min: {}, max: {}, test time: {}".format(
            green("{:0.3f}±{:0.2f}".format(mean_score, score_std)),
            red("{:0.3f}".format(min_score)),
            blue("{:0.3f}".format(max_score)), sec_to_str(test_duration)))
        return test_rewards

    def _print_train_log(self, scores, overall_start_time, last_log_time,
                         steps):
        current_time = time.time()
        mean_score = np.mean(scores)
        score_std = np.std(scores)
        min_score = np.min(scores)
        max_score = np.max(scores)

        elapsed_time = time.time() - overall_start_time
        global_steps = self._global_steps_counter.get()
        local_steps_per_sec = steps / (current_time - last_log_time)
        global_steps_per_sec = global_steps / elapsed_time
        global_mil_steps_per_hour = global_steps_per_sec * 3600 / 1000000.0
        log("TRAIN: {}(GlobalSteps), mean: {}, min: {}, max: {}, "
            " LocalSpd: {:.0f} STEPS/s GlobalSpd: "
            "{} STEPS/s, {:.2f}M STEPS/hour, total elapsed time: {}".format(
                global_steps,
                green("{:0.3f}±{:0.2f}".format(mean_score, score_std)),
                red("{:0.3f}".format(min_score)),
                blue("{:0.3f}".format(max_score)), local_steps_per_sec,
                blue("{:.0f}".format(global_steps_per_sec)),
                global_mil_steps_per_hour, sec_to_str(elapsed_time)))

    def run(self):
        # TODO this method is ugly, make it nicer
        try:
            overall_start_time = time.time()
            last_log_time = overall_start_time
            local_steps_for_log = 0
            while self._epoch <= self._epochs:
                steps = self.make_training_step()
                local_steps_for_log += steps
                global_steps = self._global_steps_counter.inc(steps)
                # Logs & tests
                if self.local_steps_per_epoch * self._epoch <= self.local_steps:
                    self._epoch += 1

                    if self.thread_index == 0:
                        self._print_train_log(self.train_scores,
                                              overall_start_time,
                                              last_log_time,
                                              local_steps_for_log)

                        if self._run_tests:
                            test_scores = self.test(
                                deterministic=self.deterministic_testing)

                        if self.write_summaries:
                            train_summary = self._session.run(
                                self._summaries,
                                {self.scores_placeholder: self.train_scores})
                            self._train_writer.add_summary(
                                train_summary, global_steps)
                            if self._run_tests:
                                test_summary = self._session.run(
                                    self._summaries,
                                    {self.scores_placeholder: test_scores})
                                self._test_writer.add_summary(
                                    test_summary, global_steps)

                        last_log_time = time.time()
                        local_steps_for_log = 0
                        log("Learning rate: {}".format(
                            self._session.run(self.learning_rate)))

                        # Saves model
                        if self._epoch % self.save_interval == 0:
                            self.save_model()
                        log("")
                    self.train_scores = []

        except (SignalException, ViZDoomUnexpectedExitException):
            threadsafe_print(
                red("Thread #{} aborting(ViZDoom killed).".format(
                    self.thread_index)))

    def run_training(self, session):
        self._session = session
        self.start()

    def save_model(self):
        savedir = os.path.dirname(self._model_savefile)
        if not os.path.exists(savedir):
            log("Creating directory: {}".format(savedir))
            os.makedirs(savedir)
        log("Saving model to: {}".format(self._model_savefile))
        saver = tf.train.Saver(self.local_network.get_params())
        saver.save(self._session, self._model_savefile)

    def load_model(self, session, savefile):
        saver = tf.train.Saver(self.local_network.get_params())
        log("Loading model from: {}".format(savefile))
        saver.restore(session, savefile)
        log("Loaded model.")

    def _get_best_action(self, sess, state, deterministic=True):
        policy = self.local_network.get_policy(sess, state)
        action_index = self.choose_action_index(policy,
                                                deterministic=deterministic)
        return action_index