def test(self): states = torch.load( os.path.join(self.args.log, "checkpoint.pth"), map_location=self.config.device, ) decoder = (MLPDecoder(self.config).to(self.config.device) if self.config.data.dataset == "MNIST" else Decoder( self.config).to(self.config.device)) decoder.eval() decoder.load_state_dict(states[1]) z = torch.randn(100, self.config.model.z_dim, device=self.config.device) if self.config.data.dataset == "CELEBA": samples, _ = decoder(z) samples = samples.view( 100, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, ) image_grid = make_grid(samples, 10) image_grid = torch.clamp(image_grid / 2.0 + 0.5, 0.0, 1.0) elif self.config.data.dataset == "MNIST": samples_logits = decoder(z) samples = torch.sigmoid(samples_logits) samples = samples.view( 100, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, ) image_grid = make_grid(samples, 10) save_image(image_grid, "image_grid.png")
def __init__(self, dims): """ M2 code replication from the paper 'Semi-Supervised Learning with Deep Generative Models' (Kingma 2014) in PyTorch. The "Generative semi-supervised model" is a probabilistic model that incorporates label information in both inference and generation. Initialise a new generative model :param dims: dimensions of x, y, z and hidden layers. """ [x_dim, self.y_dim, z_dim, h_dim] = dims super(DeepGenerativeModel, self).__init__([x_dim, z_dim, h_dim]) self.encoder = Encoder([x_dim + self.y_dim, h_dim, z_dim]) self.decoder = Decoder( [z_dim + self.y_dim, list(reversed(h_dim)), x_dim]) self.classifier = Classifier([x_dim, h_dim[0], self.y_dim]) for m in self.modules(): if isinstance(m, nn.Linear): init.xavier_normal_(m.weight.data) if m.bias is not None: m.bias.data.zero_()
def __init__(self, dims): """ Auxiliary Deep Generative Models [Maaløe 2016] code replication. The ADGM introduces an additional latent variable 'a', which enables the model to fit more complex variational distributions. :param dims: dimensions of x, y, z, a and hidden layers. """ [x_dim, y_dim, z_dim, a_dim, h_dim] = dims super(AuxiliaryDeepGenerativeModel, self).__init__([x_dim, y_dim, z_dim, h_dim]) self.aux_encoder = Encoder([x_dim, h_dim, a_dim]) # q(a|x) self.aux_decoder = Encoder( [x_dim + z_dim + y_dim, list(reversed(h_dim)), a_dim]) # p(a|x,y,z) self.classifier = Classifier([x_dim + a_dim, h_dim[0], y_dim]) # q(y|a,x) self.encoder = Encoder([a_dim + y_dim + x_dim, h_dim, z_dim]) # q(z|a,y,x) self.decoder = Decoder([y_dim + z_dim, list(reversed(h_dim)), x_dim]) # p(x|y,z)
def _create_vae_model(self): original_dim = self.all_data.shape[1] intermediate_dim = self.parameters['TechniqueParameters']['IntermediateDimension'] latent_dim = self.parameters['TechniqueParameters']['LatentDimension'] encoder = Encoder(intermediate_dim=intermediate_dim, latent_dim=latent_dim) encoder.build((None, original_dim)) decoder = Decoder(intermediate_dim=intermediate_dim, original_dim=original_dim) decoder.build((None, latent_dim)) model = VariationalAutoencoder(encoder, decoder, latent_dim=latent_dim) return model
def __init__(self, dims): """ Ladder version of the Deep Generative Model. Uses a hierarchical representation that is trained end-to-end to give very nice disentangled representations. :param dims: dimensions of x, y, z layers and h layers note that len(z) == len(h). """ [x_dim, y_dim, z_dim, h_dim] = dims super(LadderDeepGenerativeModel, self).__init__([x_dim, y_dim, z_dim[0], h_dim]) neurons = [x_dim, *h_dim] encoder_layers = [ LadderEncoder([neurons[i - 1], neurons[i], z_dim[i - 1]]) for i in range(1, len(neurons)) ] e = encoder_layers[-1] encoder_layers[-1] = LadderEncoder( [e.in_features + y_dim, e.out_features, e.z_dim]) decoder_layers = [ LadderDecoder([z_dim[i - 1], h_dim[i - 1], z_dim[i]]) for i in range(1, len(h_dim)) ][::-1] self.classifier = Classifier([x_dim, h_dim[0], y_dim]) self.encoder = nn.ModuleList(encoder_layers) self.decoder = nn.ModuleList(decoder_layers) self.reconstruction = Decoder([z_dim[0] + y_dim, h_dim, x_dim]) for m in self.modules(): if isinstance(m, nn.Linear): init.xavier_normal_(m.weight.data) if m.bias is not None: m.bias.data.zero_()
def train(self): transform = transforms.Compose([ transforms.Resize(self.config.data.image_size), transforms.ToTensor() ]) if self.config.data.dataset == "CIFAR10": dataset = CIFAR10( os.path.join(self.args.run, "datasets", "cifar10"), train=True, download=True, transform=transform, ) test_dataset = CIFAR10( os.path.join(self.args.run, "datasets", "cifar10"), train=False, download=True, transform=transform, ) elif self.config.data.dataset == "MNIST": dataset = MNIST( os.path.join(self.args.run, "datasets", "mnist"), train=True, download=True, transform=transform, ) num_items = len(dataset) indices = list(range(num_items)) random_state = np.random.get_state() np.random.seed(2019) np.random.shuffle(indices) np.random.set_state(random_state) train_indices, test_indices = ( indices[:int(num_items * 0.8)], indices[int(num_items * 0.8):], ) test_dataset = Subset(dataset, test_indices) dataset = Subset(dataset, train_indices) elif self.config.data.dataset == "CELEBA": dataset = ImageFolder( # root="/raid/tianyu/ncsn/run/datasets/celeba/celeba", root="/home/kunxu/tmp/", transform=transforms.Compose([ transforms.CenterCrop(140), transforms.Resize(self.config.data.image_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]), ) num_items = len(dataset) indices = list(range(num_items)) random_state = np.random.get_state() np.random.seed(2019) np.random.shuffle(indices) np.random.set_state(random_state) train_indices, test_indices = ( indices[:int(num_items * 0.7)], indices[int(num_items * 0.7):int(num_items * 0.8)], ) test_dataset = Subset(dataset, test_indices) dataset = Subset(dataset, train_indices) dataloader = DataLoader( dataset, batch_size=self.config.training.batch_size, shuffle=True, num_workers=4, ) test_loader = DataLoader( test_dataset, batch_size=self.config.training.batch_size, shuffle=True, num_workers=2, ) test_iter = iter(test_loader) self.config.input_dim = (self.config.data.image_size**2 * self.config.data.channels) tb_path = os.path.join(self.args.run, "tensorboard", self.args.doc) if os.path.exists(tb_path): shutil.rmtree(tb_path) tb_logger = tensorboardX.SummaryWriter(log_dir=tb_path) decoder = (MLPDecoder(self.config).to(self.config.device) if self.config.data.dataset == "MNIST" else Decoder( self.config).to(self.config.device)) if self.config.training.algo == "vae": encoder = (MLPEncoder(self.config).to(self.config.device) if self.config.data.dataset == "MNIST" else Encoder( self.config).to(self.config.device)) optimizer = self.get_optimizer( itertools.chain(encoder.parameters(), decoder.parameters())) if self.args.resume_training: states = torch.load( os.path.join(self.args.log, "checkpoint.pth")) encoder.load_state_dict(states[0]) decoder.load_state_dict(states[1]) optimizer.load_state_dict(states[2]) elif self.config.training.algo in ["ssm", "ssm_fd"]: score = (MLPScore(self.config).to(self.config.device) if self.config.data.dataset == "MNIST" else Score( self.config).to(self.config.device)) imp_encoder = (MLPImplicitEncoder(self.config).to( self.config.device) if self.config.data.dataset == "MNIST" else ImplicitEncoder(self.config).to(self.config.device)) opt_ae = optim.RMSprop( itertools.chain(decoder.parameters(), imp_encoder.parameters()), lr=self.config.optim.lr, ) opt_score = optim.RMSprop(score.parameters(), lr=self.config.optim.lr) if self.args.resume_training: states = torch.load( os.path.join(self.args.log, "checkpoint.pth")) imp_encoder.load_state_dict(states[0]) decoder.load_state_dict(states[1]) score.load_state_dict(states[2]) opt_ae.load_state_dict(states[3]) opt_score.load_state_dict(states[4]) elif self.config.training.algo in ["spectral", "stein"]: from models.kernel_score_estimators import ( SpectralScoreEstimator, SteinScoreEstimator, ) imp_encoder = (MLPImplicitEncoder(self.config).to( self.config.device) if self.config.data.dataset == "MNIST" else ImplicitEncoder(self.config).to(self.config.device)) estimator = (SpectralScoreEstimator() if self.config.training.algo == "spectral" else SteinScoreEstimator()) optimizer = self.get_optimizer( itertools.chain(imp_encoder.parameters(), decoder.parameters())) if self.args.resume_training: states = torch.load( os.path.join(self.args.log, "checkpoint.pth")) imp_encoder.load_state_dict(states[0]) decoder.load_state_dict(states[1]) optimizer.load_state_dict(states[2]) step = 0 best_validation_loss = np.inf validation_losses = [] recon_type = "bernoulli" if self.config.data.dataset == "MNIST" else "gaussian" time_dur = 0.0 for _ in range(self.config.training.n_epochs): for _, (X, y) in enumerate(dataloader): decoder.train() X = X.to(self.config.device) if self.config.data.dataset == "CELEBA": X = X + (torch.rand_like(X) - 0.5) / 128.0 elif self.config.data.dataset == "MNIST": eps = torch.rand_like(X) X = (eps <= X).float() if self.config.training.algo == "vae": encoder.train() loss, *_ = elbo(encoder, decoder, X, recon_type) optimizer.zero_grad() loss.backward() optimizer.step() elif self.config.training.algo == "ssm": imp_encoder.train() loss, run_time, ssm_loss, *_ = elbo_ssm( imp_encoder, decoder, score, opt_score, X, recon_type, training=True, n_particles=self.config.model.n_particles, ) opt_ae.zero_grad() loss.backward() opt_ae.step() elif self.config.training.algo == "ssm_fd": imp_encoder.train() loss, run_time, ssm_loss, *_ = elbo_ssm_fd( imp_encoder, decoder, score, opt_score, X, recon_type, training=True, n_particles=self.config.model.n_particles, ) opt_ae.zero_grad() loss.backward() opt_ae.step() elif self.config.training.algo in ["spectral", "stein"]: imp_encoder.train() loss = elbo_kernel( imp_encoder, decoder, estimator, X, recon_type, n_particles=self.config.model.n_particles, ) optimizer.zero_grad() loss.backward() optimizer.step() time_dur += run_time if step % 10 == 0: try: test_X, _ = next(test_iter) except: test_iter = iter(test_loader) test_X, _ = next(test_iter) test_X = test_X.to(self.config.device) if self.config.data.dataset == "CELEBA": test_X = test_X + (torch.rand_like(test_X) - 0.5) / 128.0 elif self.config.data.dataset == "MNIST": test_eps = torch.rand_like(test_X) test_X = (test_eps <= test_X).float() decoder.eval() if self.config.training.algo == "vae": encoder.eval() with torch.no_grad(): test_loss, *_ = elbo(encoder, decoder, test_X, recon_type) logging.info("loss: {}, test_loss: {}".format( loss.item(), test_loss.item())) elif self.config.training.algo == "ssm": imp_encoder.eval() test_loss, *_ = elbo_ssm( imp_encoder, decoder, score, None, test_X, recon_type, training=False, ) logging.info( "loss: {}, ssm_loss: {}, test_loss: {}".format( loss.item(), ssm_loss.item(), test_loss.item())) z = imp_encoder(test_X) tb_logger.add_histogram("z_X", z, global_step=step) elif self.config.training.algo == "ssm_fd": imp_encoder.eval() test_loss, *_ = elbo_ssm_fd( imp_encoder, decoder, score, None, test_X, recon_type, training=False, ) logging.info( "loss: {}, ssm_loss: {}, test_loss: {}".format( loss.item(), ssm_loss.item(), test_loss.item())) z = imp_encoder(test_X) tb_logger.add_histogram("z_X", z, global_step=step) elif self.config.training.algo in ["spectral", "stein"]: imp_encoder.eval() with torch.no_grad(): test_loss = elbo_kernel(imp_encoder, decoder, estimator, test_X, recon_type, 10) logging.info("loss: {}, test_loss: {}".format( loss.item(), test_loss.item())) validation_losses.append(test_loss.item()) tb_logger.add_scalar("loss", loss, global_step=step) tb_logger.add_scalar("test_loss", test_loss, global_step=step) if self.config.training.algo in ["ssm", "ssm_fd"]: tb_logger.add_scalar("ssm_loss", ssm_loss, global_step=step) if step % 500 == 0: logging.info( "Time Dur in this 500 iters: {}".format(time_dur)) time_dur = 0.0 with torch.no_grad(): z = torch.randn(100, self.config.model.z_dim, device=X.device) decoder.eval() if self.config.data.dataset == "CELEBA": samples, _ = decoder(z) samples = samples.view( 100, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, ) image_grid = make_grid(samples, 10) image_grid = torch.clamp(image_grid / 2.0 + 0.5, 0.0, 1.0) data_grid = make_grid(X[:100], 10) data_grid = torch.clamp(data_grid / 2.0 + 0.5, 0.0, 1.0) elif self.config.data.dataset == "MNIST": samples_logits = decoder(z) samples = torch.sigmoid(samples_logits) samples = samples.view( 100, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, ) image_grid = make_grid(samples, 10) data_grid = make_grid(X[:100], 10) tb_logger.add_image("samples", image_grid, global_step=step) tb_logger.add_image("data", data_grid, global_step=step) if len(validation_losses) != 0: validation_loss = sum(validation_losses) / len( validation_losses) if validation_loss < best_validation_loss: best_validation_loss = validation_loss validation_losses = [] # else: # return 0 if (step + 1) % 10000 == 0: if self.config.training.algo == "vae": states = [ encoder.state_dict(), decoder.state_dict(), optimizer.state_dict(), ] elif self.config.training.algo in ["ssm", "ssm_fd"]: states = [ imp_encoder.state_dict(), decoder.state_dict(), score.state_dict(), opt_ae.state_dict(), opt_score.state_dict(), ] elif self.config.training.algo in ["spectral", "stein"]: states = [ imp_encoder.state_dict(), decoder.state_dict(), optimizer.state_dict(), ] torch.save( states, os.path.join( self.args.log, "checkpoint_{}0k.pth".format((step + 1) // 10000), ), ) torch.save(states, os.path.join(self.args.log, "checkpoint.pth")) step += 1 if step >= self.config.training.n_iters: return 0
def test_fid(self): assert self.config.data.dataset == "CELEBA" transform = transforms.Compose([ transforms.Resize(self.config.data.image_size), transforms.ToTensor() ]) if self.config.data.dataset == "CIFAR10": test_dataset = CIFAR10( os.path.join(self.args.run, "datasets", "cifar10"), train=False, download=True, transform=transform, ) elif self.config.data.dataset == "MNIST": test_dataset = MNIST( os.path.join(self.args.run, "datasets", "mnist"), train=False, download=True, transform=transform, ) elif self.config.data.dataset == "CELEBA": dataset = ImageFolder( # root="/raid/tianyu/ncsn/run/datasets/celeba/celeba", root="/home/kunxu/tmp/", transform=transforms.Compose([ transforms.CenterCrop(140), transforms.Resize(self.config.data.image_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]), ) num_items = len(dataset) indices = list(range(num_items)) random_state = np.random.get_state() np.random.seed(2019) np.random.shuffle(indices) np.random.set_state(random_state) test_indices = indices[int(0.8 * num_items):] test_dataset = Subset(dataset, test_indices) test_loader = DataLoader( test_dataset, batch_size=self.config.training.batch_size, shuffle=False, num_workers=2, ) self.config.input_dim = (self.config.data.image_size**2 * self.config.data.channels) get_data_stats = False manual = False if get_data_stats: data_images = [] for _, (X, y) in enumerate(test_loader): X = X.to(self.config.device) X = X + (torch.rand_like(X) - 0.5) / 128.0 data_images.extend(X / 2.0 + 0.5) if len(data_images) > 10000: break if not os.path.exists( os.path.join(self.args.run, "datasets", "celeba140_fid", "raw_images")): os.makedirs( os.path.join(self.args.run, "datasets", "celeba140_fid", "raw_images")) logging.info("Saving data images") for i, image in enumerate(data_images): save_image( image, os.path.join( self.args.run, "datasets", "celeba140_fid", "raw_images", "{}.png".format(i), ), ) logging.info("Images saved. Calculating fid statistics now") fid.calculate_data_statics( os.path.join(self.args.run, "datasets", "celeba140_fid", "raw_images"), os.path.join(self.args.run, "datasets", "celeba140_fid"), 50, True, 2048, ) else: if manual: states = torch.load( os.path.join(self.args.log, "checkpoint_100k.pth"), map_location=self.config.device, ) decoder = Decoder(self.config).to(self.config.device) decoder.eval() if self.config.training.algo == "vae": encoder = Encoder(self.config).to(self.config.device) encoder.load_state_dict(states[0]) decoder.load_state_dict(states[1]) elif self.config.training.algo in ["ssm", "ssm_fd"]: score = Score(self.config).to(self.config.device) imp_encoder = ImplicitEncoder(self.config).to( self.config.device) imp_encoder.load_state_dict(states[0]) decoder.load_state_dict(states[1]) score.load_state_dict(states[2]) elif self.config.training.algo in ["spectral", "stein"]: from models.kernel_score_estimators import ( SpectralScoreEstimator, SteinScoreEstimator, ) imp_encoder = ImplicitEncoder(self.config).to( self.config.device) imp_encoder.load_state_dict(states[0]) decoder.load_state_dict(states[1]) all_samples = [] logging.info("Generating samples") for i in range(100): with torch.no_grad(): z = torch.randn(100, self.config.model.z_dim, device=self.config.device) samples, _ = decoder(z) samples = samples.view( 100, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, ) all_samples.extend(samples / 2.0 + 0.5) if not os.path.exists( os.path.join(self.args.log, "samples", "raw_images")): os.makedirs( os.path.join(self.args.log, "samples", "raw_images")) logging.info("Images generated. Saving images") for i, image in enumerate(all_samples): save_image( image, os.path.join(self.args.log, "samples", "raw_images", "{}.png".format(i)), ) logging.info("Generating fid statistics") fid.calculate_data_statics( os.path.join(self.args.log, "samples", "raw_images"), os.path.join(self.args.log, "samples"), 50, True, 2048, ) logging.info("Statistics generated.") else: for iter in range(10, 11): states = torch.load( os.path.join(self.args.log, "checkpoint_{}0k.pth".format(iter)), map_location=self.config.device, ) decoder = Decoder(self.config).to(self.config.device) decoder.eval() if self.config.training.algo == "vae": encoder = Encoder(self.config).to(self.config.device) encoder.load_state_dict(states[0]) decoder.load_state_dict(states[1]) elif self.config.training.algo in ["ssm", "ssm_fd"]: score = Score(self.config).to(self.config.device) imp_encoder = ImplicitEncoder(self.config).to( self.config.device) imp_encoder.load_state_dict(states[0]) decoder.load_state_dict(states[1]) score.load_state_dict(states[2]) elif self.config.training.algo in ["spectral", "stein"]: from models.kernel_score_estimators import ( SpectralScoreEstimator, SteinScoreEstimator, ) imp_encoder = ImplicitEncoder(self.config).to( self.config.device) imp_encoder.load_state_dict(states[0]) decoder.load_state_dict(states[1]) all_samples = [] logging.info("Generating samples") for i in range(100): with torch.no_grad(): z = torch.randn(100, self.config.model.z_dim, device=self.config.device) samples, _ = decoder(z) samples = samples.view( 100, self.config.data.channels, self.config.data.image_size, self.config.data.image_size, ) all_samples.extend(samples / 2.0 + 0.5) if not os.path.exists( os.path.join(self.args.log, "samples", "raw_images_{}0k".format(iter))): os.makedirs( os.path.join(self.args.log, "samples", "raw_images_{}0k".format(iter))) else: shutil.rmtree( os.path.join(self.args.log, "samples", "raw_images_{}0k".format(iter))) os.makedirs( os.path.join(self.args.log, "samples", "raw_images_{}0k".format(iter))) if not os.path.exists( os.path.join(self.args.log, "samples", "statistics_{}0k".format(iter))): os.makedirs( os.path.join(self.args.log, "samples", "statistics_{}0k".format(iter))) else: shutil.rmtree( os.path.join(self.args.log, "samples", "statistics_{}0k".format(iter))) os.makedirs( os.path.join(self.args.log, "samples", "statistics_{}0k".format(iter))) logging.info("Images generated. Saving images") for i, image in enumerate(all_samples): save_image( image, os.path.join( self.args.log, "samples", "raw_images_{}0k".format(iter), "{}.png".format(i), ), ) logging.info("Generating fid statistics") fid.calculate_data_statics( os.path.join(self.args.log, "samples", "raw_images_{}0k".format(iter)), os.path.join(self.args.log, "samples", "statistics_{}0k".format(iter)), 50, True, 2048, ) logging.info("Statistics generated.") fid_number = fid.calculate_fid_given_paths( [ "run/datasets/celeba140_fid/celeba_test.npz", os.path.join( self.args.log, "samples", "statistics_{}0k".format(iter), "celeba_test.npz", ), ], 50, True, 2048, ) logging.info("Number of iters: {}0k, FID: {}".format( iter, fid_number))
def main(data_directory, num_epochs, batch_size): dataset = get_audio_dataset(data_directory, max_length_in_seconds=2, pad_and_truncate=True) dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=8) train_dataloader_len = len(dataloader) encoder = Encoder(64, 1, 100).to("cuda") decoder = Decoder(64, 1, 100).to("cuda") siamese = Encoder(64, 1, 100) siamese_main = siamese.main.to("cuda") siamese_output = siamese.mu.to("cuda") optimizer_encoder = torch.optim.Adam(encoder.parameters(), lr=1e-4) optimizer_decoder = torch.optim.Adam(decoder.parameters(), lr=1e-4) optimizer_siamese = torch.optim.Adam(siamese.parameters(), lr=1e-4) criterion = torch.nn.MSELoss() for epoch in range(num_epochs): for sample_idx, (audio, _) in enumerate(dataloader): batch_size = audio.size(0) decoder.zero_grad() encoder.zero_grad() siamese.zero_grad() audio = audio.to("cuda") z, mu, logvar = encoder(audio) decoded = decoder(z) decoded = decoded.narrow(2, 0, 32000) hidden_fake_main = siamese_main(decoded) hidden_fake_output = siamese_output(hidden_fake_main) hidden_real_main = siamese_main(audio) hidden_real_output = siamese_output(hidden_real_main) err = criterion(hidden_fake_output, hidden_real_output) KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) loss = err + KLD loss.backward() optimizer_encoder.step() optimizer_decoder.step() optimizer_siamese.step() print( f"{epoch:06d}-[{sample_idx + 1}/{train_dataloader_len}]: loss {loss.mean().item()}" ) if sample_idx % 100 == 0: with torch.no_grad(): fake_noise = torch.randn(1, 100, 1).to("cuda") output_gen = decoder(fake_noise).narrow(2, 0, 32000).to("cpu") torchaudio.save( f"outputs/decoder_output_{epoch:06d}_{sample_idx:06d}.wav", output_gen[0], 16000, ) torch.save(encoder.state_dict(), "%s/encoder_epoch_%d.pth" % ("checkpoints", epoch)) torch.save(decoder.state_dict(), "%s/netD_epoch_%d.pth" % ("checkpoints", epoch)) torch.save(siamese.state_dict(), "%s/siamese_epoch_%d.pth" % ("checkpoints", epoch))
def train(self): transform = transforms.Compose([ transforms.Resize(self.config.data.image_size), transforms.ToTensor() ]) if self.config.data.dataset == 'CIFAR10': dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10'), train=True, download=True, transform=transform) test_dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10'), train=False, download=True, transform=transform) elif self.config.data.dataset == 'MNIST': dataset = MNIST(os.path.join(self.args.run, 'datasets', 'mnist'), train=True, download=True, transform=transform) num_items = len(dataset) indices = list(range(num_items)) random_state = np.random.get_state() np.random.seed(2019) np.random.shuffle(indices) np.random.set_state(random_state) train_indices, test_indices = indices[:int(num_items * 0.8)], indices[int(num_items * 0.8):] test_dataset = Subset(dataset, test_indices) dataset = Subset(dataset, train_indices) elif self.config.data.dataset == 'CELEBA': dataset = ImageFolder(root=os.path.join(self.args.run, 'datasets', 'celeba'), transform=transforms.Compose([ transforms.CenterCrop(140), transforms.Resize(self.config.data.image_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])) num_items = len(dataset) indices = list(range(num_items)) random_state = np.random.get_state() np.random.seed(2019) np.random.shuffle(indices) np.random.set_state(random_state) train_indices, test_indices = indices[:int(num_items * 0.7)], indices[ int(num_items * 0.7):int(num_items * 0.8)] test_dataset = Subset(dataset, test_indices) dataset = Subset(dataset, train_indices) dataloader = DataLoader(dataset, batch_size=self.config.training.batch_size, shuffle=True, num_workers=4) test_loader = DataLoader(test_dataset, batch_size=self.config.training.batch_size, shuffle=True, num_workers=2) test_iter = iter(test_loader) self.config.input_dim = self.config.data.image_size ** 2 * self.config.data.channels tb_path = os.path.join(self.args.run, 'tensorboard', self.args.doc) if os.path.exists(tb_path): shutil.rmtree(tb_path) tb_logger = tensorboardX.SummaryWriter(log_dir=tb_path) decoder = MLPDecoder(self.config).to(self.config.device) if self.config.data.dataset == 'MNIST' \ else Decoder(self.config).to(self.config.device) if self.config.training.algo == 'vae': encoder = MLPEncoder(self.config).to(self.config.device) if self.config.data.dataset == 'MNIST' \ else Encoder(self.config).to(self.config.device) optimizer = self.get_optimizer(itertools.chain(encoder.parameters(), decoder.parameters())) if self.args.resume_training: states = torch.load(os.path.join(self.args.log, 'checkpoint.pth')) encoder.load_state_dict(states[0]) decoder.load_state_dict(states[1]) optimizer.load_state_dict(states[2]) elif self.config.training.algo == 'ssm': score = MLPScore(self.config).to(self.config.device) if self.config.data.dataset == 'MNIST' else \ Score(self.config).to(self.config.device) imp_encoder = MLPImplicitEncoder(self.config).to(self.config.device) if self.config.data.dataset == 'MNIST' \ else ImplicitEncoder(self.config).to(self.config.device) opt_ae = optim.RMSprop(itertools.chain(decoder.parameters(), imp_encoder.parameters()), lr=self.config.optim.lr) opt_score = optim.RMSprop(score.parameters(), lr=self.config.optim.lr) if self.args.resume_training: states = torch.load(os.path.join(self.args.log, 'checkpoint.pth')) imp_encoder.load_state_dict(states[0]) decoder.load_state_dict(states[1]) score.load_state_dict(states[2]) opt_ae.load_state_dict(states[3]) opt_score.load_state_dict(states[4]) elif self.config.training.algo in ['spectral', 'stein']: from models.kernel_score_estimators import SpectralScoreEstimator, SteinScoreEstimator imp_encoder = MLPImplicitEncoder(self.config).to(self.config.device) if self.config.data.dataset == 'MNIST' \ else ImplicitEncoder(self.config).to(self.config.device) estimator = SpectralScoreEstimator() if self.config.training.algo == 'spectral' else SteinScoreEstimator() optimizer = self.get_optimizer(itertools.chain(imp_encoder.parameters(), decoder.parameters())) if self.args.resume_training: states = torch.load(os.path.join(self.args.log, 'checkpoint.pth')) imp_encoder.load_state_dict(states[0]) decoder.load_state_dict(states[1]) optimizer.load_state_dict(states[2]) step = 0 best_validation_loss = np.inf validation_losses = [] recon_type = 'bernoulli' if self.config.data.dataset == 'MNIST' else 'gaussian' for _ in range(self.config.training.n_epochs): for _, (X, y) in enumerate(dataloader): decoder.train() X = X.to(self.config.device) if self.config.data.dataset == 'CELEBA': X = X + (torch.rand_like(X) - 0.5) / 128. elif self.config.data.dataset == 'MNIST': eps = torch.rand_like(X) X = (eps <= X).float() if self.config.training.algo == 'vae': encoder.train() loss, *_ = elbo(encoder, decoder, X, recon_type) optimizer.zero_grad() loss.backward() optimizer.step() elif self.config.training.algo == 'ssm': imp_encoder.train() loss, ssm_loss, *_ = elbo_ssm(imp_encoder, decoder, score, opt_score, X, recon_type, training=True, n_particles=self.config.model.n_particles) opt_ae.zero_grad() loss.backward() opt_ae.step() elif self.config.training.algo in ['spectral', 'stein']: imp_encoder.train() loss = elbo_kernel(imp_encoder, decoder, estimator, X, recon_type, n_particles=self.config.model.n_particles) optimizer.zero_grad() loss.backward() optimizer.step() if step % 10 == 0: try: test_X, _ = next(test_iter) except: test_iter = iter(test_loader) test_X, _ = next(test_iter) test_X = test_X.to(self.config.device) if self.config.data.dataset == 'CELEBA': test_X = test_X + (torch.rand_like(test_X) - 0.5) / 128. elif self.config.data.dataset == 'MNIST': test_eps = torch.rand_like(test_X) test_X = (test_eps <= test_X).float() decoder.eval() if self.config.training.algo == 'vae': encoder.eval() with torch.no_grad(): test_loss, *_ = elbo(encoder, decoder, test_X, recon_type) logging.info("loss: {}, test_loss: {}".format(loss.item(), test_loss.item())) elif self.config.training.algo == 'ssm': imp_encoder.eval() test_loss, *_ = elbo_ssm(imp_encoder, decoder, score, None, test_X, recon_type, training=False) logging.info("loss: {}, ssm_loss: {}, test_loss: {}".format(loss.item(), ssm_loss.item(), test_loss.item())) z = imp_encoder(test_X) tb_logger.add_histogram('z_X', z, global_step=step) elif self.config.training.algo in ['spectral', 'stein']: imp_encoder.eval() with torch.no_grad(): test_loss = elbo_kernel(imp_encoder, decoder, estimator, test_X, recon_type, 10) logging.info("loss: {}, test_loss: {}".format(loss.item(), test_loss.item())) validation_losses.append(test_loss.item()) tb_logger.add_scalar('loss', loss, global_step=step) tb_logger.add_scalar('test_loss', test_loss, global_step=step) if self.config.training.algo == 'ssm': tb_logger.add_scalar('ssm_loss', ssm_loss, global_step=step) if step % 500 == 0: with torch.no_grad(): z = torch.randn(100, self.config.model.z_dim, device=X.device) decoder.eval() if self.config.data.dataset == 'CELEBA': samples, _ = decoder(z) samples = samples.view(100, self.config.data.channels, self.config.data.image_size, self.config.data.image_size) image_grid = make_grid(samples, 10) image_grid = torch.clamp(image_grid / 2. + 0.5, 0.0, 1.0) data_grid = make_grid(X[:100], 10) data_grid = torch.clamp(data_grid / 2. + 0.5, 0.0, 1.0) elif self.config.data.dataset == 'MNIST': samples_logits = decoder(z) samples = torch.sigmoid(samples_logits) samples = samples.view(100, self.config.data.channels, self.config.data.image_size, self.config.data.image_size) image_grid = make_grid(samples, 10) data_grid = make_grid(X[:100], 10) tb_logger.add_image('samples', image_grid, global_step=step) tb_logger.add_image('data', data_grid, global_step=step) if len(validation_losses) != 0: validation_loss = sum(validation_losses) / len(validation_losses) if validation_loss < best_validation_loss: best_validation_loss = validation_loss validation_losses = [] # else: # return 0 if (step + 1) % 10000 == 0: if self.config.training.algo == 'vae': states = [ encoder.state_dict(), decoder.state_dict(), optimizer.state_dict() ] elif self.config.training.algo == 'ssm': states = [ imp_encoder.state_dict(), decoder.state_dict(), score.state_dict(), opt_ae.state_dict(), opt_score.state_dict() ] elif self.config.training.algo in ['spectral', 'stein']: states = [ imp_encoder.state_dict(), decoder.state_dict(), optimizer.state_dict() ] torch.save(states, os.path.join(self.args.log, 'checkpoint_{}0k.pth'.format((step + 1) // 10000))) torch.save(states, os.path.join(self.args.log, 'checkpoint.pth')) step += 1 if step >= self.config.training.n_iters: return 0
def test_fid(self): assert self.config.data.dataset == 'CELEBA' transform = transforms.Compose([ transforms.Resize(self.config.data.image_size), transforms.ToTensor() ]) if self.config.data.dataset == 'CIFAR10': test_dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10'), train=False, download=True, transform=transform) elif self.config.data.dataset == 'MNIST': test_dataset = MNIST(os.path.join(self.args.run, 'datasets', 'mnist'), train=False, download=True, transform=transform) elif self.config.data.dataset == 'CELEBA': dataset = ImageFolder(root=os.path.join(self.args.run, 'datasets', 'celeba'), transform=transforms.Compose([ transforms.CenterCrop(140), transforms.Resize(self.config.data.image_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])) num_items = len(dataset) indices = list(range(num_items)) random_state = np.random.get_state() np.random.seed(2019) np.random.shuffle(indices) np.random.set_state(random_state) test_indices = indices[int(0.8 * num_items):] test_dataset = Subset(dataset, test_indices) test_loader = DataLoader(test_dataset, batch_size=self.config.training.batch_size, shuffle=False, num_workers=2) self.config.input_dim = self.config.data.image_size ** 2 * self.config.data.channels get_data_stats = False manual = False if get_data_stats: data_images = [] for _, (X, y) in enumerate(test_loader): X = X.to(self.config.device) X = X + (torch.rand_like(X) - 0.5) / 128. data_images.extend(X / 2. + 0.5) if len(data_images) > 10000: break if not os.path.exists(os.path.join(self.args.run, 'datasets', 'celeba140_fid', 'raw_images')): os.makedirs(os.path.join(self.args.run, 'datasets', 'celeba140_fid', 'raw_images')) logging.info("Saving data images") for i, image in enumerate(data_images): save_image(image, os.path.join(self.args.run, 'datasets', 'celeba140_fid', 'raw_images', '{}.png'.format(i))) logging.info("Images saved. Calculating fid statistics now") fid.calculate_data_statics(os.path.join(self.args.run, 'datasets', 'celeba140_fid', 'raw_images'), os.path.join(self.args.run, 'datasets', 'celeba140_fid'), 50, True, 2048) else: if manual: states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'), map_location=self.config.device) decoder = Decoder(self.config).to(self.config.device) decoder.eval() if self.config.training.algo == 'vae': encoder = Encoder(self.config).to(self.config.device) encoder.load_state_dict(states[0]) decoder.load_state_dict(states[1]) elif self.config.training.algo == 'ssm': score = Score(self.config).to(self.config.device) imp_encoder = ImplicitEncoder(self.config).to(self.config.device) imp_encoder.load_state_dict(states[0]) decoder.load_state_dict(states[1]) score.load_state_dict(states[2]) elif self.config.training.algo in ['spectral', 'stein']: from models.kernel_score_estimators import SpectralScoreEstimator, SteinScoreEstimator imp_encoder = ImplicitEncoder(self.config).to(self.config.device) imp_encoder.load_state_dict(states[0]) decoder.load_state_dict(states[1]) all_samples = [] logging.info("Generating samples") for i in range(100): with torch.no_grad(): z = torch.randn(100, self.config.model.z_dim, device=self.config.device) samples, _ = decoder(z) samples = samples.view(100, self.config.data.channels, self.config.data.image_size, self.config.data.image_size) all_samples.extend(samples / 2. + 0.5) if not os.path.exists(os.path.join(self.args.log, 'samples', 'raw_images')): os.makedirs(os.path.join(self.args.log, 'samples', 'raw_images')) logging.info("Images generated. Saving images") for i, image in enumerate(all_samples): save_image(image, os.path.join(self.args.log, 'samples', 'raw_images', '{}.png'.format(i))) logging.info("Generating fid statistics") fid.calculate_data_statics(os.path.join(self.args.log, 'samples', 'raw_images'), os.path.join(self.args.log, 'samples'), 50, True, 2048) logging.info("Statistics generated.") else: for iter in range(1, 11): states = torch.load(os.path.join(self.args.log, 'checkpoint_{}0k.pth'.format(iter)), map_location=self.config.device) decoder = Decoder(self.config).to(self.config.device) decoder.eval() if self.config.training.algo == 'vae': encoder = Encoder(self.config).to(self.config.device) encoder.load_state_dict(states[0]) decoder.load_state_dict(states[1]) elif self.config.training.algo == 'ssm': score = Score(self.config).to(self.config.device) imp_encoder = ImplicitEncoder(self.config).to(self.config.device) imp_encoder.load_state_dict(states[0]) decoder.load_state_dict(states[1]) score.load_state_dict(states[2]) elif self.config.training.algo in ['spectral', 'stein']: from models.kernel_score_estimators import SpectralScoreEstimator, SteinScoreEstimator imp_encoder = ImplicitEncoder(self.config).to(self.config.device) imp_encoder.load_state_dict(states[0]) decoder.load_state_dict(states[1]) all_samples = [] logging.info("Generating samples") for i in range(100): with torch.no_grad(): z = torch.randn(100, self.config.model.z_dim, device=self.config.device) samples, _ = decoder(z) samples = samples.view(100, self.config.data.channels, self.config.data.image_size, self.config.data.image_size) all_samples.extend(samples / 2. + 0.5) if not os.path.exists(os.path.join(self.args.log, 'samples', 'raw_images_{}0k'.format(iter))): os.makedirs(os.path.join(self.args.log, 'samples', 'raw_images_{}0k'.format(iter))) else: shutil.rmtree(os.path.join(self.args.log, 'samples', 'raw_images_{}0k'.format(iter))) os.makedirs(os.path.join(self.args.log, 'samples', 'raw_images_{}0k'.format(iter))) if not os.path.exists(os.path.join(self.args.log, 'samples', 'statistics_{}0k'.format(iter))): os.makedirs(os.path.join(self.args.log, 'samples', 'statistics_{}0k'.format(iter))) else: shutil.rmtree(os.path.join(self.args.log, 'samples', 'statistics_{}0k'.format(iter))) os.makedirs(os.path.join(self.args.log, 'samples', 'statistics_{}0k'.format(iter))) logging.info("Images generated. Saving images") for i, image in enumerate(all_samples): save_image(image, os.path.join(self.args.log, 'samples', 'raw_images_{}0k'.format(iter), '{}.png'.format(i))) logging.info("Generating fid statistics") fid.calculate_data_statics(os.path.join(self.args.log, 'samples', 'raw_images_{}0k'.format(iter)), os.path.join(self.args.log, 'samples', 'statistics_{}0k'.format(iter)), 50, True, 2048) logging.info("Statistics generated.") fid_number = fid.calculate_fid_given_paths([ 'run/datasets/celeba140_fid/celeba_test.npz', os.path.join(self.args.log, 'samples', 'statistics_{}0k'.format(iter), 'celeba_test.npz')] , 50, True, 2048) logging.info("Number of iters: {}0k, FID: {}".format(iter, fid_number))