Ejemplo n.º 1
0
    def train(self, training_batch: rlt.DiscreteDqnInput):
        if isinstance(training_batch, TrainingDataPage):
            training_batch = training_batch.as_discrete_maxq_training_batch()
        assert isinstance(training_batch, rlt.DiscreteDqnInput)
        boosted_rewards = self.boost_rewards(training_batch.reward,
                                             training_batch.action)

        self.minibatch += 1
        rewards = boosted_rewards
        discount_tensor = torch.full_like(rewards, self.gamma)
        not_done_mask = training_batch.not_terminal.float()
        assert not_done_mask.dim() == 2

        if self.use_seq_num_diff_as_time_diff:
            assert self.multi_steps is None
            discount_tensor = torch.pow(self.gamma,
                                        training_batch.time_diff.float())
        if self.multi_steps is not None:
            assert training_batch.step is not None
            # pyre-fixme[16]: `Optional` has no attribute `float`.
            discount_tensor = torch.pow(self.gamma,
                                        training_batch.step.float())

        all_next_q_values, all_next_q_values_target = self.get_detached_q_values(
            training_batch.next_state)

        if self.maxq_learning:
            # Compute max a' Q(s', a') over all possible actions using target network
            possible_next_actions_mask = (
                training_batch.possible_next_actions_mask.float())
            if self.bcq:
                action_on_policy = get_valid_actions_from_imitator(
                    self.bcq_imitator,
                    training_batch.next_state,
                    self.bcq_drop_threshold,
                )
                possible_next_actions_mask *= action_on_policy
            next_q_values, max_q_action_idxs = self.get_max_q_values_with_target(
                all_next_q_values, all_next_q_values_target,
                possible_next_actions_mask)
        else:
            # SARSA
            next_q_values, max_q_action_idxs = self.get_max_q_values_with_target(
                all_next_q_values, all_next_q_values_target,
                training_batch.next_action)

        filtered_next_q_vals = next_q_values * not_done_mask

        target_q_values = rewards + (discount_tensor * filtered_next_q_vals)

        with torch.enable_grad():
            # Get Q-value of action taken
            all_q_values = self.q_network(training_batch.state)
            # pyre-fixme[16]: `DQNTrainer` has no attribute `all_action_scores`.
            self.all_action_scores = all_q_values.detach()
            q_values = torch.sum(all_q_values * training_batch.action,
                                 1,
                                 keepdim=True)

            loss = self.q_network_loss(q_values, target_q_values)
            # pyre-fixme[16]: `DQNTrainer` has no attribute `loss`.
            self.loss = loss.detach()

            loss.backward()
            self._maybe_run_optimizer(self.q_network_optimizer,
                                      self.minibatches_per_step)

        # Use the soft update rule to update target network
        self._maybe_soft_update(self.q_network, self.q_network_target,
                                self.tau, self.minibatches_per_step)

        # Get Q-values of next states, used in computing cpe
        all_next_action_scores = self.q_network(
            training_batch.next_state).detach()

        logged_action_idxs = torch.argmax(training_batch.action,
                                          dim=1,
                                          keepdim=True)
        reward_loss, model_rewards, model_propensities = self._calculate_cpes(
            training_batch,
            training_batch.state,
            training_batch.next_state,
            self.all_action_scores,
            all_next_action_scores,
            logged_action_idxs,
            discount_tensor,
            not_done_mask,
        )

        if self.maxq_learning:
            possible_actions_mask = training_batch.possible_actions_mask

        if self.bcq:
            action_on_policy = get_valid_actions_from_imitator(
                self.bcq_imitator, training_batch.state,
                self.bcq_drop_threshold)
            possible_actions_mask *= action_on_policy

        model_action_idxs = self.get_max_q_values(
            self.all_action_scores,
            possible_actions_mask
            if self.maxq_learning else training_batch.action,
        )[1]

        # pyre-fixme[16]: `DQNTrainer` has no attribute `notify_observers`.
        self.notify_observers(
            td_loss=self.loss,
            reward_loss=reward_loss,
            logged_actions=logged_action_idxs,
            logged_propensities=training_batch.extras.action_probability,
            logged_rewards=rewards,
            model_propensities=model_propensities,
            model_rewards=model_rewards,
            model_values=self.all_action_scores,
            model_action_idxs=model_action_idxs,
        )

        self.loss_reporter.report(
            td_loss=self.loss,
            reward_loss=reward_loss,
            logged_actions=logged_action_idxs,
            logged_propensities=training_batch.extras.action_probability,
            logged_rewards=rewards,
            logged_values=None,  # Compute at end of each epoch for CPE
            model_propensities=model_propensities,
            model_rewards=model_rewards,
            model_values=self.all_action_scores,
            model_values_on_logged_actions=
            None,  # Compute at end of each epoch for CPE
            model_action_idxs=model_action_idxs,
        )
