Exemplo n.º 1
0
def get_descendants(x, collection=None):
    """Get descendant random variables of input.

  Parameters
  ----------
  x : RandomVariable or tf.Tensor
    Query node to find descendants of.
  collection : list of RandomVariable, optional
    The collection of random variables to check with respect to;
    defaults to all random variables in the graph.

  Returns
  -------
  list of RandomVariable
    Descendant random variables of x.

  Examples
  --------
  >>> a = Normal(mu=0.0, sigma=1.0)
  >>> b = Normal(mu=a, sigma=1.0)
  >>> c = Normal(mu=a, sigma=1.0)
  >>> d = Normal(mu=c, sigma=1.0)
  >>> set(ed.get_descendants(a)) == set([b, c, d])
  True
  """
    if collection is None:
        collection = random_variables()

    node_dict = {node.value(): node for node in collection}

    # Traverse the graph. Add each node to the set if it's in the collection.
    output = set()
    visited = set()
    nodes = {x}
    while nodes:
        node = nodes.pop()

        if node in visited:
            continue
        visited.add(node)

        if isinstance(node, RandomVariable):
            node = node.value()

        candidate_node = node_dict.get(node, None)
        if candidate_node is not None and candidate_node != x:
            output.add(candidate_node)

        for op in node.consumers():
            nodes.update(op.outputs)

    return list(output)
Exemplo n.º 2
0
def get_descendants(x, collection=None):
    """Get descendant random variables of input.

  Args:
    x: RandomVariable or tf.Tensor.
      Query node to find descendants of.
    collection: list of RandomVariable.
      The collection of random variables to check with respect to;
      defaults to all random variables in the graph.

  Returns:
    list of RandomVariable.
    Descendant random variables of x.

  #### Examples

  ```python
  a = Normal(0.0, 1.0)
  b = Normal(a, 1.0)
  c = Normal(a, 1.0)
  d = Normal(c, 1.0)
  assert set(ed.get_descendants(a)) == set([b, c, d])
  ```
  """
    if collection is None:
        collection = random_variables()

    node_dict = {node.value(): node for node in collection}

    # Traverse the graph. Add each node to the set if it's in the collection.
    output = set()
    visited = set()
    nodes = {x}
    while nodes:
        node = nodes.pop()

        if node in visited:
            continue
        visited.add(node)

        if isinstance(node, RandomVariable):
            node = node.value()

        candidate_node = node_dict.get(node, None)
        if candidate_node is not None and candidate_node != x:
            output.add(candidate_node)

        for op in node.consumers():
            nodes.update(op.outputs)

    return list(output)
Exemplo n.º 3
0
def get_descendants(x, collection=None):
  """Get descendant random variables of input.

  Args:
    x: RandomVariable or tf.Tensor.
      Query node to find descendants of.
    collection: list of RandomVariable.
      The collection of random variables to check with respect to;
      defaults to all random variables in the graph.

  Returns:
    list of RandomVariable.
    Descendant random variables of x.

  #### Examples

  ```python
  a = Normal(0.0, 1.0)
  b = Normal(a, 1.0)
  c = Normal(a, 1.0)
  d = Normal(c, 1.0)
  assert set(ed.get_descendants(a)) == set([b, c, d])
  ```
  """
  if collection is None:
    collection = random_variables()

  node_dict = {node.value(): node for node in collection}

  # Traverse the graph. Add each node to the set if it's in the collection.
  output = set()
  visited = set()
  nodes = {x}
  while nodes:
    node = nodes.pop()

    if node in visited:
      continue
    visited.add(node)

    if isinstance(node, RandomVariable):
      node = node.value()

    candidate_node = node_dict.get(node, None)
    if candidate_node is not None and candidate_node != x:
      output.add(candidate_node)

    for op in node.consumers():
      nodes.update(op.outputs)

  return list(output)
