def forward(self, logits: Tensor, mask: Tensor) -> Tensor: """Adds the loss functions, weighted by the prefactor.""" loss = logits.new_zeros(()) for loss_fn, prefact in self.loss_fns.items(): if prefact != 0: loss += prefact * loss_fn(logits, mask) return loss
def forward(self, input: Tensor, mask: Tensor = None, hx: Tuple[Tensor, Tensor] = None) -> Tuple[Tensor, Tensor]: batch_size = input.size(0) if self.batch_first else input.size(1) if hx is None: num_directions = 2 if self.bidirectional else 1 hx = input.new_zeros((self.num_layers * num_directions, batch_size, self.hidden_size)) hx = (hx, hx) func = rnn_f.autograd_var_masked_rnn(num_layers=self.num_layers, batch_first=self.batch_first, bidirectional=self.bidirectional, lstm=True) self.reset_noise(batch_size) output, hidden = func(input, self.all_cells, hx, None if mask is None else mask.view(mask.size() + (1,))) return output, hidden
def loss(self, input: Tensor, target: Tensor, mask: Tensor = None) -> Tuple[Tensor, Tensor]: """ Args: input: Tensor the input tensor with shape = [batch, length, input_size] target: Tensor the tensor of target labels with shape [batch, length] mask: Tensor or None the mask tensor with shape = [batch, length] Returns: Tensor A 1D tensor for minus log likelihood loss """ batch, length, _ = input.size() energy = self.forward(input, mask=mask) # shape = [length, batch, num_label, num_label] energy_transpose = energy.transpose(0, 1) # shape = [length, batch] target_transpose = target.transpose(0, 1) # shape = [batch, num_label] partition = None # shape = [batch] batch_index = torch.arange(0, batch).type_as(input).long() prev_label = input.new_full((batch, ), self.index_bos).long() tgt_energy = input.new_zeros(batch) for t in range(length): # shape = [batch, num_label, num_label] curr_energy = energy_transpose[t] if t == 0: partition = curr_energy[:, self.index_bos, :] else: # shape = [batch, num_label] partition = logsumexp(curr_energy + partition.unsqueeze(2), dim=1) label = target_transpose[t] tgt_energy += curr_energy[batch_index, prev_label, label] prev_label = label return logsumexp( self.trans_matrix.data[:, self.index_eos].unsqueeze(0) + partition, dim=1) - tgt_energy, energy
def step(self, input: Tensor, hx: Tuple[Tensor, Tensor] = None, mask: Tensor = None) -> Tuple[Tensor, Tensor]: """ execute one step forward (only for one-directional RNN). Args: input (batch, input_size): input tensor of this step. hx (num_layers, batch, hidden_size): the hidden state of last step. mask (batch): the mask tensor of this step. Returns: output (batch, hidden_size): tensor containing the output of this step from the last layer of RNN. hn (num_layers, batch, hidden_size): tensor containing the hidden state of this step """ assert not self.bidirectional, "step only cannot be applied to bidirectional RNN." batch_size = input.size(0) if hx is None: hx = input.new_zeros((self.num_layers, batch_size, self.hidden_size)) hx = (hx, hx) func = rnn_f.autograd_var_masked_step(num_layers=self.num_layers, lstm=True) output, hidden = func(input, self.all_cells, hx, mask) return output, hidden
def decode_nest(self, energy: Tensor) -> Tensor: """ Args: energy: Tensor the energy tensor with shape = [length, num_label, num_label] Returns: Tensor decoding nested results in shape [length] """ # the last row and column is the tag for pad symbol. reduce these two dimensions by 1 to remove that. # also remove the first #symbolic rows and columns. # now the shape of energies_shuffled is [n_time_steps, t, t] where t = num_labels - #symbolic - 1. energy_transpose = energy[:, :self.index_bos, :self.index_bos] length, num_label, _ = energy_transpose.size() num_label_3 = self.indices_is.size(0) indices_3 = energy.new_empty((length, num_label_3)).long() indices_3[0, :] = self.indices_bs if length > 2: indices_3[1:length - 1, :] = self.indices_is.repeat( (length - 2, 1)) indices_3[length - 1, :] = self.indices_es pointer_1 = energy.new_zeros((length, num_label)).long() pointer_3 = energy.new_zeros((length, num_label)).long() back_pointer = pointer_3.new_zeros(length) pi_1 = energy[0, self.index_bos, :self.index_bos] pi_3 = energy.new_full((num_label_3, ), -1e4) pointer_1[0] = self.index_bos pointer_3[0] = self.index_bos for t in range(1, length): e_t = energy_transpose[t] pi = pi_1.clone() pi[indices_3[t - 1]] = pi_3 pi_1, pointer_1[t] = torch.max(e_t + pi_1.unsqueeze(1), dim=0) pi_3, pointer_3[t, indices_3[t]] = torch.max(e_t[:, indices_3[t]] + pi.unsqueeze(1), dim=0) t = length - 1 e_t = self.trans_matrix.data[:self.index_bos, self.index_eos] pi = e_t + pi_1 pi[indices_3[t]] = e_t[indices_3[t]] + pi_3 _, back_pointer[-1] = torch.max(pi, dim=0) t = length - 2 while t > -1: if (indices_3[t + 1] == back_pointer[t + 1]).nonzero().numel() == 0: break pointer_last = pointer_3[t + 1] back_pointer[t] = pointer_last[back_pointer[t + 1]] t -= 1 while t > -1: pointer_last = pointer_1[t + 1] back_pointer[t] = pointer_last[back_pointer[t + 1]] t -= 1 return back_pointer