Beispiel #1
0
def get_proj_distance_square(input_image, manifold_label_proj, data_, label_,
                             args):
    '''
    Calculate the minimum distance between a data point "input_image" and a manifold (w/ label "manifold_label_proj")
    manifold_dim : the dimension of z noise
    num_epochs_z = How much step you make during finding the optimal projected image
    '''
    # saved_generator, manifold_label_proj, label_GAN, manifold_dim, data_type, data_class, device, distance_type, show_image, data_, label_, lr=0.05, num_epochs_z=500
    #self.gan, 0, label_GAN, manifold_dim, data_type, data_class, self.device, distance_type, False, data_, label_, lr=self.lr, num_epochs_z=num_epochs_z
    saved_generator = args.gan
    label_GAN = args.label_GAN
    manifold_dim = args.z_dim
    data_type = args.data_type
    data_class = args.data_class
    device = args.device
    distance_type = args.distance_type
    show_image = False
    #data_
    #label_
    lr = args.lr
    num_epochs_z = args.num_epochs_z
    num_random_z = args.num_random_z
    image_size_ref = args.image_size_ref

    #if data_class == 'Real':    num_random_z, image_size_ref = [10, 32]
    #else:    num_random_z = 5

    z_maxRad_coeff = 1.1
    z_minRad_coeff = 0.9

    if data_type == 'mnist_ext':
        lr = 0.01
        num_epochs_z = 500
    elif data_class == 'Synthetic':
        #lr = 0.005
        #num_epochs_z = 2000
        z_maxRad_coeff = 3
        z_minRad_coeff = 0

    if data_type in ['mnist', 'mnist_ext']: numCH = 1
    elif data_type in ['cifar10']: numCH = 3

    input_image = input_image.to(device)
    saved_generator.eval()
    y_label = torch.zeros([
        1, label_GAN
    ])  # one-hot vector for indicating the label of images to generate
    y_label[0][manifold_label_proj] = 1

    for random_z in range(num_random_z):
        #print('========= Random z generated at iter %d ===========' % random_z)
        if data_class == 'Real':
            z_Var = Variable(torch.randn([1, manifold_dim, 1, 1]).to(device),
                             requires_grad=True)  # Set initial z value
            y_label = y_label.view(1, -1, 1, 1)

        else:
            z_Var = Variable(torch.randn([manifold_dim, 1]).to(device),
                             requires_grad=True)  # Set initial z value
        y_label = y_label.to(device)
        #pdb.set_trace()
        z_optimizer = optim.Adam([z_Var], lr, betas=(0.5, 0.999))  # Optimizer
        for iter_z in range(num_epochs_z):
            z_maxRad = z_maxRad_coeff * np.sqrt(manifold_dim)
            z_minRad = z_minRad_coeff * np.sqrt(manifold_dim)
            if torch.norm(z_Var) > z_maxRad:
                z_Var.data = z_Var.data / torch.norm(z_Var) * z_maxRad
            if torch.norm(z_Var) < z_minRad:
                z_Var.data = z_Var.data / torch.norm(z_Var) * z_minRad

            z_Var = z_Var.reshape(-1, manifold_dim)
            #pdb.set_trace()

            # Check the projected image & calculate loss
            proj_image = saved_generator(z_Var, y_label)
            #pdb.set_trace()
            data_dim = proj_image.size()[1]
            if distance_type == 'L2':
                loss_z = nn.MSELoss()  # Loss function (L2-norm case)
                if data_class == 'Real':
                    recLoss = loss_z(
                        proj_image,
                        input_image.view(
                            -1, numCH, image_size_ref,
                            image_size_ref).float())  # L2-norm case
                else:
                    recLoss = loss_z(proj_image,
                                     input_image.view(
                                         1, data_dim).float())  # L2-norm case
            elif distance_type == 'L1':
                if data_class == 'Real':
                    recLoss = torch.sum(
                        torch.abs(proj_image - input_image.view(
                            -1, numCH, image_size_ref, image_size_ref).float())
                    )  # L1-norm case
                else:
                    recLoss = torch.sum(
                        torch.abs(proj_image -
                                  input_image.view(1, data_dim).float())
                    )  # L1-norm case
            elif distance_type == 'Linf':
                #pdb.set_trace()
                if data_class == 'Real':
                    recLoss = torch.max(
                        torch.abs(proj_image - input_image.view(
                            -1, numCH, image_size_ref, image_size_ref).float())
                    )  # Linf-norm case
                else:
                    recLoss = torch.max(
                        torch.abs(proj_image -
                                  input_image.view(1, data_dim).float())
                    )  # Linf-norm case

            # Update best z variable (with minimum loss)
            if iter_z == 0:
                best_loss_z = recLoss.data
                best_z_Var = z_Var.clone().detach()
            else:
                if recLoss.data < best_loss_z:
                    best_loss_z = recLoss.data
                    best_z_Var = z_Var.clone().detach()

            # Update z using gradient descent
            z_optimizer.zero_grad()
            recLoss.backward()
            z_optimizer.step()

            # Display status
        #     if (iter_z) % 50 == 0:
        #         print ('Epoch: %d, LR: %0.5f,  loss: %0.5f, z_Var: %0.3f' % (iter_z, lr, recLoss.data, torch.norm(z_Var.data)))

        # print ('Smallest loss: %0.5f' % (best_loss_z))
        # if distance_type == 'L2':
        #     print ('Closest L2-distance (sqrt of MSE): %0.5f' % (np.sqrt(best_loss_z.cpu()))) # L2-norm case
        # elif distance_type == 'L1':
        #     print ('Closest L1-distance: %0.5f' % (best_loss_z)) # L1-norm case
        # elif distance_type == 'Linf':
        #     print ('Closest Linf-distance: %0.5f' % (best_loss_z)) # Linf-norm case
        # print('L2-norm of best_z_Var: ', torch.norm(best_z_Var).data)

        proj_image = saved_generator(
            best_z_Var,
            y_label).detach()  # G_i (z*) << here, i depends on y_label
        if random_z == 0:  # the first time
            BoB_z_Var = best_z_Var
            BoB_loss_z = best_loss_z
        else:
            if best_loss_z < BoB_loss_z:
                BoB_z_Var = best_z_Var
                BoB_loss_z = best_loss_z

    BoB_proj_image = saved_generator(BoB_z_Var, y_label).detach()

    if distance_type == 'L2':
        #pdb.set_trace()
        distance = torch.norm(BoB_proj_image - input_image.float()
                              ) / 32  # L2-norm case (sqrt(MSE))
    elif distance_type == 'L1':
        distance = torch.sum(torch.abs(BoB_proj_image -
                                       input_image.float()))  # L1-norm case
    elif distance_type == 'Linf':
        distance = torch.max(torch.abs(BoB_proj_image -
                                       input_image.float()))  # Linf-norm case

    # print('[BEST_OF_BEST]_[Projected To Manifold %d] projected image: ' % manifold_label_proj, BoB_proj_image)
    # print('torch.max(projected image)', torch.max(BoB_proj_image) )
    # print('torch.min(projected image)', torch.min(BoB_proj_image) )
    # print('Red point image: ', input_image)
    # print('BoB Distance: ', distance)


