コード例 #1
0
    def policy_given_q_values(
        q_scores: torch.Tensor,
        action_names: List[str],
        softmax_temperature: float,
        possible_actions_presence: Optional[torch.Tensor] = None,
    ) -> DqnPolicyActionSet:
        assert q_scores.shape[0] == 1 and len(q_scores.shape) == 2

        if possible_actions_presence is None:
            possible_actions_presence = torch.ones_like(q_scores)
        possible_actions_presence = possible_actions_presence.reshape(1, -1)
        assert possible_actions_presence.shape == q_scores.shape

        # set impossible actions so low that they can't be picked
        q_scores -= (1.0 - possible_actions_presence) * 1e10  # type: ignore

        q_scores_softmax = (
            masked_softmax(q_scores, possible_actions_presence, softmax_temperature)
            .detach()
            .numpy()[0]
        )
        if np.isnan(q_scores_softmax).any() or np.max(q_scores_softmax) < 1e-3:
            q_scores_softmax[:] = 1.0 / q_scores_softmax.shape[0]
        greedy_act_idx = int(torch.argmax(q_scores))
        softmax_act_idx = int(np.random.choice(q_scores.size()[1], p=q_scores_softmax))

        return DqnPolicyActionSet(
            greedy=greedy_act_idx,
            softmax=softmax_act_idx,
            greedy_act_name=action_names[greedy_act_idx],
            softmax_act_name=action_names[softmax_act_idx],
        )
コード例 #2
0
    def policy_given_q_values(
        q_scores: torch.Tensor,
        softmax_temperature: float,
        possible_actions_presence: torch.Tensor,
    ) -> DqnPolicyActionSet:
        assert q_scores.shape[0] == 1 and len(q_scores.shape) == 2
        possible_actions_presence = possible_actions_presence.reshape(1, -1)
        assert possible_actions_presence.shape == q_scores.shape

        # set impossible actions so low that they can't be picked
        q_scores -= (1.0 - possible_actions_presence) * 1e10

        q_scores_softmax_numpy = (
            masked_softmax(
                q_scores.reshape(1, -1), possible_actions_presence, softmax_temperature
            )
            .detach()
            .numpy()[0]
        )
        if (
            np.isnan(q_scores_softmax_numpy).any()
            or np.max(q_scores_softmax_numpy) < 1e-3
        ):
            q_scores_softmax_numpy[:] = 1.0 / q_scores_softmax_numpy.shape[0]

        return DqnPolicyActionSet(
            greedy=int(torch.argmax(q_scores)),
            softmax=int(np.random.choice(q_scores.size()[1], p=q_scores_softmax_numpy)),
        )
コード例 #3
0
    def policy(
        self,
        states: torch.Tensor,
        possible_actions_with_presence: Tuple[torch.Tensor, torch.Tensor],
    ):
        possible_actions, possible_actions_presence = possible_actions_with_presence
        assert states.size()[0] == 1
        assert possible_actions.size()[1] == self.action_dim
        assert possible_actions.size()[0] == possible_actions_presence.size(
        )[0]

        q_scores = self.predict(states, possible_actions)

        # set impossible actions so low that they can't be picked
        q_scores -= (
            1.0 - possible_actions_presence.reshape(1, -1)  # type: ignore
        ) * 1e10

        q_scores_softmax_numpy = masked_softmax(
            q_scores.reshape(1, -1),
            possible_actions_presence.reshape(1, -1),
            self.trainer.rl_temperature,
        ).numpy()[0]
        if (np.isnan(q_scores_softmax_numpy).any()
                or np.max(q_scores_softmax_numpy) < 1e-3):
            q_scores_softmax_numpy[:] = 1.0 / q_scores_softmax_numpy.shape[0]
        return DqnPolicyActionSet(
            greedy=int(torch.argmax(q_scores)),
            softmax=int(
                np.random.choice(q_scores.size()[1],
                                 p=q_scores_softmax_numpy)),
        )
コード例 #4
0
    def policy(self, state: torch.Tensor,
               possible_actions_presence: torch.Tensor) -> DqnPolicyActionSet:
        assert state.size(
        )[0] == 1, "Only pass in one state when getting a policy"
        q_scores = self.predict(state)
        assert q_scores.shape[0] == 1

        # set impossible actions so low that they can't be picked
        q_scores -= (1.0 - possible_actions_presence) * 1e10  # type: ignore

        q_scores_softmax = masked_softmax(
            q_scores, possible_actions_presence,
            self.trainer.rl_temperature).numpy()[0]
        if np.isnan(q_scores_softmax).any() or np.max(q_scores_softmax) < 1e-3:
            q_scores_softmax[:] = 1.0 / q_scores_softmax.shape[0]
        return DqnPolicyActionSet(
            greedy=int(torch.argmax(q_scores)),
            softmax=int(
                np.random.choice(q_scores.size()[1], p=q_scores_softmax)),
        )
コード例 #5
0
    def policy(self, state: torch.Tensor,
               possible_actions_presence: torch.Tensor) -> DqnPolicyActionSet:
        assert state.size(
        )[0] == 1, "Only pass in one state when getting a policy"
        assert (self.softmax_temperature is not None
                ), "Please set the softmax temperature before calling policy()"
        state_feature_presence = torch.ones_like(state)
        _, q_scores = self.model((state, state_feature_presence))
        assert q_scores.shape[0] == 1

        # set impossible actions so low that they can't be picked
        q_scores -= (1.0 - possible_actions_presence) * 1e10  # type: ignore

        q_scores_softmax = (masked_softmax(
            q_scores, possible_actions_presence,
            self.softmax_temperature).detach().numpy()[0])
        if np.isnan(q_scores_softmax).any() or np.max(q_scores_softmax) < 1e-3:
            q_scores_softmax[:] = 1.0 / q_scores_softmax.shape[0]
        return DqnPolicyActionSet(
            greedy=int(torch.argmax(q_scores)),
            softmax=int(
                np.random.choice(q_scores.size()[1], p=q_scores_softmax)),
        )
コード例 #6
0
    def policy(
        self,
        tiled_states: torch.Tensor,
        possible_actions_with_presence: Tuple[torch.Tensor, torch.Tensor],
    ):
        possible_actions, possible_actions_presence = possible_actions_with_presence
        assert tiled_states.size()[0] == possible_actions.size()[0]
        assert possible_actions.size()[0] == possible_actions_presence.size(
        )[0]
        assert (self.softmax_temperature is not None
                ), "Please set the softmax temperature before calling policy()"

        state_feature_presence = torch.ones_like(tiled_states)
        _, q_scores = self.model((tiled_states, state_feature_presence),
                                 possible_actions_with_presence)
        q_scores = q_scores.reshape(1, -1)

        # set impossible actions so low that they can't be picked
        q_scores -= (
            1.0 - possible_actions_presence.reshape(1, -1)  # type: ignore
        ) * 1e10

        q_scores_softmax_numpy = (masked_softmax(
            q_scores.reshape(1, -1),
            possible_actions_presence.reshape(1, -1),
            self.softmax_temperature,
        ).detach().numpy()[0])
        if (np.isnan(q_scores_softmax_numpy).any()
                or np.max(q_scores_softmax_numpy) < 1e-3):
            q_scores_softmax_numpy[:] = 1.0 / q_scores_softmax_numpy.shape[0]
        return DqnPolicyActionSet(
            greedy=int(torch.argmax(q_scores)),
            softmax=int(
                np.random.choice(q_scores.size()[1],
                                 p=q_scores_softmax_numpy)),
        )