def forward( self, source: torch.Tensor, # (b, max_sou_seq_len) source_mask: torch.Tensor, # (b, max_sou_seq_len) target: torch.Tensor, # (b, max_tar_seq_len) target_mask: torch.Tensor, # (b, max_tar_seq_len) label: torch.Tensor, # (b, max_tar_seq_len) annealing: float ) -> Tuple[torch.Tensor, Tuple]: # (b, max_tar_seq_len, d_emb) b = source.size(0) source_embedded = self.source_embed( source, source_mask) # (b, max_sou_seq_len, d_s_emb) e_out, (hidden, _) = self.encoder(source_embedded, source_mask) h = self.transform(hidden, True) # (n_e_lay * b, d_e_hid * n_dir) z_mu = self.z_mu(h) # (n_e_lay * b, d_e_hid * n_dir) z_ln_var = self.z_ln_var(h) # (n_e_lay * b, d_e_hid * n_dir) hidden = Gaussian(z_mu, z_ln_var).rsample() # reparameterization trick # (n_e_lay * b, d_e_hid * n_dir) -> (b, d_e_hid * n_dir), initialize cell state states = (self.transform(hidden, False), self.transform(hidden.new_zeros(hidden.size()), False)) max_tar_seq_len = target.size(1) output = source_embedded.new_zeros( (b, max_tar_seq_len, self.target_vocab_size)) target_embedded = self.target_embed( target, target_mask) # (b, max_tar_seq_len, d_t_emb) target_embedded = target_embedded.transpose( 1, 0) # (max_tar_seq_len, b, d_t_emb) total_context_loss = 0 # decode per word for i in range(max_tar_seq_len): d_out, states = self.decoder(target_embedded[i], target_mask[:, i], states) if self.attention: context, cs = self.calculate_context_vector( e_out, states[0], source_mask, True) # (b, d_d_hid) total_context_loss += self.calculate_context_loss(cs) d_out = torch.cat((d_out, context), dim=-1) # (b, d_d_hid * 2) output[:, i, :] = self.w(self.maxout( d_out)) # (b, d_d_hid) -> (b, d_out) -> (b, tar_vocab_size) loss, details = self.calculate_loss(output, target_mask, label, z_mu, z_ln_var, total_context_loss, annealing) if torch.isnan(loss).any(): raise ValueError('nan detected') return loss, details
def predict( self, source: torch.Tensor, # (b, max_sou_seq_len) source_mask: torch.Tensor, # (b, max_sou_seq_len) sampling: bool = True ) -> torch.Tensor: # (b, max_seq_len) self.eval() with torch.no_grad(): b = source.size(0) source_embedded = self.source_embed( source, source_mask) # (b, max_seq_len, d_s_emb) e_out, (hidden, _) = self.encoder(source_embedded, source_mask) h = self.transform(hidden, True) z_mu = self.z_mu(h) z_ln_var = self.z_ln_var(h) hidden = Gaussian(z_mu, z_ln_var).sample() if sampling else z_mu states = (self.transform(hidden, False), self.transform(hidden.new_zeros(hidden.size()), False)) target_id = torch.full((b, 1), BOS, dtype=source.dtype).to(source.device) target_mask = torch.full( (b, 1), 1, dtype=source_mask.dtype).to(source_mask.device) predictions = source_embedded.new_zeros(b, self.max_seq_len, 1) for i in range(self.max_seq_len): target_embedded = self.target_embed( target_id, target_mask).squeeze(1) # (b, d_t_emb) d_out, states = self.decoder(target_embedded, target_mask[:, 0], states) if self.attention: context, _ = self.calculate_context_vector( e_out, states[0], source_mask, False) d_out = torch.cat((d_out, context), dim=-1) output = self.w(self.maxout(d_out)) # (b, tar_vocab_size) output[:, UNK] -= 1e6 # mask <UNK> if i == 0: output[:, EOS] -= 1e6 # avoid 0 length output prediction = torch.argmax(F.softmax(output, dim=1), dim=1).unsqueeze(1) # (b, 1), greedy target_mask = target_mask * prediction.ne(EOS).type( target_mask.dtype) target_id = prediction predictions[:, i, :] = prediction return predictions