コード例 #1
0
def test_custom_collate() -> None:
    """
    Tests the custom collate function that collates metadata into lists.
    """
    metadata = PatientMetadata(patient_id='42')
    foo = "foo"
    d1 = {foo: 1, SAMPLE_METADATA_FIELD: "something"}
    d2 = {foo: 2, SAMPLE_METADATA_FIELD: metadata}
    result = collate_with_metadata([d1, d2])
    assert foo in result
    assert SAMPLE_METADATA_FIELD in result
    assert isinstance(result[SAMPLE_METADATA_FIELD], list)
    assert result[SAMPLE_METADATA_FIELD] == ["something", metadata]
    assert isinstance(result[foo], torch.Tensor)
    assert result[foo].tolist() == [1, 2]
コード例 #2
0
def test_seq_dataset_loader() -> None:
    dummy_dataset = full_ml_test_data_path(
    ) / "sequence_data_for_classification" / "dataset.csv"
    df = pd.read_csv(dummy_dataset, sep=",", dtype=str)
    dataset = SequenceDataset(args=SequenceModelBase(
        image_file_column="IMG",
        label_value_column="Label",
        numerical_columns=["NUM1", "NUM2", "NUM3", "NUM4"],
        sequence_target_positions=[8],
        sequence_column="Position",
        local_dataset=Path(),
        should_validate=False),
                              data_frame=df)
    assert len(dataset) == 2
    # Patch the load_images function that well be called once we access a dataset item
    with mock.patch('InnerEye.ML.dataset.scalar_sample.load_images_and_stack',
                    return_value=ImageAndSegmentations[torch.Tensor](
                        images=torch.ones(1), segmentations=torch.empty(0))):
        item0 = ClassificationItemSequence(**dataset[0])
        item1 = ClassificationItemSequence(**dataset[1])
        assert item0.id == "2627.00001"
        len_2627 = 3
        assert len(item0.items) == len_2627
        assert item1.id == "3250.00005"
        len_3250 = 9
        assert len(item1.items) == len_3250

        # Data loaders use a customized collate function, that must work with the sequences too.
        collated = collate_with_metadata([dataset[0], dataset[1]])
        assert collated["id"] == ["2627.00001", "3250.00005"]
        # All subject sequences should be turned into lists of lists.
        assert isinstance(collated["items"], list)
        assert len(collated["items"]) == 2
        assert isinstance(collated["items"][0], list)
        assert isinstance(collated["items"][1], list)
        assert len(collated["items"][0]) == len_2627
        assert len(collated["items"][1]) == len_3250
        back_to_items = ClassificationItemSequence(**collated)
        assert back_to_items.id == ["2627.00001", "3250.00005"]