Example #1
0
    def __init__(self, cfg, **kwargs):
        """

    """
        # fmt: off
        self.torch_fid_stat = cfg.torch_fid_stat
        self.num_inception_images = get_attr_kwargs(cfg,
                                                    'num_inception_images',
                                                    default=50000,
                                                    **kwargs)
        self.IS_splits = get_attr_kwargs(cfg,
                                         'IS_splits',
                                         default=10,
                                         **kwargs)
        self.calculate_FID_use_torch = get_attr_kwargs(
            cfg, 'calculate_FID_use_torch', default=False, **kwargs)
        self.no_FID = get_attr_kwargs(cfg, 'no_FID', default=False, **kwargs)
        # fmt: on

        self.logger = logging.getLogger('tl')
        if os.path.isfile(self.torch_fid_stat):
            self.logger.info(f"Loading torch_fid_stat : {self.torch_fid_stat}")
            self.data_mu = np.load(self.torch_fid_stat)['mu']
            self.data_sigma = np.load(self.torch_fid_stat)['sigma']
        else:
            self.logger.warning(
                f"torch_fid_stat does not exist: {self.torch_fid_stat}")

        # Load inception_v3 network
        self.inception_net = self._load_inception_net()

        ws = comm.get_world_size()
        self.num_inception_images = self.num_inception_images // ws
        pass
  def test_case_calculate_fid_stat_CIFAR():
    from template_lib.d2.data import build_dataset_mapper
    from template_lib.d2template.trainer.base_trainer import build_detection_test_loader
    from template_lib.v2.GAN.evaluation import build_GAN_metric
    from template_lib.d2.utils.d2_utils import D2Utils
    from template_lib.v2.config_cfgnode import global_cfg

    from detectron2.utils import logger
    logger.setup_logger('d2')

    cfg = D2Utils.create_cfg()
    cfg.update(global_cfg)
    global_cfg.merge_from_dict(cfg)

    # fmt: off
    dataset_name                 = cfg.dataset_name
    IMS_PER_BATCH                = cfg.IMS_PER_BATCH
    img_size                     = cfg.img_size
    dataset_mapper_cfg           = cfg.dataset_mapper_cfg
    GAN_metric                   = cfg.GAN_metric
    # fmt: on

    num_workers = comm.get_world_size()
    batch_size = IMS_PER_BATCH // num_workers

    dataset_mapper = build_dataset_mapper(dataset_mapper_cfg, img_size=img_size)
    data_loader = build_detection_test_loader(
      cfg, dataset_name=dataset_name, batch_size=batch_size, mapper=dataset_mapper)

    FID_IS_tf = build_GAN_metric(GAN_metric)
    FID_IS_tf.calculate_fid_stat_of_dataloader(data_loader=data_loader)

    comm.synchronize()

    pass
  def __init__(self, cfg, **kwargs):

    # fmt: off
    self.tf_inception_model_dir       = cfg.tf_inception_model_dir
    self.tf_fid_stat                  = cfg.tf_fid_stat
    self.num_inception_images         = getattr(cfg, 'num_inception_images', 50000)
    self.IS_splits                    = getattr(cfg, 'IS_splits', 10)
    # fmt: on

    self.logger = logging.getLogger('tl')
    ws = comm.get_world_size()
    self.num_inception_images = self.num_inception_images // ws
    self.tf_graph_name = 'FID_IS_Inception_Net'
    if os.path.isfile(self.tf_fid_stat):
      self.logger.info(f'Loading tf_fid_stat: {self.tf_fid_stat}')
      f = np.load(self.tf_fid_stat)
      self.mu_data, self.sigma_data = f['mu'][:], f['sigma'][:]
      f.close()
    else:
      self.logger.warning(f"tf_fid_stat does not exist: {self.tf_fid_stat}")

    self.tf_inception_model_dir = os.path.expanduser(self.tf_inception_model_dir)
    inception_path = self._check_or_download_inception(self.tf_inception_model_dir)
    self.logger.info('Load tf inception model in %s', inception_path)
    self._create_inception_graph(inception_path, name=self.tf_graph_name)
    self._create_inception_net()
    comm.synchronize()
Example #4
0
def get_sample_imgs_list_ddp(sample_func,
                             num_imgs=50000,
                             as_numpy=True,
                             stdout=sys.stdout):
    """

  :param sample_func:
  :param num_imgs:
  :param stdout:
  :return:
  """
    import torch
    from template_lib.d2.utils import comm

    ws = comm.get_world_size()
    num_imgs = num_imgs // ws
    imgs = get_sample_imgs_list(sample_func=sample_func,
                                num_imgs=num_imgs,
                                as_numpy=as_numpy,
                                stdout=stdout)

    imgs_list = comm.gather(imgs)
    if len(imgs_list) > 0:
        if as_numpy:
            imgs = np.concatenate(imgs_list, axis=0)
        else:
            imgs = torch.cat(imgs_list, dim=0)
    return imgs
Example #5
0
    def __init__(self, cfg):
        """

    """
        self.torch_fid_stat = cfg.GAN_metric.torch_fid_stat
        self.num_inception_images = getattr(cfg.GAN_metric,
                                            'num_inception_images', 50000)
        self.IS_splits = getattr(cfg.GAN_metric, 'IS_splits', 10)
        self.calculate_FID_use_torch = getattr(cfg.GAN_metric,
                                               'calculate_FID_use_torch',
                                               False)
        self.no_FID = getattr(cfg.GAN_metric, 'no_FID', False)

        self.logger = logging.getLogger('tl')
        if os.path.isfile(self.torch_fid_stat):
            self.logger.info(f"Loading torch_fid_stat : {self.torch_fid_stat}")
            self.data_mu = np.load(self.torch_fid_stat)['mu']
            self.data_sigma = np.load(self.torch_fid_stat)['sigma']
        else:
            self.logger.warning(
                f"torch_fid_stat does not exist: {self.torch_fid_stat}")

        # Load inception_v3 network
        self.inception_net = self._load_inception_net()

        ws = comm.get_world_size()
        self.num_inception_images = self.num_inception_images // ws
        pass
