Exemplo n.º 1
0
def test_resnet_rewrite_tarjan(linearize=False):
    tf.reset_default_graph()
    tf_dev = tf.device('/cpu:0')
    tf_dev.__enter__()

    n = 6  # use n>5 (see test_chain_memory)

    nodes = make_resnet(n)
    a0 = nodes[0]
    a = nodes[-1]

    checkpoints = [nodes[3], nodes[5]]  # ['a03_add:0', 'a05_add:0']
    grad = memory_saving_gradients.gradients_tarjan([a], [a0])[0]
    if linearize:
        added = linearize_lib.linearize(grad.op)

    sess = create_session()
    sessrun(tf.global_variables_initializer())

    sessrun(grad.op)

    peak_memory = cpu_peak()
    expected_peak = 4e6
    util.report_memory(peak_memory, expected_peak)

    if not REMOVE_ASSERTS:
        assert (peak_memory -
                expected_peak) < 1.1 * 10**6, "Difference too large."
Exemplo n.º 2
0
def test_long_chain_tarjan(linearize=False):
    """Like test_chain, but use automatic rewriting with checkpoints="tarjan" 
  strategy."""

    tf.reset_default_graph()
    tf_dev = tf.device('/cpu:0')
    tf_dev.__enter__()

    n = 100

    nodes = make_chain_tanh_constant(n)
    a0 = nodes[0]
    a = nodes[-1]
    grad = memory_saving_gradients.gradients_tarjan([a], [a0])[0]

    sess = create_session()
    sessrun(tf.global_variables_initializer())

    sessrun(grad.op)

    if linearize:
        added = linearize_lib.linearize()

    peak_memory = cpu_peak()
    # points picked
    #  a09:0,19:0,a29:0,a39:0,a49:0,a58:0,a68:0,a78:0,a88:0,a97:0
    expected_peak = 18e6
    util.report_memory(peak_memory, expected_peak)

    # todo: remove "REMOVE_ASSERTS"
    if not REMOVE_ASSERTS:
        assert (peak_memory - expected_peak) < 1.1e6, "Difference too large."
Exemplo n.º 3
0
def test_long_resnet_rewrite_tarjan(linearize=False):
    tf.reset_default_graph()
    tf_dev = tf.device('/cpu:0')
    tf_dev.__enter__()

    n = 100
    nodes = make_resnet(n)
    a0 = nodes[0]
    a = nodes[-1]

    start_time = time.time()
    with tf.control_dependencies([a]):
        grad = memory_saving_gradients.gradients_tarjan([a], [a0])[0]

    start_time = time.time()
    if linearize:
        added = linearize_lib.linearize(grad.op)

    sess = create_session()
    sessrun(tf.global_variables_initializer())

    sessrun(grad.op)

    peak_memory = cpu_peak()
    # 20 mem used with following tensors picked automatically
    # ['a10_add:0', 'a19_add:0', 'a28_add:0', 'a37_add:0', 'a46_add:0',
    # 'a55_add:0', 'a64_add:0', 'a73_add:0', 'a82_add:0', 'a91_add:0']

    expected_peak = 18 * 10**6
    util.report_memory(peak_memory, expected_peak)

    if not REMOVE_ASSERTS:
        assert (peak_memory - expected_peak) < 1.1e6, "Difference too large."
Exemplo n.º 4
0
def test_chain_tarjan(linearize=False):
    """Like test_chain, but use automatic rewriting with checkpoints="tarjan"
  strategy."""

    tf.reset_default_graph()
    tf_dev = tf.device('/cpu:0')
    tf_dev.__enter__()

    n = 6  # for n=5, only choice of a2 saves memory, and alg picks a3
    # hence use n>5 to avoid this edge condition

    nodes = util.make_chain_tanh_fill(n)
    a0 = nodes[0]
    a = nodes[-1]
    grad = memory_saving_gradients.gradients_tarjan([a], [a0])[0]

    sess = create_session()
    sessrun(tf.global_variables_initializer())

    sessrun(grad.op)

    if linearize:
        linearize_lib.linearize()

    peak_memory = cpu_peak()
    expected_peak = 5e6  # originally needed 7 units, now a3,a5 are recomputed
    util.report_memory(peak_memory, expected_peak)
    if not REMOVE_ASSERTS:
        assert (peak_memory - expected_peak) < 1e5, "Difference too large."
