def test_differentiable_sgd(): """Test second order derivative after taking optimization step.""" policy = torch.nn.Linear(10, 10, bias=False) lr = 0.01 diff_sgd = DifferentiableSGD(policy, lr=lr) named_theta = dict(policy.named_parameters()) theta = list(named_theta.values())[0] meta_loss = torch.sum(theta**2) meta_loss.backward(create_graph=True) diff_sgd.step() theta_prime = list(policy.parameters())[0] loss = torch.sum(theta_prime**2) update_module_params(policy, named_theta) diff_sgd.zero_grad() loss.backward() result = theta.grad dtheta_prime = 1 - 2 * lr # dtheta_prime/dtheta dloss = 2 * theta_prime # dloss/dtheta_prime expected_result = dloss * dtheta_prime # dloss/dtheta assert torch.allclose(result, expected_result)
def train_once(self, runner, all_samples, all_params): """Train the algorithm once. Args: runner (garage.experiment.LocalRunner): The experiment runner. all_samples (list[list[MAMLTrajectoryBatch]]): A two dimensional list of MAMLTrajectoryBatch of size [meta_batch_size * (num_grad_updates + 1)] all_params (list[dict]): A list of named parameter dictionaries. Each dictionary contains key value pair of names (str) and parameters (torch.Tensor). Returns: float: Average return. """ itr = runner.step_itr old_theta = dict(self._policy.named_parameters()) kl_before = self._compute_kl_constraint(all_samples, all_params, set_grad=False) meta_objective = self._compute_meta_loss(all_samples, all_params) self._meta_optimizer.zero_grad() meta_objective.backward() self._meta_optimize(all_samples, all_params) # Log loss_after = self._compute_meta_loss(all_samples, all_params, set_grad=False) kl_after = self._compute_kl_constraint(all_samples, all_params, set_grad=False) with torch.no_grad(): policy_entropy = self._compute_policy_entropy( [task_samples[0] for task_samples in all_samples]) average_return = self.log_performance(itr, all_samples, meta_objective.item(), loss_after.item(), kl_before.item(), kl_after.item(), policy_entropy.mean().item()) if self._meta_evaluator and itr % self._evaluate_every_n_epochs == 0: self._meta_evaluator.evaluate(self) update_module_params(self._old_policy, old_theta) return average_return
def _obtain_samples(self, runner): """Obtain samples for each task before and after the fast-adaptation. Args: runner (LocalRunner): A local runner instance to obtain samples. Returns: tuple: Tuple of (all_samples, all_params). all_samples (list[MAMLEpisodeBatch]): A list of size [meta_batch_size * (num_grad_updates + 1)] all_params (list[dict]): A list of named parameter dictionaries. """ tasks = self._env.sample_tasks(self._meta_batch_size) all_samples = [[] for _ in range(len(tasks))] all_params = [] theta = dict(self._policy.named_parameters()) for i, task in enumerate(tasks): for j in range(self._num_grad_updates + 1): env_up = SetTaskUpdate(None, task=task) episodes = runner.obtain_samples(runner.step_itr, env_update=env_up) batch_samples = self._process_samples(episodes) all_samples[i].append(batch_samples) # The last iteration does only sampling but no adapting if j < self._num_grad_updates: # A grad need to be kept for the next grad update # Except for the last grad update require_grad = j < self._num_grad_updates - 1 self._adapt(batch_samples, set_grad=require_grad) all_params.append(dict(self._policy.named_parameters())) # Restore to pre-updated policy update_module_params(self._policy, theta) return all_samples, all_params
def _compute_meta_loss(self, all_samples, all_params, set_grad=True): """Compute loss to meta-optimize. Args: all_samples (list[list[_MAMLEpisodeBatch]]): A two dimensional list of _MAMLEpisodeBatch of size [meta_batch_size * (num_grad_updates + 1)] all_params (list[dict]): A list of named parameter dictionaries. Each dictionary contains key value pair of names (str) and parameters (torch.Tensor). set_grad (bool): Whether to enable gradient calculation or not. Returns: torch.Tensor: Calculated mean value of loss. """ theta = dict(self._policy.named_parameters()) old_theta = dict(self._old_policy.named_parameters()) losses = [] for task_samples, task_params in zip(all_samples, all_params): with torch.set_grad_enabled(set_grad): # SG-MRL specific # pylint: disable=protected-access initial_samples = task_samples[0] init_log_probs = self._inner_algo._compute_log_probs(*initial_samples[1:]) for i in range(self._num_grad_updates): require_grad = i < self._num_grad_updates - 1 or set_grad self._adapt(task_samples[i], set_grad=require_grad) update_module_params(self._old_policy, task_params) with torch.set_grad_enabled(set_grad): # pylint: disable=protected-access last_update = task_samples[-1] loss = self._inner_algo._compute_loss(*last_update[1:]) # SG-MRL specific with torch.set_grad_enabled(False): adapted_reward = last_update.rewards.detach().clone().numpy() # note that we treat it as a constant j_tilde = np.mean([discount_cumsum(path, self._inner_algo.discount)[0] for path in adapted_reward]) # SG-MRL specific loss += j_tilde * init_log_probs losses.append(loss) update_module_params(self._policy, theta) update_module_params(self._old_policy, old_theta) return torch.stack(losses).mean()
def _compute_kl_constraint(self, all_samples, all_params, set_grad=True): """Compute KL divergence. For each task, compute the KL divergence between the old policy distribution and current policy distribution. Args: all_samples (list[list[_MAMLEpisodeBatch]]): Two dimensional list of _MAMLEpisodeBatch of size [meta_batch_size * (num_grad_updates + 1)] all_params (list[dict]): A list of named parameter dictionaries. Each dictionary contains key value pair of names (str) and parameters (torch.Tensor). set_grad (bool): Whether to enable gradient calculation or not. Returns: torch.Tensor: Calculated mean value of KL divergence. """ theta = dict(self._policy.named_parameters()) old_theta = dict(self._old_policy.named_parameters()) kls = [] for task_samples, task_params in zip(all_samples, all_params): for i in range(self._num_grad_updates): require_grad = i < self._num_grad_updates - 1 or set_grad self._adapt(task_samples[i], set_grad=require_grad) update_module_params(self._old_policy, task_params) with torch.set_grad_enabled(set_grad): # pylint: disable=protected-access kl = self._inner_algo._compute_kl_constraint( task_samples[-1].observations) kls.append(kl) update_module_params(self._policy, theta) update_module_params(self._old_policy, old_theta) return torch.stack(kls).mean()
def _compute_meta_loss(self, all_samples, all_params, set_grad=True): """Compute loss to meta-optimize. Args: all_samples (list[list[_MAMLEpisodeBatch]]): A two dimensional list of _MAMLEpisodeBatch of size [meta_batch_size * (num_grad_updates + 1)] all_params (list[dict]): A list of named parameter dictionaries. Each dictionary contains key value pair of names (str) and parameters (torch.Tensor). set_grad (bool): Whether to enable gradient calculation or not. Returns: torch.Tensor: Calculated mean value of loss. """ theta = dict(self._policy.named_parameters()) old_theta = dict(self._old_policy.named_parameters()) losses = [] for task_samples, task_params in zip(all_samples, all_params): for i in range(self._num_grad_updates): require_grad = i < self._num_grad_updates - 1 or set_grad self._adapt(task_samples[i], set_grad=require_grad) update_module_params(self._old_policy, task_params) with torch.set_grad_enabled(set_grad): # pylint: disable=protected-access last_update = task_samples[-1] loss = self._inner_algo._compute_loss(*last_update[1:]) losses.append(loss) update_module_params(self._policy, theta) update_module_params(self._old_policy, old_theta) return torch.stack(losses).mean()