def __init__(self, d_emb: int, d_hid: int, embeddings: torch.Tensor or int, n_class: int, bi_directional: bool = True, dropout_rate: float = 0.333, n_layer: int = 1) -> None: super(SelfAttentionLSTM, self).__init__() self.vocab_size = embeddings if type( embeddings) is int else embeddings.size(0) self.embed = Embedder(self.vocab_size, d_emb) if type(embeddings) is not int: self.embed.set_initial_embedding(embeddings, freeze=True) self.rnn = RNNWrapper( nn.LSTM(input_size=d_emb, hidden_size=d_hid, num_layers=n_layer, batch_first=True, dropout=dropout_rate, bidirectional=bi_directional)) self.attention = nn.Linear(d_hid * 2, 1) self.w_1 = nn.Linear(d_hid * 2, d_hid) self.tanh = nn.Tanh() self.w_2 = nn.Linear(d_hid, n_class) self.dropout = nn.Dropout(p=dropout_rate) self.params = { "BiDirectional": bi_directional, "DropoutRate": dropout_rate, "NLayer": n_layer, "VocabSize": self.vocab_size }
class CNN(nn.Module): def __init__(self, d_emb: int, embeddings: torch.Tensor or int, kernel_widths: List[int], n_class: int, dropout_rate: float = 0.333, n_filter: int = 128) -> None: super(CNN, self).__init__() self.vocab_size = embeddings if type( embeddings) is int else embeddings.size(0) self.embed = Embedder(self.vocab_size, d_emb) if type(embeddings) is not int: self.embed.set_initial_embedding(embeddings, freeze=False) assert len( kernel_widths) > 1, 'kernel_widths need at least two elements' n_kernel = len(kernel_widths) self.poolers = nn.ModuleList([ CNNPooler(d_emb=d_emb, kernel_width=kernel_widths[i]) for i in range(n_kernel) ]) # highway architecture self.sigmoid = nn.Sigmoid() self.transform_gate = nn.Linear(n_filter * n_kernel, n_filter * n_kernel) self.highway = nn.Linear(n_filter * n_kernel, n_filter * n_kernel) self.dropout = nn.Dropout(p=dropout_rate) self.fc = nn.Linear(n_filter * n_kernel, n_class) self.params = { "DropoutRate": dropout_rate, "KernelWidths": kernel_widths, "NFilter": n_filter, "VocabSize": self.vocab_size } def forward( self, x: torch.Tensor, # (b, len, d_emb) mask: torch.Tensor, # (b, len) ) -> torch.Tensor: embedded = self.embed(x, mask) embedded = embedded.unsqueeze(1) # (b, 1, max_seq_len, d_emb) pooled = self.poolers[0](embedded, mask) for pooler in self.poolers[1:]: pooled = torch.cat((pooled, pooler(embedded, mask)), dim=1) pooled = pooled.squeeze(-1) # (b, num_filters * n_kernel) t = self.sigmoid( self.transform_gate(pooled)) # (b, num_filters * n_kernel) hw = t * F.relu(self.highway(pooled)) + ( 1 - t) * pooled # (b, num_filters * n_kernel) y = self.fc(self.dropout(hw)) # (b, n_class) return y
def __init__( self, d_e_hid: int, max_seq_len: int, source_embeddings: torch.Tensor, # TODO: Optional embeddings target_embeddings: torch.Tensor, attention: bool = True, bi_directional: bool = True, dropout_rate: float = 0.333, freeze: bool = False, n_e_layer: int = 2, n_d_layer: int = 1) -> None: super(VariationalSeq2seq, self).__init__() self.max_seq_len = max_seq_len self.source_vocab_size, self.d_s_emb = source_embeddings.size() self.target_vocab_size, self.d_t_emb = target_embeddings.size() self.source_embed = Embedder(self.source_vocab_size, self.d_s_emb) self.source_embed.set_initial_embedding(source_embeddings, freeze=freeze) self.target_embed = Embedder(self.target_vocab_size, self.d_t_emb) self.target_embed.set_initial_embedding(target_embeddings, freeze=freeze) self.n_e_lay = n_e_layer self.bi_directional = bi_directional self.n_dir = 2 if bi_directional else 1 self.encoder = Encoder(rnn=nn.LSTM(input_size=self.d_s_emb, hidden_size=d_e_hid, num_layers=self.n_e_lay, batch_first=True, dropout=dropout_rate, bidirectional=bi_directional)) self.z_mu = nn.Linear(d_e_hid * self.n_dir, d_e_hid * self.n_dir) self.z_ln_var = nn.Linear(d_e_hid * self.n_dir, d_e_hid * self.n_dir) self.attention = attention self.d_d_hid = d_e_hid * self.n_dir self.n_d_lay = n_d_layer assert self.d_d_hid % self.n_dir == 0, 'invalid d_e_hid' self.d_c_hid = self.d_d_hid if attention else 0 self.d_out = (self.d_d_hid + self.d_c_hid) // self.n_dir self.decoder = Decoder( rnn=nn.LSTMCell(input_size=self.d_t_emb, hidden_size=self.d_d_hid)) self.c_tanh = nn.Tanh() self.c_linear = nn.Linear(self.d_c_hid, self.d_c_hid) self.maxout = Maxout(self.d_d_hid + self.d_c_hid, self.d_out, self.n_dir) self.w = nn.Linear(self.d_out, self.target_vocab_size)
def __init__(self, d_emb: int, embeddings: torch.Tensor or int, kernel_widths: List[int], n_class: int, dropout_rate: float = 0.333, n_filter: int = 128) -> None: super(CNN, self).__init__() self.vocab_size = embeddings if type( embeddings) is int else embeddings.size(0) self.embed = Embedder(self.vocab_size, d_emb) if type(embeddings) is not int: self.embed.set_initial_embedding(embeddings, freeze=False) assert len( kernel_widths) > 1, 'kernel_widths need at least two elements' n_kernel = len(kernel_widths) self.poolers = nn.ModuleList([ CNNPooler(d_emb=d_emb, kernel_width=kernel_widths[i]) for i in range(n_kernel) ]) # highway architecture self.sigmoid = nn.Sigmoid() self.transform_gate = nn.Linear(n_filter * n_kernel, n_filter * n_kernel) self.highway = nn.Linear(n_filter * n_kernel, n_filter * n_kernel) self.dropout = nn.Dropout(p=dropout_rate) self.fc = nn.Linear(n_filter * n_kernel, n_class) self.params = { "DropoutRate": dropout_rate, "KernelWidths": kernel_widths, "NFilter": n_filter, "VocabSize": self.vocab_size }
class SelfAttentionLSTM(nn.Module): def __init__(self, d_emb: int, d_hid: int, embeddings: torch.Tensor or int, n_class: int, bi_directional: bool = True, dropout_rate: float = 0.333, n_layer: int = 1) -> None: super(SelfAttentionLSTM, self).__init__() self.vocab_size = embeddings if type( embeddings) is int else embeddings.size(0) self.embed = Embedder(self.vocab_size, d_emb) if type(embeddings) is not int: self.embed.set_initial_embedding(embeddings, freeze=True) self.rnn = RNNWrapper( nn.LSTM(input_size=d_emb, hidden_size=d_hid, num_layers=n_layer, batch_first=True, dropout=dropout_rate, bidirectional=bi_directional)) self.attention = nn.Linear(d_hid * 2, 1) self.w_1 = nn.Linear(d_hid * 2, d_hid) self.tanh = nn.Tanh() self.w_2 = nn.Linear(d_hid, n_class) self.dropout = nn.Dropout(p=dropout_rate) self.params = { "BiDirectional": bi_directional, "DropoutRate": dropout_rate, "NLayer": n_layer, "VocabSize": self.vocab_size } def forward( self, x: torch.Tensor, # (b, max_seq_len, d_emb) mask: torch.Tensor, # (b, max_seq_len) ) -> torch.Tensor: # (b, n_class) embedded = self.embed(x, mask) rnn_out = self.rnn(embedded, mask) # (b, seq_len, d_hid * 2) alignment_weights = self.calculate_alignment_weights( rnn_out, mask) # (b, seq_len, 1) out = (alignment_weights * rnn_out).sum(dim=1) # (b, d_hid * 2) h = self.tanh(self.w_1(out)) # (b, d_hid) y = self.w_2(self.dropout(h)) # (b, n_class) return y def calculate_alignment_weights( self, rnn_out: torch.Tensor, # (b, max_seq_len, d_hid * 2) mask: torch.Tensor # (b, max_seq_len) ) -> torch.Tensor: max_len = rnn_out.size(1) alignment_weights = self.attention(rnn_out) # (b, seq_len, 1) alignment_weights_mask = mask.unsqueeze(-1).type( alignment_weights.dtype) alignment_weights.masked_fill_( alignment_weights_mask[:, :max_len, :].ne(1), -1e6) return F.softmax(alignment_weights, dim=1) # (b, seq_len, 1)
class VariationalSeq2seq(nn.Module): def __init__( self, d_e_hid: int, max_seq_len: int, source_embeddings: torch.Tensor, # TODO: Optional embeddings target_embeddings: torch.Tensor, attention: bool = True, bi_directional: bool = True, dropout_rate: float = 0.333, freeze: bool = False, n_e_layer: int = 2, n_d_layer: int = 1) -> None: super(VariationalSeq2seq, self).__init__() self.max_seq_len = max_seq_len self.source_vocab_size, self.d_s_emb = source_embeddings.size() self.target_vocab_size, self.d_t_emb = target_embeddings.size() self.source_embed = Embedder(self.source_vocab_size, self.d_s_emb) self.source_embed.set_initial_embedding(source_embeddings, freeze=freeze) self.target_embed = Embedder(self.target_vocab_size, self.d_t_emb) self.target_embed.set_initial_embedding(target_embeddings, freeze=freeze) self.n_e_lay = n_e_layer self.bi_directional = bi_directional self.n_dir = 2 if bi_directional else 1 self.encoder = Encoder(rnn=nn.LSTM(input_size=self.d_s_emb, hidden_size=d_e_hid, num_layers=self.n_e_lay, batch_first=True, dropout=dropout_rate, bidirectional=bi_directional)) self.z_mu = nn.Linear(d_e_hid * self.n_dir, d_e_hid * self.n_dir) self.z_ln_var = nn.Linear(d_e_hid * self.n_dir, d_e_hid * self.n_dir) self.attention = attention self.d_d_hid = d_e_hid * self.n_dir self.n_d_lay = n_d_layer assert self.d_d_hid % self.n_dir == 0, 'invalid d_e_hid' self.d_c_hid = self.d_d_hid if attention else 0 self.d_out = (self.d_d_hid + self.d_c_hid) // self.n_dir self.decoder = Decoder( rnn=nn.LSTMCell(input_size=self.d_t_emb, hidden_size=self.d_d_hid)) self.c_tanh = nn.Tanh() self.c_linear = nn.Linear(self.d_c_hid, self.d_c_hid) self.maxout = Maxout(self.d_d_hid + self.d_c_hid, self.d_out, self.n_dir) self.w = nn.Linear(self.d_out, self.target_vocab_size) def forward( self, source: torch.Tensor, # (b, max_sou_seq_len) source_mask: torch.Tensor, # (b, max_sou_seq_len) target: torch.Tensor, # (b, max_tar_seq_len) target_mask: torch.Tensor, # (b, max_tar_seq_len) label: torch.Tensor, # (b, max_tar_seq_len) annealing: float ) -> Tuple[torch.Tensor, Tuple]: # (b, max_tar_seq_len, d_emb) b = source.size(0) source_embedded = self.source_embed( source, source_mask) # (b, max_sou_seq_len, d_s_emb) e_out, (hidden, _) = self.encoder(source_embedded, source_mask) h = self.transform(hidden, True) # (n_e_lay * b, d_e_hid * n_dir) z_mu = self.z_mu(h) # (n_e_lay * b, d_e_hid * n_dir) z_ln_var = self.z_ln_var(h) # (n_e_lay * b, d_e_hid * n_dir) hidden = Gaussian(z_mu, z_ln_var).rsample() # reparameterization trick # (n_e_lay * b, d_e_hid * n_dir) -> (b, d_e_hid * n_dir), initialize cell state states = (self.transform(hidden, False), self.transform(hidden.new_zeros(hidden.size()), False)) max_tar_seq_len = target.size(1) output = source_embedded.new_zeros( (b, max_tar_seq_len, self.target_vocab_size)) target_embedded = self.target_embed( target, target_mask) # (b, max_tar_seq_len, d_t_emb) target_embedded = target_embedded.transpose( 1, 0) # (max_tar_seq_len, b, d_t_emb) total_context_loss = 0 # decode per word for i in range(max_tar_seq_len): d_out, states = self.decoder(target_embedded[i], target_mask[:, i], states) if self.attention: context, cs = self.calculate_context_vector( e_out, states[0], source_mask, True) # (b, d_d_hid) total_context_loss += self.calculate_context_loss(cs) d_out = torch.cat((d_out, context), dim=-1) # (b, d_d_hid * 2) output[:, i, :] = self.w(self.maxout( d_out)) # (b, d_d_hid) -> (b, d_out) -> (b, tar_vocab_size) loss, details = self.calculate_loss(output, target_mask, label, z_mu, z_ln_var, total_context_loss, annealing) if torch.isnan(loss).any(): raise ValueError('nan detected') return loss, details def predict( self, source: torch.Tensor, # (b, max_sou_seq_len) source_mask: torch.Tensor, # (b, max_sou_seq_len) sampling: bool = True ) -> torch.Tensor: # (b, max_seq_len) self.eval() with torch.no_grad(): b = source.size(0) source_embedded = self.source_embed( source, source_mask) # (b, max_seq_len, d_s_emb) e_out, (hidden, _) = self.encoder(source_embedded, source_mask) h = self.transform(hidden, True) z_mu = self.z_mu(h) z_ln_var = self.z_ln_var(h) hidden = Gaussian(z_mu, z_ln_var).sample() if sampling else z_mu states = (self.transform(hidden, False), self.transform(hidden.new_zeros(hidden.size()), False)) target_id = torch.full((b, 1), BOS, dtype=source.dtype).to(source.device) target_mask = torch.full( (b, 1), 1, dtype=source_mask.dtype).to(source_mask.device) predictions = source_embedded.new_zeros(b, self.max_seq_len, 1) for i in range(self.max_seq_len): target_embedded = self.target_embed( target_id, target_mask).squeeze(1) # (b, d_t_emb) d_out, states = self.decoder(target_embedded, target_mask[:, 0], states) if self.attention: context, _ = self.calculate_context_vector( e_out, states[0], source_mask, False) d_out = torch.cat((d_out, context), dim=-1) output = self.w(self.maxout(d_out)) # (b, tar_vocab_size) output[:, UNK] -= 1e6 # mask <UNK> if i == 0: output[:, EOS] -= 1e6 # avoid 0 length output prediction = torch.argmax(F.softmax(output, dim=1), dim=1).unsqueeze(1) # (b, 1), greedy target_mask = target_mask * prediction.ne(EOS).type( target_mask.dtype) target_id = prediction predictions[:, i, :] = prediction return predictions def transform( self, state: torch. Tensor, # (n_e_lay * n_dir, b, d_e_hid) or (n_e_lay * b, d_e_hid * n_dir) switch: bool ) -> torch.Tensor: if switch: b = state.size(1) state = state.contiguous().view(self.n_e_lay, self.n_dir, b, -1) state = state.permute(0, 2, 3, 1) # (n_e_lay, b, d_e_hid, n_dir) state = state.contiguous().view( self.n_e_lay * b, -1) # (n_e_lay * b, d_e_hid * n_dir) else: b = state.size(0) // self.n_e_lay state = state.contiguous().view(self.n_e_lay, b, -1) # extract hidden layer state = state[0] return state # variational soft attention, score is calculated by dot product def calculate_context_vector( self, encoder_hidden_states: torch. Tensor, # (b, max_sou_seq_len, d_e_hid * n_dir) previous_decoder_hidden_state: torch.Tensor, # (b, d_d_hid) source_mask: torch.Tensor, # (b, max_sou_seq_len) is_training: bool) -> Tuple[torch.Tensor, Tuple]: b, max_sou_seq_len, d_d_hid = encoder_hidden_states.size() # (b, max_sou_seq_len, d_d_hid) previous_decoder_hidden_states = previous_decoder_hidden_state.unsqueeze( 1).expand(b, max_sou_seq_len, d_d_hid) alignment_weights = (encoder_hidden_states * previous_decoder_hidden_states).sum(dim=-1) alignment_weights.masked_fill_(source_mask.ne(1), -1e6) alignment_weights = F.softmax(alignment_weights, dim=-1).unsqueeze( -1) # (b, max_sou_seq_len, 1) context_vector = (alignment_weights * encoder_hidden_states).sum( dim=1) # (b, d_d_hid) c_mu = context_vector c_ln_var = (self.c_linear(self.c_tanh(context_vector))).exp() context_vector = Gaussian( c_mu, c_ln_var).rsample() if is_training else Gaussian( c_mu, c_ln_var).sample() return context_vector, (c_mu, c_ln_var) @staticmethod def calculate_context_loss( cs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: c_mu, c_ln_var = cs b = c_mu.size(0) kl_divergence = (c_mu**2 + c_ln_var.exp() - c_ln_var - 1) * 0.5 context_loss = kl_divergence.sum() / b return context_loss @staticmethod def calculate_loss( output: torch.Tensor, # (b, max_tar_len, vocab_size) target_mask: torch.Tensor, # (b, max_tar_len) label: torch.Tensor, # (b, max_tar_len) mu: torch.Tensor, # (n_e_lay * b, d_e_hid * n_dir) ln_var: torch.Tensor, # (n_e_lay * b, d_e_hid * n_dir) total_context_loss: torch.Tensor, annealing: float, gamma: float = 10, ) -> Tuple[torch.Tensor, Tuple]: b, max_tar_len, vocab_size = output.size() label = label.masked_select(target_mask.eq(1)) prediction_mask = target_mask.unsqueeze(-1).expand( b, max_tar_len, vocab_size) # (b, max_tar_len, vocab_size) prediction = output.masked_select( prediction_mask.eq(1)).contiguous().view(-1, vocab_size) reconstruction_loss = F.cross_entropy( prediction, label, reduction='none').sum() / b kl_divergence = (mu**2 + ln_var.exp() - ln_var - 1) * 0.5 regularization_loss = kl_divergence.sum() / b return reconstruction_loss + annealing * (regularization_loss + gamma * total_context_loss), \ (reconstruction_loss, regularization_loss, total_context_loss)
class Seq2seq(nn.Module): def __init__( self, d_e_hid: int, max_seq_len: int, source_embeddings: torch.Tensor, # TODO: Optional embeddings target_embeddings: torch.Tensor, attention: bool = True, bi_directional: bool = True, dropout_rate: float = 0.333, freeze: bool = False, n_e_layer: int = 2, n_d_layer: int = 1) -> None: super(Seq2seq, self).__init__() self.max_seq_len = max_seq_len self.source_vocab_size, self.d_s_emb = source_embeddings.size() self.target_vocab_size, self.d_t_emb = target_embeddings.size() self.source_embed = Embedder(self.source_vocab_size, self.d_s_emb) self.source_embed.set_initial_embedding(source_embeddings, freeze=freeze) self.target_embed = Embedder(self.target_vocab_size, self.d_t_emb) self.target_embed.set_initial_embedding(target_embeddings, freeze=freeze) self.n_e_lay = n_e_layer self.bi_directional = bi_directional self.n_dir = 2 if bi_directional else 1 self.encoder = Encoder(rnn=nn.LSTM(input_size=self.d_s_emb, hidden_size=d_e_hid, num_layers=self.n_e_lay, batch_first=True, dropout=dropout_rate, bidirectional=bi_directional)) self.attention = attention self.d_d_hid = d_e_hid * self.n_dir self.n_d_lay = n_d_layer assert self.d_d_hid % self.n_dir == 0, 'invalid d_e_hid' self.d_c_hid = self.d_d_hid if attention else 0 self.d_out = (self.d_d_hid + self.d_c_hid) // self.n_dir self.decoder = Decoder( rnn=nn.LSTMCell(input_size=self.d_t_emb, hidden_size=self.d_d_hid)) self.maxout = Maxout(self.d_d_hid + self.d_c_hid, self.d_out, self.n_dir) self.w = nn.Linear(self.d_out, self.target_vocab_size) def forward( self, source: torch.Tensor, # (b, max_sou_seq_len) source_mask: torch.Tensor, # (b, max_sou_seq_len) target: torch.Tensor, # (b, max_tar_seq_len) target_mask: torch.Tensor, # (b, max_tar_seq_len) label: torch.Tensor # (b, max_tar_seq_len) ) -> torch.Tensor: # (b, max_tar_seq_len, d_emb) b = source.size(0) source_embedded = self.source_embed( source, source_mask) # (b, max_sou_seq_len, d_s_emb) e_out, (hidden, cell) = self.encoder(source_embedded, source_mask) if self.attention: states = None else: # (n_e_lay * n_dir, b, d_e_hid) -> (b, d_e_hid * n_dir), initialize cell state states = (self.transform(hidden), self.transform(cell.new_zeros(cell.size()))) max_tar_seq_len = target.size(1) output = source_embedded.new_zeros( (b, max_tar_seq_len, self.target_vocab_size)) target_embedded = self.target_embed( target, target_mask) # (b, max_tar_seq_len, d_t_emb) target_embedded = target_embedded.transpose( 1, 0) # (max_tar_seq_len, b, d_t_emb) # decode per word for i in range(max_tar_seq_len): d_out, states = self.decoder(target_embedded[i], target_mask[:, i], states) if self.attention: context = self.calculate_context_vector( e_out, states[0], source_mask) # (b, d_d_hid) d_out = torch.cat((d_out, context), dim=-1) # (b, d_d_hid * 2) output[:, i, :] = self.w(self.maxout( d_out)) # (b, d_d_hid) -> (b, d_out) -> (b, tar_vocab_size) loss = self.calculate_loss(output, target_mask, label) return loss def predict( self, source: torch.Tensor, # (b, max_sou_seq_len) source_mask: torch.Tensor, # (b, max_sou_seq_len) ) -> torch.Tensor: # (b, max_seq_len) self.eval() with torch.no_grad(): b = source.size(0) source_embedded = self.source_embed( source, source_mask) # (b, max_seq_len, d_s_emb) e_out, (hidden, cell) = self.encoder(source_embedded, source_mask) if self.attention: states = None else: # (n_e_lay * n_dir, b, d_e_hid) -> (b, d_e_hid * n_dir), initialize cell state states = (self.transform(hidden), self.transform(cell.new_zeros(cell.size()))) target_id = torch.full((b, 1), BOS, dtype=source.dtype).to(source.device) target_mask = torch.full( (b, 1), 1, dtype=source_mask.dtype).to(source_mask.device) predictions = source_embedded.new_zeros(b, self.max_seq_len, 1) for i in range(self.max_seq_len): target_embedded = self.target_embed( target_id, target_mask).squeeze(1) # (b, d_t_emb) d_out, states = self.decoder(target_embedded, target_mask[:, 0], states) if self.attention: context = self.calculate_context_vector( e_out, states[0], source_mask) # (b, d_d_hid) d_out = torch.cat((d_out, context), dim=-1) # (b, d_d_hid * 2) output = self.w(self.maxout(d_out)) # (b, tar_vocab_size) output[:, UNK] -= 1e6 # mask <UNK> if i == 0: output[:, EOS] -= 1e6 # avoid 0 length output prediction = torch.argmax(F.softmax(output, dim=1), dim=1).unsqueeze(1) # (b, 1), greedy target_mask = target_mask * prediction.ne(EOS).type( target_mask.dtype) target_id = prediction predictions[:, i, :] = prediction return predictions def transform( self, state: torch.Tensor # (n_e_lay * n_dir, b, d_e_hid) ) -> torch.Tensor: b = state.size(1) state = state.contiguous().view(self.n_e_lay, self.n_dir, b, -1) state = state.permute(0, 2, 3, 1) # (n_e_lay, b, d_e_hid, n_dir) state = state.contiguous().view(self.n_e_lay, b, -1) # (n_e_lay, b, d_e_hid * n_dir) # extract last hidden layer state = state[0] # (b, d_e_hid * n_dir) return state @staticmethod # soft attention, score is calculated by dot product def calculate_context_vector( encoder_hidden_states: torch. Tensor, # (b, max_sou_seq_len, d_e_hid * n_dir) previous_decoder_hidden_state: torch.Tensor, # (b, d_d_hid) source_mask: torch.Tensor # (b, max_sou_seq_len) ) -> torch.Tensor: b, max_sou_seq_len, d_d_hid = encoder_hidden_states.size() # (b, max_sou_seq_len, d_d_hid) previous_decoder_hidden_states = previous_decoder_hidden_state.unsqueeze( 1).expand(b, max_sou_seq_len, d_d_hid) alignment_weights = (encoder_hidden_states * previous_decoder_hidden_states).sum(dim=-1) alignment_weights.masked_fill_(source_mask.ne(1), -1e6) alignment_weights = F.softmax(alignment_weights, dim=-1).unsqueeze( -1) # (b, max_sou_seq_len, 1) context_vector = (alignment_weights * encoder_hidden_states).sum( dim=1) # (b, d_d_hid) return context_vector @staticmethod def calculate_loss( output: torch.Tensor, # (b, max_tar_len, vocab_size) target_mask: torch.Tensor, # (b, max_tar_len) label: torch.Tensor, # (b, max_tar_len) ) -> torch.Tensor: b, max_tar_len, vocab_size = output.size() label = label.masked_select(target_mask.eq(1)) prediction_mask = target_mask.unsqueeze(-1).expand( b, max_tar_len, vocab_size) # (b, max_tar_len, vocab_size) prediction = output.masked_select( prediction_mask.eq(1)).contiguous().view(-1, vocab_size) loss = F.cross_entropy(prediction, label, reduction='none').sum() / b return loss