Exemplo n.º 1
0
def test_chain_rule():
    constant = minitorch.Variable(None)
    for variable_with_deriv in Temp.chain_rule(ctx=None,
                                               inputs=[constant, constant],
                                               d_output=5):
        assert False

    var = minitorch.Variable(History())
    constant = minitorch.Variable(None)
    for variable_with_deriv in Temp.chain_rule(ctx=None,
                                               inputs=[var, constant],
                                               d_output=5):
        assert variable_with_deriv.variable.name == var.name
        assert variable_with_deriv.deriv == 5

    ctx = minitorch.Context()
    ctx.save_for_backward(10)
    for variable_with_deriv in Temp2.chain_rule(ctx=ctx,
                                                inputs=[constant, var],
                                                d_output=5):
        assert variable_with_deriv.variable.name == var.name
        assert variable_with_deriv.deriv == 5 * 10

    ctx = minitorch.Context()
    ctx.save_for_backward(10)
    for variable_with_deriv in Temp2.chain_rule(ctx=ctx,
                                                inputs=[var, constant],
                                                d_output=5):
        assert variable_with_deriv.variable.name == var.name
        assert variable_with_deriv.deriv == 5
Exemplo n.º 2
0
def test_chain_rule2():
    "Check that constrants are ignored and variables get derivatives."
    var = minitorch.Variable(History())
    constant = minitorch.Variable(None)
    back = Function1.chain_rule(ctx=None, inputs=[var, constant], d_output=5)
    back = list(back)
    assert len(back) == 1
    variable, deriv = back[0]
    assert variable.name == var.name
    assert deriv == 5
Exemplo n.º 3
0
def test_chain_rule1():
    "Check that constants are ignored."
    constant = minitorch.Variable(None)
    back = Function1.chain_rule(ctx=None,
                                inputs=[constant, constant],
                                d_output=5)
    assert len(list(back)) == 0