예제 #1
0
    def test_masked_softmax(self):
        # Postive value case
        x = torch.tensor([[15.0, 6.0, 9.0], [3.0, 2.0, 1.0]])
        temperature = 1
        mask = torch.tensor([[1.0, 0.0, 1.0], [0.0, 1.0, 1.0]])
        out = masked_softmax(x, mask, temperature)
        expected_out = torch.tensor([[0.9975, 0.0000, 0.0025],
                                     [0, 0.7311, 0.2689]])
        npt.assert_array_almost_equal(out, expected_out, 4)

        # Postive value case (masked value goes to inf)
        x = torch.tensor([[150.0, 2.0]])
        temperature = 0.01
        mask = torch.tensor([[0.0, 1.0]])
        out = masked_softmax(x, mask, temperature)
        expected_out = torch.tensor([[0.0, 1.0]])
        npt.assert_array_almost_equal(out, expected_out, 4)

        # Negative value case
        x = torch.tensor([[-10.0, -1.0, -5.0]])
        temperature = 0.01
        mask = torch.tensor([[1.0, 1.0, 0.0]])
        out = masked_softmax(x, mask, temperature)
        expected_out = torch.tensor([[0.0, 1.0, 0.0]])
        npt.assert_array_almost_equal(out, expected_out, 4)

        # All values in a row are masked case
        x = torch.tensor([[-5.0, 4.0, 3.0], [2.0, 1.0, 2.0]])
        temperature = 1
        mask = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]])
        out = masked_softmax(x, mask, temperature)
        expected_out = torch.tensor([[0.0, 0.0, 0.0], [0.4223, 0.1554,
                                                       0.4223]])
        npt.assert_array_almost_equal(out, expected_out, 4)
예제 #2
0
    def policy_given_q_values(
        q_scores: torch.Tensor,
        action_names: List[str],
        softmax_temperature: float,
        possible_actions_presence: Optional[torch.Tensor] = None,
    ) -> DqnPolicyActionSet:
        assert q_scores.shape[0] == 1 and len(q_scores.shape) == 2

        if possible_actions_presence is None:
            possible_actions_presence = torch.ones_like(q_scores)
        possible_actions_presence = possible_actions_presence.reshape(1, -1)
        assert possible_actions_presence.shape == q_scores.shape

        # set impossible actions so low that they can't be picked
        q_scores -= (1.0 - possible_actions_presence) * 1e10  # type: ignore

        q_scores_softmax = (
            masked_softmax(q_scores, possible_actions_presence, softmax_temperature)
            .detach()
            .numpy()[0]
        )
        if np.isnan(q_scores_softmax).any() or np.max(q_scores_softmax) < 1e-3:
            q_scores_softmax[:] = 1.0 / q_scores_softmax.shape[0]
        greedy_act_idx = int(torch.argmax(q_scores))
        softmax_act_idx = int(np.random.choice(q_scores.size()[1], p=q_scores_softmax))

        return DqnPolicyActionSet(
            greedy=greedy_act_idx,
            softmax=softmax_act_idx,
            greedy_act_name=action_names[greedy_act_idx],
            softmax_act_name=action_names[softmax_act_idx],
        )
예제 #3
0
    def policy_given_q_values(
        q_scores: torch.Tensor,
        softmax_temperature: float,
        possible_actions_presence: torch.Tensor,
    ) -> DqnPolicyActionSet:
        assert q_scores.shape[0] == 1 and len(q_scores.shape) == 2
        possible_actions_presence = possible_actions_presence.reshape(1, -1)
        assert possible_actions_presence.shape == q_scores.shape

        # set impossible actions so low that they can't be picked
        q_scores -= (1.0 - possible_actions_presence) * 1e10

        q_scores_softmax_numpy = (
            masked_softmax(
                q_scores.reshape(1, -1), possible_actions_presence, softmax_temperature
            )
            .detach()
            .numpy()[0]
        )
        if (
            np.isnan(q_scores_softmax_numpy).any()
            or np.max(q_scores_softmax_numpy) < 1e-3
        ):
            q_scores_softmax_numpy[:] = 1.0 / q_scores_softmax_numpy.shape[0]

        return DqnPolicyActionSet(
            greedy=int(torch.argmax(q_scores)),
            softmax=int(np.random.choice(q_scores.size()[1], p=q_scores_softmax_numpy)),
        )
예제 #4
0
    def policy(
        self,
        states: torch.Tensor,
        possible_actions_with_presence: Tuple[torch.Tensor, torch.Tensor],
    ):
        possible_actions, possible_actions_presence = possible_actions_with_presence
        assert states.size()[0] == 1
        assert possible_actions.size()[1] == self.action_dim
        assert possible_actions.size()[0] == possible_actions_presence.size(
        )[0]

        q_scores = self.predict(states, possible_actions)

        # set impossible actions so low that they can't be picked
        q_scores -= (
            1.0 - possible_actions_presence.reshape(1, -1)  # type: ignore
        ) * 1e10

        q_scores_softmax_numpy = masked_softmax(
            q_scores.reshape(1, -1),
            possible_actions_presence.reshape(1, -1),
            self.trainer.rl_temperature,
        ).numpy()[0]
        if (np.isnan(q_scores_softmax_numpy).any()
                or np.max(q_scores_softmax_numpy) < 1e-3):
            q_scores_softmax_numpy[:] = 1.0 / q_scores_softmax_numpy.shape[0]
        return DqnPolicyActionSet(
            greedy=int(torch.argmax(q_scores)),
            softmax=int(
                np.random.choice(q_scores.size()[1],
                                 p=q_scores_softmax_numpy)),
        )
예제 #5
0
    def policy(self, state: torch.Tensor,
               possible_actions_presence: torch.Tensor) -> DqnPolicyActionSet:
        assert state.size(
        )[0] == 1, "Only pass in one state when getting a policy"
        q_scores = self.predict(state)
        assert q_scores.shape[0] == 1

        # set impossible actions so low that they can't be picked
        q_scores -= (1.0 - possible_actions_presence) * 1e10  # type: ignore

        q_scores_softmax = masked_softmax(
            q_scores, possible_actions_presence,
            self.trainer.rl_temperature).numpy()[0]
        if np.isnan(q_scores_softmax).any() or np.max(q_scores_softmax) < 1e-3:
            q_scores_softmax[:] = 1.0 / q_scores_softmax.shape[0]
        return DqnPolicyActionSet(
            greedy=int(torch.argmax(q_scores)),
            softmax=int(
                np.random.choice(q_scores.size()[1], p=q_scores_softmax)),
        )
예제 #6
0
    def policy(self, state: torch.Tensor,
               possible_actions_presence: torch.Tensor) -> DqnPolicyActionSet:
        assert state.size(
        )[0] == 1, "Only pass in one state when getting a policy"
        assert (self.softmax_temperature is not None
                ), "Please set the softmax temperature before calling policy()"
        state_feature_presence = torch.ones_like(state)
        _, q_scores = self.model((state, state_feature_presence))
        assert q_scores.shape[0] == 1

        # set impossible actions so low that they can't be picked
        q_scores -= (1.0 - possible_actions_presence) * 1e10  # type: ignore

        q_scores_softmax = (masked_softmax(
            q_scores, possible_actions_presence,
            self.softmax_temperature).detach().numpy()[0])
        if np.isnan(q_scores_softmax).any() or np.max(q_scores_softmax) < 1e-3:
            q_scores_softmax[:] = 1.0 / q_scores_softmax.shape[0]
        return DqnPolicyActionSet(
            greedy=int(torch.argmax(q_scores)),
            softmax=int(
                np.random.choice(q_scores.size()[1], p=q_scores_softmax)),
        )
