Ejemplo n.º 1
0
    def __init__(self, log_file: str = None, device: str = "cpu"):
        self.device = device
        self.models_cache_dir = pathlib.Path(MODELS_CACHE_DIR)

        if not self.models_cache_dir.is_dir():
            self.models_cache_dir.mkdir(parents=True)

        self.final_model_dir = self.models_cache_dir.joinpath(
            "sectlabel_elmo_bilstm")
        self.model_filepath = self.final_model_dir.joinpath("best_model.pt")
        self.data_dir = pathlib.Path(DATA_DIR)

        if not self.data_dir.is_dir():
            self.data_dir.mkdir(parents=True)

        self.train_data_url = DATA_FILE_URLS["SECT_LABEL_TRAIN_FILE"]
        self.dev_data_url = DATA_FILE_URLS["SECT_LABEL_DEV_FILE"]
        self.test_data_url = DATA_FILE_URLS["SECT_LABEL_TEST_FILE"]

        self.msg_printer = wasabi.Printer()
        self._download_if_required()
        self.data_manager = self._get_data()
        self.hparams = self._get_hparams()
        self.model = self._get_model()
        self.infer = self._get_infer_client()
        self.cli_interact = SciWINGInteract(self.infer)
        self.log_file = log_file

        if log_file:
            self.logger = setup_logger("sectlabel_logger",
                                       logfile=self.log_file,
                                       level=logging.INFO)
        else:
            self.logger = self.msg_printer
Ejemplo n.º 2
0
Archivo: i2b2.py Proyecto: yyht/sciwing
    def __init__(self):
        super(I2B2NER, self).__init__()
        self.models_cache_dir = pathlib.Path(MODELS_CACHE_DIR)

        if not self.models_cache_dir.is_dir():
            self.models_cache_dir.mkdir(parents=True)

        self.final_model_dir = self.models_cache_dir.joinpath("i2b2")
        self.model_filepath = self.final_model_dir.joinpath("best_model.pt")
        self.data_dir = pathlib.Path(DATA_DIR)

        if not self.data_dir.is_dir():
            self.data_dir.mkdir()

        self.train_data_url = DATA_FILE_URLS["I2B2_TRAIN"]
        self.dev_data_url = DATA_FILE_URLS["I2B2_DEV"]
        self.test_data_url = DATA_FILE_URLS["I2B2_DEV"]
        self.msg_printer = wasabi.Printer()
        self._download_if_required()
        self.hparams = self._get_hparams()
        self.data_manager = self._get_data()
        self.model: nn.Module = self._get_model()
        self.infer = self._get_infer_client()
        self.vis_tagger = VisTagging()
        self.cli_interact = SciWINGInteract(self.infer)
Ejemplo n.º 3
0
    def __init__(self):
        self.models_cache_dir = pathlib.Path(MODELS_CACHE_DIR)
        self.final_model_dir = self.models_cache_dir.joinpath(
            "genericsect_bow_elmo")

        if not self.models_cache_dir.is_dir():
            self.models_cache_dir.mkdir(parents=True)

        self.model_filepath = self.final_model_dir.joinpath("best_model.pt")
        self.data_dir = pathlib.Path(DATA_DIR)

        if not self.data_dir.is_dir():
            self.data_dir.mkdir(parents=True)

        self.train_data_url = DATA_FILE_URLS["GENERIC_SECTION_TRAIN_FILE"]
        self.dev_data_url = DATA_FILE_URLS["GENERIC_SECTION_DEV_FILE"]
        self.test_data_url = DATA_FILE_URLS["GENERIC_SECTION_TEST_FILE"]

        self.msg_printer = wasabi.Printer()
        self._download_if_required()
        self.data_manager = self._get_data()
        self.hparams = self._get_hparams()
        self.model = self._get_model()
        self.infer = self._get_infer_client()
        self.cli_interact = SciWINGInteract(self.infer)
Ejemplo n.º 4
0
    def __init__(self):
        super(CitationIntentClassification, self).__init__()
        self.models_cache_dir = pathlib.Path(MODELS_CACHE_DIR)

        if not self.models_cache_dir.is_dir():
            self.models_cache_dir.mkdir(parents=True)

        self.final_model_dir = self.models_cache_dir.joinpath(
            "citation_intent_clf_elmo")

        self.data_dir = pathlib.Path(DATA_DIR)

        if not self.data_dir.is_dir():
            self.data_dir.mkdir(parents=True)

        self.train_data_url = DATA_FILE_URLS["SCICITE_TRAIN"]
        self.dev_data_url = DATA_FILE_URLS["SCICITE_DEV"]
        self.test_data_url = DATA_FILE_URLS["SCICITE_TEST"]
        self.msg_printer = wasabi.Printer()
        self._download_if_required()
        self.hparams = self._get_hparams()
        self.data_manager = self._get_data()
        self.model: nn.Module = self._get_model()
        self.infer = self._get_infer_client()
        self.cli_interact = SciWINGInteract(infer_client=self.infer)
Ejemplo n.º 5
0
    def __init__(self, device=Optional[Tuple[torch.device, int]]):
        super(NeuralParscit, self).__init__()

        if isinstance(device, torch.device):
            self.device = device
        elif isinstance(device, int):
            if device == -1:
                device_string = "cpu"
            else:
                device_string = f"cuda:{device}"
            self.device = torch.device(device_string)
        else:
            raise ValueError(
                f"Pass the device number or the device object from Pytorch"
            )

        self.models_cache_dir = pathlib.Path(MODELS_CACHE_DIR)
        self.final_model_dir = self.models_cache_dir.joinpath("lstm_crf_parscit_final")
        if not self.models_cache_dir.is_dir():
            self.models_cache_dir.mkdir(parents=True)
        self.model_filepath = self.final_model_dir.joinpath("best_model.pt")
        self.data_dir = pathlib.Path(DATA_DIR)

        if not self.data_dir.is_dir():
            self.data_dir.mkdir(parents=True)

        self.train_data_file_url = DATA_FILE_URLS["PARSCIT_TRAIN"]
        self.dev_data_file_url = DATA_FILE_URLS["PARSCIT_DEV"]
        self.test_data_file_url = DATA_FILE_URLS["PARSCIT_TEST"]
        self.msg_printer = wasabi.Printer()
        self._download_if_required()
        self.hparams = self._get_hparams()
        self.data_manager = self._get_data()
        self.model: nn.Module = self._get_model()
        self.infer = self._get_infer_client()
        self.vis_tagger = VisTagging()
        self.interact_ = SciWINGInteract(self.infer)
Ejemplo n.º 6
0
class NeuralParscit(nn.Module):
    """ It defines a neural parscit model. The model is used for citation string parsing. This model
    helps you use a pre-trained model who architecture is fixed and is trained by SciWING.
    You can also fine-tune the model on your own dataset.

    For practitioners, we provide ways to obtain results quickly from a set of citations
    stored in a file or from a string. If you want to see the demo head over to our demo site.

    """

    def __init__(self, device=Optional[Tuple[torch.device, int]]):
        super(NeuralParscit, self).__init__()

        if isinstance(device, torch.device):
            self.device = device
        elif isinstance(device, int):
            if device == -1:
                device_string = "cpu"
            else:
                device_string = f"cuda:{device}"
            self.device = torch.device(device_string)
        else:
            raise ValueError(
                f"Pass the device number or the device object from Pytorch"
            )

        self.models_cache_dir = pathlib.Path(MODELS_CACHE_DIR)
        self.final_model_dir = self.models_cache_dir.joinpath("lstm_crf_parscit_final")
        if not self.models_cache_dir.is_dir():
            self.models_cache_dir.mkdir(parents=True)
        self.model_filepath = self.final_model_dir.joinpath("best_model.pt")
        self.data_dir = pathlib.Path(DATA_DIR)

        if not self.data_dir.is_dir():
            self.data_dir.mkdir(parents=True)

        self.train_data_file_url = DATA_FILE_URLS["PARSCIT_TRAIN"]
        self.dev_data_file_url = DATA_FILE_URLS["PARSCIT_DEV"]
        self.test_data_file_url = DATA_FILE_URLS["PARSCIT_TEST"]
        self.msg_printer = wasabi.Printer()
        self._download_if_required()
        self.hparams = self._get_hparams()
        self.data_manager = self._get_data()
        self.model: nn.Module = self._get_model()
        self.infer = self._get_infer_client()
        self.vis_tagger = VisTagging()
        self.interact_ = SciWINGInteract(self.infer)

    def _get_model(self) -> nn.Module:
        word_embedder = TrainableWordEmbedder(
            embedding_type=self.hparams.get("emb_type"),
            datasets_manager=self.data_manager,
            device=self.device,
        )

        char_embedder = CharEmbedder(
            char_embedding_dimension=self.hparams.get("char_emb_dim"),
            hidden_dimension=self.hparams.get("char_encoder_hidden_dim"),
            datasets_manager=self.data_manager,
            device=self.device,
        )

        elmo_embedder = BowElmoEmbedder(
            datasets_manager=self.data_manager,
            layer_aggregation="sum",
            device=self.device,
        )

        embedder = ConcatEmbedders([word_embedder, char_embedder, elmo_embedder])

        lstm2seqencoder = Lstm2SeqEncoder(
            embedder=embedder,
            hidden_dim=self.hparams.get("hidden_dim"),
            bidirectional=self.hparams.get("bidirectional"),
            combine_strategy=self.hparams.get("combine_strategy"),
            rnn_bias=True,
            dropout_value=self.hparams.get("lstm2seq_dropout", 0.0),
            add_projection_layer=False,
            device=self.device,
        )
        model = RnnSeqCrfTagger(
            rnn2seqencoder=lstm2seqencoder,
            encoding_dim=2 * self.hparams.get("hidden_dim")
            if self.hparams.get("bidirectional")
            and self.hparams.get("combine_strategy") == "concat"
            else self.hparams.get("hidden_dim"),
            datasets_manager=self.data_manager,
            device=self.device,
        )

        return model

    def _get_infer_client(self):
        infer_client = SequenceLabellingInference(
            model=self.model,
            model_filepath=self.final_model_dir.joinpath("best_model.pt"),
            datasets_manager=self.data_manager,
            device=self.device,
        )
        return infer_client

    def _predict(self, line: str):
        predictions = self.infer.on_user_input(line=line)
        return predictions

    def predict_for_file(self, filename: str) -> List[str]:
        """ Parse the references in a file where every line is a reference

        Parameters
        ----------
        filename : str
            The filename where the references are stored

        Returns
        -------
        List[str]
            A list of parsed tags

        """
        predictions = defaultdict(list)
        with open(filename, "r") as fp:
            for line_idx, line in enumerate(fp):
                line = line.strip()
                pred_ = self._predict(line=line)
                for namespace, prediction in pred_.items():
                    predictions[namespace].append(prediction[0])
                    stylized_string = self.vis_tagger.visualize_tokens(
                        text=line.split(), labels=prediction[0].split()
                    )
                    self.msg_printer.divider(
                        f"Predictions for Line: {line_idx+1} from {filename}"
                    )
                    print(stylized_string)
                    print("\n")

        return predictions[self.data_manager.label_namespaces[0]]

    def predict_for_text(self, text: str, show=True) -> str:
        """ Parse the citation string for the given text

        Parameters
        ----------
        text : str
            reference string to parse
        show : bool
            If `True`, then we print the stylized string - where the stylized string provides
            different colors for different tags
            If `False` - then we do not print the stylized string

        Returns
        -------
        str
            The parsed citation string

        """
        predictions = self._predict(line=text)
        for namespace, prediction in predictions.items():
            if show:
                self.msg_printer.divider(f"Prediction for {namespace.upper()}")
                stylized_string = self.vis_tagger.visualize_tokens(
                    text=text.split(), labels=prediction[0].split()
                )
                print(stylized_string)
            return prediction[0]

    def _get_data(self):
        data_manager = SeqLabellingDatasetManager(
            train_filename=cached_path(
                path=self.data_dir.joinpath("parscit.train"),
                url=self.train_data_file_url,
                unzip=False,
            ),
            dev_filename=cached_path(
                path=self.data_dir.joinpath("parscit.dev"),
                url=self.dev_data_file_url,
                unzip=False,
            ),
            test_filename=cached_path(
                path=self.data_dir.joinpath("parscit.test"),
                url=self.test_data_file_url,
                unzip=False,
            ),
        )
        return data_manager

    def _get_hparams(self):
        with open(self.final_model_dir.joinpath("hyperparams.json")) as fp:
            hyperparams = json.load(fp)
        return hyperparams

    def _download_if_required(self):
        # download the model weights and data to client machine
        cached_path(
            path=f"{self.final_model_dir}.zip",
            url="https://parsect-models.s3-ap-southeast-1.amazonaws.com/lstm_crf_parscit_final.zip",
            unzip=True,
        )

    def interact(self):
        """ Interact with the pretrained model
        You can also interact from command line using `sciwing interact neural-parscit`
        """
        self.interact_.interact()
