コード例 #1
0
    def compute_loss(self, net_output, sample, reduce=True):
        try:
            from pychain.graph import ChainGraphBatch
            from pychain.loss import ChainFunction
        except ImportError:
            raise ImportError("Please install OpenFST and PyChain by `make openfst pychain` after entering espresso/tools")

        den_graphs = ChainGraphBatch(self.den_graph, sample["nsentences"])
        encoder_out = net_output.encoder_out.transpose(0, 1)  # T x B x V -> B x T x V
        out_lengths = net_output.src_lengths.long()  # B
        den_objf = ChainFunction.apply(encoder_out, out_lengths, den_graphs, self.den_leaky_hmm_coefficient)
        num_objf = ChainFunction.apply(encoder_out, out_lengths, sample["target"], self.num_leaky_hmm_coefficient)
        loss = - num_objf + den_objf  # negative log-probs
        return loss, loss
コード例 #2
0
    def compute_loss(self, net_output, sample, reduce=True):
        try:
            from pychain.graph import ChainGraphBatch
            from pychain.loss import ChainFunction
        except ImportError:
            raise ImportError(
                "Please install OpenFST and PyChain by `make openfst pychain` after entering espresso/tools"
            )

        encoder_out = net_output["encoder_out"][0].transpose(
            0, 1)  # T x B x V -> B x T x V
        out_lengths = net_output["src_lengths"][0].long()  # B
        den_graphs = ChainGraphBatch(self.den_graph, sample["nsentences"])
        if self.xent_regularize > 0.0:
            den_objf = ChainFunction.apply(encoder_out, out_lengths,
                                           den_graphs,
                                           self.leaky_hmm_coefficient)
            num_objf = ChainFunction.apply(encoder_out, out_lengths,
                                           sample["target"])
            loss = -num_objf + den_objf  # negative log-probs
            nll_loss = loss.clone().detach()
            loss -= self.xent_regularize * num_objf
        else:
            # demonstrate another more "integrated" usage of the PyChain loss. it's equivalent to
            # the first three lines in the above "if" block, but also supports throwing away
            # batches with the NaN loss by setting their gradients to 0.
            loss = ChainLossFunction.apply(
                encoder_out,
                out_lengths,
                sample["target"],
                den_graphs,
                self.leaky_hmm_coefficient,
            )
            nll_loss = loss.clone().detach()

        if self.output_l2_regularize > 0.0:
            encoder_padding_mask = (net_output["encoder_padding_mask"][0] if
                                    len(net_output["encoder_padding_mask"]) > 0
                                    else None)
            encoder_out_squared = encoder_out.pow(2.0)
            if encoder_padding_mask is not None:
                pad_mask = encoder_padding_mask.transpose(0, 1).unsqueeze(
                    -1)  # T x B -> B x T x 1
                encoder_out_squared.masked_fill_(pad_mask, 0.0)
            loss += 0.5 * self.output_l2_regularize * encoder_out_squared.sum()

        return loss, nll_loss