示例#1
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], 'GPU')

    logging.info('JAX process: %d / %d', jax.process_index(),
                 jax.process_count())
    logging.info('JAX local devices: %r', jax.local_devices())

    # Add a note so that we can tell which task is which JAX host.
    # (Depending on the platform task 0 is not guaranteed to be host 0)
    platform.work_unit().set_task_status(
        f'process_index: {jax.process_index()}, '
        f'process_count: {jax.process_count()}')
    platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                         FLAGS.workdir, 'workdir')

    if FLAGS.mode == 'train':
        train.train_and_evaluate(FLAGS.config, FLAGS.workdir)
    else:
        predict.predict_and_evaluate(FLAGS.config, FLAGS.workdir,
                                     FLAGS.ckpt_path)
示例#2
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    utils.add_gfile_logger(_WORKDIR.value)

    # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], 'GPU')

    jax.config.update('jax_log_compiles', True)

    logging.info('JAX process: %d / %d', jax.process_index(),
                 jax.process_count())
    logging.info('JAX local devices: %r', jax.local_devices())
    jax_xla_backend = ('None' if FLAGS.jax_xla_backend is None else
                       FLAGS.jax_xla_backend)
    logging.info('Using JAX XLA backend %s', jax_xla_backend)

    logging.info('Config: %s', FLAGS.config)

    # Add a note so that we can tell which task is which JAX host.
    # (Depending on the platform task 0 is not guaranteed to be host 0)
    platform.work_unit().set_task_status(
        f'process_index: {jax.process_index()}, '
        f'process_count: {jax.process_count()}')
    platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                         _WORKDIR.value, 'workdir')

    if FLAGS.config.trainer == 'train':
        train.train_and_evaluate(FLAGS.config, _WORKDIR.value)
    elif FLAGS.config.trainer == 'inference_time':
        inference_time.inference_time(FLAGS.config, _WORKDIR.value)
    else:
        raise app.UsageError(f'Unknown trainer: {FLAGS.config.trainer}')
示例#3
0
文件: main.py 项目: joaogui1/flax
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    FLAGS.log_dir = FLAGS.workdir
    FLAGS.stderrthreshold = 'info'
    logging.get_absl_handler().start_logging_to_file()

    # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], 'GPU')

    logging.info('JAX host: %d / %d', jax.host_id(), jax.host_count())
    logging.info('JAX local devices: %r', jax.local_devices())

    # Add a note so that we can tell which task is which JAX host.
    # (Depending on the platform task 0 is not guaranteed to be host 0)
    platform.work_unit().set_task_status(
        f'host_id: {jax.host_id()}, host_count: {jax.host_count()}')
    platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                         FLAGS.workdir, 'workdir')

    if FLAGS.sample:
        sample.save_images(sample.generate_sample(FLAGS.config, FLAGS.workdir),
                           'sample.png')
    else:
        train.train_and_evaluate(FLAGS.config, FLAGS.workdir)
示例#4
0
def main(argv):
    del argv

    # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], "GPU")

    logging.info("JAX process: %d / %d", jax.process_index(),
                 jax.process_count())
    logging.info("JAX devices: %r", jax.devices())

    # Add a note so that we can tell which task is which JAX process.
    platform.work_unit().set_task_status(
        f"process_index: {jax.process_index()}, process_count: {jax.process_count()}"
    )
    platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                         FLAGS.workdir, "workdir")

    train_mode = FLAGS.config.mode
    if train_mode == TrainingMode.PRETRAINING:
        train_lib = run_pretraining
    elif train_mode == TrainingMode.CLASSIFICATION:
        train_lib = run_classifier
    else:
        raise ValueError("Unknown training mode: %s" % train_mode)

    train_lib.train_and_evaluate(FLAGS.config, FLAGS.workdir,
                                 FLAGS.vocab_filepath)
