def _load_config(self) -> None: # For Chest-XRay you need to specify the parameters of the augmentations via a config file. self.ssl_augmentation_params = load_yaml_augmentation_config( self.ssl_augmentation_config) if self.ssl_augmentation_config is not None \ else None self.classifier_augmentation_params = load_yaml_augmentation_config( self.linear_head_augmentation_config) if self.linear_head_augmentation_config is not None else \ self.ssl_augmentation_params
def get_image_transform(self) -> ModelTransformsPerExecutionMode: config = load_yaml_augmentation_config(path_linear_head_augmentation_cxr) train_transforms = Compose( [DicomPreparation(), create_transforms_from_config(config, apply_augmentations=True)]) val_transforms = Compose( [DicomPreparation(), create_transforms_from_config(config, apply_augmentations=False)]) return ModelTransformsPerExecutionMode(train=train_transforms, val=val_transforms, test=val_transforms)
from pytorch_lightning.trainer.supporters import CombinedLoader from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path from InnerEye.ML.SSL.datamodules_and_datasets.cifar_datasets import InnerEyeCIFAR10 from InnerEye.ML.SSL.datamodules_and_datasets.cxr_datasets import RSNAKaggleCXR from InnerEye.ML.SSL.datamodules_and_datasets.datamodules import CombinedDataModule, InnerEyeVisionDataModule from InnerEye.ML.SSL.datamodules_and_datasets.transforms_utils import InnerEyeCIFARLinearHeadTransform, \ InnerEyeCIFARTrainTransform, get_ssl_transforms_from_config from InnerEye.ML.SSL.lightning_containers.ssl_container import SSLContainer, SSLDatasetName from InnerEye.ML.SSL.utils import SSLDataModuleType, load_yaml_augmentation_config from InnerEye.ML.configs.ssl.CXR_SSL_configs import path_encoder_augmentation_cxr from Tests.SSL.test_ssl_containers import create_cxr_test_dataset path_to_test_dataset = full_ml_test_data_path("cxr_test_dataset") create_cxr_test_dataset(path_to_test_dataset) cxr_augmentation_config = load_yaml_augmentation_config(path_encoder_augmentation_cxr) def test_weights_innereye_module() -> None: """ Tests if weights in CXR data module are correctly initialized """ transforms = get_ssl_transforms_from_config(cxr_augmentation_config, return_two_views_per_sample=True) data_module = InnerEyeVisionDataModule(dataset_cls=RSNAKaggleCXR, return_index=False, train_transforms=transforms[0], val_transforms=transforms[1], data_dir=str(path_to_test_dataset), batch_size=1, seed=1,