def __init__(self, cpt_dir: str, refresh: bool = True):
        """Initializes the checkpoint reader.

    Args:
      cpt_dir: The checkpoint directory
      refresh: Whether to refresh the checkpoints from disk.
    """
        cpt_dir = fs.normpath(fs.abspath(cpt_dir))
        self.pers_cpt_dir = fs.join(cpt_dir, "persistent")
        self.tmp_cpt_dir = fs.join(cpt_dir, "temp")
        self.tmp_cpts = []
        self.pers_cpts = []
        if refresh:
            self.refresh()
    def save_state(self, state: bytes, step: int, persistent=False):
        if persistent:
            save_dir = self.pers_cpt_dir
            cpt_collection = self.pers_cpts
        else:
            save_dir = self.tmp_cpt_dir
            cpt_collection = self.tmp_cpts

        # Save in two stages to avoid corruption
        CM = CheckpointManager
        temp_path = fs.join(save_dir, f"temporary_state.{step:09}{CM._SUFFIX}")
        fs.write_bytes(temp_path, state)
        save_path = fs.join(save_dir, f"{CM._PREFIX}{step:09}{CM._SUFFIX}")
        fs.rename_file(temp_path, save_path)
        cpt_collection.append(_CheckPoint(save_path, step))

        self.cleanup_temporary_checkpoints()
 def _write_image(self, scene_id: str, scene_images: List[t.Tensor]):
     scene_id = scene_id.replace("/", "_")
     image = t.cat(scene_images, dim=0)
     pil_image = PIL.Image.fromarray(image.cpu().numpy())
     pil_image.save(fl := io.BytesIO(), format="png")
     fn = fs.join(self.image_output_dir, f"img_{scene_id}.png")
     fs.make_dirs(fs.dirname(fn))
     fs.write_bytes(fn, fl.getvalue())
Exemple #4
0
    def run_eval(self, output_dir: str, global_step: int,
                 progress_bar_desc: str):
        exit_stack = contextlib.ExitStack()
        if dist_util.info().global_rank == 0:
            progress_bar = ui.ProgressBar(desc=progress_bar_desc, leave=False)
            exit_stack.push(progress_bar)
        else:
            progress_bar = None
        with exit_stack:
            dataset = self.data_manager.create_dataset(local_seed=global_step)
            loader_config = self.config.data.data_loader
            data_loader = create_distributed_loader(
                dataset=dataset, loader_config=loader_config, pad_data=False)

            # Main evaluation loop
            progress_report_fn = ui.progress_bar_report_fn(
                progress_bar, progress_multiplier=loader_config.batch_size)
            progress = ui.DistributedProgress(
                report_progress_fn=progress_report_fn)
            qualitative_results = eval_results_lib.QualitativeResults(
                self.config, dataset, output_dir)
            quantitative_results = eval_results_lib.QuantitativeResults(
                dataset.classes, self.config)
            voxel_config = self.config.data.voxelization_config
            data_resolution = dataclasses.astuple(voxel_config.resolution)

            for batch in progress(data_loader):
                batch = batched_example.batch([v.cuda() for v in batch])
                batch = voxelize_batch(batch, voxel_config)

                with t.no_grad():
                    pmf = self.inference_fn(batch.input_image,
                                            batch.camera_transform,
                                            batch.v2x_transform,
                                            batch.grid_sampling_offset,
                                            data_resolution)
                quantitative_results.add_batch(pmf, batch)
                qualitative_results.add_batch(pmf, batch)

            quantitative_results.compute_metrics()
            if dist_util.info().global_rank == 0:
                voxel_metrics_path = fs.join(output_dir, "voxel_metrics.csv")
                quantitative_results.write_csv(voxel_metrics_path)
                quantitative_results.write_tensor_board_summary(
                    self.tb_writer, global_step)

            log.debug("Writing results to disk...")
            qualitative_results.write_tensor_board_summary(
                self.tb_writer, global_step)
            log.debug("Finished evaluating")
            t.distributed.barrier()

            if dist_util.info().global_rank == 0:
                return quantitative_results.get_mean_iou()
            else:
                return None
 def _get_checkpoints(cls, cpt_dir: str) -> List[_CheckPoint]:
     result = fs.glob_pattern(
         fs.join(cpt_dir, f"{cls._PREFIX}*{cls._SUFFIX}"))
     regex = rf"^{cls._PREFIX}(\d+){cls._SUFFIX}$"
     result = [(path, re.match(regex, fs.basename(path)))
               for path in result]
     result = [
         _CheckPoint(path, int(m.group(1))) for path, m in result if m
     ]
     result = sorted(result, key=lambda v: v.step)
     return result