def test_resnet_rewrite_tarjan(linearize=False):
  tf.reset_default_graph()
  tf_dev = tf.device('/cpu:0')
  tf_dev.__enter__()
  
  n = 6   # use n>5 (see test_chain_memory)

  nodes = make_resnet(n)
  a0 = nodes[0]
  a = nodes[-1]


  checkpoints = [nodes[3], nodes[5]] # ['a03_add:0', 'a05_add:0']
  grad = memory_saving_gradients.gradients_tarjan([a], [a0])[0]
  if linearize:
    added = linearize_lib.linearize(grad.op)

  sess = create_session()
  sessrun(tf.global_variables_initializer())

  sessrun(grad.op)

  peak_memory = cpu_peak()
  expected_peak = 4e6
  util.report_memory(peak_memory, expected_peak)

  if not REMOVE_ASSERTS:
    assert (peak_memory - expected_peak) < 1.1*10**6, "Difference too large."
def test_long_resnet_rewrite_tarjan(linearize=False):
  tf.reset_default_graph()
  tf_dev = tf.device('/cpu:0')
  tf_dev.__enter__()
  
  n = 100
  nodes = make_resnet(n)
  a0 = nodes[0]
  a = nodes[-1]

  start_time = time.time()
  with tf.control_dependencies([a]):
    grad = memory_saving_gradients.gradients_tarjan([a], [a0])[0]

  start_time = time.time()
  if linearize:
    added = linearize_lib.linearize(grad.op)

  sess = create_session()
  sessrun(tf.global_variables_initializer())

  sessrun(grad.op)

  peak_memory = cpu_peak()
  # 20 mem used with following tensors picked automatically
  # ['a10_add:0', 'a19_add:0', 'a28_add:0', 'a37_add:0', 'a46_add:0',
  # 'a55_add:0', 'a64_add:0', 'a73_add:0', 'a82_add:0', 'a91_add:0']

  expected_peak = 18 * 10**6 
  util.report_memory(peak_memory, expected_peak)

  if not REMOVE_ASSERTS:
    assert (peak_memory - expected_peak) < 1.1e6, "Difference too large."
def test_long_chain_tarjan(linearize=False):
  """Like test_chain, but use automatic rewriting with checkpoints="tarjan" 
  strategy."""

  tf.reset_default_graph()
  tf_dev = tf.device('/cpu:0')
  tf_dev.__enter__()
  
  n = 100

  nodes = make_chain_tanh_constant(n)
  a0 = nodes[0]
  a = nodes[-1]
  grad = memory_saving_gradients.gradients_tarjan([a], [a0])[0]

  sess = create_session()
  sessrun(tf.global_variables_initializer())

  sessrun(grad.op)

  if linearize:
    added = linearize_lib.linearize()

  peak_memory = cpu_peak()
  # points picked
  #  a09:0,19:0,a29:0,a39:0,a49:0,a58:0,a68:0,a78:0,a88:0,a97:0
  expected_peak = 18e6
  util.report_memory(peak_memory, expected_peak)

  # todo: remove "REMOVE_ASSERTS"
  if not REMOVE_ASSERTS:
    assert (peak_memory - expected_peak) < 1.1e6, "Difference too large."
def test_chain_tarjan(linearize=False):
  """Like test_chain, but use automatic rewriting with checkpoints="tarjan"
  strategy."""

  tf.reset_default_graph()
  tf_dev = tf.device('/cpu:0')
  tf_dev.__enter__()
  
  n = 6  # for n=5, only choice of a2 saves memory, and alg picks a3
         # hence use n>5 to avoid this edge condition

  nodes = util.make_chain_tanh_fill(n)
  a0 = nodes[0]
  a = nodes[-1]
  grad = memory_saving_gradients.gradients_tarjan([a], [a0])[0]

  sess = create_session()
  sessrun(tf.global_variables_initializer())

  sessrun(grad.op)

  if linearize:
    linearize_lib.linearize()

  peak_memory = cpu_peak()
  expected_peak = 5e6  # originally needed 7 units, now a3,a5 are recomputed
  util.report_memory(peak_memory, expected_peak)
  if not REMOVE_ASSERTS:
    assert (peak_memory - expected_peak) < 1e5, "Difference too large."