def _get_containing_xla_context(graph: tf.Graph) -> Optional[object]:
    """Returns the first ancestor `XLAControlFlowContext` in the `graph`."""
    ctxt = graph._get_control_flow_context()  # pylint: disable=protected-access
    while ctxt:
        if ctxt.IsXLAContext():
            return ctxt
        ctxt = ctxt.outer_context
    return None
Esempio n. 2
0
def updated_graph_flow_context_to_loop_context(graph: tf.Graph,
                                               preceeding_tensor: tf.Tensor):
    """
    updates graph flow context to loop context
    :param graph: TensorFlow Graph (tf.Graph)
    :param preceeding_tensor: TF tensor that feeds into the op which needs modification
    :return: old graph context object
    """

    # pylint: disable=protected-access
    old_graph_context = graph._get_control_flow_context()
    graph._set_control_flow_context(
        preceeding_tensor.op._get_control_flow_context())

    return old_graph_context
Esempio n. 3
0
def op_not_in_loop_control_flow_context(graph: tf.Graph,
                                        input_op: tf.Operation) -> bool:
    """
    checks if the  op is not in loop control flow context or not
    :param graph: tf.Graph is the active graph
    :param input_op: op as tf.Operation
    :return: True if op is not in a loop control flow context, False otherwise.
    """
    # pylint: disable=protected-access
    active_ctxt = graph._get_control_flow_context()
    input_ctxt = input_op._get_control_flow_context()

    if not input_ctxt or input_ctxt is active_ctxt:
        # input_op isn't in 'a' loop control flow context or
        # input_op is in the same context as op.
        return True

    return False