Exemple #1
0
def generate_edge_path_files_fast(
    edge_file_in: Path,
    edge_path_out: Path,
    edge_storage: AbstractEdgeStorage,
    entities_by_type: Dict[str, Dictionary],
    relation_types: Dictionary,
    relation_configs: List[RelationSchema],
    edgelist_reader: EdgelistReader,
) -> None:
    processed = 0
    skipped = 0

    log("Taking the fast train!")
    data = []
    for lhs_word, rhs_word, rel_word in edgelist_reader.read(edge_file_in):
        if rel_word is None:
            rel_id = 0
        else:
            try:
                rel_id = relation_types.get_id(rel_word)
            except KeyError:
                # Ignore edges whose relation type is not known.
                skipped += 1
                continue

        lhs_type = relation_configs[rel_id].lhs
        rhs_type = relation_configs[rel_id].rhs

        try:
            _, lhs_offset = entities_by_type[lhs_type].get_partition(lhs_word)
            _, rhs_offset = entities_by_type[rhs_type].get_partition(rhs_word)
        except KeyError:
            # Ignore edges whose entities are not known.
            skipped += 1
            continue

        data.append((lhs_offset, rhs_offset, rel_id))

        processed = processed + 1
        if processed % 100000 == 0:
            log(f"- Processed {processed} edges so far...")

    lhs_offsets, rhs_offsets, rel_ids = zip(*data)
    edge_list = EdgeList(
        EntityList.from_tensor(torch.tensor(list(lhs_offsets), dtype=torch.long)),
        EntityList.from_tensor(torch.tensor(list(rhs_offsets), dtype=torch.long)),
        torch.tensor(list(rel_ids), dtype=torch.long),
    )
    edge_storage.save_edges(0, 0, edge_list)

    log(f"- Processed {processed} edges in total")
    if skipped > 0:
        log(
            f"- Skipped {skipped} edges because their relation type or "
            f"entities were unknown (either not given in the config or "
            f"filtered out as too rare)."
        )
Exemple #2
0
def append_to_file(data, appender):
    lhs_offsets, rhs_offsets, rel_ids = zip(*data)
    appender.append_edges(
        EdgeList(
            EntityList.from_tensor(torch.tensor(lhs_offsets, dtype=torch.long)),
            EntityList.from_tensor(torch.tensor(rhs_offsets, dtype=torch.long)),
            torch.tensor(rel_ids, dtype=torch.long),
        )
    )
Exemple #3
0
 def test_get_relation_type_as_scalar(self):
     self.assertEqual(
         EdgeList(
             EntityList.from_tensor(torch.tensor([3, 4], dtype=torch.long)),
             EntityList.from_tensor(torch.tensor([0, 2], dtype=torch.long)),
             torch.tensor(3, dtype=torch.long),
         ).get_relation_type_as_scalar(),
         3,
     )
Exemple #4
0
 def test_len(self):
     self.assertEqual(
         len(
             EdgeList(
                 EntityList.from_tensor(
                     torch.tensor([3, 4], dtype=torch.long)),
                 EntityList.from_tensor(
                     torch.tensor([0, 2], dtype=torch.long)),
                 torch.tensor([2, 0], dtype=torch.long),
             )),
         2,
     )
Exemple #5
0
def append_to_file(data, appender):
    lhs_offsets, rhs_offsets, rel_ids, weights = zip(*data)
    weights = torch.tensor(weights) if weights[0] is not None else None
    appender.append_edges(
        EdgeList(
            EntityList.from_tensor(torch.tensor(lhs_offsets,
                                                dtype=torch.long)),
            EntityList.from_tensor(torch.tensor(rhs_offsets,
                                                dtype=torch.long)),
            torch.tensor(rel_ids, dtype=torch.long),
            weights,
        ))
Exemple #6
0
 def test_get_relation_type_as_vector(self):
     self.assertTrue(
         torch.equal(
             EdgeList(
                 EntityList.from_tensor(
                     torch.tensor([3, 4], dtype=torch.long)),
                 EntityList.from_tensor(
                     torch.tensor([0, 2], dtype=torch.long)),
                 torch.tensor([2, 0], dtype=torch.long),
             ).get_relation_type_as_vector(),
             torch.tensor([2, 0], dtype=torch.long),
         ))
