def test_chain_rewrite(linearize=False): """Take chain of length 5, save 2 nodes, make sure 2 units of RAM is saved.""" tf.reset_default_graph() tf_dev = tf.device('/cpu:0') tf_dev.__enter__() n = 5 a0, a1, a2, a3, a4 = make_chain_tanh(n) grad = memory_saving_gradients.gradients([a4], [a0], checkpoints=[a1, a3])[0] expected_peak = (n + 1 - 2) * 10**6 # subtract 2 since we recompute 2 sess = create_session() sessrun(tf.global_variables_initializer()) sessrun(grad.op) if linearize: linearize_lib.linearize() peak_memory = cpu_peak() util.report_memory(peak_memory, expected_peak) if not REMOVE_ASSERTS: assert (peak_memory - expected_peak) < 1e6 + 10000, "Difference too large."
def test_chain(): """Runs regular chain gradient, makes sure memory usage makes sense.""" tf.reset_default_graph() tf_dev = tf.device('/cpu:0') tf_dev.__enter__() n = 5 nodes = make_chain_tanh(n) a0 = nodes[0] a = nodes[-1] with tf.control_dependencies([a]): grad = tf.gradients([a], [a0])[0] #linearize_lib.linearize() sess = create_session() sessrun(tf.global_variables_initializer()) sessrun(grad.op) peak_memory = cpu_peak() expected_peak = (n) * 10**6 assert peak_memory > 2e6 # "loss" tensor util.report_memory(peak_memory, expected_peak) if not REMOVE_ASSERTS: assert (peak_memory - expected_peak) < 1e6 + 10000, "Difference too large."
def test_chain_rewrite(linearize=False): """Take chain of length 5, save 2 nodes, make sure 2 units of RAM is saved.""" tf.reset_default_graph() tf_dev = tf.device('/cpu:0') tf_dev.__enter__() n = 5 a0, a1, a2, a3, a4 = make_chain_tanh(n) grad = memory_saving_gradients.gradients([a4], [a0], checkpoints=[a1,a3])[0] expected_peak = (n+1-2)*10**6 # subtract 2 since we recompute 2 sess = create_session() sessrun(tf.global_variables_initializer()) sessrun(grad.op) if linearize: linearize_lib.linearize() peak_memory = cpu_peak() util.report_memory(peak_memory, expected_peak) if not REMOVE_ASSERTS: assert (peak_memory - expected_peak) < 1e6+10000, "Difference too large."
def test_chain(): """Runs regular chain gradient, makes sure memory usage makes sense.""" tf.reset_default_graph() tf_dev = tf.device('/cpu:0') tf_dev.__enter__() n = 5 nodes = make_chain_tanh(n) a0 = nodes[0] a = nodes[-1] with tf.control_dependencies([a]): grad = tf.gradients([a], [a0])[0] #linearize_lib.linearize() sess = create_session() sessrun(tf.global_variables_initializer()) sessrun(grad.op) peak_memory = cpu_peak() expected_peak = (n)*10**6 assert peak_memory > 2e6 # "loss" tensor util.report_memory(peak_memory, expected_peak) if not REMOVE_ASSERTS: assert (peak_memory - expected_peak) < 1e6+10000, "Difference too large."
def test_golden_order(): tf.reset_default_graph() n = 5 nodes = util.make_chain_tanh(n) a0 = nodes[0] a = nodes[-1] grad = tf.gradients([a], [a0])[0] order = linearize_lib.linearize(modify_graph=False) golden_order = ['a00/read', 'a01', 'a02', 'a03', 'gradients/Shape', 'gradients/grad_ys_0', 'gradients/Fill', 'a04', 'gradients/a04_grad/TanhGrad', 'gradients/a03_grad/TanhGrad', 'gradients/a02_grad/TanhGrad', 'gradients/a01_grad/TanhGrad', 'ones'] observed_order = [n.name for n in order] assert observed_order == golden_order
def test_chain_linearize(): tf.reset_default_graph() n = 5 # create a chain with only a single execution order # using make_chain_tanh_const doesn't work because of "shape_as_tensor" # op that is not constrained # (see "Running ones/shape_as_tensor after ones/Const") nodes = util.make_chain_tanh(n) a0 = nodes[0] a = nodes[-1] order1 = linearize_lib.obtain_linear_order() observed_order1 = [n.name for n in order1] num_new_deps = linearize_lib.linearize(targets=[a]) assert num_new_deps == 0
def test_chain_rewrite_save_last(): """Take chain of length 5, save last node. This saved no memory, and is and edge case that should raise exception by rewriter.""" tf.reset_default_graph() tf_dev = tf.device('/cpu:0') tf_dev.__enter__() n = 5 a0, a1, a2, a3, a4 = make_chain_tanh(n) try: grad = memory_saving_gradients.gradients([a4], [a0], checkpoints=[a4])[0] except Exception: return else: if not REMOVE_ASSERTS: assert "Should've been 'no checkpoints nodes found' exception"
def test_articulation_points(): tf.reset_default_graph() n = 5 nodes = util.make_chain_tanh(n) a0 = nodes[0] a = nodes[-1] points = linearize_lib.sorted_articulation_points(targets=[a]) # original list is ['a00', 'a01', 'a02', 'a03', 'a04'] # end-points are not considered separators, so result should be assert util.format_ops(points) == ['a01', 'a02', 'a03'] tf.reset_default_graph() n = 5 nodes = _make_simple_caterpillar_graph(n) a0 = nodes[0] a = nodes[-1] points = linearize_lib.sorted_articulation_points(None) assert util.format_ops(points) == ['merge0', 'merge1', 'merge2', 'merge3', 'merge4', 'merge5']