def test_cat_scalar_different(self): self.assertEqual( EdgeList.cat([ EdgeList( EntityList.from_tensor( torch.tensor([3, 4], dtype=torch.long)), EntityList.from_tensor( torch.tensor([0, 2], dtype=torch.long)), torch.tensor(0, dtype=torch.long), ), EdgeList( EntityList.from_tensor( torch.tensor([1, 0], dtype=torch.long)), EntityList.from_tensor( torch.tensor([1, 3], dtype=torch.long)), torch.tensor(1, dtype=torch.long), ), ]), EdgeList( EntityList.from_tensor( torch.tensor([3, 4, 1, 0], dtype=torch.long)), EntityList.from_tensor( torch.tensor([0, 2, 1, 3], dtype=torch.long)), torch.tensor([0, 0, 1, 1], dtype=torch.long), ), )
def test_cat_vector(self): self.assertEqual( EdgeList.cat([ EdgeList( EntityList.from_tensor( torch.tensor([3, 4], dtype=torch.long)), EntityList.from_tensor( torch.tensor([0, 2], dtype=torch.long)), torch.tensor([2, 1], dtype=torch.long), ), EdgeList( EntityList.from_tensor( torch.tensor([1, 0], dtype=torch.long)), EntityList.from_tensor( torch.tensor([1, 3], dtype=torch.long)), torch.tensor([3, 0], dtype=torch.long), ), ]), EdgeList( EntityList.from_tensor( torch.tensor([3, 4, 1, 0], dtype=torch.long)), EntityList.from_tensor( torch.tensor([0, 2, 1, 3], dtype=torch.long)), torch.tensor([2, 1, 3, 0], dtype=torch.long), ), )
def process_one_batch( self, model: MultiRelationEmbedder, batch_edges: EdgeList, ) -> Stats: model.zero_grad() scores = model(batch_edges) lhs_loss = self.loss_fn(scores.lhs_pos, scores.lhs_neg) rhs_loss = self.loss_fn(scores.rhs_pos, scores.rhs_neg) relation = self.relations[batch_edges.get_relation_type_as_scalar( ) if batch_edges.has_scalar_relation_type() else 0] loss = relation.weight * (lhs_loss + rhs_loss) stats = Stats( loss=float(loss), violators_lhs=int( (scores.lhs_neg > scores.lhs_pos.unsqueeze(1)).sum()), violators_rhs=int( (scores.rhs_neg > scores.rhs_pos.unsqueeze(1)).sum()), count=len(batch_edges)) loss.backward() self.global_optimizer.step(closure=None) for optimizer in self.entity_optimizers.values(): optimizer.step(closure=None) return stats
def test_empty(self): self.assertEqual( EdgeList.empty(), EdgeList( EntityList.empty(), EntityList.empty(), torch.empty((0, ), dtype=torch.long), ), )
def calc_loss(self, scores: Scores, batch_edges: EdgeList): lhs_loss = self.loss_fn(scores.lhs_pos, scores.lhs_neg, batch_edges.weight) rhs_loss = self.loss_fn(scores.rhs_pos, scores.rhs_neg, batch_edges.weight) relation = (batch_edges.get_relation_type_as_scalar() if batch_edges.has_scalar_relation_type() else 0) loss = self.relation_weights[relation] * (lhs_loss + rhs_loss) return loss
def test_has_scalar_relation_type(self): self.assertTrue( EdgeList( EntityList.from_tensor(torch.tensor([3, 4], dtype=torch.long)), EntityList.from_tensor(torch.tensor([0, 2], dtype=torch.long)), torch.tensor(3, dtype=torch.long), ).has_scalar_relation_type()) self.assertFalse( EdgeList( EntityList.from_tensor(torch.tensor([3, 4], dtype=torch.long)), EntityList.from_tensor(torch.tensor([0, 2], dtype=torch.long)), torch.tensor([2, 0], dtype=torch.long), ).has_scalar_relation_type())
def test_getitem_int(self): self.assertEqual( EdgeList( EntityList.from_tensor( torch.tensor([3, 4, 1, 0], dtype=torch.long)), EntityList.from_tensor( torch.tensor([0, 2, 1, 3], dtype=torch.long)), torch.tensor([1, 1, 3, 0], dtype=torch.long), )[-3], EdgeList( EntityList.from_tensor(torch.tensor([4], dtype=torch.long)), EntityList.from_tensor(torch.tensor([2], dtype=torch.long)), torch.tensor(1, dtype=torch.long), ), )
def test_getitem_longtensor(self): self.assertEqual( EdgeList( EntityList.from_tensor( torch.tensor([3, 4, 1, 0], dtype=torch.long)), EntityList.from_tensor( torch.tensor([0, 2, 1, 3], dtype=torch.long)), torch.tensor([1, 1, 3, 0], dtype=torch.long), )[torch.tensor([2, 0])], EdgeList( EntityList.from_tensor(torch.tensor([1, 3], dtype=torch.long)), EntityList.from_tensor(torch.tensor([1, 0], dtype=torch.long)), torch.tensor([3, 1], dtype=torch.long), ), )
def test_basic(self): edges = EdgeList( EntityList.from_tensor( torch.tensor([93, 24, 13, 31, 70, 66, 77, 38, 5, 5], dtype=torch.long)), EntityList.from_tensor( torch.tensor([90, 75, 9, 25, 23, 31, 49, 64, 42, 50], dtype=torch.long)), torch.tensor([1, 0, 0, 1, 2, 2, 0, 0, 2, 2], dtype=torch.long), ) edges_by_type = defaultdict(list) for batch_edges in batch_edges_group_by_relation_type(edges, batch_size=3): self.assertIsInstance(batch_edges, EdgeList) self.assertLessEqual(len(batch_edges), 3) self.assertTrue(batch_edges.has_scalar_relation_type()) edges_by_type[batch_edges.get_relation_type_as_scalar()].append( batch_edges) self.assertEqual( {k: EdgeList.cat(v) for k, v in edges_by_type.items()}, { 0: EdgeList( EntityList.from_tensor( torch.tensor([24, 13, 77, 38], dtype=torch.long)), EntityList.from_tensor( torch.tensor([75, 9, 49, 64], dtype=torch.long)), torch.tensor(0, dtype=torch.long), ), 1: EdgeList( EntityList.from_tensor( torch.tensor([93, 31], dtype=torch.long)), EntityList.from_tensor( torch.tensor([90, 25], dtype=torch.long)), torch.tensor(1, dtype=torch.long), ), 2: EdgeList( EntityList.from_tensor( torch.tensor([70, 66, 5, 5], dtype=torch.long)), EntityList.from_tensor( torch.tensor([23, 31, 42, 50], dtype=torch.long)), torch.tensor(2, dtype=torch.long), ), }, )
def test_get_relation_type_as_scalar(self): self.assertEqual( EdgeList( EntityList.from_tensor(torch.tensor([3, 4], dtype=torch.long)), EntityList.from_tensor(torch.tensor([0, 2], dtype=torch.long)), torch.tensor(3, dtype=torch.long), ).get_relation_type_as_scalar(), 3, )
def generate_edge_path_files_fast( edge_file_in: Path, edge_path_out: Path, edge_storage: AbstractEdgeStorage, entities_by_type: Dict[str, Dictionary], relation_types: Dictionary, relation_configs: List[RelationSchema], edgelist_reader: EdgelistReader, ) -> None: processed = 0 skipped = 0 log("Taking the fast train!") data = [] for lhs_word, rhs_word, rel_word in edgelist_reader.read(edge_file_in): if rel_word is None: rel_id = 0 else: try: rel_id = relation_types.get_id(rel_word) except KeyError: # Ignore edges whose relation type is not known. skipped += 1 continue lhs_type = relation_configs[rel_id].lhs rhs_type = relation_configs[rel_id].rhs try: _, lhs_offset = entities_by_type[lhs_type].get_partition(lhs_word) _, rhs_offset = entities_by_type[rhs_type].get_partition(rhs_word) except KeyError: # Ignore edges whose entities are not known. skipped += 1 continue data.append((lhs_offset, rhs_offset, rel_id)) processed = processed + 1 if processed % 100000 == 0: log(f"- Processed {processed} edges so far...") lhs_offsets, rhs_offsets, rel_ids = zip(*data) edge_list = EdgeList( EntityList.from_tensor(torch.tensor(list(lhs_offsets), dtype=torch.long)), EntityList.from_tensor(torch.tensor(list(rhs_offsets), dtype=torch.long)), torch.tensor(list(rel_ids), dtype=torch.long), ) edge_storage.save_edges(0, 0, edge_list) log(f"- Processed {processed} edges in total") if skipped > 0: log( f"- Skipped {skipped} edges because their relation type or " f"entities were unknown (either not given in the config or " f"filtered out as too rare)." )
def append_to_file(data, appender): lhs_offsets, rhs_offsets, rel_ids = zip(*data) appender.append_edges( EdgeList( EntityList.from_tensor(torch.tensor(lhs_offsets, dtype=torch.long)), EntityList.from_tensor(torch.tensor(rhs_offsets, dtype=torch.long)), torch.tensor(rel_ids, dtype=torch.long), ) )
def load_chunk_of_edges( self, lhs_p: Partition, rhs_p: Partition, chunk_idx: int = 0, num_chunks: int = 1, shared: bool = False, ) -> EdgeList: file_path = self.get_edges_file(lhs_p, rhs_p) try: with h5py.File(file_path, "r") as hf: if hf.attrs.get(FORMAT_VERSION_ATTR, None) != FORMAT_VERSION: raise RuntimeError( f"Version mismatch in edge file {file_path}") lhs_ds = hf["lhs"] rhs_ds = hf["rhs"] rel_ds = hf["rel"] num_edges = rel_ds.len() chunk_size = div_roundup(num_edges, num_chunks) begin = chunk_idx * chunk_size end = min((chunk_idx + 1) * chunk_size, num_edges) chunk_size = end - begin allocator = allocate_shared_tensor if shared else torch.empty lhs = allocator((chunk_size, ), dtype=torch.long) rhs = allocator((chunk_size, ), dtype=torch.long) rel = allocator((chunk_size, ), dtype=torch.long) # Needed because https://github.com/h5py/h5py/issues/870. if chunk_size > 0: lhs_ds.read_direct(lhs.numpy(), source_sel=np.s_[begin:end]) rhs_ds.read_direct(rhs.numpy(), source_sel=np.s_[begin:end]) rel_ds.read_direct(rel.numpy(), source_sel=np.s_[begin:end]) lhsd = self.read_dynamic(hf, "lhsd", begin, end, shared=shared) rhsd = self.read_dynamic(hf, "rhsd", begin, end, shared=shared) if "weight" in hf: weight_ds = hf["weight"] weight = allocator((chunk_size, ), dtype=torch.long) if chunk_size > 0: weight_ds.read_direct(weight.numpy(), source_sel=np.s_[begin:end]) else: weight = None return EdgeList(EntityList(lhs, lhsd), EntityList(rhs, rhsd), rel, weight) except OSError as err: # h5py refuses to make it easy to figure out what went wrong. The errno # attribute is set to None. See https://github.com/h5py/h5py/issues/493. if f"errno = {errno.ENOENT}" in str(err): raise CouldNotLoadData() from err raise err
def test_basic(self): self.assertEqual( list( batch_edges_mix_relation_types( EdgeList( EntityList.from_tensor( torch.tensor( [93, 24, 13, 31, 70, 66, 77, 38, 5, 5], dtype=torch.long)), EntityList.from_tensor( torch.tensor( [90, 75, 9, 25, 23, 31, 49, 64, 42, 50], dtype=torch.long, )), torch.tensor([1, 0, 0, 1, 2, 2, 0, 0, 2, 2], dtype=torch.long), ), batch_size=4, )), [ EdgeList( EntityList.from_tensor( torch.tensor([93, 24, 13, 31], dtype=torch.long)), EntityList.from_tensor( torch.tensor([90, 75, 9, 25], dtype=torch.long)), torch.tensor([1, 0, 0, 1], dtype=torch.long), ), EdgeList( EntityList.from_tensor( torch.tensor([70, 66, 77, 38], dtype=torch.long)), EntityList.from_tensor( torch.tensor([23, 31, 49, 64], dtype=torch.long)), torch.tensor([2, 2, 0, 0], dtype=torch.long), ), EdgeList( EntityList.from_tensor( torch.tensor([5, 5], dtype=torch.long)), EntityList.from_tensor( torch.tensor([42, 50], dtype=torch.long)), torch.tensor([2, 2], dtype=torch.long), ), ], )
def test_empty(self): self.assertEqual( group_by_relation_type( EdgeList( EntityList.empty(), EntityList.empty(), torch.empty((0, ), dtype=torch.long), ), ), [], )
def group_by_relation_type(edges: EdgeList) -> List[EdgeList]: """Split the edge list in groups that have the same relation type.""" if len(edges) == 0: return [] if edges.has_scalar_relation_type(): return [edges] # FIXME Is PyTorch's sort stable? Won't this risk messing up the random shuffle? sorted_rel, order = edges.rel.sort() delta = sorted_rel[1:] - sorted_rel[:-1] cutpoints = (delta.nonzero().flatten() + 1).tolist() result: List[EdgeList] = [] for start, end in zip([0] + cutpoints, cutpoints + [len(edges)]): rel_type = sorted_rel[start] edges_for_rel_type = edges[order[start:end]] result.append( EdgeList(edges_for_rel_type.lhs, edges_for_rel_type.rhs, rel_type)) return result
def append_edges(self, edgelist: EdgeList) -> None: self.append_tensor("lhs", edgelist.lhs.tensor) self.append_tensor("rhs", edgelist.rhs.tensor) self.append_tensor("rel", edgelist.rel) if len(edgelist.lhs.tensor_list.data) != 0: self.append_tensor_list("lhsd", edgelist.lhs.tensor_list) if len(edgelist.rhs.tensor_list.data) != 0: self.append_tensor_list("rhsd", edgelist.rhs.tensor_list) if edgelist.has_weight(): self.append_tensor("weight", edgelist.weight)
def test_constant(self): self.assertEqual( group_by_relation_type( EdgeList( EntityList.from_tensor( torch.tensor([93, 24, 13, 31], dtype=torch.long)), EntityList.from_tensor( torch.tensor([90, 75, 9, 25], dtype=torch.long)), torch.tensor([3, 3, 3, 3], dtype=torch.long), ), ), [ EdgeList( EntityList.from_tensor( torch.tensor([93, 24, 13, 31], dtype=torch.long)), EntityList.from_tensor( torch.tensor([90, 75, 9, 25], dtype=torch.long)), torch.tensor(3, dtype=torch.long), ), ], )
def test_constructor_checks(self): with self.assertRaises(ValueError): EdgeList( EntityList.from_tensor( torch.tensor([3, 4, 0], dtype=torch.long)), EntityList.from_tensor(torch.tensor([2], dtype=torch.long)), torch.tensor(1, dtype=torch.long), ) with self.assertRaises(ValueError): EdgeList( EntityList.from_tensor(torch.tensor([3, 4], dtype=torch.long)), EntityList.from_tensor(torch.tensor([0, 2], dtype=torch.long)), torch.tensor([1], dtype=torch.long), ) with self.assertRaises(ValueError): EdgeList( EntityList.from_tensor(torch.tensor([3, 4], dtype=torch.long)), EntityList.from_tensor(torch.tensor([0, 2], dtype=torch.long)), torch.tensor([[1]], dtype=torch.long), )
def test_get_relation_type(self): self.assertEqual( EdgeList( EntityList.from_tensor(torch.tensor([3, 4], dtype=torch.long)), EntityList.from_tensor(torch.tensor([0, 2], dtype=torch.long)), torch.tensor(3, dtype=torch.long), ).get_relation_type(), 3, ) self.assertTrue( torch.equal( EdgeList( EntityList.from_tensor( torch.tensor([3, 4], dtype=torch.long)), EntityList.from_tensor( torch.tensor([0, 2], dtype=torch.long)), torch.tensor([2, 0], dtype=torch.long), ).get_relation_type(), torch.tensor([2, 0], dtype=torch.long), ))
def append_to_file(data, appender): lhs_offsets, rhs_offsets, rel_ids, weights = zip(*data) weights = torch.tensor(weights) if weights[0] is not None else None appender.append_edges( EdgeList( EntityList.from_tensor(torch.tensor(lhs_offsets, dtype=torch.long)), EntityList.from_tensor(torch.tensor(rhs_offsets, dtype=torch.long)), torch.tensor(rel_ids, dtype=torch.long), weights, ))
def test_len(self): self.assertEqual( len( EdgeList( EntityList.from_tensor( torch.tensor([3, 4], dtype=torch.long)), EntityList.from_tensor( torch.tensor([0, 2], dtype=torch.long)), torch.tensor([2, 0], dtype=torch.long), )), 2, )
def test_equal(self): el = EdgeList( EntityList.from_tensor(torch.tensor([3, 4], dtype=torch.long)), EntityList.from_tensor(torch.tensor([0, 2], dtype=torch.long)), torch.tensor([2, 0], dtype=torch.long), ) self.assertEqual(el, el) self.assertNotEqual( el, EdgeList( EntityList.from_tensor(torch.tensor([0, 2], dtype=torch.long)), EntityList.from_tensor(torch.tensor([3, 4], dtype=torch.long)), torch.tensor([2, 0], dtype=torch.long), ), ) self.assertNotEqual( el, EdgeList( EntityList.from_tensor(torch.tensor([3, 4], dtype=torch.long)), EntityList.from_tensor(torch.tensor([0, 2], dtype=torch.long)), torch.tensor(1, dtype=torch.long), ), )
def test_basic(self): self.assertEqual( group_by_relation_type( EdgeList( EntityList.from_tensor( torch.tensor([93, 24, 13, 31, 70, 66, 77, 38, 5, 5], dtype=torch.long)), EntityList.from_tensor( torch.tensor([90, 75, 9, 25, 23, 31, 49, 64, 42, 50], dtype=torch.long)), torch.tensor([1, 0, 0, 1, 2, 2, 0, 0, 2, 2], dtype=torch.long), ), ), [ EdgeList( EntityList.from_tensor( torch.tensor([24, 13, 77, 38], dtype=torch.long)), EntityList.from_tensor( torch.tensor([75, 9, 49, 64], dtype=torch.long)), torch.tensor(0, dtype=torch.long), ), EdgeList( EntityList.from_tensor( torch.tensor([93, 31], dtype=torch.long)), EntityList.from_tensor( torch.tensor([90, 25], dtype=torch.long)), torch.tensor(1, dtype=torch.long), ), EdgeList( EntityList.from_tensor( torch.tensor([70, 66, 5, 5], dtype=torch.long)), EntityList.from_tensor( torch.tensor([23, 31, 42, 50], dtype=torch.long)), torch.tensor(2, dtype=torch.long), ), ], )
def load_chunk_of_edges( self, lhs_p: int, rhs_p: int, chunk_idx: int = 0, num_chunks: int = 1, ) -> EdgeList: file_path = self.get_edges_file(lhs_p, rhs_p) try: with h5py.File(file_path, "r") as hf: if hf.attrs.get(FORMAT_VERSION_ATTR, None) != FORMAT_VERSION: raise RuntimeError( f"Version mismatch in edge file {file_path}") lhs_ds = hf["lhs"] rhs_ds = hf["rhs"] rel_ds = hf["rel"] num_edges = rel_ds.len() begin = int(chunk_idx * num_edges / num_chunks) end = int((chunk_idx + 1) * num_edges / num_chunks) chunk_size = end - begin lhs = torch.empty((chunk_size, ), dtype=torch.long) rhs = torch.empty((chunk_size, ), dtype=torch.long) rel = torch.empty((chunk_size, ), dtype=torch.long) # Needed because https://github.com/h5py/h5py/issues/870. if chunk_size > 0: lhs_ds.read_direct(lhs.numpy(), source_sel=np.s_[begin:end]) rhs_ds.read_direct(rhs.numpy(), source_sel=np.s_[begin:end]) rel_ds.read_direct(rel.numpy(), source_sel=np.s_[begin:end]) lhsd = self.read_dynamic(hf, "lhsd", begin, end) rhsd = self.read_dynamic(hf, "rhsd", begin, end) return EdgeList(EntityList(lhs, lhsd), EntityList(rhs, rhsd), rel) except OSError as err: # h5py refuses to make it easy to figure out what went wrong. The errno # attribute is set to None. See https://github.com/h5py/h5py/issues/493. if f"errno = {errno.ENOENT}" in str(err): raise CouldNotLoadData() from err raise err
def read( self, lhs_p: Partition, rhs_p: Partition, chunk_idx: int = 0, num_chunks: int = 1, ) -> EdgeList: file_path = os.path.join(self.path, "edges_%d_%d.h5" % (lhs_p, rhs_p)) assert os.path.exists(file_path), "%s does not exist" % file_path with h5py.File(file_path, 'r') as hf: if FORMAT_VERSION_ATTR not in hf.attrs: log("WARNING: It may be that one of your edge paths contains " "files using the old format. See D14241362 for how to " "update them.") elif hf.attrs[FORMAT_VERSION_ATTR] != FORMAT_VERSION: raise RuntimeError("Version mismatch in edge file %s" % file_path) lhs_ds = hf['lhs'] rhs_ds = hf['rhs'] rel_ds = hf['rel'] num_edges = rel_ds.len() begin = int(chunk_idx * num_edges / num_chunks) end = int((chunk_idx + 1) * num_edges / num_chunks) chunk_size = end - begin lhs = torch.empty((chunk_size,), dtype=torch.long) rhs = torch.empty((chunk_size,), dtype=torch.long) rel = torch.empty((chunk_size,), dtype=torch.long) # Needed because https://github.com/h5py/h5py/issues/870. if chunk_size > 0: lhs_ds.read_direct(lhs.numpy(), source_sel=np.s_[begin:end]) rhs_ds.read_direct(rhs.numpy(), source_sel=np.s_[begin:end]) rel_ds.read_direct(rel.numpy(), source_sel=np.s_[begin:end]) lhsd = self.read_dynamic(hf, 'lhsd', begin, end) rhsd = self.read_dynamic(hf, 'rhsd', begin, end) return EdgeList(EntityList(lhs, lhsd), EntityList(rhs, rhsd), rel)
def load_chunk_of_edges( self, lhs_p: int, rhs_p: int, chunk_idx: int = 0, num_chunks: int = 1, ) -> EdgeList: file_path = self.get_edges_file(lhs_p, rhs_p) if not file_path.is_file(): raise RuntimeError(f"{file_path} does not exist") with h5py.File(file_path, 'r') as hf: if hf.attrs.get(FORMAT_VERSION_ATTR, None) != FORMAT_VERSION: raise RuntimeError( f"Version mismatch in edge file {file_path}") lhs_ds = hf['lhs'] rhs_ds = hf['rhs'] rel_ds = hf['rel'] num_edges = rel_ds.len() begin = int(chunk_idx * num_edges / num_chunks) end = int((chunk_idx + 1) * num_edges / num_chunks) chunk_size = end - begin lhs = torch.empty((chunk_size, ), dtype=torch.long) rhs = torch.empty((chunk_size, ), dtype=torch.long) rel = torch.empty((chunk_size, ), dtype=torch.long) # Needed because https://github.com/h5py/h5py/issues/870. if chunk_size > 0: lhs_ds.read_direct(lhs.numpy(), source_sel=np.s_[begin:end]) rhs_ds.read_direct(rhs.numpy(), source_sel=np.s_[begin:end]) rel_ds.read_direct(rel.numpy(), source_sel=np.s_[begin:end]) lhsd = self.read_dynamic(hf, 'lhsd', begin, end) rhsd = self.read_dynamic(hf, 'rhsd', begin, end) return EdgeList(EntityList(lhs, lhsd), EntityList(rhs, rhsd), rel)
def read( self, lhs_p: Partition, rhs_p: Partition, chunk_idx: int = 0, num_chunks: int = 1, ) -> EdgeList: file_path = os.path.join(self.path, f"edges_{lhs_p}_{rhs_p}.h5") assert os.path.exists(file_path), "%s does not exist" % file_path with h5py.File(file_path, 'r') as hf: if hf.attrs.get(FORMAT_VERSION_ATTR, None) != FORMAT_VERSION: raise RuntimeError("Version mismatch in edge file %s" % file_path) lhs_ds = hf['lhs'] rhs_ds = hf['rhs'] rel_ds = hf['rel'] num_edges = rel_ds.len() begin = int(chunk_idx * num_edges / num_chunks) end = int((chunk_idx + 1) * num_edges / num_chunks) chunk_size = end - begin lhs = torch.empty((chunk_size, ), dtype=torch.long) rhs = torch.empty((chunk_size, ), dtype=torch.long) rel = torch.empty((chunk_size, ), dtype=torch.long) # Needed because https://github.com/h5py/h5py/issues/870. if chunk_size > 0: lhs_ds.read_direct(lhs.numpy(), source_sel=np.s_[begin:end]) rhs_ds.read_direct(rhs.numpy(), source_sel=np.s_[begin:end]) rel_ds.read_direct(rel.numpy(), source_sel=np.s_[begin:end]) lhsd = self.read_dynamic(hf, 'lhsd', begin, end) rhsd = self.read_dynamic(hf, 'rhsd', begin, end) return EdgeList(EntityList(lhs, lhsd), EntityList(rhs, rhsd), rel)
def forward( self, edges: EdgeList, ) -> Scores: num_pos = len(edges) chunk_size: int lhs_negatives: Negatives lhs_num_uniform_negs: int rhs_negatives: Negatives rhs_num_uniform_negs: int if self.num_dynamic_rels > 0: if edges.has_scalar_relation_type(): raise TypeError("Need relation for each positive pair") relation_idx = 0 else: if not edges.has_scalar_relation_type(): raise TypeError( "All positive pairs must come from the same relation") relation_idx = edges.get_relation_type_as_scalar() relation = self.relations[relation_idx] lhs_module: AbstractEmbedding = self.lhs_embs[self.EMB_PREFIX + relation.lhs] rhs_module: AbstractEmbedding = self.rhs_embs[self.EMB_PREFIX + relation.rhs] lhs_pos: FloatTensorType = lhs_module(edges.lhs) rhs_pos: FloatTensorType = rhs_module(edges.rhs) if relation.all_negs: chunk_size = num_pos negative_sampling_method = Negatives.ALL elif self.num_batch_negs == 0: chunk_size = self.num_uniform_negs negative_sampling_method = Negatives.UNIFORM else: chunk_size = self.num_batch_negs negative_sampling_method = Negatives.BATCH_UNIFORM if self.num_dynamic_rels == 0: # In this case the operator is only applied to the RHS. This means # that an edge (u, r, v) is scored with c(u, f_r(v)), whereas the # negatives (u', r, v) and (u, r, v') are scored respectively with # c(u', f_r(v)) and c(u, f_r(v')). Since r is always the same, each # positive and negative right-hand side entity is only passed once # through the operator. if self.lhs_operators[relation_idx] is not None: raise RuntimeError("In non-dynamic relation mode there should " "be only a right-hand side operator") # Apply operator to right-hand side, sample negatives on both sides. pos_scores, lhs_neg_scores, rhs_neg_scores = self.forward_direction_agnostic( edges.lhs, edges.rhs, edges.get_relation_type(), relation.lhs, relation.rhs, None, self.rhs_operators[relation_idx], lhs_module, rhs_module, lhs_pos, rhs_pos, chunk_size, negative_sampling_method, negative_sampling_method, ) lhs_pos_scores = rhs_pos_scores = pos_scores else: # In this case the positive edges may come from different relations. # This makes it inefficient to apply the operators to the negatives # in the way we do above, because for a negative edge (u, r, v') we # would need to compute f_r(v'), with r being different from the one # in any positive pair that has v' on the right-hand side, which # could lead to v being passed through many different (potentially # all) operators. This would result in a combinatorial explosion. # So, instead, we duplicate all operators, creating two versions of # them, one for each side, and only allow one of them to be applied # at any given time. The edge (u, r, v) can thus be scored in two # ways, either as c(g_r(u), v) or as c(u, h_r(v)). The negatives # (u', r, v) and (u, r, v') are scored respectively as c(u', h_r(v)) # and c(g_r(u), v'). This way we only need to perform two operator # applications for every positive input edge, one for each side. # "Forward" edges: apply operator to rhs, sample negatives on lhs. lhs_pos_scores, lhs_neg_scores, _ = self.forward_direction_agnostic( edges.lhs, edges.rhs, edges.get_relation_type(), relation.lhs, relation.rhs, None, self.rhs_operators[relation_idx], lhs_module, rhs_module, lhs_pos, rhs_pos, chunk_size, negative_sampling_method, Negatives.NONE, ) # "Reverse" edges: apply operator to lhs, sample negatives on rhs. rhs_pos_scores, rhs_neg_scores, _ = self.forward_direction_agnostic( edges.rhs, edges.lhs, edges.get_relation_type(), relation.rhs, relation.lhs, None, self.lhs_operators[relation_idx], rhs_module, lhs_module, rhs_pos, lhs_pos, chunk_size, negative_sampling_method, Negatives.NONE, ) return Scores(lhs_pos_scores, rhs_pos_scores, lhs_neg_scores, rhs_neg_scores)
def do_one_job( # noqa self, lhs_types: Set[str], rhs_types: Set[str], lhs_part: Partition, rhs_part: Partition, lhs_subpart: SubPartition, rhs_subpart: SubPartition, next_lhs_subpart: Optional[SubPartition], next_rhs_subpart: Optional[SubPartition], model: MultiRelationEmbedder, trainer: Trainer, all_embs: Dict[Tuple[EntityName, Partition], FloatTensorType], subpart_slices: Dict[Tuple[EntityName, Partition, SubPartition], slice], subbuckets: Dict[Tuple[int, int], Tuple[LongTensorType, LongTensorType, LongTensorType]], batch_size: int, lr: float, ) -> Stats: tk = TimeKeeper() for embeddings in all_embs.values(): assert embeddings.is_pinned() occurrences: Dict[Tuple[EntityName, Partition, SubPartition], Set[Side]] = defaultdict(set) for entity_name in lhs_types: occurrences[entity_name, lhs_part, lhs_subpart].add(Side.LHS) for entity_name in rhs_types: occurrences[entity_name, rhs_part, rhs_subpart].add(Side.RHS) if lhs_part != rhs_part: # Bipartite assert all(len(v) == 1 for v in occurrences.values()) tk.start("copy_to_device") for entity_name, part, subpart in occurrences.keys(): if (entity_name, part, subpart) in self.sub_holder: continue embeddings = all_embs[entity_name, part] optimizer = trainer.partitioned_optimizers[entity_name, part] subpart_slice = subpart_slices[entity_name, part, subpart] # TODO have two permanent storages on GPU and move stuff in and out # from them # logger.info(f"GPU #{self.gpu_idx} allocating {(subpart_slice.stop - subpart_slice.start) * embeddings.shape[1] * 4:,} bytes") gpu_embeddings = torch.empty( (subpart_slice.stop - subpart_slice.start, embeddings.shape[1]), dtype=torch.float32, device=self.my_device, ) gpu_embeddings.copy_(embeddings[subpart_slice], non_blocking=True) gpu_embeddings = torch.nn.Parameter(gpu_embeddings) gpu_optimizer = RowAdagrad([gpu_embeddings], lr=lr) (cpu_state, ) = optimizer.state.values() (gpu_state, ) = gpu_optimizer.state.values() # logger.info(f"GPU #{self.gpu_idx} allocating {(subpart_slice.stop - subpart_slice.start) * 4:,} bytes") gpu_state["sum"].copy_(cpu_state["sum"][subpart_slice], non_blocking=True) self.sub_holder[entity_name, part, subpart] = ( gpu_embeddings, gpu_optimizer, ) logger.debug( f"Time spent copying subparts to GPU: {tk.stop('copy_to_device'):.4f} s" ) for ( (entity_name, part, subpart), (gpu_embeddings, gpu_optimizer), ) in self.sub_holder.items(): for side in occurrences[entity_name, part, subpart]: model.set_embeddings(entity_name, side, gpu_embeddings) trainer.partitioned_optimizers[entity_name, part, subpart] = gpu_optimizer tk.start("translate_edges") num_edges = subbuckets[lhs_subpart, rhs_subpart][0].shape[0] edge_perm = torch.randperm(num_edges) edges_lhs, edges_rhs, edges_rel = subbuckets[lhs_subpart, rhs_subpart] _C.shuffle(edges_lhs, edge_perm, os.cpu_count()) _C.shuffle(edges_rhs, edge_perm, os.cpu_count()) _C.shuffle(edges_rel, edge_perm, os.cpu_count()) assert edges_lhs.is_pinned() assert edges_rhs.is_pinned() assert edges_rel.is_pinned() gpu_edges = EdgeList( EntityList.from_tensor(edges_lhs), EntityList.from_tensor(edges_rhs), edges_rel, ).to(self.my_device, non_blocking=True) logger.debug(f"GPU #{self.gpu_idx} got {num_edges} edges") logger.debug( f"Time spent copying edges to GPU: {tk.stop('translate_edges'):.4f} s" ) tk.start("processing") stats = process_in_batches(batch_size=batch_size, model=model, batch_processor=trainer, edges=gpu_edges) logger.debug(f"Time spent processing: {tk.stop('processing'):.4f} s") next_occurrences: Dict[Tuple[EntityName, Partition, SubPartition], Set[Side]] = defaultdict(set) if next_lhs_subpart is not None: for entity_name in lhs_types: next_occurrences[entity_name, lhs_part, next_lhs_subpart].add(Side.LHS) if next_rhs_subpart is not None: for entity_name in rhs_types: next_occurrences[entity_name, rhs_part, next_rhs_subpart].add(Side.RHS) tk.start("copy_from_device") for (entity_name, part, subpart), (gpu_embeddings, gpu_optimizer) in list(self.sub_holder.items()): if (entity_name, part, subpart) in next_occurrences: continue embeddings = all_embs[entity_name, part] optimizer = trainer.partitioned_optimizers[entity_name, part] subpart_slice = subpart_slices[entity_name, part, subpart] embeddings[subpart_slice].data.copy_(gpu_embeddings.detach(), non_blocking=True) del gpu_embeddings (cpu_state, ) = optimizer.state.values() (gpu_state, ) = gpu_optimizer.state.values() cpu_state["sum"][subpart_slice].copy_(gpu_state["sum"], non_blocking=True) del gpu_state["sum"] del self.sub_holder[entity_name, part, subpart] logger.debug( f"Time spent copying subparts from GPU: {tk.stop('copy_from_device'):.4f} s" ) logger.debug( f"do_one_job: Time unaccounted for: {tk.unaccounted():.4f} s") return stats