def load_weights(self, checkpoint): if not os.path.exists(checkpoint): raise ValueError('Checkpoint path does not exists: %s' % checkpoint) if tf.gfile.IsDirectory(checkpoint): checkpoint = tf.train.latest_checkpoint(checkpoint) self._saver.restore(self.sess, save_path=checkpoint) self.step = self.sess.run(self._obs_counter) self.episode = self.sess.run(self._ep_counter) logger.info('Checkpoint has been restored from: %s' % checkpoint)
def save_weights(self, path, model_name='model.ckpt'): if not os.path.exists(path): os.makedirs(path) self.sess.run(self._obs_counter.assign(self.step)) self.sess.run(self._ep_counter.assign(self.episode)) self._saver.save(self.sess, os.path.join(path, model_name), global_step=self.global_step) logger.info('Checkpoint has been saved to: %s' % os.path.join(path, model_name))
def train(self): """Starts training.""" writer = tf.summary.FileWriter(self.logdir, self.agent.sess.graph) threads = [] stats = [] for uid, env in enumerate(self.thread_envs): stat = Stats(self.agent) stats.append(stat) t = Thread(target=self.train_thread, args=(uid, env, stat)) t.daemon = True t.start() threads.append(t) self.request_stop = False last_log_time = time.time() try: while self.agent.step < self.maxsteps: # If shared batch is ready, perform gradient step if len(self.shared_batch['thread_ready']) >= len( self.thread_envs): self.agent.train_on_batch( obs=np.asarray(self.shared_batch['obs']), actions=np.asarray(self.shared_batch['actions']), rewards=np.asarray(self.shared_batch['rewards']), term=np.asarray(self.shared_batch['term']), obs_next=np.asarray(self.shared_batch['obs_next']), traj_ends=np.asarray(self.shared_batch['traj_ends']), lr=self.lr_schedule.value(self.agent.step), summarize=False) self.shared_batch = self._clear_batch() if time.time() - last_log_time >= self.logfreq: last_log_time = time.time() flush_stats(stats, name="%s Train" % self.agent.name, maxsteps=self.maxsteps, writer=writer) self.agent.save_weights(self.logdir) self.agent.test(self.test_env, self.test_episodes, max_steps=self.test_maxsteps, render=self.test_render, writer=writer) writer.flush() if self.render: [env.render() for env in self.thread_envs] time.sleep(0.01) except KeyboardInterrupt: logger.info('Caught Ctrl+C! Stopping training process.') self.request_stop = True logger.info('Saving progress & performing evaluation.') self.agent.save_weights(self.logdir) self.agent.test(self.test_env, self.test_episodes, render=self.test_render) [t.join() for t in threads] logger.info('Training finished!') writer.close()
def train(self): """Starts training.""" writer = tf.summary.FileWriter(self.logdir, self.agent.sess.graph) threads = [] for thread_agent, sync, stats in zip(self.thread_agents, self.sync_ops, self.thread_stats): thread_agent.sess = self.agent.sess t = Thread(target=self.train_thread, args=(thread_agent, sync, stats)) t.daemon = True t.start() threads.append(t) self.request_stop = False last_log_time = time.time() try: while self.agent.step < self.maxsteps: if time.time() - last_log_time >= self.logfreq: last_log_time = time.time() flush_stats(self.thread_stats, name="%s Thread" % self.agent.name, maxsteps=self.maxsteps, writer=writer) self.agent.save_weights(self.logdir) self.agent.test(self.test_env, self.test_episodes, max_steps=self.test_maxsteps, render=self.test_render, writer=writer) writer.flush() if self.render: [agent.env.render() for agent in self.thread_agents] time.sleep(0.01) except KeyboardInterrupt: logger.info('Caught Ctrl+C! Stopping training process.') self.request_stop = True logger.info('Saving progress & performing evaluation.') self.agent.save_weights(self.logdir) self.agent.test(self.test_env, self.test_episodes, render=self.test_render) [t.join() for t in threads] logger.info('Training finished!') writer.close()
def train(self): """Starts training.""" try: lr_schedule = Schedule.create(self.lr_schedule, self.agent.opt.lr, self.maxsteps) writer = tf.summary.FileWriter(self.logdir, self.agent.sess.graph) t = Thread(target=self.collect_replay, args=(self.maxsteps, self.agent, self.replay, self.train_stats, self.render)) t.daemon = True t.start() while self.agent.step < self.maxsteps: if not self.replay.is_ready: logger.info("Fulfilling minimum replay size %d/%d." % (self.replay.size, self.replay.min_size)) time.sleep(2) continue obs, actions, rewards, term, obs_next, ends, idxs, importance = self.runner.sample( ) # TODO info and lr (take from train on batch dict?) lr = lr_schedule.value(self.agent.step) self.perform_stats.add(actions, rewards, term, {}) summarize = time.time() - self._last_log_time > self.logfreq res = self.agent.train_on_batch(obs=obs, actions=actions, rewards=rewards, term=term, obs_next=obs_next, traj_ends=ends, lr=lr, summarize=summarize, importance=importance) if isinstance(self.replay, ProportionalReplay): # TODO value methods self.replay.update( idxs, np.abs( np.sum(res['value'] * actions, 1) - res['target'])) if summarize: self._last_log_time = time.time() self.agent.save_weights(self.logdir) flush_stats(self.perform_stats, "%s Performance" % self.agent.name, log_progress=False, log_rewards=False, log_hyperparams=False, writer=writer) flush_stats(self.train_stats, "%s Train" % self.agent.name, log_performance=False, log_hyperparams=False, maxsteps=self.maxsteps, writer=writer) self.agent.test(self.test_env, self.test_episodes, max_steps=self.test_maxsteps, render=self.test_render, writer=writer) if self.logdir and 'summary' in res: writer.add_summary(res['summary'], global_step=self.agent.step) writer.flush() logger.info('Performing final evaluation.') self.agent.test(self.test_env, self.test_episodes, max_steps=self.test_maxsteps, render=self.test_render) writer.close() logger.info('Training finished.') except KeyboardInterrupt: logger.info('Stopping training process...') self.agent.save_weights(self.logdir)
def test(self, env, episodes, max_steps=1e5, render=False, max_fps=None, writer=None): """Tests agent's performance on a given number of episodes. Args: env (gym.Env): Test environment. episodes (int): Number of episodes. max_steps (int): Maximum allowed step per episode. render (bool): Enables game screen rendering. max_fps (int): Maximum allowed fps. To disable fps limitation, pass None. writer (FileWriter): TensorBoard summary writer. Returns (utils.RewardStats): Average reward per episode. """ if env is not None: self.test_env = env elif self.test_env is None: logger.warn( "Testing environment is not provided. Using training env as testing." ) self.test_env = copy.deepcopy(self.env) stats = Stats(agent=self) delta_frame = 1. / max_fps if max_fps else 0 step_counter = 0 episode_counter = 0 max_steps = int(max_steps) for _ in range(episodes): obs = self.test_env.reset() for i in range(max_steps): start_time = time.time() action = self.act(obs) obs, r, terminal, info = self.test_env.step(action) step_limit = i >= max_steps - 1 terminal = terminal or step_limit if step_limit: logger.info("Interrupting test episode due to the " "maximum allowed number of steps (%d)" % i) step_counter += 1 episode_counter += terminal stats.add(action, r, terminal, info) if render: self.test_env.render() if delta_frame > 0: delay = max(0, delta_frame - (time.time() - start_time)) time.sleep(delay) if terminal: # TODO: Check for atari life lost break reward_stats = copy.deepcopy(stats.reward_stats) flush_stats(stats, log_progress=False, log_performance=False, log_hyperparams=False, name='%s Test' % self.name, writer=writer) return reward_stats