def create_particles(self, plans, obs, n_part=None): n_opt = plans.shape[0] if n_part is None: n_part = self.num_particles * self.num_models # (N, H*m) plans = ptu.from_numpy(plans) # (N, H, m) plans = plans.view(-1, self.horizon, self.plan_dim) # (H, N, m) transposed = plans.transpose(0, 1) # (H, N, 1, m) expanded = transposed[:, :, None] # (H, N, P, m) tiled = expanded.expand(-1, -1, n_part, -1) # (H, N*P, m) plans = tiled.contiguous().view(self.horizon, -1, self.plan_dim) # (n,) obs = ptu.from_numpy(self._observation) # (1, n) obs = obs[None] # (N*P, n) obs = obs.expand(n_opt * n_part, -1) return plans, obs
def _get_model_plan_value(self, obs, plan): obs, plan = ptu.from_numpy(obs), ptu.from_numpy(plan) plans = plan.view(-1, self.horizon, self.plan_dim) plans = plans.permute(1, 0, 2) obs = obs.view(1, -1) returns = self.get_plan_values(obs, plans) return ptu.get_numpy(returns).mean()
def fit_input_stats(self, data, mask=None): mean = np.mean(data, axis=0, keepdims=True) std = np.std(data, axis=0, keepdims=True) std[std != std] = 0 std[std < 1e-12] = 1.0 if mask is not None: mean *= mask std = mask * std + (1-mask) * np.ones(self.input_size) self.input_mu.data = ptu.from_numpy(mean) self.input_std.data = ptu.from_numpy(std)
def train_from_paths(self, paths, train_discrim=True, train_policy=True): """ Reading new paths: append latent to state Note that is equivalent to on-policy when latent buffer size = sum of paths length """ epoch_obs, epoch_next_obs, epoch_latents = [], [], [] for path in paths: obs = path['observations'] next_obs = path['next_observations'] actions = path['actions'] latents = path.get('latents', None) path_len = len(obs) - self.empowerment_horizon + 1 obs_latents = np.concatenate([obs, latents], axis=-1) log_probs = self.control_policy.get_log_probs( ptu.from_numpy(obs_latents), ptu.from_numpy(actions), ) log_probs = ptu.get_numpy(log_probs) for t in range(path_len): self.add_sample( obs[t], next_obs[t+self.empowerment_horizon-1], next_obs[t], actions[t], latents[t], logprob=log_probs[t], ) epoch_obs.append(obs[t:t+1]) epoch_next_obs.append(next_obs[t+self.empowerment_horizon-1:t+self.empowerment_horizon]) epoch_latents.append(np.expand_dims(latents[t], axis=0)) epoch_obs = np.concatenate(epoch_obs, axis=0) epoch_next_obs = np.concatenate(epoch_next_obs, axis=0) epoch_latents = np.concatenate(epoch_latents, axis=0) self._epoch_size = len(epoch_obs) gt.stamp('policy training', unique=False) """ The rest is shared, train from buffer """ if train_discrim: self.train_discriminator(epoch_obs, epoch_next_obs, epoch_latents) if train_policy: self.train_from_buffer()
def train_value(self, batch): obs = ptu.from_numpy(batch['observations']) targets = ptu.from_numpy(batch['targets']) value_preds = torch.squeeze(self.value_func(obs), dim=-1) value_loss = 0.5 * ((value_preds - targets)**2).mean() self.value_optim.zero_grad() value_loss.backward() self.value_optim.step() return value_loss
def sample_latent(self, state=None): if self.unconditional or state is None: # this will probably be changed latent = self.prior.sample() # n=1).squeeze(0) else: latent = self.prior.forward(ptu.from_numpy(state)) self.set_latent(latent) return latent
def train_policy(self, batch, old_policy): obs = ptu.from_numpy(batch['observations']) actions = ptu.from_numpy(batch['actions']) advantages = ptu.from_numpy(batch['advantages']) objective, kl = self.policy_objective(obs, actions, advantages, old_policy) policy_loss = -objective self.policy_optim.zero_grad() policy_loss.backward() torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy_optim.step() return policy_loss, kl
def _create_full_tensors(start_states, max_path_length, obs_dim, action_dim): num_rollouts = start_states.shape[0] observations = ptu.zeros((num_rollouts, max_path_length + 1, obs_dim)) observations[:, 0] = ptu.from_numpy(start_states) actions = ptu.zeros((num_rollouts, max_path_length, action_dim)) rewards = ptu.zeros((num_rollouts, max_path_length, 1)) terminals = ptu.zeros((num_rollouts, max_path_length, 1)) return observations, actions, rewards, terminals
def set_param_values(self, new_params): current_idx = 0 for idx, param in enumerate(self.trainable_params): vals = new_params[current_idx:current_idx + self.param_sizes[idx]] vals = vals.reshape(self.param_shapes[idx]) param.data = ptu.from_numpy(vals).float() current_idx += self.param_sizes[idx] self.trainable_params[-1].data = torch.clamp(self.trainable_params[-1], LOG_SIG_MIN)
def train_from_paths(self, paths): """ Path processing """ paths = copy.deepcopy(paths) for path in paths: obs, next_obs = path['observations'], path['next_observations'] states, next_states = obs[:,:self.state_dim], next_obs[:,:self.state_dim] goals = obs[:,self.state_dim:2*self.state_dim] actions = path['actions'] terminals = path['terminals'] # this is probably always False, but might want it? path_len = len(obs) # Relabel goals based on transitions taken relabeled_goals = [] for t in range(len(obs)): relabeled_goals.append(self.relabel_goal_func( states[t], actions[t], next_states[t], goals[t], )) relabeled_goals = np.array(relabeled_goals) # Add transitions & resampled goals to replay buffer for t in range(path_len): goals_t = goals[t:t+1] for _ in range(self.num_sampled_goals): if self.relabel_method == 'future': goal_inds = np.random.randint(t, path_len, self.num_sampled_goals) goals_t = np.concatenate([goals_t, relabeled_goals[goal_inds]], axis=0) else: raise NotImplementedError for k in range(len(goals_t)): if not self.learn_reward_func: r = self.reward_func(states[t], actions[t], next_states[t], goals_t[k]) else: r = ptu.get_numpy( self.learned_reward_func( ptu.from_numpy( np.concatenate([next_states[t], goals[t]])))).mean() self.replay_buffer.add_sample( observation=np.concatenate([states[t], goals_t[k], obs[t,2*self.state_dim:]]), action=actions[t], reward=r, terminal=terminals[t], # not obvious what desired behavior is next_observation=np.concatenate( [next_states[t,:self.state_dim], goals_t[k], obs[t,2*self.state_dim:]]), env_info=None, ) """ Off-policy training """ for _ in range(self.num_policy_steps): train_data = self.replay_buffer.random_batch(self.policy_batch_size) self.policy_trainer.train(train_data)
def calculate_baselines(paths, value_func): for path in paths: obs = ptu.from_numpy( np.concatenate( [path['observations'], path['next_observations'][-1:]], axis=0)) values = torch.squeeze(value_func(obs), dim=-1) path['baselines'] = ptu.get_numpy(values) if path['terminals'][-1]: path['baselines'][-1] = 0
def generate_latents(self, obs): if self._train_calls < self.num_unif_train_calls: return super().generate_latents(obs) latents, *_ = self.skill_practice_dist(ptu.from_numpy(obs)) latents = ptu.get_numpy(latents) if self.epsilon_greedy > 0: unif_r = np.random.uniform(0, 1, size=latents.shape[0]) eps_replace = unif_r < self.epsilon_greedy unif_latents = super().generate_latents(obs[eps_replace]) latents[eps_replace] = unif_latents return latents
def train_policy(self, batch, old_policy): obs = ptu.from_numpy(batch['observations']) actions = ptu.from_numpy(batch['actions']) advantages = ptu.from_numpy(batch['advantages']) log_probs = torch.squeeze(self.policy.get_log_probs(obs, actions), dim=-1) log_probs_old = torch.squeeze(old_policy.get_log_probs(obs, actions), dim=-1) kl = (log_probs_old - log_probs).mean() vpg_grad, cpi_surr = self.flat_vpg(obs, actions, advantages, old_policy) hvp = self.build_Hvp_eval([obs, actions, old_policy], regu_coef=self.FIM_invert_args['damping']) npg_grad = cg_solve(hvp, vpg_grad, x_0=vpg_grad.copy(), cg_iters=self.FIM_invert_args['iters']) alpha = np.sqrt(np.abs(self.normalized_step_size / (np.dot(vpg_grad.T, npg_grad) + 1e-20))) cur_params = self.policy.get_param_values() new_params = cur_params + alpha * npg_grad self.policy.set_param_values(new_params) return -cpi_surr, kl
def get_plan_values_batch_gt(self, obs, plans): returns = ptu.zeros(plans.shape[1]) obs, plans = ptu.get_numpy(obs), ptu.get_numpy(plans) final_obs = np.copy(obs) for i in range(plans.shape[1]): returns[i], final_obs[i] = self._get_true_env_value( obs[i], plans[:, i]) if self.value_func is not None: returns += (self.discount**( self.horizon * self.repeat_length)) * (self.value_func( ptu.from_numpy(final_obs), **self.value_func_kwargs)) return returns
def get_action(self, state): if (self._steps_since_last_sample >= self.steps_between_sampling or self._last_latent is None) and not self.fixed_latent: latent = self.sample_latent(state) self._steps_since_last_sample = 0 else: latent = self._last_latent self._steps_since_last_sample += 1 state = ptu.from_numpy(state) sz = torch.cat((state, latent)) action, *_ = self.policy.forward(sz) return ptu.get_numpy(action), dict()
def HVP(self, observations, actions, old_policy, vector, regu_coef=None): regu_coef = self.FIM_invert_args['damping'] if regu_coef is None else regu_coef vec = torch.autograd.Variable(ptu.from_numpy(vector).float(), requires_grad=False) if self.hvp_sample_frac is not None and self.hvp_sample_frac < 0.99: num_samples = observations.shape[0] rand_idx = np.random.choice(num_samples, size=int(self.hvp_sample_frac*num_samples)) obs = observations[rand_idx] act = actions[rand_idx] else: obs = observations act = actions log_probs = torch.squeeze(self.policy.get_log_probs(obs, act), dim=-1) log_probs_old = torch.squeeze(old_policy.get_log_probs(obs, act), dim=-1) mean_kl = (log_probs_old - log_probs).mean() grad_fo = torch.autograd.grad(mean_kl, self.policy.trainable_params, create_graph=True) flat_grad = torch.cat([g.contiguous().view(-1) for g in grad_fo]) h = torch.sum(flat_grad*vec) hvp = torch.autograd.grad(h, self.policy.trainable_params) hvp_flat = np.concatenate([g.contiguous().view(-1).cpu().data.numpy() for g in hvp]) return hvp_flat + regu_coef * vector
def _log_prob_from_pre_tanh(self, pre_tanh_value): """ Adapted from https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/tanh.py#L73 This formula is mathematically equivalent to log(1 - tanh(x)^2). Derivation: log(1 - tanh(x)^2) = log(sech(x)^2) = 2 * log(sech(x)) = 2 * log(2e^-x / (e^-2x + 1)) = 2 * (log(2) - x - log(e^-2x + 1)) = 2 * (log(2) - x - softplus(-2x)) :param value: some value, x :param pre_tanh_value: arctanh(x) :return: """ log_prob = self.normal.log_prob(pre_tanh_value) correction = -2. * (ptu.from_numpy(np.log([2.])) - pre_tanh_value - torch.nn.functional.softplus(-2. * pre_tanh_value)) return log_prob + correction
def create_mask(inverse_beta_func, n_quantiles, risk_kwargs): """ x in [0, 1] represents the CDF of the input. beta(x) represents the cumulative weight assigned to the lower x% of values, e.g. it is analogous to the CDF. This is typically easier to represent via the inverse of the beta function, so we take the inverse of the inverse beta function to get the original function. The reweighted function becomes: R(f, beta) = sum_i f(i/n) * (beta((i+1)/(n+1)) - beta(i/(n+1)) """ tau = np.linspace(0, 1, n_quantiles + 1) betas = np.zeros(n_quantiles + 1) mask = np.zeros(n_quantiles) # TODO: there are some issues with mask and risk_kwarg caching for i in range(n_quantiles + 1): betas[i] = inverse_beta_func(tau[i], risk_kwargs) for i in range(n_quantiles): mask[i] = betas[i + 1] - betas[i] return ptu.from_numpy(mask)
def train_from_paths(self, paths): """ Path preprocessing; have to copy so we don't modify when paths are used elsewhere """ paths = copy.deepcopy(paths) for path in paths: # Other places like to have an extra dimension so that all arrays are 2D path['terminals'] = np.squeeze(path['terminals'], axis=-1) path['rewards'] = np.squeeze(path['rewards'], axis=-1) # Reward normalization; divide by std of reward in replay buffer path['rewards'] = np.clip( path['rewards'] / (self._reward_std + 1e-3), -10, 10) obs, actions = [], [] for path in paths: obs.append(path['observations']) actions.append(path['actions']) obs = np.concatenate(obs, axis=0) actions = np.concatenate(actions, axis=0) obs_tensor, act_tensor = ptu.from_numpy(obs), ptu.from_numpy(actions) """ Policy training loop """ old_policy = copy.deepcopy(self.policy) with torch.no_grad(): log_probs_old = old_policy.get_log_probs( obs_tensor, act_tensor).squeeze(dim=-1) rem_value_epochs = self.num_epochs for epoch in range(self.num_policy_epochs): """ Recompute advantages at the beginning of each epoch. This allows for advantages to utilize the latest value function. Note: while this is not present in most implementations, it is recommended by Andrychowicz et al. 2020. """ path_functions.calculate_baselines(paths, self.value_func) path_functions.calculate_returns(paths, self.discount) path_functions.calculate_advantages( paths, self.discount, self.gae_lambda, self.normalize_advantages, ) advantages, returns, baselines = [], [], [] for path in paths: advantages = np.append(advantages, path['advantages']) returns = np.append(returns, path['returns']) if epoch == 0 and self._need_to_update_eval_statistics: with torch.no_grad(): values = torch.squeeze(self.value_func(obs_tensor), dim=-1) values_np = ptu.get_numpy(values) first_val_loss = ((returns - values_np)**2).mean() old_params = self.policy.get_param_values() num_policy_steps = len(advantages) // self.policy_batch_size for _ in range(num_policy_steps): if num_policy_steps == 1: batch = dict( observations=obs, actions=actions, advantages=advantages, ) else: batch = ppp.sample_batch( self.policy_batch_size, observations=obs, actions=actions, advantages=advantages, ) policy_loss, kl = self.train_policy(batch, old_policy) with torch.no_grad(): log_probs = self.policy.get_log_probs( obs_tensor, act_tensor).squeeze(dim=-1) kl = (log_probs_old - log_probs).mean() if (self.target_kl is not None and kl > 1.5 * self.target_kl) or (kl != kl): if epoch > 0 or kl != kl: # nan check self.policy.set_param_values(old_params) break num_value_steps = len(advantages) // self.value_batch_size for i in range(num_value_steps): batch = ppp.sample_batch( self.value_batch_size, observations=obs, targets=returns, ) value_loss = self.train_value(batch) rem_value_epochs -= 1 # Ensure the value function is always updated for the maximum number # of epochs, regardless of if the policy wants to terminate early. for _ in range(rem_value_epochs): num_value_steps = len(advantages) // self.value_batch_size for i in range(num_value_steps): batch = ppp.sample_batch( self.value_batch_size, observations=obs, targets=returns, ) value_loss = self.train_value(batch) if self._need_to_update_eval_statistics: with torch.no_grad(): _, _, _, log_pi, *_ = self.policy(obs_tensor, return_log_prob=True) values = torch.squeeze(self.value_func(obs_tensor), dim=-1) values_np = ptu.get_numpy(values) errors = returns - values_np explained_variance = 1 - (np.var(errors) / np.var(returns)) value_loss = errors**2 self.eval_statistics['Num Epochs'] = epoch + 1 self.eval_statistics['Policy Loss'] = ptu.get_numpy( policy_loss).mean() self.eval_statistics['KL Divergence'] = ptu.get_numpy(kl).mean() self.eval_statistics.update( create_stats_ordered_dict( 'Log Pis', ptu.get_numpy(log_pi), )) self.eval_statistics.update( create_stats_ordered_dict( 'Advantages', advantages, )) self.eval_statistics.update( create_stats_ordered_dict( 'Returns', returns, )) self.eval_statistics['Value Loss'] = value_loss.mean() self.eval_statistics['First Value Loss'] = first_val_loss self.eval_statistics[ 'Value Explained Variance'] = explained_variance self.eval_statistics.update( create_stats_ordered_dict( 'Values', ptu.get_numpy(values), )) self.eval_statistics.update( create_stats_ordered_dict( 'Value Squared Errors', value_loss, ))
def train_from_torch(self, batch): # We only use the original batch to get the batch size for policy training """ Generate synthetic data using dynamics model """ if self._n_train_steps_total % self.rollout_generation_freq == 0: rollout_len = self.rollout_len_func(self._n_train_steps_total) total_samples = self.rollout_generation_freq * self.num_model_rollouts num_samples, generated_rewards, terminated = 0, np.array([]), [] while num_samples < total_samples: batch_samples = min(self.rollout_batch_size, total_samples - num_samples) real_batch = self.replay_buffer.random_batch(batch_samples) start_states = real_batch['observations'] with torch.no_grad(): paths = self.sample_paths(start_states, rollout_len) for path in paths: self.generated_data_buffer.add_path(path) num_samples += len(path['observations']) generated_rewards = np.concatenate( [generated_rewards, path['rewards'][:, 0]]) terminated.append(path['terminals'][-1, 0] > 0.5) if num_samples >= total_samples: break gt.stamp('generating rollouts', unique=False) """ Update policy on both real and generated data """ batch_size = batch['observations'].shape[0] n_real_data = int(self.real_data_pct * batch_size) n_generated_data = batch_size - n_real_data for _ in range(self.num_policy_updates): batch = self.replay_buffer.random_batch(n_real_data) generated_batch = self.generated_data_buffer.random_batch( n_generated_data) for k in ('rewards', 'terminals', 'observations', 'actions', 'next_observations'): batch[k] = np.concatenate((batch[k], generated_batch[k]), axis=0) batch[k] = ptu.from_numpy(batch[k]) self.policy_trainer.train_from_torch(batch) """ Save some statistics for eval """ if self._need_to_update_eval_statistics and self._n_train_steps_total % self.rollout_generation_freq == 0: self._need_to_update_eval_statistics = False self.eval_statistics['MBPO Rollout Length'] = rollout_len self.eval_statistics.update( create_stats_ordered_dict( 'MBPO Reward Predictions', generated_rewards, )) self.eval_statistics.update( create_stats_ordered_dict( 'MBPO Rollout Terminations', np.array(terminated).astype(float), )) self._n_train_steps_total += 1
def train_from_torch(self, batch): self._train_calls += 1 if self._train_calls % self.train_every > 0: return rollout_len = self.rollout_len_func(self._n_train_steps_total) num_model_rollouts = max(self.num_model_samples // rollout_len, 1) self.eval_statistics['Rollout Length'] = rollout_len real_batch = self.replay_buffer.random_batch(num_model_rollouts) start_states = real_batch['observations'] latents = self.generate_latents(start_states) observations = np.zeros((self.num_model_samples, self.obs_dim)) next_observations = np.zeros((self.num_model_samples, self.obs_dim)) actions = np.zeros((self.num_model_samples, self.action_dim)) unfolded_latents = np.zeros((self.num_model_samples, self.latent_dim)) disagreements = np.zeros(self.num_model_samples) num_samples, b_ind, num_traj = 0, 0, 0 while num_samples < self.num_model_samples: e_ind = b_ind + 4192 // rollout_len with torch.no_grad(): paths, path_disagreements = self.generate_paths( dynamics_model=self.dynamics_model, control_policy=self.control_policy, start_states=start_states[b_ind:e_ind], latents=ptu.from_numpy(latents[b_ind:e_ind]), rollout_len=rollout_len, ) b_ind = e_ind path_disagreements = ptu.get_numpy(path_disagreements) for i, path in enumerate(paths): clipped_len = min( len(path['observations'] - (self.empowerment_horizon - 1)), self.num_model_samples - num_samples) bi, ei = num_samples, num_samples + clipped_len if self.empowerment_horizon > 1: path['observations'] = path['observations'][:-( self.empowerment_horizon - 1)] path['next_observations'] = path['next_observations'][( self.empowerment_horizon - 1):(self.empowerment_horizon - 1) + clipped_len] path['actions'] = path['actions'][:-( self.empowerment_horizon - 1)] observations[bi:ei] = path['observations'][:clipped_len] next_observations[bi:ei] = path[ 'next_observations'][:clipped_len] actions[bi:ei] = path['actions'][:clipped_len] unfolded_latents[bi:ei] = latents[num_traj:num_traj + 1] disagreements[bi:ei] = path_disagreements[i, :clipped_len] num_samples += clipped_len num_traj += 1 if num_samples >= self.num_model_samples: break gt.stamp('generating rollouts', unique=False) if not self.relabel_rewards: rewards, ( logp, logp_altz, denom), reward_diagnostics = self.calculate_intrinsic_rewards( observations, next_observations, unfolded_latents) orig_rewards = rewards.copy() rewards, postproc_dict = self.reward_postprocessing( rewards, reward_kwargs=dict(disagreements=disagreements)) reward_diagnostics.update(postproc_dict) if self._need_to_update_eval_statistics: self.eval_statistics.update(reward_diagnostics) self.eval_statistics.update( create_stats_ordered_dict( 'Discriminator Log Pis', logp, )) self.eval_statistics.update( create_stats_ordered_dict( 'Discriminator Alt Log Pis', logp_altz, )) self.eval_statistics.update( create_stats_ordered_dict( 'Intrinsic Reward Denominator', denom, )) self.eval_statistics.update( create_stats_ordered_dict( 'Intrinsic Rewards (Original)', orig_rewards, )) self.eval_statistics.update( create_stats_ordered_dict( 'Intrinsic Rewards (Processed)', rewards, )) gt.stamp('intrinsic reward calculation', unique=False) if self._need_to_update_eval_statistics: self.eval_statistics.update( create_stats_ordered_dict( 'Latents', latents, )) for t in range(self.num_model_samples): self.add_sample( observations[t], next_observations[t], next_observations[t], # fix this actions[t], unfolded_latents[t], disagreement=disagreements[t], ) gt.stamp('policy training', unique=False) self.train_discriminator(observations, next_observations, unfolded_latents) reward_kwargs = dict( disagreements=self._modeL_disagreements[:self._cur_replay_size]) self.train_from_buffer(reward_kwargs=reward_kwargs)
def calculate_contrastive_empowerment( discriminator, obs, next_obs, latents, num_prior_samples=512, distribution_type='uniform', split_group=4096 * 32, obs_mean=None, obs_std=None, return_diagnostics=False, prior=None, ): """ Described in Sharma et al 2019. Approximate variational lower bound using estimate of s' from s, z. Uses contrastive negatives to approximate denominator. """ discriminator.eval() if obs_mean is not None: obs = (obs - obs_mean) / (obs_std + 1e-6) next_obs = (next_obs - obs_mean) / (obs_std + 1e-6) obs_deltas = ptu.from_numpy(next_obs - obs) obs_altz = np.concatenate([obs] * num_prior_samples, axis=0) with torch.no_grad(): logp = discriminator.get_log_prob( ptu.from_numpy(obs), ptu.from_numpy(latents), obs_deltas, ) logp = ptu.get_numpy(logp) if distribution_type == 'uniform': latent_altz = np.random.uniform(low=-1, high=1, size=(obs_altz.shape[0], latents.shape[1])) elif distribution_type == 'prior': if prior is None: raise AssertionError('prior specified but not passed in') obs_t = ptu.from_numpy(obs_altz) latent_altz, *_ = prior.get_action(obs_t, deterministic=False) else: raise NotImplementedError('distribution_type not found') # keep track of next obs/delta next_obs_altz = np.concatenate([next_obs - obs] * num_prior_samples, axis=0) with torch.no_grad(): if obs_altz.shape[0] <= split_group: logp_altz = ptu.get_numpy( discriminator.get_log_prob( ptu.from_numpy(obs_altz), ptu.from_numpy(latent_altz), ptu.from_numpy(next_obs_altz), )) else: logp_altz = [] for split_idx in range(obs_altz.shape[0] // split_group): start_split = split_idx * split_group end_split = (split_idx + 1) * split_group logp_altz.append( ptu.get_numpy( discriminator.get_log_prob( ptu.from_numpy(obs_altz[start_split:end_split]), ptu.from_numpy(latent_altz[start_split:end_split]), ptu.from_numpy( next_obs_altz[start_split:end_split]), ))) if obs_altz.shape[0] % split_group: start_split = obs_altz.shape[0] % split_group logp_altz.append( ptu.get_numpy( discriminator.get_log_prob( ptu.from_numpy(obs_altz[-start_split:]), ptu.from_numpy(latent_altz[-start_split:]), ptu.from_numpy(next_obs_altz[-start_split:]), ))) logp_altz = np.concatenate(logp_altz) logp_altz = np.array(np.array_split(logp_altz, num_prior_samples)) if return_diagnostics: diagnostics = dict() orig_rep = np.repeat(np.expand_dims(logp, axis=0), axis=0, repeats=num_prior_samples) diagnostics['Pct Random Skills > Original'] = (orig_rep < logp_altz).mean() # final DADS reward intrinsic_reward = np.log(num_prior_samples + 1) - np.log(1 + np.exp( np.clip(logp_altz - logp.reshape(1, -1), -50, 50)).sum(axis=0)) if not return_diagnostics: return intrinsic_reward, (logp, logp_altz, logp - intrinsic_reward) else: return intrinsic_reward, (logp, logp_altz, logp - intrinsic_reward), diagnostics
def train_from_buffer(self, replay_buffer, holdout_pct=0.2, max_grad_steps=1000, epochs_since_last_update=5): self._n_train_steps_total += 1 if self._n_train_steps_total % self.train_call_freq > 0 and self._n_train_steps_total > 1: return data = replay_buffer.get_transitions() x = data[:, :self.obs_dim + self.action_dim] # inputs s, a y = data[:, self.obs_dim + self.action_dim:] # predict r, d, ns y[:, -self.obs_dim:] -= x[:, :self.obs_dim] # predict delta in the state # normalize network inputs self.ensemble.fit_input_stats(x) # generate holdout set inds = np.random.permutation(data.shape[0]) x, y = x[inds], y[inds] n_train = max(int((1 - holdout_pct) * data.shape[0]), data.shape[0] - 8092) n_test = data.shape[0] - n_train x_train, y_train = x[:n_train], y[:n_train] x_test, y_test = x[n_train:], y[n_train:] x_test, y_test = ptu.from_numpy(x_test), ptu.from_numpy(y_test) # train until holdout set convergence num_epochs, num_steps = 0, 0 num_epochs_since_last_update = 0 best_holdout_loss = float('inf') num_batches = int(np.ceil(n_train / self.batch_size)) while num_epochs_since_last_update < epochs_since_last_update and num_steps < max_grad_steps: # generate idx for each model to bootstrap self.ensemble.train() for b in range(num_batches): b_idxs = np.random.randint(n_train, size=(self.ensemble_size * self.batch_size)) x_batch, y_batch = x_train[b_idxs], y_train[b_idxs] x_batch, y_batch = ptu.from_numpy(x_batch), ptu.from_numpy( y_batch) x_batch = x_batch.view(self.ensemble_size, self.batch_size, -1) y_batch = y_batch.view(self.ensemble_size, self.batch_size, -1) loss = self.ensemble.get_loss(x_batch, y_batch) self.optimizer.zero_grad() loss.backward() self.optimizer.step() num_steps += num_batches # stop training based on holdout loss improvement self.ensemble.eval() with torch.no_grad(): holdout_losses, holdout_errors = self.ensemble.get_loss( x_test, y_test, split_by_model=True, return_l2_error=True) holdout_loss = sum( sorted(holdout_losses)[:self.num_elites]) / self.num_elites if num_epochs == 0 or \ (best_holdout_loss - holdout_loss) / abs(best_holdout_loss) > 0.01: best_holdout_loss = holdout_loss num_epochs_since_last_update = 0 else: num_epochs_since_last_update += 1 num_epochs += 1 self.ensemble.elites = np.argsort(holdout_losses) if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False self.eval_statistics['Model Elites Holdout Loss'] = \ np.mean(ptu.get_numpy(holdout_loss)) self.eval_statistics['Model Holdout Loss'] = \ np.mean(ptu.get_numpy(sum(holdout_losses))) / self.ensemble_size self.eval_statistics['Model Training Epochs'] = num_epochs self.eval_statistics['Model Training Steps'] = num_steps for i in range(self.ensemble_size): name = 'M%d' % (i + 1) self.eval_statistics[name + ' Loss'] = \ np.mean(ptu.get_numpy(holdout_losses[i])) self.eval_statistics[name + ' L2 Error'] = \ np.mean(ptu.get_numpy(holdout_errors[i]))