def loop_through_data_for_eval(self, dataset: DataLoader, model: Model, graph: Graph) -> Result: graph.to(self.device) model.to(device=self.device) model.eval() result = Result(entity_dict=self.entity_id_to_str_dict, relation_dict=self.relation_id_to_str_dict) for idx, (paths, mask, _, triplet, num_paths) in enumerate(tqdm(dataset)): labels = triplet[:, 1] assert len(triplet) == len(labels) if num_paths.size() == torch.Size([1, 1]) and num_paths.item() == 0: score = torch.randn(1, self.num_relations) else: paths = paths.to(device=self.device) mask = mask.to(device=self.device) triplet = triplet.to(device=self.device) score = model(triplet, graph, paths=paths, masks=mask, num_paths=num_paths) result.append(score.cpu(), labels.cpu()) return result
def loop_through_data_for_eval( self, dataset: np.ndarray, # assuming batch_size * 3 model: Model, graph: Graph, batch_size: int) -> Result: graph.to(self.device) model.to(device=self.device) model.eval() result = Result() num_batches = ceil(batch_size / len(dataset)) for batch_idx in range(num_batches): start_idx, end_idx = batch_idx * batch_size, batch_idx * batch_size + batch_size batch = torch.from_numpy( dataset[start_idx:end_idx]).long().to(device=self.device) if self.config.link_predict: labels = batch[:, 1] else: labels = batch[:, 2] # the objects in <subject, relation, object> scores = model(batch, graph, link_predict=self.config.link_predict) result.append(scores.cpu(), labels.cpu()) return result