コード例 #1
0
ファイル: trainer.py プロジェクト: tianhai123/-
def main(_):
    logging.set_verbosity(FLAGS.log_level)

    if FLAGS.tf_eager:
        tf.enable_eager_execution()

    if FLAGS.tf_xla:
        tf.config.optimizer.set_jit(True)

    tf.config.optimizer.set_experimental_options(
        {"pin_to_host_optimization": FLAGS.tf_opt_pin_to_host})

    tf.config.optimizer.set_experimental_options(
        {"layout_optimizer": FLAGS.tf_opt_layout})

    _setup_gin()

    if FLAGS.tf_eager and backend.get_name() in ("numpy", "jax"):
        # Numpy backend doesn't benefit from having the input pipeline run on GPU,
        # and jax backend has GPU memory contention if TF uses the GPU. Gin must be
        # set up first before determining the backend.
        tf.config.experimental.set_visible_devices([], "GPU")

    # Setup output directory
    output_dir = FLAGS.output_dir or _default_output_dir()
    trax.log("Using --output_dir %s" % output_dir)
    output_dir = os.path.expanduser(output_dir)

    # If on TPU, let JAX know.
    if FLAGS.use_tpu:
        jax.config.update("jax_platform_name", "tpu")

    trax.train(output_dir=output_dir)
コード例 #2
0
def main(_):
  logging.set_verbosity(FLAGS.log_level)

  if FLAGS.tf_eager:
    tf.enable_eager_execution()

  if FLAGS.tf_xla:
    tf.config.optimizer.set_jit(True)

  tf.config.optimizer.set_experimental_options(
      {"pin_to_host_optimization": FLAGS.tf_opt_pin_to_host}
  )

  tf.config.optimizer.set_experimental_options(
      {"layout_optimizer": FLAGS.tf_opt_layout}
  )

  _setup_gin()

  # Setup output directory
  output_dir = FLAGS.output_dir or _default_output_dir()
  trax.log("Using --output_dir %s" % output_dir)
  output_dir = os.path.expanduser(output_dir)

  # If on TPU, let JAX know.
  if FLAGS.use_tpu:
    jax.config.update("jax_platform_name", "tpu")

  trax.train(output_dir=output_dir)
コード例 #3
0
ファイル: trainer.py プロジェクト: weiczhu/tensor2tensor
def main(_):
  _setup_gin()

  # Setup output directory
  output_dir = FLAGS.output_dir or _default_output_dir()
  trax.log("Using --output_dir %s" % output_dir)
  output_dir = os.path.expanduser(output_dir)

  trax.train(output_dir=output_dir)
コード例 #4
0
def main(_):
  logging.set_verbosity(FLAGS.log_level)

  _setup_gin()

  # Setup output directory
  output_dir = FLAGS.output_dir or _default_output_dir()
  trax.log("Using --output_dir %s" % output_dir)
  output_dir = os.path.expanduser(output_dir)

  trax.train(output_dir=output_dir)
コード例 #5
0
def _default_output_dir():
  """Default output directory."""
  dir_name = "{model_name}_{dataset_name}_{timestamp}".format(
      model_name=gin.query_parameter("train.model").configurable.name,
      dataset_name=gin.query_parameter("inputs.dataset_name"),
      timestamp=datetime.datetime.now().strftime("%Y%m%d_%H%M"),
  )
  dir_path = os.path.join("~", "trax", dir_name)
  print()
  trax.log("No --output_dir specified")
  return dir_path
コード例 #6
0
ファイル: trainer.py プロジェクト: daishu7/tensor2tensor
def main(_):
  _setup_gin()

  # Setup directories
  data_dir = FLAGS.data_dir
  output_dir = FLAGS.output_dir or _default_output_dir()
  assert data_dir, "Must specify a data directory"
  assert output_dir, "Must specify an output directory"
  trax.log("Using --output_dir %s" % output_dir)

  data_dir = os.path.expanduser(data_dir)
  output_dir = os.path.expanduser(output_dir)

  trax.train(data_dir=data_dir, output_dir=output_dir)
コード例 #7
0
ファイル: trainer.py プロジェクト: zhodj/tensor2tensor
def main(_):
  logging.set_verbosity(FLAGS.log_level)

  _setup_gin()

  # Setup output directory
  output_dir = FLAGS.output_dir or _default_output_dir()
  trax.log("Using --output_dir %s" % output_dir)
  output_dir = os.path.expanduser(output_dir)

  # If on TPU, let JAX know.
  if FLAGS.use_tpu:
    jax.config.update("jax_platform_name", "tpu")

  trax.train(output_dir=output_dir)