示例#1
0
def test_gradients():
    logits = Tensor.from_builtin([1, 2, 3])
    probabilities = logits.sigmoid()
    gradients = Gradients._trace(
        Gradient(tensor=probabilities, gradient=np.array([1, 0, -1]))
    )
    assert np.allclose(gradients[logits], [0.1966, 0, -0.0452], atol=1e-4)
示例#2
0
def test_gradients():
    left = Tensor.from_builtin([2, 3, 4])
    right = Tensor.from_builtin([3, 2, 1])
    result = left - right
    gradients = Gradients._trace(Gradient(tensor=result, gradient=np.array([1, 2, 3])))
    assert np.all(gradients[left] == [1, 2, 3])
    assert np.all(gradients[right] == [-1, -2, -3])
def test_gradients_full():
    logits = Tensor.from_builtin([1, 2, 3])
    probabilities = logits.softmax()
    gradients = Gradients._trace(
        Gradient(tensor=probabilities, gradient=np.array([1, 0, 0]))
    )
    assert np.allclose(gradients[logits], [0.0819, -0.0220, -0.0599], atol=1e-4)
示例#4
0
def test_gradients():
    left = Tensor.from_builtin([2, 3, 4])
    right = Tensor.from_builtin([3, 2, 1])
    result = left / right
    gradients = Gradients._trace(Gradient(tensor=result, gradient=np.array([1, 2, 3])))
    assert np.allclose(gradients[left], [1 / 3, 2 / 2, 3 / 1])
    assert np.allclose(gradients[right], [-2 / 9, -6 / 4, -12 / 1])
示例#5
0
def test_gradients():
    tensor = Tensor.from_builtin([2, 3, 4])
    low = Tensor.from_builtin([1, 2, 1])
    high = Tensor.from_builtin([3, 4, 2])
    result = tensor.clip(low, high)
    gradients = Gradients._trace(Gradient(tensor=result, gradient=np.array([1, 2, 3])))
    assert np.allclose(gradients[tensor], [1, 2, 0])
    with pytest.raises(KeyError):
        gradients[low]
    with pytest.raises(KeyError):
        gradients[high]
def test_gradients_one_axis():
    logits = Tensor.from_builtin([[1, 2, 3], [4, 5, 6]])
    probabilities = logits.softmax(1)
    gradients = Gradients._trace(
        Gradient(tensor=probabilities, gradient=np.array([[1, 0, 0], [0, 1, 0]]))
    )
    assert np.allclose(
        gradients[logits],
        [[0.0819, -0.0220, -0.0599], [-0.0220, 0.1848, -0.1628]],
        atol=1e-4,
    )