예제 #1
0
def test_fast__BasicGraphReversal():
    def method1(model):
        return BaselineGradient(model)

    def method2(model):
        return Gradient(model)

    dryrun.test_equal_analyzer(method1, method2, "trivia.*:mnist.log_reg")
예제 #2
0
def test_precommit__BasicGraphReversal():
    def method1(model):
        return BaselineGradient(model)

    def method2(model):
        return Gradient(model)

    dryrun.test_equal_analyzer(method1, method2, "mnist.*")
예제 #3
0
def test_fast_LRPZ_equal_BaselineLRPZ():
    def method1(model):
        return BaselineLRPZ(model)

    def method2(model):
        # LRP-Z with bias
        return LRPZ(model)

    dryrun.test_equal_analyzer(
        method1,
        method2,
        # mind this only works for networks with relu, max, activations
        # and no skip connections!
        "trivia.dot:mnist.log_reg",
    )