def test_getitem_int(self):
     self.assertEqual(
         EntityList(
             torch.tensor([3, 4, 1, 0], dtype=torch.long),
             tensor_list_from_lists([[2, 1], [0], [], [3, 4, 5]]),
         )[-3],
         EntityList(torch.tensor([4], dtype=torch.long),
                    tensor_list_from_lists([[0]])),
     )
Пример #2
0
    def load_chunk_of_edges(
        self,
        lhs_p: Partition,
        rhs_p: Partition,
        chunk_idx: int = 0,
        num_chunks: int = 1,
        shared: bool = False,
    ) -> EdgeList:
        file_path = self.get_edges_file(lhs_p, rhs_p)
        try:
            with h5py.File(file_path, "r") as hf:
                if hf.attrs.get(FORMAT_VERSION_ATTR, None) != FORMAT_VERSION:
                    raise RuntimeError(
                        f"Version mismatch in edge file {file_path}")
                lhs_ds = hf["lhs"]
                rhs_ds = hf["rhs"]
                rel_ds = hf["rel"]

                num_edges = rel_ds.len()
                chunk_size = div_roundup(num_edges, num_chunks)
                begin = chunk_idx * chunk_size
                end = min((chunk_idx + 1) * chunk_size, num_edges)
                chunk_size = end - begin

                allocator = allocate_shared_tensor if shared else torch.empty
                lhs = allocator((chunk_size, ), dtype=torch.long)
                rhs = allocator((chunk_size, ), dtype=torch.long)
                rel = allocator((chunk_size, ), dtype=torch.long)

                # Needed because https://github.com/h5py/h5py/issues/870.
                if chunk_size > 0:
                    lhs_ds.read_direct(lhs.numpy(),
                                       source_sel=np.s_[begin:end])
                    rhs_ds.read_direct(rhs.numpy(),
                                       source_sel=np.s_[begin:end])
                    rel_ds.read_direct(rel.numpy(),
                                       source_sel=np.s_[begin:end])

                lhsd = self.read_dynamic(hf, "lhsd", begin, end, shared=shared)
                rhsd = self.read_dynamic(hf, "rhsd", begin, end, shared=shared)

                if "weight" in hf:
                    weight_ds = hf["weight"]
                    weight = allocator((chunk_size, ), dtype=torch.long)
                    if chunk_size > 0:
                        weight_ds.read_direct(weight.numpy(),
                                              source_sel=np.s_[begin:end])
                else:
                    weight = None
                return EdgeList(EntityList(lhs, lhsd), EntityList(rhs, rhsd),
                                rel, weight)
        except OSError as err:
            # h5py refuses to make it easy to figure out what went wrong. The errno
            # attribute is set to None. See https://github.com/h5py/h5py/issues/493.
            if f"errno = {errno.ENOENT}" in str(err):
                raise CouldNotLoadData() from err
            raise err
 def test_getitem_longtensor(self):
     self.assertEqual(
         EntityList(
             torch.tensor([3, 4, 1, 0], dtype=torch.long),
             tensor_list_from_lists([[2, 1], [0], [], [3, 4, 5]]),
         )[torch.tensor([2, 0])],
         EntityList(
             torch.tensor([1, 3], dtype=torch.long),
             tensor_list_from_lists([[], [2, 1]]),
         ),
     )
 def test_cat(self):
     tensor_1 = torch.tensor([2, 3], dtype=torch.long)
     tensor_2 = torch.tensor([0, 1], dtype=torch.long)
     tensor_sum = torch.tensor([2, 3, 0, 1], dtype=torch.long)
     tensor_list_1 = tensor_list_from_lists([[3, 4], [0]])
     tensor_list_2 = tensor_list_from_lists([[1, 2, 0], []])
     tensor_list_sum = tensor_list_from_lists([[3, 4], [0], [1, 2, 0], []])
     self.assertEqual(
         EntityList.cat([
             EntityList(tensor_1, tensor_list_1),
             EntityList(tensor_2, tensor_list_2),
         ]),
         EntityList(tensor_sum, tensor_list_sum),
     )
 def test_from_tensor(self):
     self.assertEqual(
         EntityList.from_tensor(torch.tensor([3, 4], dtype=torch.long)),
         EntityList(
             torch.tensor([3, 4], dtype=torch.long), TensorList.empty(num_tensors=2)
         ),
     )
 def test_to_tensor_list(self):
     self.assertEqual(
         EntityList(
             torch.tensor([-1, -1], dtype=torch.long),
             tensor_list_from_lists([[3, 4], [0]]),
         ).to_tensor_list(),
         tensor_list_from_lists([[3, 4], [0]]),
     )
