Exemple #1
0
    def forward_test(self, img, img_metas, return_heatmap=False, **kwargs):
        """Inference the bottom-up model.

        Note:
            Batchsize = N (currently support batchsize = 1)
            num_img_channel: C
            img_weight: imgW
            img_height: imgH

        Args:
            flip_index (List(int)):
            aug_data (List(Tensor[NxCximgHximgW])): Multi-scale image
            test_scale_factor (List(float)): Multi-scale factor
            base_size (Tuple(int)): Base size of image when scale is 1
            center (np.ndarray): center of image
            scale (np.ndarray): the scale of image
        """
        assert img.size(0) == 1
        assert len(img_metas) == 1

        img_metas = img_metas[0]

        aug_data = img_metas['aug_data']

        test_scale_factor = img_metas['test_scale_factor']
        base_size = img_metas['base_size']
        center = img_metas['center']
        scale = img_metas['scale']

        aggregated_heatmaps = None
        tags_list = []
        for idx, s in enumerate(sorted(test_scale_factor, reverse=True)):
            image_resized = aug_data[idx].to(img.device)

            outputs = self.backbone(image_resized)
            outputs = self.keypoint_head(outputs)

            if self.test_cfg['flip_test']:
                # use flip test
                outputs_flip = self.backbone(torch.flip(image_resized, [3]))
                outputs_flip = self.keypoint_head(outputs_flip)
            else:
                outputs_flip = None

            _, heatmaps, tags = get_multi_stage_outputs(
                outputs, outputs_flip, self.test_cfg['num_joints'],
                self.test_cfg['with_heatmaps'], self.test_cfg['with_ae'],
                self.test_cfg['tag_per_joint'], img_metas['flip_index'],
                self.test_cfg['project2image'], base_size)

            aggregated_heatmaps, tags_list = aggregate_results(
                s, aggregated_heatmaps, tags_list, heatmaps, tags,
                test_scale_factor, self.test_cfg['project2image'],
                self.test_cfg['flip_test'])

        # average heatmaps of different scales
        aggregated_heatmaps = aggregated_heatmaps / float(
            len(test_scale_factor))
        tags = torch.cat(tags_list, dim=4)

        # perform grouping
        grouped, scores = self.parser.parse(aggregated_heatmaps, tags,
                                            self.test_cfg['adjust'],
                                            self.test_cfg['refine'])

        results = get_group_preds(
            grouped, center, scale,
            [aggregated_heatmaps.size(3),
             aggregated_heatmaps.size(2)])

        image_path = []
        image_path.extend(img_metas['image_file'])

        if return_heatmap:
            output_heatmap = aggregated_heatmaps.detach().cpu().numpy()
        else:
            output_heatmap = None

        return results, scores, image_path, output_heatmap
    def forward_test(self, img, img_metas, return_heatmap=False, **kwargs):
        """Inference the bottom-up model.

        Note:
            - Batchsize: N (currently support batchsize = 1)
            - num_img_channel: C
            - img_width: imgW
            - img_height: imgH

        Args:
            flip_index (List(int)):
            aug_data (List(Tensor[NxCximgHximgW])): Multi-scale image
            test_scale_factor (List(float)): Multi-scale factor
            base_size (Tuple(int)): Base size of image when scale is 1
            center (np.ndarray): center of image
            scale (np.ndarray): the scale of image
        """
        assert img.size(0) == 1
        assert len(img_metas) == 1

        img_metas = img_metas[0]

        aug_data = img_metas['aug_data']

        test_scale_factor = img_metas['test_scale_factor']
        base_size = img_metas['base_size']
        center = img_metas['center']
        scale = img_metas['scale']

        result = {}

        scale_heatmaps_list = []
        scale_tags_list = []

        for idx, s in enumerate(sorted(test_scale_factor, reverse=True)):
            image_resized = aug_data[idx].to(img.device)

            features = self.backbone(image_resized)
            if self.with_keypoint:
                outputs = self.keypoint_head(features)

            heatmaps, tags = split_ae_outputs(
                outputs, self.test_cfg['num_joints'],
                self.test_cfg['with_heatmaps'], self.test_cfg['with_ae'],
                self.test_cfg.get('select_output_index', range(len(outputs))))

            if self.test_cfg.get('flip_test', True):
                # use flip test
                features_flipped = self.backbone(torch.flip(
                    image_resized, [3]))
                if self.with_keypoint:
                    outputs_flipped = self.keypoint_head(features_flipped)

                heatmaps_flipped, tags_flipped = split_ae_outputs(
                    outputs_flipped, self.test_cfg['num_joints'],
                    self.test_cfg['with_heatmaps'], self.test_cfg['with_ae'],
                    self.test_cfg.get('select_output_index',
                                      range(len(outputs))))

                heatmaps_flipped = flip_feature_maps(
                    heatmaps_flipped, flip_index=img_metas['flip_index'])
                if self.test_cfg['tag_per_joint']:
                    tags_flipped = flip_feature_maps(
                        tags_flipped, flip_index=img_metas['flip_index'])
                else:
                    tags_flipped = flip_feature_maps(tags_flipped,
                                                     flip_index=None,
                                                     flip_output=True)

            else:
                heatmaps_flipped = None
                tags_flipped = None

            aggregated_heatmaps = aggregate_stage_flip(
                heatmaps,
                heatmaps_flipped,
                index=-1,
                project2image=self.test_cfg['project2image'],
                size_projected=base_size,
                align_corners=self.test_cfg.get('align_corners', True),
                aggregate_stage='average',
                aggregate_flip='average')

            aggregated_tags = aggregate_stage_flip(
                tags,
                tags_flipped,
                index=-1,
                project2image=self.test_cfg['project2image'],
                size_projected=base_size,
                align_corners=self.test_cfg.get('align_corners', True),
                aggregate_stage='concat',
                aggregate_flip='concat')

            if s == 1 or len(test_scale_factor) == 1:
                if isinstance(aggregated_tags, list):
                    scale_tags_list.extend(aggregated_tags)
                else:
                    scale_tags_list.append(aggregated_tags)

            if isinstance(aggregated_heatmaps, list):
                scale_heatmaps_list.extend(aggregated_heatmaps)
            else:
                scale_heatmaps_list.append(aggregated_heatmaps)

        aggregated_heatmaps = aggregate_scale(scale_heatmaps_list,
                                              align_corners=self.test_cfg.get(
                                                  'align_corners', True),
                                              aggregate_scale='average')

        aggregated_tags = aggregate_scale(scale_tags_list,
                                          align_corners=self.test_cfg.get(
                                              'align_corners', True),
                                          aggregate_scale='unsqueeze_concat')

        heatmap_size = aggregated_heatmaps.shape[2:4]
        tag_size = aggregated_tags.shape[2:4]
        if heatmap_size != tag_size:
            tmp = []
            for idx in range(aggregated_tags.shape[-1]):
                tmp.append(
                    torch.nn.functional.interpolate(
                        aggregated_tags[..., idx],
                        size=heatmap_size,
                        mode='bilinear',
                        align_corners=self.test_cfg.get('align_corners',
                                                        True)).unsqueeze(-1))
            aggregated_tags = torch.cat(tmp, dim=-1)

        # perform grouping
        grouped, scores = self.parser.parse(aggregated_heatmaps,
                                            aggregated_tags,
                                            self.test_cfg['adjust'],
                                            self.test_cfg['refine'])

        preds = get_group_preds(
            grouped,
            center,
            scale, [aggregated_heatmaps.size(3),
                    aggregated_heatmaps.size(2)],
            use_udp=self.use_udp)

        image_paths = []
        image_paths.append(img_metas['image_file'])

        if return_heatmap:
            output_heatmap = aggregated_heatmaps.detach().cpu().numpy()
        else:
            output_heatmap = None

        result['preds'] = preds
        result['scores'] = scores
        result['image_paths'] = image_paths
        result['output_heatmap'] = output_heatmap

        return result