Exemplo n.º 1
0
def test_shared_forward_step() -> None:
    byol = BYOLInnerEye(num_samples=16,
                        learning_rate=1e-3,
                        batch_size=4,
                        warmup_epochs=10,
                        encoder_name="resnet50")
    imgs = torch.rand((4, 3, 32, 32))
    lbls = torch.rand((4, ))
    batch = ([imgs, imgs], lbls)

    loss = byol.shared_step(batch=batch, batch_idx=0)
    assert torch.le(loss, 1.0)
    assert torch.ge(loss, -1.0)
def test_encoder_init(encoder_name: str) -> None:
    BYOLInnerEye(num_samples=16,
                 learning_rate=1e-3,
                 batch_size=4,
                 warmup_epochs=10,
                 encoder_name=encoder_name,
                 max_epochs=100)
Exemplo n.º 3
0
def test_update_tau() -> None:
    class DummyRSNADataset(RSNAKaggleCXR):
        def __getitem__(self, item: Any) -> Any:
            return (torch.rand([3, 224, 224], dtype=torch.float32),
                    torch.rand([3, 224, 224], dtype=torch.float32)), \
                   randint(0, 1)

    dataset_dir = str(path_to_test_dataset)
    dummy_rsna_train_dataloader: DataLoader = torch.utils.data.DataLoader(
        DummyRSNADataset(root=dataset_dir, return_index=False, train=True),
        batch_size=20,
        num_workers=0,
        drop_last=True)

    byol_weight_update = ByolMovingAverageWeightUpdate(initial_tau=0.99)
    trainer = Trainer(max_epochs=5)
    trainer.train_dataloader = dummy_rsna_train_dataloader
    n_steps_per_epoch = len(trainer.train_dataloader)
    total_steps = n_steps_per_epoch * trainer.max_epochs  # type: ignore
    byol_module = BYOLInnerEye(num_samples=16,
                               learning_rate=1e-3,
                               batch_size=4,
                               encoder_name="resnet50",
                               warmup_epochs=10)
    with mock.patch(
            "InnerEye.ML.SSL.lightning_modules.byol.byol_module.BYOLInnerEye.global_step",
            15):
        new_tau = byol_weight_update.update_tau(pl_module=byol_module,
                                                trainer=trainer)
    assert new_tau == 1 - 0.01 * (math.cos(math.pi * 15 / total_steps) + 1) / 2
Exemplo n.º 4
0
def test_module_param_eq() -> None:
    byol = BYOLInnerEye(num_samples=16,
                        learning_rate=1e-3,
                        batch_size=4,
                        encoder_name="resnet50",
                        warmup_epochs=10)
    pars1 = byol.online_network.parameters()
    pars2 = byol.target_network.parameters()
    for par1, par2 in zip(pars1, pars2):
        assert torch.all(torch.eq(par1, par2))
Exemplo n.º 5
0
def test_output_spatial_pooling() -> None:
    byol = BYOLInnerEye(num_samples=16,
                        learning_rate=1e-3,
                        batch_size=4,
                        warmup_epochs=10,
                        encoder_name="resnet50")
    imgs = torch.rand((4, 3, 32, 32))

    embeddings = byol(imgs)
    batch_size = embeddings.size()[0]
    embedding_size = embeddings.size()[1]

    assert batch_size == 4
    assert embedding_size == 2048
 def create_model(self) -> LightningModule:
     """
     This method must create the actual Lightning model that will be trained.
     """
     # For small images like CIFAR, if using a resnet encoder, switch the first conv layer to a 3x3 kernel instead
     # of a 7x7 conv layer.
     use_7x7_first_conv_in_resnet = False if self.ssl_training_dataset_name.value.startswith(
         "CIFAR") else True
     if self.ssl_training_type == SSLTrainingType.SimCLR:
         model: LightningModule = SimCLRInnerEye(
             encoder_name=self.ssl_encoder.value,
             dataset_name=self.ssl_training_dataset_name.value,
             use_7x7_first_conv_in_resnet=use_7x7_first_conv_in_resnet,
             num_samples=self.data_module.num_train_samples,
             batch_size=self.data_module.batch_size,
             gpus=self.num_gpus_per_node(),
             num_nodes=self.num_nodes,
             learning_rate=self.l_rate,
             max_epochs=self.num_epochs)
         logging.info(
             f"LR scheduling is using train_iters_per_epoch = {model.train_iters_per_epoch}"
         )
     elif self.ssl_training_type == SSLTrainingType.BYOL:
         model = BYOLInnerEye(
             encoder_name=self.ssl_encoder.value,
             num_samples=self.data_module.num_train_samples,
             batch_size=self.data_module.batch_size,
             learning_rate=self.l_rate,
             use_7x7_first_conv_in_resnet=use_7x7_first_conv_in_resnet,
             warmup_epochs=10,
             max_epochs=self.num_epochs)
     else:
         raise ValueError(
             f"Unknown value for ssl_training_type, should be {SSLTrainingType.SimCLR.value} or "
             f"{SSLTrainingType.BYOL.value}. "
             f"Found {self.ssl_training_type.value}")
     model.hparams.update({
         'ssl_type': self.ssl_training_type.value,
         "num_classes": self.data_module.num_classes
     })
     self.encoder_output_dim = get_encoder_output_dim(
         model, self.data_module)
     return model
Exemplo n.º 7
0
def create_ssl_image_classifier(
    num_classes: int,
    freeze_encoder: bool,
    pl_checkpoint_path: str,
    class_weights: Optional[torch.Tensor] = None
) -> LightningModuleWithOptimizer:
    """
    Creates a SSL image classifier from a frozen encoder trained on in an unsupervised manner.
    """

    # Use local imports to avoid circular imports
    from InnerEye.ML.SSL.lightning_modules.byol.byol_module import BYOLInnerEye
    from InnerEye.ML.SSL.lightning_modules.simclr_module import SimCLRInnerEye
    from InnerEye.ML.SSL.lightning_modules.ssl_classifier_module import SSLClassifier

    logging.info(f"Size of ckpt {Path(pl_checkpoint_path).stat().st_size}")
    loaded_params = torch.load(
        pl_checkpoint_path,
        map_location=lambda storage, loc: storage)["hyper_parameters"]
    ssl_type = loaded_params["ssl_type"]

    logging.info(f"Creating a {ssl_type} based image classifier")
    logging.info(
        f"Loading pretrained {ssl_type} weights from:\n {pl_checkpoint_path}")

    if ssl_type == SSLTrainingType.BYOL.value or ssl_type == SSLTrainingType.BYOL:
        byol_module = BYOLInnerEye.load_from_checkpoint(pl_checkpoint_path)
        encoder = byol_module.target_network.encoder
    elif ssl_type == SSLTrainingType.SimCLR.value or ssl_type == SSLTrainingType.SimCLR:
        simclr_module = SimCLRInnerEye.load_from_checkpoint(pl_checkpoint_path)
        encoder = simclr_module.encoder
    else:
        raise NotImplementedError(f"Unknown unsupervised model: {ssl_type}")

    model = SSLClassifier(num_classes=num_classes,
                          encoder=encoder,
                          freeze_encoder=freeze_encoder,
                          class_weights=class_weights)

    return model