Example #1
0
def clone_keras_tensors(args, keras_tensor_mapping):
    """Clone the keras tensors from the inputs.

    For any KerasTensor instance in the `args`, a new copy of KerasTensor will
    be created if it has not been cloned yet (by checking the
    `keras_tensor_mapping`). For any other types, the instance will be
    unchanged. This function is useful for cloning the Nodes since KerasTensor
    can't be reused across the models.

    Args:
      args: A nested structure of objects, which could contain KerasTensor.
      keras_tensor_mapping: A dict contains the ID of original KerasTensor, and
        the cloned KerasTensor instance. The dict will be updated with newly
        copied KerasTensor instances within this method.
    Returns:
      Same structure as inputs, with KerasTensor cloned.
    """
    result = []
    for obj in tf.nest.flatten(args):
        if node_module.is_keras_tensor(obj):
            if id(obj) in keras_tensor_mapping:
                cpy = keras_tensor_mapping[id(obj)]
            else:
                # Create copy of keras_tensor if we haven't done it before
                cpy = _clone_keras_tensor(obj)
                cpy._keras_history = obj._keras_history
                keras_tensor_mapping[id(obj)] = cpy
            result.append(cpy)
        else:
            result.append(obj)
    return tf.nest.pack_sequence_as(args, result)
Example #2
0
def is_input_keras_tensor(tensor):
    """Check if tensor is directly generated from `tf.keras.Input`.

    This check is useful when constructing the functional model, since we will
    need to clone Nodes and KerasTensors if the model is building from non input
    tensor.

    Args:
      tensor: A `KerasTensor` as inputs to the functional model.

    Returns:
      bool. Whether the tensor is directly generated from `tf.keras.Input`.

    Raises:
      ValueError: if the tensor is not a KerasTensor instance.
    """
    if not node_module.is_keras_tensor(tensor):
        raise ValueError(_KERAS_TENSOR_TYPE_CHECK_ERROR_MSG.format(tensor))
    return tensor.node.is_input
Example #3
0
def find_nodes_by_inputs_and_outputs(inputs, outputs):
    """Fetch all Nodes in the graph defined by "inputs" and "outputs".

    This method is used to find and then clone Nodes when creating a new
    sub-model from an existing functional model.

    Args:
      inputs: A nested structure of KerasTensor to use as model inputs.
      outputs: A nested structure of KerasTensor to use as model outputs.

    Returns:
      A list of Nodes that are connected to the inputs and outputs.

    Raises:
      ValueError: when inputs and outputs are disconnected or in case of
        unexpected objects in the inputs/outputs.
    """
    # We walk the graph bottom up, starting from output nodes, and keep tracing
    # the upstream node, until we find all the inputs nodes. We don't use top
    # down search here since we don't know whether a certain node is in the
    # graph between inputs and outputs, e.g. a functional graph could have
    # multiple outputs, and the user could choose a subset of them to build the
    # model. The bottom up approach will ensure all the nodes we visit are
    # actually in use. If we reach the top and didn't find the nodes in the
    # `inputs`, that's an error, since the user didn't specify the correct
    # inputs.
    start_keras_tensors = tf.nest.flatten(outputs)
    end_keras_tensors = tf.nest.flatten(inputs)

    for t in start_keras_tensors + end_keras_tensors:
        if not node_module.is_keras_tensor(t):
            raise ValueError(_KERAS_TENSOR_TYPE_CHECK_ERROR_MSG.format(t))
    end_ids = set([id(kt) for kt in end_keras_tensors])
    # Track all the end tensors we found so far, if we didn't reach all the
    # user-specified keras inputs after we finish the search, then that's an
    # error since the inputs are disconnected from the outputs.
    end_ids_found = set()

    nodes_to_visit = []
    nodes_in_graph = []
    node_id_visited = set()
    for t in start_keras_tensors:
        nodes_to_visit.append(t.node)

    while nodes_to_visit:
        node = nodes_to_visit.pop(0)
        if id(node) in node_id_visited:
            continue
        node_id_visited.add(id(node))
        nodes_in_graph.append(node)
        # Any input keras_tensor that produce the current node.
        for kt in node.keras_inputs:
            if id(kt) in end_ids:
                # We found the inputs of the model, stop tracing upstream nodes
                end_ids_found.add(id(kt))
                continue

            inbound_node = kt.node
            # In case this is the tf.keras.Input node, we have reached the end
            # of the tracing of upstream nodes. Any further tracing will just be
            # an infinite loop. we should raise an error here since we didn't
            # find the input in the user-specified inputs.
            if inbound_node.is_input:
                raise ValueError(
                    "Found input tensor cannot be reached given provided "
                    "output tensors. Please make sure the tensor {} is "
                    "included in the model inputs when building "
                    "functional model.".format(kt))
            nodes_to_visit.append(inbound_node)

    # Do a final check and make sure we have reached all the user-specified
    # inputs
    if end_ids != end_ids_found:
        unvisited_inputs = [
            kt for kt in end_keras_tensors if id(kt) not in end_ids_found
        ]
        raise ValueError(
            "Found unvisited input tensors that are disconnected from "
            "the outputs: {}".format(unvisited_inputs))
    return nodes_in_graph