def load_names(self) -> List[str]: try: with self.fs.open(self.get_names_file(), mode="rt") as tf: return json.load(tf) # TODO: check which error actually gets thrown except FileNotFoundError as err: raise CouldNotLoadData() from err
def load_stats(self) -> Generator[str, None, None]: try: with self.get_stats_file().open("rt") as tf: for line in tf: yield line except FileNotFoundError as err: raise CouldNotLoadData() from err
def load_count(self) -> int: try: with self.fs.open(self.get_count_file(), mode="rt") as tf: return int(tf.read().strip()) # TODO: check which error actually gets thrown except FileNotFoundError as err: raise CouldNotLoadData() from err
def load_chunk_of_edges( self, lhs_p: Partition, rhs_p: Partition, chunk_idx: int = 0, num_chunks: int = 1, shared: bool = False, ) -> EdgeList: file_path = self.get_edges_file(lhs_p, rhs_p) try: with h5py.File(file_path, "r") as hf: if hf.attrs.get(FORMAT_VERSION_ATTR, None) != FORMAT_VERSION: raise RuntimeError( f"Version mismatch in edge file {file_path}") lhs_ds = hf["lhs"] rhs_ds = hf["rhs"] rel_ds = hf["rel"] num_edges = rel_ds.len() chunk_size = div_roundup(num_edges, num_chunks) begin = chunk_idx * chunk_size end = min((chunk_idx + 1) * chunk_size, num_edges) chunk_size = end - begin allocator = allocate_shared_tensor if shared else torch.empty lhs = allocator((chunk_size, ), dtype=torch.long) rhs = allocator((chunk_size, ), dtype=torch.long) rel = allocator((chunk_size, ), dtype=torch.long) # Needed because https://github.com/h5py/h5py/issues/870. if chunk_size > 0: lhs_ds.read_direct(lhs.numpy(), source_sel=np.s_[begin:end]) rhs_ds.read_direct(rhs.numpy(), source_sel=np.s_[begin:end]) rel_ds.read_direct(rel.numpy(), source_sel=np.s_[begin:end]) lhsd = self.read_dynamic(hf, "lhsd", begin, end, shared=shared) rhsd = self.read_dynamic(hf, "rhsd", begin, end, shared=shared) if "weight" in hf: weight_ds = hf["weight"] weight = allocator((chunk_size, ), dtype=torch.long) if chunk_size > 0: weight_ds.read_direct(weight.numpy(), source_sel=np.s_[begin:end]) else: weight = None return EdgeList(EntityList(lhs, lhsd), EntityList(rhs, rhsd), rel, weight) except OSError as err: # h5py refuses to make it easy to figure out what went wrong. The errno # attribute is set to None. See https://github.com/h5py/h5py/issues/493. if f"errno = {errno.ENOENT}" in str(err): raise CouldNotLoadData() from err raise err
def get_number_of_edges(self, lhs_p: int, rhs_p: int) -> int: file_path = self.get_edges_file(lhs_p, rhs_p) try: with h5py.File(file_path, "r") as hf: if hf.attrs.get(FORMAT_VERSION_ATTR, None) != FORMAT_VERSION: raise RuntimeError(f"Version mismatch in edge file {file_path}") return hf["rel"].len() except OSError as err: # h5py refuses to make it easy to figure out what went wrong. The errno # attribute is set to None. See https://github.com/h5py/h5py/issues/493. if f"errno = {errno.ENOENT}" in str(err): raise CouldNotLoadData() from err raise err
def load_chunk_of_edges( self, lhs_p: int, rhs_p: int, chunk_idx: int = 0, num_chunks: int = 1, ) -> EdgeList: file_path = self.get_edges_file(lhs_p, rhs_p) try: with h5py.File(file_path, "r") as hf: if hf.attrs.get(FORMAT_VERSION_ATTR, None) != FORMAT_VERSION: raise RuntimeError( f"Version mismatch in edge file {file_path}") lhs_ds = hf["lhs"] rhs_ds = hf["rhs"] rel_ds = hf["rel"] num_edges = rel_ds.len() begin = int(chunk_idx * num_edges / num_chunks) end = int((chunk_idx + 1) * num_edges / num_chunks) chunk_size = end - begin lhs = torch.empty((chunk_size, ), dtype=torch.long) rhs = torch.empty((chunk_size, ), dtype=torch.long) rel = torch.empty((chunk_size, ), dtype=torch.long) # Needed because https://github.com/h5py/h5py/issues/870. if chunk_size > 0: lhs_ds.read_direct(lhs.numpy(), source_sel=np.s_[begin:end]) rhs_ds.read_direct(rhs.numpy(), source_sel=np.s_[begin:end]) rel_ds.read_direct(rel.numpy(), source_sel=np.s_[begin:end]) lhsd = self.read_dynamic(hf, "lhsd", begin, end) rhsd = self.read_dynamic(hf, "rhsd", begin, end) return EdgeList(EntityList(lhs, lhsd), EntityList(rhs, rhsd), rel) except OSError as err: # h5py refuses to make it easy to figure out what went wrong. The errno # attribute is set to None. See https://github.com/h5py/h5py/issues/493. if f"errno = {errno.ENOENT}" in str(err): raise CouldNotLoadData() from err raise err
def load_model( self, version: int ) -> Tuple[Optional[Dict[str, torch.Tensor]], Optional[bytes]]: path = self.get_model_file(version) logger.debug(f"Loading from {path}") try: with h5py.File(path, "r") as hf: if hf.attrs.get(FORMAT_VERSION_ATTR, None) != FORMAT_VERSION: raise RuntimeError(f"Version mismatch in model file {path}") state_dict = load_model_state_dict(hf) optim_state = load_optimizer_state_dict(hf) except OSError as err: # h5py refuses to make it easy to figure out what went wrong. The errno # attribute is set to None. See https://github.com/h5py/h5py/issues/493. if f"errno = {errno.ENOENT}" in str(err): raise CouldNotLoadData() from err raise err logger.debug(f"Done loading from {path}") return state_dict, optim_state
def load_entity_partition( self, version: int, entity_name: EntityName, partition: Partition, ) -> Tuple[FloatTensorType, Optional[bytes]]: path = self.get_entity_partition_file(version, entity_name, partition) logger.debug(f"Loading from {path}") try: with h5py.File(path, "r") as hf: if hf.attrs.get(FORMAT_VERSION_ATTR, None) != FORMAT_VERSION: raise RuntimeError(f"Version mismatch in embeddings file {path}") embs = load_embeddings(hf) optim_state = load_optimizer_state_dict(hf) except OSError as err: # h5py refuses to make it easy to figure out what went wrong. The errno # attribute is set to None. See https://github.com/h5py/h5py/issues/493. if f"errno = {errno.ENOENT}" in str(err): raise CouldNotLoadData() from err raise err logger.debug(f"Done loading from {path}") return embs, optim_state
def load_names(path: Path) -> List[str]: try: with path.open("rt") as tf: return json.load(tf) except FileNotFoundError as err: raise CouldNotLoadData() from err
def load_count(path: Path) -> int: try: with path.open("rt") as tf: return int(tf.read().strip()) except FileNotFoundError as err: raise CouldNotLoadData() from err
def load_config(self) -> str: try: with self.get_config_file().open("rt") as tf: return tf.read() except FileNotFoundError as err: raise CouldNotLoadData() from err