示例#1
0
 def _reconstruction_loss(self, Y, O, R, free_nats, global_prior):
     B, S_pri, MU_pri, STD_pri, S_pos, MU_pos, STD_pos = Y
     o_loss = F.mse_loss(bottle(self.wm.d_model, (B, S_pos)),
                         O[:, 1:].cuda(),
                         reduction='none').sum(2).mean()
     r_loss = F.mse_loss(bottle(self.wm.r_model, (B, S_pos)),
                         R[:, :-1].cuda(),
                         reduction='none').mean()
     kl_loss = torch.max(
         kl_divergence(Normal(MU_pos, STD_pos),
                       Normal(MU_pri, STD_pri)).sum(2), free_nats).mean()
     if cfg.global_kl_beta != 0:
         kl_loss += cfg.global_kl_beta * kl_divergence(
             Normal(MU_pos, STD_pos), global_prior).sum(2).mean()
     return o_loss, r_loss, kl_loss
示例#2
0
    def _train_policy(self, metrics: dict, D: ExperienceReplay, epoch: int,
                      global_prior, free_nats):
        self.wm.eval()
        self.policy.train()
        losses = []
        for _ in tqdm(range(cfg.collect_interval),
                      desc=poem(f"{epoch} Policy Interval"),
                      leave=False):
            O, A, _, M = D.sample()
            with torch.no_grad():
                b_0, _, _, _, s_0, _, _ = self.wm.t_model(
                    torch.zeros(cfg.batch_size, cfg.state_size), A[:, :-1],
                    torch.zeros(cfg.batch_size, cfg.belief_size),
                    bottle(self.wm.e_model, (O[:, 1:], )), M[:, :-1])
                b_0 = b_0.view(-1, cfg.belief_size)
                s_0 = s_0.view(cfg.batch_size * (cfg.chunk_size - 1),
                               cfg.state_size)
                m0 = M[:, 1:].reshape(cfg.batch_size *
                                      (cfg.chunk_size - 1)).byte()
                # b_0, s_0 = b_0[m0], s_0[m0]

            T = cfg.planning_horizon + 1
            B, S = [torch.empty(0)] * T, [torch.empty(0)] * T
            B[0], S[0] = b_0, s_0

            for t in range(T - 1):
                # forward actions
                A = self.policy(B[t], S[t])
                b_t, s_t, _, _ = self.wm.t_model(S[t], A.unsqueeze(dim=1),
                                                 B[t])
                B[t + 1], S[t + 1] = b_t.squeeze(dim=1), s_t.squeeze(dim=1)

            loss = -self.wm.r_model(torch.cat(B, dim=0), torch.cat(
                S, dim=0)).mean()

            if cfg.learning_rate_schedule != 0:
                _linearly_ramping_lr(self.plcy_optimizer,
                                     cfg.learning_rate_plcy)

            self.plcy_optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(self.policy.parameters(),
                                     cfg.grad_clip_norm,
                                     norm_type=2)
            self.plcy_optimizer.step()  # S

            losses.append(loss.item())
        metrics['p_loss'].append(mean(losses))
示例#3
0
    def _train_worldmodel(self, metrics: dict, D: ExperienceReplay, epoch: int,
                          global_prior, free_nats):
        losses = []
        for _ in tqdm(range(cfg.collect_interval_worm),
                      desc=poem(f"{epoch} Train Interval"),
                      leave=False):
            # self.optimizer.zero_grad()
            O, A, R, M = D.sample()

            b_0 = torch.zeros(cfg.batch_size, cfg.belief_size)
            s_0 = torch.zeros(cfg.batch_size, cfg.state_size)

            # Y := B, S_pri, MU_pri, STD_pri, S_pos, MU_pos, STD_pos
            Y = self.wm.t_model(s_0, A[:, :-1], b_0,
                                bottle(self.wm.e_model, (O[:, 1:], )),
                                M[:, :-1])

            o_loss, r_loss, kl_loss = self._reconstruction_loss(
                Y, O, R, free_nats, global_prior)

            if cfg.overshooting_kl_beta != 0:
                kl_loss += self._latent_overshooting(Y, A, M, free_nats)

            if cfg.learning_rate_schedule != 0:
                self._linearly_ramping_lr(self.wm_optimizer)

            self.wm_optimizer.zero_grad()
            (o_loss + r_loss + kl_loss).backward()
            nn.utils.clip_grad_norm_(self.param_list,
                                     cfg.grad_clip_norm,
                                     norm_type=2)
            self.wm_optimizer.step()

            losses.append([o_loss.item(), r_loss.item(), kl_loss.item()])

        o_loss, r_loss, kl_loss = tuple(zip(*losses))
        metrics['o_loss'].append(mean(o_loss))
        metrics['r_loss'].append(mean(r_loss))
        metrics['kl_loss'].append(mean(kl_loss))