Ejemplo n.º 1
0
    def forward(self, input):
        loc, scale_log = self._get_loc_and_scale_log(input.state)
        r = torch.randn_like(scale_log, device=scale_log.device)
        action = torch.tanh(loc + r * scale_log.exp())
        if not self.training:
            # ONNX doesn't like reshape either..
            return rlt.ActorOutput(action=action)
        # Since each dim are independent, log-prob is simply sum
        log_prob = self._log_prob(r, scale_log)
        squash_correction = self._squash_correction(action)
        if SummaryWriterContext._global_step % 1000 == 0:
            SummaryWriterContext.add_histogram("actor/forward/loc", loc.detach().cpu())
            SummaryWriterContext.add_histogram(
                "actor/forward/scale_log", scale_log.detach().cpu()
            )
            SummaryWriterContext.add_histogram(
                "actor/forward/log_prob", log_prob.detach().cpu()
            )
            SummaryWriterContext.add_histogram(
                "actor/forward/squash_correction", squash_correction.detach().cpu()
            )
        log_prob = torch.sum(log_prob - squash_correction, dim=1)

        return rlt.ActorOutput(
            action=action, log_prob=log_prob.reshape(-1, 1), action_mean=loc
        )
Ejemplo n.º 2
0
 def act(self, obs: rlt.FeatureData) -> rlt.ActorOutput:
     greedy = self.cem_planner_network(obs)
     if self.discrete_action:
         _, onehot = greedy
         return rlt.ActorOutput(
             action=onehot.unsqueeze(0), log_prob=torch.tensor(0.0)
         )
     else:
         return rlt.ActorOutput(
             action=greedy.unsqueeze(0), log_prob=torch.tensor(0.0)
         )
Ejemplo n.º 3
0
 def act(self,
         obs: rlt.FeatureData,
         possible_actions_mask: Optional[np.ndarray] = None
         ) -> rlt.ActorOutput:
     greedy = self.cem_planner_network(obs)
     if self.discrete_action:
         _, onehot = greedy
         return rlt.ActorOutput(action=onehot.unsqueeze(0),
                                log_prob=torch.tensor(0.0))
     else:
         return rlt.ActorOutput(action=greedy.unsqueeze(0),
                                log_prob=torch.tensor(0.0))
Ejemplo n.º 4
0
    def forward(self, input):
        concentration = self._get_concentration(input.state)
        if self.training:
            # PyTorch can't backwards pass _sample_dirichlet
            action = Dirichlet(concentration).rsample()
        else:
            # ONNX can't export Dirichlet()
            action = torch._sample_dirichlet(concentration)

        if not self.training:
            # ONNX doesn't like reshape either..
            return rlt.ActorOutput(action=action)

        log_prob = Dirichlet(concentration).log_prob(action)
        return rlt.ActorOutput(action=action, log_prob=log_prob.unsqueeze(dim=1))
Ejemplo n.º 5
0
 def act(self,
         obs: Any,
         possible_actions_mask: Optional[np.ndarray] = None
         ) -> rlt.ActorOutput:
     action = self.predictor(obs).cpu()
     # TODO: return log_probs as well
     return rlt.ActorOutput(action=action)
Ejemplo n.º 6
0
 def sample_action(
     self,
     scores: torch.Tensor,
     possible_actions_mask: Optional[torch.Tensor] = None
 ) -> rlt.ActorOutput:
     assert scores.dim() == 2, ("scores dim is %d" % scores.dim()
                                )  # batch_size x num_actions
     _, num_actions = scores.shape
     f = F.log_softmax if self.key == "logits" else F.softmax
     if possible_actions_mask is not None:
         assert possible_actions_mask.dim() == 2  # batch_size x num_actions
         mod_scores = f(scores + torch.log(possible_actions_mask))
     else:
         mod_scores = scores
     m = torch.distributions.Categorical(
         **{self.key: mod_scores / self.temperature})
     raw_action = m.sample()
     assert raw_action.ndim == 1
     assert (0 <= raw_action and raw_action < num_actions
             ), f"negative {raw_action} or >= {num_actions}."
     action = F.one_hot(raw_action, num_actions)
     assert action.ndim == 2
     log_prob = m.log_prob(raw_action).float()
     assert log_prob.ndim == 1
     return rlt.ActorOutput(action=action, log_prob=log_prob)
