def test_resnet_rewrite_tarjan(linearize=False): tf.reset_default_graph() tf_dev = tf.device('/cpu:0') tf_dev.__enter__() n = 6 # use n>5 (see test_chain_memory) nodes = make_resnet(n) a0 = nodes[0] a = nodes[-1] checkpoints = [nodes[3], nodes[5]] # ['a03_add:0', 'a05_add:0'] grad = memory_saving_gradients.gradients_tarjan([a], [a0])[0] if linearize: added = linearize_lib.linearize(grad.op) sess = create_session() sessrun(tf.global_variables_initializer()) sessrun(grad.op) peak_memory = cpu_peak() expected_peak = 4e6 util.report_memory(peak_memory, expected_peak) if not REMOVE_ASSERTS: assert (peak_memory - expected_peak) < 1.1 * 10**6, "Difference too large."
def test_long_chain_tarjan(linearize=False): """Like test_chain, but use automatic rewriting with checkpoints="tarjan" strategy.""" tf.reset_default_graph() tf_dev = tf.device('/cpu:0') tf_dev.__enter__() n = 100 nodes = make_chain_tanh_constant(n) a0 = nodes[0] a = nodes[-1] grad = memory_saving_gradients.gradients_tarjan([a], [a0])[0] sess = create_session() sessrun(tf.global_variables_initializer()) sessrun(grad.op) if linearize: added = linearize_lib.linearize() peak_memory = cpu_peak() # points picked # a09:0,19:0,a29:0,a39:0,a49:0,a58:0,a68:0,a78:0,a88:0,a97:0 expected_peak = 18e6 util.report_memory(peak_memory, expected_peak) # todo: remove "REMOVE_ASSERTS" if not REMOVE_ASSERTS: assert (peak_memory - expected_peak) < 1.1e6, "Difference too large."
def test_long_resnet_rewrite_tarjan(linearize=False): tf.reset_default_graph() tf_dev = tf.device('/cpu:0') tf_dev.__enter__() n = 100 nodes = make_resnet(n) a0 = nodes[0] a = nodes[-1] start_time = time.time() with tf.control_dependencies([a]): grad = memory_saving_gradients.gradients_tarjan([a], [a0])[0] start_time = time.time() if linearize: added = linearize_lib.linearize(grad.op) sess = create_session() sessrun(tf.global_variables_initializer()) sessrun(grad.op) peak_memory = cpu_peak() # 20 mem used with following tensors picked automatically # ['a10_add:0', 'a19_add:0', 'a28_add:0', 'a37_add:0', 'a46_add:0', # 'a55_add:0', 'a64_add:0', 'a73_add:0', 'a82_add:0', 'a91_add:0'] expected_peak = 18 * 10**6 util.report_memory(peak_memory, expected_peak) if not REMOVE_ASSERTS: assert (peak_memory - expected_peak) < 1.1e6, "Difference too large."
def test_chain_tarjan(linearize=False): """Like test_chain, but use automatic rewriting with checkpoints="tarjan" strategy.""" tf.reset_default_graph() tf_dev = tf.device('/cpu:0') tf_dev.__enter__() n = 6 # for n=5, only choice of a2 saves memory, and alg picks a3 # hence use n>5 to avoid this edge condition nodes = util.make_chain_tanh_fill(n) a0 = nodes[0] a = nodes[-1] grad = memory_saving_gradients.gradients_tarjan([a], [a0])[0] sess = create_session() sessrun(tf.global_variables_initializer()) sessrun(grad.op) if linearize: linearize_lib.linearize() peak_memory = cpu_peak() expected_peak = 5e6 # originally needed 7 units, now a3,a5 are recomputed util.report_memory(peak_memory, expected_peak) if not REMOVE_ASSERTS: assert (peak_memory - expected_peak) < 1e5, "Difference too large."
def test_resnet_rewrite_tarjan(linearize=False): tf.reset_default_graph() tf_dev = tf.device('/cpu:0') tf_dev.__enter__() n = 6 # use n>5 (see test_chain_memory) nodes = make_resnet(n) a0 = nodes[0] a = nodes[-1] checkpoints = [nodes[3], nodes[5]] # ['a03_add:0', 'a05_add:0'] grad = memory_saving_gradients.gradients_tarjan([a], [a0])[0] if linearize: added = linearize_lib.linearize(grad.op) sess = create_session() sessrun(tf.global_variables_initializer()) sessrun(grad.op) peak_memory = cpu_peak() expected_peak = 4e6 util.report_memory(peak_memory, expected_peak) if not REMOVE_ASSERTS: assert (peak_memory - expected_peak) < 1.1*10**6, "Difference too large."