def test_shared_forward_step() -> None: byol = BYOLInnerEye(num_samples=16, learning_rate=1e-3, batch_size=4, warmup_epochs=10, encoder_name="resnet50") imgs = torch.rand((4, 3, 32, 32)) lbls = torch.rand((4, )) batch = ([imgs, imgs], lbls) loss = byol.shared_step(batch=batch, batch_idx=0) assert torch.le(loss, 1.0) assert torch.ge(loss, -1.0)
def test_encoder_init(encoder_name: str) -> None: BYOLInnerEye(num_samples=16, learning_rate=1e-3, batch_size=4, warmup_epochs=10, encoder_name=encoder_name, max_epochs=100)
def test_update_tau() -> None: class DummyRSNADataset(RSNAKaggleCXR): def __getitem__(self, item: Any) -> Any: return (torch.rand([3, 224, 224], dtype=torch.float32), torch.rand([3, 224, 224], dtype=torch.float32)), \ randint(0, 1) dataset_dir = str(path_to_test_dataset) dummy_rsna_train_dataloader: DataLoader = torch.utils.data.DataLoader( DummyRSNADataset(root=dataset_dir, return_index=False, train=True), batch_size=20, num_workers=0, drop_last=True) byol_weight_update = ByolMovingAverageWeightUpdate(initial_tau=0.99) trainer = Trainer(max_epochs=5) trainer.train_dataloader = dummy_rsna_train_dataloader n_steps_per_epoch = len(trainer.train_dataloader) total_steps = n_steps_per_epoch * trainer.max_epochs # type: ignore byol_module = BYOLInnerEye(num_samples=16, learning_rate=1e-3, batch_size=4, encoder_name="resnet50", warmup_epochs=10) with mock.patch( "InnerEye.ML.SSL.lightning_modules.byol.byol_module.BYOLInnerEye.global_step", 15): new_tau = byol_weight_update.update_tau(pl_module=byol_module, trainer=trainer) assert new_tau == 1 - 0.01 * (math.cos(math.pi * 15 / total_steps) + 1) / 2
def test_module_param_eq() -> None: byol = BYOLInnerEye(num_samples=16, learning_rate=1e-3, batch_size=4, encoder_name="resnet50", warmup_epochs=10) pars1 = byol.online_network.parameters() pars2 = byol.target_network.parameters() for par1, par2 in zip(pars1, pars2): assert torch.all(torch.eq(par1, par2))
def test_output_spatial_pooling() -> None: byol = BYOLInnerEye(num_samples=16, learning_rate=1e-3, batch_size=4, warmup_epochs=10, encoder_name="resnet50") imgs = torch.rand((4, 3, 32, 32)) embeddings = byol(imgs) batch_size = embeddings.size()[0] embedding_size = embeddings.size()[1] assert batch_size == 4 assert embedding_size == 2048
def create_model(self) -> LightningModule: """ This method must create the actual Lightning model that will be trained. """ # For small images like CIFAR, if using a resnet encoder, switch the first conv layer to a 3x3 kernel instead # of a 7x7 conv layer. use_7x7_first_conv_in_resnet = False if self.ssl_training_dataset_name.value.startswith( "CIFAR") else True if self.ssl_training_type == SSLTrainingType.SimCLR: model: LightningModule = SimCLRInnerEye( encoder_name=self.ssl_encoder.value, dataset_name=self.ssl_training_dataset_name.value, use_7x7_first_conv_in_resnet=use_7x7_first_conv_in_resnet, num_samples=self.data_module.num_train_samples, batch_size=self.data_module.batch_size, gpus=self.num_gpus_per_node(), num_nodes=self.num_nodes, learning_rate=self.l_rate, max_epochs=self.num_epochs) logging.info( f"LR scheduling is using train_iters_per_epoch = {model.train_iters_per_epoch}" ) elif self.ssl_training_type == SSLTrainingType.BYOL: model = BYOLInnerEye( encoder_name=self.ssl_encoder.value, num_samples=self.data_module.num_train_samples, batch_size=self.data_module.batch_size, learning_rate=self.l_rate, use_7x7_first_conv_in_resnet=use_7x7_first_conv_in_resnet, warmup_epochs=10, max_epochs=self.num_epochs) else: raise ValueError( f"Unknown value for ssl_training_type, should be {SSLTrainingType.SimCLR.value} or " f"{SSLTrainingType.BYOL.value}. " f"Found {self.ssl_training_type.value}") model.hparams.update({ 'ssl_type': self.ssl_training_type.value, "num_classes": self.data_module.num_classes }) self.encoder_output_dim = get_encoder_output_dim( model, self.data_module) return model
def create_ssl_image_classifier( num_classes: int, freeze_encoder: bool, pl_checkpoint_path: str, class_weights: Optional[torch.Tensor] = None ) -> LightningModuleWithOptimizer: """ Creates a SSL image classifier from a frozen encoder trained on in an unsupervised manner. """ # Use local imports to avoid circular imports from InnerEye.ML.SSL.lightning_modules.byol.byol_module import BYOLInnerEye from InnerEye.ML.SSL.lightning_modules.simclr_module import SimCLRInnerEye from InnerEye.ML.SSL.lightning_modules.ssl_classifier_module import SSLClassifier logging.info(f"Size of ckpt {Path(pl_checkpoint_path).stat().st_size}") loaded_params = torch.load( pl_checkpoint_path, map_location=lambda storage, loc: storage)["hyper_parameters"] ssl_type = loaded_params["ssl_type"] logging.info(f"Creating a {ssl_type} based image classifier") logging.info( f"Loading pretrained {ssl_type} weights from:\n {pl_checkpoint_path}") if ssl_type == SSLTrainingType.BYOL.value or ssl_type == SSLTrainingType.BYOL: byol_module = BYOLInnerEye.load_from_checkpoint(pl_checkpoint_path) encoder = byol_module.target_network.encoder elif ssl_type == SSLTrainingType.SimCLR.value or ssl_type == SSLTrainingType.SimCLR: simclr_module = SimCLRInnerEye.load_from_checkpoint(pl_checkpoint_path) encoder = simclr_module.encoder else: raise NotImplementedError(f"Unknown unsupervised model: {ssl_type}") model = SSLClassifier(num_classes=num_classes, encoder=encoder, freeze_encoder=freeze_encoder, class_weights=class_weights) return model