def main_train(tf_configs=None):
    s_t = time.time()

    tf.reset_default_graph()

    if not os.path.exists(cfg.model_path):
        os.makedirs(cfg.model_path)

    global_episodes = tf.Variable(0,
                                  dtype=tf.int32,
                                  name='global_episodes',
                                  trainable=False)
    with tf.device("/gpu:0"):
        optimizer = tf.train.RMSPropOptimizer(learning_rate=1e-5)
        global_network = network.ACNetwork('global',
                                           optimizer,
                                           img_shape=cfg.IMG_SHAPE)
        num_workers = cfg.AGENTS_NUM
        agents = []
        # Create worker classes
        for i in range(num_workers):
            #agents.append(agent.Agent(DoomGame(), i, s_size, a_size, optimizer, cfg.model_path, global_episodes))
            agents.append(
                agent.Agent(game=DoomGame(),
                            name=i,
                            optimizer=optimizer,
                            model_path=cfg.model_path,
                            global_episodes=global_episodes))
    saver = tf.train.Saver(max_to_keep=100)

    with tf.Session(config=tf_configs) as sess:
        coord = tf.train.Coordinator()
        if load_model:
            print('Loading Model...')
            ckpt = tf.train.get_checkpoint_state(cfg.model_path)
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            sess.run(tf.global_variables_initializer())

        # This is where the asynchronous magic happens.
        # Start the "work" process for each worker in a separate threat.
        worker_threads = []
        for ag in agents:
            agent_train = lambda: ag.train_a3c(max_episode_length, gamma, sess,
                                               coord, saver)
            t = threading.Thread(target=(agent_train))
            t.start()
            time.sleep(0.5)
            worker_threads.append(t)
        coord.join(worker_threads)
    print("training ends, costs{}".format(time.time() - s_t))
    def __init__(self, game, name, s_size, a_size, optimizer=None, model_path=None, global_episodes=None, play=False):
        self.s_size = s_size
        self.a_size = a_size

        self.summary_step = 3

        self.name = "worker_" + str(name)
        self.number = name

        self.episode_reward = []
        self.episode_episode_health = []
        self.episode_lengths = []
        self.episode_mean_values = []
        self.episode_health = []
        self.episode_kills = []

        # Create the local copy of the network and the tensorflow op to
        # copy global parameters to local network
        if not play:
            self.model_path = model_path
            self.trainer = optimizer
            self.global_episodes = global_episodes
            self.increment = self.global_episodes.assign_add(1)
            self.local_AC_network = network.ACNetwork(self.name, optimizer, play=play)
            self.summary_writer = tf.summary.FileWriter("./summaries/defend_the_center/agent_%s" % str(self.number))
            self.update_local_ops = tf.group(*utils.update_target_graph('global', self.name))
        else:
            self.local_AC_network = network.ACNetwork(self.name, optimizer, play=play)
        if not isinstance(game, DoomGame):
            raise TypeError("Type Error")

        # The Below code is related to setting up the Doom environment
        game = DoomGame()
        # game.set_doom_scenario_path('../scenarios/deadly_corridor.cfg')
        game.load_config("../scenarios/defend_the_center.cfg")
        # game.set_doom_map("map01")
        game.set_screen_resolution(ScreenResolution.RES_640X480)
        game.set_screen_format(ScreenFormat.RGB24)
        game.set_render_hud(False)
        game.set_render_crosshair(False)
        game.set_render_weapon(True)
        game.set_render_decals(False)
        game.set_render_particles(False)
        # Enables labeling of the in game objects.
        game.set_labels_buffer_enabled(True)
        game.add_available_button(Button.TURN_LEFT)
        game.add_available_button(Button.TURN_RIGHT)
        game.add_available_button(Button.ATTACK)
        game.add_available_game_variable(GameVariable.USER1)
        game.set_episode_timeout(2100)
        game.set_episode_start_time(5)
        game.set_window_visible(play)
        game.set_sound_enabled(False)
        game.set_living_reward(0)
        game.set_mode(Mode.PLAYER)
        if play:
            # game.add_game_args("+viz_render_all 1")
            game.set_render_hud(False)
            game.set_ticrate(35)
        game.init()
        self.env = game
        self.actions = self.button_combinations()
