コード例 #1
0
ファイル: transform.py プロジェクト: bikong2/tensorflow
def _add_control_flow_ops(ops, control_ios):
  """Complete `ops` so that the tranformed graph is valid.

  Partially copying a graph can lead to a malformed graph. For instance,
  copying half of a while construct is likely to result in an invalid graph.
  This function attempts to add missing ops so that the transformation result
  in a valid graph.

  Args:
    ops: list of ops (modifed in-place).
    control_ios: object created by a call to `util.ControlOutputs`.
  """
  # Find while contexts.
  control_flow_contexts = set()
  for op in ops:
    cfc = op._control_flow_context  # pylint: disable=protected-access
    if cfc:
      control_flow_contexts.add(cfc)
  # Find new ops.
  new_ops = []
  for cfc in control_flow_contexts:
    if cfc.IsWhileContext():
      new_ops += select.get_walks_intersection_ops(
          [enter_t.op for enter_t in cfc.loop_enters],
          [exit_t.op for exit_t in cfc.loop_exits],
          control_ios=control_ios)
  # Add new ops.
  new_ops_set = set(new_ops)
  ops_set = frozenset(ops)
  for op in new_ops_set:
    if op not in ops_set:
      ops.append(op)
コード例 #2
0
def _add_control_flow_ops(ops, control_ios):
  """Complete `ops` so that the transformed graph is valid.

  Partially copying a graph can lead to a malformed graph. For instance,
  copying half of a while construct is likely to result in an invalid graph.
  This function attempts to add missing ops so that the transformation result
  in a valid graph.

  Args:
    ops: list of ops (modifed in-place).
    control_ios: object created by a call to `util.ControlOutputs`.
  """
  # Find while contexts.
  control_flow_contexts = set()
  for op in ops:
    cfc = op._control_flow_context  # pylint: disable=protected-access
    if cfc:
      control_flow_contexts.add(cfc)
  # Find new ops.
  new_ops = []
  for cfc in control_flow_contexts:
    if cfc.IsWhileContext():
      new_ops += select.get_walks_intersection_ops(
          [enter_t.op for enter_t in cfc.loop_enters],
          [exit_t.op for exit_t in cfc.loop_exits],
          control_ios=control_ios)
  # Add new ops.
  new_ops_set = set(new_ops)
  ops_set = frozenset(ops)
  for op in new_ops_set:
    if op not in ops_set:
      ops.append(op)
コード例 #3
0
def clone_replace(f, replace):
    flatten_target_ts = util.flatten_tree(f)
    graph = util.get_unique_graph(flatten_target_ts, check_types=(tf_ops.Tensor))
    control_ios = util.ControlOutputs(graph)
    ops = select.get_walks_intersection_ops(list(iterkeys(replace)),
                                            flatten_target_ts,
                                            control_ios=control_ios)
    if not ops:
        # this happens with disconnected inputs
        return f
    else:
        return tf.contrib.graph_editor.graph_replace(f, replace)
コード例 #4
0
def clone_replace(f, replace):
    flatten_target_ts = util.flatten_tree(f)
    graph = util.get_unique_graph(flatten_target_ts, check_types=(tf_ops.Tensor))
    control_ios = util.ControlOutputs(graph)
    ops = select.get_walks_intersection_ops(list(iterkeys(replace)),
                                            flatten_target_ts,
                                            control_ios=control_ios)
    if not ops:
        # this happens with disconnected inputs
        return f
    else:
        return tf.contrib.graph_editor.graph_replace(f, replace)
コード例 #5
0
def graph_replace(target_ts,
                  replacement_ts,
                  dst_scope="",
                  src_scope="",
                  reuse_dst_scope=False):
    """Create a new graph which compute the targets from the replaced Tensors.

  Args:
    target_ts: a single tf.Tensor or an iterable of tf.Tensor.
    replacement_ts: dictionary mapping from original tensors to replaced tensors
    dst_scope: the destination scope.
    src_scope: the source scope.
    reuse_dst_scope: if True the dst_scope is re-used if it already exists.
      Otherwise, the scope is given a unique name based on the one given
      by appending an underscore followed by a digit (default).
  Returns:
    A single tf.Tensor or a list of target tf.Tensor, depending on
    the type of the input argument `target_ts`.
    The returned tensors are recomputed using the tensors from replacement_ts.
  Raises:
    ValueError: if the targets are not connected to replacement_ts.
  """
    # Identify operations in the graph that will change.
    # Start forward walk at Tensors that will be replaced, and
    # backward walk at the target output Tensors.
    flatten_target_ts = util.flatten_tree(target_ts)
    # Construct the forward control dependencies edges so that
    # the get_walks_intersection_ops can also traverse the
    # control dependencies.
    graph = util.get_unique_graph(flatten_target_ts,
                                  check_types=(tf_ops.Tensor))
    control_ios = util.ControlOutputs(graph)
    ops = select.get_walks_intersection_ops(list(iterkeys(replacement_ts)),
                                            flatten_target_ts,
                                            control_ios=control_ios)
    if not ops:
        raise ValueError("Targets and replacements are not connected!")

    # Complete ops to avoid malformed control flow.
    # TODO (fkp): Consider moving this function deeper (in the transformer?). id:1300
    # https://github.com/imdone/tensorflow/issues/1301
    _add_control_flow_ops(ops, control_ios)

    # Create a copy of the relevant subgraph
    unused_sgv_, info = copy_with_input_replacements(ops, replacement_ts, None,
                                                     dst_scope, src_scope,
                                                     reuse_dst_scope)
    # Return the transformed targets but keep the original if the transformed
    # counterpart cannot be found
    missing_fn = lambda original_t: original_t
    return info.transformed(target_ts, missing_fn)
