示例#1
0
    def loss(self, src_mesh, src_verts):
        loss = 0

        if self.consider_loss("chamfer"):
            loss_chamfer, _ = chamfer_distance(
                self.target_verts, src_verts
            )  # We compare the two sets of pointclouds by computing (a) the chamfer loss

            loss += self.loss_weights["w_chamfer"] * loss_chamfer

        if self.consider_loss("edge"):
            loss_edge = mesh_edge_loss(
                src_mesh)  # and (b) the edge length of the predicted mesh
            loss += self.loss_weights["w_edge"] * loss_edge

        if self.consider_loss("normal"):
            loss_normal = mesh_normal_consistency(
                src_mesh)  # mesh normal consistency
            loss += self.loss_weights["w_normal"] * loss_normal

        if self.consider_loss("laplacian"):
            loss_laplacian = mesh_laplacian_smoothing(
                src_mesh, method="uniform")  # mesh laplacian smoothing
            loss += self.loss_weights["w_laplacian"] * loss_normal

        if self.consider_loss("arap"):
            for n in range(len(self.target_meshes)):
                loss_arap = arap_loss(self.prev_mesh,
                                      self.prev_verts,
                                      src_verts,
                                      mesh_idx=n)
                loss += self.loss_weights["w_arap"] * loss_arap

        return loss, loss_chamfer
示例#2
0
    def forward(self, src_mesh):
        loss = 0

        # Sample from target meshes
        target_verts = sample_points_from_meshes(self.target_meshes, 3000)

        if self.consider_loss("chamfer"):
            loss_chamfer, _ = chamfer_distance(target_verts,
                                               src_mesh.verts_padded())
            loss += self.loss_weights["w_chamfer"] * loss_chamfer

        if self.consider_loss("edge"):
            loss_edge = mesh_edge_loss(
                src_mesh)  # and (b) the edge length of the predicted mesh
            loss += self.loss_weights["w_edge"] * loss_edge

        if self.consider_loss("normal"):
            loss_normal = mesh_normal_consistency(
                src_mesh)  # mesh normal consistency
            loss += self.loss_weights["w_normal"] * loss_normal

        if self.consider_loss("laplacian"):
            loss_laplacian = mesh_laplacian_smoothing(
                src_mesh, method="uniform")  # mesh laplacian smoothing
            loss += self.loss_weights["w_laplacian"] * loss_laplacian

        return loss
示例#3
0
def get_loss(mesh,
             trg_mesh,
             w_chamfer,
             w_edge,
             w_normal,
             w_laplacian,
             n_points=5000):
    # We sample 5k points from the surface of each mesh
    sample_trg = sample_points_from_meshes(trg_mesh, n_points)
    sample_src = sample_points_from_meshes(mesh, n_points)

    # We compare the two sets of pointclouds by computing (a) the chamfer loss

    loss_chamfer, _ = chamfer_distance(sample_trg, sample_src)

    # and (b) the edge length of the predicted mesh
    loss_edge = mesh_edge_loss(mesh)

    # mesh normal consistency
    loss_normal = mesh_normal_consistency(mesh)

    # mesh laplacian smoothing
    loss_laplacian = mesh_laplacian_smoothing(mesh, method="uniform")

    # Weighted sum of the losses
    loss = loss_chamfer * w_chamfer + loss_edge * w_edge + loss_normal * w_normal + loss_laplacian * w_laplacian

    return loss