示例#5
0
def main(argv):
    del argv

    # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], "GPU")

    if FLAGS.jax_backend_target:
        logging.info("Using JAX backend target %s", FLAGS.jax_backend_target)
        jax_xla_backend = ("None" if FLAGS.jax_xla_backend is None else
                           FLAGS.jax_xla_backend)
        logging.info("Using JAX XLA backend %s", jax_xla_backend)

    logging.info("JAX process: %d / %d", jax.process_index(),
                 jax.process_count())
    logging.info("JAX devices: %r", jax.devices())

    if FLAGS.is_train:
        # Add a note so that we can tell which task is which JAX host.
        # (Depending on platform task 0 is not guaranteed to be host 0)
        platform.work_unit().set_task_status(
            f"process_index: {jax.process_index()}, "
            f"process_count: {jax.process_count()}")
        platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                             FLAGS.workdir, "workdir")

        train_lib.train_and_evaluate(FLAGS.ml_config, FLAGS.workdir)
    else:
        eval_lib.evaluate(FLAGS.ml_config, FLAGS.workdir)
示例#6
0
def main(argv):
    del argv

    # Hide any GPUs form TensorFlow. Otherwise, TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], "GPU")

    logging.info("JAX host: %d / %d", jax.process_index(), jax.process_count())
    logging.info("JAX devices: %r", jax.devices())

    # Add a note so that we can tell which task is which JAX host. (Task 0 is not
    # guaranteed to be host 0)
    platform.work_unit().set_task_status(
        f"host_id: {jax.process_index()}, host_count: {jax.process_count()}")
    platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                         _WORKDIR.value, "workdir")

    train_mode = _CONFIG.value.mode
    if train_mode == TrainingMode.PRETRAINING:
        train_lib = run_pretraining
    elif train_mode == TrainingMode.CLASSIFICATION:
        train_lib = run_classifier
    else:
        raise ValueError("Unknown mode: %s" % train_mode)

    train_lib.train_and_evaluate(_CONFIG.value, _WORKDIR.value,
                                 _VOCAB_FILEPATH.value)
示例#7
0
 def _end_session(self, url: Optional[str]):
     platform.work_unit().create_artifact(
         platform.ArtifactType.URL,
         url,
         description=self._artifact_name.format(step=self._previous_step))
     self._session_running = False
     self._session_started = None
 def _end_session(self, url: Optional[str]):
   platform.work_unit().create_artifact(
       platform.ArtifactType.URL,
       url,
       description=f"[{self._previous_step}] Profile")
   self._session_running = False
   self._session_started = None
 def _end_session(self):
   url = profiler.stop()
   if url is not None:
     platform.work_unit().create_artifact(
         platform.ArtifactType.URL,
         url,
         description=f"[{self._previous_step}] Profile")
   self._session_running = False
示例#10
0
 def _apply(self, step: int, t: float):
   steps_per_sec = (step - self._previous_step) / (t - self._previous_time)
   eta_seconds = (self._num_train_steps - step) / steps_per_sec
   message = (f"{100 * step / self._num_train_steps:.1f}% @{step}, "
              f"{steps_per_sec:.1f} steps/s, ETA: {eta_seconds / 60:.0f} min")
   if self._time_per_part:
     total = time.time() - self._t0
     message += " ({:.0f} min : {})".format(total / 60, ", ".join(
         f"{100 * dt / total:.1f}% {name}"
         for name, dt in sorted(self._time_per_part.items())))
   # This should be relatively cheap so we can do it in the same main thread.
   platform.work_unit().set_notes(message)
   if self._writer is not None:
     self._writer.write_scalars(step, {"steps_per_sec": steps_per_sec})
示例#11
0
def main(*_args, **_kwargs):
    # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], 'GPU')
    info_string = f'JAX process: {jax.process_index()} / {jax.process_count()}'
    logging.info(info_string)

    info_string = f'JAX local devices: {jax.local_devices()}'
    logging.info(info_string)

    platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                         FLAGS.work_unit_dir, 'work_unit_dir')
    config = FLAGS.config
    train.train_and_evaluate(config, FLAGS.work_unit_dir)
