Exemple #1
0
    def forward(self, src_mesh):
        loss = 0

        # Sample from target meshes
        target_verts = sample_points_from_meshes(self.target_meshes, 3000)

        if self.consider_loss("chamfer"):
            loss_chamfer, _ = chamfer_distance(target_verts,
                                               src_mesh.verts_padded())
            loss += self.loss_weights["w_chamfer"] * loss_chamfer

        if self.consider_loss("edge"):
            loss_edge = mesh_edge_loss(
                src_mesh)  # and (b) the edge length of the predicted mesh
            loss += self.loss_weights["w_edge"] * loss_edge

        if self.consider_loss("normal"):
            loss_normal = mesh_normal_consistency(
                src_mesh)  # mesh normal consistency
            loss += self.loss_weights["w_normal"] * loss_normal

        if self.consider_loss("laplacian"):
            loss_laplacian = mesh_laplacian_smoothing(
                src_mesh, method="uniform")  # mesh laplacian smoothing
            loss += self.loss_weights["w_laplacian"] * loss_laplacian

        return loss
Exemple #2
0
    def loss(self, src_mesh, src_verts):
        loss = 0

        if self.consider_loss("chamfer"):
            loss_chamfer, _ = chamfer_distance(
                self.target_verts, src_verts
            )  # We compare the two sets of pointclouds by computing (a) the chamfer loss

            loss += self.loss_weights["w_chamfer"] * loss_chamfer

        if self.consider_loss("edge"):
            loss_edge = mesh_edge_loss(
                src_mesh)  # and (b) the edge length of the predicted mesh
            loss += self.loss_weights["w_edge"] * loss_edge

        if self.consider_loss("normal"):
            loss_normal = mesh_normal_consistency(
                src_mesh)  # mesh normal consistency
            loss += self.loss_weights["w_normal"] * loss_normal

        if self.consider_loss("laplacian"):
            loss_laplacian = mesh_laplacian_smoothing(
                src_mesh, method="uniform")  # mesh laplacian smoothing
            loss += self.loss_weights["w_laplacian"] * loss_normal

        if self.consider_loss("arap"):
            for n in range(len(self.target_meshes)):
                loss_arap = arap_loss(self.prev_mesh,
                                      self.prev_verts,
                                      src_verts,
                                      mesh_idx=n)
                loss += self.loss_weights["w_arap"] * loss_arap

        return loss, loss_chamfer
Exemple #3
0
def get_loss(mesh,
             trg_mesh,
             w_chamfer,
             w_edge,
             w_normal,
             w_laplacian,
             n_points=5000):
    # We sample 5k points from the surface of each mesh
    sample_trg = sample_points_from_meshes(trg_mesh, n_points)
    sample_src = sample_points_from_meshes(mesh, n_points)

    # We compare the two sets of pointclouds by computing (a) the chamfer loss

    loss_chamfer, _ = chamfer_distance(sample_trg, sample_src)

    # and (b) the edge length of the predicted mesh
    loss_edge = mesh_edge_loss(mesh)

    # mesh normal consistency
    loss_normal = mesh_normal_consistency(mesh)

    # mesh laplacian smoothing
    loss_laplacian = mesh_laplacian_smoothing(mesh, method="uniform")

    # Weighted sum of the losses
    loss = loss_chamfer * w_chamfer + loss_edge * w_edge + loss_normal * w_normal + loss_laplacian * w_laplacian

    return loss
Exemple #4
0
    def test_mesh_edge_loss_output(self):
        """
        Check outputs of tensorized and iterative implementations are the same.
        """
        device = torch.device("cuda:0")
        target_length = 0.5
        num_meshes = 10
        num_verts = 32
        num_faces = 64

        verts_list = []
        faces_list = []
        valid = torch.randint(2, size=(num_meshes, ))

        for n in range(num_meshes):
            if valid[n]:
                vn = torch.randint(3, high=num_verts, size=(1, ))[0].item()
                fn = torch.randint(vn, high=num_faces, size=(1, ))[0].item()
                verts = torch.rand((vn, 3), dtype=torch.float32, device=device)
                faces = torch.randint(vn,
                                      size=(fn, 3),
                                      dtype=torch.int64,
                                      device=device)
            else:
                verts = torch.tensor([], dtype=torch.float32, device=device)
                faces = torch.tensor([], dtype=torch.int64, device=device)
            verts_list.append(verts)
            faces_list.append(faces)
        meshes = Meshes(verts=verts_list, faces=faces_list)
        loss = mesh_edge_loss(meshes, target_length=target_length)

        predloss = TestMeshEdgeLoss.mesh_edge_loss_naive(meshes, target_length)
        self.assertTrue(torch.allclose(loss, predloss))
