def test_chain_rule(): constant = minitorch.Variable(None) for variable_with_deriv in Temp.chain_rule(ctx=None, inputs=[constant, constant], d_output=5): assert False var = minitorch.Variable(History()) constant = minitorch.Variable(None) for variable_with_deriv in Temp.chain_rule(ctx=None, inputs=[var, constant], d_output=5): assert variable_with_deriv.variable.name == var.name assert variable_with_deriv.deriv == 5 ctx = minitorch.Context() ctx.save_for_backward(10) for variable_with_deriv in Temp2.chain_rule(ctx=ctx, inputs=[constant, var], d_output=5): assert variable_with_deriv.variable.name == var.name assert variable_with_deriv.deriv == 5 * 10 ctx = minitorch.Context() ctx.save_for_backward(10) for variable_with_deriv in Temp2.chain_rule(ctx=ctx, inputs=[var, constant], d_output=5): assert variable_with_deriv.variable.name == var.name assert variable_with_deriv.deriv == 5
def test_chain_rule2(): "Check that constrants are ignored and variables get derivatives." var = minitorch.Variable(History()) constant = minitorch.Variable(None) back = Function1.chain_rule(ctx=None, inputs=[var, constant], d_output=5) back = list(back) assert len(back) == 1 variable, deriv = back[0] assert variable.name == var.name assert deriv == 5
def test_chain_rule1(): "Check that constants are ignored." constant = minitorch.Variable(None) back = Function1.chain_rule(ctx=None, inputs=[constant, constant], d_output=5) assert len(list(back)) == 0