コード例 #6
0
ファイル: transform.py プロジェクト: bikong2/tensorflow
def graph_replace(target_ts, replacement_ts, dst_scope="",
                  src_scope="", reuse_dst_scope=False):
  """Create a new graph which compute the targets from the replaced Tensors.

  Args:
    target_ts: a single tf.Tensor or an iterable of tf.Tensor.
    replacement_ts: dictionary mapping from original tensors to replaced tensors
    dst_scope: the destination scope.
    src_scope: the source scope.
    reuse_dst_scope: if True the dst_scope is re-used if it already exists.
      Otherwise, the scope is given a unique name based on the one given
      by appending an underscore followed by a digit (default).
  Returns:
    A single tf.Tensor or a list of target tf.Tensor, depending on
    the type of the input argument `target_ts`.
    The returned tensors are recomputed using the tensors from replacement_ts.
  Raises:
    ValueError: if the targets are not connected to replacement_ts.
  """
  # Identify operations in the graph that will change.
  # Start forward walk at Tensors that will be replaced, and
  # backward walk at the target output Tensors.
  flatten_target_ts = util.flatten_tree(target_ts)
  # Construct the forward control dependencies edges so that
  # the get_walks_intersection_ops can also traverse the
  # control dependencies.
  graph = util.get_unique_graph(flatten_target_ts, check_types=(tf_ops.Tensor))
  control_ios = util.ControlOutputs(graph)
  ops = select.get_walks_intersection_ops(list(iterkeys(replacement_ts)),
                                          flatten_target_ts,
                                          control_ios=control_ios)
  if not ops:
    raise ValueError("Targets and replacements are not connected!")

  # Complete ops to avoid malformed control flow.
  # TODO(fkp): Consider moving this function deeper (in the transformer?).
  _add_control_flow_ops(ops, control_ios)

  # Create a copy of the relevant subgraph
  unused_sgv_, info = copy_with_input_replacements(
      ops, replacement_ts, None, dst_scope, src_scope, reuse_dst_scope)
  # Return the transformed targets but keep the original if the transformed
  # counterpart cannot be found
  missing_fn = lambda original_t: original_t
  return info.transformed(target_ts, missing_fn)
コード例 #7
0
ファイル: transform.py プロジェクト: vmichals/tensorflow
def graph_replace(target_ts, replacement_ts, dst_scope="",
                  src_scope="", reuse_dst_scope=False):
  """Create a new graph which compute the targets from the replaced Tensors.

  Args:
    target_ts: a single tf.Tensor or an iterabble of tf.Tensor.
    replacement_ts: dictionary mapping from original tensors to replaced tensors
    dst_scope: the destination scope.
    src_scope: the source scope.
    reuse_dst_scope: if True the dst_scope is re-used if it already exists.
      Otherwise, the scope is given a unique name based on the one given
      by appending an underscore followed by a digit (default).
  Returns:
    A single tf.Tensor or a list of target tf.Tensor, depending on
    the type of the input argument `target_ts`.
    The returned tensors are recomputed using the tensors from replacement_ts.
  Raises:
    ValueError: if the targets are not connected to replacement_ts.
  """
  # Identify operations in the graph that will change.
  # Start forward walk at Tensors that will be replaced, and
  # backward walk at the target output Tensors.
  flatten_target_ts = util.flatten_tree(target_ts)
  # Construct the forward control dependencies edges so that
  # the get_walks_intersection_ops can also traverse the
  # control dependencies.
  graph = util.get_unique_graph(flatten_target_ts, check_types=(tf_ops.Tensor))
  control_ios = util.ControlOutputs(graph)
  ops = select.get_walks_intersection_ops(replacement_ts.keys(),
                                          flatten_target_ts,
                                          control_ios=control_ios)
  if not ops:
    raise ValueError("Targets and replacements are not connected!")
  # Create a copy of the relevant subgraph
  _, info = copy_with_input_replacements(
      ops, replacement_ts, None, dst_scope, src_scope, reuse_dst_scope)
  # Return the transformed targets
  return info.transformed(target_ts)