예제 #1
0
    def run_eval(self, output_dir: str, global_step: int,
                 progress_bar_desc: str):
        exit_stack = contextlib.ExitStack()
        if dist_util.info().global_rank == 0:
            progress_bar = ui.ProgressBar(desc=progress_bar_desc, leave=False)
            exit_stack.push(progress_bar)
        else:
            progress_bar = None
        with exit_stack:
            dataset = self.data_manager.create_dataset(local_seed=global_step)
            loader_config = self.config.data.data_loader
            data_loader = create_distributed_loader(
                dataset=dataset, loader_config=loader_config, pad_data=False)

            # Main evaluation loop
            progress_report_fn = ui.progress_bar_report_fn(
                progress_bar, progress_multiplier=loader_config.batch_size)
            progress = ui.DistributedProgress(
                report_progress_fn=progress_report_fn)
            qualitative_results = eval_results_lib.QualitativeResults(
                self.config, dataset, output_dir)
            quantitative_results = eval_results_lib.QuantitativeResults(
                dataset.classes, self.config)
            voxel_config = self.config.data.voxelization_config
            data_resolution = dataclasses.astuple(voxel_config.resolution)

            for batch in progress(data_loader):
                batch = batched_example.batch([v.cuda() for v in batch])
                batch = voxelize_batch(batch, voxel_config)

                with t.no_grad():
                    pmf = self.inference_fn(batch.input_image,
                                            batch.camera_transform,
                                            batch.v2x_transform,
                                            batch.grid_sampling_offset,
                                            data_resolution)
                quantitative_results.add_batch(pmf, batch)
                qualitative_results.add_batch(pmf, batch)

            quantitative_results.compute_metrics()
            if dist_util.info().global_rank == 0:
                voxel_metrics_path = fs.join(output_dir, "voxel_metrics.csv")
                quantitative_results.write_csv(voxel_metrics_path)
                quantitative_results.write_tensor_board_summary(
                    self.tb_writer, global_step)

            log.debug("Writing results to disk...")
            qualitative_results.write_tensor_board_summary(
                self.tb_writer, global_step)
            log.debug("Finished evaluating")
            t.distributed.barrier()

            if dist_util.info().global_rank == 0:
                return quantitative_results.get_mean_iou()
            else:
                return None
예제 #2
0
def jit_compile_cpp_modules():
    if dist_util.info().local_rank == 0:
        log.info("JIT compiling C++ modules... "
                 "(this might take a few minutes on the first run)")
        fill_voxels.get_module()
        log.info("Done JIT compiling")
    t.distributed.barrier(dist_util.local_nccl_group())
예제 #3
0
    def _process_batch(self, batch: List[dataset_lib.DatasetElement]) -> float:
        """Processes one batch and returns the loss."""
        # Keep the batch on the CPU to reduce memory fragmentation
        batch = batched_example.batch(batch)
        batch = voxelize_batch(batch, self.config.data.voxelization_config)
        v2s = batch.camera_transform @ batch.v2x_transform.inverse()

        state = self._state
        state.optimizer.zero_grad()
        logits = self.ddp_model(batch.input_image.cuda(), v2s.cuda(),
                                batch.grid_sampling_offset.cuda())
        assert batch.grid.device.type == "cuda"
        loss = self.loss_fn(batch.grid.to(t.int64), logits)
        loss.backward()
        state.optimizer.step()

        prev_step = state.global_step
        state.global_step += self.step_size

        cpu_loss = float(loss.detach().cpu().item())
        if dist_util.info().global_rank == 0:
            if self.ev_log_to_tb.trigger(prev_step, state.global_step):
                self.tb_writer.add_scalar("loss", cpu_loss, state.global_step)
                self.tb_writer.flush()
        return cpu_loss
예제 #4
0
    def __init__(self, config: configuration.TrainConfig, cpt_dir: str,
                 tb_dir: str):
        self.config = config

        loss_fns = {
            configuration.TaskType.FG_BG: losses.iou_fgbg,
            configuration.TaskType.SEMANTIC: losses.xent_times_iou_agnostic
        }
        self.loss_fn = loss_fns[config.data.voxelization_config.task_type]

        dist_info = dist_util.info()
        if dist_info.global_rank == 0:
            self.tb_writer = t.utils.tensorboard.SummaryWriter(tb_dir)
            self.ev_log_to_tb = misc_util.StepEvent(
                0, config.tensorboard_log_interval)

        self.data_manager = DatasetManager(config.data)
        self.step_size = (dist_info.global_world_size *
                          self.config.data.data_loader.batch_size)
        self.cpt_dir = cpt_dir

        self.ddp_model: Optional[t.nn.parallel.DistributedDataParallel] = None
        self._step_it = None
        self._state: Optional[state_lib.State] = None
        self.cpt_manager: Optional[cpt_manager_lib.CheckpointManager] = None
