def get_descendants(x, collection=None): """Get descendant random variables of input. Parameters ---------- x : RandomVariable or tf.Tensor Query node to find descendants of. collection : list of RandomVariable, optional The collection of random variables to check with respect to; defaults to all random variables in the graph. Returns ------- list of RandomVariable Descendant random variables of x. Examples -------- >>> a = Normal(mu=0.0, sigma=1.0) >>> b = Normal(mu=a, sigma=1.0) >>> c = Normal(mu=a, sigma=1.0) >>> d = Normal(mu=c, sigma=1.0) >>> set(ed.get_descendants(a)) == set([b, c, d]) True """ if collection is None: collection = random_variables() node_dict = {node.value(): node for node in collection} # Traverse the graph. Add each node to the set if it's in the collection. output = set() visited = set() nodes = {x} while nodes: node = nodes.pop() if node in visited: continue visited.add(node) if isinstance(node, RandomVariable): node = node.value() candidate_node = node_dict.get(node, None) if candidate_node is not None and candidate_node != x: output.add(candidate_node) for op in node.consumers(): nodes.update(op.outputs) return list(output)
def get_descendants(x, collection=None): """Get descendant random variables of input. Args: x: RandomVariable or tf.Tensor. Query node to find descendants of. collection: list of RandomVariable. The collection of random variables to check with respect to; defaults to all random variables in the graph. Returns: list of RandomVariable. Descendant random variables of x. #### Examples ```python a = Normal(0.0, 1.0) b = Normal(a, 1.0) c = Normal(a, 1.0) d = Normal(c, 1.0) assert set(ed.get_descendants(a)) == set([b, c, d]) ``` """ if collection is None: collection = random_variables() node_dict = {node.value(): node for node in collection} # Traverse the graph. Add each node to the set if it's in the collection. output = set() visited = set() nodes = {x} while nodes: node = nodes.pop() if node in visited: continue visited.add(node) if isinstance(node, RandomVariable): node = node.value() candidate_node = node_dict.get(node, None) if candidate_node is not None and candidate_node != x: output.add(candidate_node) for op in node.consumers(): nodes.update(op.outputs) return list(output)
def get_descendants(x, collection=None): """Get descendant random variables of input. Args: x: RandomVariable or tf.Tensor. Query node to find descendants of. collection: list of RandomVariable. The collection of random variables to check with respect to; defaults to all random variables in the graph. Returns: list of RandomVariable. Descendant random variables of x. #### Examples ```python a = Normal(0.0, 1.0) b = Normal(a, 1.0) c = Normal(a, 1.0) d = Normal(c, 1.0) assert set(ed.get_descendants(a)) == set([b, c, d]) ``` """ if collection is None: collection = random_variables() node_dict = {node.value(): node for node in collection} # Traverse the graph. Add each node to the set if it's in the collection. output = set() visited = set() nodes = {x} while nodes: node = nodes.pop() if node in visited: continue visited.add(node) if isinstance(node, RandomVariable): node = node.value() candidate_node = node_dict.get(node, None) if candidate_node is not None and candidate_node != x: output.add(candidate_node) for op in node.consumers(): nodes.update(op.outputs) return list(output)
def get_parents(x, collection=None): """Get parent random variables of input. Parameters ---------- x : RandomVariable or tf.Tensor Query node to find parents of. collection : list of RandomVariable, optional The collection of random variables to check with respect to; defaults to all random variables in the graph. Returns ------- list of RandomVariable Parent random variables of x. Examples -------- >>> a = Normal(0.0, 1.0) >>> b = Normal(a, 1.0) >>> c = Normal(0.0, 1.0) >>> d = Normal(b * c, 1.0) >>> assert set(ed.get_parents(d)) == set([b, c]) """ if collection is None: collection = random_variables() node_dict = {node.value(): node for node in collection} # Traverse the graph. Add each node to the set if it's in the collection. output = set() visited = set() nodes = {x} while nodes: node = nodes.pop() if node in visited: continue visited.add(node) if isinstance(node, RandomVariable): node = node.value() candidate_node = node_dict.get(node, None) if candidate_node is not None and candidate_node != x: output.add(candidate_node) else: nodes.update(node.op.inputs) return list(output)
def get_descendants(x, collection=None): """Get descendant random variables of input. Parameters ---------- x : RandomVariable or tf.Tensor Query node to find descendants of. collection : list of RandomVariable, optional The collection of random variables to check with respect to; defaults to all random variables in the graph. Returns ------- list of RandomVariable Descendant random variables of x. Examples -------- >>> a = Normal(mu=0.0, sigma=1.0) >>> b = Normal(mu=a, sigma=1.0) >>> c = Normal(mu=a, sigma=1.0) >>> d = Normal(mu=c, sigma=1.0) >>> set(ed.get_descendants(a)) == set([b, c, d]) True """ if collection is None: collection = random_variables() node_dict = {node.value(): node for node in collection} # Traverse the graph. Add each node to the set if it's in the collection. output = set([]) nodes = set([x]) while nodes: node = nodes.pop() if isinstance(node, RandomVariable): node = node.value() candidate_node = node_dict.get(node, None) if candidate_node and candidate_node != x: output.add(candidate_node) for op in node.consumers(): nodes.update(op.outputs) return list(output)
def get_ancestors(x, collection=None): """Get ancestor random variables of input. Parameters ---------- x : RandomVariable or tf.Tensor Query node to find ancestors of. collection : list of RandomVariable, optional The collection of random variables to check with respect to; defaults to all random variables in the graph. Returns ------- list of RandomVariable Ancestor random variables of x. Examples -------- >>> a = Normal(mu=0.0, sigma=1.0) >>> b = Normal(mu=a, sigma=1.0) >>> c = Normal(mu=0.0, sigma=1.0) >>> d = Normal(mu=tf.mul(b, c), sigma=1.0) >>> set(ed.get_ancestors(d)) == set([a, b, c]) True """ if collection is None: collection = random_variables() node_dict = {node.value(): node for node in collection} # Traverse the graph. Add each node to the set if it's in the collection. output = set([]) nodes = set([x]) while nodes: node = nodes.pop() if isinstance(node, RandomVariable): node = node.value() candidate_node = node_dict.get(node, None) if candidate_node and candidate_node != x: output.add(candidate_node) nodes.update(node.op.inputs) return list(output)
def copy(org_instance, dict_swap=None, scope="copied", replace_itself=False, copy_q=False, copy_parent_rvs=True): """Build a new node in the TensorFlow graph from `org_instance`, where any of its ancestors existing in `dict_swap` are replaced with `dict_swap`'s corresponding value. Copying is done recursively. Any `Operation` whose output is required to copy `org_instance` is also copied (if it isn't already copied within the new scope). `tf.Variable`s, `tf.placeholder`s, and nodes of type `Queue` are always reused and not copied. In addition, `tf.Operation`s with operation-level seeds are copied with a new operation-level seed. Args: org_instance: RandomVariable, tf.Operation, tf.Tensor, or tf.Variable. Node to add in graph with replaced ancestors. dict_swap: dict. Random variables, variables, tensors, or operations to swap with. Its keys are what `org_instance` may depend on, and its values are the corresponding object (not necessarily of the same class instance, but must have the same type, e.g., float32) that is used in exchange. scope: str. A scope for the new node(s). This is used to avoid name conflicts with the original node(s). replace_itself: bool. Whether to replace `org_instance` itself if it exists in `dict_swap`. (This is used for the recursion.) copy_q: bool. Whether to copy the replaced tensors too (if not already copied within the new scope). Otherwise will reuse them. copy_parent_rvs: Whether to copy parent random variables `org_instance` depends on. Otherwise will copy only the sample tensors and not the random variable class itself. Returns: RandomVariable, tf.Variable, tf.Tensor, or tf.Operation. The copied node. Raises: TypeError. If `org_instance` is not one of the above types. #### Examples ```python x = tf.constant(2.0) y = tf.constant(3.0) z = x * y qx = tf.constant(4.0) # The TensorFlow graph is currently # `x` -> `z` <- y`, `qx` # This adds a subgraph with newly copied nodes, # `qx` -> `copied/z` <- `copied/y` z_new = ed.copy(z, {x: qx}) sess = tf.Session() sess.run(z) 6.0 sess.run(z_new) 12.0 ``` """ if not isinstance(org_instance, (RandomVariable, tf.Operation, tf.Tensor, tf.Variable)): raise TypeError("Could not copy instance: " + str(org_instance)) if dict_swap is None: dict_swap = {} if scope[-1] != '/': scope += '/' # Swap instance if in dictionary. if org_instance in dict_swap and replace_itself: org_instance = dict_swap[org_instance] if not copy_q: return org_instance elif isinstance(org_instance, tf.Tensor) and replace_itself: # Deal with case when `org_instance` is the associated tensor # from the RandomVariable, e.g., `z.value()`. If # `dict_swap={z: qz}`, we aim to swap it with `qz.value()`. for key, value in six.iteritems(dict_swap): if isinstance(key, RandomVariable): if org_instance == key.value(): if isinstance(value, RandomVariable): org_instance = value.value() else: org_instance = value if not copy_q: return org_instance break # If instance is a tf.Variable, return it; do not copy any. Note we # check variables via their name. If we get variables through an # op's inputs, it has type tf.Tensor and not tf.Variable. if isinstance(org_instance, (tf.Tensor, tf.Variable)): for variable in tf.global_variables(): if org_instance.name == variable.name: if variable in dict_swap and replace_itself: # Deal with case when `org_instance` is the associated _ref # tensor for a tf.Variable. org_instance = dict_swap[variable] if not copy_q or isinstance(org_instance, tf.Variable): return org_instance for variable in tf.global_variables(): if org_instance.name == variable.name: return variable break else: return variable graph = tf.get_default_graph() new_name = scope + org_instance.name # If an instance of the same name exists, return it. if isinstance(org_instance, RandomVariable): for rv in random_variables(): if new_name == rv.name: return rv elif isinstance(org_instance, (tf.Tensor, tf.Operation)): try: return graph.as_graph_element(new_name, allow_tensor=True, allow_operation=True) except: pass # Preserve ordering of random variables. Random variables are always # copied first (from parent -> child) before any deterministic # operations that depend on them. if copy_parent_rvs and \ isinstance(org_instance, (RandomVariable, tf.Tensor, tf.Variable)): for v in get_parents(org_instance): copy(v, dict_swap, scope, True, copy_q, True) if isinstance(org_instance, RandomVariable): rv = org_instance # If it has copiable arguments, copy them. args = [ _copy_default(arg, dict_swap, scope, True, copy_q, False) for arg in rv._args ] kwargs = {} for key, value in six.iteritems(rv._kwargs): if isinstance(value, list): kwargs[key] = [ _copy_default(v, dict_swap, scope, True, copy_q, False) for v in value ] else: kwargs[key] = _copy_default(value, dict_swap, scope, True, copy_q, False) kwargs['name'] = new_name # Create new random variable with copied arguments. try: new_rv = type(rv)(*args, **kwargs) except ValueError: # Handle case where parameters are copied under absolute name # scope. This can cause an error when creating a new random # variable as tf.identity name ops are called on parameters ("op # with name already exists"). To avoid remove absolute name scope. kwargs['name'] = new_name[:-1] new_rv = type(rv)(*args, **kwargs) return new_rv elif isinstance(org_instance, tf.Tensor): tensor = org_instance # Do not copy tf.placeholders. if 'Placeholder' in tensor.op.type: return tensor # A tensor is one of the outputs of its underlying # op. Therefore copy the op itself. op = tensor.op new_op = copy(op, dict_swap, scope, True, copy_q, False) output_index = op.outputs.index(tensor) new_tensor = new_op.outputs[output_index] # Add copied tensor to collections that the original one is in. for name, collection in six.iteritems(tensor.graph._collections): if tensor in collection: graph.add_to_collection(name, new_tensor) return new_tensor elif isinstance(org_instance, tf.Operation): op = org_instance # Do not copy queue operations. if 'Queue' in op.type: return op # Copy the node def. # It is unique to every Operation instance. Replace the name and # its operation-level seed if it has one. node_def = deepcopy(op.node_def) node_def.name = new_name # when copying control flow contexts, # we need to make sure frame definitions are copied if 'frame_name' in node_def.attr and node_def.attr[ 'frame_name'].s != b'': node_def.attr['frame_name'].s = (scope.encode('utf-8') + node_def.attr['frame_name'].s) if 'seed2' in node_def.attr and tf.get_seed(None)[1] is not None: node_def.attr['seed2'].i = tf.get_seed(None)[1] # Copy other arguments needed for initialization. output_types = op._output_types[:] # If it has an original op, copy it. if op._original_op is not None: original_op = copy(op._original_op, dict_swap, scope, True, copy_q, False) else: original_op = None # Copy the op def. # It is unique to every Operation type. op_def = deepcopy(op.op_def) new_op = tf.Operation( node_def, graph, [], # inputs; will add them afterwards output_types, [], # control inputs; will add them afterwards [], # input types; will add them afterwards original_op, op_def) # advertise op early to break recursions graph._add_op(new_op) # If it has control inputs, copy them. control_inputs = [] for x in op.control_inputs: elem = copy(x, dict_swap, scope, True, copy_q, False) if not isinstance(elem, tf.Operation): elem = tf.convert_to_tensor(elem) control_inputs.append(elem) new_op._add_control_inputs(control_inputs) # If it has inputs, copy them. for x in op.inputs: elem = copy(x, dict_swap, scope, True, copy_q, False) if not isinstance(elem, tf.Operation): elem = tf.convert_to_tensor(elem) new_op._add_input(elem) # Copy the control flow context. control_flow_context = _copy_context(op._get_control_flow_context(), {}, dict_swap, scope, copy_q) new_op._set_control_flow_context(control_flow_context) # Use Graph's private methods to add the op, following # implementation of `tf.Graph().create_op()`. compute_shapes = True compute_device = True op_type = new_name if compute_shapes: #set_shapes_for_outputs(new_op) set_shape_and_handle_data_for_outputs(new_op) graph._record_op_seen_by_control_dependencies(new_op) if compute_device: graph._apply_device_functions(new_op) if graph._colocation_stack: all_colocation_groups = [] for colocation_op in graph._colocation_stack: all_colocation_groups.extend(colocation_op.colocation_groups()) if colocation_op.device: # Make this device match the device of the colocated op, to # provide consistency between the device and the colocation # property. if new_op.device and new_op.device != colocation_op.device: logging.warning( "Tried to colocate %s with an op %s that had " "a different device: %s vs %s. " "Ignoring colocation property.", name, colocation_op.name, new_op.device, colocation_op.device) all_colocation_groups = sorted(set(all_colocation_groups)) new_op.node_def.attr["_class"].CopyFrom( attr_value_pb2.AttrValue( list=attr_value_pb2.AttrValue.ListValue( s=all_colocation_groups))) # Sets "container" attribute if # (1) graph._container is not None # (2) "is_stateful" is set in OpDef # (3) "container" attribute is in OpDef # (4) "container" attribute is None if (graph._container and op_type in graph._registered_ops and graph._registered_ops[op_type].is_stateful and "container" in new_op.node_def.attr and not new_op.node_def.attr["container"].s): new_op.node_def.attr["container"].CopyFrom( attr_value_pb2.AttrValue(s=compat.as_bytes(graph._container))) return new_op else: raise TypeError("Could not copy instance: " + str(org_instance))
def copy(org_instance, dict_swap=None, scope="copied", replace_itself=False, copy_q=False, copy_parent_rvs=True): """Build a new node in the TensorFlow graph from `org_instance`, where any of its ancestors existing in `dict_swap` are replaced with `dict_swap`'s corresponding value. Copying is done recursively. Any `Operation` whose output is required to copy `org_instance` is also copied (if it isn't already copied within the new scope). `tf.Variable`s, `tf.placeholder`s, and nodes of type `Queue` are always reused and not copied. In addition, `tf.Operation`s with operation-level seeds are copied with a new operation-level seed. Args: org_instance: RandomVariable, tf.Operation, tf.Tensor, or tf.Variable. Node to add in graph with replaced ancestors. dict_swap: dict. Random variables, variables, tensors, or operations to swap with. Its keys are what `org_instance` may depend on, and its values are the corresponding object (not necessarily of the same class instance, but must have the same type, e.g., float32) that is used in exchange. scope: str. A scope for the new node(s). This is used to avoid name conflicts with the original node(s). replace_itself: bool. Whether to replace `org_instance` itself if it exists in `dict_swap`. (This is used for the recursion.) copy_q: bool. Whether to copy the replaced tensors too (if not already copied within the new scope). Otherwise will reuse them. copy_parent_rvs: Whether to copy parent random variables `org_instance` depends on. Otherwise will copy only the sample tensors and not the random variable class itself. Returns: RandomVariable, tf.Variable, tf.Tensor, or tf.Operation. The copied node. Raises: TypeError. If `org_instance` is not one of the above types. #### Examples ```python x = tf.constant(2.0) y = tf.constant(3.0) z = x * y qx = tf.constant(4.0) # The TensorFlow graph is currently # `x` -> `z` <- y`, `qx` # This adds a subgraph with newly copied nodes, # `qx` -> `copied/z` <- `copied/y` z_new = ed.copy(z, {x: qx}) sess = tf.Session() sess.run(z) 6.0 sess.run(z_new) 12.0 ``` """ if not isinstance(org_instance, (RandomVariable, tf.Operation, tf.Tensor, tf.Variable)): raise TypeError("Could not copy instance: " + str(org_instance)) if dict_swap is None: dict_swap = {} if scope[-1] != '/': scope += '/' # Swap instance if in dictionary. if org_instance in dict_swap and replace_itself: org_instance = dict_swap[org_instance] if not copy_q: return org_instance elif isinstance(org_instance, tf.Tensor) and replace_itself: # Deal with case when `org_instance` is the associated tensor # from the RandomVariable, e.g., `z.value()`. If # `dict_swap={z: qz}`, we aim to swap it with `qz.value()`. for key, value in six.iteritems(dict_swap): if isinstance(key, RandomVariable): if org_instance == key.value(): if isinstance(value, RandomVariable): org_instance = value.value() else: org_instance = value if not copy_q: return org_instance break # If instance is a tf.Variable, return it; do not copy any. Note we # check variables via their name. If we get variables through an # op's inputs, it has type tf.Tensor and not tf.Variable. if isinstance(org_instance, (tf.Tensor, tf.Variable)): for variable in tf.global_variables(): if org_instance.name == variable.name: if variable in dict_swap and replace_itself: # Deal with case when `org_instance` is the associated _ref # tensor for a tf.Variable. org_instance = dict_swap[variable] if not copy_q or isinstance(org_instance, tf.Variable): return org_instance for variable in tf.global_variables(): if org_instance.name == variable.name: return variable break else: return variable graph = tf.get_default_graph() new_name = scope + org_instance.name # If an instance of the same name exists, return it. if isinstance(org_instance, RandomVariable): for rv in random_variables(): if new_name == rv.name: return rv elif isinstance(org_instance, (tf.Tensor, tf.Operation)): try: return graph.as_graph_element(new_name, allow_tensor=True, allow_operation=True) except: pass # Preserve ordering of random variables. Random variables are always # copied first (from parent -> child) before any deterministic # operations that depend on them. if copy_parent_rvs and \ isinstance(org_instance, (RandomVariable, tf.Tensor, tf.Variable)): for v in get_parents(org_instance): copy(v, dict_swap, scope, True, copy_q, True) if isinstance(org_instance, RandomVariable): rv = org_instance # If it has copiable arguments, copy them. args = [_copy_default(arg, dict_swap, scope, True, copy_q, False) for arg in rv._args] kwargs = {} for key, value in six.iteritems(rv._kwargs): if isinstance(value, list): kwargs[key] = [_copy_default(v, dict_swap, scope, True, copy_q, False) for v in value] else: kwargs[key] = _copy_default( value, dict_swap, scope, True, copy_q, False) kwargs['name'] = new_name # Create new random variable with copied arguments. try: new_rv = type(rv)(*args, **kwargs) except ValueError: # Handle case where parameters are copied under absolute name # scope. This can cause an error when creating a new random # variable as tf.identity name ops are called on parameters ("op # with name already exists"). To avoid remove absolute name scope. kwargs['name'] = new_name[:-1] new_rv = type(rv)(*args, **kwargs) return new_rv elif isinstance(org_instance, tf.Tensor): tensor = org_instance # Do not copy tf.placeholders. if 'Placeholder' in tensor.op.type: return tensor # A tensor is one of the outputs of its underlying # op. Therefore copy the op itself. op = tensor.op new_op = copy(op, dict_swap, scope, True, copy_q, False) output_index = op.outputs.index(tensor) new_tensor = new_op.outputs[output_index] # Add copied tensor to collections that the original one is in. for name, collection in six.iteritems(tensor.graph._collections): if tensor in collection: graph.add_to_collection(name, new_tensor) return new_tensor elif isinstance(org_instance, tf.Operation): op = org_instance # Do not copy queue operations. if 'Queue' in op.type: return op # Copy the node def. # It is unique to every Operation instance. Replace the name and # its operation-level seed if it has one. node_def = deepcopy(op.node_def) node_def.name = new_name # when copying control flow contexts, # we need to make sure frame definitions are copied if 'frame_name' in node_def.attr and node_def.attr['frame_name'].s != b'': node_def.attr['frame_name'].s = (scope.encode('utf-8') + node_def.attr['frame_name'].s) if 'seed2' in node_def.attr and tf.get_seed(None)[1] is not None: node_def.attr['seed2'].i = tf.get_seed(None)[1] # Copy other arguments needed for initialization. output_types = op._output_types[:] # If it has an original op, copy it. if op._original_op is not None: original_op = copy(op._original_op, dict_swap, scope, True, copy_q, False) else: original_op = None # Copy the op def. # It is unique to every Operation type. op_def = deepcopy(op.op_def) new_op = tf.Operation(node_def, graph, [], # inputs; will add them afterwards output_types, [], # control inputs; will add them afterwards [], # input types; will add them afterwards original_op, op_def) # advertise op early to break recursions graph._add_op(new_op) # If it has control inputs, copy them. control_inputs = [] for x in op.control_inputs: elem = copy(x, dict_swap, scope, True, copy_q, False) if not isinstance(elem, tf.Operation): elem = tf.convert_to_tensor(elem) control_inputs.append(elem) new_op._add_control_inputs(control_inputs) # If it has inputs, copy them. for x in op.inputs: elem = copy(x, dict_swap, scope, True, copy_q, False) if not isinstance(elem, tf.Operation): elem = tf.convert_to_tensor(elem) new_op._add_input(elem) # Copy the control flow context. control_flow_context = _copy_context(op._get_control_flow_context(), {}, dict_swap, scope, copy_q) new_op._set_control_flow_context(control_flow_context) # Use Graph's private methods to add the op, following # implementation of `tf.Graph().create_op()`. compute_shapes = True compute_device = True op_type = new_name if compute_shapes: set_shapes_for_outputs(new_op) graph._record_op_seen_by_control_dependencies(new_op) if compute_device: graph._apply_device_functions(new_op) if graph._colocation_stack: all_colocation_groups = [] for colocation_op in graph._colocation_stack: all_colocation_groups.extend(colocation_op.colocation_groups()) if colocation_op.device: # Make this device match the device of the colocated op, to # provide consistency between the device and the colocation # property. if new_op.device and new_op.device != colocation_op.device: logging.warning("Tried to colocate %s with an op %s that had " "a different device: %s vs %s. " "Ignoring colocation property.", name, colocation_op.name, new_op.device, colocation_op.device) all_colocation_groups = sorted(set(all_colocation_groups)) new_op.node_def.attr["_class"].CopyFrom(attr_value_pb2.AttrValue( list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups))) # Sets "container" attribute if # (1) graph._container is not None # (2) "is_stateful" is set in OpDef # (3) "container" attribute is in OpDef # (4) "container" attribute is None if (graph._container and op_type in graph._registered_ops and graph._registered_ops[op_type].is_stateful and "container" in new_op.node_def.attr and not new_op.node_def.attr["container"].s): new_op.node_def.attr["container"].CopyFrom( attr_value_pb2.AttrValue(s=compat.as_bytes(graph._container))) return new_op else: raise TypeError("Could not copy instance: " + str(org_instance))