示例#1
0
    def forward(self, input, conds):
        z = self.encode(conds[0].view(conds[0].size(0), conds[0].size(1),
                                      -1).permute(0, 2, 1))
        z = z.view(z.size(0), 64, -1, z.size(2)).permute(0, 1, 3, 2)
        cond = None
        if len(conds) > 1:
            cond = conds[1]
            # Squeeze out the width and heigth,
            # as we assume this conditionig is global
            cond = safe_squeeze(cond, 1)
            cond = safe_squeeze(cond, 1)
        x2 = self.decode(z, cond=cond)
        x2 = x2.view(x2.size(0), -1, self.quantizer.num_levels, x2.size(2),
                     x2.size(3)).permute(0, 3, 4, 1, 2)

        # Truncate x2 to input size
        _, _, h_in, *_ = input.shape
        _, _, h_x2, *_ = x2.shape

        assert h_x2 >= h_in, f"The reconstruction ({x2.shape}) must be as large as the input {input.shape}."

        h_d = h_x2 - h_in
        h_s = h_d // 2
        h_e = h_s + h_in

        x2 = x2[:, :, h_s:h_e, :]

        return x2
    def minibatch_loss_and_tokens(self, batch):
        kl_mult = self.compute_kl_mult()
        self.batch_id += 1
        if not self.training:
            self.batch_id = 0

        x = batch["features"][..., :1]
        b, t, f, c = x.size()
        if "features_len" in list(batch.keys()):
            x_lens = batch["features_len"]
        else:
            x_lens = torch.tensor([t] * b, dtype=torch.long, device=x.device)
        x = self.input_layer(x)
        x_mask = utils.get_mini_batch_mask(x, x_lens).to(x.device.type)
        pre_bottleneck_h, z, latent_loss, info, z_lens = self.encode(x, x_lens)

        (enc_gru, ogru_length) = self.CPCgru(pre_bottleneck_h, z_lens)
        if self.cpc is not None:
            cpc_loss = self.cpc.cpc_loss(
                gru_input_feats=pre_bottleneck_h,
                gru_output_feats=enc_gru,
                feats_len=ogru_length,
            )[0]
        else:
            cpc_loss = None

        mixed_z = self.mix_latents(z)
        # expand tensors
        _ = self.pre_bottleneck(self.upsample(x, pre_bottleneck_h))
        _ = self.post_bottleneck(self.upsample(x, z))
        mixed_z = self.post_latent_mixer(self.upsample(x, mixed_z))
        llk_fn = self.likelihood(utils.safe_squeeze(mixed_z, dim=-2))
        llk = llk_fn.log_prob(utils.safe_squeeze(
            x, dim=-1)) * x_mask.unsqueeze(-1)
        llk = llk.sum((1, 2))
        elbo = llk - latent_loss
        annealed_elbo = llk - kl_mult * latent_loss
        loss = annealed_elbo.mean().mul_(-1)
        if cpc_loss is not None:
            loss += cpc_loss
        total_time_steps_batch = float(torch.sum(x_mask))
        if not self.training:
            loss /= total_time_steps_batch
        details = {
            "neg_llk": -llk.sum() / total_time_steps_batch,
            "elbo": elbo.sum() / total_time_steps_batch,
            "kl": latent_loss.sum() / total_time_steps_batch,
            "cpc_loss": cpc_loss,
            "annealing_factor": kl_mult
        }
        return loss, details, None
    def minibatch_loss_and_tokens(self, batch):
        kl_mult = self.compute_kl_mult()
        self.batch_id += 1
        if not self.training:
            self.batch_id = 0

        x = batch["features"][..., :1]
        b, t, f, c = x.size()
        if "features_len" in list(batch.keys()):
            x_lens = batch["features_len"]
        else:
            x_lens = torch.tensor([t] * b, dtype=torch.long, device=x.device)
        x = self.input_layer(x)

        k = self.reconstruction_field
        if k == 'features':
            y = x
            y_lens = x_lens
        else:
            y = batch[k][..., :1]
            b, _, _, _ = y.size()
            if f"{k}_len" in list(batch.keys()):
                y_lens = batch[f"{k}_len"]
            else:
                y_lens = torch.tensor([t] * b,
                                      dtype=torch.long,
                                      device=y.device)

        y_mask = utils.get_mini_batch_mask(y, y_lens).to(y.device.type)

        pre_bottleneck_h, z, latent_loss, info = self.encode(x, x_lens)
        mixed_z = self.mix_latents(z)
        # expand tensors
        _ = self.pre_bottleneck(self.upsample(x, pre_bottleneck_h))
        _ = self.post_bottleneck(self.upsample(x, z))
        mixed_z = self.post_latent_mixer(self.upsample(x, mixed_z))
        llk_fn = self.likelihood(utils.safe_squeeze(mixed_z, dim=-2))
        llk = llk_fn.log_prob(utils.safe_squeeze(
            y, dim=-1)) * y_mask.unsqueeze(-1)
        llk = llk.sum((1, 2))
        elbo = llk - latent_loss
        annealed_elbo = llk - kl_mult * latent_loss
        loss = annealed_elbo.mean().mul_(-1)
        total_time_steps_batch = float(torch.sum(y_mask))
        details = {
            "neg_llk": -llk.sum() / total_time_steps_batch,
            "elbo": elbo.sum() / total_time_steps_batch,
            "kl": latent_loss.sum() / total_time_steps_batch,
            "annealing_factor": kl_mult
        }
        return loss, details, None
 def encode(self, x, x_lens):
     pre_bottleneck_h, z_mask = self.encoder(x, x_lens)
     z_lens = z_mask.sum(1)
     z, latent_loss, info = self.bottleneck(utils.safe_squeeze(
         pre_bottleneck_h, dim=-2),
                                            mask=z_mask)
     return pre_bottleneck_h, z, latent_loss, info, z_lens