예제 #7
0
    def policy(
        self,
        tiled_states: torch.Tensor,
        possible_actions_with_presence: Tuple[torch.Tensor, torch.Tensor],
    ):
        possible_actions, possible_actions_presence = possible_actions_with_presence
        assert tiled_states.size()[0] == possible_actions.size()[0]
        assert possible_actions.size()[0] == possible_actions_presence.size(
        )[0]
        assert (self.softmax_temperature is not None
                ), "Please set the softmax temperature before calling policy()"

        state_feature_presence = torch.ones_like(tiled_states)
        _, q_scores = self.model((tiled_states, state_feature_presence),
                                 possible_actions_with_presence)
        q_scores = q_scores.reshape(1, -1)

        # set impossible actions so low that they can't be picked
        q_scores -= (
            1.0 - possible_actions_presence.reshape(1, -1)  # type: ignore
        ) * 1e10

        q_scores_softmax_numpy = (masked_softmax(
            q_scores.reshape(1, -1),
            possible_actions_presence.reshape(1, -1),
            self.softmax_temperature,
        ).detach().numpy()[0])
        if (np.isnan(q_scores_softmax_numpy).any()
                or np.max(q_scores_softmax_numpy) < 1e-3):
            q_scores_softmax_numpy[:] = 1.0 / q_scores_softmax_numpy.shape[0]
        return DqnPolicyActionSet(
            greedy=int(torch.argmax(q_scores)),
            softmax=int(
                np.random.choice(q_scores.size()[1],
                                 p=q_scores_softmax_numpy)),
        )
예제 #8
0
    def calculate_cpes(
        self,
        training_batch,
        states,
        next_states,
        all_next_action_scores,
        logged_action_idxs,
        discount_tensor,
        not_done_mask,
    ):
        if not self.calc_cpe_in_training:
            return None, None, None

        if training_batch.extras.metrics is None:
            metrics_reward_concat_real_vals = training_batch.training_input.reward
        else:
            metrics_reward_concat_real_vals = torch.cat(
                (training_batch.training_input.reward,
                 training_batch.extras.metrics),
                dim=1,
            )

        model_propensities_next_states = masked_softmax(
            all_next_action_scores,
            training_batch.training_input.possible_next_actions_mask if
            self.maxq_learning else training_batch.training_input.next_action,
            self.rl_temperature,
        )

        with torch.enable_grad():
            ######### Train separate reward network for CPE evaluation #############
            # FIXME: the reward network should be outputing a tensor, not a q-value object
            reward_estimates = self.reward_network(states).q_values
            reward_estimates_for_logged_actions = reward_estimates.gather(
                1, self.reward_idx_offsets + logged_action_idxs)
            reward_loss = F.mse_loss(reward_estimates_for_logged_actions,
                                     metrics_reward_concat_real_vals)
            reward_loss.backward()
            self._maybe_run_optimizer(self.reward_network_optimizer,
                                      self.minibatches_per_step)

            ######### Train separate q-network for CPE evaluation #############
            metric_q_values = self.q_network_cpe(states).q_values.gather(
                1, self.reward_idx_offsets + logged_action_idxs)
            all_metrics_target_q_values = torch.chunk(
                self.q_network_cpe_target(next_states).q_values.detach(),
                len(self.metrics_to_score),
                dim=1,
            )
            target_metric_q_values = []
            for i, per_metric_target_q_values in enumerate(
                    all_metrics_target_q_values):
                per_metric_next_q_values = torch.sum(
                    per_metric_target_q_values *
                    model_propensities_next_states,
                    1,
                    keepdim=True,
                )
                per_metric_next_q_values = per_metric_next_q_values * not_done_mask
                per_metric_target_q_values = metrics_reward_concat_real_vals[:, i:i + 1] + (
                    discount_tensor * per_metric_next_q_values)
                target_metric_q_values.append(per_metric_target_q_values)

            target_metric_q_values = torch.cat(target_metric_q_values, dim=1)
            metric_q_value_loss = self.q_network_loss(metric_q_values,
                                                      target_metric_q_values)
            metric_q_value_loss.backward()
            self._maybe_run_optimizer(self.q_network_cpe_optimizer,
                                      self.minibatches_per_step)

        # Use the soft update rule to update target network
        self._maybe_soft_update(
            self.q_network_cpe,
            self.q_network_cpe_target,
            self.tau,
            self.minibatches_per_step,
        )

        model_propensities = masked_softmax(
            self.all_action_scores,
            training_batch.training_input.possible_actions_mask
            if self.maxq_learning else training_batch.training_input.action,
            self.rl_temperature,
        )
        model_rewards = reward_estimates[:,
                                         torch.arange(
                                             self.reward_idx_offsets[0],
                                             self.reward_idx_offsets[0] +
                                             self.num_actions,
                                         ), ]
        return reward_loss, model_rewards, model_propensities
예제 #9
0
    def calculate_cpes(
        self,
        training_batch,
        states,
        logged_action_idxs,
        max_q_action_idxs,
        discount_tensor,
        not_done_mask,
    ):
        if not self.calc_cpe_in_training:
            return None, None, None

        if training_batch.extras.metrics is None:
            metrics_reward_concat_real_vals = training_batch.training_input.reward
        else:
            metrics_reward_concat_real_vals = torch.cat(
                (training_batch.training_input.reward,
                 training_batch.extras.metrics),
                dim=1,
            )

        ######### Train separate reward network for CPE evaluation #############
        # FIXME: the reward network should be outputing a tensor, not a q-value object
        reward_estimates = self.reward_network(states).q_values
        reward_estimates_for_logged_actions = reward_estimates.gather(
            1, self.reward_idx_offsets + logged_action_idxs)
        reward_loss = F.mse_loss(reward_estimates_for_logged_actions,
                                 metrics_reward_concat_real_vals)
        self.reward_network_optimizer.zero_grad()
        reward_loss.backward()
        self.reward_network_optimizer.step()

        ######### Train separate q-network for CPE evaluation #############
        metric_q_values = self.q_network_cpe(states).q_values.gather(
            1, self.reward_idx_offsets + logged_action_idxs)
        metric_target_q_values = self.q_network_cpe_target(
            states).q_values.detach()
        max_q_values_metrics = metric_target_q_values.gather(
            1, self.reward_idx_offsets + max_q_action_idxs)
        filtered_max_q_values_metrics = max_q_values_metrics * not_done_mask
        if self.minibatch < self.reward_burnin:
            target_metric_q_values = metrics_reward_concat_real_vals
        else:
            target_metric_q_values = metrics_reward_concat_real_vals + (
                discount_tensor * filtered_max_q_values_metrics)
        metric_q_value_loss = self.q_network_loss(metric_q_values,
                                                  target_metric_q_values)
        self.q_network_cpe.zero_grad()
        metric_q_value_loss.backward()
        self.q_network_cpe_optimizer.step()

        if self.minibatch < self.reward_burnin:
            # Reward burnin: force target network
            self._soft_update(self.q_network_cpe, self.q_network_cpe_target,
                              1.0)
        else:
            # Use the soft update rule to update target network
            self._soft_update(self.q_network_cpe, self.q_network_cpe_target,
                              self.tau)

        model_propensities = masked_softmax(
            self.all_action_scores,
            training_batch.training_input.possible_actions_mask
            if self.maxq_learning else training_batch.training_input.action,
            self.rl_temperature,
        )
        model_rewards = reward_estimates[:,
                                         torch.arange(
                                             self.reward_idx_offsets[0],
                                             self.reward_idx_offsets[0] +
                                             self.num_actions,
                                         ), ]
        return reward_loss, model_rewards, model_propensities
