Ejemplo n.º 1
0
    def train_step_gen(self, training_batch: rlt.DiscreteDqnInput,
                       batch_idx: int):
        # TODO: calls to _maybe_run_optimizer removed, should be replaced with Trainer parameter
        assert isinstance(training_batch, rlt.DiscreteDqnInput)
        rewards = self.boost_rewards(training_batch.reward,
                                     training_batch.action)
        not_done_mask = training_batch.not_terminal.float()
        assert not_done_mask.dim() == 2

        discount_tensor = self.compute_discount_tensor(training_batch, rewards)
        td_loss = self.compute_td_loss(training_batch, rewards,
                                       discount_tensor)
        yield td_loss
        td_loss = td_loss.detach()

        # Get Q-values of next states, used in computing cpe
        all_next_action_scores = self.q_network(
            training_batch.next_state).detach()
        logged_action_idxs = torch.argmax(training_batch.action,
                                          dim=1,
                                          keepdim=True)

        yield from self._calculate_cpes(
            training_batch,
            training_batch.state,
            training_batch.next_state,
            # pyre-fixme[16]: `DQNTrainer` has no attribute `all_action_scores`.
            self.all_action_scores,
            all_next_action_scores,
            logged_action_idxs,
            discount_tensor,
            not_done_mask,
        )

        if self.maxq_learning:
            possible_actions_mask = training_batch.possible_actions_mask

        if self.bcq:
            action_on_policy = get_valid_actions_from_imitator(
                self.bcq_imitator, training_batch.state,
                self.bcq_drop_threshold)
            possible_actions_mask *= action_on_policy

        # Do we ever use model_action_idxs computed below?
        model_action_idxs = self.get_max_q_values(
            self.all_action_scores,
            possible_actions_mask
            if self.maxq_learning else training_batch.action,
        )[1]

        self._log_dqn(td_loss, logged_action_idxs, training_batch, rewards,
                      model_action_idxs)

        # Use the soft update rule to update target network
        yield self.soft_update_result()
Ejemplo n.º 2
0
    def compute_td_loss(
        self,
        batch: rlt.DiscreteDqnInput,
        boosted_rewards: torch.Tensor,
        discount_tensor: torch.Tensor,
    ):
        not_done_mask = batch.not_terminal.float()
        all_next_q_values, all_next_q_values_target = self.get_detached_model_outputs(
            batch.next_state)

        if self.maxq_learning:
            # Compute max a' Q(s', a') over all possible actions using target network
            possible_next_actions_mask = batch.possible_next_actions_mask.float(
            )
            if self.bcq:
                action_on_policy = get_valid_actions_from_imitator(
                    self.bcq_imitator,
                    batch.next_state,
                    self.bcq_drop_threshold,
                )
                possible_next_actions_mask *= action_on_policy
            next_q_values, max_q_action_idxs = self.get_max_q_values_with_target(
                all_next_q_values,
                all_next_q_values_target,
                possible_next_actions_mask,
            )
        else:
            # SARSA
            next_q_values, max_q_action_idxs = self.get_max_q_values_with_target(
                all_next_q_values,
                all_next_q_values_target,
                batch.next_action,
            )

        filtered_next_q_vals = next_q_values * not_done_mask

        target_q_values = boosted_rewards + (discount_tensor *
                                             filtered_next_q_vals)

        # Get Q-value of action taken
        all_q_values = self.q_network(batch.state)
        # pyre-fixme[16]: `DQNTrainer` has no attribute `all_action_scores`.
        self.all_action_scores = all_q_values.detach()
        q_values = torch.sum(all_q_values * batch.action, 1, keepdim=True)
        td_loss = self.q_network_loss(q_values, target_q_values.detach())
        return td_loss
