Exemplo n.º 1
0
    def _predict(self, model: torch.nn.Module, dataset: Dataset, input_reader: BaseInputReader):
        # create data loader
        dataset.switch_mode(Dataset.EVAL_MODE)
        data_loader = DataLoader(dataset, batch_size=self._args.eval_batch_size, shuffle=False, drop_last=False,
                                 num_workers=self._args.sampling_processes, collate_fn=sampling.collate_fn_padding)

        pred_entities = []
        pred_relations = []

        with torch.no_grad():
            model.eval()

            # iterate batches
            total = math.ceil(dataset.document_count / self._args.eval_batch_size)
            for batch in tqdm(data_loader, total=total, desc='Predict'):
                # move batch to selected device
                batch = util.to_device(batch, self._device)

                # run model (forward pass)
                result = model(encodings=batch['encodings'], context_masks=batch['context_masks'],
                               entity_masks=batch['entity_masks'], entity_sizes=batch['entity_sizes'],
                               entity_spans=batch['entity_spans'], entity_sample_masks=batch['entity_sample_masks'],
                               inference=True)
                entity_clf, rel_clf, rels = result

                # convert predictions
                predictions = prediction.convert_predictions(entity_clf, rel_clf, rels,
                                                             batch, self._args.rel_filter_threshold,
                                                             input_reader)

                batch_pred_entities, batch_pred_relations = predictions
                pred_entities.extend(batch_pred_entities)
                pred_relations.extend(batch_pred_relations)

        prediction.store_predictions(dataset.documents, pred_entities, pred_relations, self._args.predictions_path)
Exemplo n.º 2
0
    def _predict(self, model: torch.nn.Module, dataset: Dataset, input_reader: JsonInputReader,
                 epoch: int = 0, updates_epoch: int = 0, iteration: int = 0):

        if isinstance(model, DataParallel):
            model = model.modules()

        # create evaluator
        evaluator = Evaluator(dataset, input_reader, self._tokenizer,
                              self.args.rel_filter_threshold, self.args.no_overlapping, self.args.output_path,
                              self._examples_path, self.args.example_count, epoch, dataset.label)
        # create data loader
        dataset.switch_mode(Dataset.EVAL_MODE)
        data_loader = DataLoader(dataset, batch_size=self.args.eval_batch_size, shuffle=False, drop_last=False,
                                 num_workers=self.args.sampling_processes, collate_fn=sampling.collate_fn_padding)

        returned_predict= list()
        with torch.no_grad():
            model.eval()

            # iteraate batches
            total = math.ceil(dataset.document_count / self.args.eval_batch_size)
            for batch in tqdm(data_loader, total=total, desc='Evaluate epoch %s' % epoch):
                # move batch to selected device
                batch = util.to_device(batch, self._device)

                # run model (forward pass)
                result = model(encodings=batch['encodings'], context_masks=batch['context_masks'],
                               entity_masks=batch['entity_masks'], entity_sizes=batch['entity_sizes'],
                               entity_spans=batch['entity_spans'], entity_sample_masks=batch['entity_sample_masks'],
                               evaluate=True)
                entity_clf, rel_clf, rels = result

                evaluator.predict_batch(entity_clf, rel_clf, rels, batch)
        evaluator.store_predictions()
Exemplo n.º 3
0
    def _train_epoch(self, model: torch.nn.Module, compute_loss: Loss,
                     optimizer: Optimizer, dataset: Dataset,
                     updates_epoch: int, epoch: int):
        self._logger.info("Train epoch: %s" % epoch)

        # create data loader
        dataset.switch_mode(Dataset.TRAIN_MODE)
        data_loader = DataLoader(dataset,
                                 batch_size=self.args.train_batch_size,
                                 shuffle=True,
                                 drop_last=True,
                                 num_workers=self.args.sampling_processes,
                                 collate_fn=sampling.collate_fn_padding)

        model.zero_grad()

        iteration = 0
        total = dataset.document_count // self.args.train_batch_size
        for batch in tqdm(data_loader,
                          total=total,
                          desc='Train epoch %s' % epoch):
            model.train()
            batch = util.to_device(batch, self._device)

            # forward step
            entity_logits, rel_logits, s1, s2 = model(
                encodings=batch['encodings'],
                context_masks=batch['context_masks'],
                entity_masks=batch['entity_masks'],
                entity_sizes=batch['entity_sizes'],
                relations=batch['rels'],
                rel_masks=batch['rel_masks'])
            # print(s1)
            # print(s2)
            # os._exit(0)
            # compute loss and optimize parameters
            batch_loss = compute_loss.compute(
                entity_logits=entity_logits,
                rel_logits=rel_logits,
                rel_types=batch['rel_types'],
                entity_types=batch['entity_types'],
                entity_sample_masks=batch['entity_sample_masks'],
                rel_sample_masks=batch['rel_sample_masks'])

            # logging
            iteration += 1
            global_iteration = epoch * updates_epoch + iteration

            if global_iteration % self.args.train_log_iter == 0:
                self._log_train(optimizer, batch_loss, epoch, iteration,
                                global_iteration, dataset.label)

        return iteration