예제 #10
0
    def create_from_tensors(
        cls,
        trainer: RLTrainer,
        mdp_ids: np.ndarray,
        sequence_numbers: torch.Tensor,
        states: torch.Tensor,
        actions: torch.Tensor,
        propensities: torch.Tensor,
        rewards: torch.Tensor,
        possible_actions_state_concat: Optional[torch.Tensor],
        possible_actions_mask: torch.Tensor,
        metrics: Optional[torch.Tensor] = None,
    ):
        with torch.no_grad():
            # Switch to evaluation mode for the network
            old_q_train_state = trainer.q_network.training
            old_reward_train_state = trainer.reward_network.training
            trainer.q_network.train(False)
            trainer.reward_network.train(False)

            if possible_actions_state_concat is not None:
                state_action_pairs = torch.cat((states, actions), dim=1)

                # Parametric actions
                rewards = rewards
                model_values = trainer.q_network(possible_actions_state_concat)
                assert (
                    model_values.shape[0] *
                    model_values.shape[1] == possible_actions_mask.shape[0] *
                    possible_actions_mask.shape[1]), (
                        "Invalid shapes: " + str(model_values.shape) + " != " +
                        str(possible_actions_mask.shape))
                model_values = model_values.reshape(
                    possible_actions_mask.shape)

                model_rewards = trainer.reward_network(
                    possible_actions_state_concat)
                assert (
                    model_rewards.shape[0] *
                    model_rewards.shape[1] == possible_actions_mask.shape[0] *
                    possible_actions_mask.shape[1]), (
                        "Invalid shapes: " + str(model_rewards.shape) +
                        " != " + str(possible_actions_mask.shape))
                model_rewards = model_rewards.reshape(
                    possible_actions_mask.shape)

                model_values_for_logged_action = trainer.q_network(
                    state_action_pairs)
                model_rewards_for_logged_action = trainer.reward_network(
                    state_action_pairs)

                action_mask = (
                    torch.abs(model_values - model_values_for_logged_action) <
                    1e-3).float()

                model_metrics = None
                model_metrics_for_logged_action = None
                model_metrics_values = None
                model_metrics_values_for_logged_action = None
            else:
                action_mask = actions.float()

                # Switch to evaluation mode for the network
                old_q_cpe_train_state = trainer.q_network_cpe.training
                trainer.q_network_cpe.train(False)

                # Discrete actions
                rewards = trainer.boost_rewards(rewards, actions)
                model_values = trainer.get_detached_q_values(states)[0]
                assert model_values.shape == actions.shape, (
                    "Invalid shape: " + str(model_values.shape) + " != " +
                    str(actions.shape))
                assert model_values.shape == possible_actions_mask.shape, (
                    "Invalid shape: " + str(model_values.shape) + " != " +
                    str(possible_actions_mask.shape))
                model_values_for_logged_action = torch.sum(model_values *
                                                           action_mask,
                                                           dim=1,
                                                           keepdim=True)

                rewards_and_metric_rewards = trainer.reward_network(states)

                num_actions = trainer.num_actions

                model_rewards = rewards_and_metric_rewards[:, 0:num_actions]
                assert model_rewards.shape == actions.shape, (
                    "Invalid shape: " + str(model_rewards.shape) + " != " +
                    str(actions.shape))
                model_rewards_for_logged_action = torch.sum(model_rewards *
                                                            action_mask,
                                                            dim=1,
                                                            keepdim=True)

                model_metrics = rewards_and_metric_rewards[:, num_actions:]

                assert model_metrics.shape[1] % num_actions == 0, (
                    "Invalid metrics shape: " + str(model_metrics.shape) +
                    " " + str(num_actions))
                num_metrics = model_metrics.shape[1] // num_actions

                if num_metrics == 0:
                    model_metrics_values = None
                    model_metrics_for_logged_action = None
                    model_metrics_values_for_logged_action = None
                else:
                    model_metrics_values = trainer.q_network_cpe(
                        states)[:, num_actions:]
                    assert model_metrics_values.shape[
                        1] == num_actions * num_metrics, (
                            "Invalid shape: " +
                            str(model_metrics_values.shape[1]) + " != " +
                            str(actions.shape[1] * num_metrics))

                    model_metrics_for_logged_action_list = []
                    model_metrics_values_for_logged_action_list = []
                    for metric_index in range(num_metrics):
                        metric_start = metric_index * num_actions
                        metric_end = (metric_index + 1) * num_actions
                        model_metrics_for_logged_action_list.append(
                            torch.sum(
                                model_metrics[:, metric_start:metric_end] *
                                action_mask,
                                dim=1,
                                keepdim=True,
                            ))

                        model_metrics_values_for_logged_action_list.append(
                            torch.sum(
                                model_metrics_values[:,
                                                     metric_start:metric_end] *
                                action_mask,
                                dim=1,
                                keepdim=True,
                            ))
                    model_metrics_for_logged_action = torch.cat(
                        model_metrics_for_logged_action_list, dim=1)
                    model_metrics_values_for_logged_action = torch.cat(
                        model_metrics_values_for_logged_action_list, dim=1)

                # Switch back to the old mode
                trainer.q_network_cpe.train(old_q_cpe_train_state)

            # Switch back to the old mode
            trainer.q_network.train(old_q_train_state)
            trainer.reward_network.train(old_reward_train_state)

            return cls(
                mdp_id=mdp_ids,
                sequence_number=sequence_numbers,
                logged_propensities=propensities,
                logged_rewards=rewards,
                action_mask=action_mask,
                model_rewards=model_rewards,
                model_rewards_for_logged_action=model_rewards_for_logged_action,
                model_values=model_values,
                model_values_for_logged_action=model_values_for_logged_action,
                model_metrics_values=model_metrics_values,
                model_metrics_values_for_logged_action=
                model_metrics_values_for_logged_action,
                model_propensities=masked_softmax(model_values,
                                                  possible_actions_mask,
                                                  trainer.rl_temperature),
                logged_metrics=metrics,
                model_metrics=model_metrics,
                model_metrics_for_logged_action=model_metrics_for_logged_action,
                # Will compute later
                logged_values=None,
                logged_metrics_values=None,
                possible_actions_state_concat=possible_actions_state_concat,
                possible_actions_mask=possible_actions_mask,
            )
