def test_linear_with_relu_evaluate():
    input_nodes = InputNode.make_input_nodes(3)
    inputs = [1, 2, 3]
    initial_weights = [-20, 3, 2, 1]
    linear_node = LinearNode(input_nodes, initial_weights=initial_weights)
    relu_node = ReluNode(linear_node)
    assert_that(relu_node.evaluate(inputs)).is_equal_to(0)
    assert_that(linear_node.output).is_equal_to(-10)
def test_pretty_print():
    const = ConstantNode()
    input_node = InputNode(0)
    sigmoid = SigmoidNode(const)
    sigmoid.evaluate([])
    relu = ReluNode(input_node)
    relu.evaluate([2])

    assert_that(sigmoid.pretty_print()).is_equal_to(
        "Sigmoid output=0.73\n  Constant(1)\n")
    assert_that(relu.pretty_print()).is_equal_to(
        "Relu output=2.00\n  InputNode(0) output = 2.00\n")

    network = single_linear_relu_network(3, [-20, 3, 2, 1])
    network.evaluate([1, 2, 3])
    network.compute_error([1, 2, 3], 1)
    assert_that(network.pretty_print()).is_equal_to("""Relu output=0.00
  Linear weights=-20.00,3.00,2.00,1.00 gradient=0.00,0.00,0.00,0.00 output=-10.00
    Constant(1)
    InputNode(0) output = 1.00
    InputNode(1) output = 2.00
    InputNode(2) output = 3.00

""")
def test_relu_evaluate_negative():
    input_node = InputNode(0)
    relu = ReluNode(input_node)
    assert_that(relu.evaluate([-2])).is_equal_to(0)
def test_relu_local_parameter_gradient_empty():
    input_node = InputNode(0)
    relu = ReluNode(input_node)
    relu.evaluate([3])
    assert_that(len(relu.local_parameter_gradient)).is_equal_to(0)
def test_relu_local_gradient_negative():
    input_node = InputNode(0)
    relu = ReluNode(input_node)
    relu.evaluate([-3])
    assert_that(relu.local_gradient_for_argument(input_node)).is_equal_to(0)