示例#1
0
def test_toposort():
    tf.reset_default_graph()
    nodes = util.make_caterpillar_graph(length=2)
    graph = linearize_lib.get_graph()
    initial = list(toposort(graph))[0]
    assert len(initial) == 1
    assert list(initial)[0].name == 'merge2'
def test_toposort():
  tf.reset_default_graph()
  nodes = util.make_caterpillar_graph(length=2)
  graph = linearize_lib.get_graph()
  initial = list(toposort(graph))[0]
  assert len(initial) == 1
  assert list(initial)[0].name == 'merge2'
示例#3
0
def test_prune():
    tf.reset_default_graph()
    a = tf.constant([1, 2, 3])
    b = tf.constant([4, 5, 6])
    c = a + b
    d = tf.constant([7, 8, 9])
    e = tf.constant([7, 8, 9])
    graph = linearize_lib.get_graph()
    pruned = linearize_lib.prune_graph(graph, [c, d])
    assert a.op in pruned
    assert e.op not in pruned
def test_prune():
  tf.reset_default_graph()
  a = tf.constant([1,2,3])
  b = tf.constant([4,5,6])
  c = a + b
  d = tf.constant([7,8,9])
  e = tf.constant([7,8,9])
  graph = linearize_lib.get_graph()
  pruned = linearize_lib.prune_graph(graph, [c, d])
  assert a.op in pruned
  assert e.op not in pruned
示例#5
0
def test_print():
  """Should print:
  leaf1 -> merge1
  leaf0 -> merge0
  merge1 -> merge2
  merge0 -> merge1
  leaf2 -> merge2
  leaf0/shape -> leaf0
  leaf1/shape -> leaf1
  leaf2/shape -> leaf2
  """
  
  nodes = make_caterpillar_graph(length=2)
  linearize.print_tf_graph(linearize.get_graph())
示例#6
0
def test_print():
    """Should print:
  leaf1 -> merge1
  leaf0 -> merge0
  merge1 -> merge2
  merge0 -> merge1
  leaf2 -> merge2
  leaf0/shape -> leaf0
  leaf1/shape -> leaf1
  leaf2/shape -> leaf2
  """

    nodes = make_caterpillar_graph(length=2)
    linearize.print_tf_graph(linearize.get_graph())
示例#7
0
def test_print():
    """Should print:
  leaf1 -> merge1
  leaf0 -> merge0
  merge1 -> merge2
  merge0 -> merge1
  leaf2 -> merge2
  leaf0/shape -> leaf0
  leaf1/shape -> leaf1
  leaf2/shape -> leaf2
  """
    tf.reset_default_graph()

    nodes = util.make_caterpillar_graph(length=2)
    linearize_lib.print_graph(linearize_lib.get_graph())
def test_print():
  """Should print:
  leaf1 -> merge1
  leaf0 -> merge0
  merge1 -> merge2
  merge0 -> merge1
  leaf2 -> merge2
  leaf0/shape -> leaf0
  leaf1/shape -> leaf1
  leaf2/shape -> leaf2
  """
  tf.reset_default_graph()
  
  nodes = util.make_caterpillar_graph(length=2)
  linearize_lib.print_graph(linearize_lib.get_graph())
示例#9
0
def recompute_tensor(target, known_values, preceding_op=None,
                     copy_known_values=False):
  """Computes target tensor from known_values. If preceding_op is not None,
  adds necessary control dependencies such that newly created computation takes
  place after preceding_op. 

  If copy_known_values is set, also copies known_values (for nicer graph
  visualization)
  """

  assert is_computable(target, known_values)
  
  # position of target in parent op
  target_pos = list(target.op.outputs).index(target)

  if copy_known_values:
    computation = ge.get_backward_walk_ops(target)
  else:
    computation = ge.get_backward_walk_ops(target, stop_at_ts=known_values)
    
  # create copy of computation
  copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(computation), {})

  # find our target tensor in the new computation
  new_target_op = info._transformed_ops[target.op]
  new_target = new_target_op.outputs[target_pos]
  new_computation = list(info._transformed_ops.values())

  # restrict computation to run after given op
  SAVE_ON_CONTROL_EDGES = True

  if SAVE_ON_CONTROL_EDGES:
    # only add "run_after" control dependencies to root of computation,
    # the rest automatically runs after because of data dependencies
    # TODO: more efficient implementation by walking back from new_target
    # instead of whole graph
    computation_graph = linearize_lib.get_graph(restrict_to=new_computation)

    # note, toposort order is reversed from networkx/mine convention
    computation_root = list(toposort.toposort(computation_graph))[-1]
    for op in computation_root:
      run_after(op, preceding_op)
  else:
    if preceding_op is not None:
      for op in info._transformed_ops.values():
        run_after(op, preceding_op)
  return new_target
示例#10
0
def test_reversed_graph():
    tf.reset_default_graph()
    a = tf.constant([1, 2, 3])
    c = tf.constant([4, 5, 6])
    result = tf.nn.top_k(a)
    b = result[0] + result[1] + c
    d = tf.constant([7, 8, 9])

    graph = linearize_lib.get_graph()

    # graph looks like this
    """Const -> TopKV2
Const_1 -> add_1
Const_2
TopKV2 -> add
TopKV2/k -> TopKV2
add -> add_1
add_1
"""

    nodes = list(graph.keys())
    assert nodes[0].name == 'Const'
    assert nodes[-1].name == 'add_1'
    assert list(graph[nodes[0]])[0].name == 'TopKV2'

    graph = linearize_lib.reversed_graph(graph, deterministic=True)

    # graph looks like this
    """TopKV2 -> Const
TopKV2 -> TopKV2/k
Const
add_1 -> Const_1
add_1 -> add
Const_1
add -> TopKV2
TopKV2/k
Const_2
"""

    nodes = list(graph.keys())
    assert nodes[0].name == 'TopKV2'
    assert nodes[-1].name == 'Const_2'
    assert list(graph[nodes[0]])[0].name == 'Const'
def test_reversed_graph():
  tf.reset_default_graph()
  a = tf.constant([1,2,3])
  c = tf.constant([4,5,6])
  result = tf.nn.top_k(a)
  b = result[0]+result[1]+c
  d = tf.constant([7,8,9])

  graph = linearize_lib.get_graph()

  # graph looks like this
  """Const -> TopKV2
Const_1 -> add_1
Const_2
TopKV2 -> add
TopKV2/k -> TopKV2
add -> add_1
add_1
"""

  nodes = list(graph.keys())
  assert nodes[0].name == 'Const'
  assert nodes[-1].name == 'add_1'
  assert list(graph[nodes[0]])[0].name == 'TopKV2'

  graph = linearize_lib.reversed_graph(graph, deterministic=True)

  # graph looks like this
  """TopKV2 -> Const
TopKV2 -> TopKV2/k
Const
add_1 -> Const_1
add_1 -> add
Const_1
add -> TopKV2
TopKV2/k
Const_2
"""
  
  nodes = list(graph.keys())
  assert nodes[0].name == 'TopKV2'
  assert nodes[-1].name == 'Const_2'
  assert list(graph[nodes[0]])[0].name == 'Const'
示例#12
0
def test_toposort():
  nodes = make_caterpillar_graph(length=2)
  graph = linearize.get_graph()
  print(list(toposort.toposort(graph)))
示例#13
0
def test_toposort():
    nodes = make_caterpillar_graph(length=2)
    graph = linearize.get_graph()
    print(list(toposort.toposort(graph)))