def test_batch_prediction(self):
        inputs = [
                {"sentence": "What a great test sentence."},
                {"sentence": "Here's another good, interesting one."}
        ]

        archive = load_archive(self.FIXTURES_ROOT / 'constituency_parser' / 'serialization' / 'model.tar.gz')
        predictor = Predictor.from_archive(archive, 'constituency-parser')
        results = predictor.predict_batch_json(inputs)

        result = results[0]
        assert len(result["spans"]) == 21 # number of possible substrings of the sentence.
        assert len(result["class_probabilities"]) == 21
        assert result["tokens"] == ["What", "a", "great", "test", "sentence", "."]
        assert isinstance(result["trees"], str)

        for class_distribution in result["class_probabilities"]:
            self.assertAlmostEqual(sum(class_distribution), 1.0, places=4)

        result = results[1]

        assert len(result["spans"]) == 36 # number of possible substrings of the sentence.
        assert len(result["class_probabilities"]) == 36
        assert result["tokens"] == ["Here", "'s", "another", "good", ",", "interesting", "one", "."]
        assert isinstance(result["trees"], str)

        for class_distribution in result["class_probabilities"]:
            self.assertAlmostEqual(sum(class_distribution), 1.0, places=4)
    def test_uses_named_inputs(self):
        inputs = {
                "premise": "I always write unit tests for my code.",
                "hypothesis": "One time I didn't write any unit tests for my code."
        }

        archive = load_archive(self.FIXTURES_ROOT / 'decomposable_attention' / 'serialization' / 'model.tar.gz')
        predictor = Predictor.from_archive(archive, 'textual-entailment')
        result = predictor.predict_json(inputs)

        # Label probs should be 3 floats that sum to one
        label_probs = result.get("label_probs")
        assert label_probs is not None
        assert isinstance(label_probs, list)
        assert len(label_probs) == 3
        assert all(isinstance(x, float) for x in label_probs)
        assert all(x >= 0 for x in label_probs)
        assert sum(label_probs) == approx(1.0)

        # Logits should be 3 floats that softmax to label_probs
        label_logits = result.get("label_logits")
        assert label_logits is not None
        assert isinstance(label_logits, list)
        assert len(label_logits) == 3
        assert all(isinstance(x, float) for x in label_logits)

        exps = [math.exp(x) for x in label_logits]
        sumexps = sum(exps)
        for e, p in zip(exps, label_probs):
            assert e / sumexps == approx(p)
Esempio n. 3
0
    def test_uses_named_inputs(self):
        inputs = {
                "question": "What kind of test succeeded on its first attempt?",
                "passage": "One time I was writing a unit test, and it succeeded on the first attempt."
        }

        archive = load_archive(self.FIXTURES_ROOT / 'bidaf' / 'serialization' / 'model.tar.gz')
        predictor = Predictor.from_archive(archive, 'machine-comprehension')

        result = predictor.predict_json(inputs)

        best_span = result.get("best_span")
        assert best_span is not None
        assert isinstance(best_span, list)
        assert len(best_span) == 2
        assert all(isinstance(x, int) for x in best_span)
        assert best_span[0] <= best_span[1]

        best_span_str = result.get("best_span_str")
        assert isinstance(best_span_str, str)
        assert best_span_str != ""

        for probs_key in ("span_start_probs", "span_end_probs"):
            probs = result.get(probs_key)
            assert probs is not None
            assert all(isinstance(x, float) for x in probs)
            assert sum(probs) == approx(1.0)
    def test_batch_prediction(self):
        inputs = [
                {
                        "sentence": "What kind of test succeeded on its first attempt?",
                },
                {
                        "sentence": "What kind of test succeeded on its first attempt at batch processing?",
                }
        ]

        archive = load_archive(self.FIXTURES_ROOT / 'biaffine_dependency_parser'
                               / 'serialization' / 'model.tar.gz')
        predictor = Predictor.from_archive(archive, 'biaffine-dependency-parser')

        results = predictor.predict_batch_json(inputs)
        assert len(results) == 2

        for result in results:
            sequence_length = len(result.get("words"))
            predicted_heads = result.get("predicted_heads")
            assert len(predicted_heads) == sequence_length

            predicted_dependencies = result.get("predicted_dependencies")
            assert len(predicted_dependencies) == sequence_length
            assert isinstance(predicted_dependencies, list)
            assert all(isinstance(x, str) for x in predicted_dependencies)
    def test_uses_named_inputs(self):
        """
        Tests whether the model outputs conform to the expected format.
        """
        inputs = {
                "sentence": "Angela Merkel met and spoke to her EU counterparts during the climate summit."
        }

        archive = load_archive(self.FIXTURES_ROOT / \
                               'srl' / 'serialization' / 'model.tar.gz')
        predictor = Predictor.from_archive(archive, 'open-information-extraction')

        result = predictor.predict_json(inputs)

        words = result.get("words")
        assert words == ["Angela", "Merkel", "met", "and", "spoke", "to", "her", "EU", "counterparts",
                         "during", "the", "climate", "summit", "."]
        num_words = len(words)

        verbs = result.get("verbs")
        assert verbs is not None
        assert isinstance(verbs, list)

        for verb in verbs:
            tags = verb.get("tags")
            assert tags is not None
            assert isinstance(tags, list)
            assert all(isinstance(tag, str) for tag in tags)
            assert len(tags) == num_words
Esempio n. 6
0
    def test_uses_named_inputs(self):
        inputs = {
                "sentence": "The squirrel wrote a unit test to make sure its nuts worked as designed."
        }

        archive = load_archive(self.FIXTURES_ROOT / 'srl' / 'serialization' / 'model.tar.gz')
        predictor = Predictor.from_archive(archive, 'semantic-role-labeling')

        result = predictor.predict_json(inputs)

        words = result.get("words")
        assert words == ["The", "squirrel", "wrote", "a", "unit", "test",
                         "to", "make", "sure", "its", "nuts", "worked", "as", "designed", "."]
        num_words = len(words)

        verbs = result.get("verbs")
        assert verbs is not None
        assert isinstance(verbs, list)

        assert any(v["verb"] == "wrote" for v in verbs)
        assert any(v["verb"] == "make" for v in verbs)
        assert any(v["verb"] == "worked" for v in verbs)

        for verb in verbs:
            tags = verb.get("tags")
            assert tags is not None
            assert isinstance(tags, list)
            assert all(isinstance(tag, str) for tag in tags)
            assert len(tags) == num_words
