def forward(self, x: NestedTensor): """ :param x: input data :return: a dictionary object with keys - "disp_pred" [N,H,W]: predicted disparity - "occ_pred" [N,H,W]: predicted occlusion mask - "disp_pred_low_res" [N,H//s,W//s]: predicted low res (raw) disparity """ bs, _, h, w = x.left.size() # extract features feat = self.backbone(x) # concatenate left and right along the dim=0 tokens = self.tokenizer(feat) # 2NxCxHxW pos_enc = self.pos_encoder(x) # NxCxHx2W-1 # separate left and right feat_left = tokens[:bs] feat_right = tokens[bs:] # NxCxHxW # downsample if x.sampled_cols is not None: feat_left = batched_index_select(feat_left, 3, x.sampled_cols) feat_right = batched_index_select(feat_right, 3, x.sampled_cols) if x.sampled_rows is not None: feat_left = batched_index_select(feat_left, 2, x.sampled_rows) feat_right = batched_index_select(feat_right, 2, x.sampled_rows) # transformer attn_weight = self.transformer(feat_left, feat_right, pos_enc) # regress disparity and occlusion output = self.regression_head(attn_weight, x) return output
def _compute_gt_location(self, scale: int, sampled_cols: Tensor, sampled_rows: Tensor, attn_weight: Tensor, disp: Tensor): """ Find target locations using ground truth disparity. Find ground truth response at those locations using attention weight. :param scale: high-res to low-res disparity scale :param sampled_cols: index to downsample columns :param sampled_rows: index to downsample rows :param attn_weight: attention weight (output from _optimal_transport), [N,H,W,W] :param disp: ground truth disparity :return: response at ground truth location [N,H,W,1] and target ground truth locations [N,H,W,1] """ # compute target location at full res _, _, w = disp.size() pos_l = torch.linspace(0, w - 1, w)[None, ].to(disp.device) # 1 x 1 x W (left) target = (pos_l - disp)[..., None] # N x H x W (left) x 1 if sampled_cols is not None: target = batched_index_select(target, 2, sampled_cols) if sampled_rows is not None: target = batched_index_select(target, 1, sampled_rows) target = target / scale # scale target location # compute ground truth response location for rr loss gt_response = torch_1d_sample(attn_weight, target, 'linear') # NxHxW_left return gt_response, target
def compute_l1_loss(self, pred: Tensor, inputs: NestedTensor, invalid_mask: Tensor, fullres: bool = True): """ compute smooth l1 loss :param pred: disparity prediction [N,H,W] :param inputs: input data :param invalid_mask: invalid disparities (including occ and places without data), [N,H,W] :param fullres: Boolean indicating if prediction is full resolution :return: smooth l1 loss """ disp = inputs.disp if not fullres: if inputs.sampled_cols is not None: if invalid_mask is not None: invalid_mask = batched_index_select( invalid_mask, 2, inputs.sampled_cols) disp = batched_index_select(disp, 2, inputs.sampled_cols) if inputs.sampled_rows is not None: if invalid_mask is not None: invalid_mask = batched_index_select( invalid_mask, 1, inputs.sampled_rows) disp = batched_index_select(disp, 1, inputs.sampled_rows) return self.l1_criterion(pred[~invalid_mask], disp[~invalid_mask])
def compute_rr_loss(self, outputs: dict, inputs: NestedTensor, invalid_mask: Tensor): """ compute rr loss :param outputs: dictionary, outputs from the network :param inputs: input data :param invalid_mask: invalid disparities (including occ and places without data), [N,H,W] :return: rr loss """ "" if invalid_mask is not None: if inputs.sampled_cols is not None: invalid_mask = batched_index_select(invalid_mask, 2, inputs.sampled_cols) if inputs.sampled_rows is not None: invalid_mask = batched_index_select(invalid_mask, 1, inputs.sampled_rows) # compute rr loss in non-occluded region gt_response = outputs['gt_response'] eps = 1e-6 rr_loss = -torch.log(gt_response + eps) if invalid_mask is not None: rr_loss = rr_loss[~invalid_mask] # if there is occlusion try: rr_loss_occ_left = -torch.log(outputs['gt_response_occ_left'] + eps) # print(rr_loss_occ_left.shape) rr_loss = torch.cat([rr_loss, rr_loss_occ_left]) except KeyError: pass try: rr_loss_occ_right = -torch.log(outputs['gt_response_occ_right'] + eps) # print(rr_loss_occ_right.shape) rr_loss = torch.cat([rr_loss, rr_loss_occ_right]) except KeyError: pass return rr_loss.mean()
def forward(self, attn_weight: Tensor, x: NestedTensor): """ Regression head follows steps of - compute scale for disparity (if there is downsampling) - impose uniqueness constraint by optimal transport - compute RR loss - regress disparity and occlusion - upsample (if there is downsampling) and adjust based on context :param attn_weight: raw attention weight, [N,H,W,W] :param x: input data :return: dictionary of predicted values """ bs, _, h, w = x.left.size() output = {} # compute scale if x.sampled_cols is not None: scale = x.left.size(-1) / float(x.sampled_cols.size(-1)) else: scale = 1.0 # normalize attention to 0-1 if self.ot: # optimal transport attn_ot = self._optimal_transport(attn_weight, 10) else: # softmax attn_ot = self._softmax(attn_weight) # compute relative response (RR) at ground truth location if x.disp is not None: # find ground truth response (gt_response) and location (target) output['gt_response'], target = self._compute_gt_location( scale, x.sampled_cols, x.sampled_rows, attn_ot[..., :-1, :-1], x.disp) else: output['gt_response'] = None # compute relative response (RR) at occluded location if x.occ_mask is not None: # handle occlusion occ_mask = x.occ_mask occ_mask_right = x.occ_mask_right if x.sampled_cols is not None: occ_mask = batched_index_select(occ_mask, 2, x.sampled_cols) occ_mask_right = batched_index_select(occ_mask_right, 2, x.sampled_cols) if x.sampled_rows is not None: occ_mask = batched_index_select(occ_mask, 1, x.sampled_rows) occ_mask_right = batched_index_select(occ_mask_right, 1, x.sampled_rows) output['gt_response_occ_left'] = attn_ot[..., :-1, -1][occ_mask] output['gt_response_occ_right'] = attn_ot[..., -1, :-1][occ_mask_right] else: output['gt_response_occ_left'] = None output['gt_response_occ_right'] = None occ_mask = x.occ_mask # regress low res disparity pos_shift = self._compute_unscaled_pos_shift( attn_weight.shape[2], attn_weight.device) # NxHxW_leftxW_right disp_pred_low_res, matched_attn = self._compute_low_res_disp( pos_shift, attn_ot[..., :-1, :-1], occ_mask) # regress low res occlusion occ_pred_low_res = self._compute_low_res_occ(matched_attn) # with open('attn_weight.dat', 'wb') as f: # torch.save(attn_ot[0], f) # with open('target.dat', 'wb') as f: # torch.save(target, f) # upsample and context adjust if x.sampled_cols is not None: output['disp_pred'], output['disp_pred_low_res'], output[ 'occ_pred'] = self._upsample(x, disp_pred_low_res, occ_pred_low_res, scale) else: output['disp_pred'] = disp_pred_low_res output['occ_pred'] = occ_pred_low_res return output