def __init__(self, cfg): """ Args: cfg (CfgNode): Use the custom checkpointer, which loads other backbone models with matching heuristics. """ # Assume these objects must be constructed in this order. model = self.build_model(cfg) optimizer = self.build_optimizer(cfg, model) data_loader = self.build_train_loader(cfg) # Load GAN model generator = esrgan_model.GeneratorRRDB(channels=3, filters=64, num_res_blocks=23).to(device) discriminator = esrgan_model.Discriminator( input_shape=(3, *hr_shape)).to(device) feature_extractor = esrgan_model.FeatureExtractor().to(device) feature_extractor.eval() # GAN losses criterion_GAN = torch.nn.BCEWithLogitsLoss().to(device) criterion_content = torch.nn.L1Loss().to(device) criterion_pixel = torch.nn.L1Loss().to(device) # GAN optimizers optimizer_G = torch.optim.Adam(generator.parameters(), lr=.0002, betas=(.9, .999)) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=.0002, betas=(.9, .999)) # For training, wrap with DDP. But don't need this for inference. if comm.get_world_size() > 1: model = DistributedDataParallel(model, device_ids=[comm.get_local_rank()], broadcast_buffers=False) super(DefaultTrainer, self).__init__(model, data_loader, optimizer, discriminator, generator, feature_extractor, optimizer_G, optimizer_D, criterion_pixel, criterion_content, criterion_GAN) self.scheduler = self.build_lr_scheduler(cfg, optimizer) # Assume no other objects need to be checkpointed. # We can later make it checkpoint the stateful hooks self.checkpointer = AdetCheckpointer( # Assume you want to save checkpoints together with logs/statistics model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=self.scheduler, ) self.start_iter = 0 self.max_iter = cfg.SOLVER.MAX_ITER self.cfg = cfg self.register_hooks(self.build_hooks())
def main(args): cfg = setup(args) if args.eval_only: model = Trainer.build_model(cfg) AdetCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( cfg.MODEL.WEIGHTS, resume=args.resume ) res = Trainer.test(cfg, model) # d2 defaults.py if comm.is_main_process(): verify_results(cfg, res) if cfg.TEST.AUG.ENABLED: res.update(Trainer.test_with_TTA(cfg, model)) return res """ If you'd like to do anything fancier than the standard training logic, consider writing your own training loop or subclassing the trainer. """ trainer = Trainer(cfg) trainer.resume_or_load(resume=args.resume) if cfg.TEST.AUG.ENABLED: trainer.register_hooks( [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))] ) return trainer.train()
def __init__(self, cfg): """ Args: cfg (CfgNode): Use the custom checkpointer, which loads other backbone models with matching heuristics. """ # Assume these objects must be constructed in this order. model = self.build_model(cfg) optimizer = self.build_optimizer(cfg, model) data_loader = self.build_train_loader(cfg) # For training, wrap with DDP. But don't need this for inference. if comm.get_world_size() > 1: model = DistributedDataParallel(model, device_ids=[comm.get_local_rank()], broadcast_buffers=False) super(DefaultTrainer, self).__init__(model, data_loader, optimizer) self.scheduler = self.build_lr_scheduler(cfg, optimizer) # Assume no other objects need to be checkpointed. # We can later make it checkpoint the stateful hooks self.checkpointer = AdetCheckpointer( # Assume you want to save checkpoints together with logs/statistics model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=self.scheduler, ) self.start_iter = 0 self.max_iter = cfg.SOLVER.MAX_ITER self.cfg = cfg self.register_hooks(self.build_hooks())
def resume_or_load(self, resume=True): if not isinstance(self.checkpointer, AdetCheckpointer): # support loading a few other backbones self.checkpointer = AdetCheckpointer( self.model, self.cfg.OUTPUT_DIR, optimizer=self.optimizer, scheduler=self.scheduler, ) super().resume_or_load(resume=resume)
def __init__(self, cfg): """ Args: cfg (CfgNode): Use the custom checkpointer, which loads other backbone models with matching heuristics. """ # Assume these objects must be constructed in this order. dprint("build model") model = self.build_model(cfg) dprint('build optimizer') optimizer = self.build_optimizer(cfg, model) dprint("build train loader") data_loader = self.build_train_loader(cfg) images_per_batch = cfg.SOLVER.IMS_PER_BATCH if isinstance(data_loader, AspectRatioGroupedDataset): dataset_len = len(data_loader.dataset.dataset) iters_per_epoch = dataset_len // images_per_batch else: dataset_len = len(data_loader.dataset) iters_per_epoch = dataset_len // images_per_batch self.iters_per_epoch = iters_per_epoch total_iters = cfg.SOLVER.TOTAL_EPOCHS * iters_per_epoch dprint("images_per_batch: ", images_per_batch) dprint("dataset length: ", dataset_len) dprint("iters per epoch: ", iters_per_epoch) dprint("total iters: ", total_iters) # For training, wrap with DDP. But don't need this for inference. if comm.get_world_size() > 1: model = DistributedDataParallel(model, device_ids=[comm.get_local_rank()], broadcast_buffers=False) super(DefaultTrainer, self).__init__(model, data_loader, optimizer) self.scheduler = self.build_lr_scheduler(cfg, optimizer, total_iters=total_iters) # Assume no other objects need to be checkpointed. # We can later make it checkpoint the stateful hooks self.checkpointer = AdetCheckpointer( # Assume you want to save checkpoints together with logs/statistics model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=self.scheduler, ) self.start_iter = 0 self.max_iter = total_iters # NOTE: ignore cfg.SOLVER.MAX_ITER self.cfg = cfg self.register_hooks(self.build_hooks())
def main(args): cfg = setup(args) model = Trainer.build_model(cfg) net = parsingNet(pretrained=False, backbone=model, cls_dim=(200 + 1, 18, 4), use_aux=False).cuda() AdetCheckpointer(net, save_dir=cfg.OUTPUT_DIR).resume_or_load( cfg.MODEL.WEIGHTS, resume=args.resume) eval_lane(net, 'culane', '/home/ghr/CULANEROOT', '/home/ghr/CULANEROOT/own_test_result', 200, False, False)
def main(args): cfg = setup(args) from detectron2.data.datasets import register_coco_instances register_coco_instances("surgery_train2", {}, "data/coco/annotations/instances_train2017.json", "data/coco/train2017") MetadataCatalog.get("surgery_train2").thing_classes = [ 'Cerebellum', 'CN8', 'CN5', 'CN7', 'SCA', 'AICA', 'SuperiorPetrosalVein', 'Vein', 'Brainstem', 'Suction', 'Bipolar', 'Forcep', 'BluntProbe', 'Drill', 'Kerrison', 'Cottonoid', 'Scissors', 'Unknown' ] DatasetCatalog.get("surgery_train2") register_coco_instances("surgery_val2", {}, "data/coco/annotations/instances_train2017.json", "data/coco/train2017") MetadataCatalog.get("surgery_val2").thing_classes = [ 'Cerebellum', 'CN8', 'CN5', 'CN7', 'SCA', 'AICA', 'SuperiorPetrosalVein', 'Vein', 'Brainstem', 'Suction', 'Bipolar', 'Forcep', 'BluntProbe', 'Drill', 'Kerrison', 'Cottonoid', 'Scissors', 'Unknown' ] DatasetCatalog.get("surgery_val2") if args.eval_only: model = Trainer.build_model(cfg) AdetCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( cfg.MODEL.WEIGHTS, resume=args.resume) res = Trainer.test(cfg, model) # d2 defaults.py if comm.is_main_process(): verify_results(cfg, res) if cfg.TEST.AUG.ENABLED: res.update(Trainer.test_with_TTA(cfg, model)) return res """ If you'd like to do anything fancier than the standard training logic, consider writing your own training loop or subclassing the trainer. """ trainer = Trainer(cfg) trainer.resume_or_load(resume=args.resume) if cfg.TEST.AUG.ENABLED: trainer.register_hooks([ hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model)) ]) return trainer.train()
def resume_or_load(self, resume=True): if not isinstance(self.checkpointer, AdetCheckpointer): # support loading a few other backbones self.checkpointer = AdetCheckpointer( self.model, self.cfg.OUTPUT_DIR, optimizer=self.optimizer, scheduler=self.scheduler, ) self.checkpointer.path_manager.register_handler( adet.utils.file_io.Detectron2Handler()) super().resume_or_load(resume=resume)
def build_hooks(self): """ Replace `DetectionCheckpointer` with `AdetCheckpointer`. Build a list of default hooks, including timing, evaluation, checkpointing, lr scheduling, precise BN, writing events. """ ret = super().build_hooks() for i in range(len(ret)): if isinstance(ret[i], hooks.PeriodicCheckpointer): self.checkpointer = AdetCheckpointer( self.model, self.cfg.OUTPUT_DIR, optimizer=self.optimizer, scheduler=self.scheduler, ) ret[i] = hooks.PeriodicCheckpointer(self.checkpointer, self.cfg.SOLVER.CHECKPOINT_PERIOD) return ret
def do_train(cfg, model, resume=False): model.train() optimizer = build_optimizer(cfg, model) scheduler = build_lr_scheduler(cfg, optimizer) # checkpointer = DetectionCheckpointer( # model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=scheduler # ) checkpointer = AdetCheckpointer(model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=scheduler) start_iter = (checkpointer.resume_or_load( cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1) max_iter = cfg.SOLVER.MAX_ITER periodic_checkpointer = PeriodicCheckpointer(checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter) writers = ([ CommonMetricPrinter(max_iter), JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")), TensorboardXWriter(cfg.OUTPUT_DIR), ] if comm.is_main_process() else []) # compared to "train_net.py", we do not support accurate timing and # precise BN here, because they are not trivial to implement in a small training loop data_loader = build_detection_train_loader(cfg) logger.info("Starting training from iteration {}".format(start_iter)) with EventStorage(start_iter) as storage: for data, iteration in zip(data_loader, range(start_iter, max_iter)): iteration = iteration + 1 storage.step() loss_dict = model(data) losses = sum(loss_dict.values()) assert torch.isfinite(losses).all(), loss_dict loss_dict_reduced = { k: v.item() for k, v in comm.reduce_dict(loss_dict).items() } losses_reduced = sum(loss for loss in loss_dict_reduced.values()) if comm.is_main_process(): storage.put_scalars(total_loss=losses_reduced, **loss_dict_reduced) optimizer.zero_grad() losses.backward() optimizer.step() storage.put_scalar("lr", optimizer.param_groups[0]["lr"], smoothing_hint=False) scheduler.step() if (cfg.TEST.EVAL_PERIOD > 0 and iteration % cfg.TEST.EVAL_PERIOD == 0 and iteration != max_iter): do_test(cfg, model) # Compared to "train_net.py", the test results are not dumped to EventStorage comm.synchronize() if iteration - start_iter > 5 and (iteration % 20 == 0 or iteration == max_iter): for writer in writers: writer.write() periodic_checkpointer.step(iteration)