def process_one_batch( self, model: MultiRelationEmbedder, batch_edges: EdgeList, ) -> Stats: model.zero_grad() scores = model(batch_edges) lhs_loss = self.loss_fn(scores.lhs_pos, scores.lhs_neg) rhs_loss = self.loss_fn(scores.rhs_pos, scores.rhs_neg) relation = self.relations[batch_edges.get_relation_type_as_scalar( ) if batch_edges.has_scalar_relation_type() else 0] loss = relation.weight * (lhs_loss + rhs_loss) stats = Stats( loss=float(loss), violators_lhs=int( (scores.lhs_neg > scores.lhs_pos.unsqueeze(1)).sum()), violators_rhs=int( (scores.rhs_neg > scores.rhs_pos.unsqueeze(1)).sum()), count=len(batch_edges)) loss.backward() self.global_optimizer.step(closure=None) for optimizer in self.entity_optimizers.values(): optimizer.step(closure=None) return stats
def calc_loss(self, scores: Scores, batch_edges: EdgeList): lhs_loss = self.loss_fn(scores.lhs_pos, scores.lhs_neg, batch_edges.weight) rhs_loss = self.loss_fn(scores.rhs_pos, scores.rhs_neg, batch_edges.weight) relation = (batch_edges.get_relation_type_as_scalar() if batch_edges.has_scalar_relation_type() else 0) loss = self.relation_weights[relation] * (lhs_loss + rhs_loss) return loss
def forward( self, edges: EdgeList, ) -> Scores: num_pos = len(edges) chunk_size: int lhs_negatives: Negatives lhs_num_uniform_negs: int rhs_negatives: Negatives rhs_num_uniform_negs: int if self.num_dynamic_rels > 0: if edges.has_scalar_relation_type(): raise TypeError("Need relation for each positive pair") relation_idx = 0 else: if not edges.has_scalar_relation_type(): raise TypeError( "All positive pairs must come from the same relation") relation_idx = edges.get_relation_type_as_scalar() relation = self.relations[relation_idx] lhs_module: AbstractEmbedding = self.lhs_embs[self.EMB_PREFIX + relation.lhs] rhs_module: AbstractEmbedding = self.rhs_embs[self.EMB_PREFIX + relation.rhs] lhs_pos: FloatTensorType = lhs_module(edges.lhs) rhs_pos: FloatTensorType = rhs_module(edges.rhs) if relation.all_negs: chunk_size = num_pos negative_sampling_method = Negatives.ALL elif self.num_batch_negs == 0: chunk_size = self.num_uniform_negs negative_sampling_method = Negatives.UNIFORM else: chunk_size = self.num_batch_negs negative_sampling_method = Negatives.BATCH_UNIFORM if self.num_dynamic_rels == 0: # In this case the operator is only applied to the RHS. This means # that an edge (u, r, v) is scored with c(u, f_r(v)), whereas the # negatives (u', r, v) and (u, r, v') are scored respectively with # c(u', f_r(v)) and c(u, f_r(v')). Since r is always the same, each # positive and negative right-hand side entity is only passed once # through the operator. if self.lhs_operators[relation_idx] is not None: raise RuntimeError("In non-dynamic relation mode there should " "be only a right-hand side operator") # Apply operator to right-hand side, sample negatives on both sides. pos_scores, lhs_neg_scores, rhs_neg_scores = self.forward_direction_agnostic( edges.lhs, edges.rhs, edges.get_relation_type(), relation.lhs, relation.rhs, None, self.rhs_operators[relation_idx], lhs_module, rhs_module, lhs_pos, rhs_pos, chunk_size, negative_sampling_method, negative_sampling_method, ) lhs_pos_scores = rhs_pos_scores = pos_scores else: # In this case the positive edges may come from different relations. # This makes it inefficient to apply the operators to the negatives # in the way we do above, because for a negative edge (u, r, v') we # would need to compute f_r(v'), with r being different from the one # in any positive pair that has v' on the right-hand side, which # could lead to v being passed through many different (potentially # all) operators. This would result in a combinatorial explosion. # So, instead, we duplicate all operators, creating two versions of # them, one for each side, and only allow one of them to be applied # at any given time. The edge (u, r, v) can thus be scored in two # ways, either as c(g_r(u), v) or as c(u, h_r(v)). The negatives # (u', r, v) and (u, r, v') are scored respectively as c(u', h_r(v)) # and c(g_r(u), v'). This way we only need to perform two operator # applications for every positive input edge, one for each side. # "Forward" edges: apply operator to rhs, sample negatives on lhs. lhs_pos_scores, lhs_neg_scores, _ = self.forward_direction_agnostic( edges.lhs, edges.rhs, edges.get_relation_type(), relation.lhs, relation.rhs, None, self.rhs_operators[relation_idx], lhs_module, rhs_module, lhs_pos, rhs_pos, chunk_size, negative_sampling_method, Negatives.NONE, ) # "Reverse" edges: apply operator to lhs, sample negatives on rhs. rhs_pos_scores, rhs_neg_scores, _ = self.forward_direction_agnostic( edges.rhs, edges.lhs, edges.get_relation_type(), relation.rhs, relation.lhs, None, self.lhs_operators[relation_idx], rhs_module, lhs_module, rhs_pos, lhs_pos, chunk_size, negative_sampling_method, Negatives.NONE, ) return Scores(lhs_pos_scores, rhs_pos_scores, lhs_neg_scores, rhs_neg_scores)