def test_sequence_dataloader() -> None: """ Test if we can create a data loader from the dataset, and recover the items as expected in batched form. Including instances where not all elements of the sequence have labels. """ csv_string = StringIO("""subject,seq,path,value,scalar1,scalar2,META S1,0,foo.nii,,0,0,M1 S1,1,,True,1.1,1.2,M2 S2,0,bar.nii,False,2.1,2.2,M3 S2,1,,False,2.0,2.0,M4 """) df = pd.read_csv(csv_string, sep=",", dtype=str) config = SequenceModelBase(image_file_column=None, label_value_column="value", numerical_columns=["scalar1"], sequence_target_positions=[1], sequence_column="seq", local_dataset=Path.cwd(), should_validate=False) dataset = SequenceDataset(config, data_frame=df) assert len(dataset) == 2 data_loader = dataset.as_data_loader(shuffle=False, batch_size=2, num_dataload_workers=0) # We have 2 subjects, with a batch size of 2 those should be turned into 1 batch data_loader_output = list(i for i in data_loader) assert len(data_loader_output) == 1 loaded = list(ClassificationItemSequence(**i) for i in data_loader_output) assert loaded[0].id == ["S1", "S2"] assert isinstance(loaded[0].items[0][0], ScalarItem) assert loaded[0].items[0][0].metadata.id == "S1" assert loaded[0].items[0][1].metadata.id == "S1" assert loaded[0].items[1][0].metadata.id == "S2" assert loaded[0].items[1][1].metadata.id == "S2" # The batched sequence data are awkward to work with. Check if we can un-roll them correctly via # from_minibatch un_batched = ClassificationItemSequence.from_minibatch( data_loader_output[0]) assert len(un_batched) == 2 for i in range(2): assert un_batched[i].id == dataset.items[i].id assert len(un_batched[i].items) == len(dataset.items[i].items) for j in range(len(un_batched[i].items)): assert un_batched[i].items[j].metadata.id == dataset.items[ i].items[j].metadata.id
def get_scalar_model_inputs_and_labels( model_config: ScalarModelBase, model: torch.nn.Module, sample: Dict[str, Any]) -> ScalarModelInputsAndLabels: """ For a model that predicts scalars, gets the model input tensors from a sample returned by the data loader. :param model_config: The configuration object for the model. :param model: The instantiated PyTorch model. :param sample: A training sample, as returned by a PyTorch data loader (dictionary mapping from field name to value) :return: An instance of ScalarModelInputsAndLabels, containing the list of model input tensors, label tensor, subject IDs, and the data item reconstructed from the data loader output """ if isinstance(model, DataParallelModel): model = model.get_module() if isinstance(model_config, SequenceModelBase): sequence_model: DeviceAwareModule[List[ClassificationItemSequence], torch.Tensor] = model # type: ignore sequences = ClassificationItemSequence.from_minibatch(sample) subject_ids = [x.id for x in sequences] labels = ClassificationItemSequence.create_labels_tensor_for_minibatch( sequences=sequences, target_indices=model_config.get_target_indices()) model_inputs = sequence_model.get_input_tensors(sequences) return ScalarModelInputsAndLabels[List[ClassificationItemSequence], torch.Tensor]( model_inputs=model_inputs, labels=labels, subject_ids=subject_ids, data_item=sequences) else: scalar_model: DeviceAwareModule[ScalarItem, torch.Tensor] = model # type: ignore scalar_item = ScalarItem.from_dict(sample) subject_ids = [str(x.id) for x in scalar_item.metadata] # type: ignore model_inputs = scalar_model.get_input_tensors(scalar_item) return ScalarModelInputsAndLabels[ScalarItem, torch.Tensor]( model_inputs=model_inputs, labels=scalar_item.label, subject_ids=subject_ids, data_item=scalar_item)
def get_scalar_model_inputs_and_labels(model: torch.nn.Module, target_indices: List[int], sample: Dict[str, Any]) -> ScalarModelInputsAndLabels: """ For a model that predicts scalars, gets the model input tensors from a sample returned by the data loader. :param model: The instantiated PyTorch model. :param target_indices: If this list is non-empty, assume that the model is a sequence model, and build the model inputs and labels for a model that predicts those specific positions in the sequence. If the list is empty, assume that the model is a normal (non-sequence) model. :param sample: A training sample, as returned by a PyTorch data loader (dictionary mapping from field name to value) :return: An instance of ScalarModelInputsAndLabels, containing the list of model input tensors, label tensor, subject IDs, and the data item reconstructed from the data loader output """ if target_indices: sequence_model: DeviceAwareModule[List[ClassificationItemSequence], torch.Tensor] = model # type: ignore sequences = ClassificationItemSequence.from_minibatch(sample) subject_ids = [x.id for x in sequences] labels = ClassificationItemSequence.create_labels_tensor_for_minibatch( sequences=sequences, target_indices=target_indices ) model_inputs = sequence_model.get_input_tensors(sequences) return ScalarModelInputsAndLabels[List[ClassificationItemSequence]]( model_inputs=model_inputs, labels=labels, subject_ids=subject_ids, data_item=sequences ) else: scalar_model: DeviceAwareModule[ScalarItem, torch.Tensor] = model # type: ignore scalar_item = ScalarItem.from_dict(sample) subject_ids = [str(x.id) for x in scalar_item.metadata] # type: ignore model_inputs = scalar_model.get_input_tensors(scalar_item) return ScalarModelInputsAndLabels[ScalarItem]( model_inputs=model_inputs, labels=scalar_item.label, subject_ids=subject_ids, data_item=scalar_item )
def test_sequence_dataset_all(test_output_dirs: OutputFolderForTests) -> None: """ Check that the sequence dataset works end-to-end, including applying the right standardization. """ csv_string = """subject,seq,value,scalar1,scalar2,META,BETA S1,0,False,0,0,M1,B1 S1,1,True,1,10,M2,B2 S2,0,False,2,20,M2,B1 S3,0,True,3,30,M1,B1 S4,0,True,4,40,M2,B1 """ csv_path = create_dataset_csv_file(csv_string, test_output_dirs.root_dir) config = SequenceModelBase(local_dataset=csv_path, image_file_column=None, label_value_column="value", numerical_columns=["scalar1", "scalar2"], sequence_target_positions=[0], categorical_columns=["META", "BETA"], sequence_column="seq", num_dataload_workers=0, train_batch_size=2, should_validate=False, shuffle=False) config.read_dataset_if_needed() df = config.dataset_data_frame assert df is not None df1 = df[df.subject.isin(["S1", "S2"])] df2 = df[df.subject == "S3"] df3 = df[df.subject == "S4"] splits = DatasetSplits(train=df1, val=df2, test=df3) with mock.patch.object(SequenceModelBase, 'get_model_train_test_dataset_splits', return_value=splits): train_val_loaders = config.create_data_loaders() # Expected feature mean: Mean of the training data (0, 0), (1, 10), (2, 20) = (1, 10) # Expected (biased corrected) std estimate: Std of (0, 0), (1, 10), (2, 20) = (1, 10) feature_stats = config.get_torch_dataset_for_inference( ModelExecutionMode.TRAIN).feature_statistics assert feature_stats is not None assert_tensors_equal(feature_stats.mean, [1, 10]) assert_tensors_equal(feature_stats.std, [1, 10]) train_items = list( ClassificationItemSequence.from_minibatch(b) for b in train_val_loaders[ModelExecutionMode.TRAIN]) assert len( train_items ) == 1, "2 items in training set with batch size of 2 should return 1 minibatch" assert len(train_items[0]) == 2 assert train_items[0][0].id == "S1" assert_tensors_equal( train_items[0][0].items[0].get_all_non_imaging_features(), [-1., -1., 1., 0., 1., 0.]) assert_tensors_equal( train_items[0][0].items[1].get_all_non_imaging_features(), [0., 0., 0., 1., 0., 1.]) assert train_items[0][1].id == "S2" assert_tensors_equal( train_items[0][1].items[0].get_all_non_imaging_features(), [1., 1., 0., 1., 1., 0.]) val_items = list( ClassificationItemSequence.from_minibatch(b) for b in train_val_loaders[ModelExecutionMode.VAL]) assert len(val_items) == 1 assert len(val_items[0]) == 1 assert val_items[0][0].id == "S3" # Items in the validation set should be normalized using the mean and std on the training data. # Hence, the non-image features (3, 30) should turn into (2, 2) assert_tensors_equal( val_items[0][0].items[0].get_all_non_imaging_features(), [2., 2., 1., 0., 1., 0.]) # Check that the test set is also normalized correctly using the training mean and std. test_items = list( ClassificationItemSequence(**b) for b in config.get_torch_dataset_for_inference(ModelExecutionMode.TEST)) assert test_items[0].id == "S4" # Check Non-image features of (4, 40) assert_tensors_equal( test_items[0].items[0].get_all_non_imaging_features(), [3., 3., 0., 1., 1., 0.])