Ejemplo n.º 1
0
def launch():
    """Launch t2t_trainer on Cloud ML Engine."""
    validate_flags()
    job_spec = configure_job()
    job_name = job_spec["jobId"]
    tf.logging.info("Launching job %s with ML Engine spec:\n%s", job_name,
                    job_spec)
    assert cloud.confirm()
    train_dir = FLAGS.output_dir
    t2t_tar = tar_and_copy_t2t(train_dir)
    configure_trainer_package(job_spec, t2t_tar)
    if FLAGS.t2t_usr_dir:
        usr_tar = tar_and_copy_usr_dir(FLAGS.t2t_usr_dir, train_dir)
        configure_usr_dir(job_spec, usr_tar)
    launch_job(job_spec)
    tf.logging.info("Launched %s. See console to track: %s.", job_name,
                    CONSOLE_URL)
Ejemplo n.º 2
0
def launch():
  """Launch t2t_trainer on Cloud ML Engine."""
  assert not FLAGS.cloud_tpu
  assert not job_dir()
  assert FLAGS.output_dir.startswith('gs://')
  assert FLAGS.data_dir.startswith('gs://')
  assert FLAGS.worker_replicas <= 1
  assert FLAGS.ps_replicas <= 0

  build_t2t_python_package()
  job_spec = configure_job()
  job_name = job_spec['jobId']
  tf.logging.info('Launching job %s with ML Engine spec:\n%s', job_name,
                  job_spec)
  assert cloud.confirm()
  train_dir = FLAGS.output_dir
  trainer_package_gcs_path = upload_trainer_package_to_gcs(train_dir)
  configure_trainer_package(job_spec, trainer_package_gcs_path)
  launch_job(job_spec)
  tf.logging.info('Launched %s. See console to track: %s.', job_name,
                  CONSOLE_URL)
Ejemplo n.º 3
0
def launch():
    """Launch t2t_trainer on Cloud ML Engine."""
    assert not FLAGS.cloud_tpu
    assert not FLAGS.job_dir
    assert FLAGS.output_dir.startswith('gs://')
    assert FLAGS.data_dir.startswith('gs://')
    assert FLAGS.worker_replicas <= 1
    assert FLAGS.ps_replicas <= 0

    job_spec = configure_job()
    job_name = job_spec['jobId']
    tf.logging.info('Launching job %s with ML Engine spec:\n%s', job_name,
                    job_spec)
    assert cloud.confirm()
    train_dir = FLAGS.output_dir
    t2t_tar = tar_and_copy_t2t(train_dir)
    configure_trainer_package(job_spec, t2t_tar)
    if FLAGS.t2t_usr_dir:
        usr_tar = tar_and_copy_usr_dir(FLAGS.t2t_usr_dir, train_dir)
        configure_usr_dir(job_spec, usr_tar)
    launch_job(job_spec)
    tf.logging.info('Launched %s. See console to track: %s.', job_name,
                    CONSOLE_URL)