def gumbel_sample(self, data) : """ We use gumbel sampling to pick on the mixtures based on their mixture weights """ distribution = Gumbel(loc=0, scale=1) z = distribution.sample() # z = np.random.gumbel(loc=0, scale=1, size=data.shape) return (torch.log(data) + z).argmax(dim=1)
def log_sample(self, x, u, n, alpha=1., xg=None, ug=None): c = self.c.cost(x, u, xg, ug) costs = c.reshape(-1, 1) # unnormalized log probabilities samples = Gumbel(loc=0., scale=1.).sample((x.shape[0], n)) log_gumbel = alpha * costs + samples _, choices = torch.max(log_gumbel, 0) return choices, alpha * costs
def __init__(self, do_entropy=False): ''' ''' super().__init__() self.gumbel = Gumbel(loc=0, scale=1) self.do_entropy = do_entropy if self.do_entropy: self.entropy = None
def __init__(self, start_tokens: torch.LongTensor, end_token: Union[int, torch.LongTensor], tau: float, straight_through: bool = False, stop_gradient: bool = False, use_finish: bool = True): super().__init__(start_tokens, end_token, tau, stop_gradient, use_finish) self._straight_through = straight_through # unit-scale, zero-location Gumbel distribution self._gumbel = Gumbel(loc=torch.tensor(0.0), scale=torch.tensor(1.0))
def gumbel_with_maximum(phi, T, dim=-1): """ Samples a set of gumbels which are conditioned on having a maximum along a dimension phi.max(dim)[0] should be broadcastable with the desired maximum T """ # Gumbel with location phi, use PyTorch distributions so you cannot get -inf or inf (which causes trouble) g_phi = Gumbel(phi, torch.ones_like(phi)).rsample() Z, argmax = g_phi.max(dim) g = _shift_gumbel_maximum(g_phi, T, dim, Z=Z) return g, argmax
def __init__(self, bias=None): ''' :param bias: a vector of log frequencies, to bias the sampling towards what we want ''' super().__init__() self.gumbel = Gumbel(loc=0, scale=1) if bias is not None: self.bias = torch.from_numpy(np.array(bias)).to(dtype=torch.float32, device=device).unsqueeze(0) else: self.bias = None
def __init__(self, temperature=1.0, eps=0.0, do_entropy=True): ''' :param bias: a vector of log frequencies, to bias the sampling towards what we want ''' super().__init__() self.gumbel = Gumbel(loc=0, scale=1) self.temperature = torch.tensor(temperature, requires_grad=False) self.eps = eps self.do_entropy = do_entropy if self.do_entropy: self.entropy = None
def cond_gumbel_sample(all_joint_log_probs, perturbed_log_probs) -> torch.Tensor: # Sample plates x k? x |D_yv| Gumbel variables gumbel_d = Gumbel(loc=all_joint_log_probs, scale=1.0) G_yv = gumbel_d.rsample() # Condition the Gumbel samples on the maximum of previous samples # plates x k Z = G_yv.max(dim=-1)[0] T = perturbed_log_probs vi = T - G_yv + log1mexp(G_yv - Z.unsqueeze(-1)) # plates (x k) x |D_yv| return T - vi.relu() - torch.nn.Softplus()(-vi.abs())
def __init__( self, shape: float = 1.0, topk: Optional[int] = None, equiv_len: Optional[int] = None, log_scores: bool = False, ): """FréchetSort is a softer version of descending sort which samples all possible orderings of items favoring orderings which resemble descending sort. This can be used to convert descending sort by rank score into a differentiable, stochastic policy amenable to policy gradient algorithms. :param shape: parameter of Frechet Distribution. Lower values correspond to aggressive deviations from descending sort. :param topk: If specified, only the first topk actions are specified. :param equiv_len: Orders are considered equivalent if the top equiv_len match. Used in probability computations. Essentially specifies the action space. :param log_scores Scores passed in are already log-transformed. In this case, we would simply add Gumbel noise. For LearnVM, we set this to be True because we expect input and output scores to be in the log space. Example: Consider the sampler: sampler = FrechetSort(shape=3, topk=5, equiv_len=3) Given a set of scores, this sampler will produce indices of items roughly resembling a argsort by scores in descending order. The higher the shape, the more it would resemble a descending argsort. `topk=5` means only the top 5 ranks will be output. The `equiv_len` determines what orders are considered equivalent for probability computation. In this example, the sampler will produce probability for the top 3 items appearing in a given order for the `log_prob` call. """ self.shape = shape self.topk = topk self.upto = equiv_len if topk is not None: if equiv_len is None: self.upto = topk # pyre-fixme[58]: `>` is not supported for operand types `Optional[int]` # and `Optional[int]`. if self.upto > self.topk: raise ValueError( f"Equiv length {equiv_len} cannot exceed topk={topk}.") self.gumbel_noise = Gumbel(0, 1.0 / shape) self.log_scores = log_scores
def sample_gumbel_softmax(logits, temperature, n_sample): """ Input: logits: Tensor of log probs, shape = BS x k temperature = scalar Output: Tensor of values sampled from Gumbel softmax. These will tend towards a one-hot representation in the limit of temp -> 0 shape = n_sample x BS x k """ g = Gumbel(torch.zeros(*logits.shape), torch.ones(*logits.shape)).rsample( (n_sample, )).to(device) h = (g + logits) / temperature y = F.softmax(h, dim=-1) return y
def reinforce_unordered(conditional_loss_fun, log_class_weights, class_weights_detached, seq_tensor, z_sample, epoch, data, n_samples=1, baseline_separate=False, baseline_n_samples=1, baseline_deterministic=False, baseline_constant=None): # Sample without replacement using Gumbel top-k trick phi = log_class_weights.detach() g_phi = Gumbel(phi, torch.ones_like(phi)).rsample() _, ind = g_phi.topk(n_samples, -1) log_p = log_class_weights.gather(-1, ind) n_classes = log_class_weights.shape[1] costs = torch.stack([ conditional_loss_fun(get_one_hot_encoding_from_int( z_sample, n_classes)) for z_sample in ind.t() ], -1) with torch.no_grad(): # Don't compute gradients for advantage and ratio # log_R_s, log_R_ss = compute_log_R(log_p) log_R_s, log_R_ss = compute_log_R_O_nfac(log_p) if baseline_constant is not None: bl_vals = baseline_constant elif baseline_separate: bl_vals = get_baseline(conditional_loss_fun, log_class_weights, baseline_n_samples, baseline_deterministic) # Same bl for all samples, so add dimension bl_vals = bl_vals[:, None] elif log_p.size(-1) > 1: # Compute built in baseline bl_vals = ((log_p[:, None, :] + log_R_ss).exp() * costs[:, None, :]).sum(-1) else: bl_vals = 0. # No bl adv = costs - bl_vals # Also add the costs (with the unordered estimator) in case there is a direct dependency on the parameters loss = ((log_p + log_R_s).exp() * adv.detach() + (log_p + log_R_s).exp().detach() * costs).sum(-1) return loss
def sample_stgs(proba_no_softmax, K, temp): device=proba_no_softmax.device dims=(proba_no_softmax.size(0),K) # sample with GS mean=torch.zeros(dims,device=device) scale=torch.ones(dims,device=device) samp_gumb=Gumbel(mean,scale).sample() g=samp_gumb.detach() out=(proba_no_softmax + g) / temp # soft approximation y=F.softmax(out,dim=-1) # use argmax at forward pass symb=y.argmax(dim=-1) # go throught one_hot for backward pass one_hot_symb=one_hot(symb,K) sample=(one_hot_symb - y).detach() + y return sample
def __init__( self, shape: float = 1, upto: Optional[int] = None ): """Samples an ordering Args: shape: 1/scale parameter of Gumbel distribution At lower temperatures, order is closer to argmax(scores) upto [Default: None]: If provided, position upto which log_prob is computed Returns: An action sampler which produces ordering """ self.shape = shape self.upto = upto self.gumbel_noise = Gumbel(0, 1.0 / shape)
def sample_model_spec(self, num): """ Override, sample the alpha via gumbel softmax instead of normal softmax. :param num: :return: """ alpha_topology = self.alpha_topology.detach().clone() alpha_ops = self.alpha_ops.detach().clone() sample_archs = [] sample_ops = [] gumbel_dist = Gumbel(torch.tensor([.0]), torch.tensor([1.0])) with torch.no_grad(): for i in range(self.num_intermediate_nodes): # align with topoligy weights probs = gumbel_softmax(alpha_topology[: i+2, i], self.temperature(), gumbel_dist) sample_archs.append(Categorical(probs)) probs_op = gumbel_softmax(alpha_ops[:, i], self.temperature(), gumbel_dist) sample_ops.append(Categorical(probs_op)) return self._sample_model_spec(num, sample_archs, sample_ops)
def forward(self, inputs, hidden_state): inputs = inputs.reshape(-1, self.input_shape) h_in = hidden_state.reshape(-1, self.hidden_dim) #compute latent self.latent = F.softmax(self.embed_fc(inputs[:self.n_agents, - self.n_agents:]),dim=1) #(n, pi) latent_embed = self.latent.unsqueeze(0).expand(self.bs, self.n_agents, self.latent_dim).reshape( self.bs * self.n_agents, self.latent_dim) #(bs*n,pi) latent_infer = F.relu(self.inference_fc1(th.cat([h_in.detach(), inputs[:, :-self.n_agents]], dim=1))) latent_infer = F.softmax(self.inference_fc2(latent_infer),dim=1) # (bs*n,pi) #self.latent = self.embed_fc(inputs[:self.n_agents, - self.n_agents:]) # (n,2*latent_dim)==(n,mu+log var) #self.latent[:, -self.latent_dim:] = th.exp(self.latent[:, -self.latent_dim:]) # var #latent_embed = self.latent.unsqueeze(0).expand(self.bs, self.n_agents, self.latent_dim * 2).reshape( # self.bs * self.n_agents, self.latent_dim * 2) #latent_infer = F.relu(self.inference_fc1(th.cat([h_in, inputs[:, :-self.n_agents]], dim=1))) #latent_infer = self.inference_fc2(latent_infer) # (n,2*latent_dim)==(n,mu+log var) #latent_infer[:, -self.latent_dim:] = th.exp(latent_infer[:, -self.latent_dim:]) # loss #loss=(latent_embed-latent_infer).norm(dim=1).sum() #loss= -(latent_embed * th.log(latent_infer+self.eps)).sum()/(self.bs*self.n_agents) loss = -(latent_infer * th.log(latent_embed + self.eps)).sum() / (self.bs * self.n_agents) #loss = -((latent_infer * th.log(latent_embed + self.eps)).sum() + 0.01*(latent_embed*th.log(latent_embed+self.eps)).sum())/ (self.bs * self.n_agents) # sample g=Gumbel(0.0,1.0) latent_embed = F.softmax(th.log(latent_embed+self.eps) + g.sample(latent_embed.size()).cuda(), dim=1) # softmax onehot #latent_infer = F.softmax(th.log(latent_infer+self.eps) + g.sample(latent_infer.size()), dim=1) #gaussian_embed = D.Normal(latent_embed[:, :self.latent_dim], (latent_embed[:, self.latent_dim:])**(1/2)) #gaussian_infer = D.Normal(latent_infer[:, :self.latent_dim], (latent_infer[:, self.latent_dim:])**(1/2)) #loss = gaussian_embed.entropy().sum() + kl_divergence(gaussian_embed, gaussian_infer).sum() # CE = H + KL #loss = loss / (self.bs*self.n_agents) # handcrafted reparameterization # (1,n*latent_dim) (1,n*latent_dim)==>(bs,n*latent*dim) # latent_embed = self.latent[:,:self.latent_dim].reshape(1,-1)+self.latent[:,-self.latent_dim:].reshape(1,-1)*th.randn(self.bs,self.n_agents*self.latent_dim) # latent_embed = latent_embed.reshape(-1,self.latent_dim) #(bs*n,latent_dim) # latent_infer = latent_infer[:, :self.latent_dim] + latent_infer[:, -self.latent_dim:] * th.randn_like(latent_infer[:, -self.latent_dim:]) # loss= (latent_embed-latent_infer).norm(dim=1).sum()/(self.bs*self.n_agents) #latent = gaussian_embed.rsample() latent=self.latent_fc1(latent_embed) #latent = F.relu(self.latent_fc1(latent)) #latent = (self.latent_fc2(latent)) # latent=latent.reshape(-1,self.args.latent_dim) # fc1_w=F.relu(self.fc1_w_nn(latent)) # fc1_b=F.relu((self.fc1_b_nn(latent))) # fc1_w=fc1_w.reshape(-1,self.input_shape,self.args.rnn_hidden_dim) # fc1_b=fc1_b.reshape(-1,1,self.args.rnn_hidden_dim) # rnn_ih_w=F.relu(self.rnn_ih_w_nn(latent)) # rnn_ih_b=F.relu(self.rnn_ih_b_nn(latent)) # rnn_hh_w=F.relu(self.rnn_hh_w_nn(latent)) # rnn_hh_b=F.relu(self.rnn_hh_b_nn(latent)) # rnn_ih_w=rnn_ih_w.reshape(-1,self.args.rnn_hidden_dim,self.args.rnn_hidden_dim) # rnn_ih_b=rnn_ih_b.reshape(-1,1,self.args.rnn_hidden_dim) # rnn_hh_w = rnn_hh_w.reshape(-1, self.args.rnn_hidden_dim, self.args.rnn_hidden_dim) # rnn_hh_b = rnn_hh_b.reshape(-1, 1, self.args.rnn_hidden_dim) fc2_w = self.fc2_w_nn(latent) fc2_b = self.fc2_b_nn(latent) fc2_w = fc2_w.reshape(-1, self.args.rnn_hidden_dim, self.args.n_actions) fc2_b = fc2_b.reshape((-1, 1, self.args.n_actions)) # x=F.relu(th.bmm(inputs,fc1_w)+fc1_b) #(bs*n,(obs+act+id)) at time t x = F.relu(self.fc1(inputs)) # (bs*n,(obs+act+id)) at time t # gi=th.bmm(x,rnn_ih_w)+rnn_ih_b # gh=th.bmm(h_in,rnn_hh_w)+rnn_hh_b # i_r,i_i,i_n=gi.chunk(3,2) # h_r,h_i,h_n=gh.chunk(3,2) # resetgate=th.sigmoid(i_r+h_r) # inputgate=th.sigmoid(i_i+h_i) # newgate=th.tanh(i_n+resetgate*h_n) # h=newgate+inputgate*(h_in-newgate) # h=th.tanh(gi+gh) # x=x.reshape(-1,self.args.rnn_hidden_dim) h = self.rnn(x, h_in) h = h.reshape(-1, 1, self.args.rnn_hidden_dim) q = th.bmm(h, fc2_w) + fc2_b # h_in = hidden_state.reshape(-1, self.args.rnn_hidden_dim) # (bs,n,dim) ==> (bs*n, dim) # h = self.rnn(x, h_in) # q = self.fc2(h) return q.view(-1, self.args.n_actions), h.view(-1, self.args.rnn_hidden_dim), loss
def reinforce_sum_and_sample(conditional_loss_fun, log_class_weights, class_weights_detached, seq_tensor, z_sample, epoch, data, n_samples=1, baseline_separate=False, baseline_n_samples=1, baseline_deterministic=False, rao_blackwellize=False): # Sample without replacement using Gumbel top-k trick phi = log_class_weights.detach() g_phi = Gumbel(phi, torch.ones_like(phi)).rsample() _, ind = g_phi.topk(n_samples, -1) log_p = log_class_weights.gather(-1, ind) n_classes = log_class_weights.shape[1] costs = torch.stack([ conditional_loss_fun(get_one_hot_encoding_from_int( z_sample, n_classes)) for z_sample in ind.t() ], -1) with torch.no_grad(): # Don't compute gradients for advantage and ratio if baseline_separate: bl_vals = get_baseline(conditional_loss_fun, log_class_weights, baseline_n_samples, baseline_deterministic) # Same bl for all samples, so add dimension bl_vals = bl_vals[:, None] else: assert baseline_n_samples < n_samples bl_sampled_weight = log1mexp( log_p[:, :baseline_n_samples - 1].logsumexp(-1)).exp().detach() bl_vals = (log_p[:, :baseline_n_samples - 1].exp() * costs[:, :baseline_n_samples -1]).sum(-1)\ + bl_sampled_weight * costs[:, baseline_n_samples - 1] bl_vals = bl_vals[:, None] # We compute an 'exact' gradient if the sum of probabilities is roughly more than 1 - 1e-5 # in which case we can simply sum al the terms and the relative error will be < 1e-5 use_exact = log_p.logsumexp(-1) > -1e-5 not_use_exact = use_exact == 0 cost_exact = costs[use_exact] exact_loss = compute_summed_terms(log_p[use_exact], cost_exact, cost_exact - bl_vals[use_exact]) log_p_est = log_p[not_use_exact] costs_est = costs[not_use_exact] bl_vals_est = bl_vals[not_use_exact] if rao_blackwellize: ap = all_perms(torch.arange(n_samples, dtype=torch.long), device=log_p_est.device) log_p_ap = log_p_est[:, ap] bl_vals_ap = bl_vals_est.expand_as(costs_est)[:, ap] costs_ap = costs_est[:, ap] cond_losses = compute_sum_and_sample_loss(log_p_ap, costs_ap, bl_vals_ap) # Compute probabilities for permutations log_probs_perms = log_pl_rec(log_p_ap, -1) cond_log_probs_perms = log_probs_perms - log_probs_perms.logsumexp( -1, keepdim=True) losses = (cond_losses * cond_log_probs_perms.exp()).sum(-1) else: losses = compute_sum_and_sample_loss(log_p_est, costs_est, bl_vals_est) # If they are summed we can simply concatenate but for consistency it is best to place them in order all_losses = log_p.new_zeros(log_p.size(0)) all_losses[use_exact] = exact_loss all_losses[not_use_exact] = losses return all_losses
def ops_weights(self, node): gumbel_dist = Gumbel(torch.tensor([.0]), torch.tensor([1.0])) return gumbel_softmax(self.alpha_ops[:, node], self.temperature(), gumbel_dist)
def topology_weights(self, node): # return the soft weights for topology aggregation gumbel_dist = Gumbel(torch.tensor([.0]), torch.tensor([1.0])) return gumbel_softmax(self.alpha_topology[: node + 2, node], self.temperature(), gumbel_dist)[1:]
def rsample_gumbel( distr: Distribution, n: int, ) -> torch.Tensor: gumbel_distr = Gumbel(distr.logits, 1) return gumbel_distr.rsample((n, ))
def gumbel_softmax_sample(logits, temp=1.): g = Gumbel(0, 1).sample(logits.shape) y = (g + logits) / temp return torch.softmax(y, dim=-1)
def sample_genotype_index(self): with torch.no_grad(): sampled_alphas = gumbel_softmax_sample( self.gumble_arch_params.squeeze(), temperature=self.gumbel_temperature, dim=0) best_sampled_alphas = torch.argmax(sampled_alphas, dim=0) return best_sampled_alphas.detach() class MixedSequential(torch.nn.Sequential): def forward(self, input): total_cost = 0 for module in self._modules.values(): input = module(input) if isinstance(module, MixedModule): input, cost = input total_cost += cost return input gumbel_dist = Gumbel(0.0, 1.0) def gumbel_softmax_sample(logits, temperature, dim=None, std=1.0): y = logits + gumbel_dist.sample(logits.shape).to(device=logits.device, dtype=logits.dtype) # y = logits + sample_gumbel(logits=logits, std=std) return F.softmax(y / temperature, dim=dim)
def get_raoblackwell_ps_loss( conditional_loss_fun, log_class_weights, topk, sample_topk, grad_estimator, grad_estimator_kwargs={'grad_estimator_kwargs': None}, epoch=None, data=None): """ Returns a pseudo_loss, such that the gradient obtained by calling pseudo_loss.backwards() is unbiased for the true loss Parameters ---------- conditional_loss_fun : function A function that returns the loss conditional on an instance of the categorical random variable. It must take in a one-hot-encoding matrix (batchsize x n_categories) and return a vector of losses, one for each observation in the batch. log_class_weights : torch.Tensor A tensor of shape batchsize x n_categories of the log class weights topk : Integer The number of categories to sum over grad_estimator : function A function that returns the pseudo loss, that is, the loss which gives a gradient estimator when .backwards() is called. See baselines_lib for details. grad_estimator_kwargs : dict keyword arguments to gradient estimator epoch : int The epoch of the optimizer (for Gumbel-softmax, which has an annealing rate) data : torch.Tensor The data at which we evaluate the loss (for NVIl and RELAX, which have a data dependent baseline) Returns ------- ps_loss : a value such that ps_loss.backward() returns an estimate of the gradient. In general, ps_loss might not equal the actual loss. """ # class weights from the variational distribution assert np.all(log_class_weights.detach().cpu().numpy() <= 0) class_weights = torch.exp(log_class_weights.detach()) if sample_topk: # perturb the log_class_weights phi = log_class_weights.detach() g_phi = Gumbel(phi, torch.ones_like(phi)).rsample() _, ind = g_phi.topk(topk + 1, dim=-1) topk_domain = ind[..., :-1] concentrated_mask = torch.zeros_like(phi).scatter(-1, topk_domain, 1).detach() sample_ind = ind[..., -1] # Last sample we use as real sample seq_tensor = torch.arange(class_weights.size(0), dtype=torch.long, device=class_weights.device) else: # this is the indicator C_k concentrated_mask, topk_domain, seq_tensor = \ get_concentrated_mask(class_weights, topk) concentrated_mask = concentrated_mask.float().detach() ############################ # compute the summed term summed_term = 0.0 for i in range(topk): # get categories to be summed summed_indx = topk_domain[:, i] # compute gradient estimate grad_summed = \ grad_estimator(conditional_loss_fun, log_class_weights, class_weights, seq_tensor, \ z_sample = summed_indx, epoch = epoch, data = data, **grad_estimator_kwargs) # sum summed_weights = class_weights[seq_tensor, summed_indx].squeeze() summed_term = summed_term + \ (grad_summed * summed_weights).sum() ############################ # compute sampled term sampled_weight = torch.sum(class_weights * (1 - concentrated_mask), dim=1, keepdim=True) if not (topk == class_weights.shape[1]): # if we didn't sum everything # we sample from the remaining terms if not sample_topk: # class weights conditioned on being in the diffuse set conditional_class_weights = (class_weights + 1e-12) * \ (1 - concentrated_mask) / (sampled_weight + 1e-12) # sample from conditional distribution conditional_z_sample = sample_class_weights( conditional_class_weights) else: conditional_z_sample = sample_ind # We have already sampled it grad_sampled = grad_estimator(conditional_loss_fun, log_class_weights, class_weights, seq_tensor, z_sample=conditional_z_sample, epoch=epoch, data=data, **grad_estimator_kwargs) else: grad_sampled = 0. return (grad_sampled * sampled_weight.squeeze()).sum() + summed_term
def beam_search(self, init, steps, beam_size, temperature=1.0, stochastic=False, verbose=False): assert len(init.shape) == 2 and init.shape[1] == self.init_dim assert self.event_dim >= beam_size > 0 and steps > 0 batch_size = init.shape[0] current_beam_size = beam_size # Initial hidden weights hidden = self.init_to_hidden( init) # [gru_layers, batch_size, hidden_size] hidden = hidden[:, :, None, :] # [gru_layers, batch_size, 1, hidden_size] hidden = hidden.repeat( 1, 1, current_beam_size, 1) # [gru_layers, batch_size, beam_size, hidden_dim] # Initial event event = self.get_primary_event(batch_size) # [1, batch] event = event[:, :, None].repeat(1, 1, current_beam_size) # [1, batch, 1] # [batch, beam, 1] event sequences of beams beam_events = event[0, :, None, :].repeat(1, current_beam_size, 1) # [batch, beam] log probs sum of beams beam_log_prob = torch.zeros(batch_size, current_beam_size).to(device) if stochastic: # [batch, beam] Gumbel perturbed log probs of beams beam_log_prob_perturbed = torch.zeros(batch_size, current_beam_size).to(device) beam_z = torch.full((batch_size, beam_size), float('inf')) gumbel_dist = Gumbel(0, 1) step_iter = range(steps) if verbose: step_iter = Bar(['', 'Stochastic '][stochastic] + 'Beam Search').iter(step_iter) for step in step_iter: event = event.view(1, batch_size * current_beam_size) # [1, batch*beam0] hidden = hidden.view(self.rnn_layers, batch_size * current_beam_size, self.hidden_dim) # [grus, batch*beam, hid] logits, hidden = self.gen_forward(event, hidden) hidden = hidden.view(self.rnn_layers, batch_size, current_beam_size, self.hidden_dim) # [grus, batch, cbeam, hid] logits = (logits / temperature).view( 1, batch_size, current_beam_size, self.event_dim) # [1, batch, cbeam, out] beam_log_prob_expand = logits + beam_log_prob[ None, :, :, None] # [1, batch, cbeam, out] beam_log_prob_expand_batch = beam_log_prob_expand.view( 1, batch_size, -1) # [1, batch, cbeam*out] if stochastic: beam_log_prob_expand_perturbed = beam_log_prob_expand + gumbel_dist.sample( beam_log_prob_expand.shape) beam_log_prob_Z, _ = beam_log_prob_expand_perturbed.max( -1) # [1, batch, cbeam] # print(beam_log_prob_Z) beam_log_prob_expand_perturbed_normalized = beam_log_prob_expand_perturbed # beam_log_prob_expand_perturbed_normalized = -torch.log( # torch.exp(-beam_log_prob_perturbed[None, :, :, None]) # - torch.exp(-beam_log_prob_Z[:, :, :, None]) # + torch.exp(-beam_log_prob_expand_perturbed)) # [1, batch, cbeam, out] # beam_log_prob_expand_perturbed_normalized = beam_log_prob_perturbed[None, :, :, None] + beam_log_prob_expand_perturbed # [1, batch, cbeam, out] beam_log_prob_expand_perturbed_normalized_batch = \ beam_log_prob_expand_perturbed_normalized.view(1, batch_size, -1) # [1, batch, cbeam*out] _, top_indices = beam_log_prob_expand_perturbed_normalized_batch.topk( beam_size, -1) # [1, batch, cbeam] beam_log_prob_perturbed = \ torch.gather(beam_log_prob_expand_perturbed_normalized_batch, -1, top_indices)[0] # [batch, beam] else: _, top_indices = beam_log_prob_expand_batch.topk(beam_size, -1) top_indices.to(device) beam_log_prob = torch.gather(beam_log_prob_expand_batch, -1, top_indices)[0] # [batch, beam] beam_index_old = torch.arange( current_beam_size, device=device)[None, None, :, None] # [1, 1, cbeam, 1] beam_index_old = beam_index_old.repeat( 1, batch_size, 1, self.output_dim) # [1, batch, cbeam, out] beam_index_old = beam_index_old.view(1, batch_size, -1) # [1, batch, cbeam*out] # beam_index_old.to(device) # print(device) # print(beam_index_old.device) # print(top_indices.device) beam_index_new = torch.gather(beam_index_old, -1, top_indices) hidden = torch.gather( hidden, 2, beam_index_new[:, :, :, None].repeat(4, 1, 1, 1024)) event_index = torch.arange( self.output_dim, device=device)[None, None, None, :] # [1, 1, 1, out] event_index = event_index.repeat(1, batch_size, current_beam_size, 1) # [1, batch, cbeam, out] event_index = event_index.view(1, batch_size, -1) # [1, batch, cbeam*out] event_index.to(device) event = torch.gather(event_index, -1, top_indices) # [1, batch, cbeam*out] beam_events = torch.gather( beam_events[None], 2, beam_index_new.unsqueeze(-1).repeat(1, 1, 1, beam_events.shape[-1])) beam_events = torch.cat([beam_events, event.unsqueeze(-1)], -1)[0] current_beam_size = beam_size best = beam_events[torch.arange(batch_size).long(), beam_log_prob.argmax(-1)] best = best.contiguous().t() return best
import numpy as np import torch from torch import nn from torch.distributions import Normal, Gumbel _normalization_factor = float(1.0 / np.sqrt(2.0 * np.pi)) _std_normal = Normal(0.0, 1.0) _gumbel = Gumbel(0.0, 1.0) class MDN(nn.Module): def __init__(self, n_inputs, n_outputs, n_components, hidden_units=None): super(MDN, self).__init__() hidden_units = list(hidden_units or [20]) self.n_inputs = n_inputs self.n_components = n_components self.n_outputs = n_outputs sizes = [n_inputs] + hidden_units self.h = nn.Sequential(*[ nn.Sequential(nn.Linear(sizes[i], sizes[i + 1]), nn.Tanh()) for i in range(len(sizes) - 1) ]) self.out_pi = nn.Sequential( nn.Linear(hidden_units[-1], n_outputs * n_components), nn.Softmax(dim=-1)) self.out_mu = nn.Linear(hidden_units[-1], n_outputs * n_components) self.out_sig2 = nn.Linear(hidden_units[-1], n_outputs * n_components) def forward(self,