예제 #1
0
def test_group_items_with_min_and_max_sequence_position_values() -> None:
    """
    Test if grouping of sequence data works when requiring a full set of items.
    """
    items = [
        _create_item("a", 1, "a.1"),
        _create_item("a", 0, "a.0"),
        _create_item("a", 2, "a.2"),
        _create_item("b", 1, "b.1"),
        _create_item("b", 0, "b.0"),
    ]
    # When not providing a max_sequence_position_value, sequences of any length are OK.
    grouped = group_samples_into_sequences(items,
                                           max_sequence_position_value=None)
    assert len(grouped) == 2
    _assert_group(grouped[0], "a", ["a.0", "a.1", "a.2"])
    _assert_group(grouped[1], "b", ["b.0", "b.1"])
    # With a max_sequence_position_value, the set must be complete up to the given index.
    grouped = group_samples_into_sequences(items,
                                           min_sequence_position_value=1,
                                           max_sequence_position_value=2)
    assert len(grouped) == 2
    _assert_group(grouped[0], "a", ["a.1", "a.2"])
    # When a max position is given, the sequence will be truncated to at most contain the given value.
    grouped = group_samples_into_sequences(items,
                                           min_sequence_position_value=0,
                                           max_sequence_position_value=1)
    assert len(grouped) == 2
    _assert_group(grouped[0], "a", ["a.0", "a.1"])
    _assert_group(grouped[1], "b", ["b.0", "b.1"])
    grouped = group_samples_into_sequences(items,
                                           min_sequence_position_value=1,
                                           max_sequence_position_value=1)
    assert len(grouped) == 2
    _assert_group(grouped[0], "a", ["a.1"])
    _assert_group(grouped[1], "b", ["b.1"])
    # Allow sequences upto max_sequence_position_value=2
    grouped = group_samples_into_sequences(items,
                                           min_sequence_position_value=1,
                                           max_sequence_position_value=2)
    assert len(grouped) == 2
    _assert_group(grouped[0], "a", ["a.1", "a.2"])
    _assert_group(grouped[1], "b", ["b.1"])

    # There are no items that have sequence position == 3, hence the next two calls should not return any items.
    grouped = group_samples_into_sequences(items,
                                           min_sequence_position_value=3)
    assert len(grouped) == 0
    # Check that items upto max_sequence_position_value=3 are included
    grouped = group_samples_into_sequences(items,
                                           max_sequence_position_value=3)
    assert len(grouped) == 2

    # Sequence positions must be unique
    with pytest.raises(ValueError) as ex:
        group_samples_into_sequences([_create_item("a", 0, "a.0")] * 2)
    assert "contains duplicates" in str(ex)
예제 #2
0
def test_load_items_seq_from_dataset() -> None:
    """
    Test loading a sequence dataset with numerical, categorical features and images.
    """
    dummy_dataset = full_ml_test_data_path(
    ) / "sequence_data_for_classification" / "dataset.csv"
    df = pd.read_csv(dummy_dataset, sep=",", dtype=str)
    items: List[SequenceDataSource] = DataSourceReader[SequenceDataSource](
        data_frame=df,
        image_channels=None,
        image_file_column="IMG",
        label_channels=None,
        label_value_column="Label",
        numerical_columns=["NUM1", "NUM2", "NUM3", "NUM4"],
        sequence_column="Position").load_data_sources()
    assert len(items) == 3 * 9  # 3 subjects, 9 visits each, no missing
    assert items[0].metadata.id == "2137.00005"
    assert items[0].metadata.sequence_position == 0
    assert items[0].metadata.props["CAT2"] == "category_A"
    # One of the labels is missing, missing labels should be encoded as NaN
    assert math.isnan(items[0].label[0])
    assert items[0].channel_files == ["img_1"]
    assert str(items[0].numerical_non_image_features.tolist()) == str(
        [362.0, np.nan, np.nan, 71.0])
    assert items[8].metadata.id == "2137.00005"
    assert items[8].metadata.sequence_position == 8
    assert items[8].label.tolist() == [0.0]
    assert items[8].channel_files == ['']
    assert str(items[8].numerical_non_image_features.tolist()) == str(
        [350.0, np.nan, np.nan, 8.0])
    assert items[16].metadata.id == "2627.00001"
    assert items[16].label.tolist() == [0.0]
    assert items[16].channel_files == ["img_2"]
    assert_tensors_equal(items[16].numerical_non_image_features,
                         [217.0, 0.0, 0.01, 153.0])
    assert items[26].metadata.id == "3250.00005"
    assert items[26].metadata.sequence_position == 8
    assert_tensors_equal(items[26].label, [0.0])
    assert items[26].channel_files == ["img_11"]
    assert_tensors_equal(items[26].numerical_non_image_features,
                         [238.0, 0.0, 0.02, 84.0])

    grouped = group_samples_into_sequences(
        filter_valid_classification_data_sources_items(
            items, file_to_path_mapping=None,
            max_sequence_position_value=None))
    # There are 3 patients total, but one of them has missing measurements for all visits
    assert len(grouped) == 2
    assert grouped[0].id == "2627.00001"
    assert grouped[1].id == "3250.00005"
    # 2627.00001 has full information for weeks 0, 4, and 8
    assert len(grouped[0].items) == 3
    assert grouped[0].items[0].metadata["VISIT"] == "V1"
    assert grouped[0].items[2].metadata["VISIT"] == "VST 3"
    assert len(grouped[1].items) == 9
    assert items[16].metadata.sequence_position == 7
예제 #3
0
def test_group_items_with_label_positions() -> None:
    items = [
        _create_item("a", 0, "a.0", 1),
        _create_item("a", 3, "a.3", math.inf),
        _create_item("a", 1, "a.1", 0),
        _create_item("a", 2, "a.2", 1),
    ]

    # Extracting the sequence from 2 to 3
    grouped = group_samples_into_sequences(items,
                                           min_sequence_position_value=2,
                                           max_sequence_position_value=3)
    assert len(grouped) == 1
    _assert_group(grouped[0], "a", ["a.2", 'a.3'])
예제 #4
0
def test_group_items() -> None:
    """
    Test if grouping and filtering of sequence data sets works.
    """
    def _create(id: str, sequence_position: int, file: Optional[str],
                metadata: str) -> SequenceDataSource:
        return SequenceDataSource(
            channel_files=[file],
            numerical_non_image_features=torch.tensor([]),
            categorical_non_image_features=torch.tensor([]),
            label=torch.tensor([]),
            metadata=GeneralSampleMetadata(id=id,
                                           sequence_position=sequence_position,
                                           props={"M": metadata}))

    items = [
        _create("a", 1, "f", "a.1"),
        _create("a", 0, "f", "a.0"),
        _create("a", 4, "f", "a.4"),
        _create("b", 1, None, "b.1"),
        _create("b", 0, None, "b.0"),
        _create("c", 0, "f", "c.0"),
        _create("d", 1, "f", "d.1"),
    ]
    grouped = group_samples_into_sequences(items)
    assert len(grouped) == 3

    def assert_group(group: ClassificationItemSequence, subject: str,
                     props: List[str]) -> None:
        assert isinstance(group, ClassificationItemSequence)
        assert group.id == subject
        assert [i.metadata.props["M"] for i in group.items] == props

    # For subject a, item a.4 should be dropped because the consecutive sequence is only [0, 1]
    assert_group(grouped[0], "a", ["a.0", "a.1"])
    assert_group(grouped[1], "b", ["b.0", "b.1"])
    assert_group(grouped[2], "c", ["c.0"])