예제 #1
0
def test_weighted_bce():
    """
    Checking binary cross entropy calculation
    against a precomputed tensor.
    """
    array_eye = np.identity((3))
    tensor_eye = np.zeros((3, 3, 3, 3))
    tensor_eye[:, :, 0:3, 0:3] = array_eye

    tensor_pred = np.zeros((3, 3, 3, 3))
    tensor_pred[:, 0:2, :, :] = array_eye

    expect = [1.535057, 1.535057, 1.535057]
    get = label.weighted_binary_cross_entropy(tensor_eye, tensor_pred)
    assert assertTensorsEqual(get, expect)
예제 #2
0
def test_weighted_bce():
    """
    Checking binary cross entropy calculation
    against a precomputed tensor.
    """
    array_eye = np.identity(3, dtype=np.float32)
    tensor_eye = np.zeros((3, 3, 3, 3), dtype=np.float32)
    tensor_eye[:, :, 0:3, 0:3] = array_eye
    tensor_eye = tf.convert_to_tensor(tensor_eye, dtype=tf.float32)

    tensor_pred = np.zeros((3, 3, 3, 3), dtype=np.float32)
    tensor_pred[:, 0:2, :, :] = array_eye
    tensor_pred = tf.convert_to_tensor(tensor_pred, dtype=tf.float32)

    expect = [1.535057, 1.535057, 1.535057]
    get = label.weighted_binary_cross_entropy(tensor_eye, tensor_pred)
    assert is_equal_tf(get, expect)