Ejemplo n.º 7
0
            train_filename=train_filename,
            dev_filename=dev_filename,
            test_filename=test_filename,
        )
        return data_manager

    def build_infer(self):
        infer = SequenceLabellingInference(
            model=self.model,
            model_filepath=self.hparams.get("model_filepath"),
            datasets_manager=self.data_manager,
        )
        return infer


if __name__ == "__main__":
    dirname = pathlib.Path(".", "output")
    model_filepath = dirname.joinpath("checkpoints", "best_model.pt")
    hparams = {
        "emb_type": "parscit",
        "char_emb_dim": 25,
        "char_encoder_hidden_dim": 50,
        "hidden_dim": 256,
        "bidirectional": True,
        "combine_strategy": "concat",
        "model_filepath": str(model_filepath),
    }
    parscit_inference = BuildParscitInterference(hparams)
    cli = SciWINGInteract(parscit_inference)
    cli.interact()
Ejemplo n.º 8
0
        data_manager = TextClassificationDatasetManager(
            train_filename=train_filename,
            dev_filename=dev_filename,
            test_filename=test_filename,
        )
        return data_manager

    def build_infer(self):
        parsect_inference = ClassificationInference(
            model=self.model,
            model_filepath=self.hparams.get("model_filepath"),
            datasets_manager=self.data_manager,
        )
        return parsect_inference


if __name__ == "__main__":
    dirname = pathlib.Path(".", "output")
    model_filepath = dirname.joinpath("checkpoints", "best_model.pt")
    hparams = {
        "layer_aggregation": "last",
        "word_aggregation": "sum",
        "encoding_dim": 1024,
        "num_classes": 12,
        "model_filepath": model_filepath,
    }
    infer = BuildGenericSectBowElmo(hparams)
    cli = SciWINGInteract(infer)
    cli.interact()
Ejemplo n.º 9
0
        test_filename = data_dir.joinpath("genericSect.test")

        data_manager = TextClassificationDatasetManager(
            train_filename=train_filename,
            dev_filename=dev_filename,
            test_filename=test_filename,
        )
        return data_manager

    def build_infer(self):
        inference = ClassificationInference(
            model=self.model,
            model_filepath=self.hparams.get("model_filepath"),
            datasets_manager=self.data_manager,
        )
        return inference


if __name__ == "__main__":
    dirname = pathlib.Path(".", "output")
    model_filepath = dirname.joinpath("checkpoints", "best_model.pt")
    hparams = {
        "emb_type": "glove_6B_50",
        "model_filepath": str(model_filepath),
        "num_classes": 12,
        "encoding_dim": 50,
    }
    sectlabel_infer = BuildGenericSectBowRandom(hparams)
    cli = SciWINGInteract(sectlabel_infer)
    cli.interact()
Ejemplo n.º 10
0
Archivo: i2b2.py Proyecto: yyht/sciwing
class I2B2NER(nn.Module):
    """ It defines a I2B2 clinical NER model trained using SciWING

    For practitioners, we provide ways to obtain results quickly from a set of citations
    stored in a file or from a string. If you want to see the demo head over to our demo site.

    """
    def __init__(self):
        super(I2B2NER, self).__init__()
        self.models_cache_dir = pathlib.Path(MODELS_CACHE_DIR)

        if not self.models_cache_dir.is_dir():
            self.models_cache_dir.mkdir(parents=True)

        self.final_model_dir = self.models_cache_dir.joinpath("i2b2")
        self.model_filepath = self.final_model_dir.joinpath("best_model.pt")
        self.data_dir = pathlib.Path(DATA_DIR)

        if not self.data_dir.is_dir():
            self.data_dir.mkdir()

        self.train_data_url = DATA_FILE_URLS["I2B2_TRAIN"]
        self.dev_data_url = DATA_FILE_URLS["I2B2_DEV"]
        self.test_data_url = DATA_FILE_URLS["I2B2_DEV"]
        self.msg_printer = wasabi.Printer()
        self._download_if_required()
        self.hparams = self._get_hparams()
        self.data_manager = self._get_data()
        self.model: nn.Module = self._get_model()
        self.infer = self._get_infer_client()
        self.vis_tagger = VisTagging()
        self.cli_interact = SciWINGInteract(self.infer)

    def _get_model(self) -> nn.Module:
        word_embedder = TrainableWordEmbedder(
            embedding_type=self.hparams.get("emb_type"),
            datasets_manager=self.data_manager,
        )

        elmo_embedder = BowElmoEmbedder(datasets_manager=self.data_manager,
                                        layer_aggregation="sum")

        embedder = ConcatEmbedders([word_embedder, elmo_embedder])

        lstm2seqencoder = Lstm2SeqEncoder(
            embedder=embedder,
            hidden_dim=self.hparams.get("hidden_dim"),
            bidirectional=self.hparams.get("bidirectional"),
            combine_strategy=self.hparams.get("combine_strategy"),
            rnn_bias=True,
            dropout_value=self.hparams.get("lstm2seq_dropout", 0.0),
            add_projection_layer=False,
        )
        model = RnnSeqCrfTagger(
            rnn2seqencoder=lstm2seqencoder,
            encoding_dim=2 *
            self.hparams.get("hidden_dim") if self.hparams.get("bidirectional")
            and self.hparams.get("combine_strategy") == "concat" else
            self.hparams.get("hidden_dim"),
            datasets_manager=self.data_manager,
        )

        return model

    def _get_infer_client(self):
        infer_client = SequenceLabellingInference(
            model=self.model,
            model_filepath=self.final_model_dir.joinpath("best_model.pt"),
            datasets_manager=self.data_manager,
        )
        return infer_client

    def _predict(self, line: str):
        predictions = self.infer.on_user_input(line=line)
        return predictions

    def predict_for_file(self, filename: str) -> List[str]:
        predictions = defaultdict(list)
        with open(filename, "r") as fp:
            for line_idx, line in enumerate(fp):
                line = line.strip()
                pred_ = self._predict(line=line)
                for namespace, prediction in pred_.items():
                    predictions[namespace].append(prediction[0])
                    stylized_string = self.vis_tagger.visualize_tokens(
                        text=line.split(), labels=prediction[0].split())
                    self.msg_printer.divider(
                        f"Predictions for Line: {line_idx+1} from {filename}")
                    print(stylized_string)
                    print("\n")

        return predictions[self.data_manager.label_namespaces[0]]

    def predict_for_text(self, text: str):
        predictions = self._predict(line=text)
        for namespace, prediction in predictions.items():
            self.msg_printer.divider(f"Prediction for {namespace.upper()}")
            stylized_string = self.vis_tagger.visualize_tokens(
                text=text.split(), labels=prediction[0].split())
            print(stylized_string)
            return prediction[0]

    def _get_data(self):
        train_filename = cached_path(
            path=self.data_dir.joinpath("i2b2.train"),
            url=self.train_data_url,
            unzip=False,
        )

        dev_filename = cached_path(path=self.data_dir.joinpath("i2b2.dev"),
                                   url=self.dev_data_url,
                                   unzip=False)

        test_filename = cached_path(path=self.data_dir.joinpath("i2b2.dev"),
                                    url=self.dev_data_url,
                                    unzip=False)

        data_manager = CoNLLDatasetManager(
            train_filename=train_filename,
            dev_filename=dev_filename,
            test_filename=test_filename,
            column_names=["NER", "NER", "NER"],
            train_only="ner",
        )
        return data_manager

    def _get_hparams(self):
        with open(self.final_model_dir.joinpath("hyperparams.json")) as fp:
            hyperparams = json.load(fp)
        return hyperparams

    def _download_if_required(self):
        # download the model weights and data to client machine
        cached_path(
            path=f"{self.final_model_dir}.zip",
            url=
            "https://parsect-models.s3-ap-southeast-1.amazonaws.com/i2b2.zip",
            unzip=True,
        )

    def interact(self):
        self.cli_interact.interact()
