Exemplo n.º 1
0
 def act(self,
         obs: rlt.FeatureData,
         possible_actions_mask: Optional[np.ndarray] = None
         ) -> rlt.ActorOutput:
     greedy = self.cem_planner_network(obs)
     if self.discrete_action:
         _, onehot = greedy
         return rlt.ActorOutput(action=onehot.unsqueeze(0),
                                log_prob=torch.tensor(0.0))
     else:
         return rlt.ActorOutput(action=greedy.unsqueeze(0),
                                log_prob=torch.tensor(0.0))
Exemplo n.º 2
0
 def act(self,
         obs: Any,
         possible_actions_mask: Optional[np.ndarray] = None
         ) -> rlt.ActorOutput:
     output = self.predictor(obs)
     if isinstance(output, tuple):
         action, log_prob = output
         log_prob = log_prob.clamp(LOG_PROB_MIN, LOG_PROB_MAX)
         return rlt.ActorOutput(action=action.cpu(),
                                log_prob=log_prob.cpu())
     else:
         return rlt.ActorOutput(action=output.cpu())
Exemplo n.º 3
0
    def forward(self, state):
        concentration = self._get_concentration(state)
        if self.training:
            # PyTorch can't backwards pass _sample_dirichlet
            action = Dirichlet(concentration).rsample()
        else:
            # ONNX can't export Dirichlet()
            action = torch._sample_dirichlet(concentration)

        if not self.training:
            # ONNX doesn't like reshape either..
            return rlt.ActorOutput(action=action)

        log_prob = Dirichlet(concentration).log_prob(action)
        return rlt.ActorOutput(action=action,
                               log_prob=log_prob.unsqueeze(dim=1))
Exemplo n.º 4
0
    def sample_action(self, scores: torch.Tensor) -> rlt.ActorOutput:
        assert scores.dim() == 2, ("scores dim is %d" % scores.dim()
                                   )  # batch_size x num_actions
        batch_size, num_actions = scores.shape

        # pyre-fixme[16]: `Tensor` has no attribute `argmax`.
        argmax = F.one_hot(scores.argmax(dim=1), num_actions).bool()

        valid_actions_ind = (scores > INVALID_ACTION_CONSTANT).bool()
        num_valid_actions = valid_actions_ind.float().sum(1, keepdim=True)

        rand_prob = self.epsilon / num_valid_actions
        p = torch.ones_like(scores) * rand_prob

        greedy_prob = 1 - self.epsilon + rand_prob
        p[argmax] = greedy_prob.squeeze()

        p[~valid_actions_ind] = 0.0  # pyre-ignore

        assert torch.isclose(p.sum(1) == torch.ones(p.shape[0]))

        m = torch.distributions.Categorical(probs=p)
        raw_action = m.sample()
        action = F.one_hot(raw_action, num_actions)
        assert action.shape == (batch_size, num_actions)
        log_prob = m.log_prob(raw_action)
        assert log_prob.shape == (batch_size, )
        return rlt.ActorOutput(action=action, log_prob=log_prob)
Exemplo n.º 5
0
 def act(self,
         obs: Any,
         possible_actions_mask: Optional[np.ndarray] = None
         ) -> rlt.ActorOutput:
     action = self.predictor(obs).cpu()
     # TODO: return log_probs as well
     return rlt.ActorOutput(action=action)
Exemplo n.º 6
0
    def sample_action(self, scores: torch.Tensor) -> rlt.ActorOutput:

        batch_size, num_actions = scores.shape
        raw_action = self._get_greedy_indices(scores)
        action = F.one_hot(raw_action, num_actions)
        assert action.shape == (batch_size, num_actions)
        return rlt.ActorOutput(action=action,
                               log_prob=torch.ones_like(raw_action,
                                                        dtype=torch.float))
Exemplo n.º 7
0
    def forward(self, state: rlt.FeatureData) -> rlt.ActorOutput:
        action = self.fc(state.float_features)
        batch_size = action.shape[0]
        assert action.shape == (
            batch_size,
            self.action_dim,
        ), f"{action.shape} != ({batch_size}, {self.action_dim})"

        if self.exploration_variance is None:
            log_prob = torch.zeros(batch_size).to(action.device).float().view(
                -1, 1)
            return rlt.ActorOutput(action=action, log_prob=log_prob)

        noise = self.noise_dist.sample((batch_size, ))
        # TODO: log prob is affected by clamping, how to handle that?
        log_prob = (self.noise_dist.log_prob(noise).to(
            action.device).sum(dim=1).view(-1, 1))
        action = (action + noise.to(action.device)).clamp(
            *CONTINUOUS_TRAINING_ACTION_RANGE)
        return rlt.ActorOutput(action=action, log_prob=log_prob)
Exemplo n.º 8
0
 def sample_action(self, scores: torch.Tensor) -> rlt.ActorOutput:
     """Sample a ranking according to Frechet sort. Note that possible_actions_mask
     is ignored as the list of rankings scales exponentially with slate size and
     number of items and it can be difficult to enumerate them."""
     assert scores.dim() == 2, "sample_action only accepts batches"
     log_scores = scores if self.log_scores else torch.log(scores)
     perturbed = log_scores + self.gumbel_noise.sample(scores.shape)
     action = torch.argsort(perturbed.detach(), descending=True)
     log_prob = self.log_prob(scores, action)
     # Only truncate the action before returning
     if self.topk is not None:
         action = action[:self.topk]
     return rlt.ActorOutput(action, log_prob)
