def __init__(self, num_classes: int, encoder: torch.nn.Module,
              freeze_encoder: bool, class_weights: Optional[torch.Tensor]):
     super().__init__()
     self.num_classes = num_classes
     self.encoder = encoder
     self.freeze_encoder = freeze_encoder
     self.class_weights = class_weights
     self.encoder.eval()
     self.classifier_head = SSLEvaluator(n_input=get_encoder_output_dim(
         self.encoder),
                                         n_hidden=None,
                                         n_classes=num_classes,
                                         p=0.20)
     if self.num_classes == 2:
         self.train_metrics = ModuleList([
             AreaUnderRocCurve(),
             AreaUnderPrecisionRecallCurve(),
             Accuracy05()
         ])
         self.val_metrics = ModuleList([
             AreaUnderRocCurve(),
             AreaUnderPrecisionRecallCurve(),
             Accuracy05()
         ])
     else:
         # Note that for multi-class, Accuracy05 is the standard multi-class accuracy.
         self.train_metrics = ModuleList([Accuracy05()])
         self.val_metrics = ModuleList([Accuracy05()])
Ejemplo n.º 2
0
 def __init__(self, backbone, in_features, num_classes, hidden_dim=1024):
     """
     Args:
         backbone: a pretrained model
         in_features: feature dim of backbone outputs
         num_classes: classes of the dataset
         hidden_dim: dim of the MLP (1024 default used in self-supervised literature)
     """
     super().__init__()
     self.backbone = backbone
     self.ft_network = SSLEvaluator(n_input=in_features,
                                    n_classes=num_classes,
                                    p=0.2,
                                    n_hidden=hidden_dim)
Ejemplo n.º 3
0
    def __init__(self,
                 backbone,
                 in_features,
                 num_classes,
                 hidden_dim=1024,
                 **kwargs):
        super().__init__()
        self.save_hyperparameters(kwargs)

        self.backbone = backbone
        self.ft_network = SSLEvaluator(n_input=in_features,
                                       n_classes=num_classes,
                                       p=0.2,
                                       n_hidden=hidden_dim)
        self.criterion = nn.MultiLabelSoftMarginLoss()
Ejemplo n.º 4
0
    def __init__(self, backbone, in_features, num_classes, hidden_dim=1024):
        """
        Finetunes a self-supervised learning backbone using the standard evaluation protocol of a singler layer MLP
        with 1024 units

        Example::

            from pl_bolts.utils.self_supervised import SSLFineTuner
            from pl_bolts.models.self_supervised import CPCV2
            from pl_bolts.datamodules import CIFAR10DataModule
            from pl_bolts.models.self_supervised.cpc.transforms import CPCEvalTransformsCIFAR10,
                                                                        CPCTrainTransformsCIFAR10

            # pretrained model
            backbone = CPCV2.load_from_checkpoint(PATH, strict=False)

            # dataset + transforms
            dm = CIFAR10DataModule(data_dir='.')
            dm.train_transforms = CPCTrainTransformsCIFAR10()
            dm.val_transforms = CPCEvalTransformsCIFAR10()

            # finetuner
            finetuner = SSLFineTuner(backbone, in_features=backbone.z_dim, num_classes=backbone.num_classes)

            # train
            trainer = pl.Trainer()
            trainer.fit(finetuner, dm)

            # test
            trainer.test(datamodule=dm)

        Args:
            backbone: a pretrained model
            in_features: feature dim of backbone outputs
            num_classes: classes of the dataset
            hidden_dim: dim of the MLP (1024 default used in self-supervised literature)
        """
        super().__init__()
        self.backbone = backbone
        self.ft_network = SSLEvaluator(
            n_input=in_features,
            n_classes=num_classes,
            p=0.2,
            n_hidden=hidden_dim
        )
    def __init__(self,
                 backbone: torch.nn.Module,
                 in_features: int = 2048,
                 num_classes: int = 1000,
                 epochs: int = 100,
                 hidden_dim: Optional[int] = None,
                 dropout: float = 0.,
                 learning_rate: float = 0.1,
                 weight_decay: float = 1e-6,
                 nesterov: bool = False,
                 scheduler_type: str = 'cosine',
                 decay_epochs: List = [60, 80],
                 gamma: float = 0.1,
                 final_lr: float = 0.,
                 fix_backbone: bool = True):
        """
        Args:
            backbone: a pretrained model
            in_features: feature dim of backbone outputs
            num_classes: classes of the dataset
            hidden_dim: dim of the MLP (1024 default used in self-supervised literature)
        """
        super().__init__()

        self.learning_rate = learning_rate
        self.nesterov = nesterov
        self.weight_decay = weight_decay
        self.fix_backbone = fix_backbone

        self.scheduler_type = scheduler_type
        self.decay_epochs = decay_epochs
        self.gamma = gamma
        self.epochs = epochs
        self.final_lr = final_lr

        self.backbone = backbone
        self.linear_layer = SSLEvaluator(n_input=in_features,
                                         n_classes=num_classes,
                                         p=dropout,
                                         n_hidden=hidden_dim)

        # metrics
        self.train_acc = Accuracy()
        self.val_acc = Accuracy(compute_on_step=False)
        self.test_acc = Accuracy(compute_on_step=False)
