예제 #1
0
    def forward(self, x: NestedTensor):
        """
        :param x: input data
        :return:
            a dictionary object with keys
            - "disp_pred" [N,H,W]: predicted disparity
            - "occ_pred" [N,H,W]: predicted occlusion mask
            - "disp_pred_low_res" [N,H//s,W//s]: predicted low res (raw) disparity
        """
        bs, _, h, w = x.left.size()

        # extract features
        feat = self.backbone(x)  # concatenate left and right along the dim=0
        tokens = self.tokenizer(feat)  # 2NxCxHxW
        pos_enc = self.pos_encoder(x)  # NxCxHx2W-1

        # separate left and right
        feat_left = tokens[:bs]
        feat_right = tokens[bs:]  # NxCxHxW

        # downsample
        if x.sampled_cols is not None:
            feat_left = batched_index_select(feat_left, 3, x.sampled_cols)
            feat_right = batched_index_select(feat_right, 3, x.sampled_cols)
        if x.sampled_rows is not None:
            feat_left = batched_index_select(feat_left, 2, x.sampled_rows)
            feat_right = batched_index_select(feat_right, 2, x.sampled_rows)

        # transformer
        attn_weight = self.transformer(feat_left, feat_right, pos_enc)

        # regress disparity and occlusion
        output = self.regression_head(attn_weight, x)

        return output
예제 #2
0
    def _compute_gt_location(self, scale: int, sampled_cols: Tensor,
                             sampled_rows: Tensor, attn_weight: Tensor,
                             disp: Tensor):
        """
        Find target locations using ground truth disparity.
        Find ground truth response at those locations using attention weight.

        :param scale: high-res to low-res disparity scale
        :param sampled_cols: index to downsample columns
        :param sampled_rows: index to downsample rows
        :param attn_weight: attention weight (output from _optimal_transport), [N,H,W,W]
        :param disp: ground truth disparity
        :return: response at ground truth location [N,H,W,1] and target ground truth locations [N,H,W,1]
        """
        # compute target location at full res
        _, _, w = disp.size()
        pos_l = torch.linspace(0, w - 1,
                               w)[None, ].to(disp.device)  # 1 x 1 x W (left)
        target = (pos_l - disp)[..., None]  # N x H x W (left) x 1

        if sampled_cols is not None:
            target = batched_index_select(target, 2, sampled_cols)
        if sampled_rows is not None:
            target = batched_index_select(target, 1, sampled_rows)
        target = target / scale  # scale target location

        # compute ground truth response location for rr loss
        gt_response = torch_1d_sample(attn_weight, target,
                                      'linear')  # NxHxW_left

        return gt_response, target
예제 #3
0
    def compute_l1_loss(self,
                        pred: Tensor,
                        inputs: NestedTensor,
                        invalid_mask: Tensor,
                        fullres: bool = True):
        """
        compute smooth l1 loss

        :param pred: disparity prediction [N,H,W]
        :param inputs: input data
        :param invalid_mask: invalid disparities (including occ and places without data), [N,H,W]
        :param fullres: Boolean indicating if prediction is full resolution
        :return: smooth l1 loss
        """
        disp = inputs.disp
        if not fullres:
            if inputs.sampled_cols is not None:
                if invalid_mask is not None:
                    invalid_mask = batched_index_select(
                        invalid_mask, 2, inputs.sampled_cols)
                disp = batched_index_select(disp, 2, inputs.sampled_cols)
            if inputs.sampled_rows is not None:
                if invalid_mask is not None:
                    invalid_mask = batched_index_select(
                        invalid_mask, 1, inputs.sampled_rows)
                disp = batched_index_select(disp, 1, inputs.sampled_rows)

        return self.l1_criterion(pred[~invalid_mask], disp[~invalid_mask])
