def _topic_loss(self, inp, dec1, src_lengths, trg_lengths): """ Compute the pairwise distance of various outputs of the seq^3 architecture. Args: enc1: the outputs of the first encoder (input sequence) dec1: the outputs of the first decoder (latent sequence) src_lengths: the lengths of the input sequence trg_lengths: the lengths of the targer sequence (summary) """ enc_mask = sequence_mask(src_lengths).unsqueeze(-1).float() dec_mask = sequence_mask(trg_lengths - 1).unsqueeze(-1).float() enc_embs = self.model.inp_encoder.embed(inp) dec_embs = self.model.compressor.embed.expectation(dec1[3]) if self.config["model"]["topic_idf"]: enc1_energies = self.model.idf(inp) # dec1_energies = expected_vecs(dec1[3], self.model.idf.weight) x_emb, att_x = avg_vectors(enc_embs, enc_mask, enc1_energies) # y_emb, att_y = avg_vectors(dec_reps, dec_mask, dec1_energies) y_emb, att_y = avg_vectors(dec_embs, dec_mask) else: x_emb, att_x = avg_vectors(enc_embs, enc_mask) y_emb, att_y = avg_vectors(dec_embs, dec_mask) distance = self.config["model"]["topic_distance"] loss = pairwise_loss(x_emb, y_emb, distance) return loss, (att_x, att_y)
def centroid_loss(enc_feats, dec_feats, src_lengths, trg_lengths, enc_scores=None, dec_scores=None, distance="cosine", pool_func="mean", mapping: torch.Tensor = None, **kwargs): """ Compute the pairwise distance of various outputs of the seq^3 architecture. """ enc_mask = sequence_mask(src_lengths).unsqueeze(-1).float() dec_mask = sequence_mask(trg_lengths).unsqueeze(-1).float() # Aggregate the vectors of each sequence if pool_func == "mean": x_emb, _ = avg_vectors(enc_feats, enc_mask, enc_scores) y_emb, _ = avg_vectors(dec_feats, dec_mask, dec_scores) elif pool_func == "max": x_emb = enc_feats.max(1)[0] y_emb = dec_feats.max(1)[0] elif pool_func == "sum": x_emb = enc_feats.sum(1) y_emb = dec_feats.sum(1) else: raise ValueError # Apply a rotation operation on the source embedding if mapping is not None: x_emb = torch.matmul(x_emb, mapping) return pairwise_loss(x_emb, y_emb, distance)
def forward(self, sequence, lengths): energies = self.attention(sequence).squeeze() # construct a mask, based on sentence lengths if len(energies.size()) < 2: mask = sequence_mask(lengths, 1) else: mask = sequence_mask(lengths, energies.size(1)) scores = masked_normalization(energies, mask) contexts = (sequence * scores.unsqueeze(-1)).sum(1) return contexts, scores
def cross_entropy_loss(self, logits, labels, lengths=None): """ output (FloatTensor): batch_size x n_classes target (LongTensor): batch_size """ _logits = logits.contiguous().view(-1, logits.size(-1)) if self.ignore_index >= 0: _labels = labels.contiguous().view(-1) else: assert lengths is not None mask = ~sequence_mask(lengths, labels.size(1)) _labels = labels.masked_fill_(mask, -1).contiguous().view(-1) if lengths is None: loss = F.cross_entropy(_logits, _labels, ignore_index=self.ignore_index) return loss else: _loss = F.cross_entropy(_logits, _labels, ignore_index=self.ignore_index, reduction='none') _loss_per_step = _loss.view(labels.size()) loss = _loss.sum() / lengths.float().sum() return loss, _loss_per_step
def forward(self, src, trg, src_lengths, trg_lengths, **kwargs): enc = self.encode(src, src_lengths) src_mask = sequence_mask(src_lengths, src.size(1)) dec_init = self.init_decoder(enc["outputs"], enc["hidden"]) dec = self.decode(trg, enc["outputs"], dec_init, src_lengths, src_mask, trg_lengths, **kwargs) return enc, dec
def forward(self, x, lengths): emb = self.embed(x) # mask padded + future steps pad_mask = sequence_mask(lengths, x.size(1)).unsqueeze(1) mask = pad_mask & subsequent_mask(x.size(1)).type_as(pad_mask) states = self.encoder(emb, None, mask)[0] logits = self.logits(states) return {"logits": logits}
def forward(self, sequence, query, lengths, coverage=None): energies = self.score(sequence, query, coverage) # construct a mask, based on sentence lengths mask = sequence_mask(lengths, energies.size(1)) scores = masked_normalization_inf(energies, mask) # scores = self.masked_normalization(energies, mask) contexts = (sequence * scores.unsqueeze(-1)).sum(1) return contexts, scores
def forward(self, sequence, lengths): # sequence size: batch_size x length x rnn size energies = self.attention(sequence).squeeze() # construct a mask, based on sentence lengths mask = sequence_mask(lengths, energies.size(1)) # scores = masked_normalization_inf(energies, mask) scores = masked_normalization(energies, mask) # scores size: batch_size x length contexts = (sequence * scores.unsqueeze(-1)).sum(1) return contexts, scores
def decode(self, y, memory, src_mask, y_lengths, **kwargs): y_emb = self.embed_tgt(y) if y_lengths is None: trg_mask = src_mask.new_ones([1, 1, 1]) else: trg_mask = sequence_mask(y_lengths, y.size(1)).unsqueeze(1) output, states = self.decoder(trg_embed=y_emb, encoder_output=memory, src_mask=src_mask, trg_mask=trg_mask)[:2] return output, states
def _global_prior(logits, word_idx, lengths): """ Evaluate the probability of a sequence, under a language model """ mask = sequence_mask(lengths) labels = (word_idx * mask.long()).contiguous().view(-1) _logits = logits.contiguous().view(-1, logits.size(-1)) loss = F.cross_entropy(_logits, labels, ignore_index=0, reduction='none') # normalize by length to avoid mode collapse total = loss.sum() / mask.float().sum() return total, loss.view(mask.size())
def masked_mse(inp_logits, trg_logits, lengths, mask_ids=[]): # zero padded timesteps mask = sequence_mask(lengths).unsqueeze(-1).float() # shape: batch x seq_length x tokens loss = F.mse_loss(inp_logits * mask, trg_logits * mask, reduction='none') for i in mask_ids: loss[:, :, i] = 0 loss = loss.mean(-1) loss = loss * mask.squeeze() total_loss = loss.sum() / mask.sum() return total_loss, loss
def _ce_loss(logits, labels, lengths, ignore_index=0): _logits = logits.contiguous().view(-1, logits.size(-1)) if ignore_index >= 0: _labels = labels.contiguous().view(-1) else: assert lengths is not None mask = ~sequence_mask(lengths, labels.size(1)) _labels = labels.masked_fill_(mask, -1).contiguous().view(-1) _loss = F.cross_entropy(_logits, _labels, ignore_index=ignore_index, reduction='none') _loss_per_step = _loss.view(labels.size()) loss = _loss_per_step.sum(-1) / lengths.float() return loss, _loss_per_step
def kl_length(logits, lengths, eos): """ Length control loss, using a sequence of length labels (with eos token). Args: logits: lengths: eos: Returns: """ mask = sequence_mask(lengths - 1, lengths.max()) eos_labels = ((1 - mask) * eos).long().contiguous().view(-1) _logits = logits.contiguous().view(-1, logits.size(-1)) loss = F.cross_entropy(_logits, eos_labels, ignore_index=0) return loss
def prior_loss(outputs, trg_len, prior, mode, sos_id=1, tau=1, init_h=None): # The actual tokens that were used during generating the target seq. # When the decoder is trained with 100% teacher forcing, # sampled_tokens == trg_inp # sample_ids = outputs["dists"].max(-1)[1] prior_inps = differentiable_samples(prior.encoder.embed, outputs["dists"], sos_id) if mode == "prior": lm_outs = prior(prior_inps, trg_len, init_h) loss, loss_i = masked_kld(outputs["logits"], lm_outs["logits"], trg_len, tau=tau, mask_ids=[0, 1, 2, 3]) elif mode == "discriminator": # feed the embeddings to the LM Discriminator lm_outs = prior(prior_inps, trg_len, init_h) mask = sequence_mask(trg_len).float() # check = F.cross_entropy( # lm_outs["logits"].contiguous().view(-1, lm_outs["logits"].size(-1)), # outputs["dists"].argmax(-1).view(-1), ignore_index=0, # reduction='none') prior_log_probs = F.log_softmax(lm_outs["logits"], -1) loss_i = dot3D(outputs["dists"].contiguous(), prior_log_probs.contiguous()) * mask cross_entropy = loss_i.sum() / mask.sum() # avoid collapse # agg_logits = outputs["logits"].sum(1) / mask.sum(-1, keepdim=True) # entropy = Categorical(logits=agg_logits).entropy().mean() loss = -cross_entropy else: raise ValueError return loss, loss_i, lm_outs["logits"]
def masked_kld(inp_logits, trg_logits, lengths, tau=1, mask_ids=[]): """ Compute the grounding loss using a pretrained "oracle" LM. The loss is computed using the produced posteriors over the vocabulary produced by a generator and the posteriors of the "oracle" LM. Args: logits: the logits of the generator words: the argmax of the logits oracle: the oracle LM tau: the temperature of the softmax lengths: the lengths of the target sequence. Used for masking the loss. Debug = -F.softmax(_logits, -1) * torch.log(F.softmax(logits, -1) / F.softmax(_logits, -1)) Returns: the average KL Divergence per timestep (word) """ input_logp = F.log_softmax(inp_logits / tau, -1) target_p = F.softmax(trg_logits / tau, -1) # zero padded timesteps mask = sequence_mask(lengths).unsqueeze(-1).float() # shape: batch x seq_length x tokens loss = F.kl_div(input_logp * mask, target_p * mask, reduction='none') for i in mask_ids: loss[:, :, i] = 0 # sum over words/vocab (KL per word/timestep !) loss = loss.sum(-1) loss = loss * mask.squeeze() total_loss = loss.sum() / mask.sum() return total_loss, loss
def kl_loss(self, logits, labels, lengths): """ output (FloatTensor): batch_size x n_classes target (LongTensor): batch_size """ _logits = logits.contiguous().view(-1, logits.size(-1)) _labels = labels.contiguous().view(-1) log_prob = F.log_softmax(_logits, dim=1) model_prob = self.one_hot.repeat(_labels.size(0), 1) model_prob.scatter_(1, _labels.unsqueeze(1), self.high_confidence) losses = F.kl_div(log_prob, model_prob, reduction='none') mask = sequence_mask(lengths, labels.size(1)).view(-1).float() losses = losses.sum(1) * mask loss = losses.sum() / mask.sum() return loss, losses
def beam(self, x, x_len, sos_id, eos_id, pad_id, beam_size, length_penalty, **kwargs): enc = self.encode(x, x_len) dec_init = self.init_decoder(enc["outputs"], enc["hidden"]) src_mask = sequence_mask(x_len, x.size(1)) outputs = beam_search(decoder=self.decoder, size=beam_size, bos_index=sos_id, eos_index=eos_id, pad_index=pad_id, encoder_output=enc["outputs"], encoder_hidden=dec_init, src_mask=src_mask, max_output_length=(x_len.float() * 1.5).long().max(), alpha=length_penalty, lm_hidden=None, **kwargs) return outputs
def translate(self, x, x_lengths, sos_id, y_lengths=None, **kwargs): enc = self.encode(x, x_lengths) dec_init = self.init_decoder(enc["outputs"], enc["hidden"]) # Set the target length larger than source. # It will be pruned after the EOS anyway. if y_lengths is None: y_lengths = (x_lengths.float() * 1.5).long() src_mask = sequence_mask(x_lengths, x.size(1)) inp_fake = fake_inputs(x, y_lengths, sos_id) dec = self.decode(inp_fake, enc["outputs"], dec_init, x_lengths, src_mask, y_lengths, sampling=1, sampling_mode="argmax", **kwargs) return enc, dec
def process_batch(self, x_sos, x_eos, x_len, y_sos, y_eos, y_len, **kwargs): """ The inputs will be the following, assuming this pair of sentences: x = ['<sos>', 'every', 'clever', 'cat', 'hates', 'every', 'dog', '<eos>'] y = ['<sos>', 'κάθε', 'έξυπνη', 'γάτα', 'μισεί', 'κάθε', 'σκύλο', '<eos>'] Args: x_sos: ['<sos>', 'every', 'clever', 'cat', 'hates', 'every', 'dog'] x_eos: ['every', 'clever', 'cat', 'hates', 'every', 'dog', '<eos>'] x_len: 7 y_sos: ['<sos>', 'κάθε', 'έξυπνη', 'γάτα', 'μισεί', 'κάθε', 'σκύλο'] y_eos: ['κάθε', 'έξυπνη', 'γάτα', 'μισεί', 'κάθε', 'σκύλο', '<eos>'] y_len: 7 Note: _sos will be the input to decoders _eos will be the input to encoders and target for decoders Returns: """ decoding = dict(self.config["model"].get("decoding", {})) if decoding.get("fusion") is not None: decoding["lm"] = self.prior outputs = self.model(x_eos, y_sos, x_len, y_len, **decoding) losses = dict() is_gpt2 = self.get_vocab()[1].is_gpt2 # Loss calculation losses["mt"] = self.criterion(outputs[1]["logits"], y_eos, y_len)[0] if "prior" in self.config["losses"] and self.prior is not None: f_reg = self.config["losses"]["prior"].get("objective", "kl") if f_reg == "mse": lm_logits = self.prior(y_sos, y_len)["logits"] prior_loss, prior_loss_i = masked_mse(outputs[1]["logits"], lm_logits, y_len) elif f_reg in ["kl", "rkl"]: _tau = self.config["losses"]["prior"]["tau"] if is_gpt2: _mask = sequence_mask(y_len, y_sos.size(1)).float() lm_logits = self.prior(y_sos, attention_mask=_mask)[0] else: lm_logits = self.prior(y_sos, y_len)["logits"] if f_reg == "kl": # KL(p_prior, p_model) prior_loss, prior_loss_i = masked_kld( outputs[1]["logits"], lm_logits, y_len, _tau) else: # rkl: KL(p_model, p_prior) prior_loss, prior_loss_i = masked_kld( lm_logits, outputs[1]["logits"], y_len, _tau) # multiply with tau^2 to make loss tau invariant prior_loss = prior_loss * (_tau**2) elif self.config["losses"]["prior"].get("objective", "kl") == "ppl": prob_tm = relax_softmax(outputs[1]["logits"], tau=1, gumbel=False, hard=False) prior_inps = differentiable_samples(self.prior.encoder.embed, prob_tm, 1) lm_logits = self.prior(prior_inps, y_len)["logits"] mask = sequence_mask(y_len).float() prior_log_probs = F.log_softmax(lm_logits, -1) loss_i = dot3D(prob_tm.contiguous(), prior_log_probs.contiguous()) * mask cross_entropy = loss_i.sum() / mask.sum() prior_loss = -cross_entropy else: raise ValueError losses["prior"] = prior_loss return losses, {'model_outputs': outputs}
def encode(self, x, lengths, **kwargs): emb = self.embed_src(x) pad_mask = sequence_mask(lengths, x.size(1)).unsqueeze(1) memory = self.encoder(emb, None, pad_mask)[0] return memory, pad_mask