コード例 #1
0
ファイル: test_linear.py プロジェクト: uwsampl/dtr-prototype
def test_chen_sqrtn():
    for graph_length in [2, 4, 5, 7, 8]:
        for budget in range(1, min(graph_length, 4)):
            g = gen_linear_graph(graph_length)
            assert g.size_fwd == graph_length
            total_cost = sum(g.cost_ram.values())
            scheduler_result = solve_chen_sqrtn(g, total_cost)
            assert scheduler_result.feasible
コード例 #2
0
from remat.core.utils.timer import Timer
import pandas as pd
import matplotlib.pyplot as plt

if __name__ == "__main__":
    N = 16
    for B in range(4, 12):
        # model = get_keras_model("MobileNet")
        # g = dfgraph_from_keras(mod=model)
        g = gen_linear_graph(N)
        scratch_dir = remat_data_dir() / f"scratch_linear" / str(N) / str(B)
        scratch_dir.mkdir(parents=True, exist_ok=True)
        data = []

        scheduler_result_all = solve_checkpoint_all(g)
        scheduler_result_sqrtn = solve_chen_sqrtn(g, True)
        scheduler_result_griewank = solve_griewank(g, B)
        plot(scheduler_result_all,
             False,
             save_file=scratch_dir / "CHECKPOINT_ALL.png")
        plot(scheduler_result_sqrtn,
             False,
             save_file=scratch_dir / "CHEN_SQRTN.png")
        plot(scheduler_result_griewank,
             False,
             save_file=scratch_dir / "GRIEWANK.png")
        data.append({
            "Strategy":
            str(scheduler_result_all.solve_strategy.value),
            "Name":
            "CHECKPOINT_ALL",
コード例 #3
0
        futures = []

        # load model at batch size
        g = dfgraph_from_keras(model,
                               batch_size=bs,
                               cost_model=cost_model,
                               loss_cpu_cost=0,
                               loss_ram_cost=(4 * bs))
        bs_fwd2xcost[bs] = sum(g.cost_cpu_fwd.values()) + sum(
            g.cost_cpu.values())
        bs_param_ram_cost[bs] = g.cost_ram_fixed
        render_dfgraph(g, log_base, name=model_name)

        # run constant baselines
        result_dict[bs][SolveStrategy.CHEN_SQRTN_NOAP] = [
            solve_chen_sqrtn(g, False)
        ]
        futures.extend([
            ray.remote(num_cpus=1)(solve_checkpoint_all).remote(g),
            ray.remote(num_cpus=1)(solve_checkpoint_all_ap).remote(g),
            ray.remote(num_cpus=1)(solve_checkpoint_last_node).remote(g),
            ray.remote(num_cpus=1)(solve_chen_sqrtn).remote(g, True),
            ray.remote(num_cpus=1)(solve_chen_sqrtn).remote(g, False)
        ])

        # sweep chen's greedy baseline
        chen_sqrtn_noap = result_dict[bs][SolveStrategy.CHEN_SQRTN_NOAP][0]
        greedy_eval_points = chen_sqrtn_noap.schedule_aux_data.activation_ram * (
            1. + np.arange(-1, 2, 0.05))
        remote_solve_chen_greedy = ray.remote(
            num_cpus=1)(solve_chen_greedy).remote
                                  to_file=log_base /
                                  f"plot_{model_name}_keras.png",
                                  show_shapes=True,
                                  show_layer_names=True)
        render_dfgraph(g, log_base, name=model_name)

    # sweep constant baselines
    logger.info(
        f"Running constant baselines (ALL, ALL_AP, LAST_NODE, SQRTN_NOAP, SQRTN)"
    )
    result_dict[SolveStrategy.CHECKPOINT_ALL] = [solve_checkpoint_all(g)]
    result_dict[SolveStrategy.CHECKPOINT_ALL_AP] = [solve_checkpoint_all_ap(g)]
    result_dict[SolveStrategy.CHECKPOINT_LAST_NODE] = [
        solve_checkpoint_last_node(g)
    ]
    result_dict[SolveStrategy.CHEN_SQRTN_NOAP] = [solve_chen_sqrtn(g, False)]
    result_dict[SolveStrategy.CHEN_SQRTN] = [solve_chen_sqrtn(g, True)]

    # sweep chen's greedy baseline
    logger.info(f"Running Chen's greedy baseline (APs only)")
    chen_sqrtn_noap = result_dict[SolveStrategy.CHEN_SQRTN_NOAP][0]
    greedy_eval_points = chen_sqrtn_noap.schedule_aux_data.activation_ram * (
        1. + np.arange(-1, 2, 0.01))
    remote_solve_chen_greedy = ray.remote(num_cpus=1)(solve_chen_greedy).remote
    futures = [
        remote_solve_chen_greedy(g, float(b), False)
        for b in greedy_eval_points
    ]
    result_dict[SolveStrategy.CHEN_GREEDY] = get_futures(
        list(futures), desc="Greedy (APs only)")
    if model_name not in CHAIN_GRAPH_MODELS: