Exemplo n.º 1
0
def main():
    experiment = comet_ml.Experiment()
    cfg = setup()

    for d in ["train", "val"]:
        DatasetCatalog.register("balloon_" + d,
                                lambda d=d: get_balloon_dicts("balloon/" + d))
        MetadataCatalog.get("balloon_" + d).set(thing_classes=["balloon"])
    balloon_metadata = MetadataCatalog.get("balloon_train")

    # Wrap the Detectron Default Trainer
    trainer = CometDefaultTrainer(cfg, experiment)
    trainer.resume_or_load(resume=False)

    # Register Hook to compute metrics using an Evaluator Object
    trainer.register_hooks([
        hooks.EvalHook(10,
                       lambda: trainer.evaluate_metrics(cfg, trainer.model))
    ])

    # Register Hook to compute eval loss
    trainer.register_hooks([
        hooks.EvalHook(10, lambda: trainer.evaluate_loss(cfg, trainer.model))
    ])
    trainer.train()

    # Evaluate Model Predictions
    cfg.MODEL.WEIGHTS = os.path.join(
        cfg.OUTPUT_DIR, "model_final.pth")  # path to the model we just trained
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7  # set a custom testing threshold
    predictor = DefaultPredictor(cfg)
    log_predictions(predictor, get_balloon_dicts("balloon/val"), experiment)
Exemplo n.º 2
0
def main(args):
    cfg = setup(args)

    if args.eval_only:
        model = Trainer.build_model(cfg)
        DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume
        )
        res = Trainer.test(cfg, model)
    
        if cfg.TEST.AUG.ENABLED:
            res.update(Trainer.test_with_TTA(cfg, model))
        if comm.is_main_process():
            verify_results(cfg, res)

        # Save the evaluation results
        pd.DataFrame(res).to_csv(f'{cfg.OUTPUT_DIR}/eval.csv')
        return res

    """
    If you'd like to do anything fancier than the standard training logic,
    consider writing your own training loop (see plain_train_net.py) or
    subclassing the trainer.
    """
    trainer = Trainer(cfg)
    trainer.resume_or_load(resume=args.resume)
    trainer.register_hooks(
            [hooks.EvalHook(0, lambda: trainer.eval_and_save(cfg, trainer.model))]
    )
    if cfg.TEST.AUG.ENABLED:
        trainer.register_hooks(
            [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))]
        )
    return trainer.train()
Exemplo n.º 3
0
    def build_hooks(self):
        cfg = self.cfg.clone()
        cfg.defrost()
        cfg.DATALOADER.NUM_WORKERS = 0  # save some memory and time for PreciseBN

        ret = [
            hooks.IterationTimer(),
            hooks.LRScheduler(self.optimizer, self.scheduler),
            hooks.PreciseBN(
                # Run at the same freq as (but before) evaluation.
                cfg.TEST.EVAL_PERIOD,
                self.model,
                # Build a new data loader to not affect training
                self.build_train_loader(cfg),
                cfg.TEST.PRECISE_BN.NUM_ITER,
            )
            if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
            else None,
        ]

        # Do PreciseBN before checkpointer, because it updates the model and need to
        # be saved by checkpointer.
        # This is not always the best: if checkpointing has a different frequency,
        # some checkpoints may have more precise statistics than others.
        if comm.is_main_process():
            ret.append(
                hooks.PeriodicCheckpointer(
                    self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD
                )
            )

        def test_and_save_results_student():
            self._last_eval_results_student = self.test(self.cfg, self.model)
            _last_eval_results_student = {
                k + "_student": self._last_eval_results_student[k]
                for k in self._last_eval_results_student.keys()
            }
            return _last_eval_results_student

        def test_and_save_results_teacher():
            self._last_eval_results_teacher = self.test(
                self.cfg, self.model_teacher)
            return self._last_eval_results_teacher

        ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD,
                   test_and_save_results_student))
        ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD,
                   test_and_save_results_teacher))

        if comm.is_main_process():
            # run writers in the end, so that evaluation metrics are written
            ret.append(hooks.PeriodicWriter(self.build_writers(), period=20))
        return ret
Exemplo n.º 4
0
def main(args):
    setup_logger_global_cfg_global_textlogger(args, tl_textdir=args.tl_textdir)
    cfg = setup(args)
    cfg = D2Utils.cfg_merge_from_easydict(cfg, global_cfg)
    if comm.is_main_process():
        path = os.path.join(cfg.OUTPUT_DIR, "config.yaml")
        with open(path, "w") as f:
            f.write(cfg.dump())

    if args.eval_only:
        model = Trainer.build_model(cfg)
        DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume)
        res = Trainer.test(cfg, model)
        if cfg.TEST.AUG.ENABLED:
            res.update(Trainer.test_with_TTA(cfg, model))
        if comm.is_main_process():
            verify_results(cfg, res)
        return res
    """
    If you'd like to do anything fancier than the standard training logic,
    consider writing your own training loop (see plain_train_net.py) or
    subclassing the trainer.
    """
    trainer = Trainer(cfg)
    trainer.resume_or_load(resume=args.resume)
    if cfg.TEST.AUG.ENABLED:
        trainer.register_hooks([
            hooks.EvalHook(0,
                           lambda: trainer.test_with_TTA(cfg, trainer.model))
        ])
    return trainer.train()
Exemplo n.º 5
0
def main(args):
    #cfg = setup(args)

    if args.eval_only:
        # TODO(jen)
        #model = Trainer.build_model(cfg)
        from models.rpn import ProposalNetwork
        rpn = ProposalNetwork('cuda')
        cfg = rpn.cfg
        model = rpn.predictor.model
        cfg.DATASETS.TRAIN = ('lvis_v0.5_train', )
        cfg.DATASETS.TEST = ('lvis_v0.5_val', )
        DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume)
        res = Trainer.test(cfg, model)
        if cfg.TEST.AUG.ENABLED:
            res.update(Trainer.test_with_TTA(cfg, model))
        if comm.is_main_process():
            verify_results(cfg, res)
        return res
    """
    If you'd like to do anything fancier than the standard training logic,
    consider writing your own training loop or subclassing the trainer.
    """
    trainer = Trainer(cfg)
    trainer.resume_or_load(resume=args.resume)
    if cfg.TEST.AUG.ENABLED:
        trainer.register_hooks([
            hooks.EvalHook(0,
                           lambda: trainer.test_with_TTA(cfg, trainer.model))
        ])
    return trainer.train()
Exemplo n.º 6
0
def main(args):
    print("enter main")
    cfg = setup(args)
    print("setup cfg ")
    register_lofar_datasets(cfg)
    print("register lofar datasets")
    print(f"output dir is {cfg.OUTPUT_DIR} datasets")

    if args.eval_only:
        model = LOFARTrainer.build_model(cfg)
        DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume)
        res = LOFARTrainer.test(cfg, model)
        if comm.is_main_process():
            verify_results(cfg, res)
        if cfg.TEST.AUG.ENABLED:
            res.update(LOFARTrainer.test_with_TTA(cfg, model))
        return res
    """
    If you'd like to do anything fancier than the standard training logic,
    consider writing your own training loop or subclassing the trainer.
    """
    trainer = LOFARTrainer(cfg)
    trainer.resume_or_load(resume=args.resume)
    print("set up trainer")
    if cfg.TEST.AUG.ENABLED:
        trainer.register_hooks([
            hooks.EvalHook(0,
                           lambda: trainer.test_with_TTA(cfg, trainer.model))
        ])
    return trainer.train()
Exemplo n.º 7
0
def main(args):
    cfg = setup(args)

    if args.eval_only:
        model = Trainer.build_model(cfg)
        DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume)
        res = Trainer.test(cfg, model)
        if cfg.TEST.AUG.ENABLED:
            res.update(Trainer.test_with_TTA(cfg, model))
        if comm.is_main_process():
            verify_results(cfg, res)
        return res
    """
    If you'd like to do anything fancier than the standard training logic,
    consider writing your own training loop (see plain_train_net.py) or
    subclassing the trainer.
    """
    trainer = Trainer(cfg)
    trainer.resume_or_load(resume=args.resume)
    from detectron2.lib.utils.net import convert_bn2affine_model
    trainer.model = convert_bn2affine_model(trainer.model)
    trainer.cuda(cfg.MODEL.DEVICE)
    if cfg.TEST.AUG.ENABLED:
        trainer.register_hooks([
            hooks.EvalHook(0,
                           lambda: trainer.test_with_TTA(cfg, trainer.model))
        ])
    return trainer.train()
Exemplo n.º 8
0
def main(args):
    cfg = setup(args)

    if args.test_img:
        results = test_on_img(cfg, args.test_img)
        print(results)
        return results

    if args.eval_only:
        model = Trainer.build_model(cfg)
        AdetCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume
        )
        evaluators = [
            Trainer.build_evaluator(cfg, name)
            for name in cfg.DATASETS.TEST
        ]
        res = Trainer.test(cfg, model, evaluators)
        if comm.is_main_process():
            verify_results(cfg, res)
        if cfg.TEST.AUG.ENABLED:
            res.update(Trainer.test_with_TTA(cfg, model))
        return res

    """
    If you'd like to do anything fancier than the standard training logic,
    consider writing your own training loop or subclassing the trainer.
    """
    trainer = Trainer(cfg)
    trainer.resume_or_load(resume=args.resume)
    if cfg.TEST.AUG.ENABLED:
        trainer.register_hooks(
            [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))]
        )
    return trainer.train()
Exemplo n.º 9
0
def main(args):
    cfg = setup(args)

    # TODO REGISTRATION OF DATASET

    register_isprs_train_instance(cfg)
    register_isprs_val_instance(cfg)
    register_isprs_test_instance(cfg)
    register_isprs_train_panoptic(cfg)
    register_isprs_val_panoptic(cfg)
    register_isprs_test_panoptic(cfg)

    if args.eval_only:
        model = ISPRSTrainer.build_model(cfg)
        DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume)
        res = ISPRSTrainer.test(cfg, model)
        if cfg.TEST.AUG.ENABLED:
            res.update(ISPRSTrainer.test_with_TTA(cfg, model))
        if comm.is_main_process():
            verify_results(cfg, res)
        return res

    trainer = ISPRSTrainer(cfg)
    trainer.resume_or_load(resume=args.resume)
    if cfg.TEST.AUG.ENABLED:
        trainer.register_hooks([
            hooks.EvalHook(0,
                           lambda: trainer.test_with_TTA(cfg, trainer.model))
        ])
    return trainer.train()
Exemplo n.º 10
0
def main(args):
    cfg = setup(args)

    if comm.is_main_process(
    ) and cfg.USE_WANDB:  # CSD: set up wandb (for tracking visualizations)
        wandb.login()
        wandb.init(project=cfg.WANDB_PROJECT_NAME,
                   config=cfg,
                   sync_tensorboard=True)

    if args.eval_only:
        model = Trainer.build_model(cfg)
        DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume)
        res = Trainer.test(cfg, model)
        if cfg.TEST.AUG.ENABLED:
            res.update(Trainer.test_with_TTA(cfg, model))
        if comm.is_main_process():
            verify_results(cfg, res)
        return res
    """
    If you'd like to do anything fancier than the standard training logic,
    consider writing your own training loop (see plain_train_net.py) or
    subclassing the trainer.
    """
    trainer = Trainer(cfg)
    trainer.resume_or_load(resume=args.resume)
    if cfg.TEST.AUG.ENABLED:
        trainer.register_hooks([
            hooks.EvalHook(0,
                           lambda: trainer.test_with_TTA(cfg, trainer.model))
        ])
    return trainer.train()
Exemplo n.º 11
0
def main(args):
    cfg = setup(args)

    # Load cfg as python dict
    config = load_yaml(args.config_file)

    # If evaluation
    if args.eval_only:
        model = Trainer.build_model(cfg)
        DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume
        )
        res = Trainer.test(cfg, model)

        if cfg.TEST.AUG.ENABLED:
            res.update(Trainer.test_with_TTA(cfg, model))
        if comm.is_main_process():
            verify_results(cfg, res)

    # If training
    else:
        trainer = Trainer(cfg)
        # Load model weights (if specified)
        trainer.resume_or_load(resume=args.resume)
        if cfg.TEST.AUG.ENABLED:
            trainer.register_hooks(
                [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))]
            )
        # Will evaluation be done at end of training?
        res = trainer.train()

    return res
Exemplo n.º 12
0
def main(args):
    cfg = setup(args)

    if args.eval_only:
        model = Trainer.build_model(cfg)
        AdetCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume)
        res = Trainer.test(cfg, model)  # d2 defaults.py
        if comm.is_main_process():
            verify_results(cfg, res)
        if cfg.TEST.AUG.ENABLED:
            res.update(Trainer.test_with_TTA(cfg, model))
        return res

    if args.tpu:
        import torch_xla.core.xla_model as xm
        _ = xm.xla_device()
    """
    If you'd like to do anything fancier than the standard training logic,
    consider writing your own training loop or subclassing the trainer.
    """
    trainer = Trainer(cfg)
    trainer.resume_or_load(resume=args.resume)
    if cfg.TEST.AUG.ENABLED:
        trainer.register_hooks([
            hooks.EvalHook(0,
                           lambda: trainer.test_with_TTA(cfg, trainer.model))
        ])
    return trainer.train()