示例#4
0
    def loss(self, data, epoch):

         
        pred = self.forward(data)  
        # embed()
        # loss_coef = max(1/(2**(epoch//10000)), 0.1)



        # CE_Loss = nn.CrossEntropyLoss()
        # ce_loss = CE_Loss(pred[0][-1][3], data['y_voxels'])
        weight = data['base_plane'].float().cuda()
        CE_Loss = nn.CrossEntropyLoss(reduction='none')
        ce_loss = CE_Loss(pred[0][-1][3], data['y_voxels'].cuda()) * weight
        ce_loss = ce_loss.mean()

        chamfer_loss = torch.tensor(0).float().cuda()
        edge_loss = torch.tensor(0).float().cuda()
        laplacian_loss = torch.tensor(0).float().cuda()
        normal_consistency_loss = torch.tensor(0).float().cuda()  

        for c in range(self.config.num_classes-1):
            target = data['surface_points'][c].cuda() 
            for k, (vertices, faces, _, _, _) in enumerate(pred[c][1:]):

                pred_mesh = Meshes(verts=list(vertices), faces=list(faces))
                pred_points = sample_points_from_meshes(pred_mesh, 3000)

                chamfer_loss +=  chamfer_distance(pred_points, target)[0]
                laplacian_loss +=   mesh_laplacian_smoothing(pred_mesh, method="uniform")
                normal_consistency_loss += mesh_normal_consistency(pred_mesh)
                edge_loss += mesh_edge_loss(pred_mesh)

            # vertices, faces, _, _, _ = pred[c][-1]
            # pred_mesh = Meshes(verts=list(vertices), faces=list(faces))
            # pred_points = sample_points_from_meshes(pred_mesh, 3000)
            #
            # chamfer_loss += chamfer_distance(pred_points, target)[0]*5
            # laplacian_loss += mesh_laplacian_smoothing(pred_mesh, method="uniform")*5
            # normal_consistency_loss += mesh_normal_consistency(pred_mesh)*5
            # edge_loss += mesh_edge_loss(pred_mesh)*5
            #
            # # chamfer_loss = chamfer_loss/2
            # # laplacian_loss = laplacian_loss/2
            # # normal_consistency_loss = normal_consistency_loss/2
            # # edge_loss = edge_loss/2

        loss = 1 * chamfer_loss + 1 * ce_loss + 0.1 * laplacian_loss + 1 * edge_loss + 0.1 * normal_consistency_loss
        # loss = 1 * chamfer_loss + 0.1 * laplacian_loss + 1 * edge_loss + 0.1 * normal_consistency_loss
        # loss = 1 * chamfer_loss + 0.1 * laplacian_loss + loss_coef * edge_loss + 0.1 * normal_consistency_loss

        log = {"loss": loss.detach(),
               "chamfer_loss": chamfer_loss.detach(), 
               # "loss_coef": loss_coef,
               "ce_loss": ce_loss.detach(),
               "normal_consistency_loss": normal_consistency_loss.detach(),
               "edge_loss": edge_loss.detach(),
               "laplacian_loss": laplacian_loss.detach()}
        return loss, log
def update_mesh_shape_prior_losses(mesh, loss):
    # and (b) the edge length of the predicted mesh
    loss["edge"] = mesh_edge_loss(mesh)

    # mesh normal consistency
    loss["normal"] = mesh_normal_consistency(mesh)

    # mesh laplacian smoothing
    loss["laplacian"] = mesh_laplacian_smoothing(mesh, method="uniform")
示例#6
0
    def forward(self, batch_size):
        # Offset the mesh
        deformed_mesh_verts = self.template_mesh.offset_verts(
            self.deform_verts)
        texture = TexturesVertex(self.textures)
        deformed_mesh = Meshes(
            verts=deformed_mesh_verts.verts_padded(),
            faces=deformed_mesh_verts.faces_padded(),
            textures=texture,
        )
        deformed_meshes = deformed_mesh.extend(batch_size)

        laplacian_loss = mesh_laplacian_smoothing(deformed_mesh,
                                                  method="uniform")
        flatten_loss = mesh_normal_consistency(deformed_mesh)

        return deformed_meshes, laplacian_loss, flatten_loss
示例#7
0
    def refine_mesh_batched(self, deform_net, semantic_dis_net, mesh_verts_batch, img_batch, pose_batch, compute_losses=True):

        # computing mesh deformation
        delta_v = deform_net(pose_batch, img_batch, mesh_verts_batch)
        delta_v = delta_v.reshape((-1,3))
        deformed_mesh = mesh.offset_verts(delta_v)

        if not compute_losses:
            return deformed_mesh

        else:
            # prep inputs used to compute losses
            pred_dist = pose_batch[:,0]
            pred_elev = pose_batch[:,1]
            pred_azim = pose_batch[:,2]
            R, T = look_at_view_transform(pred_dist, pred_elev, pred_azim) 
            mask = rgba_image[:,:,3] > 0
            mask_gt = torch.tensor(mask, dtype=torch.float).to(self.device)
            num_vertices = mesh.verts_packed().shape[0]
            zero_deformation_tensor = torch.zeros((num_vertices, 3)).to(self.device)
            sym_plane_normal = [0,0,1] # TODO: make this generalizable to other classes

            loss_dict = {}
            # computing losses
            rendered_deformed_mesh = utils.render_mesh(deformed_mesh, R, T, self.device, img_size=224, silhouette=True)
            loss_dict["sil_loss"] = F.binary_cross_entropy(rendered_deformed_mesh[0, :,:, 3], mask_gt)
            loss_dict["l2_loss"] = F.mse_loss(delta_v, zero_deformation_tensor)
            loss_dict["lap_smoothness_loss"] = mesh_laplacian_smoothing(deformed_mesh)
            loss_dict["normal_consistency_loss"] = mesh_normal_consistency(deformed_mesh)

            # TODO: remove weights?
            if self.img_sym_lam > 0:
                loss_dict["img_sym_loss"], _ = def_losses.image_symmetry_loss(deformed_mesh, sym_plane_normal, self.cfg["training"]["img_sym_num_azim"], self.device)
            else:
                loss_dict["img_sym_loss"] = torch.tensor(0).to(self.device)
            if self.vertex_sym_lam > 0:
                loss_dict["vertex_sym_loss"] = def_losses.vertex_symmetry_loss_fast(deformed_mesh, sym_plane_normal, self.device)
            else:
                loss_dict["vertex_sym_loss"] = torch.tensor(0).to(self.device)
            if self.semantic_dis_lam > 0:
                loss_dict["semantic_dis_loss"], _ = compute_sem_dis_loss(deformed_mesh, self.semantic_dis_loss_num_render, semantic_dis_net, self.device)
            else:
                loss_dict["semantic_dis_loss"] = torch.tensor(0).to(self.device)

            return loss_dict, deformed_mesh
def get_deform_verts(target_mesh, points_to_sample=5000, sphere_level=4):
    device = torch.device("cuda:0")

    src_mesh = ico_sphere(sphere_level, device)

    deform_verts = torch.full(src_mesh.verts_packed().shape,
                              0.0,
                              device=device,
                              requires_grad=True)

    learning_rate = 0.01
    num_iter = 500
    w_chamfer = 1.0
    w_edge = 0.05
    w_normal = 0.0005
    w_laplacian = 0.005

    optimizer = torch.optim.Adam([deform_verts],
                                 lr=learning_rate,
                                 betas=(0.5, 0.999))

    for _ in range(num_iter):
        optimizer.zero_grad()

        new_src_mesh = src_mesh.offset_verts(deform_verts)

        sample_trg = sample_points_from_meshes(target_mesh, points_to_sample)
        sample_src = sample_points_from_meshes(new_src_mesh, points_to_sample)

        loss_chamfer, _ = chamfer_distance(sample_trg, sample_src)
        loss_edge = mesh_edge_loss(new_src_mesh)
        loss_normal = mesh_normal_consistency(new_src_mesh)
        loss_laplacian = mesh_laplacian_smoothing(new_src_mesh,
                                                  method="uniform")
        loss = loss_chamfer * w_chamfer + loss_edge * w_edge + loss_normal * w_normal + loss_laplacian * w_laplacian

        loss.backward()
        optimizer.step()
    print(
        f"{datetime.now()} Loss Chamfer:{loss_chamfer * w_chamfer}, Loss Edge:{loss_edge * w_edge}, Loss Normal:{loss_normal * w_normal}, Loss Laplacian:{loss_laplacian * w_laplacian}"
    )

    return deform_verts
示例#9
0
    def run(self):

        
        deform_verts = torch.full(self.src.verts_packed().shape, 0.0, device=device, requires_grad=True)
        optimizer = torch.optim.SGD([deform_verts], lr=1.0, momentum=0.9)


        Niter = 2000
        w_chamfer = 1.0
        w_edge = 1.0
        w_normal =  0.01
        w_laplacian = 0.1

        


        for i in range(Niter):
            
        
            optimizer.zero_grad()

            new_src_mesh = self.src.offset_verts(deform_verts)

            
            sampmle_trg = sample_points_from_meshes(self.target, 5000)
            sample_src = sample_points_from_meshes(new_src_mesh, 5000)


            loss_chamfer, _ = chamfer_distance(sampmle_trg, sample_src)
            loss_edge = mesh_edge_loss(new_src_mesh)
            loss_normal = mesh_normal_consistency(new_src_mesh)
            loss_laplacian = mesh_laplacian_smoothing(new_src_mesh, method="uniform")

            #weighted sum of the losses
            loss = loss_chamfer*w_chamfer + loss_edge*w_edge + loss_normal * w_normal + loss_laplacian * w_laplacian
            

            loss.backward()
            optimizer.step()
            print('total_loss = %.6f' % loss)
            self.backwarded.emit(new_src_mesh.verts_packed())
