Example #1
0
    def _get_torch_exploration_action(self, action_dist: ActionDistribution,
                                      explore: bool,
                                      timestep: Union[int, TensorType]):
        # Set last timestep or (if not given) increase by one.
        self.last_timestep = timestep if timestep is not None else \
            self.last_timestep + 1

        # Apply exploration.
        if explore:
            # Random exploration phase.
            if self.last_timestep < self.random_timesteps:
                action, _ = \
                    self.random_exploration.get_torch_exploration_action(
                        action_dist, explore=True)
            # Apply base-scaled and time-annealed scaled OU-noise to
            # deterministic actions.
            else:
                det_actions = action_dist.deterministic_sample()
                scale = self.scale_schedule(self.last_timestep)
                gaussian_sample = scale * torch.normal(
                    mean=torch.zeros(self.ou_state.size()), std=1.0) \
                    .to(self.device)
                ou_new = self.ou_theta * -self.ou_state + \
                    self.ou_sigma * gaussian_sample
                self.ou_state += ou_new
                high_m_low = torch.from_numpy(
                    self.action_space.high - self.action_space.low). \
                    to(self.device)
                high_m_low = torch.where(
                    torch.isinf(high_m_low),
                    torch.ones_like(high_m_low).to(self.device), high_m_low)
                noise = scale * self.ou_base_scale * self.ou_state * high_m_low

                action = torch.min(
                    torch.max(
                        det_actions + noise,
                        torch.tensor(self.action_space.low,
                                     dtype=torch.float32,
                                     device=self.device)),
                    torch.tensor(self.action_space.high,
                                 dtype=torch.float32,
                                 device=self.device))

        # No exploration -> Return deterministic actions.
        else:
            action = action_dist.deterministic_sample()

        # Logp=always zero.
        logp = torch.zeros((action.size()[0], ),
                           dtype=torch.float32,
                           device=self.device)

        return action, logp
Example #2
0
    def _get_torch_exploration_action(self, action_dist: ActionDistribution,
                                      timestep: Union[TensorType, int],
                                      explore: Union[TensorType, bool]):
        # Set last timestep or (if not given) increase by one.
        self.last_timestep = timestep if timestep is not None else \
            self.last_timestep + 1

        # Apply exploration.
        if explore:
            # Random exploration phase.
            if self.last_timestep < self.random_timesteps:
                action, logp = \
                    self.random_exploration.get_torch_exploration_action(
                        action_dist, explore=True)
            # Take a sample from our distribution.
            else:
                action = action_dist.sample()
                logp = action_dist.sampled_action_logp()

        # No exploration -> Return deterministic actions.
        else:
            action = action_dist.deterministic_sample()
            logp = torch.zeros_like(action_dist.sampled_action_logp())

        return action, logp
Example #3
0
 def extra_action_out_fn(
         policy: Policy, input_dict, state_batches, model,
         action_dist: ActionDistribution) -> Dict[str, TensorType]:
     action = action_dist.deterministic_sample()
     action_probs = torch.zeros_like(policy.q_values).long()
     action_probs[0][action[0]] = 1.0
     return {"q_values": policy.q_values, "action_probs": action_probs}
Example #4
0
    def get_exploration_action(
        self,
        action_distribution: ActionDistribution,
        timestep: Union[int, TensorType],
        explore: bool = True,
    ):
        assert (self.framework == "torch"
                ), "ERROR: SlateSoftQ only supports torch so far!"

        cls = type(action_distribution)

        # Re-create the action distribution with the correct temperature
        # applied.
        action_distribution = cls(action_distribution.inputs,
                                  self.model,
                                  temperature=self.temperature)
        batch_size = action_distribution.inputs.size()[0]
        action_logp = torch.zeros(batch_size, dtype=torch.float)

        self.last_timestep = timestep

        # Explore.
        if explore:
            # Return stochastic sample over (q-value) logits.
            action = action_distribution.sample()
        # Return the deterministic "sample" (argmax) over (q-value) logits.
        else:
            action = action_distribution.deterministic_sample()

        return action, action_logp
