def test_chen_sqrtn(): for graph_length in test_points: g = gen_linear_graph(graph_length) assert g.size == 2 * graph_length + 1 total_cost = sum(g.cost_ram.values()) scheduler_result = solve_chen_sqrtn(g, total_cost) assert scheduler_result.feasible if SAVE_DEBUG_PLOTS: plot_schedule(scheduler_result, save_file="/tmp/test_checkmate/plot_chen_sqrtn/{}.png".format(graph_length))
import numpy as np import tensorflow as tf from tqdm import tqdm from checkmate.core.solvers.strategy_checkpoint_all import solve_checkpoint_all from checkmate.core.solvers.strategy_chen import solve_chen_sqrtn from checkmate.core.utils.definitions import PathLike from checkmate.tf2.execution import edit_graph from checkmate.tf2.extraction import dfgraph_from_tf_function from checkmate.plot.definitions import checkmate_data_dir from checkmate.tf2.util.load_keras_model import get_keras_model logging.basicConfig(level=logging.INFO) os.environ["KMP_DUPLICATE_LIB_OK"] = "True" solve_chen_sqrtn_noap = lambda g: solve_chen_sqrtn(g, False) solve_chen_sqrtn_ap = lambda g: solve_chen_sqrtn(g, True) def get_data(dataset: str, batch_size=32): if dataset in ["mnist", "cifar10", "cifar100"]: dataset = eval("tf.keras.datasets.{}".format(dataset)) (x_train, y_train), (x_test, y_test) = dataset.load_data() x_train, x_test = x_train / 255.0, x_test / 255.0 if dataset == "mnist": x_train = x_train[..., tf.newaxis] x_test = x_test[..., tf.newaxis] train_ds = tf.data.Dataset.from_tensor_slices( (x_train, y_train)).batch(batch_size) test_ds = tf.data.Dataset.from_tensor_slices( (x_test, y_test)).batch(batch_size)
to_file=log_base / "plot_{}_keras.png".format(model_name), show_shapes=True, show_layer_names=True) plot_dfgraph(g, log_base, name=model_name) # sweep constant baselines logger.info( "Running constant baselines (ALL, ALL_AP, LAST_NODE, SQRTN_NOAP, SQRTN)" ) result_dict[SolveStrategy.CHECKPOINT_ALL] = [solve_checkpoint_all(g)] result_dict[SolveStrategy.CHECKPOINT_ALL_AP] = [solve_checkpoint_all_ap(g)] result_dict[SolveStrategy.CHECKPOINT_LAST_NODE] = [ solve_checkpoint_last_node(g) ] result_dict[SolveStrategy.CHEN_SQRTN_NOAP] = [solve_chen_sqrtn(g, False)] result_dict[SolveStrategy.CHEN_SQRTN] = [solve_chen_sqrtn(g, True)] # sweep chen's greedy baseline logger.info("Running Chen's greedy baseline (APs only)") chen_sqrtn_noap = result_dict[SolveStrategy.CHEN_SQRTN_NOAP][0] greedy_eval_points = chen_sqrtn_noap.schedule_aux_data.activation_ram * ( 1.0 + np.arange(-1, 2, 0.01)) remote_solve_chen_greedy = ray.remote(num_cpus=1)(solve_chen_greedy).remote futures = [ remote_solve_chen_greedy(g, float(b), False) for b in greedy_eval_points ] result_dict[SolveStrategy.CHEN_GREEDY] = get_futures( list(futures), desc="Greedy (APs only)") if model_name not in CHAIN_GRAPH_MODELS:
def compile_tf2( model: tf.keras.Model, loss: tf.losses.Loss, optimizer: tf.optimizers.Optimizer, input_spec=None, label_spec=None, budget="auto", ): """ Checkmate optimizes your DNN graphs to consume less GPU memory. Call this function using a tf.function :param model: a keras Model to optimize :param loss: loss function to use when training :param input_spec: tf.TensorSpec list that corresponds to model inputs :param budget: """ # set input, output shapes if model.input_spec is None and input_spec is None: raise ValueError( "Keras model has not been compiled yet! If model.input_spec is not defined, then input_spec " "parameter must be set in the call to checkmate.tensorflow2.compile." ) if label_spec is None: raise ValueError( "Checkmate needs the shape of the label in order to calculate the size of all operations. Pass in" "an example input or tf.TensorSpec object representing the shape of the label." ) input_spec = model.input_spec if input_spec is None else input_spec # query budget if not specified if budget == "auto": if _using_gpu_check(): # choose based on available GPU RAM gpu_ram = nvidiasmi_query("memory.total") budget = min(gpu_ram.values()) * 0.9 logging.info( "[checkmate] No budget specified; defaulting to the minimum amount of total GPU RAM on any single " "GPU, {0:.2f}MB".format(budget)) else: # choose based available system memory budget = psutil.virtual_memory().available * 0.8 / 1000000 logging.debug( "[checkmate] No GPU detected, using system DRAM on CPU") logging.info( "[checkmate] No budget specified; defaulting to {0:.2f}MB". format(budget)) # build gradient function for model @tf.function def grads_check(data, label): with tf.GradientTape() as check_tape: predictions = model(data) loss_val = loss(label, predictions) gradients = check_tape.gradient(loss_val, model.trainable_variables) return predictions, loss_val, gradients fn = grads_check.get_concrete_function(input_spec, label_spec) g = dfgraph_from_tf_function(fn) # choose solver and calculate solver logging.error( "[checkmate] At the moment, Checkmate does not guarentee scheduling under the specified budget. " "This feature will appear soon.") logging.debug( "[checkmate] Solving for recomputation schedule, may take a while") logging.debug("[checkmate] Using Chen et al. (2016) sqrt(n) algorithm") sched_result = solve_chen_sqrtn(g, True) logging.debug("[checkmate] Schedule solved") # create recomputed gradient function def clean_bs(tensorspec): newshape = list(tensorspec.shape) newshape[0] = None return tf.TensorSpec(shape=newshape, dtype=tensorspec.dtype) fn_nobatchsize = grads_check.get_concrete_function(clean_bs(input_spec), clean_bs(label_spec)) grad_fn_check = edit_graph(fn_nobatchsize, g.op_dict, sched_result.schedule) @tf.function def train_step_check(data, labels): predictions, loss_val, gradients = grad_fn_check(data, labels) optimizer.apply_gradients(zip(gradients, model.trainable_variables)) return predictions, loss_val return train_step_check
if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) logging.info("building graph") with Timer("build_graph", print_results=True): model = get_keras_model("ResNet50") def grads(images, labels): with tf.GradientTape() as tape: pred = model(images) loss = tf.reduce_mean(pred - labels) gradient = tape.gradient(loss, model.trainable_variables) return loss, gradient grad_fn = tf.function(grads).get_concrete_function( tf.TensorSpec(shape=(BS, 224, 224, 3)), tf.TensorSpec(shape=(BS, 1000))) logging.info("tracing graph") with Timer("trace_graph", print_results=True): g = dfgraph_from_tf_function(grad_fn) # sched_result = solve_ilp_gurobi(g, budget=platform_memory("p2xlarge"), approx=False, eps_noise=0.0) # sched_result = solve_approx_lp_deterministic_05_threshold(g, budget=platform_memory("p2xlarge")) logging.info("solving graph") with Timer("sched_graph", print_results=True): sched_result = solve_chen_sqrtn(g, True) # logging.info("rebuilding graph") # new_graph = edit_graph(grad_fn, g.op_dict, sched_result.schedule) # plot_path = checkmate_data_dir() / "exec" # plot_path.mkdir(parents=True, exist_ok=True) # plot_schedule(sched_result, save_file=plot_path / "optimal_vgg16.png")
futures = [] # load model at batch size g = dfgraph_from_keras(model, batch_size=bs, cost_model=cost_model, loss_cpu_cost=0, loss_ram_cost=(4 * bs)) bs_fwd2xcost[bs] = sum(g.cost_cpu_fwd.values()) + sum( g.cost_cpu.values()) bs_param_ram_cost[bs] = g.cost_ram_fixed plot_dfgraph(g, log_base, name=model_name) # run constant baselines result_dict[bs][SolveStrategy.CHEN_SQRTN_NOAP] = [ solve_chen_sqrtn(g, False) ] futures.extend([ ray.remote(num_cpus=1)(solve_checkpoint_all).remote(g), ray.remote(num_cpus=1)(solve_checkpoint_all_ap).remote(g), ray.remote(num_cpus=1)(solve_checkpoint_last_node).remote(g), ray.remote(num_cpus=1)(solve_chen_sqrtn).remote(g, True), ray.remote(num_cpus=1)(solve_chen_sqrtn).remote(g, False), ]) # sweep chen's greedy baseline chen_sqrtn_noap = result_dict[bs][SolveStrategy.CHEN_SQRTN_NOAP][0] greedy_eval_points = chen_sqrtn_noap.schedule_aux_data.activation_ram * ( 1.0 + np.arange(-1, 2, 0.05)) remote_solve_chen_greedy = ray.remote( num_cpus=1)(solve_chen_greedy).remote
False, save_file=scratch_dir / "ALL.png") data.append({ "Strategy": str(scheduler_result_all.solve_strategy.value), "Name": "CHECKPOINT_ALL", "CPU": scheduler_result_all.schedule_aux_data.cpu, "Activation RAM": scheduler_result_all.schedule_aux_data.activation_ram, }) if args.model_name in LINEAR_MODELS: # Sqrt(n) scheduler_result_sqrtn = solve_chen_sqrtn(g, True) plot_schedule(scheduler_result_sqrtn, False, save_file=scratch_dir / "SQRTN.png") data.append({ "Strategy": str(scheduler_result_sqrtn.solve_strategy.value), "Name": "CHEN_SQRTN", "CPU": scheduler_result_sqrtn.schedule_aux_data.cpu, "Activation RAM": scheduler_result_sqrtn.schedule_aux_data.activation_ram, }) if not args.skip_ilp: