def forward(self, laf: torch.Tensor, img: torch.Tensor) -> torch.Tensor: """ Args: laf: shape [BxNx2x3] img: shape [Bx1xHxW] Returns: laf_out, shape [BxNx2x3] """ raise_error_if_laf_is_not_valid(laf) img_message: str = "Invalid img shape, we expect BxCxHxW. Got: {}".format( img.shape) if not isinstance(img, torch.Tensor): raise TypeError("img type is not a torch.Tensor. Got {}".format( type(img))) if len(img.shape) != 4: raise ValueError(img_message) if laf.size(0) != img.size(0): raise ValueError( "Batch size of laf and img should be the same. Got {}, {}". format(img.size(0), laf.size(0))) B, N = laf.shape[:2] patches: torch.Tensor = extract_patches_from_pyramid( img, laf, self.patch_size).view(-1, 1, self.patch_size, self.patch_size) angles_radians: torch.Tensor = self.angle_detector(patches).view(B, N) prev_angle = get_laf_orientation(laf).view_as(angles_radians) laf_out: torch.Tensor = set_laf_orientation( laf, rad2deg(angles_radians) + prev_angle) return laf_out
def forward(self, x: torch.Tensor, mask=None): self.affnet = self.affnet.to(x.device) max_val = x.max() if max_val < 2.0: img_np = (255 * K.tensor_to_image(x)).astype(np.uint8) else: img_np = K.tensor_to_image(x).astype(np.uint8) if mask is not None: mask = K.tensor_to_image(x).astype(np.uint8) kpts = self.features.detect(img_np, mask) lafs, resp = laf_from_opencv_kpts(kpts, mrSize=self.mrSize, with_resp=True, device=x.device) ori = KF.get_laf_orientation(lafs) lafs = self.affnet(lafs, x.mean(dim=1, keepdim=True)) lafs = KF.set_laf_orientation(lafs, ori) return lafs, resp