def test_fast__BasicGraphReversal(): def method1(model): return BaselineGradient(model) def method2(model): return Gradient(model) dryrun.test_equal_analyzer(method1, method2, "trivia.*:mnist.log_reg")
def test_precommit__BasicGraphReversal(): def method1(model): return BaselineGradient(model) def method2(model): return Gradient(model) dryrun.test_equal_analyzer(method1, method2, "mnist.*")
def test_fast_LRPZ_equal_BaselineLRPZ(): def method1(model): return BaselineLRPZ(model) def method2(model): # LRP-Z with bias return LRPZ(model) dryrun.test_equal_analyzer( method1, method2, # mind this only works for networks with relu, max, activations # and no skip connections! "trivia.dot:mnist.log_reg", )