Ejemplo n.º 11
0
class GenericSect:
    def __init__(self):
        self.models_cache_dir = pathlib.Path(MODELS_CACHE_DIR)
        self.final_model_dir = self.models_cache_dir.joinpath(
            "genericsect_bow_elmo")

        if not self.models_cache_dir.is_dir():
            self.models_cache_dir.mkdir(parents=True)

        self.model_filepath = self.final_model_dir.joinpath("best_model.pt")
        self.data_dir = pathlib.Path(DATA_DIR)

        if not self.data_dir.is_dir():
            self.data_dir.mkdir(parents=True)

        self.train_data_url = DATA_FILE_URLS["GENERIC_SECTION_TRAIN_FILE"]
        self.dev_data_url = DATA_FILE_URLS["GENERIC_SECTION_DEV_FILE"]
        self.test_data_url = DATA_FILE_URLS["GENERIC_SECTION_TEST_FILE"]

        self.msg_printer = wasabi.Printer()
        self._download_if_required()
        self.data_manager = self._get_data()
        self.hparams = self._get_hparams()
        self.model = self._get_model()
        self.infer = self._get_infer_client()
        self.cli_interact = SciWINGInteract(self.infer)

    def _get_model(self):
        embedder = BowElmoEmbedder(
            layer_aggregation=self.hparams.get("layer_aggregation"),
            datasets_manager=self.data_manager,
        )
        encoder = BOW_Encoder(
            aggregation_type=self.hparams.get("word_aggregation"),
            embedder=embedder)

        model = SimpleClassifier(
            encoder=encoder,
            encoding_dim=1024,
            num_classes=12,
            classification_layer_bias=True,
            datasets_manager=self.data_manager,
        )
        return model

    def _get_infer_client(self):
        client = ClassificationInference(
            model=self.model,
            model_filepath=self.final_model_dir.joinpath("best_model.pt"),
            datasets_manager=self.data_manager,
        )
        return client

    def predict_for_file(self, filename: str) -> List[str]:
        """ Make predictions for every line in the file

        Parameters
        ----------
        filename: str
            The filename where section headers are stored one per line

        Returns
        -------
        List[str]
            A list of predictions

        """
        lines = []
        with open(filename) as fp:
            for line in fp:
                lines.append(line)

        predictions = self.infer.infer_batch(lines=lines)
        for line, prediction in zip(lines, predictions):
            self.msg_printer.text(title=line, text=prediction)

        return predictions

    def predict_for_text(self, text: str, show=True) -> str:
        """ Predicts the generic section headers of the text

        Parameters
        ----------
        text: str
            The section header string to be normalized
        show : bool
            If True then we print the prediction.

        Returns
        -------
        str
            The prediction for the section header

        """
        prediction = self.infer.on_user_input(line=text)
        if show:
            self.msg_printer.text(title=text, text=prediction)
        return prediction

    def _get_data(self):
        train_filename = self.data_dir.joinpath("genericSect.train")
        dev_filename = self.data_dir.joinpath("genericSect.dev")
        test_filename = self.data_dir.joinpath("genericSect.test")

        train_filename = cached_path(path=train_filename,
                                     url=self.train_data_url,
                                     unzip=False)

        dev_filename = cached_path(path=dev_filename,
                                   url=self.dev_data_url,
                                   unzip=False)

        test_filename = cached_path(path=test_filename,
                                    url=self.test_data_url,
                                    unzip=False)

        data_manager = TextClassificationDatasetManager(
            train_filename=train_filename,
            dev_filename=dev_filename,
            test_filename=test_filename,
        )

        return data_manager

    def _get_hparams(self):
        with open(self.final_model_dir.joinpath("hyperparams.json")) as fp:
            hyperparams = json.load(fp)
        return hyperparams

    def _download_if_required(self):
        cached_path(
            path=f"{self.final_model_dir}.zip",
            url=
            "https://parsect-models.s3-ap-southeast-1.amazonaws.com/genericsect_bow_elmo.zip",
            unzip=True,
        )

    def interact(self):
        """ Interact with the pretrained model
        """
        self.cli_interact.interact()
