def test_custom_dp_can_set_cost_cap(): eq, shapes = oe.helpers.rand_equation(5, 3, seed=42) opt1 = oe.DynamicProgramming(cost_cap=True) opt2 = oe.DynamicProgramming(cost_cap=False) opt3 = oe.DynamicProgramming(cost_cap=100) info1 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt1)[1] info2 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt2)[1] info3 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt3)[1] assert info1.opt_cost == info2.opt_cost == info3.opt_cost
def test_custom_dp_can_optimize_for_size(): eq, shapes = oe.helpers.rand_equation(10, 4, seed=43) opt1 = oe.DynamicProgramming(minimize='flops') opt2 = oe.DynamicProgramming(minimize='size') info1 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt1)[1] info2 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt2)[1] assert (info1.opt_cost < info2.opt_cost) assert (info1.largest_intermediate > info2.largest_intermediate)
def test_custom_dp_can_optimize_for_outer_products(): eq = "a,b,abc->c" da, db, dc = 2, 2, 3 shapes = [(da, ), (db, ), (da, db, dc)] opt1 = oe.DynamicProgramming(search_outer=False) opt2 = oe.DynamicProgramming(search_outer=True) info1 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt1)[1] info2 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt2)[1] assert info2.opt_cost < info1.opt_cost
def test_dp_errors_when_no_contractions_found(): eq, shapes, size_dict = oe.helpers.rand_equation(10, 3, seed=42, return_size_dict=True) # first get the actual minimum cost opt = oe.DynamicProgramming(minimize="size") path, info = oe.contract_path(eq, *shapes, shapes=True, optimize=opt) mincost = info.largest_intermediate # check we can still find it without minimizing size explicitly oe.contract_path(eq, *shapes, shapes=True, memory_limit=mincost, optimize="dp") # but check just below this threshold raises with pytest.raises(RuntimeError): oe.contract_path(eq, *shapes, shapes=True, memory_limit=mincost - 1, optimize="dp")
def test_custom_dp_can_set_minimize(minimize, cost, width, path): eq, shapes = oe.helpers.rand_equation(10, 4, seed=43) opt = oe.DynamicProgramming(minimize=minimize) info = oe.contract_path(eq, *shapes, shapes=True, optimize=opt)[1] assert info.path == path assert info.opt_cost == cost assert info.largest_intermediate == width