Esempio n. 1
0
    def __init__(self, cfg):
        """
        Args:
            cfg (CfgNode):
        Use the custom checkpointer, which loads other backbone models
        with matching heuristics.
        """
        # Assume these objects must be constructed in this order.
        model = self.build_model(cfg)
        optimizer = self.build_optimizer(cfg, model)
        data_loader = self.build_train_loader(cfg)

        # Load GAN model
        generator = esrgan_model.GeneratorRRDB(channels=3,
                                               filters=64,
                                               num_res_blocks=23).to(device)
        discriminator = esrgan_model.Discriminator(
            input_shape=(3, *hr_shape)).to(device)
        feature_extractor = esrgan_model.FeatureExtractor().to(device)
        feature_extractor.eval()

        # GAN losses
        criterion_GAN = torch.nn.BCEWithLogitsLoss().to(device)
        criterion_content = torch.nn.L1Loss().to(device)
        criterion_pixel = torch.nn.L1Loss().to(device)

        # GAN optimizers
        optimizer_G = torch.optim.Adam(generator.parameters(),
                                       lr=.0002,
                                       betas=(.9, .999))
        optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                       lr=.0002,
                                       betas=(.9, .999))

        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            model = DistributedDataParallel(model,
                                            device_ids=[comm.get_local_rank()],
                                            broadcast_buffers=False)
        super(DefaultTrainer,
              self).__init__(model, data_loader, optimizer, discriminator,
                             generator, feature_extractor, optimizer_G,
                             optimizer_D, criterion_pixel, criterion_content,
                             criterion_GAN)

        self.scheduler = self.build_lr_scheduler(cfg, optimizer)
        # Assume no other objects need to be checkpointed.
        # We can later make it checkpoint the stateful hooks
        self.checkpointer = AdetCheckpointer(
            # Assume you want to save checkpoints together with logs/statistics
            model,
            cfg.OUTPUT_DIR,
            optimizer=optimizer,
            scheduler=self.scheduler,
        )
        self.start_iter = 0
        self.max_iter = cfg.SOLVER.MAX_ITER
        self.cfg = cfg

        self.register_hooks(self.build_hooks())
Esempio n. 2
0
def main(args):
    cfg = setup(args)

    if args.eval_only:
        model = Trainer.build_model(cfg)
        AdetCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume
        )
        res = Trainer.test(cfg, model) # d2 defaults.py
        if comm.is_main_process():
            verify_results(cfg, res)
        if cfg.TEST.AUG.ENABLED:
            res.update(Trainer.test_with_TTA(cfg, model))
        return res

    """
    If you'd like to do anything fancier than the standard training logic,
    consider writing your own training loop or subclassing the trainer.
    """
    trainer = Trainer(cfg)
    trainer.resume_or_load(resume=args.resume)
    if cfg.TEST.AUG.ENABLED:
        trainer.register_hooks(
            [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))]
        )
    return trainer.train()
Esempio n. 3
0
    def __init__(self, cfg):
        """
        Args:
            cfg (CfgNode):
        Use the custom checkpointer, which loads other backbone models
        with matching heuristics.
        """
        # Assume these objects must be constructed in this order.
        model = self.build_model(cfg)
        optimizer = self.build_optimizer(cfg, model)
        data_loader = self.build_train_loader(cfg)

        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            model = DistributedDataParallel(model,
                                            device_ids=[comm.get_local_rank()],
                                            broadcast_buffers=False)
        super(DefaultTrainer, self).__init__(model, data_loader, optimizer)

        self.scheduler = self.build_lr_scheduler(cfg, optimizer)
        # Assume no other objects need to be checkpointed.
        # We can later make it checkpoint the stateful hooks
        self.checkpointer = AdetCheckpointer(
            # Assume you want to save checkpoints together with logs/statistics
            model,
            cfg.OUTPUT_DIR,
            optimizer=optimizer,
            scheduler=self.scheduler,
        )
        self.start_iter = 0
        self.max_iter = cfg.SOLVER.MAX_ITER
        self.cfg = cfg

        self.register_hooks(self.build_hooks())
Esempio n. 4
0
 def resume_or_load(self, resume=True):
     if not isinstance(self.checkpointer, AdetCheckpointer):
         # support loading a few other backbones
         self.checkpointer = AdetCheckpointer(
             self.model,
             self.cfg.OUTPUT_DIR,
             optimizer=self.optimizer,
             scheduler=self.scheduler,
         )
     super().resume_or_load(resume=resume)
