예제 #1
0
    def forward(self, data, mask, char, char_mask, label, prior_mean,
                prior_logvar, kl_temp):
        data, mask, char, char_mask, label, prior_mean, prior_logvar = \
            self.to_vars(data, mask, char, char_mask, label,
                         prior_mean, prior_logvar)

        batch_size, batch_len = data.size()
        input_vecs = self.get_input_vecs(data, mask, char, char_mask)
        hidden_vecs, _, _ = model_utils.get_rnn_output(input_vecs, mask,
                                                       self.word_encoder)

        z, mean_qs, logvar_qs = \
            self.to_latent_variable(hidden_vecs, mask, self.sampling)

        mean_x = self.z2x(z)

        x = model_utils.gaussian(
            mean_x, Variable(mean_x.data.new(1).fill_(self.expe.config.xvar)))

        x_pred = self.x2token(x)

        if label is None:
            sup_loss = class_logits = None
        else:
            class_logits = self.classifier(z)
            sup_loss = F.cross_entropy(class_logits.view(
                batch_size * batch_len, -1),
                                       label.view(-1).long(),
                                       reduce=False).view_as(data) * mask
            sup_loss = sup_loss.sum(-1) / mask.sum(-1)

        log_loss = F.cross_entropy(x_pred.view(batch_size * batch_len, -1),
                                   data.view(-1).long(),
                                   reduce=False).view_as(data) * mask
        log_loss = log_loss.sum(-1) / mask.sum(-1)

        if prior_mean is not None and prior_logvar is not None:
            kl_div = model_utils.compute_KL_div(mean_qs, logvar_qs, prior_mean,
                                                prior_logvar)

            kl_div = (kl_div * mask.unsqueeze(-1)).sum(-1)
            kl_div = kl_div.sum(-1) / mask.sum(-1)

            loss = log_loss + kl_temp * kl_div
        else:
            kl_div = None
            loss = log_loss

        if sup_loss is not None:
            loss = loss + sup_loss

        return loss.mean(), log_loss.mean(), \
            kl_div.mean() if kl_div is not None else None, \
            sup_loss.mean() if sup_loss is not None else None, \
            mean_qs, logvar_qs, \
            class_logits.data.cpu().numpy().argmax(-1) \
            if class_logits is not None else None
예제 #2
0
    def forward(self, inputs, mask, sample):
        """
        inputs: batch x batch_len x input_size
        """
        batch_size, batch_len, _ = inputs.size()

        mean_qs = self.q_mean_mlp(inputs) * mask.unsqueeze(-1)
        logvar_qs = self.q_logvar_mlp(inputs) * mask.unsqueeze(-1)

        mean2_qs = self.q_mean2_mlp(inputs) * mask.unsqueeze(-1)
        logvar2_qs = self.q_logvar2_mlp(inputs) * mask.unsqueeze(-1)

        if sample:
            y = gaussian(mean2_qs, logvar2_qs) * mask.unsqueeze(-1)
        else:
            y = mean2_qs * mask.unsqueeze(-1)

        if sample:
            z = gaussian(mean_qs, logvar_qs) * mask.unsqueeze(-1)
        else:
            z = mean_qs * mask.unsqueeze(-1)

        return z, y, mean_qs, logvar_qs, mean2_qs, logvar2_qs
예제 #3
0
    def forward(self, data, mask, char, char_mask, label, prior_mean,
                prior_logvar, kl_temp):
        if prior_mean is not None:
            prior_mean1, prior_mean2 = prior_mean
            prior_logvar1, prior_logvar2 = prior_logvar
        else:
            prior_mean1 = prior_mean2 = prior_logvar1 = prior_logvar2 = None

        data, mask, char, char_mask, label, prior_mean1, \
            prior_mean2, prior_logvar1, prior_logvar2 = \
            self.to_vars(data, mask, char, char_mask, label,
                         prior_mean1, prior_mean2,
                         prior_logvar1, prior_logvar2)

        batch_size, batch_len = data.size()
        input_vecs = self.get_input_vecs(data, mask, char, char_mask)
        hidden_vecs, _, _ = model_utils.get_rnn_output(input_vecs, mask,
                                                       self.word_encoder)

        z, y, mean_qs, logvar_qs, mean2_qs, logvar2_qs = \
            self.to_latent_variable(hidden_vecs, mask, self.sampling)

        if self.expe.config.model.lower() == "flat":
            yz = torch.cat([z, y], dim=-1)
        elif self.expe.config.model.lower() == "hier":
            yz = z

        mean_x = self.yz2x(yz)

        x = model_utils.gaussian(
            mean_x, Variable(mean_x.data.new(1).fill_(self.expe.config.xvar)))

        x_pred = self.x2token(x)

        if label is None:
            sup_loss = class_logits = None
        else:
            class_logits = self.classifier(y)
            sup_loss = F.cross_entropy(class_logits.view(
                batch_size * batch_len, -1),
                                       label.view(-1).long(),
                                       reduce=False).view_as(data) * mask
            sup_loss = sup_loss.sum(-1) / mask.sum(-1)

        log_loss = F.cross_entropy(x_pred.view(batch_size * batch_len, -1),
                                   data.view(-1).long(),
                                   reduce=False).view_as(data) * mask
        log_loss = log_loss.sum(-1) / mask.sum(-1)

        if prior_mean is not None:
            kl_div1 = model_utils.compute_KL_div(mean_qs, logvar_qs,
                                                 prior_mean1, prior_logvar1)
            kl_div2 = model_utils.compute_KL_div(mean2_qs, logvar2_qs,
                                                 prior_mean2, prior_logvar2)

            kl_div = (kl_div1 * mask.unsqueeze(-1)).sum(-1) + \
                (kl_div2 * mask.unsqueeze(-1)).sum(-1)
            kl_div = kl_div.sum(-1) / mask.sum(-1)

            loss = log_loss + kl_temp * kl_div
        else:
            kl_div = None
            loss = log_loss

        if sup_loss is not None:
            loss = loss + sup_loss

        return loss.mean(), log_loss.mean(), \
            kl_div.mean() if kl_div is not None else None, \
            sup_loss.mean() if sup_loss is not None else None, \
            mean_qs, logvar_qs, mean2_qs, logvar2_qs, \
            class_logits.data.cpu().numpy().argmax(-1) \
            if class_logits is not None else None