def get_config(config_url): parts = urlparse(config_url) cfg_name = os.path.basename(parts.path) assert cfg_name is not None cfg_path = pathlib.Path(torch.hub._get_torch_home(), "deep_privacy_cache", cfg_name) cfg_path.parent.mkdir(exist_ok=True, parents=True) if not cfg_path.is_file(): torch.hub.download_url_to_file(config_url, cfg_path) assert cfg_path.is_file() return Config.fromfile(cfg_path)
def build_anonymizer(model_name=available_models[0], batch_size: int = 1, fp16_inference: bool = True, truncation_level: float = 0, detection_threshold: float = .1, opts: str = None, config_path: str = None, return_cfg=False) -> DeepPrivacyAnonymizer: """ Builds anonymizer with detector and generator from checkpoints. Args: config_path: If not None, will override model_name opts: if not None, can override default settings. For example: opts="anonymizer.truncation_level=5, anonymizer.batch_size=32" """ if config_path is None: print(config_path) assert model_name in available_models,\ f"{model_name} not in available models: {available_models}" cfg = get_config(config_urls[model_name]) else: cfg = Config.fromfile(config_path) logger.info("Loaded model:" + cfg.model_name) generator = load_model_from_checkpoint(cfg) logger.info( f"Generator initialized with {torch_utils.number_of_parameters(generator)/1e6:.2f}M parameters" ) cfg.anonymizer.truncation_level = truncation_level cfg.anonymizer.batch_size = batch_size cfg.anonymizer.fp16_inference = fp16_inference cfg.anonymizer.detector_cfg.face_detector_cfg.confidence_threshold = detection_threshold cfg.merge_from_str(opts) anonymizer = DeepPrivacyAnonymizer(generator, cfg=cfg, **cfg.anonymizer) if return_cfg: return anonymizer, cfg return anonymizer
new_sd[key.replace("encoder.", "").replace(".conv.layers", "")] = value continue for subkey, new_subkey in mapping.items(): if subkey in key: old_key = key key = key.replace(subkey, new_subkey) break if "decoder.to_rgb" in key: continue new_sd[key] = value return super().load_state_dict(new_sd, strict=strict) if __name__ == "__main__": from deep_privacy.config import Config, default_parser args = default_parser().parse_args() cfg = Config.fromfile(args.config_path) g = MSGGenerator(cfg).cuda() g.extend() g.cuda() imsize = g.current_imsize batch = dict( mask=torch.ones((8, 1, imsize, imsize)).cuda(), condition=torch.randn((8, 3, imsize, imsize)).cuda(), landmarks=torch.randn((8, 14)).cuda() ) print(g(**batch).shape)