def gumbel_sample(self, data) : """ We use gumbel sampling to pick on the mixtures based on their mixture weights """ distribution = Gumbel(loc=0, scale=1) z = distribution.sample() # z = np.random.gumbel(loc=0, scale=1, size=data.shape) return (torch.log(data) + z).argmax(dim=1)
def __init__(self, start_tokens: torch.LongTensor, end_token: Union[int, torch.LongTensor], tau: float, straight_through: bool = False, stop_gradient: bool = False, use_finish: bool = True): super().__init__(start_tokens, end_token, tau, stop_gradient, use_finish) self._straight_through = straight_through # unit-scale, zero-location Gumbel distribution self._gumbel = Gumbel(loc=torch.tensor(0.0), scale=torch.tensor(1.0))
def __init__(self, do_entropy=False): ''' ''' super().__init__() self.gumbel = Gumbel(loc=0, scale=1) self.do_entropy = do_entropy if self.do_entropy: self.entropy = None
def gumbel_with_maximum(phi, T, dim=-1): """ Samples a set of gumbels which are conditioned on having a maximum along a dimension phi.max(dim)[0] should be broadcastable with the desired maximum T """ # Gumbel with location phi, use PyTorch distributions so you cannot get -inf or inf (which causes trouble) g_phi = Gumbel(phi, torch.ones_like(phi)).rsample() Z, argmax = g_phi.max(dim) g = _shift_gumbel_maximum(g_phi, T, dim, Z=Z) return g, argmax
def __init__(self, bias=None): ''' :param bias: a vector of log frequencies, to bias the sampling towards what we want ''' super().__init__() self.gumbel = Gumbel(loc=0, scale=1) if bias is not None: self.bias = torch.from_numpy(np.array(bias)).to(dtype=torch.float32, device=device).unsqueeze(0) else: self.bias = None
def __init__(self, temperature=1.0, eps=0.0, do_entropy=True): ''' :param bias: a vector of log frequencies, to bias the sampling towards what we want ''' super().__init__() self.gumbel = Gumbel(loc=0, scale=1) self.temperature = torch.tensor(temperature, requires_grad=False) self.eps = eps self.do_entropy = do_entropy if self.do_entropy: self.entropy = None
def cond_gumbel_sample(all_joint_log_probs, perturbed_log_probs) -> torch.Tensor: # Sample plates x k? x |D_yv| Gumbel variables gumbel_d = Gumbel(loc=all_joint_log_probs, scale=1.0) G_yv = gumbel_d.rsample() # Condition the Gumbel samples on the maximum of previous samples # plates x k Z = G_yv.max(dim=-1)[0] T = perturbed_log_probs vi = T - G_yv + log1mexp(G_yv - Z.unsqueeze(-1)) # plates (x k) x |D_yv| return T - vi.relu() - torch.nn.Softplus()(-vi.abs())
def __init__( self, shape: float = 1.0, topk: Optional[int] = None, equiv_len: Optional[int] = None, log_scores: bool = False, ): """FréchetSort is a softer version of descending sort which samples all possible orderings of items favoring orderings which resemble descending sort. This can be used to convert descending sort by rank score into a differentiable, stochastic policy amenable to policy gradient algorithms. :param shape: parameter of Frechet Distribution. Lower values correspond to aggressive deviations from descending sort. :param topk: If specified, only the first topk actions are specified. :param equiv_len: Orders are considered equivalent if the top equiv_len match. Used in probability computations. Essentially specifies the action space. :param log_scores Scores passed in are already log-transformed. In this case, we would simply add Gumbel noise. For LearnVM, we set this to be True because we expect input and output scores to be in the log space. Example: Consider the sampler: sampler = FrechetSort(shape=3, topk=5, equiv_len=3) Given a set of scores, this sampler will produce indices of items roughly resembling a argsort by scores in descending order. The higher the shape, the more it would resemble a descending argsort. `topk=5` means only the top 5 ranks will be output. The `equiv_len` determines what orders are considered equivalent for probability computation. In this example, the sampler will produce probability for the top 3 items appearing in a given order for the `log_prob` call. """ self.shape = shape self.topk = topk self.upto = equiv_len if topk is not None: if equiv_len is None: self.upto = topk # pyre-fixme[58]: `>` is not supported for operand types `Optional[int]` # and `Optional[int]`. if self.upto > self.topk: raise ValueError( f"Equiv length {equiv_len} cannot exceed topk={topk}.") self.gumbel_noise = Gumbel(0, 1.0 / shape) self.log_scores = log_scores
class SoftmaxRandomSamplePolicy(SimplePolicy): ''' Randomly samples from the softmax of the logits # https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/ TODO: should probably switch that to something more like http://pytorch.org/docs/master/distributions.html ''' def __init__(self, bias=None): ''' :param bias: a vector of log frequencies, to bias the sampling towards what we want ''' super().__init__() self.gumbel = Gumbel(loc=0, scale=1) if bias is not None: self.bias = torch.from_numpy(np.array(bias)).to(dtype=torch.float32, device=device).unsqueeze(0) else: self.bias = None def forward(self, logits: Variable): ''' :param logits: Logits to generate probabilities from, batch_size x out_dim float32 :return: ''' eff_logits = self.effective_logits(logits) x = self.gumbel.sample(logits.shape).to(device) + eff_logits _, out = torch.max(x, -1) return out def effective_logits(self, logits): if self.bias is not None: logits += self.bias return logits
def log_sample(self, x, u, n, alpha=1., xg=None, ug=None): c = self.c.cost(x, u, xg, ug) costs = c.reshape(-1, 1) # unnormalized log probabilities samples = Gumbel(loc=0., scale=1.).sample((x.shape[0], n)) log_gumbel = alpha * costs + samples _, choices = torch.max(log_gumbel, 0) return choices, alpha * costs
def reinforce_unordered(conditional_loss_fun, log_class_weights, class_weights_detached, seq_tensor, z_sample, epoch, data, n_samples=1, baseline_separate=False, baseline_n_samples=1, baseline_deterministic=False, baseline_constant=None): # Sample without replacement using Gumbel top-k trick phi = log_class_weights.detach() g_phi = Gumbel(phi, torch.ones_like(phi)).rsample() _, ind = g_phi.topk(n_samples, -1) log_p = log_class_weights.gather(-1, ind) n_classes = log_class_weights.shape[1] costs = torch.stack([ conditional_loss_fun(get_one_hot_encoding_from_int( z_sample, n_classes)) for z_sample in ind.t() ], -1) with torch.no_grad(): # Don't compute gradients for advantage and ratio # log_R_s, log_R_ss = compute_log_R(log_p) log_R_s, log_R_ss = compute_log_R_O_nfac(log_p) if baseline_constant is not None: bl_vals = baseline_constant elif baseline_separate: bl_vals = get_baseline(conditional_loss_fun, log_class_weights, baseline_n_samples, baseline_deterministic) # Same bl for all samples, so add dimension bl_vals = bl_vals[:, None] elif log_p.size(-1) > 1: # Compute built in baseline bl_vals = ((log_p[:, None, :] + log_R_ss).exp() * costs[:, None, :]).sum(-1) else: bl_vals = 0. # No bl adv = costs - bl_vals # Also add the costs (with the unordered estimator) in case there is a direct dependency on the parameters loss = ((log_p + log_R_s).exp() * adv.detach() + (log_p + log_R_s).exp().detach() * costs).sum(-1) return loss
def sample_stgs(proba_no_softmax, K, temp): device=proba_no_softmax.device dims=(proba_no_softmax.size(0),K) # sample with GS mean=torch.zeros(dims,device=device) scale=torch.ones(dims,device=device) samp_gumb=Gumbel(mean,scale).sample() g=samp_gumb.detach() out=(proba_no_softmax + g) / temp # soft approximation y=F.softmax(out,dim=-1) # use argmax at forward pass symb=y.argmax(dim=-1) # go throught one_hot for backward pass one_hot_symb=one_hot(symb,K) sample=(one_hot_symb - y).detach() + y return sample
def __init__( self, shape: float = 1, upto: Optional[int] = None ): """Samples an ordering Args: shape: 1/scale parameter of Gumbel distribution At lower temperatures, order is closer to argmax(scores) upto [Default: None]: If provided, position upto which log_prob is computed Returns: An action sampler which produces ordering """ self.shape = shape self.upto = upto self.gumbel_noise = Gumbel(0, 1.0 / shape)
def sample_gumbel_softmax(logits, temperature, n_sample): """ Input: logits: Tensor of log probs, shape = BS x k temperature = scalar Output: Tensor of values sampled from Gumbel softmax. These will tend towards a one-hot representation in the limit of temp -> 0 shape = n_sample x BS x k """ g = Gumbel(torch.zeros(*logits.shape), torch.ones(*logits.shape)).rsample( (n_sample, )).to(device) h = (g + logits) / temperature y = F.softmax(h, dim=-1) return y
class FrechetOrderSampler: def __init__( self, shape: float = 1, upto: Optional[int] = None ): """Samples an ordering Args: shape: 1/scale parameter of Gumbel distribution At lower temperatures, order is closer to argmax(scores) upto [Default: None]: If provided, position upto which log_prob is computed Returns: An action sampler which produces ordering """ self.shape = shape self.upto = upto self.gumbel_noise = Gumbel(0, 1.0 / shape) def sample( self, scores: torch.Tensor ): """Sample an ordering given scores""" perturbed = torch.log(scores) + self.gumbel_noise.sample((len(scores),)) return torch.argsort(-perturbed.detach()) def log_prob(self, scores : torch.Tensor, permutations): """Compute log probability given scores and an action (a permutation). The formula uses the equivalence of sorting with Gumbel noise and Plackett-Luce model (See Yellot 1977) Args: scores: scores of different items action: prescribed (permutation) order of the items """ s = torch.log(select_indices(scores, permutations)) n = len(scores) p = self.upto if self.upto is not None else n - 1 return -sum( torch.log(torch.exp((s[k:] - s[k]) * self.shape).sum(dim=0)) for k in range(p))
def sample_model_spec(self, num): """ Override, sample the alpha via gumbel softmax instead of normal softmax. :param num: :return: """ alpha_topology = self.alpha_topology.detach().clone() alpha_ops = self.alpha_ops.detach().clone() sample_archs = [] sample_ops = [] gumbel_dist = Gumbel(torch.tensor([.0]), torch.tensor([1.0])) with torch.no_grad(): for i in range(self.num_intermediate_nodes): # align with topoligy weights probs = gumbel_softmax(alpha_topology[: i+2, i], self.temperature(), gumbel_dist) sample_archs.append(Categorical(probs)) probs_op = gumbel_softmax(alpha_ops[:, i], self.temperature(), gumbel_dist) sample_ops.append(Categorical(probs_op)) return self._sample_model_spec(num, sample_archs, sample_ops)
class SoftmaxRandomSamplePolicySparse(SimplePolicy): ''' Randomly samples from the softmax of the logits # https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/ TODO: should probably switch that to something more like http://pytorch.org/docs/master/distributions.html ''' def __init__(self, do_entropy=False): ''' ''' super().__init__() self.gumbel = Gumbel(loc=0, scale=1) self.do_entropy = do_entropy if self.do_entropy: self.entropy = None def forward(self, all_logits_list, all_action_inds_list): ''' :param logits: list of tensors of len batch_size :param action_inds: list of long tensors with indices of the actions corresponding to the logits :return: ''' # The number of feasible actions is differnt for every item in the batch, so for loop is the simplest logp = [] out_actions = [] for logits, action_inds in zip(all_logits_list, all_action_inds_list): if len(logits): x = self.gumbel.sample(logits.shape).to(device=device, dtype=logits.dtype) + logits _, out = torch.max(x, -1) this_logp = F.log_softmax(logits)[out] out_actions.append(action_inds[out]) logp.append(this_logp) else: out_actions.append(torch.tensor(0, device=logits.device, dtype=action_inds.dtype)) logp.append(torch.tensor(0.0, device=logits.device, dtype=logits.dtype)) self.logp = torch.stack(logp) out = torch.stack(out_actions) return out
def ops_weights(self, node): gumbel_dist = Gumbel(torch.tensor([.0]), torch.tensor([1.0])) return gumbel_softmax(self.alpha_ops[:, node], self.temperature(), gumbel_dist)
def topology_weights(self, node): # return the soft weights for topology aggregation gumbel_dist = Gumbel(torch.tensor([.0]), torch.tensor([1.0])) return gumbel_softmax(self.alpha_topology[: node + 2, node], self.temperature(), gumbel_dist)[1:]
def rsample_gumbel( distr: Distribution, n: int, ) -> torch.Tensor: gumbel_distr = Gumbel(distr.logits, 1) return gumbel_distr.rsample((n, ))
def reinforce_sum_and_sample(conditional_loss_fun, log_class_weights, class_weights_detached, seq_tensor, z_sample, epoch, data, n_samples=1, baseline_separate=False, baseline_n_samples=1, baseline_deterministic=False, rao_blackwellize=False): # Sample without replacement using Gumbel top-k trick phi = log_class_weights.detach() g_phi = Gumbel(phi, torch.ones_like(phi)).rsample() _, ind = g_phi.topk(n_samples, -1) log_p = log_class_weights.gather(-1, ind) n_classes = log_class_weights.shape[1] costs = torch.stack([ conditional_loss_fun(get_one_hot_encoding_from_int( z_sample, n_classes)) for z_sample in ind.t() ], -1) with torch.no_grad(): # Don't compute gradients for advantage and ratio if baseline_separate: bl_vals = get_baseline(conditional_loss_fun, log_class_weights, baseline_n_samples, baseline_deterministic) # Same bl for all samples, so add dimension bl_vals = bl_vals[:, None] else: assert baseline_n_samples < n_samples bl_sampled_weight = log1mexp( log_p[:, :baseline_n_samples - 1].logsumexp(-1)).exp().detach() bl_vals = (log_p[:, :baseline_n_samples - 1].exp() * costs[:, :baseline_n_samples -1]).sum(-1)\ + bl_sampled_weight * costs[:, baseline_n_samples - 1] bl_vals = bl_vals[:, None] # We compute an 'exact' gradient if the sum of probabilities is roughly more than 1 - 1e-5 # in which case we can simply sum al the terms and the relative error will be < 1e-5 use_exact = log_p.logsumexp(-1) > -1e-5 not_use_exact = use_exact == 0 cost_exact = costs[use_exact] exact_loss = compute_summed_terms(log_p[use_exact], cost_exact, cost_exact - bl_vals[use_exact]) log_p_est = log_p[not_use_exact] costs_est = costs[not_use_exact] bl_vals_est = bl_vals[not_use_exact] if rao_blackwellize: ap = all_perms(torch.arange(n_samples, dtype=torch.long), device=log_p_est.device) log_p_ap = log_p_est[:, ap] bl_vals_ap = bl_vals_est.expand_as(costs_est)[:, ap] costs_ap = costs_est[:, ap] cond_losses = compute_sum_and_sample_loss(log_p_ap, costs_ap, bl_vals_ap) # Compute probabilities for permutations log_probs_perms = log_pl_rec(log_p_ap, -1) cond_log_probs_perms = log_probs_perms - log_probs_perms.logsumexp( -1, keepdim=True) losses = (cond_losses * cond_log_probs_perms.exp()).sum(-1) else: losses = compute_sum_and_sample_loss(log_p_est, costs_est, bl_vals_est) # If they are summed we can simply concatenate but for consistency it is best to place them in order all_losses = log_p.new_zeros(log_p.size(0)) all_losses[use_exact] = exact_loss all_losses[not_use_exact] = losses return all_losses
class FrechetSort(Sampler): EPS = 1e-12 @resolve_defaults def __init__( self, shape: float = 1.0, topk: Optional[int] = None, equiv_len: Optional[int] = None, log_scores: bool = False, ): """FréchetSort is a softer version of descending sort which samples all possible orderings of items favoring orderings which resemble descending sort. This can be used to convert descending sort by rank score into a differentiable, stochastic policy amenable to policy gradient algorithms. :param shape: parameter of Frechet Distribution. Lower values correspond to aggressive deviations from descending sort. :param topk: If specified, only the first topk actions are specified. :param equiv_len: Orders are considered equivalent if the top equiv_len match. Used in probability computations. Essentially specifies the action space. :param log_scores Scores passed in are already log-transformed. In this case, we would simply add Gumbel noise. For LearnVM, we set this to be True because we expect input and output scores to be in the log space. Example: Consider the sampler: sampler = FrechetSort(shape=3, topk=5, equiv_len=3) Given a set of scores, this sampler will produce indices of items roughly resembling a argsort by scores in descending order. The higher the shape, the more it would resemble a descending argsort. `topk=5` means only the top 5 ranks will be output. The `equiv_len` determines what orders are considered equivalent for probability computation. In this example, the sampler will produce probability for the top 3 items appearing in a given order for the `log_prob` call. """ self.shape = shape self.topk = topk self.upto = equiv_len if topk is not None: if equiv_len is None: self.upto = topk # pyre-fixme[58]: `>` is not supported for operand types `Optional[int]` # and `Optional[int]`. if self.upto > self.topk: raise ValueError( f"Equiv length {equiv_len} cannot exceed topk={topk}.") self.gumbel_noise = Gumbel(0, 1.0 / shape) self.log_scores = log_scores def sample_action(self, scores: torch.Tensor) -> rlt.ActorOutput: """Sample a ranking according to Frechet sort. Note that possible_actions_mask is ignored as the list of rankings scales exponentially with slate size and number of items and it can be difficult to enumerate them.""" assert scores.dim() == 2, "sample_action only accepts batches" log_scores = scores if self.log_scores else torch.log(scores) perturbed = log_scores + self.gumbel_noise.sample(scores.shape) action = torch.argsort(perturbed.detach(), descending=True) log_prob = self.log_prob(scores, action) # Only truncate the action before returning if self.topk is not None: action = action[:self.topk] return rlt.ActorOutput(action, log_prob) def log_prob( self, scores: torch.Tensor, action: torch.Tensor, equiv_len_override: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ What is the probability of a given set of scores producing the given list of permutations only considering the top `equiv_len` ranks? We may want to override the default equiv_len here when we know the having larger action space doesn't matter. i.e. in Reels """ upto = self.upto if equiv_len_override is not None: assert equiv_len_override.shape == ( scores.shape[0], ), f"Invalid shape {equiv_len_override.shape}, compared to scores {scores.shape}. equiv_len_override {equiv_len_override}" upto = equiv_len_override.long() if self.topk is not None and torch.any( equiv_len_override > self.topk): raise ValueError( f"Override {equiv_len_override} cannot exceed topk={self.topk}." ) squeeze = False if len(scores.shape) == 1: squeeze = True scores = scores.unsqueeze(0) action = action.unsqueeze(0) assert len(action.shape) == len( scores.shape) == 2, "scores should be batch" if action.shape[1] > scores.shape[1]: raise ValueError( f"action cardinality ({action.shape[1]}) is larger than the number of scores ({scores.shape[1]})" ) elif action.shape[1] < scores.shape[1]: raise NotImplementedError( f"This semantic is ambiguous. If you have shorter slate, pad it with scores.shape[1] ({scores.shape[1]})" ) log_scores = scores if self.log_scores else torch.log(scores) n = log_scores.shape[-1] # Add scores for the padding value log_scores = torch.cat( [ log_scores, torch.full((log_scores.shape[0], 1), -math.inf, device=log_scores.device), ], dim=1, ) s = torch.gather(log_scores, 1, action) * self.shape p = upto if upto is not None else n # We should unsqueeze here if isinstance(p, int): probs = sum( torch.nan_to_num(F.log_softmax(s[:, i:], dim=1)[:, 0], neginf=0.0) for i in range(p)) elif isinstance(p, torch.Tensor): # do masked sum probs = sum( torch.nan_to_num(F.log_softmax(s[:, i:], dim=1)[:, 0], neginf=0.0) * (i < p).float() for i in range(n)) else: raise RuntimeError(f"p is {p}") return probs
def sample_genotype_index(self): with torch.no_grad(): sampled_alphas = gumbel_softmax_sample( self.gumble_arch_params.squeeze(), temperature=self.gumbel_temperature, dim=0) best_sampled_alphas = torch.argmax(sampled_alphas, dim=0) return best_sampled_alphas.detach() class MixedSequential(torch.nn.Sequential): def forward(self, input): total_cost = 0 for module in self._modules.values(): input = module(input) if isinstance(module, MixedModule): input, cost = input total_cost += cost return input gumbel_dist = Gumbel(0.0, 1.0) def gumbel_softmax_sample(logits, temperature, dim=None, std=1.0): y = logits + gumbel_dist.sample(logits.shape).to(device=logits.device, dtype=logits.dtype) # y = logits + sample_gumbel(logits=logits, std=std) return F.softmax(y / temperature, dim=dim)
class FrechetSort(Sampler): @resolve_defaults def __init__( self, shape: float = 1.0, topk: Optional[int] = None, equiv_len: Optional[int] = None, log_scores: bool = False, ): """FréchetSort is a softer version of descending sort which samples all possible orderings of items favoring orderings which resemble descending sort. This can be used to convert descending sort by rank score into a differentiable, stochastic policy amenable to policy gradient algorithms. :param shape: parameter of Frechet Distribution. Lower values correspond to aggressive deviations from descending sort. :param topk: If specified, only the first topk actions are specified. :param equiv_len: Orders are considered equivalent if the top equiv_len match. Used in probability computations :param log_scores Scores passed in are already log-transformed. In this case, we would simply add Gumbel noise. Example: Consider the sampler: sampler = FrechetSort(shape=3, topk=5, equiv_len=3) Given a set of scores, this sampler will produce indices of items roughly resembling a argsort by scores in descending order. The higher the shape, the more it would resemble a descending argsort. `topk=5` means only the top 5 ranks will be output. The `equiv_len` determines what orders are considered equivalent for probability computation. In this example, the sampler will produce probability for the top 3 items appearing in a given order for the `log_prob` call. """ self.shape = shape self.topk = topk self.upto = equiv_len if topk is not None: if equiv_len is None: self.upto = topk # pyre-fixme[58]: `>` is not supported for operand types `Optional[int]` # and `Optional[int]`. if self.upto > self.topk: raise ValueError(f"Equiv length {equiv_len} cannot exceed topk={topk}.") self.gumbel_noise = Gumbel(0, 1.0 / shape) self.log_scores = log_scores @staticmethod def select_indices(scores: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: """Helper for scores[actions] that are also works for batched tensors""" if len(actions.shape) > 1: num_rows = scores.size(0) row_indices = torch.arange(num_rows).unsqueeze(0).T # pyre-ignore[ 16 ] return scores[row_indices, actions].T else: return scores[actions] def sample_action(self, scores: torch.Tensor) -> rlt.ActorOutput: """Sample a ranking according to Frechet sort. Note that possible_actions_mask is ignored as the list of rankings scales exponentially with slate size and number of items and it can be difficult to enumerate them.""" assert scores.dim() == 2, "sample_action only accepts batches" log_scores = scores if self.log_scores else torch.log(scores) perturbed = log_scores + self.gumbel_noise.sample((scores.shape[1],)) action = torch.argsort(perturbed.detach(), descending=True) if self.topk is not None: action = action[: self.topk] log_prob = self.log_prob(scores, action) return rlt.ActorOutput(action, log_prob) def log_prob(self, scores: torch.Tensor, action) -> torch.Tensor: """What is the probability of a given set of scores producing the given list of permutations only considering the top `equiv_len` ranks?""" log_scores = scores if self.log_scores else torch.log(scores) s = self.select_indices(log_scores, action) n = len(log_scores) p = self.upto if self.upto is not None else n return -sum( torch.log(torch.exp((s[k:] - s[k]) * self.shape).sum(dim=0)) for k in range(p) # pyre-ignore )
def gumbel_softmax_sample(logits, temp=1.): g = Gumbel(0, 1).sample(logits.shape) y = (g + logits) / temp return torch.softmax(y, dim=-1)
class SoftmaxRandomSamplePolicy(SimplePolicy): ''' Randomly samples from the softmax of the logits # https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/ TODO: should probably switch that to something more like http://pytorch.org/docs/master/distributions.html ''' def __init__(self, temperature=1.0, eps=0.0, do_entropy=True): ''' :param bias: a vector of log frequencies, to bias the sampling towards what we want ''' super().__init__() self.gumbel = Gumbel(loc=0, scale=1) self.temperature = torch.tensor(temperature, requires_grad=False) self.eps = eps self.do_entropy = do_entropy if self.do_entropy: self.entropy = None def set_temperature(self, new_temperature): if self.temperature != new_temperature: self.temperature = torch.tensor(new_temperature, dtype=self.temperature.dtype, device=self.temperature.device) def forward(self, logits: Variable, priors=None): ''' :param logits: Logits to generate probabilities from, batch_size x out_dim float32 :return: ''' # epsilon-greediness if random.random() < self.eps: if priors is None: new_logits = torch.zeros_like(logits) new_logits[logits < logits.max()-1000] = -1e4 else: new_logits = priors if self.do_entropy: self.entropy = torch.zeros_like(logits).sum(dim=1) else: if self.temperature.dtype != logits.dtype or self.temperature.device != logits.device: self.temperature = torch.tensor(self.temperature, dtype=logits.dtype, device=logits.device, requires_grad=False) # print('temperature: ', self.temperature) # temperature is applied to model only, not priors! if priors is None: new_logits = logits/self.temperature raw_logits = new_logits else: raw_logits = (logits-priors)/self.temperature new_logits = priors + raw_logits if self.do_entropy: raw_logits_normalized = F.log_softmax(raw_logits, dim=1) self.entropy = torch.sum(-raw_logits_normalized*torch.exp(raw_logits_normalized), dim=1) eff_logits = new_logits x = self.gumbel.sample(logits.shape).to(device=device, dtype=eff_logits.dtype) + eff_logits _, out = torch.max(x, -1) all_logp = F.log_softmax(eff_logits, dim=1) self.logp = torch.cat([this_logp[this_ind:(this_ind+1)] for this_logp, this_ind in zip(all_logp, out)]) return out def effective_logits(self, logits): return logits
class GumbelSoftmaxEmbeddingHelper(SoftmaxEmbeddingHelper): r"""A helper that feeds Gumbel softmax sample to the next step. Uses the Gumbel softmax vector to pass through word embeddings to get the next input (i.e., a mixed word embedding). A subclass of :class:`~texar.modules.Helper`. Used as a helper to :class:`~texar.modules.RNNDecoderBase` in inference mode. Same as :class:`~texar.modules.SoftmaxEmbeddingHelper` except that here Gumbel softmax (instead of softmax) is used. Args: embedding: A callable or the ``params`` argument for :torch_nn:`functional.embedding`. If a callable, it can take a vector tensor of ``ids`` (argmax ids), or take two arguments (``ids``, ``times``), where ``ids`` is a vector of argmax ids, and ``times`` is a vector of current time steps (i.e., position ids). The latter case can be used when :attr:`embedding` is a combination of word embedding and position embedding. The returned tensor will be passed to the decoder input. start_tokens: 1D :tensor:`LongTensor` shaped ``[batch_size]``, representing the start tokens for each sequence in batch. end_token: Python int or scalar :tensor:`LongTensor`, denoting the token that marks end of decoding. tau: A float scalar tensor, the softmax temperature. straight_through (bool): Whether to use straight through gradient between time steps. If `True`, a single token with highest probability (i.e., greedy sample) is fed to the next step and gradient is computed using straight through. If `False` (default), the soft Gumbel-softmax distribution is fed to the next step. stop_gradient (bool): Whether to stop the gradient backpropagation when feeding softmax vector to the next step. use_finish (bool): Whether to stop decoding once :attr:`end_token` is generated. If `False`, decoding will continue until :attr:`max_decoding_length` of the decoder is reached. Raises: ValueError: if :attr:`start_tokens` is not a 1D tensor or :attr:`end_token` is not a scalar. """ def __init__(self, start_tokens: torch.LongTensor, end_token: Union[int, torch.LongTensor], tau: float, straight_through: bool = False, stop_gradient: bool = False, use_finish: bool = True): super().__init__(start_tokens, end_token, tau, stop_gradient, use_finish) self._straight_through = straight_through # unit-scale, zero-location Gumbel distribution self._gumbel = Gumbel(loc=torch.tensor(0.0), scale=torch.tensor(1.0)) def sample(self, time: int, outputs: torch.Tensor) -> torch.Tensor: r"""Returns ``sample_id`` of shape ``[batch_size, vocab_size]``. If :attr:`straight_through` is `False`, this contains the Gumbel softmax distributions over vocabulary with temperature :attr:`tau`. If :attr:`straight_through` is `True`, this contains one-hot vectors of the greedy samples. """ gumbel_samples = self._gumbel.sample(outputs.size()).to( device=outputs.device, dtype=outputs.dtype) sample_ids = torch.softmax( (outputs + gumbel_samples) / self._tau, dim=-1) if self._straight_through: argmax_ids = torch.argmax(sample_ids, dim=-1).unsqueeze(1) sample_ids_hard = torch.zeros_like(sample_ids).scatter_( dim=-1, index=argmax_ids, value=1.0) # one-hot vectors sample_ids = (sample_ids_hard - sample_ids).detach() + sample_ids return sample_ids
def forward(self, inputs, hidden_state): inputs = inputs.reshape(-1, self.input_shape) h_in = hidden_state.reshape(-1, self.hidden_dim) #compute latent self.latent = F.softmax(self.embed_fc(inputs[:self.n_agents, - self.n_agents:]),dim=1) #(n, pi) latent_embed = self.latent.unsqueeze(0).expand(self.bs, self.n_agents, self.latent_dim).reshape( self.bs * self.n_agents, self.latent_dim) #(bs*n,pi) latent_infer = F.relu(self.inference_fc1(th.cat([h_in.detach(), inputs[:, :-self.n_agents]], dim=1))) latent_infer = F.softmax(self.inference_fc2(latent_infer),dim=1) # (bs*n,pi) #self.latent = self.embed_fc(inputs[:self.n_agents, - self.n_agents:]) # (n,2*latent_dim)==(n,mu+log var) #self.latent[:, -self.latent_dim:] = th.exp(self.latent[:, -self.latent_dim:]) # var #latent_embed = self.latent.unsqueeze(0).expand(self.bs, self.n_agents, self.latent_dim * 2).reshape( # self.bs * self.n_agents, self.latent_dim * 2) #latent_infer = F.relu(self.inference_fc1(th.cat([h_in, inputs[:, :-self.n_agents]], dim=1))) #latent_infer = self.inference_fc2(latent_infer) # (n,2*latent_dim)==(n,mu+log var) #latent_infer[:, -self.latent_dim:] = th.exp(latent_infer[:, -self.latent_dim:]) # loss #loss=(latent_embed-latent_infer).norm(dim=1).sum() #loss= -(latent_embed * th.log(latent_infer+self.eps)).sum()/(self.bs*self.n_agents) loss = -(latent_infer * th.log(latent_embed + self.eps)).sum() / (self.bs * self.n_agents) #loss = -((latent_infer * th.log(latent_embed + self.eps)).sum() + 0.01*(latent_embed*th.log(latent_embed+self.eps)).sum())/ (self.bs * self.n_agents) # sample g=Gumbel(0.0,1.0) latent_embed = F.softmax(th.log(latent_embed+self.eps) + g.sample(latent_embed.size()).cuda(), dim=1) # softmax onehot #latent_infer = F.softmax(th.log(latent_infer+self.eps) + g.sample(latent_infer.size()), dim=1) #gaussian_embed = D.Normal(latent_embed[:, :self.latent_dim], (latent_embed[:, self.latent_dim:])**(1/2)) #gaussian_infer = D.Normal(latent_infer[:, :self.latent_dim], (latent_infer[:, self.latent_dim:])**(1/2)) #loss = gaussian_embed.entropy().sum() + kl_divergence(gaussian_embed, gaussian_infer).sum() # CE = H + KL #loss = loss / (self.bs*self.n_agents) # handcrafted reparameterization # (1,n*latent_dim) (1,n*latent_dim)==>(bs,n*latent*dim) # latent_embed = self.latent[:,:self.latent_dim].reshape(1,-1)+self.latent[:,-self.latent_dim:].reshape(1,-1)*th.randn(self.bs,self.n_agents*self.latent_dim) # latent_embed = latent_embed.reshape(-1,self.latent_dim) #(bs*n,latent_dim) # latent_infer = latent_infer[:, :self.latent_dim] + latent_infer[:, -self.latent_dim:] * th.randn_like(latent_infer[:, -self.latent_dim:]) # loss= (latent_embed-latent_infer).norm(dim=1).sum()/(self.bs*self.n_agents) #latent = gaussian_embed.rsample() latent=self.latent_fc1(latent_embed) #latent = F.relu(self.latent_fc1(latent)) #latent = (self.latent_fc2(latent)) # latent=latent.reshape(-1,self.args.latent_dim) # fc1_w=F.relu(self.fc1_w_nn(latent)) # fc1_b=F.relu((self.fc1_b_nn(latent))) # fc1_w=fc1_w.reshape(-1,self.input_shape,self.args.rnn_hidden_dim) # fc1_b=fc1_b.reshape(-1,1,self.args.rnn_hidden_dim) # rnn_ih_w=F.relu(self.rnn_ih_w_nn(latent)) # rnn_ih_b=F.relu(self.rnn_ih_b_nn(latent)) # rnn_hh_w=F.relu(self.rnn_hh_w_nn(latent)) # rnn_hh_b=F.relu(self.rnn_hh_b_nn(latent)) # rnn_ih_w=rnn_ih_w.reshape(-1,self.args.rnn_hidden_dim,self.args.rnn_hidden_dim) # rnn_ih_b=rnn_ih_b.reshape(-1,1,self.args.rnn_hidden_dim) # rnn_hh_w = rnn_hh_w.reshape(-1, self.args.rnn_hidden_dim, self.args.rnn_hidden_dim) # rnn_hh_b = rnn_hh_b.reshape(-1, 1, self.args.rnn_hidden_dim) fc2_w = self.fc2_w_nn(latent) fc2_b = self.fc2_b_nn(latent) fc2_w = fc2_w.reshape(-1, self.args.rnn_hidden_dim, self.args.n_actions) fc2_b = fc2_b.reshape((-1, 1, self.args.n_actions)) # x=F.relu(th.bmm(inputs,fc1_w)+fc1_b) #(bs*n,(obs+act+id)) at time t x = F.relu(self.fc1(inputs)) # (bs*n,(obs+act+id)) at time t # gi=th.bmm(x,rnn_ih_w)+rnn_ih_b # gh=th.bmm(h_in,rnn_hh_w)+rnn_hh_b # i_r,i_i,i_n=gi.chunk(3,2) # h_r,h_i,h_n=gh.chunk(3,2) # resetgate=th.sigmoid(i_r+h_r) # inputgate=th.sigmoid(i_i+h_i) # newgate=th.tanh(i_n+resetgate*h_n) # h=newgate+inputgate*(h_in-newgate) # h=th.tanh(gi+gh) # x=x.reshape(-1,self.args.rnn_hidden_dim) h = self.rnn(x, h_in) h = h.reshape(-1, 1, self.args.rnn_hidden_dim) q = th.bmm(h, fc2_w) + fc2_b # h_in = hidden_state.reshape(-1, self.args.rnn_hidden_dim) # (bs,n,dim) ==> (bs*n, dim) # h = self.rnn(x, h_in) # q = self.fc2(h) return q.view(-1, self.args.n_actions), h.view(-1, self.args.rnn_hidden_dim), loss
def beam_search(self, init, steps, beam_size, temperature=1.0, stochastic=False, verbose=False): assert len(init.shape) == 2 and init.shape[1] == self.init_dim assert self.event_dim >= beam_size > 0 and steps > 0 batch_size = init.shape[0] current_beam_size = beam_size # Initial hidden weights hidden = self.init_to_hidden( init) # [gru_layers, batch_size, hidden_size] hidden = hidden[:, :, None, :] # [gru_layers, batch_size, 1, hidden_size] hidden = hidden.repeat( 1, 1, current_beam_size, 1) # [gru_layers, batch_size, beam_size, hidden_dim] # Initial event event = self.get_primary_event(batch_size) # [1, batch] event = event[:, :, None].repeat(1, 1, current_beam_size) # [1, batch, 1] # [batch, beam, 1] event sequences of beams beam_events = event[0, :, None, :].repeat(1, current_beam_size, 1) # [batch, beam] log probs sum of beams beam_log_prob = torch.zeros(batch_size, current_beam_size).to(device) if stochastic: # [batch, beam] Gumbel perturbed log probs of beams beam_log_prob_perturbed = torch.zeros(batch_size, current_beam_size).to(device) beam_z = torch.full((batch_size, beam_size), float('inf')) gumbel_dist = Gumbel(0, 1) step_iter = range(steps) if verbose: step_iter = Bar(['', 'Stochastic '][stochastic] + 'Beam Search').iter(step_iter) for step in step_iter: event = event.view(1, batch_size * current_beam_size) # [1, batch*beam0] hidden = hidden.view(self.rnn_layers, batch_size * current_beam_size, self.hidden_dim) # [grus, batch*beam, hid] logits, hidden = self.gen_forward(event, hidden) hidden = hidden.view(self.rnn_layers, batch_size, current_beam_size, self.hidden_dim) # [grus, batch, cbeam, hid] logits = (logits / temperature).view( 1, batch_size, current_beam_size, self.event_dim) # [1, batch, cbeam, out] beam_log_prob_expand = logits + beam_log_prob[ None, :, :, None] # [1, batch, cbeam, out] beam_log_prob_expand_batch = beam_log_prob_expand.view( 1, batch_size, -1) # [1, batch, cbeam*out] if stochastic: beam_log_prob_expand_perturbed = beam_log_prob_expand + gumbel_dist.sample( beam_log_prob_expand.shape) beam_log_prob_Z, _ = beam_log_prob_expand_perturbed.max( -1) # [1, batch, cbeam] # print(beam_log_prob_Z) beam_log_prob_expand_perturbed_normalized = beam_log_prob_expand_perturbed # beam_log_prob_expand_perturbed_normalized = -torch.log( # torch.exp(-beam_log_prob_perturbed[None, :, :, None]) # - torch.exp(-beam_log_prob_Z[:, :, :, None]) # + torch.exp(-beam_log_prob_expand_perturbed)) # [1, batch, cbeam, out] # beam_log_prob_expand_perturbed_normalized = beam_log_prob_perturbed[None, :, :, None] + beam_log_prob_expand_perturbed # [1, batch, cbeam, out] beam_log_prob_expand_perturbed_normalized_batch = \ beam_log_prob_expand_perturbed_normalized.view(1, batch_size, -1) # [1, batch, cbeam*out] _, top_indices = beam_log_prob_expand_perturbed_normalized_batch.topk( beam_size, -1) # [1, batch, cbeam] beam_log_prob_perturbed = \ torch.gather(beam_log_prob_expand_perturbed_normalized_batch, -1, top_indices)[0] # [batch, beam] else: _, top_indices = beam_log_prob_expand_batch.topk(beam_size, -1) top_indices.to(device) beam_log_prob = torch.gather(beam_log_prob_expand_batch, -1, top_indices)[0] # [batch, beam] beam_index_old = torch.arange( current_beam_size, device=device)[None, None, :, None] # [1, 1, cbeam, 1] beam_index_old = beam_index_old.repeat( 1, batch_size, 1, self.output_dim) # [1, batch, cbeam, out] beam_index_old = beam_index_old.view(1, batch_size, -1) # [1, batch, cbeam*out] # beam_index_old.to(device) # print(device) # print(beam_index_old.device) # print(top_indices.device) beam_index_new = torch.gather(beam_index_old, -1, top_indices) hidden = torch.gather( hidden, 2, beam_index_new[:, :, :, None].repeat(4, 1, 1, 1024)) event_index = torch.arange( self.output_dim, device=device)[None, None, None, :] # [1, 1, 1, out] event_index = event_index.repeat(1, batch_size, current_beam_size, 1) # [1, batch, cbeam, out] event_index = event_index.view(1, batch_size, -1) # [1, batch, cbeam*out] event_index.to(device) event = torch.gather(event_index, -1, top_indices) # [1, batch, cbeam*out] beam_events = torch.gather( beam_events[None], 2, beam_index_new.unsqueeze(-1).repeat(1, 1, 1, beam_events.shape[-1])) beam_events = torch.cat([beam_events, event.unsqueeze(-1)], -1)[0] current_beam_size = beam_size best = beam_events[torch.arange(batch_size).long(), beam_log_prob.argmax(-1)] best = best.contiguous().t() return best
def get_raoblackwell_ps_loss( conditional_loss_fun, log_class_weights, topk, sample_topk, grad_estimator, grad_estimator_kwargs={'grad_estimator_kwargs': None}, epoch=None, data=None): """ Returns a pseudo_loss, such that the gradient obtained by calling pseudo_loss.backwards() is unbiased for the true loss Parameters ---------- conditional_loss_fun : function A function that returns the loss conditional on an instance of the categorical random variable. It must take in a one-hot-encoding matrix (batchsize x n_categories) and return a vector of losses, one for each observation in the batch. log_class_weights : torch.Tensor A tensor of shape batchsize x n_categories of the log class weights topk : Integer The number of categories to sum over grad_estimator : function A function that returns the pseudo loss, that is, the loss which gives a gradient estimator when .backwards() is called. See baselines_lib for details. grad_estimator_kwargs : dict keyword arguments to gradient estimator epoch : int The epoch of the optimizer (for Gumbel-softmax, which has an annealing rate) data : torch.Tensor The data at which we evaluate the loss (for NVIl and RELAX, which have a data dependent baseline) Returns ------- ps_loss : a value such that ps_loss.backward() returns an estimate of the gradient. In general, ps_loss might not equal the actual loss. """ # class weights from the variational distribution assert np.all(log_class_weights.detach().cpu().numpy() <= 0) class_weights = torch.exp(log_class_weights.detach()) if sample_topk: # perturb the log_class_weights phi = log_class_weights.detach() g_phi = Gumbel(phi, torch.ones_like(phi)).rsample() _, ind = g_phi.topk(topk + 1, dim=-1) topk_domain = ind[..., :-1] concentrated_mask = torch.zeros_like(phi).scatter(-1, topk_domain, 1).detach() sample_ind = ind[..., -1] # Last sample we use as real sample seq_tensor = torch.arange(class_weights.size(0), dtype=torch.long, device=class_weights.device) else: # this is the indicator C_k concentrated_mask, topk_domain, seq_tensor = \ get_concentrated_mask(class_weights, topk) concentrated_mask = concentrated_mask.float().detach() ############################ # compute the summed term summed_term = 0.0 for i in range(topk): # get categories to be summed summed_indx = topk_domain[:, i] # compute gradient estimate grad_summed = \ grad_estimator(conditional_loss_fun, log_class_weights, class_weights, seq_tensor, \ z_sample = summed_indx, epoch = epoch, data = data, **grad_estimator_kwargs) # sum summed_weights = class_weights[seq_tensor, summed_indx].squeeze() summed_term = summed_term + \ (grad_summed * summed_weights).sum() ############################ # compute sampled term sampled_weight = torch.sum(class_weights * (1 - concentrated_mask), dim=1, keepdim=True) if not (topk == class_weights.shape[1]): # if we didn't sum everything # we sample from the remaining terms if not sample_topk: # class weights conditioned on being in the diffuse set conditional_class_weights = (class_weights + 1e-12) * \ (1 - concentrated_mask) / (sampled_weight + 1e-12) # sample from conditional distribution conditional_z_sample = sample_class_weights( conditional_class_weights) else: conditional_z_sample = sample_ind # We have already sampled it grad_sampled = grad_estimator(conditional_loss_fun, log_class_weights, class_weights, seq_tensor, z_sample=conditional_z_sample, epoch=epoch, data=data, **grad_estimator_kwargs) else: grad_sampled = 0. return (grad_sampled * sampled_weight.squeeze()).sum() + summed_term