Exemplo n.º 1
0
    def __init__(
        self,
        num_classes,
        backbone="resnet18",
        num_features: int = None,
        pretrained=True,
        loss_fn: Callable = F.cross_entropy,
        optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD,
        metrics: Union[Callable, Mapping, Sequence, None] = (Accuracy()),
        learning_rate: float = 1e-3,
    ):
        super().__init__(
            model=None,
            loss_fn=loss_fn,
            optimizer=optimizer,
            metrics=metrics,
            learning_rate=learning_rate,
        )

        self.save_hyperparameters()

        self.backbone, num_features = backbone_and_num_features(
            backbone, pretrained)

        self.head = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(num_features, num_classes),
        )
Exemplo n.º 2
0
    def __init__(
        self,
        embedding_dim: Optional[int] = None,
        backbone: str = "swav-imagenet",
        pretrained: bool = True,
        loss_fn: Callable = F.cross_entropy,
        optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD,
        metrics: Union[Callable, Mapping, Sequence, None] = (Accuracy()),
        learning_rate: float = 1e-3,
        pooling_fn: Callable = torch.max
    ):
        super().__init__(
            model=None,
            loss_fn=loss_fn,
            optimizer=optimizer,
            metrics=metrics,
            learning_rate=learning_rate,
        )

        self.save_hyperparameters()
        self.backbone_name = backbone
        self.embedding_dim = embedding_dim
        assert pooling_fn in [torch.mean, torch.max]
        self.pooling_fn = pooling_fn

        self.backbone, num_features = backbone_and_num_features(backbone, pretrained)

        if embedding_dim is None:
            self.head = nn.Identity()
        else:
            self.head = nn.Sequential(
                nn.Flatten(),
                nn.Linear(num_features, embedding_dim),
            )
            rank_zero_warn('embedding_dim is not None. Remember to finetune first!')
Exemplo n.º 3
0
def test_backbone_and_num_features(backbone, expected_num_features):

    backbone_model, num_features = backbone_and_num_features(
        model_name=backbone, pretrained=False, fpn=False)

    assert backbone_model
    assert num_features == expected_num_features
Exemplo n.º 4
0
    def get_model(
        model_name,
        num_classes,
        backbone,
        fpn,
        pretrained,
        pretrained_backbone,
        trainable_backbone_layers,
        anchor_generator,
        **kwargs,
    ):
        if backbone is None:
            # Constructs a model with a ResNet-50-FPN backbone when no backbone is specified.
            if model_name == "fasterrcnn":
                model = _models[model_name](
                    pretrained=pretrained,
                    pretrained_backbone=pretrained_backbone,
                    trainable_backbone_layers=trainable_backbone_layers,
                )
                in_features = model.roi_heads.box_predictor.cls_score.in_features
                head = FastRCNNPredictor(in_features, num_classes)
                model.roi_heads.box_predictor = head
            else:
                model = _models[model_name](pretrained=pretrained, pretrained_backbone=pretrained_backbone)
                model.head = RetinaNetHead(
                    in_channels=model.backbone.out_channels,
                    num_anchors=model.head.classification_head.num_anchors,
                    num_classes=num_classes,
                    **kwargs
                )
        else:
            backbone_model, num_features = backbone_and_num_features(
                backbone,
                fpn,
                pretrained_backbone,
                trainable_backbone_layers,
                **kwargs,
            )
            backbone_model.out_channels = num_features
            if anchor_generator is None:
                anchor_generator = AnchorGenerator(
                    sizes=((32, 64, 128, 256, 512), ), aspect_ratios=((0.5, 1.0, 2.0), )
                ) if not hasattr(backbone_model, "fpn") else None

            if model_name == "fasterrcnn":
                model = FasterRCNN(backbone_model, num_classes=num_classes, rpn_anchor_generator=anchor_generator)
            else:
                model = RetinaNet(backbone_model, num_classes=num_classes, anchor_generator=anchor_generator)
        return model