예제 #1
0
    def __init__(self,
                 inner_algo,
                 env,
                 policy,
                 baseline,
                 meta_optimizer,
                 meta_batch_size=40,
                 inner_lr=0.1,
                 outer_lr=1e-3,
                 num_grad_updates=1):
        if policy.vectorized:
            self.sampler_cls = OnPolicyVectorizedSampler
        else:
            self.sampler_cls = BatchSampler

        self.max_path_length = inner_algo.max_path_length
        self._policy = policy
        self._env = env
        self._baseline = baseline
        self._num_grad_updates = num_grad_updates
        self._meta_batch_size = meta_batch_size
        self._inner_algo = inner_algo
        self._inner_optimizer = DifferentiableSGD(self._policy, lr=inner_lr)
        self._meta_optimizer = make_optimizer(meta_optimizer,
                                              policy,
                                              lr=_Default(outer_lr),
                                              eps=_Default(1e-5))
예제 #2
0
    def __init__(self,
                 inner_algo,
                 env,
                 policy,
                 value_function,
                 meta_optimizer,
                 meta_batch_size=40,
                 inner_lr=0.1,
                 outer_lr=1e-3,
                 num_grad_updates=1,
                 meta_evaluator=None,
                 evaluate_every_n_epochs=1):
        if policy.vectorized:
            self.sampler_cls = OnPolicyVectorizedSampler
        else:
            self.sampler_cls = BatchSampler

        self.max_path_length = inner_algo.max_path_length

        self._meta_evaluator = meta_evaluator
        self._policy = policy
        self._env = env
        self._value_function = value_function
        self._num_grad_updates = num_grad_updates
        self._meta_batch_size = meta_batch_size
        self._inner_algo = inner_algo
        self._inner_optimizer = DifferentiableSGD(self._policy, lr=inner_lr)
        self._meta_optimizer = make_optimizer(meta_optimizer,
                                              policy,
                                              lr=_Default(outer_lr),
                                              eps=_Default(1e-5))
        self._evaluate_every_n_epochs = evaluate_every_n_epochs
예제 #3
0
    def __init__(self,
                 inner_algo,
                 env,
                 policy,
                 sampler,
                 task_sampler,
                 meta_optimizer,
                 meta_batch_size=40,
                 inner_lr=0.1,
                 outer_lr=1e-3,
                 num_grad_updates=1,
                 meta_evaluator=None,
                 evaluate_every_n_epochs=1):
        self._sampler = sampler

        self.max_episode_length = inner_algo.max_episode_length

        self._meta_evaluator = meta_evaluator
        self._policy = policy
        self._env = env
        self._task_sampler = task_sampler
        self._value_function = copy.deepcopy(inner_algo._value_function)
        self._initial_vf_state = self._value_function.state_dict()
        self._num_grad_updates = num_grad_updates
        self._meta_batch_size = meta_batch_size
        self._inner_algo = inner_algo
        self._inner_optimizer = DifferentiableSGD(self._policy, lr=inner_lr)
        self._meta_optimizer = make_optimizer(meta_optimizer,
                                              module=policy,
                                              lr=_Default(outer_lr),
                                              eps=_Default(1e-5))
        self._evaluate_every_n_epochs = evaluate_every_n_epochs
예제 #4
0
def test_differentiable_sgd():
    """Test second order derivative after taking optimization step."""
    policy = torch.nn.Linear(10, 10, bias=False)
    lr = 0.01
    diff_sgd = DifferentiableSGD(policy, lr=lr)

    named_theta = dict(policy.named_parameters())
    theta = list(named_theta.values())[0]
    meta_loss = torch.sum(theta**2)
    meta_loss.backward(create_graph=True)

    diff_sgd.step()

    theta_prime = list(policy.parameters())[0]
    loss = torch.sum(theta_prime**2)
    update_module_params(policy, named_theta)
    diff_sgd.zero_grad()
    loss.backward()

    result = theta.grad

    dtheta_prime = 1 - 2 * lr  # dtheta_prime/dtheta
    dloss = 2 * theta_prime  # dloss/dtheta_prime
    expected_result = dloss * dtheta_prime  # dloss/dtheta

    assert torch.allclose(result, expected_result)