Ejemplo n.º 7
0
    def sample_action(self, scores: torch.Tensor) -> rlt.ActorOutput:
        assert scores.dim() == 2, ("scores dim is %d" % scores.dim()
                                   )  # batch_size x num_actions
        batch_size, num_actions = scores.shape

        # pyre-fixme[16]: `Tensor` has no attribute `argmax`.
        argmax = F.one_hot(scores.argmax(dim=1), num_actions).bool()

        valid_actions_ind = (scores > INVALID_ACTION_CONSTANT).bool()
        num_valid_actions = valid_actions_ind.float().sum(1, keepdim=True)

        rand_prob = self.epsilon / num_valid_actions
        p = torch.ones_like(scores) * rand_prob

        greedy_prob = 1 - self.epsilon + rand_prob
        p[argmax] = greedy_prob.squeeze()

        p[~valid_actions_ind] = 0.0  # pyre-ignore

        assert torch.isclose(p.sum(1) == torch.ones(p.shape[0]))

        m = torch.distributions.Categorical(probs=p)
        raw_action = m.sample()
        action = F.one_hot(raw_action, num_actions)
        assert action.shape == (batch_size, num_actions)
        log_prob = m.log_prob(raw_action)
        assert log_prob.shape == (batch_size, )
        return rlt.ActorOutput(action=action, log_prob=log_prob)
Ejemplo n.º 8
0
    def sample_action(
        self,
        scores: torch.Tensor,
        possible_actions_mask: Optional[torch.Tensor] = None
    ) -> rlt.ActorOutput:
        assert scores.dim() == 2, ("scores dim is %d" % scores.dim()
                                   )  # batch_size x num_actions
        batch_size, num_actions = scores.shape

        if possible_actions_mask is None:
            possible_actions_mask = torch.ones(num_actions)

        argmax = F.one_hot(scores.argmax(dim=1), num_actions).bool()

        p = torch.zeros_like(scores)
        allowed_action_count = float(possible_actions_mask.sum().item())
        mask = torch.repeat_interleave(possible_actions_mask.bool(),
                                       batch_size,
                                       axis=0)

        rand_prob = self.epsilon / allowed_action_count
        p[mask] = rand_prob

        greedy_prob = 1 - self.epsilon + rand_prob
        p[argmax] = greedy_prob

        m = torch.distributions.Categorical(probs=p)
        raw_action = m.sample()
        action = F.one_hot(raw_action, num_actions)
        assert action.shape == (batch_size, num_actions)
        log_prob = m.log_prob(raw_action)
        assert log_prob.shape == (batch_size, )
        return rlt.ActorOutput(action=action, log_prob=log_prob)
Ejemplo n.º 9
0
    def sample_action(self, scores: torch.Tensor) -> rlt.ActorOutput:

        batch_size, num_actions = scores.shape
        raw_action = self._get_greedy_indices(scores)
        action = F.one_hot(raw_action, num_actions)
        assert action.shape == (batch_size, num_actions)
        return rlt.ActorOutput(action=action, log_prob=torch.ones_like(raw_action))
Ejemplo n.º 10
0
 def act(self, obs: rlt.FeatureData) -> rlt.ActorOutput:
     # TODO: Why doesn't predictor take the whole preprocessed_state?
     state = obs.float_features
     action = self.predictor.policy(state=state).softmax
     assert action is not None
     # since act should return batched data
     return rlt.ActorOutput(action=torch.tensor([[action]]))
Ejemplo n.º 11
0
    def sample_action(self, scores: GaussianSamplerScore) -> rlt.ActorOutput:
        self.actor_network.eval()
        unscaled_actions, log_prob = self._sample_action(
            scores.loc, scores.scale_log)
        self.actor_network.train()

        return rlt.ActorOutput(action=unscaled_actions, log_prob=log_prob)
