def test_fid(self): assert self.config.data.dataset == "CELEBA" transform = transforms.Compose([ transforms.Resize(self.config.data.image_size), transforms.ToTensor() ]) if self.config.data.dataset == "CIFAR10": test_dataset = CIFAR10( os.path.join(self.args.run, "datasets", "cifar10"), train=False, download=True, transform=transform, ) elif self.config.data.dataset == "MNIST": test_dataset = MNIST( os.path.join(self.args.run, "datasets", "mnist"), train=False, download=True, transform=transform, ) elif self.config.data.dataset == "CELEBA": dataset = ImageFolder( # root="/raid/tianyu/ncsn/run/datasets/celeba/celeba", root="/home/kunxu/tmp/", transform=transforms.Compose([ transforms.CenterCrop(140), transforms.Resize(self.config.data.image_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]), ) num_items = len(dataset) indices = list(range(num_items)) random_state = np.random.get_state() np.random.seed(2019) np.random.shuffle(indices) np.random.set_state(random_state) test_indices = indices[int(0.8 * num_items):] test_dataset = Subset(dataset, test_indices) test_loader = DataLoader( test_dataset, batch_size=self.config.training.batch_size, shuffle=False, num_workers=2, ) self.config.input_dim = (self.config.data.image_size**2 * self.config.data.channels) get_data_stats = False manual = False if get_data_stats: data_images = [] for _, (X, y) in enumerate(test_loader): X = X.to(self.config.device) X = X + (torch.rand_like(X) - 0.5) / 128.0 data_images.extend(X / 2.0 + 0.5) if len(data_images) > 10000: break if not os.path.exists( os.path.join(self.args.run, "datasets", "celeba140_fid", "raw_images")): os.makedirs( os.path.join(self.args.run, "datasets", "celeba140_fid", "raw_images")) logging.info("Saving data images") for i, image in enumerate(data_images): save_image( image, os.path.join( self.args.run, "datasets", "celeba140_fid", "raw_images", "{}.png".format(i), ), ) logging.info("Images saved. Calculating fid statistics now") fid.calculate_data_statics( os.path.join(self.args.run, "datasets", "celeba140_fid", "raw_images"), os.path.join(self.args.run, "datasets", "celeba140_fid"), 50, True, 2048, ) else: if manual: states = torch.load( os.path.join(self.args.log, "checkpoint_100k.pth"), map_location=self.config.device, ) decoder = Decoder(self.config).to(self.config.device) decoder.eval() if self.config.training.algo == "vae": encoder = Encoder(self.config).to(self.config.device) encoder.load_state_dict(states[0]) decoder.load_state_dict(states[1]) elif self.config.training.algo in ["ssm", "ssm_fd"]: score = Score(self.config).to(self.config.device) imp_encoder = ImplicitEncoder(self.config).to( self.config.device) imp_encoder.load_state_dict(states[0]) decoder.load_state_dict(states[1]) score.load_state_dict(states[2]) elif self.config.training.algo in ["spectral", "stein"]: from models.kernel_score_estimators import ( SpectralScoreEstimator, SteinScoreEstimator, ) imp_encoder = ImplicitEncoder(self.config).to( self.config.device) imp_encoder.load_state_dict(states[0]) decoder.load_state_dict(states[1]) all_samples = [] logging.info("Generating samples") for i in range(100): with torch.no_grad(): z = torch.randn(100, self.config.model.z_dim, device=self.config.device) samples, _ = decoder(z) samples = samples.view( 100, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, ) all_samples.extend(samples / 2.0 + 0.5) if not os.path.exists( os.path.join(self.args.log, "samples", "raw_images")): os.makedirs( os.path.join(self.args.log, "samples", "raw_images")) logging.info("Images generated. Saving images") for i, image in enumerate(all_samples): save_image( image, os.path.join(self.args.log, "samples", "raw_images", "{}.png".format(i)), ) logging.info("Generating fid statistics") fid.calculate_data_statics( os.path.join(self.args.log, "samples", "raw_images"), os.path.join(self.args.log, "samples"), 50, True, 2048, ) logging.info("Statistics generated.") else: for iter in range(10, 11): states = torch.load( os.path.join(self.args.log, "checkpoint_{}0k.pth".format(iter)), map_location=self.config.device, ) decoder = Decoder(self.config).to(self.config.device) decoder.eval() if self.config.training.algo == "vae": encoder = Encoder(self.config).to(self.config.device) encoder.load_state_dict(states[0]) decoder.load_state_dict(states[1]) elif self.config.training.algo in ["ssm", "ssm_fd"]: score = Score(self.config).to(self.config.device) imp_encoder = ImplicitEncoder(self.config).to( self.config.device) imp_encoder.load_state_dict(states[0]) decoder.load_state_dict(states[1]) score.load_state_dict(states[2]) elif self.config.training.algo in ["spectral", "stein"]: from models.kernel_score_estimators import ( SpectralScoreEstimator, SteinScoreEstimator, ) imp_encoder = ImplicitEncoder(self.config).to( self.config.device) imp_encoder.load_state_dict(states[0]) decoder.load_state_dict(states[1]) all_samples = [] logging.info("Generating samples") for i in range(100): with torch.no_grad(): z = torch.randn(100, self.config.model.z_dim, device=self.config.device) samples, _ = decoder(z) samples = samples.view( 100, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, ) all_samples.extend(samples / 2.0 + 0.5) if not os.path.exists( os.path.join(self.args.log, "samples", "raw_images_{}0k".format(iter))): os.makedirs( os.path.join(self.args.log, "samples", "raw_images_{}0k".format(iter))) else: shutil.rmtree( os.path.join(self.args.log, "samples", "raw_images_{}0k".format(iter))) os.makedirs( os.path.join(self.args.log, "samples", "raw_images_{}0k".format(iter))) if not os.path.exists( os.path.join(self.args.log, "samples", "statistics_{}0k".format(iter))): os.makedirs( os.path.join(self.args.log, "samples", "statistics_{}0k".format(iter))) else: shutil.rmtree( os.path.join(self.args.log, "samples", "statistics_{}0k".format(iter))) os.makedirs( os.path.join(self.args.log, "samples", "statistics_{}0k".format(iter))) logging.info("Images generated. Saving images") for i, image in enumerate(all_samples): save_image( image, os.path.join( self.args.log, "samples", "raw_images_{}0k".format(iter), "{}.png".format(i), ), ) logging.info("Generating fid statistics") fid.calculate_data_statics( os.path.join(self.args.log, "samples", "raw_images_{}0k".format(iter)), os.path.join(self.args.log, "samples", "statistics_{}0k".format(iter)), 50, True, 2048, ) logging.info("Statistics generated.") fid_number = fid.calculate_fid_given_paths( [ "run/datasets/celeba140_fid/celeba_test.npz", os.path.join( self.args.log, "samples", "statistics_{}0k".format(iter), "celeba_test.npz", ), ], 50, True, 2048, ) logging.info("Number of iters: {}0k, FID: {}".format( iter, fid_number))
def test_fid(self): assert self.config.data.dataset == 'CELEBA' transform = transforms.Compose([ transforms.Resize(self.config.data.image_size), transforms.ToTensor() ]) if self.config.data.dataset == 'CIFAR10': test_dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10'), train=False, download=True, transform=transform) elif self.config.data.dataset == 'MNIST': test_dataset = MNIST(os.path.join(self.args.run, 'datasets', 'mnist'), train=False, download=True, transform=transform) elif self.config.data.dataset == 'CELEBA': dataset = ImageFolder(root=os.path.join(self.args.run, 'datasets', 'celeba'), transform=transforms.Compose([ transforms.CenterCrop(140), transforms.Resize(self.config.data.image_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])) num_items = len(dataset) indices = list(range(num_items)) random_state = np.random.get_state() np.random.seed(2019) np.random.shuffle(indices) np.random.set_state(random_state) test_indices = indices[int(0.8 * num_items):] test_dataset = Subset(dataset, test_indices) test_loader = DataLoader(test_dataset, batch_size=self.config.training.batch_size, shuffle=False, num_workers=2) self.config.input_dim = self.config.data.image_size ** 2 * self.config.data.channels get_data_stats = False manual = False if get_data_stats: data_images = [] for _, (X, y) in enumerate(test_loader): X = X.to(self.config.device) X = X + (torch.rand_like(X) - 0.5) / 128. data_images.extend(X / 2. + 0.5) if len(data_images) > 10000: break if not os.path.exists(os.path.join(self.args.run, 'datasets', 'celeba140_fid', 'raw_images')): os.makedirs(os.path.join(self.args.run, 'datasets', 'celeba140_fid', 'raw_images')) logging.info("Saving data images") for i, image in enumerate(data_images): save_image(image, os.path.join(self.args.run, 'datasets', 'celeba140_fid', 'raw_images', '{}.png'.format(i))) logging.info("Images saved. Calculating fid statistics now") fid.calculate_data_statics(os.path.join(self.args.run, 'datasets', 'celeba140_fid', 'raw_images'), os.path.join(self.args.run, 'datasets', 'celeba140_fid'), 50, True, 2048) else: if manual: states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'), map_location=self.config.device) decoder = Decoder(self.config).to(self.config.device) decoder.eval() if self.config.training.algo == 'vae': encoder = Encoder(self.config).to(self.config.device) encoder.load_state_dict(states[0]) decoder.load_state_dict(states[1]) elif self.config.training.algo == 'ssm': score = Score(self.config).to(self.config.device) imp_encoder = ImplicitEncoder(self.config).to(self.config.device) imp_encoder.load_state_dict(states[0]) decoder.load_state_dict(states[1]) score.load_state_dict(states[2]) elif self.config.training.algo in ['spectral', 'stein']: from models.kernel_score_estimators import SpectralScoreEstimator, SteinScoreEstimator imp_encoder = ImplicitEncoder(self.config).to(self.config.device) imp_encoder.load_state_dict(states[0]) decoder.load_state_dict(states[1]) all_samples = [] logging.info("Generating samples") for i in range(100): with torch.no_grad(): z = torch.randn(100, self.config.model.z_dim, device=self.config.device) samples, _ = decoder(z) samples = samples.view(100, self.config.data.channels, self.config.data.image_size, self.config.data.image_size) all_samples.extend(samples / 2. + 0.5) if not os.path.exists(os.path.join(self.args.log, 'samples', 'raw_images')): os.makedirs(os.path.join(self.args.log, 'samples', 'raw_images')) logging.info("Images generated. Saving images") for i, image in enumerate(all_samples): save_image(image, os.path.join(self.args.log, 'samples', 'raw_images', '{}.png'.format(i))) logging.info("Generating fid statistics") fid.calculate_data_statics(os.path.join(self.args.log, 'samples', 'raw_images'), os.path.join(self.args.log, 'samples'), 50, True, 2048) logging.info("Statistics generated.") else: for iter in range(1, 11): states = torch.load(os.path.join(self.args.log, 'checkpoint_{}0k.pth'.format(iter)), map_location=self.config.device) decoder = Decoder(self.config).to(self.config.device) decoder.eval() if self.config.training.algo == 'vae': encoder = Encoder(self.config).to(self.config.device) encoder.load_state_dict(states[0]) decoder.load_state_dict(states[1]) elif self.config.training.algo == 'ssm': score = Score(self.config).to(self.config.device) imp_encoder = ImplicitEncoder(self.config).to(self.config.device) imp_encoder.load_state_dict(states[0]) decoder.load_state_dict(states[1]) score.load_state_dict(states[2]) elif self.config.training.algo in ['spectral', 'stein']: from models.kernel_score_estimators import SpectralScoreEstimator, SteinScoreEstimator imp_encoder = ImplicitEncoder(self.config).to(self.config.device) imp_encoder.load_state_dict(states[0]) decoder.load_state_dict(states[1]) all_samples = [] logging.info("Generating samples") for i in range(100): with torch.no_grad(): z = torch.randn(100, self.config.model.z_dim, device=self.config.device) samples, _ = decoder(z) samples = samples.view(100, self.config.data.channels, self.config.data.image_size, self.config.data.image_size) all_samples.extend(samples / 2. + 0.5) if not os.path.exists(os.path.join(self.args.log, 'samples', 'raw_images_{}0k'.format(iter))): os.makedirs(os.path.join(self.args.log, 'samples', 'raw_images_{}0k'.format(iter))) else: shutil.rmtree(os.path.join(self.args.log, 'samples', 'raw_images_{}0k'.format(iter))) os.makedirs(os.path.join(self.args.log, 'samples', 'raw_images_{}0k'.format(iter))) if not os.path.exists(os.path.join(self.args.log, 'samples', 'statistics_{}0k'.format(iter))): os.makedirs(os.path.join(self.args.log, 'samples', 'statistics_{}0k'.format(iter))) else: shutil.rmtree(os.path.join(self.args.log, 'samples', 'statistics_{}0k'.format(iter))) os.makedirs(os.path.join(self.args.log, 'samples', 'statistics_{}0k'.format(iter))) logging.info("Images generated. Saving images") for i, image in enumerate(all_samples): save_image(image, os.path.join(self.args.log, 'samples', 'raw_images_{}0k'.format(iter), '{}.png'.format(i))) logging.info("Generating fid statistics") fid.calculate_data_statics(os.path.join(self.args.log, 'samples', 'raw_images_{}0k'.format(iter)), os.path.join(self.args.log, 'samples', 'statistics_{}0k'.format(iter)), 50, True, 2048) logging.info("Statistics generated.") fid_number = fid.calculate_fid_given_paths([ 'run/datasets/celeba140_fid/celeba_test.npz', os.path.join(self.args.log, 'samples', 'statistics_{}0k'.format(iter), 'celeba_test.npz')] , 50, True, 2048) logging.info("Number of iters: {}0k, FID: {}".format(iter, fid_number))