Exemple #1
0
 def test_tf_xla_forced_compile(self):
   # TODO(wangpeng): re-enable this test
   self.skipTest('Needs --config=cuda to pass this test')
   old_flag = backend.tf_xla_forced_compile_enabled()
   backend.set_tf_xla_forced_compile(True)
   self._test_train_eval_predict('tf')
   backend.set_tf_xla_forced_compile(old_flag)
Exemple #2
0
def main(_):

  logging.set_verbosity(FLAGS.log_level)

  if FLAGS.enable_eager_execution:
    tf.compat.v1.enable_eager_execution()

  if FLAGS.tf_xla:
    tf.config.optimizer.set_jit(True)
    backend.set_tf_xla_forced_compile(FLAGS.tf_xla_forced_compile)

  tf.config.optimizer.set_experimental_options(
      {'pin_to_host_optimization': FLAGS.tf_opt_pin_to_host}
  )

  tf.config.optimizer.set_experimental_options(
      {'layout_optimizer': FLAGS.tf_opt_layout}
  )

  set_tf_allow_float64(FLAGS.tf_allow_float64)

  _setup_gin()

  if FLAGS.enable_eager_execution and backend.get_name() in ('numpy', 'jax'):
    # Numpy backend doesn't benefit from having the input pipeline run on GPU,
    # and jax backend has GPU memory contention if TF uses the GPU. Gin must be
    # set up first before determining the backend.
    tf.config.experimental.set_visible_devices([], 'GPU')

  # Setup output directory
  output_dir = FLAGS.output_dir or _default_output_dir()
  trainer_lib.log('Using --output_dir %s' % output_dir)
  output_dir = os.path.expanduser(output_dir)

  # If on TPU, let JAX know.
  if FLAGS.use_tpu:
    jax.config.update('jax_platform_name', 'tpu')
    jax.config.update('jax_xla_backend', FLAGS.jax_xla_backend)
    jax.config.update('jax_backend_target', FLAGS.jax_backend_target)

  if FLAGS.use_tpu and backend.get_name() == 'tf':
    worker_cpu = tf_init_tpu()
    with tf.device(worker_cpu):
      if trainer_lib.num_devices() == 1:
        # TF's device priority is GPU > CPU > TPU, so we need to explicitly make
        # the TPU core the default device here.
        with tf.device('/device:TPU:0'):
          trainer_lib.train(output_dir=output_dir)
      else:
        trainer_lib.train(output_dir=output_dir)
  else:
    trainer_lib.train(output_dir=output_dir)

  trainer_lib.log('Finished training.')
Exemple #3
0
def _tf_setup_from_flags():
  """Processes TensorFlow-relevant flags."""
  if FLAGS.enable_eager_execution:
    tf.compat.v1.enable_eager_execution()
  if FLAGS.tf_xla:
    tf.config.optimizer.set_jit(True)
    backend.set_tf_xla_forced_compile(FLAGS.tf_xla_forced_compile)
  tf.config.optimizer.set_experimental_options({
      'pin_to_host_optimization': FLAGS.tf_opt_pin_to_host,
      'layout_optimizer': FLAGS.tf_opt_layout,
  })
  tf_np.set_allow_float64(FLAGS.tf_allow_float64)