def test_chen_sqrtn(): for graph_length in [2, 4, 5, 7, 8]: for budget in range(1, min(graph_length, 4)): g = gen_linear_graph(graph_length) assert g.size_fwd == graph_length total_cost = sum(g.cost_ram.values()) scheduler_result = solve_chen_sqrtn(g, total_cost) assert scheduler_result.feasible
from remat.core.utils.timer import Timer import pandas as pd import matplotlib.pyplot as plt if __name__ == "__main__": N = 16 for B in range(4, 12): # model = get_keras_model("MobileNet") # g = dfgraph_from_keras(mod=model) g = gen_linear_graph(N) scratch_dir = remat_data_dir() / f"scratch_linear" / str(N) / str(B) scratch_dir.mkdir(parents=True, exist_ok=True) data = [] scheduler_result_all = solve_checkpoint_all(g) scheduler_result_sqrtn = solve_chen_sqrtn(g, True) scheduler_result_griewank = solve_griewank(g, B) plot(scheduler_result_all, False, save_file=scratch_dir / "CHECKPOINT_ALL.png") plot(scheduler_result_sqrtn, False, save_file=scratch_dir / "CHEN_SQRTN.png") plot(scheduler_result_griewank, False, save_file=scratch_dir / "GRIEWANK.png") data.append({ "Strategy": str(scheduler_result_all.solve_strategy.value), "Name": "CHECKPOINT_ALL",
futures = [] # load model at batch size g = dfgraph_from_keras(model, batch_size=bs, cost_model=cost_model, loss_cpu_cost=0, loss_ram_cost=(4 * bs)) bs_fwd2xcost[bs] = sum(g.cost_cpu_fwd.values()) + sum( g.cost_cpu.values()) bs_param_ram_cost[bs] = g.cost_ram_fixed render_dfgraph(g, log_base, name=model_name) # run constant baselines result_dict[bs][SolveStrategy.CHEN_SQRTN_NOAP] = [ solve_chen_sqrtn(g, False) ] futures.extend([ ray.remote(num_cpus=1)(solve_checkpoint_all).remote(g), ray.remote(num_cpus=1)(solve_checkpoint_all_ap).remote(g), ray.remote(num_cpus=1)(solve_checkpoint_last_node).remote(g), ray.remote(num_cpus=1)(solve_chen_sqrtn).remote(g, True), ray.remote(num_cpus=1)(solve_chen_sqrtn).remote(g, False) ]) # sweep chen's greedy baseline chen_sqrtn_noap = result_dict[bs][SolveStrategy.CHEN_SQRTN_NOAP][0] greedy_eval_points = chen_sqrtn_noap.schedule_aux_data.activation_ram * ( 1. + np.arange(-1, 2, 0.05)) remote_solve_chen_greedy = ray.remote( num_cpus=1)(solve_chen_greedy).remote
to_file=log_base / f"plot_{model_name}_keras.png", show_shapes=True, show_layer_names=True) render_dfgraph(g, log_base, name=model_name) # sweep constant baselines logger.info( f"Running constant baselines (ALL, ALL_AP, LAST_NODE, SQRTN_NOAP, SQRTN)" ) result_dict[SolveStrategy.CHECKPOINT_ALL] = [solve_checkpoint_all(g)] result_dict[SolveStrategy.CHECKPOINT_ALL_AP] = [solve_checkpoint_all_ap(g)] result_dict[SolveStrategy.CHECKPOINT_LAST_NODE] = [ solve_checkpoint_last_node(g) ] result_dict[SolveStrategy.CHEN_SQRTN_NOAP] = [solve_chen_sqrtn(g, False)] result_dict[SolveStrategy.CHEN_SQRTN] = [solve_chen_sqrtn(g, True)] # sweep chen's greedy baseline logger.info(f"Running Chen's greedy baseline (APs only)") chen_sqrtn_noap = result_dict[SolveStrategy.CHEN_SQRTN_NOAP][0] greedy_eval_points = chen_sqrtn_noap.schedule_aux_data.activation_ram * ( 1. + np.arange(-1, 2, 0.01)) remote_solve_chen_greedy = ray.remote(num_cpus=1)(solve_chen_greedy).remote futures = [ remote_solve_chen_greedy(g, float(b), False) for b in greedy_eval_points ] result_dict[SolveStrategy.CHEN_GREEDY] = get_futures( list(futures), desc="Greedy (APs only)") if model_name not in CHAIN_GRAPH_MODELS: