예제 #1
0
 def _compute_photometric_losses(
     self,
     inputs: Tensor,
     disparities: Dict[int, Tensor],
     poses: Dict[int, Tuple[Tensor, Tensor]],
 ) -> Tensor:
     identity_reprojection_loss = (
         self._compute_identity_reprojection_loss(inputs))
     total_loss: Tensor = 0
     for scale in self.hparams.scales:
         scale_disparity = disparities[scale]
         scale_reprojection_loss = self._compute_scale_reprojection_loss(
             inputs,
             poses,
             scale_disparity,
         )
         total_loss += tmin(cat(
             (identity_reprojection_loss, scale_reprojection_loss),
             dim=1,
         ),
                            dim=1)[0].mean()
         total_loss += self._compute_smooth_loss(
             inputs,
             scale_disparity,
             scale,
         )
     total_loss /= len(self.hparams.scales)
     return total_loss
예제 #2
0
 def _compute_identity_reprojection_loss(self, inputs: Tensor) -> Tensor:
     """
     Compute identity reprojection losses between source & target images.
     """
     identity_reprojection_loss: Tensor = None
     for bsid in self.batch_sources_id:
         il = self._reprojection_loss(
             inputs[:, self.batch_target_id],
             inputs[:, bsid],
         )
         il += randn(il.shape, device=il.device) * 1e-5
         if identity_reprojection_loss is None:
             identity_reprojection_loss = il
             continue
         identity_reprojection_loss = tmin(
             cat((il, identity_reprojection_loss), dim=1),
             dim=1,
             keepdim=True,
         )[0]
     return identity_reprojection_loss
예제 #3
0
def dice_loss_binary(outputs=None, target=None, beta=1, weights=None):
    """
    :param weights: element-wise weights
    :param outputs:
    :param target:
    :param beta: More beta, better precision. 1 is neutral
    :return:
    """
    from torch import min as tmin
    smooth = 1.0
    if weights is not None:
        w = weights.contiguous().float().view(-1)
        if tmin(w).item() == 0:
            w += smooth
    else:
        w = 1.0

    iflat = outputs.contiguous().float().view(-1)
    tflat = target.contiguous().float().view(-1)
    intersection = (iflat * tflat * w).sum()

    return 1. - (((1 + beta**2) * intersection) + smooth) / ((
        (beta**2 * (w * iflat).sum()) + (w * tflat).sum()) + smooth)
예제 #4
0
 def _compute_scale_reprojection_loss(
     self,
     inputs: Tensor,
     poses: Dict[int, Tuple[Tensor, Tensor]],
     scale_disparity: Tensor,
 ) -> Tensor:
     """
     Compute reprojection loss between
     scaled disparity (upscaled to original input size)
     and target image.
     """
     scale_reprojection_loss: Tensor = None
     for bsid in self.batch_sources_id:
         axisangle, translation = poses[bsid]
         transformation = transformation_from_parameters(
             axisangle,
             translation,
             invert=bsid < self.batch_target_id,
         )
         warped = self._warp_image(
             inputs[:, bsid],
             scale_disparity,
             transformation,
         )
         reprojection_loss = self._reprojection_loss(
             inputs[:, self.batch_target_id],
             warped,
         )
         if scale_reprojection_loss is None:
             scale_reprojection_loss = reprojection_loss
             continue
         scale_reprojection_loss = tmin(
             cat((scale_reprojection_loss, reprojection_loss), dim=1),
             dim=1,
             keepdim=True,
         )[0]
     return scale_reprojection_loss