Example #1
0
def test_SizedDict_shared():
    d = SizedDict(shared=True)
    x = torch.randn(10)
    d["a"] = x

    mp = multiprocessing.get_context("forkserver")
    p = mp.Process(target=_set, args=(d,))
    p.start()
    p.join()
    assert d["a"][0] == 10
Example #2
0
def test_SizedDict_size():
    d = SizedDict()
    assert d.size == 0

    x = np.random.randn(10)
    d["a"] = x
    assert d.size == get_size(x) + sys.getsizeof("a")

    y = np.random.randn(10)
    d["b"] = y
    assert d.size == get_size(x) + get_size(y) + sys.getsizeof("a") + sys.getsizeof("b")

    # Overwrite
    z = np.random.randn(10)
    d["b"] = z
    assert d.size == get_size(x) + get_size(z) + sys.getsizeof("a") + sys.getsizeof("b")
Example #3
0
    def __init__(
        self,
        path_name_type_list: Collection[Tuple[str, str, str]],
        preprocess: Callable[[str, Dict[str, np.ndarray]],
                             Dict[str, np.ndarray]] = None,
        float_dtype: str = "float32",
        int_dtype: str = "long",
        max_cache_size: Union[float, int, str] = 0.0,
        max_cache_fd: int = 0,
    ):
        assert check_argument_types()
        if len(path_name_type_list) == 0:
            raise ValueError(
                '1 or more elements are required for "path_name_type_list"')

        path_name_type_list = copy.deepcopy(path_name_type_list)
        self.preprocess = preprocess

        self.float_dtype = float_dtype
        self.int_dtype = int_dtype
        self.max_cache_fd = max_cache_fd

        self.loader_dict = {}
        self.debug_info = {}
        for path, name, _type in path_name_type_list:
            if name in self.loader_dict:
                raise RuntimeError(f'"{name}" is duplicated for data-key')

            loader = self._build_loader(path, _type)
            self.loader_dict[name] = loader
            self.debug_info[name] = path, _type
            if len(self.loader_dict[name]) == 0:
                raise RuntimeError(f"{path} has no samples")

            # TODO(kamo): Should check consistency of each utt-keys?

        if isinstance(max_cache_size, str):
            max_cache_size = humanfriendly.parse_size(max_cache_size)
        self.max_cache_size = max_cache_size
        if max_cache_size > 0:
            self.cache = SizedDict(shared=True)
        else:
            self.cache = None
Example #4
0
def test_SizedDict_len():
    d = SizedDict(data={"a": 2, "b": 5, "c": 10})
    assert len(d) == 3
Example #5
0
def test_SizedDict_contains():
    d = SizedDict(data={"a": 2, "b": 5, "c": 10})
    assert "a" in d
Example #6
0
def test_SizedDict_iter():
    d = SizedDict(data={"a": 2, "b": 5, "c": 10})
    assert list(iter(d)) == ["a", "b", "c"]
Example #7
0
def test_SizedDict_getitem():
    d = SizedDict(data={"a": 2, "b": 5, "c": 10})
    assert d["a"] == 2