def main(args):
    cfg = setup(args)

    if args.eval_only:
        cfg.DATASETS.TEST = ("usopen_nadal_test",)  
        cfg.freeze()
        default_setup(cfg, args)


        model = Trainer.build_model(cfg)
        DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=True
        )
        evaluator = Trainer.build_evaluator(
            cfg,
            dataset_name=cfg.DATASETS.TEST[0],
            output_folder=os.path.join(cfg.OUTPUT_DIR, "inference", "test")
            )
        res = Trainer.test(cfg, model, evaluator)
        return 

    """
    If you'd like to do anything fancier than the standard training logic,
    consider writing your own training loop (see plain_train_net.py) or
    subclassing the trainer.
    """
    cfg.freeze()
    default_setup(cfg, args)
    trainer = Trainer(cfg)
    trainer.resume_or_load(resume=args.resume)
    if cfg.TEST.AUG.ENABLED:
        trainer.register_hooks(
            [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))]
        )
    return trainer.train()
Exemplo n.º 14
0
def main(args):
    print('setting up configs')
    cfg = setup(args)
    print('configs loaded')

    if args.eval_only:
        model = Trainer.build_model(cfg)
        DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume
        )
        res = Trainer.test(cfg, model)
        if cfg.TEST.AUG.ENABLED:
            res.update(Trainer.test_with_TTA(cfg, model))
        if comm.is_main_process():
            verify_results(cfg, res)
        return res

    """
    If you'd like to do anything fancier than the standard training logic,
    consider writing your own training loop (see plain_train_net.py) or
    subclassing the trainer.
    """
    print('building model from configs')
    trainer = Trainer(cfg)
    print('model built')
    trainer.resume_or_load(resume=args.resume)
    if cfg.TEST.AUG.ENABLED:
        trainer.register_hooks(
            [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))]
        )
    return trainer.train()
Exemplo n.º 15
0
def main(args):
    # TODO: remove this hardcoded stuff
    root = '/home/arash/Software/datasets/VOCdevkit/'
    register_all_pascal_voc_org(root)
    cfg = setup(args)

    if args.eval_only:
        model = Trainer.build_model(cfg)
        DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume)
        res = Trainer.test(cfg, model)
        if comm.is_main_process():
            verify_results(cfg, res)
        if cfg.TEST.AUG.ENABLED:
            res.update(Trainer.test_with_TTA(cfg, model))
        return res
    """
    If you'd like to do anything fancier than the standard training logic,
    consider writing your own training loop or subclassing the trainer.
    """
    trainer = Trainer(cfg)
    trainer.resume_or_load(resume=args.resume)
    if cfg.TEST.AUG.ENABLED:
        trainer.register_hooks([
            hooks.EvalHook(0,
                           lambda: trainer.test_with_TTA(cfg, trainer.model))
        ])
    return trainer.train()
Exemplo n.º 16
0
def main(args):
    # Gets and sets up config
    cfg = setup(args)

    # If eval only
    if args.eval_only:
        model = Trainer.build_model(cfg)
        DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume)
        res = Trainer.test(cfg, model)
        if comm.is_main_process():
            verify_results(cfg, res)
        if cfg.TEST.AUG.ENABLED:
            res.update(Trainer.test_with_TTA(cfg, model))
        return res
    """
	If you'd like to do anything fancier than the standard training logic,
	consider writing your own training loop or subclassing the trainer.
	"""
    trainer = Trainer(cfg)
    trainer.resume_or_load(resume=args.resume)
    if cfg.TEST.AUG.ENABLED:
        trainer.register_hooks([
            hooks.EvalHook(0,
                           lambda: trainer.test_with_TTA(cfg, trainer.model))
        ])
    return trainer.train()
Exemplo n.º 17
0
def do_train(cfg):
    model = instantiate(cfg.model)
    logger = logging.getLogger("detectron2")
    logger.info("Model:\n{}".format(model))
    model.to(cfg.train.device)

    cfg.optimizer.params.model = model
    optim = instantiate(cfg.optimizer)

    train_loader = instantiate(cfg.dataloader.train)

    model = create_ddp_model(model, **cfg.train.ddp)
    trainer = (AMPTrainer if cfg.train.amp.enabled else SimpleTrainer)(
        model, train_loader, optim)
    checkpointer = DetectionCheckpointer(
        model,
        cfg.train.output_dir,
        optimizer=optim,
        trainer=trainer,
    )
    trainer.register_hooks([
        hooks.IterationTimer(),
        hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)),
        hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer)
        if comm.is_main_process() else None,
        hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)),
        hooks.PeriodicWriter(
            default_writers(cfg.train.output_dir, cfg.train.max_iter),
            period=cfg.train.log_period,
        ) if comm.is_main_process() else None,
    ])

    checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=True)
    start_iter = 0
    trainer.train(start_iter, cfg.train.max_iter)