示例#5
0
    def forward(self, input, conds):
        del input  # unused
        #z = conds[0].permute(0, 3, 2, 1)
        z = self.conv(conds[0].view(conds[0].size(0), conds[0].size(1),
                                    -1).permute(0, 2, 1))
        z = z.view(z.size(0), 64, -1, z.size(2)).permute(0, 1, 2, 3)
        cond = None
        if len(conds) > 1:
            cond = conds[1]
            # Squeeze out the width and heigth,
            # as we assume this conditionig is global
            cond = safe_squeeze(cond, 1)
            cond = safe_squeeze(cond, 1)
        x2 = self.decode(z, cond=cond)

        return x2.permute(0, 3, 2, 1).unsqueeze(3)
示例#6
0
    def minibatch_loss(self, batch):
        # from distsup.utils import ptvsd; ptvsd()
        log_probs, log_prob_lens = self(batch['features'],
                                        batch['features_len'])
        targets = batch['targets'].int()
        targets_len = batch['targets_len']

        # log_probs: (bs x t x 1 x nc) -> (t x bs x nc)
        log_probs = utils.safe_squeeze(log_probs, 2).permute(1, 0, 2)
        loss = self.ctc(log_probs, targets, log_prob_lens,
                        targets_len) / log_prob_lens.size(0)
        decodes = utils.greedy_ctc_decode(log_probs, log_prob_lens)
        cer = utils.error_rate(
            decodes,
            [t[:tl] for t, tl in zip(targets.to('cpu').numpy(), targets_len)])
        details = {
            'cer': torch.tensor(cer),
            'main_loss': loss,
        }
        if self.adversarial is not None:
            friend_loss, adv_loss, adv_details = self.adversarial.loss(
                batch['spkid'])
            loss = loss + friend_loss  # + adv_loss
            details['adv_friend_loss'] = friend_loss
            details['adv_adv_loss'] = adv_loss
            details['adv_acc'] = adv_details['acc']
        return loss, details
示例#7
0
    def decode(self, batch):
        # Call forward() on this model
        log_probs, log_prob_lens = self(batch['features'],
                                        batch['features_len'])
        # (bs x t x 1 x nc) --> (t x bs x nc)
        log_probs = utils.safe_squeeze(log_probs, 2).permute(1, 0, 2)
        # 'decodes'           is a Python list of the sequence of best path tokens, per sample, no blanks
        # 'decodesWithBlanks' is the same but keeps the blanks for path generation and processing
        # 'log_probs' shape is [ maxLengthDecodeSequences, batchSize, num columns]
        decodes, decodesWithBlanks = utils.greedy_ctc_decode(log_probs,
                                                             log_prob_lens,
                                                             return_raw=True)

        szLongestProbSequence = log_prob_lens[0].item()
        batchSize = log_probs.shape[1]

        assert len(decodesWithBlanks[0]) == szLongestProbSequence
        assert len(decodesWithBlanks[0]) == log_probs.shape[0]
        assert szLongestProbSequence == log_probs.shape[0]

        try:
            # Some datasets are fully transcribed. Use if available.
            targets = batch['targets'].int()
            targets_len = batch['targets_len']
        except:
            # We have no targets and therefore run a forward() only recognition.
            targets = None
            targets_len = None

        # Pretty print of paths, strings and meanings
        self.dataset.decode(self.aligner, decodesWithBlanks, decodes,
                            log_probs, log_prob_lens, targets, targets_len,
                            batch, self.verbose)
