Exemplo n.º 1
0
    def train_model(self, observations_tensor, ext_returns_tensor,
                    int_returns_tensor, actions_tensor, advantages_tensor,
                    one_channel_observations_tensor, old_log_prob):

        if flag.DEBUG:
            print("input observations shape", observations_tensor.shape)
            print("ext returns shape", ext_returns_tensor.shape)
            print("int returns shape", int_returns_tensor.shape)
            print("input actions shape", actions_tensor.shape)
            print("input advantages shape", advantages_tensor.shape)
            print("one channel observations",
                  one_channel_observations_tensor.shape)

        self.new_model.train()
        self.predictor_model.train()
        target_value = self.target_model(one_channel_observations_tensor)
        predictor_value = self.predictor_model(one_channel_observations_tensor)
        predictor_loss = self.predictor_mse_loss(predictor_value,
                                                 target_value).mean(-1)

        mask = torch.rand(len(predictor_loss)).to(self.device)
        mask = (mask < self.predictor_update_proportion).type(
            torch.FloatTensor).to(self.device)
        predictor_loss = (predictor_loss * mask).sum() / torch.max(
            mask.sum(),
            torch.Tensor([1]).to(self.device))
        new_policy, ext_new_values, int_new_values = self.new_model(
            observations_tensor)
        ext_value_loss = self.mse_loss(ext_new_values, ext_returns_tensor)
        int_value_loss = self.mse_loss(int_new_values, int_returns_tensor)
        value_loss = ext_value_loss + int_value_loss
        softmax_policy = F.softmax(new_policy, dim=1)
        new_dist = Categorical(softmax_policy)
        new_log_prob = new_dist.log_prob(actions_tensor)

        ratio = torch.exp(new_log_prob - old_log_prob)

        clipped_policy_loss = torch.clamp(ratio, 1.0 - self.clip_range,
                                          1 + self.clip_range) \
                                          * advantages_tensor
        policy_loss = ratio * advantages_tensor

        selected_policy_loss = -torch.min(clipped_policy_loss,
                                          policy_loss).mean()
        entropy = new_dist.entropy().mean()
        self.optimizer.zero_grad()

        loss = selected_policy_loss + (self.value_coef * value_loss) \
            - (self.entropy_coef * entropy) + predictor_loss
        loss.backward()

        global_grad_norm_(
            list(self.new_model.parameters()) +
            list(self.predictor_model.parameters()))

        self.optimizer.step()
        return loss, selected_policy_loss, value_loss, predictor_loss, entropy
    def train_just_vae(self, s_batch, next_obs_batch):
        s_batch = torch.FloatTensor(s_batch).to(self.device)
        next_obs_batch = torch.FloatTensor(next_obs_batch).to(self.device)

        sample_range = np.arange(len(s_batch))
        reconstruction_loss = nn.MSELoss(reduction='none')

        recon_losses = np.array([])
        kld_losses = np.array([])

        for i in range(self.epoch):
            np.random.shuffle(sample_range)
            for j in range(int(len(s_batch) / self.batch_size)):
                sample_idx = sample_range[self.batch_size * j:self.batch_size *
                                          (j + 1)]

                # --------------------------------------------------------------------------------
                # for generative curiosity (VAE loss)
                gen_next_state, mu, logvar = self.vae(
                    next_obs_batch[sample_idx])

                d = len(gen_next_state.shape)
                recon_loss = -1 * pytorch_ssim.ssim(gen_next_state,
                                                    next_obs_batch[sample_idx],
                                                    size_average=False)
                # recon_loss = reconstruction_loss(gen_next_state, next_obs_batch[sample_idx]).mean(axis=list(range(1, d)))

                kld_loss = -0.5 * (1 + logvar - mu.pow(2) -
                                   logvar.exp()).sum(axis=1)

                # TODO: keep this proportion of experience used for VAE update?
                # Proportion of experience used for VAE update
                mask = torch.rand(len(recon_loss)).to(self.device)
                mask = (mask < self.update_proportion).type(
                    torch.FloatTensor).to(self.device)
                recon_loss = (recon_loss * mask).sum() / torch.max(
                    mask.sum(),
                    torch.Tensor([1]).to(self.device))
                kld_loss = (kld_loss * mask).sum() / torch.max(
                    mask.sum(),
                    torch.Tensor([1]).to(self.device))

                recon_losses = np.append(recon_losses,
                                         recon_loss.detach().cpu().numpy())
                kld_losses = np.append(kld_losses,
                                       kld_loss.detach().cpu().numpy())
                # ---------------------------------------------------------------------------------

                self.optimizer.zero_grad()
                loss = recon_loss + kld_loss
                loss.backward()
                global_grad_norm_(list(self.vae.parameters()))
                self.optimizer.step()

        return recon_losses, kld_losses
