def argmax_decode(self, z, decoder_self_attention_bias, **kwargs):
     hparams = self._hparams
     body_output = ops.decoder("decoder", z, hparams,
                               decoder_self_attention_bias, **kwargs)
     logits = self.top(body_output, {"targets": None})
     targets = tf.argmax(logits, axis=-1)
     targets_emb = self.bottom({"targets": targets})["targets"][..., 0, :]
     return targets, targets_emb
  def compute_iw_marginal(
      self, targets, targets_mask, decoder_self_attention_bias, features,
      n_samples, reduce_mean=True, **kwargs):
    hparams = self._hparams
    z_q, log_q_z, _ = self.sample_q(
        targets, targets_mask, decoder_self_attention_bias,
        n_samples=n_samples, temp=1.0, **kwargs)  # [K*B, L, C]
    iw_kwargs = {key: ops.prepare_for_iw(value, n_samples) for (
        key, value) in kwargs.items()}
    iw_targets_mask = ops.prepare_for_iw(targets_mask, n_samples)
    iw_decoder_self_attention_bias = (
        common_attention.attention_bias_ignore_padding(1.0 - iw_targets_mask))
    iw_features = copy.copy(features)
    iw_features["targets"] = ops.prepare_for_iw(
        features["targets"], n_samples)

    log_p_z_base, log_abs_det = self.compute_prior_log_prob(
        z_q, iw_targets_mask, iw_decoder_self_attention_bias,
        check_invertibility=False, **iw_kwargs)
    log_p_z = log_p_z_base + log_abs_det

    body_output = ops.decoder(
        "decoder", z_q, hparams, iw_decoder_self_attention_bias, **iw_kwargs)
    logits = self.top(body_output, iw_features)
    numerator, denominator = self.loss_iw(logits, iw_features)
    numerator = tf.reduce_sum(numerator[..., 0, 0], 1)  # [K*B]
    denominator = tf.reduce_sum(denominator[..., 0, 0], 1)  # [K*B]
    log_p_x = -1 * numerator / denominator
    log_q_z = gops.reduce_mean_over_l_sum_over_c(log_q_z, iw_targets_mask)
    log_p_z = log_p_z / tf.reduce_sum(iw_targets_mask, 1)

    log_p_x, log_q_z, log_p_z = [ops.unprepare_for_iw(ii, n_samples) for ii in [
        log_p_x, log_q_z, log_p_z]]

    log_w_n = log_p_z - log_q_z
    log_w_n = tf.nn.log_softmax(log_w_n, axis=0)  # [K, B]

    iw_marginal = log_p_x + log_w_n
    iw_marginal = tf.reduce_logsumexp(iw_marginal, 0)  # [B]

    if reduce_mean:
      iw_marginal = tf.cast(tf.reduce_mean(iw_marginal, 0), tf.float32)  # [1]
    else:
      iw_marginal = tf.cast(iw_marginal, tf.float32)  # [1]
    return iw_marginal
    def internal(self, features, real_features):
        """Main procedure for both training and inference."""
        inputs = common_layers.flatten4d3d(features["inputs"])
        targets = common_layers.flatten4d3d(features["targets"])
        target_space = features["target_space_id"]
        hparams = self._hparams
        inputs_mask = ops.embedding_to_non_padding(inputs)
        inputs_length = tf.reduce_sum(inputs_mask, axis=-1)

        encoder_output, encoder_decoder_attention_bias = (ops.encoder(
            "encoder", hparams, inputs, target_space))
        kwargs = {
            "encoder_output": encoder_output,
            "encoder_decoder_attention_bias": encoder_decoder_attention_bias
        }
        losses, monitor = {}, {}
        log_abs_det = tf.constant(0.0)

        if not self.is_predicting:
            # Training
            targets_mask = ops.embedding_to_non_padding(targets)
            targets_length = tf.reduce_sum(targets_mask, axis=-1)
            length_diff = targets_length - inputs_length
            decoder_self_attention_bias = (
                common_attention.attention_bias_ignore_padding(1.0 -
                                                               targets_mask))
            z_q, log_q_z, q_dist = self.sample_q(targets,
                                                 targets_mask,
                                                 decoder_self_attention_bias,
                                                 n_samples=1,
                                                 temp=1.0,
                                                 **kwargs)

            body_output = ops.decoder("decoder", z_q, hparams,
                                      decoder_self_attention_bias, **kwargs)
            logits = self.top(body_output, real_features)
            numerator, denominator = self.loss(logits, real_features)

            if not (self.is_evaluating and (hparams.compute_kl_refinement
                                            or hparams.compute_iw_marginal)):
                targets_length_pred, lenpred_loss = ops.predict_target_lengths(
                    encoder_output, inputs_mask, hparams, length_diff)
                log_p_z_base, log_abs_det = self.compute_prior_log_prob(
                    z_q,
                    targets_mask,
                    decoder_self_attention_bias,
                    check_invertibility=False,
                    **kwargs)
                losses, monitor = ops.save_log_loss(
                    hparams, targets_mask, numerator, denominator, log_q_z,
                    log_abs_det, log_p_z_base, z_q, lenpred_loss,
                    targets_length_pred, targets_length)

            if self.is_evaluating:
                if hparams.compute_kl_refinement:
                    z_p, _ = self.sample_p(targets_length,
                                           temp=self._decode_hparams.temp,
                                           check_invertibility=False,
                                           targets_mask=targets_mask,
                                           **kwargs)
                    z_dq = self.delta_posterior(
                        z_p, targets_mask, decoder_self_attention_bias,
                        self._decode_hparams.n_gibbs_steps, **kwargs)
                    log_q_z_ = q_dist.log_prob(z_dq)
                    log_q_z_ = gops.reduce_mean_over_bl_sum_over_c(
                        log_q_z_, targets_mask)
                    losses = {"training": log_q_z_}

                if hparams.compute_iw_marginal:
                    # if True:
                    log_p_y_x = self.compute_iw_marginal(
                        targets, targets_mask, decoder_self_attention_bias,
                        real_features, self._decode_hparams.n_samples,
                        **kwargs)
                    # real_features, 1, **kwargs)
                    losses = {"training": log_p_y_x}

            return logits, losses, monitor, targets_mask

        else:
            # Inference
            targets_length, _ = ops.predict_target_lengths(
                encoder_output, inputs_mask, hparams)
            targets_mask = ops.sequence_mask(targets_length, hparams)
            decoder_self_attention_bias = (
                common_attention.attention_bias_ignore_padding(1.0 -
                                                               targets_mask))
            z_p, _ = self.sample_p(targets_length,
                                   temp=self._decode_hparams.temp,
                                   check_invertibility=False,
                                   **kwargs)
            z_q = self.delta_posterior(z_p, targets_mask,
                                       decoder_self_attention_bias,
                                       self._decode_hparams.n_gibbs_steps,
                                       **kwargs)
            # 0, **kwargs)

            body_output = ops.decoder("decoder", z_q, hparams,
                                      decoder_self_attention_bias, **kwargs)
            return body_output, losses, monitor, targets_mask