Exemplo n.º 9
0
 def sample_action(self, scores: torch.Tensor) -> rlt.ActorOutput:
     assert (
         scores.dim() == 2
     ), f"scores shape is {scores.shape}, not (batch_size, num_actions)"
     batch_size, num_actions = scores.shape
     m = self._get_distribution(scores)
     raw_action = m.sample()
     assert raw_action.shape == (
         batch_size, ), f"{raw_action.shape} != ({batch_size}, )"
     action = F.one_hot(raw_action, num_actions)
     assert action.ndim == 2
     log_prob = m.log_prob(raw_action)
     assert log_prob.ndim == 1
     return rlt.ActorOutput(action=action, log_prob=log_prob)
Exemplo n.º 10
0
 def act(
     self,
     obs: Union[rlt.ServingFeatureData, Tuple[torch.Tensor, torch.Tensor]],
     possible_actions_mask: Optional[torch.Tensor] = None,
 ) -> rlt.ActorOutput:
     """Input is either state_with_presence, or
     ServingFeatureData (in the case of sparse features)"""
     assert isinstance(obs, tuple)
     if isinstance(obs, rlt.ServingFeatureData):
         state: rlt.ServingFeatureData = obs
     else:
         state = rlt.ServingFeatureData(
             float_features_with_presence=obs,
             id_list_features={},
             id_score_list_features={},
         )
     output = self.predictor(*state)
     if isinstance(output, tuple):
         action, log_prob = output
         log_prob = log_prob.clamp(LOG_PROB_MIN, LOG_PROB_MAX)
         return rlt.ActorOutput(action=action.cpu(), log_prob=log_prob.cpu())
     else:
         return rlt.ActorOutput(action=output.cpu())
Exemplo n.º 11
0
 def act(self,
         obs: rlt.FeatureData,
         possible_actions_mask: Optional[np.ndarray] = None
         ) -> rlt.ActorOutput:
     """Act randomly regardless of the observation."""
     # pyre-fixme[35]: Target cannot be annotated.
     obs: torch.Tensor = obs.float_features
     assert obs.dim() >= 2, f"obs has shape {obs.shape} (dim < 2)"
     batch_size = obs.size(0)
     # pyre-fixme[6]: Expected `Union[torch.Size, torch.Tensor]` for 1st param
     #  but got `Tuple[int]`.
     action = self.dist.sample((batch_size, ))
     # sum over action_dim (since assuming i.i.d. per coordinate)
     log_prob = self.dist.log_prob(action).sum(1)
     return rlt.ActorOutput(action=action, log_prob=log_prob)
Exemplo n.º 12
0
    def forward(self, state: rlt.FeatureData):
        loc, scale_log = self._get_loc_and_scale_log(state)
        r = torch.randn_like(scale_log, device=scale_log.device)
        raw_action = loc + r * scale_log.exp()
        squashed_action = self._squash_raw_action(raw_action)
        squashed_loc = self._squash_raw_action(loc)
        if SummaryWriterContext._global_step % 1000 == 0:
            SummaryWriterContext.add_histogram("actor/forward/loc",
                                               loc.detach().cpu())
            SummaryWriterContext.add_histogram("actor/forward/scale_log",
                                               scale_log.detach().cpu())

        return rlt.ActorOutput(
            action=squashed_action,
            log_prob=self.get_log_prob(state, squashed_action),
            squashed_mean=squashed_loc,
        )
Exemplo n.º 13
0
    def act(self,
            obs: rlt.FeatureData,
            possible_actions_mask: Optional[np.ndarray] = None
            ) -> rlt.ActorOutput:
        # pyre-fixme[35]: Target cannot be annotated.
        obs: torch.Tensor = obs.float_features
        batch_size, _ = obs.shape

        actions = []
        log_probs = []
        for m in self.dists:
            actions.append(m.sample((batch_size, 1)))
            log_probs.append(m.log_prob(actions[-1]).float())

        return rlt.ActorOutput(
            action=torch.cat(actions, dim=1),
            log_prob=torch.cat(log_probs, dim=1).sum(1, keepdim=True),
        )
Exemplo n.º 14
0
    def act(self,
            obs: rlt.FeatureData,
            possible_actions_mask: Optional[np.ndarray] = None
            ) -> rlt.ActorOutput:
        """Act randomly regardless of the observation."""
        # pyre-fixme[35]: Target cannot be annotated.
        obs: torch.Tensor = obs.float_features
        assert obs.dim() >= 2, f"obs has shape {obs.shape} (dim < 2)"
        assert obs.shape[0] == 1, f"obs has shape {obs.shape} (0th dim != 1)"
        batch_size = obs.shape[0]
        scores = torch.ones((batch_size, self.num_actions))
        scores = apply_possible_actions_mask(scores,
                                             possible_actions_mask,
                                             invalid_score=0.0)

        # sample a random action
        m = torch.distributions.Categorical(scores)
        raw_action = m.sample()
        action = F.one_hot(raw_action, self.num_actions)
        log_prob = m.log_prob(raw_action).float()
        return rlt.ActorOutput(action=action, log_prob=log_prob)
Exemplo n.º 15
0
 def sample_action(self, scores: torch.Tensor) -> rlt.ActorOutput:
     top_values, item_idxs = torch.topk(scores, self.k, dim=1)
     return rlt.ActorOutput(action=item_idxs,
                            log_prob=torch.zeros(item_idxs.shape[0], 1))
Exemplo n.º 16
0
    def sample_action(self, scores: GaussianSamplerScore) -> rlt.ActorOutput:
        self.actor_network.eval()
        unscaled_actions, log_prob = self._sample_action(scores.loc, scores.scale_log)
        self.actor_network.train()

        return rlt.ActorOutput(action=unscaled_actions, log_prob=log_prob)