Example #5
0
    def _get_tf_exploration_action_op(self, action_dist: ActionDistribution,
                                      explore: Union[bool, TensorType],
                                      timestep: Union[int, TensorType]):
        ts = timestep if timestep is not None else self.last_timestep
        scale = self.scale_schedule(ts)

        # The deterministic actions (if explore=False).
        deterministic_actions = action_dist.deterministic_sample()

        # Apply base-scaled and time-annealed scaled OU-noise to
        # deterministic actions.
        gaussian_sample = tf.random.normal(shape=[self.action_space.low.size],
                                           stddev=self.stddev)
        ou_new = self.ou_theta * -self.ou_state + \
            self.ou_sigma * gaussian_sample
        if self.framework in ["tf2", "tfe"]:
            self.ou_state.assign_add(ou_new)
            ou_state_new = self.ou_state
        else:
            ou_state_new = tf1.assign_add(self.ou_state, ou_new)
        high_m_low = self.action_space.high - self.action_space.low
        high_m_low = tf.where(tf.math.is_inf(high_m_low),
                              tf.ones_like(high_m_low), high_m_low)
        noise = scale * self.ou_base_scale * ou_state_new * high_m_low
        stochastic_actions = tf.clip_by_value(
            deterministic_actions + noise,
            self.action_space.low * tf.ones_like(deterministic_actions),
            self.action_space.high * tf.ones_like(deterministic_actions))

        # Stochastic actions could either be: random OR action + noise.
        random_actions, _ = \
            self.random_exploration.get_tf_exploration_action_op(
                action_dist, explore)
        exploration_actions = tf.cond(
            pred=tf.convert_to_tensor(ts < self.random_timesteps),
            true_fn=lambda: random_actions,
            false_fn=lambda: stochastic_actions)

        # Chose by `explore` (main exploration switch).
        action = tf.cond(pred=tf.constant(explore, dtype=tf.bool)
                         if isinstance(explore, bool) else explore,
                         true_fn=lambda: exploration_actions,
                         false_fn=lambda: deterministic_actions)
        # Logp=always zero.
        batch_size = tf.shape(deterministic_actions)[0]
        logp = tf.zeros(shape=(batch_size, ), dtype=tf.float32)

        # Increment `last_timestep` by 1 (or set to `timestep`).
        if self.framework in ["tf2", "tfe"]:
            if timestep is None:
                self.last_timestep.assign_add(1)
            else:
                self.last_timestep.assign(timestep)
            return action, logp
        else:
            assign_op = (tf1.assign_add(self.last_timestep, 1)
                         if timestep is None else tf1.assign(
                             self.last_timestep, timestep))
            with tf1.control_dependencies([assign_op, ou_state_new]):
                return action, logp
Example #6
0
    def _get_torch_exploration_action(self,
                                      action_distribution: ActionDistribution,
                                      explore: bool,
                                      timestep: Union[int, TensorType]):
        """Torch method to produce an epsilon exploration action.

        Args:
            action_distribution (ActionDistribution): The instantiated
                ActionDistribution object to work with when creating
                exploration actions.

        Returns:
            torch.Tensor: The exploration-action.
        """
        q_values = action_distribution.inputs
        self.last_timestep = timestep
        exploit_action = action_distribution.deterministic_sample()
        batch_size = q_values.size()[0]
        action_logp = torch.zeros(batch_size, dtype=torch.float)

        # Explore.
        if explore:
            # Get the current epsilon.
            epsilon = self.epsilon_schedule(self.last_timestep)
            if isinstance(action_distribution, TorchMultiActionDistribution):
                exploit_action = tree.flatten(exploit_action)
                for i in range(batch_size):
                    if random.random() < epsilon:
                        # TODO: (bcahlit) Mask out actions
                        random_action = tree.flatten(
                            self.action_space.sample())
                        for j in range(len(exploit_action)):
                            exploit_action[j][i] = torch.tensor(
                                random_action[j])
                exploit_action = tree.unflatten_as(
                    action_distribution.action_space_struct, exploit_action)

                return exploit_action, action_logp

            else:
                # Mask out actions, whose Q-values are -inf, so that we don't
                # even consider them for exploration.
                random_valid_action_logits = torch.where(
                    q_values <= FLOAT_MIN,
                    torch.ones_like(q_values) * 0.0, torch.ones_like(q_values))
                # A random action.
                random_actions = torch.squeeze(torch.multinomial(
                    random_valid_action_logits, 1),
                                               axis=1)

                # Pick either random or greedy.
                action = torch.where(
                    torch.empty(
                        (batch_size, )).uniform_().to(self.device) < epsilon,
                    random_actions, exploit_action)

                return action, action_logp
        # Return the deterministic "sample" (argmax) over the logits.
        else:
            return exploit_action, action_logp