Ejemplo n.º 12
0
 def act(self, obs: rlt.FeatureData) -> rlt.ActorOutput:
     """ Act randomly regardless of the observation. """
     obs: torch.Tensor = obs.float_features
     assert obs.dim() >= 2, f"obs has shape {obs.shape} (dim < 2)"
     batch_size = obs.size(0)
     # pyre-fixme[6]: Expected `Union[torch.Size, torch.Tensor]` for 1st param
     #  but got `Tuple[int]`.
     action = self.dist.sample((batch_size, ))
     log_prob = self.dist.log_prob(action)
     return rlt.ActorOutput(action=action, log_prob=log_prob)
Ejemplo n.º 13
0
    def forward(self, state: rlt.FeatureData) -> rlt.ActorOutput:
        action = self.fc(state.float_features)
        batch_size = action.shape[0]
        assert action.shape == (
            batch_size,
            self.action_dim,
        ), f"{action.shape} != ({batch_size}, {self.action_dim})"

        if self.exploration_variance is None:
            log_prob = torch.zeros(batch_size).to(action.device).float().view(
                -1, 1)
            return rlt.ActorOutput(action=action, log_prob=log_prob)

        noise = self.noise_dist.sample((batch_size, ))
        # TODO: log prob is affected by clamping, how to handle that?
        log_prob = (self.noise_dist.log_prob(noise).to(
            action.device).sum(dim=1).view(-1, 1))
        action = (action + noise.to(action.device)).clamp(
            *CONTINUOUS_TRAINING_ACTION_RANGE)
        return rlt.ActorOutput(action=action, log_prob=log_prob)
Ejemplo n.º 14
0
 def sample_action(self, scores: torch.Tensor) -> rlt.ActorOutput:
     """Sample a ranking according to Frechet sort. Note that possible_actions_mask
     is ignored as the list of rankings scales exponentially with slate size and
     number of items and it can be difficult to enumerate them."""
     assert scores.dim() == 2, "sample_action only accepts batches"
     log_scores = scores if self.log_scores else torch.log(scores)
     perturbed = log_scores + self.gumbel_noise.sample((scores.shape[1],))
     action = torch.argsort(perturbed.detach(), descending=True)
     if self.topk is not None:
         action = action[: self.topk]
     log_prob = self.log_prob(scores, action)
     return rlt.ActorOutput(action, log_prob)
Ejemplo n.º 15
0
 def act(
     self, obs: rlt.FeatureData, possible_actions_mask: Optional[np.ndarray] = None
 ) -> rlt.ActorOutput:
     """ Act randomly regardless of the observation. """
     obs: torch.Tensor = obs.float_features
     assert obs.dim() >= 2, f"obs has shape {obs.shape} (dim < 2)"
     batch_size = obs.size(0)
     # pyre-fixme[6]: Expected `Union[torch.Size, torch.Tensor]` for 1st param
     #  but got `Tuple[int]`.
     action = self.dist.sample((batch_size,))
     # sum over action_dim (since assuming i.i.d. per coordinate)
     log_prob = self.dist.log_prob(action).sum(1)
     return rlt.ActorOutput(action=action, log_prob=log_prob)
Ejemplo n.º 16
0
    def act(self, obs: rlt.FeatureData) -> rlt.ActorOutput:
        """ Act randomly regardless of the observation. """
        obs: torch.Tensor = obs.float_features
        assert obs.dim() >= 2, f"obs has shape {obs.shape} (dim < 2)"
        batch_size = obs.shape[0]
        weights = torch.ones((batch_size, self.num_actions))

        # sample a random action
        m = torch.distributions.Categorical(weights)
        raw_action = m.sample()
        action = F.one_hot(raw_action, self.num_actions)
        log_prob = m.log_prob(raw_action).float()
        return rlt.ActorOutput(action=action, log_prob=log_prob)