Exemplo n.º 18
0
def main(args):
    register_coco_instances("object365_train", {},
                            "data/object_365/objects365_train.json",
                            "data/object_365/train")
    register_coco_instances("object365_val", {},
                            "data/object_365/objects365_val.json",
                            "data/object_365/val")
    cfg = setup(args)

    if args.eval_only:
        model = Trainer.build_model(cfg)
        DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume)
        res = Trainer.test(cfg, model)
        if cfg.TEST.AUG.ENABLED:
            res.update(Trainer.test_with_TTA(cfg, model))
        if comm.is_main_process():
            verify_results(cfg, res)
        return res
    """
    If you'd like to do anything fancier than the standard training logic,
    consider writing your own training loop (see plain_train_net.py) or
    subclassing the trainer.
    """
    trainer = Trainer(cfg,
                      linear_eval=args.linear_eval,
                      mini=args.mini,
                      mini_ratio=args.mini_ratio)
    trainer.resume_or_load(resume=args.resume)
    if cfg.TEST.AUG.ENABLED:
        trainer.register_hooks([
            hooks.EvalHook(0,
                           lambda: trainer.test_with_TTA(cfg, trainer.model))
        ])
    return trainer.train()
Exemplo n.º 19
0
def main(args):
    cfg = setup(args)

    from thop import profile

    if args.eval_only:
        model = Trainer.build_model(cfg)
        #input_size = (1, 3, 288, 800)
        #input_size = torch.zeros(*input_size)
        #flops, params = profile(model, inputs=(input_size, ))
        AdetCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume)
        res = Trainer.test(cfg, model)  # d2 defaults.py
        if comm.is_main_process():
            verify_results(cfg, res)
        if cfg.TEST.AUG.ENABLED:
            res.update(Trainer.test_with_TTA(cfg, model))
        return res
    """
    If you'd like to do anything fancier than the standard training logic,
    consider writing your own training loop or subclassing the trainer.
    """
    trainer = Trainer(cfg)
    trainer.resume_or_load(resume=args.resume)
    if cfg.TEST.AUG.ENABLED:
        trainer.register_hooks([
            hooks.EvalHook(0,
                           lambda: trainer.test_with_TTA(cfg, trainer.model))
        ])
    return trainer.train()
Exemplo n.º 20
0
    def test_writer_hooks(self):
        model = _SimpleModel(sleep_sec=0.1)
        trainer = SimpleTrainer(model, self._data_loader("cpu"),
                                torch.optim.SGD(model.parameters(), 0.1))

        max_iter = 50

        with tempfile.TemporaryDirectory(prefix="detectron2_test") as d:
            json_file = os.path.join(d, "metrics.json")
            writers = [CommonMetricPrinter(max_iter), JSONWriter(json_file)]

            trainer.register_hooks([
                hooks.EvalHook(0, lambda: {"metric": 100}),
                hooks.PeriodicWriter(writers)
            ])
            with self.assertLogs(writers[0].logger) as logs:
                trainer.train(0, max_iter)

            with open(json_file, "r") as f:
                data = [json.loads(line.strip()) for line in f]
                self.assertEqual([x["iteration"] for x in data],
                                 [19, 39, 49, 50])
                # the eval metric is in the last line with iter 50
                self.assertIn("metric", data[-1],
                              "Eval metric must be in last line of JSON!")

            # test logged messages from CommonMetricPrinter
            self.assertEqual(len(logs.output), 3)
            for log, iter in zip(logs.output, [19, 39, 49]):
                self.assertIn(f"iter: {iter}", log)

            self.assertIn("eta: 0:00:00", logs.output[-1],
                          "Last ETA must be 0!")
Exemplo n.º 21
0
def main(args):
    cfg = setup(args)
    # disable strict kwargs checking: allow one to specify path handle
    # hints through kwargs, like timeout in DP evaluation
    PathManager.set_strict_kwargs_checking(False)

    if args.eval_only:
        model = Trainer.build_model(cfg)
        DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume
        )
        res = Trainer.test(cfg, model)
        if cfg.TEST.AUG.ENABLED:
            res.update(Trainer.test_with_TTA(cfg, model))
        if comm.is_main_process():
            verify_results(cfg, res)
        return res

    trainer = Trainer(cfg)
    trainer.resume_or_load(resume=args.resume)
    if cfg.TEST.AUG.ENABLED:
        trainer.register_hooks(
            [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))]
        )
    return trainer.train()