Ejemplo n.º 2
0
    def train(self, training_batch: rlt.DiscreteDqnInput) -> None:
        if isinstance(training_batch, TrainingDataPage):
            training_batch = training_batch.as_discrete_maxq_training_batch()

        rewards = self.boost_rewards(training_batch.reward, training_batch.action)
        discount_tensor = torch.full_like(rewards, self.gamma)
        possible_next_actions_mask = training_batch.possible_next_actions_mask.float()
        possible_actions_mask = training_batch.possible_actions_mask.float()

        self.minibatch += 1
        not_terminal = training_batch.not_terminal.float()

        if self.use_seq_num_diff_as_time_diff:
            assert self.multi_steps is None
            discount_tensor = torch.pow(self.gamma, training_batch.time_diff.float())
        if self.multi_steps is not None:
            assert training_batch.step is not None
            discount_tensor = torch.pow(self.gamma, training_batch.step.float())

        next_dist = self.q_network_target.log_dist(training_batch.next_state).exp()

        if self.maxq_learning:
            # Select distribution corresponding to max valued action
            if self.double_q_learning:
                next_q_values = (
                    self.q_network.log_dist(training_batch.next_state).exp()
                    * self.support
                ).sum(2)
            else:
                next_q_values = (next_dist * self.support).sum(2)

            next_action = self.argmax_with_mask(
                next_q_values, possible_next_actions_mask
            )
            next_dist = next_dist[range(rewards.shape[0]), next_action.reshape(-1)]
        else:
            next_dist = (next_dist * training_batch.next_action.unsqueeze(-1)).sum(1)

        # Build target distribution
        target_Q = rewards + discount_tensor * not_terminal * self.support
        target_Q = target_Q.clamp(self.qmin, self.qmax)

        # rescale to indicies [0, 1, ..., N-1]
        b = (target_Q - self.qmin) / self.scale_support
        # pyre-fixme[16]: `Tensor` has no attribute `floor`.
        lo = b.floor().to(torch.int64)
        # pyre-fixme[16]: `Tensor` has no attribute `ceil`.
        up = b.ceil().to(torch.int64)

        # handle corner cases of l == b == u
        # without the following, it would give 0 signal, whereas we want
        # m to add p(s_t+n, a*) to index l == b == u.
        # So we artificially adjust l and u.
        # (1) If 0 < l == u < N-1, we make l = l-1, so b-l = 1
        # (2) If 0 == l == u, we make u = 1, so u-b=1
        # (3) If l == u == N-1, we make l = N-2, so b-1 = 1
        # This first line handles (1) and (3).
        lo[(up > 0) * (lo == up)] -= 1
        # Note: l has already changed, so the only way l == u is possible is
        # if u == 0, in which case we let u = 1
        # I don't even think we need the first condition in the next line
        up[(lo < (self.num_atoms - 1)) * (lo == up)] += 1

        # distribute the probabilities
        # m_l = m_l + p(s_t+n, a*)(u - b)
        # m_u = m_u + p(s_t+n, a*)(b - l)
        m = torch.zeros_like(next_dist)
        # pyre-fixme[16]: `Tensor` has no attribute `scatter_add_`.
        m.scatter_add_(dim=1, index=lo, src=next_dist * (up.float() - b))
        m.scatter_add_(dim=1, index=up, src=next_dist * (b - lo.float()))

        with torch.enable_grad():
            log_dist = self.q_network.log_dist(training_batch.state)

            # for reporting only
            all_q_values = (log_dist.exp() * self.support).sum(2).detach()

            log_dist = (log_dist * training_batch.action.unsqueeze(-1)).sum(1)

            loss = -(m * log_dist).sum(1).mean()
            loss.backward()
            self._maybe_run_optimizer(
                self.q_network_optimizer, self.minibatches_per_step
            )

        # Use the soft update rule to update target network
        self._maybe_soft_update(
            self.q_network, self.q_network_target, self.tau, self.minibatches_per_step
        )

        model_action_idxs = self.argmax_with_mask(
            all_q_values,
            possible_actions_mask if self.maxq_learning else training_batch.action,
        )

        # pyre-fixme[16]: `C51Trainer` has no attribute `notify_observers`.
        self.notify_observers(
            td_loss=loss,
            logged_actions=torch.argmax(training_batch.action, dim=1, keepdim=True),
            logged_propensities=training_batch.extras.action_probability,
            logged_rewards=rewards,
            model_values=all_q_values,
            model_action_idxs=model_action_idxs,
        )

        self.loss_reporter.report(
            td_loss=loss,
            logged_actions=training_batch.action.argmax(dim=1, keepdim=True),
            logged_propensities=training_batch.extras.action_probability,
            logged_rewards=rewards,
            model_values=all_q_values,
            model_action_idxs=model_action_idxs,
        )
