def forward(batch, data_features, network, conf, \ is_val=False, step=None, epoch=None, batch_ind=0, num_batch=1, start_time=0, \ log_console=False, log_tb=False, tb_writer=None, lr=None): # prepare input # generate a batch of data size < 64 batch_index = 1 if len(batch) == 0: return None cur_batch_size = len(batch[data_features.index('total_parts_cnt')]) total_part_cnt = batch[data_features.index('total_parts_cnt')][0] if total_part_cnt == 1: print('passed an entire shape does not work for batch norm') return None input_total_part_cnt = batch[data_features.index('total_parts_cnt')][0] # 1 input_img = batch[data_features.index('img')][0] # 3 x H x W input_img = input_img.repeat(input_total_part_cnt, 1, 1, 1) # part_cnt 3 x H x W input_pts = batch[data_features.index('pts')][0].squeeze(0)[:input_total_part_cnt] # part_cnt x N x 3 input_ins_one_hot = batch[data_features.index('ins_one_hot')][0].squeeze(0)[:input_total_part_cnt] # part_cnt x max_similar_parts input_similar_part_cnt = batch[data_features.index('similar_parts_cnt')][0].squeeze(0)[:input_total_part_cnt] # part_cnt x 1 input_box_size = batch[data_features.index('box_size')][0].squeeze(0)[:input_total_part_cnt] # prepare gt: gt_mask = (batch[data_features.index('mask')][0].squeeze(0)[:input_total_part_cnt].to(conf.device),) input_total_part_cnt = [batch[data_features.index('total_parts_cnt')][0]] while total_part_cnt < 32 and batch_index < cur_batch_size: cur_input_cnt = batch[data_features.index('total_parts_cnt')][batch_index] total_part_cnt += cur_input_cnt if total_part_cnt > 40: total_part_cnt -= cur_input_cnt batch_index += 1 continue cur_batch_img = batch[data_features.index('img')][batch_index].repeat(cur_input_cnt, 1, 1, 1) input_img = torch.cat((input_img, cur_batch_img), dim=0) cur_box_size = batch[data_features.index('box_size')][batch_index].squeeze(0)[:cur_input_cnt] input_box_size = torch.cat( (input_box_size, cur_box_size), dim=0) input_pts = torch.cat((input_pts, batch[data_features.index('pts')][batch_index].squeeze(0)[:cur_input_cnt]), dim=0) # B x max_parts x N x 3 input_ins_one_hot = torch.cat((input_ins_one_hot, batch[data_features.index('ins_one_hot')][batch_index].squeeze(0)[:cur_input_cnt]), dim=0) # B x max_parts x max_similar_parts input_total_part_cnt.append(batch[data_features.index('total_parts_cnt')][batch_index]) # 1 input_similar_part_cnt = torch.cat((input_similar_part_cnt, batch[data_features.index('similar_parts_cnt')][batch_index].squeeze(0)[:cur_input_cnt]), dim=0) # B x max_parts x 2 # prepare gt gt_mask = gt_mask + (batch[data_features.index('mask')][batch_index].squeeze(0)[:cur_input_cnt].to(conf.device), ) batch_index += 1 input_img = input_img.to(conf.device); input_pts = input_pts.to(conf.device); # input_sem_one_hot = input_sem_one_hot.to(conf.device); input_similar_part_cnt = input_similar_part_cnt.to(conf.device); input_ins_one_hot = input_ins_one_hot.to(conf.device) input_box_size = input_box_size.to(conf.device) batch_size = input_img.shape[0] num_point = input_pts.shape[1] # forward through the network pred_masks = network(input_img - 0.5, input_pts, input_ins_one_hot, input_total_part_cnt) # perform matching and calculate masks mask_loss_per_data = []; t = 0; matched_pred_mask_all = torch.zeros(batch_size, 224, 224); matched_gt_mask_all = torch.zeros(batch_size, 224, 224) for i in range(len(input_total_part_cnt)): total_cnt = input_total_part_cnt[i] matched_gt_ids, matched_pred_ids = network.linear_assignment(gt_mask[i], pred_masks[i][:-1, :,:], input_similar_part_cnt[t:t+total_cnt]) # select the matched data matched_pred_mask = pred_masks[i][matched_pred_ids] matched_gt_mask = gt_mask[i][matched_gt_ids] matched_gt_mask_all[t:t+total_cnt, :, :] = matched_gt_mask matched_pred_mask_all[t:t+total_cnt, :, :] = matched_pred_mask # for computing mask soft iou loss matched_mask_loss = network.get_mask_loss(matched_pred_mask, matched_gt_mask) mask_loss_per_data.append(matched_mask_loss.mean()) t+= total_cnt mask_loss_per_data = torch.stack(mask_loss_per_data) # for each type of loss, compute avg loss per batch mask_loss = mask_loss_per_data.mean() # compute total loss total_loss = mask_loss * conf.loss_weight_mask # display information data_split = 'train' if is_val: data_split = 'val' with torch.no_grad(): # log to console if log_console: utils.printout(conf.flog, \ f'''{strftime("%H:%M:%S", time.gmtime(time.time()-start_time)):>9s} ''' f'''{epoch:>5.0f}/{conf.epochs:<5.0f} ''' f'''{data_split:^10s} ''' f'''{batch_ind:>5.0f}/{num_batch:<5.0f} ''' f'''{100. * (1+batch_ind+num_batch*epoch) / (num_batch*conf.epochs):>9.1f}% ''' f'''{lr:>5.2E} ''' f'''{mask_loss.item():>10.5f}''' f'''{total_loss.item():>10.5f}''') conf.flog.flush() # log to tensorboard if log_tb and tb_writer is not None: tb_writer.add_scalar('mask_loss', mask_loss.item(), step) tb_writer.add_scalar('total_loss', total_loss.item(), step) tb_writer.add_scalar('lr', lr, step) # gen visu if is_val and (not conf.no_visu) and epoch % conf.num_epoch_every_visu == 0: visu_dir = os.path.join(conf.exp_dir, 'val_visu') out_dir = os.path.join(visu_dir, 'epoch-%04d' % epoch) input_img_dir = os.path.join(out_dir, 'input_img') input_pts_dir = os.path.join(out_dir, 'input_pts') gt_mask_dir = os.path.join(out_dir, 'gt_mask') pred_mask_dir = os.path.join(out_dir, 'pred_mask') info_dir = os.path.join(out_dir, 'info') if batch_ind == 0: # create folders os.mkdir(out_dir) os.mkdir(input_img_dir) os.mkdir(input_pts_dir) os.mkdir(gt_mask_dir) os.mkdir(pred_mask_dir) os.mkdir(info_dir) if batch_ind < conf.num_batch_every_visu: utils.printout(conf.flog, 'Visualizing ...') t = 0 for i in range(batch_size): fn = 'data-%03d.png' % (batch_ind * batch_size + i) cur_input_img = (input_img[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) Image.fromarray(cur_input_img).save(os.path.join(input_img_dir, fn)) cur_input_pts = input_pts[i].cpu().numpy() render_utils.render_pts(os.path.join(BASE_DIR, input_pts_dir, fn), cur_input_pts, blender_fn='object_centered.blend') cur_gt_mask = (matched_gt_mask_all[i].cpu().numpy() > 0.5).astype(np.uint8) * 255 Image.fromarray(cur_gt_mask).save(os.path.join(gt_mask_dir, fn)) cur_pred_mask = (matched_pred_mask_all[i].cpu().numpy() > 0.5).astype(np.uint8) * 255 Image.fromarray(cur_pred_mask).save(os.path.join(pred_mask_dir, fn)) if batch_ind == conf.num_batch_every_visu - 1: # visu html utils.printout(conf.flog, 'Generating html visualization ...') sublist = 'input_img,input_pts,gt_mask,pred_mask,info' cmd = 'cd %s && python %s . 10 htmls %s %s > /dev/null' % (out_dir, os.path.join(BASE_DIR, '../utils/gen_html_hierachy_local.py'), sublist, sublist) call(cmd, shell=True) utils.printout(conf.flog, 'DONE') return total_loss
2] = (pred_shape_mask[j] * 255) if len(gt_inds) != 0: gt_mask_to_vis[gt_inds[:, 0], gt_inds[:, 1], :] = (np.array(color) * 255) cur_gt_shape_mask_to_vis[gt_inds[:, 0], gt_inds[:, 1], :] = ( np.array(color) * 255) #if len(pred_inds) != 0: #pred_mask_to_vis[pred_inds[:,0], pred_inds[:,1], :] = (np.array(color) * 255) #cur_pred_shape_mask_to_vis[pred_inds[:,0], pred_inds[:,1], :] = (np.array(color) * 255) cur_input_pts = cur_shape_input_pts[j] render_utils.render_pts(os.path.join( BASE_DIR, child_input_pts_dir, child_fn), cur_input_pts, blender_fn='object_centered.blend') cur_gt_mask = cur_gt_shape_mask_to_vis.astype(np.uint8) Image.fromarray(cur_gt_mask).save( os.path.join(child_gt_mask_dir, child_fn + '.png')) cur_pred_mask = cur_pred_shape_mask_to_vis.astype(np.uint8) Image.fromarray(cur_pred_mask).save( os.path.join(child_pred_mask_dir, child_fn + '.png')) cur_pred_pts = pred_shape_to_vis[j] render_utils.render_pts(os.path.join( BASE_DIR, child_pred_pose_dir, child_fn), cur_pred_pts, blender_fn='camera_centered.blend') cur_gt_pts = gt_shape_to_vis[j] render_utils.render_pts(os.path.join( BASE_DIR, child_gt_pose_dir, child_fn),
def forward(batch, data_features, network, conf, \ is_val=False, step=None, epoch=None, batch_ind=0, num_batch=1, start_time=0, \ log_console=False, log_tb=False, tb_writer=None, lr=None): # prepare input input_pcs = torch.cat(batch[data_features.index('pcs')], dim=0).to(conf.device) # B x 3N x 3 input_pxids = torch.cat(batch[data_features.index('pc_pxids')], dim=0).to(conf.device) # B x 3N x 2 input_movables = torch.cat(batch[data_features.index('pc_movables')], dim=0).to(conf.device) # B x 3N batch_size = input_pcs.shape[0] input_pcid1 = torch.arange(batch_size).unsqueeze(1).repeat( 1, conf.num_point_per_shape).long().reshape(-1) # BN input_pcid2 = furthest_point_sample( input_pcs, conf.num_point_per_shape).long().reshape(-1) # BN input_pcs = input_pcs[input_pcid1, input_pcid2, :].reshape(batch_size, conf.num_point_per_shape, -1) input_pxids = input_pxids[input_pcid1, input_pcid2, :].reshape(batch_size, conf.num_point_per_shape, -1) input_movables = input_movables[input_pcid1, input_pcid2].reshape( batch_size, conf.num_point_per_shape) input_dirs1 = torch.cat( batch[data_features.index('gripper_direction_camera')], dim=0).to(conf.device) # B x 3 input_dirs2 = torch.cat( batch[data_features.index('gripper_forward_direction_camera')], dim=0).to(conf.device) # B x 3 # forward through the network pred_result_logits, pred_whole_feats = network( input_pcs, input_dirs1, input_dirs2) # B x 2, B x F x N # prepare gt gt_result = torch.Tensor(batch[data_features.index('result')]).long().to( conf.device) # B gripper_img_target = torch.cat( batch[data_features.index('gripper_img_target')], dim=0).to(conf.device) # B x 3 x H x W # for each type of loss, compute losses per data result_loss_per_data = network.critic.get_ce_loss(pred_result_logits, gt_result) # for each type of loss, compute avg loss per batch result_loss = result_loss_per_data.mean() # compute total loss total_loss = result_loss # display information data_split = 'train' if is_val: data_split = 'val' with torch.no_grad(): # log to console if log_console: utils.printout(conf.flog, \ f'''{strftime("%H:%M:%S", time.gmtime(time.time()-start_time)):>9s} ''' f'''{epoch:>5.0f}/{conf.epochs:<5.0f} ''' f'''{data_split:^10s} ''' f'''{batch_ind:>5.0f}/{num_batch:<5.0f} ''' f'''{100. * (1+batch_ind+num_batch*epoch) / (num_batch*conf.epochs):>9.1f}% ''' f'''{lr:>5.2E} ''' f'''{total_loss.item():>10.5f}''') conf.flog.flush() # log to tensorboard if log_tb and tb_writer is not None: tb_writer.add_scalar('total_loss', total_loss.item(), step) tb_writer.add_scalar('lr', lr, step) # gen visu if is_val and ( not conf.no_visu) and epoch % conf.num_epoch_every_visu == 0: visu_dir = os.path.join(conf.exp_dir, 'val_visu') out_dir = os.path.join(visu_dir, 'epoch-%04d' % epoch) input_pc_dir = os.path.join(out_dir, 'input_pc') gripper_img_target_dir = os.path.join(out_dir, 'gripper_img_target') info_dir = os.path.join(out_dir, 'info') if batch_ind == 0: # create folders os.mkdir(out_dir) os.mkdir(input_pc_dir) os.mkdir(gripper_img_target_dir) os.mkdir(info_dir) if batch_ind < conf.num_batch_every_visu: utils.printout(conf.flog, 'Visualizing ...') for i in range(batch_size): fn = 'data-%03d.png' % (batch_ind * batch_size + i) render_utils.render_pts(os.path.join( BASE_DIR, input_pc_dir, fn), input_pcs[i].cpu().numpy(), highlight_id=0) cur_gripper_img_target = ( gripper_img_target[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) Image.fromarray(cur_gripper_img_target).save( os.path.join(gripper_img_target_dir, fn)) with open( os.path.join(info_dir, fn.replace('.png', '.txt')), 'w') as fout: fout.write('cur_dir: %s\n' % batch[data_features.index('cur_dir')][i]) fout.write('pred: %s\n' % utils.print_true_false( (pred_result_logits[i] > 0).cpu().numpy())) fout.write( 'gt: %s\n' % utils.print_true_false(gt_result[i].cpu().numpy())) fout.write('result_loss: %f\n' % result_loss_per_data[i].item()) if batch_ind == conf.num_batch_every_visu - 1: # visu html utils.printout(conf.flog, 'Generating html visualization ...') sublist = 'input_pc,gripper_img_target,info' cmd = 'cd %s && python %s . 10 htmls %s %s > /dev/null' % ( out_dir, os.path.join(BASE_DIR, 'gen_html_hierachy_local.py'), sublist, sublist) call(cmd, shell=True) utils.printout(conf.flog, 'DONE') return total_loss, pred_whole_feats.detach(), input_pcs.detach( ), input_pxids.detach(), input_movables.detach()
def forward(batch, data_features, network, conf, \ is_val=False, step=None, epoch=None, batch_ind=0, num_batch=1, start_time=0, \ log_console=False, log_tb=False, tb_writer=None, lr=None): # prepare input batch_index = 1 if len(batch) == 0: return None cur_batch_size = len(batch[data_features.index('total_parts_cnt')]) total_part_cnt = batch[data_features.index('total_parts_cnt')][0] input_total_part_cnt = batch[data_features.index('total_parts_cnt')][ 0] # 1 input_img = batch[data_features.index('img')][0] # 3 x H x W input_img = input_img.repeat(input_total_part_cnt, 1, 1, 1) # part_cnt 3 x H x W input_pts = batch[data_features.index('pts')][0].squeeze( 0)[:input_total_part_cnt] input_ins_one_hot = batch[data_features.index('ins_one_hot')][0].squeeze( 0)[:input_total_part_cnt] # part_cnt x max_similar_parts input_similar_part_cnt = batch[data_features.index('similar_parts_cnt')][ 0].squeeze(0)[:input_total_part_cnt] # part_cnt x 1 input_shape_id = [batch[data_features.index('shape_id')][0] ] * input_total_part_cnt input_view_id = [batch[data_features.index('view_id')][0] ] * input_total_part_cnt # prepare gt: gt_cam_dof = batch[data_features.index('parts_cam_dof')][0].squeeze( 0)[:input_total_part_cnt] gt_mask = [ batch[data_features.index('mask')][0].squeeze(0) [:input_total_part_cnt].to(conf.device) ] input_total_part_cnt = [batch[data_features.index('total_parts_cnt')][0]] input_similar_parts_edge_indices = [ batch[data_features.index('similar_parts_edge_indices')][0].to( conf.device) ] while total_part_cnt < 70 and batch_index < cur_batch_size: cur_input_cnt = batch[data_features.index( 'total_parts_cnt')][batch_index] total_part_cnt += cur_input_cnt if total_part_cnt > 90: total_part_cnt -= cur_input_cnt batch_index += 1 continue cur_batch_img = batch[data_features.index('img')][batch_index].repeat( cur_input_cnt, 1, 1, 1) input_img = torch.cat((input_img, cur_batch_img), dim=0) input_pts = torch.cat((input_pts, batch[data_features.index('pts')] [batch_index].squeeze(0)[:cur_input_cnt]), dim=0) input_ins_one_hot = torch.cat( (input_ins_one_hot, batch[data_features.index('ins_one_hot')] [batch_index].squeeze(0)[:cur_input_cnt]), dim=0) # B x max_parts x max_similar_parts input_total_part_cnt.append( batch[data_features.index('total_parts_cnt')][batch_index]) # 1 input_similar_part_cnt = torch.cat( (input_similar_part_cnt, batch[data_features.index( 'similar_parts_cnt')][batch_index].squeeze(0)[:cur_input_cnt]), dim=0) # B x max_parts x 2 input_shape_id += [ batch[data_features.index('shape_id')][batch_index] ] * cur_input_cnt input_view_id += [batch[data_features.index('view_id')][batch_index] ] * cur_input_cnt gt_cam_dof = torch.cat((gt_cam_dof, batch[data_features.index( 'parts_cam_dof')][batch_index].squeeze(0)[:cur_input_cnt]), dim=0) # prepare gt gt_mask.append(batch[data_features.index('mask')][batch_index].squeeze( 0)[:cur_input_cnt].to(conf.device)) input_similar_parts_edge_indices.append(batch[data_features.index( 'similar_parts_edge_indices')][batch_index].to(conf.device)) batch_index += 1 input_img = input_img.to(conf.device) input_pts = input_pts.to(conf.device) input_similar_part_cnt = input_similar_part_cnt.to(conf.device) input_ins_one_hot = input_ins_one_hot.to(conf.device) gt_cam_dof = gt_cam_dof.to(conf.device) # prepare gt gt_center = gt_cam_dof[:, :3] # B x 3 gt_quat = gt_cam_dof[:, 3:] # B x 4 batch_size = input_img.shape[0] num_point = input_pts.shape[1] # forward through the network pred_masks, pred_center, pred_quat, pred_center2, pred_quat2 = network( input_img - 0.5, input_pts, input_ins_one_hot, input_total_part_cnt, input_similar_parts_edge_indices) mask_loss_per_data = [] t = 0 matched_pred_mask_all = [] matched_gt_mask_all = [] matched_mask_loss_per_data_all = [] matched_pred_center_all = [] matched_gt_center_all = [] matched_pred_center2_all = [] matched_pred_quat_all = [] matched_gt_quat_all = [] matched_pred_quat2_all = [] matched_ins_onehot_all = [] for i in range(len(input_total_part_cnt)): total_cnt = input_total_part_cnt[i] matched_gt_ids, matched_pred_ids = network.linear_assignment(gt_mask[i], pred_masks[i], \ input_similar_part_cnt[t:t+total_cnt], input_pts[t:t+total_cnt], gt_center[t:t+total_cnt], \ gt_quat[t:t+total_cnt], pred_center[t:t+total_cnt], pred_quat[t:t+total_cnt]) # select the matched data matched_pred_mask = pred_masks[i][matched_pred_ids] matched_gt_mask = gt_mask[i][matched_gt_ids] matched_pred_center = pred_center[t:t + total_cnt][matched_pred_ids] matched_pred_center2 = pred_center2[t:t + total_cnt][matched_pred_ids] matched_gt_center = gt_center[t:t + total_cnt][matched_gt_ids] matched_pred_quat = pred_quat[t:t + total_cnt][matched_pred_ids] matched_pred_quat2 = pred_quat2[t:t + total_cnt][matched_pred_ids] matched_gt_quat = gt_quat[t:t + total_cnt][matched_gt_ids] matched_ins_onehot = input_ins_one_hot[t:t + total_cnt][matched_pred_ids] matched_ins_onehot_all.append(matched_ins_onehot) matched_gt_mask_all.append(matched_gt_mask) matched_pred_mask_all.append(matched_pred_mask) matched_pred_center_all.append(matched_pred_center) matched_pred_center2_all.append(matched_pred_center2) matched_gt_center_all.append(matched_gt_center) matched_pred_quat_all.append(matched_pred_quat) matched_pred_quat2_all.append(matched_pred_quat2) matched_gt_quat_all.append(matched_gt_quat) # for computing mask soft iou loss matched_mask_loss_per_data = network.get_mask_loss( matched_pred_mask, matched_gt_mask) matched_mask_loss_per_data_all.append(matched_mask_loss_per_data) mask_loss_per_data.append(matched_mask_loss_per_data.mean()) t += total_cnt matched_ins_onehot_all = torch.cat(matched_ins_onehot_all, dim=0) matched_pred_mask_all = torch.cat(matched_pred_mask_all, dim=0) matched_gt_mask_all = torch.cat(matched_gt_mask_all, dim=0) matched_mask_loss_per_data_all = torch.cat(matched_mask_loss_per_data_all, dim=0) matched_pred_quat_all = torch.cat(matched_pred_quat_all, dim=0) matched_pred_quat2_all = torch.cat(matched_pred_quat2_all, dim=0) matched_gt_quat_all = torch.cat(matched_gt_quat_all, dim=0) matched_pred_center_all = torch.cat(matched_pred_center_all, dim=0) matched_pred_center2_all = torch.cat(matched_pred_center2_all, dim=0) matched_gt_center_all = torch.cat(matched_gt_center_all, dim=0) center_loss_per_data = network.get_center_loss(matched_pred_center_all, matched_gt_center_all) + \ network.get_center_loss(matched_pred_center2_all, matched_gt_center_all) quat_loss_per_data = network.get_quat_loss(input_pts, matched_pred_quat_all, matched_gt_quat_all) + \ network.get_quat_loss(input_pts, matched_pred_quat2_all, matched_gt_quat_all) l2_rot_loss_per_data = network.get_l2_rotation_loss(input_pts, matched_pred_quat_all, matched_gt_quat_all) + \ network.get_l2_rotation_loss(input_pts, matched_pred_quat2_all, matched_gt_quat_all) whole_shape_cd_per_data = network.get_shape_chamfer_loss(input_pts, matched_pred_quat_all, matched_gt_quat_all, matched_pred_center_all, matched_gt_center_all, input_total_part_cnt) + \ network.get_shape_chamfer_loss(input_pts, matched_pred_quat2_all, matched_gt_quat_all, matched_pred_center2_all, matched_gt_center_all, input_total_part_cnt) # for each type of loss, compute avg loss per batch mask_loss_per_data = torch.stack(mask_loss_per_data) center_loss = center_loss_per_data.mean() quat_loss = quat_loss_per_data.mean() mask_loss = mask_loss_per_data.mean() l2_rot_loss = l2_rot_loss_per_data.mean() shape_chamfer_loss = whole_shape_cd_per_data.mean() # compute total loss total_loss = \ center_loss * conf.loss_weight_center + \ quat_loss * conf.loss_weight_quat + \ l2_rot_loss * conf.loss_weight_l2_rot + \ shape_chamfer_loss * conf.loss_weight_shape_chamfer # display information data_split = 'train' if is_val: data_split = 'val' with torch.no_grad(): # log to console if log_console: utils.printout(conf.flog, \ f'''{strftime("%H:%M:%S", time.gmtime(time.time()-start_time)):>9s} ''' f'''{epoch:>5.0f}/{conf.epochs:<5.0f} ''' f'''{data_split:^10s} ''' f'''{batch_ind:>5.0f}/{num_batch:<5.0f} ''' f'''{100. * (1+batch_ind+num_batch*epoch) / (num_batch*conf.epochs):>9.1f}% ''' f'''{lr:>5.2E} ''' f'''{mask_loss.item():>10.5f}''' f'''{center_loss.item():>10.5f}''' f'''{quat_loss.item():>10.5f}''' f'''{l2_rot_loss.item():>10.5f}''' f'''{shape_chamfer_loss.item():>10.5f}''' f'''{total_loss.item():>10.5f}''') conf.flog.flush() # log to tensorboard if log_tb and tb_writer is not None: tb_writer.add_scalar('mask_loss', mask_loss.item(), step) tb_writer.add_scalar('center_loss', center_loss.item(), step) tb_writer.add_scalar('quat_loss', quat_loss.item(), step) tb_writer.add_scalar('l2 rotation_loss', l2_rot_loss.item(), step) tb_writer.add_scalar('shape_chamfer_loss', shape_chamfer_loss.item(), step) tb_writer.add_scalar('total_loss', total_loss.item(), step) tb_writer.add_scalar('lr', lr, step) # gen visu if is_val and ( not conf.no_visu) and epoch % conf.num_epoch_every_visu == 0: visu_dir = os.path.join(conf.exp_dir, 'val_visu') out_dir = os.path.join(visu_dir, 'epoch-%04d' % epoch) input_img_dir = os.path.join(out_dir, 'input_img') input_pts_dir = os.path.join(out_dir, 'input_pts') gt_mask_dir = os.path.join(out_dir, 'gt_mask') pred_mask_dir = os.path.join(out_dir, 'pred_mask') gt_dof_dir = os.path.join(out_dir, 'gt_dof') pred_dof_dir = os.path.join(out_dir, 'pred_dof') pred_dof2_dir = os.path.join(out_dir, 'pred_dof2') info_dir = os.path.join(out_dir, 'info') if batch_ind == 0: # create folders os.mkdir(out_dir) os.mkdir(input_img_dir) os.mkdir(input_pts_dir) os.mkdir(gt_mask_dir) os.mkdir(pred_mask_dir) os.mkdir(gt_dof_dir) os.mkdir(pred_dof_dir) os.mkdir(pred_dof2_dir) os.mkdir(info_dir) if batch_ind < conf.num_batch_every_visu: utils.printout(conf.flog, 'Visualizing ...') # compute pred_pts and gt_pts pred_pts = qrot( matched_pred_quat_all.unsqueeze(1).repeat(1, num_point, 1), input_pts) + matched_pred_center_all.unsqueeze(1).repeat( 1, num_point, 1) pred2_pts = qrot( matched_pred_quat2_all.unsqueeze(1).repeat( 1, num_point, 1), input_pts) + matched_pred_center2_all.unsqueeze(1).repeat( 1, num_point, 1) gt_pts = qrot( matched_gt_quat_all.unsqueeze(1).repeat(1, num_point, 1), input_pts) + matched_gt_center_all.unsqueeze(1).repeat( 1, num_point, 1) t = 0 for i in range(batch_size): fn = 'data-%03d.png' % (batch_ind * batch_size + i) cur_input_img = ( input_img[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) Image.fromarray(cur_input_img).save( os.path.join(input_img_dir, fn)) cur_input_pts = input_pts[i].cpu().numpy() render_utils.render_pts(os.path.join( BASE_DIR, input_pts_dir, fn), cur_input_pts, blender_fn='object_centered.blend') cur_gt_mask = (matched_gt_mask_all[i].cpu().numpy() > 0.5).astype(np.uint8) * 255 Image.fromarray(cur_gt_mask).save( os.path.join(gt_mask_dir, fn)) cur_pred_mask = (matched_pred_mask_all[i].cpu().numpy() > 0.5).astype(np.uint8) * 255 Image.fromarray(cur_pred_mask).save( os.path.join(pred_mask_dir, fn)) cur_pred_pts = pred_pts[i].cpu().numpy() render_utils.render_pts(os.path.join( BASE_DIR, pred_dof_dir, fn), cur_pred_pts, blender_fn='camera_centered.blend') cur_pred_pts = pred2_pts[i].cpu().numpy() render_utils.render_pts(os.path.join( BASE_DIR, pred_dof2_dir, fn), cur_pred_pts, blender_fn='camera_centered.blend') cur_gt_pts = gt_pts[i].cpu().numpy() render_utils.render_pts(os.path.join( BASE_DIR, gt_dof_dir, fn), cur_gt_pts, blender_fn='camera_centered.blend') with open( os.path.join(info_dir, fn.replace('.png', '.txt')), 'w') as fout: fout.write('shape_id: %s, view_id: %s\n' % (\ input_shape_id[i],\ input_view_id[i])) fout.write('ins onehot %s\n' % matched_ins_onehot_all[i]) fout.write('mask_loss: %f\n' % matched_mask_loss_per_data_all[i].item()) fout.write('center_loss: %f\n' % center_loss_per_data[i].item()) fout.write('quat_loss: %f\n' % quat_loss_per_data[i].item()) fout.write('l2_rot_loss: %f\n' % l2_rot_loss_per_data[i].item()) fout.write('shape_chamfer_loss %f\n' % shape_chamfer_loss.item()) return total_loss