示例#1
0
class SRLPredictor(RequestPackingProcessor):
    """
    An Semantic Role labeler trained according to `He, Luheng, et al.
    "Jointly predicting predicates and arguments in neural semantic role
    labeling." <https://aclweb.org/anthology/P18-2058>`_.
    """

    word_vocab: tx.data.Vocab
    char_vocab: tx.data.Vocab
    model: LabeledSpanGraphNetwork

    def __init__(self):
        super().__init__()
        self.device = torch.device(
            torch.cuda.current_device() if torch.cuda.is_available() else "cpu"
        )

    def initialize(self, resources: Resources, configs: Optional[Config]):
        super().initialize(resources, configs)

        model_dir = configs.storage_path if configs is not None else None
        logger.info("restoring SRL model from %s", model_dir)

        # initialize the batcher
        if configs:
            self.batcher.initialize(configs.batcher)

        self.word_vocab = tx.data.Vocab(
            os.path.join(model_dir, "embeddings/word_vocab.english.txt")
        )
        self.char_vocab = tx.data.Vocab(
            os.path.join(model_dir, "embeddings/char_vocab.english.txt")
        )
        model_hparams = LabeledSpanGraphNetwork.default_hparams()
        model_hparams["context_embeddings"]["path"] = os.path.join(
            model_dir, model_hparams["context_embeddings"]["path"]
        )
        model_hparams["head_embeddings"]["path"] = os.path.join(
            model_dir, model_hparams["head_embeddings"]["path"]
        )
        self.model = LabeledSpanGraphNetwork(
            self.word_vocab, self.char_vocab, model_hparams
        )
        self.model.load_state_dict(
            torch.load(
                os.path.join(model_dir, "pretrained/model.pt"),
                map_location=self.device,
            )
        )
        self.model.eval()

    def predict(self, data_batch: Dict) -> Dict[str, List[Prediction]]:
        text: List[List[str]] = [
            sentence.tolist() for sentence in data_batch["Token"]["text"]
        ]
        text_ids, length = tx.data.padded_batch(
            [
                self.word_vocab.map_tokens_to_ids_py(sentence)
                for sentence in text
            ]
        )
        text_ids = torch.from_numpy(text_ids).to(device=self.device)
        length = torch.tensor(length, dtype=torch.long, device=self.device)
        batch_size = len(text)
        batch = tx.data.Batch(
            batch_size,
            text=text,
            text_ids=text_ids,
            length=length,
            srl=[[]] * batch_size,
        )
        self.model = self.model.to(self.device)
        batch_srl_spans = self.model.decode(batch)

        # Convert predictions into annotations.
        batch_predictions: List[Prediction] = []
        for idx, srl_spans in enumerate(batch_srl_spans):
            word_spans = data_batch["Token"]["span"][idx]
            predictions: Prediction = []
            for pred_idx, pred_args in srl_spans.items():
                begin, end = word_spans[pred_idx]
                # TODO cannot create annotation here.
                # Need to convert from Numpy numbers to int.
                pred_span = Span(begin.item(), end.item())
                arguments = []
                for arg in pred_args:
                    begin = word_spans[arg.start][0].item()
                    end = word_spans[arg.end][1].item()
                    arg_annotation = Span(begin, end)
                    arguments.append((arg_annotation, arg.label))
                predictions.append((pred_span, arguments))
            batch_predictions.append(predictions)
        return {"predictions": batch_predictions}

    def pack(
        self,
        pack: DataPack,
        predict_results: Dict[str, List[Prediction]],
        _: Optional[Annotation] = None,
    ):
        batch_predictions = predict_results["predictions"]
        for predictions in batch_predictions:
            for pred_span, arg_result in predictions:

                pred = PredicateMention(pack, pred_span.begin, pred_span.end)

                for arg_span, label in arg_result:
                    arg = PredicateArgument(pack, arg_span.begin, arg_span.end)
                    link = PredicateLink(pack, pred, arg)
                    link.arg_type = label

    @classmethod
    def default_configs(cls):
        """
        This defines the default configuration structure for the predictor.
        """
        return {
            "storage_path": None,
            "batcher": {
                "batch_size": 4,
                "context_type": "ft.onto.base_ontology.Sentence",
                "requests": {"ft.onto.base_ontology.Token": []},
            },
        }
