def _get_containing_xla_context(graph: tf.Graph) -> Optional[object]: """Returns the first ancestor `XLAControlFlowContext` in the `graph`.""" ctxt = graph._get_control_flow_context() # pylint: disable=protected-access while ctxt: if ctxt.IsXLAContext(): return ctxt ctxt = ctxt.outer_context return None
def updated_graph_flow_context_to_loop_context(graph: tf.Graph, preceeding_tensor: tf.Tensor): """ updates graph flow context to loop context :param graph: TensorFlow Graph (tf.Graph) :param preceeding_tensor: TF tensor that feeds into the op which needs modification :return: old graph context object """ # pylint: disable=protected-access old_graph_context = graph._get_control_flow_context() graph._set_control_flow_context( preceeding_tensor.op._get_control_flow_context()) return old_graph_context
def op_not_in_loop_control_flow_context(graph: tf.Graph, input_op: tf.Operation) -> bool: """ checks if the op is not in loop control flow context or not :param graph: tf.Graph is the active graph :param input_op: op as tf.Operation :return: True if op is not in a loop control flow context, False otherwise. """ # pylint: disable=protected-access active_ctxt = graph._get_control_flow_context() input_ctxt = input_op._get_control_flow_context() if not input_ctxt or input_ctxt is active_ctxt: # input_op isn't in 'a' loop control flow context or # input_op is in the same context as op. return True return False