Exemple #7
0
 def test_has_scalar_relation_type(self):
     self.assertTrue(
         EdgeList(
             EntityList.from_tensor(torch.tensor([3, 4], dtype=torch.long)),
             EntityList.from_tensor(torch.tensor([0, 2], dtype=torch.long)),
             torch.tensor(3, dtype=torch.long),
         ).has_scalar_relation_type())
     self.assertFalse(
         EdgeList(
             EntityList.from_tensor(torch.tensor([3, 4], dtype=torch.long)),
             EntityList.from_tensor(torch.tensor([0, 2], dtype=torch.long)),
             torch.tensor([2, 0], dtype=torch.long),
         ).has_scalar_relation_type())
 def test_basic(self):
     self.assertEqual(
         group_by_relation_type(
             torch.tensor([1, 0, 0, 1, 2, 2, 0, 0, 2, 2], dtype=torch.long),
             EntityList.from_tensor(
                 torch.tensor([93, 24, 13, 31, 70, 66, 77, 38, 5, 5],
                              dtype=torch.long)),
             EntityList.from_tensor(
                 torch.tensor([90, 75, 9, 25, 23, 31, 49, 64, 42, 50],
                              dtype=torch.long)),
         ),
         [
             (
                 EntityList.from_tensor(
                     torch.tensor([24, 13, 77, 38], dtype=torch.long)),
                 EntityList.from_tensor(
                     torch.tensor([75, 9, 49, 64], dtype=torch.long)),
                 0,
             ),
             (
                 EntityList.from_tensor(
                     torch.tensor([93, 31], dtype=torch.long)),
                 EntityList.from_tensor(
                     torch.tensor([90, 25], dtype=torch.long)),
                 1,
             ),
             (
                 EntityList.from_tensor(
                     torch.tensor([70, 66, 5, 5], dtype=torch.long)),
                 EntityList.from_tensor(
                     torch.tensor([23, 31, 42, 50], dtype=torch.long)),
                 2,
             ),
         ],
     )
Exemple #9
0
 def test_getitem_longtensor(self):
     self.assertEqual(
         EdgeList(
             EntityList.from_tensor(
                 torch.tensor([3, 4, 1, 0], dtype=torch.long)),
             EntityList.from_tensor(
                 torch.tensor([0, 2, 1, 3], dtype=torch.long)),
             torch.tensor([1, 1, 3, 0], dtype=torch.long),
         )[torch.tensor([2, 0])],
         EdgeList(
             EntityList.from_tensor(torch.tensor([1, 3], dtype=torch.long)),
             EntityList.from_tensor(torch.tensor([1, 0], dtype=torch.long)),
             torch.tensor([3, 1], dtype=torch.long),
         ),
     )
Exemple #10
0
 def test_getitem_int(self):
     self.assertEqual(
         EdgeList(
             EntityList.from_tensor(
                 torch.tensor([3, 4, 1, 0], dtype=torch.long)),
             EntityList.from_tensor(
                 torch.tensor([0, 2, 1, 3], dtype=torch.long)),
             torch.tensor([1, 1, 3, 0], dtype=torch.long),
         )[-3],
         EdgeList(
             EntityList.from_tensor(torch.tensor([4], dtype=torch.long)),
             EntityList.from_tensor(torch.tensor([2], dtype=torch.long)),
             torch.tensor(1, dtype=torch.long),
         ),
     )
 def test_empty(self):
     embeddings = torch.empty((0, 3))
     module = SimpleEmbedding(weight=embeddings)
     self.assertTensorEqual(
         module(EntityList.from_tensor(torch.empty((0, ),
                                                   dtype=torch.long))),
         torch.empty((0, 3)))
 def test_from_tensor(self):
     self.assertEqual(
         EntityList.from_tensor(torch.tensor([3, 4], dtype=torch.long)),
         EntityList(
             torch.tensor([3, 4], dtype=torch.long), TensorList.empty(num_tensors=2)
         ),
     )
 def test_constant(self):
     self.assertEqual(
         group_by_relation_type(
             torch.tensor([3, 3, 3, 3], dtype=torch.long),
             EntityList.from_tensor(
                 torch.tensor([93, 24, 13, 31], dtype=torch.long)),
             EntityList.from_tensor(
                 torch.tensor([90, 75, 9, 25], dtype=torch.long)),
         ),
         [
             (
                 EntityList.from_tensor(
                     torch.tensor([93, 24, 13, 31], dtype=torch.long)),
                 EntityList.from_tensor(
                     torch.tensor([90, 75, 9, 25], dtype=torch.long)),
                 3,
             ),
         ],
     )