Пример #7
0
 def test_len(self):
     self.assertEqual(
         len(EntityList(
             torch.tensor([3, 4], dtype=torch.long),
             tensor_list_from_lists([[], [2, 1, 0]]),
         )),
         2,
     )
Пример #8
0
    def load_chunk_of_edges(
        self,
        lhs_p: int,
        rhs_p: int,
        chunk_idx: int = 0,
        num_chunks: int = 1,
    ) -> EdgeList:
        file_path = self.get_edges_file(lhs_p, rhs_p)
        try:
            with h5py.File(file_path, "r") as hf:
                if hf.attrs.get(FORMAT_VERSION_ATTR, None) != FORMAT_VERSION:
                    raise RuntimeError(
                        f"Version mismatch in edge file {file_path}")
                lhs_ds = hf["lhs"]
                rhs_ds = hf["rhs"]
                rel_ds = hf["rel"]

                num_edges = rel_ds.len()
                begin = int(chunk_idx * num_edges / num_chunks)
                end = int((chunk_idx + 1) * num_edges / num_chunks)
                chunk_size = end - begin

                lhs = torch.empty((chunk_size, ), dtype=torch.long)
                rhs = torch.empty((chunk_size, ), dtype=torch.long)
                rel = torch.empty((chunk_size, ), dtype=torch.long)

                # Needed because https://github.com/h5py/h5py/issues/870.
                if chunk_size > 0:
                    lhs_ds.read_direct(lhs.numpy(),
                                       source_sel=np.s_[begin:end])
                    rhs_ds.read_direct(rhs.numpy(),
                                       source_sel=np.s_[begin:end])
                    rel_ds.read_direct(rel.numpy(),
                                       source_sel=np.s_[begin:end])

                lhsd = self.read_dynamic(hf, "lhsd", begin, end)
                rhsd = self.read_dynamic(hf, "rhsd", begin, end)

                return EdgeList(EntityList(lhs, lhsd), EntityList(rhs, rhsd),
                                rel)
        except OSError as err:
            # h5py refuses to make it easy to figure out what went wrong. The errno
            # attribute is set to None. See https://github.com/h5py/h5py/issues/493.
            if f"errno = {errno.ENOENT}" in str(err):
                raise CouldNotLoadData() from err
            raise err
 def test_to_tensor(self):
     self.assertTrue(
         torch.equal(
             EntityList(
                 torch.tensor([2, 3], dtype=torch.long),
                 tensor_list_from_lists([[], []]),
             ).to_tensor(),
             torch.tensor([2, 3], dtype=torch.long),
         ))
 def test_equal(self):
     el = EntityList(
         torch.tensor([3, 4], dtype=torch.long),
         tensor_list_from_lists([[], [2, 1, 0]]),
     )
     self.assertEqual(el, el)
     self.assertNotEqual(
         el,
         EntityList(
             torch.tensor([4, 2], dtype=torch.long),
             tensor_list_from_lists([[], [2, 1, 0]]),
         ),
     )
     self.assertNotEqual(
         el,
         EntityList(
             torch.tensor([3, 4], dtype=torch.long),
             tensor_list_from_lists([[3], [2, 0]]),
         ),
     )
Пример #11
0
    def read(
        self,
        lhs_p: Partition,
        rhs_p: Partition,
        chunk_idx: int = 0,
        num_chunks: int = 1,
    ) -> EdgeList:
        file_path = os.path.join(self.path, "edges_%d_%d.h5" % (lhs_p, rhs_p))
        assert os.path.exists(file_path), "%s does not exist" % file_path
        with h5py.File(file_path, 'r') as hf:
            if FORMAT_VERSION_ATTR not in hf.attrs:
                log("WARNING: It may be that one of your edge paths contains "
                    "files using the old format. See D14241362 for how to "
                    "update them.")
            elif hf.attrs[FORMAT_VERSION_ATTR] != FORMAT_VERSION:
                raise RuntimeError("Version mismatch in edge file %s" % file_path)
            lhs_ds = hf['lhs']
            rhs_ds = hf['rhs']
            rel_ds = hf['rel']

            num_edges = rel_ds.len()
            begin = int(chunk_idx * num_edges / num_chunks)
            end = int((chunk_idx + 1) * num_edges / num_chunks)
            chunk_size = end - begin

            lhs = torch.empty((chunk_size,), dtype=torch.long)
            rhs = torch.empty((chunk_size,), dtype=torch.long)
            rel = torch.empty((chunk_size,), dtype=torch.long)

            # Needed because https://github.com/h5py/h5py/issues/870.
            if chunk_size > 0:
                lhs_ds.read_direct(lhs.numpy(), source_sel=np.s_[begin:end])
                rhs_ds.read_direct(rhs.numpy(), source_sel=np.s_[begin:end])
                rel_ds.read_direct(rel.numpy(), source_sel=np.s_[begin:end])

            lhsd = self.read_dynamic(hf, 'lhsd', begin, end)
            rhsd = self.read_dynamic(hf, 'rhsd', begin, end)

            return EdgeList(EntityList(lhs, lhsd),
                            EntityList(rhs, rhsd),
                            rel)