Example #7
0
    def _get_torch_exploration_action(
        self,
        action_dist: ActionDistribution,
        explore: bool,
        timestep: Union[int, TensorType],
    ):
        # Set last timestep or (if not given) increase by one.
        self.last_timestep = (timestep if timestep is not None else
                              self.last_timestep + 1)

        # Apply exploration.
        if explore:
            # Random exploration phase.
            if self.last_timestep < self.random_timesteps:
                action, _ = self.random_exploration.get_torch_exploration_action(
                    action_dist, explore=True)
            # Take a Gaussian sample with our stddev (mean=0.0) and scale it.
            else:
                det_actions = action_dist.deterministic_sample()
                scale = self.scale_schedule(self.last_timestep)
                gaussian_sample = scale * torch.normal(
                    mean=torch.zeros(det_actions.size()), std=self.stddev).to(
                        self.device)
                action = torch.min(
                    torch.max(
                        det_actions + gaussian_sample,
                        torch.tensor(
                            self.action_space.low,
                            dtype=torch.float32,
                            device=self.device,
                        ),
                    ),
                    torch.tensor(self.action_space.high,
                                 dtype=torch.float32,
                                 device=self.device),
                )
        # No exploration -> Return deterministic actions.
        else:
            action = action_dist.deterministic_sample()

        # Logp=always zero.
        logp = torch.zeros((action.size()[0], ),
                           dtype=torch.float32,
                           device=self.device)

        return action, logp
