def cost_function(self, x, it=0): with torch.no_grad(): plans, obs = self.create_particles(x, self._observation) returns = self.get_plan_values(obs, plans).view(self.num_rollouts, -1) weighted_returns = self.get_weighted_returns(returns) costs = -ptu.get_numpy(weighted_returns) if self._need_to_update_diagnostics: self.diagnostics.update( create_stats_ordered_dict( 'Iteration %d Returns' % it, ptu.get_numpy(weighted_returns), always_show_all_stats=True, )) self.diagnostics.update( create_stats_ordered_dict( 'Iteration %d Particle Stds' % it, np.std(ptu.get_numpy(returns), axis=-1), always_show_all_stats=True, )) variance = weighted_returns.var() particle_variance = returns.var(dim=-1) self.diagnostics['Return Leftover Variance'] = \ ptu.get_numpy(variance - particle_variance.mean()).mean() return costs
def get_diagnostics(self): stats = OrderedDict() stats.update( create_stats_ordered_dict( 'mean', ptu.get_numpy(self.mean), # exclude_max_min=True, )) stats.update( create_stats_ordered_dict( 'std', ptu.get_numpy(self.distribution.stddev), )) return stats
def get_diagnostics(self): stats = OrderedDict() stats.update( create_stats_ordered_dict( 'mean', ptu.get_numpy(self.mean), )) stats.update( create_stats_ordered_dict('normal/std', ptu.get_numpy(self.normal_std))) stats.update( create_stats_ordered_dict( 'normal/log_std', ptu.get_numpy(torch.log(self.normal_std)), )) return stats
def predict_transition(self, obs, actions, infos): if self.sampling_mode == 'ts': preds = self._predict_transition_ts(obs, actions, infos) elif self.sampling_mode == 'uniform': preds = self._predict_transition_uniform(obs, actions, infos) else: raise NotImplementedError('MPC sampling_mode not recognized') next_obs, rewards, dones = obs + preds[:, 2:], preds[:, 0], preds[:, 1] > 0.5 if self.reward_func is not None: given_rewards = self.reward_func(obs, actions, next_obs, num_timesteps=self.num_timesteps) self.diagnostics.update( create_stats_ordered_dict( 'Reward Squared Error', ptu.get_numpy((given_rewards - rewards)**2), always_show_all_stats=True, )) rewards = given_rewards return next_obs, rewards, dones
def get_plan_values_batch(self, obs, plans): """ Get corresponding values of the plans (higher corresponds to better plans). Classes that don't want to plan over actions or use trajectory sampling can reimplement convert_plans_to_actions (& convert_plan_to_action) and/or predict_transition. plans is input as as torch (horizon_length, num_particles (total), plan_dim). We maintain trajectory infos as torch (n_part, info_dim (ex. obs_dim)). """ if self.use_gt_model: return self.get_plan_values_batch_gt(obs, plans) n_part = plans.shape[ 1] # *total* number of particles, NOT num_particles discount = 1 returns, dones, infos = ptu.zeros(n_part), ptu.zeros(n_part), dict() # The effective planning horizon is self.horizon * self.repeat_length for t in range(self.horizon): for k in range(self.repeat_length): cur_actions = self.convert_plans_to_actions(obs, plans[t]) obs, cur_rewards, cur_dones = self.predict_transition( obs, cur_actions, infos) returns += discount * (1 - dones) * cur_rewards discount *= self.discount if self.predict_terminal: dones = torch.max(dones, cur_dones.float()) self.diagnostics.update( create_stats_ordered_dict( 'MPC Termination', ptu.get_numpy(dones), )) if self.value_func is not None: terminal_values = self.value_func( obs, **self.value_func_kwargs).view(-1) returns += discount * (1 - dones) * terminal_values self.diagnostics.update( create_stats_ordered_dict( 'MPC Terminal Values', ptu.get_numpy(terminal_values), )) return returns
def get_diagnostics(self): path_lens = [len(path['actions']) for path in self._epoch_paths] stats = OrderedDict([ ('num steps total', self._num_steps_total), ('num paths total', self._num_paths_total), ]) stats.update( create_stats_ordered_dict( "path length", path_lens, always_show_all_stats=True, )) return stats
def reward_postprocessing(self, rewards, reward_kwargs=None, *args, **kwargs): if self.disagreement_threshold is None: return super().reward_postprocessing(rewards) rewards, diagnostics = super().reward_postprocessing(rewards) disagreements = reward_kwargs['disagreements'] violated = disagreements > self.disagreement_threshold rewards[violated] = self.reward_bounds[0] if self._need_to_update_eval_statistics: diagnostics.update( create_stats_ordered_dict( 'Model Disagreement', disagreements, )) diagnostics['Pct of Timesteps over Disagreement Cutoff'] = np.mean( violated) return rewards, diagnostics
def train_from_paths(self, paths): """ Path preprocessing; have to copy so we don't modify when paths are used elsewhere """ paths = copy.deepcopy(paths) for path in paths: # Other places like to have an extra dimension so that all arrays are 2D path['terminals'] = np.squeeze(path['terminals'], axis=-1) path['rewards'] = np.squeeze(path['rewards'], axis=-1) # Reward normalization; divide by std of reward in replay buffer path['rewards'] = np.clip( path['rewards'] / (self._reward_std + 1e-3), -10, 10) obs, actions = [], [] for path in paths: obs.append(path['observations']) actions.append(path['actions']) obs = np.concatenate(obs, axis=0) actions = np.concatenate(actions, axis=0) obs_tensor, act_tensor = ptu.from_numpy(obs), ptu.from_numpy(actions) """ Policy training loop """ old_policy = copy.deepcopy(self.policy) with torch.no_grad(): log_probs_old = old_policy.get_log_probs( obs_tensor, act_tensor).squeeze(dim=-1) rem_value_epochs = self.num_epochs for epoch in range(self.num_policy_epochs): """ Recompute advantages at the beginning of each epoch. This allows for advantages to utilize the latest value function. Note: while this is not present in most implementations, it is recommended by Andrychowicz et al. 2020. """ path_functions.calculate_baselines(paths, self.value_func) path_functions.calculate_returns(paths, self.discount) path_functions.calculate_advantages( paths, self.discount, self.gae_lambda, self.normalize_advantages, ) advantages, returns, baselines = [], [], [] for path in paths: advantages = np.append(advantages, path['advantages']) returns = np.append(returns, path['returns']) if epoch == 0 and self._need_to_update_eval_statistics: with torch.no_grad(): values = torch.squeeze(self.value_func(obs_tensor), dim=-1) values_np = ptu.get_numpy(values) first_val_loss = ((returns - values_np)**2).mean() old_params = self.policy.get_param_values() num_policy_steps = len(advantages) // self.policy_batch_size for _ in range(num_policy_steps): if num_policy_steps == 1: batch = dict( observations=obs, actions=actions, advantages=advantages, ) else: batch = ppp.sample_batch( self.policy_batch_size, observations=obs, actions=actions, advantages=advantages, ) policy_loss, kl = self.train_policy(batch, old_policy) with torch.no_grad(): log_probs = self.policy.get_log_probs( obs_tensor, act_tensor).squeeze(dim=-1) kl = (log_probs_old - log_probs).mean() if (self.target_kl is not None and kl > 1.5 * self.target_kl) or (kl != kl): if epoch > 0 or kl != kl: # nan check self.policy.set_param_values(old_params) break num_value_steps = len(advantages) // self.value_batch_size for i in range(num_value_steps): batch = ppp.sample_batch( self.value_batch_size, observations=obs, targets=returns, ) value_loss = self.train_value(batch) rem_value_epochs -= 1 # Ensure the value function is always updated for the maximum number # of epochs, regardless of if the policy wants to terminate early. for _ in range(rem_value_epochs): num_value_steps = len(advantages) // self.value_batch_size for i in range(num_value_steps): batch = ppp.sample_batch( self.value_batch_size, observations=obs, targets=returns, ) value_loss = self.train_value(batch) if self._need_to_update_eval_statistics: with torch.no_grad(): _, _, _, log_pi, *_ = self.policy(obs_tensor, return_log_prob=True) values = torch.squeeze(self.value_func(obs_tensor), dim=-1) values_np = ptu.get_numpy(values) errors = returns - values_np explained_variance = 1 - (np.var(errors) / np.var(returns)) value_loss = errors**2 self.eval_statistics['Num Epochs'] = epoch + 1 self.eval_statistics['Policy Loss'] = ptu.get_numpy( policy_loss).mean() self.eval_statistics['KL Divergence'] = ptu.get_numpy(kl).mean() self.eval_statistics.update( create_stats_ordered_dict( 'Log Pis', ptu.get_numpy(log_pi), )) self.eval_statistics.update( create_stats_ordered_dict( 'Advantages', advantages, )) self.eval_statistics.update( create_stats_ordered_dict( 'Returns', returns, )) self.eval_statistics['Value Loss'] = value_loss.mean() self.eval_statistics['First Value Loss'] = first_val_loss self.eval_statistics[ 'Value Explained Variance'] = explained_variance self.eval_statistics.update( create_stats_ordered_dict( 'Values', ptu.get_numpy(values), )) self.eval_statistics.update( create_stats_ordered_dict( 'Value Squared Errors', value_loss, ))
def train_from_torch(self, batch): rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] """ Critic operations. """ next_actions = self.target_policy(next_obs) noise = ptu.randn(next_actions.shape) * self.target_policy_noise noise = torch.clamp(noise, -self.target_policy_noise_clip, self.target_policy_noise_clip) noisy_next_actions = next_actions + noise target_q1_values = self.target_qf1(next_obs, noisy_next_actions) target_q2_values = self.target_qf2(next_obs, noisy_next_actions) target_q_values = torch.min(target_q1_values, target_q2_values) q_target = self.reward_scale * rewards + ( 1. - terminals) * self.discount * target_q_values q_target = q_target.detach() q1_pred = self.qf1(obs, actions) bellman_errors_1 = (q1_pred - q_target)**2 qf1_loss = bellman_errors_1.mean() q2_pred = self.qf2(obs, actions) bellman_errors_2 = (q2_pred - q_target)**2 qf2_loss = bellman_errors_2.mean() """ Update Networks """ self.qf1_optimizer.zero_grad() qf1_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.zero_grad() qf2_loss.backward() self.qf2_optimizer.step() policy_actions = policy_loss = None if self._n_train_steps_total % self.policy_and_target_update_period == 0: policy_actions = self.policy(obs) q_output = self.qf1(obs, policy_actions) policy_loss = -q_output.mean() self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() ptu.soft_update_from_to(self.policy, self.target_policy, self.tau) ptu.soft_update_from_to(self.qf1, self.target_qf1, self.tau) ptu.soft_update_from_to(self.qf2, self.target_qf2, self.tau) if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False if policy_loss is None: policy_actions = self.policy(obs) q_output = self.qf1(obs, policy_actions) policy_loss = -q_output.mean() self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) self.eval_statistics.update( create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics.update( create_stats_ordered_dict( 'Bellman Errors 1', ptu.get_numpy(bellman_errors_1), )) self.eval_statistics.update( create_stats_ordered_dict( 'Bellman Errors 2', ptu.get_numpy(bellman_errors_2), )) self.eval_statistics.update( create_stats_ordered_dict( 'Policy Action', ptu.get_numpy(policy_actions), )) self._n_train_steps_total += 1
def train_from_torch(self, batch): # We only use the original batch to get the batch size for policy training """ Generate synthetic data using dynamics model """ if self._n_train_steps_total % self.rollout_generation_freq == 0: rollout_len = self.rollout_len_func(self._n_train_steps_total) total_samples = self.rollout_generation_freq * self.num_model_rollouts num_samples, generated_rewards, terminated = 0, np.array([]), [] while num_samples < total_samples: batch_samples = min(self.rollout_batch_size, total_samples - num_samples) real_batch = self.replay_buffer.random_batch(batch_samples) start_states = real_batch['observations'] with torch.no_grad(): paths = self.sample_paths(start_states, rollout_len) for path in paths: self.generated_data_buffer.add_path(path) num_samples += len(path['observations']) generated_rewards = np.concatenate( [generated_rewards, path['rewards'][:, 0]]) terminated.append(path['terminals'][-1, 0] > 0.5) if num_samples >= total_samples: break gt.stamp('generating rollouts', unique=False) """ Update policy on both real and generated data """ batch_size = batch['observations'].shape[0] n_real_data = int(self.real_data_pct * batch_size) n_generated_data = batch_size - n_real_data for _ in range(self.num_policy_updates): batch = self.replay_buffer.random_batch(n_real_data) generated_batch = self.generated_data_buffer.random_batch( n_generated_data) for k in ('rewards', 'terminals', 'observations', 'actions', 'next_observations'): batch[k] = np.concatenate((batch[k], generated_batch[k]), axis=0) batch[k] = ptu.from_numpy(batch[k]) self.policy_trainer.train_from_torch(batch) """ Save some statistics for eval """ if self._need_to_update_eval_statistics and self._n_train_steps_total % self.rollout_generation_freq == 0: self._need_to_update_eval_statistics = False self.eval_statistics['MBPO Rollout Length'] = rollout_len self.eval_statistics.update( create_stats_ordered_dict( 'MBPO Reward Predictions', generated_rewards, )) self.eval_statistics.update( create_stats_ordered_dict( 'MBPO Rollout Terminations', np.array(terminated).astype(float), )) self._n_train_steps_total += 1
def sample_paths(self, start_states, rollout_len): if self.sampling_mode == 'uniform': # Sample uniformly from a model of the ensemble (original MBPO; Janner et al. 2019) paths = mrf.policy( self.dynamics_model, self.policy_trainer.policy, start_states, max_path_length=rollout_len, ) elif self.sampling_mode == 'mean_disagreement': # Sample with penalty for disagreement of the mean (MOReL; Kidambi et al. 2020) paths, disagreements = mrf.policy_with_disagreement( self.dynamics_model, self.policy_trainer.policy, start_states, max_path_length=rollout_len, disagreement_type='mean', ) disagreements = ptu.get_numpy(disagreements) threshold, penalty = self.sampling_kwargs[ 'threshold'], self.sampling_kwargs['penalty'] total_penalized, total_transitions = 0, 0 for i, path in enumerate(paths): mask = np.zeros(len(path['rewards'])) disagreement_values = disagreements[i] for t in range(len(path['rewards'])): cur_mask = disagreement_values[t] > threshold if t == 0: mask[t] = cur_mask elif cur_mask or mask[t - 1] > 0.5: mask[t] = 1. else: mask[t] = 0. mask = mask.reshape(len(mask), 1) path['rewards'] = (1 - mask) * path['rewards'] - mask * penalty total_penalized += mask.sum() total_transitions += len(path) self.eval_statistics[ 'Percent of Transitions Penalized'] = total_penalized / total_transitions self.eval_statistics.update( create_stats_ordered_dict( 'Disagreement Values', disagreements, )) elif self.sampling_mode == 'var_disagreement': # Sample with penalty for disagreement of the variance (MOPO; Yu et al. 2020) paths, disagreements = mrf.policy_with_disagreement( self.dynamics_model, self.policy_trainer.policy, start_states, max_path_length=rollout_len, disagreement_type='var', ) disagreements = ptu.get_numpy(disagreements) reward_penalty = self.sampling_kwargs['reward_penalty'] for i, path in enumerate(paths): path_disagreements = disagreements[ i, :len(path['rewards'])].reshape(*path['rewards'].shape) path['rewards'] -= reward_penalty * path_disagreements self.eval_statistics.update( create_stats_ordered_dict( 'Disagreement Values', disagreements, )) else: raise NotImplementedError return paths
def train_from_torch(self, batch): rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] """ Policy operations. """ if self.policy_pre_activation_weight > 0: policy_actions, pre_tanh_value = self.policy( obs, return_preactivations=True, ) pre_activation_policy_loss = ((pre_tanh_value**2).sum( dim=1).mean()) q_output = self.qf(obs, policy_actions) raw_policy_loss = -q_output.mean() policy_loss = ( raw_policy_loss + pre_activation_policy_loss * self.policy_pre_activation_weight) else: policy_actions = self.policy(obs) q_output = self.qf(obs, policy_actions) raw_policy_loss = policy_loss = -q_output.mean() """ Critic operations. """ next_actions = self.target_policy(next_obs) # speed up computation by not backpropping these gradients next_actions.detach() target_q_values = self.target_qf( next_obs, next_actions, ) q_target = rewards + (1. - terminals) * self.discount * target_q_values q_target = q_target.detach() q_target = torch.clamp(q_target, self.min_q_value, self.max_q_value) q_pred = self.qf(obs, actions) bellman_errors = (q_pred - q_target)**2 raw_qf_loss = self.qf_criterion(q_pred, q_target) if self.qf_weight_decay > 0: reg_loss = self.qf_weight_decay * sum( torch.sum(param**2) for param in self.qf.regularizable_parameters()) qf_loss = raw_qf_loss + reg_loss else: qf_loss = raw_qf_loss """ Update Networks """ self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() self.qf_optimizer.zero_grad() qf_loss.backward() self.qf_optimizer.step() self._update_target_networks() """ Save some statistics for eval using just one batch. """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss)) self.eval_statistics['Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) self.eval_statistics['Raw Policy Loss'] = np.mean( ptu.get_numpy(raw_policy_loss)) self.eval_statistics['Preactivation Policy Loss'] = ( self.eval_statistics['Policy Loss'] - self.eval_statistics['Raw Policy Loss']) self.eval_statistics.update( create_stats_ordered_dict( 'Q Predictions', ptu.get_numpy(q_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics.update( create_stats_ordered_dict( 'Bellman Errors', ptu.get_numpy(bellman_errors), )) self.eval_statistics.update( create_stats_ordered_dict( 'Policy Action', ptu.get_numpy(policy_actions), )) self._n_train_steps_total += 1
def train_from_torch(self, batch): obs = batch['observations'] next_obs = batch['next_observations'] actions = batch['actions'] rewards = batch['rewards'] terminals = batch.get('terminals', ptu.zeros(rewards.shape[0], 1)) """ Policy and Alpha Loss """ _, policy_mean, policy_logstd, *_ = self.policy(obs) dist = TanhNormal(policy_mean, policy_logstd.exp()) new_obs_actions, log_pi = dist.rsample_and_logprob() log_pi = log_pi.sum(dim=-1, keepdims=True) if self.use_automatic_entropy_tuning: alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() alpha = self.log_alpha.exp() else: alpha_loss = 0 alpha = 1 q_new_actions = torch.min( self.qf1(obs, new_obs_actions), self.qf2(obs, new_obs_actions), ) policy_loss = (alpha * log_pi - q_new_actions).mean() """ QF Loss """ q1_pred = self.qf1(obs, actions) q2_pred = self.qf2(obs, actions) _, next_policy_mean, next_policy_logstd, *_ = self.policy(next_obs) next_dist = TanhNormal(next_policy_mean, next_policy_logstd.exp()) new_next_actions, new_log_pi = next_dist.rsample_and_logprob() new_log_pi = new_log_pi.sum(dim=-1, keepdims=True) target_q_values = torch.min( self.target_qf1(next_obs, new_next_actions), self.target_qf2(next_obs, new_next_actions), ) - alpha * new_log_pi future_values = (1. - terminals) * self.discount * target_q_values q_target = self.reward_scale * rewards + future_values qf1_loss = self.qf_criterion(q1_pred, q_target.detach()) qf2_loss = self.qf_criterion(q2_pred, q_target.detach()) if self.use_automatic_entropy_tuning: self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() self.qf1_optimizer.zero_grad() qf1_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.zero_grad() qf2_loss.backward() self.qf2_optimizer.step() self._n_train_steps_total += 1 self.try_update_target_networks() """ Save some statistics for eval """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False policy_loss = (log_pi - q_new_actions).mean() policy_avg_std = torch.exp(policy_logstd).mean() self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy( policy_loss )) self.eval_statistics.update(create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics.update(create_stats_ordered_dict( 'Log Pis', ptu.get_numpy(log_pi), )) self.eval_statistics.update(create_stats_ordered_dict( 'Policy mu', ptu.get_numpy(policy_mean), )) self.eval_statistics.update(create_stats_ordered_dict( 'Policy log std', ptu.get_numpy(policy_logstd), )) self.eval_statistics['Policy std'] = np.mean(ptu.get_numpy(policy_avg_std)) if self.use_automatic_entropy_tuning: self.eval_statistics['Alpha'] = alpha.item() self.eval_statistics['Alpha Loss'] = alpha_loss.item() self._n_train_steps_total += 1
def train_from_buffer(self, reward_kwargs=None): """ Compute intrinsic reward: approximate lower bound to I(s'; z | s) """ if self.relabel_rewards: rewards, ( logp, logp_altz, denom), reward_diagnostics = self.calculate_intrinsic_rewards( self._obs[:self._cur_replay_size], self._next_obs[:self._cur_replay_size], self._latents[:self._cur_replay_size], reward_kwargs=reward_kwargs) orig_rewards = rewards.copy() rewards, postproc_dict = self.reward_postprocessing( rewards, reward_kwargs=reward_kwargs) reward_diagnostics.update(postproc_dict) self._rewards[:self._cur_replay_size] = np.expand_dims(rewards, axis=-1) gt.stamp('intrinsic reward calculation', unique=False) """ Train policy """ state_latents = np.concatenate([self._obs, self._latents], axis=-1)[:self._cur_replay_size] next_state_latents = np.concatenate( [self._true_next_obs, self._latents], axis=-1)[:self._cur_replay_size] for _ in range(self.num_policy_updates): batch = ppp.sample_batch( self.policy_batch_size, observations=state_latents, next_observations=next_state_latents, actions=self._actions[:self._cur_replay_size], rewards=self._rewards[:self._cur_replay_size], ) batch = ptu.np_to_pytorch_batch(batch) self.policy_trainer.train_from_torch(batch) gt.stamp('policy training', unique=False) """ Diagnostics """ if self._need_to_update_eval_statistics: self.eval_statistics.update(self.policy_trainer.eval_statistics) if self.relabel_rewards: self.eval_statistics.update(reward_diagnostics) self.eval_statistics.update( create_stats_ordered_dict( 'Discriminator Log Pis', logp, )) self.eval_statistics.update( create_stats_ordered_dict( 'Discriminator Alt Log Pis', logp_altz, )) self.eval_statistics.update( create_stats_ordered_dict( 'Intrinsic Reward Denominator', denom, )) # Adjustment so intrinsic rewards are over last epoch if self._ptr < self._epoch_size: if self._ptr == 0: inds = np.r_[len(rewards) - self._epoch_size:len(rewards)] else: inds = np.r_[0:self._ptr, len(rewards) - self._ptr:len(rewards)] else: inds = np.r_[self._ptr - self._epoch_size:self._ptr] self.eval_statistics.update( create_stats_ordered_dict( 'Intrinsic Rewards (Original)', orig_rewards[inds], )) self.eval_statistics.update( create_stats_ordered_dict( 'Intrinsic Rewards (Processed)', rewards[inds], )) self._n_train_steps_total += 1
def train_from_torch(self, batch): self._train_calls += 1 if self._train_calls % self.train_every > 0: return rollout_len = self.rollout_len_func(self._n_train_steps_total) num_model_rollouts = max(self.num_model_samples // rollout_len, 1) self.eval_statistics['Rollout Length'] = rollout_len real_batch = self.replay_buffer.random_batch(num_model_rollouts) start_states = real_batch['observations'] latents = self.generate_latents(start_states) observations = np.zeros((self.num_model_samples, self.obs_dim)) next_observations = np.zeros((self.num_model_samples, self.obs_dim)) actions = np.zeros((self.num_model_samples, self.action_dim)) unfolded_latents = np.zeros((self.num_model_samples, self.latent_dim)) disagreements = np.zeros(self.num_model_samples) num_samples, b_ind, num_traj = 0, 0, 0 while num_samples < self.num_model_samples: e_ind = b_ind + 4192 // rollout_len with torch.no_grad(): paths, path_disagreements = self.generate_paths( dynamics_model=self.dynamics_model, control_policy=self.control_policy, start_states=start_states[b_ind:e_ind], latents=ptu.from_numpy(latents[b_ind:e_ind]), rollout_len=rollout_len, ) b_ind = e_ind path_disagreements = ptu.get_numpy(path_disagreements) for i, path in enumerate(paths): clipped_len = min( len(path['observations'] - (self.empowerment_horizon - 1)), self.num_model_samples - num_samples) bi, ei = num_samples, num_samples + clipped_len if self.empowerment_horizon > 1: path['observations'] = path['observations'][:-( self.empowerment_horizon - 1)] path['next_observations'] = path['next_observations'][( self.empowerment_horizon - 1):(self.empowerment_horizon - 1) + clipped_len] path['actions'] = path['actions'][:-( self.empowerment_horizon - 1)] observations[bi:ei] = path['observations'][:clipped_len] next_observations[bi:ei] = path[ 'next_observations'][:clipped_len] actions[bi:ei] = path['actions'][:clipped_len] unfolded_latents[bi:ei] = latents[num_traj:num_traj + 1] disagreements[bi:ei] = path_disagreements[i, :clipped_len] num_samples += clipped_len num_traj += 1 if num_samples >= self.num_model_samples: break gt.stamp('generating rollouts', unique=False) if not self.relabel_rewards: rewards, ( logp, logp_altz, denom), reward_diagnostics = self.calculate_intrinsic_rewards( observations, next_observations, unfolded_latents) orig_rewards = rewards.copy() rewards, postproc_dict = self.reward_postprocessing( rewards, reward_kwargs=dict(disagreements=disagreements)) reward_diagnostics.update(postproc_dict) if self._need_to_update_eval_statistics: self.eval_statistics.update(reward_diagnostics) self.eval_statistics.update( create_stats_ordered_dict( 'Discriminator Log Pis', logp, )) self.eval_statistics.update( create_stats_ordered_dict( 'Discriminator Alt Log Pis', logp_altz, )) self.eval_statistics.update( create_stats_ordered_dict( 'Intrinsic Reward Denominator', denom, )) self.eval_statistics.update( create_stats_ordered_dict( 'Intrinsic Rewards (Original)', orig_rewards, )) self.eval_statistics.update( create_stats_ordered_dict( 'Intrinsic Rewards (Processed)', rewards, )) gt.stamp('intrinsic reward calculation', unique=False) if self._need_to_update_eval_statistics: self.eval_statistics.update( create_stats_ordered_dict( 'Latents', latents, )) for t in range(self.num_model_samples): self.add_sample( observations[t], next_observations[t], next_observations[t], # fix this actions[t], unfolded_latents[t], disagreement=disagreements[t], ) gt.stamp('policy training', unique=False) self.train_discriminator(observations, next_observations, unfolded_latents) reward_kwargs = dict( disagreements=self._modeL_disagreements[:self._cur_replay_size]) self.train_from_buffer(reward_kwargs=reward_kwargs)