Ejemplo n.º 1
0
    def forward(self, input):
        loc, scale_log = self._get_loc_and_scale_log(input.state)
        r = torch.randn_like(scale_log, device=scale_log.device)
        action = torch.tanh(loc + r * scale_log.exp())
        if not self.training:
            # ONNX doesn't like reshape either..
            return rlt.ActorOutput(action=action)
        # Since each dim are independent, log-prob is simply sum
        log_prob = self._log_prob(r, scale_log)
        squash_correction = self._squash_correction(action)
        if SummaryWriterContext._global_step % 1000 == 0:
            SummaryWriterContext.add_histogram("actor/forward/loc",
                                               loc.detach().cpu())
            SummaryWriterContext.add_histogram("actor/forward/scale_log",
                                               scale_log.detach().cpu())
            SummaryWriterContext.add_histogram("actor/forward/log_prob",
                                               log_prob.detach().cpu())
            SummaryWriterContext.add_histogram(
                "actor/forward/squash_correction",
                squash_correction.detach().cpu())
        log_prob = torch.sum(log_prob - squash_correction, dim=1)

        return rlt.ActorOutput(action=action,
                               log_prob=log_prob.reshape(-1, 1),
                               action_mean=loc)
Ejemplo n.º 2
0
 def forward(self, input):
     loc, scale_log = self._get_loc_and_scale_log(input.state)
     r = torch.randn_like(scale_log, device=scale_log.device)
     action = torch.tanh(loc + r * scale_log.exp())
     if not self.training:
         # ONNX doesn't like reshape either..
         return rlt.ActorOutput(action=action)
     # Since each dim are independent, log-prob is simply sum
     log_prob = torch.sum(self._log_prob(r, scale_log) -
                          self._squash_correction(action),
                          dim=1)
     return rlt.ActorOutput(action=action, log_prob=log_prob.reshape(-1, 1))
Ejemplo n.º 3
0
    def forward(self, input):
        concentration = self._get_concentration(input.state)
        if self.training:
            # PyTorch can't backwards pass _sample_dirichlet
            action = Dirichlet(concentration).rsample()
        else:
            # ONNX can't export Dirichlet()
            action = torch._sample_dirichlet(concentration)

        if not self.training:
            # ONNX doesn't like reshape either..
            return rlt.ActorOutput(action=action)

        log_prob = Dirichlet(concentration).log_prob(action)
        return rlt.ActorOutput(action=action, log_prob=log_prob.unsqueeze(dim=1))
Ejemplo n.º 4
0
    def forward(self, input):
        concentration = self._get_concentration(input.state)
        # Backwards pass of dirichlet distribution not implemented in PyTorch
        # so sample using Gamma distribution outlined here:
        # https://en.wikipedia.org/wiki/Dirichlet_distribution#Random_number_generation
        gamma_samples = torch._standard_gamma(concentration)
        action = gamma_samples / torch.sum(gamma_samples, dim=1, keepdim=True)

        if not self.training:
            # ONNX doesn't like reshape either..
            return rlt.ActorOutput(action=action)

        log_prob = self.get_log_prob(input.state, action)
        return rlt.ActorOutput(action=action,
                               log_prob=log_prob.unsqueeze(dim=1))
Ejemplo n.º 5
0
    def sample_action(
        self,
        scores: torch.Tensor,
        possible_actions_mask: Optional[torch.Tensor] = None
    ) -> rlt.ActorOutput:
        # TODO: temp hack
        scores = scores.unsqueeze(0)
        assert scores.dim() == 2, ("scores dim is %d" % scores.dim()
                                   )  # batch_size x num_actions
        batch_size, num_actions = scores.shape

        if possible_actions_mask is None:
            possible_actions_mask = torch.ones(num_actions)

        argmax = F.one_hot(scores.argmax(dim=1), num_actions).bool()

        p = torch.zeros_like(scores)
        allowed_action_count = float(possible_actions_mask.sum().item())
        mask = torch.repeat_interleave(
            possible_actions_mask.bool().unsqueeze(0), batch_size, axis=0)

        rand_prob = self.epsilon / allowed_action_count
        p[mask] = rand_prob

        greedy_prob = 1 - self.epsilon + rand_prob
        p[argmax] = greedy_prob

        m = torch.distributions.Categorical(probs=p)
        raw_action = m.sample()
        action = F.one_hot(raw_action, num_actions).squeeze(0)
        log_prob = m.log_prob(raw_action).squeeze(0)
        return rlt.ActorOutput(action=action, log_prob=log_prob)