예제 #11
0
    def create_from_tensors(
        cls,
        trainer: RLTrainer,
        mdp_ids: np.ndarray,
        sequence_numbers: torch.Tensor,
        states: Union[mt.State, torch.Tensor],
        actions: Union[mt.Action, torch.Tensor],
        propensities: torch.Tensor,
        rewards: torch.Tensor,
        possible_actions_mask: torch.Tensor,
        possible_actions: Optional[mt.FeatureVector] = None,
        max_num_actions: Optional[int] = None,
        metrics: Optional[torch.Tensor] = None,
    ):
        with torch.no_grad():
            # Switch to evaluation mode for the network
            old_q_train_state = trainer.q_network.training
            old_reward_train_state = trainer.reward_network.training
            trainer.q_network.train(False)
            trainer.reward_network.train(False)

            if max_num_actions:
                # Parametric model CPE
                state_action_pairs = mt.StateAction(state=states, action=actions)
                tiled_state = mt.FeatureVector(
                    states.float_features.repeat(1, max_num_actions).reshape(
                        -1, states.float_features.shape[1]
                    )
                )
                # Get Q-value of action taken
                possible_actions_state_concat = mt.StateAction(
                    state=tiled_state, action=possible_actions
                )

                # Parametric actions
                model_values = trainer.q_network(possible_actions_state_concat).q_value
                assert (
                    model_values.shape[0] * model_values.shape[1]
                    == possible_actions_mask.shape[0] * possible_actions_mask.shape[1]
                ), (
                    "Invalid shapes: "
                    + str(model_values.shape)
                    + " != "
                    + str(possible_actions_mask.shape)
                )
                model_values = model_values.reshape(possible_actions_mask.shape)

                model_rewards = trainer.reward_network(
                    possible_actions_state_concat
                ).q_value
                assert (
                    model_rewards.shape[0] * model_rewards.shape[1]
                    == possible_actions_mask.shape[0] * possible_actions_mask.shape[1]
                ), (
                    "Invalid shapes: "
                    + str(model_rewards.shape)
                    + " != "
                    + str(possible_actions_mask.shape)
                )
                model_rewards = model_rewards.reshape(possible_actions_mask.shape)

                model_values_for_logged_action = trainer.q_network(
                    state_action_pairs
                ).q_value
                model_rewards_for_logged_action = trainer.reward_network(
                    state_action_pairs
                ).q_value

                action_mask = (
                    torch.abs(model_values - model_values_for_logged_action) < 1e-3
                ).float()

                model_metrics = None
                model_metrics_for_logged_action = None
                model_metrics_values = None
                model_metrics_values_for_logged_action = None
            else:
                action_mask = actions.float()

                # Switch to evaluation mode for the network
                old_q_cpe_train_state = trainer.q_network_cpe.training
                trainer.q_network_cpe.train(False)

                # Discrete actions
                rewards = trainer.boost_rewards(rewards, actions)
                model_values = trainer.get_detached_q_values(states)[0]
                assert model_values.shape == actions.shape, (
                    "Invalid shape: "
                    + str(model_values.shape)
                    + " != "
                    + str(actions.shape)
                )
                assert model_values.shape == possible_actions_mask.shape, (
                    "Invalid shape: "
                    + str(model_values.shape)
                    + " != "
                    + str(possible_actions_mask.shape)
                )
                model_values_for_logged_action = torch.sum(
                    model_values * action_mask, dim=1, keepdim=True
                )

                if isinstance(states, mt.State):
                    states = mt.StateInput(state=states)

                rewards_and_metric_rewards = trainer.reward_network(states)

                # In case we reuse the modular for Q-network
                if hasattr(rewards_and_metric_rewards, "q_values"):
                    rewards_and_metric_rewards = rewards_and_metric_rewards.q_values

                num_actions = trainer.num_actions

                model_rewards = rewards_and_metric_rewards[:, 0:num_actions]
                assert model_rewards.shape == actions.shape, (
                    "Invalid shape: "
                    + str(model_rewards.shape)
                    + " != "
                    + str(actions.shape)
                )
                model_rewards_for_logged_action = torch.sum(
                    model_rewards * action_mask, dim=1, keepdim=True
                )

                model_metrics = rewards_and_metric_rewards[:, num_actions:]

                assert model_metrics.shape[1] % num_actions == 0, (
                    "Invalid metrics shape: "
                    + str(model_metrics.shape)
                    + " "
                    + str(num_actions)
                )
                num_metrics = model_metrics.shape[1] // num_actions

                if num_metrics == 0:
                    model_metrics_values = None
                    model_metrics_for_logged_action = None
                    model_metrics_values_for_logged_action = None
                else:
                    model_metrics_values = trainer.q_network_cpe(states)
                    # Backward compatility
                    if hasattr(model_metrics_values, "q_values"):
                        model_metrics_values = model_metrics_values.q_values
                    model_metrics_values = model_metrics_values[:, num_actions:]
                    assert model_metrics_values.shape[1] == num_actions * num_metrics, (
                        "Invalid shape: "
                        + str(model_metrics_values.shape[1])
                        + " != "
                        + str(actions.shape[1] * num_metrics)
                    )

                    model_metrics_for_logged_action_list = []
                    model_metrics_values_for_logged_action_list = []
                    for metric_index in range(num_metrics):
                        metric_start = metric_index * num_actions
                        metric_end = (metric_index + 1) * num_actions
                        model_metrics_for_logged_action_list.append(
                            torch.sum(
                                model_metrics[:, metric_start:metric_end] * action_mask,
                                dim=1,
                                keepdim=True,
                            )
                        )

                        model_metrics_values_for_logged_action_list.append(
                            torch.sum(
                                model_metrics_values[:, metric_start:metric_end]
                                * action_mask,
                                dim=1,
                                keepdim=True,
                            )
                        )
                    model_metrics_for_logged_action = torch.cat(
                        model_metrics_for_logged_action_list, dim=1
                    )
                    model_metrics_values_for_logged_action = torch.cat(
                        model_metrics_values_for_logged_action_list, dim=1
                    )

                # Switch back to the old mode
                trainer.q_network_cpe.train(old_q_cpe_train_state)

            # Switch back to the old mode
            trainer.q_network.train(old_q_train_state)
            trainer.reward_network.train(old_reward_train_state)

            return cls(
                mdp_id=mdp_ids,
                sequence_number=sequence_numbers,
                logged_propensities=propensities,
                logged_rewards=rewards,
                action_mask=action_mask,
                model_rewards=model_rewards,
                model_rewards_for_logged_action=model_rewards_for_logged_action,
                model_values=model_values,
                model_values_for_logged_action=model_values_for_logged_action,
                model_metrics_values=model_metrics_values,
                model_metrics_values_for_logged_action=model_metrics_values_for_logged_action,
                model_propensities=masked_softmax(
                    model_values, possible_actions_mask, trainer.rl_temperature
                ),
                logged_metrics=metrics,
                model_metrics=model_metrics,
                model_metrics_for_logged_action=model_metrics_for_logged_action,
                # Will compute later
                logged_values=None,
                logged_metrics_values=None,
                possible_actions_mask=possible_actions_mask,
            )