Exemple #6
0
 def __init__(self, eval_configs: List[configuration.RecurrentEvalConfig],
              state: state_lib.State, tb_root_dir: str, eval_root_dir: str):
     self.state = state
     self.eval_root_dir = eval_root_dir
     inference_fn = super_resolution.super_resolution_from_state(state)
     self.eval_runs = [
         RecurrentEvals._EvalRun(
             misc_util.StepEvent(cfg.start_step, cfg.interval), cfg,
             pipeline.EvalPipeline(cfg.config,
                                   inference_fn=inference_fn,
                                   tb_dir=fs.join(tb_root_dir,
                                                  cfg.config.name)))
         for cfg in eval_configs if cfg.start_step >= 0
     ]
Exemple #7
0
def load_from_npz(path: Text,
                  meshes_dir: Text,
                  load_extra_fields=False) -> Scene:
    """Loads an input example.

  Args:
    path: Path to NPZ with scene.
    meshes_dir: Path containing ShapeNet meshes.
    load_extra_fields: Whether to load extra fields that are not required for
      running the pipeline (e.g. texture coordinates)

  Returns:
    The loaded input example.

  """
    scene_npz = NpzReader(path)
    mesh_paths = [
        fs.join(meshes_dir, *v) + ".npz" for v in zip(
            scene_npz.list("mesh_labels"), scene_npz.list("mesh_filenames"))
    ]

    result = Scene(
        mesh_vertices=[],
        view_transform=scene_npz.tensor("view_transform", t.float32),
        o2w_transforms=scene_npz.tensor("mesh_object_to_world_transforms",
                                        t.float32),
        camera_transform=scene_npz.tensor("camera_transform", t.float32),
        mesh_labels=[v for v in scene_npz.list("mesh_labels")],
        opengl_image=_load_image(scene_npz.scalar("opengl_image")),
        pbrt_image=_load_image(scene_npz.scalar("pbrt_image")),
        mesh_visible_fractions=scene_npz.tensor("mesh_visible_fractions",
                                                t.float32),
    )

    for mesh_path in mesh_paths:
        # noinspection PyTypeChecker
        mesh_npz = NpzReader(mesh_path)
        result.mesh_vertices.append(mesh_npz.tensor("vertices", t.float32))

        if load_extra_fields:
            result.normals.append(mesh_npz.tensor("normals", t.float32))
            result.material_ids.append(mesh_npz.tensor("material_ids",
                                                       t.int32))
            result.texcoords.append(mesh_npz.tensor("texcoords", t.float32))
            result.diffuse_colors.append(
                mesh_npz.tensor("diffuse_colors", t.float32))
            result.diffuse_texture_pngs.append(
                mesh_npz.scalar("diffuse_texture_pngs"))
    return result
Exemple #8
0
def process_mesh(input_path: str, output_root: str):
    log.info(f"Processing {input_path}...")
    fn_parts = fs.splitall(input_path)
    label = fn_parts[-4]
    mesh_id = fn_parts[-3]

    mesh = read_obj(input_path)
    mesh = cleanup_mesh(mesh)

    npz_path = fs.join(output_root, label, mesh_id + ".npz")

    np.savez_compressed(fl := io.BytesIO(),
                        vertices=mesh,
                        label=label,
                        mesh_id=mesh_id)
    fs.make_dirs(fs.dirname(npz_path))
    fs.write_bytes(npz_path, fl.getvalue())
Exemple #9
0
 def run(self, prev_step: int, next_step: int, force=False) -> bool:
     """Runs scheduled evals and returns true if any eval has run"""
     has_run = False
     for eval_run in self.eval_runs:
         should_run = force or eval_run.ev_run_eval.trigger(
             prev_step, next_step)
         if not should_run:
             continue
         eval_pipe = eval_run.eval_pipe
         state = self.state
         state.model.eval()
         name = eval_pipe.config.name
         desc = f"Eval, name={name}, step={state.global_step}"
         output_dir = fs.join(self.eval_root_dir, name,
                              f"{state.global_step:09}")
         iou = eval_pipe.run_eval(output_dir, state.global_step, desc)
         if iou is not None:
             log.info(
                 f"Eval '{name}', step={state.global_step}, mIoU={iou:.3f}")
         has_run = True
     return has_run
