예제 #1
0
    def _topic_loss(self, inp, dec1, src_lengths, trg_lengths):
        """
        Compute the pairwise distance of various outputs of the seq^3 architecture.
        Args:
            enc1: the outputs of the first encoder (input sequence)
            dec1: the outputs of the first decoder (latent sequence)
            src_lengths: the lengths of the input sequence
            trg_lengths: the lengths of the targer sequence (summary)

        """

        enc_mask = sequence_mask(src_lengths).unsqueeze(-1).float()
        dec_mask = sequence_mask(trg_lengths - 1).unsqueeze(-1).float()

        enc_embs = self.model.inp_encoder.embed(inp)
        dec_embs = self.model.compressor.embed.expectation(dec1[3])

        if self.config["model"]["topic_idf"]:
            enc1_energies = self.model.idf(inp)
            # dec1_energies = expected_vecs(dec1[3], self.model.idf.weight)

            x_emb, att_x = avg_vectors(enc_embs, enc_mask, enc1_energies)
            # y_emb, att_y = avg_vectors(dec_reps, dec_mask, dec1_energies)
            y_emb, att_y = avg_vectors(dec_embs, dec_mask)

        else:
            x_emb, att_x = avg_vectors(enc_embs, enc_mask)
            y_emb, att_y = avg_vectors(dec_embs, dec_mask)

        distance = self.config["model"]["topic_distance"]
        loss = pairwise_loss(x_emb, y_emb, distance)

        return loss, (att_x, att_y)
예제 #2
0
def centroid_loss(enc_feats,
                  dec_feats,
                  src_lengths,
                  trg_lengths,
                  enc_scores=None,
                  dec_scores=None,
                  distance="cosine",
                  pool_func="mean",
                  mapping: torch.Tensor = None,
                  **kwargs):
    """
    Compute the pairwise distance of various outputs of the seq^3 architecture.
    """

    enc_mask = sequence_mask(src_lengths).unsqueeze(-1).float()
    dec_mask = sequence_mask(trg_lengths).unsqueeze(-1).float()

    # Aggregate the vectors of each sequence
    if pool_func == "mean":
        x_emb, _ = avg_vectors(enc_feats, enc_mask, enc_scores)
        y_emb, _ = avg_vectors(dec_feats, dec_mask, dec_scores)
    elif pool_func == "max":
        x_emb = enc_feats.max(1)[0]
        y_emb = dec_feats.max(1)[0]
    elif pool_func == "sum":
        x_emb = enc_feats.sum(1)
        y_emb = dec_feats.sum(1)
    else:
        raise ValueError

    # Apply a rotation operation on the source embedding
    if mapping is not None:
        x_emb = torch.matmul(x_emb, mapping)

    return pairwise_loss(x_emb, y_emb, distance)
예제 #3
0
    def forward(self, sequence, lengths):

        energies = self.attention(sequence).squeeze()

        # construct a mask, based on sentence lengths
        if len(energies.size()) < 2:
            mask = sequence_mask(lengths, 1)
        else:
            mask = sequence_mask(lengths, energies.size(1))
        scores = masked_normalization(energies, mask)
        contexts = (sequence * scores.unsqueeze(-1)).sum(1)

        return contexts, scores
예제 #4
0
    def cross_entropy_loss(self, logits, labels, lengths=None):
        """
        output (FloatTensor): batch_size x n_classes
        target (LongTensor): batch_size
        """
        _logits = logits.contiguous().view(-1, logits.size(-1))

        if self.ignore_index >= 0:
            _labels = labels.contiguous().view(-1)
        else:
            assert lengths is not None
            mask = ~sequence_mask(lengths, labels.size(1))
            _labels = labels.masked_fill_(mask, -1).contiguous().view(-1)

        if lengths is None:
            loss = F.cross_entropy(_logits,
                                   _labels,
                                   ignore_index=self.ignore_index)
            return loss

        else:
            _loss = F.cross_entropy(_logits,
                                    _labels,
                                    ignore_index=self.ignore_index,
                                    reduction='none')
            _loss_per_step = _loss.view(labels.size())
            loss = _loss.sum() / lengths.float().sum()
            return loss, _loss_per_step
예제 #5
0
    def forward(self, src, trg, src_lengths, trg_lengths, **kwargs):
        enc = self.encode(src, src_lengths)
        src_mask = sequence_mask(src_lengths, src.size(1))
        dec_init = self.init_decoder(enc["outputs"], enc["hidden"])
        dec = self.decode(trg, enc["outputs"], dec_init, src_lengths, src_mask,
                          trg_lengths, **kwargs)

        return enc, dec