#    if show_image == True:
#        if data_class == 'Synthetic':
#            fig1, ax1 = plt.subplots()
#            fig2, ax2 = plt.subplots()
#            df = DataFrame(dict(x=data_[:,0].cpu(), y=data_[:,1].cpu(), label=label_.cpu()))
#            colors = {0:'black', 1:'blue'}
#            grouped = df.groupby('label')
#
#            if num_class_total == 1:
#                ax1.scatter(data_[:,0], data_[:,1], c=label_, cmap='gray', edgecolor='k')
#                ax2.scatter(data_[:,0], data_[:,1], c=label_, cmap='gray', edgecolor='k')
#
#            for key, group in grouped:
#                group.plot(ax=ax1, kind='scatter', x='x', y='y', label=key, color=colors[key])
#                group.plot(ax=ax2, kind='scatter', x='x', y='y', label=key, color=colors[key])
#            ax1.scatter(input_image[0][0].cpu(), input_image[0][1].cpu(), c='red')
#            ax2.scatter(BoB_proj_image[0][0].cpu(), BoB_proj_image[0][1].cpu(), c='red')
#
#            if data_type in ['circle', 'v_shaped']:
#                ax1.axis('equal')
#                ax2.axis('equal')
#                plt.axis('square')
#            plt.show()
#
#        elif data_type in ['mnist', 'mnist_ext']:
#            plt.subplot(121)
#            plt.imshow(input_image.view(32,32), cmap='Greys')
#            plt.subplot(122)
#            plt.imshow(BoB_proj_image.view(32,32), cmap='Greys')
#            plt.show()

    return BoB_z_Var, BoB_proj_image, distance
    edge_avg_w = edge_avg_w[sel_idx]
    E = sel_idx.shape[0] // 2
    print('[Filtering Edges] After filtering: %d' % E)

''' Select top_k nodes -------------------------------------------------------------------------------------------------------
'''
start_nodes = graph_utils.choose_anchor_node(N, bi_e_node_idx, edge_avg_w[:2 * E], top_k=selected_topk)
ori_gt_q = rot2quaternion(node_Es[:, :3, :3])

''' Optimize the init orientation --------------------------------------------------------------------------------------------
'''
# create optimizer
with torch.cuda.device(train_params.DEV_IDS[0]):
    cur_dev = torch.cuda.current_device()
    opt_w_var = Variable(edge_avg_w.clone(), requires_grad=True)
    ref_w = edge_avg_w.clone().detach() if enable_sigmoid_norm is True else torch.sigmoid(0.6 * opt_w_var.clone()).detach()
    sub_optimzer = torch.optim.Adam([{'params': opt_w_var, 'lr': 1.5}])
    bi_e_node_idx = bi_e_node_idx.to(cur_dev)
    bi_e_rel_q = (bi_e_rel_q.to(cur_dev))
    bi_e_label = torch.ones_like(opt_w_var)
    ori_gt_q = ori_gt_q.to(cur_dev)

    pre_ref_w = None
    print('[Optimizing Initial Orientation]')
    for itr in range(max_init_itr):
        sub_optimzer.zero_grad()

        if enable_sigmoid_norm is False:
            opt_w = torch.sigmoid(0.6 * opt_w_var)          # tip: factor 0.6 makes faster convergency
        else:
            opt_w = opt_w_var