def main(executable_dict, argv):
    del argv

    work_unit = platform.work_unit()
    tf.enable_v2_behavior()
    # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], 'GPU')

    logging.info('JAX host: %d / %d', jax.host_id(), jax.host_count())
    logging.info('JAX devices: %r', jax.devices())

    work_unit.set_task_status(
        f'host_id: {jax.host_id()}, host_count: {jax.host_count()}')

    # Read configuration
    if FLAGS.config_json:
        logging.info('Reading config from JSON: %s', FLAGS.config_json)
        with tf.io.gfile.GFile(FLAGS.config_json, 'r') as f:
            config = ml_collections.ConfigDict(json.loads(f.read()))
    else:
        config = FLAGS.config
    logging.info('config=%s',
                 config.to_json_best_effort(indent=4, sort_keys=True))

    # Make output directories
    if FLAGS.experiment_dir:
        work_unit.create_artifact(platform.ArtifactType.DIRECTORY,
                                  FLAGS.experiment_dir, 'experiment_dir')
    if FLAGS.work_unit_dir:
        work_unit.create_artifact(platform.ArtifactType.DIRECTORY,
                                  FLAGS.work_unit_dir, 'work_unit_dir')
    logging.info('experiment_dir=%s work_unit_dir=%s', FLAGS.experiment_dir,
                 FLAGS.work_unit_dir)

    # Seeding
    random.seed(config.seed * jax.host_count() + jax.host_id())
    onp.random.seed(config.seed * jax.host_count() + jax.host_id())
    rng = utils.RngGen(
        jax.random.fold_in(jax.random.PRNGKey(config.seed), jax.host_id()))

    # Run the main function
    logging.info('Running executable: %s', FLAGS.executable_name)

    extra_args = {}
    if FLAGS.extra_args_json_str:
        extra_args = json.loads(FLAGS.extra_args_json_str)
        logging.info('Extra args passed in: %r', extra_args)

    executable_dict[FLAGS.executable_name](config=config,
                                           experiment_dir=FLAGS.experiment_dir,
                                           work_unit_dir=FLAGS.work_unit_dir,
                                           rng=rng,
                                           **extra_args)

    utils.barrier()
示例#13
0
def main(argv):
    del argv

    # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], "GPU")

    logging.info("JAX host: %d / %d", jax.host_id(), jax.host_count())
    logging.info("JAX devices: %r", jax.devices())

    # Add a note so that we can tell which task is which JAX host.
    # (Borg task 0 is not guaranteed to be host 0)
    platform.work_unit().set_task_status(
        f"host_id: {jax.host_id()}, host_count: {jax.host_count()}")
    platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                         FLAGS.workdir, "workdir")

    state = train_and_evaluate(FLAGS.config, FLAGS.workdir)
    del state
示例#14
0
def main(argv: Sequence[str]) -> None:
  if len(argv) > 1:
    raise app.UsageError("Too many command-line arguments.")

  # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
  # it unavailable to JAX.
  tf.config.experimental.set_visible_devices([], "GPU")

  logging.info("JAX process: %d / %d", jax.process_index(), jax.process_count())
  logging.info("JAX local devices: %r", jax.local_devices())

  # Add a note so that we can tell which task is which JAX host.
  # (Depending on the platform task 0 is not guaranteed to be host 0)
  platform.work_unit().set_task_status(f"process_index: {jax.process_index()}, "
                                       f"process_count: {jax.process_count()}")
  platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                       FLAGS.model_dir, "model_dir")

  tf.io.gfile.makedirs(FLAGS.model_dir)

  # Process config here
  if FLAGS.config_file:
    with tf.io.gfile.GFile(FLAGS.config_file, "r") as reader:
      config = json.load(reader)
  else:
    config = json.loads(FLAGS.config)
    # # Save config to workdir if it's not yet exists
    if jax.process_index() == 0:
      config_file = os.path.join(FLAGS.model_dir, "config.json")
      with tf.io.gfile.GFile(config_file, "w") as writer:
        writer.write(json.dumps(config, indent=4))

  config["model_dir"] = FLAGS.model_dir
  if FLAGS.learning_rate is not None:
    config["learning_rate"] = FLAGS.learning_rate
  if FLAGS.per_device_batch_size is not None:
    config["per_device_batch_size"] = FLAGS.per_device_batch_size
  if FLAGS.num_train_steps is not None:
    config["num_train_steps"] = FLAGS.num_train_steps
  if FLAGS.warmup_steps is not None:
    config["warmup_steps"] = FLAGS.warmup_steps

  train(ml_collections.ConfigDict(config))
