Exemplo n.º 1
0
def test_checkpoint_all():
    for graph_length in range(2, 32):
        g = gen_linear_graph(graph_length)
        assert g.size_fwd == graph_length
        scheduler_result = solve_checkpoint_all(g)
        assert scheduler_result.feasible
        assert scheduler_result.schedule_aux_data.cpu == g.size
Exemplo n.º 2
0
def test_griewank():
    for graph_length in [2**i for i in range(1, 6)]:
        g = gen_linear_graph(graph_length)
        assert g.size_fwd == graph_length
        total_cost = sum(g.cost_ram.values())
        scheduler_result = solve_griewank(g, total_cost)
        assert scheduler_result.feasible
Exemplo n.º 3
0
def test_chen_greedy_ap():
    for graph_length in [2, 4, 5, 7, 8]:
        for budget in range(1, min(graph_length, 4)):
            g = gen_linear_graph(graph_length)
            assert g.size_fwd == graph_length
            total_cost = sum(g.cost_ram.values())
            scheduler_result = solve_chen_greedy(g, total_cost, True)
            assert scheduler_result.feasible
Exemplo n.º 4
0
def test_ilp():
    try:
        import gurobipy as _
    except ImportError as e:
        logging.exception(e)
        logging.warning("Continuing with tests, gurobi not installed")
        return
    from remat.core.solvers.strategy_optimal_ilp import solve_ilp_gurobi
    for graph_length in [2, 4, 8]:
        g = gen_linear_graph(graph_length)
        assert g.size_fwd == graph_length
        total_cost = sum(g.cost_ram.values())
        scheduler_result = solve_ilp_gurobi(g,
                                            total_cost,
                                            print_to_console=False,
                                            write_log_file=None)
        assert scheduler_result.feasible
Exemplo n.º 5
0
from experiments.common.definitions import remat_data_dir
from experiments.common.graph_plotting import plot
from remat.core.solvers.strategy_checkpoint_all import solve_checkpoint_all
from remat.core.solvers.strategy_chen import solve_chen_sqrtn
from remat.core.solvers.strategy_griewank import solve_griewank
from remat.core.solvers.strategy_optimal_ilp import solve_ilp_gurobi
from remat.core.utils.timer import Timer
import pandas as pd
import matplotlib.pyplot as plt

if __name__ == "__main__":
    N = 16
    for B in range(4, 12):
        # model = get_keras_model("MobileNet")
        # g = dfgraph_from_keras(mod=model)
        g = gen_linear_graph(N)
        scratch_dir = remat_data_dir() / f"scratch_linear" / str(N) / str(B)
        scratch_dir.mkdir(parents=True, exist_ok=True)
        data = []

        scheduler_result_all = solve_checkpoint_all(g)
        scheduler_result_sqrtn = solve_chen_sqrtn(g, True)
        scheduler_result_griewank = solve_griewank(g, B)
        plot(scheduler_result_all,
             False,
             save_file=scratch_dir / "CHECKPOINT_ALL.png")
        plot(scheduler_result_sqrtn,
             False,
             save_file=scratch_dir / "CHEN_SQRTN.png")
        plot(scheduler_result_griewank,
             False,
Exemplo n.º 6
0
def test_checkpoint_last():
    for graph_length in range(2, 32):
        g = gen_linear_graph(graph_length)
        assert g.size_fwd == graph_length
        scheduler_result = solve_checkpoint_last_node(g)
        assert scheduler_result.feasible