def front(self): # Front-End # x : [ Btot , 1, L , 1] # Equivalent to B_tot batches of image of height = 1, width = L and 1 channel -> for Conv1D with Conv2D input_front = tf.reshape(self.x, [self.B_tot, 1, self.L, 1]) # Filter [filter_height, filter_width, input_channels, output_channels] = [1, W, 1, N] # self.window_filter = get_scope_variable('window', 'w', shape=[self.window], initializer=tf.contrib.layers.xavier_initializer_conv2d()) # self.bases = get_scope_variable('bases', 'bases', shape=[self.window, self.N], initializer=tf.contrib.layers.xavier_initializer_conv2d()) # self.conv_filter = tf.reshape(tf.expand_dims(self.window_filter,1)*self.bases , [1, self.window, 1, self.N]) self.conv_filter = get_scope_variable('filters_front','filters_front', shape=[1, self.window, 1, self.N], initializer=tf.contrib.layers.xavier_initializer_conv2d()) # 1 Dimensional convolution along T axis with a window length = self.window # And N = 256 filters -> Create a [Btot, 1, T, N] self.X = tf.nn.conv2d(input_front, self.conv_filter, strides=[1, 1, 1, 1], padding="SAME", name='Conv_STFT') # Reshape to Btot batches of T x N images with 1 channel self.X = tf.reshape(self.X, [self.B_tot, -1, self.N, 1]) self.T = tf.shape(self.X)[1] # Max Pooling with argmax for unpooling later in the back-end layer # Along the T axis (time) self.y, argmax = tf.nn.max_pool_with_argmax(self.X, (1, self.max_pool_value, 1, 1), strides=[1, self.max_pool_value, 1, 1], padding="SAME", name='output') y_shape = tf.shape(self.y) y = tf.reshape(self.y, [self.B_tot, y_shape[1]*y_shape[2]]) self.p_hat = tf.reduce_mean(tf.abs(y), 0) self.sparse_constraint = tf.reduce_sum(kl_div(self.p, self.p_hat)) return self.y, argmax
def i_max_batch(index, mu, log_var): mu_syn = mu[:, index] log_var_syn = log_var[:, index] if len(mu_syn.size()) == 1: i_max = kl_div_uni_dim(mu_syn, log_var_syn).mean() else: i_max = kl_div(mu_syn, log_var_syn) return i_max
def S_metric_1A(mu, logvar, z_dim, batch_size): alpha = 1.5 Smax = torch.empty((1,batch_size)) for s in range(batch_size): mu_s = mu[s,:].view(1,-1) logvar_s = logvar[s,:].view(1,-1) # get the argmax index = greedy_policy_Smax_discount(z_dim, mu_s,logvar_s,alpha=0.8) print("sample {}, index {}".format(s, index)) # get the dims: mu_syn = mu_s[:, index] logvar_syn = logvar_s[:, index] if len(mu_syn.size()) == 1: I_m = kl_div_uni_dim(mu_syn, logvar_syn).mean() # print("here") else: I_m = kl_div(mu_syn, logvar_syn) Smax[0,s] = I_m print("Smax {}".format(Smax)) print("Smax size {}".format(Smax.size())) print("Smax requires grad {}".format(Smax.requires_grad)) I_max= Smax.mean() print("I_max {}".format(I_max)) print("I_max {}".format(I_max.requires_grad)) syn_loss = alpha * I_max return syn_loss
def front(self): # Front-End # x : [ Btot , 1, L , 1] # Equivalent to B_tot batches of image of height = 1, width = L and 1 channel -> for Conv1D with Conv2D input_front = tf.reshape(self.x, [self.B_tot, 1, self.L, 1]) # Filter [filter_height, filter_width, input_channels, output_channels] = [1, W, 1, N] self.window_filter = get_scope_variable( 'window', 'w', shape=[self.window], initializer=tf.contrib.layers.xavier_initializer_conv2d()) self.bases = get_scope_variable( 'bases', 'bases', shape=[self.window, self.N], initializer=tf.contrib.layers.xavier_initializer_conv2d()) self.conv_filter = tf.reshape( tf.abs(tf.expand_dims(self.window_filter, 1)) * self.bases, [1, self.window, 1, self.N]) # self.conv_filter = get_scope_variable('filters_front','filters_front', shape=[1, self.window, 1, self.N]) variable_summaries(self.conv_filter) # 1 Dimensional convolution along T axis with a window length = self.window # And N = 256 filters -> Create a [Btot, 1, T, N] self.T = tf.shape(input_front)[2] if self.with_max_pool: self.X = tf.nn.conv2d(input_front, self.conv_filter, strides=[1, 1, 1, 1], padding="SAME", name='Conv_STFT') self.y, self.argmax = tf.nn.max_pool_with_argmax( self.X, [1, 1, self.max_pool_value, 1], strides=[1, 1, self.max_pool_value, 1], padding="SAME", name='output') print self.argmax elif self.with_average_pool: self.X = tf.nn.conv2d(input_front, self.conv_filter, strides=[1, 1, 1, 1], padding="SAME", name='Conv_STFT') # self.y = tf.nn.avg_pool(self.X, [1, 1, self.max_pool_value, 1], [1, 1, self.max_pool_value, 1], padding="SAME") self.y = tf.layers.average_pooling2d(self.X, (1, self.max_pool_value), strides=(1, self.max_pool_value), name='output') else: self.y = tf.nn.conv2d(input_front, self.conv_filter, strides=[1, 1, self.max_pool_value, 1], padding="SAME", name='Conv_STFT') # Reshape to Btot batches of T x N images with 1 channel # [Btot, 1, T_pool, N] -> [Btot, T-pool, N, 1] self.y = tf.transpose(self.y, [0, 2, 3, 1], name='output') tf.summary.image('front/output', self.y, max_outputs=3) y_shape = tf.shape(self.y) y = tf.reshape(self.y, [self.B_tot, y_shape[1] * y_shape[2]]) self.p_hat = tf.reduce_sum(tf.abs(y), 0) self.sparse_constraint = tf.reduce_sum(kl_div(self.p, self.p_hat)) return self.y
def train(self): self.net_mode(train=True) epochs = int(np.ceil(self.steps) / len(self.dataloader)) print("number of epochs {}".format(epochs)) step = 0 c = Counter() d = Counter() for e in range(epochs): for x_true1, x_true2 in self.dataloader: step += 1 # VAE x_true1 = x_true1.unsqueeze(1).to(self.device) x_recon, mu, log_var, z = self.VAE(x_true1) # Reconstruction and KL vae_recon_loss = recon_loss(x_true1, x_recon) vae_kl = kl_div(mu, log_var) vae_loss = vae_recon_loss + vae_kl # Optimise VAE self.optim_VAE.zero_grad() vae_loss.backward( retain_graph=True) # grad parameters are populated self.optim_VAE.step() # Sampling if self.args.sample == "sample": x_true2 = x_true2.unsqueeze(1).to(self.device) parameters = self.VAE(x_true2, decode=False) mu_prime = parameters[1] log_var_prime = parameters[2] else: mu_prime = mu log_var_prime = log_var # Synergy Max # Step 1: compute the arg-max of D kl (q(ai | x(i)) || ) best_ai, worst_ai = greedy_policy_s_max_discount_worst( self.z_dim, mu_prime, log_var_prime, alpha=self.omega) c.update(best_ai) d.update(worst_ai) # Step 2: compute the I-max mu_syn = mu_prime[:, worst_ai] log_var_syn = log_var_prime[:, worst_ai] if len(mu_syn.size()) == 1: i_max = kl_div_uni_dim(mu_syn, log_var_syn).mean() else: i_max = kl_div(mu_syn, log_var_syn) # Step 3: Use it in the loss syn_loss = self.alpha * i_max # Step 4: Optimise Syn term self.optim_VAE.zero_grad() syn_loss.backward() self.optim_VAE.step( ) # Does the update in VAE network parameters # Logging if step % self.args.log_interval == 0: O = OrderedDict([ (i, str(round(count / sum(c.values()) * 100.0, 3)) + '%') for i, count in c.most_common() ]) P = OrderedDict([ (i, str(round(count / sum(d.values()) * 100.0, 3)) + '%') for i, count in d.most_common() ]) print("Step {}".format(step)) print("Recons. Loss = " + "{:.4f}".format(vae_recon_loss)) print("KL Loss = " + "{:.4f}".format(vae_kl)) print("VAE Loss = " + "{:.4f}".format(vae_loss)) print("best_ai {}".format(best_ai)) print("worst_ai {}".format(worst_ai)) print("I_max {}".format(i_max)) print("Syn loss {:.4f}".format(syn_loss)) print() for k, v in O.items(): print("best latent {}: {}".format(k, v)) print() for k, v in P.items(): print("worst latent {}: {}".format(k, v)) print() # Saving traverse if not step % self.args.save_interval: filename = 'alpha_' + str( self.alpha) + '_traversal_' + str(step) + '.png' filepath = os.path.join(self.args.output_dir, filename) traverse(self.net_mode, self.VAE, self.test_imgs, filepath) # Gather data if self.viz_on and (step % self.viz_il_iter == 0): Q = OrderedDict([(i, round(count / sum(c.values()) * 100.0, 3)) for i, count in c.items()]) H = dict() for k in range(10): if k in Q: H[k] = Q[k] else: H[k] = 0.0 self.line_gather.insert(iter=step, recon=vae_recon_loss.item(), kl=vae_kl.item(), syn=syn_loss.item(), l0=H[0], l1=H[1], l2=H[2], l3=H[3], l4=H[4], l5=H[5], l6=H[6], l7=H[7], l8=H[8], l9=H[9]) # Visualise data if self.viz_on and (step % self.viz_la_iter == 0): self.visualize_line() self.line_gather.flush()
def train(self): self.net_mode(train=True) epochs = int(np.ceil(self.steps) / len(self.dataloader)) print("number of epochs {}".format(epochs)) step = 0 c = Counter() d = Counter() for e in range(epochs): for x_true1, x_true2 in self.dataloader: step += 1 # VAE x_true1 = x_true1.unsqueeze(1).to(self.device) x_recon, mu, log_var, z = self.VAE(x_true1) # Reconstruction and KL vae_recon_loss = recon_loss(x_true1, x_recon) vae_kl = kl_div(mu, log_var) vae_loss = vae_recon_loss + vae_kl # Optimise VAE self.optim_VAE.zero_grad() vae_loss.backward(retain_graph=True) self.optim_VAE.step() # Sampling if self.args.sample == "sample": x_true2 = x_true2.unsqueeze(1).to(self.device) parameters = self.VAE(x_true2, decode=False) mu_prime = parameters[1] log_var_prime = parameters[2] else: mu_prime = mu log_var_prime = log_var # Synergy Max # Step 1: compute the arg-max of D kl (q(ai | x(i)) || ) in a greedy way. best_ai, worst_ai = greedy_policy_s_max_discount_worst( self.z_dim, mu_prime, log_var_prime, alpha=self.omega) c.update(best_ai) d.update(worst_ai) # Step 2: compute the I-max mu_syn = mu_prime[:, worst_ai] log_var_syn = log_var_prime[:, worst_ai] if len(mu_syn.size()) == 1: i_max = kl_div_uni_dim(mu_syn, log_var_syn).mean() else: i_max = kl_div(mu_syn, log_var_syn) # Step 3: Use it in the loss syn_loss = self.alpha * i_max # alpha>0 ~2-4 # Step 4: Optimise Syn term self.optim_VAE.zero_grad() syn_loss.backward() # back-propagate the gradients self.optim_VAE.step( ) # does the update in VAE network parameters # Logging if step % self.args.log_interval == 0: O = OrderedDict([ (i, str(round(count / sum(c.values()) * 100.0, 3)) + '%') for i, count in c.most_common() ]) P = OrderedDict([ (i, str(round(count / sum(d.values()) * 100.0, 3)) + '%') for i, count in d.most_common() ]) print("Step {}".format(step)) print("Recons. Loss = " + "{:.4f}".format(vae_recon_loss)) print("KL Loss = " + "{:.4f}".format(vae_kl)) print("VAE Loss = " + "{:.4f}".format(vae_loss)) print("best_ai {}".format(best_ai)) print("worst_ai {}".format(worst_ai)) print("I_max {}".format(i_max)) print("Syn loss {:.4f}".format(syn_loss)) print() for k, v in O.items(): print("best latent {}: {}".format(k, v)) print() for k, v in P.items(): print("worst latent {}: {}".format(k, v)) print() # Saving traverse if not step % self.args.save_interval: filename = 'alpha_' + str( self.alpha) + '_traversal_' + str(step) + '.png' filepath = os.path.join(self.args.output_dir, filename) traverse(self.net_mode, self.VAE, self.test_imgs, filepath) # Saving plot gt vs predicted if not step % self.args.gt_interval: filename = 'alpha_' + str( self.alpha) + '_gt_' + str(step) + '.png' filepath = os.path.join(self.args.output_dir, filename) plot_gt_shapes(self.net_mode, self.VAE, self.dataloader_gt, filepath)