class Trainer: """Base class of trainer. Use trainer.setup(algo, env) to setup algorithm and environment for trainer and trainer.train() to start training. Args: snapshot_config (garage.experiment.SnapshotConfig): The snapshot configuration used by Trainer to create the snapshotter. If None, it will create one with default settings. Note: For the use of any TensorFlow environments, policies and algorithms, please use TFTrainer(). Examples: | # to train | trainer = Trainer() | env = Env(...) | policy = Policy(...) | algo = Algo( | env=env, | policy=policy, | ...) | trainer.setup(algo, env) | trainer.train(n_epochs=100, batch_size=4000) | # to resume immediately. | trainer = Trainer() | trainer.restore(resume_from_dir) | trainer.resume() | # to resume with modified training arguments. | trainer = Trainer() | trainer.restore(resume_from_dir) | trainer.resume(n_epochs=20) """ def __init__(self, snapshot_config): self._snapshotter = Snapshotter(snapshot_config.snapshot_dir, snapshot_config.snapshot_mode, snapshot_config.snapshot_gap) self._has_setup = False self._plot = False self._setup_args = None self._train_args = None self._stats = ExperimentStats(total_itr=0, total_env_steps=0, total_epoch=0, last_episode=None) self._algo = None self._env = None self._sampler = None self._plotter = None self._start_time = None self._itr_start_time = None self.step_itr = None self.step_episode = None # only used for off-policy algorithms self.enable_logging = True self._n_workers = None self._worker_class = None self._worker_args = None def make_sampler(self, sampler_cls, *, seed=None, n_workers=psutil.cpu_count(logical=False), max_episode_length=None, worker_class=None, sampler_args=None, worker_args=None): """Construct a Sampler from a Sampler class. Args: sampler_cls (type): The type of sampler to construct. seed (int): Seed to use in sampler workers. max_episode_length (int): Maximum episode length to be sampled by the sampler. Epsiodes longer than this will be truncated. n_workers (int): The number of workers the sampler should use. worker_class (type): Type of worker the Sampler should use. sampler_args (dict or None): Additional arguments that should be passed to the sampler. worker_args (dict or None): Additional arguments that should be passed to the sampler. Raises: ValueError: If `max_episode_length` isn't passed and the algorithm doesn't contain a `max_episode_length` field, or if the algorithm doesn't have a policy field. Returns: sampler_cls: An instance of the sampler class. """ policy = getattr(self._algo, 'exploration_policy', None) if policy is None: policy = getattr(self._algo, 'policy', None) if policy is None: raise ValueError('If the trainer is used to construct a sampler, ' 'the algorithm must have a `policy` or ' '`exploration_policy` field.') if max_episode_length is None: if hasattr(self._algo, 'max_episode_length'): max_episode_length = self._algo.max_episode_length if max_episode_length is None: raise ValueError('If `sampler_cls` is specified in trainer.setup, ' 'the algorithm must specify `max_episode_length`') if worker_class is None: worker_class = getattr(self._algo, 'worker_cls', DefaultWorker) if seed is None: seed = get_seed() if sampler_args is None: sampler_args = {} if worker_args is None: worker_args = {} return sampler_cls.from_worker_factory(WorkerFactory( seed=seed, max_episode_length=max_episode_length, n_workers=n_workers, worker_class=worker_class, worker_args=worker_args), agents=policy, envs=self._env) def setup(self, algo, env, sampler_cls=None, sampler_args=None, n_workers=psutil.cpu_count(logical=False), worker_class=None, worker_args=None): """Set up trainer for algorithm and environment. This method saves algo and env within trainer and creates a sampler. Note: After setup() is called all variables in session should have been initialized. setup() respects existing values in session so policy weights can be loaded before setup(). Args: algo (RLAlgorithm): An algorithm instance. env (Environment): An environment instance. sampler_cls (type): A class which implements :class:`Sampler`. sampler_args (dict): Arguments to be passed to sampler constructor. n_workers (int): The number of workers the sampler should use. worker_class (type): Type of worker the sampler should use. worker_args (dict or None): Additional arguments that should be passed to the worker. Raises: ValueError: If sampler_cls is passed and the algorithm doesn't contain a `max_episode_length` field. """ self._algo = algo self._env = env self._n_workers = n_workers self._worker_class = worker_class if sampler_args is None: sampler_args = {} if sampler_cls is None: sampler_cls = getattr(algo, 'sampler_cls', None) if worker_class is None: worker_class = getattr(algo, 'worker_cls', DefaultWorker) if worker_args is None: worker_args = {} self._worker_args = worker_args if sampler_cls is None: self._sampler = None else: self._sampler = self.make_sampler(sampler_cls, sampler_args=sampler_args, n_workers=n_workers, worker_class=worker_class, worker_args=worker_args) self._has_setup = True self._setup_args = SetupArgs(sampler_cls=sampler_cls, sampler_args=sampler_args, seed=get_seed()) def _start_worker(self): """Start Plotter and Sampler workers.""" if self._plot: # pylint: disable=import-outside-toplevel from garage.plotter import Plotter self._plotter = Plotter() self._plotter.init_plot(self.get_env_copy(), self._algo.policy) def _shutdown_worker(self): """Shutdown Plotter and Sampler workers.""" if self._sampler is not None: self._sampler.shutdown_worker() if self._plot: self._plotter.close() def obtain_episodes(self, itr, batch_size=None, agent_update=None, env_update=None): """Obtain one batch of episodes. Args: itr (int): Index of iteration (epoch). batch_size (int): Number of steps in batch. This is a hint that the sampler may or may not respect. agent_update (object): Value which will be passed into the `agent_update_fn` before doing sampling episodes. If a list is passed in, it must have length exactly `factory.n_workers`, and will be spread across the workers. env_update (object): Value which will be passed into the `env_update_fn` before sampling episodes. If a list is passed in, it must have length exactly `factory.n_workers`, and will be spread across the workers. Raises: ValueError: If the trainer was initialized without a sampler, or batch_size wasn't provided here or to train. Returns: EpisodeBatch: Batch of episodes. """ if self._sampler is None: raise ValueError('trainer was not initialized with `sampler_cls`. ' 'Either provide `sampler_cls` to trainer.setup, ' ' or set `algo.sampler_cls`.') if batch_size is None and self._train_args.batch_size is None: raise ValueError( 'trainer was not initialized with `batch_size`. ' 'Either provide `batch_size` to trainer.train, ' ' or pass `batch_size` to trainer.obtain_samples.') episodes = None if agent_update is None: policy = getattr(self._algo, 'exploration_policy', None) if policy is None: # This field should exist, since self.make_sampler would have # failed otherwise. policy = self._algo.policy agent_update = policy.get_param_values() episodes = self._sampler.obtain_samples( itr, (batch_size or self._train_args.batch_size), agent_update=agent_update, env_update=env_update) self._stats.total_env_steps += sum(episodes.lengths) return episodes def obtain_samples(self, itr, batch_size=None, agent_update=None, env_update=None): """Obtain one batch of samples. Args: itr (int): Index of iteration (epoch). batch_size (int): Number of steps in batch. This is a hint that the sampler may or may not respect. agent_update (object): Value which will be passed into the `agent_update_fn` before sampling episodes. If a list is passed in, it must have length exactly `factory.n_workers`, and will be spread across the workers. env_update (object): Value which will be passed into the `env_update_fn` before sampling episodes. If a list is passed in, it must have length exactly `factory.n_workers`, and will be spread across the workers. Raises: ValueError: Raised if the trainer was initialized without a sampler, or batch_size wasn't provided here or to train. Returns: list[dict]: One batch of samples. """ eps = self.obtain_episodes(itr, batch_size, agent_update, env_update) return eps.to_list() def save(self, epoch): """Save snapshot of current batch. Args: epoch (int): Epoch. Raises: NotSetupError: if save() is called before the trainer is set up. """ if not self._has_setup: raise NotSetupError('Use setup() to setup trainer before saving.') logger.log('Saving snapshot...') params = dict() # Save arguments params['setup_args'] = self._setup_args params['train_args'] = self._train_args params['stats'] = self._stats # Save states params['env'] = self._env params['algo'] = self._algo params['n_workers'] = self._n_workers params['worker_class'] = self._worker_class params['worker_args'] = self._worker_args self._snapshotter.save_itr_params(epoch, params) logger.log('Saved') def restore(self, from_dir, from_epoch='last'): """Restore experiment from snapshot. Args: from_dir (str): Directory of the pickle file to resume experiment from. from_epoch (str or int): The epoch to restore from. Can be 'first', 'last' or a number. Not applicable when snapshot_mode='last'. Returns: TrainArgs: Arguments for train(). """ saved = self._snapshotter.load(from_dir, from_epoch) self._setup_args = saved['setup_args'] self._train_args = saved['train_args'] self._stats = saved['stats'] set_seed(self._setup_args.seed) self.setup(env=saved['env'], algo=saved['algo'], sampler_cls=self._setup_args.sampler_cls, sampler_args=self._setup_args.sampler_args, n_workers=saved['n_workers'], worker_class=saved['worker_class'], worker_args=saved['worker_args']) n_epochs = self._train_args.n_epochs last_epoch = self._stats.total_epoch last_itr = self._stats.total_itr total_env_steps = self._stats.total_env_steps batch_size = self._train_args.batch_size store_episodes = self._train_args.store_episodes pause_for_plot = self._train_args.pause_for_plot fmt = '{:<20} {:<15}' logger.log('Restore from snapshot saved in %s' % self._snapshotter.snapshot_dir) logger.log(fmt.format('-- Train Args --', '-- Value --')) logger.log(fmt.format('n_epochs', n_epochs)) logger.log(fmt.format('last_epoch', last_epoch)) logger.log(fmt.format('batch_size', batch_size)) logger.log(fmt.format('store_episodes', store_episodes)) logger.log(fmt.format('pause_for_plot', pause_for_plot)) logger.log(fmt.format('-- Stats --', '-- Value --')) logger.log(fmt.format('last_itr', last_itr)) logger.log(fmt.format('total_env_steps', total_env_steps)) self._train_args.start_epoch = last_epoch + 1 return copy.copy(self._train_args) def log_diagnostics(self, pause_for_plot=False): """Log diagnostics. Args: pause_for_plot (bool): Pause for plot. """ logger.log('Time %.2f s' % (time.time() - self._start_time)) logger.log('EpochTime %.2f s' % (time.time() - self._itr_start_time)) tabular.record('TotalEnvSteps', self._stats.total_env_steps) logger.log(tabular) if self._plot: self._plotter.update_plot(self._algo.policy, self._algo.max_episode_length) if pause_for_plot: input('Plotting evaluation run: Press Enter to " "continue...') def train(self, n_epochs, batch_size=None, plot=False, store_episodes=False, pause_for_plot=False): """Start training. Args: n_epochs (int): Number of epochs. batch_size (int or None): Number of environment steps in one batch. plot (bool): Visualize an episode from the policy after each epoch. store_episodes (bool): Save episodes in snapshot. pause_for_plot (bool): Pause for plot. Raises: NotSetupError: If train() is called before setup(). Returns: float: The average return in last epoch cycle. """ if not self._has_setup: raise NotSetupError( 'Use setup() to setup trainer before training.') # Save arguments for restore self._train_args = TrainArgs(n_epochs=n_epochs, batch_size=batch_size, plot=plot, store_episodes=store_episodes, pause_for_plot=pause_for_plot, start_epoch=0) self._plot = plot average_return = self._algo.train(self) self._shutdown_worker() return average_return def step_epochs(self): """Step through each epoch. This function returns a magic generator. When iterated through, this generator automatically performs services such as snapshotting and log management. It is used inside train() in each algorithm. The generator initializes two variables: `self.step_itr` and `self.step_episode`. To use the generator, these two have to be updated manually in each epoch, as the example shows below. Yields: int: The next training epoch. Examples: for epoch in trainer.step_epochs(): trainer.step_episode = trainer.obtain_samples(...) self.train_once(...) trainer.step_itr += 1 """ self._start_worker() self._start_time = time.time() self.step_itr = self._stats.total_itr self.step_episode = None # Used by integration tests to ensure examples can run one epoch. n_epochs = int( os.environ.get('GARAGE_EXAMPLE_TEST_N_EPOCHS', self._train_args.n_epochs)) logger.log('Obtaining samples...') for epoch in range(self._train_args.start_epoch, n_epochs): self._itr_start_time = time.time() with logger.prefix('epoch #%d | ' % epoch): yield epoch save_episode = (self.step_episode if self._train_args.store_episodes else None) self._stats.last_episode = save_episode self._stats.total_epoch = epoch self._stats.total_itr = self.step_itr self.save(epoch) if self.enable_logging: self.log_diagnostics(self._train_args.pause_for_plot) logger.dump_all(self.step_itr) tabular.clear() def resume(self, n_epochs=None, batch_size=None, plot=None, store_episodes=None, pause_for_plot=None): """Resume from restored experiment. This method provides the same interface as train(). If not specified, an argument will default to the saved arguments from the last call to train(). Args: n_epochs (int): Number of epochs. batch_size (int): Number of environment steps in one batch. plot (bool): Visualize an episode from the policy after each epoch. store_episodes (bool): Save episodes in snapshot. pause_for_plot (bool): Pause for plot. Raises: NotSetupError: If resume() is called before restore(). Returns: float: The average return in last epoch cycle. """ if self._train_args is None: raise NotSetupError('You must call restore() before resume().') self._train_args.n_epochs = n_epochs or self._train_args.n_epochs self._train_args.batch_size = batch_size or self._train_args.batch_size if plot is not None: self._train_args.plot = plot if store_episodes is not None: self._train_args.store_episodes = store_episodes if pause_for_plot is not None: self._train_args.pause_for_plot = pause_for_plot average_return = self._algo.train(self) self._shutdown_worker() return average_return def get_env_copy(self): """Get a copy of the environment. Returns: Environment: An environment instance. """ if self._env: return cloudpickle.loads(cloudpickle.dumps(self._env)) else: return None @property def total_env_steps(self): """Total environment steps collected. Returns: int: Total environment steps collected. """ return self._stats.total_env_steps
class CMAES(RLAlgorithm, Serializable): def __init__(self, env, policy, n_itr=500, max_path_length=500, discount=0.99, sigma0=1., batch_size=None, plot=False, play_every_itr=None, play_rollouts_num=3, **kwargs): """ :param n_itr: Number of iterations. :param max_path_length: Maximum length of a single rollout. :param batch_size: # of samples from trajs from param distribution, when this is set, n_samples is ignored :param discount: Discount. :param plot: Plot evaluation run after each iteration. :param sigma0: Initial std for param dist :return: """ Serializable.quick_init(self, locals()) self.env = env self.policy = policy self.plot = plot self.sigma0 = sigma0 self.discount = discount self.max_path_length = max_path_length self.n_itr = n_itr self.batch_size = batch_size self.plotter = Plotter() self.play_every_itr = play_every_itr self.play_rollouts_num = play_rollouts_num def play_policy(self, env, policy, n_rollout=3): # Rollout for _ in range(n_rollout): obs = env.reset() while True: env.render() action, _ = policy.get_action(obs) obs, _, done, _ = env.step(action) if done: break def train(self, sess=None): created_session = True if (sess is None) else False if sess is None: sess = tf.Session() sess.__enter__() sess.run(tf.global_variables_initializer()) cur_std = self.sigma0 cur_mean = self.policy.get_param_values() es = cma.CMAEvolutionStrategy(cur_mean, cur_std) parallel_sampler.populate_task(self.env, self.policy) if self.plot: self.plotter.init_plot(self.env, self.policy) cur_std = self.sigma0 cur_mean = self.policy.get_param_values() itr = 0 while itr < self.n_itr and not es.stop(): if self.batch_size is None: # Sample from multivariate normal distribution. xs = es.ask() xs = np.asarray(xs) # For each sample, do a rollout. infos = (stateful_pool.singleton_pool.run_map( sample_return, [(x, self.max_path_length, self.discount) for x in xs])) else: cum_len = 0 infos = [] xss = [] done = False while not done: sbs = stateful_pool.singleton_pool.n_parallel * 2 # Sample from multivariate normal distribution. # You want to ask for sbs samples here. xs = es.ask(sbs) xs = np.asarray(xs) xss.append(xs) sinfos = stateful_pool.singleton_pool.run_map( sample_return, [(x, self.max_path_length, self.discount) for x in xs]) for info in sinfos: infos.append(info) cum_len += len(info['returns']) if cum_len >= self.batch_size: xs = np.concatenate(xss) done = True break # Evaluate fitness of samples (negative as it is minimization # problem). fs = -np.array([info['returns'][0] for info in infos]) # When batching, you could have generated too many samples compared # to the actual evaluations. So we cut it off in this case. xs = xs[:len(fs)] # Update CMA-ES params based on sample fitness. es.tell(xs, fs) logger.push_prefix('itr #%d | ' % itr) logger.record_tabular('Iteration', itr) logger.record_tabular('CurStdMean', np.mean(cur_std)) undiscounted_returns = np.array( [info['undiscounted_return'] for info in infos]) logger.record_tabular('AverageReturn', np.mean(undiscounted_returns)) logger.record_tabular('StdReturn', np.mean(undiscounted_returns)) logger.record_tabular('MaxReturn', np.max(undiscounted_returns)) logger.record_tabular('MinReturn', np.min(undiscounted_returns)) logger.record_tabular('AverageDiscountedReturn', np.mean(fs)) logger.record_tabular( 'AvgTrajLen', np.mean([len(info['returns']) for info in infos])) self.policy.log_diagnostics(infos) logger.save_itr_params( itr, dict( itr=itr, policy=self.policy, env=self.env, )) logger.dump_tabular(with_prefix=False) if self.plot: self.plotter.update_plot(self.policy, self.max_path_length) logger.pop_prefix() # Update iteration. itr += 1 # Showing policy from time to time if self.play_every_itr is not None and self.play_every_itr > 0 and itr % self.play_every_itr == 0: self.play_policy(env=self.env, policy=self.policy, n_rollout=self.play_rollouts_num) # Set final params. self.policy.set_param_values(es.result()[0]) parallel_sampler.terminate_task() self.plotter.close()
class DDPG(RLAlgorithm): """ Deep Deterministic Policy Gradient. """ def __init__(self, env, policy, qf, es, pool, batch_size=32, n_epochs=200, epoch_length=1000, min_pool_size=10000, discount=0.99, max_path_length=250, qf_weight_decay=0., qf_update_method='adam', qf_learning_rate=1e-3, policy_weight_decay=0, policy_update_method='adam', policy_learning_rate=1e-4, eval_samples=10000, soft_target=True, soft_target_tau=0.001, n_updates_per_sample=1, scale_reward=1.0, include_horizon_terminal_transitions=False, plot=False, pause_for_plot=False): """ :param env: Environment :param policy: Policy :param qf: Q function :param es: Exploration strategy :param batch_size: Number of samples for each minibatch. :param n_epochs: Number of epochs. Policy will be evaluated after each epoch. :param epoch_length: How many timesteps for each epoch. :param min_pool_size: Minimum size of the pool to start training. :param discount: Discount factor for the cumulative return. :param max_path_length: Discount factor for the cumulative return. :param qf_weight_decay: Weight decay factor for parameters of the Q function. :param qf_update_method: Online optimization method for training Q function. :param qf_learning_rate: Learning rate for training Q function. :param policy_weight_decay: Weight decay factor for parameters of the policy. :param policy_update_method: Online optimization method for training the policy. :param policy_learning_rate: Learning rate for training the policy. :param eval_samples: Number of samples (timesteps) for evaluating the policy. :param soft_target_tau: Interpolation parameter for doing the soft target update. :param n_updates_per_sample: Number of Q function and policy updates per new sample obtained :param scale_reward: The scaling factor applied to the rewards when training :param include_horizon_terminal_transitions: whether to include transitions with terminal=True because the horizon was reached. This might make the Q value back up less stable for certain tasks. :param plot: Whether to visualize the policy performance after each eval_interval. :param pause_for_plot: Whether to pause before continuing when plotting :return: """ self.env = env self.policy = policy self.qf = qf self.es = es self.batch_size = batch_size self.n_epochs = n_epochs self.epoch_length = epoch_length self.min_pool_size = min_pool_size self.pool = pool self.discount = discount self.max_path_length = max_path_length self.qf_weight_decay = qf_weight_decay self.qf_update_method = \ parse_update_method( qf_update_method, learning_rate=qf_learning_rate, ) self.qf_learning_rate = qf_learning_rate self.policy_weight_decay = policy_weight_decay self.policy_update_method = \ parse_update_method( policy_update_method, learning_rate=policy_learning_rate, ) self.policy_learning_rate = policy_learning_rate self.eval_samples = eval_samples self.soft_target_tau = soft_target_tau self.n_updates_per_sample = n_updates_per_sample self.include_horizon_terminal_transitions = \ include_horizon_terminal_transitions self.plot = plot self.pause_for_plot = pause_for_plot self.qf_loss_averages = [] self.policy_surr_averages = [] self.q_averages = [] self.y_averages = [] self.paths = [] self.es_path_returns = [] self.paths_samples_cnt = 0 self.scale_reward = scale_reward self.opt_info = None self.plotter = Plotter() def start_worker(self): parallel_sampler.populate_task(self.env, self.policy) if self.plot: self.plotter.init_plot(self.env, self.policy) @overrides def train(self): # This seems like a rather sequential method self.start_worker() self.init_opt() itr = 0 path_length = 0 path_return = 0 terminal = False observation = self.env.reset() sample_policy = pickle.loads(pickle.dumps(self.policy)) for epoch in range(self.n_epochs): logger.push_prefix('epoch #%d | ' % epoch) logger.log("Training started") for epoch_itr in pyprind.prog_bar(range(self.epoch_length)): # Execute policy if terminal: # or path_length > self.max_path_length: # Note that if the last time step ends an episode, the very # last state and observation will be ignored and not added # to the replay pool observation = self.env.reset() self.es.reset() sample_policy.reset() self.es_path_returns.append(path_return) path_length = 0 path_return = 0 action = self.es.get_action(itr, observation, policy=sample_policy) next_observation, reward, terminal, _ = self.env.step(action) path_length += 1 path_return += reward if not terminal and path_length >= self.max_path_length: terminal = True # only include the terminal transition in this case if the # flag was set if self.include_horizon_terminal_transitions: self.pool.add_transition( observation=[observation], action=[action], reward=[reward * self.scale_reward], terminal=[terminal], next_observation=[next_observation]) else: self.pool.add_transition( observation=[observation], action=[action], reward=[reward * self.scale_reward], terminal=[terminal], next_observation=[next_observation]) observation = next_observation if self.pool.n_transitions_stored >= self.min_pool_size: for update_itr in range(self.n_updates_per_sample): # Train policy batch = self.pool.sample(self.batch_size) self.do_training(itr, batch) sample_policy.set_param_values( self.policy.get_param_values()) itr += 1 logger.log("Training finished") if self.pool.n_transitions_stored >= self.min_pool_size: self.evaluate(epoch, self.pool) params = self.get_epoch_snapshot(epoch) logger.save_itr_params(epoch, params) logger.dump_tabular(with_prefix=False) logger.pop_prefix() if self.plot: self.update_plot() if self.pause_for_plot: input("Plotting evaluation run: Press Enter to " "continue...") self.env.close() self.policy.terminate() self.plotter.close() def init_opt(self): # First, create "target" policy and Q functions target_policy = pickle.loads(pickle.dumps(self.policy)) target_qf = pickle.loads(pickle.dumps(self.qf)) # y need to be computed first obs = self.env.observation_space.new_tensor_variable( 'obs', extra_dims=1, ) # The yi values are computed separately as above and then passed to # the training functions below action = self.env.action_space.new_tensor_variable( 'action', extra_dims=1, ) yvar = TT.vector('ys') qf_weight_decay_term = 0.5 * self.qf_weight_decay * \ sum([TT.sum(TT.square(param)) for param in self.qf.get_params(regularizable=True)]) qval = self.qf.get_qval_sym(obs, action) qf_loss = TT.mean(TT.square(yvar - qval)) qf_reg_loss = qf_loss + qf_weight_decay_term policy_weight_decay_term = 0.5 * self.policy_weight_decay * sum([ TT.sum(TT.square(param)) for param in self.policy.get_params(regularizable=True) ]) policy_qval = self.qf.get_qval_sym(obs, self.policy.get_action_sym(obs), deterministic=True) policy_surr = -TT.mean(policy_qval) policy_reg_surr = policy_surr + policy_weight_decay_term qf_updates = self.qf_update_method(qf_reg_loss, self.qf.get_params(trainable=True)) policy_updates = self.policy_update_method( policy_reg_surr, self.policy.get_params(trainable=True)) f_train_qf = tensor_utils.compile_function(inputs=[yvar, obs, action], outputs=[qf_loss, qval], updates=qf_updates) f_train_policy = tensor_utils.compile_function(inputs=[obs], outputs=policy_surr, updates=policy_updates) self.opt_info = dict( f_train_qf=f_train_qf, f_train_policy=f_train_policy, target_qf=target_qf, target_policy=target_policy, ) def do_training(self, itr, batch): obs, actions, rewards, next_obs, terminals = ext.extract( batch, "observation", "action", "reward", "next_observation", "terminal") rewards = rewards.reshape(-1, ) terminals = terminals.reshape(-1, ) # compute the on-policy y values target_qf = self.opt_info["target_qf"] target_policy = self.opt_info["target_policy"] next_actions, _ = target_policy.get_actions(next_obs) next_qvals = target_qf.get_qval(next_obs, next_actions) ys = rewards + (1. - terminals) * self.discount * next_qvals f_train_qf = self.opt_info["f_train_qf"] f_train_policy = self.opt_info["f_train_policy"] qf_loss, qval = f_train_qf(ys, obs, actions) policy_surr = f_train_policy(obs) target_policy.set_param_values(target_policy.get_param_values() * (1.0 - self.soft_target_tau) + self.policy.get_param_values() * self.soft_target_tau) target_qf.set_param_values(target_qf.get_param_values() * (1.0 - self.soft_target_tau) + self.qf.get_param_values() * self.soft_target_tau) self.qf_loss_averages.append(qf_loss) self.policy_surr_averages.append(policy_surr) self.q_averages.append(qval) self.y_averages.append(ys) def evaluate(self, epoch, pool): logger.log("Collecting samples for evaluation") paths = parallel_sampler.sample_paths( policy_params=self.policy.get_param_values(), max_samples=self.eval_samples, max_path_length=self.max_path_length, ) average_discounted_return = np.mean([ special.discount_return(path["rewards"], self.discount) for path in paths ]) returns = [sum(path["rewards"]) for path in paths] all_qs = np.concatenate(self.q_averages) all_ys = np.concatenate(self.y_averages) average_q_loss = np.mean(self.qf_loss_averages) average_policy_surr = np.mean(self.policy_surr_averages) average_action = np.mean( np.square(np.concatenate([path["actions"] for path in paths]))) policy_reg_param_norm = np.linalg.norm( self.policy.get_param_values(regularizable=True)) qfun_reg_param_norm = np.linalg.norm( self.qf.get_param_values(regularizable=True)) logger.record_tabular('Epoch', epoch) logger.record_tabular('AverageReturn', np.mean(returns)) logger.record_tabular('StdReturn', np.std(returns)) logger.record_tabular('MaxReturn', np.max(returns)) logger.record_tabular('MinReturn', np.min(returns)) if self.es_path_returns: logger.record_tabular('AverageEsReturn', np.mean(self.es_path_returns)) logger.record_tabular('StdEsReturn', np.std(self.es_path_returns)) logger.record_tabular('MaxEsReturn', np.max(self.es_path_returns)) logger.record_tabular('MinEsReturn', np.min(self.es_path_returns)) logger.record_tabular('AverageDiscountedReturn', average_discounted_return) logger.record_tabular('AverageQLoss', average_q_loss) logger.record_tabular('AveragePolicySurr', average_policy_surr) logger.record_tabular('AverageQ', np.mean(all_qs)) logger.record_tabular('AverageAbsQ', np.mean(np.abs(all_qs))) logger.record_tabular('AverageY', np.mean(all_ys)) logger.record_tabular('AverageAbsY', np.mean(np.abs(all_ys))) logger.record_tabular('AverageAbsQYDiff', np.mean(np.abs(all_qs - all_ys))) logger.record_tabular('AverageAction', average_action) logger.record_tabular('PolicyRegParamNorm', policy_reg_param_norm) logger.record_tabular('QFunRegParamNorm', qfun_reg_param_norm) self.policy.log_diagnostics(paths) self.qf_loss_averages = [] self.policy_surr_averages = [] self.q_averages = [] self.y_averages = [] self.es_path_returns = [] def update_plot(self): if self.plot: self.plotter.update_plot(self.policy, self.max_path_length) def get_epoch_snapshot(self, epoch): return dict( env=self.env, epoch=epoch, qf=self.qf, policy=self.policy, target_qf=self.opt_info["target_qf"], target_policy=self.opt_info["target_policy"], es=self.es, )
class CEM(RLAlgorithm, Serializable): def __init__(self, env, policy, n_itr=500, max_path_length=500, discount=0.99, init_std=1., n_samples=100, batch_size=None, best_frac=0.05, extra_std=1., extra_decay_time=100, plot=False, n_evals=1, play_every_itr=None, play_rollouts_num=3, **kwargs): """ :param n_itr: Number of iterations. :param max_path_length: Maximum length of a single rollout. :param batch_size: # of samples from trajs from param distribution, when this is set, n_samples is ignored :param discount: Discount. :param plot: Plot evaluation run after each iteration. :param init_std: Initial std for param distribution :param extra_std: Decaying std added to param distribution at each iteration :param extra_decay_time: Iterations that it takes to decay extra std :param n_samples: #of samples from param distribution :param best_frac: Best fraction of the sampled params :param n_evals: # of evals per sample from the param distr. returned score is mean - stderr of evals :return: """ Serializable.quick_init(self, locals()) self.env = env self.policy = policy self.batch_size = batch_size self.plot = plot self.extra_decay_time = extra_decay_time self.extra_std = extra_std self.best_frac = best_frac self.n_samples = n_samples self.init_std = init_std self.discount = discount self.max_path_length = max_path_length self.n_itr = n_itr self.n_evals = n_evals self.plotter = Plotter() self.play_every_itr = play_every_itr self.play_rollouts_num = play_rollouts_num def play_policy(self, env, policy, n_rollout=3): # Rollout for _ in range(n_rollout): obs = env.reset() while True: env.render() action, _ = policy.get_action(obs) obs, _, done, _ = env.step(action) if done: break def train(self, sess=None): # created_session = True if (sess is None) else False # if sess is None: # sess = tf.Session() # sess.__enter__() sess.run(tf.global_variables_initializer()) parallel_sampler.populate_task(self.env, self.policy) if self.plot: self.plotter.init_plot(self.env, self.policy) cur_std = self.init_std cur_mean = self.policy.get_param_values() # K = cur_mean.size n_best = max(1, int(self.n_samples * self.best_frac)) for itr in range(self.n_itr): # sample around the current distribution extra_var_mult = max(1.0 - itr / self.extra_decay_time, 0) sample_std = np.sqrt( np.square(cur_std) + np.square(self.extra_std) * extra_var_mult) if self.batch_size is None: criterion = 'paths' threshold = self.n_samples else: criterion = 'samples' threshold = self.batch_size infos = stateful_pool.singleton_pool.run_collect( _worker_rollout_policy, threshold=threshold, args=(dict( cur_mean=cur_mean, sample_std=sample_std, max_path_length=self.max_path_length, discount=self.discount, criterion=criterion, n_evals=self.n_evals), )) xs = np.asarray([info[0] for info in infos]) paths = [info[1] for info in infos] fs = np.array([path['returns'][0] for path in paths]) print((xs.shape, fs.shape)) best_inds = (-fs).argsort()[:n_best] best_xs = xs[best_inds] cur_mean = best_xs.mean(axis=0) cur_std = best_xs.std(axis=0) best_x = best_xs[0] logger.push_prefix('itr #%d | ' % itr) logger.record_tabular('Iteration', itr) logger.record_tabular('CurStdMean', np.mean(cur_std)) undiscounted_returns = np.array( [path['undiscounted_return'] for path in paths]) logger.record_tabular('AverageReturn', np.mean(undiscounted_returns)) logger.record_tabular('StdReturn', np.std(undiscounted_returns)) logger.record_tabular('MaxReturn', np.max(undiscounted_returns)) logger.record_tabular('MinReturn', np.min(undiscounted_returns)) logger.record_tabular('AverageDiscountedReturn', np.mean(fs)) logger.record_tabular('NumTrajs', len(paths)) paths = list(chain( *[d['full_paths'] for d in paths])) # flatten paths for the case n_evals > 1 logger.record_tabular( 'AvgTrajLen', np.mean([len(path['returns']) for path in paths])) self.policy.set_param_values(best_x) self.policy.log_diagnostics(paths) logger.save_itr_params( itr, dict( itr=itr, policy=self.policy, env=self.env, cur_mean=cur_mean, cur_std=cur_std, )) logger.dump_tabular(with_prefix=False) logger.pop_prefix() if self.plot: self.plotter.update_plot(self.policy, self.max_path_length) # Showing policy from time to time if self.play_every_itr is not None and self.play_every_itr > 0 and itr % self.play_every_itr == 0: self.play_policy(env=self.env, policy=self.policy, n_rollout=self.play_rollouts_num) parallel_sampler.terminate_task() self.plotter.close() if created_session: sess.close()