def nests_loss(self, energy: Tensor, target: Tensor) -> Tensor:
        """
        Args:
            energy: Tensor
                the energy tensor with shape = [length, num_label, num_label]
            target: Tensor
                the tensor of target labels with shape [length]

        Returns: Tensor
                A 0D tensor for minus log likelihood loss
        """
        length, _, _ = energy.size()

        num_label_3 = self.indices_is.size(0)

        indices_3 = energy.new_empty((length, num_label_3)).long()
        indices_3[0, :] = self.indices_bs
        if length > 2:
            indices_3[1:length - 1, :] = self.indices_is.repeat(
                (length - 2, 1))
        indices_3[length - 1, :] = self.indices_es

        # shape = [num_label]
        partition_1 = None
        partition_3 = None

        # shape = []
        prev_label = self.index_bos
        tgt_energy = 0

        for t in range(length):
            # shape = [num_label, num_label]
            curr_energy = energy[t]
            if t == 0:
                partition_1 = curr_energy[self.index_bos, :]
                partition_3 = energy.new_full((num_label_3, ), -1e4)
            else:
                # shape = [num_label]
                partition = partition_1.clone()
                partition[indices_3[t - 1]] = partition_3
                partition_1 = logsumexp(curr_energy + partition_1.unsqueeze(1),
                                        dim=0)
                partition_3 = logsumexp(curr_energy[:, indices_3[t]] +
                                        partition.unsqueeze(1),
                                        dim=0)
            label = target[t]
            tgt_energy += curr_energy[prev_label, label]
            prev_label = label

        t = length - 1
        curr_energy = self.trans_matrix.data[:, self.index_eos]
        partition = curr_energy + partition_1
        partition[indices_3[t]] = curr_energy[indices_3[t]] + partition_3
        return logsumexp(partition, dim=0) - tgt_energy
    def loss(self,
             input: Tensor,
             target: Tensor,
             mask: Tensor = None) -> Tuple[Tensor, Tensor]:
        """
        Args:
            input: Tensor
                the input tensor with shape = [batch, length, input_size]
            target: Tensor
                the tensor of target labels with shape [batch, length]
            mask: Tensor or None
                the mask tensor with shape = [batch, length]

        Returns: Tensor
                A 1D tensor for minus log likelihood loss
        """
        batch, length, _ = input.size()
        energy = self.forward(input, mask=mask)
        # shape = [length, batch, num_label, num_label]
        energy_transpose = energy.transpose(0, 1)
        # shape = [length, batch]
        target_transpose = target.transpose(0, 1)

        # shape = [batch, num_label]
        partition = None

        # shape = [batch]
        batch_index = torch.arange(0, batch).type_as(input).long()
        prev_label = input.new_full((batch, ), self.index_bos).long()
        tgt_energy = input.new_zeros(batch)

        for t in range(length):
            # shape = [batch, num_label, num_label]
            curr_energy = energy_transpose[t]
            if t == 0:
                partition = curr_energy[:, self.index_bos, :]
            else:
                # shape = [batch, num_label]
                partition = logsumexp(curr_energy + partition.unsqueeze(2),
                                      dim=1)
            label = target_transpose[t]
            tgt_energy += curr_energy[batch_index, prev_label, label]
            prev_label = label

        return logsumexp(
            self.trans_matrix.data[:, self.index_eos].unsqueeze(0) + partition,
            dim=1) - tgt_energy, energy
    def decode_nest(self, energy: Tensor) -> Tensor:
        """
        Args:
            energy: Tensor
                the energy tensor with shape = [length, num_label, num_label]

        Returns: Tensor
            decoding nested results in shape [length]
        """

        # the last row and column is the tag for pad symbol. reduce these two dimensions by 1 to remove that.
        # also remove the first #symbolic rows and columns.
        # now the shape of energies_shuffled is [n_time_steps, t, t] where t = num_labels - #symbolic - 1.
        energy_transpose = energy[:, :self.index_bos, :self.index_bos]

        length, num_label, _ = energy_transpose.size()

        num_label_3 = self.indices_is.size(0)

        indices_3 = energy.new_empty((length, num_label_3)).long()
        indices_3[0, :] = self.indices_bs
        if length > 2:
            indices_3[1:length - 1, :] = self.indices_is.repeat(
                (length - 2, 1))
        indices_3[length - 1, :] = self.indices_es

        pointer_1 = energy.new_zeros((length, num_label)).long()
        pointer_3 = energy.new_zeros((length, num_label)).long()
        back_pointer = pointer_3.new_zeros(length)

        pi_1 = energy[0, self.index_bos, :self.index_bos]
        pi_3 = energy.new_full((num_label_3, ), -1e4)
        pointer_1[0] = self.index_bos
        pointer_3[0] = self.index_bos
        for t in range(1, length):
            e_t = energy_transpose[t]
            pi = pi_1.clone()
            pi[indices_3[t - 1]] = pi_3
            pi_1, pointer_1[t] = torch.max(e_t + pi_1.unsqueeze(1), dim=0)
            pi_3, pointer_3[t, indices_3[t]] = torch.max(e_t[:, indices_3[t]] +
                                                         pi.unsqueeze(1),
                                                         dim=0)
        t = length - 1
        e_t = self.trans_matrix.data[:self.index_bos, self.index_eos]
        pi = e_t + pi_1
        pi[indices_3[t]] = e_t[indices_3[t]] + pi_3

        _, back_pointer[-1] = torch.max(pi, dim=0)
        t = length - 2
        while t > -1:
            if (indices_3[t + 1] == back_pointer[t +
                                                 1]).nonzero().numel() == 0:
                break
            pointer_last = pointer_3[t + 1]
            back_pointer[t] = pointer_last[back_pointer[t + 1]]
            t -= 1
        while t > -1:
            pointer_last = pointer_1[t + 1]
            back_pointer[t] = pointer_last[back_pointer[t + 1]]
            t -= 1

        return back_pointer