Example #1
0
def get_target_paths(source_paths: typing.List[pathlib.Path],
                     target_path: str,
                     default_dir: pathlib.Path):
    if not target_path is None:
        target_path = pathlib.Path(target_path)
        if len(source_paths) > 1:
            target_path.mkdir(exist_ok=True, parents=True)
            target_paths = []
            for source_path in source_paths:
                target_paths.append(target_path.joinpath(source_path.name))
            return target_paths
        else:
            target_path.parent.mkdir(exist_ok=True)
            return [target_path]
    logger.info(
        f"Found no target path. Setting to default output path: {default_dir}")
    default_target_dir = default_dir
    target_path = default_target_dir
    target_path.mkdir(exist_ok=True, parents=True)
    target_paths = []
    for source_path in source_paths:
        if source_path.suffix in video_suffix:
            target_path = default_target_dir.joinpath("anonymized_videos")
        else:
            target_path = default_target_dir.joinpath("anonymized_images")
        target_path = target_path.joinpath(source_path.name)
        os.makedirs(target_path.parent, exist_ok=True)
        target_paths.append(target_path)
    return target_paths
Example #2
0
    def calculate_fid(self):
        logger.info("Starting calculation of FID value")
        generator = self.trainer.RA_generator
        real_images, fake_images = infer.infer_images(
            self.trainer.dataloader_val, generator,
            truncation_level=0
        )
        """
        # Remove FID calculation as holy shit this is expensive.
        cfg = self.trainer.cfg
        identifier = f"{cfg.dataset_type}_{cfg.data_val.dataset.percentage}_{self.current_imsize()}"
        transition_value = self.trainer.RA_generator.transition_value
        fid_val = metric_api.fid(
            real_images, fake_images,
            batch_size=self.fid_batch_size)
        logger.log_variable("stats/fid", np.mean(fid_val),
                            log_level=logging.INFO)
        """

        l1 = metric_api.l1(real_images, fake_images)
        l2 = metric_api.l1(real_images, fake_images)
        psnr = metric_api.psnr(real_images, fake_images)
        lpips = metric_api.lpips(
            real_images, fake_images, self.lpips_batch_size)
        logger.log_variable("stats/l1", l1, log_level=logging.INFO)
        logger.log_variable("stats/l2", l2, log_level=logging.INFO)
        logger.log_variable("stats/psnr", psnr, log_level=logging.INFO)
        logger.log_variable("stats/lpips", lpips, log_level=logging.INFO)
Example #3
0
    def _grow_phase(self):
        # Log transition value here to not create misguiding representation on
        # tensorboard
        if self.transition_value is not None:
            logger.log_variable("stats/transition-value",
                                self.get_transition_value())

        self._update_transition_value()
        transition_iters = self.transition_iters
        minibatch_repeats = self.cfg.trainer.progressive.minibatch_repeats
        next_transition = self.prev_transition + transition_iters
        num_batches = (next_transition - self.global_step) / self.batch_size()
        num_batches = int(np.ceil(num_batches))
        num_repeats = int(np.ceil(num_batches / minibatch_repeats))
        logger.info(
            f"Starting grow phase for imsize={self.current_imsize()}" +
            f" Training for {num_batches} batches with batch size: {self.batch_size()}"
        )
        for it in range(num_repeats):
            for _ in range(
                    min(minibatch_repeats,
                        num_batches - it * minibatch_repeats)):
                self.train_step()
            self._update_transition_value()
        # Check that grow phase happens at correct spot
        assert self.global_step >= self.prev_transition + transition_iters,\
            f"Global step: {self.global_step}, batch size: {self.batch_size()}, prev_transition: {self.prev_transition}" +\
            f" transition iters: {transition_iters}"
        assert self.global_step - self.batch_size() <= self.prev_transition + transition_iters,\
            f"Global step: {self.global_step}, batch size: {self.batch_size()}, prev_transition: {self.prev_transition}" +\
            f" transition iters: {transition_iters}"
