示例#1
0
 def forward(self, x):
     with torch.jit.scope('Sequential[base]'):
         sources, x = get_multiple_outputs(self.base, x, self.out_layers)
     return sources, x
示例#2
0
 def get_out_channels(self):
     self.eval()
     dummy = torch.ones((1, 3, 300, 300), dtype=torch.float)
     sources, _ = get_multiple_outputs(self.base, dummy, self.out_layers)
     return [x.size(1) for x in sources]
示例#3
0
    def forward(self, img):
        """
        Args:
            img: torch.tensor(:shape [Batch, Channel, Height, Width])
        Returns:
            prediction: tuple of
                torch.tensor(:shape [Batch, AnchorBoxes * Classes])
                torch.tensor(:shape [Batch, AnchorBoxes * 4])
                torch.tensor(:shape [AnchorBoxes, 4])
        """
        scores = []
        locs = []

        # backward compatibility
        # ToDo: remove
        if isinstance(self.features, nn.Sequential):
            from bf.utils.torch_utils import get_multiple_outputs
            with torch.jit.scope('Sequential[features]'):
                sources, x = get_multiple_outputs(self.features, img, self.source_layers)
        else:
            with torch.jit.scope(f'{type(self.features).__name__}[features]'):
                sources, x = self.features(img)

        with torch.jit.scope('Sequential[extras]'):
            for i, layer in enumerate(self.extras):
                with torch.jit.scope(f'_item[{i}]'):
                    x = layer(x)
                sources.append(x)

        class_sources = loc_sources = sources

        # backward compatibility
        # ToDo: remove
        if hasattr(self, 'predictor_conv'):
            for class_conv, loc_conv, class_norm, loc_norm in zip(self.predictor_conv['class'],
                                                                  self.predictor_conv['loc'],
                                                                  self.predictor_norm['class'],
                                                                  self.predictor_norm['loc']):
                class_sources = map(class_conv, class_sources)
                loc_sources = map(loc_conv, loc_sources)

                class_sources = map(self.predictor_activation, class_sources)
                loc_sources = map(self.predictor_activation, loc_sources)

                class_sources = [norm(x) for norm, x in zip(class_norm, class_sources)]
                loc_sources = [norm(x) for norm, x in zip(loc_norm, loc_sources)]

        for i, (head, class_source, loc_source) in enumerate(zip(self.heads, class_sources, loc_sources)):
            with torch.jit.scope(f'ModuleList[heads]/ModuleDict[{i}]'):
                with torch.jit.scope(f'_item[class]'):
                    scores.append(
                        head['class'](class_source)
                            .permute((0, 2, 3, 1))
                            .contiguous()
                            .view(class_source.size(0), -1))
                with torch.jit.scope(f'_item[loc]'):
                    locs.append(
                        head['loc'](loc_source)
                            .permute((0, 2, 3, 1))
                            .contiguous()
                            .view(loc_source.size(0), -1))

        scores = torch.cat(scores, dim=1)
        locs = torch.cat(locs, dim=1)

        return scores, locs, loc_sources