def validate_with_gt(args, val_loader, disp_net, epoch, logger, tb_writer, sample_nb_to_log=3): global device batch_time = AverageMeter() error_names = ['abs_diff', 'abs_rel', 'sq_rel', 'a1', 'a2', 'a3'] errors = AverageMeter(i=len(error_names)) log_outputs = sample_nb_to_log > 0 # Output the logs throughout the whole dataset batches_to_log = list(np.linspace(0, len(val_loader)-1, sample_nb_to_log).astype(int)) # switch to evaluate mode disp_net.eval() end = time.time() logger.valid_bar.update(0) for i, (tgt_img, depth) in enumerate(val_loader): tgt_img = tgt_img.to(device) depth = depth.to(device) # compute output output_disp = disp_net(tgt_img) output_depth = 1/output_disp[:, 0] if log_outputs and i in batches_to_log: index = batches_to_log.index(i) if epoch == 0: tb_writer.add_image('val Input/{}'.format(index), tensor2array(tgt_img[0]), 0) depth_to_show = depth[0] tb_writer.add_image('val target Depth Normalized/{}'.format(index), tensor2array(depth_to_show, max_value=None), epoch) depth_to_show[depth_to_show == 0] = 1000 disp_to_show = (1/depth_to_show).clamp(0, 10) tb_writer.add_image('val target Disparity Normalized/{}'.format(index), tensor2array(disp_to_show, max_value=None, colormap='magma'), epoch) tb_writer.add_image('val Dispnet Output Normalized/{}'.format(index), tensor2array(output_disp[0], max_value=None, colormap='magma'), epoch) tb_writer.add_image('val Depth Output Normalized/{}'.format(index), tensor2array(output_depth[0], max_value=None), epoch) errors.update(compute_depth_errors(depth, output_depth)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() logger.valid_bar.update(i+1) if i % args.print_freq == 0: logger.valid_writer.write('valid: Time {} Abs Error {:.4f} ({:.4f})'.format(batch_time, errors.val[0], errors.avg[0])) logger.valid_bar.update(len(val_loader)) return errors.avg, error_names
def validate_with_gt(args, val_loader, depth_net, pose_net, epoch, logger, output_writers=[], **env): global device batch_time = AverageMeter() depth_error_names = ['abs diff', 'abs rel', 'sq rel', 'a1', 'a2', 'a3'] stab_depth_errors = AverageMeter(i=len(depth_error_names)) unstab_depth_errors = AverageMeter(i=len(depth_error_names)) pose_error_names = ['Absolute Trajectory Error', 'Rotation Error'] pose_errors = AverageMeter(i=len(pose_error_names)) # switch to evaluate mode depth_net.eval() pose_net.eval() end = time.time() logger.valid_bar.update(0) for i, sample in enumerate(val_loader): log_output = i < len(output_writers) imgs = torch.stack(sample['imgs'], dim=1).to(device) batch_size, seq, c, h, w = imgs.size() intrinsics = sample['intrinsics'].to(device) intrinsics_inv = sample['intrinsics_inv'].to(device) if args.network_input_size is not None: imgs = F.interpolate(imgs, (c, *args.network_input_size), mode='area') downscale = h / args.network_input_size[0] intrinsics = torch.cat( (intrinsics[:, 0:2] / downscale, intrinsics[:, 2:]), dim=1) intrinsics_inv = torch.cat( (intrinsics_inv[:, :, 0:2] * downscale, intrinsics_inv[:, :, 2:]), dim=2) GT_depth = sample['depth'].to(device) GT_pose = sample['pose'].to(device) mid_index = (args.sequence_length - 1) // 2 tgt_img = imgs[:, mid_index] if epoch == 1 and log_output: for j, img in enumerate(sample['imgs']): output_writers[i].add_image('val Input', tensor2array(img[0]), j) depth_to_show = GT_depth[0].cpu() # KITTI Like data routine to discard invalid data depth_to_show[depth_to_show == 0] = 1000 disp_to_show = (1 / depth_to_show).clamp(0, 10) output_writers[i].add_image( 'val target Disparity Normalized', tensor2array(disp_to_show, max_value=None, colormap='bone'), epoch) poses = pose_net(imgs) pose_matrices = pose_vec2mat(poses, args.rotation_mode) # [B, seq, 3, 4] inverted_pose_matrices = invert_mat(pose_matrices) pose_errors.update( compute_pose_error(GT_pose[:, :-1], inverted_pose_matrices.data[:, :-1])) tgt_poses = pose_matrices[:, mid_index] # [B, 3, 4] compensated_predicted_poses = compensate_pose(pose_matrices, tgt_poses) compensated_GT_poses = compensate_pose(GT_pose, GT_pose[:, mid_index]) for j in range(args.sequence_length): if j == mid_index: if log_output and epoch == 1: output_writers[i].add_image( 'val Input Stabilized', tensor2array(sample['imgs'][j][0]), j) continue '''compute displacement magnitude for each element of batch, and rescale depth accordingly.''' prior_img = imgs[:, j] displacement = compensated_GT_poses[:, j, :, -1] # [B,3] displacement_magnitude = displacement.norm(p=2, dim=1) # [B] current_GT_depth = GT_depth * args.nominal_displacement / displacement_magnitude.view( -1, 1, 1) prior_predicted_pose = compensated_predicted_poses[:, j] # [B, 3, 4] prior_GT_pose = compensated_GT_poses[:, j] prior_predicted_rot = prior_predicted_pose[:, :, :-1] prior_GT_rot = prior_GT_pose[:, :, :-1].transpose(1, 2) prior_compensated_from_GT = inverse_rotate(prior_img, prior_GT_rot, intrinsics, intrinsics_inv) if log_output and epoch == 1: depth_to_show = current_GT_depth[0] output_writers[i].add_image( 'val target Depth {}'.format(j), tensor2array(depth_to_show, max_value=args.max_depth), epoch) output_writers[i].add_image( 'val Input Stabilized', tensor2array(prior_compensated_from_GT[0]), j) prior_compensated_from_prediction = inverse_rotate( prior_img, prior_predicted_rot, intrinsics, intrinsics_inv) predicted_input_pair = torch.cat( [prior_compensated_from_prediction, tgt_img], dim=1) # [B, 6, W, H] GT_input_pair = torch.cat([prior_compensated_from_GT, tgt_img], dim=1) # [B, 6, W, H] # This is the depth from footage stabilized with GT pose, it should be better than depth from raw footage without any GT info raw_depth_stab = depth_net(GT_input_pair) raw_depth_unstab = depth_net(predicted_input_pair) # Upsample depth so that it matches GT size scale_factor = GT_depth.size(-1) // raw_depth_stab.size(-1) depth_stab = F.interpolate(raw_depth_stab, scale_factor=scale_factor, mode='bilinear', align_corners=False) depth_unstab = F.interpolate(raw_depth_unstab, scale_factor=scale_factor, mode='bilinear', align_corners=False) for k, depth in enumerate([depth_stab, depth_unstab]): disparity = 1 / depth errors = stab_depth_errors if k == 0 else unstab_depth_errors errors.update( compute_depth_errors(current_GT_depth, depth, crop=True)) if log_output: prefix = 'stabilized' if k == 0 else 'unstabilized' output_writers[i].add_image( 'val {} Dispnet Output Normalized {}'.format( prefix, j), tensor2array(disparity[0], max_value=None, colormap='bone'), epoch) output_writers[i].add_image( 'val {} Depth Output {}'.format(prefix, j), tensor2array(depth[0], max_value=args.max_depth), epoch) # measure elapsed time batch_time.update(time.time() - end) end = time.time() logger.valid_bar.update(i + 1) if i % args.print_freq == 0: logger.valid_writer.write( 'valid: Time {} ATE Error {:.4f} ({:.4f}), Unstab Rel Abs Error {:.4f} ({:.4f})' .format(batch_time, pose_errors.val[0], pose_errors.avg[0], unstab_depth_errors.val[1], unstab_depth_errors.avg[1])) logger.valid_bar.update(len(val_loader)) errors = (*pose_errors.avg, *unstab_depth_errors.avg, *stab_depth_errors.avg) error_names = (*pose_error_names, *['unstab {}'.format(e) for e in depth_error_names], *['stab {}'.format(e) for e in depth_error_names]) return OrderedDict(zip(error_names, errors))
def validate_with_gt_pose(args, val_loader, disp_net, pose_exp_net, epoch, logger, tb_writer, sample_nb_to_log=3): global device batch_time = AverageMeter() depth_error_names = ['abs_diff', 'abs_rel', 'sq_rel', 'a1', 'a2', 'a3'] depth_errors = AverageMeter(i=len(depth_error_names), precision=4) pose_error_names = ['ATE', 'RTE'] pose_errors = AverageMeter(i=2, precision=4) log_outputs = sample_nb_to_log > 0 # Output the logs throughout the whole dataset batches_to_log = list( np.linspace(0, len(val_loader), sample_nb_to_log).astype(int)) poses_values = np.zeros( ((len(val_loader) - 1) * args.batch_size * (args.sequence_length - 1), 6)) disp_values = np.zeros(((len(val_loader) - 1) * args.batch_size * 3)) # switch to evaluate mode disp_net.eval() pose_exp_net.eval() end = time.time() logger.valid_bar.update(0) for i, (tgt_img, ref_imgs, gt_depth, gt_poses) in enumerate(val_loader): tgt_img = tgt_img.to(device) gt_depth = gt_depth.to(device) gt_poses = gt_poses.to(device) ref_imgs = [img.to(device) for img in ref_imgs] b = tgt_img.shape[0] # compute output output_disp = disp_net(tgt_img) output_depth = 1 / output_disp explainability_mask, output_poses = pose_exp_net(tgt_img, ref_imgs) reordered_output_poses = torch.cat([ output_poses[:, :gt_poses.shape[1] // 2], torch.zeros(b, 1, 6).to(output_poses), output_poses[:, gt_poses.shape[1] // 2:] ], dim=1) # pose_vec2mat only takes B, 6 tensors, so we simulate a batch dimension of B * seq_length unravelled_poses = reordered_output_poses.reshape(-1, 6) unravelled_matrices = pose_vec2mat(unravelled_poses, rotation_mode=args.rotation_mode) inv_transform_matrices = unravelled_matrices.reshape(b, -1, 3, 4) rot_matrices = inv_transform_matrices[..., :3].transpose(-2, -1) tr_vectors = -rot_matrices @ inv_transform_matrices[..., -1:] transform_matrices = torch.cat([rot_matrices, tr_vectors], axis=-1) first_inv_transform = inv_transform_matrices.reshape(b, -1, 3, 4)[:, :1] final_poses = first_inv_transform[..., :3] @ transform_matrices final_poses[..., -1:] += first_inv_transform[..., -1:] final_poses = final_poses.reshape(b, -1, 3, 4) if log_outputs and i in batches_to_log: # log first output of wanted batches index = batches_to_log.index(i) if epoch == 0: for j, ref in enumerate(ref_imgs): tb_writer.add_image('val Input {}/{}'.format(j, index), tensor2array(tgt_img[0]), 0) tb_writer.add_image('val Input {}/{}'.format(j, index), tensor2array(ref[0]), 1) log_output_tensorboard(tb_writer, 'val', index, '', epoch, output_depth, output_disp, None, None, explainability_mask) if log_outputs and i < len(val_loader) - 1: step = args.batch_size * (args.sequence_length - 1) poses_values[i * step:(i + 1) * step] = output_poses.cpu().view( -1, 6).numpy() step = args.batch_size * 3 disp_unraveled = output_disp.cpu().view(args.batch_size, -1) disp_values[i * step:(i + 1) * step] = torch.cat([ disp_unraveled.min(-1)[0], disp_unraveled.median(-1)[0], disp_unraveled.max(-1)[0] ]).numpy() depth_errors.update(compute_depth_errors(gt_depth, output_depth[:, 0])) pose_errors.update(compute_pose_errors(gt_poses, final_poses)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() logger.valid_bar.update(i + 1) if i % args.print_freq == 0: logger.valid_writer.write( 'valid: Time {} Abs Error {:.4f} ({:.4f}), ATE {:.4f} ({:.4f})' .format(batch_time, depth_errors.val[0], depth_errors.avg[0], pose_errors.val[0], pose_errors.avg[0])) if log_outputs: prefix = 'valid poses' coeffs_names = ['tx', 'ty', 'tz'] if args.rotation_mode == 'euler': coeffs_names.extend(['rx', 'ry', 'rz']) elif args.rotation_mode == 'quat': coeffs_names.extend(['qx', 'qy', 'qz']) for i in range(poses_values.shape[1]): tb_writer.add_histogram('{} {}'.format(prefix, coeffs_names[i]), poses_values[:, i], epoch) tb_writer.add_histogram('disp_values', disp_values, epoch) logger.valid_bar.update(len(val_loader)) return depth_errors.avg + pose_errors.avg, depth_error_names + pose_error_names