Example #4
0
def load_checkpoint(ckpt_dir_or_file: pathlib.Path) -> dict:
    if ckpt_dir_or_file.is_dir():
        with open(ckpt_dir_or_file.joinpath('latest_checkpoint')) as f:
            ckpt_path = f.readline().strip()
            ckpt_path = ckpt_dir_or_file.joinpath(ckpt_path)
    else:
        ckpt_path = ckpt_dir_or_file
    if not ckpt_path.is_file():
        raise FileNotFoundError(f"Did not find path: {ckpt_path}")
    ckpt = torch.load(ckpt_path, map_location=_get_map_location())
    logger.info(f"Loaded checkpoint from {ckpt_path}")
    return ckpt
Example #5
0
 def save_validation_checkpoint(self):
     checkpoints = [12, 20, 30, 40, 50]
     for checkpoint_step in checkpoints:
         checkpoint_step = checkpoint_step * 10**6
         previous_global_step = self.global_step(
         ) - self.trainer.batch_size()
         if self.global_step(
         ) >= checkpoint_step and previous_global_step < checkpoint_step:
             logger.info("Saving global checkpoint for validation")
             filepath = self.validation_checkpoint_dir.joinpath(
                 f"step_{self.global_step()}.ckpt")
             self.trainer.save_checkpoint(filepath,
                                          max_keep=len(checkpoints) + 1)
Example #6
0
    def __init__(self, dirpath, imsize: int, transform, percentage: float):
        dirpath = pathlib.Path(dirpath)
        self.dirpath = dirpath
        self.transform = transform
        self._percentage = percentage
        self.imsize = imsize
        assert self.dirpath.is_dir(),\
            f"Did not find dataset at: {dirpath}"
        self.image_paths = self._load_impaths()
        self.filter_images()

        logger.info(
            f"Dataset loaded from: {dirpath}. Number of samples:{len(self)}, imsize={imsize}"
        )
Example #7
0
 def init_models(self):
     self.discriminator = models.build_discriminator(
         self.cfg, data_parallel=torch.cuda.device_count() > 1)
     self.generator = models.build_generator(
         self.cfg, data_parallel=torch.cuda.device_count() > 1)
     self.RA_generator = models.build_generator(
         self.cfg, data_parallel=torch.cuda.device_count() > 1)
     self.RA_generator = torch_utils.to_cuda(self.RA_generator)
     self.RA_generator.load_state_dict(self.generator.state_dict())
     logger.info(str(self.generator))
     logger.info(str(self.discriminator))
     logger.log_variable(
         "stats/discriminator_parameters",
         torch_utils.number_of_parameters(self.discriminator))
     logger.log_variable("stats/generator_parameters",
                         torch_utils.number_of_parameters(self.generator))
Example #8
0
    def _stability_phase(self):
        self._update_transition_value()
        assert self.get_transition_value() == 1.0

        if self.prev_transition == 0:
            next_transition = self.transition_iters
        else:
            next_transition = self.prev_transition + self.transition_iters * 2

        num_batches = (next_transition - self.global_step) / self.batch_size()
        num_batches = int(np.ceil(num_batches))
        assert num_batches > 0
        logger.info(
            f"Starting stability phase for imsize={self.current_imsize()}" +
            f" Training for {num_batches} batches with batch size: {self.batch_size()}"
        )
        for it in range(num_batches):
            self.train_step()
