Ejemplo n.º 1
0
 def __init__(self, config: Config, *args, **kwargs):
     super().__init__()
     self.config = config
     pretrained = config.get("pretrained", False)
     pretrained_path = config.get("pretrained_path", None)
     self.frcnn = GeneralizedRCNN(config)
     if pretrained:
         state_dict = torch.load(pretrained_path)
         self.frcnn.load_state_dict(state_dict)
         self.frcnn.eval()
Ejemplo n.º 2
0
class FRCNNImageEncoder(Encoder):
    @dataclass
    class Config(Encoder.Config):
        name: str = "frcnn"
        pretrained: bool = True
        pretrained_path: str = None

    def __init__(self, config: Config, *args, **kwargs):
        super().__init__()
        self.config = config
        pretrained = config.get("pretrained", False)
        pretrained_path = config.get("pretrained_path", None)
        self.frcnn = GeneralizedRCNN(config)
        if pretrained:
            state_dict = torch.load(pretrained_path)
            self.frcnn.load_state_dict(state_dict)
            self.frcnn.eval()

    def forward(
        self,
        x: torch.Tensor,
        sizes: torch.Tensor = None,
        scales_yx: torch.Tensor = None,
        padding: torch.Tensor = None,
        max_detections: int = 0,
        return_tensors: str = "pt",
    ):
        x = self.frcnn(
            x,
            sizes,
            scales_yx=scales_yx,
            padding=padding,
            max_detections=max_detections,
            return_tensors=return_tensors,
        )
        return x