Exemple #14
0
 def test_forward(self):
     embeddings = torch.tensor(
         [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]], requires_grad=True
     )
     module = SimpleEmbedding(weight=embeddings)
     result = module(EntityList.from_tensor(torch.tensor([2, 0, 0])))
     self.assertTensorEqual(
         result, torch.tensor([[3.0, 3.0, 3.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])
     )
     result.sum().backward()
     self.assertTrue((embeddings.grad.to_dense() != 0).any())
Exemple #15
0
 def test_max_norm(self):
     embeddings = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0],
                                [3.0, 3.0, 3.0]])
     module = SimpleEmbedding(weight=embeddings, max_norm=2)
     self.assertTensorEqual(
         module(EntityList.from_tensor(torch.tensor([2, 0, 0]))),
         torch.tensor([
             [1.1547, 1.1547, 1.1547],
             [1.0000, 1.0000, 1.0000],
             [1.0000, 1.0000, 1.0000],
         ]),
     )
Exemple #16
0
 def test_cat_vector(self):
     self.assertEqual(
         EdgeList.cat([
             EdgeList(
                 EntityList.from_tensor(
                     torch.tensor([3, 4], dtype=torch.long)),
                 EntityList.from_tensor(
                     torch.tensor([0, 2], dtype=torch.long)),
                 torch.tensor([2, 1], dtype=torch.long),
             ),
             EdgeList(
                 EntityList.from_tensor(
                     torch.tensor([1, 0], dtype=torch.long)),
                 EntityList.from_tensor(
                     torch.tensor([1, 3], dtype=torch.long)),
                 torch.tensor([3, 0], dtype=torch.long),
             ),
         ]),
         EdgeList(
             EntityList.from_tensor(
                 torch.tensor([3, 4, 1, 0], dtype=torch.long)),
             EntityList.from_tensor(
                 torch.tensor([0, 2, 1, 3], dtype=torch.long)),
             torch.tensor([2, 1, 3, 0], dtype=torch.long),
         ),
     )
Exemple #17
0
 def test_cat_scalar_different(self):
     self.assertEqual(
         EdgeList.cat([
             EdgeList(
                 EntityList.from_tensor(
                     torch.tensor([3, 4], dtype=torch.long)),
                 EntityList.from_tensor(
                     torch.tensor([0, 2], dtype=torch.long)),
                 torch.tensor(0, dtype=torch.long),
             ),
             EdgeList(
                 EntityList.from_tensor(
                     torch.tensor([1, 0], dtype=torch.long)),
                 EntityList.from_tensor(
                     torch.tensor([1, 3], dtype=torch.long)),
                 torch.tensor(1, dtype=torch.long),
             ),
         ]),
         EdgeList(
             EntityList.from_tensor(
                 torch.tensor([3, 4, 1, 0], dtype=torch.long)),
             EntityList.from_tensor(
                 torch.tensor([0, 2, 1, 3], dtype=torch.long)),
             torch.tensor([0, 0, 1, 1], dtype=torch.long),
         ),
     )
