예제 #1
0
    def _build_loss(self, results, features, labels, loss_config, **kwargs):
        losses, loss = getters.get_loss(loss_config.module, results, features, **loss_config.params)

        with get_name_scope('latent_loss'):
            z_mean = kwargs['z_mean']
            z_log_sigma = kwargs['z_log_sigma']

            latent_losses = -0.5 * tf.reduce_sum(
                1 + z_log_sigma - tf.square(z_mean) - tf.exp(z_log_sigma))
            latent_loss = tf.losses.compute_weighted_loss(latent_losses)

        losses += latent_losses
        loss += latent_loss
        return losses, loss
예제 #2
0
    def _build_loss(self, results, features, labels):
        """Creates the loss operation

        Returns:
             tuple `(losses, loss)`:
                `losses` are the per-batch losses.
                `loss` is a single scalar tensor to minimize.
        """
        losses, loss = getters.get_loss(
            self.loss.IDENTIFIER, results, labels, **self.loss.to_dict())
        self._loss = loss
        self._losses = losses

        other_losses = get_tracked(tf.GraphKeys.REGULARIZATION_LOSSES)
        if other_losses:
            loss = [loss] + other_losses
            loss = tf.add_n(loss, name="TotalLoss")
            self._total_loss = loss
        return losses, loss
예제 #3
0
    def _build_loss(self, results, features, labels):
        """Creates the loss operation

        Returns:
             tuple `(losses, loss)`:
                `losses` are the per-batch losses.
                `loss` is a single scalar tensor to minimize.
        """
        losses, loss = getters.get_loss(
            self.loss_config.module, results, labels, **self.loss_config.params)
        self._loss = loss
        self._losses = losses

        other_losses = get_tracked(tf.GraphKeys.REGULARIZATION_LOSSES)
        if other_losses:
            loss = [loss] + other_losses
            loss = tf.add_n(loss, name="TotalLoss")
            self._total_loss = loss
        return losses, loss
예제 #4
0
 def _build_loss(self, results, features, labels, loss, **kwargs):
     return getters.get_loss(loss.IDENTIFIER, results, features, **loss.to_dict())
예제 #5
0
 def _build_loss(self, results, features, labels, loss_config, **kwargs):
     return getters.get_loss(loss_config.module, results, features, **loss_config.params)
예제 #6
0
 def _build_loss(self, incoming, results, loss_config, **kwargs):
     return getters.get_loss(loss_config.module, results, incoming,
                             **loss_config.params)
예제 #7
0
 def _build_loss(self, results, features, labels, loss_config, **kwargs):
     return getters.get_loss(loss_config.module, results, features,
                             **loss_config.params)