Exemple #1
0
def logits2pred(logits: tf.Tensor, activation: str) -> tf.Tensor:
    if isinstance(logits, tf.RaggedTensor):
        fill_value = 0
        if activation in ("sigmoid", "softmax"):
            fill_value = logits.dtype.min
        logits = logits.to_tensor(fill_value)
    return tf.keras.activations.get(activation)(logits)
Exemple #2
0
    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
        if isinstance(y_pred, tf.RaggedTensor):
            y_pred = y_pred.to_tensor(y_pred.dtype.min)
        y_true = tf.cast(y_true, y_pred.dtype)
        y_pred = pad2shape(y_pred, tf.shape(y_true))

        sample_size = tf.reduce_prod(tf.shape(y_true)[:self.flatten_axis])
        y_true = tf.reshape(y_true, (sample_size, -1))
        y_pred = tf.reshape(y_pred, (sample_size, -1))
        y_pred = (1 - 2 * y_true) * y_pred
        y_pred_neg: tf.Tensor = y_pred - y_true * y_pred.dtype.max
        y_pred_pos: tf.Tensor = y_pred - (1 - y_true) * y_pred.dtype.max
        zeros = tf.zeros_like(y_pred[..., :1])  # 用于生成logsum中的1
        y_pred_neg = tf.concat([y_pred_neg, zeros], axis=-1)
        y_pred_pos = tf.concat([y_pred_pos, zeros], axis=-1)
        neg_loss = tf.math.reduce_logsumexp(y_pred_neg, axis=-1)
        pos_loss = tf.math.reduce_logsumexp(y_pred_pos, axis=-1)
        return neg_loss + pos_loss