Exemplo n.º 4
0
def get_parents(x, collection=None):
  """Get parent random variables of input.

  Parameters
  ----------
  x : RandomVariable or tf.Tensor
    Query node to find parents of.
  collection : list of RandomVariable, optional
    The collection of random variables to check with respect to;
    defaults to all random variables in the graph.

  Returns
  -------
  list of RandomVariable
    Parent random variables of x.

  Examples
  --------
  >>> a = Normal(0.0, 1.0)
  >>> b = Normal(a, 1.0)
  >>> c = Normal(0.0, 1.0)
  >>> d = Normal(b * c, 1.0)
  >>> assert set(ed.get_parents(d)) == set([b, c])
  """
  if collection is None:
    collection = random_variables()

  node_dict = {node.value(): node for node in collection}

  # Traverse the graph. Add each node to the set if it's in the collection.
  output = set()
  visited = set()
  nodes = {x}
  while nodes:
    node = nodes.pop()

    if node in visited:
      continue
    visited.add(node)

    if isinstance(node, RandomVariable):
      node = node.value()

    candidate_node = node_dict.get(node, None)
    if candidate_node is not None and candidate_node != x:
      output.add(candidate_node)
    else:
      nodes.update(node.op.inputs)

  return list(output)
Exemplo n.º 5
0
def get_descendants(x, collection=None):
  """Get descendant random variables of input.

  Parameters
  ----------
  x : RandomVariable or tf.Tensor
    Query node to find descendants of.
  collection : list of RandomVariable, optional
    The collection of random variables to check with respect to;
    defaults to all random variables in the graph.

  Returns
  -------
  list of RandomVariable
    Descendant random variables of x.

  Examples
  --------
  >>> a = Normal(mu=0.0, sigma=1.0)
  >>> b = Normal(mu=a, sigma=1.0)
  >>> c = Normal(mu=a, sigma=1.0)
  >>> d = Normal(mu=c, sigma=1.0)
  >>> set(ed.get_descendants(a)) == set([b, c, d])
  True
  """
  if collection is None:
    collection = random_variables()

  node_dict = {node.value(): node for node in collection}

  # Traverse the graph. Add each node to the set if it's in the collection.
  output = set([])
  nodes = set([x])
  while nodes:
    node = nodes.pop()
    if isinstance(node, RandomVariable):
      node = node.value()

    candidate_node = node_dict.get(node, None)
    if candidate_node and candidate_node != x:
      output.add(candidate_node)

    for op in node.consumers():
      nodes.update(op.outputs)

  return list(output)
Exemplo n.º 6
0
def get_ancestors(x, collection=None):
    """Get ancestor random variables of input.

  Parameters
  ----------
  x : RandomVariable or tf.Tensor
    Query node to find ancestors of.
  collection : list of RandomVariable, optional
    The collection of random variables to check with respect to;
    defaults to all random variables in the graph.

  Returns
  -------
  list of RandomVariable
    Ancestor random variables of x.

  Examples
  --------
  >>> a = Normal(mu=0.0, sigma=1.0)
  >>> b = Normal(mu=a, sigma=1.0)
  >>> c = Normal(mu=0.0, sigma=1.0)
  >>> d = Normal(mu=tf.mul(b, c), sigma=1.0)
  >>> set(ed.get_ancestors(d)) == set([a, b, c])
  True
  """
    if collection is None:
        collection = random_variables()

    node_dict = {node.value(): node for node in collection}

    # Traverse the graph. Add each node to the set if it's in the collection.
    output = set([])
    nodes = set([x])
    while nodes:
        node = nodes.pop()
        if isinstance(node, RandomVariable):
            node = node.value()

        candidate_node = node_dict.get(node, None)
        if candidate_node and candidate_node != x:
            output.add(candidate_node)

        nodes.update(node.op.inputs)

    return list(output)