示例#2
0
class SRLPredictor(FixedSizeBatchProcessor):
    """
    An Semantic Role labeler trained according to `He, Luheng, et al.
    "Jointly predicting predicates and arguments in neural semantic role
    labeling." <https://aclweb.org/anthology/P18-2058>`_.
    """

    word_vocab: tx.data.Vocab
    char_vocab: tx.data.Vocab
    model: LabeledSpanGraphNetwork

    def __init__(self):
        super().__init__()

        self.define_context()

        self.batch_size = 4
        self.batcher = self.define_batcher()

        self.device = torch.device(torch.cuda.current_device() if torch.cuda.
                                   is_available() else 'cpu')

    def initialize(self, _: Resources, configs: Optional[HParams]):

        model_dir = configs.storage_path if configs is not None else None
        logger.info("restoring SRL model from %s", model_dir)

        self.word_vocab = tx.data.Vocab(
            os.path.join(model_dir, "embeddings/word_vocab.english.txt"))
        self.char_vocab = tx.data.Vocab(
            os.path.join(model_dir, "embeddings/char_vocab.english.txt"))
        model_hparams = LabeledSpanGraphNetwork.default_hparams()
        model_hparams["context_embeddings"]["path"] = os.path.join(
            model_dir, model_hparams["context_embeddings"]["path"])
        model_hparams["head_embeddings"]["path"] = os.path.join(
            model_dir, model_hparams["head_embeddings"]["path"])
        self.model = LabeledSpanGraphNetwork(self.word_vocab, self.char_vocab,
                                             model_hparams)
        self.model.load_state_dict(
            torch.load(os.path.join(model_dir, "pretrained/model.pt"),
                       map_location=self.device))
        self.model.eval()

    def define_context(self):
        self.context_type = Sentence

    # pylint: disable=no-self-use
    def _define_input_info(self) -> DataRequest:
        input_info: DataRequest = {Token: []}
        return input_info

    def predict(self, data_batch: Dict) -> Dict[str, List[Prediction]]:
        text: List[List[str]] = [
            sentence.tolist() for sentence in data_batch["Token"]["text"]
        ]
        text_ids, length = tx.data.padded_batch([
            self.word_vocab.map_tokens_to_ids_py(sentence) for sentence in text
        ])
        text_ids = torch.from_numpy(text_ids).to(device=self.device)
        length = torch.tensor(length, dtype=torch.long, device=self.device)
        batch_size = len(text)
        batch = tx.data.Batch(batch_size,
                              text=text,
                              text_ids=text_ids,
                              length=length,
                              srl=[[]] * batch_size)
        self.model = self.model.to(self.device)
        batch_srl_spans = self.model.decode(batch)

        # Convert predictions into annotations.
        batch_predictions: List[Prediction] = []
        for idx, srl_spans in enumerate(batch_srl_spans):
            word_spans = data_batch["Token"]["span"][idx]
            predictions: Prediction = []
            for pred_idx, pred_args in srl_spans.items():
                begin, end = word_spans[pred_idx]
                # TODO cannot create annotation here.
                pred_span = Span(begin, end)
                arguments = []
                for arg in pred_args:
                    begin = word_spans[arg.start][0]
                    end = word_spans[arg.end][1]
                    arg_annotation = Span(begin, end)
                    arguments.append((arg_annotation, arg.label))
                predictions.append((pred_span, arguments))
            batch_predictions.append(predictions)
        return {"predictions": batch_predictions}

    def pack(self, data_pack: DataPack,
             inputs: Dict[str, List[Prediction]]) -> None:
        batch_predictions = inputs["predictions"]
        for predictions in batch_predictions:
            for pred_span, arg_result in predictions:

                pred = data_pack.add_entry(
                    PredicateMention(data_pack, pred_span.begin,
                                     pred_span.end))

                for arg_span, label in arg_result:
                    arg = data_pack.add_or_get_entry(
                        PredicateArgument(data_pack, arg_span.begin,
                                          arg_span.end))
                    link = PredicateLink(data_pack, pred, arg)
                    link.set_fields(arg_type=label)
                    data_pack.add_or_get_entry(link)

    @staticmethod
    def default_hparams():
        """
        This defines a basic Hparams structure
        :return:
        """
        hparams_dict = {
            'storage_path': None,
        }
        return hparams_dict