Esempio n. 5
0
    def __init__(self, cfg):
        """
        Args:
            cfg (CfgNode):
        Use the custom checkpointer, which loads other backbone models
        with matching heuristics.
        """
        # Assume these objects must be constructed in this order.
        dprint("build model")
        model = self.build_model(cfg)
        dprint('build optimizer')
        optimizer = self.build_optimizer(cfg, model)
        dprint("build train loader")
        data_loader = self.build_train_loader(cfg)

        images_per_batch = cfg.SOLVER.IMS_PER_BATCH
        if isinstance(data_loader, AspectRatioGroupedDataset):
            dataset_len = len(data_loader.dataset.dataset)
            iters_per_epoch = dataset_len // images_per_batch
        else:
            dataset_len = len(data_loader.dataset)
            iters_per_epoch = dataset_len // images_per_batch

        self.iters_per_epoch = iters_per_epoch
        total_iters = cfg.SOLVER.TOTAL_EPOCHS * iters_per_epoch
        dprint("images_per_batch: ", images_per_batch)
        dprint("dataset length: ", dataset_len)
        dprint("iters per epoch: ", iters_per_epoch)
        dprint("total iters: ", total_iters)

        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            model = DistributedDataParallel(model,
                                            device_ids=[comm.get_local_rank()],
                                            broadcast_buffers=False)
        super(DefaultTrainer, self).__init__(model, data_loader, optimizer)

        self.scheduler = self.build_lr_scheduler(cfg,
                                                 optimizer,
                                                 total_iters=total_iters)
        # Assume no other objects need to be checkpointed.
        # We can later make it checkpoint the stateful hooks
        self.checkpointer = AdetCheckpointer(
            # Assume you want to save checkpoints together with logs/statistics
            model,
            cfg.OUTPUT_DIR,
            optimizer=optimizer,
            scheduler=self.scheduler,
        )
        self.start_iter = 0
        self.max_iter = total_iters  # NOTE: ignore cfg.SOLVER.MAX_ITER
        self.cfg = cfg

        self.register_hooks(self.build_hooks())
def main(args):
    cfg = setup(args)
    model = Trainer.build_model(cfg)
    net = parsingNet(pretrained=False,
                     backbone=model,
                     cls_dim=(200 + 1, 18, 4),
                     use_aux=False).cuda()
    AdetCheckpointer(net, save_dir=cfg.OUTPUT_DIR).resume_or_load(
        cfg.MODEL.WEIGHTS, resume=args.resume)

    eval_lane(net, 'culane', '/home/ghr/CULANEROOT',
              '/home/ghr/CULANEROOT/own_test_result', 200, False, False)
Esempio n. 7
0
def main(args):
    cfg = setup(args)

    from detectron2.data.datasets import register_coco_instances

    register_coco_instances("surgery_train2", {},
                            "data/coco/annotations/instances_train2017.json",
                            "data/coco/train2017")

    MetadataCatalog.get("surgery_train2").thing_classes = [
        'Cerebellum', 'CN8', 'CN5', 'CN7', 'SCA', 'AICA',
        'SuperiorPetrosalVein', 'Vein', 'Brainstem', 'Suction', 'Bipolar',
        'Forcep', 'BluntProbe', 'Drill', 'Kerrison', 'Cottonoid', 'Scissors',
        'Unknown'
    ]

    DatasetCatalog.get("surgery_train2")

    register_coco_instances("surgery_val2", {},
                            "data/coco/annotations/instances_train2017.json",
                            "data/coco/train2017")

    MetadataCatalog.get("surgery_val2").thing_classes = [
        'Cerebellum', 'CN8', 'CN5', 'CN7', 'SCA', 'AICA',
        'SuperiorPetrosalVein', 'Vein', 'Brainstem', 'Suction', 'Bipolar',
        'Forcep', 'BluntProbe', 'Drill', 'Kerrison', 'Cottonoid', 'Scissors',
        'Unknown'
    ]

    DatasetCatalog.get("surgery_val2")

    if args.eval_only:
        model = Trainer.build_model(cfg)
        AdetCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume)
        res = Trainer.test(cfg, model)  # d2 defaults.py
        if comm.is_main_process():
            verify_results(cfg, res)
        if cfg.TEST.AUG.ENABLED:
            res.update(Trainer.test_with_TTA(cfg, model))
        return res
    """
    If you'd like to do anything fancier than the standard training logic,
    consider writing your own training loop or subclassing the trainer.
    """
    trainer = Trainer(cfg)
    trainer.resume_or_load(resume=args.resume)
    if cfg.TEST.AUG.ENABLED:
        trainer.register_hooks([
            hooks.EvalHook(0,
                           lambda: trainer.test_with_TTA(cfg, trainer.model))
        ])
    return trainer.train()