Exemple #10
0
def main():
    args = cmd.parse_flags(Args)

    sn_root_dir = fs.normpath(fs.abspath(args.shapenet_root))
    print("Reading mesh file names ...")
    obj_files = sorted(
        fs.glob_pattern(fs.join(sn_root_dir,
                                "*/*/models/model_normalized.obj")))

    out_dir = fs.normpath(fs.abspath(args.output_root))

    print(
        f"Converting {len(obj_files)} meshes from {sn_root_dir} to {out_dir}")

    ray.init()
    process_fn = ray.remote(process_mesh)
    tasks = [process_fn.remote(v, out_dir) for v in obj_files]

    progress_bar = tqdm.tqdm(total=len(tasks))
    while tasks:
        done, tasks = ray.wait(tasks, num_returns=len(tasks), timeout=0.3)
        progress_bar.update(len(done))
Exemple #11
0
def main():
    dist_util.init()
    t.cuda.set_device(dist_util.info().local_rank)
    ui.initialize_logging()
    pipeline.jit_compile_cpp_modules()

    args = cmd_line_flags.parse_flags(ProgramArgs)
    config, original_config = pipeline.read_cmd_line_config(
        args, configuration.TrainPipeline)

    output_dir = fs.normpath(fs.abspath(config.output_path))
    tb_root_dir = fs.join(output_dir, "tb")
    eval_root_dir = fs.join(output_dir, "evals")
    cpt_dir = fs.join(output_dir, "cpt")

    train_pipe = pipeline.TrainPipeline(config.train,
                                        cpt_dir=cpt_dir,
                                        tb_dir=fs.join(tb_root_dir, "train"))
    state = train_pipe.create_or_load_state(
        extra_metadata=original_config.to_dict())
    recurrent_evals = RecurrentEvals(config.eval, state, tb_root_dir,
                                     eval_root_dir)
    max_steps = config.train.max_steps
    train_forever = max_steps < 0
    eta = None if train_forever else misc_util.Eta(state.global_step,
                                                   max_steps)
    train_pipe.switch_model_to_train()
    ev_save_temp_cpt = misc_util.StepEvent(0, config.train.checkpoint_interval)
    ev_save_pers_cpt = misc_util.StepEvent(
        0, config.train.persistent_checkpoint_interval)

    if dist_util.info().global_rank == 0:
        train_progress = ui.ProgressBar(
            desc="Training",
            bar_format=("{l_bar}{bar}| {n_fmt}/{total_fmt} "
                        "[{elapsed}, {rate_fmt}{postfix}]"),
            total=(max_steps if not train_forever else None),
            initial=state.global_step)
        bar_context = train_progress
    else:
        train_progress = None
        bar_context = contextlib.ExitStack()

    with bar_context:
        if train_progress:
            train_progress.unpause()

        while True:
            # Perform a training step
            prev_step = state.global_step
            loss = train_pipe.train_step()
            if train_progress:
                train_progress.postfix = f"loss={loss:.3f}"
                if eta:
                    train_progress.postfix += f", ETA {eta.cur_eta_str(state.global_step)}"
                train_progress.update(state.global_step - train_progress.n)
            next_step = state.global_step

            should_stop = not train_forever and next_step > max_steps

            # Save a checkpoint
            if dist_util.info().global_rank == 0:
                save_pers_cpt = (should_stop or ev_save_pers_cpt.trigger(
                    prev_step, next_step))
                if args.recurrent_evals:
                    save_pers_cpt = (save_pers_cpt
                                     or recurrent_evals.persistent_cpt(
                                         prev_step, next_step))
                save_tmp_cpt = ev_save_temp_cpt.trigger(prev_step, next_step)

                if save_tmp_cpt or save_pers_cpt:
                    train_pipe.cpt_manager.save_state(
                        state_lib.encode_state(state),
                        step=state.global_step,
                        persistent=save_pers_cpt)

            # Run evaluations
            if args.recurrent_evals or should_stop:
                eval_has_run = recurrent_evals.run(prev_step,
                                                   next_step,
                                                   force=should_stop)
                if eval_has_run:
                    train_pipe.switch_model_to_train()
                    if train_progress:
                        train_progress.unpause()

            if should_stop:
                break