Exemple #18
0
 def test_basic(self):
     edges = EdgeList(
         EntityList.from_tensor(
             torch.tensor([93, 24, 13, 31, 70, 66, 77, 38, 5, 5],
                          dtype=torch.long)),
         EntityList.from_tensor(
             torch.tensor([90, 75, 9, 25, 23, 31, 49, 64, 42, 50],
                          dtype=torch.long)),
         torch.tensor([1, 0, 0, 1, 2, 2, 0, 0, 2, 2], dtype=torch.long),
     )
     edges_by_type = defaultdict(list)
     for batch_edges in batch_edges_group_by_relation_type(edges,
                                                           batch_size=3):
         self.assertIsInstance(batch_edges, EdgeList)
         self.assertLessEqual(len(batch_edges), 3)
         self.assertTrue(batch_edges.has_scalar_relation_type())
         edges_by_type[batch_edges.get_relation_type_as_scalar()].append(
             batch_edges)
     self.assertEqual(
         {k: EdgeList.cat(v)
          for k, v in edges_by_type.items()},
         {
             0:
             EdgeList(
                 EntityList.from_tensor(
                     torch.tensor([24, 13, 77, 38], dtype=torch.long)),
                 EntityList.from_tensor(
                     torch.tensor([75, 9, 49, 64], dtype=torch.long)),
                 torch.tensor(0, dtype=torch.long),
             ),
             1:
             EdgeList(
                 EntityList.from_tensor(
                     torch.tensor([93, 31], dtype=torch.long)),
                 EntityList.from_tensor(
                     torch.tensor([90, 25], dtype=torch.long)),
                 torch.tensor(1, dtype=torch.long),
             ),
             2:
             EdgeList(
                 EntityList.from_tensor(
                     torch.tensor([70, 66, 5, 5], dtype=torch.long)),
                 EntityList.from_tensor(
                     torch.tensor([23, 31, 42, 50], dtype=torch.long)),
                 torch.tensor(2, dtype=torch.long),
             ),
         },
     )
 def test_basic(self):
     self.assertEqual(
         list(
             batch_edges_mix_relation_types(
                 EdgeList(
                     EntityList.from_tensor(
                         torch.tensor(
                             [93, 24, 13, 31, 70, 66, 77, 38, 5, 5],
                             dtype=torch.long)),
                     EntityList.from_tensor(
                         torch.tensor(
                             [90, 75, 9, 25, 23, 31, 49, 64, 42, 50],
                             dtype=torch.long,
                         )),
                     torch.tensor([1, 0, 0, 1, 2, 2, 0, 0, 2, 2],
                                  dtype=torch.long),
                 ),
                 batch_size=4,
             )),
         [
             EdgeList(
                 EntityList.from_tensor(
                     torch.tensor([93, 24, 13, 31], dtype=torch.long)),
                 EntityList.from_tensor(
                     torch.tensor([90, 75, 9, 25], dtype=torch.long)),
                 torch.tensor([1, 0, 0, 1], dtype=torch.long),
             ),
             EdgeList(
                 EntityList.from_tensor(
                     torch.tensor([70, 66, 77, 38], dtype=torch.long)),
                 EntityList.from_tensor(
                     torch.tensor([23, 31, 49, 64], dtype=torch.long)),
                 torch.tensor([2, 2, 0, 0], dtype=torch.long),
             ),
             EdgeList(
                 EntityList.from_tensor(
                     torch.tensor([5, 5], dtype=torch.long)),
                 EntityList.from_tensor(
                     torch.tensor([42, 50], dtype=torch.long)),
                 torch.tensor([2, 2], dtype=torch.long),
             ),
         ],
     )
 def test_basic(self):
     lhs = EntityList.from_tensor(
         torch.tensor([93, 24, 13, 31, 70, 66, 77, 38, 5, 5],
                      dtype=torch.long))
     rhs = EntityList.from_tensor(
         torch.tensor([90, 75, 9, 25, 23, 31, 49, 64, 42, 50],
                      dtype=torch.long))
     rel = torch.tensor([1, 0, 0, 1, 2, 2, 0, 0, 2, 2], dtype=torch.long)
     lhs_by_type = defaultdict(list)
     rhs_by_type = defaultdict(list)
     for batch_lhs, batch_rhs, rel_type in batch_edges_group_by_relation_type(
             lhs, rhs, rel, batch_size=3):
         self.assertIsInstance(batch_lhs, EntityList)
         self.assertLessEqual(batch_lhs.size(0), 3)
         lhs_by_type[rel_type].append(batch_lhs)
         self.assertIsInstance(batch_rhs, EntityList)
         self.assertLessEqual(batch_rhs.size(0), 3)
         rhs_by_type[rel_type].append(batch_rhs)
     self.assertCountEqual(lhs_by_type.keys(), [0, 1, 2])
     self.assertCountEqual(rhs_by_type.keys(), [0, 1, 2])
     self.assertEqual(
         EntityList.cat(lhs_by_type[0]),
         EntityList.from_tensor(
             torch.tensor([24, 13, 77, 38], dtype=torch.long)))
     self.assertEqual(
         EntityList.cat(rhs_by_type[0]),
         EntityList.from_tensor(
             torch.tensor([75, 9, 49, 64], dtype=torch.long)))
     self.assertEqual(
         EntityList.cat(lhs_by_type[1]),
         EntityList.from_tensor(torch.tensor([93, 31], dtype=torch.long)))
     self.assertEqual(
         EntityList.cat(rhs_by_type[1]),
         EntityList.from_tensor(torch.tensor([90, 25], dtype=torch.long)))
     self.assertEqual(
         EntityList.cat(lhs_by_type[2]),
         EntityList.from_tensor(
             torch.tensor([70, 66, 5, 5], dtype=torch.long)))
     self.assertEqual(
         EntityList.cat(rhs_by_type[2]),
         EntityList.from_tensor(
             torch.tensor([23, 31, 42, 50], dtype=torch.long)))
 def test_basic(self):
     actual_batches = batch_edges_mix_relation_types(
         EntityList.from_tensor(
             torch.tensor([93, 24, 13, 31, 70, 66, 77, 38, 5, 5],
                          dtype=torch.long)),
         EntityList.from_tensor(
             torch.tensor([90, 75, 9, 25, 23, 31, 49, 64, 42, 50],
                          dtype=torch.long)),
         torch.tensor([1, 0, 0, 1, 2, 2, 0, 0, 2, 2], dtype=torch.long),
         batch_size=4,
     )
     expected_batches = [
         (
             EntityList.from_tensor(
                 torch.tensor([93, 24, 13, 31], dtype=torch.long)),
             EntityList.from_tensor(
                 torch.tensor([90, 75, 9, 25], dtype=torch.long)),
             torch.tensor([1, 0, 0, 1], dtype=torch.long),
         ),
         (
             EntityList.from_tensor(
                 torch.tensor([70, 66, 77, 38], dtype=torch.long)),
             EntityList.from_tensor(
                 torch.tensor([23, 31, 49, 64], dtype=torch.long)),
             torch.tensor([2, 2, 0, 0], dtype=torch.long),
         ),
         (
             EntityList.from_tensor(torch.tensor([5, 5], dtype=torch.long)),
             EntityList.from_tensor(torch.tensor([42, 50],
                                                 dtype=torch.long)),
             torch.tensor([2, 2], dtype=torch.long),
         ),
     ]
     # We can't use assertEqual because == between tensors doesn't work.
     for actual_batch, expected_batch \
             in zip_longest(actual_batches, expected_batches):
         a_lhs, a_rhs, a_rel = actual_batch
         e_lhs, e_rhs, e_rel = expected_batch
         self.assertEqual(a_lhs, e_lhs)
         self.assertEqual(a_rhs, e_rhs)
         self.assertTrue(torch.equal(a_rel, e_rel),
                         "%s != %s" % (a_rel, e_rel))