예제 #5
0
    def create_or_load_state(self, extra_metadata: Any) -> state_lib.State:
        dist_info = dist_util.info()
        if dist_info.global_rank == 0:
            self.cpt_manager = cpt_manager_lib.CheckpointManager(self.cpt_dir)
            if not self.cpt_manager.has_checkpoints():
                log.info("Initializing training from scratch")
                state = state_lib.create_initial_state(
                    self.config, len(self.data_manager.classes))
                state = dataclasses.replace(state,
                                            extra_metadata=extra_metadata)
                self.cpt_manager.save_state(state_lib.encode_state(state),
                                            step=0,
                                            persistent=True)
            cpt_reader = self.cpt_manager
        t.distributed.barrier()
        if dist_info.global_rank != 0:
            self.cpt_manager = None
            cpt_reader = cpt_manager_lib.CheckpointReader(self.cpt_dir)

        # noinspection PyUnboundLocalVariable
        raw_state = cpt_reader.read_last_checkpoint()
        self._state = state_lib.decode_state(raw_state,
                                             f"cuda:{dist_info.local_rank}")
        log.info(f"Starting training from step={self._state.global_step}")

        self.ddp_model = t.nn.parallel.DistributedDataParallel(
            self._state.model, device_ids=[dist_info.local_rank])

        return self._state
예제 #6
0
 def compute_metrics(self):
     t.distributed.reduce(self.confusion_matrix,
                          0,
                          op=t.distributed.ReduceOp.SUM)
     if dist_util.info().global_rank == 0:
         self.voxel_metrics_df = compute_voxel_metrics(
             self.confusion_matrix, self.classes)
예제 #7
0
def create_distributed_loader(dataset: dataset_lib.CoReNetDataset,
                              loader_config: configuration.DataLoaderConfig,
                              pad_data=False) -> t.utils.data.DataLoader:
    """Creates a distributed-aware data loader for a dataset."""
    dist_info = dist_util.info()
    sampler = dist_util.DistributedSampler(
        dataset,
        global_rank=dist_info.global_rank,
        global_world_size=dist_info.global_world_size,
        pad_data=pad_data)
    ctx = (t.multiprocessing.get_context("fork")
           if loader_config.num_data_workers > 0 else None)

    # noinspection PyArgumentList
    loader_config = t.utils.data.DataLoader(
        dataset,
        batch_size=loader_config.batch_size,
        num_workers=loader_config.num_data_workers,
        pin_memory=True,
        collate_fn=lambda v: v,
        sampler=sampler,
        multiprocessing_context=ctx,
        drop_last=False,
        prefetch_factor=loader_config.prefetch_factor)

    return loader_config
예제 #8
0
 def _create_structs():
   dist_info = dist_util.info()
   cls = DistributedProgress
   if dist_info.global_rank == 0:
     if cls._max_progress is not None:
       raise ValueError("Multiple distributed progress bars not supported!")
     world_size = dist_info.global_world_size
     cls._max_progress = t.ones(world_size, dtype=t.int64, device="cpu") * -1
     cls._current_progress = t.zeros(world_size, dtype=t.int64, device="cpu")
   t.distributed.barrier()
예제 #9
0
def main():
    dist_util.init()
    t.cuda.set_device(dist_util.info().local_rank)
    ui.initialize_logging()
    pipeline.jit_compile_cpp_modules()
    tf_model_lib.setup_tensorflow(dist_util.info().local_rank)

    args = cmd_line_flags.parse_flags(ProgramArgs)
    config, _ = pipeline.read_cmd_line_config(
        args, configuration.TfModelEvalPipeline)

    inference_fn = tf_model_lib.super_resolution_from_tf_model(
        config.frozen_graph_path)

    eval_pipe = pipeline.EvalPipeline(config.eval_config,
                                      inference_fn=inference_fn,
                                      tb_dir=None)
    iou = eval_pipe.run_eval(config.output_path, -1, "TF Model Eval")
    if iou is not None:
        log.info(f"Evaluation complete, mIoU={iou:.3f}")
예제 #10
0
    def __init__(self, config: configuration.EvalConfig,
                 inference_fn: InferenceFn, tb_dir: Optional[str]):
        self.config = config

        if dist_util.info().global_rank == 0 and tb_dir:
            self.tb_writer = t.utils.tensorboard.SummaryWriter(tb_dir)
        else:
            self.tb_writer = None
        self.data_manager = DatasetManager(config.data, global_seed=0x4F1A2379)

        self.inference_fn = inference_fn
예제 #11
0
  def __init__(
      self,
      report_progress_fn: Optional[DistributedProgressReportFn] = None
  ):
    """Initializes the progress reporter.

    Args:
      report_progress_fn: Called to report progress. Arguments are
        (current_progress: int, max_progress: int, worker_status: str). If not
        specified, will use a ProgressBar() bar to report progress.
    """
    if dist_util.info().global_rank == 0:
      if report_progress_fn is None:
        report_progress_fn = progress_bar_report_fn(ProgressBar())
      self.report_progress_fn = report_progress_fn
    elif report_progress_fn is not None:
      raise ValueError("Only rank 0 can specify a progress reporting function.")