class SSLFineTuner(pl.LightningModule):
    """
    Finetunes a self-supervised learning backbone using the standard evaluation protocol of a singler layer MLP
    with 1024 units

    Example::

        from pl_bolts.utils.self_supervised import SSLFineTuner
        from pl_bolts.models.self_supervised import CPCV2
        from pl_bolts.datamodules import CIFAR10DataModule
        from pl_bolts.models.self_supervised.cpc.transforms import CPCEvalTransformsCIFAR10,
                                                                    CPCTrainTransformsCIFAR10

        # pretrained model
        backbone = CPCV2.load_from_checkpoint(PATH, strict=False)

        # dataset + transforms
        dm = CIFAR10DataModule(data_dir='.')
        dm.train_transforms = CPCTrainTransformsCIFAR10()
        dm.val_transforms = CPCEvalTransformsCIFAR10()

        # finetuner
        finetuner = SSLFineTuner(backbone, in_features=backbone.z_dim, num_classes=backbone.num_classes)

        # train
        trainer = pl.Trainer()
        trainer.fit(finetuner, dm)

        # test
        trainer.test(datamodule=dm)
    """
    def __init__(self,
                 backbone: torch.nn.Module,
                 in_features: int = 2048,
                 num_classes: int = 1000,
                 epochs: int = 100,
                 hidden_dim: Optional[int] = None,
                 dropout: float = 0.,
                 learning_rate: float = 0.1,
                 weight_decay: float = 1e-6,
                 nesterov: bool = False,
                 scheduler_type: str = 'cosine',
                 decay_epochs: List = [60, 80],
                 gamma: float = 0.1,
                 final_lr: float = 0.):
        """
        Args:
            backbone: a pretrained model
            in_features: feature dim of backbone outputs
            num_classes: classes of the dataset
            hidden_dim: dim of the MLP (1024 default used in self-supervised literature)
        """
        super().__init__()

        self.learning_rate = learning_rate
        self.nesterov = nesterov
        self.weight_decay = weight_decay

        self.scheduler_type = scheduler_type
        self.decay_epochs = decay_epochs
        self.gamma = gamma
        self.epochs = epochs
        self.final_lr = final_lr

        self.backbone = backbone
        self.linear_layer = SSLEvaluator(n_input=in_features,
                                         n_classes=num_classes,
                                         p=dropout,
                                         n_hidden=hidden_dim)

        # metrics
        self.train_acc = Accuracy()
        self.val_acc = Accuracy(compute_on_step=False)
        self.test_acc = Accuracy(compute_on_step=False)

    def on_train_epoch_start(self) -> None:
        self.backbone.eval()

    def training_step(self, batch, batch_idx):
        loss, logits, y = self.shared_step(batch)
        acc = self.train_acc(logits, y)

        self.log('train_loss', loss, prog_bar=True)
        self.log('train_acc_step', acc, prog_bar=True)
        self.log('train_acc_epoch', self.train_acc)

        return loss

    def validation_step(self, batch, batch_idx):
        loss, logits, y = self.shared_step(batch)
        self.val_acc(logits, y)

        self.log('val_loss', loss, prog_bar=True, sync_dist=True)
        self.log('val_acc', self.val_acc)

        return loss

    def test_step(self, batch, batch_idx):
        loss, logits, y = self.shared_step(batch)
        self.test_acc(logits, y)

        self.log('test_loss', loss, sync_dist=True)
        self.log('test_acc', self.test_acc)

        return loss

    def shared_step(self, batch):
        x, y = batch

        with torch.no_grad():
            feats = self.backbone(x)

        feats = feats.view(feats.size(0), -1)
        logits = self.linear_layer(feats)
        loss = F.cross_entropy(logits, y)

        return loss, logits, y

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(
            self.linear_layer.parameters(),
            lr=self.learning_rate,
            nesterov=self.nesterov,
            momentum=0.9,
            weight_decay=self.weight_decay,
        )

        # set scheduler
        if self.scheduler_type == "step":
            scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                             self.decay_epochs,
                                                             gamma=self.gamma)
        elif self.scheduler_type == "cosine":
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer,
                self.epochs,
                eta_min=self.final_lr  # total epochs to run
            )

        return [optimizer], [scheduler]
