class Tester:
    def __init__(self, cfg):
        self.cfg = cfg
        self.model = Trainer.build_model(cfg)
        self.check_pointer = DetectionCheckpointer(self.model,
                                                   save_dir=cfg.OUTPUT_DIR)

        self.best_res = None
        self.best_file = None
        self.all_res = {}

    def test(self, ckpt):
        self.check_pointer._load_model(self.check_pointer._load_file(ckpt))
        print('evaluating checkpoint {}'.format(ckpt))
        res = Trainer.test(self.cfg, self.model)

        if comm.is_main_process():
            verify_results(self.cfg, res)
            print(res)
            if (self.best_res is None) or (
                    self.best_res is not None
                    and self.best_res['bbox']['AP'] < res['bbox']['AP']):
                self.best_res = res
                self.best_file = ckpt
            print('best results from checkpoint {}'.format(self.best_file))
            print(self.best_res)
            self.all_res["best_file"] = self.best_file
            self.all_res["best_res"] = self.best_res
            self.all_res[ckpt] = res
            os.makedirs(os.path.join(self.cfg.OUTPUT_DIR, 'inference'),
                        exist_ok=True)
            with open(
                    os.path.join(self.cfg.OUTPUT_DIR, 'inference',
                                 'all_res.json'), 'w') as fp:
                json.dump(self.all_res, fp)
Ejemplo n.º 2
0
def main(args):
    cfg = setup(args)

    # eval_only and eval_during_train are mainly used for jointly
    # training detection and self-supervised models.
    if args.eval_only:
        model = Trainer.build_model(cfg)
        DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume)
        res = Trainer.test(cfg, model)
        if comm.is_main_process():
            verify_results(cfg, res)
        if cfg.TEST.AUG.ENABLED:
            res.update(Trainer.test_with_TTA(cfg, model))
        return res
    elif args.eval_during_train:
        model = Trainer.build_model(cfg)
        check_pointer = DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR)
        saved_checkpoint = None
        best_res = {}
        best_file = None
        while True:
            if check_pointer.has_checkpoint():
                current_ckpt = check_pointer.get_checkpoint_file()
                if (saved_checkpoint is None
                        or current_ckpt != saved_checkpoint):
                    check_pointer._load_model(
                        check_pointer._load_file(current_ckpt))
                    saved_checkpoint = current_ckpt
                    print("evaluating checkpoint {}".format(current_ckpt))
                    iters = int(
                        osp.splitext(
                            osp.basename(current_ckpt))[0].split("_")[-1])
                    res = Trainer.test(cfg, model)
                    if comm.is_main_process():
                        verify_results(cfg, res)
                    if cfg.TEST.AUG.ENABLED:
                        res.update(Trainer.test_with_TTA(cfg, model))
                    print(res)
                    if (len(best_res) == 0) or (
                            len(best_res) > 0
                            and best_res["bbox"]["AP"] < res["bbox"]["AP"]):
                        best_res = res
                        best_file = current_ckpt
                    print("best so far is from {}".format(best_file))
                    print(best_res)
                    if iters + 1 >= cfg.SOLVER.MAX_ITER:
                        return best_res
            time.sleep(10)
    """
    If you'd like to do anything fancier than the standard training logic,
    consider writing your own training loop or subclassing the trainer.
    """
    trainer = Trainer(cfg)
    trainer.resume_or_load(resume=args.resume)
    if cfg.TEST.AUG.ENABLED:
        trainer.register_hooks([
            hooks.EvalHook(0,
                           lambda: trainer.test_with_TTA(cfg, trainer.model))
        ])
    return trainer.train()