Exemple #5
0
    def loss(self, data, epoch):

         
        pred = self.forward(data)  
        # embed()
        # loss_coef = max(1/(2**(epoch//10000)), 0.1)



        # CE_Loss = nn.CrossEntropyLoss()
        # ce_loss = CE_Loss(pred[0][-1][3], data['y_voxels'])
        weight = data['base_plane'].float().cuda()
        CE_Loss = nn.CrossEntropyLoss(reduction='none')
        ce_loss = CE_Loss(pred[0][-1][3], data['y_voxels'].cuda()) * weight
        ce_loss = ce_loss.mean()

        chamfer_loss = torch.tensor(0).float().cuda()
        edge_loss = torch.tensor(0).float().cuda()
        laplacian_loss = torch.tensor(0).float().cuda()
        normal_consistency_loss = torch.tensor(0).float().cuda()  

        for c in range(self.config.num_classes-1):
            target = data['surface_points'][c].cuda() 
            for k, (vertices, faces, _, _, _) in enumerate(pred[c][1:]):

                pred_mesh = Meshes(verts=list(vertices), faces=list(faces))
                pred_points = sample_points_from_meshes(pred_mesh, 3000)

                chamfer_loss +=  chamfer_distance(pred_points, target)[0]
                laplacian_loss +=   mesh_laplacian_smoothing(pred_mesh, method="uniform")
                normal_consistency_loss += mesh_normal_consistency(pred_mesh)
                edge_loss += mesh_edge_loss(pred_mesh)

            # vertices, faces, _, _, _ = pred[c][-1]
            # pred_mesh = Meshes(verts=list(vertices), faces=list(faces))
            # pred_points = sample_points_from_meshes(pred_mesh, 3000)
            #
            # chamfer_loss += chamfer_distance(pred_points, target)[0]*5
            # laplacian_loss += mesh_laplacian_smoothing(pred_mesh, method="uniform")*5
            # normal_consistency_loss += mesh_normal_consistency(pred_mesh)*5
            # edge_loss += mesh_edge_loss(pred_mesh)*5
            #
            # # chamfer_loss = chamfer_loss/2
            # # laplacian_loss = laplacian_loss/2
            # # normal_consistency_loss = normal_consistency_loss/2
            # # edge_loss = edge_loss/2

        loss = 1 * chamfer_loss + 1 * ce_loss + 0.1 * laplacian_loss + 1 * edge_loss + 0.1 * normal_consistency_loss
        # loss = 1 * chamfer_loss + 0.1 * laplacian_loss + 1 * edge_loss + 0.1 * normal_consistency_loss
        # loss = 1 * chamfer_loss + 0.1 * laplacian_loss + loss_coef * edge_loss + 0.1 * normal_consistency_loss

        log = {"loss": loss.detach(),
               "chamfer_loss": chamfer_loss.detach(), 
               # "loss_coef": loss_coef,
               "ce_loss": ce_loss.detach(),
               "normal_consistency_loss": normal_consistency_loss.detach(),
               "edge_loss": edge_loss.detach(),
               "laplacian_loss": laplacian_loss.detach()}
        return loss, log
def update_mesh_shape_prior_losses(mesh, loss):
    # and (b) the edge length of the predicted mesh
    loss["edge"] = mesh_edge_loss(mesh)

    # mesh normal consistency
    loss["normal"] = mesh_normal_consistency(mesh)

    # mesh laplacian smoothing
    loss["laplacian"] = mesh_laplacian_smoothing(mesh, method="uniform")