Exemplo n.º 3
0
    def train_model(self, s_batch, target_ext_batch, target_int_batch, y_batch,
                    adv_batch, next_obs_batch, old_policy):
        s_batch = torch.FloatTensor(s_batch).to(self.device)
        target_ext_batch = torch.FloatTensor(target_ext_batch).to(self.device)
        target_int_batch = torch.FloatTensor(target_int_batch).to(self.device)
        y_batch = torch.LongTensor(y_batch).to(self.device)
        adv_batch = torch.FloatTensor(adv_batch).to(self.device)
        next_obs_batch = torch.FloatTensor(next_obs_batch).to(self.device)

        sample_range = np.arange(len(s_batch))
        forward_mse = nn.MSELoss(reduction='none')

        # Get old policy
        with torch.no_grad():
            policy_old_list = torch.stack(old_policy).permute(
                1, 0, 2).contiguous().view(-1,
                                           self.output_size).to(self.device)

            m_old = Categorical(F.softmax(policy_old_list, dim=-1))
            log_prob_old = m_old.log_prob(y_batch)
            # ------------------------------------------------------------

        for i in range(self.epoch):
            # Here we'll do minibatches of training
            np.random.shuffle(sample_range)
            for j in range(int(len(s_batch) / self.batch_size)):
                sample_idx = sample_range[self.batch_size * j:self.batch_size *
                                          (j + 1)]

                # --------------------------------------------------------------------------------
                # for Curiosity-driven(Random Network Distillation)
                predict_next_state_feature, target_next_state_feature = self.rnd(
                    next_obs_batch[sample_idx])

                forward_loss = forward_mse(
                    predict_next_state_feature,
                    target_next_state_feature.detach()).mean(-1)
                # Proportion of exp used for predictor update
                mask = torch.rand(len(forward_loss)).to(self.device)
                mask = (mask < self.update_proportion).type(
                    torch.FloatTensor).to(self.device)
                forward_loss = (forward_loss * mask).sum() / torch.max(
                    mask.sum(),
                    torch.Tensor([1]).to(self.device))
                # ---------------------------------------------------------------------------------

                policy, value_ext, value_int = self.model(s_batch[sample_idx])
                m = Categorical(F.softmax(policy, dim=-1))
                log_prob = m.log_prob(y_batch[sample_idx])

                ratio = torch.exp(log_prob - log_prob_old[sample_idx])

                surr1 = ratio * adv_batch[sample_idx]
                surr2 = torch.clamp(ratio, 1.0 - self.ppo_eps,
                                    1.0 + self.ppo_eps) * adv_batch[sample_idx]

                # Calculate actor loss
                # - J is equivalent to max J hence -torch
                actor_loss = -torch.min(surr1, surr2).mean()

                # Calculate critic loss
                critic_ext_loss = F.mse_loss(value_ext.sum(1),
                                             target_ext_batch[sample_idx])
                critic_int_loss = F.mse_loss(value_int.sum(1),
                                             target_int_batch[sample_idx])

                # Critic loss = critic E loss + critic I loss
                critic_loss = critic_ext_loss + critic_int_loss

                # Calculate the entropy
                # Entropy is used to improve exploration by limiting the premature convergence to suboptimal policy.
                entropy = m.entropy().mean()

                # Reset the gradients
                self.optimizer.zero_grad()

                # CALCULATE THE LOSS
                # Total loss = Policy gradient loss - entropy * entropy coefficient + Value coefficient * value loss + forward_loss
                loss = actor_loss + 0.5 * critic_loss - self.ent_coef * entropy + forward_loss

                # Backpropagation
                loss.backward()
                global_grad_norm_(
                    list(self.model.parameters()) +
                    list(self.rnd.predictor.parameters()))
                self.optimizer.step()
    def train_model(self, s_batch, target_ext_batch, target_int_batch, y_batch,
                    adv_batch, next_obs_batch, old_policy):
        s_batch = torch.FloatTensor(s_batch).to(self.device)
        target_ext_batch = torch.FloatTensor(target_ext_batch).to(self.device)
        target_int_batch = torch.FloatTensor(target_int_batch).to(self.device)
        y_batch = torch.LongTensor(y_batch).to(self.device)
        adv_batch = torch.FloatTensor(adv_batch).to(self.device)
        next_obs_batch = torch.FloatTensor(next_obs_batch).to(self.device)

        sample_range = np.arange(len(s_batch))
        forward_mse = nn.MSELoss(reduction='none')

        with torch.no_grad():
            policy_old_list = torch.stack(old_policy).permute(
                1, 0, 2).contiguous().view(-1,
                                           self.output_size).to(self.device)

            m_old = Categorical(F.softmax(policy_old_list, dim=-1))
            log_prob_old = m_old.log_prob(y_batch)
            # ------------------------------------------------------------

        for i in range(self.epoch):
            np.random.shuffle(sample_range)
            for j in range(int(len(s_batch) / self.batch_size)):
                sample_idx = sample_range[self.batch_size * j:self.batch_size *
                                          (j + 1)]

                # --------------------------------------------------------------------------------
                # for Curiosity-driven(Random Network Distillation)
                predict_next_state_feature, target_next_state_feature = self.rnd(
                    next_obs_batch[sample_idx])

                forward_loss = forward_mse(
                    predict_next_state_feature,
                    target_next_state_feature.detach()).mean(-1)
                # Proportion of exp used for predictor update
                mask = torch.rand(len(forward_loss)).to(self.device)
                mask = (mask < self.update_proportion).type(
                    torch.FloatTensor).to(self.device)
                forward_loss = (forward_loss * mask).sum() / torch.max(
                    mask.sum(),
                    torch.Tensor([1]).to(self.device))
                # ---------------------------------------------------------------------------------

                policy, value_ext, value_int = self.model(s_batch[sample_idx])
                m = Categorical(F.softmax(policy, dim=-1))
                log_prob = m.log_prob(y_batch[sample_idx])

                ratio = torch.exp(log_prob - log_prob_old[sample_idx])

                surr1 = ratio * adv_batch[sample_idx]
                surr2 = torch.clamp(ratio, 1.0 - self.ppo_eps,
                                    1.0 + self.ppo_eps) * adv_batch[sample_idx]

                actor_loss = -torch.min(surr1, surr2).mean()
                critic_ext_loss = F.mse_loss(value_ext.sum(1),
                                             target_ext_batch[sample_idx])
                critic_int_loss = F.mse_loss(value_int.sum(1),
                                             target_int_batch[sample_idx])

                critic_loss = critic_ext_loss + critic_int_loss

                entropy = m.entropy().mean()

                self.optimizer.zero_grad()
                loss = actor_loss + 0.5 * critic_loss - self.ent_coef * entropy + forward_loss
                loss.backward()
                global_grad_norm_(
                    list(self.model.parameters()) +
                    list(self.rnd.predictor.parameters()))
                self.optimizer.step()
    def train_model(self, s_batch, target_ext_batch, target_int_batch, y_batch,
                    adv_batch, next_obs_batch, old_policy):
        s_batch = torch.FloatTensor(s_batch).to(self.device)
        target_ext_batch = torch.FloatTensor(target_ext_batch).to(self.device)
        target_int_batch = torch.FloatTensor(target_int_batch).to(self.device)
        y_batch = torch.LongTensor(y_batch).to(self.device)
        adv_batch = torch.FloatTensor(adv_batch).to(self.device)
        next_obs_batch = torch.FloatTensor(next_obs_batch).to(self.device)

        sample_range = np.arange(len(s_batch))
        reconstruction_loss = nn.MSELoss(reduction='none')

        with torch.no_grad():
            policy_old_list = torch.stack(old_policy).permute(
                1, 0, 2).contiguous().view(-1,
                                           self.output_size).to(self.device)

            m_old = Categorical(F.softmax(policy_old_list, dim=-1))
            log_prob_old = m_old.log_prob(y_batch)
            # ------------------------------------------------------------

        recon_losses = np.array([])
        kld_losses = np.array([])

        for i in range(self.epoch):
            np.random.shuffle(sample_range)
            for j in range(int(len(s_batch) / self.batch_size)):
                sample_idx = sample_range[self.batch_size * j:self.batch_size *
                                          (j + 1)]

                # --------------------------------------------------------------------------------
                # for generative curiosity (VAE loss)
                gen_next_state, mu, logvar = self.vae(
                    next_obs_batch[sample_idx])

                d = len(gen_next_state.shape)
                recon_loss = reconstruction_loss(
                    gen_next_state,
                    next_obs_batch[sample_idx]).mean(axis=list(range(1, d)))

                kld_loss = -0.5 * (1 + logvar - mu.pow(2) -
                                   logvar.exp()).sum(axis=1)

                # TODO: keep this proportion of experience used for VAE update?
                # Proportion of experience used for VAE update
                mask = torch.rand(len(recon_loss)).to(self.device)
                mask = (mask < self.update_proportion).type(
                    torch.FloatTensor).to(self.device)
                recon_loss = (recon_loss * mask).sum() / torch.max(
                    mask.sum(),
                    torch.Tensor([1]).to(self.device))
                kld_loss = (kld_loss * mask).sum() / torch.max(
                    mask.sum(),
                    torch.Tensor([1]).to(self.device))

                recon_losses = np.append(recon_losses,
                                         recon_loss.detach().cpu().numpy())
                kld_losses = np.append(kld_losses,
                                       kld_loss.detach().cpu().numpy())
                # ---------------------------------------------------------------------------------

                policy, value_ext, value_int = self.model(s_batch[sample_idx])
                m = Categorical(F.softmax(policy, dim=-1))
                log_prob = m.log_prob(y_batch[sample_idx])

                ratio = torch.exp(log_prob - log_prob_old[sample_idx])

                surr1 = ratio * adv_batch[sample_idx]
                surr2 = torch.clamp(ratio, 1.0 - self.ppo_eps,
                                    1.0 + self.ppo_eps) * adv_batch[sample_idx]

                actor_loss = -torch.min(surr1, surr2).mean()
                critic_ext_loss = F.mse_loss(value_ext.sum(1),
                                             target_ext_batch[sample_idx])
                critic_int_loss = F.mse_loss(value_int.sum(1),
                                             target_int_batch[sample_idx])

                critic_loss = critic_ext_loss + critic_int_loss

                entropy = m.entropy().mean()

                self.optimizer.zero_grad()
                loss = actor_loss + 0.5 * critic_loss - self.ent_coef * entropy + recon_loss + kld_loss
                loss.backward()
                global_grad_norm_(
                    list(self.model.parameters()) +
                    list(self.vae.parameters()))
                self.optimizer.step()

        return recon_losses, kld_losses
