Beispiel #1
0
def get_fetchable_tensors(graph: tf.Graph,
                          names: List[str]) -> Dict[str, tf.Tensor]:
    """Retrieve fetch tensors from graph.

    Parameters
    ----------
    graph : tf.Graph
        A Tensorflow Graph
    names : List[str]
        List of operations or tensor names

    Returns
    -------
    Dict[str, tf.Tensor]
        Mapping of names to tf.Tensor
    """
    fetchable_tensors = {}
    for name in names:
        op_or_tensor = graph.as_graph_element(name)
        if isinstance(op_or_tensor, tf.Tensor):
            tensor = op_or_tensor
        else:
            if len(op_or_tensor.outputs) > 1:
                raise ValueError(
                    f"Found more than one tensor for operation {op_or_tensor}")
            tensor = op_or_tensor.outputs[0]
        if not graph.is_fetchable(tensor):
            raise ValueError(f"{name} should be fetchable but is not")
        fetchable_tensors[name] = tensor
    return fetchable_tensors
Beispiel #2
0
def get_by_name(graph: tf.Graph, name: str):
    """Return op in Graph with name or None if not found.

    Parameters
    ----------
    graph : tf.Graph
        A Tensorflow Graph

    Returns
    -------
    tf.Operation or None
    """
    for node in graph.as_graph_def().node:
        if node.name == name:
            return graph.as_graph_element(node.name)
    return None