Exemple #7
0
    def _mesh_loss(self, meshes_pred, points_gt, normals_gt):
        """
        Args:
          meshes_pred: Meshes containing N meshes
          points_gt: Tensor of shape NxPx3
          normals_gt: Tensor of shape NxPx3

        Returns:
          total_loss (float): The sum of all losses specific to meshes
          losses (dict): All (unweighted) mesh losses in a dictionary
        """
        device = meshes_pred.verts_list()[0].device
        zero = torch.tensor(0.0).to(device)
        losses = {"chamfer": zero, "normal": zero, "edge": zero}
        if self.upsample_pred_mesh:
            points_pred, normals_pred = sample_points_from_meshes(
                meshes_pred,
                num_samples=self.pred_num_samples,
                return_normals=True)
        else:
            points_pred = meshes_pred.verts_list()
            normals_pred = meshes_pred.verts_normals_list()

        total_loss = torch.tensor(0.0).to(device)
        if points_pred is None or points_gt is None:
            # Sampling failed, so return None
            total_loss = None
            which = "predictions" if points_pred is None else "GT"
            logger.info("WARNING: Sampling %s failed" % (which))
            return total_loss, losses

        losses = {}
        cham_loss, normal_loss = chamfer_distance(points_pred,
                                                  points_gt,
                                                  x_normals=normals_pred,
                                                  y_normals=normals_gt)

        total_loss = total_loss + self.chamfer_weight * cham_loss
        total_loss = total_loss + self.normal_weight * normal_loss
        losses["chamfer"] = cham_loss
        losses["normal"] = normal_loss

        edge_loss = mesh_edge_loss(meshes_pred)
        total_loss = total_loss + self.edge_weight * edge_loss
        losses["edge"] = edge_loss

        return total_loss, losses
Exemple #8
0
    def test_empty_meshes(self):
        device = torch.device("cuda:0")
        target_length = 0
        N = 10
        V = 32
        verts_list = []
        faces_list = []
        for _ in range(N):
            vn = torch.randint(3, high=V, size=(1,))[0].item()
            verts = torch.rand((vn, 3), dtype=torch.float32, device=device)
            faces = torch.tensor([], dtype=torch.int64, device=device)
            verts_list.append(verts)
            faces_list.append(faces)
        mesh = Meshes(verts=verts_list, faces=faces_list)
        loss = mesh_edge_loss(mesh, target_length=target_length)

        self.assertClose(loss, torch.tensor([0.0], dtype=torch.float32, device=device))
        self.assertTrue(loss.requires_grad)
def get_deform_verts(target_mesh, points_to_sample=5000, sphere_level=4):
    device = torch.device("cuda:0")

    src_mesh = ico_sphere(sphere_level, device)

    deform_verts = torch.full(src_mesh.verts_packed().shape,
                              0.0,
                              device=device,
                              requires_grad=True)

    learning_rate = 0.01
    num_iter = 500
    w_chamfer = 1.0
    w_edge = 0.05
    w_normal = 0.0005
    w_laplacian = 0.005

    optimizer = torch.optim.Adam([deform_verts],
                                 lr=learning_rate,
                                 betas=(0.5, 0.999))

    for _ in range(num_iter):
        optimizer.zero_grad()

        new_src_mesh = src_mesh.offset_verts(deform_verts)

        sample_trg = sample_points_from_meshes(target_mesh, points_to_sample)
        sample_src = sample_points_from_meshes(new_src_mesh, points_to_sample)

        loss_chamfer, _ = chamfer_distance(sample_trg, sample_src)
        loss_edge = mesh_edge_loss(new_src_mesh)
        loss_normal = mesh_normal_consistency(new_src_mesh)
        loss_laplacian = mesh_laplacian_smoothing(new_src_mesh,
                                                  method="uniform")
        loss = loss_chamfer * w_chamfer + loss_edge * w_edge + loss_normal * w_normal + loss_laplacian * w_laplacian

        loss.backward()
        optimizer.step()
    print(
        f"{datetime.now()} Loss Chamfer:{loss_chamfer * w_chamfer}, Loss Edge:{loss_edge * w_edge}, Loss Normal:{loss_normal * w_normal}, Loss Laplacian:{loss_laplacian * w_laplacian}"
    )

    return deform_verts