Exemplo n.º 6
0
    def update(self, o, a, r_i, r_e, mask, o_, log_prob):
        self.normalizer_obs.update(o_.reshape(-1, 4, 84, 84).copy())
        self.normalizer_ri.update(r_i.reshape(-1).copy())

        r_i = self.normalizer_ri.normalize(r_i)
        o_ = self.normalizer_obs.normalize(o_)
        o = torch.from_numpy(o).to(self.device).float() / 255.

        returns_ex = np.zeros_like(r_e)
        returns_in = np.zeros_like(r_e)
        advantage_ex = np.zeros_like(r_e)
        advantage_in = np.zeros_like(r_e)
        for i in range(r_e.shape[0]):
            action_logits, value_ex, value_in = self.actor_critic(o[i])
            value_ex, value_in = value_ex.cpu().detach().numpy(), value_in.cpu(
            ).detach().numpy()
            returns_ex[i], _, advantage_ex[i] = self.GAE_caculate(
                r_e[i], mask[i], value_ex, self.gamma_e, self.lamda)  #episodic
            returns_in[i], _, advantage_in[i] = self.GAE_caculate(
                r_i[i], np.ones_like(mask[i]), value_in, self.gamma_i,
                self.lamda)  #non_episodic

        o = o.reshape((-1, 4, 84, 84))
        a = np.reshape(a, -1)
        o_ = np.reshape(o_[:, :, 3, :, :], (-1, 1, 84, 84))
        log_prob = np.reshape(log_prob, -1)
        returns_ex = np.reshape(returns_ex, -1)
        returns_in = np.reshape(returns_in, -1)
        advantage_ex = np.reshape(advantage_ex, -1)
        advantage_in = np.reshape(advantage_in, -1)

        a = torch.from_numpy(a).float().to(self.device)
        o_ = torch.from_numpy(o_).float().to(self.device).float()
        log_prob = torch.from_numpy(log_prob).float().to(self.device)
        returns_ex = torch.from_numpy(returns_ex).float().to(
            self.device).unsqueeze(dim=1)
        returns_in = torch.from_numpy(returns_in).float().to(
            self.device).unsqueeze(dim=1)
        advantage_ex = torch.from_numpy(advantage_ex).float().to(self.device)
        advantage_in = torch.from_numpy(advantage_in).float().to(self.device)

        sample_range = list(range(len(o)))

        for i_update in range(self.update_epoch):
            np.random.shuffle(sample_range)
            for j in range(int(len(o) / self.batch_size)):
                idx = sample_range[self.batch_size * j:self.batch_size *
                                   (j + 1)]
                #update RND
                pred_RND, tar_RND = self.RND(o_[idx])
                loss_RND = F.mse_loss(pred_RND,
                                      tar_RND.detach(),
                                      reduction='none').mean(-1)
                mask = torch.randn(len(loss_RND)).to(self.device)
                mask = (mask < self.update_proportion).type(
                    torch.FloatTensor).to(self.device)
                loss_RND = (loss_RND * mask).sum() / torch.max(
                    mask.sum(),
                    torch.Tensor([1]).to(self.device))

                #update actor-critic
                action_logits, value_ex, value_in = self.actor_critic(o[idx])
                advantage = self.ex_coef * advantage_ex[
                    idx] + self.in_coef * advantage_in[idx]
                dist = Categorical(action_logits)
                new_log_prob = dist.log_prob(a[idx])

                ratio = torch.exp(new_log_prob - log_prob[idx])
                surr1 = ratio * advantage
                surr2 = torch.clamp(ratio, 1 - self.clip_eps,
                                    1 + self.clip_eps) * advantage
                loss_actor = torch.min(
                    surr1,
                    surr2).mean() - self.entropy_coef * dist.entropy().mean()
                loss_critic = F.mse_loss(value_ex,
                                         returns_ex[idx]) + F.mse_loss(
                                             value_in, returns_in[idx])

                loss_ac = loss_actor + 0.5 * loss_critic

                loss = loss_RND + loss_ac
                self.optimizer.zero_grad()
                loss.backward()
                global_grad_norm_(
                    list(self.actor_critic.parameters()) +
                    list(self.RND.predictor.parameters()))
                self.optimizer.step()

        return loss_RND.cpu().detach().numpy(), loss_actor.cpu().detach(
        ).numpy(), loss_critic.cpu().detach().numpy()