def compute_loss(self, net_output, sample, reduce=True): try: from pychain.graph import ChainGraphBatch from pychain.loss import ChainFunction except ImportError: raise ImportError("Please install OpenFST and PyChain by `make openfst pychain` after entering espresso/tools") den_graphs = ChainGraphBatch(self.den_graph, sample["nsentences"]) encoder_out = net_output.encoder_out.transpose(0, 1) # T x B x V -> B x T x V out_lengths = net_output.src_lengths.long() # B den_objf = ChainFunction.apply(encoder_out, out_lengths, den_graphs, self.den_leaky_hmm_coefficient) num_objf = ChainFunction.apply(encoder_out, out_lengths, sample["target"], self.num_leaky_hmm_coefficient) loss = - num_objf + den_objf # negative log-probs return loss, loss
def compute_loss(self, net_output, sample, reduce=True): try: from pychain.graph import ChainGraphBatch from pychain.loss import ChainFunction except ImportError: raise ImportError( "Please install OpenFST and PyChain by `make openfst pychain` after entering espresso/tools" ) encoder_out = net_output["encoder_out"][0].transpose( 0, 1) # T x B x V -> B x T x V out_lengths = net_output["src_lengths"][0].long() # B den_graphs = ChainGraphBatch(self.den_graph, sample["nsentences"]) if self.xent_regularize > 0.0: den_objf = ChainFunction.apply(encoder_out, out_lengths, den_graphs, self.leaky_hmm_coefficient) num_objf = ChainFunction.apply(encoder_out, out_lengths, sample["target"]) loss = -num_objf + den_objf # negative log-probs nll_loss = loss.clone().detach() loss -= self.xent_regularize * num_objf else: # demonstrate another more "integrated" usage of the PyChain loss. it's equivalent to # the first three lines in the above "if" block, but also supports throwing away # batches with the NaN loss by setting their gradients to 0. loss = ChainLossFunction.apply( encoder_out, out_lengths, sample["target"], den_graphs, self.leaky_hmm_coefficient, ) nll_loss = loss.clone().detach() if self.output_l2_regularize > 0.0: encoder_padding_mask = (net_output["encoder_padding_mask"][0] if len(net_output["encoder_padding_mask"]) > 0 else None) encoder_out_squared = encoder_out.pow(2.0) if encoder_padding_mask is not None: pad_mask = encoder_padding_mask.transpose(0, 1).unsqueeze( -1) # T x B -> B x T x 1 encoder_out_squared.masked_fill_(pad_mask, 0.0) loss += 0.5 * self.output_l2_regularize * encoder_out_squared.sum() return loss, nll_loss