Esempio n. 7
0
 def test_batch_prediction(self):
     inputs = {
             "sentence": "The squirrel wrote a unit test to make sure its nuts worked as designed."
     }
     archive = load_archive(self.FIXTURES_ROOT / 'srl' / 'serialization' / 'model.tar.gz')
     predictor = Predictor.from_archive(archive, 'semantic-role-labeling')
     result = predictor.predict_batch_json([inputs, inputs])
     assert result[0] == result[1]
    def test_atis_parser_batch_predicted_sql_present(self):
        inputs = [{"utterance": "show me flights to seattle"}]

        archive_path = self.FIXTURES_ROOT / "atis" / "serialization" / "model.tar.gz"
        archive = load_archive(archive_path)
        predictor = Predictor.from_archive(archive, "atis-parser")

        result = predictor.predict_batch_json(inputs)
        predicted_sql_query = result[0].get("predicted_sql_query")
        assert predicted_sql_query is not None
Esempio n. 9
0
    def __init__(self):
        # ALLEN NLP Corereference pre-trained model
        pretrained_coref_path = './allennlp_pretrained/allennlp_coref-model-2018.02.05.tar.gz'

        if not os.path.exists(pretrained_coref_path):
            coref_url = "https://s3-us-west-2.amazonaws.com/allennlp/models/coref-model-2018.02.05.tar.gz"
            os.mkdir('allennlp_pretrained')
            wget.download(coref_url, pretrained_coref_path)

        self.predictor = Predictor.from_path(pretrained_coref_path)
    def test_atis_parser_predicted_sql_present(self):
        inputs = {"utterance": "show me flights to seattle"}

        archive_path = self.FIXTURES_ROOT / 'semantic_parsing' / 'atis' / 'serialization' / 'model.tar.gz'
        archive = load_archive(archive_path)
        predictor = Predictor.from_archive(archive, 'atis-parser')

        result = predictor.predict_json(inputs)
        predicted_sql_query = result.get("predicted_sql_query")
        assert predicted_sql_query is not None
Esempio n. 11
0
    def load_predictor(self) -> Predictor:
        if self.pretrained_model_id is not None:
            from allennlp_models.pretrained import load_predictor

            return load_predictor(self.pretrained_model_id, overrides=self.overrides)

        assert self.archive_file is not None

        if self.use_old_load_method:
            from allennlp.models.archival import load_archive

            # Older versions require overrides to be passed as a JSON string.
            o = json.dumps(self.overrides) if self.overrides is not None else None
            archive = load_archive(self.archive_file, overrides=o)
            return Predictor.from_archive(archive, self.predictor_name)

        return Predictor.from_path(
            self.archive_file, predictor_name=self.predictor_name, overrides=self.overrides
        )
 def test_batch_prediction(self):
     inputs = {
         "sentence":
         "The squirrel wrote a unit test to make sure its nuts worked as designed."
     }
     archive = load_archive(self.FIXTURES_ROOT / 'srl' / 'serialization' /
                            'model.tar.gz')
     predictor = Predictor.from_archive(archive, 'semantic-role-labeling')
     result = predictor.predict_batch_json([inputs, inputs])
     assert result[0] == result[1]
Esempio n. 13
0
    def test_run(self, model: str):
        archive = load_archive(FIXTURES_ROOT / "rc" / model / "serialization" / "model.tar.gz")
        predictor = Predictor.from_archive(archive)

        data = [
            ("Alice is taller than Bob.", "Who is taller?"),
            ("Children were playing in the park.", "Was the park empty?"),
        ]
        suite = QuestionAnsweringSuite(context_key="passage", add_default_tests=True, data=data)
        suite.run(predictor, max_examples=10)
Esempio n. 14
0
    def test_loads_correct_dataset_reader(self):
        # This model has a different dataset reader configuration for train and validation. The parameter that
        # differs is instances_per_file.
        archive = load_archive(self.FIXTURES_ROOT /
                               "simple_tagger_with_span_f1" / "serialization" /
                               "model.tar.gz")

        predictor = Predictor.from_archive(archive, "sentence-tagger")
        assert len(predictor._dataset_reader._token_indexers) == 2

        predictor = Predictor.from_archive(archive,
                                           "sentence-tagger",
                                           dataset_reader_to_load="train")
        assert len(predictor._dataset_reader._token_indexers) == 1

        predictor = Predictor.from_archive(archive,
                                           "sentence-tagger",
                                           dataset_reader_to_load="validation")
        assert len(predictor._dataset_reader._token_indexers) == 2
Esempio n. 15
0
    def setUp(self):
        self.allens = {
            # TODO: Current download model is wrong on Allennlp.
            # 'universal': Predictor.from_path(MODEL2URL['universal']),
            'stanford': Predictor.from_path(MODEL2URL['stanford'])
        }

        self.results = {}
        for k in self.allens:
            self.results[k] = {}
        self.results['srl'] = {}

        sentences = [
            "This tool is called Forte.",
            "The goal of this project is to help you build NLP pipelines.",
            "NLP has never been made this easy before.",
            "Forte is named Forte because it is designed for text."
        ]
        self.document = ' '.join(sentences)

        for k in self.allens:
            self.results[k]['tokens'] = []
            self.results[k]['pos'] = []
            self.results[k]['dep_types'] = []
            self.results[k]['dep_heads'] = []
        self.results['srl']['verbs'] = []
        self.results['srl']['srl_tags'] = []

        for sent in sentences:
            for dep_type in self.allens.keys():
                results = self.allens[dep_type].predict(  # type: ignore
                    sentence=sent)
                self.results[dep_type]['tokens'].append(results['words'])
                self.results[dep_type]['pos'].append(results['pos'])
                self.results[dep_type]['dep_types'].append(
                    results['predicted_dependencies'])
                self.results[dep_type]['dep_heads'].append(
                    results['predicted_heads'])
            srl_predictor = Predictor.from_path(MODEL2URL['srl'])
            srl_results = parse_allennlp_srl_results(
                srl_predictor.predict(sentence=sent)['verbs'])
            self.results['srl']['verbs'].append(srl_results['verbs'])
            self.results['srl']['srl_tags'].append(srl_results['srl_tags'])
Esempio n. 16
0
 def test_batch_prediction(self):
     inputs = {
         "sentence":
         "The squirrel wrote a unit test to make sure its nuts worked as designed."
     }
     archive = load_archive(FIXTURES_ROOT / "structured_prediction" /
                            "srl" / "serialization" / "model.tar.gz")
     predictor = Predictor.from_archive(archive, "semantic_role_labeling")
     result = predictor.predict_batch_json([inputs, inputs])
     assert result[0] == result[1]
Esempio n. 17
0
def _get_predictor(args: argparse.Namespace) -> Predictor:
    check_for_gpu(args.cuda_device)
    archive = load_archive(
        args.archive_path,
        weights_file=args.weights_file,
        cuda_device=args.cuda_device,
        overrides=args.overrides,
    )

    return Predictor.from_archive(archive, args.predictor)
Esempio n. 18
0
    def test_loads_correct_dataset_reader(self):
        # The NAQANET archive has both a training and validation ``DatasetReader``
        # with different values for ``passage_length_limit`` (``1000`` for validation
        # and ``400`` for training).
        archive = load_archive(self.FIXTURES_ROOT / "naqanet" /
                               "serialization" / "model.tar.gz")

        predictor = Predictor.from_archive(archive, "machine-comprehension")
        assert predictor._dataset_reader.passage_length_limit == 1000

        predictor = Predictor.from_archive(archive,
                                           "machine-comprehension",
                                           dataset_reader_to_load="train")
        assert predictor._dataset_reader.passage_length_limit == 400

        predictor = Predictor.from_archive(archive,
                                           "machine-comprehension",
                                           dataset_reader_to_load="validation")
        assert predictor._dataset_reader.passage_length_limit == 1000
Esempio n. 19
0
    def __init__(self,
                 nlp,
                 attrs=('has_nfh', 'is_nfh', 'nfh', 'is_deter_nfh', 'nfh_head',
                        'is_implicit'),
                 force_extension=True):
        """Initialise the pipeline component.

        nlp (Language): The shared nlp object. Used to initialise the matcher
            with the shared `Vocab`, and create `Doc` match patterns.
        RETURNS (callable): A spaCy pipeline component.
        """
        download_models()

        home = path.expanduser("~")

        with open(path.join(home, NFH_DIR, IDENTIFICATION_NFH), 'rb') as f:
            self.identification = pickle.load(f)
            self.feature_extractor = FeatureExtractor(3)

        archive_model = load_archive(path.join(home, NFH_DIR, RESOLUTION_NFH))
        self.resolution_predictor = Predictor.from_archive(
            archive_model, 'nfh_classification')

        self._has_nfh, self._is_nfh, self._nfh, self._is_deter_nfh, \
            self._nfh_head, self._is_implicit = attrs
        self._nfh_items = 'nfh_items'

        # Add attributes
        Doc.set_extension(self._has_nfh,
                          getter=self.has_nfh,
                          force=force_extension)
        Span.set_extension(self._has_nfh,
                           getter=self.has_nfh,
                           force=force_extension)

        Doc.set_extension(self._nfh,
                          getter=self.iter_nfh,
                          force=force_extension)
        Span.set_extension(self._nfh,
                           getter=self.iter_nfh,
                           force=force_extension)

        Span.set_extension(self._is_nfh, default=False, force=force_extension)
        Token.set_extension(self._is_nfh, default=False, force=force_extension)
        Token.set_extension(self._is_deter_nfh,
                            default=False,
                            force=force_extension)
        Token.set_extension(self._nfh_head,
                            default=None,
                            force=force_extension)
        Token.set_extension(self._is_implicit,
                            default=False,
                            force=force_extension)

        Doc.set_extension(self._nfh_items, default=[], force=force_extension)
    def test_input_reduction(self):
        # test using entailment model
        inputs = {
            "premise": "I always write unit tests for my code.",
            "hypothesis":
            "One time I didn't write any unit tests for my code.",
        }

        archive = load_archive(self.FIXTURES_ROOT / "decomposable_attention" /
                               "serialization" / "model.tar.gz")
        predictor = Predictor.from_archive(archive, "textual-entailment")

        reducer = InputReduction(predictor)
        reduced = reducer.attack_from_json(inputs, "hypothesis",
                                           "grad_input_1")
        assert reduced is not None
        assert "final" in reduced
        assert "original" in reduced
        assert reduced["final"][0]  # always at least one token
        assert len(reduced["final"][0]) <= len(
            reduced["original"])  # input reduction removes tokens
        for word in reduced["final"][0]:  # no new words entered
            assert word in reduced["original"]

        # test using NER model (tests different underlying logic)
        inputs = {"sentence": "Eric Wallace was an intern at AI2"}

        archive = load_archive(self.FIXTURES_ROOT / "simple_tagger" /
                               "serialization" / "model.tar.gz")
        predictor = Predictor.from_archive(archive, "sentence-tagger")

        reducer = InputReduction(predictor)
        reduced = reducer.attack_from_json(inputs, "tokens", "grad_input_1")
        assert reduced is not None
        assert "final" in reduced
        assert "original" in reduced
        for reduced_input in reduced["final"]:
            assert reduced_input  # always at least one token
            assert len(reduced_input) <= len(
                reduced["original"])  # input reduction removes tokens
            for word in reduced_input:  # no new words entered
                assert word in reduced["original"]
    def initialize(self, resources: Resources, configs: Config):
        super().initialize(resources, configs)
        cuda_devices = itertools.cycle(configs['cuda_devices'])
        if configs.tag_formalism not in MODEL2URL:
            raise ProcessorConfigError('Incorrect value for tag_formalism')
        if configs.tag_formalism == 'stanford':
            self.predictor = {
                'stanford':
                Predictor.from_path(configs['stanford_url'],
                                    cuda_device=next(cuda_devices))
            }
        if 'srl' in configs.processors:
            self.predictor = {
                'stanford':
                Predictor.from_path(configs['stanford_url'],
                                    cuda_device=next(cuda_devices)),
                'srl':
                Predictor.from_path(configs['srl_url'],
                                    cuda_device=next(cuda_devices))
            }

        if configs.overwrite_entries:
            logger.warning("`overwrite_entries` is set to True, this means "
                           "that the entries of the same type as produced by "
                           "this processor will be overwritten if found.")
            if configs.allow_parallel_entries:
                logger.warning('Both `overwrite_entries` (whether to overwrite'
                               ' the entries of the same type as produced by '
                               'this processor) and '
                               '`allow_parallel_entries` (whether to allow '
                               'similar new entries when they already exist) '
                               'are True, all existing conflicting entries '
                               'will be deleted.')
        else:
            if not configs.allow_parallel_entries:
                logger.warning('Both `overwrite_entries` (whether to overwrite'
                               ' the entries of the same type as produced by '
                               'this processor) and '
                               '`allow_parallel_entries` (whether to allow '
                               'similar new entries when they already exist) '
                               'are False, processor will only run if there '
                               'are no existing conflicting entries.')
    def test_predictor_with_coverage_parser(self):
        archive_dir = self.FIXTURES_ROOT / "nlvr_coverage_semantic_parser" / "serialization"
        archive = load_archive(os.path.join(archive_dir, "model.tar.gz"))
        predictor = Predictor.from_archive(archive, "nlvr-parser")

        result = predictor.predict_json(self.inputs)
        assert "logical_form" in result
        assert "denotations" in result
        # result['denotations'] is a list corresponding to k-best logical forms, where k is 1 by
        # default.
        assert len(result["denotations"][0]) == 2  # Because there are two worlds in the input.
Esempio n. 23
0
    def test_predictor_with_direct_parser(self):
        archive_dir = self.FIXTURES_ROOT / 'semantic_parsing' / 'nlvr_direct_semantic_parser' / 'serialization'
        archive = load_archive(os.path.join(archive_dir, 'model.tar.gz'))
        predictor = Predictor.from_archive(archive, 'nlvr-parser')

        result = predictor.predict_json(self.inputs)
        assert 'logical_form' in result
        assert 'denotations' in result
        # result['denotations'] is a list corresponding to k-best logical forms, where k is 1 by
        # default.
        assert len(result['denotations'][0]) == 2  # Because there are two worlds in the input.
def _get_predictor(**params) -> Predictor:
    for package_name in params["include_package"]:
        import_submodules(package_name)
    cuda_device = params["cuda_device"]
    check_for_gpu(cuda_device)
    archive = load_archive(params["model_file"],
                           weights_file=params["weights_file"],
                           cuda_device=params["cuda_device"],
                           overrides=params["overrides"])

    return Predictor.from_archive(archive, params["predictor"])
Esempio n. 25
0
    def test_loads_correct_dataset_reader(self):
        # pylint: disable=protected-access
        # The ATIS archive has both training and validation ``DatasetReaders``. The
        # ``keep_if_unparseable`` argument has a different value in each of them
        # (``True`` for validation, ``False`` for training).
        archive = load_archive(self.FIXTURES_ROOT / 'semantic_parsing' /
                               'atis' / 'serialization' / 'model.tar.gz')

        predictor = Predictor.from_archive(archive, 'atis-parser')
        assert predictor._dataset_reader._keep_if_unparseable is True

        predictor = Predictor.from_archive(archive,
                                           'atis-parser',
                                           dataset_reader_to_load='train')
        assert predictor._dataset_reader._keep_if_unparseable is False

        predictor = Predictor.from_archive(archive,
                                           'atis-parser',
                                           dataset_reader_to_load='validation')
        assert predictor._dataset_reader._keep_if_unparseable is True
Esempio n. 26
0
def load_model(slug):
    """Return an AllenNLP Predictor for the trained model for `slug`."""
    assert is_trained(slug), "We haven't trained a model to load yet"
    model_file_name = os.path.join(SERIALIZATION_DIR, slug, "model.tar.gz")
    PREDICTOR_NAMES = {
        "pos-tagging": "sentence-tagger",
        "translation": "seq2seq",
        "classification": "text_classifier",
    }
    name = PREDICTOR_NAMES[slug]
    return Predictor.from_path(model_file_name, predictor_name=name)
    def test_prediction_with_no_verbs(self):
        """
        Tests whether the model copes with sentences without verbs.
        """
        input1 = {"sentence": "Blah no verb sentence."}
        archive = load_archive(self.FIXTURES_ROOT / \
                               'srl' / 'serialization' / 'model.tar.gz')
        predictor = Predictor.from_archive(archive, 'open-information-extraction')

        result = predictor.predict_json(input1)
        assert result == {'words': ['Blah', 'no', 'verb', 'sentence', '.'], 'verbs': []}
    def test_loads_correct_dataset_reader(self):

        # The ATIS archive has both training and validation ``DatasetReaders``. The
        # ``keep_if_unparseable`` argument has a different value in each of them
        # (``True`` for validation, ``False`` for training).
        archive = load_archive(self.FIXTURES_ROOT / "semantic_parsing" /
                               "atis" / "serialization" / "model.tar.gz")

        predictor = Predictor.from_archive(archive, "atis-parser")
        assert predictor._dataset_reader._keep_if_unparseable is True

        predictor = Predictor.from_archive(archive,
                                           "atis-parser",
                                           dataset_reader_to_load="train")
        assert predictor._dataset_reader._keep_if_unparseable is False

        predictor = Predictor.from_archive(archive,
                                           "atis-parser",
                                           dataset_reader_to_load="validation")
        assert predictor._dataset_reader._keep_if_unparseable is True
Esempio n. 29
0
    def test_uses_named_inputs(self):
        inputs = {"tokens": ["This", "is", "a", "sample", "sentence", "."]}

        archive = load_archive(
            'fixtures/streusle_tagger/serialization/model.tar.gz')
        predictor = Predictor.from_archive(archive, 'streusle-tagger')
        result = predictor.predict_json(inputs)
        tags_list = result.get("tags")
        for tag in tags_list:
            assert isinstance(tag, str)
            assert tag != ""
Esempio n. 30
0
 def test_works(self):
     inputs = {
         "prefix": ["Benton", "Brindge", "is", "in"],
         "expected_tail": "Washington",
         "entity_id": "Q4890550",
         "entity_indices": [0, 2],
         "shortlist": ["Q4890550", "Q35657"]
     }
     archive = load_archive('kglm/tests/fixtures/kglm.model.tar.gz')
     predictor = Predictor.from_archive(archive, 'cloze')
     predictor.predict_json(inputs)
     predictor.predict_json(inputs)
 def start_bundle(self):
     if self.predictor is not None:
         return
     model_dir = self.prepare_model()
     # the following line is a necessary bad import practice, otherwise beam tries to serialize allennlp and the
     # deserialization breaks on dataflow.
     from scibert.models import text_classifier
     from scibert.predictors.predictor import ScibertPredictor
     from allennlp.predictors import Predictor
     import scibert
     self.predictor = Predictor.from_path(model_dir,
                                          predictor_name="text_classifier")
Esempio n. 32
0
def ensemble_model_avg_logits(models: List[dict],
                              test_dataframe,
                              weight: Optional[List] = None):
    output_dataframe = test_dataframe.copy()

    cv_outputs = []
    cv_results = []
    start_cv_logits = []
    end_cv_logits = []

    ensemble_start_logits = []
    ensemble_end_logits = []

    for model_dict in tqdm(models):
        archive = load_archive(**model_dict)
        archive.model._delay = 50000
        predictor = Predictor.from_archive(archive, "tweet_sentiment")

        results, outputs = predict_test_data(test_dataframe, predictor)
        start_logits = []
        end_logits = []
        for output in outputs:
            start_logits.append(output["span_start_logits"])
            end_logits.append(output["span_end_logits"])
        start_cv_logits.append(start_logits)
        end_cv_logits.append(end_logits)
        cv_outputs.append(outputs)
        cv_results.append(results)

    if weight is None:
        weight = [1 / len(models)] * len(models)

    for i in range(test_dataframe.shape[0]):
        single_sample_start_logits = []
        single_sample_end_logits = []
        for j in start_cv_logits:
            single_sample_start_logits.append(torch.tensor(j[i]))
        for k in end_cv_logits:
            single_sample_end_logits.append(torch.tensor(k[i]))
        stack_j = torch.stack(single_sample_start_logits, dim=1)
        stack_k = torch.stack(single_sample_end_logits, dim=1)
        j_logits = torch.matmul(stack_j, torch.tensor(weight))
        k_logits = torch.matmul(stack_k, torch.tensor(weight))
        ensemble_start_logits.append(j_logits.tolist())
        ensemble_end_logits.append(k_logits.tolist())

    for idx, cv_logits in enumerate(start_cv_logits):
        output_dataframe[f"model{idx+1}_start_logits"] = cv_logits
    output_dataframe["ensemble_start_logits"] = ensemble_start_logits
    for idx, cv_logits in enumerate(end_cv_logits):
        output_dataframe[f"model{idx+1}_end_logits"] = cv_logits
    output_dataframe["ensemble_end_logits"] = ensemble_end_logits
    return output_dataframe, cv_outputs, cv_results