Exemple #22
0
 def test_constructor_checks(self):
     with self.assertRaises(ValueError):
         EdgeList(
             EntityList.from_tensor(
                 torch.tensor([3, 4, 0], dtype=torch.long)),
             EntityList.from_tensor(torch.tensor([2], dtype=torch.long)),
             torch.tensor(1, dtype=torch.long),
         )
     with self.assertRaises(ValueError):
         EdgeList(
             EntityList.from_tensor(torch.tensor([3, 4], dtype=torch.long)),
             EntityList.from_tensor(torch.tensor([0, 2], dtype=torch.long)),
             torch.tensor([1], dtype=torch.long),
         )
     with self.assertRaises(ValueError):
         EdgeList(
             EntityList.from_tensor(torch.tensor([3, 4], dtype=torch.long)),
             EntityList.from_tensor(torch.tensor([0, 2], dtype=torch.long)),
             torch.tensor([[1]], dtype=torch.long),
         )
Exemple #23
0
 def test_equal(self):
     el = EdgeList(
         EntityList.from_tensor(torch.tensor([3, 4], dtype=torch.long)),
         EntityList.from_tensor(torch.tensor([0, 2], dtype=torch.long)),
         torch.tensor([2, 0], dtype=torch.long),
     )
     self.assertEqual(el, el)
     self.assertNotEqual(
         el,
         EdgeList(
             EntityList.from_tensor(torch.tensor([0, 2], dtype=torch.long)),
             EntityList.from_tensor(torch.tensor([3, 4], dtype=torch.long)),
             torch.tensor([2, 0], dtype=torch.long),
         ),
     )
     self.assertNotEqual(
         el,
         EdgeList(
             EntityList.from_tensor(torch.tensor([3, 4], dtype=torch.long)),
             EntityList.from_tensor(torch.tensor([0, 2], dtype=torch.long)),
             torch.tensor(1, dtype=torch.long),
         ),
     )
