Пример #1
0
 def load_weights(self, checkpoint):
     if not os.path.exists(checkpoint):
         raise ValueError('Checkpoint path does not exists: %s' %
                          checkpoint)
     if tf.gfile.IsDirectory(checkpoint):
         checkpoint = tf.train.latest_checkpoint(checkpoint)
     self._saver.restore(self.sess, save_path=checkpoint)
     self.step = self.sess.run(self._obs_counter)
     self.episode = self.sess.run(self._ep_counter)
     logger.info('Checkpoint has been restored from: %s' % checkpoint)
Пример #2
0
 def save_weights(self, path, model_name='model.ckpt'):
     if not os.path.exists(path):
         os.makedirs(path)
     self.sess.run(self._obs_counter.assign(self.step))
     self.sess.run(self._ep_counter.assign(self.episode))
     self._saver.save(self.sess,
                      os.path.join(path, model_name),
                      global_step=self.global_step)
     logger.info('Checkpoint has been saved to: %s' %
                 os.path.join(path, model_name))
Пример #3
0
    def train(self):
        """Starts training."""
        writer = tf.summary.FileWriter(self.logdir, self.agent.sess.graph)
        threads = []
        stats = []
        for uid, env in enumerate(self.thread_envs):
            stat = Stats(self.agent)
            stats.append(stat)
            t = Thread(target=self.train_thread, args=(uid, env, stat))
            t.daemon = True
            t.start()
            threads.append(t)
        self.request_stop = False
        last_log_time = time.time()
        try:
            while self.agent.step < self.maxsteps:
                # If shared batch is ready, perform gradient step
                if len(self.shared_batch['thread_ready']) >= len(
                        self.thread_envs):
                    self.agent.train_on_batch(
                        obs=np.asarray(self.shared_batch['obs']),
                        actions=np.asarray(self.shared_batch['actions']),
                        rewards=np.asarray(self.shared_batch['rewards']),
                        term=np.asarray(self.shared_batch['term']),
                        obs_next=np.asarray(self.shared_batch['obs_next']),
                        traj_ends=np.asarray(self.shared_batch['traj_ends']),
                        lr=self.lr_schedule.value(self.agent.step),
                        summarize=False)
                    self.shared_batch = self._clear_batch()

                if time.time() - last_log_time >= self.logfreq:
                    last_log_time = time.time()
                    flush_stats(stats,
                                name="%s Train" % self.agent.name,
                                maxsteps=self.maxsteps,
                                writer=writer)
                    self.agent.save_weights(self.logdir)
                    self.agent.test(self.test_env,
                                    self.test_episodes,
                                    max_steps=self.test_maxsteps,
                                    render=self.test_render,
                                    writer=writer)
                    writer.flush()
                if self.render:
                    [env.render() for env in self.thread_envs]
                time.sleep(0.01)
        except KeyboardInterrupt:
            logger.info('Caught Ctrl+C! Stopping training process.')
        self.request_stop = True
        logger.info('Saving progress & performing evaluation.')
        self.agent.save_weights(self.logdir)
        self.agent.test(self.test_env,
                        self.test_episodes,
                        render=self.test_render)
        [t.join() for t in threads]
        logger.info('Training finished!')
        writer.close()
Пример #4
0
 def train(self):
     """Starts training."""
     writer = tf.summary.FileWriter(self.logdir, self.agent.sess.graph)
     threads = []
     for thread_agent, sync, stats in zip(self.thread_agents, self.sync_ops,
                                          self.thread_stats):
         thread_agent.sess = self.agent.sess
         t = Thread(target=self.train_thread,
                    args=(thread_agent, sync, stats))
         t.daemon = True
         t.start()
         threads.append(t)
     self.request_stop = False
     last_log_time = time.time()
     try:
         while self.agent.step < self.maxsteps:
             if time.time() - last_log_time >= self.logfreq:
                 last_log_time = time.time()
                 flush_stats(self.thread_stats,
                             name="%s Thread" % self.agent.name,
                             maxsteps=self.maxsteps,
                             writer=writer)
                 self.agent.save_weights(self.logdir)
                 self.agent.test(self.test_env,
                                 self.test_episodes,
                                 max_steps=self.test_maxsteps,
                                 render=self.test_render,
                                 writer=writer)
                 writer.flush()
             if self.render:
                 [agent.env.render() for agent in self.thread_agents]
             time.sleep(0.01)
     except KeyboardInterrupt:
         logger.info('Caught Ctrl+C! Stopping training process.')
     self.request_stop = True
     logger.info('Saving progress & performing evaluation.')
     self.agent.save_weights(self.logdir)
     self.agent.test(self.test_env,
                     self.test_episodes,
                     render=self.test_render)
     [t.join() for t in threads]
     logger.info('Training finished!')
     writer.close()
