def cross_entropy(pred, label, axis=1):
    n0 = pred.ndim
    n1 = label.ndim
    assert n0 == n1 + 1, (
        "target ndim must be one less than input ndim; input_ndim={} "
        "target_ndim={}".format(n0, n1))

    num_classes = pred.shapeof(axis)

    # Denominator of the softmax
    #offset = zero_grad(pred.max(axis=axis, keepdims=True))
    #pred = pred - offset
    #down = mgb.opr.elem.exp(pred).sum(axis=axis, keepdims=True)

    up = indexing_one_hot(pred, label, axis, keepdims=True)

    return (-log(up + 1e-8)).mean()
def cross_entropy_with_softmax(pred, label, axis=1, mask=None):
    n0 = pred.ndim
    n1 = label.ndim
    assert n0 == n1 + 1, (
        "target ndim must be one less than input ndim; input_ndim={} "
        "target_ndim={}".format(n0, n1))

    num_classes = pred.shapeof(axis)

    # Denominator of the softmax
    offset = zero_grad(pred.max(axis=axis, keepdims=True))
    pred = pred - offset
    down = mgb.opr.elem.exp(pred).sum(axis=axis, keepdims=True)

    up = indexing_one_hot(pred, label, axis, keepdims=True)

    if mask is None:
        return (log(down) - up).mean()
    else:
        return ((log(down) - up) * mask).sum() / mask.sum()