示例#8
0
 def loss(self, logits, targets):
     logits = safe_squeeze(logits, -1)
     logits = logits.permute(0, 3, 2, 1)
     B, C, H, W = logits.shape
     logits = logits.expand(B, 3, H, W)
     targets = targets.permute(0, 3, 2, 1)
     targets = targets.expand(B, 3, H, W)
     return F.l1_loss(self.vgg(logits * 2 - 1),
                      self.vgg(targets * 2 - 1),
                      reduction='none')
示例#9
0
 def forward(self, x):
     # x: (bsz x dim x t)
     x = x.permute(0, 2, 1)
     x = self.conv(x)
     x = x.view(x.size(0), self.hid_channels, self.image_height, x.size(-1))
     x2 = self.conv_stack(x)
     x2 = x2.view(x2.size(0), -1, self.quantizer.num_levels, x2.size(2),
                  x2.size(3))
     # (bs x 1 x 1 x h x t) -> (bs x t x 1 x h x 1)
     # XXX This should squeeze out 1 and leave channels as '3', but not sure
     return utils.safe_squeeze(x2.permute(0, 4, 3, 1, 2), -1)
示例#10
0
    def forward(self, x, conds=()):
        """
        x: BS x Dim x T
        conds: list of BS x DimC x T/k
        """
        x_skip = 0

        if self.ahead_corruption is not None:
            ber = torch.distributions.bernoulli.Bernoulli(
                torch.tensor([1.0 - self.ahead_corruption], device=x.device))
            mask = utils.safe_squeeze(ber.sample(sample_shape=x.size()), -1)
            x_corrupt = x * mask

        if self.ahead_fraction is not None:
            probs = (np.ones((self.ahead_frames + 1, ), dtype=np.float32) *
                     self.ahead_fraction / self.ahead_frames)
            probs[0] = 1.0 - self.ahead_fraction
            nframes = np.random.choice(self.ahead_frames + 1, p=probs)
        else:
            nframes = self.ahead_frames

        contexts = ('past', 'future') if self.bidirectional else ('past', )
        for ctx in contexts:
            if nframes == 0:
                x_shift = x
            elif ctx == 'past':  # Apply padding on the time axis (dim=2)
                x_shift = F.pad(x, (nframes, 0))[:, :, :-nframes]
            elif ctx == 'future':
                x_shift = F.pad(x, (0, nframes))[:, :, nframes:]

            if self.ahead_corruption is not None:  # Stack on the dim axis
                x_shift = torch.cat([x_shift, x_corrupt], dim=1)

            if ctx == 'future':
                x_shift = torch.flip(x_shift, dims=[2])

            x_res = self.x_to_res(x_shift)
            x_skip += self.res_to_skip(x_res)
            for res_to_hid, hid_to_skip, hid_to_res in zip(
                    self.res_to_hid, self.hid_to_skip, self.hid_to_res):
                x_hid = res_to_hid(x_res, conds)
                x_hid = F.dropout(x_hid, self.dropout, self.training, True)
                x_skip += hid_to_skip(x_hid)
                if hid_to_res is None:
                    x_res = None  # We don't use the last residual output
                else:
                    x_res = x_res + hid_to_res(x_hid)

        for skip_to_out in self.skip_to_out:
            x_skip = torch.relu(x_skip)
            x_skip = F.dropout(x_skip, self.dropout, self.training, True)
            x_skip = skip_to_out(x_skip)
        return x_skip
示例#11
0
 def forward(self, x, conds=()):
     """
     x: BS x T x H x 1
     conds: list of BS x DimC x T/k
     """
     # make the height the channel
     x = safe_squeeze(x, 3).transpose(1, 2)
     logits = self.wave_net.forward(x, conds)
     # move the channel back to height
     logits = logits.transpose(1, 2)
     # add the channel dim, the logit
     logits = logits.reshape(
         [logits.size(0),
          logits.size(1), -1, 1, self.quantizer.num_levels])
     return logits
