def test_custom_collate() -> None: """ Tests the custom collate function that collates metadata into lists. """ metadata = PatientMetadata(patient_id='42') foo = "foo" d1 = {foo: 1, SAMPLE_METADATA_FIELD: "something"} d2 = {foo: 2, SAMPLE_METADATA_FIELD: metadata} result = collate_with_metadata([d1, d2]) assert foo in result assert SAMPLE_METADATA_FIELD in result assert isinstance(result[SAMPLE_METADATA_FIELD], list) assert result[SAMPLE_METADATA_FIELD] == ["something", metadata] assert isinstance(result[foo], torch.Tensor) assert result[foo].tolist() == [1, 2]
def test_seq_dataset_loader() -> None: dummy_dataset = full_ml_test_data_path( ) / "sequence_data_for_classification" / "dataset.csv" df = pd.read_csv(dummy_dataset, sep=",", dtype=str) dataset = SequenceDataset(args=SequenceModelBase( image_file_column="IMG", label_value_column="Label", numerical_columns=["NUM1", "NUM2", "NUM3", "NUM4"], sequence_target_positions=[8], sequence_column="Position", local_dataset=Path(), should_validate=False), data_frame=df) assert len(dataset) == 2 # Patch the load_images function that well be called once we access a dataset item with mock.patch('InnerEye.ML.dataset.scalar_sample.load_images_and_stack', return_value=ImageAndSegmentations[torch.Tensor]( images=torch.ones(1), segmentations=torch.empty(0))): item0 = ClassificationItemSequence(**dataset[0]) item1 = ClassificationItemSequence(**dataset[1]) assert item0.id == "2627.00001" len_2627 = 3 assert len(item0.items) == len_2627 assert item1.id == "3250.00005" len_3250 = 9 assert len(item1.items) == len_3250 # Data loaders use a customized collate function, that must work with the sequences too. collated = collate_with_metadata([dataset[0], dataset[1]]) assert collated["id"] == ["2627.00001", "3250.00005"] # All subject sequences should be turned into lists of lists. assert isinstance(collated["items"], list) assert len(collated["items"]) == 2 assert isinstance(collated["items"][0], list) assert isinstance(collated["items"][1], list) assert len(collated["items"][0]) == len_2627 assert len(collated["items"][1]) == len_3250 back_to_items = ClassificationItemSequence(**collated) assert back_to_items.id == ["2627.00001", "3250.00005"]