def test_load_items_classification_versus_regression(test_output_dirs: OutputFolderForTests) -> None: """ Test loading file paths and labels from a datafrome with diferrent configuration """ csv_string_classification = """USUBJID,week,path,value,scalar1,scalar2,categorical1,categorical2 S1,image,foo.nii S1,label,,True,1.1,1.2,True,False S2,image,bar.nii S2,label,,False,2.1,2.2,False,True """ csv_string_regression = """USUBJID,week,path,value,scalar1,scalar2,categorical1,categorical2 S1,image,foo.nii S1,label,,1,1.1,1.2,Male,True S2,image,bar.nii S2,label,,-1.3,2.1,2.2,Female,True """ with pytest.raises(ValueError): _create_test_dataset(create_dataset_csv_file(csv_string_regression, test_output_dirs.root_dir), scalar_loss=ScalarLoss.BinaryCrossEntropyWithLogits) with pytest.raises(ValueError): _create_test_dataset(create_dataset_csv_file(csv_string_classification, test_output_dirs.root_dir), scalar_loss=ScalarLoss.MeanSquaredError) dataset_classification = _create_test_dataset(create_dataset_csv_file(csv_string_classification, test_output_dirs.root_dir), scalar_loss=ScalarLoss.BinaryCrossEntropyWithLogits) assert len(dataset_classification.items) == 2 dataset_regression = _create_test_dataset(create_dataset_csv_file(csv_string_regression, test_output_dirs.root_dir), scalar_loss=ScalarLoss.MeanSquaredError) assert len(dataset_regression.items) == 2
def test_load_items(test_output_dirs: OutputFolderForTests) -> None: """ Test loading file paths and labels from a datafrome. """ csv_string = """USUBJID,week,path,value,scalar1,scalar2,categorical1,categorical2 S1,image,foo.nii S1,label,,True,1.1,1.2,A1,A2 S2,image,bar.nii S2,label,,False,2.1,2.2,B1,A2 """ dataset = _create_test_dataset(create_dataset_csv_file(csv_string, test_output_dirs.root_dir), categorical_columns=["categorical1", "categorical2"]) items = dataset.items metadata0 = items[0].metadata assert isinstance(metadata0, GeneralSampleMetadata) assert metadata0.id == "S1" assert items[0].label.tolist() == [1.0] assert items[0].channel_files == ["foo.nii"] assert items[0].numerical_non_image_features.shape == (2,) assert items[0].numerical_non_image_features.tolist() == pytest.approx( [-0.7071067094802856, -0.7071067690849304]) assert items[0].categorical_non_image_features.tolist() == [1.0, 0.0, 1.0] assert items[1].categorical_non_image_features.tolist() == [0.0, 1.0, 1.0] metadata1 = items[1].metadata assert isinstance(metadata1, GeneralSampleMetadata) assert metadata1.id == "S2" assert items[1].label.tolist() == [0.0] assert items[1].channel_files == ["bar.nii"] assert items[1].numerical_non_image_features.shape == (2,) assert items[1].numerical_non_image_features.tolist() == pytest.approx([0.7071068286895752, 0.7071067690849304])
def test_categorical_and_numerical_columns_are_mutually_exclusive(test_output_dirs: OutputFolderForTests) -> None: csv_string = """USUBJID,week,path,value,scalar1,categorical1 S1,image,foo.nii S1,label,,True,1.1,False """ with pytest.raises(ValueError): _create_test_dataset(create_dataset_csv_file(csv_string, test_output_dirs.root_dir), categorical_columns=["scalar1"])
def test_sequence_dataset_all(test_output_dirs: OutputFolderForTests) -> None: """ Check that the sequence dataset works end-to-end, including applying the right standardization. """ csv_string = """subject,seq,value,scalar1,scalar2,META,BETA S1,0,False,0,0,M1,B1 S1,1,True,1,10,M2,B2 S2,0,False,2,20,M2,B1 S3,0,True,3,30,M1,B1 S4,0,True,4,40,M2,B1 """ csv_path = create_dataset_csv_file(csv_string, test_output_dirs.root_dir) config = SequenceModelBase(local_dataset=csv_path, image_file_column=None, label_value_column="value", numerical_columns=["scalar1", "scalar2"], sequence_target_positions=[0], categorical_columns=["META", "BETA"], sequence_column="seq", num_dataload_workers=0, train_batch_size=2, should_validate=False, shuffle=False) config.read_dataset_if_needed() df = config.dataset_data_frame assert df is not None df1 = df[df.subject.isin(["S1", "S2"])] df2 = df[df.subject == "S3"] df3 = df[df.subject == "S4"] splits = DatasetSplits(train=df1, val=df2, test=df3) with mock.patch.object(SequenceModelBase, 'get_model_train_test_dataset_splits', return_value=splits): train_val_loaders = config.create_data_loaders() # Expected feature mean: Mean of the training data (0, 0), (1, 10), (2, 20) = (1, 10) # Expected (biased corrected) std estimate: Std of (0, 0), (1, 10), (2, 20) = (1, 10) feature_stats = config.get_torch_dataset_for_inference( ModelExecutionMode.TRAIN).feature_statistics assert feature_stats is not None assert_tensors_equal(feature_stats.mean, [1, 10]) assert_tensors_equal(feature_stats.std, [1, 10]) train_items = list( ClassificationItemSequence.from_minibatch(b) for b in train_val_loaders[ModelExecutionMode.TRAIN]) assert len( train_items ) == 1, "2 items in training set with batch size of 2 should return 1 minibatch" assert len(train_items[0]) == 2 assert train_items[0][0].id == "S1" assert_tensors_equal( train_items[0][0].items[0].get_all_non_imaging_features(), [-1., -1., 1., 0., 1., 0.]) assert_tensors_equal( train_items[0][0].items[1].get_all_non_imaging_features(), [0., 0., 0., 1., 0., 1.]) assert train_items[0][1].id == "S2" assert_tensors_equal( train_items[0][1].items[0].get_all_non_imaging_features(), [1., 1., 0., 1., 1., 0.]) val_items = list( ClassificationItemSequence.from_minibatch(b) for b in train_val_loaders[ModelExecutionMode.VAL]) assert len(val_items) == 1 assert len(val_items[0]) == 1 assert val_items[0][0].id == "S3" # Items in the validation set should be normalized using the mean and std on the training data. # Hence, the non-image features (3, 30) should turn into (2, 2) assert_tensors_equal( val_items[0][0].items[0].get_all_non_imaging_features(), [2., 2., 1., 0., 1., 0.]) # Check that the test set is also normalized correctly using the training mean and std. test_items = list( ClassificationItemSequence(**b) for b in config.get_torch_dataset_for_inference(ModelExecutionMode.TEST)) assert test_items[0].id == "S4" # Check Non-image features of (4, 40) assert_tensors_equal( test_items[0].items[0].get_all_non_imaging_features(), [3., 3., 0., 1., 1., 0.])