Exemplo n.º 1
0
def test_create_labels_tensor_for_minibatch() -> None:
    """
    Test to make sure labels tensor is created as expected for minibatch
    """

    sequences = [
        ClassificationItemSequence(id=x,
                                   items=_create_scalar_items(length=i + 1))
        for i, x in enumerate(["A", "B"])
    ]

    labels = ClassificationItemSequence.create_labels_tensor_for_minibatch(
        sequences, target_indices=[0, 1, 2])
    assert torch.allclose(labels,
                          torch.tensor([[[1.0], [np.nan], [np.nan]],
                                        [[1.0], [1.0], [np.nan]]]),
                          equal_nan=True)

    labels = ClassificationItemSequence.create_labels_tensor_for_minibatch(
        sequences, target_indices=[0, 1])
    assert torch.allclose(labels,
                          torch.tensor([[[1.0], [np.nan]], [[1.0], [1.0]]]),
                          equal_nan=True)

    labels = ClassificationItemSequence.create_labels_tensor_for_minibatch(
        sequences, target_indices=[0])
    assert torch.equal(labels, torch.tensor([[[1.0]], [[1.0]]]))
Exemplo n.º 2
0
def get_scalar_model_inputs_and_labels(
        model_config: ScalarModelBase, model: torch.nn.Module,
        sample: Dict[str, Any]) -> ScalarModelInputsAndLabels:
    """
    For a model that predicts scalars, gets the model input tensors from a sample returned by the data loader.
    :param model_config: The configuration object for the model.
    :param model: The instantiated PyTorch model.
    :param sample: A training sample, as returned by a PyTorch data loader (dictionary mapping from field name to value)
    :return: An instance of ScalarModelInputsAndLabels, containing the list of model input tensors,
    label tensor, subject IDs, and the data item reconstructed from the data loader output
    """
    if isinstance(model, DataParallelModel):
        model = model.get_module()

    if isinstance(model_config, SequenceModelBase):
        sequence_model: DeviceAwareModule[List[ClassificationItemSequence],
                                          torch.Tensor] = model  # type: ignore
        sequences = ClassificationItemSequence.from_minibatch(sample)
        subject_ids = [x.id for x in sequences]
        labels = ClassificationItemSequence.create_labels_tensor_for_minibatch(
            sequences=sequences,
            target_indices=model_config.get_target_indices())
        model_inputs = sequence_model.get_input_tensors(sequences)

        return ScalarModelInputsAndLabels[List[ClassificationItemSequence],
                                          torch.Tensor](
                                              model_inputs=model_inputs,
                                              labels=labels,
                                              subject_ids=subject_ids,
                                              data_item=sequences)
    else:
        scalar_model: DeviceAwareModule[ScalarItem,
                                        torch.Tensor] = model  # type: ignore
        scalar_item = ScalarItem.from_dict(sample)
        subject_ids = [str(x.id) for x in scalar_item.metadata]  # type: ignore
        model_inputs = scalar_model.get_input_tensors(scalar_item)

        return ScalarModelInputsAndLabels[ScalarItem, torch.Tensor](
            model_inputs=model_inputs,
            labels=scalar_item.label,
            subject_ids=subject_ids,
            data_item=scalar_item)
Exemplo n.º 3
0
def get_scalar_model_inputs_and_labels(model: torch.nn.Module,
                                       target_indices: List[int],
                                       sample: Dict[str, Any]) -> ScalarModelInputsAndLabels:
    """
    For a model that predicts scalars, gets the model input tensors from a sample returned by the data loader.
    :param model: The instantiated PyTorch model.
    :param target_indices: If this list is non-empty, assume that the model is a sequence model, and build the
    model inputs and labels for a model that predicts those specific positions in the sequence. If the list is empty,
    assume that the model is a normal (non-sequence) model.
    :param sample: A training sample, as returned by a PyTorch data loader (dictionary mapping from field name to value)
    :return: An instance of ScalarModelInputsAndLabels, containing the list of model input tensors,
    label tensor, subject IDs, and the data item reconstructed from the data loader output
    """
    if target_indices:
        sequence_model: DeviceAwareModule[List[ClassificationItemSequence], torch.Tensor] = model  # type: ignore
        sequences = ClassificationItemSequence.from_minibatch(sample)
        subject_ids = [x.id for x in sequences]
        labels = ClassificationItemSequence.create_labels_tensor_for_minibatch(
            sequences=sequences,
            target_indices=target_indices
        )
        model_inputs = sequence_model.get_input_tensors(sequences)

        return ScalarModelInputsAndLabels[List[ClassificationItemSequence]](
            model_inputs=model_inputs,
            labels=labels,
            subject_ids=subject_ids,
            data_item=sequences
        )
    else:
        scalar_model: DeviceAwareModule[ScalarItem, torch.Tensor] = model  # type: ignore
        scalar_item = ScalarItem.from_dict(sample)
        subject_ids = [str(x.id) for x in scalar_item.metadata]  # type: ignore
        model_inputs = scalar_model.get_input_tensors(scalar_item)

        return ScalarModelInputsAndLabels[ScalarItem](
            model_inputs=model_inputs,
            labels=scalar_item.label,
            subject_ids=subject_ids,
            data_item=scalar_item
        )