def forward(self, x): with torch.jit.scope('Sequential[base]'): sources, x = get_multiple_outputs(self.base, x, self.out_layers) return sources, x
def get_out_channels(self): self.eval() dummy = torch.ones((1, 3, 300, 300), dtype=torch.float) sources, _ = get_multiple_outputs(self.base, dummy, self.out_layers) return [x.size(1) for x in sources]
def forward(self, img): """ Args: img: torch.tensor(:shape [Batch, Channel, Height, Width]) Returns: prediction: tuple of torch.tensor(:shape [Batch, AnchorBoxes * Classes]) torch.tensor(:shape [Batch, AnchorBoxes * 4]) torch.tensor(:shape [AnchorBoxes, 4]) """ scores = [] locs = [] # backward compatibility # ToDo: remove if isinstance(self.features, nn.Sequential): from bf.utils.torch_utils import get_multiple_outputs with torch.jit.scope('Sequential[features]'): sources, x = get_multiple_outputs(self.features, img, self.source_layers) else: with torch.jit.scope(f'{type(self.features).__name__}[features]'): sources, x = self.features(img) with torch.jit.scope('Sequential[extras]'): for i, layer in enumerate(self.extras): with torch.jit.scope(f'_item[{i}]'): x = layer(x) sources.append(x) class_sources = loc_sources = sources # backward compatibility # ToDo: remove if hasattr(self, 'predictor_conv'): for class_conv, loc_conv, class_norm, loc_norm in zip(self.predictor_conv['class'], self.predictor_conv['loc'], self.predictor_norm['class'], self.predictor_norm['loc']): class_sources = map(class_conv, class_sources) loc_sources = map(loc_conv, loc_sources) class_sources = map(self.predictor_activation, class_sources) loc_sources = map(self.predictor_activation, loc_sources) class_sources = [norm(x) for norm, x in zip(class_norm, class_sources)] loc_sources = [norm(x) for norm, x in zip(loc_norm, loc_sources)] for i, (head, class_source, loc_source) in enumerate(zip(self.heads, class_sources, loc_sources)): with torch.jit.scope(f'ModuleList[heads]/ModuleDict[{i}]'): with torch.jit.scope(f'_item[class]'): scores.append( head['class'](class_source) .permute((0, 2, 3, 1)) .contiguous() .view(class_source.size(0), -1)) with torch.jit.scope(f'_item[loc]'): locs.append( head['loc'](loc_source) .permute((0, 2, 3, 1)) .contiguous() .view(loc_source.size(0), -1)) scores = torch.cat(scores, dim=1) locs = torch.cat(locs, dim=1) return scores, locs, loc_sources