Example #9
0
 def save_checkpoint(self,
                     state_dict: dict,
                     filepath: pathlib.Path = None,
                     max_keep=2):
     if filepath is None:
         global_step = self.trainer.global_step
         filename = f"step_{global_step}.ckpt"
         filepath = self.checkpoint_dir.joinpath(filename)
     list_path = filepath.parent.joinpath("latest_checkpoint")
     torch.save(state_dict, filepath)
     previous_checkpoints = get_previous_checkpoints(filepath)
     if filepath.name not in previous_checkpoints:
         previous_checkpoints = [filepath.name] + previous_checkpoints
     if len(previous_checkpoints) > max_keep:
         for ckpt in previous_checkpoints[max_keep:]:
             path = self.checkpoint_dir.joinpath(ckpt)
             if path.exists():
                 logger.info(f"Removing old checkpoint: {path}")
                 path.unlink()
     previous_checkpoints = previous_checkpoints[:max_keep]
     with open(list_path, 'w') as fp:
         fp.write("\n".join(previous_checkpoints))
     logger.info(f"Saved checkpoint to: {filepath}")
Example #10
0
def build_anonymizer(model_name=available_models[0],
                     batch_size: int = 1,
                     fp16_inference: bool = True,
                     truncation_level: float = 0,
                     detection_threshold: float = .1,
                     opts: str = None,
                     config_path: str = None,
                     return_cfg=False) -> DeepPrivacyAnonymizer:
    """
        Builds anonymizer with detector and generator from checkpoints.

        Args:
            config_path: If not None, will override model_name
            opts: if not None, can override default settings. For example:
                opts="anonymizer.truncation_level=5, anonymizer.batch_size=32"
    """
    if config_path is None:
        print(config_path)
        assert model_name in available_models,\
            f"{model_name} not in available models: {available_models}"
        cfg = get_config(config_urls[model_name])
    else:
        cfg = Config.fromfile(config_path)
    logger.info("Loaded model:" + cfg.model_name)
    generator = load_model_from_checkpoint(cfg)
    logger.info(
        f"Generator initialized with {torch_utils.number_of_parameters(generator)/1e6:.2f}M parameters"
    )
    cfg.anonymizer.truncation_level = truncation_level
    cfg.anonymizer.batch_size = batch_size
    cfg.anonymizer.fp16_inference = fp16_inference
    cfg.anonymizer.detector_cfg.face_detector_cfg.confidence_threshold = detection_threshold
    cfg.merge_from_str(opts)
    anonymizer = DeepPrivacyAnonymizer(generator, cfg=cfg, **cfg.anonymizer)
    if return_cfg:
        return anonymizer, cfg
    return anonymizer
Example #11
0
def get_target_paths(source_paths: typing.List[pathlib.Path], target_path: str,
                     default_dir: pathlib.Path):
    if not target_path is None:
        target_path = pathlib.Path(target_path)
        assert len(source_paths) == 1, \
            f"Target path is 1 file, but expected several inputs" + \
            f"target path={target_path}, source_path={source_paths}"
        target_path.parent.mkdir(exist_ok=True)
        return [target_path]
    logger.info(
        f"Found no target path. Setting to default output path: {default_dir}")
    default_target_dir = default_dir
    target_path = default_target_dir
    target_path.mkdir(exist_ok=True, parents=True)
    target_paths = []
    for source_path in source_paths:
        if source_path.suffix in video_suffix:
            target_path = default_target_dir.joinpath("anonymized_videos")
        else:
            target_path = default_target_dir.joinpath("anonymized_images")
        target_path = target_path.joinpath(source_path.name)
        os.makedirs(target_path.parent, exist_ok=True)
        target_paths.append(target_path)
    return target_paths
 def save_checkpoint(self, filepath=None, max_keep=2):
     logger.info(f"Saving checkpoint to: {filepath}")
     state_dict = self.state_dict()
     self.checkpointer.save_checkpoint(state_dict, filepath, max_keep)
Example #13
0
 def after_step(self):
     if self.sigterm_received:
         logger.info("[SIGTERM RECEIVED] Stopping train.")
         self.trainer.save_checkpoint(max_keep=3)
         exit()
Example #14
0
 def handle_sigterm(self, signum, frame):
     logger.info(
         "[SIGTERM RECEVIED] Received sigterm. Stopping train after step.")
     self.sigterm_received = True
     exit()