Ejemplo n.º 1
0
    def forward(
        self,
        source: torch.Tensor,  # (b, max_sou_seq_len)
        source_mask: torch.Tensor,  # (b, max_sou_seq_len)
        target: torch.Tensor,  # (b, max_tar_seq_len)
        target_mask: torch.Tensor,  # (b, max_tar_seq_len)
        label: torch.Tensor,  # (b, max_tar_seq_len)
        annealing: float
    ) -> Tuple[torch.Tensor, Tuple]:  # (b, max_tar_seq_len, d_emb)
        b = source.size(0)
        source_embedded = self.source_embed(
            source, source_mask)  # (b, max_sou_seq_len, d_s_emb)
        e_out, (hidden, _) = self.encoder(source_embedded, source_mask)

        h = self.transform(hidden, True)  # (n_e_lay * b, d_e_hid * n_dir)
        z_mu = self.z_mu(h)  # (n_e_lay * b, d_e_hid * n_dir)
        z_ln_var = self.z_ln_var(h)  # (n_e_lay * b, d_e_hid * n_dir)
        hidden = Gaussian(z_mu, z_ln_var).rsample()  # reparameterization trick
        # (n_e_lay * b, d_e_hid * n_dir) -> (b, d_e_hid * n_dir), initialize cell state
        states = (self.transform(hidden, False),
                  self.transform(hidden.new_zeros(hidden.size()), False))

        max_tar_seq_len = target.size(1)
        output = source_embedded.new_zeros(
            (b, max_tar_seq_len, self.target_vocab_size))
        target_embedded = self.target_embed(
            target, target_mask)  # (b, max_tar_seq_len, d_t_emb)
        target_embedded = target_embedded.transpose(
            1, 0)  # (max_tar_seq_len, b, d_t_emb)
        total_context_loss = 0
        # decode per word
        for i in range(max_tar_seq_len):
            d_out, states = self.decoder(target_embedded[i], target_mask[:, i],
                                         states)
            if self.attention:
                context, cs = self.calculate_context_vector(
                    e_out, states[0], source_mask, True)  # (b, d_d_hid)
                total_context_loss += self.calculate_context_loss(cs)
                d_out = torch.cat((d_out, context), dim=-1)  # (b, d_d_hid * 2)
            output[:, i, :] = self.w(self.maxout(
                d_out))  # (b, d_d_hid) -> (b, d_out) -> (b, tar_vocab_size)
        loss, details = self.calculate_loss(output, target_mask, label, z_mu,
                                            z_ln_var, total_context_loss,
                                            annealing)
        if torch.isnan(loss).any():
            raise ValueError('nan detected')
        return loss, details
Ejemplo n.º 2
0
    def predict(
        self,
        source: torch.Tensor,  # (b, max_sou_seq_len)
        source_mask: torch.Tensor,  # (b, max_sou_seq_len)
        sampling: bool = True
    ) -> torch.Tensor:  # (b, max_seq_len)
        self.eval()
        with torch.no_grad():
            b = source.size(0)
            source_embedded = self.source_embed(
                source, source_mask)  # (b, max_seq_len, d_s_emb)
            e_out, (hidden, _) = self.encoder(source_embedded, source_mask)

            h = self.transform(hidden, True)
            z_mu = self.z_mu(h)
            z_ln_var = self.z_ln_var(h)
            hidden = Gaussian(z_mu, z_ln_var).sample() if sampling else z_mu
            states = (self.transform(hidden, False),
                      self.transform(hidden.new_zeros(hidden.size()), False))

            target_id = torch.full((b, 1), BOS,
                                   dtype=source.dtype).to(source.device)
            target_mask = torch.full(
                (b, 1), 1, dtype=source_mask.dtype).to(source_mask.device)
            predictions = source_embedded.new_zeros(b, self.max_seq_len, 1)
            for i in range(self.max_seq_len):
                target_embedded = self.target_embed(
                    target_id, target_mask).squeeze(1)  # (b, d_t_emb)
                d_out, states = self.decoder(target_embedded,
                                             target_mask[:, 0], states)
                if self.attention:
                    context, _ = self.calculate_context_vector(
                        e_out, states[0], source_mask, False)
                    d_out = torch.cat((d_out, context), dim=-1)

                output = self.w(self.maxout(d_out))  # (b, tar_vocab_size)
                output[:, UNK] -= 1e6  # mask <UNK>
                if i == 0:
                    output[:, EOS] -= 1e6  # avoid 0 length output
                prediction = torch.argmax(F.softmax(output, dim=1),
                                          dim=1).unsqueeze(1)  # (b, 1), greedy
                target_mask = target_mask * prediction.ne(EOS).type(
                    target_mask.dtype)
                target_id = prediction
                predictions[:, i, :] = prediction
        return predictions