def extract_reconstructions(encoder_input, style_mu, class_mu, class_logvar): grouped_mu, _ = accumulate_group_evidence(class_mu.data, class_logvar.data, torch.zeros(style_mu.size(0), 1)) decoder_style_input = style_mu.clone().detach().requires_grad_(True) decoder_content_input = grouped_mu[0].clone().detach().requires_grad_(True) # optimize wrt the above variables # decoder_style_input.cuda() # decoder_content_input.cuda() content = decoder_content_input.expand( style_mu.size(0), decoder_content_input.size(0)).double() optimizer = optim.Adam([decoder_style_input, decoder_content_input]) for iterations in range(50): optimizer.zero_grad() reconstructed = decoder(decoder_style_input, content) reconstruction_error = torch.sum( (reconstructed - encoder_input).pow(2)) reconstruction_error.backward() optimizer.step() return reconstructed, reconstruction_error
def forward(self, x, edge_index, batch, num_graphs): # batch_size = data.num_graphs if x is None: x = torch.ones(batch.shape[0]).to(device) node_mu, node_logvar, class_mu, class_logvar = self.encoder(x, edge_index, batch) grouped_mu, grouped_logvar = accumulate_group_evidence( class_mu.data, class_logvar.data, batch, True ) # kl-divergence error for style latent space node_kl_divergence_loss = torch.mean( - 0.5 * torch.sum(1 + node_logvar - node_mu.pow(2) - node_logvar.exp()) ) node_kl_divergence_loss = 0.0000001 * node_kl_divergence_loss *num_graphs node_kl_divergence_loss.backward(retain_graph=True) # kl-divergence error for class latent space class_kl_divergence_loss = torch.mean( - 0.5 * torch.sum(1 + grouped_logvar - grouped_mu.pow(2) - grouped_logvar.exp()) ) class_kl_divergence_loss = 0.0000001 * class_kl_divergence_loss * num_graphs class_kl_divergence_loss.backward(retain_graph=True) # reconstruct samples """ sampling from group mu and logvar for each graph in mini-batch differently makes the decoder consider class latent embeddings as random noise and ignore them """ node_latent_embeddings = reparameterize(training=True, mu=node_mu, logvar=node_logvar) class_latent_embeddings = group_wise_reparameterize( training=True, mu=grouped_mu, logvar=grouped_logvar, labels_batch=batch, cuda=True ) #need to reduce ml between node and class latents '''measure='JSD' mi_loss = local_global_loss_disen(node_latent_embeddings, class_latent_embeddings, edge_index, batch, measure) mi_loss.backward(retain_graph=True)''' reconstructed_node = self.decoder(node_latent_embeddings, class_latent_embeddings, edge_index) #check input feat first #print('recon ', x[0],reconstructed_node[0]) reconstruction_error = 0.1*mse_loss(reconstructed_node, x) * num_graphs reconstruction_error.backward() return reconstruction_error.item() , class_kl_divergence_loss.item() , node_kl_divergence_loss.item()
def run_through_network(X, labels_batch): style_mu, style_logvar, class_mu, class_logvar = encoder(Variable(X)) grouped_mu, grouped_logvar = accumulate_group_evidence( class_mu.data, class_logvar.data, labels_batch, FLAGS.cuda ) importance_sampling(X[0], style_mu[0], style_logvar[0], class_mu[0], class_logvar[0]) # reconstruct samples style_latent_embeddings = reparameterize(training=True, mu=style_mu, logvar=style_logvar) class_latent_embeddings = group_wise_reparameterize( training=True, mu=grouped_mu, logvar=grouped_logvar, labels_batch=labels_batch, cuda=FLAGS.cuda ) reconstructed_images = decoder(style_latent_embeddings, class_latent_embeddings) return reconstructed_images
def process(FLAGS, X, labels_batch, encoder, decoder): style_mu, style_logvar, class_mu, class_logvar = encoder(X.cuda()) content_mu, content_logvar, list_g, sizes_group = \ accumulate_group_evidence(FLAGS,class_mu, class_logvar, labels_batch, FLAGS.cuda) style_latent_embeddings = reparameterize(training=True, mu=style_mu, logvar=style_logvar) class_latent_embeddings, indexes, sizes = group_wise_reparameterize_each( training=True, mu=content_mu, logvar=content_logvar, labels_batch=labels_batch, list_groups_labels=list_g, sizes_group=sizes_group, cuda=FLAGS.cuda) # kl-divergence error for style latent space style_kl_divergence_loss = 0.5 * (-1 - style_logvar[indexes, :] + style_mu[indexes, :].pow(2) + style_logvar[indexes, :].exp()).sum() # kl-divergence error for class latent space class_kl_divergence_loss = 0.5 * (-1 - content_logvar + content_mu.pow(2) + content_logvar.exp()).sum() # reconstruct samples #reorder by the same order as class_latent_embeddings mu_x, logvar_x = decoder(style_latent_embeddings[indexes, :], class_latent_embeddings) scale_x = (torch.exp(logvar_x) + 1e-12)**0.5 scale_x = scale_x.view(X.size(0), 784) # create normal distribution on output pixel mu_x = mu_x.view(X.size(0), 784) prob_x = Normal(mu_x, scale_x) logp_batch = prob_x.log_prob(X[indexes, :].view(X.size(0), 784)).sum(1) reconstruction_proba = logp_batch.sum(0) n_groups = content_mu.size(0) elbo = (reconstruction_proba - style_kl_divergence_loss - class_kl_divergence_loss) / n_groups return elbo, reconstruction_proba / n_groups, style_kl_divergence_loss / n_groups, class_kl_divergence_loss / n_groups
def training_procedure(FLAGS): """ model definition """ encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) encoder.apply(weights_init) decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) decoder.apply(weights_init) # load saved models if load_saved flag is true if FLAGS.load_saved: encoder.load_state_dict( torch.load(os.path.join('checkpoints', FLAGS.encoder_save))) decoder.load_state_dict( torch.load(os.path.join('checkpoints', FLAGS.decoder_save))) """ variable definition """ X = torch.FloatTensor(FLAGS.batch_size, 1, FLAGS.image_size, FLAGS.image_size) ''' add option to run on GPU ''' if FLAGS.cuda: encoder.cuda() decoder.cuda() X = X.cuda() """ optimizer definition """ auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2)) """ training """ if torch.cuda.is_available() and not FLAGS.cuda: print( "WARNING: You have a CUDA device, so you should probably run with --cuda" ) if not os.path.exists('checkpoints'): os.makedirs('checkpoints') # load_saved is false when training is started from 0th iteration if not FLAGS.load_saved: with open(FLAGS.log_file, 'w') as log: log.write( 'Epoch\tIteration\tReconstruction_loss\tStyle_KL_divergence_loss\tClass_KL_divergence_loss\n' ) # load data set and create data loader instance print('Loading MNIST dataset...') mnist = datasets.MNIST(root='mnist', download=True, train=True, transform=transform_config) loader = cycle( DataLoader(mnist, batch_size=FLAGS.batch_size, shuffle=True, num_workers=0, drop_last=True)) # initialize summary writer writer = SummaryWriter() for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch): print('') print( 'Epoch #' + str(epoch) + '..........................................................................' ) for iteration in range(int(len(mnist) / FLAGS.batch_size)): # load a mini-batch image_batch, labels_batch = next(loader) # set zero_grad for the optimizer auto_encoder_optimizer.zero_grad() X.copy_(image_batch) style_mu, style_logvar, class_mu, class_logvar = encoder( Variable(X)) grouped_mu, grouped_logvar = accumulate_group_evidence( class_mu.data, class_logvar.data, labels_batch, FLAGS.cuda) # kl-divergence error for style latent space style_kl_divergence_loss = FLAGS.kl_divergence_coef * ( -0.5 * torch.sum(1 + style_logvar - style_mu.pow(2) - style_logvar.exp())) style_kl_divergence_loss /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size) style_kl_divergence_loss.backward(retain_graph=True) # kl-divergence error for class latent space class_kl_divergence_loss = FLAGS.kl_divergence_coef * ( -0.5 * torch.sum(1 + grouped_logvar - grouped_mu.pow(2) - grouped_logvar.exp())) class_kl_divergence_loss /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size) class_kl_divergence_loss.backward(retain_graph=True) # reconstruct samples """ sampling from group mu and logvar for each image in mini-batch differently makes the decoder consider class latent embeddings as random noise and ignore them """ style_latent_embeddings = reparameterize(training=True, mu=style_mu, logvar=style_logvar) class_latent_embeddings = group_wise_reparameterize( training=True, mu=grouped_mu, logvar=grouped_logvar, labels_batch=labels_batch, cuda=FLAGS.cuda) reconstructed_images = decoder(style_latent_embeddings, class_latent_embeddings) reconstruction_error = FLAGS.reconstruction_coef * mse_loss( reconstructed_images, Variable(X)) reconstruction_error.backward() auto_encoder_optimizer.step() if (iteration + 1) % 50 == 0: print('') print('Epoch #' + str(epoch)) print('Iteration #' + str(iteration)) print('') print('Reconstruction loss: ' + str(reconstruction_error.data.storage().tolist()[0])) print('Style KL-Divergence loss: ' + str(style_kl_divergence_loss.data.storage().tolist()[0])) print('Class KL-Divergence loss: ' + str(class_kl_divergence_loss.data.storage().tolist()[0])) # write to log with open(FLAGS.log_file, 'a') as log: log.write('{0}\t{1}\t{2}\t{3}\t{4}\n'.format( epoch, iteration, reconstruction_error.data.storage().tolist()[0], style_kl_divergence_loss.data.storage().tolist()[0], class_kl_divergence_loss.data.storage().tolist()[0])) # write to tensorboard writer.add_scalar( 'Reconstruction loss', reconstruction_error.data.storage().tolist()[0], epoch * (int(len(mnist) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar( 'Style KL-Divergence loss', style_kl_divergence_loss.data.storage().tolist()[0], epoch * (int(len(mnist) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar( 'Class KL-Divergence loss', class_kl_divergence_loss.data.storage().tolist()[0], epoch * (int(len(mnist) / FLAGS.batch_size) + 1) + iteration) # save checkpoints after every 5 epochs if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch: torch.save(encoder.state_dict(), os.path.join('checkpoints', FLAGS.encoder_save)) torch.save(decoder.state_dict(), os.path.join('checkpoints', FLAGS.decoder_save))
print() X_i = test_data[i] # 1-d vector l = X_i.size(0) errors = [] for eta in range(10, 90): print('\tRunning eta =', eta) g1 = X_i[0:eta].view(eta, -1).cuda() g2 = X_i[eta:l].view(l-eta, -1).cuda() total_error = 0 for g in [g1, g2]: style_mu, _, class_mu, class_logvar = encoder(g) grouped_mu, _ = accumulate_group_evidence( class_mu.data, class_logvar.data, torch.zeros(g.size(0), 1) ) decoder_style_input = torch.tensor(style_mu, requires_grad = True, device='cuda') decoder_content_input = torch.tensor(grouped_mu[0], requires_grad = True, device='cuda') content = decoder_content_input.expand(g.size(0), 1) optimizer = optim.Adam( [decoder_style_input, decoder_content_input], lr = 0.01 # this may be an important parameter ) for iterations in range(500): optimizer.zero_grad()
DataLoader(paired_mnist, batch_size=FLAGS.num_test_samples, shuffle=True, num_workers=0, drop_last=True)) image_batch, _, labels_batch = next(loader) style_mu, style_logvar, class_mu, class_logvar = encoder( Variable(image_batch)) style_latent_embeddings = reparameterize(training=True, mu=style_mu, logvar=style_logvar) if FLAGS.accumulate_evidence: grouped_mu, grouped_logvar = accumulate_group_evidence( class_mu.data, class_logvar.data, labels_batch, FLAGS.cuda) class_latent_embeddings = group_wise_reparameterize( training=True, mu=grouped_mu, logvar=grouped_logvar, labels_batch=labels_batch, cuda=FLAGS.cuda) else: class_latent_embeddings = reparameterize(training=True, mu=class_mu, logvar=class_logvar) # perform t-SNE embedding vis_data = TSNE(n_components=2, verbose=1, perplexity=30.0, n_iter=1000).fit_transform(
def training_procedure(FLAGS): """ model definition """ encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) encoder.apply(weights_init) decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim) decoder.apply(weights_init) # load saved models if load_saved flag is true if FLAGS.load_saved: encoder.load_state_dict( torch.load(os.path.join('checkpoints', FLAGS.encoder_save))) decoder.load_state_dict( torch.load(os.path.join('checkpoints', FLAGS.decoder_save))) """ variable definition """ X = torch.FloatTensor(FLAGS.batch_size, 784) ''' run on GPU if GPU is available ''' device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') encoder.to(device=device) decoder.to(device=device) X = X.to(device=device) """ optimizer definition """ auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=FLAGS.initial_learning_rate, betas=(FLAGS.beta_1, FLAGS.beta_2)) """ """ if torch.cuda.is_available() and not FLAGS.cuda: print( "WARNING: You have a CUDA device, so you should probably run with --cuda" ) if not os.path.exists('checkpoints'): os.makedirs('checkpoints') # load_saved is false when training is started from 0th iteration if not FLAGS.load_saved: with open(FLAGS.log_file, 'w') as log: log.write( 'Epoch\tIteration\tReconstruction_loss\tStyle_KL_divergence_loss\tClass_KL_divergence_loss\n' ) # load data set and create data loader instance dirs = os.listdir(os.path.join(os.getcwd(), 'data')) print('Loading double multivariate normal time series data...') for dsname in dirs: params = dsname.split('_') if params[2] in ('theta=-1'): print('Running dataset ', dsname) ds = DoubleMulNormal(dsname) # ds = experiment3(1000, 50, 3) loader = cycle( DataLoader(ds, batch_size=FLAGS.batch_size, shuffle=True, drop_last=True)) # initialize summary writer writer = SummaryWriter() for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch): print() print( 'Epoch #' + str(epoch) + '........................................................') # the total loss at each epoch after running iterations of batches total_loss = 0 for iteration in range(int(len(ds) / FLAGS.batch_size)): # load a mini-batch image_batch, labels_batch = next(loader) # set zero_grad for the optimizer auto_encoder_optimizer.zero_grad() X.copy_(image_batch) style_mu, style_logvar, class_mu, class_logvar = encoder( Variable(X)) grouped_mu, grouped_logvar = accumulate_group_evidence( class_mu.data, class_logvar.data, labels_batch, FLAGS.cuda) # kl-divergence error for style latent space style_kl_divergence_loss = FLAGS.kl_divergence_coef * ( -0.5 * torch.sum(1 + style_logvar - style_mu.pow(2) - style_logvar.exp())) style_kl_divergence_loss /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size) style_kl_divergence_loss.backward(retain_graph=True) # kl-divergence error for class latent space class_kl_divergence_loss = FLAGS.kl_divergence_coef * ( -0.5 * torch.sum(1 + grouped_logvar - grouped_mu.pow(2) - grouped_logvar.exp())) class_kl_divergence_loss /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size) class_kl_divergence_loss.backward(retain_graph=True) # reconstruct samples """ sampling from group mu and logvar for each image in mini-batch differently makes the decoder consider class latent embeddings as random noise and ignore them """ style_latent_embeddings = reparameterize( training=True, mu=style_mu, logvar=style_logvar) class_latent_embeddings = group_wise_reparameterize( training=True, mu=grouped_mu, logvar=grouped_logvar, labels_batch=labels_batch, cuda=FLAGS.cuda) reconstructed_images = decoder(style_latent_embeddings, class_latent_embeddings) reconstruction_error = FLAGS.reconstruction_coef * mse_loss( reconstructed_images, Variable(X)) reconstruction_error.backward() total_loss += style_kl_divergence_loss + class_kl_divergence_loss + reconstruction_error auto_encoder_optimizer.step() if (iteration + 1) % 50 == 0: print('\tIteration #' + str(iteration)) print('Reconstruction loss: ' + str( reconstruction_error.data.storage().tolist()[0])) print('Style KL loss: ' + str(style_kl_divergence_loss.data.storage(). tolist()[0])) print('Class KL loss: ' + str(class_kl_divergence_loss.data.storage(). tolist()[0])) # write to log with open(FLAGS.log_file, 'a') as log: log.write('{0}\t{1}\t{2}\t{3}\t{4}\n'.format( epoch, iteration, reconstruction_error.data.storage().tolist()[0], style_kl_divergence_loss.data.storage().tolist() [0], class_kl_divergence_loss.data.storage().tolist() [0])) # write to tensorboard writer.add_scalar( 'Reconstruction loss', reconstruction_error.data.storage().tolist()[0], epoch * (int(len(ds) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar( 'Style KL-Divergence loss', style_kl_divergence_loss.data.storage().tolist()[0], epoch * (int(len(ds) / FLAGS.batch_size) + 1) + iteration) writer.add_scalar( 'Class KL-Divergence loss', class_kl_divergence_loss.data.storage().tolist()[0], epoch * (int(len(ds) / FLAGS.batch_size) + 1) + iteration) if epoch == 0 and (iteration + 1) % 50 == 0: torch.save( encoder.state_dict(), os.path.join('checkpoints', 'encoder_' + dsname)) torch.save( decoder.state_dict(), os.path.join('checkpoints', 'decoder_' + dsname)) # save checkpoints after every 10 epochs if (epoch + 1) % 10 == 0 or (epoch + 1) == FLAGS.end_epoch: torch.save( encoder.state_dict(), os.path.join('checkpoints', 'encoder_' + dsname)) torch.save( decoder.state_dict(), os.path.join('checkpoints', 'decoder_' + dsname)) print('Total loss at current epoch: ', total_loss.item())
specified_factor_images = [] for i in range(1, 11): specified_factor_images.append(image_array[0][i][:, :, 0]) specified_factor_images = np.asarray(specified_factor_images) specified_factor_images = np.expand_dims(specified_factor_images, axis=3) specified_factor_images = np.transpose(specified_factor_images, (0, 3, 1, 2)) specified_factor_images = torch.FloatTensor(specified_factor_images) specified_factor_images = specified_factor_images.contiguous() if FLAGS.accumulate_evidence: # sample a big batch, accumulate evidence and use that for class embeddings image_batch, _, labels_batch = next(loader) _, __, class_mu, class_logvar = encoder(image_batch) content_mu, content_logvar, list_g, sizes_group = accumulate_group_evidence( FLAGS, class_mu.data, class_logvar.data, labels_batch, FLAGS.cuda) # generate all possible combinations using the encoder and decoder architecture in the grid for row in range(1, 11): style_image = image_array[row][0] style_image = np.transpose(style_image, (2, 0, 1)) style_image = torch.FloatTensor(style_image) style_image = style_image.contiguous() style_image = style_image[0, :, :] style_image = style_image.view(1, 1, 28, 28) style_mu, style_logvar, _, _ = encoder(Variable(style_image)) style_latent_embeddings = reparameterize(training=True, mu=style_mu, logvar=style_logvar) for col in range(1, 11):