def main(argv: Sequence[str]) -> None:
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], 'GPU')

    logging.info('JAX process: %d / %d', jax.process_index(),
                 jax.process_count())
    logging.info('JAX local devices: %r', jax.local_devices())
    logging.info('JAX total devices: %r', jax.device_count())

    # Add a note so that we can tell which task is which JAX host.
    # (Depending on the platform task 0 is not guaranteed to be host 0)
    platform.work_unit().set_task_status(
        f'process_index: {jax.process_index()}, '
        f'process_count: {jax.process_count()}')
    platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                         FLAGS.output_dir, 'output_dir')

    tf.io.gfile.makedirs(FLAGS.output_dir)

    # Process config here
    if FLAGS.config_file:
        with tf.io.gfile.GFile(FLAGS.config_file, 'r') as reader:
            config = json.load(reader)
    else:
        config = json.loads(FLAGS.config)
        # # Save config to workdir if it's not yet exists
        if jax.process_index() == 0:
            config_file = os.path.join(FLAGS.output_dir, 'config.json')
            with tf.io.gfile.GFile(config_file, 'w') as writer:
                writer.write(json.dumps(config, indent=4))

    config['output_dir'] = FLAGS.output_dir

    if 'num_total_memories' not in config:
        config['num_total_memories'] = get_num_total_memories(
            ml_collections.ConfigDict(config))

    generate(ml_collections.ConfigDict(config))
示例#16
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
  # it unavailable to JAX.
  tf.config.experimental.set_visible_devices([], 'GPU')

  # This example only supports single-host training on a single device.
  logging.info('JAX host: %d / %d', jax.process_index(), jax.process_count())
  logging.info('JAX local devices: %r', jax.local_devices())

  # Add a note so that we can tell which task is which JAX host.
  # (Depending on the platform task 0 is not guaranteed to be host 0)
  platform.work_unit().set_task_status(f'process_index: {jax.process_index()}, '
                                       f'process_count: {jax.process_count()}')
  platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                       _WORKDIR.value, 'workdir')

  train.train_and_evaluate(_CONFIG.value, _WORKDIR.value)
示例#17
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError("Too many command-line arguments.")

  flags.mark_flags_as_required(["workdir"])

  # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
  # it unavailable to JAX.
  tf.config.experimental.set_visible_devices([], "GPU")

  logging.info("JAX process: %d / %d", jax.process_index(), jax.process_count())
  logging.info("JAX local devices: %r", jax.local_devices())

  # Add a note so that we can tell which task is which JAX host.
  # (Depending on the platform task 0 is not guaranteed to be host 0)
  platform.work_unit().set_task_status(f"process_index: {jax.process_index()}, "
                                       f"process_count: {jax.process_count()}")
  platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                       FLAGS.workdir, "workdir")

  train_and_evaluate(FLAGS.config, FLAGS.workdir)
示例#18
0
def main(argv):
  del argv

  # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
  # it unavailable to JAX.
  tf.config.experimental.set_visible_devices([], "GPU")

  if FLAGS.jax_backend_target:
    logging.info("Using JAX backend target %s", FLAGS.jax_backend_target)
    jax_xla_backend = ("None" if FLAGS.jax_xla_backend is None else
                       FLAGS.jax_xla_backend)
    logging.info("Using JAX XLA backend %s", jax_xla_backend)

  logging.info("JAX host: %d / %d", jax.process_index(), jax.process_count())
  logging.info("JAX devices: %r", jax.devices())

  platform.work_unit().set_task_status(f"process_index: {jax.process_index()}, "
                                       f"process_count: {jax.process_count()}")
  platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                       FLAGS.workdir, "workdir")

  train.train_and_evaluate(FLAGS.config, FLAGS.workdir)
