def select_actions(self, obs, raw_context): # Repeat the obs as what BCQ has done, # candidate_size here indicates how many # candidate actions we need. obs = from_numpy(np.tile(obs.reshape(1, -1), (self.candidate_size, 1))) if len(raw_context) == 0: # In the beginning, the inferred_mdp is set to zero vector. inferred_mdp = ptu.zeros((1, self.f.latent_dim)) else: # Construct the context from raw context context = from_numpy(np.concatenate(raw_context, axis=0))[None] inferred_mdp = self.f(context) with torch.no_grad(): inferred_mdp = inferred_mdp.repeat(self.candidate_size, 1) z = from_numpy( np.random.normal(0, 1, size=(obs.size(0), self.vae_latent_dim))).clamp( -0.5, 0.5).to(ptu.device) candidate_actions = self.vae_decoder(obs, z, inferred_mdp) perturbed_actions = self.perturbation_generator.get_perturbed_actions( obs, candidate_actions, inferred_mdp) qv = self.Qs(obs, perturbed_actions, inferred_mdp) ind = qv.max(0)[1] return ptu.get_numpy(perturbed_actions[ind])
def _elem_or_tuple_to_variable(elem_or_tuple): if isinstance(elem_or_tuple, tuple): return tuple(_elem_or_tuple_to_variable(e) for e in elem_or_tuple) elif isinstance(elem_or_tuple, OrderedDict) or isinstance( elem_or_tuple, dict): return {k: ptu.from_numpy(v).float() for k, v in elem_or_tuple.items()} return ptu.from_numpy(elem_or_tuple).float()
def torch_ify(np_array_or_other): if isinstance(np_array_or_other, np.ndarray): return ptu.from_numpy(np_array_or_other) elif isinstance(np_array_or_other, OrderedDict) or isinstance( np_array_or_other, dict): return {k: ptu.from_numpy(v) for k, v in np_array_or_other.items()} else: return np_array_or_other
def _get_prod_of_gauss_mask(num_selected, desired_len): # Taken from # https://discuss.pytorch.org/t/create-a-2d-tensor-with-varying-lengths-of-one-in-each-row/25359 # desired_length is the desired size of the second dimension of the masks seq_lens = ptu.from_numpy(np.array(num_selected)).unsqueeze(-1) max_len = torch.max(seq_lens) # create tensor of suitable shape and same number of dimensions range_tensor = torch.arange(max_len).unsqueeze(0) range_tensor = range_tensor.to(ptu.device) range_tensor = range_tensor.expand(seq_lens.size(0), range_tensor.size(1)) # until this step, we only created auxiliary tensors (you may already have from previous steps) # the real mask tensor is created with binary masking: mask_tensor = (range_tensor < seq_lens) mask_tensor = mask_tensor.type(torch.float) current_len = mask_tensor.shape[1] pad = ptu.zeros(mask_tensor.shape[0], desired_len - current_len) mask_tensor = torch.cat((mask_tensor, pad), dim=1) return mask_tensor
def select_actions(self, obs, inferred_mdp): # Repeat the obs as what BCQ has done, # candidate_size here indicates how many # candidate actions we need. obs = from_numpy(np.tile(obs.reshape(1, -1), (self.candidate_size, 1))) with torch.no_grad(): inferred_mdp = inferred_mdp.repeat(self.candidate_size, 1) z = from_numpy( np.random.normal(0, 1, size=(obs.size(0), self.vae_latent_dim))).clamp( -0.5, 0.5).to(ptu.device) candidate_actions = self.vae_decoder(obs, z, inferred_mdp) perturbed_actions = self.perturbation_generator.get_perturbed_actions( obs, candidate_actions, inferred_mdp) qv = self.Qs(obs, perturbed_actions, inferred_mdp) ind = qv.max(0)[1] return ptu.get_numpy(perturbed_actions[ind])
def calculate_rewards(next_obs, goals): pos = next_obs[:,:2][None] pos = pos.expand(len(goals), pos.shape[1], pos.shape[2]) goals_np = np.array(goals) goals_pt = ptu.from_numpy(goals_np) goals_pt = goals_pt.unsqueeze(1) reward = torch.exp(-torch.norm(pos - goals_pt, dim=-1)) return reward
def async_evaluate(self, goal): self.env.set_goal(goal) self.policy.context_encoder.clear_z() avg_reward = 0. avg_achieved = [] final_achieved = [] raw_context = deque() for i in range(self.num_evals): # Sample MDP indentity self.policy.context_encoder.sample_z() inferred_mdp = self.policy.context_encoder.z obs = self.env.reset() done = False path_length = 0 while not done and path_length < self.max_path_length: action = self.select_actions(np.array(obs), inferred_mdp) next_obs, reward, done, env_info = self.env.step(action) avg_achieved.append(env_info['achieved']) new_context = np.concatenate([ obs.reshape(1, -1), action.reshape(1, -1), next_obs.reshape(1, -1), np.array(reward).reshape(1, -1) ], axis=1) raw_context.append(new_context) obs = next_obs.copy() if i > 1: avg_reward += reward path_length += 1 context = from_numpy(np.concatenate(raw_context, axis=0))[None] self.policy.context_encoder.infer_posterior(context) if i > 1: final_achieved.append(env_info['achieved']) avg_reward /= (self.num_evals - 2) if np.isscalar(env_info['achieved']): avg_achieved = np.mean(avg_achieved) final_achieved = np.mean(final_achieved) else: avg_achieved = np.stack(avg_achieved) avg_achieved = np.mean(avg_achieved, axis=0) final_achieved = np.stack(final_achieved) final_achieved = np.mean(final_achieved, axis=0) print(avg_reward) return avg_reward, (final_achieved.tolist(), self.env._goal.tolist())
def async_evaluate_test(self, goal): self.env.set_goal(goal) self.context_encoder.clear_z() avg_reward_list = [] online_achieved_list = [] raw_context = deque() for _ in range(self.num_evals): # Sample MDP indentity self.context_encoder.sample_z() inferred_mdp = self.context_encoder.z obs = self.env.reset() done = False path_length = 0 avg_reward = 0. online_achieved = [] while not done and path_length < self.max_path_length: action = self.select_actions(np.array(obs), inferred_mdp) next_obs, reward, done, env_info = self.env.step(action) achieved = env_info['achieved'] online_achieved.append(np.arctan(achieved[1] / achieved[0])) if self.use_next_obs_in_context: new_context = np.concatenate([ obs.reshape(1, -1), action.reshape(1, -1), next_obs.reshape(1, -1), np.array(reward).reshape(1, -1) ], axis=1) else: new_context = np.concatenate([ obs.reshape(1, -1), action.reshape(1, -1), np.array(reward).reshape(1, -1) ], axis=1) raw_context.append(new_context) obs = next_obs.copy() avg_reward += reward path_length += 1 avg_reward_list.append(avg_reward) online_achieved = np.array(online_achieved) online_achieved_list.append([ online_achieved.mean(), online_achieved.std(), self.env._goal ]) context = from_numpy(np.concatenate(raw_context, axis=0))[None] self.context_encoder.infer_posterior(context) return online_achieved_list
def check_q_funct_estimate(self, paths): s0 = np.stack([path["observations"][0] for path in paths]) s0 = ptu.from_numpy(s0) a0 = np.stack([path["actions"][0] for path in paths]) a0 = ptu.from_numpy(a0) inferred_mdps = torch.repeat_interleave(self.trainer._inferred_mdp, s0.shape[0], dim=0) q_values = torch.min( self.trainer.qf1(s0, a0, inferred_mdps), self.trainer.qf2(s0, a0, inferred_mdps), ) q_values = ptu.get_numpy(q_values) dicount_returns = [] for path in paths: discount_cof = self.trainer.discount**np.arange( len(path["rewards"])) dicount_return = np.sum(path["rewards"].flatten() * discount_cof) dicount_returns.append(dicount_return) q_values_mean = np.mean(q_values) q_values_std = np.std(q_values) dicount_returns_mean = np.mean(dicount_returns) dicount_returns_std = np.std(dicount_returns) return dict( q_values_mean=q_values_mean, q_values_std=q_values_std, dicount_returns_mean=dicount_returns_mean, dicount_returns_std=dicount_returns_std, )
def select_actions(self, obs, raw_context): # Repeat the obs as what BCQ has done, # candidate_size here indicates how many # candidate actions we need. if len(raw_context) == 0: # In the beginning, the inferred_mdp is set to zero vector. inferred_mdp = ptu.zeros( (1, self.policy.mlp_encoder.encoder_latent_dim)) else: # Construct the context from raw context context = from_numpy(np.concatenate(raw_context, axis=0))[None] inferred_mdp = self.policy.mlp_encoder(context) # obs = torch.cat([obs, inferred_mdp], dim=1) action = self.policy.select_action(obs, get_numpy(inferred_mdp)) return action
def _elem_or_tuple_to_variable(elem_or_tuple): if isinstance(elem_or_tuple, tuple): return tuple( _elem_or_tuple_to_variable(e) for e in elem_or_tuple ) return ptu.from_numpy(elem_or_tuple).float()
def torch_ify(np_array_or_other): if isinstance(np_array_or_other, np.ndarray): return ptu.from_numpy(np_array_or_other) else: return np_array_or_other
def get_optimistic_exploration_action(ob_np, policy=None, qfs=None, hyper_params=None): #assert ob_np.ndim == 1 beta_UB = hyper_params['beta_UB'] delta = hyper_params['delta'] #ob = ptu.from_numpy(ob_np) ob = {k: ptu.from_numpy(v[None]) for k, v in ob_np.items()} # Ensure that ob is not batched # assert len(list(ob.shape)) == 1 _, pre_tanh_mu_T, _, _, std, _ = policy(ob) #print(pre_tanh_mu_T.shape) pre_tanh_mu_T = pre_tanh_mu_T[0] std = std[0] # Ensure that pretanh_mu_T is not batched assert len(list(pre_tanh_mu_T.shape)) == 1, pre_tanh_mu_T assert len(list(std.shape)) == 1 pre_tanh_mu_T.requires_grad_() tanh_mu_T = torch.tanh(pre_tanh_mu_T) # Get the upper bound of the Q estimate args = [ob, torch.unsqueeze(tanh_mu_T, dim=0) ] #list(torch.unsqueeze(i, dim=0) for i in (ob, tanh_mu_T)) Q1 = qfs[0](*args) Q2 = qfs[1](*args) mu_Q = (Q1 + Q2) / 2.0 sigma_Q = torch.abs(Q1 - Q2) / 2.0 Q_UB = mu_Q + beta_UB * sigma_Q # Obtain the gradient of Q_UB wrt to a # with a evaluated at mu_t grad = torch.autograd.grad(Q_UB, pre_tanh_mu_T) grad = grad[0] assert grad is not None assert pre_tanh_mu_T.shape == grad.shape # Obtain Sigma_T (the covariance of the normal distribution) Sigma_T = torch.pow(std, 2) # The dividor is (g^T Sigma g) ** 0.5 # Sigma is diagonal, so this works out to be # ( sum_{i=1}^k (g^(i))^2 (sigma^(i))^2 ) ** 0.5 denom = torch.sqrt(torch.sum(torch.mul(torch.pow(grad, 2), Sigma_T))) + 10e-6 # Obtain the change in mu mu_C = math.sqrt(2.0 * delta) * torch.mul(Sigma_T, grad) / denom assert mu_C.shape == pre_tanh_mu_T.shape mu_E = pre_tanh_mu_T + mu_C # Construct the tanh normal distribution and sample the exploratory action from it assert mu_E.shape == std.shape dist = TanhNormal(mu_E, std) ac = dist.sample() ac_np = ptu.get_numpy(ac) # mu_T_np = ptu.get_numpy(pre_tanh_mu_T) # mu_C_np = ptu.get_numpy(mu_C) # mu_E_np = ptu.get_numpy(mu_E) # dict( # mu_T=mu_T_np, # mu_C=mu_C_np, # mu_E=mu_E_np # ) # Return an empty dict, and do not log # stats for now return ac_np, {}
def train(self, batch, batch_idxes): """ Unpack data from the batch """ obs = batch['obs'] actions = batch['actions'] contexts = batch['contexts'] num_tasks = batch_idxes.shape[0] gt.stamp('unpack_data_from_the_batch', unique=False) # Get the in_mdp_batch_size obs_dim = obs.shape[1] action_dim = actions.shape[1] in_mdp_batch_size = obs.shape[0] // batch_idxes.shape[0] num_trans_context = contexts.shape[0] // batch_idxes.shape[0] """ Relabel the context batches for each training task """ with torch.no_grad(): contexts_obs_actions = contexts[:, :obs_dim + action_dim] manual_batched_rewards = self.reward_ensemble_predictor.forward_mul_device( contexts_obs_actions) relabeled_rewards = manual_batched_rewards.reshape( num_tasks, self.num_network_ensemble, contexts.shape[0]) gt.stamp('reward_ensemble_forward', unique=False) manual_batched_next_obs = self.transition_ensemble_predictor.forward_mul_device( contexts_obs_actions) relabeled_next_obs = manual_batched_next_obs.reshape( num_tasks, self.num_network_ensemble, contexts.shape[0], obs_dim) gt.stamp('transition_ensemble_forward', unique=False) relabeled_rewards_mean = torch.mean(relabeled_rewards, dim=1).squeeze() relabeled_rewards_std = torch.std(relabeled_rewards, dim=1).squeeze() relabeled_next_obs_mean = torch.mean(relabeled_next_obs, dim=1) relabeled_next_obs_std = torch.std(relabeled_next_obs, dim=1) relabeled_next_obs_std = torch.mean(relabeled_next_obs_std, dim=-1) # Replace the predicted reward with ground truth reward for transitions # with ground truth reward inside the batch for i in range(num_tasks): relabeled_rewards_mean[i, i*num_trans_context: (i+1)*num_trans_context] \ = contexts[i*num_trans_context: (i+1)*num_trans_context, -1] relabeled_next_obs_mean[i, i*num_trans_context: (i+1)*num_trans_context, :] \ = contexts[i*num_trans_context: (i+1)*num_trans_context, obs_dim + action_dim : -1] if self.is_combine: # Set the number to be larger than the self.std_threshold, so that # they will initially be filtered out when producing the mask, # which is conducive to the sampling. relabeled_rewards_std[ i, i * num_trans_context:(i + 1) * num_trans_context] = self.reward_std_threshold + 1.0 relabeled_next_obs_std[ i, i * num_trans_context:(i + 1) * num_trans_context] = self.next_obs_std_threshold + 1.0 else: relabeled_rewards_std[i, i * num_trans_context:(i + 1) * num_trans_context] = 0.0 relabeled_next_obs_std[i, i * num_trans_context:(i + 1) * num_trans_context] = 0.0 mask_reward = relabeled_rewards_std < self.reward_std_threshold mask_reward = mask_reward.type(torch.float) mask_next_obs = relabeled_next_obs_std < self.next_obs_std_threshold mask_next_obs = mask_next_obs.type(torch.float) mask = mask_reward * mask_next_obs mask = mask.type(torch.uint8) mask_from_the_other_tasks = mask.type(torch.uint8).clone() num_context_candidate_each_task = torch.sum(mask, dim=1) mask_list = [] for i in range(num_tasks): assert mask[i].dim() == 1 mask_nonzero = torch.nonzero(mask[i]) mask_nonzero = mask_nonzero.flatten() mask_i = ptu.zeros_like(mask[i], dtype=torch.uint8) assert num_context_candidate_each_task[i].item( ) == mask_nonzero.shape[0] np_ind = np.random.choice(mask_nonzero.shape[0], num_trans_context, replace=False) ind = mask_nonzero[np_ind] mask_i[ind] = 1 if self.is_combine: # Combine the additional relabeledcontext transitions with # the original context transitions with ground-truth rewards mask_i[i * num_trans_context:(i + 1) * num_trans_context] = 1 assert torch.sum(mask_i).item() == 2 * num_trans_context else: assert torch.sum(mask_i).item() == num_trans_context mask_list.append(mask_i) mask = torch.cat(mask_list) mask = mask.type(torch.uint8) repeated_contexts = contexts.repeat(num_tasks, 1) context_without_next_obs_rewards = repeated_contexts[:, :obs_dim + action_dim] assert context_without_next_obs_rewards.shape[ 0] == relabeled_rewards_mean.reshape(-1, 1).shape[0] assert context_without_next_obs_rewards.shape[ 0] == relabeled_next_obs_mean.reshape(-1, obs_dim).shape[0] context_without_next_obs_rewards = context_without_next_obs_rewards[ mask] context_next_obs = relabeled_next_obs_mean.reshape(-1, obs_dim)[mask] context_rewards = relabeled_rewards_mean.reshape(-1, 1)[mask] fast_contexts = torch.cat((context_without_next_obs_rewards, context_next_obs, context_rewards), dim=1) fast_contexts = fast_contexts.reshape(num_tasks, -1, contexts.shape[-1]) gt.stamp('relabel_context_transitions', unique=False) """ Obtain the targets """ with torch.no_grad(): # Sample z for each state z = self.bcq_polices[0].vae.sample_z(obs).to(ptu.device) # Each item in critic_weights is a list that has device count entries # each entry in the critic_weights[i] is a list that has num layer entries # each entry in the critic_weights[i][j] is a tensor of dim (num tasks // device count, layer input size, layer out size) # Similarly to the other weights and biases critic_weights, critic_biases, vae_weights, vae_biases, actor_weights, actor_biases = self.combined_bcq_policies # CRITIC obs_reshaped = obs.reshape(len(batch_idxes), in_mdp_batch_size, -1) acs_reshaped = actions.reshape(len(batch_idxes), in_mdp_batch_size, -1) obs_acs_reshaped = torch.cat((obs_reshaped, acs_reshaped), dim=-1) target_q = batch_bcq(obs_acs_reshaped, critic_weights, critic_biases) target_q = target_q.reshape(-1) # VAE z_reshaped = z.reshape(len(batch_idxes), in_mdp_batch_size, -1) obs_z_reshaped = torch.cat((obs_reshaped, z_reshaped), dim=-1) tc = batch_bcq(obs_z_reshaped, vae_weights, vae_biases) tc = self.bcq_polices[0].vae.max_action * torch.tanh(tc) target_candidates = tc.reshape(-1, tc.shape[-1]) # PERTURBATION tc_reshaped = target_candidates.reshape(len(batch_idxes), in_mdp_batch_size, -1) obs_tc_reshaped = torch.cat((obs_reshaped, tc_reshaped), dim=-1) tp = batch_bcq(obs_tc_reshaped, actor_weights, actor_biases) tp = self.bcq_polices[0].actor.max_action * torch.tanh(tp) target_perturbations = tp.reshape(-1, tp.shape[-1]) gt.stamp('get_the_targets', unique=False) """ Compute the triplet loss """ # ----------------------------------Vectorized------------------------------------------- self.context_encoder_optimizer.zero_grad() anchors = [] positives = [] negatives = [] num_selected_list = [] # Pair of task (i,j) # where no transitions from j is selected by the ensemble of task i exclude_tasks = [] exclude_task_masks = [] for i in range(num_tasks): # Compute the triplet loss for task i for j in range(num_tasks): if j != i: # mask_for_task_j: (num_trans_context, ) # mask_from_the_other_tasks: (num_tasks, num_tasks * num_trans_context) mask_for_task_j = mask_from_the_other_tasks[ i, j * num_trans_context:(j + 1) * num_trans_context] num_selected = int(torch.sum(mask_for_task_j).item()) if num_selected == 0: exclude_tasks.append((i, j)) exclude_task_masks.append(0) else: exclude_task_masks.append(1) # context_trans_all: (num_trans_context, context_dim) context_trans_all = contexts[j * num_trans_context:(j + 1) * num_trans_context] # context_trans_all: (num_selected, context_dim) context_trans_selected = context_trans_all[mask_for_task_j] # relabel_reward_all: (num_trans_context, ) relabel_reward_all = relabeled_rewards_mean[ i, j * num_trans_context:(j + 1) * num_trans_context] # relabel_reward_all: (num_selected, ) relabel_reward_selected = relabel_reward_all[ mask_for_task_j] # relabel_reward_all: (num_selected, 1) relabel_reward_selected = relabel_reward_selected.reshape( -1, 1) # relabel_next_obs_all: (num_trans_context, obs_dim) relabel_next_obs_all = relabeled_next_obs_mean[ i, j * num_trans_context:(j + 1) * num_trans_context] # relabel_next_obs_all: (num_selected, obs_dim) relabel_next_obs_selected = relabel_next_obs_all[ mask_for_task_j] # context_trans_selected_relabel: (num_selected, context_dim) context_trans_selected_relabel = torch.cat([ context_trans_selected[:, :obs_dim + action_dim], relabel_next_obs_selected, relabel_reward_selected ], dim=1) # c_{i} ind = np.random.choice(num_trans_context, num_selected, replace=False) # Next 2 lines used for comparing to sequential version # ind = ind_list[count] # count += 1 # context_trans_task_i: (num_trans_context, context_dim) context_trans_task_i = contexts[i * num_trans_context:(i + 1) * num_trans_context] # context_trans_task_i: (num_selected, context_dim) context_trans_task_i_sampled = context_trans_task_i[ind] # Pad the contexts with 0 tensor num_to_pad = num_trans_context - num_selected # pad_zero_tensor: (num_to_pad, context_dim) pad_zero_tensor = ptu.zeros( (num_to_pad, context_trans_selected.shape[1])) num_selected_list.append(num_selected) # Dim: (1, num_trans_context, context_dim) context_trans_selected = torch.cat( [context_trans_selected, pad_zero_tensor], dim=0) context_trans_selected_relabel = torch.cat( [context_trans_selected_relabel, pad_zero_tensor], dim=0) context_trans_task_i_sampled = torch.cat( [context_trans_task_i_sampled, pad_zero_tensor], dim=0) anchors.append(context_trans_selected_relabel[None]) positives.append(context_trans_task_i_sampled[None]) negatives.append(context_trans_selected[None]) # Dim: (num_tasks * (num_tasks - 1), num_trans_context, context_dim) anchors = torch.cat(anchors, dim=0) positives = torch.cat(positives, dim=0) negatives = torch.cat(negatives, dim=0) # input_contexts: (3 * num_tasks * (num_tasks - 1), num_trans_context, context_dim) input_contexts = torch.cat([anchors, positives, negatives], dim=0) # num_selected_pt: (num_tasks * (num_tasks - 1), ) num_selected_pt = torch.from_numpy(np.array(num_selected_list)) # num_selected_repeat: (3 * num_tasks * (num_tasks - 1), ) num_selected_repeat = num_selected_pt.repeat(3) # z_means_vec, z_vars_vec: (3 * num_tasks * (num_tasks - 1), latent_dim) z_means_vec, z_vars_vec = self.context_encoder.infer_posterior_with_mean_var( input_contexts, num_trans_context, num_selected_repeat) # z_means_vec, z_vars_vec: (3, num_tasks * (num_tasks - 1), latent_dim) z_means_vec = z_means_vec.reshape(3, anchors.shape[0], -1) z_vars_vec = z_vars_vec.reshape(3, anchors.shape[0], -1) # Dim: (num_tasks * (num_tasks - 1), latent_dim) z_means_anchors, z_vars_anchors = z_means_vec[0], z_vars_vec[0] z_means_positives, z_vars_positives = z_means_vec[1], z_vars_vec[1] z_means_negatives, z_vars_negatives = z_means_vec[2], z_vars_vec[2] with_task_dist = compute_kl_div_diagonal(z_means_anchors, z_vars_anchors, z_means_positives, z_vars_positives) across_task_dist = compute_kl_div_diagonal(z_means_anchors, z_vars_anchors, z_means_negatives, z_vars_negatives) # Remove the triplet corresponding to # num selected equal 0 exclude_task_masks = ptu.from_numpy(np.array(exclude_task_masks)) with_task_dist = with_task_dist * exclude_task_masks across_task_dist = across_task_dist * exclude_task_masks unscaled_triplet_loss_vec = F.relu(with_task_dist - across_task_dist + self.triplet_margin) unscaled_triplet_loss_vec = torch.mean(unscaled_triplet_loss_vec) # assert unscaled_triplet_loss_vec is not nan assert (unscaled_triplet_loss_vec != unscaled_triplet_loss_vec).any() is not True gt.stamp('get_triplet_loss', unique=False) unscaled_triplet_loss_vec.backward() check_grad_nan_nets(self.networks, f'triplet: {unscaled_triplet_loss_vec}') gt.stamp('get_triplet_loss_gradient', unique=False) """ Infer the context variables """ # inferred_mdps = self.context_encoder(new_contexts) inferred_mdps = self.context_encoder(fast_contexts) inferred_mdps = torch.repeat_interleave(inferred_mdps, in_mdp_batch_size, dim=0) gt.stamp('infer_mdps', unique=False) """ Obtain the KL loss """ kl_div = self.context_encoder.compute_kl_div() kl_loss_each_task = self.kl_lambda * torch.sum(kl_div, dim=1) kl_loss = torch.sum(kl_loss_each_task) gt.stamp('get_kl_loss', unique=False) """ Obtain the Q-function loss """ self.Qs_optimizer.zero_grad() pred_q = self.Qs(obs, actions, inferred_mdps) pred_q = torch.squeeze(pred_q) qf_loss_each_task = (pred_q - target_q)**2 qf_loss_each_task = qf_loss_each_task.reshape(num_tasks, -1) qf_loss_each_task = torch.mean(qf_loss_each_task, dim=1) qf_loss = torch.mean(qf_loss_each_task) gt.stamp('get_qf_loss', unique=False) (kl_loss + qf_loss).backward() check_grad_nan_nets(self.networks, 'kl q') gt.stamp('get_kl_qf_gradient', unique=False) self.Qs_optimizer.step() self.context_encoder_optimizer.step() gt.stamp('update_Qs_encoder', unique=False) """ Obtain the candidate action and perturbation loss """ self.vae_decoder_optimizer.zero_grad() self.perturbation_generator_optimizer.zero_grad() pred_candidates = self.vae_decoder(obs, z, inferred_mdps.detach()) pred_perturbations = self.perturbation_generator( obs, target_candidates, inferred_mdps.detach()) candidate_loss_each_task = (pred_candidates - target_candidates)**2 # averaging over action dimension candidate_loss_each_task = torch.mean(candidate_loss_each_task, dim=1) candidate_loss_each_task = candidate_loss_each_task.reshape( num_tasks, in_mdp_batch_size) # average over action in each task candidate_loss_each_task = torch.mean(candidate_loss_each_task, dim=1) candidate_loss = torch.mean(candidate_loss_each_task) perturbation_loss_each_task = (pred_perturbations - target_perturbations)**2 # average over action dimension perturbation_loss_each_task = torch.mean(perturbation_loss_each_task, dim=1) perturbation_loss_each_task = perturbation_loss_each_task.reshape( num_tasks, in_mdp_batch_size) # average over action in each task perturbation_loss_each_task = torch.mean(perturbation_loss_each_task, dim=1) perturbation_loss = torch.mean(perturbation_loss_each_task) gt.stamp('get_candidate_and_perturbation_loss', unique=False) (candidate_loss + perturbation_loss).backward() check_grad_nan_nets(self.networks, 'perb') gt.stamp('get_candidate_and_perturbation_gradient', unique=False) self.vae_decoder_optimizer.step() self.perturbation_generator_optimizer.step() for net in self.networks: for name, m in net.named_parameters(): if (m != m).any(): print(net, name) print(num_selected_list) print(min(num_selected_list)) exit() gt.stamp('update_vae_perturbation', unique=False) """ Save some statistics for eval """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ self.eval_statistics['qf_loss'] = np.mean(ptu.get_numpy(qf_loss)) self.eval_statistics['qf_loss_each_task'] = ptu.get_numpy( qf_loss_each_task) self.eval_statistics['kl_loss'] = np.mean(ptu.get_numpy(kl_loss)) self.eval_statistics['triplet_loss'] = np.mean( ptu.get_numpy(unscaled_triplet_loss_vec)) self.eval_statistics['kl_loss_each_task'] = ptu.get_numpy( kl_loss_each_task) self.eval_statistics['candidate_loss'] = np.mean( ptu.get_numpy(candidate_loss)) self.eval_statistics['candidate_loss_each_task'] = ptu.get_numpy( candidate_loss_each_task) self.eval_statistics['perturbation_loss'] = np.mean( ptu.get_numpy(perturbation_loss)) self.eval_statistics[ 'perturbation_loss_each_task'] = ptu.get_numpy( perturbation_loss_each_task) self.eval_statistics[ 'num_context_candidate_each_task'] = num_context_candidate_each_task
def set_param_values_np(self, param_values): torch_dict = OrderedDict() for key, tensor in param_values.items(): torch_dict[key] = ptu.from_numpy(tensor) self.load_state_dict(torch_dict)
def collect_new_paths( self, max_path_length, num_steps, discard_incomplete_paths, ): self.context_encoder.clear_z() paths = [] num_steps_collected = 0 raw_context = deque() while num_steps_collected < num_steps: max_path_length_this_loop = min( # Do not go over num_steps max_path_length, num_steps - num_steps_collected, ) # Sample MDP indentity self.context_encoder.sample_z() inferred_mdp = self.context_encoder.z path_length = 0 observations = [] actions = [] rewards = [] terminals = [] agent_infos = [] env_infos = [] obs = self.env.reset() done = False while not done and path_length < max_path_length_this_loop: action = self.select_actions(np.array(obs), inferred_mdp) next_obs, reward, done, _ = self.env.step(action) if self.use_next_obs_in_context: new_context = np.concatenate([ obs.reshape(1, -1), action.reshape(1, -1), next_obs.reshape(1, -1), np.array(reward).reshape(1, -1) ], axis=1) else: assert False observations.append(obs) rewards.append(reward) terminals.append(done) actions.append(action) agent_infos.append(-1) env_infos.append(-1) path_length += 1 raw_context.append(new_context) obs = next_obs.copy() context = from_numpy(np.concatenate(raw_context, axis=0))[None] self.context_encoder.infer_posterior(context) actions = np.array(actions) if len(actions.shape) == 1: actions = np.expand_dims(actions, 1) observations = np.array(observations) if len(observations.shape) == 1: observations = np.expand_dims(observations, 1) next_obs = np.array([next_obs]) next_observations = np.vstack( (observations[1:, :], np.expand_dims(next_obs, 0))) path = dict( observations=observations, actions=actions, rewards=np.array(rewards).reshape(-1, 1), next_observations=next_observations, terminals=np.array(terminals).reshape(-1, 1), agent_infos=agent_infos, env_infos=env_infos, ) path_len = len(path['actions']) if ( # incomplete path path_len != max_path_length and # that did not end in a terminal state not path['terminals'][-1] and # and we should discard such path discard_incomplete_paths): break num_steps_collected += path_len paths.append(path) self._num_paths_total += len(paths) self._num_steps_total += num_steps_collected self._epoch_paths.extend(paths) return paths, inferred_mdp