class SSLFineTuner(pl.LightningModule):
    def __init__(self, backbone, in_features, num_classes, hidden_dim=1024):
        """
        Finetunes a self-supervised learning backbone using the standard evaluation protocol of a singler layer MLP
        with 1024 units

        Example::

            from pl_bolts.utils.self_supervised import SSLFineTuner
            from pl_bolts.models.self_supervised import CPCV2
            from pl_bolts.datamodules import CIFAR10DataModule
            from pl_bolts.models.self_supervised.cpc.transforms import CPCEvalTransformsCIFAR10,
                                                                        CPCTrainTransformsCIFAR10

            # pretrained model
            backbone = CPCV2.load_from_checkpoint(PATH, strict=False)

            # dataset + transforms
            dm = CIFAR10DataModule(data_dir='.')
            dm.train_transforms = CPCTrainTransformsCIFAR10()
            dm.val_transforms = CPCEvalTransformsCIFAR10()

            # finetuner
            finetuner = SSLFineTuner(backbone, in_features=backbone.z_dim, num_classes=backbone.num_classes)

            # train
            trainer = pl.Trainer()
            trainer.fit(finetuner, dm)

            # test
            trainer.test(datamodule=dm)

        Args:
            backbone: a pretrained model
            in_features: feature dim of backbone outputs
            num_classes: classes of the dataset
            hidden_dim: dim of the MLP (1024 default used in self-supervised literature)
        """
        super().__init__()
        self.backbone = backbone
        self.ft_network = SSLEvaluator(n_input=in_features,
                                       n_classes=num_classes,
                                       p=0.2,
                                       n_hidden=hidden_dim)

    def on_train_epoch_start(self) -> None:
        self.backbone.eval()

    def training_step(self, batch, batch_idx):
        loss, acc = self.shared_step(batch)
        result = pl.TrainResult(loss)
        result.log('train_acc', acc, prog_bar=True)
        return result

    def validation_step(self, batch, batch_idx):
        loss, acc = self.shared_step(batch)
        result = pl.EvalResult(checkpoint_on=loss, early_stop_on=loss)
        result.log_dict({'val_acc': acc, 'val_loss': loss}, prog_bar=True)
        return result

    def test_step(self, batch, batch_idx):
        loss, acc = self.shared_step(batch)
        result = pl.EvalResult()
        result.log_dict({'test_acc': acc, 'test_loss': loss})
        return result

    def shared_step(self, batch):
        x, y = batch

        with torch.no_grad():
            feats = self.backbone(x)
        feats = feats.view(x.size(0), -1)
        logits = self.ft_network(feats)
        loss = F.cross_entropy(logits, y)
        acc = plm.accuracy(logits, y)

        return loss, acc

    def configure_optimizers(self, ):
        return torch.optim.Adam(self.ft_network.parameters(), lr=0.0002)