Exemple #10
0
    def compute_loss(self, mesh, pcd=None):
        if pcd is None: pcd = self.pcd

        face_loss = pt3loss.point_mesh_face_distance(mesh, pcd)
        edge_loss = pt3loss.point_mesh_edge_distance(mesh, pcd)
        point_loss = pt3loss.chamfer_distance(mesh.verts_padded(), pcd)[0]

        length_loss = pt3loss.mesh_edge_loss(mesh)
        normal_loss = pt3loss.mesh_normal_consistency(mesh)

        mpcd = sample_points_from_meshes(mesh,
                                         2 * pcd.points_padded()[0].shape[0])
        sample_loss, _ = pt3loss.chamfer_distance(mpcd, pcd)

        losses = torch.tensor((face_loss, edge_loss, point_loss, length_loss,
                               normal_loss, sample_loss),
                              requires_grad=True).to(device='cuda')

        return losses
Exemple #11
0
    def loss(self, data, epoch):

         
        pred = self.forward(data)  
        # embed()
        

         
        CE_Loss = nn.CrossEntropyLoss() 
        ce_loss = CE_Loss(pred[0][-1][3], data['y_voxels'])


        chamfer_loss = torch.tensor(0).float().cuda()
        edge_loss = torch.tensor(0).float().cuda()
        laplacian_loss = torch.tensor(0).float().cuda()
        normal_consistency_loss = torch.tensor(0).float().cuda()  

        for c in range(self.config.num_classes-1):
            target = data['surface_points'][c].cuda() 
            for k, (vertices, faces, _, _, _) in enumerate(pred[c][1:]):
      
                pred_mesh = Meshes(verts=list(vertices), faces=list(faces))
                pred_points = sample_points_from_meshes(pred_mesh, 3000)
                
                chamfer_loss +=  chamfer_distance(pred_points, target)[0]
                laplacian_loss +=   mesh_laplacian_smoothing(pred_mesh, method="uniform")
                normal_consistency_loss += mesh_normal_consistency(pred_mesh) 
                edge_loss += mesh_edge_loss(pred_mesh) 

        
        
 
        loss = 1 * chamfer_loss + 1 * ce_loss + 0.1 * laplacian_loss + 1 * edge_loss + 0.1 * normal_consistency_loss
 
        log = {"loss": loss.detach(),
               "chamfer_loss": chamfer_loss.detach(), 
               "ce_loss": ce_loss.detach(),
               "normal_consistency_loss": normal_consistency_loss.detach(),
               "edge_loss": edge_loss.detach(),
               "laplacian_loss": laplacian_loss.detach()}
        return loss, log