Esempio n. 33
0
    def test_atis_parser_batch_predicted_sql_present(self):
        inputs = [{
                "utterance": "show me flights to seattle",
        }]

        archive_path = self.FIXTURES_ROOT / 'semantic_parsing' / 'atis' / 'serialization' / 'model.tar.gz'
        archive = load_archive(archive_path)
        predictor = Predictor.from_archive(archive, 'atis-parser')

        result = predictor.predict_batch_json(inputs)
        predicted_sql_query = result[0].get("predicted_sql_query")
        assert predicted_sql_query is not None
Esempio n. 34
0
 def __init__(self,
              database_path,
              add_claim=False,
              max_pages_per_query=None):
     self.db = FeverDocDB(database_path)
     self.add_claim = add_claim
     self.max_pages_per_query = max_pages_per_query
     self.proter_stemm = nltk.PorterStemmer()
     self.tokenizer = nltk.word_tokenize
     self.predictor = Predictor.from_path(
         "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo-constituency-parser-2018.03.14.tar.gz"
     )
Esempio n. 35
0
class SRL:
    # predictor = Predictor.from_path("/root/.allennlp/models/srl-model-2018.05.25.tar.gz")
    predictor = Predictor.from_path("./models/srl-model-2018.05.25.tar.gz")

    @staticmethod
    def get_srl(document):
        if SRL.validate_doc(document):
            return SRL.predictor.predict(document)

    @staticmethod
    def validate_doc(document):
        return True
Esempio n. 36
0
 def test_copynet_predictions(self):
     archive = load_archive(self.FIXTURES_ROOT / 'encoder_decoder' / 'copynet_seq2seq' /
                            'serialization' / 'model.tar.gz')
     predictor = Predictor.from_archive(archive, 'seq2seq')
     model = predictor._model
     end_token = model.vocab.get_token_from_index(model._end_index, model._target_namespace)
     output_dict = predictor.predict("these tokens should be copied over : hello world")
     assert len(output_dict["predictions"]) == model._beam_search.beam_size
     assert len(output_dict["predicted_tokens"]) == model._beam_search.beam_size
     for predicted_tokens in output_dict["predicted_tokens"]:
         assert all(isinstance(x, str) for x in predicted_tokens)
         assert end_token not in predicted_tokens
Esempio n. 37
0
 def test_copynet_predictions(self):
     archive = load_archive(self.FIXTURES_ROOT / 'encoder_decoder' / 'copynet_seq2seq' /
                            'serialization' / 'model.tar.gz')
     predictor = Predictor.from_archive(archive, 'seq2seq')
     model = predictor._model
     end_token = model.vocab.get_token_from_index(model._end_index, model._target_namespace)
     output_dict = predictor.predict("these tokens should be copied over : hello world")
     assert len(output_dict["predictions"]) == model._beam_search.beam_size
     assert len(output_dict["predicted_tokens"]) == model._beam_search.beam_size
     for predicted_tokens in output_dict["predicted_tokens"]:
         assert all(isinstance(x, str) for x in predicted_tokens)
         assert end_token not in predicted_tokens
Esempio n. 38
0
 def __init__(self,
              pretrained_model_name_or_path: str,
              sphereize: bool = False,
              **kwargs) -> None:
     if pretrained_model_name_or_path in PRETRAINED_MODELS:
         pretrained_model_name_or_path = PRETRAINED_MODELS[
             pretrained_model_name_or_path]
     common_util.import_module_and_submodules("declutr")
     archive = load_archive(pretrained_model_name_or_path, **kwargs)
     self._predictor = Predictor.from_archive(archive,
                                              predictor_name="declutr")
     self._sphereize = sphereize
Esempio n. 39
0
    def test_answer_present_with_batch_predict(self):
        inputs = [
            {"question": "Who is 18 years old?", "table": "Name\tAge\nShallan\t16\nKaladin\t18"}
        ]

        archive_path = self.FIXTURES_ROOT / "wikitables" / "serialization" / "model.tar.gz"
        archive = load_archive(archive_path)
        predictor = Predictor.from_archive(archive, "wikitables-parser")

        result = predictor.predict_batch_json(inputs)
        answer = result[0].get("answer")
        assert answer is not None
Esempio n. 40
0
    def test_answer_present_with_batch_predict(self):
        inputs = [{
                "question": "Who is 18 years old?",
                "table": "Name\tAge\nShallan\t16\nKaladin\t18"
        }]

        archive_path = self.FIXTURES_ROOT / 'semantic_parsing' / 'wikitables' / 'serialization' / 'model.tar.gz'
        archive = load_archive(archive_path)
        predictor = Predictor.from_archive(archive, 'wikitables-parser')

        result = predictor.predict_batch_json(inputs)
        answer = result[0].get("answer")
        assert answer is not None
Esempio n. 41
0
    def test_prediction_with_no_verbs(self):

        input1 = {"sentence": "Blah no verb sentence."}
        archive = load_archive(self.FIXTURES_ROOT / 'srl' / 'serialization' / 'model.tar.gz')
        predictor = Predictor.from_archive(archive, 'semantic-role-labeling')
        result = predictor.predict_json(input1)
        assert result == {'words': ['Blah', 'no', 'verb', 'sentence', '.'], 'verbs': []}

        input2 = {"sentence": "This sentence has a verb."}
        results = predictor.predict_batch_json([input1, input2])
        assert results[0] == {'words': ['Blah', 'no', 'verb', 'sentence', '.'], 'verbs': []}
        assert results[1] == {'words': ['This', 'sentence', 'has', 'a', 'verb', '.'],
                              'verbs': [{'verb': 'has', 'description': 'This sentence has a verb .',
                                         'tags': ['O', 'O', 'O', 'O', 'O', 'O']}]}
