def load_names(self) -> List[str]:
     try:
         with self.fs.open(self.get_names_file(), mode="rt") as tf:
             return json.load(tf)
     # TODO: check which error actually gets thrown
     except FileNotFoundError as err:
         raise CouldNotLoadData() from err
Exemple #2
0
 def load_stats(self) -> Generator[str, None, None]:
     try:
         with self.get_stats_file().open("rt") as tf:
             for line in tf:
                 yield line
     except FileNotFoundError as err:
         raise CouldNotLoadData() from err
 def load_count(self) -> int:
     try:
         with self.fs.open(self.get_count_file(), mode="rt") as tf:
             return int(tf.read().strip())
     # TODO: check which error actually gets thrown
     except FileNotFoundError as err:
         raise CouldNotLoadData() from err
Exemple #4
0
    def load_chunk_of_edges(
        self,
        lhs_p: Partition,
        rhs_p: Partition,
        chunk_idx: int = 0,
        num_chunks: int = 1,
        shared: bool = False,
    ) -> EdgeList:
        file_path = self.get_edges_file(lhs_p, rhs_p)
        try:
            with h5py.File(file_path, "r") as hf:
                if hf.attrs.get(FORMAT_VERSION_ATTR, None) != FORMAT_VERSION:
                    raise RuntimeError(
                        f"Version mismatch in edge file {file_path}")
                lhs_ds = hf["lhs"]
                rhs_ds = hf["rhs"]
                rel_ds = hf["rel"]

                num_edges = rel_ds.len()
                chunk_size = div_roundup(num_edges, num_chunks)
                begin = chunk_idx * chunk_size
                end = min((chunk_idx + 1) * chunk_size, num_edges)
                chunk_size = end - begin

                allocator = allocate_shared_tensor if shared else torch.empty
                lhs = allocator((chunk_size, ), dtype=torch.long)
                rhs = allocator((chunk_size, ), dtype=torch.long)
                rel = allocator((chunk_size, ), dtype=torch.long)

                # Needed because https://github.com/h5py/h5py/issues/870.
                if chunk_size > 0:
                    lhs_ds.read_direct(lhs.numpy(),
                                       source_sel=np.s_[begin:end])
                    rhs_ds.read_direct(rhs.numpy(),
                                       source_sel=np.s_[begin:end])
                    rel_ds.read_direct(rel.numpy(),
                                       source_sel=np.s_[begin:end])

                lhsd = self.read_dynamic(hf, "lhsd", begin, end, shared=shared)
                rhsd = self.read_dynamic(hf, "rhsd", begin, end, shared=shared)

                if "weight" in hf:
                    weight_ds = hf["weight"]
                    weight = allocator((chunk_size, ), dtype=torch.long)
                    if chunk_size > 0:
                        weight_ds.read_direct(weight.numpy(),
                                              source_sel=np.s_[begin:end])
                else:
                    weight = None
                return EdgeList(EntityList(lhs, lhsd), EntityList(rhs, rhsd),
                                rel, weight)
        except OSError as err:
            # h5py refuses to make it easy to figure out what went wrong. The errno
            # attribute is set to None. See https://github.com/h5py/h5py/issues/493.
            if f"errno = {errno.ENOENT}" in str(err):
                raise CouldNotLoadData() from err
            raise err
Exemple #5
0
 def get_number_of_edges(self, lhs_p: int, rhs_p: int) -> int:
     file_path = self.get_edges_file(lhs_p, rhs_p)
     try:
         with h5py.File(file_path, "r") as hf:
             if hf.attrs.get(FORMAT_VERSION_ATTR, None) != FORMAT_VERSION:
                 raise RuntimeError(f"Version mismatch in edge file {file_path}")
             return hf["rel"].len()
     except OSError as err:
         # h5py refuses to make it easy to figure out what went wrong. The errno
         # attribute is set to None. See https://github.com/h5py/h5py/issues/493.
         if f"errno = {errno.ENOENT}" in str(err):
             raise CouldNotLoadData() from err
         raise err
