def argmax_decode(self, z, decoder_self_attention_bias, **kwargs): hparams = self._hparams body_output = ops.decoder("decoder", z, hparams, decoder_self_attention_bias, **kwargs) logits = self.top(body_output, {"targets": None}) targets = tf.argmax(logits, axis=-1) targets_emb = self.bottom({"targets": targets})["targets"][..., 0, :] return targets, targets_emb
def compute_iw_marginal( self, targets, targets_mask, decoder_self_attention_bias, features, n_samples, reduce_mean=True, **kwargs): hparams = self._hparams z_q, log_q_z, _ = self.sample_q( targets, targets_mask, decoder_self_attention_bias, n_samples=n_samples, temp=1.0, **kwargs) # [K*B, L, C] iw_kwargs = {key: ops.prepare_for_iw(value, n_samples) for ( key, value) in kwargs.items()} iw_targets_mask = ops.prepare_for_iw(targets_mask, n_samples) iw_decoder_self_attention_bias = ( common_attention.attention_bias_ignore_padding(1.0 - iw_targets_mask)) iw_features = copy.copy(features) iw_features["targets"] = ops.prepare_for_iw( features["targets"], n_samples) log_p_z_base, log_abs_det = self.compute_prior_log_prob( z_q, iw_targets_mask, iw_decoder_self_attention_bias, check_invertibility=False, **iw_kwargs) log_p_z = log_p_z_base + log_abs_det body_output = ops.decoder( "decoder", z_q, hparams, iw_decoder_self_attention_bias, **iw_kwargs) logits = self.top(body_output, iw_features) numerator, denominator = self.loss_iw(logits, iw_features) numerator = tf.reduce_sum(numerator[..., 0, 0], 1) # [K*B] denominator = tf.reduce_sum(denominator[..., 0, 0], 1) # [K*B] log_p_x = -1 * numerator / denominator log_q_z = gops.reduce_mean_over_l_sum_over_c(log_q_z, iw_targets_mask) log_p_z = log_p_z / tf.reduce_sum(iw_targets_mask, 1) log_p_x, log_q_z, log_p_z = [ops.unprepare_for_iw(ii, n_samples) for ii in [ log_p_x, log_q_z, log_p_z]] log_w_n = log_p_z - log_q_z log_w_n = tf.nn.log_softmax(log_w_n, axis=0) # [K, B] iw_marginal = log_p_x + log_w_n iw_marginal = tf.reduce_logsumexp(iw_marginal, 0) # [B] if reduce_mean: iw_marginal = tf.cast(tf.reduce_mean(iw_marginal, 0), tf.float32) # [1] else: iw_marginal = tf.cast(iw_marginal, tf.float32) # [1] return iw_marginal
def internal(self, features, real_features): """Main procedure for both training and inference.""" inputs = common_layers.flatten4d3d(features["inputs"]) targets = common_layers.flatten4d3d(features["targets"]) target_space = features["target_space_id"] hparams = self._hparams inputs_mask = ops.embedding_to_non_padding(inputs) inputs_length = tf.reduce_sum(inputs_mask, axis=-1) encoder_output, encoder_decoder_attention_bias = (ops.encoder( "encoder", hparams, inputs, target_space)) kwargs = { "encoder_output": encoder_output, "encoder_decoder_attention_bias": encoder_decoder_attention_bias } losses, monitor = {}, {} log_abs_det = tf.constant(0.0) if not self.is_predicting: # Training targets_mask = ops.embedding_to_non_padding(targets) targets_length = tf.reduce_sum(targets_mask, axis=-1) length_diff = targets_length - inputs_length decoder_self_attention_bias = ( common_attention.attention_bias_ignore_padding(1.0 - targets_mask)) z_q, log_q_z, q_dist = self.sample_q(targets, targets_mask, decoder_self_attention_bias, n_samples=1, temp=1.0, **kwargs) body_output = ops.decoder("decoder", z_q, hparams, decoder_self_attention_bias, **kwargs) logits = self.top(body_output, real_features) numerator, denominator = self.loss(logits, real_features) if not (self.is_evaluating and (hparams.compute_kl_refinement or hparams.compute_iw_marginal)): targets_length_pred, lenpred_loss = ops.predict_target_lengths( encoder_output, inputs_mask, hparams, length_diff) log_p_z_base, log_abs_det = self.compute_prior_log_prob( z_q, targets_mask, decoder_self_attention_bias, check_invertibility=False, **kwargs) losses, monitor = ops.save_log_loss( hparams, targets_mask, numerator, denominator, log_q_z, log_abs_det, log_p_z_base, z_q, lenpred_loss, targets_length_pred, targets_length) if self.is_evaluating: if hparams.compute_kl_refinement: z_p, _ = self.sample_p(targets_length, temp=self._decode_hparams.temp, check_invertibility=False, targets_mask=targets_mask, **kwargs) z_dq = self.delta_posterior( z_p, targets_mask, decoder_self_attention_bias, self._decode_hparams.n_gibbs_steps, **kwargs) log_q_z_ = q_dist.log_prob(z_dq) log_q_z_ = gops.reduce_mean_over_bl_sum_over_c( log_q_z_, targets_mask) losses = {"training": log_q_z_} if hparams.compute_iw_marginal: # if True: log_p_y_x = self.compute_iw_marginal( targets, targets_mask, decoder_self_attention_bias, real_features, self._decode_hparams.n_samples, **kwargs) # real_features, 1, **kwargs) losses = {"training": log_p_y_x} return logits, losses, monitor, targets_mask else: # Inference targets_length, _ = ops.predict_target_lengths( encoder_output, inputs_mask, hparams) targets_mask = ops.sequence_mask(targets_length, hparams) decoder_self_attention_bias = ( common_attention.attention_bias_ignore_padding(1.0 - targets_mask)) z_p, _ = self.sample_p(targets_length, temp=self._decode_hparams.temp, check_invertibility=False, **kwargs) z_q = self.delta_posterior(z_p, targets_mask, decoder_self_attention_bias, self._decode_hparams.n_gibbs_steps, **kwargs) # 0, **kwargs) body_output = ops.decoder("decoder", z_q, hparams, decoder_self_attention_bias, **kwargs) return body_output, losses, monitor, targets_mask