Exemplo n.º 4
0
    def _eval(self, model: torch.nn.Module, dataset: Dataset, input_reader: BaseInputReader,
              epoch: int = 0, updates_epoch: int = 0, iteration: int = 0):
        self._logger.info("Evaluate: %s" % dataset.label)

        if isinstance(model, DataParallel):
            # currently no multi GPU support during evaluation
            model = model.module

        # create evaluator
        predictions_path = os.path.join(self._log_path, f'predictions_{dataset.label}_epoch_{epoch}.json')
        examples_path = os.path.join(self._log_path, f'examples_%s_{dataset.label}_epoch_{epoch}.html')
        evaluator = Evaluator(dataset, input_reader, self._tokenizer,
                              self._args.rel_filter_threshold, self._args.no_overlapping, predictions_path,
                              examples_path, self._args.example_count)

        # create data loader
        dataset.switch_mode(Dataset.EVAL_MODE)
        data_loader = DataLoader(dataset, batch_size=self._args.eval_batch_size, shuffle=False, drop_last=False,
                                 num_workers=self._args.sampling_processes, collate_fn=sampling.collate_fn_padding)

        with torch.no_grad():
            model.eval()

            # iterate batches
            total = math.ceil(dataset.document_count / self._args.eval_batch_size)
            for batch in tqdm(data_loader, total=total, desc='Evaluate epoch %s' % epoch):
                # move batch to selected device
                batch = util.to_device(batch, self._device)

                # run model (forward pass)
                result = model(encodings=batch['encodings'], context_masks=batch['context_masks'],
                               entity_masks=batch['entity_masks'], entity_sizes=batch['entity_sizes'],
                               entity_spans=batch['entity_spans'], entity_sample_masks=batch['entity_sample_masks'],
                               inference=True)
                entity_clf, rel_clf, rels = result

                # evaluate batch
                evaluator.eval_batch(entity_clf, rel_clf, rels, batch)

        global_iteration = epoch * updates_epoch + iteration
        ner_eval, rel_eval, rel_nec_eval = evaluator.compute_scores()
        self._log_eval(*ner_eval, *rel_eval, *rel_nec_eval,
                       epoch, iteration, global_iteration, dataset.label)

        if self._args.store_predictions and not self._args.no_overlapping:
            evaluator.store_predictions()

        if self._args.store_examples:
            evaluator.store_examples()
Exemplo n.º 5
0
    def forward(self, tokens):

        self.input_reader.read_for_infer(tokens)
        dataset = self.input_reader.get_dataset("infer")
        dataset.switch_mode(Dataset.EVAL_MODE)
        data_loader = DataLoader(dataset,
                                 batch_size=self.args.eval_batch_size,
                                 shuffle=False,
                                 drop_last=False,
                                 num_workers=self.args.sampling_processes,
                                 collate_fn=sampling.collate_fn_padding)

        evaluator = Evaluator(dataset, self.input_reader, self._tokenizer,
                              self.args.rel_filter_threshold,
                              self.args.no_overlapping, None,
                              "predictions_%s_epoch_%s.json",
                              self.args.example_count, 0, dataset.label)

        with torch.no_grad():
            self.model.eval()

            # iterate batches
            total = math.ceil(dataset.document_count /
                              self.args.eval_batch_size)
            for batch in tqdm(data_loader,
                              total=total,
                              desc='Evaluate epoch %s' % 0):
                # move batch to selected device
                batch = util.to_device(batch, self._device)

                # run model (forward pass)
                result = self.model(
                    encodings=batch['encodings'],
                    context_masks=batch['context_masks'],
                    entity_masks=batch['entity_masks'],
                    entity_sizes=batch['entity_sizes'],
                    entity_spans=batch['entity_spans'],
                    entity_sample_masks=batch['entity_sample_masks'],
                    evaluate=True)
                entity_clf, rel_clf, rels = result
                evaluator.eval_batch(entity_clf, rel_clf, rels, batch)

        return evaluator.get_preds()