예제 #4
0
    def compute_rr_loss(self, outputs: dict, inputs: NestedTensor,
                        invalid_mask: Tensor):
        """
        compute rr loss
        
        :param outputs: dictionary, outputs from the network
        :param inputs: input data
        :param invalid_mask: invalid disparities (including occ and places without data), [N,H,W]
        :return: rr loss
        """ ""
        if invalid_mask is not None:
            if inputs.sampled_cols is not None:
                invalid_mask = batched_index_select(invalid_mask, 2,
                                                    inputs.sampled_cols)
            if inputs.sampled_rows is not None:
                invalid_mask = batched_index_select(invalid_mask, 1,
                                                    inputs.sampled_rows)

        # compute rr loss in non-occluded region
        gt_response = outputs['gt_response']
        eps = 1e-6
        rr_loss = -torch.log(gt_response + eps)

        if invalid_mask is not None:
            rr_loss = rr_loss[~invalid_mask]

        # if there is occlusion
        try:
            rr_loss_occ_left = -torch.log(outputs['gt_response_occ_left'] +
                                          eps)
            # print(rr_loss_occ_left.shape)
            rr_loss = torch.cat([rr_loss, rr_loss_occ_left])
        except KeyError:
            pass
        try:
            rr_loss_occ_right = -torch.log(outputs['gt_response_occ_right'] +
                                           eps)
            # print(rr_loss_occ_right.shape)
            rr_loss = torch.cat([rr_loss, rr_loss_occ_right])
        except KeyError:
            pass

        return rr_loss.mean()
예제 #5
0
    def forward(self, attn_weight: Tensor, x: NestedTensor):
        """
        Regression head follows steps of
            - compute scale for disparity (if there is downsampling)
            - impose uniqueness constraint by optimal transport
            - compute RR loss
            - regress disparity and occlusion
            - upsample (if there is downsampling) and adjust based on context
        
        :param attn_weight: raw attention weight, [N,H,W,W]
        :param x: input data
        :return: dictionary of predicted values
        """
        bs, _, h, w = x.left.size()
        output = {}

        # compute scale
        if x.sampled_cols is not None:
            scale = x.left.size(-1) / float(x.sampled_cols.size(-1))
        else:
            scale = 1.0

        # normalize attention to 0-1
        if self.ot:
            # optimal transport
            attn_ot = self._optimal_transport(attn_weight, 10)
        else:
            # softmax
            attn_ot = self._softmax(attn_weight)

        # compute relative response (RR) at ground truth location
        if x.disp is not None:
            # find ground truth response (gt_response) and location (target)
            output['gt_response'], target = self._compute_gt_location(
                scale, x.sampled_cols, x.sampled_rows, attn_ot[..., :-1, :-1],
                x.disp)
        else:
            output['gt_response'] = None

        # compute relative response (RR) at occluded location
        if x.occ_mask is not None:
            # handle occlusion
            occ_mask = x.occ_mask
            occ_mask_right = x.occ_mask_right
            if x.sampled_cols is not None:
                occ_mask = batched_index_select(occ_mask, 2, x.sampled_cols)
                occ_mask_right = batched_index_select(occ_mask_right, 2,
                                                      x.sampled_cols)
            if x.sampled_rows is not None:
                occ_mask = batched_index_select(occ_mask, 1, x.sampled_rows)
                occ_mask_right = batched_index_select(occ_mask_right, 1,
                                                      x.sampled_rows)

            output['gt_response_occ_left'] = attn_ot[..., :-1, -1][occ_mask]
            output['gt_response_occ_right'] = attn_ot[...,
                                                      -1, :-1][occ_mask_right]
        else:
            output['gt_response_occ_left'] = None
            output['gt_response_occ_right'] = None
            occ_mask = x.occ_mask

        # regress low res disparity
        pos_shift = self._compute_unscaled_pos_shift(
            attn_weight.shape[2], attn_weight.device)  # NxHxW_leftxW_right
        disp_pred_low_res, matched_attn = self._compute_low_res_disp(
            pos_shift, attn_ot[..., :-1, :-1], occ_mask)
        # regress low res occlusion
        occ_pred_low_res = self._compute_low_res_occ(matched_attn)

        # with open('attn_weight.dat', 'wb') as f:
        #     torch.save(attn_ot[0], f)
        # with open('target.dat', 'wb') as f:
        #     torch.save(target, f)

        # upsample and context adjust
        if x.sampled_cols is not None:
            output['disp_pred'], output['disp_pred_low_res'], output[
                'occ_pred'] = self._upsample(x, disp_pred_low_res,
                                             occ_pred_low_res, scale)
        else:
            output['disp_pred'] = disp_pred_low_res
            output['occ_pred'] = occ_pred_low_res

        return output