예제 #1
0
    def nests_loss(self, energy: Tensor, target: Tensor) -> Tensor:
        """
        Args:
            energy: Tensor
                the energy tensor with shape = [length, num_label, num_label]
            target: Tensor
                the tensor of target labels with shape [length]

        Returns: Tensor
                A 0D tensor for minus log likelihood loss
        """
        length, _, _ = energy.size()

        num_label_3 = self.indices_is.size(0)

        indices_3 = energy.new_empty((length, num_label_3)).long()
        indices_3[0, :] = self.indices_bs
        if length > 2:
            indices_3[1:length - 1, :] = self.indices_is.repeat(
                (length - 2, 1))
        indices_3[length - 1, :] = self.indices_es

        # shape = [num_label]
        partition_1 = None
        partition_3 = None

        # shape = []
        prev_label = self.index_bos
        tgt_energy = 0

        for t in range(length):
            # shape = [num_label, num_label]
            curr_energy = energy[t]
            if t == 0:
                partition_1 = curr_energy[self.index_bos, :]
                partition_3 = energy.new_full((num_label_3, ), -1e4)
            else:
                # shape = [num_label]
                partition = partition_1.clone()
                partition[indices_3[t - 1]] = partition_3
                partition_1 = logsumexp(curr_energy + partition_1.unsqueeze(1),
                                        dim=0)
                partition_3 = logsumexp(curr_energy[:, indices_3[t]] +
                                        partition.unsqueeze(1),
                                        dim=0)
            label = target[t]
            tgt_energy += curr_energy[prev_label, label]
            prev_label = label

        t = length - 1
        curr_energy = self.trans_matrix.data[:, self.index_eos]
        partition = curr_energy + partition_1
        partition[indices_3[t]] = curr_energy[indices_3[t]] + partition_3
        return logsumexp(partition, dim=0) - tgt_energy
예제 #2
0
    def decode_nest(self, energy: Tensor) -> Tensor:
        """
        Args:
            energy: Tensor
                the energy tensor with shape = [length, num_label, num_label]

        Returns: Tensor
            decoding nested results in shape [length]
        """

        # the last row and column is the tag for pad symbol. reduce these two dimensions by 1 to remove that.
        # also remove the first #symbolic rows and columns.
        # now the shape of energies_shuffled is [n_time_steps, t, t] where t = num_labels - #symbolic - 1.
        energy_transpose = energy[:, :self.index_bos, :self.index_bos]

        length, num_label, _ = energy_transpose.size()

        num_label_3 = self.indices_is.size(0)

        indices_3 = energy.new_empty((length, num_label_3)).long()
        indices_3[0, :] = self.indices_bs
        if length > 2:
            indices_3[1:length - 1, :] = self.indices_is.repeat(
                (length - 2, 1))
        indices_3[length - 1, :] = self.indices_es

        pointer_1 = energy.new_zeros((length, num_label)).long()
        pointer_3 = energy.new_zeros((length, num_label)).long()
        back_pointer = pointer_3.new_zeros(length)

        pi_1 = energy[0, self.index_bos, :self.index_bos]
        pi_3 = energy.new_full((num_label_3, ), -1e4)
        pointer_1[0] = self.index_bos
        pointer_3[0] = self.index_bos
        for t in range(1, length):
            e_t = energy_transpose[t]
            pi = pi_1.clone()
            pi[indices_3[t - 1]] = pi_3
            pi_1, pointer_1[t] = torch.max(e_t + pi_1.unsqueeze(1), dim=0)
            pi_3, pointer_3[t, indices_3[t]] = torch.max(e_t[:, indices_3[t]] +
                                                         pi.unsqueeze(1),
                                                         dim=0)
        t = length - 1
        e_t = self.trans_matrix.data[:self.index_bos, self.index_eos]
        pi = e_t + pi_1
        pi[indices_3[t]] = e_t[indices_3[t]] + pi_3

        _, back_pointer[-1] = torch.max(pi, dim=0)
        t = length - 2
        while t > -1:
            if (indices_3[t + 1] == back_pointer[t +
                                                 1]).nonzero().numel() == 0:
                break
            pointer_last = pointer_3[t + 1]
            back_pointer[t] = pointer_last[back_pointer[t + 1]]
            t -= 1
        while t > -1:
            pointer_last = pointer_1[t + 1]
            back_pointer[t] = pointer_last[back_pointer[t + 1]]
            t -= 1

        return back_pointer