def log_diagnostics(self, eval_statistics): ''' adds logging data about encodings to eval_statistics ''' z_mean = np.mean(np.abs(ptu.get_numpy(self.z_means[0]))) z_sig = np.mean(ptu.get_numpy(self.z_vars[0])) eval_statistics['Z mean eval'] = z_mean eval_statistics['Z variance eval'] = z_sig
def train_to_imitate(self, np_batch): batch = np_to_pytorch_batch(np_batch) obs = batch['observations'] actions = batch['actions'] """ Policy Loss """ _, policy_mean, _, _, policy_std, *_ = self.policy( obs, reparameterize=True, return_log_prob=True, ) policy_var = policy_std**2 dist = (actions - policy_mean)**2 / policy_var dist = torch.sum(dist, dim=1) log_policy_var = torch.log(policy_var) det = torch.sum(log_policy_var, dim=1) policy_loss = torch.mean(dist + det) """ Update networks """ self.policy_imitation_optimizer.zero_grad() policy_loss.backward() self.policy_imitation_optimizer.step() self.eval_statistics['Policy Loss'] = np.mean( ptu.get_numpy(policy_loss))
def select_actions(self, obs, raw_context): # Repeat the obs as what BCQ has done, # candidate_size here indicates how many # candidate actions we need. obs = from_numpy(np.tile(obs.reshape(1, -1), (self.candidate_size, 1))) if len(raw_context) == 0: # In the beginning, the inferred_mdp is set to zero vector. inferred_mdp = ptu.zeros((1, self.f.latent_dim)) else: # Construct the context from raw context context = from_numpy(np.concatenate(raw_context, axis=0))[None] inferred_mdp = self.f(context) with torch.no_grad(): inferred_mdp = inferred_mdp.repeat(self.candidate_size, 1) z = from_numpy( np.random.normal(0, 1, size=(obs.size(0), self.vae_latent_dim))).clamp( -0.5, 0.5).to(ptu.device) candidate_actions = self.vae_decoder(obs, z, inferred_mdp) perturbed_actions = self.perturbation_generator.get_perturbed_actions( obs, candidate_actions, inferred_mdp) qv = self.Qs(obs, perturbed_actions, inferred_mdp) ind = qv.max(0)[1] return ptu.get_numpy(perturbed_actions[ind])
def select_actions(self, obs, inferred_mdp): # Repeat the obs as what BCQ has done, # candidate_size here indicates how many # candidate actions we need. obs = from_numpy(np.tile(obs.reshape(1, -1), (self.candidate_size, 1))) with torch.no_grad(): inferred_mdp = inferred_mdp.repeat(self.candidate_size, 1) z = from_numpy( np.random.normal(0, 1, size=(obs.size(0), self.vae_latent_dim))).clamp( -0.5, 0.5).to(ptu.device) candidate_actions = self.vae_decoder(obs, z, inferred_mdp) perturbed_actions = self.perturbation_generator.get_perturbed_actions( obs, candidate_actions, inferred_mdp) qv = self.Qs(obs, perturbed_actions, inferred_mdp) ind = qv.max(0)[1] return ptu.get_numpy(perturbed_actions[ind])
def select_actions(self, obs, raw_context): # Repeat the obs as what BCQ has done, # candidate_size here indicates how many # candidate actions we need. if len(raw_context) == 0: # In the beginning, the inferred_mdp is set to zero vector. inferred_mdp = ptu.zeros( (1, self.policy.mlp_encoder.encoder_latent_dim)) else: # Construct the context from raw context context = from_numpy(np.concatenate(raw_context, axis=0))[None] inferred_mdp = self.policy.mlp_encoder(context) # obs = torch.cat([obs, inferred_mdp], dim=1) action = self.policy.select_action(obs, get_numpy(inferred_mdp)) return action
def check_q_funct_estimate(self, paths): s0 = np.stack([path["observations"][0] for path in paths]) s0 = ptu.from_numpy(s0) a0 = np.stack([path["actions"][0] for path in paths]) a0 = ptu.from_numpy(a0) inferred_mdps = torch.repeat_interleave(self.trainer._inferred_mdp, s0.shape[0], dim=0) q_values = torch.min( self.trainer.qf1(s0, a0, inferred_mdps), self.trainer.qf2(s0, a0, inferred_mdps), ) q_values = ptu.get_numpy(q_values) dicount_returns = [] for path in paths: discount_cof = self.trainer.discount**np.arange( len(path["rewards"])) dicount_return = np.sum(path["rewards"].flatten() * discount_cof) dicount_returns.append(dicount_return) q_values_mean = np.mean(q_values) q_values_std = np.std(q_values) dicount_returns_mean = np.mean(dicount_returns) dicount_returns_std = np.std(dicount_returns) return dict( q_values_mean=q_values_mean, q_values_std=q_values_std, dicount_returns_mean=dicount_returns_mean, dicount_returns_std=dicount_returns_std, )
def train(self, batch, batch_idxes, epoch): """ Unpack data from the batch """ rewards = batch['rewards'] obs = batch['obs'] actions = batch['actions'] # Get the in_mdp_batch_size in_mdp_batch_size = obs.shape[0] // batch_idxes.shape[0] """ Obtain the model prediction loss """ # Note that here, we do not calculate the obs_loss. pred_rewards = [net(obs, actions) for net in self.network_ensemble] # If you would like to train the reward estimator without # using the ensemble (Reproduce Fig. 7 in our paper), please # comment out Line 62 and uncomment the Line 68 to train only # one network to predict the rewards # pred_rewards = [self.network_ensemble[0](obs, actions) for net in self.network_ensemble] reward_loss_task_0 = [ F.mse_loss(pred_r[:in_mdp_batch_size], rewards[:in_mdp_batch_size]) for pred_r in pred_rewards ] gt.stamp('get_reward_loss', unique=False) self.network_ensemble_optimizer.zero_grad() [loss.backward() for loss in reward_loss_task_0] # Please comment out Line 74 and uncomment Line 78 if you would # like to train the reward estimator without using the ensemble # reward_loss_task_0[0].backward() self.network_ensemble_optimizer.step() gt.stamp('update', unique=False) """ Save some statistics for eval """ if self._need_to_update_eval_statistics: if epoch > -1: obs_other_tasks = [ obs[in_mdp_batch_size * i:in_mdp_batch_size * (i + 1)] for i in range(0, batch_idxes.shape[0]) ] actions_other_tasks = [ actions[in_mdp_batch_size * i:in_mdp_batch_size * (i + 1)] for i in range(0, batch_idxes.shape[0]) ] pred_rewards_other_tasks = [ torch.cat([ pred_r[in_mdp_batch_size * i:in_mdp_batch_size * (i + 1)] for pred_r in pred_rewards ], dim=1) for i in range(0, batch_idxes.shape[0]) ] reward_loss_other_tasks = [] reward_loss_other_tasks_std = [] reward_loss_prop_other_tasks = [] num_selected_trans_other_tasks = [] for i, item in enumerate( zip(pred_rewards_other_tasks, obs_other_tasks, actions_other_tasks)): pred_r_other_task, o_other_task, a_other_task = item pred_std = torch.std(pred_r_other_task, dim=1) # print(pred_std) mask = ptu.get_numpy(pred_std < self.std_threshold) num_selected_trans_other_tasks.append(np.sum(mask)) mask = mask.astype(bool) pred_r_other_task = ptu.get_numpy(pred_r_other_task) pred_r_record = pred_r_other_task[mask] o_other_task = ptu.get_numpy(o_other_task) o_other_task = o_other_task[mask] a_other_task = ptu.get_numpy(a_other_task) a_other_task = a_other_task[mask] mse_loss = [] mse_loss_prop = [] for pred_r, o, a in zip(pred_r_record, o_other_task, a_other_task): if self.domain == 'ant-dir': qpos = np.concatenate([np.zeros(2), o[:13]]) qvel = o[13:27] elif self.domain == 'ant-goal': qpos = o[:15] qvel = o[15:29] elif self.domain == 'humanoid-ndone-goal': qpos = o[:24] qvel = o[24:47] elif self.domain == 'humanoid-openai-dir': qpos = np.concatenate([np.zeros(2), o[:22]]) qvel = o[22:45] elif self.domain == 'halfcheetah-vel': qpos = np.concatenate([np.zeros(1), o[:8]]) qvel = o[8:17] elif 'maze' in self.domain: qpos = o[:2] qvel = o[2:4] self.env.set_state(qpos, qvel) _, r, _, _ = self.env.step(a) mse_loss.append((pred_r - r)**2) mse_loss_prop.append(np.sqrt((pred_r - r)**2 / r**2)) if len(mse_loss) > 0: reward_loss_other_tasks.append( np.mean(np.stack(mse_loss), axis=0).tolist()) reward_loss_other_tasks_std.append( np.std(np.stack(mse_loss), axis=0).tolist()) reward_loss_prop_other_tasks.append( np.mean(np.stack(mse_loss_prop), axis=0).tolist()) self.eval_statistics[ 'average_task_reward_loss_other_tasks_mean'] = np.mean( reward_loss_other_tasks, axis=1) self.eval_statistics[ 'average_task_reward_loss_other_tasks_std'] = np.std( reward_loss_other_tasks, axis=1) self.eval_statistics[ 'average_task_reward_loss_prop_other_task'] = np.mean( reward_loss_prop_other_tasks, axis=1) self.eval_statistics[ 'num_selected_trans_other_tasks'] = num_selected_trans_other_tasks self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ self.eval_statistics['reward_loss_task_0'] = np.mean( ptu.get_numpy(torch.mean(torch.stack(reward_loss_task_0))))
def train(self, batch, batch_idxes): """ Unpack data from the batch """ obs = batch['obs'] actions = batch['actions'] contexts = batch['contexts'] num_tasks = batch_idxes.shape[0] gt.stamp('unpack_data_from_the_batch', unique=False) # Get the in_mdp_batch_size obs_dim = obs.shape[1] action_dim = actions.shape[1] in_mdp_batch_size = obs.shape[0] // batch_idxes.shape[0] num_trans_context = contexts.shape[0] // batch_idxes.shape[0] """ Relabel the context batches for each training task """ with torch.no_grad(): contexts_obs_actions = contexts[:, :obs_dim + action_dim] manual_batched_rewards = self.ensemble_predictor.forward_mul_device( contexts_obs_actions) relabeled_rewards = manual_batched_rewards.reshape( num_tasks, self.num_network_ensemble, contexts.shape[0]) gt.stamp('relabel_ensemble', unique=False) relabeled_rewards_mean = torch.mean(relabeled_rewards, dim=1).squeeze() relabeled_rewards_std = torch.std(relabeled_rewards, dim=1).squeeze() # Replace the predicted reward with ground truth reward for transitions # with ground truth reward inside the batch for i in range(num_tasks): relabeled_rewards_mean[i, i*num_trans_context: (i+1)*num_trans_context] \ = contexts[i*num_trans_context: (i+1)*num_trans_context, -1] if self.is_combine: # Set the number to be larger than the self.std_threshold, so that # they will initially be filtered out when producing the mask, # which is conducive to the sampling. relabeled_rewards_std[ i, i * num_trans_context:(i + 1) * num_trans_context] = self.std_threshold + 1.0 else: relabeled_rewards_std[i, i * num_trans_context:(i + 1) * num_trans_context] = 0.0 mask = relabeled_rewards_std < self.std_threshold num_context_candidate_each_task = torch.sum(mask, dim=1) mask_list = [] for i in range(num_tasks): assert mask[i].dim() == 1 mask_nonzero = torch.nonzero(mask[i]) mask_nonzero = mask_nonzero.flatten() mask_i = ptu.zeros_like(mask[i], dtype=torch.uint8) assert num_context_candidate_each_task[i].item( ) == mask_nonzero.shape[0] np_ind = np.random.choice(mask_nonzero.shape[0], num_trans_context, replace=False) ind = mask_nonzero[np_ind] mask_i[ind] = 1 if self.is_combine: # Combine the additional relabeledcontext transitions with # the original context transitions with ground-truth rewards mask_i[i * num_trans_context:(i + 1) * num_trans_context] = 1 assert torch.sum(mask_i).item() == 2 * num_trans_context else: assert torch.sum(mask_i).item() == num_trans_context mask_list.append(mask_i) mask = torch.cat(mask_list) mask = mask.type(torch.uint8) repeated_contexts = contexts.repeat(num_tasks, 1) context_without_rewards = repeated_contexts[:, :-1] assert context_without_rewards.shape[ 0] == relabeled_rewards_mean.reshape(-1, 1).shape[0] context_without_rewards = context_without_rewards[mask] context_rewards = relabeled_rewards_mean.reshape(-1, 1)[mask] fast_contexts = torch.cat( (context_without_rewards, context_rewards), dim=1) fast_contexts = fast_contexts.reshape(num_tasks, -1, contexts.shape[-1]) gt.stamp('relabel_context_transitions', unique=False) """ Infer the context variables """ # inferred_mdps = self.context_encoder(new_contexts) inferred_mdps = self.context_encoder(fast_contexts) inferred_mdps = torch.repeat_interleave(inferred_mdps, in_mdp_batch_size, dim=0) gt.stamp('infer_mdps', unique=False) with torch.no_grad(): # Sample z for each state z = self.bcq_polices[0].vae.sample_z(obs).to(ptu.device) # Each item in critic_weights is a list that has device count entries # each entry in the critic_weights[i] is a list that has num layer entries # each entry in the critic_weights[i][j] is a tensor of dim (num tasks // device count, layer input size, layer out size) # Similarly to the other weights and biases critic_weights, critic_biases, vae_weights, vae_biases, actor_weights, actor_biases = self.combined_bcq_policies # CRITIC obs_reshaped = obs.reshape(len(batch_idxes), in_mdp_batch_size, -1) acs_reshaped = actions.reshape(len(batch_idxes), in_mdp_batch_size, -1) obs_acs_reshaped = torch.cat((obs_reshaped, acs_reshaped), dim=-1) target_q = batch_bcq(obs_acs_reshaped, critic_weights, critic_biases) target_q = target_q.reshape(-1) # VAE z_reshaped = z.reshape(len(batch_idxes), in_mdp_batch_size, -1) obs_z_reshaped = torch.cat((obs_reshaped, z_reshaped), dim=-1) tc = batch_bcq(obs_z_reshaped, vae_weights, vae_biases) tc = self.bcq_polices[0].vae.max_action * torch.tanh(tc) target_candidates = tc.reshape(-1, tc.shape[-1]) # PERTURBATION tc_reshaped = target_candidates.reshape(len(batch_idxes), in_mdp_batch_size, -1) obs_tc_reshaped = torch.cat((obs_reshaped, tc_reshaped), dim=-1) tp = batch_bcq(obs_tc_reshaped, actor_weights, actor_biases) tp = self.bcq_polices[0].actor.max_action * torch.tanh(tp) target_perturbations = tp.reshape(-1, tp.shape[-1]) gt.stamp('get_the_targets', unique=False) """ Obtain the KL loss """ # KL constraint on z if probabilistic self.context_encoder_optimizer.zero_grad() kl_div = self.context_encoder.compute_kl_div() kl_loss_each_task = self.kl_lambda * torch.sum(kl_div, dim=1) kl_loss = torch.sum(kl_loss_each_task) gt.stamp('get_kl_loss', unique=False) """ Obtain the Q-function loss """ self.Qs_optimizer.zero_grad() pred_q = self.Qs(obs, actions, inferred_mdps) pred_q = torch.squeeze(pred_q) qf_loss_each_task = (pred_q - target_q)**2 qf_loss_each_task = qf_loss_each_task.reshape(num_tasks, -1) qf_loss_each_task = torch.mean(qf_loss_each_task, dim=1) qf_loss = torch.mean(qf_loss_each_task) gt.stamp('get_qf_loss', unique=False) (kl_loss + qf_loss).backward() gt.stamp('get_kl_qf_gradient', unique=False) self.Qs_optimizer.step() self.context_encoder_optimizer.step() gt.stamp('update_Qs_encoder', unique=False) """ Obtain the candidate action and perturbation loss """ self.vae_decoder_optimizer.zero_grad() self.perturbation_generator_optimizer.zero_grad() pred_candidates = self.vae_decoder(obs, z, inferred_mdps.detach()) pred_perturbations = self.perturbation_generator( obs, target_candidates, inferred_mdps.detach()) candidate_loss_each_task = (pred_candidates - target_candidates)**2 # averaging over action dimension candidate_loss_each_task = torch.mean(candidate_loss_each_task, dim=1) candidate_loss_each_task = candidate_loss_each_task.reshape( num_tasks, in_mdp_batch_size) # average over action in each task candidate_loss_each_task = torch.mean(candidate_loss_each_task, dim=1) candidate_loss = torch.mean(candidate_loss_each_task) perturbation_loss_each_task = (pred_perturbations - target_perturbations)**2 # average over action dimension perturbation_loss_each_task = torch.mean(perturbation_loss_each_task, dim=1) perturbation_loss_each_task = perturbation_loss_each_task.reshape( num_tasks, in_mdp_batch_size) # average over action in each task perturbation_loss_each_task = torch.mean(perturbation_loss_each_task, dim=1) perturbation_loss = torch.mean(perturbation_loss_each_task) gt.stamp('get_candidate_and_perturbation_loss', unique=False) (candidate_loss + perturbation_loss).backward() gt.stamp('get_candidate_and_perturbation_gradient', unique=False) self.vae_decoder_optimizer.step() self.perturbation_generator_optimizer.step() gt.stamp('update_vae_perturbation', unique=False) """ Save some statistics for eval """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ self.eval_statistics['qf_loss'] = np.mean(ptu.get_numpy(qf_loss)) self.eval_statistics['qf_loss_each_task'] = ptu.get_numpy( qf_loss_each_task) self.eval_statistics['kl_loss'] = np.mean(ptu.get_numpy(kl_loss)) self.eval_statistics['kl_loss_each_task'] = ptu.get_numpy( kl_loss_each_task) self.eval_statistics['candidate_loss'] = np.mean( ptu.get_numpy(candidate_loss)) self.eval_statistics['candidate_loss_each_task'] = ptu.get_numpy( candidate_loss_each_task) self.eval_statistics['perturbation_loss'] = np.mean( ptu.get_numpy(perturbation_loss)) self.eval_statistics[ 'perturbation_loss_each_task'] = ptu.get_numpy( perturbation_loss_each_task) self.eval_statistics[ 'num_context_candidate_each_task'] = num_context_candidate_each_task
def train_from_torch_qf2_policy(self, batch): rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] """ QF Loss """ if self.use_automatic_entropy_tuning: alpha = self.log_alpha.exp() else: alpha = 1 q2_pred = self.qf2(obs, actions) # Make sure policy accounts for squashing # functions like tanh correctly! new_next_actions, _, _, new_log_pi, *_ = self.policy( next_obs, reparameterize=True, return_log_prob=True, ) target_q2_values = torch.min( self.target_qf1(next_obs, new_next_actions), self.target_qf2(next_obs, new_next_actions), ) - alpha * new_log_pi q2_target = self.reward_scale * rewards + \ (1. - terminals) * self.discount * target_q2_values qf2_loss = self.qf_criterion(q2_pred, q2_target.detach()) """ Policy and Alpha Loss """ new_obs_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy( obs, reparameterize=True, return_log_prob=True, ) if self.use_automatic_entropy_tuning: alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() alpha = self.log_alpha.exp() else: alpha_loss = 0 alpha = 1 q_new_actions = torch.min( self.qf1(obs, new_obs_actions), self.qf2(obs, new_obs_actions), ) policy_loss = (alpha * log_pi - q_new_actions).mean() """ Update networks """ self.qf2_optimizer.zero_grad() qf2_loss.backward() self.qf2_optimizer.step() self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() """ Soft Updates """ if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to(self.qf2, self.target_qf2, self.soft_target_tau_qf) ptu.soft_update_from_to(self.policy, self.target_policy, self.soft_target_tau_policy) self.policy.load_state_dict(self.target_policy.state_dict()) """ Save some statistics for eval """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ policy_loss = (log_pi - q_new_actions).mean() self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics.update( create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q2 Targets', ptu.get_numpy(q2_target), )) self.eval_statistics['Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) self.eval_statistics.update( create_stats_ordered_dict( 'Log Pis', ptu.get_numpy(log_pi), )) self.eval_statistics.update( create_stats_ordered_dict( 'Policy mu', ptu.get_numpy(policy_mean), )) self.eval_statistics.update( create_stats_ordered_dict( 'Policy log std', ptu.get_numpy(policy_log_std), )) if self.use_automatic_entropy_tuning: self.eval_statistics['Alpha'] = alpha.item() self.eval_statistics['Alpha Loss'] = alpha_loss.item() self._n_train_steps_total += 1
def train_from_torch_qf1(self, batch): rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] if self.use_automatic_entropy_tuning: alpha = self.log_alpha.exp() else: alpha = 1 """ QF Loss """ q1_pred = self.qf1(obs, actions) # Make sure policy accounts for squashing # functions like tanh correctly! new_next_actions, _, _, new_log_pi, *_ = self.policy( next_obs, reparameterize=True, return_log_prob=True, ) target_q1_values = torch.min( self.target_qf1(next_obs, new_next_actions), self.target_qf2(next_obs, new_next_actions), ) - alpha * new_log_pi q1_target = self.reward_scale * rewards + \ (1. - terminals) * self.discount * target_q1_values qf1_loss = self.qf_criterion(q1_pred, q1_target.detach()) """ Update networks """ self.qf1_optimizer.zero_grad() qf1_loss.backward() self.qf1_optimizer.step() """ Soft Updates """ if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to(self.qf1, self.target_qf1, self.soft_target_tau_qf) """ Save some statistics for eval """ if self._need_to_update_eval_statistics: """ Eval should set this to None. This way, these statistics are only computed for one batch. """ self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics.update( create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q1 Targets', ptu.get_numpy(q1_target), )) self._n_train_steps_total += 1
def np_ify(tensor_or_other): if isinstance(tensor_or_other, torch.autograd.Variable): return ptu.get_numpy(tensor_or_other) else: return tensor_or_other
def train(self, train_data, discount=0.99, tau=0.005): # Sample replay buffer / batch state_np, next_state_np, action, reward, done, context = train_data state = torch.FloatTensor(state_np).to(device) action = torch.FloatTensor(action).to(device) next_state = torch.FloatTensor(next_state_np).to(device) reward = torch.FloatTensor(reward).to(device) done = torch.FloatTensor(1 - done).to(device) context = torch.FloatTensor(context).to(device) gt.stamp('unpack_data', unique=False) # Infer mdep identity using context # inferred_mdp = self.mlp_encoder(context) # in_mdp_batch_size = state.shape[0] // context.shape[0] # inferred_mdp = torch.repeat_interleave(inferred_mdp, in_mdp_batch_size, dim=0) # gt.stamp('infer_mdp_identity', unique=False) # Train the mlp encoder to predict the rewards. # self.mlp_encoder.zero_grad() # pred_next_obs = self.E(state, action) # pred_rewards = self.P(pred_next_obs, inferred_mdp) # reward_loss = F.mse_loss(pred_rewards, reward) # gt.stamp('get_reward_loss', unique=False) # reward_loss.backward(retain_graph=True) # gt.stamp('get_reward_gradient', unique=False) # Extend the state space using the inferred_mdp # state = torch.cat([state, inferred_mdp], dim=1) # next_state = torch.cat([next_state, inferred_mdp], dim=1) # gt.stamp('extend_original_state', unique=False) # Critic Training self.critic_optimizer.zero_grad() with torch.no_grad(): # Duplicate state 10 times state_rep = next_state.repeat_interleave(10, dim=0) gt.stamp('check0', unique=False) # candidate_action = self.vae.decode(state_rep) # torch.cuda.synchronize() # gt.stamp('check1', unique=False) # perturbated_action = self.actor_target(state_rep, candidate_action) # torch.cuda.synchronize() # gt.stamp('check2', unique=False) # target_Q1, target_Q2 = self.critic_target(state_rep, perturbated_action) # torch.cuda.synchronize() # gt.stamp('check3', unique=False) target_Q1, target_Q2 = self.critic_target( state_rep, self.actor_target(state_rep, self.vae.decode(state_rep))) # Soft Clipped Double Q-learning target_Q = self.target_q_coef * torch.min(target_Q1, target_Q2) + ( 1 - self.target_q_coef) * torch.max(target_Q1, target_Q2) target_Q = target_Q.view(state.shape[0], -1).max(1)[0].view(-1, 1) target_Q = reward + done * discount * target_Q current_Q1, current_Q2 = self.critic(state, action) critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss( current_Q2, target_Q) gt.stamp('get_critic_loss', unique=False) critic_loss.backward() # retain_graph=True gt.stamp('get_critic_gradient', unique=False) # self.mlp_encoder_optimizer.step() # gt.stamp('update_mlp_encoder', unique=False) # Variational Auto-Encoder Training recon, mean, std = self.vae(state, action) recon_loss = F.mse_loss(recon, action) KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean() vae_loss = recon_loss + 0.5 * KL_loss gt.stamp('get_vae_loss', unique=False) self.vae_optimizer.zero_grad() vae_loss.backward() self.vae_optimizer.step() gt.stamp('update_vae', unique=False) self.critic_optimizer.step() gt.stamp('update_critic', unique=False) # Pertubation Model / Action Training sampled_actions = self.vae.decode(state) perturbed_actions = self.actor(state, sampled_actions) # Update through DPG self.actor_optimizer.zero_grad() actor_loss = -self.critic.q1(state, perturbed_actions).mean() gt.stamp('get_actor_loss', unique=False) self.actor_optimizer.step() gt.stamp('update_actor', unique=False) # Update Target Networks for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) gt.stamp('update_target_network', unique=False) """ Save some statistics for eval """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ self.eval_statistics['actor_loss'] = np.mean(get_numpy(actor_loss)) self.eval_statistics['critic_loss'] = np.mean( get_numpy(critic_loss)) self.eval_statistics['vae_loss'] = np.mean(get_numpy(vae_loss))
def train(self, batch, batch_idxes): """ Unpack data from the batch """ obs = batch['obs'] actions = batch['actions'] contexts = batch['contexts'] num_candidate_context = contexts[0].shape[0] meta_batch_size = batch_idxes.shape[0] num_posterior = meta_batch_size * num_candidate_context contexts = torch.cat(contexts, dim=0) # Get the in_mdp_batch_size in_mdp_batch_size = obs.shape[0] // batch_idxes.shape[0] # Sample z for each state z = self.bcq_polices[0].vae.sample_z(obs).to(ptu.device) target_q = [] target_candidates = [] target_perturbations = [] for i, batch_idx in enumerate(batch_idxes): tq = self.bcq_polices[batch_idx].critic.q1( obs[i * in_mdp_batch_size:(i + 1) * in_mdp_batch_size], actions[i * in_mdp_batch_size:(i + 1) * in_mdp_batch_size]).detach() target_q.append(tq) tc = self.bcq_polices[batch_idx].vae.decode( obs[i * in_mdp_batch_size:(i + 1) * in_mdp_batch_size], z[i * in_mdp_batch_size:(i + 1) * in_mdp_batch_size]).detach() target_candidates.append(tc) tp = self.bcq_polices[batch_idx].get_perturbation( obs[i * in_mdp_batch_size:(i + 1) * in_mdp_batch_size], tc).detach() target_perturbations.append(tp) target_q = torch.cat(target_q, dim=0).squeeze() target_candidates = torch.cat(target_candidates, dim=0) target_perturbations = torch.cat(target_perturbations, dim=0) gt.stamp('get_the_targets', unique=False) """ Compute triplet loss """ self.context_encoder_optimizer.zero_grad() # z_means, z_var: (num_posterior, latent_dim), num_posterior = meta_batch_size * num_candidate_context z_means, z_vars = self.context_encoder.infer_posterior_with_mean_var( contexts) # z_means_interleave: (num_posterior * num_posterior, latent_dim) [1, 2, 3] -> [1, 1, 1, 2, 2, 2, 3, 3, 3] z_means_interleave = torch.repeat_interleave(z_means, num_posterior, dim=0) # z_means_repeat: (num_posterior * num_posterior, latent_dim) [1, 2, 3] -> [1, 2, 3, 2, 3, 1, 3, 1, 2]. # By doing so, it is easy to get the triplet loss z_means_repeat = [] for i in range(meta_batch_size): z_means_repeat.append( torch.cat([ z_means[i * num_candidate_context:], z_means[:i * num_candidate_context] ], dim=0).repeat(num_candidate_context, 1)) z_means_repeat = torch.cat(z_means_repeat, dim=0) # As above z_vars_interleave = torch.repeat_interleave(z_vars, num_posterior, dim=0) z_vars_repeat = [] for i in range(meta_batch_size): z_vars_repeat.append( torch.cat([ z_vars[i * num_candidate_context:], z_vars[:i * num_candidate_context] ], dim=0).repeat(num_candidate_context, 1)) z_vars_repeat = torch.cat(z_vars_repeat, dim=0) gt.stamp('get_repeated_mean_var', unique=False) # log(det(Sigma2) / det(Sigma1)): (num_posterior * num_posterior, 1) kl_divergence = torch.log( torch.prod(z_vars_repeat / z_vars_interleave, dim=1)) # -d kl_divergence -= z_means.shape[-1] # Tr(Sigma2^{-1} * Sigma1) kl_divergence += torch.sum(z_vars_interleave / z_vars_repeat, dim=1) # (m2 - m1).T Sigma2^{-1} (m2 - m1)) kl_divergence += torch.sum( (z_means_repeat - z_means_interleave)**2 / z_vars_repeat, dim=1) # / 2 # (num_posterior, num_posterior): each element kl_{i, j} denotes the kl divergence between the two distributions. # Task number for row: i // num_posterior // num_candidate_context. # for col: j % num_posterior // num_candidate_context. # Batch number for row: i // num_posterior % num_candidate_context. # for col: j % num_posterior % num_candidate_context. kl_divergence = kl_divergence.reshape(num_posterior, num_posterior) / 2 within_task_dist = torch.max(kl_divergence[:, :num_candidate_context], dim=1)[0] across_task_dist = torch.min(kl_divergence[:, num_candidate_context:], dim=1)[0] unscaled_triplet_loss = torch.sum( F.relu(within_task_dist - across_task_dist + self.triplet_margin)) gt.stamp('get_triplet_loss', unique=False) """ Infer the context variables """ index = np.random.choice( num_candidate_context, meta_batch_size ) + num_candidate_context * np.arange(meta_batch_size) # Get the sampled mean and vars for each task. # mean: (meta_batch_size, latent_dim) # var: (meta_batch_size, latent_dim) mean = z_means[index] var = z_vars[index] # Get the inferred MDP # inferred_mdps: (meta_batch_size, latent_dim) inferred_mdps = self.context_encoder.sample_z_from_mean_var(mean, var) inferred_mdps = torch.repeat_interleave(inferred_mdps, in_mdp_batch_size, dim=0) gt.stamp('infer_mdps', unique=False) """ Obtain the KL loss """ prior_mean = ptu.zeros(mean.shape) prior_var = ptu.ones(var.shape) kl_loss = self.kl_lambda * self.context_encoder.compute_kl_div_between_posterior( mean, var, prior_mean, prior_var) gt.stamp('get_kl_loss', unique=False) # triplet_loss = (kl_loss / unscaled_triplet_loss).detach() * unscaled_triplet_loss # posterior_loss = unscaled_triplet_loss + kl_loss # posterior_loss.backward(retain_graph=True) # gt.stamp('get_posterior_gradient', unique=False) """ Obtain the Q-function loss """ self.Qs_optimizer.zero_grad() pred_q = self.Qs(obs, actions, inferred_mdps) pred_q = torch.squeeze(pred_q) qf_loss = F.mse_loss(pred_q, target_q) gt.stamp('get_qf_loss', unique=False) (qf_loss + unscaled_triplet_loss + kl_loss).backward() gt.stamp('get_qf_encoder_gradient', unique=False) self.Qs_optimizer.step() self.context_encoder_optimizer.step() """ Obtain the candidate action and perturbation loss """ self.vae_decoder_optimizer.zero_grad() self.perturbation_generator_optimizer.zero_grad() pred_candidates = self.vae_decoder(obs, z, inferred_mdps.detach()) pred_perturbations = self.perturbation_generator( obs, target_candidates, inferred_mdps.detach()) candidate_loss = F.mse_loss(pred_candidates, target_candidates) perturbation_loss = F.mse_loss(pred_perturbations, target_perturbations) gt.stamp('get_candidate_and_perturbation_loss', unique=False) candidate_loss.backward() perturbation_loss.backward() gt.stamp('get_candidate_and_perturbation_gradient', unique=False) self.vae_decoder_optimizer.step() self.perturbation_generator_optimizer.step() """ Save some statistics for eval """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ self.eval_statistics['qf_loss'] = np.mean(ptu.get_numpy(qf_loss)) self.eval_statistics['unscaled_triplet_loss'] = np.mean( ptu.get_numpy(unscaled_triplet_loss)) self.eval_statistics['kl_loss'] = np.mean(ptu.get_numpy(kl_loss)) self.eval_statistics['candidate_loss'] = np.mean( ptu.get_numpy(candidate_loss)) self.eval_statistics['perturbation_loss'] = np.mean( ptu.get_numpy(perturbation_loss))
def np_ify(tensor_or_other): if isinstance(tensor_or_other, Variable): return ptu.get_numpy(tensor_or_other) else: return tensor_or_other
def select_actions(self, obs, inferred_mdp): action = self.policy.select_action(obs, get_numpy(inferred_mdp)) return action
def train(self, batch, batch_idxes): """ Unpack data from the batch """ obs = batch['obs'] actions = batch['actions'] contexts = batch['contexts'] num_tasks = batch_idxes.shape[0] gt.stamp('unpack_data_from_the_batch', unique=False) # Get the in_mdp_batch_size obs_dim = obs.shape[1] action_dim = actions.shape[1] in_mdp_batch_size = obs.shape[0] // batch_idxes.shape[0] num_trans_context = contexts.shape[0] // batch_idxes.shape[0] """ Relabel the context batches for each training task """ with torch.no_grad(): contexts_obs_actions = contexts[:, :obs_dim + action_dim] manual_batched_rewards = self.reward_ensemble_predictor.forward_mul_device( contexts_obs_actions) relabeled_rewards = manual_batched_rewards.reshape( num_tasks, self.num_network_ensemble, contexts.shape[0]) gt.stamp('reward_ensemble_forward', unique=False) manual_batched_next_obs = self.transition_ensemble_predictor.forward_mul_device( contexts_obs_actions) relabeled_next_obs = manual_batched_next_obs.reshape( num_tasks, self.num_network_ensemble, contexts.shape[0], obs_dim) gt.stamp('transition_ensemble_forward', unique=False) relabeled_rewards_mean = torch.mean(relabeled_rewards, dim=1).squeeze() relabeled_rewards_std = torch.std(relabeled_rewards, dim=1).squeeze() relabeled_next_obs_mean = torch.mean(relabeled_next_obs, dim=1) relabeled_next_obs_std = torch.std(relabeled_next_obs, dim=1) relabeled_next_obs_std = torch.mean(relabeled_next_obs_std, dim=-1) # Replace the predicted reward with ground truth reward for transitions # with ground truth reward inside the batch for i in range(num_tasks): relabeled_rewards_mean[i, i*num_trans_context: (i+1)*num_trans_context] \ = contexts[i*num_trans_context: (i+1)*num_trans_context, -1] relabeled_next_obs_mean[i, i*num_trans_context: (i+1)*num_trans_context, :] \ = contexts[i*num_trans_context: (i+1)*num_trans_context, obs_dim + action_dim : -1] if self.is_combine: # Set the number to be larger than the self.std_threshold, so that # they will initially be filtered out when producing the mask, # which is conducive to the sampling. relabeled_rewards_std[ i, i * num_trans_context:(i + 1) * num_trans_context] = self.reward_std_threshold + 1.0 relabeled_next_obs_std[ i, i * num_trans_context:(i + 1) * num_trans_context] = self.next_obs_std_threshold + 1.0 else: relabeled_rewards_std[i, i * num_trans_context:(i + 1) * num_trans_context] = 0.0 relabeled_next_obs_std[i, i * num_trans_context:(i + 1) * num_trans_context] = 0.0 mask_reward = relabeled_rewards_std < self.reward_std_threshold mask_reward = mask_reward.type(torch.float) mask_next_obs = relabeled_next_obs_std < self.next_obs_std_threshold mask_next_obs = mask_next_obs.type(torch.float) mask = mask_reward * mask_next_obs mask = mask.type(torch.uint8) mask_from_the_other_tasks = mask.type(torch.uint8).clone() num_context_candidate_each_task = torch.sum(mask, dim=1) mask_list = [] for i in range(num_tasks): assert mask[i].dim() == 1 mask_nonzero = torch.nonzero(mask[i]) mask_nonzero = mask_nonzero.flatten() mask_i = ptu.zeros_like(mask[i], dtype=torch.uint8) assert num_context_candidate_each_task[i].item( ) == mask_nonzero.shape[0] np_ind = np.random.choice(mask_nonzero.shape[0], num_trans_context, replace=False) ind = mask_nonzero[np_ind] mask_i[ind] = 1 if self.is_combine: # Combine the additional relabeledcontext transitions with # the original context transitions with ground-truth rewards mask_i[i * num_trans_context:(i + 1) * num_trans_context] = 1 assert torch.sum(mask_i).item() == 2 * num_trans_context else: assert torch.sum(mask_i).item() == num_trans_context mask_list.append(mask_i) mask = torch.cat(mask_list) mask = mask.type(torch.uint8) repeated_contexts = contexts.repeat(num_tasks, 1) context_without_next_obs_rewards = repeated_contexts[:, :obs_dim + action_dim] assert context_without_next_obs_rewards.shape[ 0] == relabeled_rewards_mean.reshape(-1, 1).shape[0] assert context_without_next_obs_rewards.shape[ 0] == relabeled_next_obs_mean.reshape(-1, obs_dim).shape[0] context_without_next_obs_rewards = context_without_next_obs_rewards[ mask] context_next_obs = relabeled_next_obs_mean.reshape(-1, obs_dim)[mask] context_rewards = relabeled_rewards_mean.reshape(-1, 1)[mask] fast_contexts = torch.cat((context_without_next_obs_rewards, context_next_obs, context_rewards), dim=1) fast_contexts = fast_contexts.reshape(num_tasks, -1, contexts.shape[-1]) gt.stamp('relabel_context_transitions', unique=False) """ Obtain the targets """ with torch.no_grad(): # Sample z for each state z = self.bcq_polices[0].vae.sample_z(obs).to(ptu.device) # Each item in critic_weights is a list that has device count entries # each entry in the critic_weights[i] is a list that has num layer entries # each entry in the critic_weights[i][j] is a tensor of dim (num tasks // device count, layer input size, layer out size) # Similarly to the other weights and biases critic_weights, critic_biases, vae_weights, vae_biases, actor_weights, actor_biases = self.combined_bcq_policies # CRITIC obs_reshaped = obs.reshape(len(batch_idxes), in_mdp_batch_size, -1) acs_reshaped = actions.reshape(len(batch_idxes), in_mdp_batch_size, -1) obs_acs_reshaped = torch.cat((obs_reshaped, acs_reshaped), dim=-1) target_q = batch_bcq(obs_acs_reshaped, critic_weights, critic_biases) target_q = target_q.reshape(-1) # VAE z_reshaped = z.reshape(len(batch_idxes), in_mdp_batch_size, -1) obs_z_reshaped = torch.cat((obs_reshaped, z_reshaped), dim=-1) tc = batch_bcq(obs_z_reshaped, vae_weights, vae_biases) tc = self.bcq_polices[0].vae.max_action * torch.tanh(tc) target_candidates = tc.reshape(-1, tc.shape[-1]) # PERTURBATION tc_reshaped = target_candidates.reshape(len(batch_idxes), in_mdp_batch_size, -1) obs_tc_reshaped = torch.cat((obs_reshaped, tc_reshaped), dim=-1) tp = batch_bcq(obs_tc_reshaped, actor_weights, actor_biases) tp = self.bcq_polices[0].actor.max_action * torch.tanh(tp) target_perturbations = tp.reshape(-1, tp.shape[-1]) gt.stamp('get_the_targets', unique=False) """ Compute the triplet loss """ # ----------------------------------Vectorized------------------------------------------- self.context_encoder_optimizer.zero_grad() anchors = [] positives = [] negatives = [] num_selected_list = [] # Pair of task (i,j) # where no transitions from j is selected by the ensemble of task i exclude_tasks = [] exclude_task_masks = [] for i in range(num_tasks): # Compute the triplet loss for task i for j in range(num_tasks): if j != i: # mask_for_task_j: (num_trans_context, ) # mask_from_the_other_tasks: (num_tasks, num_tasks * num_trans_context) mask_for_task_j = mask_from_the_other_tasks[ i, j * num_trans_context:(j + 1) * num_trans_context] num_selected = int(torch.sum(mask_for_task_j).item()) if num_selected == 0: exclude_tasks.append((i, j)) exclude_task_masks.append(0) else: exclude_task_masks.append(1) # context_trans_all: (num_trans_context, context_dim) context_trans_all = contexts[j * num_trans_context:(j + 1) * num_trans_context] # context_trans_all: (num_selected, context_dim) context_trans_selected = context_trans_all[mask_for_task_j] # relabel_reward_all: (num_trans_context, ) relabel_reward_all = relabeled_rewards_mean[ i, j * num_trans_context:(j + 1) * num_trans_context] # relabel_reward_all: (num_selected, ) relabel_reward_selected = relabel_reward_all[ mask_for_task_j] # relabel_reward_all: (num_selected, 1) relabel_reward_selected = relabel_reward_selected.reshape( -1, 1) # relabel_next_obs_all: (num_trans_context, obs_dim) relabel_next_obs_all = relabeled_next_obs_mean[ i, j * num_trans_context:(j + 1) * num_trans_context] # relabel_next_obs_all: (num_selected, obs_dim) relabel_next_obs_selected = relabel_next_obs_all[ mask_for_task_j] # context_trans_selected_relabel: (num_selected, context_dim) context_trans_selected_relabel = torch.cat([ context_trans_selected[:, :obs_dim + action_dim], relabel_next_obs_selected, relabel_reward_selected ], dim=1) # c_{i} ind = np.random.choice(num_trans_context, num_selected, replace=False) # Next 2 lines used for comparing to sequential version # ind = ind_list[count] # count += 1 # context_trans_task_i: (num_trans_context, context_dim) context_trans_task_i = contexts[i * num_trans_context:(i + 1) * num_trans_context] # context_trans_task_i: (num_selected, context_dim) context_trans_task_i_sampled = context_trans_task_i[ind] # Pad the contexts with 0 tensor num_to_pad = num_trans_context - num_selected # pad_zero_tensor: (num_to_pad, context_dim) pad_zero_tensor = ptu.zeros( (num_to_pad, context_trans_selected.shape[1])) num_selected_list.append(num_selected) # Dim: (1, num_trans_context, context_dim) context_trans_selected = torch.cat( [context_trans_selected, pad_zero_tensor], dim=0) context_trans_selected_relabel = torch.cat( [context_trans_selected_relabel, pad_zero_tensor], dim=0) context_trans_task_i_sampled = torch.cat( [context_trans_task_i_sampled, pad_zero_tensor], dim=0) anchors.append(context_trans_selected_relabel[None]) positives.append(context_trans_task_i_sampled[None]) negatives.append(context_trans_selected[None]) # Dim: (num_tasks * (num_tasks - 1), num_trans_context, context_dim) anchors = torch.cat(anchors, dim=0) positives = torch.cat(positives, dim=0) negatives = torch.cat(negatives, dim=0) # input_contexts: (3 * num_tasks * (num_tasks - 1), num_trans_context, context_dim) input_contexts = torch.cat([anchors, positives, negatives], dim=0) # num_selected_pt: (num_tasks * (num_tasks - 1), ) num_selected_pt = torch.from_numpy(np.array(num_selected_list)) # num_selected_repeat: (3 * num_tasks * (num_tasks - 1), ) num_selected_repeat = num_selected_pt.repeat(3) # z_means_vec, z_vars_vec: (3 * num_tasks * (num_tasks - 1), latent_dim) z_means_vec, z_vars_vec = self.context_encoder.infer_posterior_with_mean_var( input_contexts, num_trans_context, num_selected_repeat) # z_means_vec, z_vars_vec: (3, num_tasks * (num_tasks - 1), latent_dim) z_means_vec = z_means_vec.reshape(3, anchors.shape[0], -1) z_vars_vec = z_vars_vec.reshape(3, anchors.shape[0], -1) # Dim: (num_tasks * (num_tasks - 1), latent_dim) z_means_anchors, z_vars_anchors = z_means_vec[0], z_vars_vec[0] z_means_positives, z_vars_positives = z_means_vec[1], z_vars_vec[1] z_means_negatives, z_vars_negatives = z_means_vec[2], z_vars_vec[2] with_task_dist = compute_kl_div_diagonal(z_means_anchors, z_vars_anchors, z_means_positives, z_vars_positives) across_task_dist = compute_kl_div_diagonal(z_means_anchors, z_vars_anchors, z_means_negatives, z_vars_negatives) # Remove the triplet corresponding to # num selected equal 0 exclude_task_masks = ptu.from_numpy(np.array(exclude_task_masks)) with_task_dist = with_task_dist * exclude_task_masks across_task_dist = across_task_dist * exclude_task_masks unscaled_triplet_loss_vec = F.relu(with_task_dist - across_task_dist + self.triplet_margin) unscaled_triplet_loss_vec = torch.mean(unscaled_triplet_loss_vec) # assert unscaled_triplet_loss_vec is not nan assert (unscaled_triplet_loss_vec != unscaled_triplet_loss_vec).any() is not True gt.stamp('get_triplet_loss', unique=False) unscaled_triplet_loss_vec.backward() check_grad_nan_nets(self.networks, f'triplet: {unscaled_triplet_loss_vec}') gt.stamp('get_triplet_loss_gradient', unique=False) """ Infer the context variables """ # inferred_mdps = self.context_encoder(new_contexts) inferred_mdps = self.context_encoder(fast_contexts) inferred_mdps = torch.repeat_interleave(inferred_mdps, in_mdp_batch_size, dim=0) gt.stamp('infer_mdps', unique=False) """ Obtain the KL loss """ kl_div = self.context_encoder.compute_kl_div() kl_loss_each_task = self.kl_lambda * torch.sum(kl_div, dim=1) kl_loss = torch.sum(kl_loss_each_task) gt.stamp('get_kl_loss', unique=False) """ Obtain the Q-function loss """ self.Qs_optimizer.zero_grad() pred_q = self.Qs(obs, actions, inferred_mdps) pred_q = torch.squeeze(pred_q) qf_loss_each_task = (pred_q - target_q)**2 qf_loss_each_task = qf_loss_each_task.reshape(num_tasks, -1) qf_loss_each_task = torch.mean(qf_loss_each_task, dim=1) qf_loss = torch.mean(qf_loss_each_task) gt.stamp('get_qf_loss', unique=False) (kl_loss + qf_loss).backward() check_grad_nan_nets(self.networks, 'kl q') gt.stamp('get_kl_qf_gradient', unique=False) self.Qs_optimizer.step() self.context_encoder_optimizer.step() gt.stamp('update_Qs_encoder', unique=False) """ Obtain the candidate action and perturbation loss """ self.vae_decoder_optimizer.zero_grad() self.perturbation_generator_optimizer.zero_grad() pred_candidates = self.vae_decoder(obs, z, inferred_mdps.detach()) pred_perturbations = self.perturbation_generator( obs, target_candidates, inferred_mdps.detach()) candidate_loss_each_task = (pred_candidates - target_candidates)**2 # averaging over action dimension candidate_loss_each_task = torch.mean(candidate_loss_each_task, dim=1) candidate_loss_each_task = candidate_loss_each_task.reshape( num_tasks, in_mdp_batch_size) # average over action in each task candidate_loss_each_task = torch.mean(candidate_loss_each_task, dim=1) candidate_loss = torch.mean(candidate_loss_each_task) perturbation_loss_each_task = (pred_perturbations - target_perturbations)**2 # average over action dimension perturbation_loss_each_task = torch.mean(perturbation_loss_each_task, dim=1) perturbation_loss_each_task = perturbation_loss_each_task.reshape( num_tasks, in_mdp_batch_size) # average over action in each task perturbation_loss_each_task = torch.mean(perturbation_loss_each_task, dim=1) perturbation_loss = torch.mean(perturbation_loss_each_task) gt.stamp('get_candidate_and_perturbation_loss', unique=False) (candidate_loss + perturbation_loss).backward() check_grad_nan_nets(self.networks, 'perb') gt.stamp('get_candidate_and_perturbation_gradient', unique=False) self.vae_decoder_optimizer.step() self.perturbation_generator_optimizer.step() for net in self.networks: for name, m in net.named_parameters(): if (m != m).any(): print(net, name) print(num_selected_list) print(min(num_selected_list)) exit() gt.stamp('update_vae_perturbation', unique=False) """ Save some statistics for eval """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ self.eval_statistics['qf_loss'] = np.mean(ptu.get_numpy(qf_loss)) self.eval_statistics['qf_loss_each_task'] = ptu.get_numpy( qf_loss_each_task) self.eval_statistics['kl_loss'] = np.mean(ptu.get_numpy(kl_loss)) self.eval_statistics['triplet_loss'] = np.mean( ptu.get_numpy(unscaled_triplet_loss_vec)) self.eval_statistics['kl_loss_each_task'] = ptu.get_numpy( kl_loss_each_task) self.eval_statistics['candidate_loss'] = np.mean( ptu.get_numpy(candidate_loss)) self.eval_statistics['candidate_loss_each_task'] = ptu.get_numpy( candidate_loss_each_task) self.eval_statistics['perturbation_loss'] = np.mean( ptu.get_numpy(perturbation_loss)) self.eval_statistics[ 'perturbation_loss_each_task'] = ptu.get_numpy( perturbation_loss_each_task) self.eval_statistics[ 'num_context_candidate_each_task'] = num_context_candidate_each_task
def get_param_values_np(self): state_dict = self.state_dict() np_dict = OrderedDict() for key, tensor in state_dict.items(): np_dict[key] = ptu.get_numpy(tensor) return np_dict
def train(self, train_data, discount=0.99, tau=0.005): state_np, next_state_np, action, reward, done, context = train_data state = torch.FloatTensor(state_np).to(device) action = torch.FloatTensor(action).to(device) next_state = torch.FloatTensor(next_state_np).to(device) reward = torch.FloatTensor(reward).to(device) done = torch.FloatTensor(1 - done).to(device) context = torch.FloatTensor(context).to(device) gt.stamp('unpack_data', unique=False) # Infer mdep identity using context self.context_encoder_optimizer.zero_grad() inferred_mdp = self.context_encoder(context) in_mdp_batch_size = state.shape[0] // context.shape[0] inferred_mdp = torch.repeat_interleave(inferred_mdp, in_mdp_batch_size, dim=0) gt.stamp('infer_mdp_identity', unique=False) # Variational Auto-Encoder Training recon, mean, std = self.vae(state, action, inferred_mdp) recon_loss = F.mse_loss(recon, action) KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean() vae_loss = recon_loss + 0.5 * KL_loss gt.stamp('get_vae_loss', unique=False) self.vae_optimizer.zero_grad() vae_loss.backward(retain_graph=True) self.vae_optimizer.step() gt.stamp('update_vae', unique=False) # Critic Training self.critic_optimizer.zero_grad() with torch.no_grad(): # Duplicate state 10 times state_rep = next_state.repeat_interleave(10, dim=0) inferred_mdp_rep = inferred_mdp.repeat_interleave(10, dim=0) target_Q1, target_Q2 = self.critic_target( state_rep, self.actor_target( state_rep, self.vae.decode(state_rep, inferred_mdp=inferred_mdp_rep), inferred_mdp_rep), inferred_mdp_rep) # Soft Clipped Double Q-learning target_Q = self.target_q_coef * torch.min(target_Q1, target_Q2) + ( 1 - self.target_q_coef) * torch.max(target_Q1, target_Q2) target_Q = target_Q.view(state.shape[0], -1).max(1)[0].view(-1, 1) target_Q = reward + done * discount * target_Q current_Q1, current_Q2 = self.critic(state, action, inferred_mdp) critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss( current_Q2, target_Q) gt.stamp('get_critic_loss', unique=False) self.critic_optimizer.zero_grad() critic_loss.backward(retain_graph=True) self.critic_optimizer.step() gt.stamp('update_critic', unique=False) self.context_encoder_optimizer.step() # Pertubation Model / Action Training sampled_actions = self.vae.decode(state, inferred_mdp=inferred_mdp.detach()) perturbed_actions = self.actor(state, sampled_actions, inferred_mdp.detach()) # Update through DPG actor_loss = -self.critic.q1(state, perturbed_actions, inferred_mdp.detach()).mean() gt.stamp('get_actor_loss', unique=False) self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() gt.stamp('update_actor', unique=False) # Update Target Networks for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) """ Save some statistics for eval """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ self.eval_statistics['actor_loss'] = np.mean(get_numpy(actor_loss)) self.eval_statistics['critic_loss'] = np.mean( get_numpy(critic_loss)) self.eval_statistics['vae_loss'] = np.mean(get_numpy(vae_loss))
def train(self, batch, batch_idxes): """ Unpack data from the batch """ obs = batch['obs'] actions = batch['actions'] contexts = batch['contexts'] num_tasks = batch_idxes.shape[0] gt.stamp('unpack_data_from_the_batch', unique=False) # Get the in_mdp_batch_size in_mdp_batch_size = obs.shape[0] // batch_idxes.shape[0] num_trans_context = contexts.shape[0] // batch_idxes.shape[0] contexts = contexts.reshape(num_tasks, num_trans_context, -1) """ Infer the context variables """ inferred_mdps = self.context_encoder(contexts) inferred_mdps = torch.repeat_interleave(inferred_mdps, in_mdp_batch_size, dim=0) gt.stamp('infer_mdps', unique=False) """ Obtain the targets """ with torch.no_grad(): # Sample z for each state z = self.bcq_polices[0].vae.sample_z(obs).to(ptu.device) # Each item in critic_weights is a list that has device count entries # each entry in the critic_weights[i] is a list that has num layer entries # each entry in the critic_weights[i][j] is a tensor of dim (num tasks // device count, layer input size, layer out size) # Similarly to the other weights and biases critic_weights, critic_biases, vae_weights, vae_biases, actor_weights, actor_biases = self.combined_bcq_policies # CRITIC obs_reshaped = obs.reshape(len(batch_idxes), in_mdp_batch_size, -1) acs_reshaped = actions.reshape(len(batch_idxes), in_mdp_batch_size, -1) obs_acs_reshaped = torch.cat((obs_reshaped, acs_reshaped), dim=-1) target_q = batch_bcq(obs_acs_reshaped, critic_weights, critic_biases) target_q = target_q.reshape(-1) # VAE z_reshaped = z.reshape(len(batch_idxes), in_mdp_batch_size, -1) obs_z_reshaped = torch.cat((obs_reshaped, z_reshaped), dim=-1) tc = batch_bcq(obs_z_reshaped, vae_weights, vae_biases) tc = self.bcq_polices[0].vae.max_action * torch.tanh(tc) target_candidates = tc.reshape(-1, tc.shape[-1]) # PERTURBATION tc_reshaped = target_candidates.reshape(len(batch_idxes), in_mdp_batch_size, -1) obs_tc_reshaped = torch.cat((obs_reshaped, tc_reshaped), dim=-1) tp = batch_bcq(obs_tc_reshaped, actor_weights, actor_biases) tp = self.bcq_polices[0].actor.max_action * torch.tanh(tp) target_perturbations = tp.reshape(-1, tp.shape[-1]) gt.stamp('get_the_targets', unique=False) """ Obtain the KL loss """ # KL constraint on z if probabilistic self.context_encoder_optimizer.zero_grad() kl_div = self.context_encoder.compute_kl_div() kl_loss_each_task = self.kl_lambda * torch.sum(kl_div, dim=1) kl_loss = torch.sum(kl_loss_each_task) gt.stamp('get_kl_loss', unique=False) """ Obtain the Q-function loss """ self.Qs_optimizer.zero_grad() pred_q = self.Qs(obs, actions, inferred_mdps) pred_q = torch.squeeze(pred_q) qf_loss_each_task = (pred_q - target_q)**2 qf_loss_each_task = qf_loss_each_task.reshape(num_tasks, -1) qf_loss_each_task = torch.mean(qf_loss_each_task, dim=1) qf_loss = torch.mean(qf_loss_each_task) gt.stamp('get_qf_loss', unique=False) (kl_loss + qf_loss).backward() gt.stamp('get_kl_qf_gradient', unique=False) self.Qs_optimizer.step() self.context_encoder_optimizer.step() gt.stamp('update_Qs_encoder', unique=False) """ Obtain the candidate action and perturbation loss """ self.vae_decoder_optimizer.zero_grad() self.perturbation_generator_optimizer.zero_grad() pred_candidates = self.vae_decoder(obs, z, inferred_mdps.detach()) pred_perturbations = self.perturbation_generator( obs, target_candidates, inferred_mdps.detach()) candidate_loss_each_task = (pred_candidates - target_candidates)**2 # averaging over action dimension candidate_loss_each_task = torch.mean(candidate_loss_each_task, dim=1) candidate_loss_each_task = candidate_loss_each_task.reshape( num_tasks, in_mdp_batch_size) # average over action in each task candidate_loss_each_task = torch.mean(candidate_loss_each_task, dim=1) candidate_loss = torch.mean(candidate_loss_each_task) perturbation_loss_each_task = (pred_perturbations - target_perturbations)**2 # average over action dimension perturbation_loss_each_task = torch.mean(perturbation_loss_each_task, dim=1) perturbation_loss_each_task = perturbation_loss_each_task.reshape( num_tasks, in_mdp_batch_size) # average over action in each task perturbation_loss_each_task = torch.mean(perturbation_loss_each_task, dim=1) perturbation_loss = torch.mean(perturbation_loss_each_task) gt.stamp('get_candidate_and_perturbation_loss', unique=False) (candidate_loss + perturbation_loss).backward() gt.stamp('get_candidate_and_perturbation_gradient', unique=False) self.vae_decoder_optimizer.step() self.perturbation_generator_optimizer.step() gt.stamp('update_vae_perturbation', unique=False) """ Save some statistics for eval """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ self.eval_statistics['qf_loss'] = np.mean(ptu.get_numpy(qf_loss)) self.eval_statistics['qf_loss_each_task'] = ptu.get_numpy( qf_loss_each_task) self.eval_statistics['kl_loss'] = np.mean(ptu.get_numpy(kl_loss)) self.eval_statistics['kl_loss_each_task'] = ptu.get_numpy( kl_loss_each_task) self.eval_statistics['candidate_loss'] = np.mean( ptu.get_numpy(candidate_loss)) self.eval_statistics['candidate_loss_each_task'] = ptu.get_numpy( candidate_loss_each_task) self.eval_statistics['perturbation_loss'] = np.mean( ptu.get_numpy(perturbation_loss)) self.eval_statistics[ 'perturbation_loss_each_task'] = ptu.get_numpy( perturbation_loss_each_task)
def get_optimistic_exploration_action(ob_np, policy=None, qfs=None, hyper_params=None): #assert ob_np.ndim == 1 beta_UB = hyper_params['beta_UB'] delta = hyper_params['delta'] #ob = ptu.from_numpy(ob_np) ob = {k: ptu.from_numpy(v[None]) for k, v in ob_np.items()} # Ensure that ob is not batched # assert len(list(ob.shape)) == 1 _, pre_tanh_mu_T, _, _, std, _ = policy(ob) #print(pre_tanh_mu_T.shape) pre_tanh_mu_T = pre_tanh_mu_T[0] std = std[0] # Ensure that pretanh_mu_T is not batched assert len(list(pre_tanh_mu_T.shape)) == 1, pre_tanh_mu_T assert len(list(std.shape)) == 1 pre_tanh_mu_T.requires_grad_() tanh_mu_T = torch.tanh(pre_tanh_mu_T) # Get the upper bound of the Q estimate args = [ob, torch.unsqueeze(tanh_mu_T, dim=0) ] #list(torch.unsqueeze(i, dim=0) for i in (ob, tanh_mu_T)) Q1 = qfs[0](*args) Q2 = qfs[1](*args) mu_Q = (Q1 + Q2) / 2.0 sigma_Q = torch.abs(Q1 - Q2) / 2.0 Q_UB = mu_Q + beta_UB * sigma_Q # Obtain the gradient of Q_UB wrt to a # with a evaluated at mu_t grad = torch.autograd.grad(Q_UB, pre_tanh_mu_T) grad = grad[0] assert grad is not None assert pre_tanh_mu_T.shape == grad.shape # Obtain Sigma_T (the covariance of the normal distribution) Sigma_T = torch.pow(std, 2) # The dividor is (g^T Sigma g) ** 0.5 # Sigma is diagonal, so this works out to be # ( sum_{i=1}^k (g^(i))^2 (sigma^(i))^2 ) ** 0.5 denom = torch.sqrt(torch.sum(torch.mul(torch.pow(grad, 2), Sigma_T))) + 10e-6 # Obtain the change in mu mu_C = math.sqrt(2.0 * delta) * torch.mul(Sigma_T, grad) / denom assert mu_C.shape == pre_tanh_mu_T.shape mu_E = pre_tanh_mu_T + mu_C # Construct the tanh normal distribution and sample the exploratory action from it assert mu_E.shape == std.shape dist = TanhNormal(mu_E, std) ac = dist.sample() ac_np = ptu.get_numpy(ac) # mu_T_np = ptu.get_numpy(pre_tanh_mu_T) # mu_C_np = ptu.get_numpy(mu_C) # mu_E_np = ptu.get_numpy(mu_E) # dict( # mu_T=mu_T_np, # mu_C=mu_C_np, # mu_E=mu_E_np # ) # Return an empty dict, and do not log # stats for now return ac_np, {}
def train(self, batch, batch_idxes, epoch): """ Unpack data from the batch """ obs = batch['obs'] actions = batch['actions'] next_obs = batch['next_obs'] qpos = batch['qpos'] qvel = batch['qvel'] # Get the in_mdp_batch_size in_mdp_batch_size = obs.shape[0] // batch_idxes.shape[0] num_tasks = batch_idxes.shape[0] """ Obtain the model prediction loss """ # Note that here, we do not calculate the obs_loss. next_obs_loss_task_0 = [] pred_next_obs = [net(obs, actions) for net in self.network_ensemble] for pred_no in pred_next_obs: loss = F.mse_loss(pred_no[:in_mdp_batch_size], next_obs[:in_mdp_batch_size]) next_obs_loss_task_0.append(loss) next_obs_magnitude = torch.mean( torch.norm(next_obs[:in_mdp_batch_size], dim=1)) gt.stamp('get_tranistion_prediction_loss', unique=False) self.network_ensemble_optimizer.zero_grad() next_obs_loss_task_0 = torch.stack(next_obs_loss_task_0) next_obs_loss_task_0 = torch.sum(next_obs_loss_task_0) next_obs_loss_task_0.backward() # [loss.backward() for loss in next_obs_loss_task_0] self.network_ensemble_optimizer.step() gt.stamp('update', unique=False) """ Save some statistics for eval """ if self._need_to_update_eval_statistics: if epoch > 150: qpos_other_tasks = [ qpos[in_mdp_batch_size * i:in_mdp_batch_size * (i + 1)] for i in range(0, batch_idxes.shape[0]) ] qvel_other_tasks = [ qvel[in_mdp_batch_size * i:in_mdp_batch_size * (i + 1)] for i in range(0, batch_idxes.shape[0]) ] actions_other_tasks = [ actions[in_mdp_batch_size * i:in_mdp_batch_size * (i + 1)] for i in range(0, batch_idxes.shape[0]) ] pred_next_obs_other_tasks = [ torch.cat([ pred_no[in_mdp_batch_size * i:in_mdp_batch_size * (i + 1)][..., None] for pred_no in pred_next_obs ], dim=-1) for i in range(0, batch_idxes.shape[0]) ] next_obs_loss_other_tasks = [] next_obs_loss_other_tasks_std = [] num_selected_trans_other_tasks = [] for item in zip(pred_next_obs_other_tasks, qpos_other_tasks, qvel_other_tasks, actions_other_tasks): pred_no_other_task, qp_other_task, qv_other_task, a_other_task = item pred_std = torch.std(pred_no_other_task, dim=-1) pred_std = pred_std.squeeze() pred_std = torch.mean(pred_std, dim=1) mask = ptu.get_numpy(pred_std < self.std_threshold) num_selected_trans_other_tasks.append(np.sum(mask)) mask = mask.astype(bool) pred_no_other_task = ptu.get_numpy(pred_no_other_task) pred_no_other_task = pred_no_other_task[mask] qp_other_task = ptu.get_numpy(qp_other_task) qp_other_task = qp_other_task[mask] qv_other_task = ptu.get_numpy(qv_other_task) qv_other_task = qv_other_task[mask] a_other_task = ptu.get_numpy(a_other_task) a_other_task = a_other_task[mask] mse_loss = [] for pred_no, qp, qv, a in zip(pred_no_other_task, qp_other_task, qv_other_task, a_other_task): self.env.set_state(qp, qv) no, _, _, _ = self.env.step(a) loss = (pred_no - no.reshape(-1, 1))**2 loss = np.mean(loss, axis=0) mse_loss.append(loss) if len(mse_loss) > 0: mse_loss = np.stack(mse_loss) mse_loss_mean = np.mean(mse_loss) next_obs_loss_other_tasks.append(mse_loss_mean) mse_loss_std = np.std(mse_loss, axis=1) mse_loss_std = np.mean(mse_loss_std) next_obs_loss_other_tasks_std.append(mse_loss_std) self.eval_statistics[ 'average_task_next_obs_loss_other_tasks_mean'] = next_obs_loss_other_tasks self.eval_statistics[ 'average_task_next_obs_loss_other_tasks_std'] = next_obs_loss_other_tasks_std self.eval_statistics[ 'num_selected_trans_other_tasks'] = num_selected_trans_other_tasks self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ # self.eval_statistics['next_obs_loss_task_0'] = np.mean( # ptu.get_numpy(torch.mean(torch.stack(next_obs_loss_task_0))) # ) self.eval_statistics['next_obs_loss_task_0'] = np.mean( ptu.get_numpy(next_obs_loss_task_0 / len(self.network_ensemble))) self.eval_statistics['next_obs_magnitude'] = ptu.get_numpy( next_obs_magnitude)