Ejemplo n.º 12
0
class SectLabel:
    def __init__(self, log_file: str = None, device: str = "cpu"):
        self.device = device
        self.models_cache_dir = pathlib.Path(MODELS_CACHE_DIR)

        if not self.models_cache_dir.is_dir():
            self.models_cache_dir.mkdir(parents=True)

        self.final_model_dir = self.models_cache_dir.joinpath(
            "sectlabel_elmo_bilstm")
        self.model_filepath = self.final_model_dir.joinpath("best_model.pt")
        self.data_dir = pathlib.Path(DATA_DIR)

        if not self.data_dir.is_dir():
            self.data_dir.mkdir(parents=True)

        self.train_data_url = DATA_FILE_URLS["SECT_LABEL_TRAIN_FILE"]
        self.dev_data_url = DATA_FILE_URLS["SECT_LABEL_DEV_FILE"]
        self.test_data_url = DATA_FILE_URLS["SECT_LABEL_TEST_FILE"]

        self.msg_printer = wasabi.Printer()
        self._download_if_required()
        self.data_manager = self._get_data()
        self.hparams = self._get_hparams()
        self.model = self._get_model()
        self.infer = self._get_infer_client()
        self.cli_interact = SciWINGInteract(self.infer)
        self.log_file = log_file

        if log_file:
            self.logger = setup_logger("sectlabel_logger",
                                       logfile=self.log_file,
                                       level=logging.INFO)
        else:
            self.logger = self.msg_printer

    def _get_model(self):
        elmo_embedder = BowElmoEmbedder(layer_aggregation="sum",
                                        device=self.device)

        # instantiate the vanilla embedder
        vanilla_embedder = WordEmbedder(
            embedding_type=self.hparams.get("emb_type"), device=self.device)

        # concat the embeddings
        embedder = ConcatEmbedders([vanilla_embedder, elmo_embedder])

        hidden_dim = self.hparams.get("hidden_dim")
        bidirectional = self.hparams.get("bidirectional")
        combine_strategy = self.hparams.get("combine_strategy")

        encoder = LSTM2VecEncoder(
            embedder=embedder,
            hidden_dim=hidden_dim,
            bidirectional=bidirectional,
            combine_strategy=combine_strategy,
            device=self.device,
        )

        encoding_dim = (2 * hidden_dim if bidirectional
                        and combine_strategy == "concat" else hidden_dim)

        model = SimpleClassifier(
            encoder=encoder,
            encoding_dim=encoding_dim,
            num_classes=23,
            classification_layer_bias=True,
            datasets_manager=self.data_manager,
            device=self.device,
        )
        model.to(self.device)

        return model

    def _get_infer_client(self):
        client = ClassificationInference(
            model=self.model,
            model_filepath=self.final_model_dir.joinpath("best_model.pt"),
            datasets_manager=self.data_manager,
            device=self.device,
        )
        return client

    def predict_for_file(self, filename: str) -> List[str]:
        """ Predicts the logical sections for all the sentences in a file, with one sentence per line

        Parameters
        ----------
        filename : str
            The path of the file

        Returns
        -------
        List[str]
            The predictions for each line.

        """
        lines = []
        with open(filename) as fp:
            for line in fp:
                lines.append(line)

        predictions = self.infer.infer_batch(lines=lines)
        for line, prediction in zip(lines, predictions):
            self.msg_printer.text(title=line, text=prediction)

        return predictions

    def predict_for_pdf(self,
                        pdf_filename: pathlib.Path) -> (List[str], List[str]):
        """ Predicts lines and labels given a pdf filename

        Parameters
        ----------
        pdf_filename : pathlib.Path
            The location where pdf files are stored

        Returns
        -------
        List[str], List[str]
            The lines and labels inferred on the file
        """
        pdf_reader = PdfReader(filepath=pdf_filename)
        lines = pdf_reader.read_pdf()

        lines = self._preprocess(lines)

        if len(lines) == 0:
            self.logger.warning(f"No lines were read from file {pdf_filename}")
            return ""

        all_labels = []
        all_lines = []

        for batch_lines in chunks(lines, 64):
            labels = self.infer.infer_batch(lines=batch_lines)
            all_labels.append(labels)
            all_lines.append(batch_lines)

        all_lines = itertools.chain.from_iterable(all_lines)
        all_labels = itertools.chain.from_iterable(all_labels)
        all_lines = list(all_lines)
        all_labels = list(all_labels)

        return all_lines, all_labels

    def predict_for_text(self, text: str) -> str:
        """ Predicts the logical section that the line belongs to

        Parameters
        ----------
        text: str
            A single line of text

        Returns
        -------
        str
            The logical section of the text.

        """
        prediction = self.infer.on_user_input(line=text)
        self.msg_printer.text(title=text, text=prediction)
        return prediction

    def predict_for_text_batch(self, texts: List[str]) -> List[str]:
        """ Predicts the logical section for a batch of text.

        Parameters
        ----------
        texts: List[str]
            A batch of text

        Returns
        -------
        List[str]
            A batch of predictions

        """
        predictions = self.infer.infer_batch(lines=texts)
        return predictions

    def _get_data(self):
        train_filename = self.data_dir.joinpath("sectLabel.train")
        dev_filename = self.data_dir.joinpath("sectLabel.dev")
        test_filename = self.data_dir.joinpath("sectLabel.test")

        train_filename = cached_path(path=train_filename,
                                     url=self.train_data_url,
                                     unzip=False)
        dev_filename = cached_path(path=dev_filename,
                                   url=self.dev_data_url,
                                   unzip=False)

        test_filename = cached_path(path=test_filename,
                                    url=self.test_data_url,
                                    unzip=False)

        data_manager = TextClassificationDatasetManager(
            train_filename=train_filename,
            dev_filename=dev_filename,
            test_filename=test_filename,
        )

        return data_manager

    def _get_hparams(self):
        with open(self.final_model_dir.joinpath("hyperparams.json")) as fp:
            hyperparams = json.load(fp)
        return hyperparams

    def _download_if_required(self):
        cached_path(
            path=f"{self.final_model_dir}.zip",
            url=
            "https://parsect-models.s3-ap-southeast-1.amazonaws.com/sectlabel_elmo_bilstm.zip",
            unzip=True,
        )

    @staticmethod
    def _preprocess(lines: str):
        preprocessed_lines = []
        for line in lines:
            line_ = line.strip()
            if bool(line_):
                line_words = line_.split()
                num_single_character_words = sum(
                    [1 for word in line_words if len(word) == 1])
                num_words = len(line_words)
                percentage_single_character_words = (
                    num_single_character_words / num_words) * 100
                if percentage_single_character_words > 40:
                    line_ = "".join(line_words)
                    preprocessed_lines.append(line_)
                else:
                    preprocessed_lines.append(line_)
        return preprocessed_lines

    @staticmethod
    def _extract_abstract_for_file(lines: List[str],
                                   labels: List[str]) -> List[str]:
        """ Given the linse

        Parameters
        ----------
        lines: List[str]
            A set of lines
        labels: List[str]
            A set of labels

        Returns
        -------
        List[str]
            Lines in the abstract

        """
        response_tuples = []
        for line, label in zip(lines, labels):
            response_tuples.append((line, label))

        abstract_lines = []
        found_abstract = False
        for line, label in response_tuples:
            if label == "sectionHeader" and line.strip().lower() == "abstract":
                found_abstract = True
                continue
            if found_abstract and label == "sectionHeader":
                break
            if found_abstract:
                abstract_lines.append(line.strip())

        return abstract_lines

    def dehyphenate(self, lines: List[str]) -> List[str]:
        """ Dehyphenates a list of strings

        Parameters
        ----------
        lines: List[str]
            A list of hyphenated strings

        Returns
        -------
        List[str]
            A list of dehyphenated strings
        """
        buffer_lines = []  # holds lines that should be a single line
        final_lines = []
        for line in lines:
            if line.endswith("-"):
                line_ = line.replace("-", "")  # replace the hyphen
                buffer_lines.append(line_)
            else:
                # if the hyphenation ended on the previous
                # line then the next line also needs to be
                # added to the buffer line
                if len(buffer_lines) > 0:
                    buffer_lines.append(line)

                    line_ = "".join(buffer_lines)

                    # add the line from buffer first
                    final_lines.append(line_)

                else:
                    # add the current line
                    final_lines.append(line)

                buffer_lines = []
        return final_lines

    def extract_abstract_for_file(self,
                                  pdf_filename: pathlib.Path,
                                  dehyphenate: bool = True) -> str:
        """ Extracts abstracts from a pdf using sectlabel. This is the python programmatic version of
        the API. The APIs can be found in sciwing/api. You can see that for more information

        Parameters
        ----------
        pdf_filename : pathlib.Path
            The path where the pdf is stored
        dehyphenate : bool
            Scientific documents are two columns sometimes and there are a lot of hyphenation
            introduced. If this is true, we remove the hyphens from the code

        Returns
        -------
        str
            The abstract of the pdf

        """
        self.msg_printer.info(f"Extracting abstract for {pdf_filename}")
        all_lines, all_labels = self.predict_for_pdf(pdf_filename=pdf_filename)
        abstract_lines = self._extract_abstract_for_file(lines=all_lines,
                                                         labels=all_labels)

        if dehyphenate:
            abstract_lines = self.dehyphenate(abstract_lines)

        abstract = " ".join(abstract_lines)
        return abstract

    def extract_abstract_for_folder(self,
                                    foldername: pathlib.Path,
                                    dehyphenate=True):
        """ Extracts the abstracts for all the pdf fils stored in a folder

        Parameters
        ----------
        foldername : pathlib.Path
            THe path of the folder containing pdf files
        dehyphenate : bool
            We will try to dehyphenate the lines. Useful if the pdfs are two column research paper

        Returns
        -------
        None
            Writes the abstracts to files

        """
        num_files = sum([1 for file in foldername.iterdir()])
        for file in tqdm(foldername.iterdir(),
                         total=num_files,
                         desc="Extracting Abstracts"):
            if file.suffix == ".pdf":
                abstract = self.extract_abstract_for_file(
                    pdf_filename=file, dehyphenate=dehyphenate)
                self.msg_printer.text(title="abstract", text=abstract)
                with open(f"{file.stem}.abstract", "w") as fp:
                    fp.write(abstract)
                    fp.write("\n")

    @staticmethod
    def _extract_section_headers(lines: List[str],
                                 labels: List[str]) -> List[str]:
        section_headers = []
        for line, label in zip(lines, labels):
            if label == "sectionHeader" or label == "subsectionHeader":
                section_headers.append(line.strip())

        return section_headers

    def _extract_references(self, lines: List[str],
                            labels: List[str]) -> List[str]:
        references = []
        for line, label in zip(lines, labels):
            if label == "reference":
                references.append(line.strip())

        # references = self.dehyphenate(references)

        return references

    def extract_all_info(self, pdf_filename: pathlib.Path):
        """ Extracts information from the pdf file.

        Parameters
        ----------
        pdf_filename: pathlib.Path
            The path of the pdf file

        Returns
        -------
        Dict[str, Any]
            A dictionary containing information parsed from the pdf file

        """
        all_lines, all_labels = self.predict_for_pdf(pdf_filename=pdf_filename)
        abstract = self._extract_abstract_for_file(lines=all_lines,
                                                   labels=all_labels)
        abstract = " ".join(abstract)
        section_headers = self._extract_section_headers(lines=all_lines,
                                                        labels=all_labels)
        reference_strings = self._extract_references(lines=all_lines,
                                                     labels=all_labels)

        return {
            "abstract": abstract,
            "section_headers": section_headers,
            "references": reference_strings,
        }

    def interact(self):
        """ Interact with the pre-trained model
        """
        self.cli_interact.interact()