Exemple #12
0
    def run(self):

        
        deform_verts = torch.full(self.src.verts_packed().shape, 0.0, device=device, requires_grad=True)
        optimizer = torch.optim.SGD([deform_verts], lr=1.0, momentum=0.9)


        Niter = 2000
        w_chamfer = 1.0
        w_edge = 1.0
        w_normal =  0.01
        w_laplacian = 0.1

        


        for i in range(Niter):
            
        
            optimizer.zero_grad()

            new_src_mesh = self.src.offset_verts(deform_verts)

            
            sampmle_trg = sample_points_from_meshes(self.target, 5000)
            sample_src = sample_points_from_meshes(new_src_mesh, 5000)


            loss_chamfer, _ = chamfer_distance(sampmle_trg, sample_src)
            loss_edge = mesh_edge_loss(new_src_mesh)
            loss_normal = mesh_normal_consistency(new_src_mesh)
            loss_laplacian = mesh_laplacian_smoothing(new_src_mesh, method="uniform")

            #weighted sum of the losses
            loss = loss_chamfer*w_chamfer + loss_edge*w_edge + loss_normal * w_normal + loss_laplacian * w_laplacian
            

            loss.backward()
            optimizer.step()
            print('total_loss = %.6f' % loss)
            self.backwarded.emit(new_src_mesh.verts_packed())
    # Initialize optimizer
    optimizer.zero_grad()

    # Deform the mesh
    new_src_mesh = src_mesh.offset_verts(deform_verts)

    # We sample 5k points from the surface of each mesh
    sample_trg = sample_points_from_meshes(trg_mesh, 5000)
    sample_src = sample_points_from_meshes(new_src_mesh, 5000)

    # We compare the two sets of pointclouds by computing (a) the chamfer loss
    loss_chamfer, _ = chamfer_distance(sample_trg, sample_src)

    # and (b) the edge length of the predicted mesh

    loss_edge = mesh_edge_loss(new_src_mesh)

    # mesh normal consistency
    loss_normal = mesh_normal_consistency(new_src_mesh)

    # mesh laplacian smoothing
    loss_laplacian = mesh_laplacian_smoothing(new_src_mesh, method="uniform")

    # Weighted sum of the losses
    loss = loss_chamfer * w_chamfer + loss_edge * w_edge + loss_normal * w_normal + loss_laplacian * w_laplacian

    # Print the losses
    t.set_description('total_loss = %.6f' % loss)

    # Save the losses for plotting
    chamfer_losses.append(loss_chamfer)
        #----------------------------------------------------
        # モデルの forword 処理
        #----------------------------------------------------
        # メッシュの変形
        mesh_s_new = mesh_s.offset_verts(verts_deform)

        # 各メッシュの表面から5000個の点をサンプリング
        sample_t = sample_points_from_meshes(mesh_t, 5000)
        sample_s = sample_points_from_meshes(mesh_s_new, 5000)

        #----------------------------------------------------
        # モデルの更新処理
        #----------------------------------------------------
        # 損失関数を計算する
        loss_chamfer, _ = chamfer_distance(sample_t, sample_s)
        loss_edge = mesh_edge_loss(mesh_s_new)
        loss_normal = mesh_normal_consistency(mesh_s_new)
        loss_laplacian = mesh_laplacian_smoothing(mesh_s_new, method="uniform")
        loss_G = args.lambda_chamfer * loss_chamfer + args.lambda_edge * loss_edge + args.lambda_normal * loss_normal + args.lambda_laplacian * loss_laplacian

        # ネットワークの更新処理
        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()

        #====================================================
        # 学習過程の表示
        #====================================================
        if (step == 0 or (step % args.n_diaplay_step == 0)):
            # lr
            for param_group in optimizer_G.param_groups:
Exemple #15
0
 def compute_loss():
     mesh_edge_loss(meshes, target_length=0.0)
     torch.cuda.synchronize()