Ejemplo n.º 17
0
 def sample_action(self, scores: torch.Tensor) -> rlt.ActorOutput:
     assert (
         scores.dim() == 2
     ), f"scores shape is {scores.shape}, not (batch_size, num_actions)"
     batch_size, num_actions = scores.shape
     m = self._get_distribution(scores)
     raw_action = m.sample()
     assert raw_action.shape == (
         batch_size, ), f"{raw_action.shape} != ({batch_size}, )"
     action = F.one_hot(raw_action, num_actions)
     assert action.ndim == 2
     log_prob = m.log_prob(raw_action)
     assert log_prob.ndim == 1
     return rlt.ActorOutput(action=action, log_prob=log_prob)
Ejemplo n.º 18
0
    def act(self, obs: rlt.FeatureData) -> rlt.ActorOutput:
        obs: torch.Tensor = obs.float_features
        batch_size, _ = obs.shape

        actions = []
        log_probs = []
        for m in self.dists:
            actions.append(m.sample((batch_size, 1)))
            log_probs.append(m.log_prob(actions[-1]).float())

        return rlt.ActorOutput(
            action=torch.cat(actions, dim=1),
            log_prob=torch.cat(log_probs, dim=1).sum(1, keepdim=True),
        )
Ejemplo n.º 19
0
    def act(
        self, obs: Any, possible_actions_mask: Optional[torch.Tensor] = None
    ) -> rlt.ActorOutput:
        """ Act randomly regardless of the observation. """
        weights = self.default_weights
        if possible_actions_mask:
            assert possible_actions_mask.shape == self.default_weights.shape
            weights = weights * possible_actions_mask

        # sample a random action
        m = torch.distributions.Categorical(weights)
        raw_action = m.sample()
        action = F.one_hot(raw_action, self.num_actions)
        log_prob = m.log_prob(raw_action).float()
        return rlt.ActorOutput(action=action, log_prob=log_prob)
Ejemplo n.º 20
0
    def forward(self, state: rlt.FeatureData):
        loc, scale_log = self._get_loc_and_scale_log(state)
        r = torch.randn_like(scale_log, device=scale_log.device)
        raw_action = loc + r * scale_log.exp()
        squashed_action = self._squash_raw_action(raw_action)
        squashed_loc = self._squash_raw_action(loc)
        if SummaryWriterContext._global_step % 1000 == 0:
            SummaryWriterContext.add_histogram("actor/forward/loc",
                                               loc.detach().cpu())
            SummaryWriterContext.add_histogram("actor/forward/scale_log",
                                               scale_log.detach().cpu())

        return rlt.ActorOutput(
            action=squashed_action,
            log_prob=self.get_log_prob(state, squashed_action),
            squashed_mean=squashed_loc,
        )
Ejemplo n.º 21
0
    def sample_action(self, scores: GaussianSamplerScore) -> rlt.ActorOutput:
        self.actor_network.eval()
        action, log_prob = self._sample_action(scores.loc, scores.scale_log)

        # clamp actions to make sure actions are in the range
        clamped_actions = torch.max(
            torch.min(action, self.max_training_action),
            self.min_training_action)
        rescaled_actions = rescale_torch_tensor(
            clamped_actions,
            new_min=self.min_serving_action,
            new_max=self.max_serving_action,
            prev_min=self.min_training_action,
            prev_max=self.max_training_action,
        )
        self.actor_network.train()
        return rlt.ActorOutput(action=rescaled_actions, log_prob=log_prob)
Ejemplo n.º 22
0
    def act(self,
            obs: rlt.FeatureData,
            possible_actions_mask: Optional[np.ndarray] = None
            ) -> rlt.ActorOutput:
        # pyre-fixme[35]: Target cannot be annotated.
        obs: torch.Tensor = obs.float_features
        batch_size, _ = obs.shape

        actions = []
        log_probs = []
        for m in self.dists:
            actions.append(m.sample((batch_size, 1)))
            log_probs.append(m.log_prob(actions[-1]).float())

        return rlt.ActorOutput(
            action=torch.cat(actions, dim=1),
            log_prob=torch.cat(log_probs, dim=1).sum(1, keepdim=True),
        )
