Пример #1
0
def main(cfg):
    OmegaConf.set_struct(
        cfg,
        False)  # This allows getattr and hasattr methods to function correctly
    log.info(cfg.pretty())
    workdir = os.path.join(BASE_DIR, cfg.model_name)
    if not os.path.exists(workdir):
        os.makedirs(workdir)

    cfg.checkpoint_dir = workdir
    cfg.tracker_options.full_res = True
    local_models = {}
    for fold, url in MODELS_URL[cfg.model_name].items():
        local_file = os.path.join(workdir,
                                  "{}_{}.pt".format(cfg.model_name, fold))
        local_models[fold] = local_file
        download_file(url, local_file)

    conf_paths = []
    for fold, model_name in local_models.items():
        cfg.model_name = model_name.replace(".pt", "")
        cfg.tracker_options.full_res = True
        trainer = Trainer(cfg)
        assert str(trainer._checkpoint.data_config.fold) == fold
        trainer.eval(stage_name="test")

        conf_path = os.path.join(workdir, "{}.npy".format(cfg.model_name))
        np.save(conf_path,
                trainer._tracker.full_confusion_matrix.get_confusion_matrix())
        conf_paths.append(conf_path)

    confusion_matrix = ConfusionMatrix.create_from_matrix(
        np.sum([np.load(p) for p in conf_paths], axis=0))
    log_confusion_matrix(confusion_matrix)
Пример #2
0
def main(cfg):
    OmegaConf.set_struct(
        cfg,
        False)  # This allows getattr and hasattr methods to function correctly
    print(cfg.pretty())

    trainer = Trainer(cfg)
    trainer.eval()
def main(cfg):
    OmegaConf.set_struct(cfg, False)  # This allows getattr and hasattr methods to function correctly
    if cfg.pretty_print:
        print(OmegaConf.to_yaml(cfg))

    trainer = Trainer(cfg)
    trainer.train()
    #
    # # https://github.com/facebookresearch/hydra/issues/440
    GlobalHydra.get_state().clear()
    return 0
Пример #4
0
    def test_trainer_on_scannet_object_detection(self):
        self.path_outputs = os.path.join(DIR_PATH,
                                         "data/scannet-fixed/outputs")
        if not os.path.exists(self.path_outputs):
            os.makedirs(self.path_outputs)
        os.chdir(self.path_outputs)

        cfg = OmegaConf.load(
            os.path.join(DIR_PATH,
                         "data/scannet-fixed/config_object_detection.yaml"))
        cfg.training.epochs = 2
        cfg.training.num_workers = 0
        cfg.data.is_test = True
        cfg.data.dataroot = os.path.join(DIR_PATH, "data/")
        trainer = Trainer(cfg)
        trainer.train()
Пример #5
0
    def test_trainer_on_shapenet_fixed(self):
        self.path_outputs = os.path.join(DIR_PATH, "data/shapenet/outputs")
        if not os.path.exists(self.path_outputs):
            os.makedirs(self.path_outputs)
        os.chdir(self.path_outputs)

        cfg = OmegaConf.load(
            os.path.join(DIR_PATH, "data/shapenet/shapenet_config.yaml"))
        cfg.training.epochs = 2
        cfg.training.num_workers = 0
        cfg.data.is_test = True
        cfg.data.dataroot = os.path.join(DIR_PATH, "data/")

        trainer = Trainer(cfg)
        trainer.train()

        self.assertEqual(trainer.early_break, True)
        self.assertEqual(trainer.profiling, False)
        self.assertEqual(trainer.precompute_multi_scale, False)
        self.assertEqual(trainer.wandb_log, False)

        keys = [k for k in trainer._tracker.get_metrics().keys()]
        self.assertEqual(keys, ["test_loss_seg", "test_Cmiou", "test_Imiou"])
        trainer._cfg.voting_runs = 2
        trainer.eval()
Пример #6
0
 def test_trainer_on_scannet_segmentation(self):
     self.path_outputs = os.path.join(DIR_PATH, "data/scannet/outputs")
     if not os.path.exists(self.path_outputs):
         os.makedirs(self.path_outputs)
     os.chdir(self.path_outputs)
     cfg = OmegaConf.load(
         os.path.join(DIR_PATH, "data/scannet/config_segmentation.yaml"))
     cfg.training.epochs = 2
     cfg.training.num_workers = 0
     cfg.data.is_test = True
     cfg.data.dataroot = os.path.join(DIR_PATH, "data/")
     trainer = Trainer(cfg)
     trainer.train()
     trainer._cfg.voting_runs = 2
     trainer._cfg.tracker_options.full_res = True
     trainer._cfg.tracker_options.make_submission = True
     trainer.eval()