Exemplo n.º 7
0
def copy(org_instance,
         dict_swap=None,
         scope="copied",
         replace_itself=False,
         copy_q=False,
         copy_parent_rvs=True):
    """Build a new node in the TensorFlow graph from `org_instance`,
  where any of its ancestors existing in `dict_swap` are
  replaced with `dict_swap`'s corresponding value.

  Copying is done recursively. Any `Operation` whose output is
  required to copy `org_instance` is also copied (if it isn't already
  copied within the new scope).

  `tf.Variable`s, `tf.placeholder`s, and nodes of type `Queue` are
  always reused and not copied. In addition, `tf.Operation`s with
  operation-level seeds are copied with a new operation-level seed.

  Args:
    org_instance: RandomVariable, tf.Operation, tf.Tensor, or tf.Variable.
      Node to add in graph with replaced ancestors.
    dict_swap: dict.
      Random variables, variables, tensors, or operations to swap with.
      Its keys are what `org_instance` may depend on, and its values are
      the corresponding object (not necessarily of the same class
      instance, but must have the same type, e.g., float32) that is used
      in exchange.
    scope: str.
      A scope for the new node(s). This is used to avoid name
      conflicts with the original node(s).
    replace_itself: bool.
      Whether to replace `org_instance` itself if it exists in
      `dict_swap`. (This is used for the recursion.)
    copy_q: bool.
      Whether to copy the replaced tensors too (if not already
      copied within the new scope). Otherwise will reuse them.
    copy_parent_rvs:
      Whether to copy parent random variables `org_instance` depends
      on. Otherwise will copy only the sample tensors and not the
      random variable class itself.

  Returns:
    RandomVariable, tf.Variable, tf.Tensor, or tf.Operation.
    The copied node.

  Raises:
    TypeError.
    If `org_instance` is not one of the above types.

  #### Examples

  ```python
  x = tf.constant(2.0)
  y = tf.constant(3.0)
  z = x * y

  qx = tf.constant(4.0)
  # The TensorFlow graph is currently
  # `x` -> `z` <- y`, `qx`

  # This adds a subgraph with newly copied nodes,
  # `qx` -> `copied/z` <- `copied/y`
  z_new = ed.copy(z, {x: qx})

  sess = tf.Session()
  sess.run(z)
  6.0
  sess.run(z_new)
  12.0
  ```
  """
    if not isinstance(org_instance,
                      (RandomVariable, tf.Operation, tf.Tensor, tf.Variable)):
        raise TypeError("Could not copy instance: " + str(org_instance))

    if dict_swap is None:
        dict_swap = {}
    if scope[-1] != '/':
        scope += '/'

    # Swap instance if in dictionary.
    if org_instance in dict_swap and replace_itself:
        org_instance = dict_swap[org_instance]
        if not copy_q:
            return org_instance
    elif isinstance(org_instance, tf.Tensor) and replace_itself:
        # Deal with case when `org_instance` is the associated tensor
        # from the RandomVariable, e.g., `z.value()`. If
        # `dict_swap={z: qz}`, we aim to swap it with `qz.value()`.
        for key, value in six.iteritems(dict_swap):
            if isinstance(key, RandomVariable):
                if org_instance == key.value():
                    if isinstance(value, RandomVariable):
                        org_instance = value.value()
                    else:
                        org_instance = value
                    if not copy_q:
                        return org_instance
                    break

    # If instance is a tf.Variable, return it; do not copy any. Note we
    # check variables via their name. If we get variables through an
    # op's inputs, it has type tf.Tensor and not tf.Variable.
    if isinstance(org_instance, (tf.Tensor, tf.Variable)):
        for variable in tf.global_variables():
            if org_instance.name == variable.name:
                if variable in dict_swap and replace_itself:
                    # Deal with case when `org_instance` is the associated _ref
                    # tensor for a tf.Variable.
                    org_instance = dict_swap[variable]
                    if not copy_q or isinstance(org_instance, tf.Variable):
                        return org_instance
                    for variable in tf.global_variables():
                        if org_instance.name == variable.name:
                            return variable
                    break
                else:
                    return variable

    graph = tf.get_default_graph()
    new_name = scope + org_instance.name

    # If an instance of the same name exists, return it.
    if isinstance(org_instance, RandomVariable):
        for rv in random_variables():
            if new_name == rv.name:
                return rv
    elif isinstance(org_instance, (tf.Tensor, tf.Operation)):
        try:
            return graph.as_graph_element(new_name,
                                          allow_tensor=True,
                                          allow_operation=True)
        except:
            pass

    # Preserve ordering of random variables. Random variables are always
    # copied first (from parent -> child) before any deterministic
    # operations that depend on them.
    if copy_parent_rvs and \
            isinstance(org_instance, (RandomVariable, tf.Tensor, tf.Variable)):
        for v in get_parents(org_instance):
            copy(v, dict_swap, scope, True, copy_q, True)

    if isinstance(org_instance, RandomVariable):
        rv = org_instance

        # If it has copiable arguments, copy them.
        args = [
            _copy_default(arg, dict_swap, scope, True, copy_q, False)
            for arg in rv._args
        ]

        kwargs = {}
        for key, value in six.iteritems(rv._kwargs):
            if isinstance(value, list):
                kwargs[key] = [
                    _copy_default(v, dict_swap, scope, True, copy_q, False)
                    for v in value
                ]
            else:
                kwargs[key] = _copy_default(value, dict_swap, scope, True,
                                            copy_q, False)

        kwargs['name'] = new_name
        # Create new random variable with copied arguments.
        try:
            new_rv = type(rv)(*args, **kwargs)
        except ValueError:
            # Handle case where parameters are copied under absolute name
            # scope. This can cause an error when creating a new random
            # variable as tf.identity name ops are called on parameters ("op
            # with name already exists"). To avoid remove absolute name scope.
            kwargs['name'] = new_name[:-1]
            new_rv = type(rv)(*args, **kwargs)
        return new_rv
    elif isinstance(org_instance, tf.Tensor):
        tensor = org_instance

        # Do not copy tf.placeholders.
        if 'Placeholder' in tensor.op.type:
            return tensor

        # A tensor is one of the outputs of its underlying
        # op. Therefore copy the op itself.
        op = tensor.op
        new_op = copy(op, dict_swap, scope, True, copy_q, False)

        output_index = op.outputs.index(tensor)
        new_tensor = new_op.outputs[output_index]

        # Add copied tensor to collections that the original one is in.
        for name, collection in six.iteritems(tensor.graph._collections):
            if tensor in collection:
                graph.add_to_collection(name, new_tensor)

        return new_tensor
    elif isinstance(org_instance, tf.Operation):
        op = org_instance

        # Do not copy queue operations.
        if 'Queue' in op.type:
            return op

        # Copy the node def.
        # It is unique to every Operation instance. Replace the name and
        # its operation-level seed if it has one.
        node_def = deepcopy(op.node_def)
        node_def.name = new_name

        # when copying control flow contexts,
        # we need to make sure frame definitions are copied
        if 'frame_name' in node_def.attr and node_def.attr[
                'frame_name'].s != b'':
            node_def.attr['frame_name'].s = (scope.encode('utf-8') +
                                             node_def.attr['frame_name'].s)

        if 'seed2' in node_def.attr and tf.get_seed(None)[1] is not None:
            node_def.attr['seed2'].i = tf.get_seed(None)[1]

        # Copy other arguments needed for initialization.
        output_types = op._output_types[:]

        # If it has an original op, copy it.
        if op._original_op is not None:
            original_op = copy(op._original_op, dict_swap, scope, True, copy_q,
                               False)
        else:
            original_op = None

        # Copy the op def.
        # It is unique to every Operation type.
        op_def = deepcopy(op.op_def)

        new_op = tf.Operation(
            node_def,
            graph,
            [],  # inputs; will add them afterwards
            output_types,
            [],  # control inputs; will add them afterwards
            [],  # input types; will add them afterwards
            original_op,
            op_def)

        # advertise op early to break recursions
        graph._add_op(new_op)

        # If it has control inputs, copy them.
        control_inputs = []
        for x in op.control_inputs:
            elem = copy(x, dict_swap, scope, True, copy_q, False)
            if not isinstance(elem, tf.Operation):
                elem = tf.convert_to_tensor(elem)

            control_inputs.append(elem)

        new_op._add_control_inputs(control_inputs)

        # If it has inputs, copy them.
        for x in op.inputs:
            elem = copy(x, dict_swap, scope, True, copy_q, False)
            if not isinstance(elem, tf.Operation):
                elem = tf.convert_to_tensor(elem)

            new_op._add_input(elem)

        # Copy the control flow context.
        control_flow_context = _copy_context(op._get_control_flow_context(),
                                             {}, dict_swap, scope, copy_q)
        new_op._set_control_flow_context(control_flow_context)

        # Use Graph's private methods to add the op, following
        # implementation of `tf.Graph().create_op()`.
        compute_shapes = True
        compute_device = True
        op_type = new_name

        if compute_shapes:
            #set_shapes_for_outputs(new_op)
            set_shape_and_handle_data_for_outputs(new_op)
        graph._record_op_seen_by_control_dependencies(new_op)

        if compute_device:
            graph._apply_device_functions(new_op)

        if graph._colocation_stack:
            all_colocation_groups = []
            for colocation_op in graph._colocation_stack:
                all_colocation_groups.extend(colocation_op.colocation_groups())
                if colocation_op.device:
                    # Make this device match the device of the colocated op, to
                    # provide consistency between the device and the colocation
                    # property.
                    if new_op.device and new_op.device != colocation_op.device:
                        logging.warning(
                            "Tried to colocate %s with an op %s that had "
                            "a different device: %s vs %s. "
                            "Ignoring colocation property.", name,
                            colocation_op.name, new_op.device,
                            colocation_op.device)

            all_colocation_groups = sorted(set(all_colocation_groups))
            new_op.node_def.attr["_class"].CopyFrom(
                attr_value_pb2.AttrValue(
                    list=attr_value_pb2.AttrValue.ListValue(
                        s=all_colocation_groups)))

        # Sets "container" attribute if
        # (1) graph._container is not None
        # (2) "is_stateful" is set in OpDef
        # (3) "container" attribute is in OpDef
        # (4) "container" attribute is None
        if (graph._container and op_type in graph._registered_ops
                and graph._registered_ops[op_type].is_stateful
                and "container" in new_op.node_def.attr
                and not new_op.node_def.attr["container"].s):
            new_op.node_def.attr["container"].CopyFrom(
                attr_value_pb2.AttrValue(s=compat.as_bytes(graph._container)))

        return new_op
    else:
        raise TypeError("Could not copy instance: " + str(org_instance))