Ejemplo n.º 6
0
    def forward(self, input: rlt.StateInput) -> rlt.ActorOutput:
        """ Forward pass for actor network. Assumes activation names are
        valid pytorch activation names.
        :param input StateInput containing float_features
        """
        if input.state.float_features is None:
            raise NotImplementedError("Not implemented for non-float_features!")

        action = self.network.forward(state=input.state.float_features)
        return rlt.ActorOutput(action=action)
Ejemplo n.º 7
0
 def sample_action(
     self,
     scores: torch.Tensor,
     possible_actions_mask: Optional[torch.Tensor] = None
 ) -> rlt.ActorOutput:
     # TODO: temp hack
     scores = scores.unsqueeze(0)
     assert scores.dim() == 2, ("scores dim is %d" % scores.dim()
                                )  # batch_size x num_actions
     _, num_actions = scores.shape
     if possible_actions_mask is not None:
         assert scores.shape == possible_actions_mask.shape
         mod_scores = scores.clone().float()
         mod_scores[~possible_actions_mask.bool()] = -float("inf")
     else:
         mod_scores = scores
     raw_action = mod_scores.argmax(dim=1)
     return rlt.ActorOutput(action=F.one_hot(raw_action, num_actions),
                            log_prob=torch.tensor(1.0))
Ejemplo n.º 8
0
    def sample_action(self, scores: GaussianSamplerScore) -> rlt.ActorOutput:
        self.actor_network.eval()
        action, log_prob = self._sample_action(scores.loc, scores.scale_log)

        # clamp actions to make sure actions are in the range
        clamped_actions = torch.max(
            torch.min(action, self.max_training_action),
            self.min_training_action)
        rescaled_actions = rescale_torch_tensor(
            clamped_actions,
            new_min=self.min_serving_action,
            new_max=self.max_serving_action,
            prev_min=self.min_training_action,
            prev_max=self.max_training_action,
        )
        self.actor_network.train()
        action = rescaled_actions.squeeze(0)
        log_prob = torch.tensor(log_prob.item())
        return rlt.ActorOutput(action=action, log_prob=log_prob)
Ejemplo n.º 9
0
 def sample_action(
     self,
     scores: torch.Tensor,
     possible_actions_mask: Optional[torch.Tensor] = None
 ) -> rlt.ActorOutput:
     # TODO: temp hack, convert to single instead of batched
     scores = scores.unsqueeze(0)
     assert scores.dim() == 2, ("scores dim is %d" % scores.dim()
                                )  # batch_size x num_actions
     _, num_actions = scores.shape
     f = F.log_softmax if self.key == "logits" else F.softmax
     if possible_actions_mask is not None:
         assert possible_actions_mask.dim() == 2  # batch_size x num_actions
         mod_scores = f(scores + torch.log(possible_actions_mask))
     else:
         mod_scores = scores
     m = torch.distributions.Categorical(
         **{self.key: mod_scores / self.temperature})
     raw_action = m.sample()
     action = F.one_hot(raw_action, num_actions).squeeze(0)
     log_prob = m.log_prob(raw_action).float().squeeze(0)
     return rlt.ActorOutput(action=action, log_prob=log_prob)
Ejemplo n.º 10
0
 def forward(self, input):
     action = self.fc(input.state.float_features)
     return rlt.ActorOutput(action=action)