예제 #12
0
    def train(self, training_samples: TrainingDataPage):

        if self.minibatch == 0:
            # Assume that the tensors are the right shape after the first minibatch
            assert (training_samples.states.shape[0] == self.minibatch_size
                    ), "Invalid shape: " + str(training_samples.states.shape)
            assert training_samples.actions.shape == torch.Size([
                self.minibatch_size, len(self._actions)
            ]), "Invalid shape: " + str(training_samples.actions.shape)
            assert training_samples.rewards.shape == torch.Size(
                [self.minibatch_size,
                 1]), "Invalid shape: " + str(training_samples.rewards.shape)
            assert (training_samples.next_states.shape ==
                    training_samples.states.shape), "Invalid shape: " + str(
                        training_samples.next_states.shape)
            assert (training_samples.not_terminal.shape ==
                    training_samples.rewards.shape), "Invalid shape: " + str(
                        training_samples.not_terminal.shape)
            if training_samples.possible_next_actions_mask is not None:
                assert (
                    training_samples.possible_next_actions_mask.shape ==
                    training_samples.actions.shape), (
                        "Invalid shape: " +
                        str(training_samples.possible_next_actions_mask.shape))
            if training_samples.propensities is not None:
                assert (training_samples.propensities.shape == training_samples
                        .rewards.shape), "Invalid shape: " + str(
                            training_samples.propensities.shape)
            if training_samples.metrics is not None:
                assert (
                    training_samples.metrics.shape[0] == self.minibatch_size
                ), "Invalid shape: " + str(training_samples.metrics.shape)

        boosted_rewards = self.boost_rewards(training_samples.rewards,
                                             training_samples.actions)

        self.minibatch += 1
        states = training_samples.states.detach().requires_grad_(True)
        actions = training_samples.actions
        rewards = boosted_rewards
        discount_tensor = torch.full(training_samples.time_diffs.shape,
                                     self.gamma).type(self.dtype)
        not_done_mask = training_samples.not_terminal

        if self.use_seq_num_diff_as_time_diff:
            discount_tensor = discount_tensor.pow(training_samples.time_diffs)

        all_next_q_values, all_next_q_values_target = self.get_detached_q_values(
            training_samples.next_states)
        if self.maxq_learning:
            # Compute max a' Q(s', a') over all possible actions using target network
            next_q_values, max_q_action_idxs = self.get_max_q_values_with_target(
                all_next_q_values,
                all_next_q_values_target,
                training_samples.possible_next_actions_mask,
            )
        else:
            # SARSA
            next_q_values, max_q_action_idxs = self.get_max_q_values_with_target(
                all_next_q_values,
                all_next_q_values_target,
                training_samples.next_actions,
            )

        filtered_next_q_vals = next_q_values * not_done_mask

        if self.minibatch < self.reward_burnin:
            target_q_values = rewards
        else:
            target_q_values = rewards + (discount_tensor *
                                         filtered_next_q_vals)

        # Get Q-value of action taken
        all_q_values = self.q_network(states)
        self.all_action_scores = all_q_values.detach()
        q_values = torch.sum(all_q_values * actions, 1, keepdim=True)

        loss = self.q_network_loss(q_values, target_q_values)
        self.loss = loss.detach()

        self.q_network_optimizer.zero_grad()
        loss.backward()
        if self.gradient_handler:
            self.gradient_handler(self.q_network.parameters())
        self.q_network_optimizer.step()

        if self.minibatch < self.reward_burnin:
            # Reward burnin: force target network
            self._soft_update(self.q_network, self.q_network_target, 1.0)
        else:
            # Use the soft update rule to update target network
            self._soft_update(self.q_network, self.q_network_target, self.tau)

        if training_samples.metrics is None:
            metrics_reward_concat_real_vals = training_samples.rewards
        else:
            metrics_reward_concat_real_vals = torch.cat(
                (training_samples.rewards, training_samples.metrics), dim=1)

        ######### Train separate reward network for CPE evaluation #############
        reward_estimates = self.reward_network(states)
        logged_action_idxs = actions.argmax(dim=1, keepdim=True)
        reward_estimates_for_logged_actions = reward_estimates.gather(
            1, self.reward_idx_offsets + logged_action_idxs)
        reward_loss = F.mse_loss(reward_estimates_for_logged_actions,
                                 metrics_reward_concat_real_vals)
        self.reward_network_optimizer.zero_grad()
        reward_loss.backward()
        self.reward_network_optimizer.step()

        ######### Train separate q-network for CPE evaluation #############
        metric_q_values = self.q_network_cpe(states).gather(
            1, self.reward_idx_offsets + logged_action_idxs)
        metric_target_q_values = self.q_network_cpe_target(states).detach()
        max_q_values_metrics = metric_target_q_values.gather(
            1, self.reward_idx_offsets + max_q_action_idxs)
        filtered_max_q_values_metrics = max_q_values_metrics * not_done_mask
        if self.minibatch < self.reward_burnin:
            target_metric_q_values = metrics_reward_concat_real_vals
        else:
            target_metric_q_values = metrics_reward_concat_real_vals + (
                discount_tensor * filtered_max_q_values_metrics)
        metric_q_value_loss = self.q_network_loss(metric_q_values,
                                                  target_metric_q_values)
        self.q_network_cpe.zero_grad()
        metric_q_value_loss.backward()
        self.q_network_cpe_optimizer.step()

        if self.minibatch < self.reward_burnin:
            # Reward burnin: force target network
            self._soft_update(self.q_network_cpe, self.q_network_cpe_target,
                              1.0)
        else:
            # Use the soft update rule to update target network
            self._soft_update(self.q_network_cpe, self.q_network_cpe_target,
                              self.tau)

        model_propensities = masked_softmax(
            self.all_action_scores,
            training_samples.possible_actions_mask,
            self.rl_temperature,
        )

        self.loss_reporter.report(
            td_loss=self.loss,
            reward_loss=reward_loss,
            logged_actions=logged_action_idxs,
            logged_propensities=training_samples.propensities,
            logged_rewards=rewards,
            logged_values=None,  # Compute at end of each epoch for CPE
            model_propensities=model_propensities,
            model_rewards=reward_estimates[:,
                                           torch.arange(
                                               self.reward_idx_offsets[0],
                                               self.reward_idx_offsets[0] +
                                               self.num_actions,
                                           ), ],
            model_values=self.all_action_scores,
            model_values_on_logged_actions=
            None,  # Compute at end of each epoch for CPE
            model_action_idxs=self.get_max_q_values(
                self.all_action_scores,
                training_samples.possible_actions_mask)[1],
        )

        training_metadata = {}
        training_metadata["model_rewards"] = reward_estimates.detach().cpu(
        ).numpy()
        return training_metadata
