def __init__(self, cpt_dir: str, refresh: bool = True): """Initializes the checkpoint reader. Args: cpt_dir: The checkpoint directory refresh: Whether to refresh the checkpoints from disk. """ cpt_dir = fs.normpath(fs.abspath(cpt_dir)) self.pers_cpt_dir = fs.join(cpt_dir, "persistent") self.tmp_cpt_dir = fs.join(cpt_dir, "temp") self.tmp_cpts = [] self.pers_cpts = [] if refresh: self.refresh()
def save_state(self, state: bytes, step: int, persistent=False): if persistent: save_dir = self.pers_cpt_dir cpt_collection = self.pers_cpts else: save_dir = self.tmp_cpt_dir cpt_collection = self.tmp_cpts # Save in two stages to avoid corruption CM = CheckpointManager temp_path = fs.join(save_dir, f"temporary_state.{step:09}{CM._SUFFIX}") fs.write_bytes(temp_path, state) save_path = fs.join(save_dir, f"{CM._PREFIX}{step:09}{CM._SUFFIX}") fs.rename_file(temp_path, save_path) cpt_collection.append(_CheckPoint(save_path, step)) self.cleanup_temporary_checkpoints()
def _write_image(self, scene_id: str, scene_images: List[t.Tensor]): scene_id = scene_id.replace("/", "_") image = t.cat(scene_images, dim=0) pil_image = PIL.Image.fromarray(image.cpu().numpy()) pil_image.save(fl := io.BytesIO(), format="png") fn = fs.join(self.image_output_dir, f"img_{scene_id}.png") fs.make_dirs(fs.dirname(fn)) fs.write_bytes(fn, fl.getvalue())
def run_eval(self, output_dir: str, global_step: int, progress_bar_desc: str): exit_stack = contextlib.ExitStack() if dist_util.info().global_rank == 0: progress_bar = ui.ProgressBar(desc=progress_bar_desc, leave=False) exit_stack.push(progress_bar) else: progress_bar = None with exit_stack: dataset = self.data_manager.create_dataset(local_seed=global_step) loader_config = self.config.data.data_loader data_loader = create_distributed_loader( dataset=dataset, loader_config=loader_config, pad_data=False) # Main evaluation loop progress_report_fn = ui.progress_bar_report_fn( progress_bar, progress_multiplier=loader_config.batch_size) progress = ui.DistributedProgress( report_progress_fn=progress_report_fn) qualitative_results = eval_results_lib.QualitativeResults( self.config, dataset, output_dir) quantitative_results = eval_results_lib.QuantitativeResults( dataset.classes, self.config) voxel_config = self.config.data.voxelization_config data_resolution = dataclasses.astuple(voxel_config.resolution) for batch in progress(data_loader): batch = batched_example.batch([v.cuda() for v in batch]) batch = voxelize_batch(batch, voxel_config) with t.no_grad(): pmf = self.inference_fn(batch.input_image, batch.camera_transform, batch.v2x_transform, batch.grid_sampling_offset, data_resolution) quantitative_results.add_batch(pmf, batch) qualitative_results.add_batch(pmf, batch) quantitative_results.compute_metrics() if dist_util.info().global_rank == 0: voxel_metrics_path = fs.join(output_dir, "voxel_metrics.csv") quantitative_results.write_csv(voxel_metrics_path) quantitative_results.write_tensor_board_summary( self.tb_writer, global_step) log.debug("Writing results to disk...") qualitative_results.write_tensor_board_summary( self.tb_writer, global_step) log.debug("Finished evaluating") t.distributed.barrier() if dist_util.info().global_rank == 0: return quantitative_results.get_mean_iou() else: return None
def _get_checkpoints(cls, cpt_dir: str) -> List[_CheckPoint]: result = fs.glob_pattern( fs.join(cpt_dir, f"{cls._PREFIX}*{cls._SUFFIX}")) regex = rf"^{cls._PREFIX}(\d+){cls._SUFFIX}$" result = [(path, re.match(regex, fs.basename(path))) for path in result] result = [ _CheckPoint(path, int(m.group(1))) for path, m in result if m ] result = sorted(result, key=lambda v: v.step) return result
def __init__(self, eval_configs: List[configuration.RecurrentEvalConfig], state: state_lib.State, tb_root_dir: str, eval_root_dir: str): self.state = state self.eval_root_dir = eval_root_dir inference_fn = super_resolution.super_resolution_from_state(state) self.eval_runs = [ RecurrentEvals._EvalRun( misc_util.StepEvent(cfg.start_step, cfg.interval), cfg, pipeline.EvalPipeline(cfg.config, inference_fn=inference_fn, tb_dir=fs.join(tb_root_dir, cfg.config.name))) for cfg in eval_configs if cfg.start_step >= 0 ]
def load_from_npz(path: Text, meshes_dir: Text, load_extra_fields=False) -> Scene: """Loads an input example. Args: path: Path to NPZ with scene. meshes_dir: Path containing ShapeNet meshes. load_extra_fields: Whether to load extra fields that are not required for running the pipeline (e.g. texture coordinates) Returns: The loaded input example. """ scene_npz = NpzReader(path) mesh_paths = [ fs.join(meshes_dir, *v) + ".npz" for v in zip( scene_npz.list("mesh_labels"), scene_npz.list("mesh_filenames")) ] result = Scene( mesh_vertices=[], view_transform=scene_npz.tensor("view_transform", t.float32), o2w_transforms=scene_npz.tensor("mesh_object_to_world_transforms", t.float32), camera_transform=scene_npz.tensor("camera_transform", t.float32), mesh_labels=[v for v in scene_npz.list("mesh_labels")], opengl_image=_load_image(scene_npz.scalar("opengl_image")), pbrt_image=_load_image(scene_npz.scalar("pbrt_image")), mesh_visible_fractions=scene_npz.tensor("mesh_visible_fractions", t.float32), ) for mesh_path in mesh_paths: # noinspection PyTypeChecker mesh_npz = NpzReader(mesh_path) result.mesh_vertices.append(mesh_npz.tensor("vertices", t.float32)) if load_extra_fields: result.normals.append(mesh_npz.tensor("normals", t.float32)) result.material_ids.append(mesh_npz.tensor("material_ids", t.int32)) result.texcoords.append(mesh_npz.tensor("texcoords", t.float32)) result.diffuse_colors.append( mesh_npz.tensor("diffuse_colors", t.float32)) result.diffuse_texture_pngs.append( mesh_npz.scalar("diffuse_texture_pngs")) return result
def process_mesh(input_path: str, output_root: str): log.info(f"Processing {input_path}...") fn_parts = fs.splitall(input_path) label = fn_parts[-4] mesh_id = fn_parts[-3] mesh = read_obj(input_path) mesh = cleanup_mesh(mesh) npz_path = fs.join(output_root, label, mesh_id + ".npz") np.savez_compressed(fl := io.BytesIO(), vertices=mesh, label=label, mesh_id=mesh_id) fs.make_dirs(fs.dirname(npz_path)) fs.write_bytes(npz_path, fl.getvalue())
def run(self, prev_step: int, next_step: int, force=False) -> bool: """Runs scheduled evals and returns true if any eval has run""" has_run = False for eval_run in self.eval_runs: should_run = force or eval_run.ev_run_eval.trigger( prev_step, next_step) if not should_run: continue eval_pipe = eval_run.eval_pipe state = self.state state.model.eval() name = eval_pipe.config.name desc = f"Eval, name={name}, step={state.global_step}" output_dir = fs.join(self.eval_root_dir, name, f"{state.global_step:09}") iou = eval_pipe.run_eval(output_dir, state.global_step, desc) if iou is not None: log.info( f"Eval '{name}', step={state.global_step}, mIoU={iou:.3f}") has_run = True return has_run
def main(): args = cmd.parse_flags(Args) sn_root_dir = fs.normpath(fs.abspath(args.shapenet_root)) print("Reading mesh file names ...") obj_files = sorted( fs.glob_pattern(fs.join(sn_root_dir, "*/*/models/model_normalized.obj"))) out_dir = fs.normpath(fs.abspath(args.output_root)) print( f"Converting {len(obj_files)} meshes from {sn_root_dir} to {out_dir}") ray.init() process_fn = ray.remote(process_mesh) tasks = [process_fn.remote(v, out_dir) for v in obj_files] progress_bar = tqdm.tqdm(total=len(tasks)) while tasks: done, tasks = ray.wait(tasks, num_returns=len(tasks), timeout=0.3) progress_bar.update(len(done))
def main(): dist_util.init() t.cuda.set_device(dist_util.info().local_rank) ui.initialize_logging() pipeline.jit_compile_cpp_modules() args = cmd_line_flags.parse_flags(ProgramArgs) config, original_config = pipeline.read_cmd_line_config( args, configuration.TrainPipeline) output_dir = fs.normpath(fs.abspath(config.output_path)) tb_root_dir = fs.join(output_dir, "tb") eval_root_dir = fs.join(output_dir, "evals") cpt_dir = fs.join(output_dir, "cpt") train_pipe = pipeline.TrainPipeline(config.train, cpt_dir=cpt_dir, tb_dir=fs.join(tb_root_dir, "train")) state = train_pipe.create_or_load_state( extra_metadata=original_config.to_dict()) recurrent_evals = RecurrentEvals(config.eval, state, tb_root_dir, eval_root_dir) max_steps = config.train.max_steps train_forever = max_steps < 0 eta = None if train_forever else misc_util.Eta(state.global_step, max_steps) train_pipe.switch_model_to_train() ev_save_temp_cpt = misc_util.StepEvent(0, config.train.checkpoint_interval) ev_save_pers_cpt = misc_util.StepEvent( 0, config.train.persistent_checkpoint_interval) if dist_util.info().global_rank == 0: train_progress = ui.ProgressBar( desc="Training", bar_format=("{l_bar}{bar}| {n_fmt}/{total_fmt} " "[{elapsed}, {rate_fmt}{postfix}]"), total=(max_steps if not train_forever else None), initial=state.global_step) bar_context = train_progress else: train_progress = None bar_context = contextlib.ExitStack() with bar_context: if train_progress: train_progress.unpause() while True: # Perform a training step prev_step = state.global_step loss = train_pipe.train_step() if train_progress: train_progress.postfix = f"loss={loss:.3f}" if eta: train_progress.postfix += f", ETA {eta.cur_eta_str(state.global_step)}" train_progress.update(state.global_step - train_progress.n) next_step = state.global_step should_stop = not train_forever and next_step > max_steps # Save a checkpoint if dist_util.info().global_rank == 0: save_pers_cpt = (should_stop or ev_save_pers_cpt.trigger( prev_step, next_step)) if args.recurrent_evals: save_pers_cpt = (save_pers_cpt or recurrent_evals.persistent_cpt( prev_step, next_step)) save_tmp_cpt = ev_save_temp_cpt.trigger(prev_step, next_step) if save_tmp_cpt or save_pers_cpt: train_pipe.cpt_manager.save_state( state_lib.encode_state(state), step=state.global_step, persistent=save_pers_cpt) # Run evaluations if args.recurrent_evals or should_stop: eval_has_run = recurrent_evals.run(prev_step, next_step, force=should_stop) if eval_has_run: train_pipe.switch_model_to_train() if train_progress: train_progress.unpause() if should_stop: break