예제 #6
0
    def forward(self, x, lengths):
        emb = self.embed(x)

        # mask padded + future steps
        pad_mask = sequence_mask(lengths, x.size(1)).unsqueeze(1)
        mask = pad_mask & subsequent_mask(x.size(1)).type_as(pad_mask)

        states = self.encoder(emb, None, mask)[0]
        logits = self.logits(states)

        return {"logits": logits}
예제 #7
0
    def forward(self, sequence, query, lengths, coverage=None):

        energies = self.score(sequence, query, coverage)

        # construct a mask, based on sentence lengths
        mask = sequence_mask(lengths, energies.size(1))

        scores = masked_normalization_inf(energies, mask)
        # scores = self.masked_normalization(energies, mask)

        contexts = (sequence * scores.unsqueeze(-1)).sum(1)

        return contexts, scores
예제 #8
0
    def forward(self, sequence, lengths):
        # sequence size: batch_size x length x rnn size

        energies = self.attention(sequence).squeeze()

        # construct a mask, based on sentence lengths
        mask = sequence_mask(lengths, energies.size(1))

        # scores = masked_normalization_inf(energies, mask)
        scores = masked_normalization(energies, mask)
        # scores size: batch_size x length
        contexts = (sequence * scores.unsqueeze(-1)).sum(1)

        return contexts, scores
예제 #9
0
    def decode(self, y, memory, src_mask, y_lengths, **kwargs):
        y_emb = self.embed_tgt(y)

        if y_lengths is None:
            trg_mask = src_mask.new_ones([1, 1, 1])
        else:
            trg_mask = sequence_mask(y_lengths, y.size(1)).unsqueeze(1)

        output, states = self.decoder(trg_embed=y_emb,
                                      encoder_output=memory,
                                      src_mask=src_mask,
                                      trg_mask=trg_mask)[:2]

        return output, states
예제 #10
0
def _global_prior(logits, word_idx, lengths):
    """
    Evaluate the probability of a sequence, under a language model

    """

    mask = sequence_mask(lengths)
    labels = (word_idx * mask.long()).contiguous().view(-1)
    _logits = logits.contiguous().view(-1, logits.size(-1))
    loss = F.cross_entropy(_logits, labels, ignore_index=0, reduction='none')

    # normalize by length to avoid mode collapse
    total = loss.sum() / mask.float().sum()

    return total, loss.view(mask.size())
예제 #11
0
def masked_mse(inp_logits, trg_logits, lengths, mask_ids=[]):
    # zero padded timesteps
    mask = sequence_mask(lengths).unsqueeze(-1).float()

    # shape: batch x seq_length x tokens
    loss = F.mse_loss(inp_logits * mask, trg_logits * mask, reduction='none')

    for i in mask_ids:
        loss[:, :, i] = 0

    loss = loss.mean(-1)
    loss = loss * mask.squeeze()
    total_loss = loss.sum() / mask.sum()

    return total_loss, loss
예제 #12
0
def _ce_loss(logits, labels, lengths, ignore_index=0):
    _logits = logits.contiguous().view(-1, logits.size(-1))

    if ignore_index >= 0:
        _labels = labels.contiguous().view(-1)
    else:
        assert lengths is not None
        mask = ~sequence_mask(lengths, labels.size(1))
        _labels = labels.masked_fill_(mask, -1).contiguous().view(-1)

    _loss = F.cross_entropy(_logits,
                            _labels,
                            ignore_index=ignore_index,
                            reduction='none')
    _loss_per_step = _loss.view(labels.size())
    loss = _loss_per_step.sum(-1) / lengths.float()
    return loss, _loss_per_step
예제 #13
0
def kl_length(logits, lengths, eos):
    """
    Length control loss, using a sequence of length labels (with eos token).

    Args:
        logits:
        lengths:
        eos:

    Returns:

    """
    mask = sequence_mask(lengths - 1, lengths.max())
    eos_labels = ((1 - mask) * eos).long().contiguous().view(-1)

    _logits = logits.contiguous().view(-1, logits.size(-1))
    loss = F.cross_entropy(_logits, eos_labels, ignore_index=0)

    return loss