Ejemplo n.º 8
0
class SSLFineTuner(pl.LightningModule):
    def __init__(self,
                 backbone: torch.nn.Module,
                 in_features: int = 2048,
                 num_classes: int = 1000,
                 epochs: int = 100,
                 hidden_dim: Optional[int] = None,
                 dropout: float = 0.1,
                 learning_rate: float = 1e-3,
                 weight_decay: float = 1e-6,
                 nesterov: bool = False,
                 scheduler_type: str = 'cosine',
                 decay_epochs: List = [60, 80],
                 gamma: float = 0.1,
                 final_lr: float = 0.):
        """
        Args:
            backbone: a pretrained model
            in_features: feature dim of backbone outputs
            num_classes: classes of the dataset
            hidden_dim: dim of the MLP (1024 default used in self-supervised literature)
        """
        super().__init__()

        self.learning_rate = learning_rate
        self.nesterov = nesterov
        self.weight_decay = weight_decay

        self.scheduler_type = scheduler_type
        self.decay_epochs = decay_epochs
        self.gamma = gamma
        self.epochs = epochs
        self.final_lr = final_lr

        self.backbone = backbone
        self.linear_layer = SSLEvaluator(n_input=in_features,
                                         n_classes=num_classes,
                                         p=dropout,
                                         n_hidden=hidden_dim)

        # metrics
        self.train_acc = Accuracy()
        self.val_acc = Accuracy(compute_on_step=False)
        self.test_acc = Accuracy(compute_on_step=False)

    def on_train_epoch_start(self) -> None:
        self.backbone.train()

    def training_step(self, batch, batch_idx):
        loss, logits, y = self.shared_step(batch)
        acc = self.train_acc(logits, y)

        self.log('train_loss', loss, prog_bar=True)
        self.log('train_acc_step', acc, prog_bar=True)
        self.log('train_acc_epoch', self.train_acc)

        return loss

    def validation_step(self, batch, batch_idx):
        loss, logits, y = self.shared_step(batch)
        acc = self.val_acc(logits, y)

        self.log('val_loss', loss, prog_bar=True, sync_dist=True)
        self.log('val_acc', self.val_acc)

        return loss

    def test_step(self, batch, batch_idx):
        loss, logits, y = self.shared_step(batch)
        acc = self.test_acc(logits, y)

        self.log('test_loss', loss, sync_dist=True)
        self.log('test_acc', self.test_acc)

        return loss

    def shared_step(self, batch):
        x, y = batch

        with torch.no_grad():
            feats = self.backbone(x)

        feats = feats.view(feats.size(0), -1)
        logits = self.linear_layer(feats)
        loss = F.cross_entropy(logits, y)

        return loss, logits, y

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(
            list(self.backbone.parameters()) +
            list(self.linear_layer.parameters()),
            lr=self.learning_rate,
            nesterov=self.nesterov,
            momentum=0.9,
            weight_decay=self.weight_decay,
        )

        # set scheduler
        if self.scheduler_type == "step":
            scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                             self.decay_epochs,
                                                             gamma=self.gamma)
        elif self.scheduler_type == "cosine":
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer,
                self.epochs,
                eta_min=self.final_lr  # total epochs to run
            )

        return [optimizer], [scheduler]
