def test_differentiable_sgd(): """Test second order derivative after taking optimization step.""" policy = torch.nn.Linear(10, 10, bias=False) lr = 0.01 diff_sgd = DifferentiableSGD(policy, lr=lr) named_theta = dict(policy.named_parameters()) theta = list(named_theta.values())[0] meta_loss = torch.sum(theta**2) meta_loss.backward(create_graph=True) diff_sgd.step() theta_prime = list(policy.parameters())[0] loss = torch.sum(theta_prime**2) update_module_params(policy, named_theta) diff_sgd.zero_grad() loss.backward() result = theta.grad dtheta_prime = 1 - 2 * lr # dtheta_prime/dtheta dloss = 2 * theta_prime # dloss/dtheta_prime expected_result = dloss * dtheta_prime # dloss/dtheta assert torch.allclose(result, expected_result)
def train_once(self, runner, all_samples, all_params): """Train the algorithm once. Args: runner (metarl.experiment.LocalRunner): The experiment runner. all_samples (list[list[MAMLTrajectoryBatch]]): A two dimensional list of MAMLTrajectoryBatch of size [meta_batch_size * (num_grad_updates + 1)] all_params (list[dict]): A list of named parameter dictionaries. Each dictionary contains key value pair of names (str) and parameters (torch.Tensor). Returns: float: Average return. """ itr = runner.step_itr old_theta = dict(self._policy.named_parameters()) kl_before = self._compute_kl_constraint(itr, all_samples, all_params, set_grad=False) meta_objective = self._compute_meta_loss(itr, all_samples, all_params) self._meta_optimizer.zero_grad() meta_objective.backward() self._meta_optimize(itr, all_samples, all_params) # Log loss_after = self._compute_meta_loss(itr, all_samples, all_params, set_grad=False) kl_after = self._compute_kl_constraint(itr, all_samples, all_params, set_grad=False) with torch.no_grad(): policy_entropy = self._compute_policy_entropy( [task_samples[0] for task_samples in all_samples]) average_return = self.log_performance(itr, all_samples, meta_objective.item(), loss_after.item(), kl_before.item(), kl_after.item(), policy_entropy.mean().item()) update_module_params(self._old_policy, old_theta) return average_return
def _obtain_samples(self, runner): """Obtain samples for each task before and after the fast-adaptation. Args: runner (LocalRunner): A local runner instance to obtain samples. Returns: tuple: Tuple of (all_samples, all_params). all_samples (list[MAMLTrajectoryBatch]): A list of size [meta_batch_size * (num_grad_updates + 1)] all_params (list[dict]): A list of named parameter dictionaries. """ tasks = self._env.sample_tasks(self._meta_batch_size) all_samples = [[] for _ in range(len(tasks))] all_params = [] theta = dict(self._policy.named_parameters()) for i, task in enumerate(tasks): self._set_task(runner, task) for j in range(self._num_grad_updates + 1): paths = runner.obtain_samples(runner.step_itr) batch_samples = self._process_samples(runner.step_itr, paths) all_samples[i].append(batch_samples) # The last iteration does only sampling but no adapting if j != self._num_grad_updates: self._adapt(runner.step_itr, batch_samples, set_grad=False) all_params.append(dict(self._policy.named_parameters())) # Restore to pre-updated policy update_module_params(self._policy, theta) return all_samples, all_params
def _compute_kl_constraint(self, itr, all_samples, all_params, set_grad=True): """Compute KL divergence. For each task, compute the KL divergence between the old policy distribution and current policy distribution. Args: itr (int): Iteration number. all_samples (list[list[MAMLTrajectoryBatch]]): Two dimensional list of MAMLTrajectoryBatch of size [meta_batch_size * (num_grad_updates + 1)] all_params (list[dict]): A list of named parameter dictionaries. Each dictionary contains key value pair of names (str) and parameters (torch.Tensor). set_grad (bool): Whether to enable gradient calculation or not. Returns: torch.Tensor: Calculated mean value of KL divergence. """ theta = dict(self._policy.named_parameters()) old_theta = dict(self._old_policy.named_parameters()) kls = [] for task_samples, task_params in zip(all_samples, all_params): for i in range(self._num_grad_updates): self._adapt(itr, task_samples[i], set_grad=set_grad) update_module_params(self._old_policy, task_params) with torch.set_grad_enabled(set_grad): # pylint: disable=protected-access kl = self._inner_algo._compute_kl_constraint( task_samples[-1].observations) kls.append(kl) update_module_params(self._policy, theta) update_module_params(self._old_policy, old_theta) return torch.stack(kls).mean()
def _compute_meta_loss(self, itr, all_samples, all_params, set_grad=True): """Compute loss to meta-optimize. Args: itr (int): Iteration number. all_samples (list[list[MAMLTrajectoryBatch]]): A two dimensional list of MAMLTrajectoryBatch of size [meta_batch_size * (num_grad_updates + 1)] all_params (list[dict]): A list of named parameter dictionaries. Each dictionary contains key value pair of names (str) and parameters (torch.Tensor). set_grad (bool): Whether to enable gradient calculation or not. Returns: torch.Tensor: Calculated mean value of loss. """ theta = dict(self._policy.named_parameters()) old_theta = dict(self._old_policy.named_parameters()) losses = [] for task_samples, task_params in zip(all_samples, all_params): for i in range(self._num_grad_updates): self._adapt(itr, task_samples[i], set_grad=set_grad) update_module_params(self._old_policy, task_params) with torch.set_grad_enabled(set_grad): # pylint: disable=protected-access last_update = task_samples[-1] loss = self._inner_algo._compute_loss(itr, *last_update[1:]) losses.append(loss) update_module_params(self._policy, theta) update_module_params(self._old_policy, old_theta) return torch.stack(losses).mean()