Esempio n. 42
0
    def test_uses_named_inputs_with_simple_seq2seq(self):
        inputs = {
                "source": "What kind of test succeeded on its first attempt?",
        }

        archive = load_archive(self.FIXTURES_ROOT / 'encoder_decoder' / 'simple_seq2seq' /
                               'serialization' / 'model.tar.gz')
        predictor = Predictor.from_archive(archive, 'seq2seq')

        result = predictor.predict_json(inputs)

        predicted_tokens = result.get("predicted_tokens")
        assert predicted_tokens is not None
        assert isinstance(predicted_tokens, list)
        assert all(isinstance(x, str) for x in predicted_tokens)
Esempio n. 43
0
    def test_uses_named_inputs(self):
        inputs = {"document": "This is a single string document about a test. Sometimes it "
                              "contains coreferent parts."}
        archive = load_archive(self.FIXTURES_ROOT / 'coref' / 'serialization' / 'model.tar.gz')
        predictor = Predictor.from_archive(archive, 'coreference-resolution')

        result = predictor.predict_json(inputs)
        self.assert_predict_result(result)

        document = ['This', 'is', 'a', 'single', 'string',
                    'document', 'about', 'a', 'test', '.', 'Sometimes',
                    'it', 'contains', 'coreferent', 'parts', '.']

        result_doc_words = predictor.predict_tokenized(document)
        self.assert_predict_result(result_doc_words)
Esempio n. 44
0
    def test_uses_named_inputs(self):
        inputs = {
                "sentence": "The squirrel wrote a unit test to make sure its nuts worked as designed."
        }

        archive = load_archive(self.FIXTURES_ROOT / 'srl' / 'serialization' / 'model.tar.gz')
        predictor = Predictor.from_archive(archive, 'semantic-role-labeling')

        result_json = predictor.predict_json(inputs)
        self.assert_predict_result(result_json)

        words = ["The", "squirrel", "wrote", "a", "unit", "test",
                 "to", "make", "sure", "its", "nuts", "worked", "as", "designed", "."]

        result_words = predictor.predict_tokenized(words)
        self.assert_predict_result(result_words)
    def test_uses_named_inputs(self):
        inputs = {
                "sentence": "What a great test sentence.",
        }

        archive = load_archive(self.FIXTURES_ROOT / 'constituency_parser' / 'serialization' / 'model.tar.gz')
        predictor = Predictor.from_archive(archive, 'constituency-parser')

        result = predictor.predict_json(inputs)

        assert len(result["spans"]) == 21 # number of possible substrings of the sentence.
        assert len(result["class_probabilities"]) == 21
        assert result["tokens"] == ["What", "a", "great", "test", "sentence", "."]
        assert isinstance(result["trees"], str)

        for class_distribution in result["class_probabilities"]:
            self.assertAlmostEqual(sum(class_distribution), 1.0, places=4)
    def test_predictor_uses_dataset_reader_to_determine_pos_set(self):
        # pylint: disable=protected-access
        archive = load_archive(self.FIXTURES_ROOT / 'biaffine_dependency_parser'
                               / 'serialization' / 'model.tar.gz')
        predictor = Predictor.from_archive(archive, 'biaffine-dependency-parser')

        inputs = {
                "sentence": "Dogs eat cats.",
        }
        instance_with_ud_pos = predictor._json_to_instance(inputs)
        tags = instance_with_ud_pos.fields["pos_tags"].labels
        assert tags == ['NOUN', 'VERB', 'NOUN', 'PUNCT']

        predictor._dataset_reader.use_language_specific_pos = True

        instance_with_ptb_pos = predictor._json_to_instance(inputs)
        tags = instance_with_ptb_pos.fields["pos_tags"].labels
        assert tags == ['NNS', 'VBP', 'NNS', '.']
Esempio n. 47
0
    def test_uses_named_inputs(self):
        inputs = {"paragraphs": [{"qas": [{"followup": "y", "yesno": "x", "question": "When was the first one?",
                                           "answers": [{"answer_start": 0, "text": "One time"}], "id": "C_q#0"},
                                          {"followup": "n", "yesno": "x", "question": "What were you doing?",
                                           "answers": [{"answer_start": 15, "text": "writing a"}], "id": "C_q#1"},
                                          {"followup": "m", "yesno": "y", "question": "How often?",
                                           "answers": [{"answer_start": 4, "text": "time I"}], "id": "C_q#2"}],
                                  "context": "One time I was writing a unit test,\
                                   and it succeeded on the first attempt."}]}

        archive = load_archive(self.FIXTURES_ROOT / 'dialog_qa' / 'serialization' / 'model.tar.gz')
        predictor = Predictor.from_archive(archive, 'dialog_qa')

        result = predictor.predict_json(inputs)

        best_span_str_list = result.get("best_span_str")
        for best_span_str in best_span_str_list:
            assert isinstance(best_span_str, str)
            assert best_span_str != ""
Esempio n. 48
0
    def test_atis_parser_uses_named_inputs(self):
        inputs = {
                "utterance": "show me the flights to seattle",
        }

        archive_path = self.FIXTURES_ROOT / 'semantic_parsing' / 'atis' / 'serialization' / 'model.tar.gz'
        archive = load_archive(archive_path)
        predictor = Predictor.from_archive(archive, 'atis-parser')

        result = predictor.predict_json(inputs)
        action_sequence = result.get("best_action_sequence")
        if action_sequence:
            # An untrained model will likely get into a loop, and not produce at finished states.
            # When the model gets into a loop it will not produce any valid SQL, so we don't get
            # any actions. This basically just tests if the model runs.
            assert len(action_sequence) > 1
            assert all([isinstance(action, str) for action in action_sequence])
            predicted_sql_query = result.get("predicted_sql_query")
            assert predicted_sql_query is not None