示例#19
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError("Too many command-line arguments.")

    # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], "GPU")

    if FLAGS.jax_backend_target:
        logging.info("Using JAX backend target %s", FLAGS.jax_backend_target)
        jax.config.update("jax_xla_backend", "tpu_driver")
        jax.config.update("jax_backend_target", FLAGS.jax_backend_target)

    logging.info("JAX host: %d / %d", jax.host_id(), jax.host_count())
    logging.info("JAX local devices: %r", jax.local_devices())

    # Add a note so that we can tell which task is which JAX host.
    # (Depending on the platform task 0 is not guaranteed to be host 0)
    platform.work_unit().set_task_status(
        f"host_id: {jax.host_id()}, host_count: {jax.host_count()}")
    platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                         FLAGS.workdir, "workdir")

    train.train_and_evaluate(FLAGS.config, FLAGS.workdir)
示例#20
0
文件: main.py 项目: myagues/flax_nerf
def main(_):
    if FLAGS.config.precrop_iters > 0 and FLAGS.config.batching:
        raise ValueError(
            "'precrop_iters has no effect when 'batching' the dataset")
    assert FLAGS.config.down_factor > 0 and FLAGS.config.render_factor > 0

    logging.info("JAX host: %d / %d", jax.process_index(), jax.host_count())
    logging.info("JAX local devices: %r", jax.local_devices())

    platform.work_unit().set_task_status(
        f"host_id: {jax.process_index()}, host_count: {jax.host_count()}")
    platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                         FLAGS.model_dir, "model_dir")

    os.makedirs(FLAGS.model_dir, exist_ok=True)
    rng = jax.random.PRNGKey(FLAGS.seed)
    rng, rng_coarse, rng_fine, data_rng, step_rng = jax.random.split(rng, 5)
    rngs = common_utils.shard_prng_key(step_rng)

    ### Load dataset and data values
    datasets, counts, optics, render_datasets = get_dataset(
        FLAGS.data_dir,
        FLAGS.config,
        rng=data_rng,
        num_poses=FLAGS.config.num_poses)
    train_ds, val_ds, test_ds = datasets
    *_, test_items = counts
    hwf, r_hwf, near, far = optics
    render_ds, render_vdirs_ds, num_poses = render_datasets
    iter_render_ds = zip(range(num_poses), render_ds)
    iter_vdirs_ds = zip(range(num_poses), render_vdirs_ds)
    iter_test_ds = zip(range(test_items), test_ds)
    img_h, img_w, _ = hwf

    logging.info("Num poses: %d", num_poses)
    logging.info("Splits: train - %d, val - %d, test - %d", *counts)
    logging.info("Images: height %d, width %d, focal %.5f", *hwf)
    logging.info("Render: height %d, width %d, focal %.5f", *r_hwf)

    ### Init model parameters and optimizer
    initialized_ = functools.partial(initialized,
                                     model_config=FLAGS.config.model)
    pts_shape = (FLAGS.config.num_rand, FLAGS.config.num_samples, 3)
    views_shape = (FLAGS.config.num_rand, 3)
    model_coarse, params_coarse = initialized_(rng_coarse, pts_shape,
                                               views_shape)

    schedule_fn = optax.exponential_decay(
        init_value=FLAGS.config.learning_rate,
        transition_steps=FLAGS.config.lr_decay * 1000,
        decay_rate=FLAGS.config.decay_factor,
    )
    tx = optax.adam(learning_rate=schedule_fn)
    state = train_state.TrainState.create(apply_fn=(model_coarse.apply, None),
                                          params={"coarse": params_coarse},
                                          tx=tx)

    if FLAGS.config.num_importance > 0:
        pts_shape = (
            FLAGS.config.num_rand,
            FLAGS.config.num_importance + FLAGS.config.num_samples,
            3,
        )
        model_fine, params_fine = initialized_(rng_fine, pts_shape,
                                               views_shape)
        state = train_state.TrainState.create(
            apply_fn=(model_coarse.apply, model_fine.apply),
            params={
                "coarse": params_coarse,
                "fine": params_fine
            },
            tx=tx,
        )

    state = checkpoints.restore_checkpoint(FLAGS.model_dir, state)
    start_step = int(state.step)

    # cycle already seen examples if resuming from checkpoint
    # (only useful for ensuring deterministic dataset, slow for large start_step)
    # if start_step != 0:
    #     for _ in range(start_step):
    #         _ = next(train_ds)

    # parameter_overview.log_parameter_overview(state.optimizer_coarse.target)
    # if FLAGS.config.num_importance > 0:
    #     parameter_overview.log_parameter_overview(state.optimizer_fine.target)

    state = jax.device_put_replicated(state, jax.local_devices())

    ### Build "pmapped" functions for distributed training
    train_fn = functools.partial(train_step, near, far, FLAGS.config,
                                 schedule_fn)
    p_train_step = jax.pmap(
        train_fn,
        axis_name="batch",
        in_axes=(0, 0, None, 0),
        # donate_argnums=(0, 1, 2),
    )

    def render_fn(state, rays):
        step_fn = functools.partial(eval_step, FLAGS.config, near, far, state)
        return lax.map(step_fn, rays)

    p_eval_step = jax.pmap(
        render_fn,
        axis_name="batch",
        # in_axes=(0, 0, None),
        # donate_argnums=(0, 1))
    )

    # TODO: add hparams
    writer = metric_writers.create_default_writer(
        FLAGS.model_dir, just_logging=jax.process_index() > 0)
    logging.info("Starting training loop.")

    hooks = []
    profiler = periodic_actions.Profile(num_profile_steps=5,
                                        logdir=FLAGS.model_dir)
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=FLAGS.config.num_steps, writer=writer)
    if jax.process_index() == 0:
        hooks += [profiler, report_progress]
    train_metrics = []
    gen_video_ = functools.partial(gen_video, FLAGS.model_dir)

    for step in range(start_step, FLAGS.config.num_steps + 1):
        is_last_step = step == FLAGS.config.num_steps

        batch = next(train_ds)
        coords = None
        if not FLAGS.config.batching:
            coords = jnp.meshgrid(jnp.arange(img_h),
                                  jnp.arange(img_w),
                                  indexing="ij")
            if step < FLAGS.config.precrop_iters:
                dH = int(img_h // 2 * FLAGS.config.precrop_frac)
                dW = int(img_w // 2 * FLAGS.config.precrop_frac)
                coords = jnp.meshgrid(
                    jnp.arange(img_h // 2 - dH, img_h // 2 + dH),
                    jnp.arange(img_w // 2 - dW, img_w // 2 + dW),
                    indexing="ij",
                )
            coords = jnp.stack(coords, axis=-1).reshape([-1, 2])

        with jax.profiler.StepTraceAnnotation("train", step_num=step):
            state, metrics = p_train_step(batch, state, coords, rngs)
        train_metrics.append(metrics)

        logging.log_first_n(logging.INFO, "Finished training step %d.", 5,
                            step)
        _ = [h(step) for h in hooks]

        ### Write train summaries to TB
        if step % FLAGS.config.i_print == 0 or is_last_step:
            with report_progress.timed("training_metrics"):
                train_metrics = common_utils.get_metrics(train_metrics)
                train_summary = jax.tree_map(lambda x: x.mean(), train_metrics)
                summary = {f"train/{k}": v for k, v in train_summary.items()}
                writer.write_scalars(step, summary)
            train_metrics = []

        ### Eval a random validation image and plot it to TB
        if step % FLAGS.config.i_img == 0 and step > 0 or is_last_step:
            with report_progress.timed("validation"):
                inputs = next(val_ds)
                rays, padding = prepare_render_data(inputs["rays"]._numpy())
                outputs = p_eval_step(state, rays)
                preds, preds_c, z_std = jax.tree_map(
                    lambda x: to_np(x, hwf, padding), outputs)
                loss = np.mean((preds["rgb"] - inputs["image"])**2)
                summary = {"val/loss": loss, "val/psnr": psnr_fn(loss)}
                writer.write_scalars(step, summary)

                summary = {
                    "val/rgb": to_rgb(preds["rgb"]),
                    "val/target": to_np(inputs["image"], hwf, padding),
                    "val/disp": disp_post(preds["disp"], FLAGS.config),
                    "val/acc": preds["acc"],
                }
                if FLAGS.config.num_importance > 0:
                    summary["val/rgb_c"] = to_rgb(preds_c["rgb"])
                    summary["val/disp_c"] = disp_post(preds_c["disp"],
                                                      FLAGS.config)
                    summary["val/z_std"] = z_std
                writer.write_images(step, summary)

        ### Render a video with test poses
        if step % FLAGS.config.i_video == 0 and step > 0:
            with report_progress.timed("video_render"):
                logging.info("Rendering video at step %d", step)
                rgb_list = []
                disp_list = []
                for idx, inputs in tqdm(iter_render_ds, desc="Rays render"):
                    rays, padding = prepare_render_data(inputs["rays"].numpy())
                    preds, *_ = p_eval_step(state, rays)
                    preds = jax.tree_map(lambda x: to_np(x, r_hwf, padding),
                                         preds)
                    rgb_list.append(preds["rgb"])
                    disp_list.append(preds["disp"])

                gen_video_(np.stack(rgb_list), "rgb", r_hwf, step)
                disp = np.stack(disp_list)
                gen_video_(disp_post(disp, FLAGS.config),
                           "disp",
                           r_hwf,
                           step,
                           ch=1)

                if FLAGS.config.use_viewdirs:
                    rgb_list = []
                    for idx, inputs in tqdm(iter_vdirs_ds,
                                            desc="Viewdirs render"):
                        rays, padding = prepare_render_data(
                            inputs["rays"].numpy())
                        preds, *_ = p_eval_step(state, rays)
                        rgb_list.append(to_np(preds["rgb"], r_hwf, padding))
                    gen_video_(np.stack(rgb_list), "rgb_still", r_hwf, step)

        ### Save images in the test set
        if step % FLAGS.config.i_testset == 0 and step > 0:
            with report_progress.timed("test_render"):
                logging.info("Rendering test set at step %d", step)
                test_losses = []
                for idx, inputs in tqdm(iter_test_ds, desc="Test render"):
                    rays, padding = prepare_render_data(inputs["rays"].numpy())
                    preds, *_ = p_eval_step(state, rays)
                    save_test_imgs(FLAGS.model_dir, preds["rgb"], r_hwf, step,
                                   idx)

                    if FLAGS.config.render_factor == 0:
                        loss = np.mean((preds["rgb"] - inputs["image"])**2.0)
                        test_losses.append(loss)
                if FLAGS.config.render_factor == 0:
                    loss = np.mean(test_losses)
                    summary = {"test/loss": loss, "test/psnr": psnr_fn(loss)}
                    writer.write_scalars(step, summary)
        writer.flush()

        ### Save ckpt
        if step % FLAGS.config.i_weights == 0 or is_last_step:
            with report_progress.timed("checkpoint"):
                save_checkpoint(state, FLAGS.model_dir)
示例#21
0
def main(executable_dict, argv):
    del argv

    work_unit = platform.work_unit()
    tf.enable_v2_behavior()
    # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], 'GPU')

    host_id = jax.host_id()
    n_host = jax.host_count()
    logging.info('JAX host: %d / %d', host_id, n_host)
    logging.info('JAX devices: %r', jax.devices())
    # Add a note so that we can tell which task is which JAX host.
    # (task 0 is not guaranteed to be host 0)
    work_unit.set_task_status(
        f'host_id: {jax.host_id()}, host_count: {jax.host_count()}')

    # Read configuration
    if FLAGS.config_json:
        logging.info('Reading config from JSON: %s', FLAGS.config_json)
        with gfile.GFile(FLAGS.config_json, 'r') as f:
            config = ml_collections.ConfigDict(json.loads(f.read()))
    else:
        config = FLAGS.config

    # Set query
    if FLAGS.query:
        config.query = FLAGS.query

    # Make output directories
    work_unit.create_artifact(platform.ArtifactType.DIRECTORY,
                              FLAGS.experiment_dir, 'experiment_dir')
    if not FLAGS.work_unit_dir:
        timestr = time.strftime('%Y%m%d-%H%M%S')
        FLAGS.work_unit_dir = os.path.join(FLAGS.experiment_dir,
                                           f"'{FLAGS.query}' {timestr}")
    work_unit.create_artifact(platform.ArtifactType.DIRECTORY,
                              FLAGS.work_unit_dir, 'work_unit_dir')
    logging.info('experiment_dir=%s work_unit_dir=%s', FLAGS.experiment_dir,
                 FLAGS.work_unit_dir)

    # Seeding
    if FLAGS.seed is not None:
        config.seed = FLAGS.seed
    random.seed(config.seed * n_host + host_id)
    onp.random.seed(config.seed * n_host + host_id)
    logging.debug('setting up RNG...')
    key = jax.random.PRNGKey(config.seed)
    key = jax.random.fold_in(key, host_id)
    rng = helpers.RngGen(key)
    logging.debug('done setting up RNG')

    # Log config
    logging.info('config=%s',
                 config.to_json_best_effort(indent=4, sort_keys=True))

    # Run the main function
    logging.info('Running executable: %s', FLAGS.executable_name)

    extra_args = {}
    if FLAGS.extra_args_json_str:
        extra_args = json.loads(FLAGS.extra_args_json_str)
        logging.info('Extra args passed in: %r', extra_args)

    executable_dict[FLAGS.executable_name](config=config,
                                           experiment_dir=FLAGS.experiment_dir,
                                           work_unit_dir=FLAGS.work_unit_dir,
                                           rng=rng,
                                           **extra_args)
示例#22
0
def main(argv: Sequence[str]) -> None:
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    setup_jax(FLAGS.globally_use_hardware_rng,
              FLAGS.jax_parallel_functions_output_gda,
              FLAGS.jax_backend_target, FLAGS.jax_xla_backend)

    # Add a note so that we can tell which Borg task is which JAX host.
    # (Borg task 0 is not guaranteed to be host 0)
    if jax.process_count() > 128:
        wait_with_random_jitter(min_secs=0, max_secs=60)
    work_unit = platform.work_unit()
    work_unit.set_task_status(f'process_index: {jax.process_index()}, '
                              f'process_count: {jax.process_count()}')
    work_unit.create_artifact(platform.ArtifactType.DIRECTORY,
                              FLAGS.job_log_dir, 'job_log_dir')

    # Start jax.profiler for Tensorboard and profiling in open source.
    if FLAGS.jax_profiler_port is not None:
        server = jax.profiler.start_server(FLAGS.jax_profiler_port)  # pylint:disable=unused-variable
    if FLAGS.mode == 'train':
        train.train_and_evaluate(
            model_name=FLAGS.model,
            job_log_dir=FLAGS.job_log_dir,
            multi_host_checkpointing=FLAGS.multi_host_checkpointing,
            maybe_use_persistence_checkpointing=FLAGS.
            maybe_use_persistence_checkpointing,
            restore_checkpoint_dir=FLAGS.restore_checkpoint_dir,
            restore_checkpoint_step=FLAGS.restore_checkpoint_step,
            eval_on_test=FLAGS.eval_on_test,
            checkpoint_todelete_subdir=FLAGS.checkpoint_todelete_subdir)
    elif FLAGS.mode == 'eval':
        eval_lib.evaluate(
            model_name=FLAGS.model,
            job_log_dir=FLAGS.job_log_dir,
            multi_host_checkpointing=FLAGS.multi_host_checkpointing,
            maybe_use_persistence_checkpointing=FLAGS.
            maybe_use_persistence_checkpointing)
    elif FLAGS.mode == 'decode':
        eval_lib.decode(
            model_name=FLAGS.model,
            job_log_dir=FLAGS.job_log_dir,
            multi_host_checkpointing=FLAGS.multi_host_checkpointing,
            maybe_use_persistence_checkpointing=FLAGS.
            maybe_use_persistence_checkpointing,
            restore_checkpoint_dir=None,
            restore_checkpoint_step=None,
            continuous_decode=True,
        )
    elif FLAGS.mode == 'decode_once':
        if not FLAGS.restore_checkpoint_dir:
            raise ValueError(
                '--mode=decode_once requires --restore_checkpoint_dir.')
        eval_lib.decode(
            model_name=FLAGS.model,
            job_log_dir=FLAGS.job_log_dir,
            multi_host_checkpointing=FLAGS.multi_host_checkpointing,
            maybe_use_persistence_checkpointing=FLAGS.
            maybe_use_persistence_checkpointing,
            restore_checkpoint_dir=FLAGS.restore_checkpoint_dir,
            restore_checkpoint_step=FLAGS.restore_checkpoint_step,
            continuous_decode=False,
        )