def get_ce_loss(log_probs: FT, batch: OnePairBatch, agg='all') -> FT: ce_losses = -log_probs.gather('unit', batch.tgt_seqs.ids) weights = batch.tgt_seqs.paddings.float() ce_losses = ce_losses * weights if agg == 'batch': return ce_losses.sum(dim='pos') elif agg == 'batch_mean': return ce_losses.sum(dim='pos') / weights.sum(dim='pos') elif agg == 'all': return ce_losses.sum() elif agg == 'char': return ce_losses elif agg == 'char_mean': return ce_losses.sum() / weights.sum() else: raise ValueError(f'Unrecognized value "{agg}" for agg.')
def search_by_probs(self, lengths: LT, label_log_probs: FT) -> Tuple[LT, FT]: max_length = lengths.max().item() samples = get_tensor( torch.LongTensor(list(product([B, I, O], repeat=max_length)))) samples.rename_('sample', 'length') bs = label_log_probs.size('batch') samples = samples.align_to('batch', 'sample', 'length').expand(bs, -1, -1) sample_log_probs = label_log_probs.gather('label', samples) with NoName(lengths): length_mask = get_length_mask(lengths, max_length).rename( 'batch', 'length') length_mask = length_mask.align_to(sample_log_probs) sample_log_probs = (sample_log_probs * length_mask.float()).sum(dim='length') return samples, sample_log_probs
def search(self, lengths: LT, label_log_probs: FT, gold_tag_seqs: Optional[LT] = None) -> Tuple[LT, FT]: samples, sample_log_probs = self.search_by_probs( lengths, label_log_probs) if gold_tag_seqs is not None: gold_tag_seqs = gold_tag_seqs.align_as(samples) max_length = lengths.max().item() with NoName(lengths): length_mask = get_length_mask(lengths, max_length).rename( 'batch', 'length') gold_log_probs = label_log_probs.gather('label', gold_tag_seqs) gold_log_probs = ( gold_log_probs * length_mask.align_as(gold_log_probs)).sum('length') samples = torch.cat([gold_tag_seqs, samples], dim='sample') sample_log_probs = torch.cat([gold_log_probs, sample_log_probs], dim='sample') return samples, sample_log_probs