Пример #12
0
    def load_chunk_of_edges(
        self,
        lhs_p: int,
        rhs_p: int,
        chunk_idx: int = 0,
        num_chunks: int = 1,
    ) -> EdgeList:
        file_path = self.get_edges_file(lhs_p, rhs_p)
        if not file_path.is_file():
            raise RuntimeError(f"{file_path} does not exist")
        with h5py.File(file_path, 'r') as hf:
            if hf.attrs.get(FORMAT_VERSION_ATTR, None) != FORMAT_VERSION:
                raise RuntimeError(
                    f"Version mismatch in edge file {file_path}")
            lhs_ds = hf['lhs']
            rhs_ds = hf['rhs']
            rel_ds = hf['rel']

            num_edges = rel_ds.len()
            begin = int(chunk_idx * num_edges / num_chunks)
            end = int((chunk_idx + 1) * num_edges / num_chunks)
            chunk_size = end - begin

            lhs = torch.empty((chunk_size, ), dtype=torch.long)
            rhs = torch.empty((chunk_size, ), dtype=torch.long)
            rel = torch.empty((chunk_size, ), dtype=torch.long)

            # Needed because https://github.com/h5py/h5py/issues/870.
            if chunk_size > 0:
                lhs_ds.read_direct(lhs.numpy(), source_sel=np.s_[begin:end])
                rhs_ds.read_direct(rhs.numpy(), source_sel=np.s_[begin:end])
                rel_ds.read_direct(rel.numpy(), source_sel=np.s_[begin:end])

            lhsd = self.read_dynamic(hf, 'lhsd', begin, end)
            rhsd = self.read_dynamic(hf, 'rhsd', begin, end)

            return EdgeList(EntityList(lhs, lhsd), EntityList(rhs, rhsd), rel)
Пример #13
0
    def read(
        self,
        lhs_p: Partition,
        rhs_p: Partition,
        chunk_idx: int = 0,
        num_chunks: int = 1,
    ) -> EdgeList:
        file_path = os.path.join(self.path, f"edges_{lhs_p}_{rhs_p}.h5")
        assert os.path.exists(file_path), "%s does not exist" % file_path
        with h5py.File(file_path, 'r') as hf:
            if hf.attrs.get(FORMAT_VERSION_ATTR, None) != FORMAT_VERSION:
                raise RuntimeError("Version mismatch in edge file %s" %
                                   file_path)
            lhs_ds = hf['lhs']
            rhs_ds = hf['rhs']
            rel_ds = hf['rel']

            num_edges = rel_ds.len()
            begin = int(chunk_idx * num_edges / num_chunks)
            end = int((chunk_idx + 1) * num_edges / num_chunks)
            chunk_size = end - begin

            lhs = torch.empty((chunk_size, ), dtype=torch.long)
            rhs = torch.empty((chunk_size, ), dtype=torch.long)
            rel = torch.empty((chunk_size, ), dtype=torch.long)

            # Needed because https://github.com/h5py/h5py/issues/870.
            if chunk_size > 0:
                lhs_ds.read_direct(lhs.numpy(), source_sel=np.s_[begin:end])
                rhs_ds.read_direct(rhs.numpy(), source_sel=np.s_[begin:end])
                rel_ds.read_direct(rel.numpy(), source_sel=np.s_[begin:end])

            lhsd = self.read_dynamic(hf, 'lhsd', begin, end)
            rhsd = self.read_dynamic(hf, 'rhsd', begin, end)

            return EdgeList(EntityList(lhs, lhsd), EntityList(rhs, rhsd), rel)
 def test_constructor_checks(self):
     with self.assertRaises(ValueError):
         EntityList(
             torch.tensor([3, 4, 0], dtype=torch.long),
             tensor_list_from_lists([[2, 1]]),
         )
 def test_from_tensor_list(self):
     tensor_list = tensor_list_from_lists([[3, 4], [0, 2]])
     self.assertEqual(
         EntityList.from_tensor_list(tensor_list),
         EntityList(torch.full((2, ), -1, dtype=torch.long), tensor_list),
     )
 def test_empty(self):
     self.assertEqual(
         EntityList.empty(),
         EntityList(torch.empty((0, ), dtype=torch.long),
                    TensorList.empty()),
     )