def decode(self, context, states, sampling=False): """Decode the output states. """ if not self._stepwise_training and not sampling: states.feedback_embed = self.lookup_feedback(context.feedbacks) self.decode_step(context, states, full_sequence=True) return TensorMap(states) else: T = context.feedbacks.shape[1] state_stack = [] steps = T + 9 if sampling else T - 1 for t in range(steps): states = states.copy() states.t = t if sampling: if t == 0: feedback = context.feedbacks[:, 0].unsqueeze(0) else: logits = self.expand(states) feedback = logits.argmax(-1) states.prev_token = feedback states.feedback_embed = self.lookup_feedback( feedback.squeeze(0)) else: states.prev_token = context.feedbacks[:, t].unsqueeze(0) states.feedback_embed = context.feedback_embeds[:, t].unsqueeze( 0) self.decode_step(context, states) state_stack.append(states) return self.combine_states(state_stack)
def forward(self, x, y, sampling=False, return_code=False): """Model training. """ score_map = {} x_mask = self.to_float(torch.ne(x, 0)) y_mask = self.to_float(torch.ne(y, 0)) # ----------- Compute prior and approximated posterior -------------# # Compute p(z|x) prior_states = self.prior_encoder(x, x_mask) if OPTS.zeroprior: prior_prob = self.standard_gaussian_dist(x.shape[0], x.shape[1]) else: prior_prob = self.prior_prob_estimator(prior_states) # Compute q(z|x,y) and sample z q_states = self.compute_Q_states(self.x_embed_layer(x), x_mask, y, y_mask) # Sample latent variables from q(z|x,y) z_mask = x_mask sampled_z, q_prob = self.sample_from_Q(q_states) # ----------------- Convert the length of latents ------------------# # Compute length prediction loss length_scores = self.compute_length_predictor_loss(prior_states, sampled_z, z_mask, y_mask) score_map.update(length_scores) # Padding z to fit target states z_with_y_length = self.convert_length(sampled_z, z_mask, y_mask.sum(-1)) # -------------------------- Decoder -------------------------------# decoder_states = self.decoder(z_with_y_length, y_mask, prior_states, x_mask) # -------------------------- Compute losses ------------------------# decoder_outputs = TensorMap({"final_states": decoder_states}) denom = x.shape[0] if self._shard_size is not None and self._shard_size > 0: loss_scores, decoder_tensors, decoder_grads = self.compute_shard_loss( decoder_outputs, y, y_mask, denominator=denom, ignore_first_token=False, backward=False ) loss_scores["word_acc"] *= float(y_mask.shape[0]) / self.to_float(y_mask.sum()) score_map.update(loss_scores) else: raise SystemError("Shard size must be setted or the memory is not enough for this model.") score_map, remain_loss = self.compute_final_loss(q_prob, prior_prob, z_mask, score_map) # Report smoothed BLEU during validation if not torch.is_grad_enabled() and self.training_criteria == "BLEU": logits = self.expander_nn(decoder_outputs["final_states"]) predictions = logits.argmax(-1) score_map["BLEU"] = - self.get_BLEU(predictions, y) # -------------------------- Bacprop gradient --------------------# if self._shard_size is not None and self._shard_size > 0 and decoder_tensors is not None: decoder_tensors.append(remain_loss) decoder_grads.append(None) torch.autograd.backward(decoder_tensors, decoder_grads) # if torch.isnan(score_map["loss"]) or torch.isinf(score_map["loss"]): # import pdb;pdb.set_trace() return score_map
def forward(self, x, y, sampling=False, return_code=False): """Model training. """ score_map = {} seq = y mask = self.to_float(torch.ne(seq, 0)) # ----------- Compute prior and approximated posterior -------------# # Compute p(z|x) prior_prob = self.standard_gaussian_dist(seq.shape[0], seq.shape[1]) # Compute q(z|x,y) and sample z q_states = self.compute_Q_states(seq, mask) # Sample latent variables from q(z|x,y) sampled_z, q_prob = self.sample_from_Q(q_states) # -------------------------- Decoder -------------------------------# sampled_z = F.tanh(sampled_z) decoder_states = self.decoder(sampled_z, mask) # -------------------------- Compute losses ------------------------# decoder_outputs = TensorMap({"final_states": decoder_states}) denom = seq.shape[0] if self._shard_size is not None and self._shard_size > 0: loss_scores, decoder_tensors, decoder_grads = self.compute_shard_loss( decoder_outputs, seq, mask, denominator=denom, ignore_first_token=False, backward=False) loss_scores["word_acc"] *= float(mask.shape[0]) / self.to_float( mask.sum()) score_map.update(loss_scores) else: raise SystemError( "Shard size must be setted or the memory is not enough for this model." ) score_map, remain_loss = self.compute_final_loss( q_prob, prior_prob, mask, score_map) # -------------------------- Bacprop gradient --------------------# if self._shard_size is not None and self._shard_size > 0 and decoder_tensors is not None: decoder_tensors.append(remain_loss) decoder_grads.append(None) torch.autograd.backward(decoder_tensors, decoder_grads) return score_map
def forward(self, x, y, sampling=False, return_code=False): """Model training. """ score_map = {} x_mask = self.to_float(torch.ne(x, 0)) y_mask = self.to_float(torch.ne(y, 0)) batch_size = list(x.shape)[0] y_shape = list(y.shape) # Source sentence hidden states, shared between prior, posterior, decoder. x_states = self.embed_layer(x) x_states = self.x_encoder(x_states, x_mask) # Compute p(z|x) p_prob = self.compute_prior(y_mask, x_states, x_mask) # Compute q(z|x,y) q_prob = self.compute_posterior(y, y_mask, x_states, x_mask) q_mean, q_stddev = q_prob[..., :self.latent_dim], F.softplus( q_prob[..., self.latent_dim:]) if not self.training: z_q = q_mean else: z_q = q_mean + q_stddev * torch.randn_like(q_stddev) # Compute length prediction loss length_scores = self.compute_length_predictor_loss( x_states, x_mask, y_mask) score_map.update(length_scores) # -------------------------- Decoder -------------------------------# hid_q = self.lat2hid(z_q) decoder_states = self.decoder(hid_q, y_mask, x_states, x_mask) # -------------------------- Compute losses ------------------------# decoder_outputs = TensorMap({"final_states": decoder_states}) denom = x.shape[0] if self._shard_size is not None and self._shard_size > 0: loss_scores, decoder_tensors, decoder_grads = self.compute_shard_loss( decoder_outputs, y, y_mask, denominator=denom, ignore_first_token=False, backward=False) loss_scores["word_acc"] *= float(y_mask.shape[0]) / self.to_float( y_mask.sum()) score_map.update(loss_scores) else: raise SystemError( "Shard size must be setted or the memory is not enough for this model." ) score_map, remain_loss = self.compute_final_loss( q_prob, p_prob, y_mask, score_map) # Report smoothed BLEU during validation if not torch.is_grad_enabled() and not self.training: decoder_outputs.unselect_batch() logits = self.expander_nn(decoder_outputs["final_states"]) predictions = logits.argmax(-1) score_map["BLEU"] = -self.get_BLEU((predictions * y_mask).long(), (y * y_mask).long()) # -------------------------- Bacprop gradient --------------------# if self._shard_size is not None and self._shard_size > 0 and decoder_tensors is not None: decoder_tensors.append(remain_loss) decoder_grads.append(None) torch.autograd.backward(decoder_tensors, decoder_grads) if torch.isnan(score_map["loss"]) or torch.isinf(score_map["loss"]): import pdb pdb.set_trace() return score_map