def forward(self, input, conds): z = self.encode(conds[0].view(conds[0].size(0), conds[0].size(1), -1).permute(0, 2, 1)) z = z.view(z.size(0), 64, -1, z.size(2)).permute(0, 1, 3, 2) cond = None if len(conds) > 1: cond = conds[1] # Squeeze out the width and heigth, # as we assume this conditionig is global cond = safe_squeeze(cond, 1) cond = safe_squeeze(cond, 1) x2 = self.decode(z, cond=cond) x2 = x2.view(x2.size(0), -1, self.quantizer.num_levels, x2.size(2), x2.size(3)).permute(0, 3, 4, 1, 2) # Truncate x2 to input size _, _, h_in, *_ = input.shape _, _, h_x2, *_ = x2.shape assert h_x2 >= h_in, f"The reconstruction ({x2.shape}) must be as large as the input {input.shape}." h_d = h_x2 - h_in h_s = h_d // 2 h_e = h_s + h_in x2 = x2[:, :, h_s:h_e, :] return x2
def minibatch_loss_and_tokens(self, batch): kl_mult = self.compute_kl_mult() self.batch_id += 1 if not self.training: self.batch_id = 0 x = batch["features"][..., :1] b, t, f, c = x.size() if "features_len" in list(batch.keys()): x_lens = batch["features_len"] else: x_lens = torch.tensor([t] * b, dtype=torch.long, device=x.device) x = self.input_layer(x) x_mask = utils.get_mini_batch_mask(x, x_lens).to(x.device.type) pre_bottleneck_h, z, latent_loss, info, z_lens = self.encode(x, x_lens) (enc_gru, ogru_length) = self.CPCgru(pre_bottleneck_h, z_lens) if self.cpc is not None: cpc_loss = self.cpc.cpc_loss( gru_input_feats=pre_bottleneck_h, gru_output_feats=enc_gru, feats_len=ogru_length, )[0] else: cpc_loss = None mixed_z = self.mix_latents(z) # expand tensors _ = self.pre_bottleneck(self.upsample(x, pre_bottleneck_h)) _ = self.post_bottleneck(self.upsample(x, z)) mixed_z = self.post_latent_mixer(self.upsample(x, mixed_z)) llk_fn = self.likelihood(utils.safe_squeeze(mixed_z, dim=-2)) llk = llk_fn.log_prob(utils.safe_squeeze( x, dim=-1)) * x_mask.unsqueeze(-1) llk = llk.sum((1, 2)) elbo = llk - latent_loss annealed_elbo = llk - kl_mult * latent_loss loss = annealed_elbo.mean().mul_(-1) if cpc_loss is not None: loss += cpc_loss total_time_steps_batch = float(torch.sum(x_mask)) if not self.training: loss /= total_time_steps_batch details = { "neg_llk": -llk.sum() / total_time_steps_batch, "elbo": elbo.sum() / total_time_steps_batch, "kl": latent_loss.sum() / total_time_steps_batch, "cpc_loss": cpc_loss, "annealing_factor": kl_mult } return loss, details, None
def minibatch_loss_and_tokens(self, batch): kl_mult = self.compute_kl_mult() self.batch_id += 1 if not self.training: self.batch_id = 0 x = batch["features"][..., :1] b, t, f, c = x.size() if "features_len" in list(batch.keys()): x_lens = batch["features_len"] else: x_lens = torch.tensor([t] * b, dtype=torch.long, device=x.device) x = self.input_layer(x) k = self.reconstruction_field if k == 'features': y = x y_lens = x_lens else: y = batch[k][..., :1] b, _, _, _ = y.size() if f"{k}_len" in list(batch.keys()): y_lens = batch[f"{k}_len"] else: y_lens = torch.tensor([t] * b, dtype=torch.long, device=y.device) y_mask = utils.get_mini_batch_mask(y, y_lens).to(y.device.type) pre_bottleneck_h, z, latent_loss, info = self.encode(x, x_lens) mixed_z = self.mix_latents(z) # expand tensors _ = self.pre_bottleneck(self.upsample(x, pre_bottleneck_h)) _ = self.post_bottleneck(self.upsample(x, z)) mixed_z = self.post_latent_mixer(self.upsample(x, mixed_z)) llk_fn = self.likelihood(utils.safe_squeeze(mixed_z, dim=-2)) llk = llk_fn.log_prob(utils.safe_squeeze( y, dim=-1)) * y_mask.unsqueeze(-1) llk = llk.sum((1, 2)) elbo = llk - latent_loss annealed_elbo = llk - kl_mult * latent_loss loss = annealed_elbo.mean().mul_(-1) total_time_steps_batch = float(torch.sum(y_mask)) details = { "neg_llk": -llk.sum() / total_time_steps_batch, "elbo": elbo.sum() / total_time_steps_batch, "kl": latent_loss.sum() / total_time_steps_batch, "annealing_factor": kl_mult } return loss, details, None
def encode(self, x, x_lens): pre_bottleneck_h, z_mask = self.encoder(x, x_lens) z_lens = z_mask.sum(1) z, latent_loss, info = self.bottleneck(utils.safe_squeeze( pre_bottleneck_h, dim=-2), mask=z_mask) return pre_bottleneck_h, z, latent_loss, info, z_lens
def forward(self, input, conds): del input # unused #z = conds[0].permute(0, 3, 2, 1) z = self.conv(conds[0].view(conds[0].size(0), conds[0].size(1), -1).permute(0, 2, 1)) z = z.view(z.size(0), 64, -1, z.size(2)).permute(0, 1, 2, 3) cond = None if len(conds) > 1: cond = conds[1] # Squeeze out the width and heigth, # as we assume this conditionig is global cond = safe_squeeze(cond, 1) cond = safe_squeeze(cond, 1) x2 = self.decode(z, cond=cond) return x2.permute(0, 3, 2, 1).unsqueeze(3)
def minibatch_loss(self, batch): # from distsup.utils import ptvsd; ptvsd() log_probs, log_prob_lens = self(batch['features'], batch['features_len']) targets = batch['targets'].int() targets_len = batch['targets_len'] # log_probs: (bs x t x 1 x nc) -> (t x bs x nc) log_probs = utils.safe_squeeze(log_probs, 2).permute(1, 0, 2) loss = self.ctc(log_probs, targets, log_prob_lens, targets_len) / log_prob_lens.size(0) decodes = utils.greedy_ctc_decode(log_probs, log_prob_lens) cer = utils.error_rate( decodes, [t[:tl] for t, tl in zip(targets.to('cpu').numpy(), targets_len)]) details = { 'cer': torch.tensor(cer), 'main_loss': loss, } if self.adversarial is not None: friend_loss, adv_loss, adv_details = self.adversarial.loss( batch['spkid']) loss = loss + friend_loss # + adv_loss details['adv_friend_loss'] = friend_loss details['adv_adv_loss'] = adv_loss details['adv_acc'] = adv_details['acc'] return loss, details
def decode(self, batch): # Call forward() on this model log_probs, log_prob_lens = self(batch['features'], batch['features_len']) # (bs x t x 1 x nc) --> (t x bs x nc) log_probs = utils.safe_squeeze(log_probs, 2).permute(1, 0, 2) # 'decodes' is a Python list of the sequence of best path tokens, per sample, no blanks # 'decodesWithBlanks' is the same but keeps the blanks for path generation and processing # 'log_probs' shape is [ maxLengthDecodeSequences, batchSize, num columns] decodes, decodesWithBlanks = utils.greedy_ctc_decode(log_probs, log_prob_lens, return_raw=True) szLongestProbSequence = log_prob_lens[0].item() batchSize = log_probs.shape[1] assert len(decodesWithBlanks[0]) == szLongestProbSequence assert len(decodesWithBlanks[0]) == log_probs.shape[0] assert szLongestProbSequence == log_probs.shape[0] try: # Some datasets are fully transcribed. Use if available. targets = batch['targets'].int() targets_len = batch['targets_len'] except: # We have no targets and therefore run a forward() only recognition. targets = None targets_len = None # Pretty print of paths, strings and meanings self.dataset.decode(self.aligner, decodesWithBlanks, decodes, log_probs, log_prob_lens, targets, targets_len, batch, self.verbose)
def loss(self, logits, targets): logits = safe_squeeze(logits, -1) logits = logits.permute(0, 3, 2, 1) B, C, H, W = logits.shape logits = logits.expand(B, 3, H, W) targets = targets.permute(0, 3, 2, 1) targets = targets.expand(B, 3, H, W) return F.l1_loss(self.vgg(logits * 2 - 1), self.vgg(targets * 2 - 1), reduction='none')
def forward(self, x): # x: (bsz x dim x t) x = x.permute(0, 2, 1) x = self.conv(x) x = x.view(x.size(0), self.hid_channels, self.image_height, x.size(-1)) x2 = self.conv_stack(x) x2 = x2.view(x2.size(0), -1, self.quantizer.num_levels, x2.size(2), x2.size(3)) # (bs x 1 x 1 x h x t) -> (bs x t x 1 x h x 1) # XXX This should squeeze out 1 and leave channels as '3', but not sure return utils.safe_squeeze(x2.permute(0, 4, 3, 1, 2), -1)
def forward(self, x, conds=()): """ x: BS x Dim x T conds: list of BS x DimC x T/k """ x_skip = 0 if self.ahead_corruption is not None: ber = torch.distributions.bernoulli.Bernoulli( torch.tensor([1.0 - self.ahead_corruption], device=x.device)) mask = utils.safe_squeeze(ber.sample(sample_shape=x.size()), -1) x_corrupt = x * mask if self.ahead_fraction is not None: probs = (np.ones((self.ahead_frames + 1, ), dtype=np.float32) * self.ahead_fraction / self.ahead_frames) probs[0] = 1.0 - self.ahead_fraction nframes = np.random.choice(self.ahead_frames + 1, p=probs) else: nframes = self.ahead_frames contexts = ('past', 'future') if self.bidirectional else ('past', ) for ctx in contexts: if nframes == 0: x_shift = x elif ctx == 'past': # Apply padding on the time axis (dim=2) x_shift = F.pad(x, (nframes, 0))[:, :, :-nframes] elif ctx == 'future': x_shift = F.pad(x, (0, nframes))[:, :, nframes:] if self.ahead_corruption is not None: # Stack on the dim axis x_shift = torch.cat([x_shift, x_corrupt], dim=1) if ctx == 'future': x_shift = torch.flip(x_shift, dims=[2]) x_res = self.x_to_res(x_shift) x_skip += self.res_to_skip(x_res) for res_to_hid, hid_to_skip, hid_to_res in zip( self.res_to_hid, self.hid_to_skip, self.hid_to_res): x_hid = res_to_hid(x_res, conds) x_hid = F.dropout(x_hid, self.dropout, self.training, True) x_skip += hid_to_skip(x_hid) if hid_to_res is None: x_res = None # We don't use the last residual output else: x_res = x_res + hid_to_res(x_hid) for skip_to_out in self.skip_to_out: x_skip = torch.relu(x_skip) x_skip = F.dropout(x_skip, self.dropout, self.training, True) x_skip = skip_to_out(x_skip) return x_skip
def forward(self, x, conds=()): """ x: BS x T x H x 1 conds: list of BS x DimC x T/k """ # make the height the channel x = safe_squeeze(x, 3).transpose(1, 2) logits = self.wave_net.forward(x, conds) # move the channel back to height logits = logits.transpose(1, 2) # add the channel dim, the logit logits = logits.reshape( [logits.size(0), logits.size(1), -1, 1, self.quantizer.num_levels]) return logits
def forward(self, x, conds): """x is BS x C x T x H!!! each c in conds is BS x T' x 1 x C' """ assert len(conds) == len(self.cond_convs) bs, c, t = x.shape[:3] for cconv, cond in zip(self.cond_convs, conds): c_bs, c_t, c_h, c_c = cond.size() cond = safe_squeeze(cond, 2).permute(0, 2, 1) cond = cconv(cond) # expand cond to length of x cond = cond.repeat_interleave(t // c_t, 2) if x.dim() == 4: cond = cond.unsqueeze(3) x = x + cond return x
def minibatch_loss(self, batch): # Call forward() on this model log_probs, log_prob_lens = self(batch['features'], batch['features_len']) targets = batch['targets'].int() targets_len = batch['targets_len'] # log_probs: (bs x t x 1 x nc) -> (t x bs x nc) log_probs = utils.safe_squeeze(log_probs, 2).permute(1, 0, 2) loss = self.ctc(log_probs, targets, log_prob_lens, targets_len) / log_prob_lens.size(0) # 'decodes' is a Python list of the sequence of best path tokens, per sample, no blanks # 'decodesWithBlanks' is the same but keeps the blanks for path generation and processing # 'log_probs' shape is [ maxLengthDecodeSequences, batchSize, num columns] decodes, decodesWithBlanks = utils.greedy_ctc_decode(log_probs, log_prob_lens, return_raw=True) szLongestProbSequence = log_prob_lens[0].item() batchSize = log_probs.shape[1] assert len(decodesWithBlanks[0]) == szLongestProbSequence assert len(decodesWithBlanks[0]) == log_probs.shape[0] assert szLongestProbSequence == log_probs.shape[0] # Pretty print of paths, strings and meanings # Also, write path to output file if requested. self.dataset.decode(self.aligner, decodesWithBlanks, decodes, log_probs, log_prob_lens, targets, targets_len, batch, self.verbose) # Calculate Levenshtein character (or label) error-rate, on clean strings cer = utils.error_rate( decodes, [t[:tl] for t, tl in zip(targets.to('cpu'), targets_len)]) return loss, {'cer': torch.tensor(cer)}
def retrieve_saved_input(self): # Pick 'indices' and squeeze to (bsz x L) indices = utils.safe_squeeze(self.input['indices'], -1) indices = utils.safe_squeeze(indices, -1) self.input = None return indices
def loss(self, features, targets, features_len=None, targets_len=None): # the features may be padded if features_len is None: assert targets_len is None assert features.shape[1] == targets.shape[1], ( f"The lengths of the targets and the inputs should " f"be the same for a framewise prediction. " f"Currently: {targets.shape[1]} and {features.shape[1]} respectively." ) else: assert (torch.all(features_len == targets_len) and (features.shape[1] >= targets.shape[1])) lens = features_len if lens is None: lens = torch.full((features.shape[0], ), fill_value=features.shape[1], device=targets.device) hidden = self(self.input) feat_aligned_len = features.shape[1] hidden_aligned_len = hidden.shape[1] assert feat_aligned_len >= lens.max(), ( f"Incompatible shapes for features, hidden, targets: " f"{(features.shape, hidden.shape, targets.shape)}") targets = targets.long() rate_factor = feat_aligned_len // hidden_aligned_len assert (feat_aligned_len % hidden_aligned_len) == 0, ( "The hidden (captured) representation should evenly divide the " "features length") hidden = hidden.repeat_interleave(rate_factor, dim=1) assert lens.max() <= hidden.shape[1], ( f" Incompatible shapes for lens, hidden.shape[1]: " f"{(lens.max(), hidden.shape[1])}") hidden = hidden[:, :targets.shape[1]].contiguous() pred_labels = utils.safe_squeeze(hidden.argmax(dim=3), 2) accs = (pred_labels == targets).float() losses = F.cross_entropy(utils.safe_squeeze(hidden, 2).permute(0, 2, 1), targets, reduction="none") mask = utils.get_mask1d(lens, mask_length=losses.size(1)) mask = mask / mask.sum() if not self.ignore_padding: mask[:] = 1 acc = (accs * mask).sum() loss = (losses * mask).sum() if logger.is_currently_logging(): logger.log_mpl_figure( "framewise_debug", self.plot(features, F.softmax(hidden.detach(), dim=-1))) details = {"loss": loss, "acc": acc, "out_seq": pred_labels.detach()} return loss, details
def sample(self, logits): logits = safe_squeeze(logits, -1) return Normal(logits, 1.0).sample()
def mean_field(self, logits): logits = safe_squeeze(logits, -1) return torch.sigmoid(logits)
def sample(self, logits): logits = safe_squeeze(logits, -1) return Laplace(logits, 1.0).sample()
def forward(ctx, log_probs, act_lens, graph_matrices, neg_inf=-np.inf): logsumexp = torch.logsumexp log_probs = log_probs.detach() log_probs = safe_squeeze(log_probs, 2).transpose(0, 1).contiguous() T, bs, _ = log_probs.size() assert graph_matrices[0].size(0) in [1, bs] assert all( sm.size(0) == graph_matrices[0].size(0) for sm in graph_matrices) if graph_matrices[0].size(0) == 1: graph_matrices = [gm.expand(bs, -1, -1) for gm in graph_matrices] (states_mat, ilabels_mat, weights_mat, terminal_mat, states_mat_out, ilabels_mat_out, weights_mat_out, _) = graph_matrices terminal_mat = terminal_mat.squeeze(-1) _, n, _ = states_mat.size() # a helper to select the next indices for a transition def get_idx(m, i): _bs = m.size(0) return torch.gather(m, 1, i.view(_bs, -1)).view(i.size()) lalpha = torch.full((bs, n), neg_inf, device=log_probs.device) lalpha[:, 0] = 0 lalpha0 = lalpha.clone() lalphas = torch.full((T, bs, n), neg_inf, device=log_probs.device) # The utterances are sorted according to length descending. # Rather than masking, stop updates to alphas when an utterance ends. assert act_lens.tolist() == sorted(act_lens, reverse=True) last_iter_end = 0 for bitem in range(bs, 0, -1): iter_end = act_lens[bitem - 1] for t in range(last_iter_end, iter_end): lalphas[t] = lalpha token_probs = weights_mat[:bitem].clone() token_probs += get_idx(lalpha[:bitem], states_mat[:bitem]) token_probs += get_idx(log_probs[t, :bitem], ilabels_mat[:bitem]) logsumexp(token_probs, dim=-1, out=lalpha[:bitem]) last_iter_end = iter_end log_cost = logsumexp(lalpha + terminal_mat, dim=-1) lbeta = terminal_mat.clone() logprobs_grad = torch.zeros_like(log_probs) last_iter_end = T for bitem in range(1, bs + 1): if bitem < bs: iter_end = act_lens[bitem] else: iter_end = 0 for t in range(last_iter_end - 1, iter_end - 1, -1): token_probs = weights_mat_out[:bitem].clone() token_probs += get_idx(lbeta[:bitem], states_mat_out[:bitem]) token_probs += get_idx(log_probs[t, :bitem], ilabels_mat_out[:bitem]) logsumexp(token_probs, dim=-1, out=lbeta[:bitem]) token_probs += (lalphas[t, :bitem] - log_cost[:bitem].unsqueeze(-1)).unsqueeze(-1) token_probs.exp_() logprobs_grad[t, :bitem].scatter_add_( 1, ilabels_mat_out[:bitem].view(bitem, -1), token_probs.view(bitem, -1)) last_iter_end = iter_end ctx.grads = logprobs_grad.transpose(0, 1).unsqueeze(2) # approximate the numerical error log_cost0 = logsumexp(lalpha0 + lbeta, dim=1) if torch.abs(log_cost - log_cost0).max().item() > 1e-3: print('forward_backward num error: fwd losses %s bwd losses %s' % (log_cost, log_cost0)) return log_cost
def sample(self, logits): return safe_squeeze(logits, -1)
def mean_field(self, logits): logits = safe_squeeze(logits, -1) return logits
def _get_normal(self, logits): loc, scale = logits.chunk(2, dim=-1) loc = safe_squeeze(loc, -1) scale = torch.exp(safe_squeeze(scale, -1)) return Normal(loc, scale)
def loss(self, logits, targets): logits = safe_squeeze(logits, -1) assert logits.size() == targets.size() return F.mse_loss(logits, targets, reduction='none')
def evaluate(self, batches): tot_examples = 0. tot_loss = 0. tot_detached_probesloss = 0. tot_backprop_probesloss = 0. tot_errs = 0. alis_es = [] alis_gt = [] alis_lens = [] total_stats = {} first_batch = None for batch in batches: if first_batch is None: first_batch = copy.deepcopy(batch) num_examples = batch['features'].shape[0] loss, stats, tokens = self.minibatch_loss_and_tokens(batch) # Run the probes detached_loss, backprop_loss, probes_details = self.probes_loss( batch) stats.update(probes_details) if tokens is not None: # Tokens should be in layout B x W x 1 x 1 tokens = utils.safe_squeeze(tokens, dim=3) tokens = utils.safe_squeeze(tokens, dim=2) feat_len = batch['features_len'] alis_lens.append(feat_len) # the tokens should match the rate of the alignment ali_es = self.align_tokens_to_features(batch, tokens) assert (ali_es.shape[0] == batch['features'].shape[0]) assert (ali_es.shape[1] == batch['features'].shape[1]) alis_es.append(ali_es[:, :]) if 'alignment' in batch: ali_gt = batch['alignment'] ali_len = batch['alignment_len'] assert ((ali_len == feat_len).all()) alis_gt.append(ali_gt) tot_examples += num_examples tot_loss += loss * num_examples tot_errs += stats.get('err', np.nan) * num_examples tot_detached_probesloss += detached_loss * num_examples tot_backprop_probesloss += backprop_loss * num_examples for k, v in stats.items(): if k == 'segmental_values': if logger.is_currently_logging(): import matplotlib.pyplot as plt f = plt.figure(dpi=300) plt.plot(v.data.cpu().numpy(), 'r.-') f.set_tight_layout(True) logger.log_mpl_figure(f'segmentation_values', f) elif utils.is_scalar(v): if k not in total_stats: total_stats[k] = v * num_examples else: total_stats[k] += v * num_examples # loss is special, as we use it e.g. for learn rate control # add all signals that we train agains, but remove the passive ones all_scores = { 'loss': (tot_loss + tot_backprop_probesloss) / tot_examples, 'probes_backprop_loss': tot_backprop_probesloss / tot_examples, 'probes_detached_loss': tot_detached_probesloss / tot_examples, 'err': tot_errs / tot_examples, 'probes_loss': (tot_detached_probesloss + tot_backprop_probesloss) / tot_examples } for k, v in total_stats.items(): all_scores[k] = v / tot_examples if (len(alis_es) > 0) and (len(alis_gt) > 0): # If we have gathered any alignments f1_scores = dict(precision=[], recall=[], f1=[]) for batch in zip(alis_gt, alis_es, alis_lens): batch = [t.detach().cpu().numpy() for t in batch] for k, v in scoring.compute_f1_scores(*batch, delta=1).items(): f1_scores[k].extend(v) for k in ('f1', 'precision', 'recall'): print(f"f1/{k}: {np.mean(f1_scores[k])}") logger.log_scalar(f'f1/{k}', np.mean(f1_scores[k])) alis_es = self._unpad_and_concat(alis_es, alis_lens) alis_gt = self._unpad_and_concat( alis_gt, alis_lens) if len(alis_gt) else None scores_to_compute = [('', lambda x: x)] if alis_gt is not None and self.pad_symbol is not None: not_pad = (alis_gt != self.pad_symbol) scores_to_compute.append(('nonpad_', lambda x: x[not_pad])) if alis_gt is not None and alis_es.min() < 0: not_pad2 = (alis_es != -1) scores_to_compute.append( ('validtokens_', lambda x: x[not_pad2])) for prefix, ali_filter in scores_to_compute: es = ali_filter(alis_es) if alis_gt is not None: gt = ali_filter(alis_gt) mapping_scores, mapping = self._mapping_metrics( gt, es, prefix=prefix) all_scores.update(mapping_scores) # Run the segmentation plottin with mapping if logger.is_currently_logging(): _, _, tokens = self.minibatch_loss_and_tokens( first_batch) self.plot_input_and_alignments( first_batch['features'], alignment_es=tokens, alignment_gt=first_batch['alignment'], mapping=mapping, imshow_kwargs=dict(cmap='Greys'), log_suffix=f'{prefix[:-1]}') clustering_scores = self._clustering_metrics(gt, es, prefix=prefix) all_scores.update(clustering_scores) perplexity_scores = self._perplexity_metrics(es, prefix=prefix) all_scores.update(perplexity_scores) return all_scores
def align_tokens_to_features(self, batch, tokens): # No downsampling in our case return utils.safe_squeeze(tokens, 1)
def loss(self, logits, targets): logits = safe_squeeze(logits, -1) assert logits.size() == targets.size() return F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
def sample(self, logits): logits = safe_squeeze(logits, -1) probs = torch.sigmoid(logits) return (torch.rand_like(probs) < probs).float()
def loss(self, features, targets, features_len=None, targets_len=None): # the features may be padded if features_len is None: assert targets_len is None assert features.shape[1] == targets.shape[1], ( f"The lengths of the targets and the inputs should " f"be the same for a framewise prediction. " f"Currently: {targets.shape[1]} and {features.shape[1]} respectively." ) else: assert (torch.all(features_len == targets_len) and (features.shape[1] >= targets.shape[1])) features_len = self.calculateFeatureLens(features, features_len) inputs_len, rate_factor = self.calculateInputLengths( self.input, features, features_len) feat_aligned_len = features.shape[1] assert feat_aligned_len >= features_len.max(), ( f"Incompatible shapes for features, pred, targets: " f"{(features.shape, pred.shape, targets.shape)}") targets = targets.long() details = {} total_loss = 0 for pred_name, pred in self(self.input, inputs_len).items(): hidden_aligned_len = pred.shape[1] assert (feat_aligned_len % hidden_aligned_len) == 0, ( "The hidden (captured) representation should evenly divide the " "features length") pred = pred.repeat_interleave(rate_factor, dim=1) assert features_len.max() <= pred.shape[1], ( f" Incompatible shapes for features_len, pred.shape[1]: " f"{(features_len.max(), pred.shape[1])}") pred = pred[:, :targets.shape[1]].contiguous() pred_labels = utils.safe_squeeze(pred.argmax(dim=3), 2) accs = (pred_labels == targets).float() losses = F.cross_entropy(utils.safe_squeeze(pred, 2).permute(0, 2, 1), targets, reduction="none") mask = utils.get_mask1d(features_len.to(losses.device), mask_length=losses.size(1)) mask = mask / mask.sum() if not self.ignore_padding: mask[:] = 1 acc = (accs * mask).sum() loss = (losses * mask).sum() if logger.is_currently_logging(): logger.log_mpl_figure( "framewise_debug_" + pred_name, self.plot(features, F.softmax(pred.detach(), dim=-1))) total_loss = total_loss + loss details.update({ "loss_" + pred_name: loss, "acc_" + pred_name: acc, "out_seq_" + pred_name: pred_labels.detach() }) return total_loss, details
def path_reduction(log_probs, act_lens, graph_matrices, red_kind='logsumexp', neg_inf=-1e20): """ Compute a sum of all paths through a graph. Args: log_probs: bs x T x 1 x NUM_SYMBOLS tensor of log_probs of emitting symbols act_lens: bs tensor of lengths of utternaces red_kind: logsumexp / viterbi - chooses between aggregating al paths by summing their probabilities (logsumexp of logprobs), or by taking the maximally probable one. Also encoded which reduction engige ot use: logsumexp_fwb forces a forward-backward algo, while logsumexp_autodiff uses backward pass using autodiff. graphs_matrices: a tuple of four matrices of shape bs x N [x K] that encode the transitions and weights in the graph neg_inf: what value to use for improbable events (-1e10 or -1e20 are OK) Returns: tensor of shape bs: a sum of weigths on the maximally probable path or on all paths """ if (red_kind == 'logsumexp_fwb' or (red_kind == 'logsumexp' and len(graph_matrices) == 8)): return path_logsumexp(log_probs, act_lens, graph_matrices, -1e20) log_probs = safe_squeeze(log_probs, 2).transpose(0, 1).contiguous() _, bs, _ = log_probs.size() assert graph_matrices[0].size(0) in [1, bs] assert all( sm.size(0) == graph_matrices[0].size(0) for sm in graph_matrices) # This can happen if we get the matrices for full forward-backward # and here we only need the ones for worward if len(graph_matrices) == 8: graph_matrices = graph_matrices[:4] if graph_matrices[0].size(0) == 1: graph_matrices = [gm.expand(bs, -1, -1) for gm in graph_matrices] states_mat, ilabels_mat, weights_mat, terminal_mat = graph_matrices _, n, k = states_mat.size() if red_kind in ['logsumexp', 'logsumexp_autodiff']: # reduction = torch.logsumexp reduction = torch.logsumexp else: assert red_kind in ['viterbi', 'viterbi_autodiff'] def reduction(t, dim): return torch.max(t, dim)[0] # a helper to select the next indices for a transition def get_idx(m, i): _bs = m.size(0) return torch.gather(m, 1, i.view(_bs, n * k)).view((_bs, n, k)) lalpha = torch.full((bs, n), neg_inf, device=log_probs.device) lalpha[:, 0] = 0 # The utterances are sorted according to length descending. # Rather than masking, stop updates to alphas when an utterance ends. assert act_lens.tolist() == sorted(act_lens, reverse=True) last_iter_end = 0 for bitem in range(bs, 0, -1): iter_end = act_lens[bitem - 1] for t in range(last_iter_end, iter_end): # print(torch.softmax(lalpha[0], -1)) token_probs = (get_idx(lalpha[:bitem], states_mat[:bitem]) + weights_mat[:bitem] + get_idx(log_probs[t, :bitem], ilabels_mat[:bitem])) la = reduction(token_probs, dim=-1) lalpha = lalpha.clone() lalpha[:bitem] = la last_iter_end = iter_end path_sum = reduction(lalpha + terminal_mat.squeeze(2), dim=-1) return path_sum
def loss(self, logits, targets): logits = safe_squeeze(logits, -1) assert logits.size() == targets.size( ), f"{logits.size()} != {targets.size()}" return F.l1_loss(logits, targets, reduction='none')