def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) logging.info('JAX local devices: %r', jax.local_devices()) # Add a note so that we can tell which task is which JAX host. # (Depending on the platform task 0 is not guaranteed to be host 0) platform.work_unit().set_task_status( f'process_index: {jax.process_index()}, ' f'process_count: {jax.process_count()}') platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir') if FLAGS.mode == 'train': train.train_and_evaluate(FLAGS.config, FLAGS.workdir) else: predict.predict_and_evaluate(FLAGS.config, FLAGS.workdir, FLAGS.ckpt_path)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') utils.add_gfile_logger(_WORKDIR.value) # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') jax.config.update('jax_log_compiles', True) logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) logging.info('JAX local devices: %r', jax.local_devices()) jax_xla_backend = ('None' if FLAGS.jax_xla_backend is None else FLAGS.jax_xla_backend) logging.info('Using JAX XLA backend %s', jax_xla_backend) logging.info('Config: %s', FLAGS.config) # Add a note so that we can tell which task is which JAX host. # (Depending on the platform task 0 is not guaranteed to be host 0) platform.work_unit().set_task_status( f'process_index: {jax.process_index()}, ' f'process_count: {jax.process_count()}') platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, _WORKDIR.value, 'workdir') if FLAGS.config.trainer == 'train': train.train_and_evaluate(FLAGS.config, _WORKDIR.value) elif FLAGS.config.trainer == 'inference_time': inference_time.inference_time(FLAGS.config, _WORKDIR.value) else: raise app.UsageError(f'Unknown trainer: {FLAGS.config.trainer}')
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') FLAGS.log_dir = FLAGS.workdir FLAGS.stderrthreshold = 'info' logging.get_absl_handler().start_logging_to_file() # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') logging.info('JAX host: %d / %d', jax.host_id(), jax.host_count()) logging.info('JAX local devices: %r', jax.local_devices()) # Add a note so that we can tell which task is which JAX host. # (Depending on the platform task 0 is not guaranteed to be host 0) platform.work_unit().set_task_status( f'host_id: {jax.host_id()}, host_count: {jax.host_count()}') platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir') if FLAGS.sample: sample.save_images(sample.generate_sample(FLAGS.config, FLAGS.workdir), 'sample.png') else: train.train_and_evaluate(FLAGS.config, FLAGS.workdir)
def main(argv): del argv # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], "GPU") logging.info("JAX process: %d / %d", jax.process_index(), jax.process_count()) logging.info("JAX devices: %r", jax.devices()) # Add a note so that we can tell which task is which JAX process. platform.work_unit().set_task_status( f"process_index: {jax.process_index()}, process_count: {jax.process_count()}" ) platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.workdir, "workdir") train_mode = FLAGS.config.mode if train_mode == TrainingMode.PRETRAINING: train_lib = run_pretraining elif train_mode == TrainingMode.CLASSIFICATION: train_lib = run_classifier else: raise ValueError("Unknown training mode: %s" % train_mode) train_lib.train_and_evaluate(FLAGS.config, FLAGS.workdir, FLAGS.vocab_filepath)
def main(argv): del argv # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], "GPU") if FLAGS.jax_backend_target: logging.info("Using JAX backend target %s", FLAGS.jax_backend_target) jax_xla_backend = ("None" if FLAGS.jax_xla_backend is None else FLAGS.jax_xla_backend) logging.info("Using JAX XLA backend %s", jax_xla_backend) logging.info("JAX process: %d / %d", jax.process_index(), jax.process_count()) logging.info("JAX devices: %r", jax.devices()) if FLAGS.is_train: # Add a note so that we can tell which task is which JAX host. # (Depending on platform task 0 is not guaranteed to be host 0) platform.work_unit().set_task_status( f"process_index: {jax.process_index()}, " f"process_count: {jax.process_count()}") platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.workdir, "workdir") train_lib.train_and_evaluate(FLAGS.ml_config, FLAGS.workdir) else: eval_lib.evaluate(FLAGS.ml_config, FLAGS.workdir)
def main(argv): del argv # Hide any GPUs form TensorFlow. Otherwise, TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], "GPU") logging.info("JAX host: %d / %d", jax.process_index(), jax.process_count()) logging.info("JAX devices: %r", jax.devices()) # Add a note so that we can tell which task is which JAX host. (Task 0 is not # guaranteed to be host 0) platform.work_unit().set_task_status( f"host_id: {jax.process_index()}, host_count: {jax.process_count()}") platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, _WORKDIR.value, "workdir") train_mode = _CONFIG.value.mode if train_mode == TrainingMode.PRETRAINING: train_lib = run_pretraining elif train_mode == TrainingMode.CLASSIFICATION: train_lib = run_classifier else: raise ValueError("Unknown mode: %s" % train_mode) train_lib.train_and_evaluate(_CONFIG.value, _WORKDIR.value, _VOCAB_FILEPATH.value)
def _end_session(self, url: Optional[str]): platform.work_unit().create_artifact( platform.ArtifactType.URL, url, description=self._artifact_name.format(step=self._previous_step)) self._session_running = False self._session_started = None
def _end_session(self, url: Optional[str]): platform.work_unit().create_artifact( platform.ArtifactType.URL, url, description=f"[{self._previous_step}] Profile") self._session_running = False self._session_started = None
def _end_session(self): url = profiler.stop() if url is not None: platform.work_unit().create_artifact( platform.ArtifactType.URL, url, description=f"[{self._previous_step}] Profile") self._session_running = False
def _apply(self, step: int, t: float): steps_per_sec = (step - self._previous_step) / (t - self._previous_time) eta_seconds = (self._num_train_steps - step) / steps_per_sec message = (f"{100 * step / self._num_train_steps:.1f}% @{step}, " f"{steps_per_sec:.1f} steps/s, ETA: {eta_seconds / 60:.0f} min") if self._time_per_part: total = time.time() - self._t0 message += " ({:.0f} min : {})".format(total / 60, ", ".join( f"{100 * dt / total:.1f}% {name}" for name, dt in sorted(self._time_per_part.items()))) # This should be relatively cheap so we can do it in the same main thread. platform.work_unit().set_notes(message) if self._writer is not None: self._writer.write_scalars(step, {"steps_per_sec": steps_per_sec})
def main(*_args, **_kwargs): # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') info_string = f'JAX process: {jax.process_index()} / {jax.process_count()}' logging.info(info_string) info_string = f'JAX local devices: {jax.local_devices()}' logging.info(info_string) platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.work_unit_dir, 'work_unit_dir') config = FLAGS.config train.train_and_evaluate(config, FLAGS.work_unit_dir)
def main(executable_dict, argv): del argv work_unit = platform.work_unit() tf.enable_v2_behavior() # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') logging.info('JAX host: %d / %d', jax.host_id(), jax.host_count()) logging.info('JAX devices: %r', jax.devices()) work_unit.set_task_status( f'host_id: {jax.host_id()}, host_count: {jax.host_count()}') # Read configuration if FLAGS.config_json: logging.info('Reading config from JSON: %s', FLAGS.config_json) with tf.io.gfile.GFile(FLAGS.config_json, 'r') as f: config = ml_collections.ConfigDict(json.loads(f.read())) else: config = FLAGS.config logging.info('config=%s', config.to_json_best_effort(indent=4, sort_keys=True)) # Make output directories if FLAGS.experiment_dir: work_unit.create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.experiment_dir, 'experiment_dir') if FLAGS.work_unit_dir: work_unit.create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.work_unit_dir, 'work_unit_dir') logging.info('experiment_dir=%s work_unit_dir=%s', FLAGS.experiment_dir, FLAGS.work_unit_dir) # Seeding random.seed(config.seed * jax.host_count() + jax.host_id()) onp.random.seed(config.seed * jax.host_count() + jax.host_id()) rng = utils.RngGen( jax.random.fold_in(jax.random.PRNGKey(config.seed), jax.host_id())) # Run the main function logging.info('Running executable: %s', FLAGS.executable_name) extra_args = {} if FLAGS.extra_args_json_str: extra_args = json.loads(FLAGS.extra_args_json_str) logging.info('Extra args passed in: %r', extra_args) executable_dict[FLAGS.executable_name](config=config, experiment_dir=FLAGS.experiment_dir, work_unit_dir=FLAGS.work_unit_dir, rng=rng, **extra_args) utils.barrier()
def main(argv): del argv # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], "GPU") logging.info("JAX host: %d / %d", jax.host_id(), jax.host_count()) logging.info("JAX devices: %r", jax.devices()) # Add a note so that we can tell which task is which JAX host. # (Borg task 0 is not guaranteed to be host 0) platform.work_unit().set_task_status( f"host_id: {jax.host_id()}, host_count: {jax.host_count()}") platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.workdir, "workdir") state = train_and_evaluate(FLAGS.config, FLAGS.workdir) del state
def main(argv: Sequence[str]) -> None: if len(argv) > 1: raise app.UsageError("Too many command-line arguments.") # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], "GPU") logging.info("JAX process: %d / %d", jax.process_index(), jax.process_count()) logging.info("JAX local devices: %r", jax.local_devices()) # Add a note so that we can tell which task is which JAX host. # (Depending on the platform task 0 is not guaranteed to be host 0) platform.work_unit().set_task_status(f"process_index: {jax.process_index()}, " f"process_count: {jax.process_count()}") platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.model_dir, "model_dir") tf.io.gfile.makedirs(FLAGS.model_dir) # Process config here if FLAGS.config_file: with tf.io.gfile.GFile(FLAGS.config_file, "r") as reader: config = json.load(reader) else: config = json.loads(FLAGS.config) # # Save config to workdir if it's not yet exists if jax.process_index() == 0: config_file = os.path.join(FLAGS.model_dir, "config.json") with tf.io.gfile.GFile(config_file, "w") as writer: writer.write(json.dumps(config, indent=4)) config["model_dir"] = FLAGS.model_dir if FLAGS.learning_rate is not None: config["learning_rate"] = FLAGS.learning_rate if FLAGS.per_device_batch_size is not None: config["per_device_batch_size"] = FLAGS.per_device_batch_size if FLAGS.num_train_steps is not None: config["num_train_steps"] = FLAGS.num_train_steps if FLAGS.warmup_steps is not None: config["warmup_steps"] = FLAGS.warmup_steps train(ml_collections.ConfigDict(config))
def main(argv: Sequence[str]) -> None: if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) logging.info('JAX local devices: %r', jax.local_devices()) logging.info('JAX total devices: %r', jax.device_count()) # Add a note so that we can tell which task is which JAX host. # (Depending on the platform task 0 is not guaranteed to be host 0) platform.work_unit().set_task_status( f'process_index: {jax.process_index()}, ' f'process_count: {jax.process_count()}') platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.output_dir, 'output_dir') tf.io.gfile.makedirs(FLAGS.output_dir) # Process config here if FLAGS.config_file: with tf.io.gfile.GFile(FLAGS.config_file, 'r') as reader: config = json.load(reader) else: config = json.loads(FLAGS.config) # # Save config to workdir if it's not yet exists if jax.process_index() == 0: config_file = os.path.join(FLAGS.output_dir, 'config.json') with tf.io.gfile.GFile(config_file, 'w') as writer: writer.write(json.dumps(config, indent=4)) config['output_dir'] = FLAGS.output_dir if 'num_total_memories' not in config: config['num_total_memories'] = get_num_total_memories( ml_collections.ConfigDict(config)) generate(ml_collections.ConfigDict(config))
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') # This example only supports single-host training on a single device. logging.info('JAX host: %d / %d', jax.process_index(), jax.process_count()) logging.info('JAX local devices: %r', jax.local_devices()) # Add a note so that we can tell which task is which JAX host. # (Depending on the platform task 0 is not guaranteed to be host 0) platform.work_unit().set_task_status(f'process_index: {jax.process_index()}, ' f'process_count: {jax.process_count()}') platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, _WORKDIR.value, 'workdir') train.train_and_evaluate(_CONFIG.value, _WORKDIR.value)
def main(argv): if len(argv) > 1: raise app.UsageError("Too many command-line arguments.") flags.mark_flags_as_required(["workdir"]) # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], "GPU") logging.info("JAX process: %d / %d", jax.process_index(), jax.process_count()) logging.info("JAX local devices: %r", jax.local_devices()) # Add a note so that we can tell which task is which JAX host. # (Depending on the platform task 0 is not guaranteed to be host 0) platform.work_unit().set_task_status(f"process_index: {jax.process_index()}, " f"process_count: {jax.process_count()}") platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.workdir, "workdir") train_and_evaluate(FLAGS.config, FLAGS.workdir)
def main(argv): del argv # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], "GPU") if FLAGS.jax_backend_target: logging.info("Using JAX backend target %s", FLAGS.jax_backend_target) jax_xla_backend = ("None" if FLAGS.jax_xla_backend is None else FLAGS.jax_xla_backend) logging.info("Using JAX XLA backend %s", jax_xla_backend) logging.info("JAX host: %d / %d", jax.process_index(), jax.process_count()) logging.info("JAX devices: %r", jax.devices()) platform.work_unit().set_task_status(f"process_index: {jax.process_index()}, " f"process_count: {jax.process_count()}") platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.workdir, "workdir") train.train_and_evaluate(FLAGS.config, FLAGS.workdir)
def main(argv): if len(argv) > 1: raise app.UsageError("Too many command-line arguments.") # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], "GPU") if FLAGS.jax_backend_target: logging.info("Using JAX backend target %s", FLAGS.jax_backend_target) jax.config.update("jax_xla_backend", "tpu_driver") jax.config.update("jax_backend_target", FLAGS.jax_backend_target) logging.info("JAX host: %d / %d", jax.host_id(), jax.host_count()) logging.info("JAX local devices: %r", jax.local_devices()) # Add a note so that we can tell which task is which JAX host. # (Depending on the platform task 0 is not guaranteed to be host 0) platform.work_unit().set_task_status( f"host_id: {jax.host_id()}, host_count: {jax.host_count()}") platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.workdir, "workdir") train.train_and_evaluate(FLAGS.config, FLAGS.workdir)
def main(_): if FLAGS.config.precrop_iters > 0 and FLAGS.config.batching: raise ValueError( "'precrop_iters has no effect when 'batching' the dataset") assert FLAGS.config.down_factor > 0 and FLAGS.config.render_factor > 0 logging.info("JAX host: %d / %d", jax.process_index(), jax.host_count()) logging.info("JAX local devices: %r", jax.local_devices()) platform.work_unit().set_task_status( f"host_id: {jax.process_index()}, host_count: {jax.host_count()}") platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.model_dir, "model_dir") os.makedirs(FLAGS.model_dir, exist_ok=True) rng = jax.random.PRNGKey(FLAGS.seed) rng, rng_coarse, rng_fine, data_rng, step_rng = jax.random.split(rng, 5) rngs = common_utils.shard_prng_key(step_rng) ### Load dataset and data values datasets, counts, optics, render_datasets = get_dataset( FLAGS.data_dir, FLAGS.config, rng=data_rng, num_poses=FLAGS.config.num_poses) train_ds, val_ds, test_ds = datasets *_, test_items = counts hwf, r_hwf, near, far = optics render_ds, render_vdirs_ds, num_poses = render_datasets iter_render_ds = zip(range(num_poses), render_ds) iter_vdirs_ds = zip(range(num_poses), render_vdirs_ds) iter_test_ds = zip(range(test_items), test_ds) img_h, img_w, _ = hwf logging.info("Num poses: %d", num_poses) logging.info("Splits: train - %d, val - %d, test - %d", *counts) logging.info("Images: height %d, width %d, focal %.5f", *hwf) logging.info("Render: height %d, width %d, focal %.5f", *r_hwf) ### Init model parameters and optimizer initialized_ = functools.partial(initialized, model_config=FLAGS.config.model) pts_shape = (FLAGS.config.num_rand, FLAGS.config.num_samples, 3) views_shape = (FLAGS.config.num_rand, 3) model_coarse, params_coarse = initialized_(rng_coarse, pts_shape, views_shape) schedule_fn = optax.exponential_decay( init_value=FLAGS.config.learning_rate, transition_steps=FLAGS.config.lr_decay * 1000, decay_rate=FLAGS.config.decay_factor, ) tx = optax.adam(learning_rate=schedule_fn) state = train_state.TrainState.create(apply_fn=(model_coarse.apply, None), params={"coarse": params_coarse}, tx=tx) if FLAGS.config.num_importance > 0: pts_shape = ( FLAGS.config.num_rand, FLAGS.config.num_importance + FLAGS.config.num_samples, 3, ) model_fine, params_fine = initialized_(rng_fine, pts_shape, views_shape) state = train_state.TrainState.create( apply_fn=(model_coarse.apply, model_fine.apply), params={ "coarse": params_coarse, "fine": params_fine }, tx=tx, ) state = checkpoints.restore_checkpoint(FLAGS.model_dir, state) start_step = int(state.step) # cycle already seen examples if resuming from checkpoint # (only useful for ensuring deterministic dataset, slow for large start_step) # if start_step != 0: # for _ in range(start_step): # _ = next(train_ds) # parameter_overview.log_parameter_overview(state.optimizer_coarse.target) # if FLAGS.config.num_importance > 0: # parameter_overview.log_parameter_overview(state.optimizer_fine.target) state = jax.device_put_replicated(state, jax.local_devices()) ### Build "pmapped" functions for distributed training train_fn = functools.partial(train_step, near, far, FLAGS.config, schedule_fn) p_train_step = jax.pmap( train_fn, axis_name="batch", in_axes=(0, 0, None, 0), # donate_argnums=(0, 1, 2), ) def render_fn(state, rays): step_fn = functools.partial(eval_step, FLAGS.config, near, far, state) return lax.map(step_fn, rays) p_eval_step = jax.pmap( render_fn, axis_name="batch", # in_axes=(0, 0, None), # donate_argnums=(0, 1)) ) # TODO: add hparams writer = metric_writers.create_default_writer( FLAGS.model_dir, just_logging=jax.process_index() > 0) logging.info("Starting training loop.") hooks = [] profiler = periodic_actions.Profile(num_profile_steps=5, logdir=FLAGS.model_dir) report_progress = periodic_actions.ReportProgress( num_train_steps=FLAGS.config.num_steps, writer=writer) if jax.process_index() == 0: hooks += [profiler, report_progress] train_metrics = [] gen_video_ = functools.partial(gen_video, FLAGS.model_dir) for step in range(start_step, FLAGS.config.num_steps + 1): is_last_step = step == FLAGS.config.num_steps batch = next(train_ds) coords = None if not FLAGS.config.batching: coords = jnp.meshgrid(jnp.arange(img_h), jnp.arange(img_w), indexing="ij") if step < FLAGS.config.precrop_iters: dH = int(img_h // 2 * FLAGS.config.precrop_frac) dW = int(img_w // 2 * FLAGS.config.precrop_frac) coords = jnp.meshgrid( jnp.arange(img_h // 2 - dH, img_h // 2 + dH), jnp.arange(img_w // 2 - dW, img_w // 2 + dW), indexing="ij", ) coords = jnp.stack(coords, axis=-1).reshape([-1, 2]) with jax.profiler.StepTraceAnnotation("train", step_num=step): state, metrics = p_train_step(batch, state, coords, rngs) train_metrics.append(metrics) logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) _ = [h(step) for h in hooks] ### Write train summaries to TB if step % FLAGS.config.i_print == 0 or is_last_step: with report_progress.timed("training_metrics"): train_metrics = common_utils.get_metrics(train_metrics) train_summary = jax.tree_map(lambda x: x.mean(), train_metrics) summary = {f"train/{k}": v for k, v in train_summary.items()} writer.write_scalars(step, summary) train_metrics = [] ### Eval a random validation image and plot it to TB if step % FLAGS.config.i_img == 0 and step > 0 or is_last_step: with report_progress.timed("validation"): inputs = next(val_ds) rays, padding = prepare_render_data(inputs["rays"]._numpy()) outputs = p_eval_step(state, rays) preds, preds_c, z_std = jax.tree_map( lambda x: to_np(x, hwf, padding), outputs) loss = np.mean((preds["rgb"] - inputs["image"])**2) summary = {"val/loss": loss, "val/psnr": psnr_fn(loss)} writer.write_scalars(step, summary) summary = { "val/rgb": to_rgb(preds["rgb"]), "val/target": to_np(inputs["image"], hwf, padding), "val/disp": disp_post(preds["disp"], FLAGS.config), "val/acc": preds["acc"], } if FLAGS.config.num_importance > 0: summary["val/rgb_c"] = to_rgb(preds_c["rgb"]) summary["val/disp_c"] = disp_post(preds_c["disp"], FLAGS.config) summary["val/z_std"] = z_std writer.write_images(step, summary) ### Render a video with test poses if step % FLAGS.config.i_video == 0 and step > 0: with report_progress.timed("video_render"): logging.info("Rendering video at step %d", step) rgb_list = [] disp_list = [] for idx, inputs in tqdm(iter_render_ds, desc="Rays render"): rays, padding = prepare_render_data(inputs["rays"].numpy()) preds, *_ = p_eval_step(state, rays) preds = jax.tree_map(lambda x: to_np(x, r_hwf, padding), preds) rgb_list.append(preds["rgb"]) disp_list.append(preds["disp"]) gen_video_(np.stack(rgb_list), "rgb", r_hwf, step) disp = np.stack(disp_list) gen_video_(disp_post(disp, FLAGS.config), "disp", r_hwf, step, ch=1) if FLAGS.config.use_viewdirs: rgb_list = [] for idx, inputs in tqdm(iter_vdirs_ds, desc="Viewdirs render"): rays, padding = prepare_render_data( inputs["rays"].numpy()) preds, *_ = p_eval_step(state, rays) rgb_list.append(to_np(preds["rgb"], r_hwf, padding)) gen_video_(np.stack(rgb_list), "rgb_still", r_hwf, step) ### Save images in the test set if step % FLAGS.config.i_testset == 0 and step > 0: with report_progress.timed("test_render"): logging.info("Rendering test set at step %d", step) test_losses = [] for idx, inputs in tqdm(iter_test_ds, desc="Test render"): rays, padding = prepare_render_data(inputs["rays"].numpy()) preds, *_ = p_eval_step(state, rays) save_test_imgs(FLAGS.model_dir, preds["rgb"], r_hwf, step, idx) if FLAGS.config.render_factor == 0: loss = np.mean((preds["rgb"] - inputs["image"])**2.0) test_losses.append(loss) if FLAGS.config.render_factor == 0: loss = np.mean(test_losses) summary = {"test/loss": loss, "test/psnr": psnr_fn(loss)} writer.write_scalars(step, summary) writer.flush() ### Save ckpt if step % FLAGS.config.i_weights == 0 or is_last_step: with report_progress.timed("checkpoint"): save_checkpoint(state, FLAGS.model_dir)
def main(executable_dict, argv): del argv work_unit = platform.work_unit() tf.enable_v2_behavior() # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') host_id = jax.host_id() n_host = jax.host_count() logging.info('JAX host: %d / %d', host_id, n_host) logging.info('JAX devices: %r', jax.devices()) # Add a note so that we can tell which task is which JAX host. # (task 0 is not guaranteed to be host 0) work_unit.set_task_status( f'host_id: {jax.host_id()}, host_count: {jax.host_count()}') # Read configuration if FLAGS.config_json: logging.info('Reading config from JSON: %s', FLAGS.config_json) with gfile.GFile(FLAGS.config_json, 'r') as f: config = ml_collections.ConfigDict(json.loads(f.read())) else: config = FLAGS.config # Set query if FLAGS.query: config.query = FLAGS.query # Make output directories work_unit.create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.experiment_dir, 'experiment_dir') if not FLAGS.work_unit_dir: timestr = time.strftime('%Y%m%d-%H%M%S') FLAGS.work_unit_dir = os.path.join(FLAGS.experiment_dir, f"'{FLAGS.query}' {timestr}") work_unit.create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.work_unit_dir, 'work_unit_dir') logging.info('experiment_dir=%s work_unit_dir=%s', FLAGS.experiment_dir, FLAGS.work_unit_dir) # Seeding if FLAGS.seed is not None: config.seed = FLAGS.seed random.seed(config.seed * n_host + host_id) onp.random.seed(config.seed * n_host + host_id) logging.debug('setting up RNG...') key = jax.random.PRNGKey(config.seed) key = jax.random.fold_in(key, host_id) rng = helpers.RngGen(key) logging.debug('done setting up RNG') # Log config logging.info('config=%s', config.to_json_best_effort(indent=4, sort_keys=True)) # Run the main function logging.info('Running executable: %s', FLAGS.executable_name) extra_args = {} if FLAGS.extra_args_json_str: extra_args = json.loads(FLAGS.extra_args_json_str) logging.info('Extra args passed in: %r', extra_args) executable_dict[FLAGS.executable_name](config=config, experiment_dir=FLAGS.experiment_dir, work_unit_dir=FLAGS.work_unit_dir, rng=rng, **extra_args)
def main(argv: Sequence[str]) -> None: if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') setup_jax(FLAGS.globally_use_hardware_rng, FLAGS.jax_parallel_functions_output_gda, FLAGS.jax_backend_target, FLAGS.jax_xla_backend) # Add a note so that we can tell which Borg task is which JAX host. # (Borg task 0 is not guaranteed to be host 0) if jax.process_count() > 128: wait_with_random_jitter(min_secs=0, max_secs=60) work_unit = platform.work_unit() work_unit.set_task_status(f'process_index: {jax.process_index()}, ' f'process_count: {jax.process_count()}') work_unit.create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.job_log_dir, 'job_log_dir') # Start jax.profiler for Tensorboard and profiling in open source. if FLAGS.jax_profiler_port is not None: server = jax.profiler.start_server(FLAGS.jax_profiler_port) # pylint:disable=unused-variable if FLAGS.mode == 'train': train.train_and_evaluate( model_name=FLAGS.model, job_log_dir=FLAGS.job_log_dir, multi_host_checkpointing=FLAGS.multi_host_checkpointing, maybe_use_persistence_checkpointing=FLAGS. maybe_use_persistence_checkpointing, restore_checkpoint_dir=FLAGS.restore_checkpoint_dir, restore_checkpoint_step=FLAGS.restore_checkpoint_step, eval_on_test=FLAGS.eval_on_test, checkpoint_todelete_subdir=FLAGS.checkpoint_todelete_subdir) elif FLAGS.mode == 'eval': eval_lib.evaluate( model_name=FLAGS.model, job_log_dir=FLAGS.job_log_dir, multi_host_checkpointing=FLAGS.multi_host_checkpointing, maybe_use_persistence_checkpointing=FLAGS. maybe_use_persistence_checkpointing) elif FLAGS.mode == 'decode': eval_lib.decode( model_name=FLAGS.model, job_log_dir=FLAGS.job_log_dir, multi_host_checkpointing=FLAGS.multi_host_checkpointing, maybe_use_persistence_checkpointing=FLAGS. maybe_use_persistence_checkpointing, restore_checkpoint_dir=None, restore_checkpoint_step=None, continuous_decode=True, ) elif FLAGS.mode == 'decode_once': if not FLAGS.restore_checkpoint_dir: raise ValueError( '--mode=decode_once requires --restore_checkpoint_dir.') eval_lib.decode( model_name=FLAGS.model, job_log_dir=FLAGS.job_log_dir, multi_host_checkpointing=FLAGS.multi_host_checkpointing, maybe_use_persistence_checkpointing=FLAGS. maybe_use_persistence_checkpointing, restore_checkpoint_dir=FLAGS.restore_checkpoint_dir, restore_checkpoint_step=FLAGS.restore_checkpoint_step, continuous_decode=False, )