def train(odometry_net, depth_net, train_loader, epoch, optimizer): global device odometry_net.set_fix_method(nfp.FIX_AUTO) odometry_net.train() depth_net.train() total_loss = 0 lr_total = 0 r12_total = 0 smooth_total = 0 for batch_idx, (img_R1, img_L2, img_R2, intrinsics, inv_intrinsics, raw_K, T_R2L) in tqdm(enumerate(train_loader), desc='Train epoch %d' % epoch, leave=False, ncols=80): img_R1 = img_R1.type(torch.FloatTensor).to(device) img_R2 = img_R2.type(torch.FloatTensor).to(device) img_L2 = img_L2.type(torch.FloatTensor).to(device) intrinsics = intrinsics.type(torch.FloatTensor).to(device) inv_intrinsics = inv_intrinsics.type(torch.FloatTensor).to(device) raw_K = raw_K.type(torch.FloatTensor).to(device) T_R2L = T_R2L.type(torch.FloatTensor).to(device) batch_size = img_R1.size(0) img_R = torch.cat((img_R2, img_R1), dim=1) K = torch.cat((raw_K, raw_K), dim=0) norm_img_L2 = 0.004 * img_L2 norm_img_R1 = 0.004 * img_R1 norm_img_R2 = 0.004 * img_R2 inv_depth_img_R2 = depth_net(img_R2) T_2to1, _ = odometry_net(img_R) T = torch.cat((T_R2L, T_2to1), dim=0) SE3 = generate_se3(T) inv_depth = torch.cat((inv_depth_img_R2, inv_depth_img_R2), dim=0) depth = (1 / (inv_depth + 1e-4)) pts3D = geo_transform(depth, SE3, K) proj_coords = pin_hole_project(pts3D, K) Isrc = torch.cat((norm_img_L2, norm_img_R1), dim=0) warp_Itgt = inverse_warp(Isrc, proj_coords) warp_Itgt_LR = warp_Itgt[:batch_size, :, :, :] warp_Itgt_R12 = warp_Itgt[batch_size:, :, :, :] out_of_bound = 1 - (warp_Itgt_LR == 0).prod( 1, keepdim=True).type_as(warp_Itgt_LR) diff_LR = (norm_img_R2 - warp_Itgt_LR) * out_of_bound LR_error = diff_LR.abs().mean() out_of_bound = 1 - (warp_Itgt_R12 == 0).prod( 1, keepdim=True).type_as(warp_Itgt_R12) diff_R12 = (norm_img_R2 - warp_Itgt_R12) * out_of_bound R12_error = diff_R12.abs().mean() smooth_error = smooth_loss(depth) loss = LR_error + R12_error + 10 * smooth_error total_loss += loss.item() lr_total += LR_error.item() r12_total += R12_error.item() smooth_total += smooth_error.item() optimizer.zero_grad() loss.backward() optimizer.step() print( "Train epoch {}: loss: {:.6f} LR-loss: {:.6f} R12-loss: {:.6f} smooth-loss: {:.6f}" .format(epoch, total_loss / len(train_loader), lr_total / len(train_loader), r12_total / len(train_loader), smooth_total / len(train_loader)))
def train(train_loader, mask_net, pose_net, optimizer, epoch_size, train_writer): global args, n_iter w1 = args.smooth_loss_weight w2 = args.mask_loss_weight w3 = args.consensus_loss_weight w4 = args.pose_loss_weight mask_net.train() pose_net.train() average_loss = 0 for i, (rgb_tgt_img, rgb_ref_imgs, depth_tgt_img, depth_ref_imgs, mask_tgt_img, mask_ref_imgs, intrinsics, intrinsics_inv, pose_list) in enumerate(tqdm(train_loader)): rgb_tgt_img_var = Variable(rgb_tgt_img.cuda()) rgb_ref_imgs_var = [Variable(img.cuda()) for img in rgb_ref_imgs] depth_tgt_img_var = Variable(depth_tgt_img.unsqueeze(1).cuda()) depth_ref_imgs_var = [ Variable(img.unsqueeze(1).cuda()) for img in depth_ref_imgs ] mask_tgt_img_var = Variable(mask_tgt_img.cuda()) mask_ref_imgs_var = [Variable(img.cuda()) for img in mask_ref_imgs] mask_tgt_img_var = torch.where(mask_tgt_img_var > 0, torch.ones_like(mask_tgt_img_var), torch.zeros_like(mask_tgt_img_var)) mask_ref_imgs_var = [ torch.where(img > 0, torch.ones_like(img), torch.zeros_like(img)) for img in mask_ref_imgs_var ] intrinsics_var = Variable(intrinsics.cuda()) intrinsics_inv_var = Variable(intrinsics_inv.cuda()) # pose_list_var = [Variable(one_pose.float().cuda()) for one_pose in pose_list] explainability_mask = mask_net(rgb_tgt_img_var, rgb_ref_imgs_var) # print(explainability_mask[0].size()) #torch.Size([4, 2, 384, 512]) # print() pose = pose_net(rgb_tgt_img_var, rgb_ref_imgs_var) # loss 1: smoothness loss loss1 = smooth_loss(explainability_mask) # loss 2: explainability loss loss2 = explainability_loss(explainability_mask) # loss 3 consensus loss (the mask from networks and the mask from residual) loss3 = consensus_loss(explainability_mask[0], mask_ref_imgs_var) # loss 4 pose loss valid_pixle_mask = [ torch.where(depth_ref_imgs_var[0] == 0, torch.zeros_like(depth_tgt_img_var), torch.ones_like(depth_tgt_img_var)), torch.where(depth_ref_imgs_var[1] == 0, torch.zeros_like(depth_tgt_img_var), torch.ones_like(depth_tgt_img_var)) ] # zero is invalid loss4, ref_img_warped, diff = pose_loss( valid_pixle_mask, mask_ref_imgs_var, rgb_tgt_img_var, rgb_ref_imgs_var, intrinsics_var, intrinsics_inv_var, depth_tgt_img_var, pose) # compute gradient and do Adam step loss = w1 * loss1 + w2 * loss2 + w3 * loss3 + w4 * loss4 average_loss += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() # visualization in tensorboard if i > 0 and n_iter % args.print_freq == 0: train_writer.add_scalar('smoothness loss', loss1.item(), n_iter) train_writer.add_scalar('explainability loss', loss2.item(), n_iter) train_writer.add_scalar('consensus loss', loss3.item(), n_iter) train_writer.add_scalar('pose loss', loss4.item(), n_iter) train_writer.add_scalar('total loss', loss.item(), n_iter) if n_iter % (args.training_output_freq) == 0: train_writer.add_image('train Input', tensor2array(rgb_tgt_img_var[0]), n_iter) train_writer.add_image( 'train Exp mask Outputs ', tensor2array(explainability_mask[0][0, 0].data.cpu(), max_value=1, colormap='bone'), n_iter) train_writer.add_image( 'train gt mask ', tensor2array(mask_tgt_img[0].data.cpu(), max_value=1, colormap='bone'), n_iter) train_writer.add_image( 'train depth ', tensor2array(depth_tgt_img[0].data.cpu(), max_value=1, colormap='bone'), n_iter) train_writer.add_image( 'train after mask', tensor2array(rgb_tgt_img_var[0] * explainability_mask[0][0, 0]), n_iter) train_writer.add_image('train diff', tensor2array(diff[0]), n_iter) train_writer.add_image('train warped img', tensor2array(ref_img_warped[0]), n_iter) n_iter += 1 return average_loss / i
def validate_without_gt(args, val_loader, disp_net, pose_exp_net, epoch, logger, output_writers=[]): batch_time = AverageMeter() losses = AverageMeter(i=3, precision=4) log_outputs = len(output_writers) > 0 w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight poses = np.zeros( ((len(val_loader) - 1) * args.batch_size * (args.sequence_length - 1), 6)) disp_values = np.zeros(((len(val_loader) - 1) * args.batch_size * 3)) # switch to evaluate mode disp_net.eval() pose_exp_net.eval() end = time.time() logger.valid_bar.update(0) for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(val_loader): tgt_img_var = Variable(tgt_img.cuda(), volatile=True) ref_imgs_var = [ Variable(img.cuda(), volatile=True) for img in ref_imgs ] intrinsics_var = Variable(intrinsics.cuda(), volatile=True) intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True) # compute output disp = disp_net(tgt_img_var) depth = 1 / disp explainability_mask, pose = pose_exp_net(tgt_img_var, ref_imgs_var) loss_1 = photometric_reconstruction_loss(tgt_img_var, ref_imgs_var, intrinsics_var, intrinsics_inv_var, depth, explainability_mask, pose, args.rotation_mode, args.padding_mode) loss_1 = loss_1.data[0] if w2 > 0: loss_2 = explainability_loss(explainability_mask).data[0] else: loss_2 = 0 loss_3 = smooth_loss(disp).data[0] if log_outputs and i % 100 == 0 and i / 100 < len( output_writers): # log first output of every 100 batch index = int(i // 100) if epoch == 0: for j, ref in enumerate(ref_imgs): output_writers[index].add_image('val Input {}'.format(j), tensor2array(tgt_img[0]), 0) output_writers[index].add_image('val Input {}'.format(j), tensor2array(ref[0]), 1) output_writers[index].add_image( 'val Dispnet Output Normalized', tensor2array(disp.data[0].cpu(), max_value=None, colormap='bone'), epoch) output_writers[index].add_image( 'val Depth Output', tensor2array(1. / disp.data[0].cpu(), max_value=10), epoch) # log warped images along with explainability mask for j, ref in enumerate(ref_imgs_var): ref_warped = inverse_warp(ref[:1], depth[:1, 0], pose[:1, j], intrinsics_var[:1], intrinsics_inv_var[:1], rotation_mode=args.rotation_mode, padding_mode=args.padding_mode)[0] output_writers[index].add_image( 'val Warped Outputs {}'.format(j), tensor2array(ref_warped.data.cpu()), epoch) output_writers[index].add_image( 'val Diff Outputs {}'.format(j), tensor2array( 0.5 * (tgt_img_var[0] - ref_warped).abs().data.cpu()), epoch) if explainability_mask is not None: output_writers[index].add_image( 'val Exp mask Outputs {}'.format(j), tensor2array(explainability_mask[0, j].data.cpu(), max_value=1, colormap='bone'), epoch) if log_outputs and i < len(val_loader) - 1: step = args.batch_size * (args.sequence_length - 1) poses[i * step:(i + 1) * step] = pose.data.cpu().view(-1, 6).numpy() step = args.batch_size * 3 disp_unraveled = disp.data.cpu().view(args.batch_size, -1) disp_values[i * step:(i + 1) * step] = torch.cat([ disp_unraveled.min(-1)[0], disp_unraveled.median(-1)[0], disp_unraveled.max(-1)[0] ]).numpy() loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 losses.update([loss, loss_1, loss_2]) # measure elapsed time batch_time.update(time.time() - end) end = time.time() logger.valid_bar.update(i + 1) if i % args.print_freq == 0: logger.valid_writer.write('valid: Time {} Loss {}'.format( batch_time, losses)) if log_outputs: prefix = 'valid poses' coeffs_names = ['tx', 'ty', 'tz'] if args.rotation_mode == 'euler': coeffs_names.extend(['rx', 'ry', 'rz']) elif args.rotation_mode == 'quat': coeffs_names.extend(['qx', 'qy', 'qz']) for i in range(poses.shape[1]): output_writers.add_histogram( '{} {}'.format(prefix, coeffs_names[i]), poses[:, i], epoch) output_writers[0].add_histogram('disp_values', disp_values, epoch) logger.valid_bar.update(len(val_loader)) return losses.avg, ['Total loss', 'Photo loss', 'Exp loss']
def train(train_loader, disp_net, pose_exp_net, optimizer, epoch_size, logger, train_writer): global args, n_iter batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight # switch to train mode disp_net.train() pose_exp_net.train() end = time.time() for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) tgt_img_var = Variable(tgt_img.cuda()) ref_imgs_var = [Variable(img.cuda()) for img in ref_imgs] intrinsics_var = Variable(intrinsics.cuda()) intrinsics_inv_var = Variable(intrinsics_inv.cuda()) # compute output disparities = disp_net(tgt_img_var) depth = [1 / disp for disp in disparities] explainability_mask, pose = pose_exp_net(tgt_img_var, ref_imgs_var) loss_1 = photometric_reconstruction_loss(tgt_img_var, ref_imgs_var, intrinsics_var, intrinsics_inv_var, depth, explainability_mask, pose) loss_2 = explainability_loss(explainability_mask) loss_3 = smooth_loss(disparities) loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 train_writer.add_scalar('photometric_error', loss_1.data[0], n_iter) train_writer.add_scalar('explanability_loss', loss_2.data[0], n_iter) train_writer.add_scalar('disparity_smoothness_loss', loss_3.data[0], n_iter) train_writer.add_scalar('total_loss', loss.data[0], n_iter) if n_iter % 200 == 0 and args.log_output: train_writer.add_image('train Input', tensor2array(ref_imgs[0][0]), n_iter - 1) train_writer.add_image('train Input', tensor2array(tgt_img[0]), n_iter) train_writer.add_image('train Input', tensor2array(ref_imgs[1][0]), n_iter + 1) for k, scaled_depth in enumerate(depth): train_writer.add_image( 'train Dispnet Output {}'.format(k), tensor2array(disparities[k].data[0].cpu(), max_value=10, colormap='bone'), n_iter) train_writer.add_image( 'train Depth Output Normalized {}'.format(k), tensor2array(1 / disparities[k].data[0].cpu(), max_value=None), n_iter) b, _, h, w = scaled_depth.size() downscale = tgt_img_var.size(2) / h tgt_img_scaled = nn.functional.adaptive_avg_pool2d( tgt_img_var, (h, w)) ref_imgs_scaled = [ nn.functional.adaptive_avg_pool2d(ref_img, (h, w)) for ref_img in ref_imgs_var ] intrinsics_scaled = torch.cat( (intrinsics_var[:, 0:2] / downscale, intrinsics_var[:, 2:]), dim=1) intrinsics_scaled_inv = torch.cat( (intrinsics_inv_var[:, :, 0:2] * downscale, intrinsics_inv_var[:, :, 2:]), dim=2) # log warped images along with explainability mask for j, ref in enumerate(ref_imgs_scaled): ref_warped = inverse_warp(ref, scaled_depth[:, 0], pose[:, j], intrinsics_scaled, intrinsics_scaled_inv)[0] train_writer.add_image( 'train Warped Outputs {} {}'.format(k, j), tensor2array(ref_warped.data.cpu(), max_value=1), n_iter) train_writer.add_image( 'train Diff Outputs {} {}'.format(k, j), tensor2array( 0.5 * (tgt_img_scaled[0] - ref_warped).abs().data.cpu()), n_iter) train_writer.add_image( 'train Exp mask Outputs {} {}'.format(k, j), tensor2array(explainability_mask[k][0, j].data.cpu(), max_value=1, colormap='bone'), n_iter) # record loss and EPE losses.update(loss.data[0], args.batch_size) # compute gradient and do Adam step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() with open(args.save_path / args.log_full, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow( [loss.data[0], loss_1.data[0], loss_2.data[0], loss_3.data[0]]) logger.train_bar.update(i) if i % args.print_freq == 0: logger.train_writer.write( 'Train: Time {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 'Data {data_time.val:.3f} ({data_time.avg:.3f}) ' 'Loss {loss.val:.4f} ({loss.avg:.4f}) '.format( batch_time=batch_time, data_time=data_time, loss=losses)) if i >= epoch_size - 1: break n_iter += 1 return losses.avg
def train(train_loader, mask_net, pose_net, flow_net, optimizer, epoch_size, train_writer): global args, n_iter w1 = args.smooth_loss_weight w2 = args.mask_loss_weight w3 = args.consensus_loss_weight w4 = args.flow_loss_weight mask_net.train() pose_net.train() flow_net.train() average_loss = 0 for i, (rgb_tgt_img, rgb_ref_imgs, depth_tgt_img, depth_ref_imgs, intrinsics, intrinsics_inv, pose_list) in enumerate(tqdm(train_loader)): rgb_tgt_img_var = Variable(rgb_tgt_img.cuda()) # print(rgb_tgt_img_var.size()) rgb_ref_imgs_var = [Variable(img.cuda()) for img in rgb_ref_imgs] # rgb_ref_imgs_var = [rgb_ref_imgs_var[0], rgb_ref_imgs_var[0], rgb_ref_imgs_var[1], rgb_ref_imgs_var[1]] depth_tgt_img_var = Variable(depth_tgt_img.unsqueeze(1).cuda()) depth_ref_imgs_var = [ Variable(img.unsqueeze(1).cuda()) for img in depth_ref_imgs ] intrinsics_var = Variable(intrinsics.cuda()) intrinsics_inv_var = Variable(intrinsics_inv.cuda()) # pose_list_var = [Variable(one_pose.float().cuda()) for one_pose in pose_list] explainability_mask = mask_net(rgb_tgt_img_var, rgb_ref_imgs_var) valid_pixle_mask = torch.where( depth_tgt_img_var == 0, torch.zeros_like(depth_tgt_img_var), torch.ones_like(depth_tgt_img_var)) # zero is invalid # print(depth_test[0].sum()) # print(explainability_mask[0].size()) #torch.Size([4, 2, 384, 512]) # print() pose = pose_net(rgb_tgt_img_var, rgb_ref_imgs_var) # generate flow from camera pose and depth flow_fwd, flow_bwd, _ = flow_net(rgb_tgt_img_var, rgb_ref_imgs_var) flows_cam_fwd = pose2flow(depth_ref_imgs_var[1].squeeze(1), pose[:, 1], intrinsics_var, intrinsics_inv_var) flows_cam_bwd = pose2flow(depth_ref_imgs_var[0].squeeze(1), pose[:, 0], intrinsics_var, intrinsics_inv_var) rigidity_mask_fwd = (flows_cam_fwd - flow_fwd[0]).abs() rigidity_mask_bwd = (flows_cam_bwd - flow_bwd[0]).abs() # loss 1: smoothness loss loss1 = smooth_loss(explainability_mask) + smooth_loss( flow_bwd) + smooth_loss(flow_fwd) # loss 2: explainability loss loss2 = explainability_loss(explainability_mask) # loss 3 consensus loss (the mask from networks and the mask from residual) depth_Res_mask, depth_ref_img_warped, depth_diff = depth_residual_mask( valid_pixle_mask, explainability_mask[0], rgb_tgt_img_var, rgb_ref_imgs_var, intrinsics_var, intrinsics_inv_var, depth_tgt_img_var, pose) # print(depth_Res_mask[0].size(), explainability_mask[0].size()) loss3 = consensus_loss(explainability_mask[0], rigidity_mask_bwd, rigidity_mask_fwd, args.THRESH, args.wbce) # loss 4: flow loss loss4, flow_ref_img_warped, flow_diff = flow_loss( rgb_tgt_img_var, rgb_ref_imgs_var, [flow_bwd, flow_fwd], explainability_mask) # compute gradient and do Adam step loss = w1 * loss1 + w2 * loss2 + w3 * loss3 + w4 * loss4 average_loss += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() # visualization in tensorboard if i > 0 and n_iter % args.print_freq == 0: train_writer.add_scalar('smoothness loss', loss1.item(), n_iter) train_writer.add_scalar('explainability loss', loss2.item(), n_iter) train_writer.add_scalar('consensus loss', loss3.item(), n_iter) train_writer.add_scalar('flow loss', loss4.item(), n_iter) train_writer.add_scalar('total loss', loss.item(), n_iter) if n_iter % (args.training_output_freq) == 0: train_writer.add_image('train Input', tensor2array(rgb_tgt_img_var[0]), n_iter) train_writer.add_image( 'train Exp mask Outputs ', tensor2array(explainability_mask[0][0, 0].data.cpu(), max_value=1, colormap='bone'), n_iter) train_writer.add_image( 'train depth Res mask ', tensor2array(depth_Res_mask[0][0].data.cpu(), max_value=1, colormap='bone'), n_iter) train_writer.add_image( 'train depth ', tensor2array(depth_tgt_img_var[0].data.cpu(), max_value=1, colormap='bone'), n_iter) train_writer.add_image( 'train valid pixel ', tensor2array(valid_pixle_mask[0].data.cpu(), max_value=1, colormap='bone'), n_iter) train_writer.add_image( 'train after mask', tensor2array(rgb_tgt_img_var[0] * explainability_mask[0][0, 0]), n_iter) train_writer.add_image('train depth diff', tensor2array(depth_diff[0]), n_iter) train_writer.add_image('train flow diff', tensor2array(flow_diff[0]), n_iter) train_writer.add_image('train depth warped img', tensor2array(depth_ref_img_warped[0]), n_iter) train_writer.add_image('train flow warped img', tensor2array(flow_ref_img_warped[0]), n_iter) train_writer.add_image( 'train Cam Flow Output', flow_to_image(tensor2array(flow_fwd[0].data[0].cpu())), n_iter) train_writer.add_image( 'train Flow from Depth Output', flow_to_image(tensor2array(flows_cam_fwd.data[0].cpu())), n_iter) train_writer.add_image( 'train Flow and Depth diff', flow_to_image(tensor2array(rigidity_mask_fwd.data[0].cpu())), n_iter) n_iter += 1 return average_loss / i
def train(args, train_loader, disp_net, pose_exp_net, optimizer, epoch_size, logger, tb_writer): global n_iter, device batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter(precision=4) w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight # switch to train mode disp_net.train() pose_exp_net.train() end = time.time() logger.train_bar.update(0) for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(train_loader): log_losses = i > 0 and n_iter % args.print_freq == 0 log_output = args.training_output_freq > 0 and n_iter % args.training_output_freq == 0 # measure data loading time data_time.update(time.time() - end) tgt_img = tgt_img.to(device) ref_imgs = [img.to(device) for img in ref_imgs] intrinsics = intrinsics.to(device) # compute output disparities = disp_net(tgt_img) depth = [1 / disp for disp in disparities] explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs) loss_1, warped, diff = photometric_reconstruction_loss( tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose, args.rotation_mode, args.padding_mode) if w2 > 0: loss_2 = explainability_loss(explainability_mask) else: loss_2 = 0 loss_3 = smooth_loss(depth) loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 if log_losses: tb_writer.add_scalar('photometric_error', loss_1.item(), n_iter) if w2 > 0: tb_writer.add_scalar('explanability_loss', loss_2.item(), n_iter) tb_writer.add_scalar('disparity_smoothness_loss', loss_3.item(), n_iter) tb_writer.add_scalar('total_loss', loss.item(), n_iter) if log_output: tb_writer.add_image('train Input', tensor2array(tgt_img[0]), n_iter) for k, scaled_maps in enumerate( zip(depth, disparities, warped, diff, explainability_mask)): log_output_tensorboard(tb_writer, "train", 0, " {}".format(k), n_iter, *scaled_maps) # record loss and EPE losses.update(loss.item(), args.batch_size) # compute gradient and do Adam step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() with open(args.save_path / args.log_full, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([ loss.item(), loss_1.item(), loss_2.item() if w2 > 0 else 0, loss_3.item() ]) logger.train_bar.update(i + 1) if i % args.print_freq == 0: logger.train_writer.write('Train: Time {} Data {} Loss {}'.format( batch_time, data_time, losses)) if i >= epoch_size - 1: break n_iter += 1 return losses.avg[0]
def validate_without_gt(args, val_loader, disp_net, pose_exp_net, epoch, logger, tb_writer, sample_nb_to_log=3): global device batch_time = AverageMeter() losses = AverageMeter(i=3, precision=4) log_outputs = sample_nb_to_log > 0 w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight poses = np.zeros( ((len(val_loader) - 1) * args.batch_size * (args.sequence_length - 1), 6)) disp_values = np.zeros(((len(val_loader) - 1) * args.batch_size * 3)) # switch to evaluate mode disp_net.eval() pose_exp_net.eval() end = time.time() logger.valid_bar.update(0) for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(val_loader): tgt_img = tgt_img.to(device) ref_imgs = [img.to(device) for img in ref_imgs] intrinsics = intrinsics.to(device) intrinsics_inv = intrinsics_inv.to(device) # compute output disp = disp_net(tgt_img) depth = 1 / disp explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs) loss_1, warped, diff = photometric_reconstruction_loss( tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose, args.rotation_mode, args.padding_mode) loss_1 = loss_1.item() if w2 > 0: loss_2 = explainability_loss(explainability_mask).item() else: loss_2 = 0 loss_3 = smooth_loss(depth).item() if log_outputs and i < sample_nb_to_log - 1: # log first output of first batches if epoch == 0: for j, ref in enumerate(ref_imgs): tb_writer.add_image('val Input {}/{}'.format(j, i), tensor2array(tgt_img[0]), 0) tb_writer.add_image('val Input {}/{}'.format(j, i), tensor2array(ref[0]), 1) log_output_tensorboard(tb_writer, 'val', i, '', epoch, 1. / disp, disp, warped[0], diff[0], explainability_mask) if log_outputs and i < len(val_loader) - 1: step = args.batch_size * (args.sequence_length - 1) poses[i * step:(i + 1) * step] = pose.cpu().view(-1, 6).numpy() step = args.batch_size * 3 disp_unraveled = disp.cpu().view(args.batch_size, -1) disp_values[i * step:(i + 1) * step] = torch.cat([ disp_unraveled.min(-1)[0], disp_unraveled.median(-1)[0], disp_unraveled.max(-1)[0] ]).numpy() loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 if args.with_photocon_loss: batch_size = pose.size()[0] homo_row = torch.tensor([[0, 0, 0, 1]], dtype=torch.float).to(device) homo_row = homo_row.unsqueeze(0).expand(batch_size, -1, -1) T21 = pose_vec2mat(pose[:, 0]) T21 = torch.cat((T21, homo_row), 1) T12 = torch.inverse(T21) T23 = pose_vec2mat(pose[:, 1]) T23 = torch.cat((T23, homo_row), 1) T13 = torch.matmul(T23, T12) #[B,4,4] # print("----",T13.size()) # target = 1(ref_imgs[0]) and ref = 3(ref_imgs[1]) ref_img_warped, valid_points = inverse_warp_posemat( ref_imgs[1], depth[:, 0], T13, intrinsics, args.rotation_mode, args.padding_mode) diff = (ref_imgs[0] - ref_img_warped) * valid_points.unsqueeze(1).float() loss_4 = diff.abs().mean() loss += loss_4 losses.update([loss, loss_1, loss_2]) # measure elapsed time batch_time.update(time.time() - end) end = time.time() logger.valid_bar.update(i + 1) if i % args.print_freq == 0: logger.valid_writer.write('valid: Time {} Loss {}'.format( batch_time, losses)) if log_outputs: prefix = 'valid poses' coeffs_names = ['tx', 'ty', 'tz'] if args.rotation_mode == 'euler': coeffs_names.extend(['rx', 'ry', 'rz']) elif args.rotation_mode == 'quat': coeffs_names.extend(['qx', 'qy', 'qz']) for i in range(poses.shape[1]): tb_writer.add_histogram('{} {}'.format(prefix, coeffs_names[i]), poses[:, i], epoch) tb_writer.add_histogram('disp_values', disp_values, epoch) logger.valid_bar.update(len(val_loader)) return losses.avg, [ 'Validation Total loss', 'Validation Photo loss', 'Validation Exp loss' ]
def validate_without_gt(args, val_loader, disp_net, pose_exp_net, epoch, tb_writer, sample_nb_to_log=3): global device batch_time = AverageMeter() losses = AverageMeter(i=3, precision=4) log_outputs = sample_nb_to_log > 0 w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight poses = np.zeros( ((len(val_loader) - 1) * args.batch_size * (args.sequence_length - 1), 6)) disp_values = np.zeros(((len(val_loader) - 1) * args.batch_size * 3)) # switch to evaluate mode disp_net.eval() pose_exp_net.eval() end = time.time() validate_pbar = tqdm( total=len(val_loader), bar_format='{desc} {percentage:3.0f}%|{bar}| {postfix}') validate_pbar.set_description( 'valid: Loss *.**** *.*****.****(*.**** *.**** *.****)') validate_pbar.set_postfix_str('<Time *.***(*.***)>') for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(val_loader): tgt_img = tgt_img.to(device) ref_imgs = [img.to(device) for img in ref_imgs] intrinsics = intrinsics.to(device) intrinsics_inv = intrinsics_inv.to(device) # compute output disp = disp_net(tgt_img) depth = 1 / disp explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs) loss_1, warped, diff = photometric_reconstruction_loss( tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose, args.rotation_mode, args.padding_mode) loss_1 = loss_1.item() if w2 > 0: loss_2 = explainability_loss(explainability_mask).item() else: loss_2 = 0 loss_3 = smooth_loss(depth).item() if log_outputs and i < sample_nb_to_log - 1: # log first output of first batches if epoch == 0: for j, ref in enumerate(ref_imgs): tb_writer.add_image('val Input {}/{}'.format(j, i), tensor2array(tgt_img[0]), 0) tb_writer.add_image('val Input {}/{}'.format(j, i), tensor2array(ref[0]), 1) log_output_tensorboard(tb_writer, 'val', i, '', epoch, 1. / disp, disp, warped, diff, explainability_mask) if log_outputs and i < len(val_loader) - 1: step = args.batch_size * (args.sequence_length - 1) poses[i * step:(i + 1) * step] = pose.cpu().view(-1, 6).numpy() step = args.batch_size * 3 disp_unraveled = disp.cpu().view(args.batch_size, -1) disp_values[i * step:(i + 1) * step] = torch.cat([ disp_unraveled.min(-1)[0], disp_unraveled.median(-1)[0], disp_unraveled.max(-1)[0] ]).numpy() loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 losses.update([loss, loss_1, loss_2]) # measure elapsed time batch_time.update(time.time() - end) end = time.time() validate_pbar.clear() validate_pbar.update(1) validate_pbar.set_description('valid: Loss {}'.format(losses)) validate_pbar.set_postfix_str('<Time {}>'.format(batch_time)) validate_pbar.close() if log_outputs: prefix = 'valid poses' coeffs_names = ['tx', 'ty', 'tz'] if args.rotation_mode == 'euler': coeffs_names.extend(['rx', 'ry', 'rz']) elif args.rotation_mode == 'quat': coeffs_names.extend(['qx', 'qy', 'qz']) for i in range(poses.shape[1]): tb_writer.add_histogram('{} {}'.format(prefix, coeffs_names[i]), poses[:, i], epoch) tb_writer.add_histogram('disp_values', disp_values, epoch) time.sleep(0.2) else: time.sleep(1) return losses.avg, ['Total loss', 'Photo loss', 'Exp loss']
def validate_without_gt(args, val_loader, disp_net, pose_exp_net, epoch, logger, tb_writer, sample_nb_to_log=3): global device mse_l = torch.nn.MSELoss(reduction='mean') batch_time = AverageMeter() losses = AverageMeter(i=3, precision=4) log_outputs = sample_nb_to_log > 0 # Output the logs throughout the whole dataset batches_to_log = list( np.linspace(0, len(val_loader), sample_nb_to_log).astype(int)) w1, w2, w3, wf, wp = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight, args.flow_loss_weight, args.prior_loss_weight poses = np.zeros( ((len(val_loader) - 1) * args.batch_size * (args.sequence_length - 1), 6)) disp_values = np.zeros(((len(val_loader) - 1) * args.batch_size * 3)) # switch to evaluate mode disp_net.eval() pose_exp_net.eval() end = time.time() logger.valid_bar.update(0) for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv, flow_maps) in enumerate(val_loader): tgt_img = tgt_img.to(device) ref_imgs = [img.to(device) for img in ref_imgs] intrinsics = intrinsics.to(device) intrinsics_inv = intrinsics_inv.to(device) flow_maps = [flow_map.to(device) for flow_map in flow_maps] # compute output disp = disp_net(tgt_img) depth = 1 / disp explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs) loss_1, warped, diff, grid = photometric_reconstruction_loss( tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose, args.rotation_mode, args.padding_mode) loss_1 = loss_1.item() if wf > 0: loss_f = flow_consistency_loss(grid, flow_maps, mse_l) else: loss_f = 0 if wp > 0: loss_p = ground_prior_loss(disp) else: loss_p = 0 if w2 > 0: loss_2 = explainability_loss(explainability_mask).item() else: loss_2 = 0 loss_3 = smooth_loss(depth).item() if log_outputs and i in batches_to_log: # log first output of wanted batches index = batches_to_log.index(i) if epoch == 0: for j, ref in enumerate(ref_imgs): tb_writer.add_image('val Input {}/{}'.format(j, index), tensor2array(tgt_img[0]), 0) tb_writer.add_image('val Input {}/{}'.format(j, index), tensor2array(ref[0]), 1) log_output_tensorboard(tb_writer, 'val', index, '', epoch, 1. / disp, disp, warped[0], diff[0], explainability_mask) if log_outputs and i < len(val_loader) - 1: step = args.batch_size * (args.sequence_length - 1) poses[i * step:(i + 1) * step] = pose.cpu().view(-1, 6).numpy() step = args.batch_size * 3 disp_unraveled = disp.cpu().view(args.batch_size, -1) disp_values[i * step:(i + 1) * step] = torch.cat([ disp_unraveled.min(-1)[0], disp_unraveled.median(-1)[0], disp_unraveled.max(-1)[0] ]).numpy() loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 + wf * loss_f + wp * loss_p losses.update([loss, loss_1, loss_2]) # measure elapsed time batch_time.update(time.time() - end) end = time.time() logger.valid_bar.update(i + 1) if i % args.print_freq == 0: logger.valid_writer.write('valid: Time {} Loss {}'.format( batch_time, losses)) if log_outputs: prefix = 'valid poses' coeffs_names = ['tx', 'ty', 'tz'] if args.rotation_mode == 'euler': coeffs_names.extend(['rx', 'ry', 'rz']) elif args.rotation_mode == 'quat': coeffs_names.extend(['qx', 'qy', 'qz']) for i in range(poses.shape[1]): tb_writer.add_histogram('{} {}'.format(prefix, coeffs_names[i]), poses[:, i], epoch) tb_writer.add_histogram('disp_values', disp_values, epoch) logger.valid_bar.update(len(val_loader)) return losses.avg, [ 'Validation Total loss', 'Validation Photo loss', 'Validation Exp loss' ]
def validate_without_gt(val_loader, disp_net, pose_net, mask_net, flow_net, epoch, logger, tb_writer, nb_writers, global_vars_dict=None): #data prepared device = global_vars_dict['device'] n_iter_val = global_vars_dict['n_iter_val'] args = global_vars_dict['args'] show_samples = copy.deepcopy(args.show_samples) for i in range(len(show_samples)): show_samples[i] *= len(val_loader) show_samples[i] = show_samples[i] // 1 batch_time = AverageMeter() data_time = AverageMeter() log_outputs = nb_writers > 0 losses = AverageMeter(precision=4) w1, w2, w3, w4 = args.cam_photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight, args.flow_photo_loss_weight w5 = args.consensus_loss_weight loss_camera = photometric_reconstruction_loss loss_flow = photometric_flow_loss # to eval model disp_net.eval() pose_net.eval() mask_net.eval() flow_net.eval() end = time.time() poses = np.zeros( ((len(val_loader) - 1) * 1 * (args.sequence_length - 1), 6)) #init #3. validation cycle for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(val_loader): data_time.update(time.time() - end) tgt_img = tgt_img.to(device) ref_imgs = [img.to(device) for img in ref_imgs] intrinsics, intrinsics_inv = intrinsics.to(device), intrinsics_inv.to( device) #3.1 forwardpass #disp disp = disp_net(tgt_img) if args.spatial_normalize: disp = spatial_normalize(disp) depth = 1 / disp #pose pose = pose_net(tgt_img, ref_imgs) #[b,3,h,w]; list #flow---- #制作前后一帧的 if args.flownet == 'Back2Future': flow_fwd, flow_bwd, _ = flow_net(tgt_img, ref_imgs[1:3]) elif args.flownet == 'FlowNetC6': flow_fwd = flow_net(tgt_img, ref_imgs[2]) flow_bwd = flow_net(tgt_img, ref_imgs[1]) flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics, intrinsics_inv) flows_cam_fwd = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics, intrinsics_inv) flows_cam_bwd = pose2flow(depth.squeeze(1), pose[:, 1], intrinsics, intrinsics_inv) exp_masks_target = consensus_exp_masks(flows_cam_fwd, flows_cam_bwd, flow_fwd, flow_bwd, tgt_img, ref_imgs[2], ref_imgs[1], wssim=args.wssim, wrig=args.wrig, ws=args.smooth_loss_weight) no_rigid_flow = flow_fwd - flows_cam_fwd rigidity_mask_fwd = (flows_cam_fwd - flow_fwd).abs() #[b,2,h,w] rigidity_mask_bwd = (flows_cam_bwd - flow_bwd).abs() # mask # 4.explainability_mask(none) explainability_mask = mask_net(tgt_img, ref_imgs) # 有效区域?4?? # list(5):item:tensor:[4,4,128,512]...[4,4,4,16] value:[0.33~0.48~0.63] if args.joint_mask_for_depth: # false explainability_mask_for_depth = explainability_mask #explainability_mask_for_depth = compute_joint_mask_for_depth(explainability_mask, rigidity_mask_bwd, # rigidity_mask_fwd,THRESH=args.THRESH) else: explainability_mask_for_depth = explainability_mask # chage if args.no_non_rigid_mask: flow_exp_mask = None if args.DEBUG: print('Using no masks for flow') else: flow_exp_mask = 1 - explainability_mask[:, 1:3] #3.2loss-compute if w1 > 0: loss_1 = loss_camera(tgt_img, ref_imgs, intrinsics, intrinsics_inv, depth, explainability_mask_for_depth, pose, lambda_oob=args.lambda_oob, qch=args.qch, wssim=args.wssim) else: loss_1 = torch.tensor([0.]).to(device) # E_M if w2 > 0: loss_2 = explainability_loss( explainability_mask ) # + 0.2*gaussian_explainability_loss(explainability_mask) else: loss_2 = 0 #if args.smoothness_type == "regular": if w3 > 0: loss_3 = smooth_loss(depth) + smooth_loss( explainability_mask) + smooth_loss(flow_fwd) + smooth_loss( flow_bwd) else: loss_3 = torch.tensor([0.]).to(device) if w4 > 0: loss_4 = loss_flow(tgt_img, ref_imgs[1:3], [flow_bwd, flow_fwd], flow_exp_mask, lambda_oob=args.lambda_oob, qch=args.qch, wssim=args.wssim) else: loss_4 = torch.tensor([0.]).to(device) if w5 > 0: loss_5 = consensus_depth_flow_mask(explainability_mask, rigidity_mask_bwd, rigidity_mask_fwd, exp_masks_target, exp_masks_target, THRESH=args.THRESH, wbce=args.wbce) else: loss_5 = torch.tensor([0.]).to(device) loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 + w4 * loss_4 + w5 * loss_5 #3.3 data update losses.update(loss.item(), args.batch_size) batch_time.update(time.time() - end) end = time.time() #3.4 check log #查看forward pass效果 if args.img_freq > 0 and i in show_samples: #output_writers list(3) if epoch == 0: #训练前的validate,目的在于先评估下网络效果 #1.img # 不会执行第二次,注意ref_imgs axis0是batch的索引; axis 1是list(adjacent frame)的索引! tb_writer.add_image( 'epoch 0 Input/sample{}(img{} to img{})'.format( i, i + 1, i + args.sequence_length), tensor2array(ref_imgs[0][0]), 0) tb_writer.add_image( 'epoch 0 Input/sample{}(img{} to img{})'.format( i, i + 1, i + args.sequence_length), tensor2array(ref_imgs[1][0]), 1) tb_writer.add_image( 'epoch 0 Input/sample{}(img{} to img{})'.format( i, i + 1, i + args.sequence_length), tensor2array(tgt_img[0]), 2) tb_writer.add_image( 'epoch 0 Input/sample{}(img{} to img{})'.format( i, i + 1, i + args.sequence_length), tensor2array(ref_imgs[2][0]), 3) tb_writer.add_image( 'epoch 0 Input/sample{}(img{} to img{})'.format( i, i + 1, i + args.sequence_length), tensor2array(ref_imgs[3][0]), 4) depth_to_show = depth[0].cpu( ) # tensor disp_to_show :[1,h,w],0.5~3.1~10 tb_writer.add_image( 'Disp Output/sample{}'.format(i), tensor2array(depth_to_show, max_value=None, colormap='bone'), 0) else: #2.disp depth_to_show = disp[0].cpu( ) # tensor disp_to_show :[1,h,w],0.5~3.1~10 tb_writer.add_image( 'Disp Output/sample{}'.format(i), tensor2array(depth_to_show, max_value=None, colormap='bone'), epoch) #3. flow tb_writer.add_image('Flow/Flow Output sample {}'.format(i), flow2rgb(flow_fwd[0], max_value=6), epoch) tb_writer.add_image('Flow/cam_Flow Output sample {}'.format(i), flow2rgb(flow_cam[0], max_value=6), epoch) tb_writer.add_image( 'Flow/no rigid flow Output sample {}'.format(i), flow2rgb(no_rigid_flow[0], max_value=6), epoch) tb_writer.add_image( 'Flow/rigidity_mask_fwd{}'.format(i), flow2rgb(rigidity_mask_fwd[0], max_value=6), epoch) #4. mask tb_writer.add_image( 'Mask Output/mask0 sample{}'.format(i), tensor2array(explainability_mask[0][0], max_value=None, colormap='magma'), epoch) #tb_writer.add_image('Mask Output/mask1 sample{}'.format(i),tensor2array(explainability_mask[1][0], max_value=None, colormap='magma'), epoch) #tb_writer.add_image('Mask Output/mask2 sample{}'.format(i),tensor2array(explainability_mask[2][0], max_value=None, colormap='magma'), epoch) #tb_writer.add_image('Mask Output/mask3 sample{}'.format(i),tensor2array(explainability_mask[3][0], max_value=None, colormap='magma'), epoch) tb_writer.add_image( 'Mask Output/exp_masks_target sample{}'.format(i), tensor2array(exp_masks_target[0][0], max_value=None, colormap='magma'), epoch) #tb_writer.add_image('Mask Output/mask0 sample{}'.format(i), # tensor2array(explainability_mask[0][0], max_value=None, colormap='magma'), epoch) # #output_writers[index].add_image('val Depth Output', tensor2array(depth.data[0].cpu(), max_value=10), # epoch) # errors.update(compute_errors(depth, output_depth.data.squeeze(1))) # add scalar if args.scalar_freq > 0 and n_iter_val % args.scalar_freq == 0: tb_writer.add_scalar('val/E_R', loss_1.item(), n_iter_val) if w2 > 0: tb_writer.add_scalar('val/E_M', loss_2.item(), n_iter_val) tb_writer.add_scalar('val/E_S', loss_3.item(), n_iter_val) tb_writer.add_scalar('val/E_F', loss_4.item(), n_iter_val) tb_writer.add_scalar('val/E_C', loss_5.item(), n_iter_val) tb_writer.add_scalar('val/total_loss', loss.item(), n_iter_val) # terminal output if args.log_terminal: logger.valid_bar.update(i + 1) # 当前epoch 进度 if i % args.print_freq == 0: logger.valid_bar_writer.write( 'Valid: Time {} Data {} Loss {}'.format( batch_time, data_time, losses)) n_iter_val += 1 global_vars_dict['n_iter_val'] = n_iter_val return losses.avg[0] #epoch validate loss
def train(args, train_loader, disp_net, pose_exp_net, optimizer, epoch_size, tb_writer): global n_iter, device batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter(precision=4) w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight # switch to train mode disp_net.train() pose_exp_net.train() end = time.time() train_pbar = tqdm(total=min(len(train_loader), args.epoch_size), bar_format='{desc} {percentage:3.0f}%|{bar}| {postfix}') train_pbar.set_description('Train: Total Loss=#.####(#.####)') train_pbar.set_postfix_str('<TIME: op=#.###(#.###) DataFlow=#.###(#.###)>') for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(train_loader): log_losses = i > 0 and n_iter % args.print_freq == 0 log_output = args.training_output_freq > 0 and n_iter % args.training_output_freq == 0 # measure DataFlow loading time data_time.update(time.time() - end) tgt_img = tgt_img.to(device) ref_imgs = [img.to(device) for img in ref_imgs] intrinsics = intrinsics.to(device) # compute output disparities = disp_net(tgt_img) depth = [1 / disp for disp in disparities] explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs) loss_1, warped, diff = photometric_reconstruction_loss( tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose, args.rotation_mode, args.padding_mode) if w2 > 0: loss_2 = explainability_loss(explainability_mask) else: loss_2 = 0 loss_3 = smooth_loss(depth) loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 if loss < 0.0005: abc = 0 if log_losses: tb_writer.add_scalar('photometric_error', loss_1.item(), n_iter) if w2 > 0: tb_writer.add_scalar('explanability_loss', loss_2.item(), n_iter) tb_writer.add_scalar('disparity_smoothness_loss', loss_3.item(), n_iter) tb_writer.add_scalar('total_loss', loss.item(), n_iter) if log_output: tb_writer.add_image('train Input', tensor2array(tgt_img[0]), n_iter) for k, scaled_maps in enumerate( zip(depth, disparities, warped, diff, explainability_mask)): log_output_tensorboard(tb_writer, "train", 0, k, n_iter, *scaled_maps) # record loss and EPE losses.update(loss.item(), args.batch_size) # compute gradient and do Adam step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() with open(args.save_path / args.log_full, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([ loss.item(), loss_1.item(), loss_2.item() if w2 > 0 else 0, loss_3.item() ]) train_pbar.clear() train_pbar.update(1) train_pbar.set_description('Train: Total Loss={}'.format(losses)) train_pbar.set_postfix_str('<TIME: op={} DataFlow={}>'.format( batch_time, data_time)) if i >= epoch_size - 1: break n_iter += 1 train_pbar.close() time.sleep(1) return losses.avg[0]
def validate_without_gt(val_loader, disp_net, pose_net, mask_net, epoch, logger, output_writers=[]): #data prepared global args, device, n_iter_val batch_time = AverageMeter() data_time = AverageMeter() log_outputs = len(output_writers) > 0 losses = AverageMeter(precision=4) w1, w2, w3, w4 = args.cam_photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight, args.flow_photo_loss_weight loss_camera = photometric_reconstruction_loss loss_flow = photometric_flow_loss # to eval model disp_net.eval() pose_net.eval() mask_net.eval() #flow_net.eval() end = time.time() poses = np.zeros( ((len(val_loader) - 1) * 1 * (args.sequence_length - 1), 6)) #init #3. validation cycle for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(val_loader): data_time.update(time.time() - end) tgt_img = tgt_img.to(device) ref_imgs = [img.to(device) for img in ref_imgs] intrinsics, intrinsics_inv = intrinsics.to(device), intrinsics_inv.to( device) #3.1 forwardpass #disp disp = disp_net(tgt_img) if args.spatial_normalize: disp = spatial_normalize(disp) depth = 1 / disp #pose pose = pose_net(tgt_img, ref_imgs) #mask # 4.explainability_mask(none) explainability_mask = mask_net(tgt_img, ref_imgs) # 有效区域?4?? # list(5):item:tensor:[4,4,128,512]...[4,4,4,16] value:[0.33~0.48~0.63] # ------------------------------------------------- if args.joint_mask_for_depth: explainability_mask_for_depth = compute_joint_mask_for_depth( explainability_mask, rigidity_mask_bwd, rigidity_mask_fwd) else: explainability_mask_for_depth = explainability_mask #3.2loss-compute loss_1 = loss_camera(tgt_img, ref_imgs, intrinsics, intrinsics_inv, depth, explainability_mask_for_depth, pose, lambda_oob=args.lambda_oob, qch=args.qch, wssim=args.wssim) # E_M if w2 > 0: loss_2 = explainability_loss( explainability_mask ) # + 0.2*gaussian_explainability_loss(explainability_mask) else: loss_2 = 0 #if args.smoothness_type == "regular": loss_3 = smooth_loss(depth) + smooth_loss(explainability_mask) loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 #3.3 data update losses.update(loss.item(), args.batch_size) batch_time.update(time.time() - end) end = time.time() #3.4 check log #查看forward pass效果 if log_outputs and i % 40 == 0 and i / 100 < len( output_writers): #output_writers list(3) index = int(i // 40) #disp = disp.data.cpu()[0] #disp = (255 * tensor2array(disp, max_value=None, colormap='bone')).astype(np.uint8) #disp = disp.transpose(1, 2, 0) if epoch == 0: output_writers[index].add_image('val Input', tensor2array(tgt_img[0]), 0) disp_to_show = disp[0].cpu( ) # tensor disp_to_show :[1,h,w],0.5~3.1~10 output_writers[index].add_image( 'val target disp222', tensor2array(disp_to_show, max_value=None, colormap='magma'), epoch) save = (255 * tensor2array( disp_to_show, max_value=None, colormap='magma')).astype( np.uint8) save = save.transpose(1, 2, 0) plt.imsave('ep1_test.jpg', save, cmap='plasma') # depth_to_show[depth_to_show == 0] = 1000 # disp_to_show = (1 / depth_to_show).clamp(0, 10) # output_writers[index].add_image('val target Disparity Normalized', # tensor2array(disp_to_show, max_value=None, colormap='bone'), epoch) output_writers[index].add_image( 'val Dispnet Output Normalized123', tensor2array(disp.data[0].cpu(), max_value=None, colormap='bone'), epoch) #output_writers[index].add_image('val Depth Output', tensor2array(depth.data[0].cpu(), max_value=10), # epoch) # errors.update(compute_errors(depth, output_depth.data.squeeze(1))) # add scalar if args.scalar_freq > 0 and n_iter_val % args.scalar_freq == 0: output_writers[0].add_scalar('val/cam_photometric_error', loss_1.item(), n_iter_val) if w2 > 0: output_writers[0].add_scalar('val/explanability_loss', loss_2.item(), n_iter_val) output_writers[0].add_scalar('val/disparity_smoothness_loss', loss_3.item(), n_iter_val) #output_writers[0].add_scalar('batch/flow_photometric_error', loss_4.item(), n_iter) #output_writers.add_scalar('batch/consensus_error', loss_5.item(), n_iter) output_writers[0].add_scalar('val/total_loss', loss.item(), n_iter_val) # terminal output if args.log_terminal: logger.valid_bar.update(i + 1) # 当前epoch 进度 if i % args.print_freq == 0: logger.valid_bar_writer.write( 'Valid: Time {} Data {} Loss {}'.format( batch_time, data_time, losses)) n_iter_val += 1 return losses.avg[0] #epoch validate loss
def train(odometry_net, depth_net, feat_extractor, train_loader, epoch, optimizer): global device global data_parallel if data_parallel: odometry_net.module.set_fix_method(nfp.FIX_AUTO) else: odometry_net.set_fix_method(nfp.FIX_AUTO) odometry_net.train() depth_net.train() feat_extractor.train() total_loss = 0 img_reconstruction_total = 0 f_reconstruction_total = 0 smooth_total = 0 for batch_idx, (img_R1, img_L2, img_R2, intrinsics, inv_intrinsics, raw_K, T_R2L) in tqdm(enumerate(train_loader), desc='Train epoch %d' % epoch, leave=False, ncols=80): img_R1 = img_R1.type(torch.FloatTensor).to(device) img_R2 = img_R2.type(torch.FloatTensor).to(device) img_L2 = img_L2.type(torch.FloatTensor).to(device) intrinsics = intrinsics.type(torch.FloatTensor).to(device) inv_intrinsics = inv_intrinsics.type(torch.FloatTensor).to(device) raw_K = raw_K.type(torch.FloatTensor).to(device) T_R2L = T_R2L.type(torch.FloatTensor).to(device) img_R = torch.cat((img_R2, img_R1), dim=1) inv_depth_img_R2 = depth_net(img_R2) T_2to1, _ = odometry_net(img_R) T_2to1 = T_2to1.view(T_2to1.size(0), -1) T_R2L = T_R2L.view(T_R2L.size(0), -1) depth = (1 / (inv_depth_img_R2 + 1e-4)).squeeze(1) img_reconstruction_error = photometric_reconstruction_loss( 0.004 * img_R2, 0.004 * img_R1, 0.004 * img_L2, depth, T_2to1, T_R2L, intrinsics, inv_intrinsics) smooth_error = smooth_loss(depth.unsqueeze(1)) imgs = torch.cat((img_L2, img_R2, img_R1), dim=0) feat = feat_extractor(imgs) batch_size = img_R1.size(0) f_L2, f_R2, f_R1 = feat[:batch_size, :, :, :], feat[ batch_size:batch_size * 2, :, :, :], feat[2 * batch_size:, :, :, :] feat_reconstruction_error = photometric_reconstruction_loss( f_R2, f_R1, f_L2, depth, T_2to1, T_R2L, intrinsics, inv_intrinsics) loss = img_reconstruction_error + 0.1 * feat_reconstruction_error + 10 * smooth_error total_loss += loss.item() img_reconstruction_total += img_reconstruction_error.item() f_reconstruction_total += feat_reconstruction_error.item() smooth_total += smooth_error.item() optimizer.zero_grad() loss.backward() optimizer.step() print( "Train epoch {}: loss: {:.9f} img-recon-loss: {:.9f} f-recon-loss: {:.9f} smooth-loss: {:.9f}" .format(epoch, total_loss / len(train_loader), img_reconstruction_total / len(train_loader), f_reconstruction_total / len(train_loader), smooth_total / len(train_loader)))
def train(args, train_loader, disp_net, pose_exp_net, optimizer, epoch_size, logger, train_writer): global n_iter, device batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter(precision=4) w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight # switch to train mode disp_net.train() pose_exp_net.train() end = time.time() logger.train_bar.update(0) #train main cycle for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(train_loader): #for (i, data) in enumerate(train_loader):#data(list): [tensor(B,3,H,W),list(B),(B,H,W),(b,h,w)] log_losses = i > 0 and n_iter % args.print_freq == 0 log_output = args.training_output_freq > 0 and n_iter % args.training_output_freq == 0 #1 measure data loading time data_time.update(time.time() - end) tgt_img = tgt_img.to(device) #(4,3,128,416) ref_imgs = [img.to(device) for img in ref_imgs] #batch size张图片的前一帧和后一帧 intrinsics = intrinsics.to(device) #(4,3,3) """forward and loss""" #2 compute output disparities = disp_net( tgt_img ) # lenth batch-size list of tensor(4,1,128,416) ,(4,1,64,208),(4,1,32,104),(4,1,16,52)] explainability_mask, pose = pose_exp_net( tgt_img, ref_imgs) #pose tensor(bs,sq-lenth-1,6), relative camera pose depth = [1 / disp for disp in disparities] #depth = fxT/(d) 成反比关系,简单取倒数 #3 loss compute loss_1, warped, diff = photometric_reconstruction_loss( tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose, args.rotation_mode, args.padding_mode) if w2 > 0: loss_2 = explainability_loss(explainability_mask) else: loss_2 = 0 loss_3 = smooth_loss(depth) loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 #4. 数据记录 tensorboard batch-record data, 而且不用初始化数据名称(自动初始化),直接往里面加 if log_losses: train_writer.add_scalar('photometric_error', loss_1.item(), n_iter) if w2 > 0: train_writer.add_scalar('explanabilityyyyyy_loss', loss_2.item(), n_iter) train_writer.add_scalar('disparity_smoothness_loss', loss_3.item(), n_iter) train_writer.add_scalar('total_loss', loss.item(), n_iter) if log_output: #数据弄到tensorboard可读文件里去, 名字就是events开头(defaulted) train_writer.add_image('train Input', tensor2array(tgt_img[0]), n_iter) for k, scaled_maps in enumerate( zip(depth, disparities, warped, diff, explainability_mask)): log_output_tensorboard(train_writer, "train", k, n_iter, *scaled_maps) # record loss and EPE losses.update(loss.item(), args.batch_size) # compute gradient and do Adam step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() #csv record with open(args.save_path / args.log_full, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([ loss.item(), loss_1.item(), loss_2.item() if w2 > 0 else 0, loss_3.item() ]) logger.train_bar.update(i + 1) if i % args.print_freq == 0: logger.train_writer.write('Train: Time {} Data {} Loss {}'.format( batch_time, data_time, losses)) if i >= epoch_size - 1: break n_iter += 1 return losses.avg[0]
def validate(val_loader, disp_net, pose_exp_net, epoch, logger, output_writers=[]): global args batch_time = AverageMeter() losses = AverageMeter(i=3, precision=4) log_outputs = len(output_writers) > 0 w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight poses = np.zeros( ((len(val_loader) - 1) * args.batch_size * (args.sequence_length - 1), 6)) # switch to evaluate mode disp_net.eval() pose_exp_net.eval() end = time.time() for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(val_loader): tgt_img_var = Variable(tgt_img.cuda(), volatile=True) ref_imgs_var = [ Variable(img.cuda(), volatile=True) for img in ref_imgs ] intrinsics_var = Variable(intrinsics.cuda(), volatile=True) intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True) # compute output disp = disp_net(tgt_img_var) depth = 1 / disp explainability_mask, pose = pose_exp_net(tgt_img_var, ref_imgs_var) loss_1 = photometric_reconstruction_loss(tgt_img_var, ref_imgs_var, intrinsics_var, intrinsics_inv_var, depth, explainability_mask, pose) loss_1 = loss_1.data[0] if w2 > 0: loss_2 = explainability_loss(explainability_mask).data[0] else: loss_2 = 0 loss_3 = smooth_loss(disp).data[0] if log_outputs and i % 100 == 0 and i / 100 < len( output_writers): # log first output of every 100 batch index = int(i // 100) if epoch == 0: for j, ref in enumerate(ref_imgs): output_writers[index].add_image('val Input {}'.format(j), tensor2array(tgt_img[0]), 0) output_writers[index].add_image('val Input {}'.format(j), tensor2array(ref[0]), 1) output_writers[index].add_image( 'val Dispnet Output Normalized', tensor2array(disp.data[0].cpu(), max_value=None, colormap='bone'), epoch) output_writers[index].add_image( 'val Depth Output', tensor2array(1. / disp.data[0].cpu(), max_value=10), epoch) # log warped images along with explainability mask for j, ref in enumerate(ref_imgs_var): ref_warped = inverse_warp(ref[:1], depth[:1, 0], pose[:1, j], intrinsics_var[:1], intrinsics_inv_var[:1])[0] output_writers[index].add_image( 'val Warped Outputs {}'.format(j), tensor2array(ref_warped.data.cpu()), epoch) output_writers[index].add_image( 'val Diff Outputs {}'.format(j), tensor2array( 0.5 * (tgt_img_var[0] - ref_warped).abs().data.cpu()), epoch) if explainability_mask is not None: output_writers[index].add_image( 'val Exp mask Outputs {}'.format(j), tensor2array(explainability_mask[0, j].data.cpu(), max_value=1, colormap='bone'), epoch) if log_outputs and i < len(val_loader) - 1: step = args.batch_size * (args.sequence_length - 1) poses[i * step:(i + 1) * step] = pose.data.cpu().view(-1, 6).numpy() loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 losses.update([loss, loss_1, loss_2]) # measure elapsed time batch_time.update(time.time() - end) end = time.time() logger.valid_bar.update(i) if i % args.print_freq == 0: logger.valid_writer.write('valid: Time {} Loss {}'.format( batch_time, losses)) if log_outputs: output_writers[0].add_histogram('val poses_tx', poses[:, 0], epoch) output_writers[0].add_histogram('val poses_ty', poses[:, 1], epoch) output_writers[0].add_histogram('val poses_tz', poses[:, 2], epoch) output_writers[0].add_histogram('val poses_rx', poses[:, 3], epoch) output_writers[0].add_histogram('val poses_ry', poses[:, 4], epoch) output_writers[0].add_histogram('val poses_rz', poses[:, 5], epoch) return losses.avg
def train(train_loader, disp_net, pose_net, mask_net, flow_net, optimizer, logger=None, train_writer=None, global_vars_dict=None): # 0. 准备 args = global_vars_dict['args'] n_iter = global_vars_dict['n_iter'] device = global_vars_dict['device'] batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter(precision=4) w1, w2, w3, w4 = args.cam_photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight, args.flow_photo_loss_weight w5 = args.consensus_loss_weight if args.robust: loss_camera = photometric_reconstruction_loss_robust loss_flow = photometric_flow_loss_robust else: loss_camera = photometric_reconstruction_loss loss_flow = photometric_flow_loss #2. switch to train mode disp_net.train() pose_net.train() mask_net.train() flow_net.train() end = time.time() #3. train cycle for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) tgt_img = tgt_img.to(device) ref_imgs = [img.to(device) for img in ref_imgs] intrinsics = intrinsics.to(device) intrinsics_inv = intrinsics_inv.to(device) #3.1 compute output and lossfunc input valve--------------------- #1. disp->depth(none) disparities = disp_net(tgt_img) if args.spatial_normalize: disparities = [spatial_normalize(disp) for disp in disparities] depth = [1 / disp for disp in disparities] #2. pose(none) pose = pose_net(tgt_img, ref_imgs) #pose:[4,4,6] #3.flow_fwd,flow_bwd 全光流 (depth, pose) # 自己改了一点 if args.flownet == 'Back2Future': #临近一共三帧做训练/推断 flow_fwd, flow_bwd, _ = flow_net(tgt_img, ref_imgs[1:3]) elif args.flownet == 'FlowNetC6': flow_fwd = flow_net(tgt_img, ref_imgs[2]) flow_bwd = flow_net(tgt_img, ref_imgs[1]) elif args.flownet == 'FlowNetS': print(' ') # flow_cam 即背景光流 # flow - flow_s = flow_o flow_cam = pose2flow( depth[0].squeeze(), pose[:, 2], intrinsics, intrinsics_inv) # pose[:,2] belongs to forward frame flows_cam_fwd = [ pose2flow(depth_.squeeze(1), pose[:, 2], intrinsics, intrinsics_inv) for depth_ in depth ] flows_cam_bwd = [ pose2flow(depth_.squeeze(1), pose[:, 1], intrinsics, intrinsics_inv) for depth_ in depth ] exp_masks_target = consensus_exp_masks(flows_cam_fwd, flows_cam_bwd, flow_fwd, flow_bwd, tgt_img, ref_imgs[2], ref_imgs[1], wssim=args.wssim, wrig=args.wrig, ws=args.smooth_loss_weight) rigidity_mask_fwd = [ (flows_cam_fwd_i - flow_fwd_i).abs() for flows_cam_fwd_i, flow_fwd_i in zip(flows_cam_fwd, flow_fwd) ] # .normalize() rigidity_mask_bwd = [ (flows_cam_bwd_i - flow_bwd_i).abs() for flows_cam_bwd_i, flow_bwd_i in zip(flows_cam_bwd, flow_bwd) ] # .normalize() #v_u # 4.explainability_mask(none) explainability_mask = mask_net(tgt_img, ref_imgs) #有效区域?4?? #list(5):item:tensor:[4,4,128,512]...[4,4,4,16] value:[0.33~0.48~0.63] #------------------------------------------------- if args.joint_mask_for_depth: explainability_mask_for_depth = compute_joint_mask_for_depth( explainability_mask, rigidity_mask_bwd, rigidity_mask_fwd, args.THRESH) else: explainability_mask_for_depth = explainability_mask #explainability_mask_for_depth list(5) [b,2,h/ , w/] if args.no_non_rigid_mask: flow_exp_mask = [None for exp_mask in explainability_mask] if args.DEBUG: print('Using no masks for flow') else: flow_exp_mask = [ 1 - exp_mask[:, 1:3] for exp_mask in explainability_mask ] # explaninbility mask 本来是背景mask, 背景对应像素为1 #取反改成动物mask,并且只要前后两帧 #list(4) [4,2,256,512] #3.2. compute loss重 # E-r minimizes the photometric loss on static scene if w1 > 0: loss_1 = loss_camera(tgt_img, ref_imgs, intrinsics, intrinsics_inv, depth, explainability_mask_for_depth, pose, lambda_oob=args.lambda_oob, qch=args.qch, wssim=args.wssim) else: loss_1 = torch.tensor([0.]).to(device) # E_M if w2 > 0: loss_2 = explainability_loss( explainability_mask ) #+ 0.2*gaussian_explainability_loss(explainability_mask) else: loss_2 = 0 # E_S if w3 > 0: if args.smoothness_type == "regular": loss_3 = smooth_loss(depth) + smooth_loss( flow_fwd) + smooth_loss(flow_bwd) + smooth_loss( explainability_mask) elif args.smoothness_type == "edgeaware": loss_3 = edge_aware_smoothness_loss( tgt_img, depth) + edge_aware_smoothness_loss( tgt_img, flow_fwd) loss_3 += edge_aware_smoothness_loss( tgt_img, flow_bwd) + edge_aware_smoothness_loss( tgt_img, explainability_mask) else: loss_3 = torch.tensor([0.]).to(device) # E_F # minimizes photometric loss on moving regions if w4 > 0: loss_4 = loss_flow(tgt_img, ref_imgs[1:3], [flow_bwd, flow_fwd], flow_exp_mask, lambda_oob=args.lambda_oob, qch=args.qch, wssim=args.wssim) else: loss_4 = torch.tensor([0.]).to(device) # E_C # drives the collaboration #explainagy_mask:list(6) of [4,4,4,16] rigidity_mask :list(4):[4,2,128,512] if w5 > 0: loss_5 = consensus_depth_flow_mask(explainability_mask, rigidity_mask_bwd, rigidity_mask_fwd, exp_masks_target, exp_masks_target, THRESH=args.THRESH, wbce=args.wbce) else: loss_5 = torch.tensor([0.]).to(device) #3.2.6 loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 + w4 * loss_4 + w5 * loss_5 #end of loss #3.3 # record loss and EPE losses.update(loss.item(), args.batch_size) # compute gradient and do Adam step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() #3.4 log data # add scalar if args.scalar_freq > 0 and n_iter % args.scalar_freq == 0: train_writer.add_scalar('batch/cam_photometric_error', loss_1.item(), n_iter) if w2 > 0: train_writer.add_scalar('batch/explanability_loss', loss_2.item(), n_iter) train_writer.add_scalar('batch/disparity_smoothness_loss', loss_3.item(), n_iter) train_writer.add_scalar('batch/flow_photometric_error', loss_4.item(), n_iter) train_writer.add_scalar('batch/consensus_error', loss_5.item(), n_iter) train_writer.add_scalar('batch/total_loss', loss.item(), n_iter) # add_image为0 则不输出 if args.training_output_freq > 0 and n_iter % args.training_output_freq == 0: train_writer.add_image('train Input', tensor2array(tgt_img[0]), n_iter) train_writer.add_image( 'train Cam Flow Output', flow_to_image(tensor2array(flow_cam.data[0].cpu())), n_iter) for k, scaled_depth in enumerate(depth): train_writer.add_image( 'train Dispnet Output Normalized111 {}'.format(k), tensor2array(disparities[k].data[0].cpu(), max_value=None, colormap='bone'), n_iter) train_writer.add_image( 'train Depth Output {}'.format(k), tensor2array(1 / disparities[k].data[0].cpu(), max_value=10), n_iter) train_writer.add_image( 'train Non Rigid Flow Output {}'.format(k), flow_to_image(tensor2array(flow_fwd[k].data[0].cpu())), n_iter) train_writer.add_image( 'train Target Rigidity {}'.format(k), tensor2array((rigidity_mask_fwd[k] > args.THRESH).type_as( rigidity_mask_fwd[k]).data[0].cpu(), max_value=1, colormap='bone'), n_iter) b, _, h, w = scaled_depth.size() downscale = tgt_img.size(2) / h tgt_img_scaled = nn.functional.adaptive_avg_pool2d( tgt_img, (h, w)) ref_imgs_scaled = [ nn.functional.adaptive_avg_pool2d(ref_img, (h, w)) for ref_img in ref_imgs ] intrinsics_scaled = torch.cat( (intrinsics[:, 0:2] / downscale, intrinsics[:, 2:]), dim=1) intrinsics_scaled_inv = torch.cat( (intrinsics_inv[:, :, 0:2] * downscale, intrinsics_inv[:, :, 2:]), dim=2) train_writer.add_image( 'train Non Rigid Warped Image {}'.format(k), tensor2array( flow_warp(ref_imgs_scaled[2], flow_fwd[k]).data[0].cpu()), n_iter) # log warped images along with explainability mask for j, ref in enumerate(ref_imgs_scaled): ref_warped = inverse_warp( ref, scaled_depth[:, 0], pose[:, j], intrinsics_scaled, intrinsics_scaled_inv, rotation_mode=args.rotation_mode, padding_mode=args.padding_mode)[0] train_writer.add_image( 'train Warped Outputs {} {}'.format(k, j), tensor2array(ref_warped.data.cpu()), n_iter) train_writer.add_image( 'train Diff Outputs {} {}'.format(k, j), tensor2array( 0.5 * (tgt_img_scaled[0] - ref_warped).abs().data.cpu()), n_iter) if explainability_mask[k] is not None: train_writer.add_image( 'train Exp mask Outputs {} {}'.format(k, j), tensor2array(explainability_mask[k][0, j].data.cpu(), max_value=1, colormap='bone'), n_iter) # csv file write with open(args.save_path / args.log_full, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([ loss.item(), loss_1.item(), loss_2.item() if w2 > 0 else 0, loss_3.item(), loss_4.item() ]) #terminal output if args.log_terminal: logger.train_bar.update(i + 1) #当前epoch 进度 if i % args.print_freq == 0: logger.valid_bar_writer.write( 'Train: Time {} Data {} Loss {}'.format( batch_time, data_time, losses)) # 3.4 edge conditionsssssssssssssssssssssssss epoch_size = len(train_loader) if i >= epoch_size - 1: break n_iter += 1 global_vars_dict['n_iter'] = n_iter return losses.avg[0] #epoch loss
def train(args, train_loader, disp_net, pose_exp_net, optimizer, epoch_size, logger, train_writer): global n_iter, device batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter(precision=4) w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight # switch to train mode disp_net.train() pose_exp_net.train() end = time.time() logger.train_bar.update(0) for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) tgt_img = tgt_img.to(device) ref_imgs = [img.to(device) for img in ref_imgs] intrinsics = intrinsics.to(device) # compute output disparities = disp_net(tgt_img) depth = [1 / disp for disp in disparities] explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs) loss_1 = photometric_reconstruction_loss(tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose, args.rotation_mode, args.padding_mode) if w2 > 0: loss_2 = explainability_loss(explainability_mask) else: loss_2 = 0 loss_3 = smooth_loss(depth) loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 if i > 0 and n_iter % args.print_freq == 0: train_writer.add_scalar('photometric_error', loss_1.item(), n_iter) if w2 > 0: train_writer.add_scalar('explanability_loss', loss_2.item(), n_iter) train_writer.add_scalar('disparity_smoothness_loss', loss_3.item(), n_iter) train_writer.add_scalar('total_loss', loss.item(), n_iter) if args.training_output_freq > 0 and n_iter % args.training_output_freq == 0: train_writer.add_image('train Input', tensor2array(tgt_img[0]), n_iter) with torch.no_grad(): for k, scaled_depth in enumerate(depth): train_writer.add_image( 'train Dispnet Output Normalized {}'.format(k), tensor2array(disparities[k][0], max_value=None, colormap='magma'), n_iter) train_writer.add_image( 'train Depth Output Normalized {}'.format(k), tensor2array(1 / disparities[k][0], max_value=None), n_iter) b, _, h, w = scaled_depth.size() downscale = tgt_img.size(2) / h tgt_img_scaled = F.interpolate(tgt_img, (h, w), mode='area') ref_imgs_scaled = [ F.interpolate(ref_img, (h, w), mode='area') for ref_img in ref_imgs ] intrinsics_scaled = torch.cat( (intrinsics[:, 0:2] / downscale, intrinsics[:, 2:]), dim=1) # log warped images along with explainability mask for j, ref in enumerate(ref_imgs_scaled): ref_warped = inverse_warp( ref, scaled_depth[:, 0], pose[:, j], intrinsics_scaled, rotation_mode=args.rotation_mode, padding_mode=args.padding_mode)[0] train_writer.add_image( 'train Warped Outputs {} {}'.format(k, j), tensor2array(ref_warped), n_iter) train_writer.add_image( 'train Diff Outputs {} {}'.format(k, j), tensor2array( 0.5 * (tgt_img_scaled[0] - ref_warped).abs()), n_iter) if explainability_mask[k] is not None: train_writer.add_image( 'train Exp mask Outputs {} {}'.format(k, j), tensor2array(explainability_mask[k][0, j], max_value=1, colormap='bone'), n_iter) # record loss and EPE losses.update(loss.item(), args.batch_size) # compute gradient and do Adam step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() with open(args.save_path / args.log_full, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([ loss.item(), loss_1.item(), loss_2.item() if w2 > 0 else 0, loss_3.item() ]) logger.train_bar.update(i + 1) if i % args.print_freq == 0: logger.train_writer.write('Train: Time {} Data {} Loss {}'.format( batch_time, data_time, losses)) if i >= epoch_size - 1: break n_iter += 1 return losses.avg[0]
def train_depth_gt(train_loader, disp_net, optimizer, criterion, logger=None, train_writer=None, global_vars_dict=None): # 0. 准备 args = global_vars_dict['args'] n_iter = global_vars_dict['n_iter'] device = global_vars_dict['device'] batch_time = AverageMeter() data_time = AverageMeter() loss_names = ['total_loss', 'l1_loss', 'smooth'] losses = AverageMeter(precision=4, i=len(loss_names)) w1, w2 = args.gt_loss_weight, args.smooth_loss_weight loss_l1 = MaskedL1Loss().to(device) #2. switch to train mode disp_net.train() #pose_net.train() #mask_net.train() #flow_net.train() end = time.time() #3. train cycle numel = args.batch_size * 1 * 256 * 512 #main cycle for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv, gt_depth) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) #dat tgt_img = tgt_img.to(device) ref_imgs = [(img.to(device)) for img in ref_imgs] intrinsics = intrinsics.to(device) intrinsics_inv = intrinsics_inv.to(device) gt_depth = gt_depth.to(device) #[0~1] #gt disparities = disp_net(tgt_img) if args.spatial_normalize: disparities = [spatial_normalize(disp) for disp in disparities] #[0.4,2.7,8.7] output_depth = [1 / disp for disp in disparities] #output_depth = output_depth[0]#只保留最大尺度 # compute gradient and do Adam step # pre_histcs=[] # gt_histcs=[] # for depth in output_depth: # pre_histcs.append(torch.histc(depth,bins=100,min=0,max=1)) loss1 = loss_l1(gt_depth, output_depth) loss2 = smooth_loss(output_depth) loss = w1 * loss1 + w2 * loss2 loss.requires_grad_() loss.to(device) losses.update([loss.item(), loss1.item(), loss2.item()], args.batch_size) #plt.imshow(tensor2array(output_depth[0],out_shape='HWC',colormap='bone')) optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() #log terminal if args.log_terminal: logger.train_logger_update(batch=i, time=batch_time, names=loss_names, values=losses) #3.4 log data#只在train这里输出batch data 尽早看看能否学习 train_writer.add_scalar('batch/l2_loss', loss.item(), n_iter) # 3.4 edge conditions epoch_size = len(train_loader) if i >= epoch_size - 1: break n_iter += 1 global_vars_dict['n_iter'] = n_iter return loss_names, losses #epoch loss
def validate_without_gt(args, val_loader, disp_net, pose_exp_net, epoch, logger, output_writers=[]): global device batch_time = AverageMeter() losses = AverageMeter(i=3, precision=4) log_outputs = len(output_writers) > 0 w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight poses = np.zeros( ((len(val_loader) - 1) * args.batch_size * (args.sequence_length - 1), 6)) disp_values = np.zeros(((len(val_loader) - 1) * args.batch_size * 3)) # switch to evaluate mode disp_net.eval() pose_exp_net.eval() end = time.time() logger.valid_bar.update(0) for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(val_loader): tgt_img = tgt_img.to(device) ref_imgs = [img.to(device) for img in ref_imgs] intrinsics = intrinsics.to(device) intrinsics_inv = intrinsics_inv.to(device) # compute output disp = disp_net(tgt_img) depth = 1 / disp explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs) loss_1, warped, diff = photometric_reconstruction_loss( tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose, args.rotation_mode, args.padding_mode) loss_1 = loss_1.item() if w2 > 0: loss_2 = explainability_loss(explainability_mask).item() else: loss_2 = 0 loss_3 = smooth_loss(depth).item() if log_outputs and i < len( output_writers): # log first output of first batches if epoch == 0: for j, ref in enumerate(ref_imgs): output_writers[i].add_image('val Input {}'.format(j), tensor2array(tgt_img[0]), 0) output_writers[i].add_image('val Input {}'.format(j), tensor2array(ref[0]), 1) log_output_tensorboard(output_writers[i], 'val', '', epoch, 1. / disp, disp, warped, diff, explainability_mask) if log_outputs and i < len(val_loader) - 1: step = args.batch_size * (args.sequence_length - 1) poses[i * step:(i + 1) * step] = pose.cpu().view(-1, 6).numpy() step = args.batch_size * 3 disp_unraveled = disp.cpu().view(args.batch_size, -1) disp_values[i * step:(i + 1) * step] = torch.cat([ disp_unraveled.min(-1)[0], disp_unraveled.median(-1)[0], disp_unraveled.max(-1)[0] ]).numpy() loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 losses.update([loss, loss_1, loss_2]) # measure elapsed time batch_time.update(time.time() - end) end = time.time() logger.valid_bar.update(i + 1) if i % args.print_freq == 0: logger.valid_writer.write('valid: Time {} Loss {}'.format( batch_time, losses)) if log_outputs: prefix = 'valid poses' coeffs_names = ['tx', 'ty', 'tz'] if args.rotation_mode == 'euler': coeffs_names.extend(['rx', 'ry', 'rz']) elif args.rotation_mode == 'quat': coeffs_names.extend(['qx', 'qy', 'qz']) for i in range(poses.shape[1]): output_writers[0].add_histogram( '{} {}'.format(prefix, coeffs_names[i]), poses[:, i], epoch) output_writers[0].add_histogram('disp_values', disp_values, epoch) logger.valid_bar.update(len(val_loader)) return losses.avg, ['Total loss', 'Photo loss', 'Exp loss']
def train(args, train_loader, disp_net, pose_exp_net, optimizer, epoch_size, logger, tb_writer): global n_iter, device batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter(precision=4) w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight # switch to train mode disp_net.train() pose_exp_net.train() end = time.time() logger.train_bar.update(0) for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(train_loader): log_losses = i > 0 and n_iter % args.print_freq == 0 log_output = args.training_output_freq > 0 and n_iter % args.training_output_freq == 0 # measure data loading time data_time.update(time.time() - end) tgt_img = tgt_img.to(device) ref_imgs = [img.to(device) for img in ref_imgs] intrinsics = intrinsics.to(device) # compute output disparities = disp_net(tgt_img) depth = [1 / disp for disp in disparities] # print("***",len(depth),depth[0].size()) explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs) loss_1, warped, diff = photometric_reconstruction_loss( tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose, args.rotation_mode, args.padding_mode) if w2 > 0: loss_2 = explainability_loss(explainability_mask) else: loss_2 = 0 loss_3 = smooth_loss(depth) loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 if args.with_photocon_loss: batch_size = pose.size()[0] homo_row = torch.tensor([[0, 0, 0, 1]], dtype=torch.float).to(device) homo_row = homo_row.unsqueeze(0).expand(batch_size, -1, -1) T21 = pose_vec2mat(pose[:, 0]) T21 = torch.cat((T21, homo_row), 1) T12 = torch.inverse(T21) T23 = pose_vec2mat(pose[:, 1]) T23 = torch.cat((T23, homo_row), 1) T13 = torch.matmul(T23, T12) #[B, 4, 4] # print("----",T13.size()) # target = 1 and ref = 3 ref_img_warped, valid_points = inverse_warp_posemat( ref_imgs[1], depth[0][:, 0], T13, intrinsics, args.rotation_mode, args.padding_mode) diff = (ref_imgs[0] - ref_img_warped) * valid_points.unsqueeze(1).float() loss_4 = diff.abs().mean() loss += loss_4 if log_losses: tb_writer.add_scalar('photometric_error', loss_1.item(), n_iter) if w2 > 0: tb_writer.add_scalar('explanability_loss', loss_2.item(), n_iter) tb_writer.add_scalar('disparity_smoothness_loss', loss_3.item(), n_iter) tb_writer.add_scalar('total_loss', loss.item(), n_iter) if log_output: tb_writer.add_image('train Input', tensor2array(tgt_img[0]), n_iter) for k, scaled_maps in enumerate( zip(depth, disparities, warped, diff, explainability_mask)): log_output_tensorboard(tb_writer, "train", 0, " {}".format(k), n_iter, *scaled_maps) # record loss and EPE losses.update(loss.item(), args.batch_size) # compute gradient and do Adam step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() with open(args.save_path / args.log_full, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([ loss.item(), loss_1.item(), loss_2.item() if w2 > 0 else 0, loss_3.item() ]) logger.train_bar.update(i + 1) if i % args.print_freq == 0: logger.train_writer.write('Train: Time {} Data {} Loss {}'.format( batch_time, data_time, losses)) if i >= epoch_size - 1: break n_iter += 1 return losses.avg[0]