def _process_one_batch( self, model: MultiRelationEmbedder, batch_edges: EdgeList ) -> Stats: model.zero_grad() scores, reg = model(batch_edges) loss = self.calc_loss(scores, batch_edges) stats = Stats( loss=float(loss), reg=float(reg) if reg is not None else 0.0, violators_lhs=int((scores.lhs_neg > scores.lhs_pos.unsqueeze(1)).sum()), violators_rhs=int((scores.rhs_neg > scores.rhs_pos.unsqueeze(1)).sum()), count=len(batch_edges), ) if reg is not None: (loss + reg).backward() else: loss.backward() self.model_optimizer.step(closure=None) for optimizer in self.unpartitioned_optimizers.values(): optimizer.step(closure=None) for optimizer in self.partitioned_optimizers.values(): optimizer.step(closure=None) return stats
def process_one_batch( self, model: MultiRelationEmbedder, batch_edges: EdgeList, ) -> Stats: model.zero_grad() scores = model(batch_edges) lhs_loss = self.loss_fn(scores.lhs_pos, scores.lhs_neg) rhs_loss = self.loss_fn(scores.rhs_pos, scores.rhs_neg) relation = self.relations[batch_edges.get_relation_type_as_scalar( ) if batch_edges.has_scalar_relation_type() else 0] loss = relation.weight * (lhs_loss + rhs_loss) stats = Stats( loss=float(loss), violators_lhs=int( (scores.lhs_neg > scores.lhs_pos.unsqueeze(1)).sum()), violators_rhs=int( (scores.rhs_neg > scores.rhs_pos.unsqueeze(1)).sum()), count=len(batch_edges)) loss.backward() self.global_optimizer.step(closure=None) for optimizer in self.entity_optimizers.values(): optimizer.step(closure=None) return stats