Ejemplo n.º 23
0
 def sample_action(
     self, scores: torch.Tensor, possible_actions_mask: Optional[torch.Tensor] = None
 ) -> rlt.ActorOutput:
     assert scores.dim() == 2, (
         "scores dim is %d" % scores.dim()
     )  # batch_size x num_actions
     batch_size, num_actions = scores.shape
     if possible_actions_mask is not None:
         assert scores.shape == possible_actions_mask.shape
         mod_scores = scores.clone().float()
         mod_scores[~possible_actions_mask.bool()] = -float("inf")
     else:
         mod_scores = scores
     raw_action = mod_scores.argmax(dim=1)
     assert raw_action.ndim == 2
     action = F.one_hot(raw_action, num_actions)
     assert action.shape == (batch_size, num_actions)
     log_prob = torch.ones(batch_size, device=scores.device)
     return rlt.ActorOutput(action=action, log_prob=log_prob)
Ejemplo n.º 24
0
    def act(
        self, obs: rlt.FeatureData, possible_actions_mask: Optional[np.ndarray] = None
    ) -> rlt.ActorOutput:
        """ Act randomly regardless of the observation. """
        obs: torch.Tensor = obs.float_features
        assert obs.dim() >= 2, f"obs has shape {obs.shape} (dim < 2)"
        assert obs.shape[0] == 1, f"obs has shape {obs.shape} (0th dim != 1)"
        batch_size = obs.shape[0]
        scores = torch.ones((batch_size, self.num_actions))
        scores = apply_possible_actions_mask(
            scores, possible_actions_mask, invalid_score=0.0
        )

        # sample a random action
        m = torch.distributions.Categorical(scores)
        raw_action = m.sample()
        action = F.one_hot(raw_action, self.num_actions)
        log_prob = m.log_prob(raw_action).float()
        return rlt.ActorOutput(action=action, log_prob=log_prob)
Ejemplo n.º 25
0
    def sample_action(self, scores: torch.Tensor) -> rlt.ActorOutput:
        assert scores.dim() == 2, ("scores dim is %d" % scores.dim()
                                   )  # batch_size x num_actions
        batch_size, num_actions = scores.shape

        # pyre-fixme[16]: `Tensor` has no attribute `argmax`.
        argmax = F.one_hot(scores.argmax(dim=1), num_actions).bool()

        rand_prob = self.epsilon / num_actions
        p = torch.full_like(rand_prob, scores)

        greedy_prob = 1 - self.epsilon + rand_prob
        p[argmax] = greedy_prob

        m = torch.distributions.Categorical(probs=p)
        raw_action = m.sample()
        action = F.one_hot(raw_action, num_actions)
        assert action.shape == (batch_size, num_actions)
        log_prob = m.log_prob(raw_action)
        assert log_prob.shape == (batch_size, )
        return rlt.ActorOutput(action=action, log_prob=log_prob)
Ejemplo n.º 26
0
 def act(self, obs: Any) -> rlt.ActorOutput:
     action = self.predictor(obs).cpu()
     # TODO: return log_probs as well
     return rlt.ActorOutput(action=action)
Ejemplo n.º 27
0
 def act(self, obs: rlt.FeatureData) -> rlt.ActorOutput:
     # TODO: Why doesn't predictor take the whole preprocessed_state?
     state = obs.float_features
     actions = self.predictor.policy(states=state).greedy
     return rlt.ActorOutput(action=actions)
Ejemplo n.º 28
0
 def sample_action(self, scores: torch.Tensor) -> rlt.ActorOutput:
     top_values, item_idxs = torch.topk(scores, self.k, dim=1)
     return rlt.ActorOutput(action=item_idxs,
                            log_prob=torch.zeros(item_idxs.shape[0], 1))
Ejemplo n.º 29
0
 def forward(self, input):
     action = self.fc(input.state.float_features)
     return rlt.ActorOutput(action=action)
Ejemplo n.º 30
0
 def act(self, obs: Any, possible_actions_mask=None) -> rlt.ActorOutput:
     action = self.dist.sample()
     log_prob = self.dist.log_prob(action)
     return rlt.ActorOutput(action=action, log_prob=log_prob)