Esempio n. 8
0
    def resume_or_load(self, resume=True):
        if not isinstance(self.checkpointer, AdetCheckpointer):
            # support loading a few other backbones
            self.checkpointer = AdetCheckpointer(
                self.model,
                self.cfg.OUTPUT_DIR,
                optimizer=self.optimizer,
                scheduler=self.scheduler,
            )

        self.checkpointer.path_manager.register_handler(
            adet.utils.file_io.Detectron2Handler())

        super().resume_or_load(resume=resume)
Esempio n. 9
0
    def build_hooks(self):
        """
        Replace `DetectionCheckpointer` with `AdetCheckpointer`.

        Build a list of default hooks, including timing, evaluation,
        checkpointing, lr scheduling, precise BN, writing events.
        """
        ret = super().build_hooks()
        for i in range(len(ret)):
            if isinstance(ret[i], hooks.PeriodicCheckpointer):
                self.checkpointer = AdetCheckpointer(
                    self.model,
                    self.cfg.OUTPUT_DIR,
                    optimizer=self.optimizer,
                    scheduler=self.scheduler,
                )
                ret[i] = hooks.PeriodicCheckpointer(self.checkpointer, self.cfg.SOLVER.CHECKPOINT_PERIOD)
        return ret
Esempio n. 10
0
def do_train(cfg, model, resume=False):
    model.train()
    optimizer = build_optimizer(cfg, model)
    scheduler = build_lr_scheduler(cfg, optimizer)

    # checkpointer = DetectionCheckpointer(
    #     model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=scheduler
    # )
    checkpointer = AdetCheckpointer(model,
                                    cfg.OUTPUT_DIR,
                                    optimizer=optimizer,
                                    scheduler=scheduler)
    start_iter = (checkpointer.resume_or_load(
        cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1)
    max_iter = cfg.SOLVER.MAX_ITER

    periodic_checkpointer = PeriodicCheckpointer(checkpointer,
                                                 cfg.SOLVER.CHECKPOINT_PERIOD,
                                                 max_iter=max_iter)

    writers = ([
        CommonMetricPrinter(max_iter),
        JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")),
        TensorboardXWriter(cfg.OUTPUT_DIR),
    ] if comm.is_main_process() else [])

    # compared to "train_net.py", we do not support accurate timing and
    # precise BN here, because they are not trivial to implement in a small training loop
    data_loader = build_detection_train_loader(cfg)
    logger.info("Starting training from iteration {}".format(start_iter))
    with EventStorage(start_iter) as storage:
        for data, iteration in zip(data_loader, range(start_iter, max_iter)):
            iteration = iteration + 1
            storage.step()

            loss_dict = model(data)
            losses = sum(loss_dict.values())
            assert torch.isfinite(losses).all(), loss_dict

            loss_dict_reduced = {
                k: v.item()
                for k, v in comm.reduce_dict(loss_dict).items()
            }
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            if comm.is_main_process():
                storage.put_scalars(total_loss=losses_reduced,
                                    **loss_dict_reduced)

            optimizer.zero_grad()
            losses.backward()
            optimizer.step()
            storage.put_scalar("lr",
                               optimizer.param_groups[0]["lr"],
                               smoothing_hint=False)
            scheduler.step()

            if (cfg.TEST.EVAL_PERIOD > 0
                    and iteration % cfg.TEST.EVAL_PERIOD == 0
                    and iteration != max_iter):
                do_test(cfg, model)
                # Compared to "train_net.py", the test results are not dumped to EventStorage
                comm.synchronize()

            if iteration - start_iter > 5 and (iteration % 20 == 0
                                               or iteration == max_iter):
                for writer in writers:
                    writer.write()
            periodic_checkpointer.step(iteration)