def __init__(self, cfg, **kwargs): """ """ # fmt: off self.torch_fid_stat = cfg.torch_fid_stat self.num_inception_images = get_attr_kwargs(cfg, 'num_inception_images', default=50000, **kwargs) self.IS_splits = get_attr_kwargs(cfg, 'IS_splits', default=10, **kwargs) self.calculate_FID_use_torch = get_attr_kwargs( cfg, 'calculate_FID_use_torch', default=False, **kwargs) self.no_FID = get_attr_kwargs(cfg, 'no_FID', default=False, **kwargs) # fmt: on self.logger = logging.getLogger('tl') if os.path.isfile(self.torch_fid_stat): self.logger.info(f"Loading torch_fid_stat : {self.torch_fid_stat}") self.data_mu = np.load(self.torch_fid_stat)['mu'] self.data_sigma = np.load(self.torch_fid_stat)['sigma'] else: self.logger.warning( f"torch_fid_stat does not exist: {self.torch_fid_stat}") # Load inception_v3 network self.inception_net = self._load_inception_net() ws = comm.get_world_size() self.num_inception_images = self.num_inception_images // ws pass
def test_case_calculate_fid_stat_CIFAR(): from template_lib.d2.data import build_dataset_mapper from template_lib.d2template.trainer.base_trainer import build_detection_test_loader from template_lib.v2.GAN.evaluation import build_GAN_metric from template_lib.d2.utils.d2_utils import D2Utils from template_lib.v2.config_cfgnode import global_cfg from detectron2.utils import logger logger.setup_logger('d2') cfg = D2Utils.create_cfg() cfg.update(global_cfg) global_cfg.merge_from_dict(cfg) # fmt: off dataset_name = cfg.dataset_name IMS_PER_BATCH = cfg.IMS_PER_BATCH img_size = cfg.img_size dataset_mapper_cfg = cfg.dataset_mapper_cfg GAN_metric = cfg.GAN_metric # fmt: on num_workers = comm.get_world_size() batch_size = IMS_PER_BATCH // num_workers dataset_mapper = build_dataset_mapper(dataset_mapper_cfg, img_size=img_size) data_loader = build_detection_test_loader( cfg, dataset_name=dataset_name, batch_size=batch_size, mapper=dataset_mapper) FID_IS_tf = build_GAN_metric(GAN_metric) FID_IS_tf.calculate_fid_stat_of_dataloader(data_loader=data_loader) comm.synchronize() pass
def __init__(self, cfg, **kwargs): # fmt: off self.tf_inception_model_dir = cfg.tf_inception_model_dir self.tf_fid_stat = cfg.tf_fid_stat self.num_inception_images = getattr(cfg, 'num_inception_images', 50000) self.IS_splits = getattr(cfg, 'IS_splits', 10) # fmt: on self.logger = logging.getLogger('tl') ws = comm.get_world_size() self.num_inception_images = self.num_inception_images // ws self.tf_graph_name = 'FID_IS_Inception_Net' if os.path.isfile(self.tf_fid_stat): self.logger.info(f'Loading tf_fid_stat: {self.tf_fid_stat}') f = np.load(self.tf_fid_stat) self.mu_data, self.sigma_data = f['mu'][:], f['sigma'][:] f.close() else: self.logger.warning(f"tf_fid_stat does not exist: {self.tf_fid_stat}") self.tf_inception_model_dir = os.path.expanduser(self.tf_inception_model_dir) inception_path = self._check_or_download_inception(self.tf_inception_model_dir) self.logger.info('Load tf inception model in %s', inception_path) self._create_inception_graph(inception_path, name=self.tf_graph_name) self._create_inception_net() comm.synchronize()
def get_sample_imgs_list_ddp(sample_func, num_imgs=50000, as_numpy=True, stdout=sys.stdout): """ :param sample_func: :param num_imgs: :param stdout: :return: """ import torch from template_lib.d2.utils import comm ws = comm.get_world_size() num_imgs = num_imgs // ws imgs = get_sample_imgs_list(sample_func=sample_func, num_imgs=num_imgs, as_numpy=as_numpy, stdout=stdout) imgs_list = comm.gather(imgs) if len(imgs_list) > 0: if as_numpy: imgs = np.concatenate(imgs_list, axis=0) else: imgs = torch.cat(imgs_list, dim=0) return imgs
def __init__(self, cfg): """ """ self.torch_fid_stat = cfg.GAN_metric.torch_fid_stat self.num_inception_images = getattr(cfg.GAN_metric, 'num_inception_images', 50000) self.IS_splits = getattr(cfg.GAN_metric, 'IS_splits', 10) self.calculate_FID_use_torch = getattr(cfg.GAN_metric, 'calculate_FID_use_torch', False) self.no_FID = getattr(cfg.GAN_metric, 'no_FID', False) self.logger = logging.getLogger('tl') if os.path.isfile(self.torch_fid_stat): self.logger.info(f"Loading torch_fid_stat : {self.torch_fid_stat}") self.data_mu = np.load(self.torch_fid_stat)['mu'] self.data_sigma = np.load(self.torch_fid_stat)['sigma'] else: self.logger.warning( f"torch_fid_stat does not exist: {self.torch_fid_stat}") # Load inception_v3 network self.inception_net = self._load_inception_net() ws = comm.get_world_size() self.num_inception_images = self.num_inception_images // ws pass
def test_case_calculate_fid_stat_CIFAR10(): from template_lib.d2.data import build_dataset_mapper from template_lib.d2template.trainer.base_trainer import build_detection_test_loader from template_lib.gans.evaluation import build_GAN_metric from template_lib.utils.detection2_utils import D2Utils from template_lib.d2.data import build_cifar10 cfg_str = """ update_cfg: true dataset_name: "cifar10_train" IMS_PER_BATCH: 32 img_size: 32 NUM_WORKERS: 0 dataset_mapper_cfg: name: CIFAR10DatasetMapper GAN_metric: tf_fid_stat: "datasets/fid_stats_tf_cifar10.npz" """ config = EasyDict(yaml.safe_load(cfg_str)) config = TFFIDISScore.update_cfg(config) cfg = D2Utils.create_cfg() cfg = D2Utils.cfg_merge_from_easydict(cfg, config) # fmt: off dataset_name = cfg.dataset_name IMS_PER_BATCH = cfg.IMS_PER_BATCH img_size = cfg.img_size NUM_WORKERS = cfg.NUM_WORKERS dataset_mapper_cfg = cfg.dataset_mapper_cfg GAN_metric = cfg.GAN_metric # fmt: on cfg.defrost() cfg.DATALOADER.NUM_WORKERS = NUM_WORKERS cfg.GAN_metric.tf_fid_stat = cfg.GAN_metric.tf_fid_stat.format( dataset_name=dataset_name, img_size=img_size) cfg.freeze() num_workers = comm.get_world_size() batch_size = IMS_PER_BATCH // num_workers dataset_mapper = build_dataset_mapper(dataset_mapper_cfg, img_size=img_size) data_loader = build_detection_test_loader(cfg, dataset_name=dataset_name, batch_size=batch_size, mapper=dataset_mapper) FID_IS_tf = build_GAN_metric(GAN_metric) FID_IS_tf.calculate_fid_stat_of_dataloader(data_loader=data_loader) comm.synchronize() pass
def __call__(self, sample_func, return_fid_stat=False, num_inception_images=None, return_fid_logit=False, stdout=sys.stdout): import torch class SampleClass(object): def __init__(self, sample_func): self.sample_func = sample_func def __call__(self, *args, **kwargs): """ :return: images: [0, 255] """ images = self.sample_func() images = images.mul_(127.5).add_(127.5).clamp_(0.0, 255.0).permute(0, 2, 3, 1).type(torch.uint8) images = images.cpu().numpy() return images sample_func = SampleClass(sample_func) if num_inception_images is None: num_inception_images = self.num_inception_images else: num_inception_images = num_inception_images // comm.get_world_size() pred_FIDs, pred_ISs = self._get_activations_with_sample_func( sample_func=sample_func, num_inception_images=num_inception_images, stdout=stdout) if return_fid_stat: if comm.is_main_process(): self.logger.info(f"Num of images: {len(pred_FIDs)}") mu, sigma = self._calculate_fid_stat(pred_FIDs=pred_FIDs) else: mu, sigma = 0, 0 if return_fid_logit: return mu, sigma, pred_FIDs else: return mu, sigma if comm.is_main_process(): self.logger.info(f"Num of images: {len(pred_FIDs)}") IS_mean_tf, IS_std_tf = self._calculate_IS(pred_ISs=pred_ISs, IS_splits=self.IS_splits) # calculate FID stat mu = np.mean(pred_FIDs, axis=0) sigma = np.cov(pred_FIDs, rowvar=False) FID_tf = calculate_frechet_distance(mu, sigma, self.mu_data, self.sigma_data) else: FID_tf = IS_mean_tf = IS_std_tf = 0 del pred_FIDs, pred_ISs comm.synchronize() return FID_tf, IS_mean_tf, IS_std_tf
def load_inception_net(parallel, device='cuda'): net = load_inception_net(parallel=False) net = net.to(device) if parallel: if not comm.get_world_size() > 1: return pg = torch.distributed.new_group(range(torch.distributed.get_world_size())) net = DistributedDataParallel( net, device_ids=[dist.get_rank()], broadcast_buffers=False, process_group=pg, check_reduction=False ) return net
def __init__(self, saved_inception_moments, parallel=False, device='cuda'): """ """ self.logger = logging.getLogger('tl') saved_inception_moments = os.path.expanduser(saved_inception_moments) self.data_mu = np.load(saved_inception_moments)['mu'] self.data_sigma = np.load(saved_inception_moments)['sigma'] # Load network self.parallel = parallel if not comm.get_world_size() > 1: self.parallel = False self.net = self.load_inception_net(parallel=self.parallel) pass