示例#12
0
    def forward(self, x, conds):
        """x is BS x C x T x H!!!
           each c in conds is BS x T' x 1 x C'
        """
        assert len(conds) == len(self.cond_convs)
        bs, c, t = x.shape[:3]

        for cconv, cond in zip(self.cond_convs, conds):
            c_bs, c_t, c_h, c_c = cond.size()
            cond = safe_squeeze(cond, 2).permute(0, 2, 1)
            cond = cconv(cond)
            # expand cond to length of x
            cond = cond.repeat_interleave(t // c_t, 2)
            if x.dim() == 4:
                cond = cond.unsqueeze(3)
            x = x + cond
        return x
示例#13
0
    def minibatch_loss(self, batch):
        # Call forward() on this model
        log_probs, log_prob_lens = self(batch['features'],
                                        batch['features_len'])
        targets = batch['targets'].int()
        targets_len = batch['targets_len']
        # log_probs: (bs x t x 1 x nc) -> (t x bs x nc)
        log_probs = utils.safe_squeeze(log_probs, 2).permute(1, 0, 2)
        loss = self.ctc(log_probs, targets, log_prob_lens,
                        targets_len) / log_prob_lens.size(0)

        # 'decodes'           is a Python list of the sequence of best path tokens, per sample, no blanks
        # 'decodesWithBlanks' is the same but keeps the blanks for path generation and processing
        # 'log_probs' shape is [ maxLengthDecodeSequences, batchSize, num columns]
        decodes, decodesWithBlanks = utils.greedy_ctc_decode(log_probs,
                                                             log_prob_lens,
                                                             return_raw=True)

        szLongestProbSequence = log_prob_lens[0].item()
        batchSize = log_probs.shape[1]

        assert len(decodesWithBlanks[0]) == szLongestProbSequence
        assert len(decodesWithBlanks[0]) == log_probs.shape[0]
        assert szLongestProbSequence == log_probs.shape[0]

        # Pretty print of paths, strings and meanings
        # Also, write path to output file if requested.
        self.dataset.decode(self.aligner, decodesWithBlanks, decodes,
                            log_probs, log_prob_lens, targets, targets_len,
                            batch, self.verbose)

        # Calculate Levenshtein character (or label) error-rate, on clean strings
        cer = utils.error_rate(
            decodes, [t[:tl] for t, tl in zip(targets.to('cpu'), targets_len)])

        return loss, {'cer': torch.tensor(cer)}
示例#14
0
 def retrieve_saved_input(self):
     # Pick 'indices' and squeeze to (bsz x L)
     indices = utils.safe_squeeze(self.input['indices'], -1)
     indices = utils.safe_squeeze(indices, -1)
     self.input = None
     return indices
示例#15
0
    def loss(self, features, targets, features_len=None, targets_len=None):
        # the features may be padded
        if features_len is None:
            assert targets_len is None
            assert features.shape[1] == targets.shape[1], (
                f"The lengths of the targets and the inputs should "
                f"be the same for a framewise prediction. "
                f"Currently: {targets.shape[1]} and {features.shape[1]} respectively."
            )
        else:
            assert (torch.all(features_len == targets_len)
                    and (features.shape[1] >= targets.shape[1]))
        lens = features_len

        if lens is None:
            lens = torch.full((features.shape[0], ),
                              fill_value=features.shape[1],
                              device=targets.device)

        hidden = self(self.input)
        feat_aligned_len = features.shape[1]
        hidden_aligned_len = hidden.shape[1]

        assert feat_aligned_len >= lens.max(), (
            f"Incompatible shapes for features, hidden, targets: "
            f"{(features.shape, hidden.shape, targets.shape)}")
        targets = targets.long()

        rate_factor = feat_aligned_len // hidden_aligned_len
        assert (feat_aligned_len % hidden_aligned_len) == 0, (
            "The hidden (captured) representation should evenly divide the "
            "features length")
        hidden = hidden.repeat_interleave(rate_factor, dim=1)
        assert lens.max() <= hidden.shape[1], (
            f" Incompatible shapes for lens, hidden.shape[1]: "
            f"{(lens.max(), hidden.shape[1])}")
        hidden = hidden[:, :targets.shape[1]].contiguous()

        pred_labels = utils.safe_squeeze(hidden.argmax(dim=3), 2)
        accs = (pred_labels == targets).float()

        losses = F.cross_entropy(utils.safe_squeeze(hidden,
                                                    2).permute(0, 2, 1),
                                 targets,
                                 reduction="none")

        mask = utils.get_mask1d(lens, mask_length=losses.size(1))
        mask = mask / mask.sum()

        if not self.ignore_padding:
            mask[:] = 1

        acc = (accs * mask).sum()
        loss = (losses * mask).sum()

        if logger.is_currently_logging():
            logger.log_mpl_figure(
                "framewise_debug",
                self.plot(features, F.softmax(hidden.detach(), dim=-1)))
        details = {"loss": loss, "acc": acc, "out_seq": pred_labels.detach()}
        return loss, details
示例#16
0
 def sample(self, logits):
     logits = safe_squeeze(logits, -1)
     return Normal(logits, 1.0).sample()
示例#17
0
 def mean_field(self, logits):
     logits = safe_squeeze(logits, -1)
     return torch.sigmoid(logits)
示例#18
0
 def sample(self, logits):
     logits = safe_squeeze(logits, -1)
     return Laplace(logits, 1.0).sample()
示例#19
0
    def forward(ctx, log_probs, act_lens, graph_matrices, neg_inf=-np.inf):
        logsumexp = torch.logsumexp
        log_probs = log_probs.detach()
        log_probs = safe_squeeze(log_probs, 2).transpose(0, 1).contiguous()
        T, bs, _ = log_probs.size()
        assert graph_matrices[0].size(0) in [1, bs]
        assert all(
            sm.size(0) == graph_matrices[0].size(0) for sm in graph_matrices)
        if graph_matrices[0].size(0) == 1:
            graph_matrices = [gm.expand(bs, -1, -1) for gm in graph_matrices]
        (states_mat, ilabels_mat, weights_mat, terminal_mat, states_mat_out,
         ilabels_mat_out, weights_mat_out, _) = graph_matrices

        terminal_mat = terminal_mat.squeeze(-1)

        _, n, _ = states_mat.size()

        # a helper to select the next indices for a transition
        def get_idx(m, i):
            _bs = m.size(0)
            return torch.gather(m, 1, i.view(_bs, -1)).view(i.size())

        lalpha = torch.full((bs, n), neg_inf, device=log_probs.device)
        lalpha[:, 0] = 0
        lalpha0 = lalpha.clone()

        lalphas = torch.full((T, bs, n), neg_inf, device=log_probs.device)

        # The utterances are sorted according to length descending.
        # Rather than masking, stop updates to alphas when an utterance ends.
        assert act_lens.tolist() == sorted(act_lens, reverse=True)
        last_iter_end = 0
        for bitem in range(bs, 0, -1):
            iter_end = act_lens[bitem - 1]
            for t in range(last_iter_end, iter_end):
                lalphas[t] = lalpha
                token_probs = weights_mat[:bitem].clone()
                token_probs += get_idx(lalpha[:bitem], states_mat[:bitem])
                token_probs += get_idx(log_probs[t, :bitem],
                                       ilabels_mat[:bitem])
                logsumexp(token_probs, dim=-1, out=lalpha[:bitem])
            last_iter_end = iter_end

        log_cost = logsumexp(lalpha + terminal_mat, dim=-1)

        lbeta = terminal_mat.clone()
        logprobs_grad = torch.zeros_like(log_probs)

        last_iter_end = T
        for bitem in range(1, bs + 1):
            if bitem < bs:
                iter_end = act_lens[bitem]
            else:
                iter_end = 0
            for t in range(last_iter_end - 1, iter_end - 1, -1):
                token_probs = weights_mat_out[:bitem].clone()
                token_probs += get_idx(lbeta[:bitem], states_mat_out[:bitem])
                token_probs += get_idx(log_probs[t, :bitem],
                                       ilabels_mat_out[:bitem])
                logsumexp(token_probs, dim=-1, out=lbeta[:bitem])

                token_probs += (lalphas[t, :bitem] -
                                log_cost[:bitem].unsqueeze(-1)).unsqueeze(-1)
                token_probs.exp_()

                logprobs_grad[t, :bitem].scatter_add_(
                    1, ilabels_mat_out[:bitem].view(bitem, -1),
                    token_probs.view(bitem, -1))
            last_iter_end = iter_end

        ctx.grads = logprobs_grad.transpose(0, 1).unsqueeze(2)

        # approximate the numerical error
        log_cost0 = logsumexp(lalpha0 + lbeta, dim=1)
        if torch.abs(log_cost - log_cost0).max().item() > 1e-3:
            print('forward_backward num error: fwd losses %s bwd losses %s' %
                  (log_cost, log_cost0))
        return log_cost
示例#20
0
 def sample(self, logits):
     return safe_squeeze(logits, -1)
示例#21
0
 def mean_field(self, logits):
     logits = safe_squeeze(logits, -1)
     return logits
示例#22
0
 def _get_normal(self, logits):
     loc, scale = logits.chunk(2, dim=-1)
     loc = safe_squeeze(loc, -1)
     scale = torch.exp(safe_squeeze(scale, -1))
     return Normal(loc, scale)
示例#23
0
 def loss(self, logits, targets):
     logits = safe_squeeze(logits, -1)
     assert logits.size() == targets.size()
     return F.mse_loss(logits, targets, reduction='none')
示例#24
0
    def evaluate(self, batches):
        tot_examples = 0.
        tot_loss = 0.
        tot_detached_probesloss = 0.
        tot_backprop_probesloss = 0.
        tot_errs = 0.

        alis_es = []
        alis_gt = []
        alis_lens = []
        total_stats = {}

        first_batch = None

        for batch in batches:
            if first_batch is None:
                first_batch = copy.deepcopy(batch)

            num_examples = batch['features'].shape[0]
            loss, stats, tokens = self.minibatch_loss_and_tokens(batch)

            # Run the probes
            detached_loss, backprop_loss, probes_details = self.probes_loss(
                batch)
            stats.update(probes_details)

            if tokens is not None:
                # Tokens should be in layout B x W x 1 x 1
                tokens = utils.safe_squeeze(tokens, dim=3)
                tokens = utils.safe_squeeze(tokens, dim=2)

                feat_len = batch['features_len']
                alis_lens.append(feat_len)

                # the tokens should match the rate of the alignment
                ali_es = self.align_tokens_to_features(batch, tokens)
                assert (ali_es.shape[0] == batch['features'].shape[0])
                assert (ali_es.shape[1] == batch['features'].shape[1])
                alis_es.append(ali_es[:, :])
                if 'alignment' in batch:
                    ali_gt = batch['alignment']
                    ali_len = batch['alignment_len']

                    assert ((ali_len == feat_len).all())
                    alis_gt.append(ali_gt)

            tot_examples += num_examples
            tot_loss += loss * num_examples
            tot_errs += stats.get('err', np.nan) * num_examples

            tot_detached_probesloss += detached_loss * num_examples
            tot_backprop_probesloss += backprop_loss * num_examples
            for k, v in stats.items():
                if k == 'segmental_values':
                    if logger.is_currently_logging():
                        import matplotlib.pyplot as plt
                        f = plt.figure(dpi=300)
                        plt.plot(v.data.cpu().numpy(), 'r.-')
                        f.set_tight_layout(True)
                        logger.log_mpl_figure(f'segmentation_values', f)
                elif utils.is_scalar(v):
                    if k not in total_stats:
                        total_stats[k] = v * num_examples
                    else:
                        total_stats[k] += v * num_examples
        # loss is special, as we use it e.g. for learn rate control
        # add all signals that we train agains, but remove the passive ones
        all_scores = {
            'loss': (tot_loss + tot_backprop_probesloss) / tot_examples,
            'probes_backprop_loss':
            tot_backprop_probesloss / tot_examples,
            'probes_detached_loss':
            tot_detached_probesloss / tot_examples,
            'err':
            tot_errs / tot_examples,
            'probes_loss':
            (tot_detached_probesloss + tot_backprop_probesloss) / tot_examples
        }

        for k, v in total_stats.items():
            all_scores[k] = v / tot_examples

        if (len(alis_es) > 0) and (len(alis_gt) > 0):
            # If we have gathered any alignments
            f1_scores = dict(precision=[], recall=[], f1=[])
            for batch in zip(alis_gt, alis_es, alis_lens):
                batch = [t.detach().cpu().numpy() for t in batch]
                for k, v in scoring.compute_f1_scores(*batch, delta=1).items():
                    f1_scores[k].extend(v)
            for k in ('f1', 'precision', 'recall'):
                print(f"f1/{k}: {np.mean(f1_scores[k])}")
                logger.log_scalar(f'f1/{k}', np.mean(f1_scores[k]))

            alis_es = self._unpad_and_concat(alis_es, alis_lens)
            alis_gt = self._unpad_and_concat(
                alis_gt, alis_lens) if len(alis_gt) else None

            scores_to_compute = [('', lambda x: x)]
            if alis_gt is not None and self.pad_symbol is not None:
                not_pad = (alis_gt != self.pad_symbol)
                scores_to_compute.append(('nonpad_', lambda x: x[not_pad]))

            if alis_gt is not None and alis_es.min() < 0:
                not_pad2 = (alis_es != -1)
                scores_to_compute.append(
                    ('validtokens_', lambda x: x[not_pad2]))

            for prefix, ali_filter in scores_to_compute:
                es = ali_filter(alis_es)

                if alis_gt is not None:
                    gt = ali_filter(alis_gt)

                    mapping_scores, mapping = self._mapping_metrics(
                        gt, es, prefix=prefix)
                    all_scores.update(mapping_scores)

                    # Run the segmentation plottin with mapping
                    if logger.is_currently_logging():
                        _, _, tokens = self.minibatch_loss_and_tokens(
                            first_batch)
                        self.plot_input_and_alignments(
                            first_batch['features'],
                            alignment_es=tokens,
                            alignment_gt=first_batch['alignment'],
                            mapping=mapping,
                            imshow_kwargs=dict(cmap='Greys'),
                            log_suffix=f'{prefix[:-1]}')

                    clustering_scores = self._clustering_metrics(gt,
                                                                 es,
                                                                 prefix=prefix)
                    all_scores.update(clustering_scores)

                perplexity_scores = self._perplexity_metrics(es, prefix=prefix)
                all_scores.update(perplexity_scores)

        return all_scores
示例#25
0
 def align_tokens_to_features(self, batch, tokens):
     # No downsampling in our case
     return utils.safe_squeeze(tokens, 1)
示例#26
0
 def loss(self, logits, targets):
     logits = safe_squeeze(logits, -1)
     assert logits.size() == targets.size()
     return F.binary_cross_entropy_with_logits(logits,
                                               targets,
                                               reduction='none')
示例#27
0
 def sample(self, logits):
     logits = safe_squeeze(logits, -1)
     probs = torch.sigmoid(logits)
     return (torch.rand_like(probs) < probs).float()
示例#28
0
    def loss(self, features, targets, features_len=None, targets_len=None):
        # the features may be padded
        if features_len is None:
            assert targets_len is None
            assert features.shape[1] == targets.shape[1], (
                f"The lengths of the targets and the inputs should "
                f"be the same for a framewise prediction. "
                f"Currently: {targets.shape[1]} and {features.shape[1]} respectively."
            )
        else:
            assert (torch.all(features_len == targets_len)
                    and (features.shape[1] >= targets.shape[1]))
        features_len = self.calculateFeatureLens(features, features_len)
        inputs_len, rate_factor = self.calculateInputLengths(
            self.input, features, features_len)
        feat_aligned_len = features.shape[1]
        assert feat_aligned_len >= features_len.max(), (
            f"Incompatible shapes for features, pred, targets: "
            f"{(features.shape, pred.shape, targets.shape)}")
        targets = targets.long()

        details = {}
        total_loss = 0
        for pred_name, pred in self(self.input, inputs_len).items():
            hidden_aligned_len = pred.shape[1]

            assert (feat_aligned_len % hidden_aligned_len) == 0, (
                "The hidden (captured) representation should evenly divide the "
                "features length")
            pred = pred.repeat_interleave(rate_factor, dim=1)
            assert features_len.max() <= pred.shape[1], (
                f" Incompatible shapes for features_len, pred.shape[1]: "
                f"{(features_len.max(), pred.shape[1])}")
            pred = pred[:, :targets.shape[1]].contiguous()

            pred_labels = utils.safe_squeeze(pred.argmax(dim=3), 2)
            accs = (pred_labels == targets).float()

            losses = F.cross_entropy(utils.safe_squeeze(pred,
                                                        2).permute(0, 2, 1),
                                     targets,
                                     reduction="none")

            mask = utils.get_mask1d(features_len.to(losses.device),
                                    mask_length=losses.size(1))
            mask = mask / mask.sum()

            if not self.ignore_padding:
                mask[:] = 1

            acc = (accs * mask).sum()
            loss = (losses * mask).sum()

            if logger.is_currently_logging():
                logger.log_mpl_figure(
                    "framewise_debug_" + pred_name,
                    self.plot(features, F.softmax(pred.detach(), dim=-1)))
            total_loss = total_loss + loss
            details.update({
                "loss_" + pred_name: loss,
                "acc_" + pred_name: acc,
                "out_seq_" + pred_name: pred_labels.detach()
            })
        return total_loss, details
示例#29
0
def path_reduction(log_probs,
                   act_lens,
                   graph_matrices,
                   red_kind='logsumexp',
                   neg_inf=-1e20):
    """
    Compute a sum of all paths through a graph.
    Args:
        log_probs: bs x T x 1 x NUM_SYMBOLS tensor of log_probs of emitting symbols
        act_lens: bs tensor of lengths of utternaces
        red_kind: logsumexp / viterbi - chooses between aggregating al paths by
            summing their probabilities (logsumexp of logprobs), or
            by taking the maximally probable one. Also encoded which reduction
            engige ot use:
                logsumexp_fwb forces a forward-backward algo, while
                logsumexp_autodiff uses backward pass using autodiff.
        graphs_matrices: a tuple of four matrices of shape bs x N [x K]
            that encode the transitions and weights in the graph
        neg_inf: what value to use for improbable events (-1e10 or -1e20 are OK)
    Returns:
        tensor of shape bs: a sum of weigths on the maximally probable path
        or on all paths
    """
    if (red_kind == 'logsumexp_fwb'
            or (red_kind == 'logsumexp' and len(graph_matrices) == 8)):
        return path_logsumexp(log_probs, act_lens, graph_matrices, -1e20)

    log_probs = safe_squeeze(log_probs, 2).transpose(0, 1).contiguous()
    _, bs, _ = log_probs.size()
    assert graph_matrices[0].size(0) in [1, bs]
    assert all(
        sm.size(0) == graph_matrices[0].size(0) for sm in graph_matrices)
    # This can happen if we get the matrices for full forward-backward
    # and here we only need the ones for worward
    if len(graph_matrices) == 8:
        graph_matrices = graph_matrices[:4]
    if graph_matrices[0].size(0) == 1:
        graph_matrices = [gm.expand(bs, -1, -1) for gm in graph_matrices]
    states_mat, ilabels_mat, weights_mat, terminal_mat = graph_matrices

    _, n, k = states_mat.size()

    if red_kind in ['logsumexp', 'logsumexp_autodiff']:
        # reduction = torch.logsumexp
        reduction = torch.logsumexp
    else:
        assert red_kind in ['viterbi', 'viterbi_autodiff']

        def reduction(t, dim):
            return torch.max(t, dim)[0]

    # a helper to select the next indices for a transition
    def get_idx(m, i):
        _bs = m.size(0)
        return torch.gather(m, 1, i.view(_bs, n * k)).view((_bs, n, k))

    lalpha = torch.full((bs, n), neg_inf, device=log_probs.device)
    lalpha[:, 0] = 0

    # The utterances are sorted according to length descending.
    # Rather than masking, stop updates to alphas when an utterance ends.
    assert act_lens.tolist() == sorted(act_lens, reverse=True)
    last_iter_end = 0
    for bitem in range(bs, 0, -1):
        iter_end = act_lens[bitem - 1]
        for t in range(last_iter_end, iter_end):
            # print(torch.softmax(lalpha[0], -1))
            token_probs = (get_idx(lalpha[:bitem], states_mat[:bitem]) +
                           weights_mat[:bitem] +
                           get_idx(log_probs[t, :bitem], ilabels_mat[:bitem]))
            la = reduction(token_probs, dim=-1)
            lalpha = lalpha.clone()
            lalpha[:bitem] = la
        last_iter_end = iter_end

    path_sum = reduction(lalpha + terminal_mat.squeeze(2), dim=-1)
    return path_sum
示例#30
0
 def loss(self, logits, targets):
     logits = safe_squeeze(logits, -1)
     assert logits.size() == targets.size(
     ), f"{logits.size()} != {targets.size()}"
     return F.l1_loss(logits, targets, reduction='none')