コード例 #1
0
def load_detector(run_id):
    run_dir = EXP_DIR / run_id
    cfg = yaml.load((run_dir / 'config.yaml').read_text(), Loader=yaml.FullLoader)
    cfg = check_update_config(cfg)
    label_to_category_id = cfg.label_to_category_id
    model = create_model_detector(cfg, len(label_to_category_id))
    ckpt = torch.load(run_dir / 'checkpoint.pth.tar')
    ckpt = ckpt['state_dict']
    model.load_state_dict(ckpt)
    model = model.cuda().eval()
    model.cfg = cfg
    model.config = cfg
    model = Detector(model)
    return model
コード例 #2
0
def make_eval_configs(args, model_training, epoch):
    model = model_training.module
    model.config = args
    model.cfg = args
    detector = Detector(model)

    configs = []
    for ds_name in args.test_ds_names:
        cfg = argparse.ArgumentParser('').parse_args([])
        cfg.ds_name = ds_name
        cfg.save_dir = args.save_dir / f'dataset={ds_name}/epoch={epoch}'
        cfg.n_workers = args.n_dataloader_workers
        cfg.pred_bsz = 16
        cfg.eval_bsz = 16
        cfg.n_frames = None
        cfg.skip_evaluation = False
        cfg.skip_model_predictions = False
        cfg.external_predictions = True
        cfg.n_frames = args.n_test_frames
        configs.append(cfg)
    return configs, detector