예제 #1
0
def test_chain_rewrite(linearize=False):
    """Take chain of length 5, save 2 nodes, make sure 2 units of RAM is
  saved."""

    tf.reset_default_graph()
    tf_dev = tf.device('/cpu:0')
    tf_dev.__enter__()

    n = 5

    a0, a1, a2, a3, a4 = make_chain_tanh(n)
    grad = memory_saving_gradients.gradients([a4], [a0], checkpoints=[a1,
                                                                      a3])[0]
    expected_peak = (n + 1 - 2) * 10**6  # subtract 2 since we recompute 2

    sess = create_session()
    sessrun(tf.global_variables_initializer())

    sessrun(grad.op)
    if linearize:
        linearize_lib.linearize()

    peak_memory = cpu_peak()
    util.report_memory(peak_memory, expected_peak)

    if not REMOVE_ASSERTS:
        assert (peak_memory -
                expected_peak) < 1e6 + 10000, "Difference too large."
예제 #2
0
def test_chain():
    """Runs regular chain gradient, makes sure memory usage makes sense."""

    tf.reset_default_graph()
    tf_dev = tf.device('/cpu:0')
    tf_dev.__enter__()

    n = 5
    nodes = make_chain_tanh(n)
    a0 = nodes[0]
    a = nodes[-1]
    with tf.control_dependencies([a]):
        grad = tf.gradients([a], [a0])[0]

    #linearize_lib.linearize()

    sess = create_session()
    sessrun(tf.global_variables_initializer())

    sessrun(grad.op)

    peak_memory = cpu_peak()
    expected_peak = (n) * 10**6

    assert peak_memory > 2e6

    # "loss" tensor
    util.report_memory(peak_memory, expected_peak)
    if not REMOVE_ASSERTS:
        assert (peak_memory -
                expected_peak) < 1e6 + 10000, "Difference too large."
def test_chain_rewrite(linearize=False):
  """Take chain of length 5, save 2 nodes, make sure 2 units of RAM is
  saved."""

  tf.reset_default_graph()
  tf_dev = tf.device('/cpu:0')
  tf_dev.__enter__()
  
  n = 5

  a0, a1, a2, a3, a4 = make_chain_tanh(n)
  grad = memory_saving_gradients.gradients([a4], [a0], checkpoints=[a1,a3])[0]
  expected_peak = (n+1-2)*10**6  # subtract 2 since we recompute 2

  sess = create_session()
  sessrun(tf.global_variables_initializer())

  sessrun(grad.op)
  if linearize:
    linearize_lib.linearize()

  peak_memory = cpu_peak()
  util.report_memory(peak_memory, expected_peak)

  if not REMOVE_ASSERTS:
    assert (peak_memory - expected_peak) < 1e6+10000, "Difference too large."
def test_chain():
  """Runs regular chain gradient, makes sure memory usage makes sense."""

  tf.reset_default_graph()
  tf_dev = tf.device('/cpu:0')
  tf_dev.__enter__()
  

  n = 5
  nodes = make_chain_tanh(n)
  a0 = nodes[0]
  a = nodes[-1]
  with tf.control_dependencies([a]):
      grad = tf.gradients([a], [a0])[0]

  #linearize_lib.linearize()

  sess = create_session()
  sessrun(tf.global_variables_initializer())

  sessrun(grad.op)

  peak_memory = cpu_peak()
  expected_peak = (n)*10**6
  
  assert peak_memory > 2e6
  
  # "loss" tensor
  util.report_memory(peak_memory, expected_peak)
  if not REMOVE_ASSERTS:
    assert (peak_memory - expected_peak) < 1e6+10000, "Difference too large."
예제 #5
0
def test_golden_order():
  tf.reset_default_graph()
  n = 5
  nodes = util.make_chain_tanh(n)
  a0 = nodes[0]
  a = nodes[-1]
  grad = tf.gradients([a], [a0])[0]
  
  order = linearize_lib.linearize(modify_graph=False)
  golden_order = ['a00/read', 'a01', 'a02', 'a03', 'gradients/Shape', 'gradients/grad_ys_0', 'gradients/Fill', 'a04', 'gradients/a04_grad/TanhGrad', 'gradients/a03_grad/TanhGrad', 'gradients/a02_grad/TanhGrad', 'gradients/a01_grad/TanhGrad', 'ones']

  observed_order = [n.name for n in order]
  assert observed_order == golden_order
예제 #6
0
def test_chain_linearize():
  tf.reset_default_graph()
  n = 5
  # create a chain with only a single execution order
  # using make_chain_tanh_const doesn't work because of "shape_as_tensor"
  # op that is not constrained
  # (see "Running ones/shape_as_tensor after ones/Const")

  nodes = util.make_chain_tanh(n)
  a0 = nodes[0]
  a = nodes[-1]
  order1 = linearize_lib.obtain_linear_order()
  observed_order1 = [n.name for n in order1]
  
  num_new_deps = linearize_lib.linearize(targets=[a])
  assert num_new_deps == 0
def test_chain_rewrite_save_last():
  """Take chain of length 5, save last node. This saved no memory, and is 
  and edge case that should raise exception by rewriter."""

  tf.reset_default_graph()
  tf_dev = tf.device('/cpu:0')
  tf_dev.__enter__()
  
  n = 5

  a0, a1, a2, a3, a4 = make_chain_tanh(n)
  try:
      grad = memory_saving_gradients.gradients([a4], [a0], checkpoints=[a4])[0]
  except Exception:
      return
  else:
    if not REMOVE_ASSERTS:
      assert "Should've been 'no checkpoints nodes found' exception"
def test_chain_rewrite_save_last():
  """Take chain of length 5, save last node. This saved no memory, and is 
  and edge case that should raise exception by rewriter."""

  tf.reset_default_graph()
  tf_dev = tf.device('/cpu:0')
  tf_dev.__enter__()
  
  n = 5

  a0, a1, a2, a3, a4 = make_chain_tanh(n)
  try:
      grad = memory_saving_gradients.gradients([a4], [a0], checkpoints=[a4])[0]
  except Exception:
      return
  else:
    if not REMOVE_ASSERTS:
      assert "Should've been 'no checkpoints nodes found' exception"
예제 #9
0
def test_articulation_points():
  tf.reset_default_graph()
  n = 5
  nodes = util.make_chain_tanh(n)
  a0 = nodes[0]
  a = nodes[-1]
  points = linearize_lib.sorted_articulation_points(targets=[a])
  # original list is ['a00', 'a01', 'a02', 'a03', 'a04']
  # end-points are not considered separators, so result should be
  assert util.format_ops(points) == ['a01', 'a02', 'a03']
  
  tf.reset_default_graph()
  n = 5
  nodes = _make_simple_caterpillar_graph(n)
  a0 = nodes[0]
  a = nodes[-1] 
  points = linearize_lib.sorted_articulation_points(None)
  
  assert util.format_ops(points) ==  ['merge0', 'merge1', 'merge2',
                                       'merge3', 'merge4', 'merge5']