Exemple #24
0
def generate_edge_path_files(
    edge_file_in: Path,
    edge_path_out: Path,
    edge_storage: AbstractEdgeStorage,
    entities_by_type: Dict[str, Dictionary],
    relation_types: Dictionary,
    relation_configs: List[RelationSchema],
    dynamic_relations: bool,
    lhs_col: int,
    rhs_col: int,
    rel_col: Optional[int],
) -> None:
    print(f"Preparing edge path {edge_path_out}, "
          f"out of the edges found in {edge_file_in}")
    edge_storage.prepare()

    num_lhs_parts = max(entities_by_type[rconfig.lhs].num_parts
                        for rconfig in relation_configs)
    num_rhs_parts = max(entities_by_type[rconfig.rhs].num_parts
                        for rconfig in relation_configs)

    print(
        f"- Edges will be partitioned in {num_lhs_parts} x {num_rhs_parts} buckets."
    )

    buckets: DefaultDict[Tuple[int, int], List[Tuple[int, int, int]]] = \
        DefaultDict(list)
    processed = 0
    skipped = 0

    with edge_file_in.open("rt") as tf:
        for line_num, line in enumerate(tf, start=1):
            words = line.split()
            try:
                lhs_word = words[lhs_col]
                rhs_word = words[rhs_col]
                rel_word = words[rel_col] if rel_col is not None else None
            except IndexError:
                raise RuntimeError(
                    f"Line {line_num} of {edge_file_in} has only {len(words)} words"
                ) from None

            if rel_col is None:
                rel_id = 0
            else:
                try:
                    rel_id = relation_types.get_id(rel_word)
                except KeyError:
                    # Ignore edges whose relation type is not known.
                    skipped += 1
                    continue

            if dynamic_relations:
                lhs_type = relation_configs[0].lhs
                rhs_type = relation_configs[0].rhs
            else:
                lhs_type = relation_configs[rel_id].lhs
                rhs_type = relation_configs[rel_id].rhs

            try:
                lhs_part, lhs_offset = \
                    entities_by_type[lhs_type].get_partition(lhs_word)
                rhs_part, rhs_offset = \
                    entities_by_type[rhs_type].get_partition(rhs_word)
            except KeyError:
                # Ignore edges whose entities are not known.
                skipped += 1
                continue

            buckets[lhs_part, rhs_part].append(
                (lhs_offset, rhs_offset, rel_id))

            processed = processed + 1
            if processed % 100000 == 0:
                print(f"- Processed {processed} edges so far...")

    print(f"- Processed {processed} edges in total")
    if skipped > 0:
        print(f"- Skipped {skipped} edges because their relation type or "
              f"entities were unknown (either not given in the config or "
              f"filtered out as too rare).")

    for i in range(num_lhs_parts):
        for j in range(num_rhs_parts):
            print(f"- Writing bucket ({i}, {j}), "
                  f"containing {len(buckets[i, j])} edges...")
            edges = torch.tensor(buckets[i, j], dtype=torch.long).view((-1, 3))
            edge_storage.save_edges(
                i, j,
                EdgeList(
                    EntityList.from_tensor(edges[:, 0]),
                    EntityList.from_tensor(edges[:, 1]),
                    edges[:, 2],
                ))