예제 #12
0
 def write_tensor_board_summary(
         self, sw: Optional[torch.utils.tensorboard.SummaryWriter],
         global_step: int):
     """Writes the saved images to a tensorboard summary."""
     all_results = dist_util.gather(self.tb_results, 0)
     if dist_util.info().global_rank == 0 and sw:
         all_results = {k: v for d in all_results for k, v in d.items()}
         all_results = sorted(all_results.items(), key=lambda v: v[0])
         for rec_idx, (scene_id, scene_images) in enumerate(all_results):
             for cam_idx, image in enumerate(scene_images):
                 assert (len(image.shape) == 3 and image.shape[-1] == 3
                         and image.dtype == t.uint8)
                 image = image.permute([2, 0, 1])
                 sw.add_image(f"rec_{rec_idx}/cam_{cam_idx}", image,
                              global_step)
     else:
         assert sw is None
예제 #13
0
  def __call__(self, seq: Sequence[T]) -> Iterable[T]:
    ev_rpc_report = misc_util.TimedEvent(1)
    ev_call_report_fn = misc_util.TimedEvent(0.2)

    rank = dist_util.info().global_rank
    master_node = dist_util.get_node_name(0)
    seq_len = len(seq)
    C = DistributedProgress
    self._create_structs()
    t.distributed.rpc.rpc_sync(master_node, C._set_max_progress,
                               (rank, seq_len))
    for progress, v in enumerate(seq):
      if ev_rpc_report.trigger():
        t.distributed.rpc.rpc_sync(master_node, C._update_progress,
                                   (rank, progress))
      if rank == 0 and ev_call_report_fn.trigger():
        self._report_progress()
      yield v
    t.distributed.rpc.rpc_sync(master_node, C._update_progress, (rank, seq_len))
    async_op = t.distributed.barrier(async_op=True)
    if rank == 0:
      self._flush(async_op)
    else:
      async_op.wait()
예제 #14
0
def main():
    dist_util.init()
    t.cuda.set_device(dist_util.info().local_rank)
    ui.initialize_logging()
    pipeline.jit_compile_cpp_modules()

    args = cmd_line_flags.parse_flags(ProgramArgs)
    config, original_config = pipeline.read_cmd_line_config(
        args, configuration.TrainPipeline)

    output_dir = fs.normpath(fs.abspath(config.output_path))
    tb_root_dir = fs.join(output_dir, "tb")
    eval_root_dir = fs.join(output_dir, "evals")
    cpt_dir = fs.join(output_dir, "cpt")

    train_pipe = pipeline.TrainPipeline(config.train,
                                        cpt_dir=cpt_dir,
                                        tb_dir=fs.join(tb_root_dir, "train"))
    state = train_pipe.create_or_load_state(
        extra_metadata=original_config.to_dict())
    recurrent_evals = RecurrentEvals(config.eval, state, tb_root_dir,
                                     eval_root_dir)
    max_steps = config.train.max_steps
    train_forever = max_steps < 0
    eta = None if train_forever else misc_util.Eta(state.global_step,
                                                   max_steps)
    train_pipe.switch_model_to_train()
    ev_save_temp_cpt = misc_util.StepEvent(0, config.train.checkpoint_interval)
    ev_save_pers_cpt = misc_util.StepEvent(
        0, config.train.persistent_checkpoint_interval)

    if dist_util.info().global_rank == 0:
        train_progress = ui.ProgressBar(
            desc="Training",
            bar_format=("{l_bar}{bar}| {n_fmt}/{total_fmt} "
                        "[{elapsed}, {rate_fmt}{postfix}]"),
            total=(max_steps if not train_forever else None),
            initial=state.global_step)
        bar_context = train_progress
    else:
        train_progress = None
        bar_context = contextlib.ExitStack()

    with bar_context:
        if train_progress:
            train_progress.unpause()

        while True:
            # Perform a training step
            prev_step = state.global_step
            loss = train_pipe.train_step()
            if train_progress:
                train_progress.postfix = f"loss={loss:.3f}"
                if eta:
                    train_progress.postfix += f", ETA {eta.cur_eta_str(state.global_step)}"
                train_progress.update(state.global_step - train_progress.n)
            next_step = state.global_step

            should_stop = not train_forever and next_step > max_steps

            # Save a checkpoint
            if dist_util.info().global_rank == 0:
                save_pers_cpt = (should_stop or ev_save_pers_cpt.trigger(
                    prev_step, next_step))
                if args.recurrent_evals:
                    save_pers_cpt = (save_pers_cpt
                                     or recurrent_evals.persistent_cpt(
                                         prev_step, next_step))
                save_tmp_cpt = ev_save_temp_cpt.trigger(prev_step, next_step)

                if save_tmp_cpt or save_pers_cpt:
                    train_pipe.cpt_manager.save_state(
                        state_lib.encode_state(state),
                        step=state.global_step,
                        persistent=save_pers_cpt)

            # Run evaluations
            if args.recurrent_evals or should_stop:
                eval_has_run = recurrent_evals.run(prev_step,
                                                   next_step,
                                                   force=should_stop)
                if eval_has_run:
                    train_pipe.switch_model_to_train()
                    if train_progress:
                        train_progress.unpause()

            if should_stop:
                break