def _predict(self, model: torch.nn.Module, dataset: Dataset, input_reader: BaseInputReader): # create data loader dataset.switch_mode(Dataset.EVAL_MODE) data_loader = DataLoader(dataset, batch_size=self._args.eval_batch_size, shuffle=False, drop_last=False, num_workers=self._args.sampling_processes, collate_fn=sampling.collate_fn_padding) pred_entities = [] pred_relations = [] with torch.no_grad(): model.eval() # iterate batches total = math.ceil(dataset.document_count / self._args.eval_batch_size) for batch in tqdm(data_loader, total=total, desc='Predict'): # move batch to selected device batch = util.to_device(batch, self._device) # run model (forward pass) result = model(encodings=batch['encodings'], context_masks=batch['context_masks'], entity_masks=batch['entity_masks'], entity_sizes=batch['entity_sizes'], entity_spans=batch['entity_spans'], entity_sample_masks=batch['entity_sample_masks'], inference=True) entity_clf, rel_clf, rels = result # convert predictions predictions = prediction.convert_predictions(entity_clf, rel_clf, rels, batch, self._args.rel_filter_threshold, input_reader) batch_pred_entities, batch_pred_relations = predictions pred_entities.extend(batch_pred_entities) pred_relations.extend(batch_pred_relations) prediction.store_predictions(dataset.documents, pred_entities, pred_relations, self._args.predictions_path)
def _predict(self, model: torch.nn.Module, dataset: Dataset, input_reader: JsonInputReader, epoch: int = 0, updates_epoch: int = 0, iteration: int = 0): if isinstance(model, DataParallel): model = model.modules() # create evaluator evaluator = Evaluator(dataset, input_reader, self._tokenizer, self.args.rel_filter_threshold, self.args.no_overlapping, self.args.output_path, self._examples_path, self.args.example_count, epoch, dataset.label) # create data loader dataset.switch_mode(Dataset.EVAL_MODE) data_loader = DataLoader(dataset, batch_size=self.args.eval_batch_size, shuffle=False, drop_last=False, num_workers=self.args.sampling_processes, collate_fn=sampling.collate_fn_padding) returned_predict= list() with torch.no_grad(): model.eval() # iteraate batches total = math.ceil(dataset.document_count / self.args.eval_batch_size) for batch in tqdm(data_loader, total=total, desc='Evaluate epoch %s' % epoch): # move batch to selected device batch = util.to_device(batch, self._device) # run model (forward pass) result = model(encodings=batch['encodings'], context_masks=batch['context_masks'], entity_masks=batch['entity_masks'], entity_sizes=batch['entity_sizes'], entity_spans=batch['entity_spans'], entity_sample_masks=batch['entity_sample_masks'], evaluate=True) entity_clf, rel_clf, rels = result evaluator.predict_batch(entity_clf, rel_clf, rels, batch) evaluator.store_predictions()
def _train_epoch(self, model: torch.nn.Module, compute_loss: Loss, optimizer: Optimizer, dataset: Dataset, updates_epoch: int, epoch: int): self._logger.info("Train epoch: %s" % epoch) # create data loader dataset.switch_mode(Dataset.TRAIN_MODE) data_loader = DataLoader(dataset, batch_size=self.args.train_batch_size, shuffle=True, drop_last=True, num_workers=self.args.sampling_processes, collate_fn=sampling.collate_fn_padding) model.zero_grad() iteration = 0 total = dataset.document_count // self.args.train_batch_size for batch in tqdm(data_loader, total=total, desc='Train epoch %s' % epoch): model.train() batch = util.to_device(batch, self._device) # forward step entity_logits, rel_logits, s1, s2 = model( encodings=batch['encodings'], context_masks=batch['context_masks'], entity_masks=batch['entity_masks'], entity_sizes=batch['entity_sizes'], relations=batch['rels'], rel_masks=batch['rel_masks']) # print(s1) # print(s2) # os._exit(0) # compute loss and optimize parameters batch_loss = compute_loss.compute( entity_logits=entity_logits, rel_logits=rel_logits, rel_types=batch['rel_types'], entity_types=batch['entity_types'], entity_sample_masks=batch['entity_sample_masks'], rel_sample_masks=batch['rel_sample_masks']) # logging iteration += 1 global_iteration = epoch * updates_epoch + iteration if global_iteration % self.args.train_log_iter == 0: self._log_train(optimizer, batch_loss, epoch, iteration, global_iteration, dataset.label) return iteration
def _eval(self, model: torch.nn.Module, dataset: Dataset, input_reader: BaseInputReader, epoch: int = 0, updates_epoch: int = 0, iteration: int = 0): self._logger.info("Evaluate: %s" % dataset.label) if isinstance(model, DataParallel): # currently no multi GPU support during evaluation model = model.module # create evaluator predictions_path = os.path.join(self._log_path, f'predictions_{dataset.label}_epoch_{epoch}.json') examples_path = os.path.join(self._log_path, f'examples_%s_{dataset.label}_epoch_{epoch}.html') evaluator = Evaluator(dataset, input_reader, self._tokenizer, self._args.rel_filter_threshold, self._args.no_overlapping, predictions_path, examples_path, self._args.example_count) # create data loader dataset.switch_mode(Dataset.EVAL_MODE) data_loader = DataLoader(dataset, batch_size=self._args.eval_batch_size, shuffle=False, drop_last=False, num_workers=self._args.sampling_processes, collate_fn=sampling.collate_fn_padding) with torch.no_grad(): model.eval() # iterate batches total = math.ceil(dataset.document_count / self._args.eval_batch_size) for batch in tqdm(data_loader, total=total, desc='Evaluate epoch %s' % epoch): # move batch to selected device batch = util.to_device(batch, self._device) # run model (forward pass) result = model(encodings=batch['encodings'], context_masks=batch['context_masks'], entity_masks=batch['entity_masks'], entity_sizes=batch['entity_sizes'], entity_spans=batch['entity_spans'], entity_sample_masks=batch['entity_sample_masks'], inference=True) entity_clf, rel_clf, rels = result # evaluate batch evaluator.eval_batch(entity_clf, rel_clf, rels, batch) global_iteration = epoch * updates_epoch + iteration ner_eval, rel_eval, rel_nec_eval = evaluator.compute_scores() self._log_eval(*ner_eval, *rel_eval, *rel_nec_eval, epoch, iteration, global_iteration, dataset.label) if self._args.store_predictions and not self._args.no_overlapping: evaluator.store_predictions() if self._args.store_examples: evaluator.store_examples()
def forward(self, tokens): self.input_reader.read_for_infer(tokens) dataset = self.input_reader.get_dataset("infer") dataset.switch_mode(Dataset.EVAL_MODE) data_loader = DataLoader(dataset, batch_size=self.args.eval_batch_size, shuffle=False, drop_last=False, num_workers=self.args.sampling_processes, collate_fn=sampling.collate_fn_padding) evaluator = Evaluator(dataset, self.input_reader, self._tokenizer, self.args.rel_filter_threshold, self.args.no_overlapping, None, "predictions_%s_epoch_%s.json", self.args.example_count, 0, dataset.label) with torch.no_grad(): self.model.eval() # iterate batches total = math.ceil(dataset.document_count / self.args.eval_batch_size) for batch in tqdm(data_loader, total=total, desc='Evaluate epoch %s' % 0): # move batch to selected device batch = util.to_device(batch, self._device) # run model (forward pass) result = self.model( encodings=batch['encodings'], context_masks=batch['context_masks'], entity_masks=batch['entity_masks'], entity_sizes=batch['entity_sizes'], entity_spans=batch['entity_spans'], entity_sample_masks=batch['entity_sample_masks'], evaluate=True) entity_clf, rel_clf, rels = result evaluator.eval_batch(entity_clf, rel_clf, rels, batch) return evaluator.get_preds()