Ejemplo n.º 9
0
class SSLFineTuner(LightningModule):
    def __init__(self,
                 backbone,
                 in_features,
                 num_classes,
                 hidden_dim=1024,
                 **kwargs):
        super().__init__()
        self.save_hyperparameters(kwargs)

        self.backbone = backbone
        self.ft_network = SSLEvaluator(n_input=in_features,
                                       n_classes=num_classes,
                                       p=0.2,
                                       n_hidden=hidden_dim)
        self.criterion = nn.MultiLabelSoftMarginLoss()

    def forward(self, x):
        with torch.set_grad_enabled(not self.hparams.freeze_backbone):
            feats = self.backbone(x)
        logits = self.ft_network(feats)
        return logits

    def on_train_epoch_start(self) -> None:
        self.backbone.eval()

    def training_step(self, batch, batch_idx):
        loss, acc = self.shared_step(batch)
        self.log_dict({'acc/train': acc, 'loss/train': loss}, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, acc = self.shared_step(batch)
        self.log_dict({'acc/val': acc, 'loss/val': loss}, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        loss, acc = self.shared_step(batch)
        self.log_dict({'acc/test': acc, 'loss/test': loss})
        return loss

    def shared_step(self, batch):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = average_precision_score(y.cpu(),
                                      torch.sigmoid(logits).detach().cpu(),
                                      average='micro') * 100.0
        return loss, acc

    def configure_optimizers(self):
        params = self.ft_network.parameters()
        if not self.hparams.freeze_backbone:
            params = chain(self.backbone.parameters(), params)
        optimizer = optim.Adam(params, lr=self.hparams.learning_rate)
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=self.hparams.milestones)
        return {'optimizer': optimizer, 'lr_scheduler': scheduler}

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--num_workers', type=int, default=32)
        parser.add_argument('--batch_size', type=int, default=256)
        parser.add_argument('--learning_rate', type=float, default=1e-5)
        parser.add_argument('--freeze_backbone', action='store_true')
        parser.add_argument('--milestones',
                            type=int,
                            nargs='*',
                            default=[60, 80])
        return parser
class SSLClassifier(LightningModuleWithOptimizer, DeviceAwareModule):
    """
    SSL Image classifier that combines pre-trained SSL encoder with a trainable linear-head.
    """
    def __init__(self, num_classes: int, encoder: torch.nn.Module,
                 freeze_encoder: bool, class_weights: Optional[torch.Tensor]):
        super().__init__()
        self.num_classes = num_classes
        self.encoder = encoder
        self.freeze_encoder = freeze_encoder
        self.class_weights = class_weights
        self.encoder.eval()
        self.classifier_head = SSLEvaluator(n_input=get_encoder_output_dim(
            self.encoder),
                                            n_hidden=None,
                                            n_classes=num_classes,
                                            p=0.20)
        if self.num_classes == 2:
            self.train_metrics = ModuleList([
                AreaUnderRocCurve(),
                AreaUnderPrecisionRecallCurve(),
                Accuracy05()
            ])
            self.val_metrics = ModuleList([
                AreaUnderRocCurve(),
                AreaUnderPrecisionRecallCurve(),
                Accuracy05()
            ])
        else:
            # Note that for multi-class, Accuracy05 is the standard multi-class accuracy.
            self.train_metrics = ModuleList([Accuracy05()])
            self.val_metrics = ModuleList([Accuracy05()])

    def train(self, mode: bool = True) -> Any:
        self.classifier_head.train(mode)
        if self.freeze_encoder:
            return self
        self.encoder.train(mode)
        return self

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # type: ignore
        if self.freeze_encoder:
            with torch.no_grad():
                agg_repr = self.encoder(x).flatten(1).detach()
        else:
            agg_repr = self.encoder(x).flatten(1)
        return self.classifier_head(agg_repr)

    def shared_step(self, batch: Any, is_training: bool) -> Any:
        _, x, y = batch
        mlp_preds = self.forward(x)
        weights = None if self.class_weights is None else self.class_weights.to(
            device=self.device)
        mlp_loss = F.cross_entropy(mlp_preds, y, weight=weights)

        with torch.no_grad():
            posteriors = F.softmax(mlp_preds, dim=-1)
            for metric in (self.train_metrics
                           if is_training else self.val_metrics):
                metric(posteriors, y)
        return mlp_loss

    def training_step(self, batch: Any, batch_id: int, *args: Any,
                      **kwargs: Any) -> Any:  # type: ignore
        loss = self.shared_step(batch, True)
        log_on_epoch(self, "train/loss", loss)
        for metric in self.train_metrics:
            log_on_epoch(self, f"train/{metric.name}", metric)
        return loss

    def validation_step(self, batch: Any, batch_id: int, *args: Any,
                        **kwargs: Any) -> None:  # type: ignore
        loss = self.shared_step(batch, is_training=False)
        log_on_epoch(self, 'val/loss', loss)
        for metric in self.val_metrics:
            log_on_epoch(self, f"val/{metric.name}", metric)

    def get_input_tensors(self, item: ScalarItem) -> List[torch.Tensor]:
        """
        Not used for CXRImageClassifier container. This is just need if we use this model within a InnerEyeContainer.
        """
        return [item.images]
Ejemplo n.º 11
0
 def __init__(self, **kwargs: Any) -> None:
     super().__init__(**kwargs)
     self.non_linear_evaluator = SSLEvaluator(
         n_input=get_encoder_output_dim(self),
         n_classes=num_classes,
         n_hidden=None)