def _f(*args): with tf.device(self._basic_block_xla_device): return xla.compile_nested_output( f, tf.xla.experimental.compile)(*args)
def wrap_fn(self, f): return xla.compile_nested_output( f, (tf1.tpu.rewrite if 'TPU' in self.device else tf.xla.experimental.compile))