Example #8
0
    def _get_tf_exploration_action_op(
        self,
        action_distribution: ActionDistribution,
        explore: Union[bool, TensorType],
        timestep: Union[int, TensorType],
    ) -> "tf.Tensor":

        per_slate_q_values = action_distribution.inputs
        all_slates = self.model.slates

        exploit_indices = action_distribution.deterministic_sample()
        exploit_action = tf.gather(all_slates, exploit_indices)

        batch_size = tf.shape(per_slate_q_values)[0]
        action_logp = tf.zeros(batch_size, dtype=tf.float32)

        # Get the current epsilon.
        epsilon = self.epsilon_schedule(
            timestep if timestep is not None else self.last_timestep
        )
        # Mask out actions, whose Q-values are -inf, so that we don't
        # even consider them for exploration.
        random_valid_action_logits = tf.where(
            tf.equal(per_slate_q_values, tf.float32.min),
            tf.ones_like(per_slate_q_values) * tf.float32.min,
            tf.ones_like(per_slate_q_values),
        )
        # A random action.
        random_indices = tf.squeeze(
            tf.random.categorical(random_valid_action_logits, 1), axis=1
        )
        random_actions = tf.gather(all_slates, random_indices)

        choose_random = (
            tf.random.uniform(
                tf.stack([batch_size]), minval=0, maxval=1, dtype=tf.float32
            )
            < epsilon
        )

        # Pick either random or greedy.
        action = tf.cond(
            pred=tf.constant(explore, dtype=tf.bool)
            if isinstance(explore, bool)
            else explore,
            true_fn=(lambda: tf.where(choose_random, random_actions, exploit_action)),
            false_fn=lambda: exploit_action,
        )

        if self.framework in ["tf2", "tfe"] and not self.policy_config["eager_tracing"]:
            self.last_timestep = timestep
            return action, action_logp
        else:
            assign_op = tf1.assign(self.last_timestep, tf.cast(timestep, tf.int64))
            with tf1.control_dependencies([assign_op]):
                return action, action_logp
    def _get_tf_exploration_action_op(
        self,
        action_distribution: ActionDistribution,
        explore: Union[bool, TensorType],
        timestep: Union[int, TensorType],
    ) -> "tf.Tensor":

        per_slate_q_values = action_distribution.inputs
        all_slates = action_distribution.all_slates

        exploit_action = action_distribution.deterministic_sample()

        batch_size, num_slates = (
            tf.shape(per_slate_q_values)[0],
            tf.shape(per_slate_q_values)[1],
        )
        action_logp = tf.zeros(batch_size, dtype=tf.float32)

        # Get the current epsilon.
        epsilon = self.epsilon_schedule(
            timestep if timestep is not None else self.last_timestep)
        # A random action.
        random_indices = tf.random.uniform(
            (batch_size, ),
            minval=0,
            maxval=num_slates,
            dtype=tf.dtypes.int32,
        )
        random_actions = tf.gather(all_slates, random_indices)

        choose_random = (tf.random.uniform(
            tf.stack([batch_size]), minval=0, maxval=1, dtype=tf.float32) <
                         epsilon)

        # Pick either random or greedy.
        action = tf.cond(
            pred=tf.constant(explore, dtype=tf.bool) if isinstance(
                explore, bool) else explore,
            true_fn=(lambda: tf.where(choose_random, random_actions,
                                      exploit_action)),
            false_fn=lambda: exploit_action,
        )

        if self.framework in ["tf2", "tfe"
                              ] and not self.policy_config["eager_tracing"]:
            self.last_timestep = timestep
            return action, action_logp
        else:
            assign_op = tf1.assign(self.last_timestep,
                                   tf.cast(timestep, tf.int64))
            with tf1.control_dependencies([assign_op]):
                return action, action_logp
Example #10
0
    def _get_tf_exploration_action_op(
        self,
        action_dist: ActionDistribution,
        explore: bool,
        timestep: Union[int, TensorType],
    ):
        ts = timestep if timestep is not None else self.last_timestep

        # The deterministic actions (if explore=False).
        deterministic_actions = action_dist.deterministic_sample()

        # Take a Gaussian sample with our stddev (mean=0.0) and scale it.
        gaussian_sample = self.scale_schedule(ts) * tf.random.normal(
            tf.shape(deterministic_actions), stddev=self.stddev)

        # Stochastic actions could either be: random OR action + noise.
        random_actions, _ = self.random_exploration.get_tf_exploration_action_op(
            action_dist, explore)
        stochastic_actions = tf.cond(
            pred=tf.convert_to_tensor(ts < self.random_timesteps),
            true_fn=lambda: random_actions,
            false_fn=lambda: tf.clip_by_value(
                deterministic_actions + gaussian_sample,
                self.action_space.low * tf.ones_like(deterministic_actions),
                self.action_space.high * tf.ones_like(deterministic_actions),
            ),
        )

        # Chose by `explore` (main exploration switch).
        action = tf.cond(
            pred=tf.constant(explore, dtype=tf.bool) if isinstance(
                explore, bool) else explore,
            true_fn=lambda: stochastic_actions,
            false_fn=lambda: deterministic_actions,
        )
        # Logp=always zero.
        logp = zero_logps_from_actions(deterministic_actions)

        # Increment `last_timestep` by 1 (or set to `timestep`).
        if self.framework in ["tf2", "tfe"]:
            if timestep is None:
                self.last_timestep.assign_add(1)
            else:
                self.last_timestep.assign(tf.cast(timestep, tf.int64))
            return action, logp
        else:
            assign_op = (tf1.assign_add(self.last_timestep, 1)
                         if timestep is None else tf1.assign(
                             self.last_timestep, timestep))
            with tf1.control_dependencies([assign_op]):
                return action, logp