Ejemplo n.º 3
0
    def train(self, training_batch: rlt.DiscreteDqnInput):
        if isinstance(training_batch, TrainingDataPage):
            training_batch = training_batch.as_discrete_maxq_training_batch()

        rewards = self.boost_rewards(training_batch.reward,
                                     training_batch.action)
        discount_tensor = torch.full_like(rewards, self.gamma)
        possible_next_actions_mask = training_batch.possible_next_actions_mask.float(
        )
        possible_actions_mask = training_batch.possible_actions_mask.float()

        self.minibatch += 1
        not_done_mask = training_batch.not_terminal.float()

        if self.use_seq_num_diff_as_time_diff:
            assert self.multi_steps is None
            discount_tensor = torch.pow(self.gamma,
                                        training_batch.time_diff.float())
        if self.multi_steps is not None:
            assert training_batch.step is not None
            discount_tensor = torch.pow(self.gamma,
                                        training_batch.step.float())

        next_qf = self.q_network_target(training_batch.next_state)

        if self.maxq_learning:
            # Select distribution corresponding to max valued action
            next_q_values = (self.q_network(training_batch.next_state) if
                             self.double_q_learning else next_qf).mean(dim=2)
            next_action = self.argmax_with_mask(next_q_values,
                                                possible_next_actions_mask)
            next_qf = next_qf[range(rewards.shape[0]), next_action.reshape(-1)]
        else:
            next_qf = (next_qf *
                       training_batch.next_action.unsqueeze(-1)).sum(1)

        # Build target distribution
        target_Q = rewards + discount_tensor * not_done_mask * next_qf

        with torch.enable_grad():
            current_qf = self.q_network(training_batch.state)

            # for reporting only
            all_q_values = current_qf.mean(2).detach()

            current_qf = (current_qf *
                          training_batch.action.unsqueeze(-1)).sum(1)

            # (batch, atoms) -> (atoms, batch, 1) -> (atoms, batch, atoms)
            td = target_Q.t().unsqueeze(-1) - current_qf
            loss = (
                self.huber(td)
                # pyre-fixme[16]: `FloatTensor` has no attribute `abs`.
                * (self.quantiles - (td.detach() < 0).float()).abs()).mean()

            loss.backward()
            self._maybe_run_optimizer(self.q_network_optimizer,
                                      self.minibatches_per_step)

        # Use the soft update rule to update target network
        self._maybe_soft_update(self.q_network, self.q_network_target,
                                self.tau, self.minibatches_per_step)

        # Get Q-values of next states, used in computing cpe
        all_next_action_scores = (self.q_network(
            training_batch.next_state).detach().mean(dim=2))

        logged_action_idxs = torch.argmax(training_batch.action,
                                          dim=1,
                                          keepdim=True)
        reward_loss, model_rewards, model_propensities = self._calculate_cpes(
            training_batch,
            training_batch.state,
            training_batch.next_state,
            all_q_values,
            all_next_action_scores,
            logged_action_idxs,
            discount_tensor,
            not_done_mask,
        )

        model_action_idxs = self.argmax_with_mask(
            all_q_values,
            possible_actions_mask
            if self.maxq_learning else training_batch.action,
        )

        # pyre-fixme[16]: `QRDQNTrainer` has no attribute `notify_observers`.
        self.notify_observers(
            td_loss=loss,
            logged_actions=logged_action_idxs,
            logged_propensities=training_batch.extras.action_probability,
            logged_rewards=rewards,
            model_propensities=model_propensities,
            model_rewards=model_rewards,
            model_values=all_q_values,
            model_action_idxs=model_action_idxs,
        )

        self.loss_reporter.report(
            td_loss=loss,
            logged_actions=logged_action_idxs,
            logged_propensities=training_batch.extras.action_probability,
            logged_rewards=rewards,
            logged_values=None,  # Compute at end of each epoch for CPE
            model_propensities=model_propensities,
            model_rewards=model_rewards,
            model_values=all_q_values,
            model_values_on_logged_actions=
            None,  # Compute at end of each epoch for CPE
            model_action_idxs=model_action_idxs,
        )
