Пример #1
0
def main(_):

    logging.set_verbosity(FLAGS.log_level)

    if FLAGS.enable_eager_execution:
        tf.compat.v1.enable_eager_execution()

    if FLAGS.tf_xla:
        tf.config.optimizer.set_jit(True)
        backend.set_tf_xla_forced_compile(FLAGS.tf_xla_forced_compile)

    tf.config.optimizer.set_experimental_options(
        {'pin_to_host_optimization': FLAGS.tf_opt_pin_to_host})

    tf.config.optimizer.set_experimental_options(
        {'layout_optimizer': FLAGS.tf_opt_layout})

    set_tf_allow_float64(FLAGS.tf_allow_float64)

    _setup_gin()

    if FLAGS.enable_eager_execution and backend.get_name() in ('numpy', 'jax'):
        # Numpy backend doesn't benefit from having the input pipeline run on GPU,
        # and jax backend has GPU memory contention if TF uses the GPU. Gin must be
        # set up first before determining the backend.
        tf.config.experimental.set_visible_devices([], 'GPU')

    # Setup output directory
    output_dir = FLAGS.output_dir or _default_output_dir()
    trainer_lib.log('Using --output_dir %s' % output_dir)
    output_dir = os.path.expanduser(output_dir)

    # If on TPU, let JAX know.
    if FLAGS.use_tpu:
        jax.config.update('jax_platform_name', 'tpu')
        jax.config.update('jax_xla_backend', FLAGS.jax_xla_backend)
        jax.config.update('jax_backend_target', FLAGS.jax_backend_target)

    if FLAGS.use_tpu and backend.get_name() == 'tf':
        worker_cpu = tf_init_tpu()
        with tf.device(worker_cpu):
            if trainer_lib.num_devices() == 1:
                # TF's device priority is GPU > CPU > TPU, so we need to explicitly make
                # the TPU core the default device here.
                with tf.device('/device:TPU:0'):
                    trainer_lib.train(output_dir=output_dir)
            else:
                trainer_lib.train(output_dir=output_dir)
    else:
        trainer_lib.train(output_dir=output_dir)

    trainer_lib.log('Finished training.')
Пример #2
0
def main(_):
    logging.set_verbosity(FLAGS.log_level)

    _tf_setup_from_flags()
    _gin_parse_configs()
    _jax_and_tf_configure_for_devices()

    output_dir = _output_dir_or_default()
    if FLAGS.use_tpu and math.backend_name() == 'tf':
        _train_using_tf(output_dir)
    else:
        trainer_lib.train(output_dir=output_dir)

    trainer_lib.log('Finished training.')
Пример #3
0
def _default_output_dir():
    """Default output directory."""
    try:
        dataset_name = gin.query_parameter('inputs.dataset_name')
    except ValueError:
        dataset_name = 'random'
    dir_name = '{model_name}_{dataset_name}_{timestamp}'.format(
        model_name=gin.query_parameter('train.model').configurable.name,
        dataset_name=dataset_name,
        timestamp=datetime.datetime.now().strftime('%Y%m%d_%H%M'),
    )
    dir_path = os.path.join('~', 'trax', dir_name)
    print()
    trainer_lib.log('No --output_dir specified')
    return dir_path
Пример #4
0
def main(_):
    logging.set_verbosity(FLAGS.log_level)

    _tf_setup_from_flags()
    _gin_parse_configs()
    _jax_and_tf_configure_for_devices()

    if FLAGS.disable_jit:
        fastmath.disable_jit()

    output_dir = _output_dir_or_default()
    if FLAGS.use_tpu and fastmath.is_backend(Backend.TFNP):
        _train_using_tf(output_dir)
    else:
        trainer_lib.train(output_dir=output_dir)

    trainer_lib.log('Finished training.')
Пример #5
0
def main(_):
    logging.set_verbosity(FLAGS.log_level)

    _tf_setup_from_flags()
    _gin_parse_configs()
    _jax_and_tf_configure_for_devices()

    # Create a JAX GPU cluster if using JAX and given a chief IP.
    if fastmath.is_backend(Backend.JAX) and FLAGS.gpu_cluster_chief_ip:
        _make_jax_gpu_cluster(FLAGS.gpu_cluster_host_id,
                              FLAGS.gpu_cluster_chief_ip,
                              FLAGS.gpu_cluster_n_hosts,
                              FLAGS.gpu_cluster_port)

    if FLAGS.disable_jit:
        fastmath.disable_jit()

    output_dir = _output_dir_or_default()
    if FLAGS.use_tpu and fastmath.is_backend(Backend.TFNP):
        _train_using_tf(output_dir)
    else:
        trainer_lib.train(output_dir=output_dir)

    trainer_lib.log('Finished training.')
Пример #6
0
def _output_dir_or_default():
    """Returns a path to the output directory."""
    if FLAGS.output_dir:
        output_dir = FLAGS.output_dir
        trainer_lib.log('Using --output_dir {}'.format(output_dir))
        return os.path.expanduser(output_dir)

    # Else, generate a default output dir (under the user's home directory).
    try:
        dataset_name = gin.query_parameter('inputs.dataset_name')
    except ValueError:
        dataset_name = 'random'
    output_name = '{model_name}_{dataset_name}_{timestamp}'.format(
        model_name=gin.query_parameter('train.model').configurable.name,
        dataset_name=dataset_name,
        timestamp=datetime.datetime.now().strftime('%Y%m%d_%H%M'),
    )
    output_dir = os.path.join('~', 'trax', output_name)
    output_dir = os.path.expanduser(output_dir)
    print()
    trainer_lib.log('No --output_dir specified')
    trainer_lib.log('Using default output_dir: {}'.format(output_dir))
    return output_dir