def forward(self, input): loc, scale_log = self._get_loc_and_scale_log(input.state) r = torch.randn_like(scale_log, device=scale_log.device) action = torch.tanh(loc + r * scale_log.exp()) if not self.training: # ONNX doesn't like reshape either.. return rlt.ActorOutput(action=action) # Since each dim are independent, log-prob is simply sum log_prob = self._log_prob(r, scale_log) squash_correction = self._squash_correction(action) if SummaryWriterContext._global_step % 1000 == 0: SummaryWriterContext.add_histogram("actor/forward/loc", loc.detach().cpu()) SummaryWriterContext.add_histogram("actor/forward/scale_log", scale_log.detach().cpu()) SummaryWriterContext.add_histogram("actor/forward/log_prob", log_prob.detach().cpu()) SummaryWriterContext.add_histogram( "actor/forward/squash_correction", squash_correction.detach().cpu()) log_prob = torch.sum(log_prob - squash_correction, dim=1) return rlt.ActorOutput(action=action, log_prob=log_prob.reshape(-1, 1), action_mean=loc)
def forward(self, input): loc, scale_log = self._get_loc_and_scale_log(input.state) r = torch.randn_like(scale_log, device=scale_log.device) action = torch.tanh(loc + r * scale_log.exp()) if not self.training: # ONNX doesn't like reshape either.. return rlt.ActorOutput(action=action) # Since each dim are independent, log-prob is simply sum log_prob = torch.sum(self._log_prob(r, scale_log) - self._squash_correction(action), dim=1) return rlt.ActorOutput(action=action, log_prob=log_prob.reshape(-1, 1))
def forward(self, input): concentration = self._get_concentration(input.state) if self.training: # PyTorch can't backwards pass _sample_dirichlet action = Dirichlet(concentration).rsample() else: # ONNX can't export Dirichlet() action = torch._sample_dirichlet(concentration) if not self.training: # ONNX doesn't like reshape either.. return rlt.ActorOutput(action=action) log_prob = Dirichlet(concentration).log_prob(action) return rlt.ActorOutput(action=action, log_prob=log_prob.unsqueeze(dim=1))
def forward(self, input): concentration = self._get_concentration(input.state) # Backwards pass of dirichlet distribution not implemented in PyTorch # so sample using Gamma distribution outlined here: # https://en.wikipedia.org/wiki/Dirichlet_distribution#Random_number_generation gamma_samples = torch._standard_gamma(concentration) action = gamma_samples / torch.sum(gamma_samples, dim=1, keepdim=True) if not self.training: # ONNX doesn't like reshape either.. return rlt.ActorOutput(action=action) log_prob = self.get_log_prob(input.state, action) return rlt.ActorOutput(action=action, log_prob=log_prob.unsqueeze(dim=1))
def sample_action( self, scores: torch.Tensor, possible_actions_mask: Optional[torch.Tensor] = None ) -> rlt.ActorOutput: # TODO: temp hack scores = scores.unsqueeze(0) assert scores.dim() == 2, ("scores dim is %d" % scores.dim() ) # batch_size x num_actions batch_size, num_actions = scores.shape if possible_actions_mask is None: possible_actions_mask = torch.ones(num_actions) argmax = F.one_hot(scores.argmax(dim=1), num_actions).bool() p = torch.zeros_like(scores) allowed_action_count = float(possible_actions_mask.sum().item()) mask = torch.repeat_interleave( possible_actions_mask.bool().unsqueeze(0), batch_size, axis=0) rand_prob = self.epsilon / allowed_action_count p[mask] = rand_prob greedy_prob = 1 - self.epsilon + rand_prob p[argmax] = greedy_prob m = torch.distributions.Categorical(probs=p) raw_action = m.sample() action = F.one_hot(raw_action, num_actions).squeeze(0) log_prob = m.log_prob(raw_action).squeeze(0) return rlt.ActorOutput(action=action, log_prob=log_prob)
def forward(self, input: rlt.StateInput) -> rlt.ActorOutput: """ Forward pass for actor network. Assumes activation names are valid pytorch activation names. :param input StateInput containing float_features """ if input.state.float_features is None: raise NotImplementedError("Not implemented for non-float_features!") action = self.network.forward(state=input.state.float_features) return rlt.ActorOutput(action=action)
def sample_action( self, scores: torch.Tensor, possible_actions_mask: Optional[torch.Tensor] = None ) -> rlt.ActorOutput: # TODO: temp hack scores = scores.unsqueeze(0) assert scores.dim() == 2, ("scores dim is %d" % scores.dim() ) # batch_size x num_actions _, num_actions = scores.shape if possible_actions_mask is not None: assert scores.shape == possible_actions_mask.shape mod_scores = scores.clone().float() mod_scores[~possible_actions_mask.bool()] = -float("inf") else: mod_scores = scores raw_action = mod_scores.argmax(dim=1) return rlt.ActorOutput(action=F.one_hot(raw_action, num_actions), log_prob=torch.tensor(1.0))
def sample_action(self, scores: GaussianSamplerScore) -> rlt.ActorOutput: self.actor_network.eval() action, log_prob = self._sample_action(scores.loc, scores.scale_log) # clamp actions to make sure actions are in the range clamped_actions = torch.max( torch.min(action, self.max_training_action), self.min_training_action) rescaled_actions = rescale_torch_tensor( clamped_actions, new_min=self.min_serving_action, new_max=self.max_serving_action, prev_min=self.min_training_action, prev_max=self.max_training_action, ) self.actor_network.train() action = rescaled_actions.squeeze(0) log_prob = torch.tensor(log_prob.item()) return rlt.ActorOutput(action=action, log_prob=log_prob)
def sample_action( self, scores: torch.Tensor, possible_actions_mask: Optional[torch.Tensor] = None ) -> rlt.ActorOutput: # TODO: temp hack, convert to single instead of batched scores = scores.unsqueeze(0) assert scores.dim() == 2, ("scores dim is %d" % scores.dim() ) # batch_size x num_actions _, num_actions = scores.shape f = F.log_softmax if self.key == "logits" else F.softmax if possible_actions_mask is not None: assert possible_actions_mask.dim() == 2 # batch_size x num_actions mod_scores = f(scores + torch.log(possible_actions_mask)) else: mod_scores = scores m = torch.distributions.Categorical( **{self.key: mod_scores / self.temperature}) raw_action = m.sample() action = F.one_hot(raw_action, num_actions).squeeze(0) log_prob = m.log_prob(raw_action).float().squeeze(0) return rlt.ActorOutput(action=action, log_prob=log_prob)
def forward(self, input): action = self.fc(input.state.float_features) return rlt.ActorOutput(action=action)