def test_group_items_with_min_and_max_sequence_position_values() -> None: """ Test if grouping of sequence data works when requiring a full set of items. """ items = [ _create_item("a", 1, "a.1"), _create_item("a", 0, "a.0"), _create_item("a", 2, "a.2"), _create_item("b", 1, "b.1"), _create_item("b", 0, "b.0"), ] # When not providing a max_sequence_position_value, sequences of any length are OK. grouped = group_samples_into_sequences(items, max_sequence_position_value=None) assert len(grouped) == 2 _assert_group(grouped[0], "a", ["a.0", "a.1", "a.2"]) _assert_group(grouped[1], "b", ["b.0", "b.1"]) # With a max_sequence_position_value, the set must be complete up to the given index. grouped = group_samples_into_sequences(items, min_sequence_position_value=1, max_sequence_position_value=2) assert len(grouped) == 2 _assert_group(grouped[0], "a", ["a.1", "a.2"]) # When a max position is given, the sequence will be truncated to at most contain the given value. grouped = group_samples_into_sequences(items, min_sequence_position_value=0, max_sequence_position_value=1) assert len(grouped) == 2 _assert_group(grouped[0], "a", ["a.0", "a.1"]) _assert_group(grouped[1], "b", ["b.0", "b.1"]) grouped = group_samples_into_sequences(items, min_sequence_position_value=1, max_sequence_position_value=1) assert len(grouped) == 2 _assert_group(grouped[0], "a", ["a.1"]) _assert_group(grouped[1], "b", ["b.1"]) # Allow sequences upto max_sequence_position_value=2 grouped = group_samples_into_sequences(items, min_sequence_position_value=1, max_sequence_position_value=2) assert len(grouped) == 2 _assert_group(grouped[0], "a", ["a.1", "a.2"]) _assert_group(grouped[1], "b", ["b.1"]) # There are no items that have sequence position == 3, hence the next two calls should not return any items. grouped = group_samples_into_sequences(items, min_sequence_position_value=3) assert len(grouped) == 0 # Check that items upto max_sequence_position_value=3 are included grouped = group_samples_into_sequences(items, max_sequence_position_value=3) assert len(grouped) == 2 # Sequence positions must be unique with pytest.raises(ValueError) as ex: group_samples_into_sequences([_create_item("a", 0, "a.0")] * 2) assert "contains duplicates" in str(ex)
def test_load_items_seq_from_dataset() -> None: """ Test loading a sequence dataset with numerical, categorical features and images. """ dummy_dataset = full_ml_test_data_path( ) / "sequence_data_for_classification" / "dataset.csv" df = pd.read_csv(dummy_dataset, sep=",", dtype=str) items: List[SequenceDataSource] = DataSourceReader[SequenceDataSource]( data_frame=df, image_channels=None, image_file_column="IMG", label_channels=None, label_value_column="Label", numerical_columns=["NUM1", "NUM2", "NUM3", "NUM4"], sequence_column="Position").load_data_sources() assert len(items) == 3 * 9 # 3 subjects, 9 visits each, no missing assert items[0].metadata.id == "2137.00005" assert items[0].metadata.sequence_position == 0 assert items[0].metadata.props["CAT2"] == "category_A" # One of the labels is missing, missing labels should be encoded as NaN assert math.isnan(items[0].label[0]) assert items[0].channel_files == ["img_1"] assert str(items[0].numerical_non_image_features.tolist()) == str( [362.0, np.nan, np.nan, 71.0]) assert items[8].metadata.id == "2137.00005" assert items[8].metadata.sequence_position == 8 assert items[8].label.tolist() == [0.0] assert items[8].channel_files == [''] assert str(items[8].numerical_non_image_features.tolist()) == str( [350.0, np.nan, np.nan, 8.0]) assert items[16].metadata.id == "2627.00001" assert items[16].label.tolist() == [0.0] assert items[16].channel_files == ["img_2"] assert_tensors_equal(items[16].numerical_non_image_features, [217.0, 0.0, 0.01, 153.0]) assert items[26].metadata.id == "3250.00005" assert items[26].metadata.sequence_position == 8 assert_tensors_equal(items[26].label, [0.0]) assert items[26].channel_files == ["img_11"] assert_tensors_equal(items[26].numerical_non_image_features, [238.0, 0.0, 0.02, 84.0]) grouped = group_samples_into_sequences( filter_valid_classification_data_sources_items( items, file_to_path_mapping=None, max_sequence_position_value=None)) # There are 3 patients total, but one of them has missing measurements for all visits assert len(grouped) == 2 assert grouped[0].id == "2627.00001" assert grouped[1].id == "3250.00005" # 2627.00001 has full information for weeks 0, 4, and 8 assert len(grouped[0].items) == 3 assert grouped[0].items[0].metadata["VISIT"] == "V1" assert grouped[0].items[2].metadata["VISIT"] == "VST 3" assert len(grouped[1].items) == 9 assert items[16].metadata.sequence_position == 7
def test_group_items_with_label_positions() -> None: items = [ _create_item("a", 0, "a.0", 1), _create_item("a", 3, "a.3", math.inf), _create_item("a", 1, "a.1", 0), _create_item("a", 2, "a.2", 1), ] # Extracting the sequence from 2 to 3 grouped = group_samples_into_sequences(items, min_sequence_position_value=2, max_sequence_position_value=3) assert len(grouped) == 1 _assert_group(grouped[0], "a", ["a.2", 'a.3'])
def test_group_items() -> None: """ Test if grouping and filtering of sequence data sets works. """ def _create(id: str, sequence_position: int, file: Optional[str], metadata: str) -> SequenceDataSource: return SequenceDataSource( channel_files=[file], numerical_non_image_features=torch.tensor([]), categorical_non_image_features=torch.tensor([]), label=torch.tensor([]), metadata=GeneralSampleMetadata(id=id, sequence_position=sequence_position, props={"M": metadata})) items = [ _create("a", 1, "f", "a.1"), _create("a", 0, "f", "a.0"), _create("a", 4, "f", "a.4"), _create("b", 1, None, "b.1"), _create("b", 0, None, "b.0"), _create("c", 0, "f", "c.0"), _create("d", 1, "f", "d.1"), ] grouped = group_samples_into_sequences(items) assert len(grouped) == 3 def assert_group(group: ClassificationItemSequence, subject: str, props: List[str]) -> None: assert isinstance(group, ClassificationItemSequence) assert group.id == subject assert [i.metadata.props["M"] for i in group.items] == props # For subject a, item a.4 should be dropped because the consecutive sequence is only [0, 1] assert_group(grouped[0], "a", ["a.0", "a.1"]) assert_group(grouped[1], "b", ["b.0", "b.1"]) assert_group(grouped[2], "c", ["c.0"])