Ejemplo n.º 1
0
def get_config(config_url):
    parts = urlparse(config_url)
    cfg_name = os.path.basename(parts.path)
    assert cfg_name is not None
    cfg_path = pathlib.Path(torch.hub._get_torch_home(), "deep_privacy_cache",
                            cfg_name)
    cfg_path.parent.mkdir(exist_ok=True, parents=True)
    if not cfg_path.is_file():
        torch.hub.download_url_to_file(config_url, cfg_path)
    assert cfg_path.is_file()
    return Config.fromfile(cfg_path)
Ejemplo n.º 2
0
def build_anonymizer(model_name=available_models[0],
                     batch_size: int = 1,
                     fp16_inference: bool = True,
                     truncation_level: float = 0,
                     detection_threshold: float = .1,
                     opts: str = None,
                     config_path: str = None,
                     return_cfg=False) -> DeepPrivacyAnonymizer:
    """
        Builds anonymizer with detector and generator from checkpoints.

        Args:
            config_path: If not None, will override model_name
            opts: if not None, can override default settings. For example:
                opts="anonymizer.truncation_level=5, anonymizer.batch_size=32"
    """
    if config_path is None:
        print(config_path)
        assert model_name in available_models,\
            f"{model_name} not in available models: {available_models}"
        cfg = get_config(config_urls[model_name])
    else:
        cfg = Config.fromfile(config_path)
    logger.info("Loaded model:" + cfg.model_name)
    generator = load_model_from_checkpoint(cfg)
    logger.info(
        f"Generator initialized with {torch_utils.number_of_parameters(generator)/1e6:.2f}M parameters"
    )
    cfg.anonymizer.truncation_level = truncation_level
    cfg.anonymizer.batch_size = batch_size
    cfg.anonymizer.fp16_inference = fp16_inference
    cfg.anonymizer.detector_cfg.face_detector_cfg.confidence_threshold = detection_threshold
    cfg.merge_from_str(opts)
    anonymizer = DeepPrivacyAnonymizer(generator, cfg=cfg, **cfg.anonymizer)
    if return_cfg:
        return anonymizer, cfg
    return anonymizer
Ejemplo n.º 3
0
                new_sd[key.replace("encoder.", "").replace(".conv.layers", "")] = value
                continue
            for subkey, new_subkey in mapping.items():
                if subkey in key:
                    old_key = key
                    key = key.replace(subkey, new_subkey)

                    break
            if "decoder.to_rgb" in key:
                continue

            new_sd[key] = value
        return super().load_state_dict(new_sd, strict=strict)
        

if __name__ == "__main__":
    from deep_privacy.config import Config, default_parser
    args = default_parser().parse_args()
    cfg = Config.fromfile(args.config_path)

    g = MSGGenerator(cfg).cuda()
    g.extend()
    g.cuda()
    imsize = g.current_imsize
    batch = dict(
        mask=torch.ones((8, 1, imsize, imsize)).cuda(),
        condition=torch.randn((8, 3, imsize, imsize)).cuda(),
        landmarks=torch.randn((8, 14)).cuda()
    )
    print(g(**batch).shape)