def train(args, train_loader, disvo, optimizer, epoch_size, logger, train_writer): global n_iter, device batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter(precision=4) # switch to train mode disvo.train() end = time.time() logger.train_bar.update(0) for i, (img_ref, img_tar, poses_gt) in enumerate(train_loader): log_losses = i > 0 and n_iter % args.print_freq == 0 log_output = args.training_output_freq > 0 and n_iter % args.training_output_freq == 0 # measure data loading time data_time.update(time.time() - end) img_ref = img_ref.to(device) img_tar = img_tar.to(device) # compute output _, poses_pred = disvo(img_ref, img_tar) loss = sum((poses_pred[:6] - poses_gt).^2 * torch.exp(-poses_pred[6:]) + poses[6:]) if log_losses: train_writer.add_scalar('total_loss', loss.item(), n_iter) if log_output: train_writer.add_image('train Input', tensor2array(tgt_img[0]), n_iter) for k, scaled_maps in enumerate(zip(depth, disparities, warped, diff, explainability_mask)): log_output_tensorboard(train_writer, "train", k, n_iter, *scaled_maps) # record loss and EPE losses.update(loss.item(), args.batch_size) # compute gradient and do Adam step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() with open(args.save_path/args.log_full, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([loss.item()]) logger.train_bar.update(i+1) if i % args.print_freq == 0: logger.train_writer.write('Train: Time {} Data {} Loss {}'.format(batch_time, data_time, losses)) if i >= epoch_size - 1: break n_iter += 1 return losses.avg[0]
def validate_without_gt(args, val_loader, disp_net, pose_exp_net, epoch, logger, output_writers=[]): global device batch_time = AverageMeter() losses = AverageMeter(i=3, precision=4) log_outputs = len(output_writers) > 0 w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight poses = np.zeros( ((len(val_loader) - 1) * args.batch_size * (args.sequence_length - 1), 6)) disp_values = np.zeros(((len(val_loader) - 1) * args.batch_size * 3)) # switch to evaluate mode disp_net.eval() pose_exp_net.eval() end = time.time() logger.valid_bar.update(0) for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(val_loader): tgt_img = tgt_img.to(device) ref_imgs = [img.to(device) for img in ref_imgs] intrinsics = intrinsics.to(device) intrinsics_inv = intrinsics_inv.to(device) # compute output disp = disp_net(tgt_img) depth = 1 / disp explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs) loss_1, warped, diff = photometric_reconstruction_loss( tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose, args.rotation_mode, args.padding_mode) loss_1 = loss_1.item() if w2 > 0: loss_2 = explainability_loss(explainability_mask).item() else: loss_2 = 0 loss_3 = smooth_loss(depth).item() if log_outputs and i < len( output_writers): # log first output of first batches if epoch == 0: for j, ref in enumerate(ref_imgs): output_writers[i].add_image('val Input {}'.format(j), tensor2array(tgt_img[0]), 0) output_writers[i].add_image('val Input {}'.format(j), tensor2array(ref[0]), 1) log_output_tensorboard(output_writers[i], 'val', '', epoch, 1. / disp, disp, warped, diff, explainability_mask) if log_outputs and i < len(val_loader) - 1: step = args.batch_size * (args.sequence_length - 1) poses[i * step:(i + 1) * step] = pose.cpu().view(-1, 6).numpy() step = args.batch_size * 3 disp_unraveled = disp.cpu().view(args.batch_size, -1) disp_values[i * step:(i + 1) * step] = torch.cat([ disp_unraveled.min(-1)[0], disp_unraveled.median(-1)[0], disp_unraveled.max(-1)[0] ]).numpy() loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 losses.update([loss, loss_1, loss_2]) # measure elapsed time batch_time.update(time.time() - end) end = time.time() logger.valid_bar.update(i + 1) if i % args.print_freq == 0: logger.valid_writer.write('valid: Time {} Loss {}'.format( batch_time, losses)) if log_outputs: prefix = 'valid poses' coeffs_names = ['tx', 'ty', 'tz'] if args.rotation_mode == 'euler': coeffs_names.extend(['rx', 'ry', 'rz']) elif args.rotation_mode == 'quat': coeffs_names.extend(['qx', 'qy', 'qz']) for i in range(poses.shape[1]): output_writers[0].add_histogram( '{} {}'.format(prefix, coeffs_names[i]), poses[:, i], epoch) output_writers[0].add_histogram('disp_values', disp_values, epoch) logger.valid_bar.update(len(val_loader)) return losses.avg, ['Total loss', 'Photo loss', 'Exp loss']
def train(args, train_loader, disp_net, pose_exp_net, optimizer, epoch_size, logger, train_writer): global n_iter, device batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter(precision=4) w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight # switch to train mode disp_net.train() pose_exp_net.train() end = time.time() logger.train_bar.update(0) for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(train_loader): log_losses = i > 0 and n_iter % args.print_freq == 0 log_output = args.training_output_freq > 0 and n_iter % args.training_output_freq == 0 # measure data loading time data_time.update(time.time() - end) tgt_img = tgt_img.to(device) ref_imgs = [img.to(device) for img in ref_imgs] intrinsics = intrinsics.to(device) # compute output disparities = disp_net(tgt_img) depth = [1 / disp for disp in disparities] explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs) loss_1, warped, diff = photometric_reconstruction_loss( tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose, args.rotation_mode, args.padding_mode) if w2 > 0: loss_2 = explainability_loss(explainability_mask) else: loss_2 = 0 loss_3 = inverse_depth_smooth_loss(depth, tgt_img) loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 if log_losses: train_writer.add_scalar('photometric_error', loss_1.item(), n_iter) if w2 > 0: train_writer.add_scalar('explanability_loss', loss_2.item(), n_iter) train_writer.add_scalar('disparity_smoothness_loss', loss_3.item(), n_iter) train_writer.add_scalar('total_loss', loss.item(), n_iter) if log_output: train_writer.add_image('train Input', tensor2array(tgt_img[0]), n_iter) for k, scaled_maps in enumerate( zip(depth, disparities, warped, diff, explainability_mask)): log_output_tensorboard(train_writer, "train", k, n_iter, *scaled_maps) # record loss and EPE losses.update(loss.item(), args.batch_size) # compute gradient and do Adam step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() with open(args.save_path / args.log_full, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([ loss.item(), loss_1.item(), loss_2.item() if w2 > 0 else 0, loss_3.item() ]) logger.train_bar.update(i + 1) if i % args.print_freq == 0: logger.train_writer.write('Train: Time {} Data {} Loss {}'.format( batch_time, data_time, losses)) if i >= epoch_size - 1: break n_iter += 1 return losses.avg[0]
def validate_with_gt_pose(args, val_loader, disp_net, pose_exp_net, epoch, logger, tb_writer, sample_nb_to_log=3): global device batch_time = AverageMeter() depth_error_names = ['abs_diff', 'abs_rel', 'sq_rel', 'a1', 'a2', 'a3'] depth_errors = AverageMeter(i=len(depth_error_names), precision=4) pose_error_names = ['ATE', 'RTE'] pose_errors = AverageMeter(i=2, precision=4) log_outputs = sample_nb_to_log > 0 # Output the logs throughout the whole dataset batches_to_log = list( np.linspace(0, len(val_loader), sample_nb_to_log).astype(int)) poses_values = np.zeros( ((len(val_loader) - 1) * args.batch_size * (args.sequence_length - 1), 6)) disp_values = np.zeros(((len(val_loader) - 1) * args.batch_size * 3)) # switch to evaluate mode disp_net.eval() pose_exp_net.eval() end = time.time() logger.valid_bar.update(0) for i, (tgt_img, ref_imgs, gt_depth, gt_poses) in enumerate(val_loader): tgt_img = tgt_img.to(device) gt_depth = gt_depth.to(device) gt_poses = gt_poses.to(device) ref_imgs = [img.to(device) for img in ref_imgs] b = tgt_img.shape[0] # compute output output_disp = disp_net(tgt_img) output_depth = 1 / output_disp explainability_mask, output_poses = pose_exp_net(tgt_img, ref_imgs) reordered_output_poses = torch.cat([ output_poses[:, :gt_poses.shape[1] // 2], torch.zeros(b, 1, 6).to(output_poses), output_poses[:, gt_poses.shape[1] // 2:] ], dim=1) # pose_vec2mat only takes B, 6 tensors, so we simulate a batch dimension of B * seq_length unravelled_poses = reordered_output_poses.reshape(-1, 6) unravelled_matrices = pose_vec2mat(unravelled_poses, rotation_mode=args.rotation_mode) inv_transform_matrices = unravelled_matrices.reshape(b, -1, 3, 4) rot_matrices = inv_transform_matrices[..., :3].transpose(-2, -1) tr_vectors = -rot_matrices @ inv_transform_matrices[..., -1:] transform_matrices = torch.cat([rot_matrices, tr_vectors], axis=-1) first_inv_transform = inv_transform_matrices.reshape(b, -1, 3, 4)[:, :1] final_poses = first_inv_transform[..., :3] @ transform_matrices final_poses[..., -1:] += first_inv_transform[..., -1:] final_poses = final_poses.reshape(b, -1, 3, 4) if log_outputs and i in batches_to_log: # log first output of wanted batches index = batches_to_log.index(i) if epoch == 0: for j, ref in enumerate(ref_imgs): tb_writer.add_image('val Input {}/{}'.format(j, index), tensor2array(tgt_img[0]), 0) tb_writer.add_image('val Input {}/{}'.format(j, index), tensor2array(ref[0]), 1) log_output_tensorboard(tb_writer, 'val', index, '', epoch, output_depth, output_disp, None, None, explainability_mask) if log_outputs and i < len(val_loader) - 1: step = args.batch_size * (args.sequence_length - 1) poses_values[i * step:(i + 1) * step] = output_poses.cpu().view( -1, 6).numpy() step = args.batch_size * 3 disp_unraveled = output_disp.cpu().view(args.batch_size, -1) disp_values[i * step:(i + 1) * step] = torch.cat([ disp_unraveled.min(-1)[0], disp_unraveled.median(-1)[0], disp_unraveled.max(-1)[0] ]).numpy() depth_errors.update(compute_depth_errors(gt_depth, output_depth[:, 0])) pose_errors.update(compute_pose_errors(gt_poses, final_poses)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() logger.valid_bar.update(i + 1) if i % args.print_freq == 0: logger.valid_writer.write( 'valid: Time {} Abs Error {:.4f} ({:.4f}), ATE {:.4f} ({:.4f})' .format(batch_time, depth_errors.val[0], depth_errors.avg[0], pose_errors.val[0], pose_errors.avg[0])) if log_outputs: prefix = 'valid poses' coeffs_names = ['tx', 'ty', 'tz'] if args.rotation_mode == 'euler': coeffs_names.extend(['rx', 'ry', 'rz']) elif args.rotation_mode == 'quat': coeffs_names.extend(['qx', 'qy', 'qz']) for i in range(poses_values.shape[1]): tb_writer.add_histogram('{} {}'.format(prefix, coeffs_names[i]), poses_values[:, i], epoch) tb_writer.add_histogram('disp_values', disp_values, epoch) logger.valid_bar.update(len(val_loader)) return depth_errors.avg + pose_errors.avg, depth_error_names + pose_error_names
def validate_without_gt(args, val_loader, disp_net, pose_exp_net, epoch, logger, tb_writer, sample_nb_to_log=3): global device mse_l = torch.nn.MSELoss(reduction='mean') batch_time = AverageMeter() losses = AverageMeter(i=3, precision=4) log_outputs = sample_nb_to_log > 0 # Output the logs throughout the whole dataset batches_to_log = list( np.linspace(0, len(val_loader), sample_nb_to_log).astype(int)) w1, w2, w3, wf, wp = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight, args.flow_loss_weight, args.prior_loss_weight poses = np.zeros( ((len(val_loader) - 1) * args.batch_size * (args.sequence_length - 1), 6)) disp_values = np.zeros(((len(val_loader) - 1) * args.batch_size * 3)) # switch to evaluate mode disp_net.eval() pose_exp_net.eval() end = time.time() logger.valid_bar.update(0) for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv, flow_maps) in enumerate(val_loader): tgt_img = tgt_img.to(device) ref_imgs = [img.to(device) for img in ref_imgs] intrinsics = intrinsics.to(device) intrinsics_inv = intrinsics_inv.to(device) flow_maps = [flow_map.to(device) for flow_map in flow_maps] # compute output disp = disp_net(tgt_img) depth = 1 / disp explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs) loss_1, warped, diff, grid = photometric_reconstruction_loss( tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose, args.rotation_mode, args.padding_mode) loss_1 = loss_1.item() if wf > 0: loss_f = flow_consistency_loss(grid, flow_maps, mse_l) else: loss_f = 0 if wp > 0: loss_p = ground_prior_loss(disp) else: loss_p = 0 if w2 > 0: loss_2 = explainability_loss(explainability_mask).item() else: loss_2 = 0 loss_3 = smooth_loss(depth).item() if log_outputs and i in batches_to_log: # log first output of wanted batches index = batches_to_log.index(i) if epoch == 0: for j, ref in enumerate(ref_imgs): tb_writer.add_image('val Input {}/{}'.format(j, index), tensor2array(tgt_img[0]), 0) tb_writer.add_image('val Input {}/{}'.format(j, index), tensor2array(ref[0]), 1) log_output_tensorboard(tb_writer, 'val', index, '', epoch, 1. / disp, disp, warped[0], diff[0], explainability_mask) if log_outputs and i < len(val_loader) - 1: step = args.batch_size * (args.sequence_length - 1) poses[i * step:(i + 1) * step] = pose.cpu().view(-1, 6).numpy() step = args.batch_size * 3 disp_unraveled = disp.cpu().view(args.batch_size, -1) disp_values[i * step:(i + 1) * step] = torch.cat([ disp_unraveled.min(-1)[0], disp_unraveled.median(-1)[0], disp_unraveled.max(-1)[0] ]).numpy() loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 + wf * loss_f + wp * loss_p losses.update([loss, loss_1, loss_2]) # measure elapsed time batch_time.update(time.time() - end) end = time.time() logger.valid_bar.update(i + 1) if i % args.print_freq == 0: logger.valid_writer.write('valid: Time {} Loss {}'.format( batch_time, losses)) if log_outputs: prefix = 'valid poses' coeffs_names = ['tx', 'ty', 'tz'] if args.rotation_mode == 'euler': coeffs_names.extend(['rx', 'ry', 'rz']) elif args.rotation_mode == 'quat': coeffs_names.extend(['qx', 'qy', 'qz']) for i in range(poses.shape[1]): tb_writer.add_histogram('{} {}'.format(prefix, coeffs_names[i]), poses[:, i], epoch) tb_writer.add_histogram('disp_values', disp_values, epoch) logger.valid_bar.update(len(val_loader)) return losses.avg, [ 'Validation Total loss', 'Validation Photo loss', 'Validation Exp loss' ]
def validate_without_gt(args, val_loader, depth_net, pose_net, epoch, logger, tb_writer, sample_nb_to_log, **env): global device batch_time = AverageMeter() losses = AverageMeter(i=3, precision=4) w1, w2, w3 = args.photo_loss_weight, args.smooth_loss_weight, args.ssim if args.log_output: poses_values = np.zeros(((len(val_loader) - 1) * args.test_batch_size * (args.sequence_length - 1), 6)) disp_values = np.zeros( ((len(val_loader) - 1) * args.test_batch_size * 3)) # switch to evaluate mode depth_net.eval() pose_net.eval() end = time.time() logger.valid_bar.update(0) for i, sample in enumerate(val_loader): log_output = i < sample_nb_to_log imgs = torch.stack(sample['imgs'], dim=1).to(device) intrinsics = sample['intrinsics'].to(device) if epoch == 1 and log_output: for j, img in enumerate(sample['imgs']): tb_writer.add_image('val Input/{}'.format(i), tensor2array(img[0]), j) batch_size, seq = imgs.size()[:2] poses = pose_net(imgs) pose_matrices = pose_vec2mat(poses, args.rotation_mode) # [B, seq, 3, 4] mid_index = (args.sequence_length - 1) // 2 tgt_imgs = imgs[:, mid_index] # [B, 3, H, W] tgt_poses = pose_matrices[:, mid_index] # [B, 3, 4] compensated_poses = compensate_pose( pose_matrices, tgt_poses) # [B, seq, 3, 4] tgt_poses are now neutral pose ref_ids = list(range(args.sequence_length)) ref_ids.remove(mid_index) loss_1 = 0 loss_2 = 0 for ref_index in ref_ids: prior_imgs = imgs[:, ref_index] prior_poses = compensated_poses[:, ref_index] # [B, 3, 4] prior_imgs_compensated = inverse_rotate(prior_imgs, prior_poses[:, :, :3], intrinsics) input_pair = torch.cat([prior_imgs_compensated, tgt_imgs], dim=1) # [B, 6, W, H] predicted_magnitude = prior_poses[:, :, -1:].norm( p=2, dim=1, keepdim=True).unsqueeze(1) # [B, 1, 1, 1] scale_factor = args.nominal_displacement / predicted_magnitude normalized_translation = compensated_poses[:, :, :, -1:] * scale_factor # [B, seq, 3, 1] new_pose_matrices = torch.cat( [compensated_poses[:, :, :, :-1], normalized_translation], dim=-1) depth = depth_net(input_pair) disparity = 1 / depth tgt_id = torch.full((batch_size, ), ref_index, dtype=torch.int64, device=device) ref_ids_tensor = torch.tensor(ref_ids, dtype=torch.int64, device=device).expand( batch_size, -1) photo_loss, *to_log = photometric_reconstruction_loss( imgs, tgt_id, ref_ids_tensor, depth, new_pose_matrices, intrinsics, args.rotation_mode, ssim_weight=w3, upsample=args.upscale) loss_1 += photo_loss if log_output: log_output_tensorboard(tb_writer, "train", i, ref_index, epoch, depth[0], disparity[0], *to_log) loss_2 += grad_diffusion_loss(disparity, tgt_imgs, args.kappa) if args.log_output and i < len(val_loader) - 1: step = args.test_batch_size * (args.sequence_length - 1) poses_values[i * step:(i + 1) * step] = poses[:, :-1].cpu().view( -1, 6).numpy() step = args.test_batch_size * 3 disp_unraveled = disparity.cpu().view(args.test_batch_size, -1) disp_values[i * step:(i + 1) * step] = torch.cat([ disp_unraveled.min(-1)[0], disp_unraveled.median(-1)[0], disp_unraveled.max(-1)[0] ]).numpy() loss = w1 * loss_1 + w2 * loss_2 losses.update([loss.item(), loss_1.item(), loss_2.item()]) # measure elapsed time batch_time.update(time.time() - end) end = time.time() logger.valid_bar.update(i + 1) if i % args.print_freq == 0: logger.valid_writer.write('valid: Time {} Loss {}'.format( batch_time, losses)) if args.log_output: rot_coeffs = ['rx', 'ry', 'rz'] if args.rotation_mode == 'euler' else [ 'qx', 'qy', 'qz' ] tr_coeffs = ['tx', 'ty', 'tz'] for k, (coeff_name) in enumerate(tr_coeffs + rot_coeffs): tb_writer.add_histogram('val poses_{}'.format(coeff_name), poses_values[:, k], epoch) tb_writer.add_histogram('disp_values', disp_values, epoch) logger.valid_bar.update(len(val_loader)) return OrderedDict( zip(['Total loss', 'Photo loss', 'Smooth loss'], losses.avg))
def train_one_epoch(args, train_loader, depth_net, pose_net, optimizer, epoch, n_iter, logger, tb_writer, **env): global device logger.reset_train_bar() batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter(precision=4) w1, w2, w3 = args.photo_loss_weight, args.smooth_loss_weight, args.ssim e1, e2 = args.training_milestones # switch to train mode depth_net.train() pose_net.train() end = time.time() logger.train_bar.update(0) for i, sample in enumerate(train_loader): log_losses = i > 0 and n_iter % args.print_freq == 0 log_output = args.training_output_freq > 0 and n_iter % args.training_output_freq == 0 # measure data loading time data_time.update(time.time() - end) imgs = torch.stack(sample['imgs'], dim=1).to(device) intrinsics = sample['intrinsics'].to(device) batch_size, seq = imgs.size()[:2] if args.network_input_size is not None: h, w = args.network_input_size downsample_imgs = F.interpolate(imgs, (3, h, w), mode='area') poses = pose_net(downsample_imgs) # [B, seq, 6] else: poses = pose_net(imgs) pose_matrices = pose_vec2mat(poses, args.rotation_mode) # [B, seq, 3, 4] total_indices = torch.arange(seq, dtype=torch.int64, device=device).expand(batch_size, seq) batch_range = torch.arange(batch_size, dtype=torch.int64, device=device) ''' for each element of the batch select a random picture in the sequence to which we will compute the depth, all poses are then converted so that pose of this very picture is exactly identity. At first this image is always in the middle of the sequence''' if epoch > e2: tgt_id = torch.randint(0, seq, (batch_size, ), device=device) else: tgt_id = torch.full_like(batch_range, args.sequence_length // 2) ref_ids = total_indices[total_indices != tgt_id.unsqueeze(1)].view( batch_size, seq - 1) ''' Select what other picture we are going to feed DepthNet, it must not be the same as tgt_id. At first, it's always first picture of the sequence, it is randomly chosen when first training milestone is reached ''' if epoch > e1: probs = torch.ones_like(total_indices, dtype=torch.float32) probs[batch_range, tgt_id] = args.same_ratio prior_id = torch.multinomial(probs, 1)[:, 0] else: prior_id = torch.zeros_like(batch_range) # Treat the case of prior_id == tgt_id and the depth must be max_depth, regardless of apparent movement tgt_imgs = imgs[batch_range, tgt_id] # [B, 3, H, W] tgt_poses = pose_matrices[batch_range, tgt_id] # [B, 3, 4] prior_imgs = imgs[batch_range, prior_id] compensated_poses = compensate_pose( pose_matrices, tgt_poses) # [B, seq, 3, 4] tgt_poses are now neutral pose prior_poses = compensated_poses[batch_range, prior_id] # [B, 3, 4] if args.supervise_pose: from_GT = invert_mat(sample['pose']).to(device) compensated_GT_poses = compensate_pose( from_GT, from_GT[batch_range, tgt_id]) prior_GT_poses = compensated_GT_poses[batch_range, prior_id] prior_imgs_compensated = inverse_rotate(prior_imgs, prior_GT_poses[:, :, :-1], intrinsics) else: prior_imgs_compensated = inverse_rotate(prior_imgs, prior_poses[:, :, :-1], intrinsics) input_pair = torch.cat([prior_imgs_compensated, tgt_imgs], dim=1) # [B, 6, W, H] depth = depth_net(input_pair) # depth = [sample['depth'].to(device).unsqueeze(1) * 3 / abs(tgt_id[0] - prior_id[0])] # depth.append(torch.nn.functional.interpolate(depth[0], scale_factor=2)) disparities = [1 / d for d in depth] predicted_magnitude = prior_poses[:, :, -1:].norm(p=2, dim=1, keepdim=True).unsqueeze(1) scale_factor = args.nominal_displacement / (predicted_magnitude + 1e-5) normalized_translation = compensated_poses[:, :, :, -1:] * scale_factor # [B, seq_length-1, 3] new_pose_matrices = torch.cat( [compensated_poses[:, :, :, :-1], normalized_translation], dim=-1) biggest_scale = depth[0].size(-1) # Construct valid sequence to compute photometric error, # make the rest converge to max_depth because nothing moved vb = batch_range[prior_id != tgt_id] same_range = batch_range[prior_id == tgt_id] # batch of still pairs loss_1 = 0 loss_1_same = 0 for k, scaled_depth in enumerate(depth): size_ratio = scaled_depth.size(-1) / biggest_scale if len(same_range) > 0: # Frames are identical. The corresponding depth must be infinite. Here, we set it to max depth still_depth = scaled_depth[same_range] loss_same = F.smooth_l1_loss(still_depth / args.max_depth, torch.ones_like(still_depth)) else: loss_same = 0 loss_valid, *to_log = photometric_reconstruction_loss( imgs[vb], tgt_id[vb], ref_ids[vb], scaled_depth[vb], new_pose_matrices[vb], intrinsics[vb], args.rotation_mode, ssim_weight=w3, upsample=args.upscale) loss_1 += loss_valid * size_ratio loss_1_same += loss_same * size_ratio if log_output and len(vb) > 0: log_output_tensorboard(tb_writer, "train", 0, k, n_iter, scaled_depth[0], disparities[k][0], *to_log) loss_2 = grad_diffusion_loss(disparities, tgt_imgs, args.kappa) loss = w1 * (loss_1 + loss_1_same) + w2 * loss_2 if args.supervise_pose: loss += (from_GT[:, :, :, :3] - pose_matrices[:, :, :, :3]).abs().mean() if log_losses: tb_writer.add_scalar('photometric_error', loss_1.item(), n_iter) tb_writer.add_scalar('disparity_smoothness_loss', loss_2.item(), n_iter) tb_writer.add_scalar('total_loss', loss.item(), n_iter) if log_output and len(vb) > 0: valid_poses = poses[vb] nominal_translation_magnitude = valid_poses[:, -2, :3].norm(p=2, dim=-1) # Log the translation magnitude relative to translation magnitude between last and penultimate frames # for a perfectly constant displacement magnitude, you should get ratio of 2,3,4 and so forth. # last pose is always identity and penultimate translation magnitude is always 1, so you don't need to log them for j in range(args.sequence_length - 2): trans_mag = valid_poses[:, j, :3].norm(p=2, dim=-1) tb_writer.add_histogram( 'tr {}'.format(j), (trans_mag / nominal_translation_magnitude).detach().cpu().numpy(), n_iter) for j in range(args.sequence_length - 1): # TODO log a better value : this is magnitude of vector (yaw, pitch, roll) which is not a physical value rot_mag = valid_poses[:, j, 3:].norm(p=2, dim=-1) tb_writer.add_histogram('rot {}'.format(j), rot_mag.detach().cpu().numpy(), n_iter) tb_writer.add_image('train Input', tensor2array(tgt_imgs[0]), n_iter) # record loss for average meter losses.update(loss.item(), args.batch_size) # compute gradient and do Adam step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() with open(args.save_path / args.log_full, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([loss.item(), loss_1.item(), loss_2.item()]) logger.train_bar.update(i + 1) if i % args.print_freq == 0: logger.train_writer.write('Train: Time {} Data {} Loss {}'.format( batch_time, data_time, losses)) if i >= args.epoch_size - 1: break n_iter += 1 return losses.avg[0], n_iter
def train(args, train_loader, disp_net, pose_exp_net, optimizer, epoch_size, logger, train_writer): global n_iter, device batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter(precision=4) w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight # switch to train mode disp_net.train() pose_exp_net.train() end = time.time() logger.train_bar.update(0) #train main cycle for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(train_loader): #for (i, data) in enumerate(train_loader):#data(list): [tensor(B,3,H,W),list(B),(B,H,W),(b,h,w)] log_losses = i > 0 and n_iter % args.print_freq == 0 log_output = args.training_output_freq > 0 and n_iter % args.training_output_freq == 0 #1 measure data loading time data_time.update(time.time() - end) tgt_img = tgt_img.to(device) #(4,3,128,416) ref_imgs = [img.to(device) for img in ref_imgs] #batch size张图片的前一帧和后一帧 intrinsics = intrinsics.to(device) #(4,3,3) """forward and loss""" #2 compute output disparities = disp_net( tgt_img ) # lenth batch-size list of tensor(4,1,128,416) ,(4,1,64,208),(4,1,32,104),(4,1,16,52)] explainability_mask, pose = pose_exp_net( tgt_img, ref_imgs) #pose tensor(bs,sq-lenth-1,6), relative camera pose depth = [1 / disp for disp in disparities] #depth = fxT/(d) 成反比关系,简单取倒数 #3 loss compute loss_1, warped, diff = photometric_reconstruction_loss( tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose, args.rotation_mode, args.padding_mode) if w2 > 0: loss_2 = explainability_loss(explainability_mask) else: loss_2 = 0 loss_3 = smooth_loss(depth) loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 #4. 数据记录 tensorboard batch-record data, 而且不用初始化数据名称(自动初始化),直接往里面加 if log_losses: train_writer.add_scalar('photometric_error', loss_1.item(), n_iter) if w2 > 0: train_writer.add_scalar('explanabilityyyyyy_loss', loss_2.item(), n_iter) train_writer.add_scalar('disparity_smoothness_loss', loss_3.item(), n_iter) train_writer.add_scalar('total_loss', loss.item(), n_iter) if log_output: #数据弄到tensorboard可读文件里去, 名字就是events开头(defaulted) train_writer.add_image('train Input', tensor2array(tgt_img[0]), n_iter) for k, scaled_maps in enumerate( zip(depth, disparities, warped, diff, explainability_mask)): log_output_tensorboard(train_writer, "train", k, n_iter, *scaled_maps) # record loss and EPE losses.update(loss.item(), args.batch_size) # compute gradient and do Adam step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() #csv record with open(args.save_path / args.log_full, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([ loss.item(), loss_1.item(), loss_2.item() if w2 > 0 else 0, loss_3.item() ]) logger.train_bar.update(i + 1) if i % args.print_freq == 0: logger.train_writer.write('Train: Time {} Data {} Loss {}'.format( batch_time, data_time, losses)) if i >= epoch_size - 1: break n_iter += 1 return losses.avg[0]
def validate_without_gt(args, val_loader, disp_net, pose_exp_net, epoch, logger, tb_writer, sample_nb_to_log=3): global device batch_time = AverageMeter() losses = AverageMeter(i=3, precision=4) log_outputs = sample_nb_to_log > 0 w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight poses = np.zeros( ((len(val_loader) - 1) * args.batch_size * (args.sequence_length - 1), 6)) disp_values = np.zeros(((len(val_loader) - 1) * args.batch_size * 3)) # switch to evaluate mode disp_net.eval() pose_exp_net.eval() end = time.time() logger.valid_bar.update(0) for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(val_loader): tgt_img = tgt_img.to(device) ref_imgs = [img.to(device) for img in ref_imgs] intrinsics = intrinsics.to(device) intrinsics_inv = intrinsics_inv.to(device) # compute output disp = disp_net(tgt_img) depth = 1 / disp explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs) loss_1, warped, diff = photometric_reconstruction_loss( tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose, args.rotation_mode, args.padding_mode) loss_1 = loss_1.item() if w2 > 0: loss_2 = explainability_loss(explainability_mask).item() else: loss_2 = 0 loss_3 = smooth_loss(depth).item() if log_outputs and i < sample_nb_to_log - 1: # log first output of first batches if epoch == 0: for j, ref in enumerate(ref_imgs): tb_writer.add_image('val Input {}/{}'.format(j, i), tensor2array(tgt_img[0]), 0) tb_writer.add_image('val Input {}/{}'.format(j, i), tensor2array(ref[0]), 1) log_output_tensorboard(tb_writer, 'val', i, '', epoch, 1. / disp, disp, warped[0], diff[0], explainability_mask) if log_outputs and i < len(val_loader) - 1: step = args.batch_size * (args.sequence_length - 1) poses[i * step:(i + 1) * step] = pose.cpu().view(-1, 6).numpy() step = args.batch_size * 3 disp_unraveled = disp.cpu().view(args.batch_size, -1) disp_values[i * step:(i + 1) * step] = torch.cat([ disp_unraveled.min(-1)[0], disp_unraveled.median(-1)[0], disp_unraveled.max(-1)[0] ]).numpy() loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 if args.with_photocon_loss: batch_size = pose.size()[0] homo_row = torch.tensor([[0, 0, 0, 1]], dtype=torch.float).to(device) homo_row = homo_row.unsqueeze(0).expand(batch_size, -1, -1) T21 = pose_vec2mat(pose[:, 0]) T21 = torch.cat((T21, homo_row), 1) T12 = torch.inverse(T21) T23 = pose_vec2mat(pose[:, 1]) T23 = torch.cat((T23, homo_row), 1) T13 = torch.matmul(T23, T12) #[B,4,4] # print("----",T13.size()) # target = 1(ref_imgs[0]) and ref = 3(ref_imgs[1]) ref_img_warped, valid_points = inverse_warp_posemat( ref_imgs[1], depth[:, 0], T13, intrinsics, args.rotation_mode, args.padding_mode) diff = (ref_imgs[0] - ref_img_warped) * valid_points.unsqueeze(1).float() loss_4 = diff.abs().mean() loss += loss_4 losses.update([loss, loss_1, loss_2]) # measure elapsed time batch_time.update(time.time() - end) end = time.time() logger.valid_bar.update(i + 1) if i % args.print_freq == 0: logger.valid_writer.write('valid: Time {} Loss {}'.format( batch_time, losses)) if log_outputs: prefix = 'valid poses' coeffs_names = ['tx', 'ty', 'tz'] if args.rotation_mode == 'euler': coeffs_names.extend(['rx', 'ry', 'rz']) elif args.rotation_mode == 'quat': coeffs_names.extend(['qx', 'qy', 'qz']) for i in range(poses.shape[1]): tb_writer.add_histogram('{} {}'.format(prefix, coeffs_names[i]), poses[:, i], epoch) tb_writer.add_histogram('disp_values', disp_values, epoch) logger.valid_bar.update(len(val_loader)) return losses.avg, [ 'Validation Total loss', 'Validation Photo loss', 'Validation Exp loss' ]
def train(args, train_loader, disp_net, pose_exp_net, optimizer, epoch_size, logger, tb_writer): global n_iter, device batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter(precision=4) w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight # switch to train mode disp_net.train() pose_exp_net.train() end = time.time() logger.train_bar.update(0) for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(train_loader): log_losses = i > 0 and n_iter % args.print_freq == 0 log_output = args.training_output_freq > 0 and n_iter % args.training_output_freq == 0 # measure data loading time data_time.update(time.time() - end) tgt_img = tgt_img.to(device) ref_imgs = [img.to(device) for img in ref_imgs] intrinsics = intrinsics.to(device) # compute output disparities = disp_net(tgt_img) depth = [1 / disp for disp in disparities] # print("***",len(depth),depth[0].size()) explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs) loss_1, warped, diff = photometric_reconstruction_loss( tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose, args.rotation_mode, args.padding_mode) if w2 > 0: loss_2 = explainability_loss(explainability_mask) else: loss_2 = 0 loss_3 = smooth_loss(depth) loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 if args.with_photocon_loss: batch_size = pose.size()[0] homo_row = torch.tensor([[0, 0, 0, 1]], dtype=torch.float).to(device) homo_row = homo_row.unsqueeze(0).expand(batch_size, -1, -1) T21 = pose_vec2mat(pose[:, 0]) T21 = torch.cat((T21, homo_row), 1) T12 = torch.inverse(T21) T23 = pose_vec2mat(pose[:, 1]) T23 = torch.cat((T23, homo_row), 1) T13 = torch.matmul(T23, T12) #[B, 4, 4] # print("----",T13.size()) # target = 1 and ref = 3 ref_img_warped, valid_points = inverse_warp_posemat( ref_imgs[1], depth[0][:, 0], T13, intrinsics, args.rotation_mode, args.padding_mode) diff = (ref_imgs[0] - ref_img_warped) * valid_points.unsqueeze(1).float() loss_4 = diff.abs().mean() loss += loss_4 if log_losses: tb_writer.add_scalar('photometric_error', loss_1.item(), n_iter) if w2 > 0: tb_writer.add_scalar('explanability_loss', loss_2.item(), n_iter) tb_writer.add_scalar('disparity_smoothness_loss', loss_3.item(), n_iter) tb_writer.add_scalar('total_loss', loss.item(), n_iter) if log_output: tb_writer.add_image('train Input', tensor2array(tgt_img[0]), n_iter) for k, scaled_maps in enumerate( zip(depth, disparities, warped, diff, explainability_mask)): log_output_tensorboard(tb_writer, "train", 0, " {}".format(k), n_iter, *scaled_maps) # record loss and EPE losses.update(loss.item(), args.batch_size) # compute gradient and do Adam step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() with open(args.save_path / args.log_full, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([ loss.item(), loss_1.item(), loss_2.item() if w2 > 0 else 0, loss_3.item() ]) logger.train_bar.update(i + 1) if i % args.print_freq == 0: logger.train_writer.write('Train: Time {} Data {} Loss {}'.format( batch_time, data_time, losses)) if i >= epoch_size - 1: break n_iter += 1 return losses.avg[0]