Ejemplo n.º 1
0
def _train_using_tf(output_dir):
    worker_cpu = tf_init_tpu()
    with tf.device(worker_cpu):
        if trainer_lib.num_devices() == 1:
            # TF's device priority is GPU > CPU > TPU, so we need to explicitly make
            # the TPU core the default device here.
            with tf.device('/device:TPU:0'):
                trainer_lib.train(output_dir=output_dir)
        else:
            trainer_lib.train(output_dir=output_dir)
Ejemplo n.º 2
0
def main(_):

    logging.set_verbosity(FLAGS.log_level)

    if FLAGS.enable_eager_execution:
        tf.compat.v1.enable_eager_execution()

    if FLAGS.tf_xla:
        tf.config.optimizer.set_jit(True)
        backend.set_tf_xla_forced_compile(FLAGS.tf_xla_forced_compile)

    tf.config.optimizer.set_experimental_options(
        {'pin_to_host_optimization': FLAGS.tf_opt_pin_to_host})

    tf.config.optimizer.set_experimental_options(
        {'layout_optimizer': FLAGS.tf_opt_layout})

    set_tf_allow_float64(FLAGS.tf_allow_float64)

    _setup_gin()

    if FLAGS.enable_eager_execution and backend.get_name() in ('numpy', 'jax'):
        # Numpy backend doesn't benefit from having the input pipeline run on GPU,
        # and jax backend has GPU memory contention if TF uses the GPU. Gin must be
        # set up first before determining the backend.
        tf.config.experimental.set_visible_devices([], 'GPU')

    # Setup output directory
    output_dir = FLAGS.output_dir or _default_output_dir()
    trainer_lib.log('Using --output_dir %s' % output_dir)
    output_dir = os.path.expanduser(output_dir)

    # If on TPU, let JAX know.
    if FLAGS.use_tpu:
        jax.config.update('jax_platform_name', 'tpu')
        jax.config.update('jax_xla_backend', FLAGS.jax_xla_backend)
        jax.config.update('jax_backend_target', FLAGS.jax_backend_target)

    if FLAGS.use_tpu and backend.get_name() == 'tf':
        worker_cpu = tf_init_tpu()
        with tf.device(worker_cpu):
            if trainer_lib.num_devices() == 1:
                # TF's device priority is GPU > CPU > TPU, so we need to explicitly make
                # the TPU core the default device here.
                with tf.device('/device:TPU:0'):
                    trainer_lib.train(output_dir=output_dir)
            else:
                trainer_lib.train(output_dir=output_dir)
    else:
        trainer_lib.train(output_dir=output_dir)

    trainer_lib.log('Finished training.')