def train(self): plotter = Plotter() if self.plot: plotter.init_plot(self.env, self.policy) self.start_worker() self.init_opt() for itr in range(self.current_itr, self.n_itr): with logger.prefix('itr #%d | ' % itr): paths = self.sampler.obtain_samples(itr) samples_data = self.sampler.process_samples(itr, paths) self.log_diagnostics(paths) self.optimize_policy(itr, samples_data) logger.log("saving snapshot...") params = self.get_itr_snapshot(itr, samples_data) self.current_itr = itr + 1 params["algo"] = self if self.store_paths: params["paths"] = samples_data["paths"] logger.save_itr_params(itr, params) logger.log("saved") logger.dump_tabular(with_prefix=False) if self.plot: plotter.update_plot(self.policy, self.max_path_length) if self.pause_for_plot: input("Plotting evaluation run: Press Enter to " "continue...") plotter.shutdown() self.shutdown_worker()
class DDPG(RLAlgorithm): """ Deep Deterministic Policy Gradient. """ def __init__(self, env, policy, qf, es, batch_size=32, n_epochs=200, epoch_length=1000, min_pool_size=10000, replay_pool_size=1000000, discount=0.99, max_path_length=250, qf_weight_decay=0., qf_update_method='adam', qf_learning_rate=1e-3, policy_weight_decay=0, policy_update_method='adam', policy_learning_rate=1e-4, eval_samples=10000, soft_target=True, soft_target_tau=0.001, n_updates_per_sample=1, scale_reward=1.0, include_horizon_terminal_transitions=False, plot=False, pause_for_plot=False): """ :param env: Environment :param policy: Policy :param qf: Q function :param es: Exploration strategy :param batch_size: Number of samples for each minibatch. :param n_epochs: Number of epochs. Policy will be evaluated after each epoch. :param epoch_length: How many timesteps for each epoch. :param min_pool_size: Minimum size of the pool to start training. :param replay_pool_size: Size of the experience replay pool. :param discount: Discount factor for the cumulative return. :param max_path_length: Discount factor for the cumulative return. :param qf_weight_decay: Weight decay factor for parameters of the Q function. :param qf_update_method: Online optimization method for training Q function. :param qf_learning_rate: Learning rate for training Q function. :param policy_weight_decay: Weight decay factor for parameters of the policy. :param policy_update_method: Online optimization method for training the policy. :param policy_learning_rate: Learning rate for training the policy. :param eval_samples: Number of samples (timesteps) for evaluating the policy. :param soft_target_tau: Interpolation parameter for doing the soft target update. :param n_updates_per_sample: Number of Q function and policy updates per new sample obtained :param scale_reward: The scaling factor applied to the rewards when training :param include_horizon_terminal_transitions: whether to include transitions with terminal=True because the horizon was reached. This might make the Q value back up less stable for certain tasks. :param plot: Whether to visualize the policy performance after each eval_interval. :param pause_for_plot: Whether to pause before continuing when plotting :return: """ self.env = env self.policy = policy self.qf = qf self.es = es self.batch_size = batch_size self.n_epochs = n_epochs self.epoch_length = epoch_length self.min_pool_size = min_pool_size self.replay_pool_size = replay_pool_size self.discount = discount self.max_path_length = max_path_length self.qf_weight_decay = qf_weight_decay self.qf_update_method = \ parse_update_method( qf_update_method, learning_rate=qf_learning_rate, ) self.qf_learning_rate = qf_learning_rate self.policy_weight_decay = policy_weight_decay self.policy_update_method = \ parse_update_method( policy_update_method, learning_rate=policy_learning_rate, ) self.policy_learning_rate = policy_learning_rate self.eval_samples = eval_samples self.soft_target_tau = soft_target_tau self.n_updates_per_sample = n_updates_per_sample self.include_horizon_terminal_transitions = \ include_horizon_terminal_transitions self.plot = plot self.pause_for_plot = pause_for_plot self.qf_loss_averages = [] self.policy_surr_averages = [] self.q_averages = [] self.y_averages = [] self.paths = [] self.es_path_returns = [] self.paths_samples_cnt = 0 self.scale_reward = scale_reward self.opt_info = None self.plotter = Plotter() def start_worker(self): parallel_sampler.populate_task(self.env, self.policy) if self.plot: self.plotter.init_plot(self.env, self.policy) @overrides def train(self): # This seems like a rather sequential method pool = ReplayBuffer( max_buffer_size=self.replay_pool_size, observation_dim=self.env.observation_space.flat_dim, action_dim=self.env.action_space.flat_dim, ) self.start_worker() self.init_opt() itr = 0 path_length = 0 path_return = 0 terminal = False observation = self.env.reset() sample_policy = pickle.loads(pickle.dumps(self.policy)) for epoch in range(self.n_epochs): logger.push_prefix('epoch #%d | ' % epoch) logger.log("Training started") for epoch_itr in pyprind.prog_bar(range(self.epoch_length)): # Execute policy if terminal: # or path_length > self.max_path_length: # Note that if the last time step ends an episode, the very # last state and observation will be ignored and not added # to the replay pool observation = self.env.reset() self.es.reset() sample_policy.reset() self.es_path_returns.append(path_return) path_length = 0 path_return = 0 action = self.es.get_action( itr, observation, policy=sample_policy) next_observation, reward, terminal, _ = self.env.step(action) path_length += 1 path_return += reward if not terminal and path_length >= self.max_path_length: terminal = True # only include the terminal transition in this case if the # flag was set if self.include_horizon_terminal_transitions: pool.add_transition(observation, action, reward * self.scale_reward, terminal, next_observation) else: pool.add_transition(observation, action, reward * self.scale_reward, terminal, next_observation) observation = next_observation if pool.size >= self.min_pool_size: for update_itr in range(self.n_updates_per_sample): # Train policy batch = pool.random_sample(self.batch_size) self.do_training(itr, batch) sample_policy.set_param_values( self.policy.get_param_values()) itr += 1 logger.log("Training finished") if pool.size >= self.min_pool_size: self.evaluate(epoch, pool) params = self.get_epoch_snapshot(epoch) logger.save_itr_params(epoch, params) logger.dump_tabular(with_prefix=False) logger.pop_prefix() if self.plot: self.update_plot() if self.pause_for_plot: input("Plotting evaluation run: Press Enter to " "continue...") self.env.close() self.policy.terminate() self.plotter.shutdown() def init_opt(self): # First, create "target" policy and Q functions target_policy = pickle.loads(pickle.dumps(self.policy)) target_qf = pickle.loads(pickle.dumps(self.qf)) # y need to be computed first obs = self.env.observation_space.new_tensor_variable( 'obs', extra_dims=1, ) # The yi values are computed separately as above and then passed to # the training functions below action = self.env.action_space.new_tensor_variable( 'action', extra_dims=1, ) yvar = TT.vector('ys') qf_weight_decay_term = 0.5 * self.qf_weight_decay * \ sum([TT.sum(TT.square(param)) for param in self.qf.get_params(regularizable=True)]) qval = self.qf.get_qval_sym(obs, action) qf_loss = TT.mean(TT.square(yvar - qval)) qf_reg_loss = qf_loss + qf_weight_decay_term policy_weight_decay_term = 0.5 * self.policy_weight_decay * sum([ TT.sum(TT.square(param)) for param in self.policy.get_params(regularizable=True) ]) policy_qval = self.qf.get_qval_sym( obs, self.policy.get_action_sym(obs), deterministic=True) policy_surr = -TT.mean(policy_qval) policy_reg_surr = policy_surr + policy_weight_decay_term qf_updates = self.qf_update_method( qf_reg_loss, self.qf.get_params(trainable=True)) policy_updates = self.policy_update_method( policy_reg_surr, self.policy.get_params(trainable=True)) f_train_qf = ext.compile_function( inputs=[yvar, obs, action], outputs=[qf_loss, qval], updates=qf_updates) f_train_policy = ext.compile_function( inputs=[obs], outputs=policy_surr, updates=policy_updates) self.opt_info = dict( f_train_qf=f_train_qf, f_train_policy=f_train_policy, target_qf=target_qf, target_policy=target_policy, ) def do_training(self, itr, batch): obs, actions, rewards, next_obs, terminals = ext.extract( batch, "observations", "actions", "rewards", "next_observations", "terminals") # compute the on-policy y values target_qf = self.opt_info["target_qf"] target_policy = self.opt_info["target_policy"] next_actions, _ = target_policy.get_actions(next_obs) next_qvals = target_qf.get_qval(next_obs, next_actions) ys = rewards + (1. - terminals) * self.discount * next_qvals f_train_qf = self.opt_info["f_train_qf"] f_train_policy = self.opt_info["f_train_policy"] qf_loss, qval = f_train_qf(ys, obs, actions) policy_surr = f_train_policy(obs) target_policy.set_param_values( target_policy.get_param_values() * (1.0 - self.soft_target_tau) + self.policy.get_param_values() * self.soft_target_tau) target_qf.set_param_values( target_qf.get_param_values() * (1.0 - self.soft_target_tau) + self.qf.get_param_values() * self.soft_target_tau) self.qf_loss_averages.append(qf_loss) self.policy_surr_averages.append(policy_surr) self.q_averages.append(qval) self.y_averages.append(ys) def evaluate(self, epoch, pool): logger.log("Collecting samples for evaluation") paths = parallel_sampler.sample_paths( policy_params=self.policy.get_param_values(), max_samples=self.eval_samples, max_path_length=self.max_path_length, ) average_discounted_return = np.mean([ special.discount_return(path["rewards"], self.discount) for path in paths ]) returns = [sum(path["rewards"]) for path in paths] all_qs = np.concatenate(self.q_averages) all_ys = np.concatenate(self.y_averages) average_q_loss = np.mean(self.qf_loss_averages) average_policy_surr = np.mean(self.policy_surr_averages) average_action = np.mean( np.square(np.concatenate([path["actions"] for path in paths]))) policy_reg_param_norm = np.linalg.norm( self.policy.get_param_values(regularizable=True)) qfun_reg_param_norm = np.linalg.norm( self.qf.get_param_values(regularizable=True)) logger.record_tabular('Epoch', epoch) logger.record_tabular('AverageReturn', np.mean(returns)) logger.record_tabular('StdReturn', np.std(returns)) logger.record_tabular('MaxReturn', np.max(returns)) logger.record_tabular('MinReturn', np.min(returns)) if self.es_path_returns: logger.record_tabular('AverageEsReturn', np.mean(self.es_path_returns)) logger.record_tabular('StdEsReturn', np.std(self.es_path_returns)) logger.record_tabular('MaxEsReturn', np.max(self.es_path_returns)) logger.record_tabular('MinEsReturn', np.min(self.es_path_returns)) logger.record_tabular('AverageDiscountedReturn', average_discounted_return) logger.record_tabular('AverageQLoss', average_q_loss) logger.record_tabular('AveragePolicySurr', average_policy_surr) logger.record_tabular('AverageQ', np.mean(all_qs)) logger.record_tabular('AverageAbsQ', np.mean(np.abs(all_qs))) logger.record_tabular('AverageY', np.mean(all_ys)) logger.record_tabular('AverageAbsY', np.mean(np.abs(all_ys))) logger.record_tabular('AverageAbsQYDiff', np.mean(np.abs(all_qs - all_ys))) logger.record_tabular('AverageAction', average_action) logger.record_tabular('PolicyRegParamNorm', policy_reg_param_norm) logger.record_tabular('QFunRegParamNorm', qfun_reg_param_norm) self.env.log_diagnostics(paths) self.policy.log_diagnostics(paths) self.qf_loss_averages = [] self.policy_surr_averages = [] self.q_averages = [] self.y_averages = [] self.es_path_returns = [] def update_plot(self): if self.plot: self.plotter.update_plot(self.policy, self.max_path_length) def get_epoch_snapshot(self, epoch): return dict( env=self.env, epoch=epoch, qf=self.qf, policy=self.policy, target_qf=self.opt_info["target_qf"], target_policy=self.opt_info["target_policy"], es=self.es, )
class CEM(RLAlgorithm, Serializable): def __init__(self, env, policy, n_itr=500, max_path_length=500, discount=0.99, init_std=1., n_samples=100, batch_size=None, best_frac=0.05, extra_std=1., extra_decay_time=100, plot=False, n_evals=1, **kwargs): """ :param n_itr: Number of iterations. :param max_path_length: Maximum length of a single rollout. :param batch_size: # of samples from trajs from param distribution, when this is set, n_samples is ignored :param discount: Discount. :param plot: Plot evaluation run after each iteration. :param init_std: Initial std for param distribution :param extra_std: Decaying std added to param distribution at each iteration :param extra_decay_time: Iterations that it takes to decay extra std :param n_samples: #of samples from param distribution :param best_frac: Best fraction of the sampled params :param n_evals: # of evals per sample from the param distr. returned score is mean - stderr of evals :return: """ Serializable.quick_init(self, locals()) self.env = env self.policy = policy self.batch_size = batch_size self.plot = plot self.extra_decay_time = extra_decay_time self.extra_std = extra_std self.best_frac = best_frac self.n_samples = n_samples self.init_std = init_std self.discount = discount self.max_path_length = max_path_length self.n_itr = n_itr self.n_evals = n_evals self.plotter = Plotter() def train(self): parallel_sampler.populate_task(self.env, self.policy) if self.plot: self.plotter.init_plot(self.env, self.policy) cur_std = self.init_std cur_mean = self.policy.get_param_values() # K = cur_mean.size n_best = max(1, int(self.n_samples * self.best_frac)) for itr in range(self.n_itr): # sample around the current distribution extra_var_mult = max(1.0 - itr / self.extra_decay_time, 0) sample_std = np.sqrt( np.square(cur_std) + np.square(self.extra_std) * extra_var_mult) if self.batch_size is None: criterion = 'paths' threshold = self.n_samples else: criterion = 'samples' threshold = self.batch_size infos = stateful_pool.singleton_pool.run_collect( _worker_rollout_policy, threshold=threshold, args=(dict( cur_mean=cur_mean, sample_std=sample_std, max_path_length=self.max_path_length, discount=self.discount, criterion=criterion, n_evals=self.n_evals), )) xs = np.asarray([info[0] for info in infos]) paths = [info[1] for info in infos] fs = np.array([path['returns'][0] for path in paths]) print((xs.shape, fs.shape)) best_inds = (-fs).argsort()[:n_best] best_xs = xs[best_inds] cur_mean = best_xs.mean(axis=0) cur_std = best_xs.std(axis=0) best_x = best_xs[0] logger.push_prefix('itr #%d | ' % itr) logger.record_tabular('Iteration', itr) logger.record_tabular('CurStdMean', np.mean(cur_std)) undiscounted_returns = np.array( [path['undiscounted_return'] for path in paths]) logger.record_tabular('AverageReturn', np.mean(undiscounted_returns)) logger.record_tabular('StdReturn', np.std(undiscounted_returns)) logger.record_tabular('MaxReturn', np.max(undiscounted_returns)) logger.record_tabular('MinReturn', np.min(undiscounted_returns)) logger.record_tabular('AverageDiscountedReturn', np.mean(fs)) logger.record_tabular('NumTrajs', len(paths)) paths = list(chain( *[d['full_paths'] for d in paths])) # flatten paths for the case n_evals > 1 logger.record_tabular( 'AvgTrajLen', np.mean([len(path['returns']) for path in paths])) self.policy.set_param_values(best_x) self.env.log_diagnostics(paths) self.policy.log_diagnostics(paths) logger.save_itr_params( itr, dict( itr=itr, policy=self.policy, env=self.env, cur_mean=cur_mean, cur_std=cur_std, )) logger.dump_tabular(with_prefix=False) logger.pop_prefix() if self.plot: self.plotter.update_plot(self.policy, self.max_path_length) parallel_sampler.terminate_task() self.plotter.shutdown()
class CMAES(RLAlgorithm, Serializable): def __init__(self, env, policy, n_itr=500, max_path_length=500, discount=0.99, sigma0=1., batch_size=None, plot=False, **kwargs): """ :param n_itr: Number of iterations. :param max_path_length: Maximum length of a single rollout. :param batch_size: # of samples from trajs from param distribution, when this is set, n_samples is ignored :param discount: Discount. :param plot: Plot evaluation run after each iteration. :param sigma0: Initial std for param dist :return: """ Serializable.quick_init(self, locals()) self.env = env self.policy = policy self.plot = plot self.sigma0 = sigma0 self.discount = discount self.max_path_length = max_path_length self.n_itr = n_itr self.batch_size = batch_size self.plotter = Plotter() def train(self): cur_std = self.sigma0 cur_mean = self.policy.get_param_values() es = cma.CMAEvolutionStrategy(cur_mean, cur_std) parallel_sampler.populate_task(self.env, self.policy) if self.plot: self.plotter.init_plot(self.env, self.policy) cur_std = self.sigma0 cur_mean = self.policy.get_param_values() itr = 0 while itr < self.n_itr and not es.stop(): if self.batch_size is None: # Sample from multivariate normal distribution. xs = es.ask() xs = np.asarray(xs) # For each sample, do a rollout. infos = (stateful_pool.singleton_pool.run_map( sample_return, [(x, self.max_path_length, self.discount) for x in xs])) else: cum_len = 0 infos = [] xss = [] done = False while not done: sbs = stateful_pool.singleton_pool.n_parallel * 2 # Sample from multivariate normal distribution. # You want to ask for sbs samples here. xs = es.ask(sbs) xs = np.asarray(xs) xss.append(xs) sinfos = stateful_pool.singleton_pool.run_map( sample_return, [(x, self.max_path_length, self.discount) for x in xs]) for info in sinfos: infos.append(info) cum_len += len(info['returns']) if cum_len >= self.batch_size: xs = np.concatenate(xss) done = True break # Evaluate fitness of samples (negative as it is minimization # problem). fs = -np.array([info['returns'][0] for info in infos]) # When batching, you could have generated too many samples compared # to the actual evaluations. So we cut it off in this case. xs = xs[:len(fs)] # Update CMA-ES params based on sample fitness. es.tell(xs, fs) logger.push_prefix('itr #%d | ' % itr) logger.record_tabular('Iteration', itr) logger.record_tabular('CurStdMean', np.mean(cur_std)) undiscounted_returns = np.array( [info['undiscounted_return'] for info in infos]) logger.record_tabular('AverageReturn', np.mean(undiscounted_returns)) logger.record_tabular('StdReturn', np.mean(undiscounted_returns)) logger.record_tabular('MaxReturn', np.max(undiscounted_returns)) logger.record_tabular('MinReturn', np.min(undiscounted_returns)) logger.record_tabular('AverageDiscountedReturn', np.mean(fs)) logger.record_tabular( 'AvgTrajLen', np.mean([len(info['returns']) for info in infos])) self.policy.log_diagnostics(infos) logger.save_itr_params( itr, dict( itr=itr, policy=self.policy, env=self.env, )) logger.dump_tabular(with_prefix=False) if self.plot: self.plotter.update_plot(self.policy, self.max_path_length) logger.pop_prefix() # Update iteration. itr += 1 # Set final params. self.policy.set_param_values(es.result()[0]) parallel_sampler.terminate_task() self.plotter.shutdown()