def get_proj_distance_square(input_image, manifold_label_proj, data_, label_, args): ''' Calculate the minimum distance between a data point "input_image" and a manifold (w/ label "manifold_label_proj") manifold_dim : the dimension of z noise num_epochs_z = How much step you make during finding the optimal projected image ''' # saved_generator, manifold_label_proj, label_GAN, manifold_dim, data_type, data_class, device, distance_type, show_image, data_, label_, lr=0.05, num_epochs_z=500 #self.gan, 0, label_GAN, manifold_dim, data_type, data_class, self.device, distance_type, False, data_, label_, lr=self.lr, num_epochs_z=num_epochs_z saved_generator = args.gan label_GAN = args.label_GAN manifold_dim = args.z_dim data_type = args.data_type data_class = args.data_class device = args.device distance_type = args.distance_type show_image = False #data_ #label_ lr = args.lr num_epochs_z = args.num_epochs_z num_random_z = args.num_random_z image_size_ref = args.image_size_ref #if data_class == 'Real': num_random_z, image_size_ref = [10, 32] #else: num_random_z = 5 z_maxRad_coeff = 1.1 z_minRad_coeff = 0.9 if data_type == 'mnist_ext': lr = 0.01 num_epochs_z = 500 elif data_class == 'Synthetic': #lr = 0.005 #num_epochs_z = 2000 z_maxRad_coeff = 3 z_minRad_coeff = 0 if data_type in ['mnist', 'mnist_ext']: numCH = 1 elif data_type in ['cifar10']: numCH = 3 input_image = input_image.to(device) saved_generator.eval() y_label = torch.zeros([ 1, label_GAN ]) # one-hot vector for indicating the label of images to generate y_label[0][manifold_label_proj] = 1 for random_z in range(num_random_z): #print('========= Random z generated at iter %d ===========' % random_z) if data_class == 'Real': z_Var = Variable(torch.randn([1, manifold_dim, 1, 1]).to(device), requires_grad=True) # Set initial z value y_label = y_label.view(1, -1, 1, 1) else: z_Var = Variable(torch.randn([manifold_dim, 1]).to(device), requires_grad=True) # Set initial z value y_label = y_label.to(device) #pdb.set_trace() z_optimizer = optim.Adam([z_Var], lr, betas=(0.5, 0.999)) # Optimizer for iter_z in range(num_epochs_z): z_maxRad = z_maxRad_coeff * np.sqrt(manifold_dim) z_minRad = z_minRad_coeff * np.sqrt(manifold_dim) if torch.norm(z_Var) > z_maxRad: z_Var.data = z_Var.data / torch.norm(z_Var) * z_maxRad if torch.norm(z_Var) < z_minRad: z_Var.data = z_Var.data / torch.norm(z_Var) * z_minRad z_Var = z_Var.reshape(-1, manifold_dim) #pdb.set_trace() # Check the projected image & calculate loss proj_image = saved_generator(z_Var, y_label) #pdb.set_trace() data_dim = proj_image.size()[1] if distance_type == 'L2': loss_z = nn.MSELoss() # Loss function (L2-norm case) if data_class == 'Real': recLoss = loss_z( proj_image, input_image.view( -1, numCH, image_size_ref, image_size_ref).float()) # L2-norm case else: recLoss = loss_z(proj_image, input_image.view( 1, data_dim).float()) # L2-norm case elif distance_type == 'L1': if data_class == 'Real': recLoss = torch.sum( torch.abs(proj_image - input_image.view( -1, numCH, image_size_ref, image_size_ref).float()) ) # L1-norm case else: recLoss = torch.sum( torch.abs(proj_image - input_image.view(1, data_dim).float()) ) # L1-norm case elif distance_type == 'Linf': #pdb.set_trace() if data_class == 'Real': recLoss = torch.max( torch.abs(proj_image - input_image.view( -1, numCH, image_size_ref, image_size_ref).float()) ) # Linf-norm case else: recLoss = torch.max( torch.abs(proj_image - input_image.view(1, data_dim).float()) ) # Linf-norm case # Update best z variable (with minimum loss) if iter_z == 0: best_loss_z = recLoss.data best_z_Var = z_Var.clone().detach() else: if recLoss.data < best_loss_z: best_loss_z = recLoss.data best_z_Var = z_Var.clone().detach() # Update z using gradient descent z_optimizer.zero_grad() recLoss.backward() z_optimizer.step() # Display status # if (iter_z) % 50 == 0: # print ('Epoch: %d, LR: %0.5f, loss: %0.5f, z_Var: %0.3f' % (iter_z, lr, recLoss.data, torch.norm(z_Var.data))) # print ('Smallest loss: %0.5f' % (best_loss_z)) # if distance_type == 'L2': # print ('Closest L2-distance (sqrt of MSE): %0.5f' % (np.sqrt(best_loss_z.cpu()))) # L2-norm case # elif distance_type == 'L1': # print ('Closest L1-distance: %0.5f' % (best_loss_z)) # L1-norm case # elif distance_type == 'Linf': # print ('Closest Linf-distance: %0.5f' % (best_loss_z)) # Linf-norm case # print('L2-norm of best_z_Var: ', torch.norm(best_z_Var).data) proj_image = saved_generator( best_z_Var, y_label).detach() # G_i (z*) << here, i depends on y_label if random_z == 0: # the first time BoB_z_Var = best_z_Var BoB_loss_z = best_loss_z else: if best_loss_z < BoB_loss_z: BoB_z_Var = best_z_Var BoB_loss_z = best_loss_z BoB_proj_image = saved_generator(BoB_z_Var, y_label).detach() if distance_type == 'L2': #pdb.set_trace() distance = torch.norm(BoB_proj_image - input_image.float() ) / 32 # L2-norm case (sqrt(MSE)) elif distance_type == 'L1': distance = torch.sum(torch.abs(BoB_proj_image - input_image.float())) # L1-norm case elif distance_type == 'Linf': distance = torch.max(torch.abs(BoB_proj_image - input_image.float())) # Linf-norm case # print('[BEST_OF_BEST]_[Projected To Manifold %d] projected image: ' % manifold_label_proj, BoB_proj_image) # print('torch.max(projected image)', torch.max(BoB_proj_image) ) # print('torch.min(projected image)', torch.min(BoB_proj_image) ) # print('Red point image: ', input_image) # print('BoB Distance: ', distance) # if show_image == True: # if data_class == 'Synthetic': # fig1, ax1 = plt.subplots() # fig2, ax2 = plt.subplots() # df = DataFrame(dict(x=data_[:,0].cpu(), y=data_[:,1].cpu(), label=label_.cpu())) # colors = {0:'black', 1:'blue'} # grouped = df.groupby('label') # # if num_class_total == 1: # ax1.scatter(data_[:,0], data_[:,1], c=label_, cmap='gray', edgecolor='k') # ax2.scatter(data_[:,0], data_[:,1], c=label_, cmap='gray', edgecolor='k') # # for key, group in grouped: # group.plot(ax=ax1, kind='scatter', x='x', y='y', label=key, color=colors[key]) # group.plot(ax=ax2, kind='scatter', x='x', y='y', label=key, color=colors[key]) # ax1.scatter(input_image[0][0].cpu(), input_image[0][1].cpu(), c='red') # ax2.scatter(BoB_proj_image[0][0].cpu(), BoB_proj_image[0][1].cpu(), c='red') # # if data_type in ['circle', 'v_shaped']: # ax1.axis('equal') # ax2.axis('equal') # plt.axis('square') # plt.show() # # elif data_type in ['mnist', 'mnist_ext']: # plt.subplot(121) # plt.imshow(input_image.view(32,32), cmap='Greys') # plt.subplot(122) # plt.imshow(BoB_proj_image.view(32,32), cmap='Greys') # plt.show() return BoB_z_Var, BoB_proj_image, distance
edge_avg_w = edge_avg_w[sel_idx] E = sel_idx.shape[0] // 2 print('[Filtering Edges] After filtering: %d' % E) ''' Select top_k nodes ------------------------------------------------------------------------------------------------------- ''' start_nodes = graph_utils.choose_anchor_node(N, bi_e_node_idx, edge_avg_w[:2 * E], top_k=selected_topk) ori_gt_q = rot2quaternion(node_Es[:, :3, :3]) ''' Optimize the init orientation -------------------------------------------------------------------------------------------- ''' # create optimizer with torch.cuda.device(train_params.DEV_IDS[0]): cur_dev = torch.cuda.current_device() opt_w_var = Variable(edge_avg_w.clone(), requires_grad=True) ref_w = edge_avg_w.clone().detach() if enable_sigmoid_norm is True else torch.sigmoid(0.6 * opt_w_var.clone()).detach() sub_optimzer = torch.optim.Adam([{'params': opt_w_var, 'lr': 1.5}]) bi_e_node_idx = bi_e_node_idx.to(cur_dev) bi_e_rel_q = (bi_e_rel_q.to(cur_dev)) bi_e_label = torch.ones_like(opt_w_var) ori_gt_q = ori_gt_q.to(cur_dev) pre_ref_w = None print('[Optimizing Initial Orientation]') for itr in range(max_init_itr): sub_optimzer.zero_grad() if enable_sigmoid_norm is False: opt_w = torch.sigmoid(0.6 * opt_w_var) # tip: factor 0.6 makes faster convergency else: opt_w = opt_w_var