Exemple #6
0
    def load_chunk_of_edges(
        self,
        lhs_p: int,
        rhs_p: int,
        chunk_idx: int = 0,
        num_chunks: int = 1,
    ) -> EdgeList:
        file_path = self.get_edges_file(lhs_p, rhs_p)
        try:
            with h5py.File(file_path, "r") as hf:
                if hf.attrs.get(FORMAT_VERSION_ATTR, None) != FORMAT_VERSION:
                    raise RuntimeError(
                        f"Version mismatch in edge file {file_path}")
                lhs_ds = hf["lhs"]
                rhs_ds = hf["rhs"]
                rel_ds = hf["rel"]

                num_edges = rel_ds.len()
                begin = int(chunk_idx * num_edges / num_chunks)
                end = int((chunk_idx + 1) * num_edges / num_chunks)
                chunk_size = end - begin

                lhs = torch.empty((chunk_size, ), dtype=torch.long)
                rhs = torch.empty((chunk_size, ), dtype=torch.long)
                rel = torch.empty((chunk_size, ), dtype=torch.long)

                # Needed because https://github.com/h5py/h5py/issues/870.
                if chunk_size > 0:
                    lhs_ds.read_direct(lhs.numpy(),
                                       source_sel=np.s_[begin:end])
                    rhs_ds.read_direct(rhs.numpy(),
                                       source_sel=np.s_[begin:end])
                    rel_ds.read_direct(rel.numpy(),
                                       source_sel=np.s_[begin:end])

                lhsd = self.read_dynamic(hf, "lhsd", begin, end)
                rhsd = self.read_dynamic(hf, "rhsd", begin, end)

                return EdgeList(EntityList(lhs, lhsd), EntityList(rhs, rhsd),
                                rel)
        except OSError as err:
            # h5py refuses to make it easy to figure out what went wrong. The errno
            # attribute is set to None. See https://github.com/h5py/h5py/issues/493.
            if f"errno = {errno.ENOENT}" in str(err):
                raise CouldNotLoadData() from err
            raise err
 def load_model(
     self, version: int
 ) -> Tuple[Optional[Dict[str, torch.Tensor]], Optional[bytes]]:
     path = self.get_model_file(version)
     logger.debug(f"Loading from {path}")
     try:
         with h5py.File(path, "r") as hf:
             if hf.attrs.get(FORMAT_VERSION_ATTR, None) != FORMAT_VERSION:
                 raise RuntimeError(f"Version mismatch in model file {path}")
             state_dict = load_model_state_dict(hf)
             optim_state = load_optimizer_state_dict(hf)
     except OSError as err:
         # h5py refuses to make it easy to figure out what went wrong. The errno
         # attribute is set to None. See https://github.com/h5py/h5py/issues/493.
         if f"errno = {errno.ENOENT}" in str(err):
             raise CouldNotLoadData() from err
         raise err
     logger.debug(f"Done loading from {path}")
     return state_dict, optim_state
 def load_entity_partition(
     self,
     version: int,
     entity_name: EntityName,
     partition: Partition,
 ) -> Tuple[FloatTensorType, Optional[bytes]]:
     path = self.get_entity_partition_file(version, entity_name, partition)
     logger.debug(f"Loading from {path}")
     try:
         with h5py.File(path, "r") as hf:
             if hf.attrs.get(FORMAT_VERSION_ATTR, None) != FORMAT_VERSION:
                 raise RuntimeError(f"Version mismatch in embeddings file {path}")
             embs = load_embeddings(hf)
             optim_state = load_optimizer_state_dict(hf)
     except OSError as err:
         # h5py refuses to make it easy to figure out what went wrong. The errno
         # attribute is set to None. See https://github.com/h5py/h5py/issues/493.
         if f"errno = {errno.ENOENT}" in str(err):
             raise CouldNotLoadData() from err
         raise err
     logger.debug(f"Done loading from {path}")
     return embs, optim_state
Exemple #9
0
def load_names(path: Path) -> List[str]:
    try:
        with path.open("rt") as tf:
            return json.load(tf)
    except FileNotFoundError as err:
        raise CouldNotLoadData() from err
Exemple #10
0
def load_count(path: Path) -> int:
    try:
        with path.open("rt") as tf:
            return int(tf.read().strip())
    except FileNotFoundError as err:
        raise CouldNotLoadData() from err
Exemple #11
0
 def load_config(self) -> str:
     try:
         with self.get_config_file().open("rt") as tf:
             return tf.read()
     except FileNotFoundError as err:
         raise CouldNotLoadData() from err