def test_encoder_init(encoder_name: str) -> None:
    BYOLInnerEye(num_samples=16,
                 learning_rate=1e-3,
                 batch_size=4,
                 warmup_epochs=10,
                 encoder_name=encoder_name,
                 max_epochs=100)
Exemple #2
0
def test_update_tau() -> None:
    class DummyRSNADataset(RSNAKaggleCXR):
        def __getitem__(self, item: Any) -> Any:
            return (torch.rand([3, 224, 224], dtype=torch.float32),
                    torch.rand([3, 224, 224], dtype=torch.float32)), \
                   randint(0, 1)

    dataset_dir = str(path_to_test_dataset)
    dummy_rsna_train_dataloader: DataLoader = torch.utils.data.DataLoader(
        DummyRSNADataset(root=dataset_dir, return_index=False, train=True),
        batch_size=20,
        num_workers=0,
        drop_last=True)

    byol_weight_update = ByolMovingAverageWeightUpdate(initial_tau=0.99)
    trainer = Trainer(max_epochs=5)
    trainer.train_dataloader = dummy_rsna_train_dataloader
    n_steps_per_epoch = len(trainer.train_dataloader)
    total_steps = n_steps_per_epoch * trainer.max_epochs  # type: ignore
    byol_module = BYOLInnerEye(num_samples=16,
                               learning_rate=1e-3,
                               batch_size=4,
                               encoder_name="resnet50",
                               warmup_epochs=10)
    with mock.patch(
            "InnerEye.ML.SSL.lightning_modules.byol.byol_module.BYOLInnerEye.global_step",
            15):
        new_tau = byol_weight_update.update_tau(pl_module=byol_module,
                                                trainer=trainer)
    assert new_tau == 1 - 0.01 * (math.cos(math.pi * 15 / total_steps) + 1) / 2
Exemple #3
0
def test_module_param_eq() -> None:
    byol = BYOLInnerEye(num_samples=16,
                        learning_rate=1e-3,
                        batch_size=4,
                        encoder_name="resnet50",
                        warmup_epochs=10)
    pars1 = byol.online_network.parameters()
    pars2 = byol.target_network.parameters()
    for par1, par2 in zip(pars1, pars2):
        assert torch.all(torch.eq(par1, par2))
Exemple #4
0
def test_shared_forward_step() -> None:
    byol = BYOLInnerEye(num_samples=16,
                        learning_rate=1e-3,
                        batch_size=4,
                        warmup_epochs=10,
                        encoder_name="resnet50")
    imgs = torch.rand((4, 3, 32, 32))
    lbls = torch.rand((4, ))
    batch = ([imgs, imgs], lbls)

    loss = byol.shared_step(batch=batch, batch_idx=0)
    assert torch.le(loss, 1.0)
    assert torch.ge(loss, -1.0)
Exemple #5
0
def test_output_spatial_pooling() -> None:
    byol = BYOLInnerEye(num_samples=16,
                        learning_rate=1e-3,
                        batch_size=4,
                        warmup_epochs=10,
                        encoder_name="resnet50")
    imgs = torch.rand((4, 3, 32, 32))

    embeddings = byol(imgs)
    batch_size = embeddings.size()[0]
    embedding_size = embeddings.size()[1]

    assert batch_size == 4
    assert embedding_size == 2048
 def create_model(self) -> LightningModule:
     """
     This method must create the actual Lightning model that will be trained.
     """
     # For small images like CIFAR, if using a resnet encoder, switch the first conv layer to a 3x3 kernel instead
     # of a 7x7 conv layer.
     use_7x7_first_conv_in_resnet = False if self.ssl_training_dataset_name.value.startswith(
         "CIFAR") else True
     if self.ssl_training_type == SSLTrainingType.SimCLR:
         model: LightningModule = SimCLRInnerEye(
             encoder_name=self.ssl_encoder.value,
             dataset_name=self.ssl_training_dataset_name.value,
             use_7x7_first_conv_in_resnet=use_7x7_first_conv_in_resnet,
             num_samples=self.data_module.num_train_samples,
             batch_size=self.data_module.batch_size,
             gpus=self.num_gpus_per_node(),
             num_nodes=self.num_nodes,
             learning_rate=self.l_rate,
             max_epochs=self.num_epochs)
         logging.info(
             f"LR scheduling is using train_iters_per_epoch = {model.train_iters_per_epoch}"
         )
     elif self.ssl_training_type == SSLTrainingType.BYOL:
         model = BYOLInnerEye(
             encoder_name=self.ssl_encoder.value,
             num_samples=self.data_module.num_train_samples,
             batch_size=self.data_module.batch_size,
             learning_rate=self.l_rate,
             use_7x7_first_conv_in_resnet=use_7x7_first_conv_in_resnet,
             warmup_epochs=10,
             max_epochs=self.num_epochs)
     else:
         raise ValueError(
             f"Unknown value for ssl_training_type, should be {SSLTrainingType.SimCLR.value} or "
             f"{SSLTrainingType.BYOL.value}. "
             f"Found {self.ssl_training_type.value}")
     model.hparams.update({
         'ssl_type': self.ssl_training_type.value,
         "num_classes": self.data_module.num_classes
     })
     self.encoder_output_dim = get_encoder_output_dim(
         model, self.data_module)
     return model