Exemple #1
0
 def load_G(self):
     ckpt_path = self._get_ckpt_path(
         ckpt_dir=self.trained_supernet_ckpt_dir,
         ckpt_epoch=self.ckpt_epoch,
         iter_every_epoch=self.supernet_iter_every_epoch)
     saved_model = self.get_saved_model()
     checkpointer = Checkpointer(saved_model, save_to_disk=False)
     loaded_checkpoint = checkpointer.resume_or_load(ckpt_path,
                                                     resume=False)
     pass
    def load_model_weights(self,
                           ckpt_dir,
                           ckpt_epoch,
                           ckpt_iter_every_epoch,
                           ckpt_path=None):
        if ckpt_path is None:
            ckpt_path = self._get_ckpt_path(ckpt_dir, ckpt_epoch,
                                            ckpt_iter_every_epoch)

        model = self.get_saved_model()
        checkpointer = Checkpointer(model, save_to_disk=False)
        checkpointer.resume_or_load(ckpt_path, resume=False)
        pass
    def _load_G(self, eval_ckpt):
        checkpointer = Checkpointer(self.get_saved_model(), save_to_disk=False)

        # strip .module.
        checkpoint = checkpointer._load_file(eval_ckpt)
        checkpoint_state_dict = checkpoint['model']
        checkpointer._convert_ndarray_to_tensor(checkpoint_state_dict)
        self._strip_module_if_present(checkpoint_state_dict)

        self.logger.info(f"Load model from {eval_ckpt}")
        checkpointer._load_model(checkpoint)
        pass
Exemple #4
0
 def get_d2_checkpointer(model_dict, optim_dict, ckptdir):
     ckpt_model = DumpModule(model_dict)
     checkpointer = Checkpointer(ckpt_model, ckptdir, **optim_dict)
     return checkpointer
    def _load_G(self, eval_ckpt):
        checkpointer = Checkpointer(self.get_saved_model(), save_to_disk=False)
        checkpointer.resume_or_load(eval_ckpt)

        pass
Exemple #6
0
    return imgs_anns


for d in ["train", "test"]:
    DatasetCatalog.register("caltech_" + d, lambda d=d: get_caltech_dicts(d))
    MetadataCatalog.get("caltech_" + d).set(thing_classes=["person"])
caltech_metadata = MetadataCatalog.get("caltech_train")

cfg = get_cfg()
cfg.merge_from_file("./configs/frcn_dt.yaml")
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = .0  # set to 0 to achieve smaller miss rate
predictor = DefaultPredictor(cfg)
evaluator = COCOEvaluator("caltech_test", cfg, False, output_dir='./output/')
val_loader = build_detection_test_loader(cfg, "caltech_test")
model = build_model(cfg)
ckpt = Checkpointer(model)
ckpt.load(os.path.join(cfg.OUTPUT_DIR, "model_0049999.pth"))
# inference_on_dataset(model, val_loader, evaluator)  # compute map value

# convert to caltech data eval format
res = {}
with torch.no_grad():
    model.eval()
    for inputs in tqdm(val_loader):
        outputs = model(inputs)
        for i, o in zip(inputs, outputs):
            fn = i['file_name']
            idx = fn.rfind('/') + 1
            fn = fn[idx:-4].split('_')
            sid, vid, frame = fn[0], fn[1], int(fn[2]) + 1
            if sid not in res:
if not logger.isEnabledFor(logging.INFO):  # setup_logger is not called for d2
    setup_logger()

model = build_model(cfg)
logger.info("Model:\n{}".format(model))
model.train()
optimizer = build_optimizer(cfg, model)
scheduler = build_lr_scheduler(cfg, optimizer)

checkpointer = DetectionCheckpointer(model,
                                     cfg.OUTPUT_DIR,
                                     optimizer=optimizer,
                                     scheduler=scheduler)
start_iter = (checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=False).get(
    "iteration", -1) + 1)
ckpt = Checkpointer(model)
ckpt.load("./frcn_attn_0/model_0044999.pth")
max_iter = cfg.SOLVER.MAX_ITER

periodic_checkpointer = PeriodicCheckpointer(checkpointer,
                                             cfg.SOLVER.CHECKPOINT_PERIOD,
                                             max_iter=max_iter)

writers = ([
    CommonMetricPrinter(max_iter),
    JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")),
    TensorboardXWriter(cfg.OUTPUT_DIR),
] if comm.is_main_process() else [])

# compared to "train_net.py", we do not support accurate timing and
# precise BN here, because they are not trivial to implement