Example #1
0
def test_sequence_dataloader() -> None:
    """
    Test if we can create a data loader from the dataset, and recover the items as expected in batched form.
    Including instances where not all elements of the sequence have labels.
    """
    csv_string = StringIO("""subject,seq,path,value,scalar1,scalar2,META
S1,0,foo.nii,,0,0,M1
S1,1,,True,1.1,1.2,M2
S2,0,bar.nii,False,2.1,2.2,M3
S2,1,,False,2.0,2.0,M4
""")
    df = pd.read_csv(csv_string, sep=",", dtype=str)
    config = SequenceModelBase(image_file_column=None,
                               label_value_column="value",
                               numerical_columns=["scalar1"],
                               sequence_target_positions=[1],
                               sequence_column="seq",
                               local_dataset=Path.cwd(),
                               should_validate=False)
    dataset = SequenceDataset(config, data_frame=df)
    assert len(dataset) == 2
    data_loader = dataset.as_data_loader(shuffle=False,
                                         batch_size=2,
                                         num_dataload_workers=0)
    # We have 2 subjects, with a batch size of 2 those should be turned into 1 batch
    data_loader_output = list(i for i in data_loader)
    assert len(data_loader_output) == 1
    loaded = list(ClassificationItemSequence(**i) for i in data_loader_output)
    assert loaded[0].id == ["S1", "S2"]
    assert isinstance(loaded[0].items[0][0], ScalarItem)
    assert loaded[0].items[0][0].metadata.id == "S1"
    assert loaded[0].items[0][1].metadata.id == "S1"
    assert loaded[0].items[1][0].metadata.id == "S2"
    assert loaded[0].items[1][1].metadata.id == "S2"

    # The batched sequence data are awkward to work with. Check if we can un-roll them correctly via
    # from_minibatch
    un_batched = ClassificationItemSequence.from_minibatch(
        data_loader_output[0])
    assert len(un_batched) == 2
    for i in range(2):
        assert un_batched[i].id == dataset.items[i].id
        assert len(un_batched[i].items) == len(dataset.items[i].items)
        for j in range(len(un_batched[i].items)):
            assert un_batched[i].items[j].metadata.id == dataset.items[
                i].items[j].metadata.id
Example #2
0
def get_scalar_model_inputs_and_labels(
        model_config: ScalarModelBase, model: torch.nn.Module,
        sample: Dict[str, Any]) -> ScalarModelInputsAndLabels:
    """
    For a model that predicts scalars, gets the model input tensors from a sample returned by the data loader.
    :param model_config: The configuration object for the model.
    :param model: The instantiated PyTorch model.
    :param sample: A training sample, as returned by a PyTorch data loader (dictionary mapping from field name to value)
    :return: An instance of ScalarModelInputsAndLabels, containing the list of model input tensors,
    label tensor, subject IDs, and the data item reconstructed from the data loader output
    """
    if isinstance(model, DataParallelModel):
        model = model.get_module()

    if isinstance(model_config, SequenceModelBase):
        sequence_model: DeviceAwareModule[List[ClassificationItemSequence],
                                          torch.Tensor] = model  # type: ignore
        sequences = ClassificationItemSequence.from_minibatch(sample)
        subject_ids = [x.id for x in sequences]
        labels = ClassificationItemSequence.create_labels_tensor_for_minibatch(
            sequences=sequences,
            target_indices=model_config.get_target_indices())
        model_inputs = sequence_model.get_input_tensors(sequences)

        return ScalarModelInputsAndLabels[List[ClassificationItemSequence],
                                          torch.Tensor](
                                              model_inputs=model_inputs,
                                              labels=labels,
                                              subject_ids=subject_ids,
                                              data_item=sequences)
    else:
        scalar_model: DeviceAwareModule[ScalarItem,
                                        torch.Tensor] = model  # type: ignore
        scalar_item = ScalarItem.from_dict(sample)
        subject_ids = [str(x.id) for x in scalar_item.metadata]  # type: ignore
        model_inputs = scalar_model.get_input_tensors(scalar_item)

        return ScalarModelInputsAndLabels[ScalarItem, torch.Tensor](
            model_inputs=model_inputs,
            labels=scalar_item.label,
            subject_ids=subject_ids,
            data_item=scalar_item)