示例#10
0
    def loss(self, data, epoch):

         
        pred = self.forward(data)  
        # embed()
        

         
        CE_Loss = nn.CrossEntropyLoss() 
        ce_loss = CE_Loss(pred[0][-1][3], data['y_voxels'])


        chamfer_loss = torch.tensor(0).float().cuda()
        edge_loss = torch.tensor(0).float().cuda()
        laplacian_loss = torch.tensor(0).float().cuda()
        normal_consistency_loss = torch.tensor(0).float().cuda()  

        for c in range(self.config.num_classes-1):
            target = data['surface_points'][c].cuda() 
            for k, (vertices, faces, _, _, _) in enumerate(pred[c][1:]):
      
                pred_mesh = Meshes(verts=list(vertices), faces=list(faces))
                pred_points = sample_points_from_meshes(pred_mesh, 3000)
                
                chamfer_loss +=  chamfer_distance(pred_points, target)[0]
                laplacian_loss +=   mesh_laplacian_smoothing(pred_mesh, method="uniform")
                normal_consistency_loss += mesh_normal_consistency(pred_mesh) 
                edge_loss += mesh_edge_loss(pred_mesh) 

        
        
 
        loss = 1 * chamfer_loss + 1 * ce_loss + 0.1 * laplacian_loss + 1 * edge_loss + 0.1 * normal_consistency_loss
 
        log = {"loss": loss.detach(),
               "chamfer_loss": chamfer_loss.detach(), 
               "ce_loss": ce_loss.detach(),
               "normal_consistency_loss": normal_consistency_loss.detach(),
               "edge_loss": edge_loss.detach(),
               "laplacian_loss": laplacian_loss.detach()}
        return loss, log
    # We sample 5k points from the surface of each mesh
    sample_trg = sample_points_from_meshes(trg_mesh, 5000)
    sample_src = sample_points_from_meshes(new_src_mesh, 5000)

    # We compare the two sets of pointclouds by computing (a) the chamfer loss
    loss_chamfer, _ = chamfer_distance(sample_trg, sample_src)

    # and (b) the edge length of the predicted mesh

    loss_edge = mesh_edge_loss(new_src_mesh)

    # mesh normal consistency
    loss_normal = mesh_normal_consistency(new_src_mesh)

    # mesh laplacian smoothing
    loss_laplacian = mesh_laplacian_smoothing(new_src_mesh, method="uniform")

    # Weighted sum of the losses
    loss = loss_chamfer * w_chamfer + loss_edge * w_edge + loss_normal * w_normal + loss_laplacian * w_laplacian

    # Print the losses
    t.set_description('total_loss = %.6f' % loss)

    # Save the losses for plotting
    chamfer_losses.append(loss_chamfer)
    edge_losses.append(loss_edge)
    normal_losses.append(loss_normal)
    laplacian_losses.append(loss_laplacian)
    losses.append(loss)
    # Plot mesh
    if i % plot_period == 0:
示例#12
0
    def compute_loss(self, batch, ep=None):
        inp = batch.get('inp').to(self.device)
        gt_verts = batch.get('gt_verts').to(self.device)
        betas = batch.get('betas').to(self.device)
        pose = batch.get('pose').to(self.device)
        trans = batch.get('trans').to(self.device)
        weights_from_net = self.model(inp).view(self.batch_size,
                                                self.layer_size,
                                                self.num_neigh)
        weights_from_net = self.out_layer(weights_from_net)

        loss_dict = {}
        pretrain = False
        if ep < 16:
            pretrain = True
        if pretrain:
            loss = (weights_from_net - self.init_weight).abs().sum(-1).mean()
        else:
            input_copy = inp[:, self.idx2, :3]
            pred_x = weights_from_net * input_copy[:, :, :, 0]
            pred_y = weights_from_net * input_copy[:, :, :, 1]
            pred_z = weights_from_net * input_copy[:, :, :, 2]

            pred_verts = torch.sum(torch.stack((pred_x, pred_y, pred_z),
                                               axis=3),
                                   axis=2)

            # local neighbourhood regulaiser
            current_argmax = torch.argmax(weights_from_net, axis=2)
            idx = torch.stack([
                torch.index_select(self.layer_neigh, 1, current_argmax[i])[0]
                for i in range(self.batch_size)
            ])
            current_argmax_verts = torch.stack([
                torch.index_select(inp[i, :, :3], 0, idx[i])
                for i in range(self.batch_size)
            ])
            current_argmax_verts = torch.stack(
                [current_argmax_verts for i in range(self.num_neigh)], dim=2)
            dist_from_max = current_argmax_verts - input_copy  # todo: should it be input copy??

            dist_from_max = torch.sqrt(
                torch.sum(dist_from_max * dist_from_max, dim=3))
            local_regu = torch.sum(dist_from_max * weights_from_net) / (
                self.batch_size * self.num_neigh * self.layer_size)

            body_tmp = self.smpl.forward(beta=betas, theta=pose, trans=trans)
            # body_mesh = [tm.from_tensors(vertices=v,
            #                              faces=self.smpl_faces) for v in body_tmp]

            if self.garment_layer == 'Body':
                # update body verts with prediction
                body_tmp[:, self.vert_indices, :] = pred_verts
                # get skin cutout
                loss_data = data_loss(self.garment_layer, pred_verts,
                                      inp[:, self.vert_indices, :],
                                      self.geo_weights)
            else:
                #loss_data = data_loss(self.garment_layer, pred_verts, gt_verts)
                loss_data, _ = chamfer_distance(pred_verts, gt_verts)
            # create mesh for predicted and smpl mesh
            #pred_mesh = Meshes(verts=[pred_verts], faces=[self.garment_f_torch.unsqueeze(0).repeat(self.batch_size,1,1)])
            pred_mesh = Meshes(verts=pred_verts,
                               faces=self.garment_f_torch.unsqueeze(0).repeat(
                                   self.batch_size, 1, 1))
            # pred_mesh = [tm.from_tensors(vertices=v,
            #                              faces=self.garment_f_torch) for v in pred_verts]
            # gt_mesh = [tm.from_tensors(vertices=v,
            #                            faces=self.garment_f_torch) for v in gt_verts]

            #loss_lap = lap_loss(pred_mesh, gt_mesh)
            loss_lap = mesh_laplacian_smoothing(pred_mesh, method='uniform')
            # calculate normal for gt, pred and body
            #loss_norm, body_normals, pred_normals = normal_loss(self.batch_size, pred_mesh, gt_mesh, body_mesh, self.num_faces)
            #loss_edge = mesh_edge_loss(smpl_mesh_deformed)
            # interpenetration loss
            # loss_interp = interp_loss(self.sideddistance, self.relu, pred_verts, gt_verts, body_tmp, body_normals,
            #                           self.layer_size, d_tol=self.d_tol)

            loss = loss_data + 100. * loss_lap + local_regu  #+ loss_interp  # loss_norm

        return loss, loss_dict
示例#13
0
    def compute_loss(self, batch, ep=None):
        gar_vert0 = batch.get('gar_vert0').to(self.device)
        gar_vert1 = batch.get('gar_vert1').to(self.device)
        gar_vert2 = batch.get('gar_vert2').to(self.device)

        betas0 = batch.get('betas0').to(self.device)

        pose0 = batch.get('pose0').to(self.device)
        pose1 = batch.get('pose1').to(self.device)
        pose2 = batch.get('pose2').to(self.device)

        trans0 = batch.get('trans0').to(self.device)
        trans1 = batch.get('trans1').to(self.device)
        trans2 = batch.get('trans2').to(self.device)

        size0 = batch.get('size0').to(self.device)
        size1 = batch.get('size1').to(self.device)
        size2 = batch.get('size2').to(self.device)
        inp_gar = torch.cat([
            gar_vert0, gar_vert0, gar_vert0, gar_vert1, gar_vert1, gar_vert1,
            gar_vert2, gar_vert2, gar_vert2
        ],
                            dim=0)
        size_inp = torch.cat(
            [size0, size0, size0, size1, size1, size1, size2, size2, size2],
            dim=0)
        size_des = torch.cat(
            [size0, size1, size2, size0, size1, size2, size0, size1, size2],
            dim=0)
        pose_all = torch.cat(
            [pose0, pose1, pose2, pose0, pose1, pose2, pose0, pose1, pose2],
            dim=0)
        trans_all = torch.cat([
            trans0, trans1, trans2, trans0, trans1, trans2, trans0, trans1,
            trans2
        ],
                              dim=0)
        betas_feat = torch.cat([
            betas0, betas0, betas0, betas0, betas0, betas0, betas0, betas0,
            betas0
        ],
                               dim=0)
        all_dist = self.model(inp_gar, size_inp, size_des, betas_feat)
        #todo change this to displacement in unposed space , not really because of wrong correspondence
        _, pred_verts = self.smpl.forward(beta=betas_feat,
                                          theta=pose_all,
                                          trans=trans_all,
                                          garment_class='t-shirt',
                                          garment_d=all_dist)
        gt_verts = torch.cat([
            gar_vert0, gar_vert1, gar_vert2, gar_vert0, gar_vert1, gar_vert2,
            gar_vert0, gar_vert1, gar_vert2
        ],
                             dim=0)
        pred_mesh = Meshes(verts=pred_verts,
                           faces=self.garment_f_torch.unsqueeze(0).repeat(
                               self.batch_size * 4, 1, 1))
        gt_mesh = Meshes(verts=gt_verts,
                         faces=self.garment_f_torch.unsqueeze(0).repeat(
                             self.batch_size * 4, 1, 1))
        loss_data, _ = chamfer_distance(pred_verts, gt_verts)
        loss_lap = mesh_laplacian_smoothing(pred_mesh, method='uniform')
        loss_dict = {}
        loss = loss_data + 100. * loss_lap
        return loss, loss_dict