예제 #13
0
    def create_from_tensors_parametric_dqn(
        cls,
        trainer: ParametricDQNTrainer,
        mdp_ids: np.ndarray,
        sequence_numbers: torch.Tensor,
        states: rlt.PreprocessedFeatureVector,
        actions: rlt.PreprocessedFeatureVector,
        propensities: torch.Tensor,
        rewards: torch.Tensor,
        possible_actions_mask: torch.Tensor,
        possible_actions: rlt.PreprocessedFeatureVector,
        max_num_actions: int,
        metrics: Optional[torch.Tensor] = None,
    ):
        old_q_train_state = trainer.q_network.training
        old_reward_train_state = trainer.reward_network.training
        trainer.q_network.train(False)
        trainer.reward_network.train(False)

        state_action_pairs = rlt.PreprocessedStateAction(state=states,
                                                         action=actions)
        tiled_state = states.float_features.repeat(1, max_num_actions).reshape(
            -1, states.float_features.shape[1])
        assert possible_actions is not None
        # Get Q-value of action taken
        possible_actions_state_concat = rlt.PreprocessedStateAction(
            state=rlt.PreprocessedFeatureVector(float_features=tiled_state),
            action=possible_actions,
        )

        # FIXME: model_values, model_values_for_logged_action, and model_metrics_values
        # should be calculated using q_network_cpe (as in discrete dqn).
        # q_network_cpe has not been added in parametric dqn yet.
        model_values = trainer.q_network(
            possible_actions_state_concat).q_value  # type: ignore
        optimal_q_values, _ = trainer.get_detached_q_values(
            possible_actions_state_concat.state,
            possible_actions_state_concat.action)
        eval_action_idxs = None

        assert (model_values.shape[1] == 1
                and model_values.shape[0] == possible_actions_mask.shape[0] *
                possible_actions_mask.shape[1]), (
                    "Invalid shapes: " + str(model_values.shape) + " != " +
                    str(possible_actions_mask.shape))
        model_values = model_values.reshape(possible_actions_mask.shape)
        optimal_q_values = optimal_q_values.reshape(
            possible_actions_mask.shape)
        model_propensities = masked_softmax(optimal_q_values,
                                            possible_actions_mask,
                                            trainer.rl_temperature)

        rewards_and_metric_rewards = trainer.reward_network(
            possible_actions_state_concat).q_value  # type: ignore
        model_rewards = rewards_and_metric_rewards[:, :1]
        assert (model_rewards.shape[0] *
                model_rewards.shape[1] == possible_actions_mask.shape[0] *
                possible_actions_mask.shape[1]), (
                    "Invalid shapes: " + str(model_rewards.shape) + " != " +
                    str(possible_actions_mask.shape))
        model_rewards = model_rewards.reshape(possible_actions_mask.shape)

        model_metrics = rewards_and_metric_rewards[:, 1:]
        model_metrics = model_metrics.reshape(possible_actions_mask.shape[0],
                                              -1)

        model_values_for_logged_action = trainer.q_network(
            state_action_pairs).q_value
        model_rewards_and_metrics_for_logged_action = trainer.reward_network(
            state_action_pairs).q_value
        model_rewards_for_logged_action = model_rewards_and_metrics_for_logged_action[:, :
                                                                                      1]

        action_dim = possible_actions.float_features.shape[1]
        action_mask = torch.all(
            possible_actions.float_features.view(
                -1, max_num_actions,
                action_dim) == actions.float_features.unsqueeze(dim=1),
            dim=2,
        ).float()
        assert torch.all(action_mask.sum(dim=1) == 1)
        num_metrics = model_metrics.shape[1] // max_num_actions

        model_metrics_values = None
        model_metrics_for_logged_action = None
        model_metrics_values_for_logged_action = None
        if num_metrics > 0:
            # FIXME: calculate model_metrics_values when q_network_cpe is added
            # to parametric dqn
            model_metrics_values = model_values.repeat(1, num_metrics)

        trainer.q_network.train(old_q_train_state)  # type: ignore
        trainer.reward_network.train(old_reward_train_state)  # type: ignore

        return cls(
            mdp_id=mdp_ids,
            sequence_number=sequence_numbers,
            logged_propensities=propensities,
            logged_rewards=rewards,
            action_mask=action_mask,
            model_rewards=model_rewards,
            model_rewards_for_logged_action=model_rewards_for_logged_action,
            model_values=model_values,
            model_values_for_logged_action=model_values_for_logged_action,
            model_metrics_values=model_metrics_values,
            model_metrics_values_for_logged_action=
            model_metrics_values_for_logged_action,
            model_propensities=model_propensities,
            logged_metrics=metrics,
            model_metrics=model_metrics,
            model_metrics_for_logged_action=model_metrics_for_logged_action,
            # Will compute later
            logged_values=None,
            logged_metrics_values=None,
            possible_actions_mask=possible_actions_mask,
            optimal_q_values=optimal_q_values,
            eval_action_idxs=eval_action_idxs,
        )
예제 #14
0
    def create_from_tensors_dqn(
        cls,
        trainer: DQNTrainer,
        mdp_ids: np.ndarray,
        sequence_numbers: torch.Tensor,
        states: rlt.PreprocessedFeatureVector,
        actions: rlt.PreprocessedFeatureVector,
        propensities: torch.Tensor,
        rewards: torch.Tensor,
        possible_actions_mask: torch.Tensor,
        metrics: Optional[torch.Tensor] = None,
    ):
        old_q_train_state = trainer.q_network.training
        old_reward_train_state = trainer.reward_network.training
        old_q_cpe_train_state = trainer.q_network_cpe.training
        trainer.q_network.train(False)
        trainer.reward_network.train(False)
        trainer.q_network_cpe.train(False)

        num_actions = trainer.num_actions
        action_mask = actions.float()  # type: ignore

        rewards = trainer.boost_rewards(rewards, actions)  # type: ignore
        model_values = trainer.q_network_cpe(
            rlt.PreprocessedState(state=states)).q_values[:, 0:num_actions]
        optimal_q_values, _ = trainer.get_detached_q_values(
            states  # type: ignore
        )
        eval_action_idxs = trainer.get_max_q_values(  # type: ignore
            optimal_q_values, possible_actions_mask)[1]
        model_propensities = masked_softmax(optimal_q_values,
                                            possible_actions_mask,
                                            trainer.rl_temperature)
        assert model_values.shape == actions.shape, (  # type: ignore
            "Invalid shape: " + str(model_values.shape)  # type: ignore
            + " != " + str(actions.shape)  # type: ignore
        )
        assert model_values.shape == possible_actions_mask.shape, (  # type: ignore
            "Invalid shape: " + str(model_values.shape)  # type: ignore
            + " != " + str(possible_actions_mask.shape)  # type: ignore
        )
        model_values_for_logged_action = torch.sum(model_values * action_mask,
                                                   dim=1,
                                                   keepdim=True)

        rewards_and_metric_rewards = trainer.reward_network(
            rlt.PreprocessedState(state=states))

        # In case we reuse the modular for Q-network
        if hasattr(rewards_and_metric_rewards, "q_values"):
            rewards_and_metric_rewards = rewards_and_metric_rewards.q_values

        model_rewards = rewards_and_metric_rewards[:, 0:num_actions]
        assert model_rewards.shape == actions.shape, (  # type: ignore
            "Invalid shape: " + str(model_rewards.shape)  # type: ignore
            + " != " + str(actions.shape)  # type: ignore
        )
        model_rewards_for_logged_action = torch.sum(model_rewards *
                                                    action_mask,
                                                    dim=1,
                                                    keepdim=True)

        model_metrics = rewards_and_metric_rewards[:, num_actions:]

        assert model_metrics.shape[1] % num_actions == 0, (
            "Invalid metrics shape: " + str(model_metrics.shape) + " " +
            str(num_actions))
        num_metrics = model_metrics.shape[1] // num_actions

        if num_metrics == 0:
            model_metrics_values = None
            model_metrics_for_logged_action = None
            model_metrics_values_for_logged_action = None
        else:
            model_metrics_values = trainer.q_network_cpe(
                rlt.PreprocessedState(state=states))
            # Backward compatility
            if hasattr(model_metrics_values, "q_values"):
                model_metrics_values = model_metrics_values.q_values
            model_metrics_values = model_metrics_values[:, num_actions:]
            assert (model_metrics_values.shape[1] == num_actions *
                    num_metrics), (  # type: ignore
                        "Invalid shape: " +
                        str(model_metrics_values.shape[1])  # type: ignore
                        + " != " +
                        str(actions.shape[1] * num_metrics)  # type: ignore
                    )

            model_metrics_for_logged_action_list = []
            model_metrics_values_for_logged_action_list = []
            for metric_index in range(num_metrics):
                metric_start = metric_index * num_actions
                metric_end = (metric_index + 1) * num_actions
                model_metrics_for_logged_action_list.append(
                    torch.sum(
                        model_metrics[:, metric_start:metric_end] *
                        action_mask,
                        dim=1,
                        keepdim=True,
                    ))

                model_metrics_values_for_logged_action_list.append(
                    torch.sum(
                        model_metrics_values[:, metric_start:metric_end] *
                        action_mask,
                        dim=1,
                        keepdim=True,
                    ))
            model_metrics_for_logged_action = torch.cat(
                model_metrics_for_logged_action_list, dim=1)
            model_metrics_values_for_logged_action = torch.cat(
                model_metrics_values_for_logged_action_list, dim=1)

        trainer.q_network_cpe.train(old_q_cpe_train_state)  # type: ignore
        trainer.q_network.train(old_q_train_state)  # type: ignore
        trainer.reward_network.train(old_reward_train_state)  # type: ignore

        return cls(
            mdp_id=mdp_ids,
            sequence_number=sequence_numbers,
            logged_propensities=propensities,
            logged_rewards=rewards,
            action_mask=action_mask,
            model_rewards=model_rewards,
            model_rewards_for_logged_action=model_rewards_for_logged_action,
            model_values=model_values,
            model_values_for_logged_action=model_values_for_logged_action,
            model_metrics_values=model_metrics_values,
            model_metrics_values_for_logged_action=
            model_metrics_values_for_logged_action,
            model_propensities=model_propensities,
            logged_metrics=metrics,
            model_metrics=model_metrics,
            model_metrics_for_logged_action=model_metrics_for_logged_action,
            # Will compute later
            logged_values=None,
            logged_metrics_values=None,
            possible_actions_mask=possible_actions_mask,
            optimal_q_values=optimal_q_values,
            eval_action_idxs=eval_action_idxs,
        )