Пример #5
0
    def train(self):
        """Starts training."""
        try:
            lr_schedule = Schedule.create(self.lr_schedule, self.agent.opt.lr,
                                          self.maxsteps)
            writer = tf.summary.FileWriter(self.logdir, self.agent.sess.graph)
            t = Thread(target=self.collect_replay,
                       args=(self.maxsteps, self.agent, self.replay,
                             self.train_stats, self.render))

            t.daemon = True
            t.start()
            while self.agent.step < self.maxsteps:
                if not self.replay.is_ready:
                    logger.info("Fulfilling minimum replay size %d/%d." %
                                (self.replay.size, self.replay.min_size))
                    time.sleep(2)
                    continue

                obs, actions, rewards, term, obs_next, ends, idxs, importance = self.runner.sample(
                )
                # TODO info and lr (take from train on batch dict?)
                lr = lr_schedule.value(self.agent.step)
                self.perform_stats.add(actions, rewards, term, {})
                summarize = time.time() - self._last_log_time > self.logfreq
                res = self.agent.train_on_batch(obs=obs,
                                                actions=actions,
                                                rewards=rewards,
                                                term=term,
                                                obs_next=obs_next,
                                                traj_ends=ends,
                                                lr=lr,
                                                summarize=summarize,
                                                importance=importance)

                if isinstance(self.replay, ProportionalReplay):
                    # TODO value methods
                    self.replay.update(
                        idxs,
                        np.abs(
                            np.sum(res['value'] * actions, 1) - res['target']))

                if summarize:
                    self._last_log_time = time.time()
                    self.agent.save_weights(self.logdir)
                    flush_stats(self.perform_stats,
                                "%s Performance" % self.agent.name,
                                log_progress=False,
                                log_rewards=False,
                                log_hyperparams=False,
                                writer=writer)
                    flush_stats(self.train_stats,
                                "%s Train" % self.agent.name,
                                log_performance=False,
                                log_hyperparams=False,
                                maxsteps=self.maxsteps,
                                writer=writer)
                    self.agent.test(self.test_env,
                                    self.test_episodes,
                                    max_steps=self.test_maxsteps,
                                    render=self.test_render,
                                    writer=writer)
                    if self.logdir and 'summary' in res:
                        writer.add_summary(res['summary'],
                                           global_step=self.agent.step)
                    writer.flush()

            logger.info('Performing final evaluation.')
            self.agent.test(self.test_env,
                            self.test_episodes,
                            max_steps=self.test_maxsteps,
                            render=self.test_render)
            writer.close()
            logger.info('Training finished.')
        except KeyboardInterrupt:
            logger.info('Stopping training process...')
        self.agent.save_weights(self.logdir)
Пример #6
0
    def test(self,
             env,
             episodes,
             max_steps=1e5,
             render=False,
             max_fps=None,
             writer=None):
        """Tests agent's performance on a given number of episodes.

        Args:
            env (gym.Env): Test environment.
            episodes (int): Number of episodes.
            max_steps (int): Maximum allowed step per episode.
            render (bool): Enables game screen rendering.
            max_fps (int): Maximum allowed fps. To disable fps limitation, pass None.
            writer (FileWriter): TensorBoard summary writer.

        Returns (utils.RewardStats): Average reward per episode.
        """
        if env is not None:
            self.test_env = env
        elif self.test_env is None:
            logger.warn(
                "Testing environment is not provided. Using training env as testing."
            )
            self.test_env = copy.deepcopy(self.env)
        stats = Stats(agent=self)
        delta_frame = 1. / max_fps if max_fps else 0
        step_counter = 0
        episode_counter = 0
        max_steps = int(max_steps)
        for _ in range(episodes):
            obs = self.test_env.reset()
            for i in range(max_steps):
                start_time = time.time()
                action = self.act(obs)
                obs, r, terminal, info = self.test_env.step(action)
                step_limit = i >= max_steps - 1
                terminal = terminal or step_limit
                if step_limit:
                    logger.info("Interrupting test episode due to the "
                                "maximum allowed number of steps (%d)" % i)
                step_counter += 1
                episode_counter += terminal
                stats.add(action, r, terminal, info)
                if render:
                    self.test_env.render()
                    if delta_frame > 0:
                        delay = max(0,
                                    delta_frame - (time.time() - start_time))
                        time.sleep(delay)
                if terminal:
                    # TODO: Check for atari life lost
                    break
        reward_stats = copy.deepcopy(stats.reward_stats)
        flush_stats(stats,
                    log_progress=False,
                    log_performance=False,
                    log_hyperparams=False,
                    name='%s Test' % self.name,
                    writer=writer)
        return reward_stats