def _compute_photometric_losses( self, inputs: Tensor, disparities: Dict[int, Tensor], poses: Dict[int, Tuple[Tensor, Tensor]], ) -> Tensor: identity_reprojection_loss = ( self._compute_identity_reprojection_loss(inputs)) total_loss: Tensor = 0 for scale in self.hparams.scales: scale_disparity = disparities[scale] scale_reprojection_loss = self._compute_scale_reprojection_loss( inputs, poses, scale_disparity, ) total_loss += tmin(cat( (identity_reprojection_loss, scale_reprojection_loss), dim=1, ), dim=1)[0].mean() total_loss += self._compute_smooth_loss( inputs, scale_disparity, scale, ) total_loss /= len(self.hparams.scales) return total_loss
def _compute_identity_reprojection_loss(self, inputs: Tensor) -> Tensor: """ Compute identity reprojection losses between source & target images. """ identity_reprojection_loss: Tensor = None for bsid in self.batch_sources_id: il = self._reprojection_loss( inputs[:, self.batch_target_id], inputs[:, bsid], ) il += randn(il.shape, device=il.device) * 1e-5 if identity_reprojection_loss is None: identity_reprojection_loss = il continue identity_reprojection_loss = tmin( cat((il, identity_reprojection_loss), dim=1), dim=1, keepdim=True, )[0] return identity_reprojection_loss
def dice_loss_binary(outputs=None, target=None, beta=1, weights=None): """ :param weights: element-wise weights :param outputs: :param target: :param beta: More beta, better precision. 1 is neutral :return: """ from torch import min as tmin smooth = 1.0 if weights is not None: w = weights.contiguous().float().view(-1) if tmin(w).item() == 0: w += smooth else: w = 1.0 iflat = outputs.contiguous().float().view(-1) tflat = target.contiguous().float().view(-1) intersection = (iflat * tflat * w).sum() return 1. - (((1 + beta**2) * intersection) + smooth) / (( (beta**2 * (w * iflat).sum()) + (w * tflat).sum()) + smooth)
def _compute_scale_reprojection_loss( self, inputs: Tensor, poses: Dict[int, Tuple[Tensor, Tensor]], scale_disparity: Tensor, ) -> Tensor: """ Compute reprojection loss between scaled disparity (upscaled to original input size) and target image. """ scale_reprojection_loss: Tensor = None for bsid in self.batch_sources_id: axisangle, translation = poses[bsid] transformation = transformation_from_parameters( axisangle, translation, invert=bsid < self.batch_target_id, ) warped = self._warp_image( inputs[:, bsid], scale_disparity, transformation, ) reprojection_loss = self._reprojection_loss( inputs[:, self.batch_target_id], warped, ) if scale_reprojection_loss is None: scale_reprojection_loss = reprojection_loss continue scale_reprojection_loss = tmin( cat((scale_reprojection_loss, reprojection_loss), dim=1), dim=1, keepdim=True, )[0] return scale_reprojection_loss