def policy_given_q_values( q_scores: torch.Tensor, action_names: List[str], softmax_temperature: float, possible_actions_presence: Optional[torch.Tensor] = None, ) -> DqnPolicyActionSet: assert q_scores.shape[0] == 1 and len(q_scores.shape) == 2 if possible_actions_presence is None: possible_actions_presence = torch.ones_like(q_scores) possible_actions_presence = possible_actions_presence.reshape(1, -1) assert possible_actions_presence.shape == q_scores.shape # set impossible actions so low that they can't be picked q_scores -= (1.0 - possible_actions_presence) * 1e10 # type: ignore q_scores_softmax = ( masked_softmax(q_scores, possible_actions_presence, softmax_temperature) .detach() .numpy()[0] ) if np.isnan(q_scores_softmax).any() or np.max(q_scores_softmax) < 1e-3: q_scores_softmax[:] = 1.0 / q_scores_softmax.shape[0] greedy_act_idx = int(torch.argmax(q_scores)) softmax_act_idx = int(np.random.choice(q_scores.size()[1], p=q_scores_softmax)) return DqnPolicyActionSet( greedy=greedy_act_idx, softmax=softmax_act_idx, greedy_act_name=action_names[greedy_act_idx], softmax_act_name=action_names[softmax_act_idx], )
def policy_given_q_values( q_scores: torch.Tensor, softmax_temperature: float, possible_actions_presence: torch.Tensor, ) -> DqnPolicyActionSet: assert q_scores.shape[0] == 1 and len(q_scores.shape) == 2 possible_actions_presence = possible_actions_presence.reshape(1, -1) assert possible_actions_presence.shape == q_scores.shape # set impossible actions so low that they can't be picked q_scores -= (1.0 - possible_actions_presence) * 1e10 q_scores_softmax_numpy = ( masked_softmax( q_scores.reshape(1, -1), possible_actions_presence, softmax_temperature ) .detach() .numpy()[0] ) if ( np.isnan(q_scores_softmax_numpy).any() or np.max(q_scores_softmax_numpy) < 1e-3 ): q_scores_softmax_numpy[:] = 1.0 / q_scores_softmax_numpy.shape[0] return DqnPolicyActionSet( greedy=int(torch.argmax(q_scores)), softmax=int(np.random.choice(q_scores.size()[1], p=q_scores_softmax_numpy)), )
def policy( self, states: torch.Tensor, possible_actions_with_presence: Tuple[torch.Tensor, torch.Tensor], ): possible_actions, possible_actions_presence = possible_actions_with_presence assert states.size()[0] == 1 assert possible_actions.size()[1] == self.action_dim assert possible_actions.size()[0] == possible_actions_presence.size( )[0] q_scores = self.predict(states, possible_actions) # set impossible actions so low that they can't be picked q_scores -= ( 1.0 - possible_actions_presence.reshape(1, -1) # type: ignore ) * 1e10 q_scores_softmax_numpy = masked_softmax( q_scores.reshape(1, -1), possible_actions_presence.reshape(1, -1), self.trainer.rl_temperature, ).numpy()[0] if (np.isnan(q_scores_softmax_numpy).any() or np.max(q_scores_softmax_numpy) < 1e-3): q_scores_softmax_numpy[:] = 1.0 / q_scores_softmax_numpy.shape[0] return DqnPolicyActionSet( greedy=int(torch.argmax(q_scores)), softmax=int( np.random.choice(q_scores.size()[1], p=q_scores_softmax_numpy)), )
def policy(self, state: torch.Tensor, possible_actions_presence: torch.Tensor) -> DqnPolicyActionSet: assert state.size( )[0] == 1, "Only pass in one state when getting a policy" q_scores = self.predict(state) assert q_scores.shape[0] == 1 # set impossible actions so low that they can't be picked q_scores -= (1.0 - possible_actions_presence) * 1e10 # type: ignore q_scores_softmax = masked_softmax( q_scores, possible_actions_presence, self.trainer.rl_temperature).numpy()[0] if np.isnan(q_scores_softmax).any() or np.max(q_scores_softmax) < 1e-3: q_scores_softmax[:] = 1.0 / q_scores_softmax.shape[0] return DqnPolicyActionSet( greedy=int(torch.argmax(q_scores)), softmax=int( np.random.choice(q_scores.size()[1], p=q_scores_softmax)), )
def policy(self, state: torch.Tensor, possible_actions_presence: torch.Tensor) -> DqnPolicyActionSet: assert state.size( )[0] == 1, "Only pass in one state when getting a policy" assert (self.softmax_temperature is not None ), "Please set the softmax temperature before calling policy()" state_feature_presence = torch.ones_like(state) _, q_scores = self.model((state, state_feature_presence)) assert q_scores.shape[0] == 1 # set impossible actions so low that they can't be picked q_scores -= (1.0 - possible_actions_presence) * 1e10 # type: ignore q_scores_softmax = (masked_softmax( q_scores, possible_actions_presence, self.softmax_temperature).detach().numpy()[0]) if np.isnan(q_scores_softmax).any() or np.max(q_scores_softmax) < 1e-3: q_scores_softmax[:] = 1.0 / q_scores_softmax.shape[0] return DqnPolicyActionSet( greedy=int(torch.argmax(q_scores)), softmax=int( np.random.choice(q_scores.size()[1], p=q_scores_softmax)), )
def policy( self, tiled_states: torch.Tensor, possible_actions_with_presence: Tuple[torch.Tensor, torch.Tensor], ): possible_actions, possible_actions_presence = possible_actions_with_presence assert tiled_states.size()[0] == possible_actions.size()[0] assert possible_actions.size()[0] == possible_actions_presence.size( )[0] assert (self.softmax_temperature is not None ), "Please set the softmax temperature before calling policy()" state_feature_presence = torch.ones_like(tiled_states) _, q_scores = self.model((tiled_states, state_feature_presence), possible_actions_with_presence) q_scores = q_scores.reshape(1, -1) # set impossible actions so low that they can't be picked q_scores -= ( 1.0 - possible_actions_presence.reshape(1, -1) # type: ignore ) * 1e10 q_scores_softmax_numpy = (masked_softmax( q_scores.reshape(1, -1), possible_actions_presence.reshape(1, -1), self.softmax_temperature, ).detach().numpy()[0]) if (np.isnan(q_scores_softmax_numpy).any() or np.max(q_scores_softmax_numpy) < 1e-3): q_scores_softmax_numpy[:] = 1.0 / q_scores_softmax_numpy.shape[0] return DqnPolicyActionSet( greedy=int(torch.argmax(q_scores)), softmax=int( np.random.choice(q_scores.size()[1], p=q_scores_softmax_numpy)), )