def eval(_run, _log): cfg = edict(_run.config) torch.manual_seed(cfg.seed) np.random.seed(cfg.seed) random.seed(cfg.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") evaluation_dir = os.path.join('../evaluations_7_scenes_eval', str(_run._id)) if not os.path.exists(evaluation_dir): os.makedirs(evaluation_dir) # build normal_network and depth_network depth_network = depthNet(depth_scale=cfg.dataset.depth_scale) # load nets into gpu if cfg.num_gpus > 0 and torch.cuda.is_available(): depth_network = torch.nn.DataParallel(depth_network) if not cfg.resume_dir == 'None': print('resume training') checkpoint = torch.load(cfg.resume_dir) # should change to here this line try: try: depth_network.load_state_dict( checkpoint['depth_network_state_dict']) except: # for model is saved by nn.DataParallel from collections import OrderedDict new_state_dict = OrderedDict() state_dict = checkpoint['depth_network_state_dict'] for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v # load params depth_network.load_state_dict(new_state_dict) except: depth_network.load_state_dict(checkpoint['state_dict']) depth_network = depth_network.to(device) else: print("evaluation must need checkpoint") if cfg.resume_dir == 'None': depth_network.to(device) depth2normal = Depth2normal(cfg.k_size) depth2normal.to(device) # data loader sevenScenes = LoadSevenScenes(cfg.dataset.root_dir) depth_network.eval() # main loop for scene, seq in sevenScenes.test_seqs_list: rgb_dir = os.path.join(evaluation_dir, scene, seq, 'rgb') gt_depth_dir = os.path.join(evaluation_dir, scene, seq, 'gt_depth') pred_depth_dir = os.path.join(evaluation_dir, scene, seq, 'pred_depth') pred_normal_dir = os.path.join(evaluation_dir, scene, seq, 'pred_normal') dirs = [rgb_dir, gt_depth_dir, pred_depth_dir, pred_normal_dir] for dir in dirs: if not os.path.exists(dir): os.makedirs(dir) filepaths_list = sevenScenes.get_filepaths(scene, seq) for index in range(0, len(filepaths_list) - 10, 1): if index % 3 != 0: continue print(scene, seq, index) ref_sample_path = filepaths_list[index] source_sample_path = filepaths_list[index + 10] ref_rgb, gt_depth, ref_cam, pred_depth_name = sevenScenes.load_sample( ref_sample_path, cfg.dataset.image_height, cfg.dataset.image_width) source_rgb, _, source_cam, _ = sevenScenes.load_sample( source_sample_path, cfg.dataset.image_height, cfg.dataset.image_width) ref_rgb = ref_rgb.to(device) source_rgb = source_rgb.to(device) c, h, w = ref_rgb.shape gt_depth = gt_depth.to(device) # [h, w] ref_cam = ref_cam.to(device) source_cam = source_cam.to(device) depth_preds_list, _ = depth_network(ref_rgb.unsqueeze(0), source_rgb.unsqueeze(0), ref_cam.unsqueeze(0), source_cam.unsqueeze(0)) depth_preds = depth_preds_list[0].squeeze(1) intrinsic = ref_cam[1, 0:3, 0:3] intrinsic_inv = torch.inverse(intrinsic) normal_from_depth, _ = depth2normal(depth_preds, intrinsic_inv.unsqueeze(0)) # ================================================================== # # Tensorboard Logging # # ================================================================== # with torch.no_grad(): info = { 'rgb': ref_rgb.permute(1, 2, 0).cpu().numpy(), 'normal_from_depth': normal_from_depth.permute(0, 2, 3, 1).squeeze().cpu().numpy(), 'gt_depth': gt_depth.cpu().numpy(), 'pred_depth': depth_preds_list[0].squeeze().cpu().numpy() } rgb_filepath = os.path.join( rgb_dir, pred_depth_name.replace("pred_depth", "color")) scipy.misc.imsave(rgb_filepath, info['rgb']) pred_normal_filepath = os.path.join( pred_normal_dir, pred_depth_name.replace("pred_depth.png", "pred_normal.npy")) np.save(pred_normal_filepath, info['normal_from_depth']) pred_normal_color = normal2color(info['normal_from_depth']) pred_normal_color_filepath = os.path.join( pred_normal_dir, pred_depth_name.replace("pred_depth", "pred_normal")) scipy.misc.imsave(pred_normal_color_filepath, pred_normal_color) gt_depth_filepath = os.path.join( gt_depth_dir, pred_depth_name.replace("pred_depth.png", "gt_depth.npy")) np.save(gt_depth_filepath, info['gt_depth']) gt_depth_color = depth2color( info['gt_depth'], depth_scale=cfg.dataset.depth_scale) gt_depth_color_filepath = os.path.join( gt_depth_dir, pred_depth_name.replace("pred_depth", "gt_depth")) scipy.misc.imsave(gt_depth_color_filepath, gt_depth_color) pred_depth = info['pred_depth'] pred_depth_filepath = os.path.join( pred_depth_dir, pred_depth_name.replace("pred_depth.png", "pred_depth.npy")) np.save(pred_depth_filepath, pred_depth) pred_depth_color = depth2color( pred_depth, depth_scale=cfg.dataset.depth_scale) pred_depth_color_filepath = os.path.join( pred_depth_dir, pred_depth_name) scipy.misc.imsave(pred_depth_color_filepath, pred_depth_color)
def eval_refine_seven_views(_run, _log): cfg = edict(_run.config) torch.manual_seed(cfg.seed) np.random.seed(cfg.seed) random.seed(cfg.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") evaluation_dir = os.path.join('../evaluations_7_scenes_refine_seven_views', str(_run._id)) if not os.path.exists(evaluation_dir): os.makedirs(evaluation_dir) # build normal_network and depth_network depth_network = depthNet(depth_scale=cfg.dataset.depth_scale) depth_refine_network = DepthRefineNet(depth_scale=cfg.dataset.depth_scale) if not cfg.resume_dir == 'None': print('resume training') checkpoint = torch.load(cfg.resume_dir) # should change to here this line try: depth_network.load_state_dict( checkpoint['depth_network_state_dict']) except: # for model is saved by nn.DataParallel from collections import OrderedDict new_state_dict = OrderedDict() state_dict = checkpoint['depth_network_state_dict'] for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v # load params depth_network.load_state_dict(new_state_dict) try: depth_refine_network.load_state_dict( checkpoint['depth_refine_network_state_dict']) except: # for model is saved by nn.DataParallel from collections import OrderedDict refine_state_dict = OrderedDict() state_dict = checkpoint['depth_refine_network_state_dict'] for k, v in state_dict.items(): name = k[7:] # remove `module.` refine_state_dict[name] = v # load params depth_refine_network.load_state_dict(refine_state_dict) else: print("evaluation must need checkpoint") # load nets into gpu if cfg.num_gpus > 1 and torch.cuda.is_available(): depth_network = torch.nn.DataParallel(depth_network) depth2normal = Depth2normal(cfg.k_size) depth2normal.to(device) depth_network.to(device) depth_refine_network.to(device) # data loader sevenScenes = LoadSevenScenes(cfg.dataset.root_dir) depth_network.eval() depth_refine_network.eval() # main loop for scene, seq in sevenScenes.test_seqs_list: rgb_dir = os.path.join(evaluation_dir, scene, seq, 'rgb') gt_depth_dir = os.path.join(evaluation_dir, scene, seq, 'gt_depth') pred_depth_dir = os.path.join(evaluation_dir, scene, seq, 'pred_depth') pred_normal_dir = os.path.join(evaluation_dir, scene, seq, 'pred_normal') prob_map_dir = os.path.join(evaluation_dir, scene, seq, 'prob_map') dirs = [ rgb_dir, gt_depth_dir, pred_depth_dir, pred_normal_dir, prob_map_dir ] for dir in dirs: if not os.path.exists(dir): os.makedirs(dir) filepaths_list = sevenScenes.get_filepaths(scene, seq) for index in range(10, len(filepaths_list) - 20, 1): if index % 3 != 0: continue print(scene, seq, index) ref_sample_path = filepaths_list[index] ## expriments show that # (20, 10, 5, 0, -5, -10, -20) is best source_1_sample_path = filepaths_list[index + 10] source_2_sample_path = filepaths_list[index - 10] source_3_sample_path = filepaths_list[index + 5] source_4_sample_path = filepaths_list[index - 5] source_5_sample_path = filepaths_list[index + 20] source_6_sample_path = filepaths_list[index - 20] # camera may be invalid try: ref_rgb, gt_depth, ref_cam, pred_depth_name = sevenScenes.load_sample( ref_sample_path, cfg.dataset.image_height, cfg.dataset.image_width) source_1_rgb, _, source_1_cam, _ = sevenScenes.load_sample( source_1_sample_path, cfg.dataset.image_height, cfg.dataset.image_width) source_2_rgb, _, source_2_cam, _ = sevenScenes.load_sample( source_2_sample_path, cfg.dataset.image_height, cfg.dataset.image_width) source_3_rgb, _, source_3_cam, _ = sevenScenes.load_sample( source_3_sample_path, cfg.dataset.image_height, cfg.dataset.image_width) source_4_rgb, _, source_4_cam, _ = sevenScenes.load_sample( source_4_sample_path, cfg.dataset.image_height, cfg.dataset.image_width) source_5_rgb, _, source_5_cam, _ = sevenScenes.load_sample( source_5_sample_path, cfg.dataset.image_height, cfg.dataset.image_width) source_6_rgb, _, source_6_cam, _ = sevenScenes.load_sample( source_6_sample_path, cfg.dataset.image_height, cfg.dataset.image_width) except: print("invalid_camera") continue ref_rgb = ref_rgb.to(device) source_1_rgb = source_1_rgb.to(device) source_2_rgb = source_2_rgb.to(device) source_3_rgb = source_3_rgb.to(device) source_4_rgb = source_4_rgb.to(device) source_5_rgb = source_5_rgb.to(device) source_6_rgb = source_6_rgb.to(device) c, h, w = ref_rgb.shape gt_depth = gt_depth.to(device) # [h, w] ref_cam = ref_cam.to(device) source_1_cam = source_1_cam.to(device) source_2_cam = source_2_cam.to(device) source_3_cam = source_3_cam.to(device) source_4_cam = source_4_cam.to(device) source_5_cam = source_5_cam.to(device) source_6_cam = source_6_cam.to(device) idepth_preds_01, iconv_01 = depth_network( ref_rgb.unsqueeze(0), source_1_rgb.unsqueeze(0), ref_cam.unsqueeze(0), source_1_cam.unsqueeze(0)) idepth_preds_02, iconv_02 = depth_network( ref_rgb.unsqueeze(0), source_2_rgb.unsqueeze(0), ref_cam.unsqueeze(0), source_2_cam.unsqueeze(0)) idepth_preds_03, iconv_03 = depth_network( ref_rgb.unsqueeze(0), source_3_rgb.unsqueeze(0), ref_cam.unsqueeze(0), source_3_cam.unsqueeze(0)) idepth_preds_04, iconv_04 = depth_network( ref_rgb.unsqueeze(0), source_4_rgb.unsqueeze(0), ref_cam.unsqueeze(0), source_4_cam.unsqueeze(0)) idepth_preds_05, iconv_05 = depth_network( ref_rgb.unsqueeze(0), source_5_rgb.unsqueeze(0), ref_cam.unsqueeze(0), source_5_cam.unsqueeze(0)) idepth_preds_06, iconv_06 = depth_network( ref_rgb.unsqueeze(0), source_6_rgb.unsqueeze(0), ref_cam.unsqueeze(0), source_6_cam.unsqueeze(0)) idepth_refined, prob_map = depth_refine_network( idepth01=(idepth_preds_01[0] + idepth_preds_03[0] + idepth_preds_05[0]) / 3., idepth02=(idepth_preds_02[0] + idepth_preds_04[0] + idepth_preds_06[0]) / 3., iconv01=(iconv_01 + iconv_03 + iconv_05) / 3., iconv02=(iconv_02 + iconv_04 + iconv_06) / 3.) depth_preds = idepth_refined.squeeze(1) intrinsic = ref_cam[1, 0:3, 0:3] intrinsic_inv = torch.inverse(intrinsic) normal_from_depth, _ = depth2normal(depth_preds, intrinsic_inv.unsqueeze(0)) # ================================================================== # # Tensorboard Logging # # ================================================================== # with torch.no_grad(): info = { 'rgb': ref_rgb.permute(1, 2, 0).cpu().numpy(), 'normal_from_depth': normal_from_depth.permute(0, 2, 3, 1).squeeze().cpu().numpy(), 'gt_depth': gt_depth.cpu().numpy(), 'pred_idepth': idepth_refined.squeeze().cpu().numpy(), 'prob_map': prob_map.squeeze().cpu().numpy() } rgb_filepath = os.path.join( rgb_dir, pred_depth_name.replace("pred_depth", "color")) scipy.misc.imsave(rgb_filepath, info['rgb']) pred_normal_filepath = os.path.join( pred_normal_dir, pred_depth_name.replace("pred_depth.png", "pred_normal.npy")) np.save(pred_normal_filepath, info['normal_from_depth']) pred_normal_color = normal2color(info['normal_from_depth']) pred_normal_color_filepath = os.path.join( pred_normal_dir, pred_depth_name.replace("pred_depth", "pred_normal")) scipy.misc.imsave(pred_normal_color_filepath, pred_normal_color) gt_depth_filepath = os.path.join( gt_depth_dir, pred_depth_name.replace("pred_depth.png", "gt_depth.npy")) np.save(gt_depth_filepath, info['gt_depth']) gt_depth_color = depth2color( info['gt_depth'], depth_scale=cfg.dataset.depth_scale) gt_depth_color_filepath = os.path.join( gt_depth_dir, pred_depth_name.replace("pred_depth", "gt_depth")) scipy.misc.imsave(gt_depth_color_filepath, gt_depth_color) pred_depth = info['pred_idepth'] pred_depth_filepath = os.path.join( pred_depth_dir, pred_depth_name.replace("pred_depth.png", "pred_depth.npy")) np.save(pred_depth_filepath, pred_depth) pred_depth_color = depth2color( pred_depth, depth_scale=cfg.dataset.depth_scale) pred_depth_color_filepath = os.path.join( pred_depth_dir, pred_depth_name) scipy.misc.imsave(pred_depth_color_filepath, pred_depth_color) prob_map_color = colorize_probmap(info['prob_map']) prob_map_color_filepath = os.path.join( prob_map_dir, pred_depth_name.replace("pred_depth.png", "prob_map.png")) scipy.misc.imsave(prob_map_color_filepath, prob_map_color) prob_map_filepath = os.path.join( prob_map_dir, pred_depth_name.replace("pred_depth.png", "prob_map.npy")) np.save(prob_map_filepath, info['prob_map'])
def train(_run, _log): cfg = edict(_run.config) torch.manual_seed(cfg.seed) np.random.seed(cfg.seed) random.seed(cfg.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") checkpoint_dir = os.path.join('../experiments', str(_run._id), 'checkpoints') if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) logger_dir = os.path.join('../experiments', str(_run._id), 'log') if not os.path.exists(logger_dir): os.mkdir(logger_dir) # logger logger = Logger(logger_dir) # build normal_network and depth_network depth_network = depthNet(idepth_scale=cfg.idepth_scale) # for p in depth_network.parameters(): # p.requires_grad = False depth_refine_network = DepthRefineNet(idepth_scale=cfg.idepth_scale) # set up optimizers network_params = list(depth_refine_network.parameters()) + list(depth_network.parameters()) optimizer = get_optimizer(network_params, cfg.solver) if not cfg.resume_dir == 'None': print('resume training') checkpoint = torch.load(cfg.resume_dir) depth_network.load_state_dict(checkpoint['depth_network_state_dict']) try: depth_refine_network.load_state_dict(checkpoint['depth_refine_network_state_dict']) except: print("no checkpoint for refineNet") # optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch'] global_step = checkpoint['global_step'] else: global_step = 0 start_epoch = 0 # load nets into gpu if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") gpu_num = torch.cuda.device_count() depth_network = torch.nn.DataParallel(depth_network) depth_refine_network = torch.nn.DataParallel(depth_refine_network) else: gpu_num = 1 device_normal = torch.device('cuda:' + str(gpu_num - 1)) device_net = torch.device(device) depth_network.to(device_net) depth_refine_network.to(device_net) depth2normal = Depth2normal(cfg.k_size) depth2normal.to(device_normal) # data loader data_loader = load_dataset('train', cfg.dataset) depth_network.train() depth_refine_network.train() criterion_234 = IdepthLoss_234() criterion_1 = IdepthLoss() criterion_idepth_prob = IdepthwithProbLoss() criterion_prob = ProbLoss() # main loop for epoch in range(start_epoch + 1, cfg.num_epochs): batch_time = AverageMeter() tic = time.time() for iter, sample in enumerate(data_loader): # try: image = sample['rgbs'].to(device) batch_size, views, c, h, w = image.shape instance = sample['plane_instance_segs'] instance = instance.to(device) # semantic = sample['semantic'].to(device) gt_depth = sample['depths'].to(device) gt_seg = sample['plane_segs'].to(device) # [b, views, h, w] normals_from_plane_para = sample['normals_from_plane_para'].to(device) gt_normal = sample['normals'].to(device) gt_normal_valid = gt_depth > 0.1 plane_nums = sample['plane_nums'] # [b, views] # valid_region = sample['valid_region'].to(device) # gt_plane_instance_parameter = sample['plane_instance_parameter'].to(device) gt_cam = sample['cameras'].to(device) gt_disparity = sample['disparities'].to(device) idepth_preds_01, iconv_01 = depth_network(image[:, 0, :, :, :], image[:, 1, :, :, :], gt_cam[:, 0, :, :, :], gt_cam[:, 1, :, :, :]) idepth_preds_02, iconv_02 = depth_network(image[:, 0, :, :, :], image[:, 2, :, :, :], gt_cam[:, 0, :, :, :], gt_cam[:, 2, :, :, :]) ###################### # left-right refine idepth ###################### idepth_refined, prob_map = depth_refine_network(idepth01=idepth_preds_01[0], idepth02=idepth_preds_02[0], iconv01=iconv_01, iconv02=iconv_02) loss_idepth_1 = (criterion_1(idepth_preds_01[0], gt_disparity[:, 0, :, :, :]) + criterion_1(idepth_preds_02[0], gt_disparity[:, 0, :, :, :])) * 0.5 loss_idepth_refined = criterion_1(idepth_refined, gt_disparity[:, 0, :, :, :]) loss_idepth_234 = (criterion_234(idepth_preds_01, gt_disparity[:, 0, :, :, :]) + criterion_234(idepth_preds_02, gt_disparity[:, 0, :, :, :])) * 0.5 depth_preds_01 = torch.div(1.0, idepth_preds_01[0].squeeze(1)) depth_preds_02 = torch.div(1.0, idepth_preds_02[0].squeeze(1)) depth_refined = torch.div(1.0, idepth_refined.squeeze(1) + 1e-5) #################################### # prob loss ##################################### prob_loss_depth = criterion_idepth_prob(idepth_refined, gt_disparity[:, 0, :, :, :], prob_map) + \ criterion_idepth_prob(depth_refined.unsqueeze(1), gt_depth[:, 0, :, :, :], prob_map) prob_loss_minusmean = 1 - prob_map.mean() prob_map_loss, prob_map_gt = criterion_prob(prob_map, idepth_refined, gt_disparity[:, 0, :, :, :]) prob_loss = 5 * prob_loss_depth + prob_loss_minusmean # + prob_map_loss intrinsic = gt_cam[:, 0, 1, 0:3, 0:3] intrinsic_inv = torch.inverse(intrinsic) normal_from_depth_01, _ = depth2normal(depth_preds_01.to(device_normal), intrinsic_inv.to(device_normal)) normal_from_depth_02, _ = depth2normal(depth_preds_02.to(device_normal), intrinsic_inv.to(device_normal)) normal_from_depth_refined, _ = depth2normal(depth_refined.to(device_normal), intrinsic_inv.to(device_normal)) normal_from_depth_01 = normal_from_depth_01.to('cuda:0') normal_from_depth_02 = normal_from_depth_02.to('cuda:0') normal_from_depth_refined = normal_from_depth_refined.to('cuda:0') normal_std = 0 normal_by_planes = get_normal_by_planes(gt_normal[:, 0, :, :, :], instance[:, 0, :, :, :], plane_nums[:, 0]) loss_depth_1 = (criterion_1(depth_preds_01.unsqueeze(1), gt_depth[:, 0, :, :, :]) + criterion_1(depth_preds_02.unsqueeze(1), gt_depth[:, 0, :, :, :])) * 0.5 loss_depth_refined = criterion_1(depth_refined.unsqueeze(1), gt_depth[:, 0, :, :, :]) # calculate loss loss, loss_depth, loss_normal = 0., 0., 0. loss_normal_depth = 0 loss_normal_depth_refined = 0 for i in range(batch_size): if not cfg.use_normal_refined_by_planes: _loss_normal_depth_01, mean_angle_depth_01 = surface_normal_loss(normal_from_depth_01[i:i + 1], gt_normal[i:i + 1, 0, :, :, :], gt_normal_valid[i:i + 1, 0, :, :, :]) _loss_normal_depth_02, mean_angle_depth_02 = surface_normal_loss(normal_from_depth_02[i:i + 1], gt_normal[i:i + 1, 0, :, :, :], gt_normal_valid[i:i + 1, 0, :, :, :]) _loss_normal_depth_refined, mean_angle_depth_refined = surface_normal_loss( normal_from_depth_refined[i:i + 1], gt_normal[i:i + 1, 0, :, :, :], gt_normal_valid[i:i + 1, 0, :, :, :]) _loss_normal_depth = (_loss_normal_depth_01 + _loss_normal_depth_02) * 0.5 mean_angle_depth = (mean_angle_depth_01 + mean_angle_depth_02 + mean_angle_depth_refined) / 3.0 else: _loss_normal_depth_01, mean_angle_depth_01 = surface_normal_loss(normal_from_depth_01[i:i + 1], normal_by_planes[i:i + 1, :, :, :], gt_normal_valid[i:i + 1, 0, :, :, :]) _loss_normal_depth_02, mean_angle_depth_02 = surface_normal_loss(normal_from_depth_02[i:i + 1], normal_by_planes[i:i + 1, :, :, :], gt_normal_valid[i:i + 1, 0, :, :, :]) _loss_normal_depth_refined, mean_angle_depth_refined = surface_normal_loss( normal_from_depth_refined[i:i + 1], normal_by_planes[i:i + 1, :, :, :], gt_normal_valid[i:i + 1, 0, :, :, :]) _loss_normal_depth = (_loss_normal_depth_01 + _loss_normal_depth_02) * 0.5 mean_angle_depth = (mean_angle_depth_01 + mean_angle_depth_02 + mean_angle_depth_refined) / 3.0 # planar segmentation iou loss_normal_depth += _loss_normal_depth loss_normal_depth_refined += _loss_normal_depth_refined # loss_pw += _pw_loss loss_depth /= batch_size loss_normal /= batch_size loss_normal_depth /= batch_size loss_normal_depth_refined /= batch_size normal_std /= batch_size # loss_pw /= batch_size loss += loss_idepth_1 loss += loss_idepth_234 if not ((~torch.isnan(loss_normal_depth)) & (~torch.isnan(loss_normal_depth_refined))): print('loss depth is nan') print(sample['filenames']) loss_train = loss_idepth_1 + loss_depth_1 + loss_depth_refined + loss_idepth_refined else: loss_train = loss_idepth_1 + loss_normal_depth + loss_depth_1 + loss_depth_refined + loss_idepth_refined + loss_normal_depth_refined loss_train += prob_loss ref_extrinsic = gt_cam[:, 0, 0, :, :] source1_extrinsic = gt_cam[:, 1, 0, :, :] pose1 = (source1_extrinsic @ torch.inverse(ref_extrinsic))[:, :3, :] warped_depth_loss_1 = get_warped_depth_loss(depth_refined, gt_depth[:, 1, 0, :, :], pose1, intrinsic, intrinsic_inv) source2_extrinsic = gt_cam[:, 2, 0, :, :] pose2 = (source2_extrinsic @ torch.inverse(ref_extrinsic))[:, :3, :] warped_depth_loss_2 = get_warped_depth_loss(depth_refined, gt_depth[:, 2, 0, :, :], pose2, intrinsic, intrinsic_inv) # if (epoch - start_epoch) < 2: # loss_train = loss_idepth_refined # elif (epoch - start_epoch) < 4: # loss_train = loss_idepth_refined + loss_depth_refined # elif (epoch - start_epoch) < 7: # loss_train = loss_idepth_refined + loss_depth_refined + loss_normal_depth_refined # else: # loss_train = loss_idepth_refined + loss_depth_refined + loss_normal_depth_refined + prob_map_loss loss_train += (warped_depth_loss_1 + warped_depth_loss_2) # Backward optimizer.zero_grad() loss_train.backward() # loss_idepth_refined.backward() optimizer.step() # update time batch_time.update(time.time() - tic) tic = time.time() if iter % cfg.print_interval == 0: _log.info(f"[{epoch:2d}][{iter:5d}/{len(data_loader):5d}] " f"Time: {batch_time.avg:.2f} " f"Loss: {loss_train.item():.4f} " f"Depth {loss_depth_1.item():.4f} " f"Depth_warp1 {warped_depth_loss_1.item():.4f} " f"Depth_warp2 {warped_depth_loss_2.item():.4f} " f"LN: {loss_normal_depth.item():.4f} " f"IDepth: {loss_idepth_1.item():.4f} " f"Depth_refined: {loss_depth_refined.item():.4f} " f"LN_refined: {loss_normal_depth_refined.item():.4f} " f"IDepth_refined: {loss_idepth_refined.item():.4f} " f"prob_loss: {prob_loss.item():.4f} " f"prob_loss_depth: {prob_loss_depth.item():.4f} " f"prob_loss_minusmean: {prob_loss_minusmean.item():.4f} " f"prob_map_loss: {prob_map_loss.item():.4f}" ) # ================================================================== # # Tensorboard Logging # # ================================================================== # with torch.no_grad(): # 1. Log scalar values (scalar summary) info = {'loss': loss_train.item(), 'loss_idepth': loss_idepth_1.item(), 'loss_depth': loss_depth_1.item(), 'loss_normal_depth': loss_normal_depth.item(), 'loss_idepth_refined': loss_idepth_refined.item(), 'loss_depth_refined': loss_depth_refined.item(), 'loss_normal_depth_refined': loss_normal_depth_refined.item(), 'prob_loss': prob_loss.item(), 'prob_loss_depth': prob_loss_depth.item(), 'prob_loss_minusmean': prob_loss_minusmean.item(), 'prob_map_loss': prob_map_loss.item()} for tag, value in info.items(): logger.scalar_summary(tag, value, global_step) if iter % (cfg.print_interval * 10) == 0: # 2. log histgrams info = {'prob_map_gt': prob_map_gt.cpu().numpy(), 'prob_map': prob_map.cpu().numpy(), 'diff': np.clip(torch.abs(depth_refined - gt_depth).cpu().numpy(), a_min=0.0, a_max=8.0)} for tag, values in info.items(): logger.histo_summary(tag, values, global_step) # 3. Log training images (image summary) gt_segmentation = gt_seg[:, 0, :, :] gt_segmentation += 1 gt_segmentation[gt_segmentation == 21] = 0 info = {'rgb': image[:, 0, :, :, :].permute(0, 2, 3, 1).cpu().numpy(), 'gt_normal': gt_normal[:, 0, :, :, :].permute(0, 2, 3, 1).cpu().numpy(), 'normal_by_planes': normal_by_planes.permute(0, 2, 3, 1).cpu().numpy(), 'plane_normal': normals_from_plane_para[:, 0, :, :, :].permute(0, 2, 3, 1).cpu().numpy(), 'normal_from_depth_01': normal_from_depth_01.permute(0, 2, 3, 1).cpu().numpy(), 'normal_from_depth_refined': normal_from_depth_refined.permute(0, 2, 3, 1).cpu().numpy(), 'gt_seg': np.stack([colors[gt_segmentation.cpu().numpy(), 0], colors[gt_segmentation.cpu().numpy(), 1], colors[gt_segmentation.cpu().numpy(), 2]], axis=3), 'gt_idepth': np2Depth(gt_disparity[:, 0, :, :, :].squeeze(1).cpu().numpy()), 'pred_idepth_01': np2Depth(idepth_preds_01[0].squeeze(1).cpu().numpy()), 'pred_idepth_refined': np2Depth(idepth_refined.squeeze(1).cpu().numpy()), 'prob_map_pred': colorize_probmap(prob_map.squeeze(1).cpu().numpy()), 'prob_map_gt': colorize_probmap(prob_map_gt.squeeze(1).cpu().numpy()) } for tag, images in info.items(): logger.image_summary(tag, images, global_step) # update global_step global_step = global_step + 1 # except: # print(sample['rgbs_filepath']) # exit(1) if iter % (len(data_loader) // 8) == 0: # save checkpoint torch.save({ 'epoch': epoch, 'global_step': global_step, 'depth_network_state_dict': depth_network.module.state_dict(), 'depth_refine_network_state_dict': depth_refine_network.module.state_dict(), 'optimizer': optimizer.state_dict()}, os.path.join(checkpoint_dir, f"network_epoch_{epoch}_scale_{int(cfg.idepth_scale)}.pt"))