예제 #1
0
def test_class_label():
    cl1 = ClassLabel(num_classes=5)
    cl2 = ClassLabel(names=["apple", "orange", "banana"])
    with pytest.raises(ValueError):
        cl3 = ClassLabel(names=["apple", "orange", "banana", "apple"])
    with pytest.raises(ValueError):
        cl4 = ClassLabel(names=["apple", "orange", "banana", "apple"], num_classes=2)
    cl5 = ClassLabel()
    cl6 = ClassLabel(names_file="./hub/schema/tests/class_label_names.txt")

    assert cl1.names == ["0", "1", "2", "3", "4"]
    assert cl2.names == ["apple", "orange", "banana"]
    assert cl6.names == [
        "alpha",
        "beta",
        "gamma",
    ]
    assert cl1.num_classes == 5
    assert cl2.num_classes == 3
    assert cl1.str2int("3") == 3
    assert cl2.str2int("orange") == 1
    assert cl1.int2str(4) == "4"
    assert cl2.int2str(2) == "banana"

    with pytest.raises(KeyError):
        cl2.str2int("2")
    with pytest.raises(ValueError):
        cl1.str2int("8")
    with pytest.raises(ValueError):
        cl1.str2int("abc")
    with pytest.raises(ValueError):
        cl1.names = ["ab", "cd", "ef", "gh"]
    with pytest.raises(ValueError):
        cl2.names = ["ab", "cd", "ef", "gh"]
예제 #2
0
def test_class_label():
    bel1 = ClassLabel(num_classes=4)
    bel2 = ClassLabel(names=["alpha", "beta", "gamma"])
    ClassLabel(names_file=names_file)
    assert bel1.names == ["0", "1", "2", "3"]
    assert bel2.names == ["alpha", "beta", "gamma"]
    assert bel1.str2int("1") == 1
    assert bel2.str2int("gamma") == 2
    assert bel1.int2str(
        2) is None  # FIXME This is a bug, should raise an error
    assert bel2.int2str(0) == "alpha"
    assert bel1.num_classes == 4
    assert bel2.num_classes == 3
    bel1.get_attr_dict()
예제 #3
0
class Segmentation(Tensor):
    """`HubSchema` for segmentation"""
    def __init__(
        self,
        shape: Tuple[int, ...] = None,
        dtype: str = None,
        num_classes: int = None,
        names: Tuple[str] = None,
        names_file: str = None,
        max_shape: Tuple[int, ...] = None,
        chunks=None,
        compressor="lz4",
    ):
        """Constructs a Segmentation HubSchema.
        Also constructs ClassLabel HubSchema for Segmentation classes.

        Parameters
        ----------
        shape: tuple of ints or None
            Shape in format (height, width, 1)
        dtype: str
            dtype of segmentation array: `uint16` or `uint8`
        num_classes: int
            Number of classes. All labels must be < num_classes.
        names: `list<str>`
            string names for the integer classes. The order in which the names are provided is kept.
        names_file: str
            Path to a file with names for the integer classes, one per line.
        max_shape : tuple[int]
            Maximum shape of tensor shape if tensor is dynamic
        chunks : tuple[int] | True
            Describes how to split tensor dimensions into chunks (files) to store them efficiently.
            It is anticipated that each file should be ~16MB.
            Sample Count is also in the list of tensor's dimensions (first dimension)
            If default value is chosen, automatically detects how to split into chunks
        """
        super().__init__(shape, dtype, max_shape=max_shape, chunks=chunks)
        self.class_labels = ClassLabel(
            num_classes=num_classes,
            names=names,
            names_file=names_file,
            chunks=chunks,
            compressor="lz4",
        )

    def get_segmentation_classes(self):
        """Get classes of the segmentation mask"""
        class_indices = np.unique(self)
        return [self.class_labels.int2str(value) for value in class_indices]

    def get_attr_dict(self):
        """Return class attributes."""
        return self.__dict__

    def __str__(self):
        out = super().__str__()
        out = "Segmentation" + out[6:-1]
        out = (out + ", names=" + str(self.class_labels._names)
               if self.class_labels._names is not None else out)
        out = (out + ", num_classes=" + str(self.class_labels._num_classes)
               if self.class_labels._num_classes is not None else out)
        out += ")"
        return out

    def __repr__(self):
        return self.__str__()