def get_scalar_model_inputs_and_labels(model: torch.nn.Module,
                                       target_indices: List[int],
                                       sample: Dict[str, Any]) -> ScalarModelInputsAndLabels:
    """
    For a model that predicts scalars, gets the model input tensors from a sample returned by the data loader.
    :param model: The instantiated PyTorch model.
    :param target_indices: If this list is non-empty, assume that the model is a sequence model, and build the
    model inputs and labels for a model that predicts those specific positions in the sequence. If the list is empty,
    assume that the model is a normal (non-sequence) model.
    :param sample: A training sample, as returned by a PyTorch data loader (dictionary mapping from field name to value)
    :return: An instance of ScalarModelInputsAndLabels, containing the list of model input tensors,
    label tensor, subject IDs, and the data item reconstructed from the data loader output
    """
    if target_indices:
        sequence_model: DeviceAwareModule[List[ClassificationItemSequence], torch.Tensor] = model  # type: ignore
        sequences = ClassificationItemSequence.from_minibatch(sample)
        subject_ids = [x.id for x in sequences]
        labels = ClassificationItemSequence.create_labels_tensor_for_minibatch(
            sequences=sequences,
            target_indices=target_indices
        )
        model_inputs = sequence_model.get_input_tensors(sequences)

        return ScalarModelInputsAndLabels[List[ClassificationItemSequence]](
            model_inputs=model_inputs,
            labels=labels,
            subject_ids=subject_ids,
            data_item=sequences
        )
    else:
        scalar_model: DeviceAwareModule[ScalarItem, torch.Tensor] = model  # type: ignore
        scalar_item = ScalarItem.from_dict(sample)
        subject_ids = [str(x.id) for x in scalar_item.metadata]  # type: ignore
        model_inputs = scalar_model.get_input_tensors(scalar_item)

        return ScalarModelInputsAndLabels[ScalarItem](
            model_inputs=model_inputs,
            labels=scalar_item.label,
            subject_ids=subject_ids,
            data_item=scalar_item
        )
Example #4
0
def test_sequence_dataset_all(test_output_dirs: OutputFolderForTests) -> None:
    """
    Check that the sequence dataset works end-to-end, including applying the right standardization.
    """
    csv_string = """subject,seq,value,scalar1,scalar2,META,BETA
S1,0,False,0,0,M1,B1
S1,1,True,1,10,M2,B2
S2,0,False,2,20,M2,B1
S3,0,True,3,30,M1,B1
S4,0,True,4,40,M2,B1
"""
    csv_path = create_dataset_csv_file(csv_string, test_output_dirs.root_dir)
    config = SequenceModelBase(local_dataset=csv_path,
                               image_file_column=None,
                               label_value_column="value",
                               numerical_columns=["scalar1", "scalar2"],
                               sequence_target_positions=[0],
                               categorical_columns=["META", "BETA"],
                               sequence_column="seq",
                               num_dataload_workers=0,
                               train_batch_size=2,
                               should_validate=False,
                               shuffle=False)
    config.read_dataset_if_needed()
    df = config.dataset_data_frame
    assert df is not None
    df1 = df[df.subject.isin(["S1", "S2"])]
    df2 = df[df.subject == "S3"]
    df3 = df[df.subject == "S4"]
    splits = DatasetSplits(train=df1, val=df2, test=df3)
    with mock.patch.object(SequenceModelBase,
                           'get_model_train_test_dataset_splits',
                           return_value=splits):
        train_val_loaders = config.create_data_loaders()
        # Expected feature mean: Mean of the training data (0, 0), (1, 10), (2, 20) = (1, 10)
        # Expected (biased corrected) std estimate: Std of (0, 0), (1, 10), (2, 20) = (1, 10)
        feature_stats = config.get_torch_dataset_for_inference(
            ModelExecutionMode.TRAIN).feature_statistics
        assert feature_stats is not None
        assert_tensors_equal(feature_stats.mean, [1, 10])
        assert_tensors_equal(feature_stats.std, [1, 10])

        train_items = list(
            ClassificationItemSequence.from_minibatch(b)
            for b in train_val_loaders[ModelExecutionMode.TRAIN])
        assert len(
            train_items
        ) == 1, "2 items in training set with batch size of 2 should return 1 minibatch"
        assert len(train_items[0]) == 2
        assert train_items[0][0].id == "S1"
        assert_tensors_equal(
            train_items[0][0].items[0].get_all_non_imaging_features(),
            [-1., -1., 1., 0., 1., 0.])
        assert_tensors_equal(
            train_items[0][0].items[1].get_all_non_imaging_features(),
            [0., 0., 0., 1., 0., 1.])
        assert train_items[0][1].id == "S2"
        assert_tensors_equal(
            train_items[0][1].items[0].get_all_non_imaging_features(),
            [1., 1., 0., 1., 1., 0.])
        val_items = list(
            ClassificationItemSequence.from_minibatch(b)
            for b in train_val_loaders[ModelExecutionMode.VAL])
        assert len(val_items) == 1
        assert len(val_items[0]) == 1
        assert val_items[0][0].id == "S3"
        # Items in the validation set should be normalized using the mean and std on the training data.
        # Hence, the non-image features (3, 30) should turn into (2, 2)
        assert_tensors_equal(
            val_items[0][0].items[0].get_all_non_imaging_features(),
            [2., 2., 1., 0., 1., 0.])

        # Check that the test set is also normalized correctly using the training mean and std.
        test_items = list(
            ClassificationItemSequence(**b) for b in
            config.get_torch_dataset_for_inference(ModelExecutionMode.TEST))
        assert test_items[0].id == "S4"
        # Check Non-image features of (4, 40)
        assert_tensors_equal(
            test_items[0].items[0].get_all_non_imaging_features(),
            [3., 3., 0., 1., 1., 0.])