Ejemplo n.º 4
0
    def train(self, training_batch: rlt.DiscreteDqnInput):
        if isinstance(training_batch, TrainingDataPage):
            training_batch = training_batch.as_discrete_maxq_training_batch()

        state = rlt.PreprocessedState(state=training_batch.state)
        next_state = rlt.PreprocessedState(state=training_batch.next_state)
        rewards = self.boost_rewards(training_batch.reward,
                                     training_batch.action)
        discount_tensor = torch.full_like(rewards, self.gamma)
        possible_next_actions_mask = training_batch.possible_next_actions_mask.float(
        )
        possible_actions_mask = training_batch.possible_actions_mask.float()

        self.minibatch += 1
        not_done_mask = training_batch.not_terminal.float()

        if self.use_seq_num_diff_as_time_diff:
            assert self.multi_steps is None
            discount_tensor = torch.pow(self.gamma,
                                        training_batch.time_diff.float())
        if self.multi_steps is not None:
            assert training_batch.step is not None
            discount_tensor = torch.pow(self.gamma,
                                        training_batch.step.float())

        next_dist = self.q_network_target.log_dist(next_state).exp()

        if self.maxq_learning:
            # Select distribution corresponding to max valued action
            if self.double_q_learning:
                next_q_values = (self.q_network.log_dist(next_state).exp() *
                                 self.support).sum(2)
            else:
                next_q_values = (next_dist * self.support).sum(2)

            next_action = self.argmax_with_mask(next_q_values,
                                                possible_next_actions_mask)
            next_dist = next_dist[range(rewards.shape[0]),
                                  next_action.reshape(-1)]
        else:
            next_dist = (next_dist *
                         training_batch.next_action.unsqueeze(-1)).sum(1)

        # Build target distribution
        target_Q = rewards + discount_tensor * not_done_mask * self.support

        # Project target distribution back onto support
        # remove support outliers
        target_Q = target_Q.clamp(self.qmin, self.qmax)
        # rescale to indicies
        b = (target_Q - self.qmin) / (self.qmax -
                                      self.qmin) * (self.num_atoms - 1.0)
        lower = b.floor()
        upper = b.ceil()

        # Since index_add_ doesn't work with multiple dimensions
        # we operate on the flattened tensors
        offset = self.num_atoms * torch.arange(
            rewards.shape[0], device=self.device, dtype=torch.long).reshape(
                -1, 1).repeat(1, self.num_atoms)

        m = torch.zeros_like(next_dist)
        m.reshape(-1).index_add_(  # type: ignore
            0,
            (lower.long() + offset).reshape(-1),
            (next_dist * (upper - b)).reshape(-1),
        )
        m.reshape(-1).index_add_(  # type: ignore
            0,
            (upper.long() + offset).reshape(-1),
            (next_dist * (b - lower)).reshape(-1),
        )

        with torch.enable_grad():
            log_dist = self.q_network.log_dist(state)

            # for reporting only
            all_q_values = (log_dist.exp() * self.support).sum(2).detach()

            log_dist = (log_dist * training_batch.action.unsqueeze(-1)).sum(1)

            loss = -(m * log_dist).sum(1).mean()
            loss.backward()
            self._maybe_run_optimizer(self.q_network_optimizer,
                                      self.minibatches_per_step)

        # Use the soft update rule to update target network
        self._maybe_soft_update(self.q_network, self.q_network_target,
                                self.tau, self.minibatches_per_step)

        model_action_idxs = self.argmax_with_mask(
            all_q_values,
            possible_actions_mask
            if self.maxq_learning else training_batch.action,
        )

        self.notify_observers(  # type: ignore
            td_loss=loss,
            logged_actions=torch.argmax(training_batch.action,
                                        dim=1,
                                        keepdim=True),
            logged_propensities=training_batch.extras.action_probability,
            logged_rewards=rewards,
            model_values=all_q_values,
            model_action_idxs=model_action_idxs,
        )

        self.loss_reporter.report(
            td_loss=loss,
            logged_actions=training_batch.action.argmax(dim=1, keepdim=True),
            logged_propensities=training_batch.extras.action_probability,
            logged_rewards=rewards,
            model_values=all_q_values,
            model_action_idxs=model_action_idxs,
        )