Esempio n. 49
0
    def test_answer_present(self):
        inputs = {
                'question':  'Mike was snowboarding on the snow and hit a piece of ice. He went much faster on the ice because _____ is smoother. (A) snow (B) ice',  # pylint: disable=line-too-long
                'world_literals': {'world1': 'snow', 'world2': 'ice'},  # Added to avoid world tagger
                'qrspec': '[smoothness, +speed]',
                'entitycues': 'smoothness: smoother\nspeed:faster'
        }

        archive_path = self.FIXTURES_ROOT / 'semantic_parsing' / 'quarel' / 'serialization_parser_zeroshot' / 'model.tar.gz'  # pylint: disable=line-too-long
        archive = load_archive(archive_path)
        predictor = Predictor.from_archive(archive, 'quarel-parser')

        result = predictor.predict_json(inputs)
        answer_index = result.get('answer_index')
        assert answer_index is not None

        # Check input modality where entity cues are not given
        del inputs['entitycues']
        result = predictor.predict_json(inputs)
        answer_index = result.get('answer_index')
        assert answer_index is not None
    def test_uses_named_inputs(self):
        inputs = {
                "sentence": "Please could you parse this sentence?",
        }

        archive = load_archive(self.FIXTURES_ROOT / 'biaffine_dependency_parser'
                               / 'serialization' / 'model.tar.gz')
        predictor = Predictor.from_archive(archive, 'biaffine-dependency-parser')

        result = predictor.predict_json(inputs)

        words = result.get("words")
        predicted_heads = result.get("predicted_heads")
        assert len(predicted_heads) == len(words)

        predicted_dependencies = result.get("predicted_dependencies")
        assert len(predicted_dependencies) == len(words)
        assert isinstance(predicted_dependencies, list)
        assert all(isinstance(x, str) for x in predicted_dependencies)

        assert result.get("loss") is not None
        assert result.get("arc_loss") is not None
        assert result.get("tag_loss") is not None

        hierplane_tree = result.get("hierplane_tree")
        hierplane_tree.pop("nodeTypeToStyle")
        hierplane_tree.pop("linkToPosition")
        # pylint: disable=line-too-long,bad-continuation
        assert result.get("hierplane_tree") == {'text': 'Please could you parse this sentence ?',
                                                'root': {'word': 'Please', 'nodeType': 'det', 'attributes': ['INTJ'], 'link': 'det', 'spans': [{'start': 0, 'end': 7}],
                                                    'children': [
                                                            {'word': 'could', 'nodeType': 'nummod', 'attributes': ['VERB'], 'link': 'nummod', 'spans': [{'start': 7, 'end': 13}]},
                                                            {'word': 'you', 'nodeType': 'nummod', 'attributes': ['PRON'], 'link': 'nummod', 'spans': [{'start': 13, 'end': 17}]},
                                                            {'word': 'parse', 'nodeType': 'nummod', 'attributes': ['VERB'], 'link': 'nummod', 'spans': [{'start': 17, 'end': 23}]},
                                                            {'word': 'this', 'nodeType': 'nummod', 'attributes': ['DET'], 'link': 'nummod', 'spans': [{'start': 23, 'end': 28}]},
                                                            {'word': 'sentence', 'nodeType': 'nummod', 'attributes':['NOUN'], 'link': 'nummod', 'spans': [{'start': 28, 'end': 37}]},
                                                            {'word': '?', 'nodeType': 'nummod', 'attributes': ['PUNCT'], 'link': 'nummod', 'spans': [{'start': 37, 'end': 39}]}
                                                            ]
                                                        }
                                               }
Esempio n. 51
0
def main(args):
    # Executing this file with no extra options runs the simple service with the bidaf test fixture
    # and the machine-comprehension predictor. There's no good reason you'd want
    # to do this, except possibly to test changes to the stock HTML).

    parser = argparse.ArgumentParser(description='Serve up a simple model')

    parser.add_argument('--archive-path', type=str, required=True, help='path to trained archive file')
    parser.add_argument('--predictor', type=str, required=True, help='name of predictor')
    parser.add_argument('--static-dir', type=str, help='serve index.html from this directory')
    parser.add_argument('--title', type=str, help='change the default page title', default="AllenNLP Demo")
    parser.add_argument('--field-name', type=str, required=True, action='append',
                        help='field names to include in the demo')
    parser.add_argument('--port', type=int, default=8000, help='port to serve the demo on')

    parser.add_argument('--include-package',
                        type=str,
                        action='append',
                        default=[],
                        help='additional packages to include')

    args = parser.parse_args(args)

    # Load modules
    for package_name in args.include_package:
        import_submodules(package_name)

    archive = load_archive(args.archive_path)
    predictor = Predictor.from_archive(archive, args.predictor)
    field_names = args.field_name

    app = make_app(predictor=predictor,
                   field_names=field_names,
                   static_dir=args.static_dir,
                   title=args.title)
    CORS(app)

    http_server = WSGIServer(('0.0.0.0', args.port), app)
    print(f"Model loaded, serving demo on port {args.port}")
    http_server.serve_forever()
Esempio n. 52
0
    def test_uses_named_inputs(self):
        inputs = {
                "source": "personx gave persony a present",
        }

        archive = load_archive(self.FIXTURES_ROOT / 'event2mind' /
                               'serialization' / 'model.tar.gz')
        predictor = Predictor.from_archive(archive, 'event2mind')

        result = predictor.predict_json(inputs)

        token_names = [
                'xintent_top_k_predicted_tokens',
                'xreact_top_k_predicted_tokens',
                'oreact_top_k_predicted_tokens'
        ]

        for token_name in token_names:
            all_predicted_tokens = result.get(token_name)
            for predicted_tokens in all_predicted_tokens:
                assert isinstance(predicted_tokens, list)
                assert all(isinstance(x, str) for x in predicted_tokens)
    def test_batch_prediction(self):
        batch_inputs = [
                {
                        "premise": "I always write unit tests for my code.",
                        "hypothesis": "One time I didn't write any unit tests for my code."
                },
                {
                        "premise": "I also write batched unit tests for throughput!",
                        "hypothesis": "Batch tests are slower."
                },
        ]

        archive = load_archive(self.FIXTURES_ROOT / 'decomposable_attention' / 'serialization' / 'model.tar.gz')
        predictor = Predictor.from_archive(archive, 'textual-entailment')
        results = predictor.predict_batch_json(batch_inputs)
        print(results)
        assert len(results) == 2

        for result in results:
            # Logits should be 3 floats that softmax to label_probs
            label_logits = result.get("label_logits")
            # Label probs should be 3 floats that sum to one
            label_probs = result.get("label_probs")
            assert label_probs is not None
            assert isinstance(label_probs, list)
            assert len(label_probs) == 3
            assert all(isinstance(x, float) for x in label_probs)
            assert all(x >= 0 for x in label_probs)
            assert sum(label_probs) == approx(1.0)

            assert label_logits is not None
            assert isinstance(label_logits, list)
            assert len(label_logits) == 3
            assert all(isinstance(x, float) for x in label_logits)

            exps = [math.exp(x) for x in label_logits]
            sumexps = sum(exps)
            for e, p in zip(exps, label_probs):
                assert e / sumexps == approx(p)
