Esempio n. 1
0
def sequence_masking(x, mask, value=0, axis=None):
    """为序列条件mask的函数
    mask: 形如(batch_size, seq_len)的0-1矩阵;
    value: mask部分要被替换成的值,可以是'-inf'或'inf';
    axis: 序列所在轴,默认为1;
    """
    if mask is None:
        return x
    else:
        x_dtype = K.dtype(x)
        if x_dtype == 'bool':
            x = K.cast(x, 'int32')
        if K.dtype(mask) != K.dtype(x):
            mask = K.cast(mask, K.dtype(x))
        if value == '-inf':
            value = -K.infinity()
        elif value == 'inf':
            value = K.infinity()
        if axis is None:
            axis = 1
        elif axis < 0:
            axis = K.ndim(x) + axis
        assert axis > 0, 'axis must be greater than 0'
        mask = align(mask, [0, axis], K.ndim(x))
        value = K.cast(value, K.dtype(x))
        x = x * mask + value * (1 - mask)
        if x_dtype == 'bool':
            x = K.cast(x, 'bool')
        return x
Esempio n. 2
0
def sparse_multilabel_categorical_crossentropy(y_true,
                                               y_pred,
                                               mask_zero=False):
    """稀疏版多标签分类的交叉熵
    说明:
        1. y_true.shape=[..., num_positive],
           y_pred.shape=[..., num_classes];
        2. 请保证y_pred的值域是全体实数,换言之一般情况下
           y_pred不用加激活函数,尤其是不能加sigmoid或者
           softmax;
        3. 预测阶段则输出y_pred大于0的类;
        4. 详情请看:https://kexue.fm/archives/7359 。
    """
    zeros = K.zeros_like(y_pred[..., :1])
    y_pred = K.concatenate([y_pred, zeros], axis=-1)
    if mask_zero:
        infs = zeros + K.infinity()
        y_pred = K.concatenate([infs, y_pred[..., 1:]], axis=-1)
    y_pos_2 = batch_gather(y_pred, y_true)
    y_pos_1 = K.concatenate([y_pos_2, zeros], axis=-1)
    if mask_zero:
        y_pred = K.concatenate([-infs, y_pred[..., 1:]], axis=-1)
        y_pos_2 = batch_gather(y_pred, y_true)
    pos_loss = K.logsumexp(-y_pos_1, axis=-1)
    all_loss = K.logsumexp(y_pred, axis=-1)
    aux_loss = K.logsumexp(y_pos_2, axis=-1) - all_loss
    aux_loss = K.clip(1 - K.exp(aux_loss), K.epsilon(), 1)
    neg_loss = all_loss + K.log(aux_loss)
    return pos_loss + neg_loss
Esempio n. 3
0
def multilabel_categorical_crossentropy(y_true, y_pred):
    """多标签分类的交叉熵
    说明:
        1. y_true和y_pred的shape一致,y_true的元素非0即1,
           1表示对应的类为目标类,0表示对应的类为非目标类;
        2. 请保证y_pred的值域是全体实数,换言之一般情况下
           y_pred不用加激活函数,尤其是不能加sigmoid或者
           softmax;
        3. 预测阶段则输出y_pred大于0的类;
        4. 详情请看:https://kexue.fm/archives/7359 。
    """
    y_pred = (1 - 2 * y_true) * y_pred
    y_neg = y_pred - y_true * K.infinity()
    y_pos = y_pred - (1 - y_true) * K.infinity()
    zeros = K.zeros_like(y_pred[..., :1])
    y_neg = K.concatenate([y_neg, zeros], axis=-1)
    y_pos = K.concatenate([y_pos, zeros], axis=-1)
    neg_loss = K.logsumexp(y_neg, axis=-1)
    pos_loss = K.logsumexp(y_pos, axis=-1)
    return neg_loss + pos_loss
Esempio n. 4
0
def multilabel_categorical_crossentropy(y_true, y_pred):
    """多标签分类的交叉熵
    说明:
        1. y_true和y_pred的shape一致,y_true的元素是0~1
           的数,表示当前类是目标类的概率;
        2. 请保证y_pred的值域是全体实数,换言之一般情况下
           y_pred不用加激活函数,尤其是不能加sigmoid或者
           softmax;
        3. 预测阶段则输出y_pred大于0的类;
        4. 详情请看:https://kexue.fm/archives/7359 和
           https://kexue.fm/archives/9064 。
    """
    y_mask = y_pred > -K.infinity() / 10
    n_mask = (y_true < 1 - K.epsilon()) & y_mask
    p_mask = (y_true > K.epsilon()) & y_mask
    infs = K.zeros_like(y_pred) + K.infinity()
    y_neg = K.switch(n_mask, y_pred, -infs) + K.log(1 - y_true, True)
    y_pos = K.switch(p_mask, -y_pred, -infs) + K.log(y_true, True)
    zeros = K.zeros_like(y_pred[..., :1])
    y_neg = K.concatenate([y_neg, zeros], axis=-1)
    y_pos = K.concatenate([y_pos, zeros], axis=-1)
    neg_loss = K.logsumexp(y_neg, axis=-1)
    pos_loss = K.logsumexp(y_pos, axis=-1)
    return neg_loss + pos_loss
Esempio n. 5
0
def attention_normalize(a, axis=-1, method='softmax'):
    """不同的注意力归一化方案
    softmax:常规/标准的指数归一化;
    squared_relu:来自 https://arxiv.org/abs/2202.10447 ;
    softmax_plus:来自 https://kexue.fm/archives/8823 。
    """
    if method == 'softmax':
        return K.softmax(a, axis=axis)
    else:
        mask = K.cast(a > -K.infinity() / 10, K.floatx())
        l = K.maximum(K.sum(mask, axis=axis, keepdims=True), 1)
        if method == 'squared_relu':
            return K.relu(a)**2 / l
        elif method == 'softmax_plus':
            return K.softmax(a * K.log(l) / np.log(512), axis=axis)
    return a