Exemplo n.º 1
0
    def update_q_functions(self,
                           batch,
                           writer,
                           imp_ws1=None,
                           imp_ws2=None,
                           fast_batch=None):
        states, actions, rewards, next_states, dones, *_ = batch

        # Calculate current and target Q values.
        curr_qs1, curr_qs2 = self.calc_current_qs(states)
        target_qs = self.calc_target_qs(rewards, next_states, dones)

        # Update Q functions.
        q_loss, mean_q1, mean_q2, unweighted_q_loss = \
            self.calc_q_loss(curr_qs1, curr_qs2, target_qs, imp_ws1, imp_ws2)
        update_params(self._q_optim, q_loss)

        #TODO: compute Q loss for online batch

        if self._learning_steps % self._log_interval == 0:
            writer.add_scalar('loss/Q',
                              unweighted_q_loss.detach().item(),
                              self._learning_steps)
            writer.add_scalar('stats/mean_Q1', mean_q1, self._learning_steps)
            writer.add_scalar('stats/mean_Q2', mean_q2, self._learning_steps)

        # Return there values for DisCor algorithm.
        return curr_qs1.detach(), curr_qs2.detach(), target_qs
Exemplo n.º 2
0
    def calc_update_d_pi_iw(self,
                            slow_obs,
                            slow_act,
                            fast_obs,
                            fast_act,
                            target_obs=None,
                            target_act=None):
        slow_samples = torch.cat((slow_obs, slow_act), dim=1)
        fast_samples = torch.cat((fast_obs, fast_act), dim=1)

        zeros = torch.zeros(slow_samples.size(0)).to(device=self._device)
        ones = torch.ones(fast_samples.size(0)).to(device=self._device)

        slow_preds = self._prob_classifier(slow_samples)
        fast_preds = self._prob_classifier(fast_samples)

        loss = F.binary_cross_entropy(F.sigmoid(slow_preds), zeros) + \
                F.binary_cross_entropy(F.sigmoid(fast_preds), ones)

        update_params(self._prob_optim, loss)

        # In case we want to compute ratio on data different from what we train the network
        if target_obs is None:
            target_obs = slow_obs
        if target_act is None:
            target_act = slow_act
        target_samples = torch.cat((target_obs, target_act), dim=1)
        slow_preds = self._prob_classifier(target_samples)

        importance_weights = F.sigmoid(slow_preds /
                                       self.prob_temperature).detach()
        importance_weights = importance_weights / torch.sum(importance_weights)

        return importance_weights, loss
Exemplo n.º 3
0
    def update_q_functions_and_error_models(self, batch, writer):
        states, actions, rewards, next_states, dones = batch

        # Calculate importance weights.
        imp_ws1, imp_ws2 = self.calc_importance_weights(next_states, dones)

        # Update Q functions.
        curr_qs1, curr_qs2, target_qs = \
            self.update_q_functions(batch, writer, imp_ws1, imp_ws2)

        # Calculate current and target errors, as well as importance weights.
        curr_errs1, curr_errs2 = self.calc_current_errors(states, actions)
        target_errs1, target_errs2 = self.calc_target_errors(
            next_states, dones, curr_qs1, curr_qs2, target_qs)

        # Update error models.
        err_loss = self.calc_error_loss(curr_errs1, curr_errs2, target_errs1,
                                        target_errs2)
        update_params(self._error_optim, err_loss)

        if self._learning_steps % self._log_interval == 0:
            writer.add_scalar('loss/error',
                              err_loss.detach().item(), self._learning_steps)
            writer.add_scalar('stats/tau1', self._tau1.item(),
                              self._learning_steps)
            writer.add_scalar('stats/tau2', self._tau2.item(),
                              self._learning_steps)
Exemplo n.º 4
0
    def update_q_functions(self,
                           batch,
                           writer,
                           imp_ws1=None,
                           imp_ws2=None,
                           fast_batch=None,
                           err_preds=None):
        states, actions, rewards, next_states, dones = \
            batch["states"], batch["actions"], batch["rewards"], batch["next_states"], batch["dones"]

        # Calculate current and target Q values.
        curr_qs1, curr_qs2 = self.calc_current_qs(states, actions)
        target_qs = self.calc_target_qs(rewards, next_states, dones)

        # Update Q functions.
        q_loss, mean_q1, mean_q2, unweighted_q_loss = \
            self.calc_q_loss(curr_qs1, curr_qs2, target_qs, imp_ws1, imp_ws2)
        update_params(self._q_optim, q_loss)

        #TODO: compute Q loss for online batch

        if self._learning_steps % self._log_interval == 0:
            writer.add_scalar('loss/Q',
                              unweighted_q_loss.detach().item(),
                              self._learning_steps)
            writer.add_scalar('stats/mean_Q1', mean_q1, self._learning_steps)
            writer.add_scalar('stats/mean_Q2', mean_q2, self._learning_steps)

        if self._eval_tper and self._learning_steps % self._eval_tper_interval == 0:
            steps = batch["steps"]
            sim_states = batch["sim_states"]
            done_cnts = batch["done_cnts"]
            self.eval_Q(states[:128], actions[:128], steps[:128],
                        sim_states[:128], curr_qs1[:128], done_cnts[:128],
                        err_preds[:128] if err_preds is not None else None)

        # Return their values for DisCor algorithm.
        return curr_qs1.detach(), curr_qs2.detach(), target_qs
