Пример #1
0
    def init_weights(self, pretrained=None):
        self._init_weights(self, pretrained)

    def forward(self, x):
        """Defines the computation performed at every call.

        Args:
            x (torch.Tensor): The input data.

        Returns:
            torch.Tensor: The feature of the input
            samples extracted by the backbone.
        """
        res_layer = getattr(self, self.layer_name)
        out = res_layer(x)
        return out

    def train(self, mode=True):
        """Set the optimization status when training."""
        super().train(mode)
        self._freeze_stages()
        if mode and self.norm_eval:
            for m in self.modules():
                if isinstance(m, _BatchNorm):
                    m.eval()


if 'mmdet' in dir():
    MMDET_SHARED_HEADS.register_module()(ResNet3dLayer)
Пример #2
0
            dist.barrier()
        if rank > 0:
            return

        print('Gathering all the roi features...')

        lfb = {}
        for rank_id in range(world_size):
            _lfb_file_path = osp.normpath(
                osp.join(self.lfb_prefix_path,
                         f'_lfb_{self.dataset_mode}_{rank_id}.pkl'))

            # Since each frame will only be distributed to one GPU,
            # the roi features on the same timestamp of the same video are all
            # on the same GPU
            _lfb = torch.load(_lfb_file_path)
            for video_id in _lfb:
                if video_id not in lfb:
                    lfb[video_id] = _lfb[video_id]
                else:
                    lfb[video_id].update(_lfb[video_id])

        lfb_file_path = osp.normpath(
            osp.join(self.lfb_prefix_path, f'lfb_{self.dataset_mode}.pkl'))
        torch.save(lfb, lfb_file_path)
        print(f'LFB has been constructed in {lfb_file_path}!')


if mmdet_imported:
    MMDET_SHARED_HEADS.register_module()(LFBInferHead)
Пример #3
0
        self.fbo.init_weights(pretrained=pretrained)

    def sample_lfb(self, rois, img_metas):
        """Sample long-term features for each ROI feature."""
        inds = rois[:, 0].type(torch.int64)
        lt_feat_list = []
        for ind in inds:
            lt_feat_list.append(self.lfb[img_metas[ind]['img_key']].to())
        lt_feat = torch.stack(lt_feat_list, dim=0)
        # [N, lfb_channels, window_size * max_num_feat_per_step]
        lt_feat = lt_feat.permute(0, 2, 1).contiguous()
        return lt_feat.unsqueeze(-1).unsqueeze(-1)

    def forward(self, x, rois, img_metas, **kwargs):
        # [N, C, 1, 1, 1]
        st_feat = self.temporal_pool(x)
        st_feat = self.spatial_pool(st_feat)
        identity = st_feat

        # [N, C, window_size * num_feat_per_step, 1, 1]
        lt_feat = self.sample_lfb(rois, img_metas).to(st_feat.device)

        fbo_feat = self.fbo(st_feat, lt_feat)

        out = torch.cat([identity, fbo_feat], dim=1)
        return out


if mmdet_imported:
    MMDET_SHARED_HEADS.register_module()(FBOHead)
Пример #4
0
        Args:
            x (torch.Tensor): The extracted RoI feature.
            feat (torch.Tensor): The context feature.
            rois (torch.Tensor): The regions of interest.

        Returns:
            torch.Tensor: The RoI features that have interacted with context
                feature.
        """
        # We use max pooling by default
        x = self.max_pool(x)

        h, w = feat.shape[-2:]
        x_tile = x.repeat(1, 1, 1, h, w)

        roi_inds = rois[:, 0].type(torch.long)
        roi_gfeat = feat[roi_inds]

        new_feat = torch.cat([x_tile, roi_gfeat], dim=1)
        new_feat = self.conv1(new_feat)
        new_feat = self.conv2(new_feat)

        for conv in self.convs:
            new_feat = conv(new_feat)

        return new_feat


if mmdet_imported:
    MMDET_SHARED_HEADS.register_module()(ACRNHead)