Ejemplo n.º 3
0
    def train(self, training_batch: rlt.DiscreteDqnInput):
        if isinstance(training_batch, TrainingDataPage):
            training_batch = training_batch.as_discrete_maxq_training_batch()
        assert isinstance(training_batch, rlt.DiscreteDqnInput)
        boosted_rewards = self.boost_rewards(training_batch.reward,
                                             training_batch.action)

        self.minibatch += 1
        rewards = boosted_rewards
        discount_tensor = torch.full_like(rewards, self.gamma)
        not_done_mask = training_batch.not_terminal.float()
        assert not_done_mask.dim() == 2

        if self.use_seq_num_diff_as_time_diff:
            assert self.multi_steps is None
            discount_tensor = torch.pow(self.gamma,
                                        training_batch.time_diff.float())
        if self.multi_steps is not None:
            assert training_batch.step is not None
            # pyre-fixme[16]: `Optional` has no attribute `float`.
            discount_tensor = torch.pow(self.gamma,
                                        training_batch.step.float())

        all_next_q_values, all_next_q_values_target = self.get_detached_q_values(
            training_batch.next_state)

        if self.maxq_learning:
            # Compute max a' Q(s', a') over all possible actions using target network
            possible_next_actions_mask = (
                training_batch.possible_next_actions_mask.float())
            if self.bcq:
                action_on_policy = get_valid_actions_from_imitator(
                    self.bcq_imitator,
                    training_batch.next_state,
                    self.bcq_drop_threshold,
                )
                possible_next_actions_mask *= action_on_policy
            next_q_values, max_q_action_idxs = self.get_max_q_values_with_target(
                all_next_q_values, all_next_q_values_target,
                possible_next_actions_mask)
        else:
            # SARSA
            next_q_values, max_q_action_idxs = self.get_max_q_values_with_target(
                all_next_q_values, all_next_q_values_target,
                training_batch.next_action)

        filtered_next_q_vals = next_q_values * not_done_mask

        target_q_values = rewards + (discount_tensor * filtered_next_q_vals)

        with torch.enable_grad():
            # Get Q-value of action taken
            all_q_values = self.q_network(training_batch.state)
            # pyre-fixme[16]: `DQNTrainer` has no attribute `all_action_scores`.
            self.all_action_scores = all_q_values.detach()
            q_values = torch.sum(all_q_values * training_batch.action,
                                 1,
                                 keepdim=True)

            loss = self.q_network_loss(q_values, target_q_values)
            # pyre-fixme[16]: `DQNTrainer` has no attribute `loss`.
            self.loss = loss.detach()

            loss.backward()
            self._maybe_run_optimizer(self.q_network_optimizer,
                                      self.minibatches_per_step)

        # Use the soft update rule to update target network
        self._maybe_soft_update(self.q_network, self.q_network_target,
                                self.tau, self.minibatches_per_step)

        # Get Q-values of next states, used in computing cpe
        all_next_action_scores = self.q_network(
            training_batch.next_state).detach()

        logged_action_idxs = torch.argmax(training_batch.action,
                                          dim=1,
                                          keepdim=True)
        reward_loss, model_rewards, model_propensities = self._calculate_cpes(
            training_batch,
            training_batch.state,
            training_batch.next_state,
            self.all_action_scores,
            all_next_action_scores,
            logged_action_idxs,
            discount_tensor,
            not_done_mask,
        )

        if self.maxq_learning:
            possible_actions_mask = training_batch.possible_actions_mask

        if self.bcq:
            action_on_policy = get_valid_actions_from_imitator(
                self.bcq_imitator, training_batch.state,
                self.bcq_drop_threshold)
            possible_actions_mask *= action_on_policy

        model_action_idxs = self.get_max_q_values(
            self.all_action_scores,
            possible_actions_mask
            if self.maxq_learning else training_batch.action,
        )[1]

        # pyre-fixme[16]: `DQNTrainer` has no attribute `notify_observers`.
        self.notify_observers(
            td_loss=self.loss,
            reward_loss=reward_loss,
            logged_actions=logged_action_idxs,
            logged_propensities=training_batch.extras.action_probability,
            logged_rewards=rewards,
            model_propensities=model_propensities,
            model_rewards=model_rewards,
            model_values=self.all_action_scores,
            model_action_idxs=model_action_idxs,
        )

        self.loss_reporter.report(
            td_loss=self.loss,
            reward_loss=reward_loss,
            logged_actions=logged_action_idxs,
            logged_propensities=training_batch.extras.action_probability,
            logged_rewards=rewards,
            logged_values=None,  # Compute at end of each epoch for CPE
            model_propensities=model_propensities,
            model_rewards=model_rewards,
            model_values=self.all_action_scores,
            model_values_on_logged_actions=
            None,  # Compute at end of each epoch for CPE
            model_action_idxs=model_action_idxs,
        )