Ejemplo n.º 13
0
            dev_filename=dev_filename,
            test_filename=test_filename,
        )
        return data_manager

    def build_infer(self):
        inference = ClassificationInference(
            model=self.model,
            model_filepath=self.hparams.get("model_filepath"),
            datasets_manager=self.data_manager,
        )

        return inference


if __name__ == "__main__":
    dirname = pathlib.Path("./backend/abstract_tagging/sciwing/",
                           "coda19_classification_elmo_slower")
    model_filepath = dirname.joinpath("checkpoints", "best_model.pt")
    hparams = {
        "embedding_type": "glove_6B_100",
        "hidden_dim": 50,
        "bidirectional": True,
        "combine_strategy": "concat",
        "num_classes": 5,
        "model_filepath": model_filepath,
        "device": "cpu",
    }
    infer = BuildCoda19ClassificationInfer(hparams=hparams)
    cli = SciWINGInteract(infer_client=infer)
    cli.interact()
Ejemplo n.º 14
0
class CitationIntentClassification(nn.Module):
    def __init__(self):
        super(CitationIntentClassification, self).__init__()
        self.models_cache_dir = pathlib.Path(MODELS_CACHE_DIR)

        if not self.models_cache_dir.is_dir():
            self.models_cache_dir.mkdir(parents=True)

        self.final_model_dir = self.models_cache_dir.joinpath(
            "citation_intent_clf_elmo")

        self.data_dir = pathlib.Path(DATA_DIR)

        if not self.data_dir.is_dir():
            self.data_dir.mkdir(parents=True)

        self.train_data_url = DATA_FILE_URLS["SCICITE_TRAIN"]
        self.dev_data_url = DATA_FILE_URLS["SCICITE_DEV"]
        self.test_data_url = DATA_FILE_URLS["SCICITE_TEST"]
        self.msg_printer = wasabi.Printer()
        self._download_if_required()
        self.hparams = self._get_hparams()
        self.data_manager = self._get_data()
        self.model: nn.Module = self._get_model()
        self.infer = self._get_infer_client()
        self.cli_interact = SciWINGInteract(infer_client=self.infer)

    def _get_model(self) -> nn.Module:
        embedding_type = self.hparams.get("emb_type")
        word_embedder = WordEmbedder(embedding_type=embedding_type)
        elmo_embedder = ElmoEmbedder(datasets_manager=self.data_manager)
        embedder = ConcatEmbedders([word_embedder, elmo_embedder])

        hidden_dim = self.hparams.get("hidden_dim")
        combine_strategy = self.hparams.get("combine_strategy")
        bidirectional = self.hparams.get("bidirectional")

        encoder = LSTM2VecEncoder(
            embedder=embedder,
            hidden_dim=hidden_dim,
            combine_strategy=combine_strategy,
            bidirectional=bidirectional,
        )

        classifier_encoding_dim = 2 * hidden_dim if bidirectional else hidden_dim
        model = SimpleClassifier(
            encoder=encoder,
            encoding_dim=classifier_encoding_dim,
            num_classes=3,
            classification_layer_bias=True,
            datasets_manager=self.data_manager,
        )
        return model

    def _get_infer_client(self):
        client = ClassificationInference(
            model=self.model,
            model_filepath=self.final_model_dir.joinpath(
                "checkpoints", "best_model.pt"),
            datasets_manager=self.data_manager,
        )
        return client

    def predict_for_file(self, filename: str) -> List[str]:
        """ Predict the intents for all the citations in the filename
        The citations should be contained one per line

        Parameters
        ----------
        filename : str
            The filename where the citations are stored

        Returns
        -------
        List[str]
            Returns the intents for each line of citation

        """
        with open(filename, "r") as fp:
            lines = []
            for line in fp:
                line = line.strip()
                lines.append(line)

            predictions = self.infer.infer_batch(lines=lines)
            for prediction, line in zip(predictions, lines):
                self.msg_printer.text(title=line, text=prediction)

        return predictions

    def predict_for_text(self, text: str) -> str:
        """ Predict the intent for citation

        Parameters
        ----------
        text : str
            The citation string

        Returns
        -------
        str
            The predicted label for the citation

        """
        label = self.infer.on_user_input(line=text)
        self.msg_printer.text(title=text, text=label)
        return label

    def _get_data(self):
        train_file = cached_path(
            path=self.data_dir.joinpath("scicite.train"),
            url=self.train_data_url,
            unzip=False,
        )
        dev_file = cached_path(
            path=self.data_dir.joinpath("scicite.dev"),
            url=self.dev_data_url,
            unzip=False,
        )
        test_file = cached_path(
            path=self.data_dir.joinpath("scicite.test"),
            url=self.test_data_url,
            unzip=False,
        )

        data_manager = TextClassificationDatasetManager(
            train_filename=train_file,
            dev_filename=dev_file,
            test_filename=test_file)
        return data_manager

    def _get_hparams(self):
        with open(
                self.final_model_dir.joinpath("checkpoints",
                                              "hyperparams.json")) as fp:
            hyperparams = json.load(fp)
        return hyperparams

    def _download_if_required(self):
        # download the model weights and data to client machine
        cached_path(
            path=f"{self.final_model_dir}.zip",
            url=
            "https://parsect-models.s3-ap-southeast-1.amazonaws.com/citation_intent_clf_elmo.zip",
            unzip=True,
        )

    def interact(self):
        """ Interact with the pretrained model
        """
        self.cli_interact.interact()