def update_state(
     self,
     y_true: tf.Tensor,
     y_logits: tf.Tensor,
     sample_weight: Optional[tf.Tensor] = None,
 ) -> None:
     y_pred = logits2pred(y_logits, self.activation)
     y_pred = pad2shape(y_pred, tf.shape(y_true), value=0)
     super().update_state(y_true, y_pred, sample_weight=sample_weight)
Beispiel #2
0
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
        if isinstance(y_pred, tf.RaggedTensor):
            y_pred = y_pred.to_tensor(y_pred.dtype.min)
        y_true = tf.cast(y_true, y_pred.dtype)
        y_pred = pad2shape(y_pred, tf.shape(y_true))

        sample_size = tf.reduce_prod(tf.shape(y_true)[:self.flatten_axis])
        y_true = tf.reshape(y_true, (sample_size, -1))
        y_pred = tf.reshape(y_pred, (sample_size, -1))
        y_pred = (1 - 2 * y_true) * y_pred
        y_pred_neg: tf.Tensor = y_pred - y_true * y_pred.dtype.max
        y_pred_pos: tf.Tensor = y_pred - (1 - y_true) * y_pred.dtype.max
        zeros = tf.zeros_like(y_pred[..., :1])  # 用于生成logsum中的1
        y_pred_neg = tf.concat([y_pred_neg, zeros], axis=-1)
        y_pred_pos = tf.concat([y_pred_pos, zeros], axis=-1)
        neg_loss = tf.math.reduce_logsumexp(y_pred_neg, axis=-1)
        pos_loss = tf.math.reduce_logsumexp(y_pred_pos, axis=-1)
        return neg_loss + pos_loss