def forward(self, x): prev_hidden = self.agent(x) prev_c = torch.zeros_like(prev_hidden) # only for LSTM e_t = torch.stack([self.sos_embedding] * prev_hidden.size(0)) sequence = [] for step in range(self.max_len): if isinstance(self.cell, nn.LSTMCell): h_t, prev_c = self.cell(e_t, (prev_hidden, prev_c)) else: h_t = self.cell(e_t, prev_hidden) step_logits = F.log_softmax(self.hidden_to_output(h_t), dim=1) distr = RelaxedOneHotCategorical(logits=step_logits, temperature=self.temperature) if self.training: x = distr.rsample() else: x = torch.zeros_like(step_logits).scatter_( -1, step_logits.argmax(dim=-1, keepdim=True), 1.0) prev_hidden = h_t e_t = self.embedding(x) sequence.append(x) sequence = torch.stack(sequence).permute(1, 0, 2) if self.force_eos: eos = torch.zeros_like(sequence[:, 0, :]).unsqueeze(1) eos[:, 0, 0] = 1 sequence = torch.cat([sequence, eos], dim=1) return sequence
def multi_dist_sample(self, logits, temp, return_only_action): action = [] log_prob = [] for i in range(0, self.num_heads * self.num_actions, self.num_heads): dist = RelaxedOneHotCategorical(logits=logits[:, i:i + self.num_heads], temperature=temp) raw_action = dist.rsample() action.append(raw_action.argmax(1).unsqueeze(1)) if not return_only_action: log_prob.append(dist.log_prob(raw_action).t()[:, 0:1]) action = torch.cat(action, 1).float() if len(action) == 1 and random.random() < self.epsilon: for _ in range(4): action_ind = random.randint( 0, int(self.num_actions / self.num_heads)) action[0, action_ind] = 0.0 if random.random() < 0.5 else 1.0 if self.epsilon > self.epsilon_end: self.epsilon -= self.epsilon_decay_rate #action -= 1 #Translate 0,1,2 to -1,0,1 if return_only_action: return action log_prob = torch.cat(log_prob, 1).mean(1).unsqueeze(1) return action, log_prob
def gumbel_softmax_sample(logits: torch.Tensor, temperature: float = 1.0, training: bool = True, straight_through: bool = False): size = logits.size() if not training: indexes = logits.argmax(dim=-1) one_hot = torch.zeros_like(logits).view(-1, size[-1]) one_hot.scatter_(1, indexes.view(-1, 1), 1) one_hot = one_hot.view(*size) return one_hot sample = RelaxedOneHotCategorical(logits=logits, temperature=temperature).rsample() if straight_through: size = sample.size() indexes = sample.argmax(dim=-1) hard_sample = torch.zeros_like(sample).view(-1, size[-1]) hard_sample.scatter_(1, indexes.view(-1, 1), 1) hard_sample = hard_sample.view(*size) sample = sample + (hard_sample - sample).detach() return sample
def forward(self, input, args, n_particles, test=False): """ n_particles is interpreted as 1 for now to not screw anything up """ n_particles = 1 T = nn.Softmax(dim=0)(self.T) # NOTE: not in log-space pi = nn.Softmax(dim=0)(self.pi) emit = self.calc_emit() # run the input and teacher-forcing inputs through the embedding layers here seq_len, batch_sz = input.size() emb = self.inp_embedding(input) hidden = self.init_hidden(batch_sz, self.nhid, 2) # bidirectional hidden_states, (_, _) = self.encoder(emb, hidden) hidden_states = hidden_states.repeat(1, n_particles, 1) # run the z-decoder at this point, evaluating the NLL at each step z = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) nlls = hidden_states.data.new(seq_len, batch_sz * n_particles) loss = 0 # now a log-prob prior_probs = pi.unsqueeze(0).expand(batch_sz * n_particles, self.z_dim) logits = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) for i in range(seq_len): # logits = self.logits(torch.cat([hidden_states[i], h], 1)) # logits = self.logits(nn.functional.relu(self.z_decoder(hidden_states[i], logits))) logits = self.logits(hidden_states[i]) # build the next z sample q = RelaxedOneHotCategorical(temperature=Variable( torch.Tensor([args.temp]).cuda()), logits=logits) z = q.sample() lse = log_sum_exp(logits, dim=1).view(-1, 1) log_probs = logits - lse # now, compute the log-likelihood of the data given this z-sample # so emission is [batch_sz x z_dim], i.e. emission[i, j] is the log-probability of getting this # data for element i given choice z emission = F.embedding(input[i].repeat(n_particles), emit) NLL = -log_sum_exp(emission + log_probs, 1) nlls[i] = NLL.data KL = (log_probs.exp() * (log_probs - (prior_probs + 1e-16).log())).sum(1) loss += (NLL + KL) if i != seq_len - 1: prior_probs = (T.unsqueeze(0) * z.unsqueeze(1)).sum(2) # now, we calculate the final log-marginal estimator return loss.sum(), nlls.sum(), (seq_len * batch_sz * n_particles), 0
def reparameterize(self, p): if self.training: # At training time we sample from a relaxed Gumbel-Softmax Distribution. The samples are continuous but when we increase the temperature the samples gets closer to a Categorical. m = RelaxedOneHotCategorical(TEMPERATURE, p) return m.rsample() else: # At testing time we sample from a Categorical Distribution. m = OneHotCategorical(p) return m.sample()
def noisy_action(self, obs, return_only_action=True): _, log_temp, logits = self.clean_action(obs, return_only_action=False) temp = log_temp.exp() dist = RelaxedOneHotCategorical(temperature=temp, probs=F.softmax(logits, dim=1)) action = dist.rsample() if return_only_action: return action.argmax(1) log_prob = dist.log_prob(action) log_prob = torch.diagonal(log_prob, offset=0).unsqueeze(1) return action.argmax(1), log_prob, logits
def forward(self, input: torch.Tensor, proposal: distributions.RelaxedOneHotCategorical, proposal_sample: torch.Tensor, reconstruction: torch.Tensor) -> torch.Tensor: if self.likelihood == 'bernoulli': likelihood = distributions.Bernoulli(probs=reconstruction) else: likelihood = distributions.Normal(reconstruction, torch.ones_like(reconstruction)) likelihood = distributions.Independent(likelihood, reinterpreted_batch_ndims=-1) reconstruction_loss = likelihood.log_prob(input).mean() assert proposal.logits.dim( ) == 2, "proposal.shape == [*, D], D is shape of isotopic gaussian" prior = distributions.RelaxedOneHotCategorical(proposal.temperature, logits=torch.ones_like( proposal.logits)) regularization = (proposal.log_prob(proposal_sample) - prior.log_prob(proposal_sample)) \ .mean() # evidence lower bound (maximize) total_loss = reconstruction_loss - self.beta * regularization return -total_loss, -reconstruction_loss, regularization
def forward(self, state, action): #encoding state = state.unsqueeze(0) action = action.unsqueeze(0) q = F.relu(self.e1(torch.cat([state, action], dim=1))) q = F.relu(self.e2(q)) q = F.relu(self.e3(q)) q_y = q.view(q.size(0), self.latent_dim, self.categorical_dim) # 1 x 2 x 5 # decoding z = RelaxedOneHotCategorical(torch.tensor([self.temperature]), logits=q) log_prob = z.log_prob(action) sample = z.sample() recon = self.decode(state, sample) return recon, F.softmax(q_y, dim=1).reshape(*q.size()))
def forward(self, *args, **kwargs): logits = self.agent(*args, **kwargs) if self.training: return RelaxedOneHotCategorical(logits=logits, temperature=self.temperature).rsample() else: return (logits / self.temperature).softmax(dim=1)
def forward(self, x): params = self.network(x) if self.beta not in (None, 0., np.inf): RelaxedOneHotCategorical(temperature=1./self.beta, logits=params) return OneHotCategorical(logits=params)
def prob_dists(self, obs, temperature=1.0): logits = self.forward(obs) split_logits = torch.split(logits, self.action_split, dim=-1) temperature = torch.tensor(temperature).to(DEVICE) return [ RelaxedOneHotCategorical(temperature, logits=l) for l in split_logits ]
def concrete(self, state): conv = F.relu(self.c3(F.relu(self.c2(F.relu(self.c1(state)))))) conv_flat = conv.view(state.size()[0], -1) fc_out = self.fc2(F.relu(self.fc1(conv_flat))) # print fc_out.data[0].numpy() c = torch.clamp(torch.sign(fc_out), 0.0).data[0].cpu().numpy() return RelaxedOneHotCategorical(self.temperature, logits=fc_out).sample(), c
def forward(self, *args, **kwargs): logits = self.agent(*args, **kwargs) if self.training: return RelaxedOneHotCategorical( logits=logits, temperature=self.temperature).rsample() else: return torch.zeros_like(logits).scatter_( -1, logits.argmax(dim=-1, keepdim=True), 1.0)
def forward(self, logits: torch.Tensor): size = logits.size() if not self.training: indexes = logits.argmax(dim=-1) one_hot = torch.zeros_like(logits).view(-1, size[-1]) one_hot.scatter_(1, indexes.view(-1, 1), 1) one_hot = one_hot.view(*size) return one_hot sample = RelaxedOneHotCategorical( logits=logits, temperature=self.temperature).rsample() if self.straight_through: size = sample.size() indexes = sample.argmax(dim=-1) hard_sample = torch.zeros_like(sample).view(-1, size[-1]) hard_sample.scatter_(1, indexes.view(-1, 1), 1) hard_sample = hard_sample.view(*size) sample = sample + (hard_sample - sample).detach() return sample
def forward(self, tensor_list): """ Creates a RelaxedOneHotCategorical distribution conditioned on the inputs. Parameters ---------- tensor_list: list of torch.Tensor a list of tensors that will be first concatenatedd on the last dimension. """ x = torch.cat(tensor_list, dim=-1) logits = self.w_dense(x) return RelaxedOneHotCategorical(self._temperature, logits=logits)
def forward(ctx, input, temperature): """Forward pass Parameters ========== :param input: input tensor Returns ======= :return: a one-hot tensor with 1 indicating the max of that input vector""" # We can cache arbitrary Tensors for use in the backward pass using the # save_for_backward method. # ctx.save_for_backward(input) batch_size = input.shape[0] # maxes = torch.max(input,1)[1] dist = RelaxedOneHotCategorical(temperature, input) samples = dist.sample() # probs = # out = torch.zeros_like(input) # out[range(batch_size), samples] = 1 # ctx.save_for_backward(out) # return out return samples
def forward(self, x): B, C, H, W = x.size() N, M, D = self.embedding.size() assert C == N * D x = x.view(B, N, D, H, W).permute(1, 0, 3, 4, 2) x_flat = x.reshape(N, -1, D) distances = torch.baddbmm( torch.sum(self.embedding**2, dim=2).unsqueeze(1) + torch.sum(x_flat**2, dim=2, keepdim=True), x_flat, self.embedding.transpose(1, 2), alpha=-2.0, beta=1.0) distances = distances.view(N, B, H, W, M) dist = RelaxedOneHotCategorical(0.5, logits=-distances) if self.training: samples = dist.rsample().view(N, -1, M) else: samples = torch.argmax(dist.probs, dim=-1) samples = F.one_hot(samples, M).float() samples = samples.view(N, -1, M) quantized = torch.bmm(samples, self.embedding) quantized = quantized.view_as(x) KL = dist.probs * (dist.logits + math.log(M)) KL[(dist.probs == 0).expand_as(KL)] = 0 KL = KL.sum(dim=(0, 2, 3, 4)).mean() avg_probs = torch.mean(samples, dim=1) perplexity = torch.exp( -torch.sum(avg_probs * torch.log(avg_probs + 1e-10), dim=-1)) return quantized.permute(1, 0, 4, 2, 3).reshape(B, C, H, W), KL, perplexity.sum()
def gumbel_softmax_sample(logits: torch.Tensor, temperature: float = 1.0, straight_through: bool = False): """Samples from a Gumbel-Sotmax/Concrete of a Categorical distribution. More details in: - Gumbel-Softmax: https://arxiv.org/abs/1611.01144 - Concrete distribution: https://arxiv.org/abs/1611.00712 Arguments: logits {torch.Tensor} -- tensor of logits, the output of an inference network. Size: [batch_size, n_categories] Keyword Arguments: temperature {float} -- temperature of the softmax relaxation. The lower the temperature (-->0), the closer the sample is to a discrete sample. (default: {1.0}) straight_through {bool} -- Whether to use the straight-through estimator. (default: {False}) Returns: torch.Tensor -- the relaxed sample. Size: [batch_size, n_categories] """ sample = RelaxedOneHotCategorical(logits=logits, temperature=temperature).rsample() if straight_through: size = sample.size() indexes = sample.argmax(dim=-1) hard_sample = torch.zeros_like(sample).view(-1, size[-1]) hard_sample.scatter_(1, indexes.view(-1, 1), 1) hard_sample = hard_sample.view(*size) sample = sample + (hard_sample - sample).detach() return sample
def sampled_elbo(self, input, args, n_particles, emb, hidden_states): seq_len, batch_sz = input.size() T = nn.Softmax(dim=0)(self.T) # NOTE: not in log-space pi = nn.Softmax(dim=0)(self.pi) emit = self.calc_emit() hidden_states = hidden_states.repeat(1, n_particles, 1) nlls = hidden_states.data.new(seq_len, batch_sz * n_particles) loss = 0 # now a value in probability space prior_probs = pi.unsqueeze(0).expand(batch_sz * n_particles, self.z_dim) logits = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) for i in range(seq_len): logits = self.logits(hidden_states[i]) # build the next z sample p = RelaxedOneHotCategorical(temperature=self.temp_prior, probs=prior_probs) q = RelaxedOneHotCategorical(temperature=self.temp, logits=logits) z = q.rsample() log_probs = F.log_softmax(logits, dim=1) # now, compute the log-likelihood of the data given this z-sample # so emission is [batch_sz x z_dim], i.e. emission[i, j] is the log-probability of getting this # data for element i given choice z emission = F.embedding(input[i].repeat(n_particles), emit) NLL = -log_sum_exp(emission + log_probs, 1) nlls[i] = NLL.data KL = q.log_prob(z) - p.log_prob(z) # pretty inexact loss += (NLL + KL) if i != seq_len - 1: prior_probs = (T.unsqueeze(0) * z.unsqueeze(1)).sum(2) (loss.sum() / (seq_len * batch_sz * n_particles)).backward(retain_graph=True) return loss, 0, seq_len * batch_sz * n_particles, 0
def rsample(self): if not self.has_rsample: raise NotImplementedError("Mixture does not support rsample.") samples = [] for distribution in self.distributions: sample = distribution.rsample().unsqueeze(0) samples.append(sample) samples = torch.cat(samples, dim=0) expand = samples.dim() - 2 choice = RelaxedOneHotCategorical(probs=self.weights, temperature=0.1) choice = choice.rsample().permute(1, 0) choice = choice.view(choice.size(0), choice.size(1), *expand) result = (samples * choice).sum(dim=0) return result
def add_noise_(self, batch): for i in range(len(batch.actions)): if i == self.index: continue # get observations and actions for agent i obs = batch.observations[i] actions = batch.actions[i] # create noise tensors, same shape and on same device if self.sigma_noise is not None: obs = obs + torch.randn_like(obs) * self.sigma_noise if self.temp_noise is not None: temp = torch.tensor(self.temp_noise, dtype=torch.float, device=actions.device) # avoid zero probs which lead to nan samples probs = actions + 1e-45 actions = RelaxedOneHotCategorical(temp, probs=probs).sample() # add noise batch.observations[i] = obs batch.actions[i] = actions
def rsample_gumbel_softmax( distr: Distribution, n: int, temperature: torch.Tensor, straight_through: bool = False, ) -> torch.Tensor: if isinstance(distr, (Categorical, OneHotCategorical)): if straight_through: gumbel_distr = RelaxedOneHotCategoricalStraightThrough( temperature, probs=distr.probs) else: gumbel_distr = RelaxedOneHotCategorical(temperature, probs=distr.probs) elif isinstance(distr, Bernoulli): if straight_through: gumbel_distr = RelaxedBernoulliStraightThrough(temperature, probs=distr.probs) else: gumbel_distr = RelaxedBernoulli(temperature, probs=distr.probs) else: raise ValueError("Using Gumbel Softmax with non-discrete distribution") return gumbel_distr.rsample((n, ))
def sample(self, mean, logvar, probabilities): normal = Normal(mean, torch.exp(0.5 * logvar)) categorical = RelaxedOneHotCategorical(self.temperature, probabilities) return normal.rsample(), categorical.rsample()
def forward(self, x, mask, num_particles=4): logweight_acc = torch.zeros(x.size(1), num_particles).to( device) # (batch_size, num_particles) log_hat_p_acc = torch.zeros(x.size(1)).to(device) # (batch_size, ) log_hat_p_iwae_acc = torch.zeros(x.size(1)).to(device) kl_acc = torch.zeros(x.size(1)).to(device) # (batch_size, ) # [0, 1, 2, 3, 4, 5, 6, 7, ... ] noresampleidxs = torch.arange(x.size(1) * num_particles).to(device) h = Variable( torch.zeros(self.n_layers, x.size(1) * num_particles, self.h_dim)).to(device) c = Variable( torch.zeros(self.n_layers, x.size(1) * num_particles, self.h_dim)).to(device) # with torch.autograd.set_detect_anomaly(True): for t in range(x.size(0)): # VRNN Cell xts = x[t].repeat((1, num_particles)).reshape( (x.size(1) * num_particles, -1)) phi_x_ts = self.phi_x( xts) # [batch_size * num_particle, embed_size] enc_t = self.enc(torch.cat([phi_x_ts, h[-1]], 1)) enc_mean_t = self.enc_mean(enc_t) enc_std_t = self.enc_std(enc_t) encoder_dist = MultivariateNormal( enc_mean_t, scale_tril=torch.diag_embed(enc_std_t)) prior_t = self.prior(h[-1]) prior_mean_t = self.prior_mean(prior_t) prior_std_t = self.prior_std(prior_t) prior_dist = MultivariateNormal( prior_mean_t, scale_tril=torch.diag_embed(prior_std_t)) z_t_is = encoder_dist.rsample( ) # reparametrizable # [batch_size * seq_len, latent_size] phi_z_ts = self.phi_z(z_t_is) dec_t = self.dec(torch.cat([phi_z_ts, h[-1]], 1)) dec_mean_t = self.dec_mean(dec_t) decoder_dist = Bernoulli(probs=dec_mean_t) prior_logprob_ti = prior_dist.log_prob(z_t_is.detach()) + 1e-7 encoder_logprob_ti = encoder_dist.log_prob(z_t_is.detach()) + 1e-7 decoder_logprob_ti = decoder_dist.log_prob(xts).sum(-1) + 1e-7 # recurrence _, (h, c) = self.rnn( torch.cat([phi_x_ts, phi_z_ts], 1).unsqueeze(0), (h, c)) kl = torch.distributions.kl_divergence(encoder_dist, prior_dist) kl_acc += kl.mean(-1) * mask[t] nll = self._nll_bernoulli(dec_mean_t, xts) # log_alpha_ti = prior_logprob_ti + decoder_logprob_ti - encoder_logprob_ti # [batch_size, ] log_alpha_ti = -(nll + kl) log_alpha_ti = log_alpha_ti.reshape( x.size(1), -1) # [batch_size, num_particles] log_alpha_ti = log_alpha_ti * mask[t][ None].T # [batch_size, num_particles] * [batch_size, 1] # hat_p = torch.exp(logweight_acc + log_alpha_ti) # [batch_size, num_particles] logweight_acc += log_alpha_ti # Add resampling procedure here # ess = 1. / (torch.exp(logweight_acc) ** 2).sum(-1) # [batch_size, ] # logess = torch.log(1. / (torch.exp(logweight_acc) ** 2).sum(-1) ) logess_num = 2 * torch.logsumexp(logweight_acc, dim=-1) logess_denom = torch.logsumexp(2 * logweight_acc, dim=-1) logess = logess_num - logess_denom if not self.use_resampling_gradient: resample_dist = Categorical( logits=logweight_acc.reshape(x.size(1), num_particles)) resampled_idxs = resample_dist.sample([num_particles]).T # [0, 0, 0, 0, 4, 4, 4, 4, ... ] sample_offset = torch.arange(x.size(1)).repeat([ num_particles, 1 ]).T.reshape(-1).to(device) * num_particles resampled_idxs = resampled_idxs.reshape(-1) + sample_offset should_resample = logess <= torch.log( torch.ones_like(logess).to(device) * num_particles / 2.0) should_resample = should_resample & mask[t].bool() should_resample_tiled = should_resample.repeat( [num_particles, 1]).T.reshape(-1) new_idxs = torch.where(should_resample_tiled, resampled_idxs, noresampleidxs) h[-1] = h[-1][new_idxs] c[-1] = c[-1][new_idxs] log_hat_p = torch.logsumexp(logweight_acc.clone(), dim=-1) - math.log( float(num_particles)) log_hat_p_acc += log_hat_p * should_resample.float() logweight_acc *= (1. - should_resample_tiled.reshape( x.size(1), num_particles).float()) else: # raise NotImplementedError resample_dist = RelaxedOneHotCategorical( logits=logweight_acc.reshape(x.size(1), num_particles), temperature=0.1) resampled_onehot_relaxedidxs = resample_dist.rsample( [num_particles]).permute(1, 0, 2) #.reshape(-1, num_particles) should_resample = logess <= torch.log( torch.ones_like(logess).to(device) * num_particles / 2.0) should_resample = should_resample & mask[t].bool() should_resample_tiled = should_resample.repeat( [num_particles, 1]).T.reshape(-1) # noresample_onehot = torch.eye(x.size(1) * num_particles) for batch_idx in range(x.size(1)): if should_resample[batch_idx]: # cur_slice = (batch_idx * x.size(1) * num_particles) : (batch_idx * x.size(1) * num_particles + x.size(1) * num_particles) h[-1][(batch_idx * num_particles) : (batch_idx * num_particles + num_particles)] = \ resampled_onehot_relaxedidxs[batch_idx] @ h[-1][(batch_idx * num_particles) : (batch_idx * num_particles + num_particles)].clone() c[-1][(batch_idx * num_particles) : (batch_idx * num_particles + num_particles)] = \ resampled_onehot_relaxedidxs[batch_idx] @ c[-1][(batch_idx * num_particles) : (batch_idx * num_particles + num_particles)].clone() log_hat_p = torch.logsumexp(logweight_acc.clone(), dim=-1) - math.log( float(num_particles)) log_hat_p_acc += log_hat_p * should_resample.float() logweight_acc *= (1. - should_resample_tiled.reshape( x.size(1), num_particles).float()) log_hat_p_iwae_acc += ( torch.logsumexp(log_alpha_ti.detach(), dim=-1) - math.log(float(num_particles))) * mask[t] #computing losses # kld_loss /= self.num_zs # nll_loss /= self.num_zs log_hat_p_acc += torch.logsumexp(logweight_acc, dim=-1) - math.log( float(num_particles)) fivo_bound = torch.sum(log_hat_p_acc) # kl = torch.mean(kl_acc.reshape(x.size(1), -1), dim=-1) # return fivo_loss, kld_loss, nll_loss, \ # (all_enc_mean, all_enc_std), \ # (all_dec_mean, all_dec_std), \ # log_hat_ps return -fivo_bound, log_hat_p_acc, logweight_acc, kl_acc, log_hat_p_iwae_acc
def forward(self, X, index, length, hidden=None, temp=1.0): # A mode for learning to predict words, ignoring how to distinguish between parser states # Set to true for pre-training, then set to False for fine-tuning pretrain = False dice = random.random() prev_a, prev_b, prev_depth, _ = hidden batch_size = X.shape[0] hidden_size = prev_b.shape[-1] batch_range = range(batch_size) device = next(self.parameters()).device # Depth "0" is initialized to 0 (needed for conditioning of depth 1) ab_00 = [torch.zeros(batch_size, 2 * hidden_size, device=device)] ab_01 = [torch.zeros(batch_size, 2 * hidden_size, device=device)] ab_10 = [torch.zeros(batch_size, 2 * hidden_size, device=device)] ab_11 = [torch.zeros(batch_size, 2 * hidden_size, device=device)] sect_start = time.time() for d in range(1, self.depth + 1): fork_join_a = prev_a[:, d, :] nofork_nojoin_a = self.w_a00( torch.cat((X, prev_b[:, d - 1, :], prev_a[:, d, :]), 1)) fork_nojoin_a = self.w_a10(torch.cat((X, prev_b[:, d - 1, :]), 1)) nofork_join_a = prev_a[:, d - 1, :] ## At next depth, need to update a and/or b next_a_d_00 = ( torch.eq(prev_depth, float(d)).float() * nofork_nojoin_a + # at shallower depth, copy over torch.gt(prev_depth, float(d)).float() * prev_a[:, d, :] + # at deeper depth, zero out torch.lt(prev_depth, float(d)).float() * torch.zeros_like(prev_a[:, d, :])) next_a_d_11 = ( torch.eq(prev_depth, float(d)).float() * fork_join_a + # at shallower depth, copy over torch.gt(prev_depth, float(d)).float() * prev_a[:, d, :] + # at deeper depth, zero out torch.lt(prev_depth, float(d)).float() * torch.zeros_like(prev_a[:, d, :])) next_a_d_10 = ( torch.eq(prev_depth + 1, float(d)).float() * fork_nojoin_a + # at shallower depth, copy over torch.gt(prev_depth + 1, float(d)).float() * prev_a[:, d, :] + # at deeper depth, zero out torch.lt(prev_depth + 1, float(d)).float() * torch.zeros_like(prev_a[:, d, :])) next_a_d_01 = ( torch.eq(prev_depth - 1, float(d)).float() * nofork_join_a + # at shallower depth, copy over torch.gt(prev_depth - 1, float(d)).float() * prev_a[:, d, :] + # at deeper depth, zero out torch.lt(prev_depth - 1, float(d)).float() * torch.zeros_like(prev_a[:, d, :])) fork_join_b = self.w_b11(torch.cat((X, prev_b[:, d, :]), 1)) nofork_nojoin_b = self.w_b00( torch.cat((X, prev_a[:, d, :], next_a_d_00), 1)) fork_nojoin_b = self.w_b10(torch.cat((X, next_a_d_10), 1)) nofork_join_b = self.w_b01( torch.cat((X, prev_b[:, d - 1, :], prev_a[:, d, :]), 1)) next_b_d_00 = ( torch.eq(prev_depth, float(d)).float() * nofork_nojoin_b + # at shallower depth, copy over torch.gt(prev_depth, float(d)).float() * prev_b[:, d, :] + # at deeper depth, zero out torch.lt(prev_depth, float(d)).float() * torch.zeros_like(prev_b[:, d, :])) next_b_d_11 = ( torch.eq(prev_depth, float(d)).float() * fork_join_b + # at shallower depth, copy over torch.gt(prev_depth, float(d)).float() * prev_b[:, d, :] + # at deeper depth, zero out torch.lt(prev_depth, float(d)).float() * torch.zeros_like(prev_b[:, d, :])) next_b_d_10 = ( torch.eq(prev_depth + 1, float(d)).float() * fork_nojoin_b + # at shallower depth, copy over torch.gt(prev_depth + 1, float(d)).float() * prev_b[:, d, :] + # at deeper depth, zero out torch.lt(prev_depth + 1, float(d)).float() * torch.zeros_like(prev_b[:, d, :])) next_b_d_01 = ( torch.eq(prev_depth - 1, float(d)).float() * nofork_join_b + # at shallower depth, copy over torch.gt(prev_depth - 1, float(d)).float() * prev_b[:, d, :] + # at deeper depth, zero out torch.lt(prev_depth - 1, float(d)).float() * torch.zeros_like(prev_b[:, d, :])) next_ab_00 = torch.cat((next_a_d_00, next_b_d_00), 1) next_ab_01 = torch.cat((next_a_d_01, next_b_d_01), 1) next_ab_10 = torch.cat((next_a_d_10, next_b_d_10), 1) next_ab_11 = torch.cat((next_a_d_11, next_b_d_11), 1) ab_00.append(next_ab_00) ab_01.append(next_ab_01) ab_10.append(next_ab_10) ab_11.append(next_ab_11) sect_end = time.time() self.state_compute += (sect_end - sect_start) sect_start = time.time() # Now flatten the depth and predict the attention variables: # next_state_flat = torch.squeeze( next_state.view(batch_size, 1, -1) ) ab_00_flat = torch.stack(ab_00, 1).view(batch_size, 1, -1) ab_01_flat = torch.stack(ab_01, 1).view(batch_size, 1, -1) ab_10_flat = torch.stack(ab_10, 1).view(batch_size, 1, -1) ab_11_flat = torch.stack(ab_11, 1).view(batch_size, 1, -1) next_state_flat = torch.squeeze( torch.cat((ab_00_flat, ab_01_flat, ab_10_flat, ab_11_flat), 2), 1) ## These are our deterministic masks: (Dis)Allow certain states at start, end, and depth limits # At time 0, and only at time 0, prev_Depth is 0, so we must choose 1/0 mask = torch.ones(batch_size, 4, device=device) # if we're at depth 0, we can only allow 1/0 mask[:, (0, 1, 3)] *= (1 - torch.eq(prev_depth, 0).float()) # if we're at depth d, we cannot allow 1/0 mask[:, (2, )] *= (1 - torch.eq(prev_depth, self.depth).float()) # if we're at depth 1, we cannot allow 0/1 (reduce in the middle of the sentence) mask[:, (1, )] *= (1 - torch.eq(prev_depth, 1).float()) # if we're in word-learning mode, disallow either 0/0 or 1/1. (Logic above takes care of forcing 1/0 where necessary) if pretrain: if dice > 0.5: mask[:, 0] *= 0 else: mask[:, 3] *= 0 # Get the attention variables dist = RelaxedOneHotCategorical( temp, torch.nn.functional.softmax(torch.sigmoid( self.attention(next_state_flat[:, self.depth_size:])), dim=1)) att_vars = torch.nn.functional.normalize(mask * dist.sample()) # att_vars = torch.nn.functional.softmax(mask * SampleST(torch.nn.functional.softmax( torch.sigmoid(self.attention( next_state_flat[:, self.depth_size:] ) ), dim=1 ), temp)) if self.parsing: selection = ArgmaxST(att_vars) else: selection = att_vars sect_end = time.time() self.mask_time += (sect_end - sect_start) sect_start = time.time() striped = torch.mm(selection, self.selection_striper) ## It's ok up to here. Now we need to broadcast a dot product for each ## stripe across the stacked identity matrix to mask out the unwanted ## parts of the state space mask_list = [] for b in range(batch_size): batch_mask = [] expanded_stripe = striped[b].repeat(self.stride, 1) batch_mask = expanded_stripe * self.stripe_expander.t() mask_list.append(batch_mask.t()) striped_identity = torch.stack(mask_list, 0) sect_end = time.time() self.batch_mask_time += (sect_end - sect_start) sect_start = time.time() # It's ok below here. hidden = torch.squeeze( torch.bmm(torch.unsqueeze(next_state_flat, 1), striped_identity)) # now hidden needs to be re-viewed as batch x depth x hidden hidden = hidden.view(batch_size, self.depth + 1, self.hidden_size * 2) next_a = hidden[:, :, :self.hidden_size] next_b = hidden[:, :, self.hidden_size:] # compute the selected next depth and equivalent f/j variables from the selection: next_depth = torch.unsqueeze( (selection[:, 0].long() * prev_depth[:, 0] + selection[:, 1].long() * (prev_depth[:, 0] - 1) + selection[:, 2].long() * (prev_depth[:, 0] + 1) + selection[:, 3].long() * prev_depth[:, 0]), 1) f = torch.unsqueeze((selection[:, 2] * 1 + selection[:, 3] * 1), 1) j = torch.unsqueeze((selection[:, 1] * 1 + selection[:, 3] * 1), 1) sect_end = time.time() self.finish_time += (sect_end - sect_start) return next_a, next_b, next_depth, (f, j)
def forward(self, input, targets, args, n_particles, criterion, test=False): """ This version takes the inputs, and does not expose the logits, but instead computes the losses directly """ # run the input and teacher-forcing inputs through the embedding layers here seq_len, batch_sz = input.size() emb = self.inp_embedding(input) hidden = self.init_hidden(batch_sz, self.nhid, 2) # bidirectional hidden_states, (h, c) = self.encoder(emb, hidden) # teacher-forcing out_emb = self.dropout(self.dec_embedding(targets)) # now, we'll replicate it for each particle - it's currently [seq_len x batch_sz x nhid] hidden_states = hidden_states.repeat(1, n_particles, 1) out_emb = out_emb.repeat(1, n_particles, 1) # now [seq_len x (n_particles x batch_sz) x nhid] # out_emb, hidden_states should be viewed as (n_particles x batch_sz) - this means that's true for h as well # run the z-decoder at this point, evaluating the NLL at each step p_h = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) # initially zero h = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) d_h = self.init_hidden(batch_sz * n_particles, self.nhid, squeeze=True) nlls = hidden_states.data.new(seq_len, batch_sz * n_particles) loss = 0 accumulated_weights = -math.log(n_particles) # will contain log w_{t - 1} resamples = 0 for i in range(seq_len): h = self.z_decoder(hidden_states[i], h) logits = self.logits(h) # build the next z sample if test: q = OneHotCategorical(logits=logits) z = q.sample() else: q = RelaxedOneHotCategorical(temperature=self.temp, logits=logits) z = q.rsample() h = z # prior if test: p = OneHotCategorical(logits=p_h) else: p = RelaxedOneHotCategorical(temperature=self.temp_prior, logits=p_h) # now, compute the log-likelihood of the data given this mean, and the input out_emb d_h = self.decoder(torch.cat([z, out_emb[i]], 1), d_h) decoder_logits = self.out_embedding(d_h) NLL = criterion(decoder_logits, input[i].repeat(n_particles)) nlls[i] = NLL.data # compute the weight using `reweight` on page (4) f_term = p.log_prob(z) # prior r_term = q.log_prob(z) # proposal alpha = -NLL + args.anneal * (f_term - r_term) wa = accumulated_weights + alpha.view(n_particles, batch_sz) # sample ancestors, and reindex everything Z = log_sum_exp(wa, dim=0) # line 7 if (Z.data > 0.1).any(): pdb.set_trace() loss += Z # line 8 accumulated_weights = wa - Z # line 9 probs = accumulated_weights.data.exp() probs += 0.01 probs = probs / probs.sum(0, keepdim=True) effective_sample_size = 1./probs.pow(2).sum(0) # resample / RSAMP if 3 batch elements need resampling if ((effective_sample_size / n_particles) < 0.3).sum() > 0: resamples += 1 ancestors = torch.multinomial(probs.transpose(0, 1), n_particles, True) # now, reindex, which is the most important thing offsets = n_particles * torch.arange(batch_sz).unsqueeze(1).repeat(1, n_particles).long() if ancestors.is_cuda: offsets = offsets.cuda() unrolled_idx = Variable(ancestors.t().contiguous()+offsets).view(-1) h = torch.index_select(h, 0, unrolled_idx) p_h = torch.index_select(p_h, 0, unrolled_idx) d_h = torch.index_select(d_h, 0, unrolled_idx) # reset accumulated_weights accumulated_weights = -math.log(n_particles) # will contain log w_{t - 1} if i != seq_len - 1: # build the next mean prediction, feeding in the correct ancestor p_h = self.ar_prior_logits(torch.cat([h, out_emb[i]], 1), p_h) # now, we calculate the final log-marginal estimator nll = nlls.view(seq_len, n_particles, batch_sz).mean(1).sum() return -loss.sum(), nll, (seq_len * batch_sz), resamples
def gumbel_softmax(input: torch.Tensor, dim: int, temp: float) -> torch.Tensor: """ gumbel softmax """ return RelaxedOneHotCategorical(temp, input.softmax(dim=dim)).rsample()
def forward(self, input, args, n_particles, test=False): """ evaluation is the IWAE-10 bound """ if test: n_particles = 10 else: n_particles = 1 pi = F.log_softmax(self.pi, 0) # run the input and teacher-forcing inputs through the embedding layers here seq_len, batch_sz = input.size() emb = self.inp_embedding(input) hidden = self.init_hidden(batch_sz, self.nhid, 2) # bidirectional hidden_states, (_, _) = self.encoder(emb, hidden) hidden_states = hidden_states.repeat(1, n_particles, 1) # run the z-decoder at this point, evaluating the NLL at each step h = (Variable( hidden_states.data.new(batch_sz * n_particles, self.hidden_size).zero_()), Variable( hidden_states.data.new(batch_sz * n_particles, self.hidden_size).zero_())) nlls = hidden_states.data.new(seq_len, batch_sz * n_particles) loss = 0 # now a log-prob prior_logits = pi.unsqueeze(0).expand(batch_sz * n_particles, self.z_dim) prior_h = (Variable(torch.zeros(batch_sz * n_particles, 50).cuda()), Variable(torch.zeros(batch_sz * n_particles, 50).cuda())) logits = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) feed = None x_emb = self.lockdrop(emb, self.dropout_x) for i in range(seq_len): # build the next z sample - not differentiable! we don't train the inference network logits = F.log_softmax(self.logits(hidden_states[i]), 1) if test: q = OneHotCategorical(logits=logits) # p = OneHotCategorical(logits=prior_logits) z = q.sample() else: q = RelaxedOneHotCategorical(temperature=self.temp, logits=logits) # p = RelaxedOneHotCategorical(temperature=self.temp_prior, logits=prior_logits) z = q.rsample() # this should be batch_sz x x_dim scores = torch.mm(self.project(torch.cat([h[0], z], 1)), self.emit.t()) NLL = nn.CrossEntropyLoss(reduce=False)( scores, input[i].repeat(n_particles)) # KL = q.log_prob(z) - p.log_prob(z) KL = (logits.exp() * (logits - prior_logits)).sum(1) loss += (NLL + KL) # else: # loss += (NLL + args.anneal * KL) nlls[i] = NLL.data # set things up for next time if i != seq_len - 1: feed = torch.cat( [emb[i].repeat(n_particles, 1), self.z_emb(z)], 1) prior_h = self.z_decoder(feed, prior_h) prior_logits = F.log_softmax(self.project_z(prior_h[0]), 1) h = self.hidden_rnn(x_emb[i].repeat(n_particles, 1), h) # feed the next word into the RNN if n_particles != 1: loss = -log_sum_exp(-loss.view(n_particles, batch_sz), 0) + math.log(n_particles) NLL = -log_sum_exp( -nlls.view(seq_len, n_particles, batch_sz), 1) + math.log( n_particles) # not quite accurate, but what can you do else: NLL = nlls # now, we calculate the final log-marginal estimator return loss.sum(), NLL.sum(), (seq_len * batch_sz), 0
def sampled_filter(self, input, args, n_particles, emb, hidden_states): seq_len, batch_sz = input.size() T = F.log_softmax(self.T, 0) # NOTE: in log-space pi = F.log_softmax(self.pi, 0) # NOTE: in log-space emit = self.calc_emit() hidden_states = hidden_states.repeat(1, n_particles, 1) nlls = hidden_states.data.new(seq_len, batch_sz * n_particles) loss = 0 accumulated_weights = -math.log( n_particles) # will contain log w_{t - 1} resamples = 0 # in log probability space prior_logits = pi.unsqueeze(0).expand(batch_sz * n_particles, self.z_dim) for i in range(seq_len): # the approximate posterior comes from the same thing as before logits = self.logits(hidden_states[i]) if not self.training: # this is crucial!! p = OneHotCategorical(logits=prior_logits) q = OneHotCategorical(logits=logits) z = q.sample() else: p = RelaxedOneHotCategorical(temperature=self.temp_prior, logits=prior_logits) q = RelaxedOneHotCategorical(temperature=self.temp, logits=logits) z = q.rsample() # now, compute the log-likelihood of the data given this z-sample emission = F.embedding(input[i].repeat(n_particles), emit) NLL = -(emission * z).sum(1) # NLL = -self.decode(z, input[i].repeat(n_particles), (emit,)) # diff. w.r.t. z nlls[i] = NLL.data # compute the weight using `reweight` on page (4) f_term = p.log_prob(z) # prior r_term = q.log_prob(z) # proposal alpha = -NLL + (f_term - r_term) wa = accumulated_weights + alpha.view(n_particles, batch_sz) Z = log_sum_exp(wa, dim=0) # line 7 loss += Z # line 8 accumulated_weights = wa - Z # F.log_softmax(wa, dim=0) # line 9 # sample ancestors, and reindex everything if args.filter: probs = accumulated_weights.data.exp() probs += 0.01 probs = probs / probs.sum(0, keepdim=True) effective_sample_size = 1. / probs.pow(2).sum(0) # probs is [n_particles, batch_sz] # ancestors [2 x 15] = [[0, 0, 0, ..., 0], [0, 1, 2, 3, ...]] # offsets [2 x 15] = [[0, 0, 0, ..., 0], [1, 1, 1, 1, ...]] # resample / RSAMP if ((effective_sample_size / n_particles) < 0.3).sum() > 0: resamples += 1 ancestors = torch.multinomial(probs.transpose(0, 1), n_particles, True) # now, reindex, which is the most important thing offsets = n_particles * torch.arange(batch_sz).unsqueeze( 1).repeat(1, n_particles).long() if ancestors.is_cuda: offsets = offsets.cuda() unrolled_idx = Variable(ancestors + offsets).view(-1) z = torch.index_select(z, 0, unrolled_idx) # reset accumulated_weights accumulated_weights = -math.log( n_particles) # will contain log w_{t - 1} if i != seq_len - 1: # now in log-probability space prior_logits = log_sum_exp(T.unsqueeze(0) + z.unsqueeze(1), 2) if self.training: (-loss.sum() / (seq_len * batch_sz * n_particles)).backward(retain_graph=True) return -loss.sum(), nlls.sum(), seq_len * batch_sz * n_particles, 0
def inference( self, x, y=None, temperature=None, n_samples=1, reparam=True, encoder_key="default", counts=None, ): """ Dimension choice (n_categories, n_is, n_batch, n_latent) log_q (n_categories, n_is, n_batch) """ if temperature is None: raise ValueError( "Please provide a temperature for the relaxed OneHot distribution" ) if counts is not None: return self.inference_defensive_sampling( x=x, y=y, temperature=temperature, counts=counts ) n_cat = self.n_labels n_batch = len(x) # Z | X inp = x q_z1 = self.encoder_z1[encoder_key]( inp, n_samples=n_samples, reparam=reparam, squeeze=False ) # if not self.do_iaf: qz1_m = q_z1["q_m"] qz1_v = q_z1["q_v"] z1 = q_z1["latent"] assert z1.dim() == 3 # log_qz1_x = Normal(qz1_m, qz1_v.sqrt()).log_prob(z1).sum(-1) log_qz1_x = q_z1["dist"].log_prob(z1) dfs = q_z1.get("df", None) if q_z1["sum_last"]: log_qz1_x = log_qz1_x.sum(-1) z1s = z1 # torch.cuda.synchronize() # C | Z # Broadcast labels if necessary qc_z1 = self.classifier[encoder_key](z1) log_qc_z1 = qc_z1.log() qc_z1_all_probas = qc_z1 # C if y is None: if reparam: cat_dist = RelaxedOneHotCategorical( temperature=temperature, probs=qc_z1 ) ys_probs = cat_dist.rsample() else: cat_dist = OneHotCategorical(probs=qc_z1) ys_probs = cat_dist.sample() ys = (ys_probs == ys_probs.max(-1, keepdim=True).values).float() y_int = ys.argmax(-1) else: ys = torch.cuda.FloatTensor(n_batch, n_cat) ys.zero_() ys.scatter_(1, y.view(-1, 1), 1) ys = ys.view(1, n_batch, n_cat).expand(n_samples, n_batch, n_cat) y_int = y.view(1, -1).expand(n_samples, n_batch) log_pc = self.y_prior.log_prob(y_int) assert y_int.unsqueeze(-1).shape == (n_samples, n_batch, 1), y_int.shape log_qc_z1 = torch.gather(log_qc_z1, dim=-1, index=y_int.unsqueeze(-1)).squeeze( -1 ) qc_z1 = torch.gather(qc_z1, dim=-1, index=y_int.unsqueeze(-1)).squeeze(-1) assert qc_z1.shape == (n_samples, n_batch) pc = log_pc.exp() # U | Z1, C z1_y = torch.cat([z1s, ys], dim=-1) q_z2_z1 = self.encoder_z2_z1[encoder_key](z1_y, n_samples=1, reparam=reparam) z2 = q_z2_z1["latent"] qz2_z1_m = q_z2_z1["q_m"] qz2_z1_v = q_z2_z1["q_v"] # log_qz2_z1 = Normal(q_z2_z1["q_m"], q_z2_z1["q_v"].sqrt()).log_prob(z2).sum(-1) log_qz2_z1 = q_z2_z1["dist"].log_prob(z2) if q_z2_z1["sum_last"]: log_qz2_z1 = log_qz2_z1.sum(-1) z2_y = torch.cat([z2, ys], dim=-1) pz1_z2m, pz1_z2_v = self.decoder_z1_z2(z2_y) log_pz1_z2 = Normal(pz1_z2m, pz1_z2_v.sqrt()).log_prob(z1).sum(-1) log_pz2 = Normal(torch.zeros_like(z2), torch.ones_like(z2)).log_prob(z2).sum(-1) px_z_loc = self.x_decoder(z1) log_px_z = Bernoulli(px_z_loc).log_prob(x).sum(-1) generative_density = log_pz2 + log_pc + log_pz1_z2 + log_px_z variational_density = log_qz1_x + log_qz2_z1 log_ratio = generative_density - variational_density variables = dict( z1=z1, ys=ys, z2=z2, qz1_m=qz1_m, qz1_v=qz1_v, qz2_z1_m=qz2_z1_m, qz2_z1_v=qz2_z1_v, pz1_z2m=pz1_z2m, pz1_z2_v=pz1_z2_v, px_z_m=px_z_loc, log_qz1_x=log_qz1_x, qc_z1=qc_z1, log_qc_z1=log_qc_z1, log_qz2_z1=log_qz2_z1, log_pz2=log_pz2, log_pc=log_pc, pc=pc, log_pz1_z2=log_pz1_z2, log_px_z=log_px_z, generative_density=generative_density, variational_density=variational_density, log_ratio=log_ratio, qc_z1_all_probas=qc_z1_all_probas, df=dfs, ) # torch.cuda.synchronize() return variables