def test_SizedDict_shared(): d = SizedDict(shared=True) x = torch.randn(10) d["a"] = x mp = multiprocessing.get_context("forkserver") p = mp.Process(target=_set, args=(d,)) p.start() p.join() assert d["a"][0] == 10
def test_SizedDict_size(): d = SizedDict() assert d.size == 0 x = np.random.randn(10) d["a"] = x assert d.size == get_size(x) + sys.getsizeof("a") y = np.random.randn(10) d["b"] = y assert d.size == get_size(x) + get_size(y) + sys.getsizeof("a") + sys.getsizeof("b") # Overwrite z = np.random.randn(10) d["b"] = z assert d.size == get_size(x) + get_size(z) + sys.getsizeof("a") + sys.getsizeof("b")
def __init__( self, path_name_type_list: Collection[Tuple[str, str, str]], preprocess: Callable[[str, Dict[str, np.ndarray]], Dict[str, np.ndarray]] = None, float_dtype: str = "float32", int_dtype: str = "long", max_cache_size: Union[float, int, str] = 0.0, max_cache_fd: int = 0, ): assert check_argument_types() if len(path_name_type_list) == 0: raise ValueError( '1 or more elements are required for "path_name_type_list"') path_name_type_list = copy.deepcopy(path_name_type_list) self.preprocess = preprocess self.float_dtype = float_dtype self.int_dtype = int_dtype self.max_cache_fd = max_cache_fd self.loader_dict = {} self.debug_info = {} for path, name, _type in path_name_type_list: if name in self.loader_dict: raise RuntimeError(f'"{name}" is duplicated for data-key') loader = self._build_loader(path, _type) self.loader_dict[name] = loader self.debug_info[name] = path, _type if len(self.loader_dict[name]) == 0: raise RuntimeError(f"{path} has no samples") # TODO(kamo): Should check consistency of each utt-keys? if isinstance(max_cache_size, str): max_cache_size = humanfriendly.parse_size(max_cache_size) self.max_cache_size = max_cache_size if max_cache_size > 0: self.cache = SizedDict(shared=True) else: self.cache = None
def test_SizedDict_len(): d = SizedDict(data={"a": 2, "b": 5, "c": 10}) assert len(d) == 3
def test_SizedDict_contains(): d = SizedDict(data={"a": 2, "b": 5, "c": 10}) assert "a" in d
def test_SizedDict_iter(): d = SizedDict(data={"a": 2, "b": 5, "c": 10}) assert list(iter(d)) == ["a", "b", "c"]
def test_SizedDict_getitem(): d = SizedDict(data={"a": 2, "b": 5, "c": 10}) assert d["a"] == 2