예제 #14
0
def prior_loss(outputs, trg_len, prior, mode, sos_id=1, tau=1, init_h=None):
    # The actual tokens that were used during generating the target seq.
    # When the decoder is trained with 100% teacher forcing,
    # sampled_tokens == trg_inp
    # sample_ids = outputs["dists"].max(-1)[1]
    prior_inps = differentiable_samples(prior.encoder.embed, outputs["dists"],
                                        sos_id)

    if mode == "prior":
        lm_outs = prior(prior_inps, trg_len, init_h)
        loss, loss_i = masked_kld(outputs["logits"],
                                  lm_outs["logits"],
                                  trg_len,
                                  tau=tau,
                                  mask_ids=[0, 1, 2, 3])

    elif mode == "discriminator":
        # feed the embeddings to the LM Discriminator
        lm_outs = prior(prior_inps, trg_len, init_h)
        mask = sequence_mask(trg_len).float()

        # check = F.cross_entropy(
        #     lm_outs["logits"].contiguous().view(-1, lm_outs["logits"].size(-1)),
        #     outputs["dists"].argmax(-1).view(-1), ignore_index=0,
        #     reduction='none')

        prior_log_probs = F.log_softmax(lm_outs["logits"], -1)
        loss_i = dot3D(outputs["dists"].contiguous(),
                       prior_log_probs.contiguous()) * mask

        cross_entropy = loss_i.sum() / mask.sum()

        # avoid collapse
        # agg_logits = outputs["logits"].sum(1) / mask.sum(-1, keepdim=True)
        # entropy = Categorical(logits=agg_logits).entropy().mean()

        loss = -cross_entropy

    else:
        raise ValueError

    return loss, loss_i, lm_outs["logits"]
예제 #15
0
def masked_kld(inp_logits, trg_logits, lengths, tau=1, mask_ids=[]):
    """
    Compute the grounding loss using a pretrained "oracle" LM.
    The loss is computed using the produced posteriors over the vocabulary
    produced by a generator and the posteriors of the "oracle" LM.

    Args:
        logits: the logits of the generator
        words: the argmax of the logits
        oracle: the oracle LM
        tau: the temperature of the softmax
        lengths: the lengths of the target sequence. Used for masking the loss.


    Debug = -F.softmax(_logits, -1) * torch.log(F.softmax(logits, -1) /
                                                F.softmax(_logits, -1))

    Returns:
        the average KL Divergence per timestep (word)

    """

    input_logp = F.log_softmax(inp_logits / tau, -1)
    target_p = F.softmax(trg_logits / tau, -1)

    # zero padded timesteps
    mask = sequence_mask(lengths).unsqueeze(-1).float()

    # shape: batch x seq_length x tokens
    loss = F.kl_div(input_logp * mask, target_p * mask, reduction='none')

    for i in mask_ids:
        loss[:, :, i] = 0

    # sum over words/vocab (KL per word/timestep !)
    loss = loss.sum(-1)

    loss = loss * mask.squeeze()
    total_loss = loss.sum() / mask.sum()

    return total_loss, loss
예제 #16
0
    def kl_loss(self, logits, labels, lengths):
        """
        output (FloatTensor): batch_size x n_classes
        target (LongTensor): batch_size
        """

        _logits = logits.contiguous().view(-1, logits.size(-1))
        _labels = labels.contiguous().view(-1)

        log_prob = F.log_softmax(_logits, dim=1)

        model_prob = self.one_hot.repeat(_labels.size(0), 1)
        model_prob.scatter_(1, _labels.unsqueeze(1), self.high_confidence)

        losses = F.kl_div(log_prob, model_prob, reduction='none')

        mask = sequence_mask(lengths, labels.size(1)).view(-1).float()
        losses = losses.sum(1) * mask
        loss = losses.sum() / mask.sum()

        return loss, losses
예제 #17
0
    def beam(self, x, x_len, sos_id, eos_id, pad_id, beam_size, length_penalty,
             **kwargs):

        enc = self.encode(x, x_len)
        dec_init = self.init_decoder(enc["outputs"], enc["hidden"])
        src_mask = sequence_mask(x_len, x.size(1))

        outputs = beam_search(decoder=self.decoder,
                              size=beam_size,
                              bos_index=sos_id,
                              eos_index=eos_id,
                              pad_index=pad_id,
                              encoder_output=enc["outputs"],
                              encoder_hidden=dec_init,
                              src_mask=src_mask,
                              max_output_length=(x_len.float() *
                                                 1.5).long().max(),
                              alpha=length_penalty,
                              lm_hidden=None,
                              **kwargs)

        return outputs
