def main(_): logging.set_verbosity(FLAGS.log_level) if FLAGS.enable_eager_execution: tf.compat.v1.enable_eager_execution() if FLAGS.tf_xla: tf.config.optimizer.set_jit(True) backend.set_tf_xla_forced_compile(FLAGS.tf_xla_forced_compile) tf.config.optimizer.set_experimental_options( {'pin_to_host_optimization': FLAGS.tf_opt_pin_to_host}) tf.config.optimizer.set_experimental_options( {'layout_optimizer': FLAGS.tf_opt_layout}) set_tf_allow_float64(FLAGS.tf_allow_float64) _setup_gin() if FLAGS.enable_eager_execution and backend.get_name() in ('numpy', 'jax'): # Numpy backend doesn't benefit from having the input pipeline run on GPU, # and jax backend has GPU memory contention if TF uses the GPU. Gin must be # set up first before determining the backend. tf.config.experimental.set_visible_devices([], 'GPU') # Setup output directory output_dir = FLAGS.output_dir or _default_output_dir() trainer_lib.log('Using --output_dir %s' % output_dir) output_dir = os.path.expanduser(output_dir) # If on TPU, let JAX know. if FLAGS.use_tpu: jax.config.update('jax_platform_name', 'tpu') jax.config.update('jax_xla_backend', FLAGS.jax_xla_backend) jax.config.update('jax_backend_target', FLAGS.jax_backend_target) if FLAGS.use_tpu and backend.get_name() == 'tf': worker_cpu = tf_init_tpu() with tf.device(worker_cpu): if trainer_lib.num_devices() == 1: # TF's device priority is GPU > CPU > TPU, so we need to explicitly make # the TPU core the default device here. with tf.device('/device:TPU:0'): trainer_lib.train(output_dir=output_dir) else: trainer_lib.train(output_dir=output_dir) else: trainer_lib.train(output_dir=output_dir) trainer_lib.log('Finished training.')
def main(_): logging.set_verbosity(FLAGS.log_level) _tf_setup_from_flags() _gin_parse_configs() _jax_and_tf_configure_for_devices() output_dir = _output_dir_or_default() if FLAGS.use_tpu and math.backend_name() == 'tf': _train_using_tf(output_dir) else: trainer_lib.train(output_dir=output_dir) trainer_lib.log('Finished training.')
def _default_output_dir(): """Default output directory.""" try: dataset_name = gin.query_parameter('inputs.dataset_name') except ValueError: dataset_name = 'random' dir_name = '{model_name}_{dataset_name}_{timestamp}'.format( model_name=gin.query_parameter('train.model').configurable.name, dataset_name=dataset_name, timestamp=datetime.datetime.now().strftime('%Y%m%d_%H%M'), ) dir_path = os.path.join('~', 'trax', dir_name) print() trainer_lib.log('No --output_dir specified') return dir_path
def main(_): logging.set_verbosity(FLAGS.log_level) _tf_setup_from_flags() _gin_parse_configs() _jax_and_tf_configure_for_devices() if FLAGS.disable_jit: fastmath.disable_jit() output_dir = _output_dir_or_default() if FLAGS.use_tpu and fastmath.is_backend(Backend.TFNP): _train_using_tf(output_dir) else: trainer_lib.train(output_dir=output_dir) trainer_lib.log('Finished training.')
def main(_): logging.set_verbosity(FLAGS.log_level) _tf_setup_from_flags() _gin_parse_configs() _jax_and_tf_configure_for_devices() # Create a JAX GPU cluster if using JAX and given a chief IP. if fastmath.is_backend(Backend.JAX) and FLAGS.gpu_cluster_chief_ip: _make_jax_gpu_cluster(FLAGS.gpu_cluster_host_id, FLAGS.gpu_cluster_chief_ip, FLAGS.gpu_cluster_n_hosts, FLAGS.gpu_cluster_port) if FLAGS.disable_jit: fastmath.disable_jit() output_dir = _output_dir_or_default() if FLAGS.use_tpu and fastmath.is_backend(Backend.TFNP): _train_using_tf(output_dir) else: trainer_lib.train(output_dir=output_dir) trainer_lib.log('Finished training.')
def _output_dir_or_default(): """Returns a path to the output directory.""" if FLAGS.output_dir: output_dir = FLAGS.output_dir trainer_lib.log('Using --output_dir {}'.format(output_dir)) return os.path.expanduser(output_dir) # Else, generate a default output dir (under the user's home directory). try: dataset_name = gin.query_parameter('inputs.dataset_name') except ValueError: dataset_name = 'random' output_name = '{model_name}_{dataset_name}_{timestamp}'.format( model_name=gin.query_parameter('train.model').configurable.name, dataset_name=dataset_name, timestamp=datetime.datetime.now().strftime('%Y%m%d_%H%M'), ) output_dir = os.path.join('~', 'trax', output_name) output_dir = os.path.expanduser(output_dir) print() trainer_lib.log('No --output_dir specified') trainer_lib.log('Using default output_dir: {}'.format(output_dir)) return output_dir