예제 #15
0
    def create_from_tensors(
        cls,
        trainer: DQNTrainer,
        mdp_ids: np.ndarray,
        sequence_numbers: torch.Tensor,
        states: rlt.PreprocessedFeatureVector,
        actions: rlt.PreprocessedFeatureVector,
        propensities: torch.Tensor,
        rewards: torch.Tensor,
        possible_actions_mask: torch.Tensor,
        possible_actions: Optional[rlt.PreprocessedFeatureVector] = None,
        max_num_actions: Optional[int] = None,
        metrics: Optional[torch.Tensor] = None,
    ):
        # Switch to evaluation mode for the network
        old_q_train_state = trainer.q_network.training
        old_reward_train_state = trainer.reward_network.training
        trainer.q_network.train(False)
        trainer.reward_network.train(False)

        if max_num_actions:
            # Parametric model CPE
            state_action_pairs = rlt.PreprocessedStateAction(
                state=states, action=actions
            )
            tiled_state = states.float_features.repeat(1, max_num_actions).reshape(
                -1, states.float_features.shape[1]
            )
            assert possible_actions is not None
            # Get Q-value of action taken
            possible_actions_state_concat = rlt.PreprocessedStateAction(
                state=rlt.PreprocessedFeatureVector(float_features=tiled_state),
                action=possible_actions,
            )

            # Parametric actions
            # FIXME: model_values and model propensities should be calculated
            # as in discrete dqn model
            model_values = trainer.q_network(
                possible_actions_state_concat
            ).q_value  # type: ignore
            optimal_q_values = model_values
            eval_action_idxs = None

            assert (
                model_values.shape[0] * model_values.shape[1]
                == possible_actions_mask.shape[0] * possible_actions_mask.shape[1]
            ), (
                "Invalid shapes: "
                + str(model_values.shape)
                + " != "
                + str(possible_actions_mask.shape)
            )
            model_values = model_values.reshape(possible_actions_mask.shape)
            model_propensities = masked_softmax(
                model_values, possible_actions_mask, trainer.rl_temperature
            )

            model_rewards = trainer.reward_network(
                possible_actions_state_concat
            ).q_value  # type: ignore
            assert (
                model_rewards.shape[0] * model_rewards.shape[1]
                == possible_actions_mask.shape[0] * possible_actions_mask.shape[1]
            ), (
                "Invalid shapes: "
                + str(model_rewards.shape)
                + " != "
                + str(possible_actions_mask.shape)
            )
            model_rewards = model_rewards.reshape(possible_actions_mask.shape)

            model_values_for_logged_action = trainer.q_network(
                state_action_pairs
            ).q_value
            model_rewards_for_logged_action = trainer.reward_network(
                state_action_pairs
            ).q_value

            action_mask = (
                torch.abs(model_values - model_values_for_logged_action) < 1e-3
            ).float()

            model_metrics = None
            model_metrics_for_logged_action = None
            model_metrics_values = None
            model_metrics_values_for_logged_action = None
        else:
            num_actions = trainer.num_actions
            action_mask = actions.float()  # type: ignore

            # Switch to evaluation mode for the network
            old_q_cpe_train_state = trainer.q_network_cpe.training
            trainer.q_network_cpe.train(False)

            # Discrete actions
            rewards = trainer.boost_rewards(rewards, actions)  # type: ignore
            model_values = trainer.q_network_cpe(
                rlt.PreprocessedState(state=states)
            ).q_values[:, 0:num_actions]
            optimal_q_values = trainer.get_detached_q_values(
                states  # type: ignore
            )[  # type: ignore
                0
            ]  # type: ignore
            eval_action_idxs = trainer.get_max_q_values(  # type: ignore
                optimal_q_values, possible_actions_mask
            )[1]
            model_propensities = masked_softmax(
                optimal_q_values, possible_actions_mask, trainer.rl_temperature
            )
            assert model_values.shape == actions.shape, (  # type: ignore
                "Invalid shape: "
                + str(model_values.shape)  # type: ignore
                + " != "
                + str(actions.shape)  # type: ignore
            )
            assert model_values.shape == possible_actions_mask.shape, (  # type: ignore
                "Invalid shape: "
                + str(model_values.shape)  # type: ignore
                + " != "
                + str(possible_actions_mask.shape)  # type: ignore
            )
            model_values_for_logged_action = torch.sum(
                model_values * action_mask, dim=1, keepdim=True
            )

            rewards_and_metric_rewards = trainer.reward_network(
                rlt.PreprocessedState(state=states)
            )

            # In case we reuse the modular for Q-network
            if hasattr(rewards_and_metric_rewards, "q_values"):
                rewards_and_metric_rewards = rewards_and_metric_rewards.q_values

            model_rewards = rewards_and_metric_rewards[:, 0:num_actions]
            assert model_rewards.shape == actions.shape, (  # type: ignore
                "Invalid shape: "
                + str(model_rewards.shape)  # type: ignore
                + " != "
                + str(actions.shape)  # type: ignore
            )
            model_rewards_for_logged_action = torch.sum(
                model_rewards * action_mask, dim=1, keepdim=True
            )

            model_metrics = rewards_and_metric_rewards[:, num_actions:]

            assert model_metrics.shape[1] % num_actions == 0, (
                "Invalid metrics shape: "
                + str(model_metrics.shape)
                + " "
                + str(num_actions)
            )
            num_metrics = model_metrics.shape[1] // num_actions

            if num_metrics == 0:
                model_metrics_values = None
                model_metrics_for_logged_action = None
                model_metrics_values_for_logged_action = None
            else:
                model_metrics_values = trainer.q_network_cpe(
                    rlt.PreprocessedState(state=states)
                )
                # Backward compatility
                if hasattr(model_metrics_values, "q_values"):
                    model_metrics_values = model_metrics_values.q_values
                model_metrics_values = model_metrics_values[:, num_actions:]
                assert (
                    model_metrics_values.shape[1] == num_actions * num_metrics
                ), (  # type: ignore
                    "Invalid shape: "
                    + str(model_metrics_values.shape[1])  # type: ignore
                    + " != "
                    + str(actions.shape[1] * num_metrics)  # type: ignore
                )

                model_metrics_for_logged_action_list = []
                model_metrics_values_for_logged_action_list = []
                for metric_index in range(num_metrics):
                    metric_start = metric_index * num_actions
                    metric_end = (metric_index + 1) * num_actions
                    model_metrics_for_logged_action_list.append(
                        torch.sum(
                            model_metrics[:, metric_start:metric_end] * action_mask,
                            dim=1,
                            keepdim=True,
                        )
                    )

                    model_metrics_values_for_logged_action_list.append(
                        torch.sum(
                            model_metrics_values[:, metric_start:metric_end]
                            * action_mask,
                            dim=1,
                            keepdim=True,
                        )
                    )
                model_metrics_for_logged_action = torch.cat(
                    model_metrics_for_logged_action_list, dim=1
                )
                model_metrics_values_for_logged_action = torch.cat(
                    model_metrics_values_for_logged_action_list, dim=1
                )

            # Switch back to the old mode
            trainer.q_network_cpe.train(old_q_cpe_train_state)  # type: ignore

        # Switch back to the old mode
        trainer.q_network.train(old_q_train_state)  # type: ignore
        trainer.reward_network.train(old_reward_train_state)  # type: ignore

        return cls(
            mdp_id=mdp_ids,
            sequence_number=sequence_numbers,
            logged_propensities=propensities,
            logged_rewards=rewards,
            action_mask=action_mask,
            model_rewards=model_rewards,
            model_rewards_for_logged_action=model_rewards_for_logged_action,
            model_values=model_values,
            model_values_for_logged_action=model_values_for_logged_action,
            model_metrics_values=model_metrics_values,
            model_metrics_values_for_logged_action=model_metrics_values_for_logged_action,
            model_propensities=model_propensities,
            logged_metrics=metrics,
            model_metrics=model_metrics,
            model_metrics_for_logged_action=model_metrics_for_logged_action,
            # Will compute later
            logged_values=None,
            logged_metrics_values=None,
            possible_actions_mask=possible_actions_mask,
            optimal_q_values=optimal_q_values,
            eval_action_idxs=eval_action_idxs,
        )