예제 #18
0
    def translate(self, x, x_lengths, sos_id, y_lengths=None, **kwargs):

        enc = self.encode(x, x_lengths)
        dec_init = self.init_decoder(enc["outputs"], enc["hidden"])

        # Set the target length larger than source.
        # It will be pruned after the EOS anyway.
        if y_lengths is None:
            y_lengths = (x_lengths.float() * 1.5).long()

        src_mask = sequence_mask(x_lengths, x.size(1))
        inp_fake = fake_inputs(x, y_lengths, sos_id)
        dec = self.decode(inp_fake,
                          enc["outputs"],
                          dec_init,
                          x_lengths,
                          src_mask,
                          y_lengths,
                          sampling=1,
                          sampling_mode="argmax",
                          **kwargs)

        return enc, dec
예제 #19
0
    def process_batch(self, x_sos, x_eos, x_len, y_sos, y_eos, y_len,
                      **kwargs):
        """
        The inputs will be the following, assuming this pair of sentences:

        x = ['<sos>', 'every', 'clever', 'cat', 'hates', 'every', 'dog', '<eos>']
        y = ['<sos>', 'κάθε', 'έξυπνη', 'γάτα', 'μισεί', 'κάθε', 'σκύλο', '<eos>']

        Args:
            x_sos: ['<sos>', 'every', 'clever', 'cat', 'hates', 'every', 'dog']
            x_eos: ['every', 'clever', 'cat', 'hates', 'every', 'dog', '<eos>']
            x_len: 7

            y_sos: ['<sos>', 'κάθε', 'έξυπνη', 'γάτα', 'μισεί', 'κάθε', 'σκύλο']
            y_eos: ['κάθε', 'έξυπνη', 'γάτα', 'μισεί', 'κάθε', 'σκύλο', '<eos>']
            y_len: 7

            Note:

                _sos will be the input to decoders
                _eos will be the input to encoders and target for decoders

        Returns:

        """
        decoding = dict(self.config["model"].get("decoding", {}))

        if decoding.get("fusion") is not None:
            decoding["lm"] = self.prior

        outputs = self.model(x_eos, y_sos, x_len, y_len, **decoding)
        losses = dict()
        is_gpt2 = self.get_vocab()[1].is_gpt2

        # Loss calculation
        losses["mt"] = self.criterion(outputs[1]["logits"], y_eos, y_len)[0]

        if "prior" in self.config["losses"] and self.prior is not None:
            f_reg = self.config["losses"]["prior"].get("objective", "kl")

            if f_reg == "mse":
                lm_logits = self.prior(y_sos, y_len)["logits"]
                prior_loss, prior_loss_i = masked_mse(outputs[1]["logits"],
                                                      lm_logits, y_len)

            elif f_reg in ["kl", "rkl"]:
                _tau = self.config["losses"]["prior"]["tau"]

                if is_gpt2:
                    _mask = sequence_mask(y_len, y_sos.size(1)).float()
                    lm_logits = self.prior(y_sos, attention_mask=_mask)[0]
                else:
                    lm_logits = self.prior(y_sos, y_len)["logits"]

                if f_reg == "kl":  # KL(p_prior, p_model)
                    prior_loss, prior_loss_i = masked_kld(
                        outputs[1]["logits"], lm_logits, y_len, _tau)
                else:  # rkl: KL(p_model, p_prior)
                    prior_loss, prior_loss_i = masked_kld(
                        lm_logits, outputs[1]["logits"], y_len, _tau)

                # multiply with tau^2 to make loss tau invariant
                prior_loss = prior_loss * (_tau**2)

            elif self.config["losses"]["prior"].get("objective",
                                                    "kl") == "ppl":
                prob_tm = relax_softmax(outputs[1]["logits"],
                                        tau=1,
                                        gumbel=False,
                                        hard=False)
                prior_inps = differentiable_samples(self.prior.encoder.embed,
                                                    prob_tm, 1)
                lm_logits = self.prior(prior_inps, y_len)["logits"]
                mask = sequence_mask(y_len).float()
                prior_log_probs = F.log_softmax(lm_logits, -1)
                loss_i = dot3D(prob_tm.contiguous(),
                               prior_log_probs.contiguous()) * mask

                cross_entropy = loss_i.sum() / mask.sum()
                prior_loss = -cross_entropy
            else:
                raise ValueError

            losses["prior"] = prior_loss

        return losses, {'model_outputs': outputs}
예제 #20
0
 def encode(self, x, lengths, **kwargs):
     emb = self.embed_src(x)
     pad_mask = sequence_mask(lengths, x.size(1)).unsqueeze(1)
     memory = self.encoder(emb, None, pad_mask)[0]
     return memory, pad_mask