def _get_containing_xla_context(graph: tf.Graph) -> Optional[object]: """Returns the first ancestor `XLAControlFlowContext` in the `graph`.""" ctxt = graph._get_control_flow_context() # pylint: disable=protected-access while ctxt: if ctxt.IsXLAContext(): return ctxt ctxt = ctxt.outer_context return None
def _graph_dataset_iterator(ds_iter, graph: tf.Graph) -> Iterator[NumpyElem]: """Constructs a Python generator from a tf.data.Iterator.""" with graph.as_default(): init = ds_iter.initializer ds_item = ds_iter.get_next() with utils.nogpu_session() as sess: sess.run(init) while True: try: yield sess.run(ds_item) except tf.errors.OutOfRangeError: break