Ejemplo n.º 1
0
def get_ce_loss(log_probs: FT, batch: OnePairBatch, agg='all') -> FT:
    ce_losses = -log_probs.gather('unit', batch.tgt_seqs.ids)
    weights = batch.tgt_seqs.paddings.float()
    ce_losses = ce_losses * weights
    if agg == 'batch':
        return ce_losses.sum(dim='pos')
    elif agg == 'batch_mean':
        return ce_losses.sum(dim='pos') / weights.sum(dim='pos')
    elif agg == 'all':
        return ce_losses.sum()
    elif agg == 'char':
        return ce_losses
    elif agg == 'char_mean':
        return ce_losses.sum() / weights.sum()
    else:
        raise ValueError(f'Unrecognized value "{agg}" for agg.')
Ejemplo n.º 2
0
 def search_by_probs(self, lengths: LT,
                     label_log_probs: FT) -> Tuple[LT, FT]:
     max_length = lengths.max().item()
     samples = get_tensor(
         torch.LongTensor(list(product([B, I, O], repeat=max_length))))
     samples.rename_('sample', 'length')
     bs = label_log_probs.size('batch')
     samples = samples.align_to('batch', 'sample',
                                'length').expand(bs, -1, -1)
     sample_log_probs = label_log_probs.gather('label', samples)
     with NoName(lengths):
         length_mask = get_length_mask(lengths, max_length).rename(
             'batch', 'length')
     length_mask = length_mask.align_to(sample_log_probs)
     sample_log_probs = (sample_log_probs *
                         length_mask.float()).sum(dim='length')
     return samples, sample_log_probs
Ejemplo n.º 3
0
    def search(self,
               lengths: LT,
               label_log_probs: FT,
               gold_tag_seqs: Optional[LT] = None) -> Tuple[LT, FT]:
        samples, sample_log_probs = self.search_by_probs(
            lengths, label_log_probs)
        if gold_tag_seqs is not None:
            gold_tag_seqs = gold_tag_seqs.align_as(samples)

            max_length = lengths.max().item()
            with NoName(lengths):
                length_mask = get_length_mask(lengths, max_length).rename(
                    'batch', 'length')
            gold_log_probs = label_log_probs.gather('label', gold_tag_seqs)
            gold_log_probs = (
                gold_log_probs *
                length_mask.align_as(gold_log_probs)).sum('length')

            samples = torch.cat([gold_tag_seqs, samples], dim='sample')
            sample_log_probs = torch.cat([gold_log_probs, sample_log_probs],
                                         dim='sample')
        return samples, sample_log_probs