Exemple #25
0
    def do_one_job(  # noqa
        self,
        lhs_types: Set[str],
        rhs_types: Set[str],
        lhs_part: Partition,
        rhs_part: Partition,
        lhs_subpart: SubPartition,
        rhs_subpart: SubPartition,
        next_lhs_subpart: Optional[SubPartition],
        next_rhs_subpart: Optional[SubPartition],
        model: MultiRelationEmbedder,
        trainer: Trainer,
        all_embs: Dict[Tuple[EntityName, Partition], FloatTensorType],
        subpart_slices: Dict[Tuple[EntityName, Partition, SubPartition],
                             slice],
        subbuckets: Dict[Tuple[int, int], Tuple[LongTensorType, LongTensorType,
                                                LongTensorType]],
        batch_size: int,
        lr: float,
    ) -> Stats:
        tk = TimeKeeper()

        for embeddings in all_embs.values():
            assert embeddings.is_pinned()

        occurrences: Dict[Tuple[EntityName, Partition, SubPartition],
                          Set[Side]] = defaultdict(set)
        for entity_name in lhs_types:
            occurrences[entity_name, lhs_part, lhs_subpart].add(Side.LHS)
        for entity_name in rhs_types:
            occurrences[entity_name, rhs_part, rhs_subpart].add(Side.RHS)

        if lhs_part != rhs_part:  # Bipartite
            assert all(len(v) == 1 for v in occurrences.values())

        tk.start("copy_to_device")
        for entity_name, part, subpart in occurrences.keys():
            if (entity_name, part, subpart) in self.sub_holder:
                continue
            embeddings = all_embs[entity_name, part]
            optimizer = trainer.partitioned_optimizers[entity_name, part]
            subpart_slice = subpart_slices[entity_name, part, subpart]

            # TODO have two permanent storages on GPU and move stuff in and out
            # from them
            # logger.info(f"GPU #{self.gpu_idx} allocating {(subpart_slice.stop - subpart_slice.start) * embeddings.shape[1] * 4:,} bytes")
            gpu_embeddings = torch.empty(
                (subpart_slice.stop - subpart_slice.start,
                 embeddings.shape[1]),
                dtype=torch.float32,
                device=self.my_device,
            )
            gpu_embeddings.copy_(embeddings[subpart_slice], non_blocking=True)
            gpu_embeddings = torch.nn.Parameter(gpu_embeddings)
            gpu_optimizer = RowAdagrad([gpu_embeddings], lr=lr)
            (cpu_state, ) = optimizer.state.values()
            (gpu_state, ) = gpu_optimizer.state.values()
            # logger.info(f"GPU #{self.gpu_idx} allocating {(subpart_slice.stop - subpart_slice.start) * 4:,} bytes")
            gpu_state["sum"].copy_(cpu_state["sum"][subpart_slice],
                                   non_blocking=True)

            self.sub_holder[entity_name, part, subpart] = (
                gpu_embeddings,
                gpu_optimizer,
            )
        logger.debug(
            f"Time spent copying subparts to GPU: {tk.stop('copy_to_device'):.4f} s"
        )

        for (
            (entity_name, part, subpart),
            (gpu_embeddings, gpu_optimizer),
        ) in self.sub_holder.items():
            for side in occurrences[entity_name, part, subpart]:
                model.set_embeddings(entity_name, side, gpu_embeddings)
                trainer.partitioned_optimizers[entity_name, part,
                                               subpart] = gpu_optimizer

        tk.start("translate_edges")
        num_edges = subbuckets[lhs_subpart, rhs_subpart][0].shape[0]
        edge_perm = torch.randperm(num_edges)
        edges_lhs, edges_rhs, edges_rel = subbuckets[lhs_subpart, rhs_subpart]
        _C.shuffle(edges_lhs, edge_perm, os.cpu_count())
        _C.shuffle(edges_rhs, edge_perm, os.cpu_count())
        _C.shuffle(edges_rel, edge_perm, os.cpu_count())
        assert edges_lhs.is_pinned()
        assert edges_rhs.is_pinned()
        assert edges_rel.is_pinned()
        gpu_edges = EdgeList(
            EntityList.from_tensor(edges_lhs),
            EntityList.from_tensor(edges_rhs),
            edges_rel,
        ).to(self.my_device, non_blocking=True)
        logger.debug(f"GPU #{self.gpu_idx} got {num_edges} edges")
        logger.debug(
            f"Time spent copying edges to GPU: {tk.stop('translate_edges'):.4f} s"
        )

        tk.start("processing")
        stats = process_in_batches(batch_size=batch_size,
                                   model=model,
                                   batch_processor=trainer,
                                   edges=gpu_edges)
        logger.debug(f"Time spent processing: {tk.stop('processing'):.4f} s")

        next_occurrences: Dict[Tuple[EntityName, Partition, SubPartition],
                               Set[Side]] = defaultdict(set)
        if next_lhs_subpart is not None:
            for entity_name in lhs_types:
                next_occurrences[entity_name, lhs_part,
                                 next_lhs_subpart].add(Side.LHS)
        if next_rhs_subpart is not None:
            for entity_name in rhs_types:
                next_occurrences[entity_name, rhs_part,
                                 next_rhs_subpart].add(Side.RHS)

        tk.start("copy_from_device")
        for (entity_name, part,
             subpart), (gpu_embeddings,
                        gpu_optimizer) in list(self.sub_holder.items()):
            if (entity_name, part, subpart) in next_occurrences:
                continue
            embeddings = all_embs[entity_name, part]
            optimizer = trainer.partitioned_optimizers[entity_name, part]
            subpart_slice = subpart_slices[entity_name, part, subpart]

            embeddings[subpart_slice].data.copy_(gpu_embeddings.detach(),
                                                 non_blocking=True)
            del gpu_embeddings
            (cpu_state, ) = optimizer.state.values()
            (gpu_state, ) = gpu_optimizer.state.values()
            cpu_state["sum"][subpart_slice].copy_(gpu_state["sum"],
                                                  non_blocking=True)
            del gpu_state["sum"]
            del self.sub_holder[entity_name, part, subpart]
        logger.debug(
            f"Time spent copying subparts from GPU: {tk.stop('copy_from_device'):.4f} s"
        )

        logger.debug(
            f"do_one_job: Time unaccounted for: {tk.unaccounted():.4f} s")

        return stats
