"roi_heads.box_head.cls_score.weight": "roi_heads.box_predictor.cls_score.weight", "roi_heads.box_head.cls_score.bias": "roi_heads.box_predictor.cls_score.bias", "roi_heads.box_head.bbox_pred.weight": "roi_heads.box_predictor.bbox_pred.weight", "roi_heads.box_head.bbox_pred.bias": "roi_heads.box_predictor.bbox_pred.bias", } temp = torch.load("weight.pt") temp = {state_dict_map.get(k, k): v for k, v in temp.items()} print("Problems with:\n" + "\n".join([k for k in net.state_dict() if k not in temp])) net.load_state_dict({k: temp.get(k, v) for k, v in net.state_dict().items()}) #net.eval() targets = Instances((512, 512)) targets.gt_boxes = Boxes(torch.load("targets.pt")["boxes"]) targets.gt_classes = torch.load("targets.pt")["classes"] data = [{"image": torch.load("data.pt").cuda(), "instances": targets}] storage_4del = EventStorage(0).__enter__() torch.random.manual_seed(0) torch.cuda.manual_seed(0) with torch.no_grad(): torch.save(net(data), "res1.pt") storage_4del.__exit__(None, None, None)
class TrainingModule(LightningModule): def __init__(self, cfg): super().__init__() if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2 setup_logger() self.cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size()) self.storage: EventStorage = None self.model = build_model(self.cfg) self.start_iter = 0 self.max_iter = cfg.SOLVER.MAX_ITER def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: checkpoint["iteration"] = self.storage.iter def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]) -> None: self.start_iter = checkpointed_state["iteration"] self.storage.iter = self.start_iter def setup(self, stage: str): if self.cfg.MODEL.WEIGHTS: self.checkpointer = DetectionCheckpointer( # Assume you want to save checkpoints together with logs/statistics self.model, self.cfg.OUTPUT_DIR, ) logger.info(f"Load model weights from checkpoint: {self.cfg.MODEL.WEIGHTS}.") # Only load weights, use lightning checkpointing if you want to resume self.checkpointer.load(self.cfg.MODEL.WEIGHTS) self.iteration_timer = hooks.IterationTimer() self.iteration_timer.before_train() self.data_start = time.perf_counter() self.writers = None def training_step(self, batch, batch_idx): data_time = time.perf_counter() - self.data_start # Need to manually enter/exit since trainer may launch processes # This ideally belongs in setup, but setup seems to run before processes are spawned if self.storage is None: self.storage = EventStorage(0) self.storage.__enter__() self.iteration_timer.trainer = weakref.proxy(self) self.iteration_timer.before_step() self.writers = ( default_writers(self.cfg.OUTPUT_DIR, self.max_iter) if comm.is_main_process() else {} ) loss_dict = self.model(batch) SimpleTrainer.write_metrics(loss_dict, data_time) opt = self.optimizers() self.storage.put_scalar( "lr", opt.param_groups[self._best_param_group_id]["lr"], smoothing_hint=False ) self.iteration_timer.after_step() self.storage.step() # A little odd to put before step here, but it's the best way to get a proper timing self.iteration_timer.before_step() if self.storage.iter % 20 == 0: for writer in self.writers: writer.write() return sum(loss_dict.values()) def training_step_end(self, training_step_outpus): self.data_start = time.perf_counter() return training_step_outpus def training_epoch_end(self, training_step_outputs): self.iteration_timer.after_train() if comm.is_main_process(): self.checkpointer.save("model_final") for writer in self.writers: writer.write() writer.close() self.storage.__exit__(None, None, None) def _process_dataset_evaluation_results(self) -> OrderedDict: results = OrderedDict() for idx, dataset_name in enumerate(self.cfg.DATASETS.TEST): results[dataset_name] = self._evaluators[idx].evaluate() if comm.is_main_process(): print_csv_format(results[dataset_name]) if len(results) == 1: results = list(results.values())[0] return results def _reset_dataset_evaluators(self): self._evaluators = [] for dataset_name in self.cfg.DATASETS.TEST: evaluator = build_evaluator(self.cfg, dataset_name) evaluator.reset() self._evaluators.append(evaluator) def on_validation_epoch_start(self, _outputs): self._reset_dataset_evaluators() def validation_epoch_end(self, _outputs): results = self._process_dataset_evaluation_results(_outputs) flattened_results = flatten_results_dict(results) for k, v in flattened_results.items(): try: v = float(v) except Exception as e: raise ValueError( "[EvalHook] eval_function should return a nested dict of float. " "Got '{}: {}' instead.".format(k, v) ) from e self.storage.put_scalars(**flattened_results, smoothing_hint=False) def validation_step(self, batch, batch_idx: int, dataloader_idx: int = 0) -> None: if not isinstance(batch, List): batch = [batch] outputs = self.model(batch) self._evaluators[dataloader_idx].process(batch, outputs) def configure_optimizers(self): optimizer = build_optimizer(self.cfg, self.model) self._best_param_group_id = hooks.LRScheduler.get_best_param_group_id(optimizer) scheduler = build_lr_scheduler(self.cfg, optimizer) return [optimizer], [{"scheduler": scheduler, "interval": "step"}]