def main(args):
    cfg = setup(args)

    # Select which trainer to use
    Trainer = MetaReweightTrainer if cfg.get("META_REWEIGHT",
                                             False) else PlainTrainer

    if args.eval_only:
        model = Trainer.build_model(cfg)
        DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume)
        res = Trainer.test(cfg, model)
        if cfg.TEST.AUG.ENABLED:
            res.update(Trainer.test_with_TTA(cfg, model))
        if comm.is_main_process():
            verify_results(cfg, res)
        return res
    """
    If you'd like to do anything fancier than the standard training logic,
    consider writing your own training loop (see plain_train_net.py) or
    subclassing the trainer.
    """
    trainer = Trainer(cfg)
    trainer.resume_or_load(resume=args.resume)
    if cfg.TEST.AUG.ENABLED:
        trainer.register_hooks([
            hooks.EvalHook(0,
                           lambda: trainer.test_with_TTA(cfg, trainer.model))
        ])

    return trainer.train()
Exemplo n.º 23
0
 def test_best_checkpointer(self):
     model = _SimpleModel()
     dataloader = self._data_loader("cpu")
     opt = torch.optim.SGD(model.parameters(), 0.1)
     metric_name = "metric"
     total_iter = 40
     test_period = 10
     test_cases = [
         ("max", iter([0.3, 0.4, 0.35, 0.5]), 3),
         ("min", iter([1.0, 0.8, 0.9, 0.9]), 2),
         ("min", iter([math.nan, 0.8, 0.9, 0.9]), 1),
     ]
     for mode, metrics, call_count in test_cases:
         trainer = SimpleTrainer(model, dataloader, opt)
         with tempfile.TemporaryDirectory(prefix="detectron2_test") as d:
             checkpointer = Checkpointer(model, d, opt=opt, trainer=trainer)
             trainer.register_hooks([
                 hooks.EvalHook(test_period,
                                lambda: {metric_name: next(metrics)}),
                 hooks.BestCheckpointer(test_period,
                                        checkpointer,
                                        metric_name,
                                        mode=mode),
             ])
             with mock.patch.object(checkpointer,
                                    "save") as mock_save_method:
                 trainer.train(0, total_iter)
                 self.assertEqual(mock_save_method.call_count, call_count)
Exemplo n.º 24
0
def main(args):
    cfg = setup(args)

    # Load cfg as python dict
    config = load_yaml(args.config_file)

    # Setup wandb
    wandb.init(
        # Use exp name to resume run later on
        # id="cascade_df_resume",
        id=args.exp_name,
        project="piplup-od",
        # name="cascade_df_resume",
        name=args.exp_name,
        sync_tensorboard=True,
        config=config,
        # Resume making use of the same exp name
        resume=args.exp_name if args.resume else False,
        # dir=cfg.OUTPUT_DIR,
    )
    # Auto upload any checkpoints to wandb as they are written
    # wandb.save(os.path.join(cfg.OUTPUT_DIR, "*.pth"))

    # TODO: Visualize and log training examples and annotations
    # training_imgs = viz_data(cfg)
    # wandb.log({"training_examples": training_imgs})

    # If evaluation
    if args.eval_only:
        model = Trainer.build_model(cfg)
        DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume
        )
        res = Trainer.test(cfg, model)

        # FIXME: TTA
        if cfg.TEST.AUG.ENABLED:
            res.update(Trainer.test_with_TTA(cfg, model))
        if comm.is_main_process():
            verify_results(cfg, res)

    # If training
    else:
        trainer = Trainer(cfg)
        # Load model weights (if specified)
        trainer.resume_or_load(resume=args.resume)
        # FIXME: TTA
        if cfg.TEST.AUG.ENABLED:
            trainer.register_hooks(
                [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))]
            )
        # Will evaluation be done at end of training?
        res = trainer.train()

    # TODO: Visualize and log predictions and groundtruth annotations
    pred_imgs = viz_preds(cfg)
    wandb.log({"prediction_examples": pred_imgs})

    return res
