def test_weighted_bce(): """ Checking binary cross entropy calculation against a precomputed tensor. """ array_eye = np.identity((3)) tensor_eye = np.zeros((3, 3, 3, 3)) tensor_eye[:, :, 0:3, 0:3] = array_eye tensor_pred = np.zeros((3, 3, 3, 3)) tensor_pred[:, 0:2, :, :] = array_eye expect = [1.535057, 1.535057, 1.535057] get = label.weighted_binary_cross_entropy(tensor_eye, tensor_pred) assert assertTensorsEqual(get, expect)
def test_weighted_bce(): """ Checking binary cross entropy calculation against a precomputed tensor. """ array_eye = np.identity(3, dtype=np.float32) tensor_eye = np.zeros((3, 3, 3, 3), dtype=np.float32) tensor_eye[:, :, 0:3, 0:3] = array_eye tensor_eye = tf.convert_to_tensor(tensor_eye, dtype=tf.float32) tensor_pred = np.zeros((3, 3, 3, 3), dtype=np.float32) tensor_pred[:, 0:2, :, :] = array_eye tensor_pred = tf.convert_to_tensor(tensor_pred, dtype=tf.float32) expect = [1.535057, 1.535057, 1.535057] get = label.weighted_binary_cross_entropy(tensor_eye, tensor_pred) assert is_equal_tf(get, expect)