def get_exploration_action( self, action_distribution: ActionDistribution, timestep: Union[int, TensorType], explore: bool = True, ): assert (self.framework == "torch" ), "ERROR: SlateSoftQ only supports torch so far!" cls = type(action_distribution) # Re-create the action distribution with the correct temperature # applied. action_distribution = cls(action_distribution.inputs, self.model, temperature=self.temperature) batch_size = action_distribution.inputs.size()[0] action_logp = torch.zeros(batch_size, dtype=torch.float) self.last_timestep = timestep # Explore. if explore: # Return stochastic sample over (q-value) logits. action = action_distribution.sample() # Return the deterministic "sample" (argmax) over (q-value) logits. else: action = action_distribution.deterministic_sample() return action, action_logp
def _get_torch_exploration_action(self, action_dist: ActionDistribution, timestep: Union[TensorType, int], explore: Union[TensorType, bool]): # Set last timestep or (if not given) increase by one. self.last_timestep = timestep if timestep is not None else \ self.last_timestep + 1 # Apply exploration. if explore: # Random exploration phase. if self.last_timestep < self.random_timesteps: action, logp = \ self.random_exploration.get_torch_exploration_action( action_dist, explore=True) # Take a sample from our distribution. else: action = action_dist.sample() logp = action_dist.sampled_action_logp() # No exploration -> Return deterministic actions. else: action = action_dist.deterministic_sample() logp = torch.zeros_like(action_dist.sampled_action_logp()) return action, logp
def get_exploration_action( self, *, action_distribution: ActionDistribution, timestep: int, explore: bool = True, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: if explore: if timestep < self._pure_exploration_steps: return super().get_exploration_action( action_distribution=action_distribution, timestep=timestep, explore=explore, ) return action_distribution.sample() return action_distribution.deterministic_sample()