def loss(self, src_mesh, src_verts): loss = 0 if self.consider_loss("chamfer"): loss_chamfer, _ = chamfer_distance( self.target_verts, src_verts ) # We compare the two sets of pointclouds by computing (a) the chamfer loss loss += self.loss_weights["w_chamfer"] * loss_chamfer if self.consider_loss("edge"): loss_edge = mesh_edge_loss( src_mesh) # and (b) the edge length of the predicted mesh loss += self.loss_weights["w_edge"] * loss_edge if self.consider_loss("normal"): loss_normal = mesh_normal_consistency( src_mesh) # mesh normal consistency loss += self.loss_weights["w_normal"] * loss_normal if self.consider_loss("laplacian"): loss_laplacian = mesh_laplacian_smoothing( src_mesh, method="uniform") # mesh laplacian smoothing loss += self.loss_weights["w_laplacian"] * loss_normal if self.consider_loss("arap"): for n in range(len(self.target_meshes)): loss_arap = arap_loss(self.prev_mesh, self.prev_verts, src_verts, mesh_idx=n) loss += self.loss_weights["w_arap"] * loss_arap return loss, loss_chamfer
def forward(self, src_mesh): loss = 0 # Sample from target meshes target_verts = sample_points_from_meshes(self.target_meshes, 3000) if self.consider_loss("chamfer"): loss_chamfer, _ = chamfer_distance(target_verts, src_mesh.verts_padded()) loss += self.loss_weights["w_chamfer"] * loss_chamfer if self.consider_loss("edge"): loss_edge = mesh_edge_loss( src_mesh) # and (b) the edge length of the predicted mesh loss += self.loss_weights["w_edge"] * loss_edge if self.consider_loss("normal"): loss_normal = mesh_normal_consistency( src_mesh) # mesh normal consistency loss += self.loss_weights["w_normal"] * loss_normal if self.consider_loss("laplacian"): loss_laplacian = mesh_laplacian_smoothing( src_mesh, method="uniform") # mesh laplacian smoothing loss += self.loss_weights["w_laplacian"] * loss_laplacian return loss
def get_loss(mesh, trg_mesh, w_chamfer, w_edge, w_normal, w_laplacian, n_points=5000): # We sample 5k points from the surface of each mesh sample_trg = sample_points_from_meshes(trg_mesh, n_points) sample_src = sample_points_from_meshes(mesh, n_points) # We compare the two sets of pointclouds by computing (a) the chamfer loss loss_chamfer, _ = chamfer_distance(sample_trg, sample_src) # and (b) the edge length of the predicted mesh loss_edge = mesh_edge_loss(mesh) # mesh normal consistency loss_normal = mesh_normal_consistency(mesh) # mesh laplacian smoothing loss_laplacian = mesh_laplacian_smoothing(mesh, method="uniform") # Weighted sum of the losses loss = loss_chamfer * w_chamfer + loss_edge * w_edge + loss_normal * w_normal + loss_laplacian * w_laplacian return loss
def loss(self, data, epoch): pred = self.forward(data) # embed() # loss_coef = max(1/(2**(epoch//10000)), 0.1) # CE_Loss = nn.CrossEntropyLoss() # ce_loss = CE_Loss(pred[0][-1][3], data['y_voxels']) weight = data['base_plane'].float().cuda() CE_Loss = nn.CrossEntropyLoss(reduction='none') ce_loss = CE_Loss(pred[0][-1][3], data['y_voxels'].cuda()) * weight ce_loss = ce_loss.mean() chamfer_loss = torch.tensor(0).float().cuda() edge_loss = torch.tensor(0).float().cuda() laplacian_loss = torch.tensor(0).float().cuda() normal_consistency_loss = torch.tensor(0).float().cuda() for c in range(self.config.num_classes-1): target = data['surface_points'][c].cuda() for k, (vertices, faces, _, _, _) in enumerate(pred[c][1:]): pred_mesh = Meshes(verts=list(vertices), faces=list(faces)) pred_points = sample_points_from_meshes(pred_mesh, 3000) chamfer_loss += chamfer_distance(pred_points, target)[0] laplacian_loss += mesh_laplacian_smoothing(pred_mesh, method="uniform") normal_consistency_loss += mesh_normal_consistency(pred_mesh) edge_loss += mesh_edge_loss(pred_mesh) # vertices, faces, _, _, _ = pred[c][-1] # pred_mesh = Meshes(verts=list(vertices), faces=list(faces)) # pred_points = sample_points_from_meshes(pred_mesh, 3000) # # chamfer_loss += chamfer_distance(pred_points, target)[0]*5 # laplacian_loss += mesh_laplacian_smoothing(pred_mesh, method="uniform")*5 # normal_consistency_loss += mesh_normal_consistency(pred_mesh)*5 # edge_loss += mesh_edge_loss(pred_mesh)*5 # # # chamfer_loss = chamfer_loss/2 # # laplacian_loss = laplacian_loss/2 # # normal_consistency_loss = normal_consistency_loss/2 # # edge_loss = edge_loss/2 loss = 1 * chamfer_loss + 1 * ce_loss + 0.1 * laplacian_loss + 1 * edge_loss + 0.1 * normal_consistency_loss # loss = 1 * chamfer_loss + 0.1 * laplacian_loss + 1 * edge_loss + 0.1 * normal_consistency_loss # loss = 1 * chamfer_loss + 0.1 * laplacian_loss + loss_coef * edge_loss + 0.1 * normal_consistency_loss log = {"loss": loss.detach(), "chamfer_loss": chamfer_loss.detach(), # "loss_coef": loss_coef, "ce_loss": ce_loss.detach(), "normal_consistency_loss": normal_consistency_loss.detach(), "edge_loss": edge_loss.detach(), "laplacian_loss": laplacian_loss.detach()} return loss, log
def update_mesh_shape_prior_losses(mesh, loss): # and (b) the edge length of the predicted mesh loss["edge"] = mesh_edge_loss(mesh) # mesh normal consistency loss["normal"] = mesh_normal_consistency(mesh) # mesh laplacian smoothing loss["laplacian"] = mesh_laplacian_smoothing(mesh, method="uniform")
def forward(self, batch_size): # Offset the mesh deformed_mesh_verts = self.template_mesh.offset_verts( self.deform_verts) texture = TexturesVertex(self.textures) deformed_mesh = Meshes( verts=deformed_mesh_verts.verts_padded(), faces=deformed_mesh_verts.faces_padded(), textures=texture, ) deformed_meshes = deformed_mesh.extend(batch_size) laplacian_loss = mesh_laplacian_smoothing(deformed_mesh, method="uniform") flatten_loss = mesh_normal_consistency(deformed_mesh) return deformed_meshes, laplacian_loss, flatten_loss
def refine_mesh_batched(self, deform_net, semantic_dis_net, mesh_verts_batch, img_batch, pose_batch, compute_losses=True): # computing mesh deformation delta_v = deform_net(pose_batch, img_batch, mesh_verts_batch) delta_v = delta_v.reshape((-1,3)) deformed_mesh = mesh.offset_verts(delta_v) if not compute_losses: return deformed_mesh else: # prep inputs used to compute losses pred_dist = pose_batch[:,0] pred_elev = pose_batch[:,1] pred_azim = pose_batch[:,2] R, T = look_at_view_transform(pred_dist, pred_elev, pred_azim) mask = rgba_image[:,:,3] > 0 mask_gt = torch.tensor(mask, dtype=torch.float).to(self.device) num_vertices = mesh.verts_packed().shape[0] zero_deformation_tensor = torch.zeros((num_vertices, 3)).to(self.device) sym_plane_normal = [0,0,1] # TODO: make this generalizable to other classes loss_dict = {} # computing losses rendered_deformed_mesh = utils.render_mesh(deformed_mesh, R, T, self.device, img_size=224, silhouette=True) loss_dict["sil_loss"] = F.binary_cross_entropy(rendered_deformed_mesh[0, :,:, 3], mask_gt) loss_dict["l2_loss"] = F.mse_loss(delta_v, zero_deformation_tensor) loss_dict["lap_smoothness_loss"] = mesh_laplacian_smoothing(deformed_mesh) loss_dict["normal_consistency_loss"] = mesh_normal_consistency(deformed_mesh) # TODO: remove weights? if self.img_sym_lam > 0: loss_dict["img_sym_loss"], _ = def_losses.image_symmetry_loss(deformed_mesh, sym_plane_normal, self.cfg["training"]["img_sym_num_azim"], self.device) else: loss_dict["img_sym_loss"] = torch.tensor(0).to(self.device) if self.vertex_sym_lam > 0: loss_dict["vertex_sym_loss"] = def_losses.vertex_symmetry_loss_fast(deformed_mesh, sym_plane_normal, self.device) else: loss_dict["vertex_sym_loss"] = torch.tensor(0).to(self.device) if self.semantic_dis_lam > 0: loss_dict["semantic_dis_loss"], _ = compute_sem_dis_loss(deformed_mesh, self.semantic_dis_loss_num_render, semantic_dis_net, self.device) else: loss_dict["semantic_dis_loss"] = torch.tensor(0).to(self.device) return loss_dict, deformed_mesh
def get_deform_verts(target_mesh, points_to_sample=5000, sphere_level=4): device = torch.device("cuda:0") src_mesh = ico_sphere(sphere_level, device) deform_verts = torch.full(src_mesh.verts_packed().shape, 0.0, device=device, requires_grad=True) learning_rate = 0.01 num_iter = 500 w_chamfer = 1.0 w_edge = 0.05 w_normal = 0.0005 w_laplacian = 0.005 optimizer = torch.optim.Adam([deform_verts], lr=learning_rate, betas=(0.5, 0.999)) for _ in range(num_iter): optimizer.zero_grad() new_src_mesh = src_mesh.offset_verts(deform_verts) sample_trg = sample_points_from_meshes(target_mesh, points_to_sample) sample_src = sample_points_from_meshes(new_src_mesh, points_to_sample) loss_chamfer, _ = chamfer_distance(sample_trg, sample_src) loss_edge = mesh_edge_loss(new_src_mesh) loss_normal = mesh_normal_consistency(new_src_mesh) loss_laplacian = mesh_laplacian_smoothing(new_src_mesh, method="uniform") loss = loss_chamfer * w_chamfer + loss_edge * w_edge + loss_normal * w_normal + loss_laplacian * w_laplacian loss.backward() optimizer.step() print( f"{datetime.now()} Loss Chamfer:{loss_chamfer * w_chamfer}, Loss Edge:{loss_edge * w_edge}, Loss Normal:{loss_normal * w_normal}, Loss Laplacian:{loss_laplacian * w_laplacian}" ) return deform_verts
def run(self): deform_verts = torch.full(self.src.verts_packed().shape, 0.0, device=device, requires_grad=True) optimizer = torch.optim.SGD([deform_verts], lr=1.0, momentum=0.9) Niter = 2000 w_chamfer = 1.0 w_edge = 1.0 w_normal = 0.01 w_laplacian = 0.1 for i in range(Niter): optimizer.zero_grad() new_src_mesh = self.src.offset_verts(deform_verts) sampmle_trg = sample_points_from_meshes(self.target, 5000) sample_src = sample_points_from_meshes(new_src_mesh, 5000) loss_chamfer, _ = chamfer_distance(sampmle_trg, sample_src) loss_edge = mesh_edge_loss(new_src_mesh) loss_normal = mesh_normal_consistency(new_src_mesh) loss_laplacian = mesh_laplacian_smoothing(new_src_mesh, method="uniform") #weighted sum of the losses loss = loss_chamfer*w_chamfer + loss_edge*w_edge + loss_normal * w_normal + loss_laplacian * w_laplacian loss.backward() optimizer.step() print('total_loss = %.6f' % loss) self.backwarded.emit(new_src_mesh.verts_packed())
def loss(self, data, epoch): pred = self.forward(data) # embed() CE_Loss = nn.CrossEntropyLoss() ce_loss = CE_Loss(pred[0][-1][3], data['y_voxels']) chamfer_loss = torch.tensor(0).float().cuda() edge_loss = torch.tensor(0).float().cuda() laplacian_loss = torch.tensor(0).float().cuda() normal_consistency_loss = torch.tensor(0).float().cuda() for c in range(self.config.num_classes-1): target = data['surface_points'][c].cuda() for k, (vertices, faces, _, _, _) in enumerate(pred[c][1:]): pred_mesh = Meshes(verts=list(vertices), faces=list(faces)) pred_points = sample_points_from_meshes(pred_mesh, 3000) chamfer_loss += chamfer_distance(pred_points, target)[0] laplacian_loss += mesh_laplacian_smoothing(pred_mesh, method="uniform") normal_consistency_loss += mesh_normal_consistency(pred_mesh) edge_loss += mesh_edge_loss(pred_mesh) loss = 1 * chamfer_loss + 1 * ce_loss + 0.1 * laplacian_loss + 1 * edge_loss + 0.1 * normal_consistency_loss log = {"loss": loss.detach(), "chamfer_loss": chamfer_loss.detach(), "ce_loss": ce_loss.detach(), "normal_consistency_loss": normal_consistency_loss.detach(), "edge_loss": edge_loss.detach(), "laplacian_loss": laplacian_loss.detach()} return loss, log
# We sample 5k points from the surface of each mesh sample_trg = sample_points_from_meshes(trg_mesh, 5000) sample_src = sample_points_from_meshes(new_src_mesh, 5000) # We compare the two sets of pointclouds by computing (a) the chamfer loss loss_chamfer, _ = chamfer_distance(sample_trg, sample_src) # and (b) the edge length of the predicted mesh loss_edge = mesh_edge_loss(new_src_mesh) # mesh normal consistency loss_normal = mesh_normal_consistency(new_src_mesh) # mesh laplacian smoothing loss_laplacian = mesh_laplacian_smoothing(new_src_mesh, method="uniform") # Weighted sum of the losses loss = loss_chamfer * w_chamfer + loss_edge * w_edge + loss_normal * w_normal + loss_laplacian * w_laplacian # Print the losses t.set_description('total_loss = %.6f' % loss) # Save the losses for plotting chamfer_losses.append(loss_chamfer) edge_losses.append(loss_edge) normal_losses.append(loss_normal) laplacian_losses.append(loss_laplacian) losses.append(loss) # Plot mesh if i % plot_period == 0:
def compute_loss(self, batch, ep=None): inp = batch.get('inp').to(self.device) gt_verts = batch.get('gt_verts').to(self.device) betas = batch.get('betas').to(self.device) pose = batch.get('pose').to(self.device) trans = batch.get('trans').to(self.device) weights_from_net = self.model(inp).view(self.batch_size, self.layer_size, self.num_neigh) weights_from_net = self.out_layer(weights_from_net) loss_dict = {} pretrain = False if ep < 16: pretrain = True if pretrain: loss = (weights_from_net - self.init_weight).abs().sum(-1).mean() else: input_copy = inp[:, self.idx2, :3] pred_x = weights_from_net * input_copy[:, :, :, 0] pred_y = weights_from_net * input_copy[:, :, :, 1] pred_z = weights_from_net * input_copy[:, :, :, 2] pred_verts = torch.sum(torch.stack((pred_x, pred_y, pred_z), axis=3), axis=2) # local neighbourhood regulaiser current_argmax = torch.argmax(weights_from_net, axis=2) idx = torch.stack([ torch.index_select(self.layer_neigh, 1, current_argmax[i])[0] for i in range(self.batch_size) ]) current_argmax_verts = torch.stack([ torch.index_select(inp[i, :, :3], 0, idx[i]) for i in range(self.batch_size) ]) current_argmax_verts = torch.stack( [current_argmax_verts for i in range(self.num_neigh)], dim=2) dist_from_max = current_argmax_verts - input_copy # todo: should it be input copy?? dist_from_max = torch.sqrt( torch.sum(dist_from_max * dist_from_max, dim=3)) local_regu = torch.sum(dist_from_max * weights_from_net) / ( self.batch_size * self.num_neigh * self.layer_size) body_tmp = self.smpl.forward(beta=betas, theta=pose, trans=trans) # body_mesh = [tm.from_tensors(vertices=v, # faces=self.smpl_faces) for v in body_tmp] if self.garment_layer == 'Body': # update body verts with prediction body_tmp[:, self.vert_indices, :] = pred_verts # get skin cutout loss_data = data_loss(self.garment_layer, pred_verts, inp[:, self.vert_indices, :], self.geo_weights) else: #loss_data = data_loss(self.garment_layer, pred_verts, gt_verts) loss_data, _ = chamfer_distance(pred_verts, gt_verts) # create mesh for predicted and smpl mesh #pred_mesh = Meshes(verts=[pred_verts], faces=[self.garment_f_torch.unsqueeze(0).repeat(self.batch_size,1,1)]) pred_mesh = Meshes(verts=pred_verts, faces=self.garment_f_torch.unsqueeze(0).repeat( self.batch_size, 1, 1)) # pred_mesh = [tm.from_tensors(vertices=v, # faces=self.garment_f_torch) for v in pred_verts] # gt_mesh = [tm.from_tensors(vertices=v, # faces=self.garment_f_torch) for v in gt_verts] #loss_lap = lap_loss(pred_mesh, gt_mesh) loss_lap = mesh_laplacian_smoothing(pred_mesh, method='uniform') # calculate normal for gt, pred and body #loss_norm, body_normals, pred_normals = normal_loss(self.batch_size, pred_mesh, gt_mesh, body_mesh, self.num_faces) #loss_edge = mesh_edge_loss(smpl_mesh_deformed) # interpenetration loss # loss_interp = interp_loss(self.sideddistance, self.relu, pred_verts, gt_verts, body_tmp, body_normals, # self.layer_size, d_tol=self.d_tol) loss = loss_data + 100. * loss_lap + local_regu #+ loss_interp # loss_norm return loss, loss_dict
def compute_loss(self, batch, ep=None): gar_vert0 = batch.get('gar_vert0').to(self.device) gar_vert1 = batch.get('gar_vert1').to(self.device) gar_vert2 = batch.get('gar_vert2').to(self.device) betas0 = batch.get('betas0').to(self.device) pose0 = batch.get('pose0').to(self.device) pose1 = batch.get('pose1').to(self.device) pose2 = batch.get('pose2').to(self.device) trans0 = batch.get('trans0').to(self.device) trans1 = batch.get('trans1').to(self.device) trans2 = batch.get('trans2').to(self.device) size0 = batch.get('size0').to(self.device) size1 = batch.get('size1').to(self.device) size2 = batch.get('size2').to(self.device) inp_gar = torch.cat([ gar_vert0, gar_vert0, gar_vert0, gar_vert1, gar_vert1, gar_vert1, gar_vert2, gar_vert2, gar_vert2 ], dim=0) size_inp = torch.cat( [size0, size0, size0, size1, size1, size1, size2, size2, size2], dim=0) size_des = torch.cat( [size0, size1, size2, size0, size1, size2, size0, size1, size2], dim=0) pose_all = torch.cat( [pose0, pose1, pose2, pose0, pose1, pose2, pose0, pose1, pose2], dim=0) trans_all = torch.cat([ trans0, trans1, trans2, trans0, trans1, trans2, trans0, trans1, trans2 ], dim=0) betas_feat = torch.cat([ betas0, betas0, betas0, betas0, betas0, betas0, betas0, betas0, betas0 ], dim=0) all_dist = self.model(inp_gar, size_inp, size_des, betas_feat) #todo change this to displacement in unposed space , not really because of wrong correspondence _, pred_verts = self.smpl.forward(beta=betas_feat, theta=pose_all, trans=trans_all, garment_class='t-shirt', garment_d=all_dist) gt_verts = torch.cat([ gar_vert0, gar_vert1, gar_vert2, gar_vert0, gar_vert1, gar_vert2, gar_vert0, gar_vert1, gar_vert2 ], dim=0) pred_mesh = Meshes(verts=pred_verts, faces=self.garment_f_torch.unsqueeze(0).repeat( self.batch_size * 4, 1, 1)) gt_mesh = Meshes(verts=gt_verts, faces=self.garment_f_torch.unsqueeze(0).repeat( self.batch_size * 4, 1, 1)) loss_data, _ = chamfer_distance(pred_verts, gt_verts) loss_lap = mesh_laplacian_smoothing(pred_mesh, method='uniform') loss_dict = {} loss = loss_data + 100. * loss_lap return loss, loss_dict
def refine_mesh(self, mesh, rgba_image, pred_dist, pred_elev, pred_azim, record_intermediate=False): ''' Args: pred_dist (int) pred_elev (int) pred_azim (int) rgba_image (np int array, 224 x 224 x 4, rgba, 0-255) ''' # prep inputs used during training image = rgba_image[:,:,:3] image_in = torch.unsqueeze(torch.tensor(image/255, dtype=torch.float).permute(2,0,1),0).to(self.device) mask = rgba_image[:,:,3] > 0 mask_gt = torch.tensor(mask, dtype=torch.float).to(self.device) pose_in = torch.unsqueeze(torch.tensor([pred_dist, pred_elev, pred_azim]),0).to(self.device) verts_in = torch.unsqueeze(mesh.verts_packed(),0).to(self.device) R, T = look_at_view_transform(pred_dist, pred_elev, pred_azim) num_vertices = mesh.verts_packed().shape[0] zero_deformation_tensor = torch.zeros((num_vertices, 3)).to(self.device) # prep network & optimizer deform_net = DeformationNetwork(self.cfg, num_vertices, self.device) deform_net.to(self.device) optimizer = optim.Adam(deform_net.parameters(), lr=self.cfg["training"]["learning_rate"]) # optimizing loss_info = pd.DataFrame() deformed_meshes = [] for i in tqdm(range(self.num_iterations)): deform_net.train() optimizer.zero_grad() # computing mesh deformation & its render at the input pose delta_v = deform_net(pose_in, image_in, verts_in) delta_v = delta_v.reshape((-1,3)) deformed_mesh = mesh.offset_verts(delta_v) rendered_deformed_mesh = utils.render_mesh(deformed_mesh, R, T, self.device, img_size=224, silhouette=True) # computing losses l2_loss = F.mse_loss(delta_v, zero_deformation_tensor) lap_smoothness_loss = mesh_laplacian_smoothing(deformed_mesh) normal_consistency_loss = mesh_normal_consistency(deformed_mesh) sil_loss = F.binary_cross_entropy(rendered_deformed_mesh[0, :,:, 3], mask_gt) sym_plane_normal = [0,0,1] if self.img_sym_lam > 0: img_sym_loss, _ = def_losses.image_symmetry_loss(deformed_mesh, sym_plane_normal, self.img_sym_num_azim, self.device) else: img_sym_loss = torch.tensor(0).to(self.device) if self.vertex_sym_lam > 0: vertex_sym_loss = def_losses.vertex_symmetry_loss_fast(deformed_mesh, sym_plane_normal, self.device) else: vertex_sym_loss = torch.tensor(0).to(self.device) if self.semantic_dis_lam > 0: semantic_dis_loss, _ = self.semantic_loss_computer.compute_loss(deformed_mesh) else: semantic_dis_loss = torch.tensor(0).to(self.device) # optimization step on weighted losses total_loss = (sil_loss*self.sil_lam + l2_loss*self.l2_lam + lap_smoothness_loss*self.lap_lam + normal_consistency_loss*self.normals_lam + img_sym_loss*self.img_sym_lam + vertex_sym_loss*self.vertex_sym_lam + semantic_dis_loss*self.semantic_dis_lam) total_loss.backward() optimizer.step() # saving info iter_loss_info = {"iter":i, "sil_loss": sil_loss.item(), "l2_loss": l2_loss.item(), "lap_smoothness_loss":lap_smoothness_loss.item(), "normal_consistency_loss": normal_consistency_loss.item(),"img_sym_loss": img_sym_loss.item(), "vertex_sym_loss": vertex_sym_loss.item(), "semantic_dis_loss": semantic_dis_loss.item(), "total_loss": total_loss.item()} loss_info = loss_info.append(iter_loss_info, ignore_index = True) if record_intermediate and (i % 100 == 0 or i == self.num_iterations-1): print(i) deformed_meshes.append(deformed_mesh) if record_intermediate: return deformed_meshes, loss_info else: return deformed_mesh, loss_info
#---------------------------------------------------- # メッシュの変形 mesh_s_new = mesh_s.offset_verts(verts_deform) # 各メッシュの表面から5000個の点をサンプリング sample_t = sample_points_from_meshes(mesh_t, 5000) sample_s = sample_points_from_meshes(mesh_s_new, 5000) #---------------------------------------------------- # モデルの更新処理 #---------------------------------------------------- # 損失関数を計算する loss_chamfer, _ = chamfer_distance(sample_t, sample_s) loss_edge = mesh_edge_loss(mesh_s_new) loss_normal = mesh_normal_consistency(mesh_s_new) loss_laplacian = mesh_laplacian_smoothing(mesh_s_new, method="uniform") loss_G = args.lambda_chamfer * loss_chamfer + args.lambda_edge * loss_edge + args.lambda_normal * loss_normal + args.lambda_laplacian * loss_laplacian # ネットワークの更新処理 optimizer_G.zero_grad() loss_G.backward() optimizer_G.step() #==================================================== # 学習過程の表示 #==================================================== if (step == 0 or (step % args.n_diaplay_step == 0)): # lr for param_group in optimizer_G.param_groups: lr = param_group['lr']
def attack(self, data, target, label): """Attack on given data to target. Args: data (torch.FloatTensor): victim data, [B, num_vertices, 3] target (torch.LongTensor): target output, [B] """ B, K = len(data), 1024 global bas data = data.cuda() label_val = target.detach().cpu().numpy() # [B] label = label.long().cuda().detach() label_true = label.detach().cpu().numpy() deform_ori = data.clone() # weight factor for budget regularization lower_bound = np.zeros((B, )) upper_bound = np.ones((B, )) * self.max_weight current_weight = np.ones((B, )) * self.init_weight # record best results in binary search o_bestdist = np.array([1e10] * B) o_bestscore = np.array([-1] * B) o_bestattack = np.zeros((B, 3, K)) # Weight for the chamfer loss w_chamfer = 1.0 # Weight for mesh edge loss w_edge = 0.2 # Weight for mesh laplacian smoothing w_laplacian = 0.5 # perform binary search for binary_step in range(self.binary_step): deform_verts = torch.full(deform_ori.verts_packed().shape, 0.000001, device='cuda:%s' % args.local_rank, requires_grad=True) ori_def = deform_verts.detach().clone() bestdist = np.array([1e10] * B) bestscore = np.array([-1] * B) dist_val = 0 opt = optim.Adam([deform_verts], lr=self.attack_lr, weight_decay=0.) # opt = optim.SGD([deform_verts], lr=1.0, momentum=0.9) #optim.Adam([deform_verts], lr=self.attack_lr, weight_decay=0.) adv_loss = torch.tensor(0.).cuda() dist_loss = torch.tensor(0.).cuda() total_time = 0. forward_time = 0. backward_time = 0. update_time = 0. # one step in binary search for iteration in range(self.num_iter): t1 = time.time() opt.zero_grad() new_defrom_mesh = deform_ori.offset_verts(deform_verts) # forward passing ori_data = sample_points_from_meshes(data, 1024) adv_pl = sample_points_from_meshes(new_defrom_mesh, 1024) adv_pl1 = adv_pl.transpose(1, 2).contiguous() logits = self.model(adv_pl1) # [B, num_classes] if isinstance(logits, tuple): # PointNet logits = logits[0] t2 = time.time() forward_time += t2 - t1 pred = torch.argmax(logits, dim=1) # [B] success_num = (pred == target).sum().item() if iteration % (self.num_iter // 5) == 0: print('Step {}, iteration {}, current_c {},success {}/{}\n' 'adv_loss: {:.4f}'.format( binary_step, iteration, torch.from_numpy(current_weight).mean(), success_num, B, adv_loss.item())) dist_val = torch.sqrt(torch.sum( (adv_pl - ori_data) ** 2, dim=[1, 2])).\ detach().cpu().numpy() # [B] pred_val = pred.detach().cpu().numpy() # [B] input_val = adv_pl1.detach().cpu().numpy() # [B, 3, K] # update for e, (dist, pred, label, ii) in \ enumerate(zip(dist_val, pred_val, label_val, input_val)): if dist < bestdist[e] and pred == label: bestdist[e] = dist bestscore[e] = pred if dist < o_bestdist[e] and pred == label: o_bestdist[e] = dist o_bestscore[e] = pred o_bestattack[e] = ii t3 = time.time() # compute loss and backward adv_loss = self.adv_func(logits, target).mean() loss_chamfer, _ = chamfer_distance(ori_data, adv_pl) loss_edge = mesh_edge_loss(new_defrom_mesh) loss_laplacian = mesh_laplacian_smoothing(new_defrom_mesh, method="uniform") loss = adv_loss + torch.from_numpy(current_weight).mean() * ( loss_chamfer * w_chamfer + loss_edge * w_edge + loss_laplacian * w_laplacian) loss.backward() opt.step() deform_verts.data = self.clip(deform_verts.clone().detach(), ori_def) t4 = time.time() backward_time += t4 - t3 total_time += t4 - t1 if iteration % 100 == 0: print( 'total time: {:.2f}, for: {:.2f}, ' 'back: {:.6f}, update: {:.2f}, total loss: {:.6f}, chamfer loss: {:.6f}' .format(total_time, forward_time, backward_time, update_time, loss, loss_chamfer)) total_time = 0. forward_time = 0. backward_time = 0. update_time = 0. torch.cuda.empty_cache() # adjust weight factor for e, label in enumerate(label_val): if bestscore[e] == label and bestscore[e] != -1 and bestdist[ e] <= o_bestdist[e]: # success lower_bound[e] = max(lower_bound[e], current_weight[e]) current_weight[e] = (lower_bound[e] + upper_bound[e]) / 2. else: # failure upper_bound[e] = min(upper_bound[e], current_weight[e]) current_weight[e] = (lower_bound[e] + upper_bound[e]) / 2. bas += 1 ## save the mesh new_defrom_mesh = deform_ori.offset_verts(deform_verts) for e1 in range(B): final_verts, final_faces = new_defrom_mesh.get_mesh_verts_faces(e1) final_obj = os.path.join( './p1_manifold_random_target01', 'result_model%s_%s_%s_%s.obj' % (bas, e1, label_val[e1], label_true[e1])) save_obj(final_obj, final_verts, final_faces) fail_idx = (lower_bound == 0.) o_bestattack[fail_idx] = input_val[fail_idx] # return final results success_num = (lower_bound > 0.).sum() print('Successfully attack {}/{}'.format(success_num, B)) return o_bestdist, o_bestattack.transpose((0, 2, 1)), success_num