Example #11
0
 def get_exploration_action(
     self,
     *,
     action_distribution: ActionDistribution,
     timestep: int,
     explore: bool = True,
 ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
     if explore:
         if timestep < self._pure_exploration_steps:
             return super().get_exploration_action(
                 action_distribution=action_distribution,
                 timestep=timestep,
                 explore=explore,
             )
         return action_distribution.sample()
     return action_distribution.deterministic_sample()
Example #12
0
    def _get_torch_exploration_action(
        self,
        action_distribution: ActionDistribution,
        explore: bool,
        timestep: Union[int, TensorType],
    ) -> "torch.Tensor":

        per_slate_q_values = action_distribution.inputs
        all_slates = self.model.slates

        exploit_indices = action_distribution.deterministic_sample()
        exploit_action = all_slates[exploit_indices]

        batch_size = per_slate_q_values.size()[0]
        action_logp = torch.zeros(batch_size, dtype=torch.float)

        self.last_timestep = timestep

        # Explore.
        if explore:
            # Get the current epsilon.
            epsilon = self.epsilon_schedule(self.last_timestep)
            # Mask out actions, whose Q-values are -inf, so that we don't
            # even consider them for exploration.
            random_valid_action_logits = torch.where(
                per_slate_q_values <= FLOAT_MIN,
                torch.ones_like(per_slate_q_values) * 0.0,
                torch.ones_like(per_slate_q_values),
            )
            # A random action.
            random_indices = torch.squeeze(torch.multinomial(
                random_valid_action_logits, 1),
                                           axis=1)
            random_actions = all_slates[random_indices]

            # Pick either random or greedy.
            action = torch.where(
                torch.empty(
                    (batch_size, )).uniform_().to(self.device) < epsilon,
                random_actions,
                exploit_action,
            )

            return action, action_logp
        # Return the deterministic "sample" (argmax) over the logits.
        else:
            return exploit_action, action_logp
Example #13
0
 def get_torch_exploration_action(self, action_dist: ActionDistribution,
                                  explore: bool):
     if explore:
         req = force_tuple(
             action_dist.required_model_output_shape(
                 self.action_space, self.model.model_config))
         # Add a batch dimension?
         if len(action_dist.inputs.shape) == len(req) + 1:
             batch_size = action_dist.inputs.shape[0]
             a = np.stack(
                 [self.action_space.sample() for _ in range(batch_size)])
         else:
             a = self.action_space.sample()
         # Convert action to torch tensor.
         action = torch.from_numpy(a).to(self.device)
     else:
         action = action_dist.deterministic_sample()
     logp = torch.zeros(
         (action.size()[0], ), dtype=torch.float32, device=self.device)
     return action, logp
Example #14
0
    def _get_torch_exploration_action(
        self,
        action_distribution: ActionDistribution,
        explore: bool,
        timestep: Union[int, TensorType],
    ) -> "torch.Tensor":

        per_slate_q_values = action_distribution.inputs
        all_slates = self.model.slates

        exploit_indices = action_distribution.deterministic_sample()
        exploit_action = all_slates[exploit_indices]

        batch_size = per_slate_q_values.size()[0]
        action_logp = torch.zeros(batch_size, dtype=torch.float)

        self.last_timestep = timestep

        # Explore.
        if explore:
            # Get the current epsilon.
            epsilon = self.epsilon_schedule(self.last_timestep)
            # A random action.
            random_indices = torch.randint(0, per_slate_q_values.shape[1],
                                           (per_slate_q_values.shape[0], ))
            random_actions = all_slates[random_indices]

            # Pick either random or greedy.
            action = torch.where(
                torch.empty((batch_size, )).uniform_() < epsilon,
                random_actions,
                exploit_action,
            )
            return action, action_logp
        # Return the deterministic "sample" (argmax) over the logits.
        else:
            return exploit_action, action_logp