예제 #1
0
 def forward(self, features, rois):
     if self.use_torchvision:
         from torchvision.ops import roi_align as tv_roi_align
         return tv_roi_align(features, rois, self.out_size,
                             self.spatial_scale, self.sample_num)
     else:
         return roi_align(features, rois, self.out_size, self.spatial_scale,
                          self.sample_num)
예제 #2
0
 def forward(self, input, rois):
     """
     Args:
         input: NCHW images
         rois: Bx5 boxes. First column is the index into N.\
             The other 4 columns are xyxy.
     """
     if self.use_torchvision:
         from torchvision.ops import roi_align as tv_roi_align
         if 'aligned' in tv_roi_align.__code__.co_varnames:
             return tv_roi_align(input, rois, self.output_size,
                                 self.spatial_scale, self.sampling_ratio,
                                 self.aligned)
         else:
             if self.aligned:
                 rois -= rois.new_tensor([0.] +
                                         [0.5 / self.spatial_scale] * 4)
             return tv_roi_align(input, rois, self.output_size,
                                 self.spatial_scale, self.sampling_ratio)
     else:
         return roi_align(input, rois, self.output_size, self.spatial_scale,
                          self.sampling_ratio, self.pool_mode, self.aligned)
예제 #3
0
 def forward(self, input, rois):
     """
     Args:
         input: NCHW images
         rois: Bx5 boxes. First column is the index into N.\
             The other 4 columns are xyxy.
     """
     if self.use_torchvision:
         from torchvision.ops import roi_align as tv_roi_align
         return tv_roi_align(input, rois, self.output_size,
                             self.spatial_scale, self.sampling_ratio)
     else:
         return roi_align(input, rois, self.output_size, self.spatial_scale,
                          self.sampling_ratio, self.pool_mode, self.aligned)
예제 #4
0
    def forward(self, input, boxes):
        """
        Apply torchvision.roi_align

        Args:
            input (Tensor): shape (N x C x H x W)
            boxes (Tensor): boxes in pooling format (image_index, x1, y1, x2, y2), shape (M x 5)

        Returns:
            output (Tensor): pooled roi feature map, shape (M x C x out_size x out_size)
        """
        return tv_roi_align(input, boxes, (self.output_size, self.output_size),
                            self.spatial_scale, self.sampling_ratio,
                            self.aligned)
예제 #5
0
    def forward(self, features, rois):
        """
        Args:
            features: NCHW images
            rois: Bx5 boxes. First column is the index into N. The other 4
            columns are xyxy.
        """
        assert rois.dim() == 2 and rois.size(1) == 5

        if self.use_torchvision:
            from torchvision.ops import roi_align as tv_roi_align
            return tv_roi_align(features, rois, self.out_size,
                                self.spatial_scale, self.sample_num)
        else:
            return roi_align(features, rois, self.out_size, self.spatial_scale,
                             self.sample_num, self.aligned)
예제 #6
0
 def forward(self, input, rois):
     if self.use_torchvision:
         from torchvision.ops import roi_align as tv_roi_align
         return tv_roi_align(input, rois, _pair(self.out_size), self.spatial_scale, self.sampling_ratio)
     else:
         return roi_align(input, rois, self.output_size, self.spatial_scale, self.sampling_ratio)