예제 #16
0
    def create_from_tensors(
        cls,
        trainer: RLTrainer,
        mdp_ids: np.ndarray,
        sequence_numbers: torch.Tensor,
        states: torch.Tensor,
        actions: torch.Tensor,
        propensities: torch.Tensor,
        rewards: torch.Tensor,
        possible_actions_state_concat: Optional[torch.Tensor],
        possible_actions_mask: torch.Tensor,
        metrics: Optional[torch.Tensor] = None,
    ):
        with torch.no_grad():
            # Switch to evaluation mode for the network
            old_q_train_state = trainer.q_network.training
            old_reward_train_state = trainer.reward_network.training
            trainer.q_network.train(False)
            trainer.reward_network.train(False)

            if possible_actions_state_concat is not None:
                state_action_pairs = torch.cat((states, actions), dim=1)

                # Parametric actions
                rewards = rewards
                model_values = trainer.q_network(possible_actions_state_concat)
                assert (
                    model_values.shape[0] * model_values.shape[1]
                    == possible_actions_mask.shape[0] * possible_actions_mask.shape[1]
                ), (
                    "Invalid shapes: "
                    + str(model_values.shape)
                    + " != "
                    + str(possible_actions_mask.shape)
                )
                model_values = model_values.reshape(possible_actions_mask.shape)

                model_rewards = trainer.reward_network(possible_actions_state_concat)
                assert (
                    model_rewards.shape[0] * model_rewards.shape[1]
                    == possible_actions_mask.shape[0] * possible_actions_mask.shape[1]
                ), (
                    "Invalid shapes: "
                    + str(model_rewards.shape)
                    + " != "
                    + str(possible_actions_mask.shape)
                )
                model_rewards = model_rewards.reshape(possible_actions_mask.shape)

                model_values_for_logged_action = trainer.q_network(state_action_pairs)
                model_rewards_for_logged_action = trainer.reward_network(
                    state_action_pairs
                )

                action_mask = (
                    torch.abs(model_values - model_values_for_logged_action) < 1e-3
                ).float()

                model_metrics = None
                model_metrics_for_logged_action = None
                model_metrics_values = None
                model_metrics_values_for_logged_action = None
            else:
                action_mask = actions.float()

                # Switch to evaluation mode for the network
                old_q_cpe_train_state = trainer.q_network_cpe.training
                trainer.q_network_cpe.train(False)

                # Discrete actions
                rewards = trainer.boost_rewards(rewards, actions)
                model_values = trainer.get_detached_q_values(states)[0]
                assert model_values.shape == actions.shape, (
                    "Invalid shape: "
                    + str(model_values.shape)
                    + " != "
                    + str(actions.shape)
                )
                assert model_values.shape == possible_actions_mask.shape, (
                    "Invalid shape: "
                    + str(model_values.shape)
                    + " != "
                    + str(possible_actions_mask.shape)
                )
                model_values_for_logged_action = torch.sum(
                    model_values * action_mask, dim=1, keepdim=True
                )

                rewards_and_metric_rewards = trainer.reward_network(states)

                num_actions = trainer.num_actions

                model_rewards = rewards_and_metric_rewards[:, 0:num_actions]
                assert model_rewards.shape == actions.shape, (
                    "Invalid shape: "
                    + str(model_rewards.shape)
                    + " != "
                    + str(actions.shape)
                )
                model_rewards_for_logged_action = torch.sum(
                    model_rewards * action_mask, dim=1, keepdim=True
                )

                model_metrics = rewards_and_metric_rewards[:, num_actions:]

                assert model_metrics.shape[1] % num_actions == 0, (
                    "Invalid metrics shape: "
                    + str(model_metrics.shape)
                    + " "
                    + str(num_actions)
                )
                num_metrics = model_metrics.shape[1] // num_actions

                if num_metrics == 0:
                    model_metrics_values = None
                    model_metrics_for_logged_action = None
                    model_metrics_values_for_logged_action = None
                else:
                    model_metrics_values = trainer.q_network_cpe(states)[
                        :, num_actions:
                    ]
                    assert model_metrics_values.shape[1] == num_actions * num_metrics, (
                        "Invalid shape: "
                        + str(model_metrics_values.shape[1])
                        + " != "
                        + str(actions.shape[1] * num_metrics)
                    )

                    model_metrics_for_logged_action_list = []
                    model_metrics_values_for_logged_action_list = []
                    for metric_index in range(num_metrics):
                        metric_start = metric_index * num_actions
                        metric_end = (metric_index + 1) * num_actions
                        model_metrics_for_logged_action_list.append(
                            torch.sum(
                                model_metrics[:, metric_start:metric_end] * action_mask,
                                dim=1,
                                keepdim=True,
                            )
                        )

                        model_metrics_values_for_logged_action_list.append(
                            torch.sum(
                                model_metrics_values[:, metric_start:metric_end]
                                * action_mask,
                                dim=1,
                                keepdim=True,
                            )
                        )
                    model_metrics_for_logged_action = torch.cat(
                        model_metrics_for_logged_action_list, dim=1
                    )
                    model_metrics_values_for_logged_action = torch.cat(
                        model_metrics_values_for_logged_action_list, dim=1
                    )

                # Switch back to the old mode
                trainer.q_network_cpe.train(old_q_cpe_train_state)

            # Switch back to the old mode
            trainer.q_network.train(old_q_train_state)
            trainer.reward_network.train(old_reward_train_state)

            return cls(
                mdp_id=mdp_ids,
                sequence_number=sequence_numbers,
                logged_propensities=propensities,
                logged_rewards=rewards,
                action_mask=action_mask,
                model_rewards=model_rewards,
                model_rewards_for_logged_action=model_rewards_for_logged_action,
                model_values=model_values,
                model_values_for_logged_action=model_values_for_logged_action,
                model_metrics_values=model_metrics_values,
                model_metrics_values_for_logged_action=model_metrics_values_for_logged_action,
                model_propensities=masked_softmax(
                    model_values, possible_actions_mask, trainer.rl_temperature
                ),
                logged_metrics=metrics,
                model_metrics=model_metrics,
                model_metrics_for_logged_action=model_metrics_for_logged_action,
                # Will compute later
                logged_values=None,
                logged_metrics_values=None,
                possible_actions_state_concat=possible_actions_state_concat,
                possible_actions_mask=possible_actions_mask,
            )