def gumbel_sample(self, data) : """ We use gumbel sampling to pick on the mixtures based on their mixture weights """ distribution = Gumbel(loc=0, scale=1) z = distribution.sample() # z = np.random.gumbel(loc=0, scale=1, size=data.shape) return (torch.log(data) + z).argmax(dim=1)
class SoftmaxRandomSamplePolicy(SimplePolicy): ''' Randomly samples from the softmax of the logits # https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/ TODO: should probably switch that to something more like http://pytorch.org/docs/master/distributions.html ''' def __init__(self, bias=None): ''' :param bias: a vector of log frequencies, to bias the sampling towards what we want ''' super().__init__() self.gumbel = Gumbel(loc=0, scale=1) if bias is not None: self.bias = torch.from_numpy(np.array(bias)).to(dtype=torch.float32, device=device).unsqueeze(0) else: self.bias = None def forward(self, logits: Variable): ''' :param logits: Logits to generate probabilities from, batch_size x out_dim float32 :return: ''' eff_logits = self.effective_logits(logits) x = self.gumbel.sample(logits.shape).to(device) + eff_logits _, out = torch.max(x, -1) return out def effective_logits(self, logits): if self.bias is not None: logits += self.bias return logits
class FrechetOrderSampler: def __init__( self, shape: float = 1, upto: Optional[int] = None ): """Samples an ordering Args: shape: 1/scale parameter of Gumbel distribution At lower temperatures, order is closer to argmax(scores) upto [Default: None]: If provided, position upto which log_prob is computed Returns: An action sampler which produces ordering """ self.shape = shape self.upto = upto self.gumbel_noise = Gumbel(0, 1.0 / shape) def sample( self, scores: torch.Tensor ): """Sample an ordering given scores""" perturbed = torch.log(scores) + self.gumbel_noise.sample((len(scores),)) return torch.argsort(-perturbed.detach()) def log_prob(self, scores : torch.Tensor, permutations): """Compute log probability given scores and an action (a permutation). The formula uses the equivalence of sorting with Gumbel noise and Plackett-Luce model (See Yellot 1977) Args: scores: scores of different items action: prescribed (permutation) order of the items """ s = torch.log(select_indices(scores, permutations)) n = len(scores) p = self.upto if self.upto is not None else n - 1 return -sum( torch.log(torch.exp((s[k:] - s[k]) * self.shape).sum(dim=0)) for k in range(p))
class SoftmaxRandomSamplePolicySparse(SimplePolicy): ''' Randomly samples from the softmax of the logits # https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/ TODO: should probably switch that to something more like http://pytorch.org/docs/master/distributions.html ''' def __init__(self, do_entropy=False): ''' ''' super().__init__() self.gumbel = Gumbel(loc=0, scale=1) self.do_entropy = do_entropy if self.do_entropy: self.entropy = None def forward(self, all_logits_list, all_action_inds_list): ''' :param logits: list of tensors of len batch_size :param action_inds: list of long tensors with indices of the actions corresponding to the logits :return: ''' # The number of feasible actions is differnt for every item in the batch, so for loop is the simplest logp = [] out_actions = [] for logits, action_inds in zip(all_logits_list, all_action_inds_list): if len(logits): x = self.gumbel.sample(logits.shape).to(device=device, dtype=logits.dtype) + logits _, out = torch.max(x, -1) this_logp = F.log_softmax(logits)[out] out_actions.append(action_inds[out]) logp.append(this_logp) else: out_actions.append(torch.tensor(0, device=logits.device, dtype=action_inds.dtype)) logp.append(torch.tensor(0.0, device=logits.device, dtype=logits.dtype)) self.logp = torch.stack(logp) out = torch.stack(out_actions) return out
def beam_search(self, init, steps, beam_size, temperature=1.0, stochastic=False, verbose=False): assert len(init.shape) == 2 and init.shape[1] == self.init_dim assert self.event_dim >= beam_size > 0 and steps > 0 batch_size = init.shape[0] current_beam_size = beam_size # Initial hidden weights hidden = self.init_to_hidden( init) # [gru_layers, batch_size, hidden_size] hidden = hidden[:, :, None, :] # [gru_layers, batch_size, 1, hidden_size] hidden = hidden.repeat( 1, 1, current_beam_size, 1) # [gru_layers, batch_size, beam_size, hidden_dim] # Initial event event = self.get_primary_event(batch_size) # [1, batch] event = event[:, :, None].repeat(1, 1, current_beam_size) # [1, batch, 1] # [batch, beam, 1] event sequences of beams beam_events = event[0, :, None, :].repeat(1, current_beam_size, 1) # [batch, beam] log probs sum of beams beam_log_prob = torch.zeros(batch_size, current_beam_size).to(device) if stochastic: # [batch, beam] Gumbel perturbed log probs of beams beam_log_prob_perturbed = torch.zeros(batch_size, current_beam_size).to(device) beam_z = torch.full((batch_size, beam_size), float('inf')) gumbel_dist = Gumbel(0, 1) step_iter = range(steps) if verbose: step_iter = Bar(['', 'Stochastic '][stochastic] + 'Beam Search').iter(step_iter) for step in step_iter: event = event.view(1, batch_size * current_beam_size) # [1, batch*beam0] hidden = hidden.view(self.rnn_layers, batch_size * current_beam_size, self.hidden_dim) # [grus, batch*beam, hid] logits, hidden = self.gen_forward(event, hidden) hidden = hidden.view(self.rnn_layers, batch_size, current_beam_size, self.hidden_dim) # [grus, batch, cbeam, hid] logits = (logits / temperature).view( 1, batch_size, current_beam_size, self.event_dim) # [1, batch, cbeam, out] beam_log_prob_expand = logits + beam_log_prob[ None, :, :, None] # [1, batch, cbeam, out] beam_log_prob_expand_batch = beam_log_prob_expand.view( 1, batch_size, -1) # [1, batch, cbeam*out] if stochastic: beam_log_prob_expand_perturbed = beam_log_prob_expand + gumbel_dist.sample( beam_log_prob_expand.shape) beam_log_prob_Z, _ = beam_log_prob_expand_perturbed.max( -1) # [1, batch, cbeam] # print(beam_log_prob_Z) beam_log_prob_expand_perturbed_normalized = beam_log_prob_expand_perturbed # beam_log_prob_expand_perturbed_normalized = -torch.log( # torch.exp(-beam_log_prob_perturbed[None, :, :, None]) # - torch.exp(-beam_log_prob_Z[:, :, :, None]) # + torch.exp(-beam_log_prob_expand_perturbed)) # [1, batch, cbeam, out] # beam_log_prob_expand_perturbed_normalized = beam_log_prob_perturbed[None, :, :, None] + beam_log_prob_expand_perturbed # [1, batch, cbeam, out] beam_log_prob_expand_perturbed_normalized_batch = \ beam_log_prob_expand_perturbed_normalized.view(1, batch_size, -1) # [1, batch, cbeam*out] _, top_indices = beam_log_prob_expand_perturbed_normalized_batch.topk( beam_size, -1) # [1, batch, cbeam] beam_log_prob_perturbed = \ torch.gather(beam_log_prob_expand_perturbed_normalized_batch, -1, top_indices)[0] # [batch, beam] else: _, top_indices = beam_log_prob_expand_batch.topk(beam_size, -1) top_indices.to(device) beam_log_prob = torch.gather(beam_log_prob_expand_batch, -1, top_indices)[0] # [batch, beam] beam_index_old = torch.arange( current_beam_size, device=device)[None, None, :, None] # [1, 1, cbeam, 1] beam_index_old = beam_index_old.repeat( 1, batch_size, 1, self.output_dim) # [1, batch, cbeam, out] beam_index_old = beam_index_old.view(1, batch_size, -1) # [1, batch, cbeam*out] # beam_index_old.to(device) # print(device) # print(beam_index_old.device) # print(top_indices.device) beam_index_new = torch.gather(beam_index_old, -1, top_indices) hidden = torch.gather( hidden, 2, beam_index_new[:, :, :, None].repeat(4, 1, 1, 1024)) event_index = torch.arange( self.output_dim, device=device)[None, None, None, :] # [1, 1, 1, out] event_index = event_index.repeat(1, batch_size, current_beam_size, 1) # [1, batch, cbeam, out] event_index = event_index.view(1, batch_size, -1) # [1, batch, cbeam*out] event_index.to(device) event = torch.gather(event_index, -1, top_indices) # [1, batch, cbeam*out] beam_events = torch.gather( beam_events[None], 2, beam_index_new.unsqueeze(-1).repeat(1, 1, 1, beam_events.shape[-1])) beam_events = torch.cat([beam_events, event.unsqueeze(-1)], -1)[0] current_beam_size = beam_size best = beam_events[torch.arange(batch_size).long(), beam_log_prob.argmax(-1)] best = best.contiguous().t() return best
class GumbelSoftmaxEmbeddingHelper(SoftmaxEmbeddingHelper): r"""A helper that feeds Gumbel softmax sample to the next step. Uses the Gumbel softmax vector to pass through word embeddings to get the next input (i.e., a mixed word embedding). A subclass of :class:`~texar.modules.Helper`. Used as a helper to :class:`~texar.modules.RNNDecoderBase` in inference mode. Same as :class:`~texar.modules.SoftmaxEmbeddingHelper` except that here Gumbel softmax (instead of softmax) is used. Args: embedding: A callable or the ``params`` argument for :torch_nn:`functional.embedding`. If a callable, it can take a vector tensor of ``ids`` (argmax ids), or take two arguments (``ids``, ``times``), where ``ids`` is a vector of argmax ids, and ``times`` is a vector of current time steps (i.e., position ids). The latter case can be used when :attr:`embedding` is a combination of word embedding and position embedding. The returned tensor will be passed to the decoder input. start_tokens: 1D :tensor:`LongTensor` shaped ``[batch_size]``, representing the start tokens for each sequence in batch. end_token: Python int or scalar :tensor:`LongTensor`, denoting the token that marks end of decoding. tau: A float scalar tensor, the softmax temperature. straight_through (bool): Whether to use straight through gradient between time steps. If `True`, a single token with highest probability (i.e., greedy sample) is fed to the next step and gradient is computed using straight through. If `False` (default), the soft Gumbel-softmax distribution is fed to the next step. stop_gradient (bool): Whether to stop the gradient backpropagation when feeding softmax vector to the next step. use_finish (bool): Whether to stop decoding once :attr:`end_token` is generated. If `False`, decoding will continue until :attr:`max_decoding_length` of the decoder is reached. Raises: ValueError: if :attr:`start_tokens` is not a 1D tensor or :attr:`end_token` is not a scalar. """ def __init__(self, start_tokens: torch.LongTensor, end_token: Union[int, torch.LongTensor], tau: float, straight_through: bool = False, stop_gradient: bool = False, use_finish: bool = True): super().__init__(start_tokens, end_token, tau, stop_gradient, use_finish) self._straight_through = straight_through # unit-scale, zero-location Gumbel distribution self._gumbel = Gumbel(loc=torch.tensor(0.0), scale=torch.tensor(1.0)) def sample(self, time: int, outputs: torch.Tensor) -> torch.Tensor: r"""Returns ``sample_id`` of shape ``[batch_size, vocab_size]``. If :attr:`straight_through` is `False`, this contains the Gumbel softmax distributions over vocabulary with temperature :attr:`tau`. If :attr:`straight_through` is `True`, this contains one-hot vectors of the greedy samples. """ gumbel_samples = self._gumbel.sample(outputs.size()).to( device=outputs.device, dtype=outputs.dtype) sample_ids = torch.softmax( (outputs + gumbel_samples) / self._tau, dim=-1) if self._straight_through: argmax_ids = torch.argmax(sample_ids, dim=-1).unsqueeze(1) sample_ids_hard = torch.zeros_like(sample_ids).scatter_( dim=-1, index=argmax_ids, value=1.0) # one-hot vectors sample_ids = (sample_ids_hard - sample_ids).detach() + sample_ids return sample_ids
class FrechetSort(Sampler): EPS = 1e-12 @resolve_defaults def __init__( self, shape: float = 1.0, topk: Optional[int] = None, equiv_len: Optional[int] = None, log_scores: bool = False, ): """FréchetSort is a softer version of descending sort which samples all possible orderings of items favoring orderings which resemble descending sort. This can be used to convert descending sort by rank score into a differentiable, stochastic policy amenable to policy gradient algorithms. :param shape: parameter of Frechet Distribution. Lower values correspond to aggressive deviations from descending sort. :param topk: If specified, only the first topk actions are specified. :param equiv_len: Orders are considered equivalent if the top equiv_len match. Used in probability computations. Essentially specifies the action space. :param log_scores Scores passed in are already log-transformed. In this case, we would simply add Gumbel noise. For LearnVM, we set this to be True because we expect input and output scores to be in the log space. Example: Consider the sampler: sampler = FrechetSort(shape=3, topk=5, equiv_len=3) Given a set of scores, this sampler will produce indices of items roughly resembling a argsort by scores in descending order. The higher the shape, the more it would resemble a descending argsort. `topk=5` means only the top 5 ranks will be output. The `equiv_len` determines what orders are considered equivalent for probability computation. In this example, the sampler will produce probability for the top 3 items appearing in a given order for the `log_prob` call. """ self.shape = shape self.topk = topk self.upto = equiv_len if topk is not None: if equiv_len is None: self.upto = topk # pyre-fixme[58]: `>` is not supported for operand types `Optional[int]` # and `Optional[int]`. if self.upto > self.topk: raise ValueError( f"Equiv length {equiv_len} cannot exceed topk={topk}.") self.gumbel_noise = Gumbel(0, 1.0 / shape) self.log_scores = log_scores def sample_action(self, scores: torch.Tensor) -> rlt.ActorOutput: """Sample a ranking according to Frechet sort. Note that possible_actions_mask is ignored as the list of rankings scales exponentially with slate size and number of items and it can be difficult to enumerate them.""" assert scores.dim() == 2, "sample_action only accepts batches" log_scores = scores if self.log_scores else torch.log(scores) perturbed = log_scores + self.gumbel_noise.sample(scores.shape) action = torch.argsort(perturbed.detach(), descending=True) log_prob = self.log_prob(scores, action) # Only truncate the action before returning if self.topk is not None: action = action[:self.topk] return rlt.ActorOutput(action, log_prob) def log_prob( self, scores: torch.Tensor, action: torch.Tensor, equiv_len_override: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ What is the probability of a given set of scores producing the given list of permutations only considering the top `equiv_len` ranks? We may want to override the default equiv_len here when we know the having larger action space doesn't matter. i.e. in Reels """ upto = self.upto if equiv_len_override is not None: assert equiv_len_override.shape == ( scores.shape[0], ), f"Invalid shape {equiv_len_override.shape}, compared to scores {scores.shape}. equiv_len_override {equiv_len_override}" upto = equiv_len_override.long() if self.topk is not None and torch.any( equiv_len_override > self.topk): raise ValueError( f"Override {equiv_len_override} cannot exceed topk={self.topk}." ) squeeze = False if len(scores.shape) == 1: squeeze = True scores = scores.unsqueeze(0) action = action.unsqueeze(0) assert len(action.shape) == len( scores.shape) == 2, "scores should be batch" if action.shape[1] > scores.shape[1]: raise ValueError( f"action cardinality ({action.shape[1]}) is larger than the number of scores ({scores.shape[1]})" ) elif action.shape[1] < scores.shape[1]: raise NotImplementedError( f"This semantic is ambiguous. If you have shorter slate, pad it with scores.shape[1] ({scores.shape[1]})" ) log_scores = scores if self.log_scores else torch.log(scores) n = log_scores.shape[-1] # Add scores for the padding value log_scores = torch.cat( [ log_scores, torch.full((log_scores.shape[0], 1), -math.inf, device=log_scores.device), ], dim=1, ) s = torch.gather(log_scores, 1, action) * self.shape p = upto if upto is not None else n # We should unsqueeze here if isinstance(p, int): probs = sum( torch.nan_to_num(F.log_softmax(s[:, i:], dim=1)[:, 0], neginf=0.0) for i in range(p)) elif isinstance(p, torch.Tensor): # do masked sum probs = sum( torch.nan_to_num(F.log_softmax(s[:, i:], dim=1)[:, 0], neginf=0.0) * (i < p).float() for i in range(n)) else: raise RuntimeError(f"p is {p}") return probs
class FrechetSort(Sampler): @resolve_defaults def __init__( self, shape: float = 1.0, topk: Optional[int] = None, equiv_len: Optional[int] = None, log_scores: bool = False, ): """FréchetSort is a softer version of descending sort which samples all possible orderings of items favoring orderings which resemble descending sort. This can be used to convert descending sort by rank score into a differentiable, stochastic policy amenable to policy gradient algorithms. :param shape: parameter of Frechet Distribution. Lower values correspond to aggressive deviations from descending sort. :param topk: If specified, only the first topk actions are specified. :param equiv_len: Orders are considered equivalent if the top equiv_len match. Used in probability computations :param log_scores Scores passed in are already log-transformed. In this case, we would simply add Gumbel noise. Example: Consider the sampler: sampler = FrechetSort(shape=3, topk=5, equiv_len=3) Given a set of scores, this sampler will produce indices of items roughly resembling a argsort by scores in descending order. The higher the shape, the more it would resemble a descending argsort. `topk=5` means only the top 5 ranks will be output. The `equiv_len` determines what orders are considered equivalent for probability computation. In this example, the sampler will produce probability for the top 3 items appearing in a given order for the `log_prob` call. """ self.shape = shape self.topk = topk self.upto = equiv_len if topk is not None: if equiv_len is None: self.upto = topk # pyre-fixme[58]: `>` is not supported for operand types `Optional[int]` # and `Optional[int]`. if self.upto > self.topk: raise ValueError(f"Equiv length {equiv_len} cannot exceed topk={topk}.") self.gumbel_noise = Gumbel(0, 1.0 / shape) self.log_scores = log_scores @staticmethod def select_indices(scores: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: """Helper for scores[actions] that are also works for batched tensors""" if len(actions.shape) > 1: num_rows = scores.size(0) row_indices = torch.arange(num_rows).unsqueeze(0).T # pyre-ignore[ 16 ] return scores[row_indices, actions].T else: return scores[actions] def sample_action(self, scores: torch.Tensor) -> rlt.ActorOutput: """Sample a ranking according to Frechet sort. Note that possible_actions_mask is ignored as the list of rankings scales exponentially with slate size and number of items and it can be difficult to enumerate them.""" assert scores.dim() == 2, "sample_action only accepts batches" log_scores = scores if self.log_scores else torch.log(scores) perturbed = log_scores + self.gumbel_noise.sample((scores.shape[1],)) action = torch.argsort(perturbed.detach(), descending=True) if self.topk is not None: action = action[: self.topk] log_prob = self.log_prob(scores, action) return rlt.ActorOutput(action, log_prob) def log_prob(self, scores: torch.Tensor, action) -> torch.Tensor: """What is the probability of a given set of scores producing the given list of permutations only considering the top `equiv_len` ranks?""" log_scores = scores if self.log_scores else torch.log(scores) s = self.select_indices(log_scores, action) n = len(log_scores) p = self.upto if self.upto is not None else n return -sum( torch.log(torch.exp((s[k:] - s[k]) * self.shape).sum(dim=0)) for k in range(p) # pyre-ignore )
def forward(self, inputs, hidden_state): inputs = inputs.reshape(-1, self.input_shape) h_in = hidden_state.reshape(-1, self.hidden_dim) #compute latent self.latent = F.softmax(self.embed_fc(inputs[:self.n_agents, - self.n_agents:]),dim=1) #(n, pi) latent_embed = self.latent.unsqueeze(0).expand(self.bs, self.n_agents, self.latent_dim).reshape( self.bs * self.n_agents, self.latent_dim) #(bs*n,pi) latent_infer = F.relu(self.inference_fc1(th.cat([h_in.detach(), inputs[:, :-self.n_agents]], dim=1))) latent_infer = F.softmax(self.inference_fc2(latent_infer),dim=1) # (bs*n,pi) #self.latent = self.embed_fc(inputs[:self.n_agents, - self.n_agents:]) # (n,2*latent_dim)==(n,mu+log var) #self.latent[:, -self.latent_dim:] = th.exp(self.latent[:, -self.latent_dim:]) # var #latent_embed = self.latent.unsqueeze(0).expand(self.bs, self.n_agents, self.latent_dim * 2).reshape( # self.bs * self.n_agents, self.latent_dim * 2) #latent_infer = F.relu(self.inference_fc1(th.cat([h_in, inputs[:, :-self.n_agents]], dim=1))) #latent_infer = self.inference_fc2(latent_infer) # (n,2*latent_dim)==(n,mu+log var) #latent_infer[:, -self.latent_dim:] = th.exp(latent_infer[:, -self.latent_dim:]) # loss #loss=(latent_embed-latent_infer).norm(dim=1).sum() #loss= -(latent_embed * th.log(latent_infer+self.eps)).sum()/(self.bs*self.n_agents) loss = -(latent_infer * th.log(latent_embed + self.eps)).sum() / (self.bs * self.n_agents) #loss = -((latent_infer * th.log(latent_embed + self.eps)).sum() + 0.01*(latent_embed*th.log(latent_embed+self.eps)).sum())/ (self.bs * self.n_agents) # sample g=Gumbel(0.0,1.0) latent_embed = F.softmax(th.log(latent_embed+self.eps) + g.sample(latent_embed.size()).cuda(), dim=1) # softmax onehot #latent_infer = F.softmax(th.log(latent_infer+self.eps) + g.sample(latent_infer.size()), dim=1) #gaussian_embed = D.Normal(latent_embed[:, :self.latent_dim], (latent_embed[:, self.latent_dim:])**(1/2)) #gaussian_infer = D.Normal(latent_infer[:, :self.latent_dim], (latent_infer[:, self.latent_dim:])**(1/2)) #loss = gaussian_embed.entropy().sum() + kl_divergence(gaussian_embed, gaussian_infer).sum() # CE = H + KL #loss = loss / (self.bs*self.n_agents) # handcrafted reparameterization # (1,n*latent_dim) (1,n*latent_dim)==>(bs,n*latent*dim) # latent_embed = self.latent[:,:self.latent_dim].reshape(1,-1)+self.latent[:,-self.latent_dim:].reshape(1,-1)*th.randn(self.bs,self.n_agents*self.latent_dim) # latent_embed = latent_embed.reshape(-1,self.latent_dim) #(bs*n,latent_dim) # latent_infer = latent_infer[:, :self.latent_dim] + latent_infer[:, -self.latent_dim:] * th.randn_like(latent_infer[:, -self.latent_dim:]) # loss= (latent_embed-latent_infer).norm(dim=1).sum()/(self.bs*self.n_agents) #latent = gaussian_embed.rsample() latent=self.latent_fc1(latent_embed) #latent = F.relu(self.latent_fc1(latent)) #latent = (self.latent_fc2(latent)) # latent=latent.reshape(-1,self.args.latent_dim) # fc1_w=F.relu(self.fc1_w_nn(latent)) # fc1_b=F.relu((self.fc1_b_nn(latent))) # fc1_w=fc1_w.reshape(-1,self.input_shape,self.args.rnn_hidden_dim) # fc1_b=fc1_b.reshape(-1,1,self.args.rnn_hidden_dim) # rnn_ih_w=F.relu(self.rnn_ih_w_nn(latent)) # rnn_ih_b=F.relu(self.rnn_ih_b_nn(latent)) # rnn_hh_w=F.relu(self.rnn_hh_w_nn(latent)) # rnn_hh_b=F.relu(self.rnn_hh_b_nn(latent)) # rnn_ih_w=rnn_ih_w.reshape(-1,self.args.rnn_hidden_dim,self.args.rnn_hidden_dim) # rnn_ih_b=rnn_ih_b.reshape(-1,1,self.args.rnn_hidden_dim) # rnn_hh_w = rnn_hh_w.reshape(-1, self.args.rnn_hidden_dim, self.args.rnn_hidden_dim) # rnn_hh_b = rnn_hh_b.reshape(-1, 1, self.args.rnn_hidden_dim) fc2_w = self.fc2_w_nn(latent) fc2_b = self.fc2_b_nn(latent) fc2_w = fc2_w.reshape(-1, self.args.rnn_hidden_dim, self.args.n_actions) fc2_b = fc2_b.reshape((-1, 1, self.args.n_actions)) # x=F.relu(th.bmm(inputs,fc1_w)+fc1_b) #(bs*n,(obs+act+id)) at time t x = F.relu(self.fc1(inputs)) # (bs*n,(obs+act+id)) at time t # gi=th.bmm(x,rnn_ih_w)+rnn_ih_b # gh=th.bmm(h_in,rnn_hh_w)+rnn_hh_b # i_r,i_i,i_n=gi.chunk(3,2) # h_r,h_i,h_n=gh.chunk(3,2) # resetgate=th.sigmoid(i_r+h_r) # inputgate=th.sigmoid(i_i+h_i) # newgate=th.tanh(i_n+resetgate*h_n) # h=newgate+inputgate*(h_in-newgate) # h=th.tanh(gi+gh) # x=x.reshape(-1,self.args.rnn_hidden_dim) h = self.rnn(x, h_in) h = h.reshape(-1, 1, self.args.rnn_hidden_dim) q = th.bmm(h, fc2_w) + fc2_b # h_in = hidden_state.reshape(-1, self.args.rnn_hidden_dim) # (bs,n,dim) ==> (bs*n, dim) # h = self.rnn(x, h_in) # q = self.fc2(h) return q.view(-1, self.args.n_actions), h.view(-1, self.args.rnn_hidden_dim), loss
class SoftmaxRandomSamplePolicy(SimplePolicy): ''' Randomly samples from the softmax of the logits # https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/ TODO: should probably switch that to something more like http://pytorch.org/docs/master/distributions.html ''' def __init__(self, temperature=1.0, eps=0.0, do_entropy=True): ''' :param bias: a vector of log frequencies, to bias the sampling towards what we want ''' super().__init__() self.gumbel = Gumbel(loc=0, scale=1) self.temperature = torch.tensor(temperature, requires_grad=False) self.eps = eps self.do_entropy = do_entropy if self.do_entropy: self.entropy = None def set_temperature(self, new_temperature): if self.temperature != new_temperature: self.temperature = torch.tensor(new_temperature, dtype=self.temperature.dtype, device=self.temperature.device) def forward(self, logits: Variable, priors=None): ''' :param logits: Logits to generate probabilities from, batch_size x out_dim float32 :return: ''' # epsilon-greediness if random.random() < self.eps: if priors is None: new_logits = torch.zeros_like(logits) new_logits[logits < logits.max()-1000] = -1e4 else: new_logits = priors if self.do_entropy: self.entropy = torch.zeros_like(logits).sum(dim=1) else: if self.temperature.dtype != logits.dtype or self.temperature.device != logits.device: self.temperature = torch.tensor(self.temperature, dtype=logits.dtype, device=logits.device, requires_grad=False) # print('temperature: ', self.temperature) # temperature is applied to model only, not priors! if priors is None: new_logits = logits/self.temperature raw_logits = new_logits else: raw_logits = (logits-priors)/self.temperature new_logits = priors + raw_logits if self.do_entropy: raw_logits_normalized = F.log_softmax(raw_logits, dim=1) self.entropy = torch.sum(-raw_logits_normalized*torch.exp(raw_logits_normalized), dim=1) eff_logits = new_logits x = self.gumbel.sample(logits.shape).to(device=device, dtype=eff_logits.dtype) + eff_logits _, out = torch.max(x, -1) all_logp = F.log_softmax(eff_logits, dim=1) self.logp = torch.cat([this_logp[this_ind:(this_ind+1)] for this_logp, this_ind in zip(all_logp, out)]) return out def effective_logits(self, logits): return logits