def forward(self, data, mask, char, char_mask, label, prior_mean, prior_logvar, kl_temp): data, mask, char, char_mask, label, prior_mean, prior_logvar = \ self.to_vars(data, mask, char, char_mask, label, prior_mean, prior_logvar) batch_size, batch_len = data.size() input_vecs = self.get_input_vecs(data, mask, char, char_mask) hidden_vecs, _, _ = model_utils.get_rnn_output(input_vecs, mask, self.word_encoder) z, mean_qs, logvar_qs = \ self.to_latent_variable(hidden_vecs, mask, self.sampling) mean_x = self.z2x(z) x = model_utils.gaussian( mean_x, Variable(mean_x.data.new(1).fill_(self.expe.config.xvar))) x_pred = self.x2token(x) if label is None: sup_loss = class_logits = None else: class_logits = self.classifier(z) sup_loss = F.cross_entropy(class_logits.view( batch_size * batch_len, -1), label.view(-1).long(), reduce=False).view_as(data) * mask sup_loss = sup_loss.sum(-1) / mask.sum(-1) log_loss = F.cross_entropy(x_pred.view(batch_size * batch_len, -1), data.view(-1).long(), reduce=False).view_as(data) * mask log_loss = log_loss.sum(-1) / mask.sum(-1) if prior_mean is not None and prior_logvar is not None: kl_div = model_utils.compute_KL_div(mean_qs, logvar_qs, prior_mean, prior_logvar) kl_div = (kl_div * mask.unsqueeze(-1)).sum(-1) kl_div = kl_div.sum(-1) / mask.sum(-1) loss = log_loss + kl_temp * kl_div else: kl_div = None loss = log_loss if sup_loss is not None: loss = loss + sup_loss return loss.mean(), log_loss.mean(), \ kl_div.mean() if kl_div is not None else None, \ sup_loss.mean() if sup_loss is not None else None, \ mean_qs, logvar_qs, \ class_logits.data.cpu().numpy().argmax(-1) \ if class_logits is not None else None
def forward(self, inputs, mask, sample): """ inputs: batch x batch_len x input_size """ batch_size, batch_len, _ = inputs.size() mean_qs = self.q_mean_mlp(inputs) * mask.unsqueeze(-1) logvar_qs = self.q_logvar_mlp(inputs) * mask.unsqueeze(-1) mean2_qs = self.q_mean2_mlp(inputs) * mask.unsqueeze(-1) logvar2_qs = self.q_logvar2_mlp(inputs) * mask.unsqueeze(-1) if sample: y = gaussian(mean2_qs, logvar2_qs) * mask.unsqueeze(-1) else: y = mean2_qs * mask.unsqueeze(-1) if sample: z = gaussian(mean_qs, logvar_qs) * mask.unsqueeze(-1) else: z = mean_qs * mask.unsqueeze(-1) return z, y, mean_qs, logvar_qs, mean2_qs, logvar2_qs
def forward(self, data, mask, char, char_mask, label, prior_mean, prior_logvar, kl_temp): if prior_mean is not None: prior_mean1, prior_mean2 = prior_mean prior_logvar1, prior_logvar2 = prior_logvar else: prior_mean1 = prior_mean2 = prior_logvar1 = prior_logvar2 = None data, mask, char, char_mask, label, prior_mean1, \ prior_mean2, prior_logvar1, prior_logvar2 = \ self.to_vars(data, mask, char, char_mask, label, prior_mean1, prior_mean2, prior_logvar1, prior_logvar2) batch_size, batch_len = data.size() input_vecs = self.get_input_vecs(data, mask, char, char_mask) hidden_vecs, _, _ = model_utils.get_rnn_output(input_vecs, mask, self.word_encoder) z, y, mean_qs, logvar_qs, mean2_qs, logvar2_qs = \ self.to_latent_variable(hidden_vecs, mask, self.sampling) if self.expe.config.model.lower() == "flat": yz = torch.cat([z, y], dim=-1) elif self.expe.config.model.lower() == "hier": yz = z mean_x = self.yz2x(yz) x = model_utils.gaussian( mean_x, Variable(mean_x.data.new(1).fill_(self.expe.config.xvar))) x_pred = self.x2token(x) if label is None: sup_loss = class_logits = None else: class_logits = self.classifier(y) sup_loss = F.cross_entropy(class_logits.view( batch_size * batch_len, -1), label.view(-1).long(), reduce=False).view_as(data) * mask sup_loss = sup_loss.sum(-1) / mask.sum(-1) log_loss = F.cross_entropy(x_pred.view(batch_size * batch_len, -1), data.view(-1).long(), reduce=False).view_as(data) * mask log_loss = log_loss.sum(-1) / mask.sum(-1) if prior_mean is not None: kl_div1 = model_utils.compute_KL_div(mean_qs, logvar_qs, prior_mean1, prior_logvar1) kl_div2 = model_utils.compute_KL_div(mean2_qs, logvar2_qs, prior_mean2, prior_logvar2) kl_div = (kl_div1 * mask.unsqueeze(-1)).sum(-1) + \ (kl_div2 * mask.unsqueeze(-1)).sum(-1) kl_div = kl_div.sum(-1) / mask.sum(-1) loss = log_loss + kl_temp * kl_div else: kl_div = None loss = log_loss if sup_loss is not None: loss = loss + sup_loss return loss.mean(), log_loss.mean(), \ kl_div.mean() if kl_div is not None else None, \ sup_loss.mean() if sup_loss is not None else None, \ mean_qs, logvar_qs, mean2_qs, logvar2_qs, \ class_logits.data.cpu().numpy().argmax(-1) \ if class_logits is not None else None