Exemplo n.º 1
0
    def learn(self):
        self.learning_steps += 1
        self.online_net.sample_noise()
        self.target_net.sample_noise()

        if self.use_per:
            (states, actions, rewards, next_states, dones), weights = \
                self.memory.sample(self.batch_size)
        else:
            states, actions, rewards, next_states, dones = \
                self.memory.sample(self.batch_size)
            weights = None

        quantile_loss, mean_q, errors = self.calculate_loss(
            states, actions, rewards, next_states, dones, weights)

        update_params(
            self.optim, quantile_loss,
            networks=[self.online_net],
            retain_graph=False, grad_cliping=self.grad_cliping)

        if self.use_per:
            self.memory.update_priority(errors)

        if 4 * self.steps % self.log_interval == 0:
            self.writer.add_scalar(
                'loss/quantile_loss', quantile_loss.detach().item(),
                4 * self.steps)
            self.writer.add_scalar('stats/mean_Q', mean_q, 4 * self.steps)
    def learn(self):
        self.learning_steps += 1
        self.online_net.sample_noise()
        self.target_net.sample_noise()

        if self.use_per:
            (states, actions, rewards, next_states, dones), weights = \
                self.memory.sample(self.batch_size)
        else:
            states, actions, rewards, next_states, dones = \
                self.memory.sample(self.batch_size)
            weights = None

        dmog_loss = self.calculate_loss(states, actions, rewards, next_states,
                                        dones, weights)

        update_params(self.optim,
                      dmog_loss,
                      networks=[self.online_net],
                      retain_graph=False,
                      grad_cliping=self.grad_cliping)
Exemplo n.º 3
0
    def learn(self):
        self.learning_steps += 1
        self.online_net.sample_noise()
        self.target_net.sample_noise()

        if self.use_per:
            (states, actions, rewards, next_states, dones), weights =\
                self.memory.sample(self.batch_size)
        else:
            states, actions, rewards, next_states, dones =\
                self.memory.sample(self.batch_size)
            weights = None

        # Calculate embeddings of current states.
        state_embeddings = self.online_net.calculate_state_embeddings(states)

        # Calculate fractions of current states and entropies.
        taus, tau_hats, entropies =\
            self.online_net.calculate_fractions(
                state_embeddings=state_embeddings)

        # Calculate quantile values of current states and actions at tau_hats.
        current_sa_quantile_hats = evaluate_quantile_at_action(
            self.online_net.calculate_quantiles(
                tau_hats, state_embeddings=state_embeddings), actions)
        assert current_sa_quantile_hats.shape == (self.batch_size, self.N, 1)

        # NOTE: Detach state_embeddings not to update convolution layers. Also,
        # detach current_sa_quantile_hats because I calculate gradients of taus
        # explicitly, not by backpropagation.
        fraction_loss = self.calculate_fraction_loss(
            state_embeddings.detach(), current_sa_quantile_hats.detach(), taus,
            actions, weights)

        quantile_loss, mean_q, errors = self.calculate_quantile_loss(
            state_embeddings, tau_hats, current_sa_quantile_hats, actions,
            rewards, next_states, dones, weights)

        entropy_loss = -self.ent_coef * entropies.mean()

        update_params(self.fraction_optim,
                      fraction_loss + entropy_loss,
                      networks=[self.online_net.fraction_net],
                      retain_graph=True,
                      grad_cliping=self.grad_cliping)
        update_params(self.quantile_optim,
                      quantile_loss + entropy_loss,
                      networks=[
                          self.online_net.dqn_net, self.online_net.cosine_net,
                          self.online_net.quantile_net
                      ],
                      retain_graph=False,
                      grad_cliping=self.grad_cliping)

        if self.use_per:
            self.memory.update_priority(errors)

        if self.learning_steps % self.log_interval == 0:
            self.writer.add_scalar('loss/fraction_loss',
                                   fraction_loss.detach().item(),
                                   4 * self.steps)
            self.writer.add_scalar('loss/quantile_loss',
                                   quantile_loss.detach().item(),
                                   4 * self.steps)
            if self.ent_coef > 0.0:
                self.writer.add_scalar('loss/entropy_loss',
                                       entropy_loss.detach().item(),
                                       4 * self.steps)

            self.writer.add_scalar('stats/mean_Q', mean_q, 4 * self.steps)
            self.writer.add_scalar('stats/mean_entropy_of_value_distribution',
                                   entropies.mean().detach().item(),
                                   4 * self.steps)