def test_case_calculate_fid_stat_CIFAR(): from template_lib.d2.data import build_dataset_mapper from template_lib.d2template.trainer.base_trainer import build_detection_test_loader from template_lib.v2.GAN.evaluation import build_GAN_metric from template_lib.d2.utils.d2_utils import D2Utils from template_lib.v2.config_cfgnode import global_cfg from detectron2.utils import logger logger.setup_logger('d2') cfg = D2Utils.create_cfg() cfg.update(global_cfg) global_cfg.merge_from_dict(cfg) # fmt: off dataset_name = cfg.dataset_name IMS_PER_BATCH = cfg.IMS_PER_BATCH img_size = cfg.img_size dataset_mapper_cfg = cfg.dataset_mapper_cfg GAN_metric = cfg.GAN_metric # fmt: on num_workers = comm.get_world_size() batch_size = IMS_PER_BATCH // num_workers dataset_mapper = build_dataset_mapper(dataset_mapper_cfg, img_size=img_size) data_loader = build_detection_test_loader( cfg, dataset_name=dataset_name, batch_size=batch_size, mapper=dataset_mapper) FID_IS_tf = build_GAN_metric(GAN_metric) FID_IS_tf.calculate_fid_stat_of_dataloader(data_loader=data_loader) comm.synchronize() pass
def __init__(self, cfg, **kwargs): # fmt: off self.tf_inception_model_dir = cfg.tf_inception_model_dir self.tf_fid_stat = cfg.tf_fid_stat self.num_inception_images = getattr(cfg, 'num_inception_images', 50000) self.IS_splits = getattr(cfg, 'IS_splits', 10) # fmt: on self.logger = logging.getLogger('tl') ws = comm.get_world_size() self.num_inception_images = self.num_inception_images // ws self.tf_graph_name = 'FID_IS_Inception_Net' if os.path.isfile(self.tf_fid_stat): self.logger.info(f'Loading tf_fid_stat: {self.tf_fid_stat}') f = np.load(self.tf_fid_stat) self.mu_data, self.sigma_data = f['mu'][:], f['sigma'][:] f.close() else: self.logger.warning(f"tf_fid_stat does not exist: {self.tf_fid_stat}") self.tf_inception_model_dir = os.path.expanduser(self.tf_inception_model_dir) inception_path = self._check_or_download_inception(self.tf_inception_model_dir) self.logger.info('Load tf inception model in %s', inception_path) self._create_inception_graph(inception_path, name=self.tf_graph_name) self._create_inception_net() comm.synchronize()
def calculate_fid_stat_of_dataloader(self, data_loader, sample_func=None, return_fid_stat=False, num_images=float('inf'), save_fid_stat=True): if sample_func is None: class SampleClass(object): def __init__(self, data_loader): self.data_iter = iter(data_loader) def __call__(self, *args, **kwargs): """ :return: images: [-1, 1] """ inputs = next(self.data_iter) # images = [x["image"].to('cuda') for x in inputs] # images = torch.stack(images) images, labels = inputs images = images.to('cuda') return images sample_func = SampleClass(data_loader) data, label = next(iter(data_loader)) num_inception_images = len(data) * len(data_loader) num_inception_images = min(num_images, num_inception_images) pool, logits = self._accumulate_inception_activations( sample_func, net=self.inception_net, num_inception_images=num_inception_images, as_numpy=True) pool = self._gather_data(pool[:num_inception_images], is_numpy=True) logits = self._gather_data(logits[:num_inception_images], is_numpy=True) if return_fid_stat: if comm.is_main_process(): self.logger.info(f"Num of images: {len(pool)}") mu, sigma = self._get_FID_stat(pool=pool) else: mu, sigma = 0, 0 return mu, sigma if comm.is_main_process(): self.logger.info(f"Num of images: {len(pool)}") IS_mean, IS_std = calculate_inception_score(logits, self.IS_splits) self.logger.info(f'dataset IS_mean: {IS_mean:.3f} +- {IS_std}') if save_fid_stat: mu, sigma = self._get_FID_stat(pool=pool) self.logger.info( f'Saving torch_fid_stat to {self.torch_fid_stat}') os.makedirs(os.path.dirname(self.torch_fid_stat), exist_ok=True) np.savez(self.torch_fid_stat, **{'mu': mu, 'sigma': sigma}) comm.synchronize()
def __call__(self, sample_func, return_fid_stat=False, num_inception_images=None, return_fid_logit=False, stdout=sys.stdout): import torch class SampleClass(object): def __init__(self, sample_func): self.sample_func = sample_func def __call__(self, *args, **kwargs): """ :return: images: [0, 255] """ images = self.sample_func() images = images.mul_(127.5).add_(127.5).clamp_( 0.0, 255.0).permute(0, 2, 3, 1).type(torch.uint8) images = images.cpu().numpy() return images sample_func = SampleClass(sample_func) if num_inception_images is None: num_inception_images = self.num_inception_images pred_FIDs, pred_ISs = self._get_activations_with_sample_func( sample_func=sample_func, num_inception_images=num_inception_images, stdout=stdout) if return_fid_stat: if comm.is_main_process(): self.logger.info(f"Num of images: {len(pred_FIDs)}") mu, sigma = self._calculate_fid_stat(pred_FIDs=pred_FIDs) else: mu, sigma = 0, 0 if return_fid_logit: return mu, sigma, pred_FIDs else: return mu, sigma if comm.is_main_process(): self.logger.info(f"Num of images: {len(pred_FIDs)}") IS_mean_tf, IS_std_tf = self._calculate_IS( pred_ISs=pred_ISs, IS_splits=self.IS_splits) # calculate FID stat mu = np.mean(pred_FIDs, axis=0) sigma = np.cov(pred_FIDs, rowvar=False) FID_tf = calculate_frechet_distance(mu, sigma, self.mu_data, self.sigma_data) else: FID_tf = IS_mean_tf = IS_std_tf = 0 del pred_FIDs, pred_ISs comm.synchronize() return FID_tf, IS_mean_tf, IS_std_tf
def test_case_calculate_fid_stat_CIFAR10(): from template_lib.d2.data import build_dataset_mapper from template_lib.d2template.trainer.base_trainer import build_detection_test_loader from template_lib.gans.evaluation import build_GAN_metric from template_lib.utils.detection2_utils import D2Utils from template_lib.d2.data import build_cifar10 cfg_str = """ update_cfg: true dataset_name: "cifar10_train" IMS_PER_BATCH: 32 img_size: 32 NUM_WORKERS: 0 dataset_mapper_cfg: name: CIFAR10DatasetMapper GAN_metric: tf_fid_stat: "datasets/fid_stats_tf_cifar10.npz" """ config = EasyDict(yaml.safe_load(cfg_str)) config = TFFIDISScore.update_cfg(config) cfg = D2Utils.create_cfg() cfg = D2Utils.cfg_merge_from_easydict(cfg, config) # fmt: off dataset_name = cfg.dataset_name IMS_PER_BATCH = cfg.IMS_PER_BATCH img_size = cfg.img_size NUM_WORKERS = cfg.NUM_WORKERS dataset_mapper_cfg = cfg.dataset_mapper_cfg GAN_metric = cfg.GAN_metric # fmt: on cfg.defrost() cfg.DATALOADER.NUM_WORKERS = NUM_WORKERS cfg.GAN_metric.tf_fid_stat = cfg.GAN_metric.tf_fid_stat.format( dataset_name=dataset_name, img_size=img_size) cfg.freeze() num_workers = comm.get_world_size() batch_size = IMS_PER_BATCH // num_workers dataset_mapper = build_dataset_mapper(dataset_mapper_cfg, img_size=img_size) data_loader = build_detection_test_loader(cfg, dataset_name=dataset_name, batch_size=batch_size, mapper=dataset_mapper) FID_IS_tf = build_GAN_metric(GAN_metric) FID_IS_tf.calculate_fid_stat_of_dataloader(data_loader=data_loader) comm.synchronize() pass
def calculate_fid_stat_of_dataloader(self, data_loader, sample_func=None, return_fid_stat=False, num_images=float('inf'), stdout=sys.stdout): import torch if sample_func is None: class SampleClass(object): def __init__(self, data_loader): self.data_iter = iter(data_loader) def __call__(self, *args, **kwargs): inputs = next(self.data_iter) images = [x["image"].to('cuda') for x in inputs] images = torch.stack(images) images = images.mul_(127.5).add_(127.5).clamp_( 0.0, 255.0).permute(0, 2, 3, 1).type(torch.uint8) images = images.cpu().numpy() return images sample_func = SampleClass(data_loader) num_inception_images = len(next(iter(data_loader))) * len(data_loader) num_inception_images = min(num_images, num_inception_images) pred_FIDs, pred_ISs = self._get_activations_with_sample_func( sample_func=sample_func, num_inception_images=num_inception_images, stdout=stdout) if return_fid_stat: if comm.is_main_process(): self.logger.warning(f"Num of images: {len(pred_FIDs)}") mu, sigma = self._calculate_fid_stat(pred_FIDs=pred_FIDs) else: mu, sigma = 0, 0 return mu, sigma if comm.is_main_process(): self.logger.info(f"Num of images: {len(pred_FIDs)}") IS_mean, IS_std = self._calculate_IS(pred_ISs=pred_ISs, IS_splits=self.IS_splits) self.logger.info(f'dataset IS_mean: {IS_mean:.3f} +- {IS_std}') # calculate FID stat mu, sigma = self._calculate_fid_stat(pred_FIDs=pred_FIDs) self.logger.info(f'Saving tf_fid_stat to {self.tf_fid_stat}') os.makedirs(os.path.dirname(self.tf_fid_stat), exist_ok=True) np.savez(self.tf_fid_stat, **{'mu': mu, 'sigma': sigma}) comm.synchronize()
def _get_activations_with_sample_func(self, sample_func, num_inception_images, stdout=sys.stdout, verbose=True): pred_FIDs = [] pred_ISs = [] count = 0 while (count) < num_inception_images: if verbose and comm.is_main_process(): print( '\r', end= f'TF FID IS Score forwarding: [{count}/{num_inception_images}]', file=stdout, flush=True) try: batch = sample_func() # batch_list = comm.gather(data=batch) # if len(batch_list) > 0: # batch = np.concatenate(batch_list, axis=0) except StopIteration: break try: pred_FID, pred_IS = self.sess.run( [self.FID_pool3, self.IS_softmax], {f'{self.tf_graph_name}/ExpandDims:0': batch}) except KeyboardInterrupt: exit(-1) except: print(traceback.format_exc()) continue count += len(batch) pred_FIDs.append(pred_FID) pred_ISs.append(pred_IS) if verbose: print(f'rank: {comm.get_rank()}', file=stdout) pred_FIDs = np.concatenate(pred_FIDs, 0).squeeze() pred_ISs = np.concatenate(pred_ISs, 0) # sess.close() pred_FIDs = self._gather_numpy_array(pred_FIDs[:num_inception_images]) pred_ISs = self._gather_numpy_array(pred_ISs[:num_inception_images]) comm.synchronize() return pred_FIDs, pred_ISs
def __call__(self, sample_func, return_fid_stat=False, num_inception_images=None, stdout=sys.stdout): start_time = time.time() if num_inception_images is None: num_inception_images = self.num_inception_images pool, logits = self._accumulate_inception_activations( sample_func, net=self.inception_net, num_inception_images=num_inception_images, as_numpy=True, stdout=stdout) pool = self._gather_data(pool[:num_inception_images], is_numpy=True) logits = self._gather_data(logits[:num_inception_images], is_numpy=True) if return_fid_stat: if comm.is_main_process(): self.logger.info(f"Num of images: {len(pool)}") mu, sigma = self._get_FID_stat(pool=pool) else: mu, sigma = 0, 0 return mu, sigma if comm.is_main_process(): self.logger.info(f"Num of images: {len(pool)}") IS_mean_torch, IS_std_torch = calculate_inception_score( logits, num_splits=self.IS_splits) FID_torch = self._calculate_FID( pool=pool, no_fid=self.no_FID, use_torch=self.calculate_FID_use_torch) else: IS_mean_torch = IS_std_torch = FID_torch = 0 elapsed_time = time.time() - start_time time_str = time.strftime('%H:%M:%S', time.gmtime(elapsed_time)) self.logger.info('Elapsed time: %s' % (time_str)) del pool, logits comm.synchronize() return FID_torch, IS_mean_torch, IS_std_torch
def __call__(self, images, labels, z, iteration, **kwargs): """ :param images: :param labels: :param z: z.sample() :param iteration: :param kwargs: :return: """ if self.dummy: return summary_d = collections.defaultdict(dict) real = images dy = labels gy = dy self.G.train() self.D.train() self.D.zero_grad() d_real = self.D(real, dy) d_real_mean = d_real.mean() summary_d['d_logit_mean']['d_real_mean'] = d_real_mean.item() z_sample = z.sample() z_sample = z_sample.to(self.device) fake = self.G(z_sample, y=gy, **kwargs) d_fake = self.D(fake.detach(), gy, **kwargs) d_fake_mean = d_fake.mean() summary_d['d_logit_mean']['d_fake_mean'] = d_fake_mean.item() gp = self.wgan_gp_gradient_penalty_cond(x=real, G_z=fake, gy=gy, f=self.D, backward=False, gp_lambda=self.gp_lambda) summary_d['gp']['gp'] = gp.item() wd = d_real_mean - d_fake_mean summary_d['wd']['wd'] = wd.item() d_loss = -wd + gp d_loss.backward() self.D_optim.step() summary_d['d_loss']['d_loss'] = d_loss.item() ############################ # (2) Update G network ########################### if iteration % self.n_critic == 0: self.G.zero_grad() z_sample = z.sample() z_sample = z_sample.to(self.device) gy = dy fake = self.G(z_sample, y=gy, **kwargs) d_fake_g = self.D(fake, gy, **kwargs) d_fake_g_mean = d_fake_g.mean() g_loss = -d_fake_g_mean g_loss.backward() summary_d['d_logit_mean']['d_fake_g_mean'] = d_fake_g_mean.item() summary_d['g_loss']['g_loss'] = g_loss.item() if self.child_grad_bound > 0: grad_norm = torch.nn.utils.clip_grad_norm_( self.G.parameters(), self.child_grad_bound) summary_d['grad_norm']['grad_norm'] = grad_norm self.G_optim.step() if iteration % self.log_every == 0: Trainer.summary_defaultdict2txtfig( default_dict=summary_d, prefix='WGANGPCond', step=iteration, textlogger=self.myargs.textlogger) comm.synchronize() return
def train_controller(self, searched_cnn, valid_dataset_iter, preprocess_image_func, controller, controller_optim, iteration, pbar): """ :param controller: for ddp training :return: """ if comm.is_main_process() and iteration % 1000 == 0: pbar.set_postfix_str("ClsControllerRLAlphaFair") meter_dict = {} controller.train() controller.zero_grad() sampled_arcs = controller() sample_entropy = get_ddp_attr(controller, 'sample_entropy') sample_log_prob = get_ddp_attr(controller, 'sample_log_prob') val_data = next(valid_dataset_iter) bs = len(val_data) batched_arcs = sampled_arcs.repeat(bs, 1) top1 = AverageMeter() for i in range(self.num_aggregate): val_data = next(valid_dataset_iter) val_X, val_y = preprocess_image_func(val_data, device=self.device) val_X = val_X.tensor with torch.set_grad_enabled(False): logits = searched_cnn(val_X, batched_arcs=batched_arcs) prec1, = top_accuracy(output=logits, target=val_y, topk=(1, )) top1.update(prec1.item(), bs) reward_g = top1.avg meter_dict['reward_g'] = reward_g # detach to make sure that gradients aren't backpropped through the reward reward = torch.tensor(reward_g).cuda() sample_entropy_mean = sample_entropy.mean() meter_dict['sample_entropy'] = sample_entropy_mean.item() reward += self.entropy_weight * sample_entropy_mean if self.baseline is None: baseline = torch.tensor(reward_g) else: baseline = self.baseline - (1 - self.bl_dec) * (self.baseline - reward) # detach to make sure that gradients are not backpropped through the baseline baseline = baseline.detach() sample_log_prob_mean = sample_log_prob.mean() meter_dict['sample_log_prob'] = sample_log_prob_mean.item() loss = -1 * sample_log_prob_mean * (reward - baseline) meter_dict['reward'] = reward.item() meter_dict['baseline'] = baseline.item() meter_dict['loss'] = loss.item() loss.backward(retain_graph=False) grad_norm = torch.nn.utils.clip_grad_norm_(controller.parameters(), self.child_grad_bound) meter_dict['grad_norm'] = grad_norm controller_optim.step() baseline_list = comm.all_gather(baseline) baseline_mean = sum(map(lambda v: v.item(), baseline_list)) / len(baseline_list) baseline.fill_(baseline_mean) self.baseline = baseline if iteration % self.log_every_iter == 0: self.print_distribution(iteration=iteration, log_prob=False, print_interval=10) default_dicts = collections.defaultdict(dict) for meter_k, meter in meter_dict.items(): if meter_k in ['reward', 'baseline']: default_dicts['reward_baseline'][meter_k] = meter else: default_dicts[meter_k][meter_k] = meter summary_defaultdict2txtfig(default_dict=default_dicts, prefix='train_controller', step=iteration, textlogger=self.myargs.textlogger) comm.synchronize() return
def __call__(self, images, labels, z, iteration, ema=None, **kwargs): """ :param images: :param labels: :param z: z.sample() :param iteration: :param kwargs: :return: """ if self.dummy: return summary_d = collections.defaultdict(dict) real = images dy = labels gy = dy self.G.train() self.D.train() self.D.zero_grad() d_real = self.D(real, y=dy, **kwargs) z_sample = z.sample() z_sample = z_sample.to(self.device) fake = self.G(z_sample, y=gy, **kwargs) d_fake = self.D(fake.detach(), y=gy, **kwargs) r_logit_mean, f_logit_mean, d_loss = self.hinge_loss_discriminator( r_logit=d_real, f_logit=d_fake) summary_d['d_logit_mean']['r_logit_mean'] = r_logit_mean.item() summary_d['d_logit_mean']['f_logit_mean'] = f_logit_mean.item() d_loss.backward() self.D_optim.step() summary_d['d_loss']['d_loss'] = d_loss.item() ############################ # (2) Update G network ########################### if iteration % self.n_critic == 0: self.G.zero_grad() z_sample = z.sample() z_sample = z_sample.to(self.device) gy = dy fake = self.G(z_sample, y=gy, **kwargs) d_fake_g = self.D(fake, y=gy, **kwargs) G_f_logit_mean, g_loss = self.hinge_loss_generator( f_logit=d_fake_g) summary_d['d_logit_mean']['G_f_logit_mean'] = G_f_logit_mean.item() summary_d['g_loss']['g_loss'] = g_loss.item() g_loss.backward() self.G_optim.step() if ema is not None: ema.update(iteration) if iteration % self.log_every == 0: Trainer.summary_defaultdict2txtfig(default_dict=summary_d, prefix='HingeLossCond', step=iteration, textlogger=global_textlogger) comm.synchronize() return
def train_controller(self, G, z, y, controller, controller_optim, iteration, pbar): """ :param controller: for ddp training :return: """ if comm.is_main_process() and iteration % 1000 == 0: pbar.set_postfix_str("ControllerRLAlpha") meter_dict = {} G.eval() controller.train() controller.zero_grad() sampled_arcs = controller(iteration) sample_entropy = get_ddp_attr(controller, 'sample_entropy') sample_log_prob = get_ddp_attr(controller, 'sample_log_prob') pool_list, logits_list = [], [] for i in range(self.num_aggregate): z_samples = z.sample().to(self.device) y_samples = y.sample().to(self.device) with torch.set_grad_enabled(False): batched_arcs = sampled_arcs[y_samples] x = G(z=z_samples, y=y_samples, batched_arcs=batched_arcs) pool, logits = self.FID_IS.get_pool_and_logits(x) # pool_list.append(pool) logits_list.append(logits) # pool = np.concatenate(pool_list, 0) logits = np.concatenate(logits_list, 0) reward_g, _ = self.FID_IS.calculate_IS(logits) meter_dict['reward_g'] = reward_g # detach to make sure that gradients aren't backpropped through the reward reward = torch.tensor(reward_g).cuda() sample_entropy_mean = sample_entropy.mean() meter_dict['sample_entropy'] = sample_entropy_mean.item() reward += self.entropy_weight * sample_entropy_mean if self.baseline is None: baseline = torch.tensor(reward_g) else: baseline = self.baseline - (1 - self.bl_dec) * (self.baseline - reward) # detach to make sure that gradients are not backpropped through the baseline baseline = baseline.detach() sample_log_prob_mean = sample_log_prob.mean() meter_dict['sample_log_prob'] = sample_log_prob_mean.item() loss = -1 * sample_log_prob_mean * (reward - baseline) meter_dict['reward'] = reward.item() meter_dict['baseline'] = baseline.item() meter_dict['loss'] = loss.item() loss.backward(retain_graph=False) grad_norm = torch.nn.utils.clip_grad_norm_(controller.parameters(), self.child_grad_bound) meter_dict['grad_norm'] = grad_norm controller_optim.step() baseline_list = comm.all_gather(baseline) baseline_mean = sum(map(lambda v: v.item(), baseline_list)) / len(baseline_list) baseline.fill_(baseline_mean) self.baseline = baseline if iteration % self.log_every_iter == 0: self.print_distribution(iteration=iteration, print_interval=10) if len(sampled_arcs) <= 10: self.logger.info('\nsampled arcs: \n%s' % sampled_arcs.cpu().numpy()) self.myargs.textlogger.logstr( iteration, sampled_arcs='\n' + np.array2string(sampled_arcs.cpu().numpy(), threshold=np.inf)) default_dicts = collections.defaultdict(dict) for meter_k, meter in meter_dict.items(): if meter_k in ['reward', 'baseline']: default_dicts['reward_baseline'][meter_k] = meter else: default_dicts[meter_k][meter_k] = meter summary_defaultdict2txtfig(default_dict=default_dicts, prefix='train_controller', step=iteration, textlogger=self.myargs.textlogger) comm.synchronize() return