예제 #1
0
def CrossEntropyLossWithLogSoftmax():
    """Mean prediction-target cross-entropy for multiclass classification."""
    return cb.Serial(core.LogSoftmax(),
                     _CrossEntropy(),
                     _WeightedMean(),
                     name='CrossEntropyLossWithLogSoftmax',
                     sublayers_to_print=[])
예제 #2
0
파일: metrics.py 프로젝트: yaoshuyin/trax
def CrossEntropyLossWithLogSoftmax():
    """Mean prediction-target cross entropy for multiclass classification."""
    return cb.Serial(core.LogSoftmax(),
                     CrossEntropyLoss(),
                     name='CrossEntropyLoss')