コード例 #1
0
def test_differentiable_sgd():
    """Test second order derivative after taking optimization step."""
    policy = torch.nn.Linear(10, 10, bias=False)
    lr = 0.01
    diff_sgd = DifferentiableSGD(policy, lr=lr)

    named_theta = dict(policy.named_parameters())
    theta = list(named_theta.values())[0]
    meta_loss = torch.sum(theta**2)
    meta_loss.backward(create_graph=True)

    diff_sgd.step()

    theta_prime = list(policy.parameters())[0]
    loss = torch.sum(theta_prime**2)
    update_module_params(policy, named_theta)
    diff_sgd.zero_grad()
    loss.backward()

    result = theta.grad

    dtheta_prime = 1 - 2 * lr  # dtheta_prime/dtheta
    dloss = 2 * theta_prime  # dloss/dtheta_prime
    expected_result = dloss * dtheta_prime  # dloss/dtheta

    assert torch.allclose(result, expected_result)
コード例 #2
0
    def train_once(self, runner, all_samples, all_params):
        """Train the algorithm once.

        Args:
            runner (metarl.experiment.LocalRunner): The experiment runner.
            all_samples (list[list[MAMLTrajectoryBatch]]): A two
                dimensional list of MAMLTrajectoryBatch of size
                [meta_batch_size * (num_grad_updates + 1)]
            all_params (list[dict]): A list of named parameter dictionaries.
                Each dictionary contains key value pair of names (str) and
                parameters (torch.Tensor).

        Returns:
            float: Average return.

        """
        itr = runner.step_itr
        old_theta = dict(self._policy.named_parameters())

        kl_before = self._compute_kl_constraint(itr,
                                                all_samples,
                                                all_params,
                                                set_grad=False)

        meta_objective = self._compute_meta_loss(itr, all_samples, all_params)

        self._meta_optimizer.zero_grad()
        meta_objective.backward()

        self._meta_optimize(itr, all_samples, all_params)

        # Log
        loss_after = self._compute_meta_loss(itr,
                                             all_samples,
                                             all_params,
                                             set_grad=False)
        kl_after = self._compute_kl_constraint(itr,
                                               all_samples,
                                               all_params,
                                               set_grad=False)

        with torch.no_grad():
            policy_entropy = self._compute_policy_entropy(
                [task_samples[0] for task_samples in all_samples])
            average_return = self.log_performance(itr, all_samples,
                                                  meta_objective.item(),
                                                  loss_after.item(),
                                                  kl_before.item(),
                                                  kl_after.item(),
                                                  policy_entropy.mean().item())

        update_module_params(self._old_policy, old_theta)

        return average_return
コード例 #3
0
    def _obtain_samples(self, runner):
        """Obtain samples for each task before and after the fast-adaptation.

        Args:
            runner (LocalRunner): A local runner instance to obtain samples.

        Returns:
            tuple: Tuple of (all_samples, all_params).
                all_samples (list[MAMLTrajectoryBatch]): A list of size
                    [meta_batch_size * (num_grad_updates + 1)]
                all_params (list[dict]): A list of named parameter
                    dictionaries.

        """
        tasks = self._env.sample_tasks(self._meta_batch_size)
        all_samples = [[] for _ in range(len(tasks))]
        all_params = []
        theta = dict(self._policy.named_parameters())

        for i, task in enumerate(tasks):
            self._set_task(runner, task)

            for j in range(self._num_grad_updates + 1):
                paths = runner.obtain_samples(runner.step_itr)
                batch_samples = self._process_samples(runner.step_itr, paths)
                all_samples[i].append(batch_samples)

                # The last iteration does only sampling but no adapting
                if j != self._num_grad_updates:
                    self._adapt(runner.step_itr, batch_samples, set_grad=False)

            all_params.append(dict(self._policy.named_parameters()))
            # Restore to pre-updated policy
            update_module_params(self._policy, theta)

        return all_samples, all_params
コード例 #4
0
    def _compute_kl_constraint(self,
                               itr,
                               all_samples,
                               all_params,
                               set_grad=True):
        """Compute KL divergence.

        For each task, compute the KL divergence between the old policy
        distribution and current policy distribution.

        Args:
            itr (int): Iteration number.
            all_samples (list[list[MAMLTrajectoryBatch]]): Two
                dimensional list of MAMLTrajectoryBatch of size
                [meta_batch_size * (num_grad_updates + 1)]
            all_params (list[dict]): A list of named parameter dictionaries.
                Each dictionary contains key value pair of names (str) and
                parameters (torch.Tensor).
            set_grad (bool): Whether to enable gradient calculation or not.

        Returns:
            torch.Tensor: Calculated mean value of KL divergence.

        """
        theta = dict(self._policy.named_parameters())
        old_theta = dict(self._old_policy.named_parameters())

        kls = []
        for task_samples, task_params in zip(all_samples, all_params):
            for i in range(self._num_grad_updates):
                self._adapt(itr, task_samples[i], set_grad=set_grad)

            update_module_params(self._old_policy, task_params)
            with torch.set_grad_enabled(set_grad):
                # pylint: disable=protected-access
                kl = self._inner_algo._compute_kl_constraint(
                    task_samples[-1].observations)
            kls.append(kl)

            update_module_params(self._policy, theta)
            update_module_params(self._old_policy, old_theta)

        return torch.stack(kls).mean()
コード例 #5
0
    def _compute_meta_loss(self, itr, all_samples, all_params, set_grad=True):
        """Compute loss to meta-optimize.

        Args:
            itr (int): Iteration number.
            all_samples (list[list[MAMLTrajectoryBatch]]): A two
                dimensional list of MAMLTrajectoryBatch of size
                [meta_batch_size * (num_grad_updates + 1)]
            all_params (list[dict]): A list of named parameter dictionaries.
                Each dictionary contains key value pair of names (str) and
                parameters (torch.Tensor).
            set_grad (bool): Whether to enable gradient calculation or not.

        Returns:
            torch.Tensor: Calculated mean value of loss.

        """
        theta = dict(self._policy.named_parameters())
        old_theta = dict(self._old_policy.named_parameters())

        losses = []
        for task_samples, task_params in zip(all_samples, all_params):
            for i in range(self._num_grad_updates):
                self._adapt(itr, task_samples[i], set_grad=set_grad)

            update_module_params(self._old_policy, task_params)
            with torch.set_grad_enabled(set_grad):
                # pylint: disable=protected-access
                last_update = task_samples[-1]
                loss = self._inner_algo._compute_loss(itr, *last_update[1:])
            losses.append(loss)

            update_module_params(self._policy, theta)
            update_module_params(self._old_policy, old_theta)

        return torch.stack(losses).mean()