def generate_edge_path_files(
    edge_file_in: Path,
    edge_path_out: Path,
    edge_storage: AbstractEdgeStorage,
    entities_by_type: Dict[str, Dictionary],
    relation_types: Dictionary,
    relation_configs: List[RelationSchema],
    dynamic_relations: bool,
    edgelist_reader: EdgelistReader,
) -> None:
    print(f"Preparing edge path {edge_path_out}, "
          f"out of the edges found in {edge_file_in}")
    edge_storage.prepare()

    num_lhs_parts = max(entities_by_type[rconfig.lhs].num_parts
                        for rconfig in relation_configs)
    num_rhs_parts = max(entities_by_type[rconfig.rhs].num_parts
                        for rconfig in relation_configs)

    print(
        f"- Edges will be partitioned in {num_lhs_parts} x {num_rhs_parts} buckets."
    )

    processed = 0
    skipped = 0

    # We use an ExitStack in order to close the dynamically-created edge appenders.
    with ExitStack() as appender_stack:
        appenders: Dict[Tuple[int, int], AbstractEdgeAppender] = {}
        for lhs_word, rhs_word, rel_word in edgelist_reader.read(edge_file_in):
            if rel_word is None:
                rel_id = 0
            else:
                try:
                    rel_id = relation_types.get_id(rel_word)
                except KeyError:
                    # Ignore edges whose relation type is not known.
                    skipped += 1
                    continue

            if dynamic_relations:
                lhs_type = relation_configs[0].lhs
                rhs_type = relation_configs[0].rhs
            else:
                lhs_type = relation_configs[rel_id].lhs
                rhs_type = relation_configs[rel_id].rhs

            try:
                lhs_part, lhs_offset = \
                    entities_by_type[lhs_type].get_partition(lhs_word)
                rhs_part, rhs_offset = \
                    entities_by_type[rhs_type].get_partition(rhs_word)
            except KeyError:
                # Ignore edges whose entities are not known.
                skipped += 1
                continue

            if (lhs_part, rhs_part) not in appenders:
                appenders[lhs_part, rhs_part] = appender_stack.enter_context(
                    edge_storage.save_edges_by_appending(lhs_part, rhs_part))
            appenders[lhs_part, rhs_part].append_edges(
                EdgeList(
                    EntityList.from_tensor(
                        torch.tensor([lhs_offset], dtype=torch.long)),
                    EntityList.from_tensor(
                        torch.tensor([rhs_offset], dtype=torch.long)),
                    torch.tensor([rel_id], dtype=torch.long),
                ))

            processed = processed + 1
            if processed % 100000 == 0:
                print(f"- Processed {processed} edges so far...")

    print(f"- Processed {processed} edges in total")
    if skipped > 0:
        print(f"- Skipped {skipped} edges because their relation type or "
              f"entities were unknown (either not given in the config or "
              f"filtered out as too rare).")