Esempio n. 54
0
    def test_batch_prediction(self):
        inputs = [{"paragraphs": [{"qas": [{"followup": "y", "yesno": "x", "question": "When was the first one?",
                                            "answers": [{"answer_start": 0, "text": "One time"}], "id": "C_q#0"},
                                           {"followup": "n", "yesno": "x", "question": "What were you doing?",
                                            "answers": [{"answer_start": 15, "text": "writing a"}], "id": "C_q#1"},
                                           {"followup": "m", "yesno": "y", "question": "How often?",
                                            "answers": [{"answer_start": 4, "text": "time I"}], "id": "C_q#2"}],
                                   "context": "One time I was writing a unit test,\
                                    and it succeeded on the first attempt."}]},
                  {"paragraphs": [{"qas": [{"followup": "y", "yesno": "x", "question": "When was the first one?",
                                            "answers": [{"answer_start": 0, "text": "One time"}], "id": "C_q#0"},
                                           {"followup": "n", "yesno": "x", "question": "What were you doing?",
                                            "answers": [{"answer_start": 15, "text": "writing a"}], "id": "C_q#1"},
                                           {"followup": "m", "yesno": "y", "question": "How often?",
                                            "answers": [{"answer_start": 4, "text": "time I"}], "id": "C_q#2"}],
                                   "context": "One time I was writing a unit test,\
                                    and it succeeded on the first attempt."}]}]

        archive = load_archive(self.FIXTURES_ROOT / 'dialog_qa' / 'serialization' / 'model.tar.gz')
        predictor = Predictor.from_archive(archive, 'dialog_qa')

        results = predictor.predict_batch_json(inputs)
        assert len(results) == 2
Esempio n. 55
0
    def test_uses_named_inputs(self):
        inputs = {
                "question": "names",
                "table": "name\tdate\nmatt\t2017\npradeep\t2018"
        }

        archive_path = self.FIXTURES_ROOT / 'semantic_parsing' / 'wikitables' / 'serialization' / 'model.tar.gz'
        archive = load_archive(archive_path)
        predictor = Predictor.from_archive(archive, 'wikitables-parser')

        result = predictor.predict_json(inputs)

        action_sequence = result.get("best_action_sequence")
        if action_sequence:
            # We don't currently disallow endless loops in the decoder, and an untrained seq2seq
            # model will easily get itself into a loop.  An endless loop isn't a finished logical
            # form, so decoding doesn't return any finished states, which means no actions.  So,
            # sadly, we don't have a great test here.  This is just testing that the predictor
            # runs, basically.
            assert len(action_sequence) > 1
            assert all([isinstance(action, str) for action in action_sequence])

            logical_form = result.get("logical_form")
            assert logical_form is not None
Esempio n. 56
0
    def test_uses_named_inputs(self):
        inputs = {"document": "This is a single string document about a test. Sometimes it "
                              "contains coreferent parts."}
        archive = load_archive(self.FIXTURES_ROOT / 'coref' / 'serialization' / 'model.tar.gz')
        predictor = Predictor.from_archive(archive, 'coreference-resolution')

        result = predictor.predict_json(inputs)

        document = result["document"]
        assert document == ['This', 'is', 'a', 'single', 'string',
                            'document', 'about', 'a', 'test', '.', 'Sometimes',
                            'it', 'contains', 'coreferent', 'parts', '.']

        clusters = result["clusters"]
        assert isinstance(clusters, list)
        for cluster in clusters:
            assert isinstance(cluster, list)
            for mention in cluster:
                # Spans should be integer indices.
                assert isinstance(mention[0], int)
                assert isinstance(mention[1], int)
                # Spans should be inside document.
                assert 0 < mention[0] <= len(document)
                assert 0 < mention[1] <= len(document)
Esempio n. 57
0
    def test_batch_prediction(self):
        inputs = [
                {
                        "question": "What kind of test succeeded on its first attempt?",
                        "passage": "One time I was writing a unit test, and it succeeded on the first attempt."
                },
                {
                        "question": "What kind of test succeeded on its first attempt at batch processing?",
                        "passage": "One time I was writing a unit test, and it always failed!"
                }
        ]

        archive = load_archive(self.FIXTURES_ROOT / 'bidaf' / 'serialization' / 'model.tar.gz')
        predictor = Predictor.from_archive(archive, 'machine-comprehension')

        results = predictor.predict_batch_json(inputs)
        assert len(results) == 2

        for result in results:
            best_span = result.get("best_span")
            best_span_str = result.get("best_span_str")
            start_probs = result.get("span_start_probs")
            end_probs = result.get("span_end_probs")
            assert best_span is not None
            assert isinstance(best_span, list)
            assert len(best_span) == 2
            assert all(isinstance(x, int) for x in best_span)
            assert best_span[0] <= best_span[1]

            assert isinstance(best_span_str, str)
            assert best_span_str != ""

            for probs in (start_probs, end_probs):
                assert probs is not None
                assert all(isinstance(x, float) for x in probs)
                assert sum(probs) == approx(1.0)
    def test_build_hierplane_tree(self):
        tree = Tree.fromstring("(S (NP (D the) (N dog)) (VP (V chased) (NP (D the) (N cat))))")
        archive = load_archive(self.FIXTURES_ROOT / 'constituency_parser' / 'serialization' / 'model.tar.gz')
        predictor = Predictor.from_archive(archive, 'constituency-parser')

        hierplane_tree = predictor._build_hierplane_tree(tree, 0, is_root=True)

        # pylint: disable=bad-continuation
        correct_tree = {
                'text': 'the dog chased the cat',
                "linkNameToLabel": LINK_TO_LABEL,
                "nodeTypeToStyle": NODE_TYPE_TO_STYLE,
                'root': {
                        'word': 'the dog chased the cat',
                        'nodeType': 'S',
                        'attributes': ['S'],
                        'link': 'S',
                        'children': [{
                                'word': 'the dog',
                                'nodeType': 'NP',
                                'attributes': ['NP'],
                                'link': 'NP',
                                'children': [{
                                        'word': 'the',
                                        'nodeType': 'D',
                                        'attributes': ['D'],
                                        'link': 'D'
                                        },
                                        {
                                        'word': 'dog',
                                        'nodeType': 'N',
                                        'attributes': ['N'],
                                        'link': 'N'}
                                        ]
                                },
                                {
                                'word': 'chased the cat',
                                'nodeType': 'VP',
                                'attributes': ['VP'],
                                'link': 'VP',
                                'children': [{
                                    'word': 'chased',
                                    'nodeType': 'V',
                                    'attributes': ['V'],
                                    'link': 'V'
                                    },
                                    {
                                    'word':
                                    'the cat',
                                    'nodeType': 'NP',
                                    'attributes': ['NP'],
                                    'link': 'NP',
                                    'children': [{
                                            'word': 'the',
                                            'nodeType': 'D',
                                            'attributes': ['D'],
                                            'link': 'D'
                                            },
                                            {
                                            'word': 'cat',
                                            'nodeType': 'N',
                                            'attributes': ['N'],
                                            'link': 'N'}
                                        ]
                                    }
                                ]
                            }
                        ]
                    }
                }
        # pylint: enable=bad-continuation
        assert correct_tree == hierplane_tree