Exemple #16
0
def mesh_rcnn_loss(
    pred_meshes,
    instances,
    loss_weights=None,
    gt_num_samples=5000,
    pred_num_samples=5000,
    gt_coord_thresh=None,
):
    """
    Compute the mesh prediction loss defined in the Mesh R-CNN paper.

    Args:
        pred_meshes (list of Meshes): A list of K Meshes. Each entry contains B meshes,
            where B is the total number of predicted meshes in all images.
            K is the number of refinements
        instances (list[Instances]): A list of N Instances, where N is the number of images
            in the batch. These instances are in 1:1 correspondence with the pred_meshes.
            The ground-truth labels (class, box, mask, ...) associated with each instance
            are stored in fields.
        loss_weights (dict): Contains the weights for the different losses, e.g.
            loss_weights = {'champfer': 1.0, 'normals': 0.0, 'edge': 0.2}
        gt_num_samples (int): The number of points to sample from gt meshes
        pred_num_samples (int): The number of points to sample from predicted meshes
        gt_coord_thresh (float): A threshold value over which the batch is ignored
    Returns:
        mesh_loss (Tensor): A scalar tensor containing the loss.
    """
    if not isinstance(pred_meshes, list):
        raise ValueError("Expecting a list of Meshes")

    gt_verts, gt_faces = [], []
    for instances_per_image in instances:
        if len(instances_per_image) == 0:
            continue

        gt_K = instances_per_image.gt_K
        gt_mesh_per_image = batch_crop_meshes_within_box(
            instances_per_image.gt_meshes, instances_per_image.proposal_boxes.tensor, gt_K
        ).to(device=pred_meshes[0].device)
        gt_verts.extend(gt_mesh_per_image.verts_list())
        gt_faces.extend(gt_mesh_per_image.faces_list())

    if len(gt_verts) == 0:
        return None, None

    gt_meshes = Meshes(verts=gt_verts, faces=gt_faces)
    gt_valid = gt_meshes.valid
    gt_sampled_verts, gt_sampled_normals = sample_points_from_meshes(
        gt_meshes, num_samples=gt_num_samples, return_normals=True
    )

    all_loss_chamfer = []
    all_loss_normals = []
    all_loss_edge = []
    for pred_mesh in pred_meshes:
        pred_sampled_verts, pred_sampled_normals = sample_points_from_meshes(
            pred_mesh, num_samples=pred_num_samples, return_normals=True
        )
        wts = (pred_mesh.valid * gt_valid).to(dtype=torch.float32)
        # chamfer loss
        loss_chamfer, loss_normals = chamfer_distance(
            pred_sampled_verts,
            gt_sampled_verts,
            pred_sampled_normals,
            gt_sampled_normals,
            weights=wts,
        )

        # chamfer loss
        loss_chamfer = loss_chamfer * loss_weights["chamfer"]
        all_loss_chamfer.append(loss_chamfer)
        # normal loss
        loss_normals = loss_normals * loss_weights["normals"]
        all_loss_normals.append(loss_normals)
        # mesh edge regularization
        loss_edge = mesh_edge_loss(pred_mesh)
        loss_edge = loss_edge * loss_weights["edge"]
        all_loss_edge.append(loss_edge)

    loss_chamfer = sum(all_loss_chamfer)
    loss_normals = sum(all_loss_normals)
    loss_edge = sum(all_loss_edge)

    # if the rois are bad, the target verts can be arbitrarily large
    # causing exploding gradients. If this is the case, ignore the batch
    if gt_coord_thresh and gt_sampled_verts.abs().max() > gt_coord_thresh:
        loss_chamfer = loss_chamfer * 0.0
        loss_normals = loss_normals * 0.0
        loss_edge = loss_edge * 0.0

    return loss_chamfer, loss_normals, loss_edge, gt_meshes
