示例#1
0
    def train(self, replay_buffer, writer, iterations):
        self.step += 1
        # Step 4: We sample a batch of transitions (s, s’, a, r) from the memory
        sys.stdout = open(os.devnull, "w")
        obs, action, reward, next_obs, not_done, obs_list, next_obs_list = replay_buffer.sample(
            self.batch_size)
        sys.stdout = sys.__stdout__
        #batch_states, batch_next_states, batch_actions, batch_rewards, batch_dones = replay_buffer.sample(self.batch_size)
        #state_image = torch.Tensor(batch_states).to(self.device).div_(255)
        #next_state = torch.Tensor(batch_next_states).to(self.device).div_(255)
        # create vector
        #reward = torch.Tensor(batch_rewards).to(self.device)
        #done = torch.Tensor(batch_dones).to(self.device)
        obs = obs.div_(255)
        next_obs = next_obs.div_(255)

        state = self.decoder.create_vector(obs)
        detach_state = state.detach()
        next_state = self.decoder.create_vector(next_obs)

        alpha = torch.exp(self.log_alpha)
        with torch.no_grad():
            # Step 5:
            next_action, next_log_pi = self.actor(next_state)
            # compute quantile
            next_z = self.critic_target(next_obs_list, next_action)
            sorted_z, _ = torch.sort(next_z.reshape(self.batch_size, -1))
            sorted_z_part = sorted_z[:, :self.quantiles_total -
                                     self.top_quantiles_to_drop]

            # get target
            target = reward + not_done * self.discount * (sorted_z_part -
                                                          alpha * next_log_pi)
        #---update critic
        cur_z = self.critic(obs_list, action)
        critic_loss = quantile_huber_loss_f(cur_z, target, self.device)
        self.critic_optimizer.zero_grad()
        self.decoder_optimizer.zero_grad()
        critic_loss.backward()
        self.decoder_optimizer.step()
        self.critic_optimizer.step()
        for param, target_param in zip(self.critic.parameters(),
                                       self.critic_target.parameters()):
            target_param.data.copy_(self.tau * param.data +
                                    (1 - self.tau) * target_param.data)

        #---Update policy and alpha
        new_action, log_pi = self.actor(detach_state)
        alpha_loss = -self.log_alpha * (log_pi +
                                        self.target_entropy).detach().mean()
        actor_loss = (alpha * log_pi - self.critic(
            obs_list, new_action).mean(2).mean(1, keepdim=True)).mean()

        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        self.alpha_optimizer.zero_grad()
        alpha_loss.backward()
        self.alpha_optimizer.step()
        self.total_it += 1
    def train(self, replay_buffer, writer, iterations):
        self.step += 1
        if self.step % 1000 == 0:
            self.write_tensorboard = 1 - self.write_tensorboard
        for it in range(iterations):
            # Step 4: We sample a batch of transitions (s, s’, a, r) from the memoy
            sys.stdout = open(os.devnull, "w")
            obs, action, reward, next_obs, not_done, obs_aug, obs_next_aug = replay_buffer.sample(
                self.batch_size)
            sys.stdout = sys.__stdout__
            # for augment 1
            obs = obs.div_(255)
            next_obs = next_obs.div_(255)
            state = self.decoder.create_vector(obs)
            detach_state = state.detach()
            next_state = self.target_decoder.create_vector(next_obs)
            # for augment 2

            obs_aug = obs_aug.div_(255)
            next_obs_aug = obs_next_aug.div_(255)
            state_aug = self.decoder.create_vector(obs_aug)
            detach_state_aug = state_aug.detach()
            next_state_aug = self.target_decoder.create_vector(next_obs_aug)

            alpha = torch.exp(self.log_alpha)
            with torch.no_grad():
                # Step 5: Get policy action
                new_next_action, next_log_pi = self.actor(next_state)

                # compute quantile at next state
                next_z = self.target_critic(next_state, new_next_action)
                sorted_z, _ = torch.sort(next_z.reshape(self.batch_size, -1))
                sorted_z_part = sorted_z[:, :self.quantiles_total -
                                         self.top_quantiles_to_drop]
                target = reward + not_done * self.discount * (
                    sorted_z_part - alpha * next_log_pi)

                # again for augment
                new_next_action_aug, next_log_pi_aug = self.actor(
                    next_state_aug)
                next_z_aug = self.target_critic(next_state_aug,
                                                new_next_action_aug)
                sorted_z_aug, _ = torch.sort(
                    next_z_aug.reshape(self.batch_size, -1))
                sorted_z_part_aug = sorted_z_aug[:, :self.quantiles_total -
                                                 self.top_quantiles_to_drop]
                target_aug = reward + not_done * self.discount * (
                    sorted_z_part_aug - alpha * next_log_pi_aug)

            target = (target + target_aug) / 2.
            #---update critic
            cur_z = self.critic(state, action)
            #print("curz shape", cur_z.shape)
            #print("target shape", target.shape)
            critic_loss = quantile_huber_loss_f(cur_z, target, self.device)

            # for augment
            cur_z_aug = self.critic(state_aug, action)
            critic_loss += quantile_huber_loss_f(cur_z_aug, target,
                                                 self.device)
            self.critic_optimizer.zero_grad()
            self.decoder_optimizer.zero_grad()
            critic_loss.backward()
            self.decoder_optimizer.step()
            self.critic_optimizer.step()

            for param, target_param in zip(self.critic.parameters(),
                                           self.target_critic.parameters()):
                target_param.data.copy_(self.tau * param.data +
                                        (1 - self.tau) * target_param.data)

            for param, target_param in zip(self.decoder.parameters(),
                                           self.target_decoder.parameters()):
                target_param.data.copy_(self.tau * param.data +
                                        (1 - self.tau) * target_param.data)

            #---Update policy and alpha
            new_action, log_pi = self.actor(detach_state)
            alpha_loss = -self.log_alpha * (
                log_pi + self.target_entropy).detach().mean()
            actor_loss = (alpha * log_pi -
                          self.critic(detach_state, new_action).mean(2).mean(
                              1, keepdim=True)).mean()

            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            self.total_it += 1