def test_resnet_rewrite_memory(linearize=False): tf.reset_default_graph() tf_dev = tf.device('/cpu:0') tf_dev.__enter__() n = 6 # use n>5 (see test_chain_memory) nodes = make_resnet(n) a0 = nodes[0] a = nodes[-1] checkpoints = [nodes[3], nodes[5]] # ['a03_add:0', 'a05_add:0'] grad = memory_saving_gradients.gradients_memory([a], [a0])[0] if linearize: added = linearize_lib.linearize(grad.op) sess = create_session() sessrun(tf.global_variables_initializer()) sessrun(grad.op) peak_memory = cpu_peak() # 1 for activation of each tanh node + 1 for initial backprop node # + 1 temporary memory for computing the adds, # -1 for discarding, then recomputing a1_tanh expected_peak = (n + 1 + 1 - 1) * 10**6 util.report_memory(peak_memory, expected_peak) if not REMOVE_ASSERTS: assert (peak_memory - expected_peak) < 1.1 * 10**6, "Difference too large."
def test_long_chain_memory(linearize=False): """Like test_chain, but use automatic rewriting with checkpoints="memory" strategy.""" tf.reset_default_graph() tf_dev = tf.device('/cpu:0') tf_dev.__enter__() n = 100 nodes = make_chain_tanh_constant(n) a0 = nodes[0] a = nodes[-1] tf.add_to_collection("checkpoints", nodes[10]) tf.add_to_collection("checkpoints", nodes[20]) #grad = memory_saving_gradients.gradients_collection([a], [a0])[0] grad = memory_saving_gradients.gradients_memory([a], [a0])[0] sess = create_session() sessrun(tf.global_variables_initializer()) sessrun(grad.op) if linearize: added = linearize_lib.linearize() peak_memory = cpu_peak() # 20 mem used with following tensors picked automatically as bottlenecks # ['a10:0', 'a19:0', 'a28:0', 'a37:0', 'a46:0', 'a55:0', 'a64:0', 'a73:0', # 'a82:0', 'a91:0'] expected_peak = 20 * 10**6 util.report_memory(peak_memory, expected_peak) if not REMOVE_ASSERTS: assert (peak_memory - expected_peak) < 1.1e6, "Difference too large."
def test_long_resnet_rewrite_memory(linearize=False): tf.reset_default_graph() tf_dev = tf.device('/cpu:0') tf_dev.__enter__() n = 100 nodes = make_resnet(n) a0 = nodes[0] a = nodes[-1] start_time = time.time() with tf.control_dependencies([a]): grad = memory_saving_gradients.gradients_memory([a], [a0])[0] start_time = time.time() if linearize: added = linearize_lib.linearize(grad.op) sess = create_session() sessrun(tf.global_variables_initializer()) sessrun(grad.op) peak_memory = cpu_peak() # 20 mem used with following tensors picked automatically # ['a10_add:0', 'a19_add:0', 'a28_add:0', 'a37_add:0', 'a46_add:0', # 'a55_add:0', 'a64_add:0', 'a73_add:0', 'a82_add:0', 'a91_add:0'] expected_peak = 20 * 10**6 util.report_memory(peak_memory, expected_peak) if not REMOVE_ASSERTS: assert (peak_memory - expected_peak) < 10000, "Difference too large."
def test_chain_memory(linearize=False): """Like test_chain, but use automatic rewriting with checkpoints="memory" strat.""" tf.reset_default_graph() tf_dev = tf.device('/cpu:0') tf_dev.__enter__() n = 6 # for n=5, only choice of a2 saves memory, and alg picks a3 # hence use n>5 to avoid this edge condition nodes = make_chain_tanh_constant(n) a0 = nodes[0] a = nodes[-1] grad = memory_saving_gradients.gradients_memory([a], [a0])[0] sess = create_session() sessrun(tf.global_variables_initializer()) sessrun(grad.op) if linearize: linearize_lib.linearize() peak_memory = cpu_peak() expected_peak = (n + 1 - 1) * 10**6 # 1 for each node + 1 for generated - 1 saved # "loss" tensor util.report_memory(peak_memory, expected_peak) if not REMOVE_ASSERTS: assert (peak_memory - expected_peak) < 10000, "Difference too large."
def test_resnet_rewrite_memory(linearize=False): tf.reset_default_graph() tf_dev = tf.device('/cpu:0') tf_dev.__enter__() n = 6 # use n>5 (see test_chain_memory) nodes = make_resnet(n) a0 = nodes[0] a = nodes[-1] checkpoints = [nodes[3], nodes[5]] # ['a03_add:0', 'a05_add:0'] grad = memory_saving_gradients.gradients_memory([a], [a0])[0] if linearize: added = linearize_lib.linearize(grad.op) sess = create_session() sessrun(tf.global_variables_initializer()) sessrun(grad.op) peak_memory = cpu_peak() # 1 for activation of each tanh node + 1 for initial backprop node # + 1 temporary memory for computing the adds, # -1 for discarding, then recomputing a1_tanh expected_peak = (n+1+1-1)*10**6 util.report_memory(peak_memory, expected_peak) if not REMOVE_ASSERTS: assert (peak_memory - expected_peak) < 1.1*10**6, "Difference too large."
def test_chain_memory(linearize=False): """Like test_chain, but use automatic rewriting with checkpoints="memory" strat.""" tf.reset_default_graph() tf_dev = tf.device('/cpu:0') tf_dev.__enter__() n = 6 # for n=5, only choice of a2 saves memory, and alg picks a3 # hence use n>5 to avoid this edge condition nodes = make_chain_tanh_constant(n) a0 = nodes[0] a = nodes[-1] grad = memory_saving_gradients.gradients_memory([a], [a0])[0] sess = create_session() sessrun(tf.global_variables_initializer()) sessrun(grad.op) if linearize: linearize_lib.linearize() peak_memory = cpu_peak() expected_peak = (n+1-1)*10**6 # 1 for each node + 1 for generated - 1 saved # "loss" tensor util.report_memory(peak_memory, expected_peak) if not REMOVE_ASSERTS: assert (peak_memory - expected_peak) < 10000, "Difference too large."