示例#14
0
    def refine_mesh(self, mesh, rgba_image, pred_dist, pred_elev, pred_azim, record_intermediate=False):
        '''
        Args:
        pred_dist (int)
        pred_elev (int)
        pred_azim (int)
        rgba_image (np int array, 224 x 224 x 4, rgba, 0-255)
        '''

        # prep inputs used during training
        image = rgba_image[:,:,:3]
        image_in = torch.unsqueeze(torch.tensor(image/255, dtype=torch.float).permute(2,0,1),0).to(self.device)
        mask = rgba_image[:,:,3] > 0
        mask_gt = torch.tensor(mask, dtype=torch.float).to(self.device)
        pose_in = torch.unsqueeze(torch.tensor([pred_dist, pred_elev, pred_azim]),0).to(self.device)
        verts_in = torch.unsqueeze(mesh.verts_packed(),0).to(self.device)

        R, T = look_at_view_transform(pred_dist, pred_elev, pred_azim) 
        num_vertices = mesh.verts_packed().shape[0]
        zero_deformation_tensor = torch.zeros((num_vertices, 3)).to(self.device)

        # prep network & optimizer
        deform_net = DeformationNetwork(self.cfg, num_vertices, self.device)
        deform_net.to(self.device)
        optimizer = optim.Adam(deform_net.parameters(), lr=self.cfg["training"]["learning_rate"])

        # optimizing  
        loss_info = pd.DataFrame()
        deformed_meshes = []
        for i in tqdm(range(self.num_iterations)):
            deform_net.train()
            optimizer.zero_grad()
            
            # computing mesh deformation & its render at the input pose
            delta_v = deform_net(pose_in, image_in, verts_in)
            delta_v = delta_v.reshape((-1,3))
            deformed_mesh = mesh.offset_verts(delta_v)
            rendered_deformed_mesh = utils.render_mesh(deformed_mesh, R, T, self.device, img_size=224, silhouette=True)

            # computing losses
            l2_loss = F.mse_loss(delta_v, zero_deformation_tensor)
            lap_smoothness_loss = mesh_laplacian_smoothing(deformed_mesh)
            normal_consistency_loss = mesh_normal_consistency(deformed_mesh)

            sil_loss = F.binary_cross_entropy(rendered_deformed_mesh[0, :,:, 3], mask_gt)

            sym_plane_normal = [0,0,1]
            if self.img_sym_lam > 0:
                img_sym_loss, _ = def_losses.image_symmetry_loss(deformed_mesh, sym_plane_normal, self.img_sym_num_azim, self.device)
            else:
                img_sym_loss = torch.tensor(0).to(self.device)
            if self.vertex_sym_lam > 0:
                vertex_sym_loss = def_losses.vertex_symmetry_loss_fast(deformed_mesh, sym_plane_normal, self.device)
            else:
                vertex_sym_loss = torch.tensor(0).to(self.device)
            if self.semantic_dis_lam > 0:
                semantic_dis_loss, _ = self.semantic_loss_computer.compute_loss(deformed_mesh)
            else:
                semantic_dis_loss = torch.tensor(0).to(self.device)

            # optimization step on weighted losses
            total_loss = (sil_loss*self.sil_lam + l2_loss*self.l2_lam + lap_smoothness_loss*self.lap_lam +
                          normal_consistency_loss*self.normals_lam + img_sym_loss*self.img_sym_lam + 
                          vertex_sym_loss*self.vertex_sym_lam + semantic_dis_loss*self.semantic_dis_lam)
            total_loss.backward()
            optimizer.step()

            # saving info
            iter_loss_info = {"iter":i, "sil_loss": sil_loss.item(), "l2_loss": l2_loss.item(), 
                              "lap_smoothness_loss":lap_smoothness_loss.item(),
                              "normal_consistency_loss": normal_consistency_loss.item(),"img_sym_loss": img_sym_loss.item(),
                              "vertex_sym_loss": vertex_sym_loss.item(), "semantic_dis_loss": semantic_dis_loss.item(),
                              "total_loss": total_loss.item()}
            loss_info = loss_info.append(iter_loss_info, ignore_index = True)
            if record_intermediate and (i % 100 == 0 or i == self.num_iterations-1):
                print(i)
                deformed_meshes.append(deformed_mesh)

        if record_intermediate:
            return deformed_meshes, loss_info
        else:
            return deformed_mesh, loss_info
            
        #----------------------------------------------------
        # メッシュの変形
        mesh_s_new = mesh_s.offset_verts(verts_deform)

        # 各メッシュの表面から5000個の点をサンプリング
        sample_t = sample_points_from_meshes(mesh_t, 5000)
        sample_s = sample_points_from_meshes(mesh_s_new, 5000)

        #----------------------------------------------------
        # モデルの更新処理
        #----------------------------------------------------
        # 損失関数を計算する
        loss_chamfer, _ = chamfer_distance(sample_t, sample_s)
        loss_edge = mesh_edge_loss(mesh_s_new)
        loss_normal = mesh_normal_consistency(mesh_s_new)
        loss_laplacian = mesh_laplacian_smoothing(mesh_s_new, method="uniform")
        loss_G = args.lambda_chamfer * loss_chamfer + args.lambda_edge * loss_edge + args.lambda_normal * loss_normal + args.lambda_laplacian * loss_laplacian

        # ネットワークの更新処理
        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()

        #====================================================
        # 学習過程の表示
        #====================================================
        if (step == 0 or (step % args.n_diaplay_step == 0)):
            # lr
            for param_group in optimizer_G.param_groups:
                lr = param_group['lr']
