def __init__(self, config): #super(Trainer, self).__init__() self.config = config lr = config['lr'] # Initiate the networks imgconf = config['image'] self.vae = VAE(imgconf,config['gen'],config['latent']) self.dis = DiscriminatorVAE(config['dis'],imgconf['image_size'],imgconf['image_dim']) ''' disconf = config['dis'] self.dis = DiscriminatorVAE(disconf['n_downsample'],disconf['n_res'], imgconf['image_size'],imgconf['image_dim'], disconf['dim'],disconf['norm'],disconf['activ'],disconf['pad_type']) ''' self.vae_optim = optim.Adam(self.vae.parameters(),lr=lr) self.dis_optim = optim.Adam(self.dis.parameters(),lr=lr) self.vae_scheduler = get_scheduler(self.vae_optim, config) self.dis_scheduler = get_scheduler(self.dis_optim, config) ''' beta1 = config['beta1'] beta2 = config['beta2'] self.vae_optim = optim.Adam(self.vae.parameters(), lr=lr, betas=(beta1, beta2), weight_decay=config['weight_decay']) self.dis_optim = optim.Adam(self.dis.parameters(), lr=lr, betas=(beta1, beta2), weight_decay=config['weight_decay']) ''' self.mse_crit = nn.MSELoss() self.bce_vae = nn.BCELoss() self.bce_dis = nn.BCELoss() '''
def __init__(self, config): self.config = config lr = config['lr'] # Initiate the networks imgconf = config['image'] self.vae = VAE(imgconf,config['gen'],config['latent']) self.vae_optim = optim.Adam(self.vae.parameters(),lr=lr) self.vae_scheduler = get_scheduler(self.vae_optim, config) self.mse_crit = nn.MSELoss()
def eval_db_agent(env, params): if params['use_preproc']: preprocessor = Preprocessor(params['state_dim'], params['history'], params['use_luminance'], params['resize_shape']) params['state_dim'] = preprocessor.state_shape else: preprocessor = None agent = VAE(params['state_dim'], params['action_dim']) if params['use_cuda']: agent = agent.cuda() agent.load_state_dict(torch.load('./agents/{0}_{1}'.format(params['arch'], params['env_name']))) else: agent.load_state_dict( torch.load('./agents/{0}_{1}'.format(params['arch'], params['env_name']), map_location='cpu')) agent.eval() agent_steps = 0 episode_rewards = [] start = time.time() for episode in xrange(1, params['num_episodes'] + 1): env_state = env.reset() episode_reward = 0.0 for t in xrange(1, params['max_steps'] + 1): if params['env_render']: env.render() if preprocessor: state = preprocessor.process_state(env_state) else: state = env_state var_state = createVariable(state, use_cuda=params['use_cuda']) action, state_val = agent.sample_action_eval(var_state) reward = 0.0 for _ in range(1): env_state, r, terminal, _ = env.step(action) reward += r if terminal: break episode_reward += reward if terminal: break episode_rewards.append(episode_reward) agent_steps += t if preprocessor: preprocessor.reset() print 'Episode {0} | Total Steps {1} | Total Reward {2} | Mean Reward {3} | Total Time {4}' \ .format(episode, agent_steps, episode_reward, sum(episode_rewards[-100:]) / 100, timeSince(start, episode / params['num_episodes']))
def __init__(self, config): self.config = config lr = config['lr'] # Initiate the networks imgconf = config['image'] self.vae = VAE(imgconf,config['gen'],config['latent']) self.dis = QNet(config['dis'],imgconf['image_size'],imgconf['image_dim'],config['latent']) self.vae_optim = optim.Adam(self.vae.parameters(),lr=lr) self.dis_optim = optim.Adam(self.dis.parameters(),lr=lr) self.vae_scheduler = get_scheduler(self.vae_optim, config) self.dis_scheduler = get_scheduler(self.dis_optim, config) self.mse_crit = nn.MSELoss() self.bce_vae = nn.BCELoss() self.bce_dis = nn.BCELoss()
def deep_barley(params): agent = VAE(params['state_dim'], params['action_dim']) agent.train() if params['use_cuda']: agent = agent.cuda() dataset = EpisodeDataset('./out/A2C_{0}_episode.pkl'.format(params['env_name'])) trainloader = DataLoader(dataset, batch_size=params['batch_size'], shuffle=True, num_workers=4) optimizer = torch.optim.Adam(agent.parameters(), lr=params['learning_rate']) # optimizer = torch.optim.RMSprop(agent.parameters(), lr=params['learning_rate']) for epoch in xrange(1, params['num_epochs'] + 1): total_loss = 0.0 for batch_id, batch in enumerate(trainloader): optimizer.zero_grad() batch_states, batch_pols = batch['state'], batch['policy'] if params['use_cuda']: batch_pols = batch_pols.cuda() if agent.use_concrete: pi_phi, _, phi = agent.forward(createVariable(batch_states, use_cuda=params['use_cuda'])) phi, _ = phi loss, r_loss, p_loss = loss_concrete(batch_pols, pi_phi, phi, params) else: pi_phi, _, rets = agent.forward(createVariable(batch_states, use_cuda=params['use_cuda'])) mus, logvars = rets loss, r_loss, p_loss = loss_gauss(batch_pols, pi_phi, mus, logvars, params) loss.backward() total_loss += loss.data optimizer.step() if (batch_id + 1) % params['print_every'] == 0: print '\tBatch {} | Total Loss: {:.6f} | R-Loss {:.6f} | P-Loss {:.6f} | \t[{}/{} ({:.0f}%)]' \ .format(batch_id + 1, loss.data, r_loss.data, p_loss.data, batch_id * len(batch_states), len(trainloader.dataset), 100. * batch_id / len(trainloader)) print 'Epoch {} | Total Loss {:.6f}'.format(epoch + 1, total_loss) if (epoch + 1) % params['save_every'] == 0 or (epoch + 1) == params['num_epochs']: torch.save(agent.state_dict(), './agents/{0}_{1}'.format(params['arch'], params['env_name']))
def initialize_network(sim_type, network_type, input_dim, latent_dim, dropout_val=0, device=None): """Helper method that grabs the appropriate network type given the parameters. Args: sim_type ([string]): "pairwise" or "triplet" network_type ([string]): "vae" or "embenc" input_dim ([int]): input dimension latent_dim ([int]): latent dimension dropout_val ([float]): dropout. defaults to 0. """ print("sim type: ", sim_type) print("network_type: ", network_type) if network_type == 'vae': base_vae = VAE(input_dim, latent_dim, dropout=dropout_val, device=device) if sim_type == 'pairwise': network = PairwiseVAE(base_vae) elif sim_type == 'triplet': network == TripletVAE(base_vae) else: base_network = EmbeddingEncoder(input_dim=input_dim, latent_dim=latent_dim, dropout_val=dropout_val) if sim_type == 'pairwise': network = PairwiseEmbeddingEncoder(base_network) if sim_type == 'triplet': network = TripletEmbeddingEncoder(base_network) return network
def cache_abstraction(env, params): if os.path.exists('./out/{0}'.format(params['env_name'])): shutil.rmtree('./out/{0}'.format(params['env_name'])) if params['use_preproc']: preprocessor = Preprocessor(params['state_dim'], params['history'], params['use_luminance'], params['resize_shape']) params['state_dim'] = preprocessor.state_shape else: preprocessor = None agent = VAE(params['state_dim'], params['action_dim']) if params['use_cuda']: agent = agent.cuda() agent.load_state_dict(torch.load('./agents/{0}_{1}'.format(params['arch'], params['env_name']))) else: agent.load_state_dict( torch.load('./agents/{0}_{1}'.format(params['arch'], params['env_name']), map_location='cpu')) agent.eval() agent_steps = 0 episode_rewards = [] start = time.time() for episode in xrange(1): env_state = env.reset() episode_reward = 0.0 for t in xrange(1, params['max_steps'] + 1): if params['env_render']: env.render() if preprocessor: state = preprocessor.process_state(env_state) else: state = env_state var_state = createVariable(state, use_cuda=params['use_cuda']) # action, state_val = agent.sample_action_eval(var_state) action, state_val, code = agent.sample_action_eval_code(var_state) if not os.path.exists('./out/{0}/{1}'.format(params['env_name'], code)): os.makedirs('./out/{0}/{1}'.format(params['env_name'], code)) preprocessor.get_img_state().save('./out/{0}/{1}/{2}.png'.format(params['env_name'], code, t)) reward = 0.0 for _ in range(1): env_state, r, terminal, _ = env.step(action) reward += r if terminal: break episode_reward += reward if terminal: break episode_rewards.append(episode_reward) agent_steps += t if preprocessor: preprocessor.reset() print 'Episode {0} | Total Steps {1} | Total Reward {2} | Mean Reward {3}' \ .format(episode, agent_steps, episode_reward, sum(episode_rewards[-100:]) / 100)
def build_model(self): self.load_data_in_memory() self.add_data(reuse=self.reuse) with tf.variable_scope("Text2Mel"): # Get S or decoder inputs. (B, T//r, n_mels). This is audio shifted 1 frame to the right. self.S = tf.concat( (tf.zeros_like(self.mels[:, :1, :]), self.mels[:, :-1, :]), 1) # Networks if self.hp.text_encoder_type == 'none': assert self.hp.merlin_label_dir self.K = self.V = self.merlin_label elif self.hp.text_encoder_type == 'minimal_feedforward': assert self.hp.merlin_label_dir #sys.exit('Not implemented: hp.text_encoder_type=="minimal_feedforward"') self.K = self.V = LinearTransformLabels(self.hp, self.merlin_label, training=self.training, reuse=self.reuse) else: ## default DCTTS text encoder # Build a latent representation of expressiveness, if we defined uee in config file (for unsupervised expressiveness embedding) #try: if self.hp.uee != 0: with tf.variable_scope("Audio2Emo"): self.emos = Audio2Emo( self.hp, self.S, training=self.training, speaker_codes=self.speakers, reuse=self.reuse) # (B, T/r, d=8) self.emo_mean = tf.reduce_mean(self.emos, 1) if self.hp.use_vae: self.emo_mean_sampled, mu, log_var = VAE( inputs=self.emo_mean, num_units=self.hp.vae_dim, scope='vae', reuse=self.reuse) #import pdb;pdb.set_trace() self.mu = mu self.log_var = log_var print(self.emo_mean_sampled.get_shape()) self.emo_mean_expanded = tf.expand_dims( self.emo_mean_sampled, axis=1) print(self.emo_mean_expanded.get_shape()) else: print(self.emo_mean.get_shape()) self.emo_mean_expanded = tf.expand_dims( self.emo_mean, axis=1) print(self.emo_mean_expanded.get_shape()) #pdb.set_trace() else: print('No unsupervised expressive embedding') self.emo_mean_expanded = None #pdb.set_trace() with tf.variable_scope("TextEnc"): self.K, self.V = TextEnc(self.hp, self.L, training=self.training, emos=self.emo_mean_expanded, speaker_codes=self.speakers, reuse=self.reuse) # (N, Tx, e) with tf.variable_scope("AudioEnc"): if self.hp.history_type in [ 'fractional_position_in_phone', 'absolute_position_in_phone' ]: self.Q = self.position_in_phone elif self.hp.history_type == 'minimal_history': sys.exit( 'Not implemented: hp.history_type=="minimal_history"') else: assert self.hp.history_type == 'DCTTS_standard' self.Q = AudioEnc(self.hp, self.S, training=self.training, speaker_codes=self.speakers, reuse=self.reuse) with tf.variable_scope("Attention"): # R: (B, T/r, 2d) # alignments: (B, N, T/r) # max_attentions: (B,) if not self.hp.attention_reparam: AppropriateAttention = Attention else: AppropriateAttention = Attention_reparametrized if self.hp.use_external_durations: self.R, self.alignments, self.max_attentions = FixedAttention( self.hp, self.durations, self.Q, self.V) elif self.mode is 'synthesize': self.R, self.alignments, self.max_attentions = AppropriateAttention( self.hp, self.Q, self.K, self.V, monotonic_attention=True, prev_max_attentions=self.prev_max_attentions) elif self.mode is 'train': self.R, self.alignments, self.max_attentions = AppropriateAttention( self.hp, self.Q, self.K, self.V, monotonic_attention=False, prev_max_attentions=self.prev_max_attentions) elif self.mode is 'generate_attention': self.R, self.alignments, self.max_attentions = AppropriateAttention( self.hp, self.Q, self.K, self.V, monotonic_attention=False, prev_max_attentions=None) with tf.variable_scope("AudioDec"): self.Y_logits, self.Y = AudioDec( self.hp, self.R, training=self.training, speaker_codes=self.speakers, reuse=self.reuse) # (B, T/r, n_mels)
class Text2MelGraph(Graph): def get_batchsize(self): return self.hp.batchsize['t2m'] ## TODO: naming? def build_model(self): self.load_data_in_memory() self.add_data(reuse=self.reuse) with tf.variable_scope("Text2Mel"): # Get S or decoder inputs. (B, T//r, n_mels). This is audio shifted 1 frame to the right. self.S = tf.concat( (tf.zeros_like(self.mels[:, :1, :]), self.mels[:, :-1, :]), 1) # Networks if self.hp.text_encoder_type == 'none': assert self.hp.merlin_label_dir self.K = self.V = self.merlin_label elif self.hp.text_encoder_type == 'minimal_feedforward': assert self.hp.merlin_label_dir #sys.exit('Not implemented: hp.text_encoder_type=="minimal_feedforward"') self.K = self.V = LinearTransformLabels(self.hp, self.merlin_label, training=self.training, reuse=self.reuse) else: ## default DCTTS text encoder # Build a latent representation of expressiveness, if we defined uee in config file (for unsupervised expressiveness embedding) #try: if self.hp.uee != 0: with tf.variable_scope("Audio2Emo"): self.emos = Audio2Emo( self.hp, self.S, training=self.training, speaker_codes=self.speakers, reuse=self.reuse) # (B, T/r, d=8) self.emo_mean = tf.reduce_mean(self.emos, 1) if self.hp.use_vae: self.emo_mean_sampled, mu, log_var = VAE( inputs=self.emo_mean, num_units=self.hp.vae_dim, scope='vae', reuse=self.reuse) #import pdb;pdb.set_trace() self.mu = mu self.log_var = log_var print(self.emo_mean_sampled.get_shape()) self.emo_mean_expanded = tf.expand_dims( self.emo_mean_sampled, axis=1) print(self.emo_mean_expanded.get_shape()) else: print(self.emo_mean.get_shape()) self.emo_mean_expanded = tf.expand_dims( self.emo_mean, axis=1) print(self.emo_mean_expanded.get_shape()) #pdb.set_trace() else: print('No unsupervised expressive embedding') self.emo_mean_expanded = None #pdb.set_trace() with tf.variable_scope("TextEnc"): self.K, self.V = TextEnc(self.hp, self.L, training=self.training, emos=self.emo_mean_expanded, speaker_codes=self.speakers, reuse=self.reuse) # (N, Tx, e) with tf.variable_scope("AudioEnc"): if self.hp.history_type in [ 'fractional_position_in_phone', 'absolute_position_in_phone' ]: self.Q = self.position_in_phone elif self.hp.history_type == 'minimal_history': sys.exit( 'Not implemented: hp.history_type=="minimal_history"') else: assert self.hp.history_type == 'DCTTS_standard' self.Q = AudioEnc(self.hp, self.S, training=self.training, speaker_codes=self.speakers, reuse=self.reuse) with tf.variable_scope("Attention"): # R: (B, T/r, 2d) # alignments: (B, N, T/r) # max_attentions: (B,) if not self.hp.attention_reparam: AppropriateAttention = Attention else: AppropriateAttention = Attention_reparametrized if self.hp.use_external_durations: self.R, self.alignments, self.max_attentions = FixedAttention( self.hp, self.durations, self.Q, self.V) elif self.mode is 'synthesize': self.R, self.alignments, self.max_attentions = AppropriateAttention( self.hp, self.Q, self.K, self.V, monotonic_attention=True, prev_max_attentions=self.prev_max_attentions) elif self.mode is 'train': self.R, self.alignments, self.max_attentions = AppropriateAttention( self.hp, self.Q, self.K, self.V, monotonic_attention=False, prev_max_attentions=self.prev_max_attentions) elif self.mode is 'generate_attention': self.R, self.alignments, self.max_attentions = AppropriateAttention( self.hp, self.Q, self.K, self.V, monotonic_attention=False, prev_max_attentions=None) with tf.variable_scope("AudioDec"): self.Y_logits, self.Y = AudioDec( self.hp, self.R, training=self.training, speaker_codes=self.speakers, reuse=self.reuse) # (B, T/r, n_mels) def build_loss(self): hp = self.hp ## L2 loss (new) self.loss_l2 = tf.reduce_mean(tf.squared_difference(self.Y, self.mels)) # mel L1 loss self.loss_mels = tf.reduce_mean(tf.abs(self.Y - self.mels)) # mel binary divergence loss self.loss_bd1 = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=self.Y_logits, labels=self.mels)) if not hp.squash_output_t2m: self.loss_bd1 = tf.zeros_like(self.loss_bd1) print( "binary divergence loss disabled because squash_output_t2m==False" ) # guided_attention loss self.A = tf.pad(self.alignments, [(0, 0), (0, hp.max_N), (0, hp.max_T)], mode="CONSTANT", constant_values=-1.)[:, :hp.max_N, :hp.max_T] if hp.attention_guide_dir: self.gts = tf.pad( self.gts, [(0, 0), (0, hp.max_N), (0, hp.max_T)], mode="CONSTANT", constant_values=1.0 )[:, :hp.max_N, :hp. max_T] ## TODO: check adding penalty here (1.0 is the right thing) self.attention_masks = tf.to_float(tf.not_equal(self.A, -1)) self.loss_att = tf.reduce_sum( tf.abs(self.A * self.gts) * self.attention_masks ) ## (B, Letters, Frames) * (Letters, Frames) -- Broadcasting first adds singleton dimensions to the left until rank is matched. self.mask_sum = tf.reduce_sum(self.attention_masks) self.loss_att /= self.mask_sum # total loss try: ## new way to configure loss weights:- TODO: ensure all configs use new pattern, and remove 'except' branch # total loss, with 2 terms combined with loss weights: self.loss = (hp.loss_weights['t2m']['L1'] * self.loss_mels) + \ (hp.loss_weights['t2m']['binary_divergence'] * self.loss_bd1) +\ (hp.loss_weights['t2m']['attention'] * self.loss_att) +\ (hp.loss_weights['t2m']['L2'] * self.loss_l2) except: self.lw_mel = hp.lw_mel self.lw_bd1 = hp.lw_bd1 self.lw_att = hp.lw_att self.lw_t2m_l2 = self.hp.lw_t2m_l2 self.loss = (self.lw_mel * self.loss_mels) + ( self.lw_bd1 * self.loss_bd1 ) + (self.lw_att * self.loss_att) + (self.lw_t2m_l2 * self.loss_l2) #import pdb;pdb.set_trace() if self.hp.use_vae and self.hp.if_vae_use_loss: self.ki_loss = -0.5 * tf.reduce_sum(1 + self.log_var - tf.pow(self.mu, 2) - tf.exp(self.log_var)) self.vae_loss_weight = vae_weight(hp, self.global_step) self.loss += self.ki_loss * self.vae_loss_weight # loss_components attribute is used for reporting to log (osw) self.loss_components = [ self.loss, self.loss_mels, self.loss_bd1, self.loss_att, self.loss_l2, self.ki_loss ] else: # loss_components attribute is used for reporting to log (osw) self.loss_components = [ self.loss, self.loss_mels, self.loss_bd1, self.loss_att, self.loss_l2 ] # summary used for reporting to tensorboard (kp) tf.summary.scalar('train/loss_mels', self.loss_mels) tf.summary.scalar('train/loss_bd1', self.loss_bd1) tf.summary.scalar('train/loss_att', self.loss_att) if self.hp.use_vae and self.hp.if_vae_use_loss: tf.summary.scalar('train/ki_loss', self.ki_loss) tf.summary.image( 'train/mel_gt', tf.expand_dims(tf.transpose(self.mels[:1], [0, 2, 1]), -1)) tf.summary.image( 'train/mel_hat', tf.expand_dims(tf.transpose(self.Y[:1], [0, 2, 1]), -1))
print('====> Test set loss: {:.4f}'.format(test_loss)) def save_checkpoint(state, filename='checkpoints/vae.pth'): torch.save(state, filename) batch_size = 64 epochs = 10000 train_loader = DataLoader(BufferDataset(BufferDataset()), batch_size=batch_size) test_loader = DataLoader(BufferDataset(BufferDataset(train=False)), batch_size=8) model = VAE().to(0) #chkpt = torch.load('checkpoints/vae.pth')['state_dict'] #model.load_state_dict(chkpt) optimizer = optim.Adam(model.parameters(), lr=1e-4) for epoch in range(1, epochs + 1): train(epoch) test(epoch) with torch.no_grad(): sample = torch.randn(64, 20).to(0) sample = model.decode(sample).cpu() save_image(sample.view(64, 1, 10, 13), 'results/sample_' + str(epoch) + '.png') save_checkpoint({ 'epoch': epoch, 'state_dict': model.state_dict(),
class TrainerInfoVAE(): def __init__(self, config): self.config = config lr = config['lr'] # Initiate the networks imgconf = config['image'] self.vae = VAE(imgconf,config['gen'],config['latent']) self.dis = QNet(config['dis'],imgconf['image_size'],imgconf['image_dim'],config['latent']) self.vae_optim = optim.Adam(self.vae.parameters(),lr=lr) self.dis_optim = optim.Adam(self.dis.parameters(),lr=lr) self.vae_scheduler = get_scheduler(self.vae_optim, config) self.dis_scheduler = get_scheduler(self.dis_optim, config) self.mse_crit = nn.MSELoss() self.bce_vae = nn.BCELoss() self.bce_dis = nn.BCELoss() def update_learning_rate(self): if self.vae_scheduler is not None: self.vae_scheduler.step() if self.dis_scheduler is not None: self.dis_scheduler.step() def cuda(self,device=0): members = [attr for attr in dir(self) if isinstance(getattr(self, attr),torch.nn.Module)] for m in members: getattr(self, m).cuda(device) def update_vae(self,images,config): self.vae_optim.zero_grad() self.dis_optim.zero_grad() inputs = images batch_size = images.size(0) ##### vae part recons, latent, samples = self.vae(inputs) kl_loss = self.compute_KL_loss(latent) * config['kl_w'] rec_loss = self.mse_crit(recons,images) ##### gan part prior_samples = self.vae.prior.sample_prior(batch_size,images.device) geners = self.vae.decoder(prior_samples) fake = torch.cat([geners,recons],0) q_dist = self.dis(fake) #####need to separate prior likelihood from latent for full mi loss ##### info part inf_dim = config['latent']['inform_dim'] inf_code = torch.cat([prior_samples[:,:inf_dim],samples[:,:inf_dim]],0) mi_loss = self.compute_mi(inf_code, q_dist) * config['inf_w'] total_loss = rec_loss + kl_loss + mi_loss total_loss.backward() self.vae_optim.step() self.vae_kl_loss = kl_loss.item() self.vae_rec_loss = rec_loss.item() self.vae_total_loss = total_loss.item() self.vae_inf_loss = mi_loss.item() self.encoder_samples = samples.data return recons def compute_mi(self, samples, q_dist_raw): #so far computes only entropy of Q(c|X) q_dist = self.vae.prior.activate(q_dist_raw) qx_li = self.vae.prior.log_li(samples, q_dist) qx_ent = torch.mean(-qx_li) return qx_ent def compute_KL_loss(self,distribution): mu_2 = torch.pow(distribution['mean'],2) sigma_2 = torch.pow(distribution['std'],2) return (-0.5 * (1 + torch.log(sigma_2) - mu_2 - sigma_2).sum(1)).mean() def update_dis(self,images,config): self.vae_optim.zero_grad() self.dis_optim.zero_grad() batch_size = images.size(0) inputs = images with torch.no_grad(): recons, latent, samples = self.vae(inputs) samples = samples.detach() prior_samples = self.vae.prior.sample_prior(batch_size).detach() geners = self.vae.decoder(prior_samples) recons = recons.detach() geners = geners.detach() fake = torch.cat([geners,recons],0) q_dist = self.dis(fake) inf_dim = config['latent']['inform_dim'] inf_code = torch.cat([prior_samples[:,:inf_dim],samples[:,:inf_dim]],0) mi_loss = self.compute_mi(inf_code, q_dist) dis_total_loss = mi_loss dis_total_loss.backward() self.dis_optim.step() self.dis_mi_loss = mi_loss.item() def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers vae_name = os.path.join(snapshot_dir, 'vae_%08d.pt' % (iterations + 1)) torch.save(self.vae.state_dict(), vae_name) def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "vae") state_dict = torch.load(last_model_name) self.vae.load_state_dict(state_dict) def get_latent_visualization(self,image_directory,postfix,images,prior_samples): with torch.no_grad(): recons, latent, samples = self.vae(images) start = latent['mean'].clone().detach() out_list1 = [recons[0:1]] out_list2 = [recons[1:2]] out_list3 = [recons[2:3]] for i in range(10): start.data[:3,0] = -5. + i out = self.vae.decoder(start) out_list1.append(out[0:1]) out_list2.append(out[1:2]) out_list3.append(out[2:3]) out_list = out_list1 + out_list2 + out_list3 recons = torch.cat(out_list,0) file_name = '%s/recons_intp%s.jpg' % (image_directory, postfix) vutils.save_image(recons/2 + 0.5, file_name, nrow=11) ####################################### geners = self.vae.decoder(prior_samples) out_list1 = [geners[0:1]] out_list2 = [geners[1:2]] out_list3 = [geners[2:3]] start = prior_samples.clone().detach() for i in range(10): start.data[:3,0] = -5. + i out = self.vae.decoder(start) out_list1.append(out[0:1]) out_list2.append(out[1:2]) out_list3.append(out[2:3]) out_list = out_list1 + out_list2 + out_list3 geners = torch.cat(out_list,0) file_name = '%s/geners_intp%s.jpg' % (image_directory, postfix) vutils.save_image(geners/2 + 0.5, file_name, nrow=11)
class TrainerVAE(): def __init__(self, config): self.config = config lr = config['lr'] # Initiate the networks imgconf = config['image'] self.vae = VAE(imgconf,config['gen'],config['latent']) self.vae_optim = optim.Adam(self.vae.parameters(),lr=lr) self.vae_scheduler = get_scheduler(self.vae_optim, config) self.mse_crit = nn.MSELoss() def update_learning_rate(self): if self.vae_scheduler is not None: self.vae_scheduler.step() def cuda(self,device=0): members = [attr for attr in dir(self) if isinstance(getattr(self, attr),torch.nn.Module)] for m in members: getattr(self, m).cuda(device) def update_vae(self,images,config): #passes adv grad to decoder self.vae_optim.zero_grad() inputs = images batch_size = images.size(0) # vae part update recons, latent, samples = self.vae(inputs) kl_loss = self.compute_KL_loss(latent) * config['kl_w'] rec_loss = self.mse_crit(recons,images) total_loss = rec_loss + kl_loss total_loss.backward() self.vae_optim.step() self.vae_kl_loss = kl_loss.item() self.vae_rec_loss = rec_loss.item() self.vae_total_loss = total_loss.item() self.encoder_samples = samples.data return recons def compute_KL_loss(self,distribution): mu_2 = torch.pow(distribution['mean'],2) sigma_2 = torch.pow(distribution['std'],2) return (-0.5 * (1 + torch.log(sigma_2) - mu_2 - sigma_2).sum(1)).mean() #return (-0.5 * (1 + torch.log(sigma_2) - mu_2 - sigma_2)).mean() def get_latent_visualization(self,image_directory,postfix,images,prior_samples): with torch.no_grad(): recons, latent, samples = self.vae(images) start = latent['mean'].clone().detach() out_list1 = [recons[0:1]] out_list2 = [recons[1:2]] out_list3 = [recons[2:3]] for i in range(10): start.data[:3,0] = -5. + i out = self.vae.decoder(start) out_list1.append(out[0:1]) out_list2.append(out[1:2]) out_list3.append(out[2:3]) out_list = out_list1 + out_list2 + out_list3 recons = torch.cat(out_list,0) file_name = '%s/recons_intp%s.jpg' % (image_directory, postfix) vutils.save_image(recons/2 + 0.5, file_name, nrow=11) ####################################### geners = self.vae.decoder(prior_samples) out_list1 = [geners[0:1]] out_list2 = [geners[1:2]] out_list3 = [geners[2:3]] start = prior_samples.clone().detach() for i in range(10): start.data[:3,0] = -5. + i out = self.vae.decoder(start) out_list1.append(out[0:1]) out_list2.append(out[1:2]) out_list3.append(out[2:3]) out_list = out_list1 + out_list2 + out_list3 geners = torch.cat(out_list,0) file_name = '%s/geners_intp%s.jpg' % (image_directory, postfix) vutils.save_image(geners/2 + 0.5, file_name, nrow=11) def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers vae_name = os.path.join(snapshot_dir, 'vae_%08d.pt' % (iterations + 1)) torch.save(self.vae.state_dict(), vae_name) def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "vae") state_dict = torch.load(last_model_name) self.vae.load_state_dict(state_dict) iterations = int(last_model_name[-11:-3])
class TrainerVAEGAN(): def __init__(self, config): #super(Trainer, self).__init__() self.config = config lr = config['lr'] # Initiate the networks imgconf = config['image'] self.vae = VAE(imgconf,config['gen'],config['latent']) self.dis = DiscriminatorVAE(config['dis'],imgconf['image_size'],imgconf['image_dim']) ''' disconf = config['dis'] self.dis = DiscriminatorVAE(disconf['n_downsample'],disconf['n_res'], imgconf['image_size'],imgconf['image_dim'], disconf['dim'],disconf['norm'],disconf['activ'],disconf['pad_type']) ''' self.vae_optim = optim.Adam(self.vae.parameters(),lr=lr) self.dis_optim = optim.Adam(self.dis.parameters(),lr=lr) self.vae_scheduler = get_scheduler(self.vae_optim, config) self.dis_scheduler = get_scheduler(self.dis_optim, config) ''' beta1 = config['beta1'] beta2 = config['beta2'] self.vae_optim = optim.Adam(self.vae.parameters(), lr=lr, betas=(beta1, beta2), weight_decay=config['weight_decay']) self.dis_optim = optim.Adam(self.dis.parameters(), lr=lr, betas=(beta1, beta2), weight_decay=config['weight_decay']) ''' self.mse_crit = nn.MSELoss() self.bce_vae = nn.BCELoss() self.bce_dis = nn.BCELoss() ''' self.vae.apply(weights_init(config['init'])) self.dis.apply(weights_init('gaussian')) ''' def update_learning_rate(self): if self.vae_scheduler is not None: self.vae_scheduler.step() if self.dis_scheduler is not None: self.dis_scheduler.step() def cuda(self,device=0): members = [attr for attr in dir(self) if isinstance(getattr(self, attr),torch.nn.Module)] for m in members: getattr(self, m).cuda(device) def update_vae(self,images,config): if not config['vae_adv_full']: return self.__update_vae_adv_dec(images,config) else: return self.__update_vae_adv_all(images,config) def __update_vae_adv_dec(self,images,config): #passes adv grad to decoder self.vae_optim.zero_grad() inputs = images batch_size = images.size(0) # vae part update recons, latent, samples = self.vae(inputs) kl_loss = self.compute_KL_loss(latent) * config['kl_w'] rec_loss = self.mse_crit(recons,images) total_loss = rec_loss + kl_loss total_loss.backward() self.vae_optim.step() # gan part update self.vae_optim.zero_grad() self.dis_optim.zero_grad() #pass gan gradient only to decoder samples = samples.detach() prior_samples = self.vae.prior.sample_prior(batch_size,images.device) all_samples = torch.cat([samples,prior_samples],0) geners = self.vae.decoder(all_samples) fake_out = self.dis(geners).view(geners.size(0),-1) adv_loss = self.dis.calc_gen_loss(fake_out) * config['adv_w'] adv_loss.backward() self.vae_optim.step() #not added adv to total loss self.vae_kl_loss = kl_loss.item() self.vae_rec_loss = rec_loss.item() self.vae_total_loss = total_loss.item() self.vae_adv_loss = adv_loss.item() self.encoder_samples = samples.data return recons def __update_vae_adv_all(self,images,config): #passes adv grad through vae self.vae_optim.zero_grad() self.dis_optim.zero_grad() inputs = images batch_size = images.size(0) # vae part recons, latent, samples = self.vae(inputs) kl_loss = self.compute_KL_loss(latent) * config['kl_w'] rec_loss = self.mse_crit(recons,images) #gan part prior_samples = self.vae.prior.sample_prior(batch_size,images.device) geners = self.vae.decoder(prior_samples) fake = torch.cat([geners,recons],0) fake_out = self.dis(fake).view(fake.size(0),-1) adv_loss = self.dis.calc_gen_loss(fake_out) * config['adv_w'] total_loss = rec_loss + kl_loss + adv_loss total_loss.backward() self.vae_optim.step() self.vae_kl_loss = kl_loss.item() self.vae_rec_loss = rec_loss.item() self.vae_total_loss = total_loss.item() self.vae_adv_loss = adv_loss.item() self.encoder_samples = samples.data return recons def compute_KL_loss(self,distribution): mu_2 = torch.pow(distribution['mean'],2) sigma_2 = torch.pow(distribution['std'],2) return (-0.5 * (1 + torch.log(sigma_2) - mu_2 - sigma_2).sum(1)).mean() #return (-0.5 * (1 + torch.log(sigma_2) - mu_2 - sigma_2).sum(0)).mean() def update_dis(self,images,config): self.vae_optim.zero_grad() self.dis_optim.zero_grad() batch_size = images.size(0) inputs = images with torch.no_grad(): recons, _, _ = self.vae(inputs) prior_samples = self.vae.prior.sample_prior(batch_size) geners = self.vae.decoder(prior_samples) recons = recons.detach() geners = geners.detach() fake = torch.cat([recons,geners],0) fake_out = self.dis(fake) real_out = self.dis(images) ''' fake_loss = self.bce_dis(fake_out,fake_l) real_loss = self.bce_dis(real_out,real_l) dis_total_loss = self.dis.calc_dis_loss(fake_out, real_out) * config['adv_w'] * (fake_loss + real_loss) ''' dis_total_loss, fake_loss, real_loss = self.dis.calc_dis_loss(fake_out, real_out) dis_total_loss *= config['adv_w'] dis_total_loss.backward() self.dis_optim.step() self.dis_fake_loss = fake_loss.item() self.dis_real_loss = real_loss.item() self.dis_total_loss = dis_total_loss.item() def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers vae_name = os.path.join(snapshot_dir, 'vae_%08d.pt' % (iterations + 1)) torch.save(self.vae.state_dict(), vae_name) def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "vae") state_dict = torch.load(last_model_name) self.vae.load_state_dict(state_dict) def get_latent_visualization(self,image_directory,postfix,images,prior_samples): with torch.no_grad(): recons, latent, samples = self.vae(images) start = latent['mean'].clone().detach() out_list1 = [recons[0:1]] out_list2 = [recons[1:2]] out_list3 = [recons[2:3]] for i in range(10): start.data[:3,0] = -5. + i out = self.vae.decoder(start) out_list1.append(out[0:1]) out_list2.append(out[1:2]) out_list3.append(out[2:3]) out_list = out_list1 + out_list2 + out_list3 recons = torch.cat(out_list,0) file_name = '%s/recons_intp%s.jpg' % (image_directory, postfix) vutils.save_image(recons/2 + 0.5, file_name, nrow=11) ####################################### geners = self.vae.decoder(prior_samples) out_list1 = [geners[0:1]] out_list2 = [geners[1:2]] out_list3 = [geners[2:3]] start = prior_samples.clone().detach() for i in range(10): start.data[:3,0] = -5. + i out = self.vae.decoder(start) out_list1.append(out[0:1]) out_list2.append(out[1:2]) out_list3.append(out[2:3]) out_list = out_list1 + out_list2 + out_list3 geners = torch.cat(out_list,0) file_name = '%s/geners_intp%s.jpg' % (image_directory, postfix) vutils.save_image(geners/2 + 0.5, file_name, nrow=11)