Exemplo n.º 8
0
def copy(org_instance, dict_swap=None, scope="copied",
         replace_itself=False, copy_q=False, copy_parent_rvs=True):
  """Build a new node in the TensorFlow graph from `org_instance`,
  where any of its ancestors existing in `dict_swap` are
  replaced with `dict_swap`'s corresponding value.

  Copying is done recursively. Any `Operation` whose output is
  required to copy `org_instance` is also copied (if it isn't already
  copied within the new scope).

  `tf.Variable`s, `tf.placeholder`s, and nodes of type `Queue` are
  always reused and not copied. In addition, `tf.Operation`s with
  operation-level seeds are copied with a new operation-level seed.

  Args:
    org_instance: RandomVariable, tf.Operation, tf.Tensor, or tf.Variable.
      Node to add in graph with replaced ancestors.
    dict_swap: dict.
      Random variables, variables, tensors, or operations to swap with.
      Its keys are what `org_instance` may depend on, and its values are
      the corresponding object (not necessarily of the same class
      instance, but must have the same type, e.g., float32) that is used
      in exchange.
    scope: str.
      A scope for the new node(s). This is used to avoid name
      conflicts with the original node(s).
    replace_itself: bool.
      Whether to replace `org_instance` itself if it exists in
      `dict_swap`. (This is used for the recursion.)
    copy_q: bool.
      Whether to copy the replaced tensors too (if not already
      copied within the new scope). Otherwise will reuse them.
    copy_parent_rvs:
      Whether to copy parent random variables `org_instance` depends
      on. Otherwise will copy only the sample tensors and not the
      random variable class itself.

  Returns:
    RandomVariable, tf.Variable, tf.Tensor, or tf.Operation.
    The copied node.

  Raises:
    TypeError.
    If `org_instance` is not one of the above types.

  #### Examples

  ```python
  x = tf.constant(2.0)
  y = tf.constant(3.0)
  z = x * y

  qx = tf.constant(4.0)
  # The TensorFlow graph is currently
  # `x` -> `z` <- y`, `qx`

  # This adds a subgraph with newly copied nodes,
  # `qx` -> `copied/z` <- `copied/y`
  z_new = ed.copy(z, {x: qx})

  sess = tf.Session()
  sess.run(z)
  6.0
  sess.run(z_new)
  12.0
  ```
  """
  if not isinstance(org_instance,
                    (RandomVariable, tf.Operation, tf.Tensor, tf.Variable)):
    raise TypeError("Could not copy instance: " + str(org_instance))

  if dict_swap is None:
    dict_swap = {}
  if scope[-1] != '/':
    scope += '/'

  # Swap instance if in dictionary.
  if org_instance in dict_swap and replace_itself:
    org_instance = dict_swap[org_instance]
    if not copy_q:
      return org_instance
  elif isinstance(org_instance, tf.Tensor) and replace_itself:
    # Deal with case when `org_instance` is the associated tensor
    # from the RandomVariable, e.g., `z.value()`. If
    # `dict_swap={z: qz}`, we aim to swap it with `qz.value()`.
    for key, value in six.iteritems(dict_swap):
      if isinstance(key, RandomVariable):
        if org_instance == key.value():
          if isinstance(value, RandomVariable):
            org_instance = value.value()
          else:
            org_instance = value
          if not copy_q:
            return org_instance
          break

  # If instance is a tf.Variable, return it; do not copy any. Note we
  # check variables via their name. If we get variables through an
  # op's inputs, it has type tf.Tensor and not tf.Variable.
  if isinstance(org_instance, (tf.Tensor, tf.Variable)):
    for variable in tf.global_variables():
      if org_instance.name == variable.name:
        if variable in dict_swap and replace_itself:
          # Deal with case when `org_instance` is the associated _ref
          # tensor for a tf.Variable.
          org_instance = dict_swap[variable]
          if not copy_q or isinstance(org_instance, tf.Variable):
            return org_instance
          for variable in tf.global_variables():
            if org_instance.name == variable.name:
              return variable
          break
        else:
          return variable

  graph = tf.get_default_graph()
  new_name = scope + org_instance.name

  # If an instance of the same name exists, return it.
  if isinstance(org_instance, RandomVariable):
    for rv in random_variables():
      if new_name == rv.name:
        return rv
  elif isinstance(org_instance, (tf.Tensor, tf.Operation)):
    try:
      return graph.as_graph_element(new_name,
                                    allow_tensor=True,
                                    allow_operation=True)
    except:
      pass

  # Preserve ordering of random variables. Random variables are always
  # copied first (from parent -> child) before any deterministic
  # operations that depend on them.
  if copy_parent_rvs and \
          isinstance(org_instance, (RandomVariable, tf.Tensor, tf.Variable)):
    for v in get_parents(org_instance):
      copy(v, dict_swap, scope, True, copy_q, True)

  if isinstance(org_instance, RandomVariable):
    rv = org_instance

    # If it has copiable arguments, copy them.
    args = [_copy_default(arg, dict_swap, scope, True, copy_q, False)
            for arg in rv._args]

    kwargs = {}
    for key, value in six.iteritems(rv._kwargs):
      if isinstance(value, list):
        kwargs[key] = [_copy_default(v, dict_swap, scope, True, copy_q, False)
                       for v in value]
      else:
        kwargs[key] = _copy_default(
            value, dict_swap, scope, True, copy_q, False)

    kwargs['name'] = new_name
    # Create new random variable with copied arguments.
    try:
      new_rv = type(rv)(*args, **kwargs)
    except ValueError:
      # Handle case where parameters are copied under absolute name
      # scope. This can cause an error when creating a new random
      # variable as tf.identity name ops are called on parameters ("op
      # with name already exists"). To avoid remove absolute name scope.
      kwargs['name'] = new_name[:-1]
      new_rv = type(rv)(*args, **kwargs)
    return new_rv
  elif isinstance(org_instance, tf.Tensor):
    tensor = org_instance

    # Do not copy tf.placeholders.
    if 'Placeholder' in tensor.op.type:
      return tensor

    # A tensor is one of the outputs of its underlying
    # op. Therefore copy the op itself.
    op = tensor.op
    new_op = copy(op, dict_swap, scope, True, copy_q, False)

    output_index = op.outputs.index(tensor)
    new_tensor = new_op.outputs[output_index]

    # Add copied tensor to collections that the original one is in.
    for name, collection in six.iteritems(tensor.graph._collections):
      if tensor in collection:
        graph.add_to_collection(name, new_tensor)

    return new_tensor
  elif isinstance(org_instance, tf.Operation):
    op = org_instance

    # Do not copy queue operations.
    if 'Queue' in op.type:
      return op

    # Copy the node def.
    # It is unique to every Operation instance. Replace the name and
    # its operation-level seed if it has one.
    node_def = deepcopy(op.node_def)
    node_def.name = new_name

    # when copying control flow contexts,
    # we need to make sure frame definitions are copied
    if 'frame_name' in node_def.attr and node_def.attr['frame_name'].s != b'':
      node_def.attr['frame_name'].s = (scope.encode('utf-8') +
                                       node_def.attr['frame_name'].s)

    if 'seed2' in node_def.attr and tf.get_seed(None)[1] is not None:
      node_def.attr['seed2'].i = tf.get_seed(None)[1]

    # Copy other arguments needed for initialization.
    output_types = op._output_types[:]

    # If it has an original op, copy it.
    if op._original_op is not None:
      original_op = copy(op._original_op, dict_swap, scope, True, copy_q, False)
    else:
      original_op = None

    # Copy the op def.
    # It is unique to every Operation type.
    op_def = deepcopy(op.op_def)

    new_op = tf.Operation(node_def,
                          graph,
                          [],  # inputs; will add them afterwards
                          output_types,
                          [],  # control inputs; will add them afterwards
                          [],  # input types; will add them afterwards
                          original_op,
                          op_def)

    # advertise op early to break recursions
    graph._add_op(new_op)

    # If it has control inputs, copy them.
    control_inputs = []
    for x in op.control_inputs:
      elem = copy(x, dict_swap, scope, True, copy_q, False)
      if not isinstance(elem, tf.Operation):
        elem = tf.convert_to_tensor(elem)

      control_inputs.append(elem)

    new_op._add_control_inputs(control_inputs)

    # If it has inputs, copy them.
    for x in op.inputs:
      elem = copy(x, dict_swap, scope, True, copy_q, False)
      if not isinstance(elem, tf.Operation):
        elem = tf.convert_to_tensor(elem)

      new_op._add_input(elem)

    # Copy the control flow context.
    control_flow_context = _copy_context(op._get_control_flow_context(), {},
                                         dict_swap, scope, copy_q)
    new_op._set_control_flow_context(control_flow_context)

    # Use Graph's private methods to add the op, following
    # implementation of `tf.Graph().create_op()`.
    compute_shapes = True
    compute_device = True
    op_type = new_name

    if compute_shapes:
      set_shapes_for_outputs(new_op)
    graph._record_op_seen_by_control_dependencies(new_op)

    if compute_device:
      graph._apply_device_functions(new_op)

    if graph._colocation_stack:
      all_colocation_groups = []
      for colocation_op in graph._colocation_stack:
        all_colocation_groups.extend(colocation_op.colocation_groups())
        if colocation_op.device:
          # Make this device match the device of the colocated op, to
          # provide consistency between the device and the colocation
          # property.
          if new_op.device and new_op.device != colocation_op.device:
            logging.warning("Tried to colocate %s with an op %s that had "
                            "a different device: %s vs %s. "
                            "Ignoring colocation property.",
                            name, colocation_op.name, new_op.device,
                            colocation_op.device)

      all_colocation_groups = sorted(set(all_colocation_groups))
      new_op.node_def.attr["_class"].CopyFrom(attr_value_pb2.AttrValue(
          list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups)))

    # Sets "container" attribute if
    # (1) graph._container is not None
    # (2) "is_stateful" is set in OpDef
    # (3) "container" attribute is in OpDef
    # (4) "container" attribute is None
    if (graph._container and
        op_type in graph._registered_ops and
        graph._registered_ops[op_type].is_stateful and
        "container" in new_op.node_def.attr and
            not new_op.node_def.attr["container"].s):
      new_op.node_def.attr["container"].CopyFrom(
          attr_value_pb2.AttrValue(s=compat.as_bytes(graph._container)))

    return new_op
  else:
    raise TypeError("Could not copy instance: " + str(org_instance))