def optim_initialize(self, rank=0):
        self.rank = rank
        model = self.agent.model
        self.model_modules = [
            model.observation_encoder, model.observation_decoder,
            model.reward_model, model.representation, model.transition
        ]
        if self.use_pcont:
            self.model_modules += [model.pcont]
        self.actor_modules = [model.action_decoder]
        self.value_modules = [model.value_model]
        self.model_optimizer = torch.optim.Adam(get_parameters(
            self.model_modules),
                                                lr=self.model_lr,
                                                **self.optim_kwargs)
        self.actor_optimizer = torch.optim.Adam(get_parameters(
            self.actor_modules),
                                                lr=self.actor_lr,
                                                **self.optim_kwargs)
        self.value_optimizer = torch.optim.Adam(get_parameters(
            self.value_modules),
                                                lr=self.value_lr,
                                                **self.optim_kwargs)

        if self.initial_optim_state_dict is not None:
            self.load_optim_state_dict(self.initial_optim_state_dict)
        # must define these fields to for logging purposes. Used by runner.
        self.opt_info_fields = OptInfo._fields
Example #2
0
def test_get_parameters():
    linear_module_1 = nn.Linear(4, 3)
    linear_module_2 = nn.Linear(3, 2)

    params = get_parameters([linear_module_1])
    assert len(params) == 2

    params = get_parameters([linear_module_1, linear_module_2])
    assert len(params) == 4
Example #3
0
    def optimize_agent(self, itr, samples=None, sampler_itr=None):
        itr = itr if sampler_itr is None else sampler_itr
        if samples is not None:
            # Note: discount not saved here
            self.replay_buffer.append_samples(samples_to_buffer(samples))

        opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields))))
        if itr < self.prefill:
            return opt_info
        if itr % self.train_every != 0:
            return opt_info
        for i in tqdm(range(self.train_steps), desc='Imagination'):

            samples_from_replay = self.replay_buffer.sample_batch(
                self._batch_size, self.batch_length)
            buffed_samples = buffer_to(samples_from_replay, self.agent.device)
            model_loss, actor_loss, value_loss, loss_info = self.loss(
                buffed_samples, itr, i)

            self.model_optimizer.zero_grad()
            self.actor_optimizer.zero_grad()
            self.value_optimizer.zero_grad()

            model_loss.backward()
            actor_loss.backward()
            value_loss.backward()

            grad_norm_model = torch.nn.utils.clip_grad_norm_(
                get_parameters(self.model_modules), self.grad_clip)
            grad_norm_actor = torch.nn.utils.clip_grad_norm_(
                get_parameters(self.actor_modules), self.grad_clip)
            grad_norm_value = torch.nn.utils.clip_grad_norm_(
                get_parameters(self.value_modules), self.grad_clip)

            self.model_optimizer.step()
            self.actor_optimizer.step()
            self.value_optimizer.step()

            with torch.no_grad():
                loss = model_loss + actor_loss + value_loss
            opt_info.loss.append(loss.item())
            if isinstance(grad_norm_model, torch.Tensor):
                opt_info.grad_norm_model.append(grad_norm_model.item())
                opt_info.grad_norm_actor.append(grad_norm_actor.item())
                opt_info.grad_norm_value.append(grad_norm_value.item())
            else:
                opt_info.grad_norm_model.append(grad_norm_model)
                opt_info.grad_norm_actor.append(grad_norm_actor)
                opt_info.grad_norm_value.append(grad_norm_value)
            for field in loss_info_fields:
                if hasattr(opt_info, field) \
                        and getattr(loss_info, field) is not None:
                    getattr(opt_info, field).append(
                        getattr(loss_info, field).item())

        return opt_info