Exemplo n.º 25
0
    def build_hooks(self):
        """
        Build a list of default hooks, including timing, evaluation,
        checkpointing, lr scheduling, precise BN, writing events.
        Returns:
            list[HookBase]:
        """
        logger = logging.getLogger(__name__)
        cfg = self.cfg.clone()
        cfg.defrost()
        cfg.DATALOADER.NUM_WORKERS = 0  # save some memory and time for PreciseBN

        ret = [
            hooks.IterationTimer(),
            hooks.LRScheduler(self.optimizer, self.scheduler),
        ]

        if cfg.SOLVER.SWA.ENABLED:
            ret.append(
                additional_hooks.SWA(
                    cfg.SOLVER.MAX_ITER,
                    cfg.SOLVER.SWA.PERIOD,
                    cfg.SOLVER.SWA.LR_START,
                    cfg.SOLVER.SWA.ETA_MIN_LR,
                    cfg.SOLVER.SWA.LR_SCHED,
                )
            )

        if cfg.TEST.PRECISE_BN.ENABLED and hooks.get_bn_modules(self.model):
            logger.info("Prepare precise BN dataset")
            ret.append(hooks.PreciseBN(
                # Run at the same freq as (but before) evaluation.
                cfg.TEST.EVAL_PERIOD,
                self.model,
                # Build a new data loader to not affect training
                self.build_train_loader(cfg),
                cfg.TEST.PRECISE_BN.NUM_ITER,
            ))
        # Do PreciseBN before checkpointer, because it updates the model and need to
        # be saved by checkpointer.
        # This is not always the best: if checkpointing has a different frequency,
        # some checkpoints may have more precise statistics than others.
        if comm.is_main_process():
            ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))

        def test_and_save_results():
            self._last_eval_results = self.test(self.cfg, self.model)
            return self._last_eval_results

        # Do evaluation after checkpointer, because then if it fails,
        # we can use the saved checkpoint to debug.
        ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))

        if comm.is_main_process():
            # run writers in the end, so that evaluation metrics are written
            ret.append(hooks.PeriodicWriter(self.build_writers(), 20))

        return ret
Exemplo n.º 26
0
def do_train(args, cfg):
    """
    Args:
        cfg: an object with the following attributes:
            model: instantiate to a module
            dataloader.{train,test}: instantiate to dataloaders
            dataloader.evaluator: instantiate to evaluator for test set
            optimizer: instantaite to an optimizer
            lr_multiplier: instantiate to a fvcore scheduler
            train: other misc config defined in `configs/common/train.py`, including:
                output_dir (str)
                init_checkpoint (str)
                amp.enabled (bool)
                max_iter (int)
                eval_period, log_period (int)
                device (str)
                checkpointer (dict)
                ddp (dict)
    """
    model = instantiate(cfg.model)
    logger = logging.getLogger("detectron2")
    logger.info("Model:\n{}".format(model))
    model.to(cfg.train.device)

    cfg.optimizer.params.model = model
    optim = instantiate(cfg.optimizer)

    train_loader = instantiate(cfg.dataloader.train)

    model = create_ddp_model(model, **cfg.train.ddp)
    trainer = (AMPTrainer if cfg.train.amp.enabled else SimpleTrainer)(
        model, train_loader, optim)
    checkpointer = DetectionCheckpointer(
        model,
        cfg.train.output_dir,
        optimizer=optim,
        trainer=trainer,
    )
    trainer.register_hooks([
        hooks.IterationTimer(),
        hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)),
        hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer)
        if comm.is_main_process() else None,
        hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)),
        hooks.PeriodicWriter(
            default_writers(cfg.train.output_dir, cfg.train.max_iter),
            period=cfg.train.log_period,
        ) if comm.is_main_process() else None,
    ])

    checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume)
    if args.resume and checkpointer.has_checkpoint():
        # The checkpoint stores the training iteration that just finished, thus we start
        # at the next iteration
        start_iter = trainer.iter + 1
    else:
        start_iter = 0
    trainer.train(start_iter, cfg.train.max_iter)
Exemplo n.º 27
0
def train(cfg):
    trainer = Trainer(cfg)
    trainer.resume_or_load(resume=args.resume)
    if cfg.TEST.AUG.ENABLED:
        trainer.register_hooks([
            hooks.EvalHook(0,
                           lambda: trainer.test_with_TTA(cfg, trainer.model))
        ])
    return trainer.train()
