def nests_loss(self, energy: Tensor, target: Tensor) -> Tensor: """ Args: energy: Tensor the energy tensor with shape = [length, num_label, num_label] target: Tensor the tensor of target labels with shape [length] Returns: Tensor A 0D tensor for minus log likelihood loss """ length, _, _ = energy.size() num_label_3 = self.indices_is.size(0) indices_3 = energy.new_empty((length, num_label_3)).long() indices_3[0, :] = self.indices_bs if length > 2: indices_3[1:length - 1, :] = self.indices_is.repeat( (length - 2, 1)) indices_3[length - 1, :] = self.indices_es # shape = [num_label] partition_1 = None partition_3 = None # shape = [] prev_label = self.index_bos tgt_energy = 0 for t in range(length): # shape = [num_label, num_label] curr_energy = energy[t] if t == 0: partition_1 = curr_energy[self.index_bos, :] partition_3 = energy.new_full((num_label_3, ), -1e4) else: # shape = [num_label] partition = partition_1.clone() partition[indices_3[t - 1]] = partition_3 partition_1 = logsumexp(curr_energy + partition_1.unsqueeze(1), dim=0) partition_3 = logsumexp(curr_energy[:, indices_3[t]] + partition.unsqueeze(1), dim=0) label = target[t] tgt_energy += curr_energy[prev_label, label] prev_label = label t = length - 1 curr_energy = self.trans_matrix.data[:, self.index_eos] partition = curr_energy + partition_1 partition[indices_3[t]] = curr_energy[indices_3[t]] + partition_3 return logsumexp(partition, dim=0) - tgt_energy
def decode_nest(self, energy: Tensor) -> Tensor: """ Args: energy: Tensor the energy tensor with shape = [length, num_label, num_label] Returns: Tensor decoding nested results in shape [length] """ # the last row and column is the tag for pad symbol. reduce these two dimensions by 1 to remove that. # also remove the first #symbolic rows and columns. # now the shape of energies_shuffled is [n_time_steps, t, t] where t = num_labels - #symbolic - 1. energy_transpose = energy[:, :self.index_bos, :self.index_bos] length, num_label, _ = energy_transpose.size() num_label_3 = self.indices_is.size(0) indices_3 = energy.new_empty((length, num_label_3)).long() indices_3[0, :] = self.indices_bs if length > 2: indices_3[1:length - 1, :] = self.indices_is.repeat( (length - 2, 1)) indices_3[length - 1, :] = self.indices_es pointer_1 = energy.new_zeros((length, num_label)).long() pointer_3 = energy.new_zeros((length, num_label)).long() back_pointer = pointer_3.new_zeros(length) pi_1 = energy[0, self.index_bos, :self.index_bos] pi_3 = energy.new_full((num_label_3, ), -1e4) pointer_1[0] = self.index_bos pointer_3[0] = self.index_bos for t in range(1, length): e_t = energy_transpose[t] pi = pi_1.clone() pi[indices_3[t - 1]] = pi_3 pi_1, pointer_1[t] = torch.max(e_t + pi_1.unsqueeze(1), dim=0) pi_3, pointer_3[t, indices_3[t]] = torch.max(e_t[:, indices_3[t]] + pi.unsqueeze(1), dim=0) t = length - 1 e_t = self.trans_matrix.data[:self.index_bos, self.index_eos] pi = e_t + pi_1 pi[indices_3[t]] = e_t[indices_3[t]] + pi_3 _, back_pointer[-1] = torch.max(pi, dim=0) t = length - 2 while t > -1: if (indices_3[t + 1] == back_pointer[t + 1]).nonzero().numel() == 0: break pointer_last = pointer_3[t + 1] back_pointer[t] = pointer_last[back_pointer[t + 1]] t -= 1 while t > -1: pointer_last = pointer_1[t + 1] back_pointer[t] = pointer_last[back_pointer[t + 1]] t -= 1 return back_pointer