def test_tf_xla_forced_compile(self): # TODO(wangpeng): re-enable this test self.skipTest('Needs --config=cuda to pass this test') old_flag = backend.tf_xla_forced_compile_enabled() backend.set_tf_xla_forced_compile(True) self._test_train_eval_predict('tf') backend.set_tf_xla_forced_compile(old_flag)
def main(_): logging.set_verbosity(FLAGS.log_level) if FLAGS.enable_eager_execution: tf.compat.v1.enable_eager_execution() if FLAGS.tf_xla: tf.config.optimizer.set_jit(True) backend.set_tf_xla_forced_compile(FLAGS.tf_xla_forced_compile) tf.config.optimizer.set_experimental_options( {'pin_to_host_optimization': FLAGS.tf_opt_pin_to_host} ) tf.config.optimizer.set_experimental_options( {'layout_optimizer': FLAGS.tf_opt_layout} ) set_tf_allow_float64(FLAGS.tf_allow_float64) _setup_gin() if FLAGS.enable_eager_execution and backend.get_name() in ('numpy', 'jax'): # Numpy backend doesn't benefit from having the input pipeline run on GPU, # and jax backend has GPU memory contention if TF uses the GPU. Gin must be # set up first before determining the backend. tf.config.experimental.set_visible_devices([], 'GPU') # Setup output directory output_dir = FLAGS.output_dir or _default_output_dir() trainer_lib.log('Using --output_dir %s' % output_dir) output_dir = os.path.expanduser(output_dir) # If on TPU, let JAX know. if FLAGS.use_tpu: jax.config.update('jax_platform_name', 'tpu') jax.config.update('jax_xla_backend', FLAGS.jax_xla_backend) jax.config.update('jax_backend_target', FLAGS.jax_backend_target) if FLAGS.use_tpu and backend.get_name() == 'tf': worker_cpu = tf_init_tpu() with tf.device(worker_cpu): if trainer_lib.num_devices() == 1: # TF's device priority is GPU > CPU > TPU, so we need to explicitly make # the TPU core the default device here. with tf.device('/device:TPU:0'): trainer_lib.train(output_dir=output_dir) else: trainer_lib.train(output_dir=output_dir) else: trainer_lib.train(output_dir=output_dir) trainer_lib.log('Finished training.')
def _tf_setup_from_flags(): """Processes TensorFlow-relevant flags.""" if FLAGS.enable_eager_execution: tf.compat.v1.enable_eager_execution() if FLAGS.tf_xla: tf.config.optimizer.set_jit(True) backend.set_tf_xla_forced_compile(FLAGS.tf_xla_forced_compile) tf.config.optimizer.set_experimental_options({ 'pin_to_host_optimization': FLAGS.tf_opt_pin_to_host, 'layout_optimizer': FLAGS.tf_opt_layout, }) tf_np.set_allow_float64(FLAGS.tf_allow_float64)