Exemplo n.º 1
0
def main(_):

    logging.set_verbosity(FLAGS.log_level)

    if FLAGS.enable_eager_execution:
        tf.enable_eager_execution()

    if FLAGS.tf_xla:
        tf.config.optimizer.set_jit(True)

    tf.config.optimizer.set_experimental_options(
        {'pin_to_host_optimization': FLAGS.tf_opt_pin_to_host})

    tf.config.optimizer.set_experimental_options(
        {'layout_optimizer': FLAGS.tf_opt_layout})

    _setup_gin()

    if FLAGS.enable_eager_execution and backend.get_name() in ('numpy', 'jax'):
        # Numpy backend doesn't benefit from having the input pipeline run on GPU,
        # and jax backend has GPU memory contention if TF uses the GPU. Gin must be
        # set up first before determining the backend.
        tf.config.experimental.set_visible_devices([], 'GPU')

    # Setup output directory
    output_dir = FLAGS.output_dir or _default_output_dir()
    trainer_lib.log('Using --output_dir %s' % output_dir)
    output_dir = os.path.expanduser(output_dir)

    # If on TPU, let JAX know.
    if FLAGS.use_tpu:
        jax.config.update('jax_platform_name', 'tpu')

    trainer_lib.train(output_dir=output_dir)
Exemplo n.º 2
0
def main(_):

  logging.set_verbosity(FLAGS.log_level)

  if FLAGS.enable_eager_execution:
    tf.compat.v1.enable_eager_execution()

  if FLAGS.tf_xla:
    tf.config.optimizer.set_jit(True)
    backend.set_tf_xla_forced_compile(FLAGS.tf_xla_forced_compile)

  tf.config.optimizer.set_experimental_options(
      {'pin_to_host_optimization': FLAGS.tf_opt_pin_to_host}
  )

  tf.config.optimizer.set_experimental_options(
      {'layout_optimizer': FLAGS.tf_opt_layout}
  )

  set_tf_allow_float64(FLAGS.tf_allow_float64)

  _setup_gin()

  if FLAGS.enable_eager_execution and backend.get_name() in ('numpy', 'jax'):
    # Numpy backend doesn't benefit from having the input pipeline run on GPU,
    # and jax backend has GPU memory contention if TF uses the GPU. Gin must be
    # set up first before determining the backend.
    tf.config.experimental.set_visible_devices([], 'GPU')

  # Setup output directory
  output_dir = FLAGS.output_dir or _default_output_dir()
  trainer_lib.log('Using --output_dir %s' % output_dir)
  output_dir = os.path.expanduser(output_dir)

  # If on TPU, let JAX know.
  if FLAGS.use_tpu:
    jax.config.update('jax_platform_name', 'tpu')
    jax.config.update('jax_xla_backend', FLAGS.jax_xla_backend)
    jax.config.update('jax_backend_target', FLAGS.jax_backend_target)

  if FLAGS.use_tpu and backend.get_name() == 'tf':
    worker_cpu = tf_init_tpu()
    with tf.device(worker_cpu):
      if trainer_lib.num_devices() == 1:
        # TF's device priority is GPU > CPU > TPU, so we need to explicitly make
        # the TPU core the default device here.
        with tf.device('/device:TPU:0'):
          trainer_lib.train(output_dir=output_dir)
      else:
        trainer_lib.train(output_dir=output_dir)
  else:
    trainer_lib.train(output_dir=output_dir)

  trainer_lib.log('Finished training.')
Exemplo n.º 3
0
def _default_output_dir():
    """Default output directory."""
    try:
        dataset_name = gin.query_parameter('inputs.dataset_name')
    except ValueError:
        dataset_name = 'random'
    dir_name = '{model_name}_{dataset_name}_{timestamp}'.format(
        model_name=gin.query_parameter('train.model').configurable.name,
        dataset_name=dataset_name,
        timestamp=datetime.datetime.now().strftime('%Y%m%d_%H%M'),
    )
    dir_path = os.path.join('~', 'trax', dir_name)
    print()
    trainer_lib.log('No --output_dir specified')
    return dir_path
Exemplo n.º 4
0
def _default_output_dir():
  """Default output directory."""
  try:
    dataset_name = gin.query_parameter("inputs.dataset_name")
  except ValueError:
    dataset_name = "random"
  dir_name = "{model_name}_{dataset_name}_{timestamp}".format(
      model_name=gin.query_parameter("train.model").configurable.name,
      dataset_name=dataset_name,
      timestamp=datetime.datetime.now().strftime("%Y%m%d_%H%M"),
  )
  dir_path = os.path.join("~", "trax", dir_name)
  print()
  trainer_lib.log("No --output_dir specified")
  return dir_path