예제 #1
0
def save_log_loss(hparams, targets_mask, numerator, denominator, log_q_z,
                  log_abs_det, log_p_z_base, z_q, lenpred_loss,
                  targets_length_pred, targets_length):
    """Populate loss dictionary and summary."""
    anneal, kl_mask = get_anneal_mask(hparams)
    lenpred_acc, lenpred_acc5 = (lenpred_stats(targets_length_pred,
                                               targets_length))
    batch_length = tf.reduce_sum(targets_mask)

    z_q_norm = gops.reduce_mean_over_bl(tf.norm(z_q, axis=2, keepdims=True),
                                        targets_mask)[0]

    log_q_z = gops.reduce_mean_over_bl_sum_over_c(log_q_z, targets_mask)
    log_p_z_base = tf.reduce_sum(log_p_z_base, axis=0) / batch_length
    log_abs_det = tf.reduce_sum(log_abs_det, axis=0) / batch_length
    log_p_z_reg = gops.standard_normal_density(z_q,
                                               targets_mask,
                                               reduce_sum=True)

    log_p_x = -1 * numerator / denominator
    log_p_z = log_p_z_base + log_abs_det
    kl = log_q_z - log_p_z
    kl_reg = log_p_z - log_p_z_reg
    elbo = log_p_x - kl
    monitor = {
        "elbo": elbo,
        "kl": kl,
        "kl_reg": kl_reg,
        "log_p_x": log_p_x,
        "log_q_z": log_q_z,
        "log_p_z": log_p_z,
        "log_p_z_base": log_p_z_base,
        "log_abs_det": log_abs_det,
        "anneal": anneal,
        "z_q_norm": z_q_norm,
        "lenpred_acc": lenpred_acc,
        "lenpred_acc5": lenpred_acc5,
    }

    kl = kl * anneal
    kl_reg = hparams.kl_reg * kl_reg * anneal
    loss_dict = {
        "training": -1 * log_p_x,
        "kl": kl * kl_mask,
        "kl_reg": kl_reg * kl_mask,
    }
    if lenpred_loss is not None:
        monitor["lenpred_loss"] = lenpred_loss
        loss_dict["lenpred_loss"] = lenpred_loss
    return loss_dict, monitor
    def internal(self, features, real_features):
        """Main procedure for both training and inference."""
        inputs = common_layers.flatten4d3d(features["inputs"])
        targets = common_layers.flatten4d3d(features["targets"])
        target_space = features["target_space_id"]
        hparams = self._hparams
        inputs_mask = ops.embedding_to_non_padding(inputs)
        inputs_length = tf.reduce_sum(inputs_mask, axis=-1)

        encoder_output, encoder_decoder_attention_bias = (ops.encoder(
            "encoder", hparams, inputs, target_space))
        kwargs = {
            "encoder_output": encoder_output,
            "encoder_decoder_attention_bias": encoder_decoder_attention_bias
        }
        losses, monitor = {}, {}
        log_abs_det = tf.constant(0.0)

        if not self.is_predicting:
            # Training
            targets_mask = ops.embedding_to_non_padding(targets)
            targets_length = tf.reduce_sum(targets_mask, axis=-1)
            length_diff = targets_length - inputs_length
            decoder_self_attention_bias = (
                common_attention.attention_bias_ignore_padding(1.0 -
                                                               targets_mask))
            z_q, log_q_z, q_dist = self.sample_q(targets,
                                                 targets_mask,
                                                 decoder_self_attention_bias,
                                                 n_samples=1,
                                                 temp=1.0,
                                                 **kwargs)

            body_output = ops.decoder("decoder", z_q, hparams,
                                      decoder_self_attention_bias, **kwargs)
            logits = self.top(body_output, real_features)
            numerator, denominator = self.loss(logits, real_features)

            if not (self.is_evaluating and (hparams.compute_kl_refinement
                                            or hparams.compute_iw_marginal)):
                targets_length_pred, lenpred_loss = ops.predict_target_lengths(
                    encoder_output, inputs_mask, hparams, length_diff)
                log_p_z_base, log_abs_det = self.compute_prior_log_prob(
                    z_q,
                    targets_mask,
                    decoder_self_attention_bias,
                    check_invertibility=False,
                    **kwargs)
                losses, monitor = ops.save_log_loss(
                    hparams, targets_mask, numerator, denominator, log_q_z,
                    log_abs_det, log_p_z_base, z_q, lenpred_loss,
                    targets_length_pred, targets_length)

            if self.is_evaluating:
                if hparams.compute_kl_refinement:
                    z_p, _ = self.sample_p(targets_length,
                                           temp=self._decode_hparams.temp,
                                           check_invertibility=False,
                                           targets_mask=targets_mask,
                                           **kwargs)
                    z_dq = self.delta_posterior(
                        z_p, targets_mask, decoder_self_attention_bias,
                        self._decode_hparams.n_gibbs_steps, **kwargs)
                    log_q_z_ = q_dist.log_prob(z_dq)
                    log_q_z_ = gops.reduce_mean_over_bl_sum_over_c(
                        log_q_z_, targets_mask)
                    losses = {"training": log_q_z_}

                if hparams.compute_iw_marginal:
                    # if True:
                    log_p_y_x = self.compute_iw_marginal(
                        targets, targets_mask, decoder_self_attention_bias,
                        real_features, self._decode_hparams.n_samples,
                        **kwargs)
                    # real_features, 1, **kwargs)
                    losses = {"training": log_p_y_x}

            return logits, losses, monitor, targets_mask

        else:
            # Inference
            targets_length, _ = ops.predict_target_lengths(
                encoder_output, inputs_mask, hparams)
            targets_mask = ops.sequence_mask(targets_length, hparams)
            decoder_self_attention_bias = (
                common_attention.attention_bias_ignore_padding(1.0 -
                                                               targets_mask))
            z_p, _ = self.sample_p(targets_length,
                                   temp=self._decode_hparams.temp,
                                   check_invertibility=False,
                                   **kwargs)
            z_q = self.delta_posterior(z_p, targets_mask,
                                       decoder_self_attention_bias,
                                       self._decode_hparams.n_gibbs_steps,
                                       **kwargs)
            # 0, **kwargs)

            body_output = ops.decoder("decoder", z_q, hparams,
                                      decoder_self_attention_bias, **kwargs)
            return body_output, losses, monitor, targets_mask