示例#16
0
    def attack(self, data, target, label):
        """Attack on given data to target.

        Args:
            data (torch.FloatTensor): victim data, [B, num_vertices, 3]
            target (torch.LongTensor): target output, [B]
        """
        B, K = len(data), 1024
        global bas
        data = data.cuda()
        label_val = target.detach().cpu().numpy()  # [B]

        label = label.long().cuda().detach()
        label_true = label.detach().cpu().numpy()

        deform_ori = data.clone()

        # weight factor for budget regularization
        lower_bound = np.zeros((B, ))
        upper_bound = np.ones((B, )) * self.max_weight
        current_weight = np.ones((B, )) * self.init_weight

        # record best results in binary search
        o_bestdist = np.array([1e10] * B)
        o_bestscore = np.array([-1] * B)
        o_bestattack = np.zeros((B, 3, K))
        # Weight for the chamfer loss
        w_chamfer = 1.0
        # Weight for mesh edge loss
        w_edge = 0.2
        # Weight for mesh laplacian smoothing
        w_laplacian = 0.5

        # perform binary search
        for binary_step in range(self.binary_step):
            deform_verts = torch.full(deform_ori.verts_packed().shape,
                                      0.000001,
                                      device='cuda:%s' % args.local_rank,
                                      requires_grad=True)
            ori_def = deform_verts.detach().clone()

            bestdist = np.array([1e10] * B)
            bestscore = np.array([-1] * B)
            dist_val = 0
            opt = optim.Adam([deform_verts],
                             lr=self.attack_lr,
                             weight_decay=0.)
            # opt = optim.SGD([deform_verts], lr=1.0, momentum=0.9) #optim.Adam([deform_verts], lr=self.attack_lr, weight_decay=0.)

            adv_loss = torch.tensor(0.).cuda()
            dist_loss = torch.tensor(0.).cuda()

            total_time = 0.
            forward_time = 0.
            backward_time = 0.
            update_time = 0.

            # one step in binary search
            for iteration in range(self.num_iter):
                t1 = time.time()
                opt.zero_grad()
                new_defrom_mesh = deform_ori.offset_verts(deform_verts)

                # forward passing
                ori_data = sample_points_from_meshes(data, 1024)
                adv_pl = sample_points_from_meshes(new_defrom_mesh, 1024)
                adv_pl1 = adv_pl.transpose(1, 2).contiguous()
                logits = self.model(adv_pl1)  # [B, num_classes]
                if isinstance(logits, tuple):  # PointNet
                    logits = logits[0]

                t2 = time.time()
                forward_time += t2 - t1

                pred = torch.argmax(logits, dim=1)  # [B]
                success_num = (pred == target).sum().item()
                if iteration % (self.num_iter // 5) == 0:
                    print('Step {}, iteration {}, current_c {},success {}/{}\n'
                          'adv_loss: {:.4f}'.format(
                              binary_step, iteration,
                              torch.from_numpy(current_weight).mean(),
                              success_num, B, adv_loss.item()))
                dist_val = torch.sqrt(torch.sum(
                    (adv_pl - ori_data) ** 2, dim=[1, 2])).\
                    detach().cpu().numpy()  # [B]
                pred_val = pred.detach().cpu().numpy()  # [B]
                input_val = adv_pl1.detach().cpu().numpy()  # [B, 3, K]

                # update
                for e, (dist, pred, label, ii) in \
                        enumerate(zip(dist_val, pred_val, label_val, input_val)):
                    if dist < bestdist[e] and pred == label:
                        bestdist[e] = dist
                        bestscore[e] = pred
                    if dist < o_bestdist[e] and pred == label:
                        o_bestdist[e] = dist
                        o_bestscore[e] = pred
                        o_bestattack[e] = ii

                t3 = time.time()
                # compute loss and backward
                adv_loss = self.adv_func(logits, target).mean()
                loss_chamfer, _ = chamfer_distance(ori_data, adv_pl)
                loss_edge = mesh_edge_loss(new_defrom_mesh)
                loss_laplacian = mesh_laplacian_smoothing(new_defrom_mesh,
                                                          method="uniform")

                loss = adv_loss + torch.from_numpy(current_weight).mean() * (
                    loss_chamfer * w_chamfer + loss_edge * w_edge +
                    loss_laplacian * w_laplacian)
                loss.backward()
                opt.step()

                deform_verts.data = self.clip(deform_verts.clone().detach(),
                                              ori_def)

                t4 = time.time()
                backward_time += t4 - t3
                total_time += t4 - t1

                if iteration % 100 == 0:
                    print(
                        'total time: {:.2f}, for: {:.2f}, '
                        'back: {:.6f}, update: {:.2f}, total loss: {:.6f}, chamfer loss: {:.6f}'
                        .format(total_time, forward_time, backward_time,
                                update_time, loss, loss_chamfer))
                    total_time = 0.
                    forward_time = 0.
                    backward_time = 0.
                    update_time = 0.
                    torch.cuda.empty_cache()

            # adjust weight factor
            for e, label in enumerate(label_val):
                if bestscore[e] == label and bestscore[e] != -1 and bestdist[
                        e] <= o_bestdist[e]:
                    # success
                    lower_bound[e] = max(lower_bound[e], current_weight[e])
                    current_weight[e] = (lower_bound[e] + upper_bound[e]) / 2.
                else:
                    # failure
                    upper_bound[e] = min(upper_bound[e], current_weight[e])
                    current_weight[e] = (lower_bound[e] + upper_bound[e]) / 2.

        bas += 1
        ## save the mesh
        new_defrom_mesh = deform_ori.offset_verts(deform_verts)
        for e1 in range(B):
            final_verts, final_faces = new_defrom_mesh.get_mesh_verts_faces(e1)
            final_obj = os.path.join(
                './p1_manifold_random_target01',
                'result_model%s_%s_%s_%s.obj' %
                (bas, e1, label_val[e1], label_true[e1]))
            save_obj(final_obj, final_verts, final_faces)

        fail_idx = (lower_bound == 0.)
        o_bestattack[fail_idx] = input_val[fail_idx]

        # return final results
        success_num = (lower_bound > 0.).sum()
        print('Successfully attack {}/{}'.format(success_num, B))
        return o_bestdist, o_bestattack.transpose((0, 2, 1)), success_num