def test_fb_consistency_with_occlusion(self): batch_size = 4 height = 64 width = 64 # flows points right and up by 4 flow_01 = np.ones((batch_size, height, width, 2)) * 4. # flow points left and down by 2 imperfect_flow_10 = -flow_01 * .5 flow_01 = tf.convert_to_tensor(value=flow_01.astype(np.float32)) flow_01_level1 = tf.image.resize(flow_01, (height // 2, width // 2)) / 2. imperfect_flow_10 = tf.convert_to_tensor( value=imperfect_flow_10.astype(np.float32)) imperfect_flow_10_level1 = -flow_01_level1 * .5 flows = {} flows[(0, 1, 0)] = [flow_01, flow_01_level1] flows[(1, 0, 0)] = [imperfect_flow_10, imperfect_flow_10_level1] _, _, _, not_occluded_masks, _, _ = \ uflow_utils.compute_warps_and_occlusion( flows, occlusion_estimation='brox') # assert that everything is occluded is_zeros_01 = np.equal( np.zeros((batch_size, height - 8, width - 8, 1)), not_occluded_masks[(0, 1, 0)][0][:, 4:-4, 4:-4, :]).all() is_zeros_10 = np.equal( np.zeros((batch_size, height - 8, width - 8, 1)), not_occluded_masks[(1, 0, 0)][0][:, 4:-4, 4:-4, :]).all() self.assertTrue(is_zeros_01) self.assertTrue(is_zeros_10)
def infer_occlusion(self, flow_forward, flow_backward): """Gets a 'soft' occlusion mask from the forward and backward flow.""" flows = { (0, 1, 'inference'): [flow_forward], (1, 0, 'inference'): [flow_backward], } _, _, _, occlusion_masks, _, _ = uflow_utils.compute_warps_and_occlusion( flows, self._occlusion_estimation, self._occ_weights, self._occ_thresholds, self._occ_clip_max, occlusions_are_zeros=False) occlusion_mask_forward = occlusion_masks[(0, 1, 'inference')][0] return occlusion_mask_forward
def compute_loss(self, batch, weights, plot_dir=None, distance_metrics=None, ground_truth_flow=None, ground_truth_valid=None, ground_truth_occlusions=None, images_without_photo_aug=None, occ_active=None): """Applies the model and computes losses for a batch of image sequences.""" # Compute only a supervised loss. if self._train_with_supervision: if ground_truth_flow is None: raise ValueError( 'Need ground truth flow to compute supervised loss.') flows = uflow_utils.compute_flow_for_supervised_loss( self._feature_model, self._flow_model, batch=batch, training=True) losses = uflow_utils.supervised_loss(weights, ground_truth_flow, ground_truth_valid, flows) losses = {key + '-loss': losses[key] for key in losses} return losses # Use possibly augmented images if non augmented version is not provided. if images_without_photo_aug is None: images_without_photo_aug = batch flows, selfsup_transform_fns = uflow_utils.compute_features_and_flow( self._feature_model, self._flow_model, batch=batch, batch_without_aug=images_without_photo_aug, training=True, build_selfsup_transformations=self._build_selfsup_transformations, teacher_feature_model=self._teacher_feature_model, teacher_flow_model=self._teacher_flow_model, teacher_image_version=self._teacher_image_version, ) # Prepare images for unsupervised loss (prefer unaugmented images). images = dict() seq_len = int(batch.shape[1]) images = {i: images_without_photo_aug[:, i] for i in range(seq_len)} # Warp stuff and compute occlusion. warps, valid_warp_masks, _, not_occluded_masks, fb_sq_diff, fb_sum_sq = uflow_utils.compute_warps_and_occlusion( flows, occlusion_estimation=self._occlusion_estimation, occ_weights=self._occ_weights, occ_thresholds=self._occ_thresholds, occ_clip_max=self._occ_clip_max, occlusions_are_zeros=True, occ_active=occ_active) # Warp images and features. warped_images = uflow_utils.apply_warps_stop_grad(images, warps, level=0) # Compute losses. losses = uflow_utils.compute_loss( weights=weights, images=images, flows=flows, warps=warps, valid_warp_masks=valid_warp_masks, not_occluded_masks=not_occluded_masks, fb_sq_diff=fb_sq_diff, fb_sum_sq=fb_sum_sq, warped_images=warped_images, only_forward=self._only_forward, selfsup_transform_fns=selfsup_transform_fns, fb_sigma_teacher=self._fb_sigma_teacher, fb_sigma_student=self._fb_sigma_student, plot_dir=plot_dir, distance_metrics=distance_metrics, smoothness_edge_weighting=self._smoothness_edge_weighting, stop_gradient_mask=self._stop_gradient_mask, selfsup_mask=self._selfsup_mask, ground_truth_occlusions=ground_truth_occlusions, smoothness_at_level=self._smoothness_at_level) losses = {key + '-loss': losses[key] for key in losses} return losses