Ejemplo n.º 1
0
    def test_should_create_multimodal_train_valid_test_datasets(self):
        train_dataset, valid_dataset, test_dataset, reconstruction_dataset, = iSEGSliceDatasetFactory.create_train_valid_test(
            self.DATA_PATH, [Modality.T1, Modality.T2], 0, 0.2)

        assert_that(train_dataset, instance_of(Dataset))
        assert_that(valid_dataset, instance_of(Dataset))
        assert_that(test_dataset, instance_of(Dataset))
        assert_that(reconstruction_dataset, instance_of(Dataset))
Ejemplo n.º 2
0
    def test_should_produce_a_single_modality_input_with_one_channel(self):
        train_dataset, valid_dataset, test_dataset, reconstruction_dataset, = iSEGSliceDatasetFactory.create_train_valid_test(
            self.DATA_PATH, Modality.T1, 0, 0.2)

        sample = train_dataset[32000]

        assert_that(sample.x.size(), is_(torch.Size([1, 32, 32, 32])))
        plt.imshow(sample.x[0, 16, :, :], cmap="gray")
        plt.show()
        plt.imshow(sample.y[0, 16, :, :], cmap="gray")
        plt.show()
Ejemplo n.º 3
0
    model_trainer_factory = ModelTrainerFactory(
        model_factory=CustomModelFactory(),
        criterion_factory=CustomCriterionFactory())
    model_trainers = model_trainer_factory.create(model_trainer_configs)
    if not isinstance(model_trainers, list):
        model_trainers = [model_trainers]

    # Create datasets
    if dataset_configs.get("iSEG", None) is not None:
        iSEG_train, iSEG_valid, iSEG_test, iSEG_reconstruction = iSEGSliceDatasetFactory.create_train_valid_test(
            source_dir=dataset_configs["iSEG"].path,
            modalities=dataset_configs["iSEG"].modalities,
            dataset_id=ISEG_ID,
            test_size=dataset_configs["iSEG"].validation_split,
            max_subjects=dataset_configs["iSEG"].max_subjects,
            max_num_patches=dataset_configs["iSEG"].max_num_patches,
            augment=dataset_configs["iSEG"].augment,
            patch_size=dataset_configs["iSEG"].patch_size,
            step=dataset_configs["iSEG"].step,
            test_patch_size=dataset_configs["iSEG"].test_patch_size,
            test_step=dataset_configs["iSEG"].test_step,
            data_augmentation_config=data_augmentation_config)
        train_datasets.append(iSEG_train)
        valid_datasets.append(iSEG_valid)
        test_datasets.append(iSEG_test)
        reconstruction_datasets.append(iSEG_reconstruction)

    if dataset_configs.get("MRBrainS", None) is not None:
        MRBrainS_train, MRBrainS_valid, MRBrainS_test, MRBrainS_reconstruction = MRBrainSSliceDatasetFactory.create_train_valid_test(
            source_dir=dataset_configs["MRBrainS"].path,
            modalities=dataset_configs["MRBrainS"].modalities,