class MBPO(RLAlgorithm): """Model-Based Policy Optimization (MBPO) References ---------- Michael Janner, Justin Fu, Marvin Zhang, Sergey Levine. When to Trust Your Model: Model-Based Policy Optimization. arXiv preprint arXiv:1906.08253. 2019. """ def __init__( self, training_environment, evaluation_environment, policy, Qs, pool, static_fns, plotter=None, tf_summaries=False, lr=3e-4, reward_scale=1.0, target_entropy='auto', discount=0.99, tau=5e-3, target_update_interval=1, action_prior='uniform', reparameterize=False, store_extra_policy_info=False, deterministic=False, model_train_freq=250, num_networks=7, num_elites=5, model_retain_epochs=20, load_model_dir=None, rollout_batch_size=100e3, real_ratio=0.1, rollout_schedule=[20, 100, 1, 1], hidden_dim=200, max_model_t=None, **kwargs, ): """ Args: env (`SoftlearningEnv`): Environment used for training. policy: A policy function approximator. initial_exploration_policy: ('Policy'): A policy that we use for initial exploration which is not trained by the algorithm. Qs: Q-function approximators. The min of these approximators will be used. Usage of at least two Q-functions improves performance by reducing overestimation bias. pool (`PoolBase`): Replay pool to add gathered samples to. plotter (`QFPolicyPlotter`): Plotter instance to be used for visualizing Q-function during training. lr (`float`): Learning rate used for the function approximators. discount (`float`): Discount factor for Q-function updates. tau (`float`): Soft value function target update weight. target_update_interval ('int'): Frequency at which target network updates occur in iterations. reparameterize ('bool'): If True, we use a gradient estimator for the policy derived using the reparameterization trick. We use a likelihood ratio based estimator otherwise. """ super(MBPO, self).__init__(**kwargs) obs_dim = np.prod(training_environment.observation_space.shape) act_dim = np.prod(training_environment.action_space.shape) self._model_params = dict(obs_dim=obs_dim, act_dim=act_dim, hidden_dim=hidden_dim, num_networks=num_networks, num_elites=num_elites) if load_model_dir is not None: self._model_params['load_model'] = True self._model_params['model_dir'] = load_model_dir self._model = construct_model(**self._model_params) self._static_fns = static_fns self.fake_env = FakeEnv(self._model, self._static_fns) self._rollout_schedule = rollout_schedule self._max_model_t = max_model_t self._model_retain_epochs = model_retain_epochs self._model_train_freq = model_train_freq self._rollout_batch_size = int(rollout_batch_size) self._deterministic = deterministic self._real_ratio = real_ratio self._log_dir = os.getcwd() self._writer = Writer(self._log_dir) self._training_environment = training_environment self._evaluation_environment = evaluation_environment self._policy = policy self._Qs = Qs self._Q_targets = tuple(tf.keras.models.clone_model(Q) for Q in Qs) self._pool = pool self._plotter = plotter self._tf_summaries = tf_summaries self._policy_lr = lr self._Q_lr = lr self._reward_scale = reward_scale self._target_entropy = ( -np.prod(self._training_environment.action_space.shape) if target_entropy == 'auto' else target_entropy) print('[ MBPO ] Target entropy: {}'.format(self._target_entropy)) self._discount = discount self._tau = tau self._target_update_interval = target_update_interval self._action_prior = action_prior self._reparameterize = reparameterize self._store_extra_policy_info = store_extra_policy_info observation_shape = self._training_environment.active_observation_shape action_shape = self._training_environment.action_space.shape ### @ anyboby fixed pool size, reallocate causes memory leak obs_space = self._pool._observation_space act_space = self._pool._action_space rollouts_per_epoch = self._rollout_batch_size * self._epoch_length / self._model_train_freq model_steps_per_epoch = int(self._rollout_schedule[-1] * rollouts_per_epoch) mpool_size = self._model_retain_epochs * model_steps_per_epoch self._model_pool = SimpleReplayPool(obs_space, act_space, mpool_size) assert len(observation_shape) == 1, observation_shape self._observation_shape = observation_shape assert len(action_shape) == 1, action_shape self._action_shape = action_shape self._build() def _build(self): self._training_ops = {} self._init_global_step() self._init_placeholders() self._init_actor_update() self._init_critic_update() def _train(self): """Return a generator that performs RL training. Args: env (`SoftlearningEnv`): Environment used for training. policy (`Policy`): Policy used for training initial_exploration_policy ('Policy'): Policy used for exploration If None, then all exploration is done using policy pool (`PoolBase`): Sample pool to add samples to """ training_environment = self._training_environment evaluation_environment = self._evaluation_environment policy = self._policy pool = self._pool model_metrics = {} if not self._training_started: self._init_training() self._initial_exploration_hook(training_environment, self._initial_exploration_policy, pool) self.sampler.initialize(training_environment, policy, pool) gt.reset_root() gt.rename_root('RLAlgorithm') gt.set_def_unique(False) self._training_before_hook() for self._epoch in gt.timed_for(range(self._epoch, self._n_epochs)): self._epoch_before_hook() gt.stamp('epoch_before_hook') self._training_progress = Progress(self._epoch_length * self._n_train_repeat) start_samples = self.sampler._total_samples for i in count(): samples_now = self.sampler._total_samples self._timestep = samples_now - start_samples if (samples_now >= start_samples + self._epoch_length and self.ready_to_train): break self._timestep_before_hook() gt.stamp('timestep_before_hook') if self._timestep % self._model_train_freq == 0 and self._real_ratio < 1.0: self._training_progress.pause() print('[ MBPO ] log_dir: {} | ratio: {}'.format( self._log_dir, self._real_ratio)) print( '[ MBPO ] Training model at epoch {} | freq {} | timestep {} (total: {}) | epoch train steps: {} (total: {})' .format(self._epoch, self._model_train_freq, self._timestep, self._total_timestep, self._train_steps_this_epoch, self._num_train_steps)) model_train_metrics = self._train_model( batch_size=256, max_epochs=None, holdout_ratio=0.2, max_t=self._max_model_t) model_metrics.update(model_train_metrics) gt.stamp('epoch_train_model') self._set_rollout_length() if self._rollout_batch_size > 30000: factor = self._rollout_batch_size // 30000 + 1 mini_batch = self._rollout_batch_size // factor for i in range(factor): model_rollout_metrics = self._rollout_model( rollout_batch_size=mini_batch, deterministic=self._deterministic) else: model_rollout_metrics = self._rollout_model( rollout_batch_size=self._rollout_batch_size, deterministic=self._deterministic) model_metrics.update(model_rollout_metrics) gt.stamp('epoch_rollout_model') # self._visualize_model(self._evaluation_environment, self._total_timestep) self._training_progress.resume() self._do_sampling(timestep=self._total_timestep) gt.stamp('sample') if self.ready_to_train: self._do_training_repeats(timestep=self._total_timestep) gt.stamp('train') self._timestep_after_hook() gt.stamp('timestep_after_hook') training_paths = self.sampler.get_last_n_paths( math.ceil(self._epoch_length / self.sampler._max_path_length)) gt.stamp('training_paths') evaluation_paths = self._evaluation_paths(policy, evaluation_environment) gt.stamp('evaluation_paths') training_metrics = self._evaluate_rollouts(training_paths, training_environment) gt.stamp('training_metrics') if evaluation_paths: evaluation_metrics = self._evaluate_rollouts( evaluation_paths, evaluation_environment) gt.stamp('evaluation_metrics') else: evaluation_metrics = {} self._epoch_after_hook(training_paths) gt.stamp('epoch_after_hook') sampler_diagnostics = self.sampler.get_diagnostics() diagnostics = self.get_diagnostics( iteration=self._total_timestep, batch=self._evaluation_batch(), training_paths=training_paths, evaluation_paths=evaluation_paths) time_diagnostics = gt.get_times().stamps.itrs diagnostics.update( OrderedDict(( *((f'evaluation/{key}', evaluation_metrics[key]) for key in sorted(evaluation_metrics.keys())), *((f'training/{key}', training_metrics[key]) for key in sorted(training_metrics.keys())), *((f'times/{key}', time_diagnostics[key][-1]) for key in sorted(time_diagnostics.keys())), *((f'sampler/{key}', sampler_diagnostics[key]) for key in sorted(sampler_diagnostics.keys())), *((f'model/{key}', model_metrics[key]) for key in sorted(model_metrics.keys())), ('epoch', self._epoch), ('timestep', self._timestep), ('timesteps_total', self._total_timestep), ('train-steps', self._num_train_steps), ))) if self._eval_render_mode is not None and hasattr( evaluation_environment, 'render_rollouts'): training_environment.render_rollouts(evaluation_paths) yield diagnostics self.sampler.terminate() self._training_after_hook() self._training_progress.close() yield {'done': True, **diagnostics} def train(self, *args, **kwargs): return self._train(*args, **kwargs) def _log_policy(self): save_path = os.path.join(self._log_dir, 'models') filesystem.mkdir(save_path) weights = self._policy.get_weights() data = {'policy_weights': weights} full_path = os.path.join(save_path, 'policy_{}.pkl'.format(self._total_timestep)) print('Saving policy to: {}'.format(full_path)) pickle.dump(data, open(full_path, 'wb')) def _log_model(self): save_path = os.path.join(self._log_dir, 'models') filesystem.mkdir(save_path) print('Saving model to: {}'.format(save_path)) self._model.save(save_path, self._total_timestep) def _set_rollout_length(self): min_epoch, max_epoch, min_length, max_length = self._rollout_schedule if self._epoch <= min_epoch: y = min_length else: dx = (self._epoch - min_epoch) / (max_epoch - min_epoch) dx = min(dx, 1) y = dx * (max_length - min_length) + min_length self._rollout_length = int(y) print( '[ Model Length ] Epoch: {} (min: {}, max: {}) | Length: {} (min: {} , max: {})' .format(self._epoch, min_epoch, max_epoch, self._rollout_length, min_length, max_length)) def _reallocate_model_pool(self): obs_space = self._pool._observation_space act_space = self._pool._action_space rollouts_per_epoch = self._rollout_batch_size * self._epoch_length / self._model_train_freq model_steps_per_epoch = int(self._rollout_length * rollouts_per_epoch) new_pool_size = self._model_retain_epochs * model_steps_per_epoch if not hasattr(self, '_model_pool'): print( '[ MBPO ] Initializing new model pool with size {:.2e}'.format( new_pool_size)) self._model_pool = SimpleReplayPool(obs_space, act_space, new_pool_size) elif self._model_pool._max_size != new_pool_size: print('[ MBPO ] Updating model pool | {:.2e} --> {:.2e}'.format( self._model_pool._max_size, new_pool_size)) samples = self._model_pool.return_all_samples() new_pool = SimpleReplayPool(obs_space, act_space, new_pool_size) new_pool.add_samples(samples) assert self._model_pool.size == new_pool.size self._model_pool = new_pool def _train_model(self, **kwargs): env_samples = self._pool.return_all_samples() train_inputs, train_outputs = format_samples_for_training(env_samples) model_metrics = self._model.train(train_inputs, train_outputs, **kwargs) return model_metrics def _rollout_model(self, rollout_batch_size, **kwargs): print( '[ Model Rollout ] Starting | Epoch: {} | Rollout length: {} | Batch size: {}' .format(self._epoch, self._rollout_length, rollout_batch_size)) batch = self.sampler.random_batch(rollout_batch_size) obs = batch['observations'] steps_added = [] for i in range(self._rollout_length): act = self._policy.actions_np(obs) next_obs, rew, term, info = self.fake_env.step(obs, act, **kwargs) steps_added.append(len(obs)) samples = { 'observations': obs, 'actions': act, 'next_observations': next_obs, 'rewards': rew, 'terminals': term } self._model_pool.add_samples(samples) nonterm_mask = ~term.squeeze(-1) if nonterm_mask.sum() == 0: print('[ Model Rollout ] Breaking early: {} | {} / {}'.format( i, nonterm_mask.sum(), nonterm_mask.shape)) break obs = next_obs[nonterm_mask] mean_rollout_length = sum(steps_added) / rollout_batch_size rollout_stats = {'mean_rollout_length': mean_rollout_length} print( '[ Model Rollout ] Added: {:.1e} | Model pool: {:.1e} (max {:.1e}) | Length: {} | Train rep: {}' .format(sum(steps_added), self._model_pool.size, self._model_pool._max_size, mean_rollout_length, self._n_train_repeat)) return rollout_stats def _visualize_model(self, env, timestep): ## save env state state = env.unwrapped.state_vector() qpos_dim = len(env.unwrapped.sim.data.qpos) qpos = state[:qpos_dim] qvel = state[qpos_dim:] print('[ Visualization ] Starting | Epoch {} | Log dir: {}\n'.format( self._epoch, self._log_dir)) visualize_policy(env, self.fake_env, self._policy, self._writer, timestep) print('[ Visualization ] Done') ## set env state env.unwrapped.set_state(qpos, qvel) def _training_batch(self, batch_size=None): batch_size = batch_size or self.sampler._batch_size env_batch_size = int(batch_size * self._real_ratio) model_batch_size = batch_size - env_batch_size ## can sample from the env pool even if env_batch_size == 0 env_batch = self._pool.random_batch(env_batch_size) if model_batch_size > 0: model_batch = self._model_pool.random_batch(model_batch_size) keys = env_batch.keys() batch = { k: np.concatenate((env_batch[k], model_batch[k]), axis=0) for k in keys } else: ## if real_ratio == 1.0, no model pool was ever allocated, ## so skip the model pool sampling batch = env_batch return batch def _init_global_step(self): self.global_step = training_util.get_or_create_global_step() self._training_ops.update( {'increment_global_step': training_util._increment_global_step(1)}) def _init_placeholders(self): """Create input placeholders for the SAC algorithm. Creates `tf.placeholder`s for: - observation - next observation - action - reward - terminals """ self._iteration_ph = tf.placeholder(tf.int64, shape=None, name='iteration') self._observations_ph = tf.placeholder( tf.float32, shape=(None, *self._observation_shape), name='observation', ) self._next_observations_ph = tf.placeholder( tf.float32, shape=(None, *self._observation_shape), name='next_observation', ) self._actions_ph = tf.placeholder( tf.float32, shape=(None, *self._action_shape), name='actions', ) self._rewards_ph = tf.placeholder( tf.float32, shape=(None, 1), name='rewards', ) self._terminals_ph = tf.placeholder( tf.float32, shape=(None, 1), name='terminals', ) if self._store_extra_policy_info: self._log_pis_ph = tf.placeholder( tf.float32, shape=(None, 1), name='log_pis', ) self._raw_actions_ph = tf.placeholder( tf.float32, shape=(None, *self._action_shape), name='raw_actions', ) def _get_Q_target(self): next_actions = self._policy.actions([self._next_observations_ph]) next_log_pis = self._policy.log_pis([self._next_observations_ph], next_actions) next_Qs_values = tuple( Q([self._next_observations_ph, next_actions]) for Q in self._Q_targets) min_next_Q = tf.reduce_min(next_Qs_values, axis=0) next_value = min_next_Q - self._alpha * next_log_pis Q_target = td_target(reward=self._reward_scale * self._rewards_ph, discount=self._discount, next_value=(1 - self._terminals_ph) * next_value) return Q_target def _init_critic_update(self): """Create minimization operation for critic Q-function. Creates a `tf.optimizer.minimize` operation for updating critic Q-function with gradient descent, and appends it to `self._training_ops` attribute. """ Q_target = tf.stop_gradient(self._get_Q_target()) assert Q_target.shape.as_list() == [None, 1] Q_values = self._Q_values = tuple( Q([self._observations_ph, self._actions_ph]) for Q in self._Qs) Q_losses = self._Q_losses = tuple( tf.losses.mean_squared_error( labels=Q_target, predictions=Q_value, weights=0.5) for Q_value in Q_values) self._Q_optimizers = tuple( tf.train.AdamOptimizer(learning_rate=self._Q_lr, name='{}_{}_optimizer'.format(Q._name, i)) for i, Q in enumerate(self._Qs)) Q_training_ops = tuple( tf.contrib.layers.optimize_loss(Q_loss, self.global_step, learning_rate=self._Q_lr, optimizer=Q_optimizer, variables=Q.trainable_variables, increment_global_step=False, summaries=(( "loss", "gradients", "gradient_norm", "global_gradient_norm" ) if self._tf_summaries else ())) for i, (Q, Q_loss, Q_optimizer) in enumerate( zip(self._Qs, Q_losses, self._Q_optimizers))) self._training_ops.update({'Q': tf.group(Q_training_ops)}) def _init_actor_update(self): """Create minimization operations for policy and entropy. Creates a `tf.optimizer.minimize` operations for updating policy and entropy with gradient descent, and adds them to `self._training_ops` attribute. """ actions = self._policy.actions([self._observations_ph]) log_pis = self._policy.log_pis([self._observations_ph], actions) assert log_pis.shape.as_list() == [None, 1] log_alpha = self._log_alpha = tf.get_variable('log_alpha', dtype=tf.float32, initializer=0.0) alpha = tf.exp(log_alpha) if isinstance(self._target_entropy, Number): alpha_loss = -tf.reduce_mean( log_alpha * tf.stop_gradient(log_pis + self._target_entropy)) self._alpha_optimizer = tf.train.AdamOptimizer( self._policy_lr, name='alpha_optimizer') self._alpha_train_op = self._alpha_optimizer.minimize( loss=alpha_loss, var_list=[log_alpha]) self._training_ops.update( {'temperature_alpha': self._alpha_train_op}) self._alpha = alpha if self._action_prior == 'normal': policy_prior = tf.contrib.distributions.MultivariateNormalDiag( loc=tf.zeros(self._action_shape), scale_diag=tf.ones(self._action_shape)) policy_prior_log_probs = policy_prior.log_prob(actions) elif self._action_prior == 'uniform': policy_prior_log_probs = 0.0 Q_log_targets = tuple( Q([self._observations_ph, actions]) for Q in self._Qs) min_Q_log_target = tf.reduce_min(Q_log_targets, axis=0) if self._reparameterize: policy_kl_losses = (alpha * log_pis - min_Q_log_target - policy_prior_log_probs) else: raise NotImplementedError assert policy_kl_losses.shape.as_list() == [None, 1] policy_loss = tf.reduce_mean(policy_kl_losses) self._policy_optimizer = tf.train.AdamOptimizer( learning_rate=self._policy_lr, name="policy_optimizer") policy_train_op = tf.contrib.layers.optimize_loss( policy_loss, self.global_step, learning_rate=self._policy_lr, optimizer=self._policy_optimizer, variables=self._policy.trainable_variables, increment_global_step=False, summaries=("loss", "gradients", "gradient_norm", "global_gradient_norm") if self._tf_summaries else ()) self._training_ops.update({'policy_train_op': policy_train_op}) def _init_training(self): self._update_target(tau=1.0) def _update_target(self, tau=None): tau = tau or self._tau for Q, Q_target in zip(self._Qs, self._Q_targets): source_params = Q.get_weights() target_params = Q_target.get_weights() Q_target.set_weights([ tau * source + (1.0 - tau) * target for source, target in zip(source_params, target_params) ]) def _do_training(self, iteration, batch): """Runs the operations for updating training and target ops.""" self._training_progress.update() self._training_progress.set_description() feed_dict = self._get_feed_dict(iteration, batch) self._session.run(self._training_ops, feed_dict) if iteration % self._target_update_interval == 0: # Run target ops here. self._update_target() def _get_feed_dict(self, iteration, batch): """Construct TensorFlow feed_dict from sample batch.""" feed_dict = { self._observations_ph: batch['observations'], self._actions_ph: batch['actions'], self._next_observations_ph: batch['next_observations'], self._rewards_ph: batch['rewards'], self._terminals_ph: batch['terminals'], } if self._store_extra_policy_info: feed_dict[self._log_pis_ph] = batch['log_pis'] feed_dict[self._raw_actions_ph] = batch['raw_actions'] if iteration is not None: feed_dict[self._iteration_ph] = iteration return feed_dict def get_diagnostics(self, iteration, batch, training_paths, evaluation_paths): """Return diagnostic information as ordered dictionary. Records mean and standard deviation of Q-function and state value function, and TD-loss (mean squared Bellman error) for the sample batch. Also calls the `draw` method of the plotter, if plotter defined. """ feed_dict = self._get_feed_dict(iteration, batch) (Q_values, Q_losses, alpha, global_step) = self._session.run( (self._Q_values, self._Q_losses, self._alpha, self.global_step), feed_dict) diagnostics = OrderedDict({ 'Q-avg': np.mean(Q_values), 'Q-std': np.std(Q_values), 'Q_loss': np.mean(Q_losses), 'alpha': alpha, }) policy_diagnostics = self._policy.get_diagnostics( batch['observations']) diagnostics.update({ f'policy/{key}': value for key, value in policy_diagnostics.items() }) if self._plotter: self._plotter.draw() return diagnostics @property def tf_saveables(self): saveables = { '_policy_optimizer': self._policy_optimizer, **{ f'Q_optimizer_{i}': optimizer for i, optimizer in enumerate(self._Q_optimizers) }, '_log_alpha': self._log_alpha, } if hasattr(self, '_alpha_optimizer'): saveables['_alpha_optimizer'] = self._alpha_optimizer return saveables def save_model(self, dir): self._model.save(savedir=dir, timestep=self._epoch + 1)
def train(self, inputs, targets, batch_size=32, max_epochs=None, max_epochs_since_update=5, hide_progress=False, holdout_ratio=0.0, max_logging=5000, max_grad_updates=None, timer=None, max_t=None): """Trains/Continues network training Arguments: inputs (np.ndarray): Network inputs in the training dataset in rows. targets (np.ndarray): Network target outputs in the training dataset in rows corresponding to the rows in inputs. batch_size (int): The minibatch size to be used for training. epochs (int): Number of epochs (full network passes that will be done. hide_progress (bool): If True, hides the progress bar shown at the beginning of training. Returns: None """ self._max_epochs_since_update = max_epochs_since_update self._start_train() break_train = False def shuffle_rows(arr): idxs = np.argsort(np.random.uniform(size=arr.shape), axis=-1) return arr[np.arange(arr.shape[0])[:, None], idxs] # Split into training and holdout sets num_holdout = min(int(inputs.shape[0] * holdout_ratio), max_logging) permutation = np.random.permutation(inputs.shape[0]) inputs, holdout_inputs = inputs[permutation[num_holdout:]], inputs[ permutation[:num_holdout]] targets, holdout_targets = targets[permutation[num_holdout:]], targets[ permutation[:num_holdout]] holdout_inputs = np.tile(holdout_inputs[None], [self.num_nets, 1, 1]) holdout_targets = np.tile(holdout_targets[None], [self.num_nets, 1, 1]) print('[ BNN ] Training {} | Holdout: {}'.format( inputs.shape, holdout_inputs.shape)) with self.sess.as_default(): self.scaler.fit(inputs) idxs = np.random.randint(inputs.shape[0], size=[self.num_nets, inputs.shape[0]]) if hide_progress: progress = Silent() else: progress = Progress(max_epochs) if max_epochs: epoch_iter = range(max_epochs) else: epoch_iter = itertools.count() # else: # epoch_range = trange(epochs, unit="epoch(s)", desc="Network training") t0 = time.time() grad_updates = 0 for epoch in epoch_iter: for batch_num in range(int(np.ceil(idxs.shape[-1] / batch_size))): batch_idxs = idxs[:, batch_num * batch_size:(batch_num + 1) * batch_size] self.sess.run(self.train_op, feed_dict={ self.sy_train_in: inputs[batch_idxs], self.sy_train_targ: targets[batch_idxs] }) grad_updates += 1 idxs = shuffle_rows(idxs) if not hide_progress: if holdout_ratio < 1e-12: losses = self.sess.run(self.mse_loss, feed_dict={ self.sy_train_in: inputs[idxs[:, :max_logging]], self.sy_train_targ: targets[idxs[:, :max_logging]] }) named_losses = [['M{}'.format(i), losses[i]] for i in range(len(losses))] progress.set_description(named_losses) else: losses = self.sess.run(self.mse_loss, feed_dict={ self.sy_train_in: inputs[idxs[:, :max_logging]], self.sy_train_targ: targets[idxs[:, :max_logging]] }) holdout_losses = self.sess.run(self.mse_loss, feed_dict={ self.sy_train_in: holdout_inputs, self.sy_train_targ: holdout_targets }) named_losses = [['M{}'.format(i), losses[i]] for i in range(len(losses))] named_holdout_losses = [[ 'V{}'.format(i), holdout_losses[i] ] for i in range(len(holdout_losses))] named_losses = named_losses + named_holdout_losses + [[ 'T', time.time() - t0 ]] progress.set_description(named_losses) break_train = self._save_best(epoch, holdout_losses) progress.update() t = time.time() - t0 if break_train or (max_grad_updates and grad_updates > max_grad_updates): break if max_t and t > max_t: descr = 'Breaking because of timeout: {}! (max: {})'.format( t, max_t) progress.append_description(descr) # print('Breaking because of timeout: {}! | (max: {})\n'.format(t, max_t)) # time.sleep(5) break progress.stamp() if timer: timer.stamp('bnn_train') self._set_state() if timer: timer.stamp('bnn_set_state') holdout_losses = self.sess.run(self.mse_loss, feed_dict={ self.sy_train_in: holdout_inputs, self.sy_train_targ: holdout_targets }) if timer: timer.stamp('bnn_holdout') self._end_train(holdout_losses) if timer: timer.stamp('bnn_end') val_loss = (np.sort(holdout_losses)[:self.num_elites]).mean() model_metrics = {'val_loss': val_loss} print('[ BNN ] Holdout', np.sort(holdout_losses), model_metrics) return OrderedDict(model_metrics)
def _train(self): """Return a generator that performs RL training. Args: env (`SoftlearningEnv`): Environment used for training. policy (`Policy`): Policy used for training initial_exploration_policy ('Policy'): Policy used for exploration If None, then all exploration is done using policy pool (`PoolBase`): Sample pool to add samples to """ training_environment = self._training_environment evaluation_environment = self._evaluation_environment policy = self._policy pool = self._pool model_metrics = {} if not self._training_started: self._init_training() self._initial_exploration_hook(training_environment, self._initial_exploration_policy, pool) self.sampler.initialize(training_environment, policy, pool) gt.reset_root() gt.rename_root('RLAlgorithm') gt.set_def_unique(False) self._training_before_hook() for self._epoch in gt.timed_for(range(self._epoch, self._n_epochs)): self._epoch_before_hook() gt.stamp('epoch_before_hook') self._training_progress = Progress(self._epoch_length * self._n_train_repeat) start_samples = self.sampler._total_samples for i in count(): samples_now = self.sampler._total_samples self._timestep = samples_now - start_samples if (samples_now >= start_samples + self._epoch_length and self.ready_to_train): break self._timestep_before_hook() gt.stamp('timestep_before_hook') if self._timestep % self._model_train_freq == 0 and self._real_ratio < 1.0: self._training_progress.pause() print('[ MBPO ] log_dir: {} | ratio: {}'.format( self._log_dir, self._real_ratio)) print( '[ MBPO ] Training model at epoch {} | freq {} | timestep {} (total: {}) | epoch train steps: {} (total: {})' .format(self._epoch, self._model_train_freq, self._timestep, self._total_timestep, self._train_steps_this_epoch, self._num_train_steps)) model_train_metrics = self._train_model( batch_size=256, max_epochs=None, holdout_ratio=0.2, max_t=self._max_model_t) model_metrics.update(model_train_metrics) gt.stamp('epoch_train_model') self._set_rollout_length() if self._rollout_batch_size > 30000: factor = self._rollout_batch_size // 30000 + 1 mini_batch = self._rollout_batch_size // factor for i in range(factor): model_rollout_metrics = self._rollout_model( rollout_batch_size=mini_batch, deterministic=self._deterministic) else: model_rollout_metrics = self._rollout_model( rollout_batch_size=self._rollout_batch_size, deterministic=self._deterministic) model_metrics.update(model_rollout_metrics) gt.stamp('epoch_rollout_model') # self._visualize_model(self._evaluation_environment, self._total_timestep) self._training_progress.resume() self._do_sampling(timestep=self._total_timestep) gt.stamp('sample') if self.ready_to_train: self._do_training_repeats(timestep=self._total_timestep) gt.stamp('train') self._timestep_after_hook() gt.stamp('timestep_after_hook') training_paths = self.sampler.get_last_n_paths( math.ceil(self._epoch_length / self.sampler._max_path_length)) gt.stamp('training_paths') evaluation_paths = self._evaluation_paths(policy, evaluation_environment) gt.stamp('evaluation_paths') training_metrics = self._evaluate_rollouts(training_paths, training_environment) gt.stamp('training_metrics') if evaluation_paths: evaluation_metrics = self._evaluate_rollouts( evaluation_paths, evaluation_environment) gt.stamp('evaluation_metrics') else: evaluation_metrics = {} self._epoch_after_hook(training_paths) gt.stamp('epoch_after_hook') sampler_diagnostics = self.sampler.get_diagnostics() diagnostics = self.get_diagnostics( iteration=self._total_timestep, batch=self._evaluation_batch(), training_paths=training_paths, evaluation_paths=evaluation_paths) time_diagnostics = gt.get_times().stamps.itrs diagnostics.update( OrderedDict(( *((f'evaluation/{key}', evaluation_metrics[key]) for key in sorted(evaluation_metrics.keys())), *((f'training/{key}', training_metrics[key]) for key in sorted(training_metrics.keys())), *((f'times/{key}', time_diagnostics[key][-1]) for key in sorted(time_diagnostics.keys())), *((f'sampler/{key}', sampler_diagnostics[key]) for key in sorted(sampler_diagnostics.keys())), *((f'model/{key}', model_metrics[key]) for key in sorted(model_metrics.keys())), ('epoch', self._epoch), ('timestep', self._timestep), ('timesteps_total', self._total_timestep), ('train-steps', self._num_train_steps), ))) if self._eval_render_mode is not None and hasattr( evaluation_environment, 'render_rollouts'): training_environment.render_rollouts(evaluation_paths) yield diagnostics self.sampler.terminate() self._training_after_hook() self._training_progress.close() yield {'done': True, **diagnostics}
class CMBPO(RLAlgorithm): """Model-Based Policy Optimization (MBPO) References ---------- Michael Janner, Justin Fu, Marvin Zhang, Sergey Levine. When to Trust Your Model: Model-Based Policy Optimization. arXiv preprint arXiv:1906.08253. 2019. """ def __init__( self, training_environment, evaluation_environment, mjc_model_environment, policy, Qs, pool, static_fns, plotter=None, tf_summaries=False, n_env_interacts=1e7, lr=3e-4, reward_scale=1.0, target_entropy='auto', discount=0.99, tau=5e-3, target_update_interval=1, action_prior='uniform', reparameterize=False, store_extra_policy_info=False, eval_every_n_steps=5e3, deterministic=False, model_train_freq=250, num_networks=7, num_elites=5, model_retain_epochs=20, rollout_batch_size=100e3, real_ratio=0.1, dyn_model_train_schedule=[20, 100, 1, 5], cost_model_train_schedule=[20, 100, 1, 30], cares_about_cost=False, policy_alpha=1, max_uncertainty_rew=None, max_uncertainty_c=None, rollout_mode='schedule', rollout_schedule=[20, 100, 1, 1], maxroll=80, max_tddyn_err=1e-5, max_tddyn_err_decay=.995, min_real_samples_per_epoch=1000, batch_size_policy=5000, hidden_dims=(200, 200, 200, 200), max_model_t=None, use_mjc_state_model=False, model_std_inc=0.02, **kwargs, ): """ Args: env (`SoftlearningEnv`): Environment used for training. policy: A policy function approximator. initial_exploration_policy: ('Policy'): A policy that we use for initial exploration which is not trained by the algorithm. Qs: Q-function approximators. The min of these approximators will be used. Usage of at least two Q-functions improves performance by reducing overestimation bias. pool (`PoolBase`): Replay pool to add gathered samples to. plotter (`QFPolicyPlotter`): Plotter instance to be used for visualizing Q-function during training. lr (`float`): Learning rate used for the function approximators. discount (`float`): Discount factor for Q-function updates. tau (`float`): Soft value function target update weight. target_update_interval ('int'): Frequency at which target network updates occur in iterations. reparameterize ('bool'): If True, we use a gradient estimator for the policy derived using the reparameterization trick. We use a likelihood ratio based estimator otherwise. """ super(CMBPO, self).__init__(**kwargs) self.obs_space = training_environment.observation_space self.act_space = training_environment.action_space self.obs_dim = np.prod(training_environment.observation_space.shape) self.act_dim = np.prod(training_environment.action_space.shape) self.n_env_interacts = n_env_interacts #### determine unstacked obs dim self.num_stacks = training_environment.stacks self.stacking_axis = training_environment.stacking_axis self.active_obs_dim = int(self.obs_dim / self.num_stacks) self.policy_alpha = policy_alpha self.cares_about_cost = cares_about_cost ## create fake environment for model self.fake_env = FakeEnv(training_environment, static_fns, num_networks=7, num_elites=3, hidden_dims=hidden_dims, cares_about_cost=cares_about_cost, session=self._session) self.use_mjc_state_model = use_mjc_state_model self._rollout_schedule = rollout_schedule self._max_model_t = max_model_t self._model_retain_epochs = model_retain_epochs self.eval_every_n_steps = eval_every_n_steps self._dyn_model_train_schedule = dyn_model_train_schedule self._cost_model_train_schedule = cost_model_train_schedule self._dyn_model_train_freq = 1 self._cost_model_train_freq = 1 self._rollout_batch_size = int(rollout_batch_size) self._max_uncertainty_rew = max_uncertainty_rew self._max_uncertainty_c = max_uncertainty_c self._deterministic = deterministic self._real_ratio = real_ratio self._log_dir = os.getcwd() self._writer = Writer(self._log_dir) self._training_environment = training_environment self._evaluation_environment = evaluation_environment self._mjc_model_environment = mjc_model_environment self.perturbed_env = PerturbedEnv(self._mjc_model_environment, std_inc=model_std_inc) self._policy = policy self._initial_exploration_policy = policy #overwriting initial _exploration policy, not implemented for cpo yet self._pool = pool self._plotter = plotter self._tf_summaries = tf_summaries # set up pool pi_info_shapes = { k: v.shape.as_list()[1:] for k, v in self._policy.pi_info_phs.items() } self._pool.initialize(pi_info_shapes, gamma=self._policy.gamma, lam=self._policy.lam, cost_gamma=self._policy.cost_gamma, cost_lam=self._policy.cost_lam) self._policy_lr = lr self._reward_scale = reward_scale self._target_entropy = ( -np.prod(self._training_environment.action_space.shape) if target_entropy == 'auto' else target_entropy) print('[ MBPO ] Target entropy: {}'.format(self._target_entropy)) self._discount = discount self._tau = tau self._target_update_interval = target_update_interval self._action_prior = action_prior self._reparameterize = reparameterize self._store_extra_policy_info = store_extra_policy_info observation_shape = self._training_environment.active_observation_shape action_shape = self._training_environment.action_space.shape assert len(observation_shape) == 1, observation_shape self._observation_shape = observation_shape assert len(action_shape) == 1, action_shape self._action_shape = action_shape ### model sampler and buffer self.rollout_mode = rollout_mode self.max_tddyn_err = max_tddyn_err self.max_tddyn_err_decay = max_tddyn_err_decay self.min_real_samples_per_epoch = min_real_samples_per_epoch self.batch_size_policy = batch_size_policy self.model_pool = ModelBuffer( batch_size=self._rollout_batch_size, max_path_length=maxroll, env=self.fake_env, ensemble_size=num_networks, rollout_mode=self.rollout_mode, cares_about_cost=cares_about_cost, max_uncertainty_c=self._max_uncertainty_c, max_uncertainty_r=self._max_uncertainty_rew, ) self.model_pool.initialize( pi_info_shapes, gamma=self._policy.gamma, lam=self._policy.lam, cost_gamma=self._policy.cost_gamma, cost_lam=self._policy.cost_lam, ) #@anyboby debug self.model_sampler = ModelSampler( max_path_length=maxroll, batch_size=self._rollout_batch_size, store_last_n_paths=10, cares_about_cost=cares_about_cost, max_uncertainty_c=self._max_uncertainty_c, max_uncertainty_r=self._max_uncertainty_rew, logger=None, rollout_mode=self.rollout_mode, ) # provide policy and sampler with the same logger self.logger = EpochLogger() self._policy.set_logger(self.logger) self.sampler.set_logger(self.logger) #self.model_sampler.set_logger(self.logger) def _train(self): """Return a generator that performs RL training. Args: env (`SoftlearningEnv`): Environment used for training. policy (`Policy`): Policy used for training initial_exploration_policy ('Policy'): Policy used for exploration If None, then all exploration is done using policy pool (`PoolBase`): Sample pool to add samples to """ #### pool is e.g. simple_replay_pool training_environment = self._training_environment evaluation_environment = self._evaluation_environment policy = self._policy pool = self._pool if not self._training_started: #### perform some initial steps (gather samples) using initial policy ###### fills pool with _n_initial_exploration_steps samples self._initial_exploration_hook(training_environment, self._policy, pool) #### set up sampler with train env and actual policy (may be different from initial exploration policy) ######## note: sampler is set up with the pool that may be already filled from initial exploration hook self.sampler.initialize(training_environment, policy, pool) self.model_sampler.initialize(self.fake_env, policy, self.model_pool) rollout_dkl_lim = self.model_sampler.compute_dynamics_dkl( obs_batch=self._pool.rand_batch_from_archive( 5000, fields=['observations'])['observations'], depth=self._rollout_schedule[2]) self.model_sampler.set_rollout_dkl(rollout_dkl_lim) self.initial_model_dkl = self.model_sampler.dyn_dkl #### reset gtimer (for coverage of project development) gt.reset_root() gt.rename_root('RLAlgorithm') gt.set_def_unique(False) self.policy_epoch = 0 ### count policy updates self.new_real_samples = 0 self.last_eval_step = 0 self.diag_counter = 0 running_diag = {} self.approx_model_batch = self.batch_size_policy - self.min_real_samples_per_epoch ### some size to start off #### not implemented, could train policy before hook self._training_before_hook() #### iterate over epochs, gt.timed_for to create loop with gt timestamps for self._epoch in gt.timed_for(range(self._epoch, self._n_epochs)): #### do something at beginning of epoch (in this case reset self._train_steps_this_epoch=0) self._epoch_before_hook() gt.stamp('epoch_before_hook') #### util class Progress, e.g. for plotting a progress bar ####### note: sampler may already contain samples in its pool from initial_exploration_hook or previous epochs self._training_progress = Progress(self._epoch_length * self._n_train_repeat / self._train_every_n_steps) samples_added = 0 #=====================================================================# # Rollout model # #=====================================================================# model_samples = None keep_rolling = True model_metrics = {} #### start model rollout if self._real_ratio < 1.0: #if self._timestep % self._model_train_freq == 0 and self._real_ratio < 1.0: #=====================================================================# # Model Rollouts # #=====================================================================# if self.rollout_mode == 'schedule': self._set_rollout_length() while keep_rolling: ep_b = self._pool.epoch_batch( batch_size=self._rollout_batch_size, epochs=self._pool.epochs_list, fields=['observations', 'pi_infos']) kls = np.clip(self._policy.compute_DKL( ep_b['observations'], ep_b['mu'], ep_b['log_std']), a_min=0, a_max=None) btz_dist = self._pool.boltz_dist(kls, alpha=self.policy_alpha) btz_b = self._pool.distributed_batch_from_archive( self._rollout_batch_size, btz_dist, fields=['observations', 'pi_infos']) start_states, mus, logstds = btz_b['observations'], btz_b[ 'mu'], btz_b['log_std'] btz_kl = np.clip(self._policy.compute_DKL( start_states, mus, logstds), a_min=0, a_max=None) self.model_sampler.reset(start_states) if self.rollout_mode == 'uncertainty': self.model_sampler.set_max_uncertainty( self.max_tddyn_err) for i in count(): # print(f'Model Sampling step Nr. {i+1}') _, _, _, info = self.model_sampler.sample( max_samples=int(self.approx_model_batch - samples_added)) if self.model_sampler._total_samples + samples_added >= .99 * self.approx_model_batch: keep_rolling = False break if info['alive_ratio'] <= 0.1: break ### diagnostics for rollout ### rollout_diagnostics = self.model_sampler.finish_all_paths() if self.rollout_mode == 'iv_gae': keep_rolling = self.model_pool.size + samples_added <= .99 * self.approx_model_batch ###################################################################### ### get model_samples, get() invokes the inverse variance rollouts ### model_samples_new, buffer_diagnostics_new = self.model_pool.get( ) model_samples = [ np.concatenate((o, n), axis=0) for o, n in zip(model_samples, model_samples_new) ] if model_samples else model_samples_new ###################################################################### ### diagnostics new_n_samples = len(model_samples_new[0]) + EPS diag_weight_old = samples_added / (new_n_samples + samples_added) diag_weight_new = new_n_samples / (new_n_samples + samples_added) model_metrics = update_dict(model_metrics, rollout_diagnostics, weight_a=diag_weight_old, weight_b=diag_weight_new) model_metrics = update_dict(model_metrics, buffer_diagnostics_new, weight_a=diag_weight_old, weight_b=diag_weight_new) ### run diagnostics on model data if buffer_diagnostics_new['poolm_batch_size'] > 0: model_data_diag = self._policy.run_diagnostics( model_samples_new) model_data_diag = { k + '_m': v for k, v in model_data_diag.items() } model_metrics = update_dict(model_metrics, model_data_diag, weight_a=diag_weight_old, weight_b=diag_weight_new) samples_added += new_n_samples model_metrics.update({'samples_added': samples_added}) ###################################################################### ## for debugging model_metrics.update({ 'cached_var': np.mean(self.fake_env._model.scaler_out.cached_var) }) model_metrics.update({ 'cached_mu': np.mean(self.fake_env._model.scaler_out.cached_mu) }) print(f'Rollouts finished') gt.stamp('epoch_rollout_model') #=====================================================================# # Sample # #=====================================================================# n_real_samples = self.model_sampler.dyn_dkl / self.initial_model_dkl * self.min_real_samples_per_epoch n_real_samples = max(n_real_samples, 1000) # n_real_samples = self.min_real_samples_per_epoch ### for ablation model_metrics.update({'n_real_samples': n_real_samples}) start_samples = self.sampler._total_samples ### train for epoch_length ### for i in count(): #### _timestep is within an epoch samples_now = self.sampler._total_samples self._timestep = samples_now - start_samples #### not implemented atm self._timestep_before_hook() gt.stamp('timestep_before_hook') ##### śampling from the real world ! ##### _, _, _, _ = self._do_sampling(timestep=self.policy_epoch) gt.stamp('sample') self._timestep_after_hook() gt.stamp('timestep_after_hook') if self.ready_to_train or self._timestep > n_real_samples: self.sampler.finish_all_paths(append_val=True, append_cval=True, reset_path=False) self.new_real_samples += self._timestep break #=====================================================================# # Train model # #=====================================================================# if self.new_real_samples > 2048 and self._real_ratio < 1.0: model_diag = self.train_model(min_epochs=1, max_epochs=10) self.new_real_samples = 0 model_metrics.update(model_diag) #=====================================================================# # Get Buffer Data # #=====================================================================# real_samples, buf_diag = self._pool.get() ### run diagnostics on real data policy_diag = self._policy.run_diagnostics(real_samples) policy_diag = {k + '_r': v for k, v in policy_diag.items()} model_metrics.update(policy_diag) model_metrics.update(buf_diag) #=====================================================================# # Update Policy # #=====================================================================# train_samples = [ np.concatenate((r, m), axis=0) for r, m in zip(real_samples, model_samples) ] if model_samples else real_samples self._policy.update_real_c(real_samples) self._policy.update_policy(train_samples) self._policy.update_critic( train_samples, train_vc=(train_samples[-3] > 0).any()) ### only train vc if there are any costs if self._real_ratio < 1.0: self.approx_model_batch = self.batch_size_policy - n_real_samples #self.model_sampler.dyn_dkl/self.initial_model_dkl * self.min_real_samples_per_epoch self.policy_epoch += 1 self.max_tddyn_err *= self.max_tddyn_err_decay #### log policy diagnostics self._policy.log() gt.stamp('train') #=====================================================================# # Log performance and stats # #=====================================================================# self.sampler.log() # write results to file, ray prints for us, so no need to print from logger logger_diagnostics = self.logger.dump_tabular( output_dir=self._log_dir, print_out=False) #=====================================================================# if self._total_timestep // self.eval_every_n_steps > self.last_eval_step: evaluation_paths = self._evaluation_paths( policy, evaluation_environment) gt.stamp('evaluation_paths') self.last_eval_step = self._total_timestep // self.eval_every_n_steps else: evaluation_paths = [] if evaluation_paths: evaluation_metrics = self._evaluate_rollouts( evaluation_paths, evaluation_environment) gt.stamp('evaluation_metrics') diag_obs_batch = np.concatenate(([ evaluation_paths[i]['observations'] for i in range(len(evaluation_paths)) ]), axis=0) else: evaluation_metrics = {} diag_obs_batch = [] gt.stamp('epoch_after_hook') new_diagnostics = {} time_diagnostics = gt.get_times().stamps.itrs # add diagnostics from logger new_diagnostics.update(logger_diagnostics) new_diagnostics.update( OrderedDict(( *((f'evaluation/{key}', evaluation_metrics[key]) for key in sorted(evaluation_metrics.keys())), *((f'times/{key}', time_diagnostics[key][-1]) for key in sorted(time_diagnostics.keys())), *((f'model/{key}', model_metrics[key]) for key in sorted(model_metrics.keys())), ))) if self._eval_render_mode is not None and hasattr( evaluation_environment, 'render_rollouts'): training_environment.render_rollouts(evaluation_paths) #### updateing and averaging old_ts_diag = running_diag.get('timestep', 0) new_ts_diag = self._total_timestep - self.diag_counter - old_ts_diag w_olddiag = old_ts_diag / (new_ts_diag + old_ts_diag) w_newdiag = new_ts_diag / (new_ts_diag + old_ts_diag) running_diag = update_dict(running_diag, new_diagnostics, weight_a=w_olddiag, weight_b=w_newdiag) running_diag.update({'timestep': new_ts_diag + old_ts_diag}) #### if new_ts_diag + old_ts_diag > self.eval_every_n_steps: running_diag.update({ 'epoch': self._epoch, 'timesteps_total': self._total_timestep, 'train-steps': self._num_train_steps, }) self.diag_counter = self._total_timestep diag = running_diag.copy() running_diag = {} yield diag if self._total_timestep >= self.n_env_interacts: self.sampler.terminate() self._training_after_hook() self._training_progress.close() print("###### DONE ######") yield {'done': True, **running_diag} break def train(self, *args, **kwargs): return self._train(*args, **kwargs) def _initial_exploration_hook(self, env, initial_exploration_policy, pool): if self._n_initial_exploration_steps < 1: return if not initial_exploration_policy: raise ValueError("Initial exploration policy must be provided when" " n_initial_exploration_steps > 0.") self.sampler.initialize(env, initial_exploration_policy, pool) while True: self.sampler.sample(timestep=0) if self.sampler._total_samples >= self._n_initial_exploration_steps: self.sampler.finish_all_paths(append_val=True, append_cval=True, reset_path=False) pool.get() # moves policy samples to archive break ### train model if self._real_ratio < 1.0: self.train_model(min_epochs=150, max_epochs=500) def train_model(self, min_epochs=5, max_epochs=100, batch_size=2048): self._dyn_model_train_freq = self._set_model_train_freq( self._dyn_model_train_freq, self._dyn_model_train_schedule) ## set current train freq. self._cost_model_train_freq = self._set_model_train_freq( self._cost_model_train_freq, self._cost_model_train_schedule) print('[ MBPO ] log_dir: {} | ratio: {}'.format( self._log_dir, self._real_ratio)) print( '[ MBPO ] Training model at epoch {} | freq {} | timestep {} (total: {}) (total: {})' .format(self._epoch, self._dyn_model_train_freq, self._timestep, self._total_timestep, self._num_train_steps)) model_samples = self._pool.get_archive([ 'observations', 'actions', 'next_observations', 'rewards', 'costs', 'terminals', 'epochs', ]) if self._epoch % self._dyn_model_train_freq == 0: diag_dyn = self.fake_env.train_dyn_model( model_samples, batch_size=batch_size, #512 max_epochs=max_epochs, # max_epochs min_epoch_before_break=min_epochs, # min_epochs, holdout_ratio=0.2, max_t=self._max_model_t) if self._epoch % self._cost_model_train_freq == 0 and self.fake_env.cares_about_cost: diag_c = self.fake_env.train_cost_model( model_samples, batch_size=batch_size, #batch_size, #512, min_epoch_before_break=min_epochs, #min_epochs, max_epochs=max_epochs, # max_epochs, holdout_ratio=0.2, max_t=self._max_model_t) diag_dyn.update(diag_c) return diag_dyn @property def _total_timestep(self): total_timestep = self.sampler._total_samples return total_timestep def _set_rollout_length(self): min_epoch, max_epoch, min_length, max_length = self._rollout_schedule if self._epoch <= min_epoch: y = min_length else: dx = (self._epoch - min_epoch) / (max_epoch - min_epoch) dx = min(dx, 1) y = dx * (max_length - min_length) + min_length self._rollout_length = int(y) self.model_sampler.set_max_path_length(self._rollout_length) print( '[ Model Length ] Epoch: {} (min: {}, max: {}) | Length: {} (min: {} , max: {})' .format(self._epoch, min_epoch, max_epoch, self._rollout_length, min_length, max_length)) def _do_sampling(self, timestep): return self.sampler.sample(timestep=timestep) def _set_model_train_freq(self, var, schedule): min_epoch, max_epoch, min_freq, max_freq = schedule if self._epoch <= min_epoch: y = min_freq else: dx = (self._epoch - min_epoch) / (max_epoch - min_epoch) dx = min(dx, 1) y = dx * (max_freq - min_freq) + min_freq var = int(y) print( '[ Model Train Frequency ] Epoch: {} (min: {}, max: {}) | Frequency: {} (min: {} , max: {})' .format(self._epoch, min_epoch, max_epoch, var, min_freq, max_freq)) return var def _evaluate_rollouts(self, paths, env): """Compute evaluation metrics for the given rollouts.""" total_returns = [path['rewards'].sum() for path in paths] episode_lengths = [len(p['rewards']) for p in paths] total_cost = [path['cost'].sum() for path in paths] diagnostics = OrderedDict(( ('return-average', np.mean(total_returns)), ('return-min', np.min(total_returns)), ('return-max', np.max(total_returns)), ('return-std', np.std(total_returns)), ('episode-length-avg', np.mean(episode_lengths)), ('episode-length-min', np.min(episode_lengths)), ('episode-length-max', np.max(episode_lengths)), ('episode-length-std', np.std(episode_lengths)), ('creturn-average', np.mean(total_cost)), ('creturn-fullep-average', np.mean(total_cost) / np.mean(episode_lengths) * self.sampler.max_path_length), ('creturn-min', np.min(total_cost)), ('creturn-max', np.max(total_cost)), ('creturn-std', np.std(total_cost)), )) env_infos = env.get_path_infos(paths) for key, value in env_infos.items(): diagnostics[f'env_infos/{key}'] = value return diagnostics def _visualize_model(self, env, timestep): ## save env state state = env.unwrapped.state_vector() qpos_dim = len(env.unwrapped.sim.data.qpos) qpos = state[:qpos_dim] qvel = state[qpos_dim:] print('[ Visualization ] Starting | Epoch {} | Log dir: {}\n'.format( self._epoch, self._log_dir)) visualize_policy(env, self.fake_env, self._policy, self._writer, timestep) print('[ Visualization ] Done') ## set env state env.unwrapped.set_state(qpos, qvel) def get_diagnostics(self, iteration, obs_batch=None, training_paths=None, evaluation_paths=None): """Return diagnostic information as ordered dictionary. Records state value function, and TD-loss (mean squared Bellman error) for the sample batch. Also calls the `draw` method of the plotter, if plotter defined. """ # @anyboby warnings.warn('diagnostics not implemented yet!') # diagnostics = OrderedDict({ # }) # policy_diagnostics = self._policy.get_diagnostics( #use eval paths # obs_batch[:,-self.active_obs_dim:]) # diagnostics.update({ # f'policy/{key}': value # for key, value in policy_diagnostics.items() # }) # if self._plotter: # self._plotter.draw() diagnostics = {} return diagnostics def save(self, savedir): self.fake_env._model.save(savedir, self._epoch) @property def tf_saveables(self): saveables = {self._policy.tf_saveables} # saveables = { # '_policy_optimizer': self._policy_optimizer, # **{ # f'Q_optimizer_{i}': optimizer # for i, optimizer in enumerate(self._Q_optimizers) # }, # '_log_alpha': self._log_alpha, # } # if hasattr(self, '_alpha_optimizer'): # saveables['_alpha_optimizer'] = self._alpha_optimizer return saveables
def _train(self): """Return a generator that performs RL training. Args: env (`SoftlearningEnv`): Environment used for training. policy (`Policy`): Policy used for training initial_exploration_policy ('Policy'): Policy used for exploration If None, then all exploration is done using policy pool (`PoolBase`): Sample pool to add samples to """ #### pool is e.g. simple_replay_pool training_environment = self._training_environment evaluation_environment = self._evaluation_environment policy = self._policy pool = self._pool if not self._training_started: #### perform some initial steps (gather samples) using initial policy ###### fills pool with _n_initial_exploration_steps samples self._initial_exploration_hook(training_environment, self._policy, pool) #### set up sampler with train env and actual policy (may be different from initial exploration policy) ######## note: sampler is set up with the pool that may be already filled from initial exploration hook self.sampler.initialize(training_environment, policy, pool) self.model_sampler.initialize(self.fake_env, policy, self.model_pool) rollout_dkl_lim = self.model_sampler.compute_dynamics_dkl( obs_batch=self._pool.rand_batch_from_archive( 5000, fields=['observations'])['observations'], depth=self._rollout_schedule[2]) self.model_sampler.set_rollout_dkl(rollout_dkl_lim) self.initial_model_dkl = self.model_sampler.dyn_dkl #### reset gtimer (for coverage of project development) gt.reset_root() gt.rename_root('RLAlgorithm') gt.set_def_unique(False) self.policy_epoch = 0 ### count policy updates self.new_real_samples = 0 self.last_eval_step = 0 self.diag_counter = 0 running_diag = {} self.approx_model_batch = self.batch_size_policy - self.min_real_samples_per_epoch ### some size to start off #### not implemented, could train policy before hook self._training_before_hook() #### iterate over epochs, gt.timed_for to create loop with gt timestamps for self._epoch in gt.timed_for(range(self._epoch, self._n_epochs)): #### do something at beginning of epoch (in this case reset self._train_steps_this_epoch=0) self._epoch_before_hook() gt.stamp('epoch_before_hook') #### util class Progress, e.g. for plotting a progress bar ####### note: sampler may already contain samples in its pool from initial_exploration_hook or previous epochs self._training_progress = Progress(self._epoch_length * self._n_train_repeat / self._train_every_n_steps) samples_added = 0 #=====================================================================# # Rollout model # #=====================================================================# model_samples = None keep_rolling = True model_metrics = {} #### start model rollout if self._real_ratio < 1.0: #if self._timestep % self._model_train_freq == 0 and self._real_ratio < 1.0: #=====================================================================# # Model Rollouts # #=====================================================================# if self.rollout_mode == 'schedule': self._set_rollout_length() while keep_rolling: ep_b = self._pool.epoch_batch( batch_size=self._rollout_batch_size, epochs=self._pool.epochs_list, fields=['observations', 'pi_infos']) kls = np.clip(self._policy.compute_DKL( ep_b['observations'], ep_b['mu'], ep_b['log_std']), a_min=0, a_max=None) btz_dist = self._pool.boltz_dist(kls, alpha=self.policy_alpha) btz_b = self._pool.distributed_batch_from_archive( self._rollout_batch_size, btz_dist, fields=['observations', 'pi_infos']) start_states, mus, logstds = btz_b['observations'], btz_b[ 'mu'], btz_b['log_std'] btz_kl = np.clip(self._policy.compute_DKL( start_states, mus, logstds), a_min=0, a_max=None) self.model_sampler.reset(start_states) if self.rollout_mode == 'uncertainty': self.model_sampler.set_max_uncertainty( self.max_tddyn_err) for i in count(): # print(f'Model Sampling step Nr. {i+1}') _, _, _, info = self.model_sampler.sample( max_samples=int(self.approx_model_batch - samples_added)) if self.model_sampler._total_samples + samples_added >= .99 * self.approx_model_batch: keep_rolling = False break if info['alive_ratio'] <= 0.1: break ### diagnostics for rollout ### rollout_diagnostics = self.model_sampler.finish_all_paths() if self.rollout_mode == 'iv_gae': keep_rolling = self.model_pool.size + samples_added <= .99 * self.approx_model_batch ###################################################################### ### get model_samples, get() invokes the inverse variance rollouts ### model_samples_new, buffer_diagnostics_new = self.model_pool.get( ) model_samples = [ np.concatenate((o, n), axis=0) for o, n in zip(model_samples, model_samples_new) ] if model_samples else model_samples_new ###################################################################### ### diagnostics new_n_samples = len(model_samples_new[0]) + EPS diag_weight_old = samples_added / (new_n_samples + samples_added) diag_weight_new = new_n_samples / (new_n_samples + samples_added) model_metrics = update_dict(model_metrics, rollout_diagnostics, weight_a=diag_weight_old, weight_b=diag_weight_new) model_metrics = update_dict(model_metrics, buffer_diagnostics_new, weight_a=diag_weight_old, weight_b=diag_weight_new) ### run diagnostics on model data if buffer_diagnostics_new['poolm_batch_size'] > 0: model_data_diag = self._policy.run_diagnostics( model_samples_new) model_data_diag = { k + '_m': v for k, v in model_data_diag.items() } model_metrics = update_dict(model_metrics, model_data_diag, weight_a=diag_weight_old, weight_b=diag_weight_new) samples_added += new_n_samples model_metrics.update({'samples_added': samples_added}) ###################################################################### ## for debugging model_metrics.update({ 'cached_var': np.mean(self.fake_env._model.scaler_out.cached_var) }) model_metrics.update({ 'cached_mu': np.mean(self.fake_env._model.scaler_out.cached_mu) }) print(f'Rollouts finished') gt.stamp('epoch_rollout_model') #=====================================================================# # Sample # #=====================================================================# n_real_samples = self.model_sampler.dyn_dkl / self.initial_model_dkl * self.min_real_samples_per_epoch n_real_samples = max(n_real_samples, 1000) # n_real_samples = self.min_real_samples_per_epoch ### for ablation model_metrics.update({'n_real_samples': n_real_samples}) start_samples = self.sampler._total_samples ### train for epoch_length ### for i in count(): #### _timestep is within an epoch samples_now = self.sampler._total_samples self._timestep = samples_now - start_samples #### not implemented atm self._timestep_before_hook() gt.stamp('timestep_before_hook') ##### śampling from the real world ! ##### _, _, _, _ = self._do_sampling(timestep=self.policy_epoch) gt.stamp('sample') self._timestep_after_hook() gt.stamp('timestep_after_hook') if self.ready_to_train or self._timestep > n_real_samples: self.sampler.finish_all_paths(append_val=True, append_cval=True, reset_path=False) self.new_real_samples += self._timestep break #=====================================================================# # Train model # #=====================================================================# if self.new_real_samples > 2048 and self._real_ratio < 1.0: model_diag = self.train_model(min_epochs=1, max_epochs=10) self.new_real_samples = 0 model_metrics.update(model_diag) #=====================================================================# # Get Buffer Data # #=====================================================================# real_samples, buf_diag = self._pool.get() ### run diagnostics on real data policy_diag = self._policy.run_diagnostics(real_samples) policy_diag = {k + '_r': v for k, v in policy_diag.items()} model_metrics.update(policy_diag) model_metrics.update(buf_diag) #=====================================================================# # Update Policy # #=====================================================================# train_samples = [ np.concatenate((r, m), axis=0) for r, m in zip(real_samples, model_samples) ] if model_samples else real_samples self._policy.update_real_c(real_samples) self._policy.update_policy(train_samples) self._policy.update_critic( train_samples, train_vc=(train_samples[-3] > 0).any()) ### only train vc if there are any costs if self._real_ratio < 1.0: self.approx_model_batch = self.batch_size_policy - n_real_samples #self.model_sampler.dyn_dkl/self.initial_model_dkl * self.min_real_samples_per_epoch self.policy_epoch += 1 self.max_tddyn_err *= self.max_tddyn_err_decay #### log policy diagnostics self._policy.log() gt.stamp('train') #=====================================================================# # Log performance and stats # #=====================================================================# self.sampler.log() # write results to file, ray prints for us, so no need to print from logger logger_diagnostics = self.logger.dump_tabular( output_dir=self._log_dir, print_out=False) #=====================================================================# if self._total_timestep // self.eval_every_n_steps > self.last_eval_step: evaluation_paths = self._evaluation_paths( policy, evaluation_environment) gt.stamp('evaluation_paths') self.last_eval_step = self._total_timestep // self.eval_every_n_steps else: evaluation_paths = [] if evaluation_paths: evaluation_metrics = self._evaluate_rollouts( evaluation_paths, evaluation_environment) gt.stamp('evaluation_metrics') diag_obs_batch = np.concatenate(([ evaluation_paths[i]['observations'] for i in range(len(evaluation_paths)) ]), axis=0) else: evaluation_metrics = {} diag_obs_batch = [] gt.stamp('epoch_after_hook') new_diagnostics = {} time_diagnostics = gt.get_times().stamps.itrs # add diagnostics from logger new_diagnostics.update(logger_diagnostics) new_diagnostics.update( OrderedDict(( *((f'evaluation/{key}', evaluation_metrics[key]) for key in sorted(evaluation_metrics.keys())), *((f'times/{key}', time_diagnostics[key][-1]) for key in sorted(time_diagnostics.keys())), *((f'model/{key}', model_metrics[key]) for key in sorted(model_metrics.keys())), ))) if self._eval_render_mode is not None and hasattr( evaluation_environment, 'render_rollouts'): training_environment.render_rollouts(evaluation_paths) #### updateing and averaging old_ts_diag = running_diag.get('timestep', 0) new_ts_diag = self._total_timestep - self.diag_counter - old_ts_diag w_olddiag = old_ts_diag / (new_ts_diag + old_ts_diag) w_newdiag = new_ts_diag / (new_ts_diag + old_ts_diag) running_diag = update_dict(running_diag, new_diagnostics, weight_a=w_olddiag, weight_b=w_newdiag) running_diag.update({'timestep': new_ts_diag + old_ts_diag}) #### if new_ts_diag + old_ts_diag > self.eval_every_n_steps: running_diag.update({ 'epoch': self._epoch, 'timesteps_total': self._total_timestep, 'train-steps': self._num_train_steps, }) self.diag_counter = self._total_timestep diag = running_diag.copy() running_diag = {} yield diag if self._total_timestep >= self.n_env_interacts: self.sampler.terminate() self._training_after_hook() self._training_progress.close() print("###### DONE ######") yield {'done': True, **running_diag} break
class MBPO(RLAlgorithm): """Model-Based Policy Optimization (MBPO) References ---------- Michael Janner, Justin Fu, Marvin Zhang, Sergey Levine. When to Trust Your Model: Model-Based Policy Optimization. arXiv preprint arXiv:1906.08253. 2019. """ def __init__( self, training_environment, evaluation_environment, policy, Qs, pool, static_fns, plotter=None, tf_summaries=False, lr=3e-4, reward_scale=1.0, target_entropy='auto', discount=0.99, tau=5e-3, target_update_interval=1, action_prior='uniform', reparameterize=False, store_extra_policy_info=False, deterministic=False, model_train_freq=250, model_train_slower=1, num_networks=7, num_elites=5, num_Q_elites=2, # The num of Q ensemble is set in command line model_retain_epochs=20, rollout_batch_size=100e3, real_ratio=0.1, critic_same_as_actor=True, rollout_schedule=[20,100,1,1], hidden_dim=200, max_model_t=None, dir_name=None, evaluate_explore_freq=0, num_Q_per_grp=2, num_Q_grp=1, cross_grp_diff_batch=False, model_load_dir=None, model_load_index=None, model_log_freq=0, **kwargs, ): """ Args: env (`SoftlearningEnv`): Environment used for training. policy: A policy function approximator. initial_exploration_policy: ('Policy'): A policy that we use for initial exploration which is not trained by the algorithm. Qs: Q-function approximators. The min of these approximators will be used. Usage of at least two Q-functions improves performance by reducing overestimation bias. pool (`PoolBase`): Replay pool to add gathered samples to. plotter (`QFPolicyPlotter`): Plotter instance to be used for visualizing Q-function during training. lr (`float`): Learning rate used for the function approximators. discount (`float`): Discount factor for Q-function updates. tau (`float`): Soft value function target update weight. target_update_interval ('int'): Frequency at which target network updates occur in iterations. reparameterize ('bool'): If True, we use a gradient estimator for the policy derived using the reparameterization trick. We use a likelihood ratio based estimator otherwise. critic_same_as_actor ('bool'): If True, use the same sampling schema (model free or model based) as the actor in critic training. Otherwise, use model free sampling to train critic. """ super(MBPO, self).__init__(**kwargs) if training_environment.unwrapped.spec.id.find("Fetch") != -1: # Fetch env obs_dim = sum([i.shape[0] for i in training_environment.observation_space.spaces.values()]) self.multigoal = 1 else: obs_dim = np.prod(training_environment.observation_space.shape) # print("====", obs_dim, "========") act_dim = np.prod(training_environment.action_space.shape) # TODO: add variable scope to directly extract model parameters self._model_load_dir = model_load_dir print("============Model dir: ", self._model_load_dir) if model_load_index: latest_model_index = model_load_index else: latest_model_index = self._get_latest_index() self._model = construct_model(obs_dim=obs_dim, act_dim=act_dim, hidden_dim=hidden_dim, num_networks=num_networks, num_elites=num_elites, model_dir=self._model_load_dir, model_load_timestep=latest_model_index, load_model=True if model_load_dir else False) self._static_fns = static_fns self.fake_env = FakeEnv(self._model, self._static_fns) model_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self._model.name) all_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) self._rollout_schedule = rollout_schedule self._max_model_t = max_model_t # self._model_pool_size = model_pool_size # print('[ MBPO ] Model pool size: {:.2E}'.format(self._model_pool_size)) # self._model_pool = SimpleReplayPool(pool._observation_space, pool._action_space, self._model_pool_size) self._model_retain_epochs = model_retain_epochs self._model_train_freq = model_train_freq self._rollout_batch_size = int(rollout_batch_size) self._deterministic = deterministic self._real_ratio = real_ratio self._log_dir = os.getcwd() self._writer = Writer(self._log_dir) self._training_environment = training_environment self._evaluation_environment = evaluation_environment self._policy = policy self._Qs = Qs self._Q_ensemble = len(Qs) self._Q_elites = num_Q_elites self._Q_targets = tuple(tf.keras.models.clone_model(Q) for Q in Qs) self._pool = pool self._plotter = plotter self._tf_summaries = tf_summaries self._policy_lr = lr self._Q_lr = lr self._reward_scale = reward_scale self._target_entropy = ( -np.prod(self._training_environment.action_space.shape) if target_entropy == 'auto' else target_entropy) print('[ MBPO ] Target entropy: {}'.format(self._target_entropy)) self._discount = discount self._tau = tau self._target_update_interval = target_update_interval self._action_prior = action_prior self._reparameterize = reparameterize self._store_extra_policy_info = store_extra_policy_info observation_shape = self._training_environment.active_observation_shape action_shape = self._training_environment.action_space.shape assert len(observation_shape) == 1, observation_shape self._observation_shape = observation_shape assert len(action_shape) == 1, action_shape self._action_shape = action_shape # self._critic_train_repeat = kwargs["critic_train_repeat"] # actor UTD should be n times larger or smaller than critic UTD assert self._actor_train_repeat % self._critic_train_repeat == 0 or \ self._critic_train_repeat % self._actor_train_repeat == 0 self._critic_train_freq = self._n_train_repeat // self._critic_train_repeat self._actor_train_freq = self._n_train_repeat // self._actor_train_repeat self._critic_same_as_actor = critic_same_as_actor self._model_train_slower = model_train_slower self._origin_model_train_epochs = 0 self._dir_name = dir_name self._evaluate_explore_freq = evaluate_explore_freq # Inter-group Qs are trained with the same data; Cross-group Qs different. self._num_Q_per_grp = num_Q_per_grp self._num_Q_grp = num_Q_grp self._cross_grp_diff_batch = cross_grp_diff_batch self._model_log_freq = model_log_freq self._build() def _build(self): self._training_ops = {} self._actor_training_ops = {} self._critic_training_ops = {} if not self._cross_grp_diff_batch else \ [{} for _ in range(self._num_Q_grp)] self._misc_training_ops = {} # basically no feeddict is needed # device = "/device:GPU:1" # with tf.device(device): # self._init_global_step() # self._init_placeholders() # self._init_actor_update() # self._init_critic_update() self._init_global_step() self._init_placeholders() self._init_actor_update() self._init_critic_update() def _train(self): """Return a generator that performs RL training. Args: env (`SoftlearningEnv`): Environment used for training. policy (`Policy`): Policy used for training initial_exploration_policy ('Policy'): Policy used for exploration If None, then all exploration is done using policy pool (`PoolBase`): Sample pool to add samples to """ training_environment = self._training_environment evaluation_environment = self._evaluation_environment policy = self._policy pool = self._pool model_metrics = {} if not self._training_started: self._init_training() self._initial_exploration_hook( training_environment, self._initial_exploration_policy, pool) self.sampler.initialize(training_environment, policy, pool) gt.reset_root() gt.rename_root('RLAlgorithm') gt.set_def_unique(False) self._training_before_hook() for self._epoch in gt.timed_for(range(self._epoch, self._n_epochs)): self._epoch_before_hook() gt.stamp('epoch_before_hook') if self._evaluate_explore_freq != 0 and self._epoch % self._evaluate_explore_freq == 0: self._evaluate_exploration() self._training_progress = Progress(self._epoch_length * self._n_train_repeat) start_samples = self.sampler._total_samples for i in count(): samples_now = self.sampler._total_samples self._timestep = samples_now - start_samples if (samples_now >= start_samples + self._epoch_length and self.ready_to_train): break self._timestep_before_hook() gt.stamp('timestep_before_hook') if self._timestep % self._model_train_freq == 0 and self._real_ratio < 1.0: self._training_progress.pause() print('[ MBPO ] log_dir: {} | ratio: {}'.format(self._log_dir, self._real_ratio)) print('[ MBPO ] Training model at epoch {} | freq {} | timestep {} (total: {}) | epoch train steps: {} (total: {}) | times slower: {}'.format( self._epoch, self._model_train_freq, self._timestep, self._total_timestep, self._train_steps_this_epoch, self._num_train_steps, self._model_train_slower) ) if self._origin_model_train_epochs % self._model_train_slower == 0: model_train_metrics = self._train_model(batch_size=256, max_epochs=None, holdout_ratio=0.2, max_t=self._max_model_t) model_metrics.update(model_train_metrics) gt.stamp('epoch_train_model') else: print('[ MBPO ] Skipping model training due to slowed training setting') self._origin_model_train_epochs += 1 self._set_rollout_length() self._reallocate_model_pool() model_rollout_metrics = self._rollout_model(rollout_batch_size=self._rollout_batch_size, deterministic=self._deterministic) model_metrics.update(model_rollout_metrics) if self._model_log_freq != 0 and self._timestep % self._model_log_freq == 0: self._log_model() gt.stamp('epoch_rollout_model') # self._visualize_model(self._evaluation_environment, self._total_timestep) self._training_progress.resume() self._do_sampling(timestep=self._total_timestep) gt.stamp('sample') if self.ready_to_train: self._do_training_repeats(timestep=self._total_timestep) gt.stamp('train') self._timestep_after_hook() gt.stamp('timestep_after_hook') training_paths = self.sampler.get_last_n_paths( math.ceil(self._epoch_length / self.sampler._max_path_length)) gt.stamp('training_paths') evaluation_paths = self._evaluation_paths( policy, evaluation_environment) gt.stamp('evaluation_paths') training_metrics = self._evaluate_rollouts( training_paths, training_environment) gt.stamp('training_metrics') if evaluation_paths: evaluation_metrics = self._evaluate_rollouts( evaluation_paths, evaluation_environment) gt.stamp('evaluation_metrics') else: evaluation_metrics = {} self._epoch_after_hook(training_paths) gt.stamp('epoch_after_hook') sampler_diagnostics = self.sampler.get_diagnostics() diagnostics = self.get_diagnostics( iteration=self._total_timestep, batch=self._evaluation_batch(), training_paths=training_paths, evaluation_paths=evaluation_paths) time_diagnostics = gt.get_times().stamps.itrs diagnostics.update(OrderedDict(( *( (f'evaluation/{key}', evaluation_metrics[key]) for key in sorted(evaluation_metrics.keys()) ), *( (f'training/{key}', training_metrics[key]) for key in sorted(training_metrics.keys()) ), *( (f'times/{key}', time_diagnostics[key][-1]) for key in sorted(time_diagnostics.keys()) ), *( (f'sampler/{key}', sampler_diagnostics[key]) for key in sorted(sampler_diagnostics.keys()) ), *( (f'model/{key}', model_metrics[key]) for key in sorted(model_metrics.keys()) ), ('epoch', self._epoch), ('timestep', self._timestep), ('timesteps_total', self._total_timestep), ('train-steps', self._num_train_steps), ))) if self._eval_render_mode is not None and hasattr( evaluation_environment, 'render_rollouts'): training_environment.render_rollouts(evaluation_paths) yield diagnostics self.sampler.terminate() self._training_after_hook() self._training_progress.close() yield {'done': True, **diagnostics} def _evaluate_exploration(self): print("=============evaluate exploration=========") # specify data dir base_dir = "/home/linus/Research/mbpo/mbpo_experiment/exploration_eval" if not self._dir_name: return data_dir = os.path.join(base_dir, self._dir_name) if not os.path.isdir(data_dir): os.mkdir(data_dir) # specify data name exp_name = "%d.pkl"%self._epoch path = os.path.join(data_dir, exp_name) evaluation_size = 3000 action_repeat = 20 batch = self.sampler.random_batch(evaluation_size) obs = batch['observations'] actions_repeat = [self._policy.actions_np(obs) for _ in range(action_repeat)] Qs = [] policy_std = [] for action in actions_repeat: Q = [] for (s,a) in zip(obs, action): s, a = np.array(s).reshape(1, -1), np.array(a).reshape(1, -1) Q.append( self._session.run( self._Q_values, feed_dict = { self._observations_ph: s, self._actions_ph: a } ) ) Qs.append(Q) Qs = np.array(Qs).squeeze() Qs_mean_action = np.mean(Qs, axis = 0) # Compute mean across different actions of one given state. if self._cross_grp_diff_batch: inter_grp_q_stds = [np.std(Qs_mean_action[:, i * self._num_Q_per_grp:(i+1) * self._num_Q_per_grp], axis = 1) for i in range(self._num_Q_grp)] mean_inter_grp_q_std = np.mean(np.array(inter_grp_q_stds), axis = 0) min_qs_per_grp = [np.mean(Qs_mean_action[:, i * self._num_Q_per_grp:(i+1) * self._num_Q_per_grp], axis = 1) for i in range(self._num_Q_grp)] cross_grp_std = np.std(np.array(min_qs_per_grp), axis = 0) else: q_std = np.std(Qs_mean_action, axis=1) # In fact V std policy_std = [np.prod(np.exp(self._policy.policy_log_scale_model.predict(np.array(s).reshape(1,-1)))) for s in obs] if self._cross_grp_diff_batch: data = { 'obs': obs, 'inter_q_std': mean_inter_grp_q_std, 'cross_q_std': cross_grp_std, 'pi_std': policy_std } else: data = { 'obs': obs, 'q_std': q_std, 'pi_std': policy_std } with open(path, 'wb') as f: pickle.dump(data, f) print("==========================================") def train(self, *args, **kwargs): return self._train(*args, **kwargs) def _log_policy(self): save_path = os.path.join(self._log_dir, 'models') filesystem.mkdir(save_path) weights = self._policy.get_weights() data = {'policy_weights': weights} full_path = os.path.join(save_path, 'policy_{}.pkl'.format(self._total_timestep)) print('Saving policy to: {}'.format(full_path)) pickle.dump(data, open(full_path, 'wb')) # TODO: use this function to save model def _log_model(self): save_path = os.path.join(self._log_dir, 'models') filesystem.mkdir(save_path) print('Saving model to: {}'.format(save_path)) self._model.save(save_path, self._total_timestep) def _set_rollout_length(self): min_epoch, max_epoch, min_length, max_length = self._rollout_schedule if self._epoch <= min_epoch: y = min_length else: dx = (self._epoch - min_epoch) / (max_epoch - min_epoch) dx = min(dx, 1) y = dx * (max_length - min_length) + min_length self._rollout_length = int(y) print('[ Model Length ] Epoch: {} (min: {}, max: {}) | Length: {} (min: {} , max: {})'.format( self._epoch, min_epoch, max_epoch, self._rollout_length, min_length, max_length )) def _reallocate_model_pool(self): obs_space = self._pool._observation_space act_space = self._pool._action_space rollouts_per_epoch = self._rollout_batch_size * self._epoch_length / self._model_train_freq model_steps_per_epoch = int(self._rollout_length * rollouts_per_epoch) new_pool_size = self._model_retain_epochs * model_steps_per_epoch if not hasattr(self, '_model_pool'): print('[ MBPO ] Initializing new model pool with size {:.2e}'.format( new_pool_size )) self._model_pool = SimpleReplayPool(obs_space, act_space, new_pool_size) elif self._model_pool._max_size != new_pool_size: print('[ MBPO ] Updating model pool | {:.2e} --> {:.2e}'.format( self._model_pool._max_size, new_pool_size )) samples = self._model_pool.return_all_samples() new_pool = SimpleReplayPool(obs_space, act_space, new_pool_size) new_pool.add_samples(samples) assert self._model_pool.size == new_pool.size self._model_pool = new_pool def _train_model(self, **kwargs): env_samples = self._pool.return_all_samples() # train_inputs, train_outputs = format_samples_for_training(env_samples, self.multigoal) train_inputs, train_outputs = format_samples_for_training(env_samples) model_metrics = self._model.train(train_inputs, train_outputs, **kwargs) return model_metrics def _rollout_model(self, rollout_batch_size, **kwargs): print('[ Model Rollout ] Starting | Epoch: {} | Rollout length: {} | Batch size: {}'.format( self._epoch, self._rollout_length, rollout_batch_size )) # Keep total rollout sample complexity unchanged batch = self.sampler.random_batch(rollout_batch_size // self._sample_repeat) obs = batch['observations'] steps_added = [] sampled_actions = [] for _ in range(self._sample_repeat): for i in range(self._rollout_length): # TODO: alter policy distribution in different times of sample repeating # self._policy: softlearning.policies.gaussian_policy.FeedforwardGaussianPolicy # self._policy._deterministic: False # print("=====================================") # print(self._policy._deterministic) # print("=====================================") act = self._policy.actions_np(obs) sampled_actions.append(act) next_obs, rew, term, info = self.fake_env.step(obs, act, **kwargs) steps_added.append(len(obs)) samples = {'observations': obs, 'actions': act, 'next_observations': next_obs, 'rewards': rew, 'terminals': term} self._model_pool.add_samples(samples) nonterm_mask = ~term.squeeze(-1) if nonterm_mask.sum() == 0: print('[ Model Rollout ] Breaking early: {} | {} / {}'.format(i, nonterm_mask.sum(), nonterm_mask.shape)) break obs = next_obs[nonterm_mask] # print(sampled_actions) mean_rollout_length = sum(steps_added) / rollout_batch_size rollout_stats = {'mean_rollout_length': mean_rollout_length} print('[ Model Rollout ] Added: {:.1e} | Model pool: {:.1e} (max {:.1e}) | Length: {} | Train rep: {}'.format( sum(steps_added), self._model_pool.size, self._model_pool._max_size, mean_rollout_length, self._n_train_repeat )) return rollout_stats def _visualize_model(self, env, timestep): ## save env state state = env.unwrapped.state_vector() qpos_dim = len(env.unwrapped.sim.data.qpos) qpos = state[:qpos_dim] qvel = state[qpos_dim:] print('[ Visualization ] Starting | Epoch {} | Log dir: {}\n'.format(self._epoch, self._log_dir)) visualize_policy(env, self.fake_env, self._policy, self._writer, timestep) print('[ Visualization ] Done') ## set env state env.unwrapped.set_state(qpos, qvel) def _training_batch(self, batch_size=None): batch_size = batch_size or self.sampler._batch_size env_batch_size = int(batch_size*self._real_ratio) model_batch_size = batch_size - env_batch_size ## can sample from the env pool even if env_batch_size == 0 if self._cross_grp_diff_batch: env_batch = [self._pool.random_batch(env_batch_size) for _ in range(self._num_Q_grp)] else: env_batch = self._pool.random_batch(env_batch_size) if model_batch_size > 0: model_batch = self._model_pool.random_batch(model_batch_size) keys = env_batch.keys() batch = {k: np.concatenate((env_batch[k], model_batch[k]), axis=0) for k in keys} else: ## if real_ratio == 1.0, no model pool was ever allocated, ## so skip the model pool sampling batch = env_batch return batch, env_batch def _init_global_step(self): self.global_step = training_util.get_or_create_global_step() self._training_ops.update({ 'increment_global_step': training_util._increment_global_step(1) }) self._misc_training_ops.update({ 'increment_global_step': training_util._increment_global_step(1) }) def _init_placeholders(self): """Create input placeholders for the SAC algorithm. Creates `tf.placeholder`s for: - observation - next observation - action - reward - terminals """ self._iteration_ph = tf.placeholder( tf.int64, shape=None, name='iteration') self._observations_ph = tf.placeholder( tf.float32, shape=(None, *self._observation_shape), name='observation', ) self._next_observations_ph = tf.placeholder( tf.float32, shape=(None, *self._observation_shape), name='next_observation', ) self._actions_ph = tf.placeholder( tf.float32, shape=(None, *self._action_shape), name='actions', ) self._rewards_ph = tf.placeholder( tf.float32, shape=(None, 1), name='rewards', ) self._terminals_ph = tf.placeholder( tf.float32, shape=(None, 1), name='terminals', ) if self._store_extra_policy_info: self._log_pis_ph = tf.placeholder( tf.float32, shape=(None, 1), name='log_pis', ) self._raw_actions_ph = tf.placeholder( tf.float32, shape=(None, *self._action_shape), name='raw_actions', ) def _get_Q_target(self): next_actions = self._policy.actions([self._next_observations_ph]) next_log_pis = self._policy.log_pis( [self._next_observations_ph], next_actions) next_Qs_values = tuple( Q([self._next_observations_ph, next_actions]) for Q in self._Q_targets) Qs_subset = np.random.choice(next_Qs_values, self._Q_elites, replace=False).tolist() # Line 8 of REDQ: min over M random indices min_next_Q = tf.reduce_min(Qs_subset, axis=0) next_value = min_next_Q - self._alpha * next_log_pis Q_target = td_target( reward=self._reward_scale * self._rewards_ph, discount=self._discount, next_value=(1 - self._terminals_ph) * next_value) return Q_target def _init_critic_update(self): """Create minimization operation for critic Q-function. Creates a `tf.optimizer.minimize` operation for updating critic Q-function with gradient descent, and appends it to `self._training_ops` attribute. """ Q_target = tf.stop_gradient(self._get_Q_target()) assert Q_target.shape.as_list() == [None, 1] Q_values = self._Q_values = tuple( Q([self._observations_ph, self._actions_ph]) for Q in self._Qs) Q_losses = self._Q_losses = tuple( tf.losses.mean_squared_error( labels=Q_target, predictions=Q_value, weights=0.5) for Q_value in Q_values) self._Q_optimizers = tuple( tf.train.AdamOptimizer( learning_rate=self._Q_lr, name='{}_{}_optimizer'.format(Q._name, i) ) for i, Q in enumerate(self._Qs)) # TODO: divide it to N separate ops, where N is # of Q grps Q_training_ops = tuple( tf.contrib.layers.optimize_loss( Q_loss, self.global_step, learning_rate=self._Q_lr, optimizer=Q_optimizer, variables=Q.trainable_variables, increment_global_step=False, summaries=(( "loss", "gradients", "gradient_norm", "global_gradient_norm" ) if self._tf_summaries else ())) for i, (Q, Q_loss, Q_optimizer) in enumerate(zip(self._Qs, Q_losses, self._Q_optimizers))) self._training_ops.update({'Q': tf.group(Q_training_ops)}) if self._cross_grp_diff_batch: assert len(Q_training_ops) >= self._num_Q_grp * self._num_Q_per_grp for i in range(self._num_Q_grp - 1): self._critic_training_ops[i].update({ 'Q': tf.group(Q_training_ops[i * self._num_Q_grp: (i+1) * self._num_Q_grp]) }) self._critic_training_ops[self._num_Q_grp - 1].update({ 'Q': tf.group(Q_training_ops[(self._num_Q_grp - 1) * self._num_Q_grp:]) }) else: self._critic_training_ops.update({'Q': tf.group(Q_training_ops)}) def _init_actor_update(self): """Create minimization operations for policy and entropy. Creates a `tf.optimizer.minimize` operations for updating policy and entropy with gradient descent, and adds them to `self._training_ops` attribute. """ actions = self._policy.actions([self._observations_ph]) log_pis = self._policy.log_pis([self._observations_ph], actions) assert log_pis.shape.as_list() == [None, 1] log_alpha = self._log_alpha = tf.get_variable( 'log_alpha', dtype=tf.float32, initializer=0.0) alpha = tf.exp(log_alpha) if isinstance(self._target_entropy, Number): alpha_loss = -tf.reduce_mean( log_alpha * tf.stop_gradient(log_pis + self._target_entropy)) self._alpha_optimizer = tf.train.AdamOptimizer( self._policy_lr, name='alpha_optimizer') self._alpha_train_op = self._alpha_optimizer.minimize( loss=alpha_loss, var_list=[log_alpha]) self._training_ops.update({ 'temperature_alpha': self._alpha_train_op }) self._actor_training_ops.update({ 'temperature_alpha': self._alpha_train_op }) self._alpha = alpha if self._action_prior == 'normal': policy_prior = tf.contrib.distributions.MultivariateNormalDiag( loc=tf.zeros(self._action_shape), scale_diag=tf.ones(self._action_shape)) policy_prior_log_probs = policy_prior.log_prob(actions) elif self._action_prior == 'uniform': policy_prior_log_probs = 0.0 Q_log_targets = tuple( Q([self._observations_ph, actions]) for Q in self._Qs) assert len(Q_log_targets) == self._Q_ensemble min_Q_log_target = tf.reduce_min(Q_log_targets, axis=0) mean_Q_log_target = tf.reduce_mean(Q_log_targets, axis=0) Q_target = min_Q_log_target if self._Q_ensemble == 2 else mean_Q_log_target if self._reparameterize: policy_kl_losses = ( alpha * log_pis - Q_target - policy_prior_log_probs) else: raise NotImplementedError assert policy_kl_losses.shape.as_list() == [None, 1] policy_loss = tf.reduce_mean(policy_kl_losses) self._policy_optimizer = tf.train.AdamOptimizer( learning_rate=self._policy_lr, name="policy_optimizer") policy_train_op = tf.contrib.layers.optimize_loss( policy_loss, self.global_step, learning_rate=self._policy_lr, optimizer=self._policy_optimizer, variables=self._policy.trainable_variables, increment_global_step=False, summaries=( "loss", "gradients", "gradient_norm", "global_gradient_norm" ) if self._tf_summaries else ()) self._training_ops.update({'policy_train_op': policy_train_op}) self._actor_training_ops.update({'policy_train_op': policy_train_op}) def _init_training(self): self._update_target(tau=1.0) def _update_target(self, tau=None): tau = tau or self._tau for Q, Q_target in zip(self._Qs, self._Q_targets): source_params = Q.get_weights() target_params = Q_target.get_weights() Q_target.set_weights([ tau * source + (1.0 - tau) * target for source, target in zip(source_params, target_params) ]) def _do_training(self, iteration, batch): """Runs the operations for updating training and target ops.""" mix_batch, mf_batch = batch self._training_progress.update() self._training_progress.set_description() if self._cross_grp_diff_batch: assert len(mix_batch) == self._num_Q_grp if self._real_ratio != 1: assert 0, "Currently different batch is not supported in MBPO" mix_feed_dict = [self._get_feed_dict(iteration, i) for i in mix_batch] single_mix_feed_dict = mix_feed_dict[0] else: mix_feed_dict = self._get_feed_dict(iteration, mix_batch) single_mix_feed_dict = mix_feed_dict if self._critic_same_as_actor: critic_feed_dict = mix_feed_dict else: critic_feed_dict = self._get_feed_dict(iteration, mf_batch) self._session.run(self._misc_training_ops, single_mix_feed_dict) if iteration % self._actor_train_freq == 0: self._session.run(self._actor_training_ops, single_mix_feed_dict) if iteration % self._critic_train_freq == 0: if self._cross_grp_diff_batch: assert len(self._critic_training_ops) == len(critic_feed_dict) [ self._session.run(op, feed_dict) for (op, feed_dict) in zip(self._critic_training_ops, critic_feed_dict) ] else: self._session.run(self._critic_training_ops, critic_feed_dict) if iteration % self._target_update_interval == 0: # Run target ops here. self._update_target() def _get_feed_dict(self, iteration, batch): """Construct TensorFlow feed_dict from sample batch.""" feed_dict = { self._observations_ph: batch['observations'], self._actions_ph: batch['actions'], self._next_observations_ph: batch['next_observations'], self._rewards_ph: batch['rewards'], self._terminals_ph: batch['terminals'], } if self._store_extra_policy_info: feed_dict[self._log_pis_ph] = batch['log_pis'] feed_dict[self._raw_actions_ph] = batch['raw_actions'] if iteration is not None: feed_dict[self._iteration_ph] = iteration return feed_dict def get_diagnostics(self, iteration, batch, training_paths, evaluation_paths): """Return diagnostic information as ordered dictionary. Records mean and standard deviation of Q-function and state value function, and TD-loss (mean squared Bellman error) for the sample batch. Also calls the `draw` method of the plotter, if plotter defined. """ mix_batch, _ = batch if self._cross_grp_diff_batch: mix_batch = mix_batch[0] mix_feed_dict = self._get_feed_dict(iteration, mix_batch) # (Q_values, Q_losses, alpha, global_step) = self._session.run( # (self._Q_values, # self._Q_losses, # self._alpha, # self.global_step), # feed_dict) Q_values, Q_losses = self._session.run( [self._Q_values, self._Q_losses], mix_feed_dict ) alpha, global_step = self._session.run( [self._alpha, self.global_step], mix_feed_dict ) diagnostics = OrderedDict({ 'Q-avg': np.mean(Q_values), 'Q-std': np.std(Q_values), 'Q_loss': np.mean(Q_losses), 'alpha': alpha, }) policy_diagnostics = self._policy.get_diagnostics( mix_batch['observations']) diagnostics.update({ f'policy/{key}': value for key, value in policy_diagnostics.items() }) if self._plotter: self._plotter.draw() return diagnostics @property def tf_saveables(self): saveables = { '_policy_optimizer': self._policy_optimizer, **{ f'Q_optimizer_{i}': optimizer for i, optimizer in enumerate(self._Q_optimizers) }, '_log_alpha': self._log_alpha, } if hasattr(self, '_alpha_optimizer'): saveables['_alpha_optimizer'] = self._alpha_optimizer return saveables def _get_latest_index(self): if self._model_load_dir is None: return return max([int(i.split("_")[1].split(".")[0]) for i in os.listdir(self._model_load_dir)])
def _train(self): """Return a generator that performs RL training. Args: env (`SoftlearningEnv`): Environment used for training. policy (`Policy`): Policy used for training initial_exploration_policy ('Policy'): Policy used for exploration If None, then all exploration is done using policy pool (`PoolBase`): Sample pool to add samples to """ #### pool is e.g. simple_replay_pool training_environment = self._training_environment evaluation_environment = self._evaluation_environment policy = self._policy pool = self._pool model_metrics = {} #### init Qs for SAC if not self._training_started: self._init_training() #### perform some initial steps (gather samples) using initial policy ###### fills pool with _n_initial_exploration_steps samples self._initial_exploration_hook(training_environment, self._initial_exploration_policy, pool) #### set up sampler with train env and actual policy (may be different from initial exploration policy) ######## note: sampler is set up with the pool that may be already filled from initial exploration hook self.sampler.initialize(training_environment, policy, pool) #### reset gtimer (for coverage of project development) gt.reset_root() gt.rename_root('RLAlgorithm') gt.set_def_unique(False) #### not implemented, could train policy before hook self._training_before_hook() #### iterate over epochs, gt.timed_for to create loop with gt timestamps for self._epoch in gt.timed_for(range(self._epoch, self._n_epochs)): #### do something at beginning of epoch (in this case reset self._train_steps_this_epoch=0) self._epoch_before_hook() gt.stamp('epoch_before_hook') #### util class Progress, e.g. for plotting a progress bar ####### note: sampler may already contain samples in its pool from initial_exploration_hook or previous epochs self._training_progress = Progress(self._epoch_length * self._n_train_repeat / self._train_every_n_steps) start_samples = self.sampler._total_samples ### train for epoch_length ### for i in count(): #### _timestep is within an epoch samples_now = self.sampler._total_samples self._timestep = samples_now - start_samples #### check if you're at the end of an epoch to train if (samples_now >= start_samples + self._epoch_length and self.ready_to_train): break #### not implemented atm self._timestep_before_hook() gt.stamp('timestep_before_hook') #### start model rollout if self._timestep % self._model_train_freq == 0 and self._real_ratio < 1.0: self._training_progress.pause() print('[ MBPO ] log_dir: {} | ratio: {}'.format( self._log_dir, self._real_ratio)) print( '[ MBPO ] Training model at epoch {} | freq {} | timestep {} (total: {}) | epoch train steps: {} (total: {})' .format(self._epoch, self._model_train_freq, self._timestep, self._total_timestep, self._train_steps_this_epoch, self._num_train_steps)) #### train the model with input:(obs, act), outputs: (rew, delta_obs), inputs are divided into sets with holdout_ratio #@anyboby debug samples = self._pool.return_all_samples() self.fake_env.reset_model() model_train_metrics = self.fake_env.train( samples, batch_size=512, max_epochs=None, holdout_ratio=0.2, max_t=self._max_model_t) model_metrics.update(model_train_metrics) gt.stamp('epoch_train_model') #### rollout model env #### self._set_rollout_length() self._reallocate_model_pool( use_mjc_model_pool=self.use_mjc_state_model) model_rollout_metrics = self._rollout_model( rollout_batch_size=self._rollout_batch_size, deterministic=self._deterministic) model_metrics.update(model_rollout_metrics) ########################### gt.stamp('epoch_rollout_model') # self._visualize_model(self._evaluation_environment, self._total_timestep) self._training_progress.resume() ##### śampling from the real world ! ##### ##### _total_timestep % train_every_n_steps is checked inside _do_sampling self._do_sampling(timestep=self._total_timestep) gt.stamp('sample') ### n_train_repeat from config ### if self.ready_to_train: self._do_training_repeats(timestep=self._total_timestep) gt.stamp('train') self._timestep_after_hook() gt.stamp('timestep_after_hook') training_paths = self.sampler.get_last_n_paths( math.ceil(self._epoch_length / self.sampler._max_path_length)) gt.stamp('training_paths') evaluation_paths = self._evaluation_paths(policy, evaluation_environment) gt.stamp('evaluation_paths') training_metrics = self._evaluate_rollouts(training_paths, training_environment) gt.stamp('training_metrics') if evaluation_paths: evaluation_metrics = self._evaluate_rollouts( evaluation_paths, evaluation_environment) gt.stamp('evaluation_metrics') else: evaluation_metrics = {} self._epoch_after_hook(training_paths) gt.stamp('epoch_after_hook') sampler_diagnostics = self.sampler.get_diagnostics() diagnostics = self.get_diagnostics( iteration=self._total_timestep, batch=self._evaluation_batch(), training_paths=training_paths, evaluation_paths=evaluation_paths) time_diagnostics = gt.get_times().stamps.itrs diagnostics.update( OrderedDict(( *((f'evaluation/{key}', evaluation_metrics[key]) for key in sorted(evaluation_metrics.keys())), *((f'training/{key}', training_metrics[key]) for key in sorted(training_metrics.keys())), *((f'times/{key}', time_diagnostics[key][-1]) for key in sorted(time_diagnostics.keys())), *((f'sampler/{key}', sampler_diagnostics[key]) for key in sorted(sampler_diagnostics.keys())), *((f'model/{key}', model_metrics[key]) for key in sorted(model_metrics.keys())), ('epoch', self._epoch), ('timestep', self._timestep), ('timesteps_total', self._total_timestep), ('train-steps', self._num_train_steps), ))) if self._eval_render_mode is not None and hasattr( evaluation_environment, 'render_rollouts'): training_environment.render_rollouts(evaluation_paths) yield diagnostics self.sampler.terminate() self._training_after_hook() self._training_progress.close() ### this is where we yield the episode diagnostics to tune trial runner ### yield {'done': True, **diagnostics}