Exemplo n.º 28
0
    def build_hooks(self):
        """
        Build a list of default hooks, including timing, evaluation,
        checkpointing, lr scheduling, precise BN, writing events.
        Returns:
            list[HookBase]:
        """
        cfg = self.cfg.clone()
        cfg.defrost()
        cfg.DATALOADER.NUM_WORKERS = 0  # save some memory and time for PreciseBN

        ret = [
            hooks.IterationTimer(),
            hooks.LRScheduler(self.optimizer, self.scheduler),
            hooks.PreciseBN(
                # Run at the same freq as (but before) evaluation.
                cfg.TEST.EVAL_PERIOD,
                self.model,
                # Build a new data loader to not affect training
                self.build_train_loader(cfg),
                cfg.TEST.PRECISE_BN.NUM_ITER,
            ) if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
            else None,
        ]

        # Do PreciseBN before checkpointer, because it updates the model and need to
        # be saved by checkpointer.
        # This is not always the best: if checkpointing has a different frequency,
        # some checkpoints may have more precise statistics than others.
        if comm.is_main_process():
            if cfg.SOLVER.CHECKPOINT_BY_EPOCH:
                ckpt_period = cfg.SOLVER.CHECKPOINT_PERIOD * self.iters_per_epoch
            else:
                ckpt_period = cfg.SOLVER.CHECKPOINT_PERIOD
            ret.append(
                MyPeriodicCheckpointer(self.checkpointer,
                                       ckpt_period,
                                       max_to_keep=cfg.SOLVER.get(
                                           "NUM_CKPT_KEEP", 5),
                                       iters_per_epoch=self.iters_per_epoch))

        def test_and_save_results():
            self._last_eval_results = self.test(self.cfg, self.model)
            return self._last_eval_results

        # Do evaluation after checkpointer, because then if it fails,
        # we can use the saved checkpoint to debug.
        ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))

        if comm.is_main_process():
            # run writers in the end, so that evaluation metrics are written
            ret.append(
                hooks.PeriodicWriter(self.build_writers(),
                                     period=cfg.TRAIN.get("PRINT_FREQ", 100)))
        return ret
Exemplo n.º 29
0
    def test_eval_hook(self):
        model = _SimpleModel()
        dataloader = self._data_loader("cpu")
        opt = torch.optim.SGD(model.parameters(), 0.1)

        for total_iter, period, eval_count in [(30, 15, 2), (31, 15, 3), (20, 0, 1)]:
            test_func = mock.Mock(return_value={"metric": 3.0})
            trainer = SimpleTrainer(model, dataloader, opt)
            trainer.register_hooks([hooks.EvalHook(period, test_func)])
            trainer.train(0, total_iter)
            self.assertEqual(test_func.call_count, eval_count)
Exemplo n.º 30
0
def main(args):
    cfg = setup(args)

    from detectron2.data.datasets import register_coco_instances

    register_coco_instances("surgery_train2", {},
                            "data/coco/annotations/instances_train2017.json",
                            "data/coco/train2017")

    MetadataCatalog.get("surgery_train2").thing_classes = [
        'Cerebellum', 'CN8', 'CN5', 'CN7', 'SCA', 'AICA',
        'SuperiorPetrosalVein', 'Vein', 'Brainstem', 'Suction', 'Bipolar',
        'Forcep', 'BluntProbe', 'Drill', 'Kerrison', 'Cottonoid', 'Scissors',
        'Unknown'
    ]

    DatasetCatalog.get("surgery_train2")

    register_coco_instances("surgery_val2", {},
                            "data/coco/annotations/instances_train2017.json",
                            "data/coco/train2017")

    MetadataCatalog.get("surgery_val2").thing_classes = [
        'Cerebellum', 'CN8', 'CN5', 'CN7', 'SCA', 'AICA',
        'SuperiorPetrosalVein', 'Vein', 'Brainstem', 'Suction', 'Bipolar',
        'Forcep', 'BluntProbe', 'Drill', 'Kerrison', 'Cottonoid', 'Scissors',
        'Unknown'
    ]

    DatasetCatalog.get("surgery_val2")

    if args.eval_only:
        model = Trainer.build_model(cfg)
        AdetCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume)
        res = Trainer.test(cfg, model)  # d2 defaults.py
        if comm.is_main_process():
            verify_results(cfg, res)
        if cfg.TEST.AUG.ENABLED:
            res.update(Trainer.test_with_TTA(cfg, model))
        return res
    """
    If you'd like to do anything fancier than the standard training logic,
    consider writing your own training loop or subclassing the trainer.
    """
    trainer = Trainer(cfg)
    trainer.resume_or_load(resume=args.resume)
    if cfg.TEST.AUG.ENABLED:
        trainer.register_hooks([
            hooks.EvalHook(0,
                           lambda: trainer.test_with_TTA(cfg, trainer.model))
        ])
    return trainer.train()