Exemple #3
0
    def __init__(self,
                 game,
                 name,
                 optimizer=None,
                 model_path=None,
                 global_episodes=None,
                 play=False,
                 task_name='healthpack_simple'):
        self.task_name = task_name

        self.summary_step = 3

        self.name = "worker_" + str(name)
        self.number = name

        self.last_total_health = 100.
        self.img_shape = cfg.IMG_SHAPE

        self.episode_reward = []
        self.episode_episode_total_pickes = []
        self.episode_lengths = []
        self.episode_mean_values = []
        self.episode_health = []

        # Create the local copy of the network and the tensorflow op to
        # copy global parameters to local network
        if not play:
            self.model_path = model_path
            self.trainer = optimizer
            self.global_episodes = global_episodes
            self.increment = self.global_episodes.assign_add(1)
            self.local_AC_network = network.ACNetwork(self.name,
                                                      optimizer,
                                                      play=play,
                                                      img_shape=cfg.IMG_SHAPE)
            self.summary_writer = tf.summary.FileWriter(
                "./summaries/healthpack/train_health%s" % str(self.number))
            self.update_local_ops = tf.group(
                *utils.update_target_graph(self.task_name +
                                           '/global', self.task_name + '/' +
                                           self.name))
        else:
            self.local_AC_network = network.ACNetwork(self.name,
                                                      optimizer,
                                                      play=play,
                                                      img_shape=cfg.IMG_SHAPE)
        if not isinstance(game, DoomGame):
            raise TypeError("Type Error")

        # The Below code is related to setting up the Doom environment
        game = DoomGame()
        game.set_doom_scenario_path("../scenarios/{}".format(
            'health_gathering_supreme.wad' if cfg.
            IS_SUPREME_VERSION else 'health_gathering.wad'))

        game.set_doom_map("map01")
        game.set_screen_resolution(ScreenResolution.RES_640X480)
        game.set_screen_format(ScreenFormat.RGB24)
        game.set_render_hud(False)
        game.set_render_crosshair(False)
        game.set_render_weapon(True)
        game.set_render_decals(False)
        game.set_render_particles(True)
        # Enables labeling of the in game objects.
        game.set_labels_buffer_enabled(True)
        game.add_available_button(Button.TURN_LEFT)
        game.add_available_button(Button.TURN_RIGHT)
        game.add_available_button(Button.MOVE_FORWARD)
        game.add_available_game_variable(GameVariable.USER1)
        game.set_episode_timeout(2100)
        game.set_episode_start_time(5)
        game.set_window_visible(play)
        game.set_sound_enabled(False)
        game.set_living_reward(0)
        game.set_mode(Mode.PLAYER)
        if play:
            game.add_game_args("+viz_render_all 1")
            game.set_render_hud(False)
            game.set_ticrate(35)
        game.init()
        self.env = game
        self.actions = [
            list(perm)
            for perm in iter.product([False, True],
                                     repeat=game.get_available_buttons_size())
        ]
        self.actions.remove([True, True, True])
        self.actions.remove([True, True, False])
    def __init__(self,
                 game,
                 name,
                 optimizer=None,
                 model_path=None,
                 global_episodes=None,
                 play=False,
                 task_name='healthpack_simple'):
        self.task_name = task_name
        self.play = play
        self.summary_step = 3

        self.name = cfg.AGENT_PREFIX + str(name)
        self.number = name

        self.imitate_data = None

        self.last_total_health = 100.
        self.last_total_kills = 0.
        self.last_total_ammos = 0.
        self.img_shape = cfg.IMG_SHAPE

        self.episode_reward = []
        self.episode_lengths = []
        self.episode_mean_values = []
        self.episode_health = []
        self.episode_kills = []

        if not self.play:
            self.model_path = model_path
            self.trainer = optimizer
            self.global_episodes = global_episodes
            self.increment = self.global_episodes.assign_add(1)
            self.local_AC_network = network.ACNetwork(self.name,
                                                      optimizer,
                                                      play=self.play,
                                                      img_shape=cfg.IMG_SHAPE)
            self.summary_writer = tf.summary.FileWriter(
                "./summaries/%s/ag_%s" % (self.task_name, str(self.number)))
            # create a tensorflow op to copy weights from global network regularly when training
            self.update_local_ops = tf.group(
                *utils.update_target_graph('global', self.name))
        else:
            self.local_AC_network = network.ACNetwork(self.name,
                                                      optimizer,
                                                      play=self.play,
                                                      img_shape=cfg.IMG_SHAPE)
        if not isinstance(game, DoomGame):
            raise TypeError("Type Error")

        game = DoomGame()
        game.load_config(cfg.SCENARIO_PATH)
        game.set_doom_map("map01")
        game.set_screen_resolution(ScreenResolution.RES_640X480)
        game.set_screen_format(ScreenFormat.RGB24)
        game.set_render_hud(False)
        game.set_render_crosshair(False)
        game.set_render_weapon(True)
        game.set_render_decals(False)
        game.set_render_particles(True)
        # Enables labeling of the in game objects.
        game.set_labels_buffer_enabled(True)
        game.add_available_button(Button.MOVE_FORWARD)
        game.add_available_button(Button.MOVE_RIGHT)
        game.add_available_button(Button.MOVE_LEFT)
        game.add_available_button(Button.TURN_LEFT)
        game.add_available_button(Button.TURN_RIGHT)
        game.add_available_button(Button.ATTACK)
        game.add_available_button(Button.SPEED)
        game.add_available_game_variable(GameVariable.AMMO2)
        game.add_available_game_variable(GameVariable.HEALTH)
        game.add_available_game_variable(GameVariable.USER2)
        game.set_episode_timeout(2100)
        game.set_episode_start_time(5)
        game.set_window_visible(self.play)
        game.set_sound_enabled(False)
        game.set_living_reward(0)
        game.set_mode(Mode.PLAYER)
        if self.play:
            game.add_game_args("+viz_render_all 1")
            game.set_render_hud(False)
            game.set_ticrate(35)
        game.init()
        self.env = game
        self.actions = cfg.button_combinations()