Example #6
0
    def test_case_calculate_fid_stat_CIFAR10():
        from template_lib.d2.data import build_dataset_mapper
        from template_lib.d2template.trainer.base_trainer import build_detection_test_loader
        from template_lib.gans.evaluation import build_GAN_metric
        from template_lib.utils.detection2_utils import D2Utils
        from template_lib.d2.data import build_cifar10

        cfg_str = """
                  update_cfg: true
                  dataset_name: "cifar10_train"
                  IMS_PER_BATCH: 32
                  img_size: 32
                  NUM_WORKERS: 0
                  dataset_mapper_cfg:
                    name: CIFAR10DatasetMapper
                  GAN_metric:
                    tf_fid_stat: "datasets/fid_stats_tf_cifar10.npz"
              """
        config = EasyDict(yaml.safe_load(cfg_str))
        config = TFFIDISScore.update_cfg(config)

        cfg = D2Utils.create_cfg()
        cfg = D2Utils.cfg_merge_from_easydict(cfg, config)

        # fmt: off
        dataset_name = cfg.dataset_name
        IMS_PER_BATCH = cfg.IMS_PER_BATCH
        img_size = cfg.img_size
        NUM_WORKERS = cfg.NUM_WORKERS
        dataset_mapper_cfg = cfg.dataset_mapper_cfg
        GAN_metric = cfg.GAN_metric
        # fmt: on

        cfg.defrost()
        cfg.DATALOADER.NUM_WORKERS = NUM_WORKERS
        cfg.GAN_metric.tf_fid_stat = cfg.GAN_metric.tf_fid_stat.format(
            dataset_name=dataset_name, img_size=img_size)
        cfg.freeze()

        num_workers = comm.get_world_size()
        batch_size = IMS_PER_BATCH // num_workers

        dataset_mapper = build_dataset_mapper(dataset_mapper_cfg,
                                              img_size=img_size)
        data_loader = build_detection_test_loader(cfg,
                                                  dataset_name=dataset_name,
                                                  batch_size=batch_size,
                                                  mapper=dataset_mapper)

        FID_IS_tf = build_GAN_metric(GAN_metric)
        FID_IS_tf.calculate_fid_stat_of_dataloader(data_loader=data_loader)

        comm.synchronize()

        pass
  def __call__(self, sample_func, return_fid_stat=False, num_inception_images=None,
               return_fid_logit=False, stdout=sys.stdout):
    import torch

    class SampleClass(object):
      def __init__(self, sample_func):
        self.sample_func = sample_func

      def __call__(self, *args, **kwargs):
        """
        :return: images: [0, 255]
        """
        images = self.sample_func()
        images = images.mul_(127.5).add_(127.5).clamp_(0.0, 255.0).permute(0, 2, 3, 1).type(torch.uint8)
        images = images.cpu().numpy()
        return images

    sample_func = SampleClass(sample_func)

    if num_inception_images is None:
      num_inception_images = self.num_inception_images
    else:
      num_inception_images = num_inception_images // comm.get_world_size()

    pred_FIDs, pred_ISs = self._get_activations_with_sample_func(
      sample_func=sample_func, num_inception_images=num_inception_images, stdout=stdout)

    if return_fid_stat:
      if comm.is_main_process():
        self.logger.info(f"Num of images: {len(pred_FIDs)}")
        mu, sigma = self._calculate_fid_stat(pred_FIDs=pred_FIDs)
      else:
        mu, sigma = 0, 0
      if return_fid_logit:
        return mu, sigma, pred_FIDs
      else:
        return mu, sigma

    if comm.is_main_process():
      self.logger.info(f"Num of images: {len(pred_FIDs)}")
      IS_mean_tf, IS_std_tf = self._calculate_IS(pred_ISs=pred_ISs, IS_splits=self.IS_splits)

      # calculate FID stat
      mu = np.mean(pred_FIDs, axis=0)
      sigma = np.cov(pred_FIDs, rowvar=False)
      FID_tf = calculate_frechet_distance(mu, sigma, self.mu_data, self.sigma_data)

    else:
      FID_tf = IS_mean_tf = IS_std_tf = 0

    del pred_FIDs, pred_ISs
    comm.synchronize()
    return FID_tf, IS_mean_tf, IS_std_tf
Example #8
0
  def load_inception_net(parallel, device='cuda'):
    net = load_inception_net(parallel=False)
    net = net.to(device)

    if parallel:
      if not comm.get_world_size() > 1:
        return
      pg = torch.distributed.new_group(range(torch.distributed.get_world_size()))
      net = DistributedDataParallel(
        net, device_ids=[dist.get_rank()], broadcast_buffers=False,
        process_group=pg, check_reduction=False
      )
    return net
Example #9
0
  def __init__(self, saved_inception_moments, parallel=False, device='cuda'):
    """

    """
    self.logger = logging.getLogger('tl')
    saved_inception_moments = os.path.expanduser(saved_inception_moments)
    self.data_mu = np.load(saved_inception_moments)['mu']
    self.data_sigma = np.load(saved_inception_moments)['sigma']
    # Load network
    self.parallel = parallel
    if not comm.get_world_size() > 1:
      self.parallel = False
    self.net = self.load_inception_net(parallel=self.parallel)
    pass