def logits2pred(logits: tf.Tensor, activation: str) -> tf.Tensor: if isinstance(logits, tf.RaggedTensor): fill_value = 0 if activation in ("sigmoid", "softmax"): fill_value = logits.dtype.min logits = logits.to_tensor(fill_value) return tf.keras.activations.get(activation)(logits)
def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: if isinstance(y_pred, tf.RaggedTensor): y_pred = y_pred.to_tensor(y_pred.dtype.min) y_true = tf.cast(y_true, y_pred.dtype) y_pred = pad2shape(y_pred, tf.shape(y_true)) sample_size = tf.reduce_prod(tf.shape(y_true)[:self.flatten_axis]) y_true = tf.reshape(y_true, (sample_size, -1)) y_pred = tf.reshape(y_pred, (sample_size, -1)) y_pred = (1 - 2 * y_true) * y_pred y_pred_neg: tf.Tensor = y_pred - y_true * y_pred.dtype.max y_pred_pos: tf.Tensor = y_pred - (1 - y_true) * y_pred.dtype.max zeros = tf.zeros_like(y_pred[..., :1]) # 用于生成logsum中的1 y_pred_neg = tf.concat([y_pred_neg, zeros], axis=-1) y_pred_pos = tf.concat([y_pred_pos, zeros], axis=-1) neg_loss = tf.math.reduce_logsumexp(y_pred_neg, axis=-1) pos_loss = tf.math.reduce_logsumexp(y_pred_pos, axis=-1) return neg_loss + pos_loss