Ejemplo n.º 4
0
    def train_step_gen(self, training_batch: rlt.DiscreteDqnInput,
                       batch_idx: int):
        # TODO: calls to _maybe_run_optimizer removed, should be replaced with Trainer parameter
        assert isinstance(training_batch, rlt.DiscreteDqnInput)
        boosted_rewards = self.boost_rewards(training_batch.reward,
                                             training_batch.action)
        rewards = boosted_rewards
        discount_tensor = torch.full_like(rewards, self.gamma)
        not_done_mask = training_batch.not_terminal.float()
        assert not_done_mask.dim() == 2

        if self.use_seq_num_diff_as_time_diff:
            assert self.multi_steps is None
            discount_tensor = torch.pow(self.gamma,
                                        training_batch.time_diff.float())
        if self.multi_steps is not None:
            assert training_batch.step is not None
            # pyre-fixme[16]: `Optional` has no attribute `float`.
            discount_tensor = torch.pow(self.gamma,
                                        training_batch.step.float())

        all_next_q_values, all_next_q_values_target = self.get_detached_q_values(
            training_batch.next_state)

        if self.maxq_learning:
            # Compute max a' Q(s', a') over all possible actions using target network
            possible_next_actions_mask = (
                training_batch.possible_next_actions_mask.float())
            if self.bcq:
                action_on_policy = get_valid_actions_from_imitator(
                    self.bcq_imitator,
                    training_batch.next_state,
                    self.bcq_drop_threshold,
                )
                possible_next_actions_mask *= action_on_policy
            next_q_values, max_q_action_idxs = self.get_max_q_values_with_target(
                all_next_q_values,
                all_next_q_values_target,
                possible_next_actions_mask,
            )
        else:
            # SARSA
            next_q_values, max_q_action_idxs = self.get_max_q_values_with_target(
                all_next_q_values,
                all_next_q_values_target,
                training_batch.next_action,
            )

        filtered_next_q_vals = next_q_values * not_done_mask

        target_q_values = rewards + (discount_tensor * filtered_next_q_vals)

        # Get Q-value of action taken
        all_q_values = self.q_network(training_batch.state)
        # pyre-fixme[16]: `DQNTrainer` has no attribute `all_action_scores`.
        self.all_action_scores = all_q_values.detach()
        q_values = torch.sum(all_q_values * training_batch.action,
                             1,
                             keepdim=True)
        loss = self.q_network_loss(q_values, target_q_values)

        # pyre-fixme[16]: `DQNTrainer` has no attribute `loss`.
        self.loss = loss.detach()
        yield loss

        # Get Q-values of next states, used in computing cpe
        all_next_action_scores = self.q_network(
            training_batch.next_state).detach()
        logged_action_idxs = torch.argmax(training_batch.action,
                                          dim=1,
                                          keepdim=True)

        yield from self._calculate_cpes(
            training_batch,
            training_batch.state,
            training_batch.next_state,
            self.all_action_scores,
            all_next_action_scores,
            logged_action_idxs,
            discount_tensor,
            not_done_mask,
        )

        if self.maxq_learning:
            possible_actions_mask = training_batch.possible_actions_mask

        if self.bcq:
            action_on_policy = get_valid_actions_from_imitator(
                self.bcq_imitator, training_batch.state,
                self.bcq_drop_threshold)
            possible_actions_mask *= action_on_policy

        # Do we ever use model_action_idxs computed below?
        model_action_idxs = self.get_max_q_values(
            self.all_action_scores,
            possible_actions_mask
            if self.maxq_learning else training_batch.action,
        )[1]

        self.reporter.log(
            td_loss=self.loss,
            logged_actions=logged_action_idxs,
            logged_propensities=training_batch.extras.action_probability,
            logged_rewards=rewards,
            logged_values=None,  # Compute at end of each epoch for CPE
            model_values=self.all_action_scores,
            model_values_on_logged_actions=
            None,  # Compute at end of each epoch for CPE
            model_action_idxs=model_action_idxs,
        )

        # Use the soft update rule to update target network
        yield self.soft_update_result()