def initialize_flow(self, img): """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" N, C, H, W = img.shape coords0 = coords_grid(N, H // 8, W // 8).to(img.device) coords1 = coords_grid(N, H // 8, W // 8).to(img.device) # optical flow computed as difference: flow = coords1 - coords0 return coords0, coords1
def flow_crop_and_resize(self, flow, ph, pw, nsize=None, mode='bilinear'): if nsize is not None: flow = flow[ph:ph + nsize[0], pw:pw + nsize[1]] else: nsize = (flow.shape[0], flow.shape[1]) flow = flow.permute(2, 0, 1).unsqueeze(0) # gradient check def _gradient_check(_fa, _fb): _fa = _fa.squeeze() _fb = _fb.squeeze() dotab = (_fa * _fb).sum(dim=0) nfa = torch.sqrt((_fa**2).sum(dim=0)) nfb = torch.sqrt((_fb**2).sum(dim=0)) nfab = nfa * nfb # cosine eps = 1e-6 angle = torch.acos((dotab / nfab).abs().clamp(0, 1.0 - eps)) angle_valid = angle <= np.pi / 4 angle_valid[nfab == 0] = True angle_valid[(nfa + nfb) < 2] = True # magnitude mag = torch.abs(nfa - nfb) mag_valid = mag < 50 return (angle_valid * mag_valid).int() gradh = _gradient_check(flow[:, :, :-1, :], flow[:, :, 1:, :]) gradh = F.pad(gradh, (0, 0, 0, 1), value=1) gradw = _gradient_check(flow[:, :, :, :-1], flow[:, :, :, 1:]) gradw = F.pad(gradw, (0, 1, 0, 0), value=1) valid = gradw * gradh # H, W # interpolate sample_scaleh = (nsize[0] - 1) / float( self.image_shape[0] - 1) # for align corners sample_scalew = (nsize[1] - 1) / float(self.image_shape[1] - 1) coords_new = coords_grid(1, *self.image_shape).float() # 1, 2, H, W coordsw = coords_new[:, :1, :, :] * sample_scalew coordsh = coords_new[:, 1:, :, :] * sample_scaleh coords = torch.cat([coordsw, coordsh], dim=1) interp = grid_sampler(flow, coords, mode=mode) # 1, 2, H, W # gradient filter scaleh = nsize[0] / float(self.image_shape[0]) scalew = nsize[1] / float(self.image_shape[1]) cw, ch = torch.floor(coords).split(1, dim=1) # 1, H, W validp = valid[(ch.squeeze().long(), cw.squeeze().long())][None, None, ...] # 1, 1, H, W interp = torch.where(validp.bool(), interp, torch.tensor(np.nan)) # 1, 2, H, W interp[:, 0, :, :] /= scalew interp[:, 1, :, :] /= scaleh # outbound filter flowed_coord = (coords_new + interp).squeeze(0) outbound = (flowed_coord[0] < 0) + (flowed_coord[1] < 0) + \ (flowed_coord[0] > self.image_shape[1]-1) + \ (flowed_coord[1] > self.image_shape[0]-1) # H, W outbound = outbound.bool()[None, None, ...].repeat(1, 2, 1, 1) # 1, 2, H, W interp[outbound] = torch.tensor(np.nan) return interp