Exemple #17
0
    def attack(self, data, target, label):
        """Attack on given data to target.

        Args:
            data (torch.FloatTensor): victim data, [B, num_vertices, 3]
            target (torch.LongTensor): target output, [B]
        """
        B, K = len(data), 1024
        global bas
        data = data.cuda()
        label_val = target.detach().cpu().numpy()  # [B]

        label = label.long().cuda().detach()
        label_true = label.detach().cpu().numpy()

        deform_ori = data.clone()

        # weight factor for budget regularization
        lower_bound = np.zeros((B, ))
        upper_bound = np.ones((B, )) * self.max_weight
        current_weight = np.ones((B, )) * self.init_weight

        # record best results in binary search
        o_bestdist = np.array([1e10] * B)
        o_bestscore = np.array([-1] * B)
        o_bestattack = np.zeros((B, 3, K))
        # Weight for the chamfer loss
        w_chamfer = 1.0
        # Weight for mesh edge loss
        w_edge = 0.2
        # Weight for mesh laplacian smoothing
        w_laplacian = 0.5

        # perform binary search
        for binary_step in range(self.binary_step):
            deform_verts = torch.full(deform_ori.verts_packed().shape,
                                      0.000001,
                                      device='cuda:%s' % args.local_rank,
                                      requires_grad=True)
            ori_def = deform_verts.detach().clone()

            bestdist = np.array([1e10] * B)
            bestscore = np.array([-1] * B)
            dist_val = 0
            opt = optim.Adam([deform_verts],
                             lr=self.attack_lr,
                             weight_decay=0.)
            # opt = optim.SGD([deform_verts], lr=1.0, momentum=0.9) #optim.Adam([deform_verts], lr=self.attack_lr, weight_decay=0.)

            adv_loss = torch.tensor(0.).cuda()
            dist_loss = torch.tensor(0.).cuda()

            total_time = 0.
            forward_time = 0.
            backward_time = 0.
            update_time = 0.

            # one step in binary search
            for iteration in range(self.num_iter):
                t1 = time.time()
                opt.zero_grad()
                new_defrom_mesh = deform_ori.offset_verts(deform_verts)

                # forward passing
                ori_data = sample_points_from_meshes(data, 1024)
                adv_pl = sample_points_from_meshes(new_defrom_mesh, 1024)
                adv_pl1 = adv_pl.transpose(1, 2).contiguous()
                logits = self.model(adv_pl1)  # [B, num_classes]
                if isinstance(logits, tuple):  # PointNet
                    logits = logits[0]

                t2 = time.time()
                forward_time += t2 - t1

                pred = torch.argmax(logits, dim=1)  # [B]
                success_num = (pred == target).sum().item()
                if iteration % (self.num_iter // 5) == 0:
                    print('Step {}, iteration {}, current_c {},success {}/{}\n'
                          'adv_loss: {:.4f}'.format(
                              binary_step, iteration,
                              torch.from_numpy(current_weight).mean(),
                              success_num, B, adv_loss.item()))
                dist_val = torch.sqrt(torch.sum(
                    (adv_pl - ori_data) ** 2, dim=[1, 2])).\
                    detach().cpu().numpy()  # [B]
                pred_val = pred.detach().cpu().numpy()  # [B]
                input_val = adv_pl1.detach().cpu().numpy()  # [B, 3, K]

                # update
                for e, (dist, pred, label, ii) in \
                        enumerate(zip(dist_val, pred_val, label_val, input_val)):
                    if dist < bestdist[e] and pred == label:
                        bestdist[e] = dist
                        bestscore[e] = pred
                    if dist < o_bestdist[e] and pred == label:
                        o_bestdist[e] = dist
                        o_bestscore[e] = pred
                        o_bestattack[e] = ii

                t3 = time.time()
                # compute loss and backward
                adv_loss = self.adv_func(logits, target).mean()
                loss_chamfer, _ = chamfer_distance(ori_data, adv_pl)
                loss_edge = mesh_edge_loss(new_defrom_mesh)
                loss_laplacian = mesh_laplacian_smoothing(new_defrom_mesh,
                                                          method="uniform")

                loss = adv_loss + torch.from_numpy(current_weight).mean() * (
                    loss_chamfer * w_chamfer + loss_edge * w_edge +
                    loss_laplacian * w_laplacian)
                loss.backward()
                opt.step()

                deform_verts.data = self.clip(deform_verts.clone().detach(),
                                              ori_def)

                t4 = time.time()
                backward_time += t4 - t3
                total_time += t4 - t1

                if iteration % 100 == 0:
                    print(
                        'total time: {:.2f}, for: {:.2f}, '
                        'back: {:.6f}, update: {:.2f}, total loss: {:.6f}, chamfer loss: {:.6f}'
                        .format(total_time, forward_time, backward_time,
                                update_time, loss, loss_chamfer))
                    total_time = 0.
                    forward_time = 0.
                    backward_time = 0.
                    update_time = 0.
                    torch.cuda.empty_cache()

            # adjust weight factor
            for e, label in enumerate(label_val):
                if bestscore[e] == label and bestscore[e] != -1 and bestdist[
                        e] <= o_bestdist[e]:
                    # success
                    lower_bound[e] = max(lower_bound[e], current_weight[e])
                    current_weight[e] = (lower_bound[e] + upper_bound[e]) / 2.
                else:
                    # failure
                    upper_bound[e] = min(upper_bound[e], current_weight[e])
                    current_weight[e] = (lower_bound[e] + upper_bound[e]) / 2.

        bas += 1
        ## save the mesh
        new_defrom_mesh = deform_ori.offset_verts(deform_verts)
        for e1 in range(B):
            final_verts, final_faces = new_defrom_mesh.get_mesh_verts_faces(e1)
            final_obj = os.path.join(
                './p1_manifold_random_target01',
                'result_model%s_%s_%s_%s.obj' %
                (bas, e1, label_val[e1], label_true[e1]))
            save_obj(final_obj, final_verts, final_faces)

        fail_idx = (lower_bound == 0.)
        o_bestattack[fail_idx] = input_val[fail_idx]

        # return final results
        success_num = (lower_bound > 0.).sum()
        print('Successfully attack {}/{}'.format(success_num, B))
        return o_bestdist, o_bestattack.transpose((0, 2, 1)), success_num