Exemplo n.º 5
0
    def update_policy_and_entropy(self, batch, writer):
        states = batch["states"]

        # Update policy.
        policy_loss, entropies = self.calc_policy_loss(states)
        update_params(self._policy_optim, policy_loss)

        # Update the entropy coefficient.
        entropy_loss = self.calc_entropy_loss(entropies)
        update_params(self._alpha_optim, entropy_loss)
        self._alpha = self._log_alpha.detach().exp()

        if self._learning_steps % self._log_interval == 0:
            writer.add_scalar('loss/policy',
                              policy_loss.detach().item(),
                              self._learning_steps)
            writer.add_scalar('loss/entropy',
                              entropy_loss.detach().item(),
                              self._learning_steps)
            writer.add_scalar('stats/alpha', self._alpha.item(),
                              self._learning_steps)
            writer.add_scalar('stats/entropy',
                              entropies.detach().mean().item(),
                              self._learning_steps)
Exemplo n.º 6
0
    def update_q_functions_and_error_models(self, batch, writer):
        uniform_batch = batch["uniform"]
        if self.lfiw:
            fast_batch = batch['fast']
            fast_states, fast_actions = fast_batch['states'], fast_batch[
                'actions']
        else:
            fast_batch = None
        # train_batch = batch["prior"] if self.tper else batch["uniform"]
        train_batch = batch["uniform"]

        # transition to update Q net
        states, actions, next_states, dones = \
            train_batch["states"], train_batch["actions"], train_batch["next_states"], train_batch["dones"]
        # s,a to update the weight of lfiw network
        slow_states, slow_actions = uniform_batch["states"], uniform_batch[
            "actions"]

        # Calculate importance weights.
        batch_size = states.shape[0]
        weights1 = torch.ones((batch_size, 1)).to(device=self._device)
        weights2 = torch.ones((batch_size, 1)).to(device=self._device)
        if self.discor:
            discor_weights = self.calc_importance_weights(next_states, dones)
            # print(weights[0].shape, discor_weights[0].shape)
            weights1 *= discor_weights[0]
            weights2 *= discor_weights[1]
        # Calculate and update prob_classifier
        if self.lfiw:
            lfiw_weights, prob_loss = self.calc_update_d_pi_iw(
                slow_states, slow_actions, fast_states, fast_actions, states,
                actions)
            weights1 *= lfiw_weights
            weights2 *= lfiw_weights
        # Calculate weights for temporal priority
        if self.tper:
            steps = train_batch["steps"]
            done_cnts = train_batch["done_cnts"]
            tper_weights = self.calc_tper_weights(steps, done_cnts)
            weights1 *= tper_weights
            weights2 *= tper_weights

        # Update Q functions.
        curr_errs1, curr_errs2 = None, None
        if self.discor:
            curr_errs1, curr_errs2 = self.calc_current_errors(states, actions)
        # pass in curr_errs1 for evaluating discor
        curr_qs1, curr_qs2, target_qs = \
            self.update_q_functions(train_batch, writer, weights1, weights2, fast_batch, curr_errs1)

        # Calculate current and target errors.
        if self.discor:
            target_errs1, target_errs2 = self.calc_target_errors(
                next_states, dones, curr_qs1, curr_qs2, target_qs)
            # Update error models.
            err_loss = self.calc_error_loss(curr_errs1, curr_errs2,
                                            target_errs1, target_errs2)
            update_params(self._error_optim, err_loss)

        if self._learning_steps % self._log_interval == 0:
            if self.discor:
                writer.add_scalar('loss/error',
                                  err_loss.detach().item(),
                                  self._learning_steps)
                writer.add_scalar('stats/tau1', self._tau1.item(),
                                  self._learning_steps)
                writer.add_scalar('stats/tau2', self._tau2.item(),
                                  self._learning_steps)
            if self.lfiw:
                writer.add_scalar('loss/prob_loss',
                                  prob_loss.detach().item(),
                                  self._learning_steps)