Пример #1
0
def _get_predictor(args: argparse.Namespace) -> Predictor:
    archive = load_archive(args.archive_file,
                           weights_file=args.weights_file,
                           cuda_device=args.cuda_device,
                           overrides=args.overrides)

    if args.predictor:
        # Predictor explicitly specified, so use it
        return Predictor.from_archive(archive, args.predictor)

    # Otherwise, use the mapping
    model_type = archive.config.get("model").get("type")
    if model_type not in DEFAULT_PREDICTORS:
        raise ConfigurationError(f"No known predictor for model type {model_type}.\n"
                                 f"Specify one with the --predictor flag.")
    return Predictor.from_archive(archive, DEFAULT_PREDICTORS[model_type])
    def test_uses_named_inputs(self):
        inputs = {
                "premise": "I always write unit tests for my code.",
                "hypothesis": "One time I didn't write any unit tests for my code."
        }

        archive = load_archive('tests/fixtures/decomposable_attention/serialization/model.tar.gz')
        predictor = Predictor.from_archive(archive, 'textual-entailment')
        result = predictor.predict_json(inputs)

        # Label probs should be 3 floats that sum to one
        label_probs = result.get("label_probs")
        assert label_probs is not None
        assert isinstance(label_probs, list)
        assert len(label_probs) == 3
        assert all(isinstance(x, float) for x in label_probs)
        assert all(x >= 0 for x in label_probs)
        assert sum(label_probs) == approx(1.0)

        # Logits should be 3 floats that softmax to label_probs
        label_logits = result.get("label_logits")
        assert label_logits is not None
        assert isinstance(label_logits, list)
        assert len(label_logits) == 3
        assert all(isinstance(x, float) for x in label_logits)

        exps = [math.exp(x) for x in label_logits]
        sumexps = sum(exps)
        for e, p in zip(exps, label_probs):
            assert e / sumexps == approx(p)
Пример #3
0
    def test_uses_named_inputs(self):
        inputs = {
                "sentence": "The squirrel wrote a unit test to make sure its nuts worked as designed."
        }

        archive = load_archive('tests/fixtures/srl/serialization/model.tar.gz')
        predictor = Predictor.from_archive(archive, 'semantic-role-labeling')

        result = predictor.predict_json(inputs)
        words = result.get("words")
        assert words == ["The", "squirrel", "wrote", "a", "unit", "test",
                         "to", "make", "sure", "its", "nuts", "worked", "as", "designed", "."]
        num_words = len(words)

        verbs = result.get("verbs")
        assert verbs is not None
        assert isinstance(verbs, list)

        assert any(v["verb"] == "wrote" for v in verbs)
        assert any(v["verb"] == "make" for v in verbs)
        assert any(v["verb"] == "worked" for v in verbs)

        for verb in verbs:
            tags = verb.get("tags")
            assert tags is not None
            assert isinstance(tags, list)
            assert all(isinstance(tag, str) for tag in tags)
            assert len(tags) == num_words
Пример #4
0
    def test_uses_named_inputs(self):
        inputs = {
                "question": "What kind of test succeeded on its first attempt?",
                "passage": "One time I was writing a unit test, and it succeeded on the first attempt."
        }

        archive = load_archive('tests/fixtures/bidaf/serialization/model.tar.gz')
        predictor = Predictor.from_archive(archive, 'machine-comprehension')

        result = predictor.predict_json(inputs)

        best_span = result.get("best_span")
        assert best_span is not None
        assert isinstance(best_span, list)
        assert len(best_span) == 2
        assert all(isinstance(x, int) for x in best_span)
        assert best_span[0] <= best_span[1]

        best_span_str = result.get("best_span_str")
        assert isinstance(best_span_str, str)
        assert best_span_str != ""

        for probs_key in ("span_start_probs", "span_end_probs"):
            probs = result.get(probs_key)
            assert probs is not None
            assert all(isinstance(x, float) for x in probs)
            assert sum(probs) == approx(1.0)
    def test_batch_prediction(self):
        inputs = [
                {"sentence": "What a great test sentence."},
                {"sentence": "Here's another good, interesting one."}
        ]

        archive = load_archive('tests/fixtures/constituency_parser/serialization/model.tar.gz')
        predictor = Predictor.from_archive(archive, 'constituency-parser')
        results = predictor.predict_batch_json(inputs)

        result = results[0]
        assert len(result["spans"]) == 21 # number of possible substrings of the sentence.
        assert len(result["class_probabilities"]) == 21
        assert result["tokens"] == ["What", "a", "great", "test", "sentence", "."]
        assert isinstance(result["trees"], str)

        for class_distribution in result["class_probabilities"]:
            self.assertAlmostEqual(sum(class_distribution), 1.0, places=4)

        result = results[1]

        assert len(result["spans"]) == 36 # number of possible substrings of the sentence.
        assert len(result["class_probabilities"]) == 36
        assert result["tokens"] == ["Here", "'s", "another", "good", ",", "interesting", "one", "."]
        assert isinstance(result["trees"], str)

        for class_distribution in result["class_probabilities"]:
            self.assertAlmostEqual(sum(class_distribution), 1.0, places=4)
Пример #6
0
 def test_batch_prediction(self):
     inputs = {
             "sentence": "The squirrel wrote a unit test to make sure its nuts worked as designed."
     }
     archive = load_archive('tests/fixtures/srl/serialization/model.tar.gz')
     predictor = Predictor.from_archive(archive, 'semantic-role-labeling')
     result = predictor.predict_batch_json([inputs, inputs])
     assert result[0] == result[1]
    def test_uses_named_inputs(self):
        inputs = {
                "source": "What kind of test succeeded on its first attempt?",
        }

        archive = load_archive('tests/fixtures/encoder_decoder/simple_seq2seq/serialization/model.tar.gz')
        predictor = Predictor.from_archive(archive, 'simple_seq2seq')

        result = predictor.predict_json(inputs)

        predicted_tokens = result.get("predicted_tokens")
        assert predicted_tokens is not None
        assert isinstance(predicted_tokens, list)
        assert all(isinstance(x, str) for x in predicted_tokens)
Пример #8
0
    def test_prediction_with_no_verbs(self):

        input1 = {"sentence": "Blah no verb sentence."}
        archive = load_archive('tests/fixtures/srl/serialization/model.tar.gz')
        predictor = Predictor.from_archive(archive, 'semantic-role-labeling')
        result = predictor.predict_json(input1)
        assert result == {'words': ['Blah', 'no', 'verb', 'sentence', '.'], 'verbs': []}

        input2 = {"sentence": "This sentence has a verb."}
        results = predictor.predict_batch_json([input1, input2])
        assert results[0] == {'words': ['Blah', 'no', 'verb', 'sentence', '.'], 'verbs': []}
        assert results[1] == {'words': ['This', 'sentence', 'has', 'a', 'verb', '.'],
                              'verbs': [{'verb': 'has', 'description': 'This sentence has a verb .',
                                         'tags': ['O', 'O', 'O', 'O', 'O', 'O']}]}
    def test_uses_named_inputs(self):
        inputs = {
                "sentence": "What a great test sentence.",
        }

        archive = load_archive('tests/fixtures/constituency_parser/serialization/model.tar.gz')
        predictor = Predictor.from_archive(archive, 'constituency-parser')
        result = predictor.predict_json(inputs)

        assert len(result["spans"]) == 21 # number of possible substrings of the sentence.
        assert len(result["class_probabilities"]) == 21
        assert result["tokens"] == ["What", "a", "great", "test", "sentence", "."]
        assert isinstance(result["trees"], str)

        for class_distribution in result["class_probabilities"]:
            self.assertAlmostEqual(sum(class_distribution), 1.0, places=4)
    def test_batch_prediction(self):
        batch_inputs = [
                {
                        "premise": "I always write unit tests for my code.",
                        "hypothesis": "One time I didn't write any unit tests for my code."
                },
                {
                        "premise": "I also write batched unit tests for throughput!",
                        "hypothesis": "Batch tests are slower."
                },
        ]

        archive = load_archive('tests/fixtures/decomposable_attention/serialization/model.tar.gz')
        predictor = Predictor.from_archive(archive, 'textual-entailment')
        results = predictor.predict_batch_json(batch_inputs)
        print(results)
        assert len(results) == 2

        for result in results:
            # Logits should be 3 floats that softmax to label_probs
            label_logits = result.get("label_logits")
            # Label probs should be 3 floats that sum to one
            label_probs = result.get("label_probs")
            assert label_probs is not None
            assert isinstance(label_probs, list)
            assert len(label_probs) == 3
            assert all(isinstance(x, float) for x in label_probs)
            assert all(x >= 0 for x in label_probs)
            assert sum(label_probs) == approx(1.0)

            assert label_logits is not None
            assert isinstance(label_logits, list)
            assert len(label_logits) == 3
            assert all(isinstance(x, float) for x in label_logits)

            exps = [math.exp(x) for x in label_logits]
            sumexps = sum(exps)
            for e, p in zip(exps, label_probs):
                assert e / sumexps == approx(p)
Пример #11
0
def _run(predictor: Predictor,
         input_file: IO,
         output_file: Optional[IO],
         batch_size: int,
         print_to_console: bool,
         cuda_device: int) -> None:

    def _run_predictor(batch_data):
        if len(batch_data) == 1:
            result = predictor.predict_json(batch_data[0], cuda_device)
            # Batch results return a list of json objects, so in
            # order to iterate over the result below we wrap this in a list.
            results = [result]
        else:
            results = predictor.predict_batch_json(batch_data, cuda_device)

        for model_input, output in zip(batch_data, results):
            string_output = predictor.dump_line(output)
            if print_to_console:
                print("input: ", model_input)
                print("prediction: ", string_output)
            if output_file:
                output_file.write(string_output)

    batch_json_data = []
    for line in input_file:
        if not line.isspace():
            # Collect batch size amount of data.
            json_data = predictor.load_line(line)
            batch_json_data.append(json_data)
            if len(batch_json_data) == batch_size:
                _run_predictor(batch_json_data)
                batch_json_data = []

    # We might not have a dataset perfectly divisible by the batch size,
    # so tidy up the scraps.
    if batch_json_data:
        _run_predictor(batch_json_data)
Пример #12
0
    def test_uses_named_inputs(self):
        inputs = {"document": "This is a single string document about a test. Sometimes it "
                              "contains coreferent parts."}
        archive = load_archive('tests/fixtures/coref/serialization/model.tar.gz')
        predictor = Predictor.from_archive(archive, 'coreference-resolution')
        result = predictor.predict_json(inputs)

        document = result["document"]
        assert document == ['This', 'is', 'a', 'single', 'string',
                            'document', 'about', 'a', 'test', '.', 'Sometimes',
                            'it', 'contains', 'coreferent', 'parts', '.']

        clusters = result["clusters"]
        assert isinstance(clusters, list)
        for cluster in clusters:
            assert isinstance(cluster, list)
            for mention in cluster:
                # Spans should be integer indices.
                assert isinstance(mention[0], int)
                assert isinstance(mention[1], int)
                # Spans should be inside document.
                assert 0 < mention[0] <= len(document)
                assert 0 < mention[1] <= len(document)
Пример #13
0
    def test_batch_prediction(self):
        inputs = [
                {
                        "question": "What kind of test succeeded on its first attempt?",
                        "passage": "One time I was writing a unit test, and it succeeded on the first attempt."
                },
                {
                        "question": "What kind of test succeeded on its first attempt at batch processing?",
                        "passage": "One time I was writing a unit test, and it always failed!"
                }
        ]

        archive = load_archive('tests/fixtures/bidaf/serialization/model.tar.gz')
        predictor = Predictor.from_archive(archive, 'machine-comprehension')

        results = predictor.predict_batch_json(inputs)
        assert len(results) == 2

        for result in results:
            best_span = result.get("best_span")
            best_span_str = result.get("best_span_str")
            start_probs = result.get("span_start_probs")
            end_probs = result.get("span_end_probs")
            assert best_span is not None
            assert isinstance(best_span, list)
            assert len(best_span) == 2
            assert all(isinstance(x, int) for x in best_span)
            assert best_span[0] <= best_span[1]

            assert isinstance(best_span_str, str)
            assert best_span_str != ""

            for probs in (start_probs, end_probs):
                assert probs is not None
                assert all(isinstance(x, float) for x in probs)
                assert sum(probs) == approx(1.0)
Пример #14
0
from allennlp.models.archival import load_archive
from allennlp.service.predictors import Predictor
import torch
#  pip install allennlp
#archive = load_archive("https://s3-us-west-2.amazonaws.com/allennlp/models/elmo-constituency-parser-2018.03.14.tar.gz", cuda_device=0)
archive = load_archive("elmo-constituency-parser-2018.03.14.tar.gz",
                       cuda_device=0)
predictor = Predictor.from_archive(archive, 'constituency-parser')

# Using readlines()
file1 = open('uniquesentences_unstripped.txt', 'r')
cands = file1.readlines()
to_predict = []
tags = []
batch_len = 50
start = 0
with torch.no_grad():
    #for index, cand in enumerate(cands):
    for index in range(start, len(cands)):
        cand = cands[index]
        curr_dict = {"sentence": cand.replace("_comma_", ", ")}
        to_predict.append(curr_dict)
        #z = predictor.predict_json()

        if len(to_predict) == batch_len:
            print(index)
            batch_answer = predictor.predict_batch_json(to_predict)
            to_predict = []
            for answer in batch_answer:
                tags.append(answer['trees'])
            torch.cuda.empty_cache()
Пример #15
0
from allennlp.models.archival import load_archive
from allennlp.service.predictors import Predictor
from field_classifier.classifier import Classifier
from field_classifier.predictor import ClassifierPredictor
from field_classifier.textcat import TextCatReader
import os
import json
import numpy as np

l0_archive = load_archive(
                os.path.abspath(os.path.join("data", "model_logs", "l0_model.tar.gz"))
            )
l0_predictor = Predictor.from_archive(l0_archive, 'classifier')
l1_archive = load_archive(
            os.path.abspath(os.path.join("data", "model_logs", "l1_model.tar.gz"))
        )
l1_predictor = Predictor.from_archive(l1_archive, 'classifier')
test_pubs = [{"title": "this is a test", "publication_id": 1}]
clf_output = []
l0_label_map = l0_archive.model.vocab.get_index_to_token_vocabulary("labels")
l1_label_map = l1_archive.model.vocab.get_index_to_token_vocabulary("labels")
for test_pub in test_pubs:
    l0_prediction = l0_predictor.predict_json({"title": test_pub['title']})
    l1_prediction = l1_predictor.predict_json({"title": test_pub['title']})
    pred = {}
    pred['publication_id'] = test_pub['publication_id']
    l0_score = np.max(l0_prediction['label_probs'])
    l1_score = np.max(l1_prediction['label_probs'])
    l0_field = l0_label_map[np.argmax(l0_prediction['label_probs'])]
    l1_field = l1_label_map[np.argmax(l1_prediction['label_probs'])]
    if l1_score > 0.4:
    def test_build_hierplane_tree(self):
        tree = Tree.fromstring("(S (NP (D the) (N dog)) (VP (V chased) (NP (D the) (N cat))))")
        archive = load_archive('tests/fixtures/constituency_parser/serialization/model.tar.gz')
        predictor = Predictor.from_archive(archive, 'constituency-parser')

        hierplane_tree = predictor._build_hierplane_tree(tree, 0, is_root=True)

        # pylint: disable=bad-continuation
        correct_tree = {
                'text': 'the dog chased the cat',
                "linkNameToLabel": LINK_TO_LABEL,
                "nodeTypeToStyle": NODE_TYPE_TO_STYLE,
                'root': {
                        'word': 'the dog chased the cat',
                        'nodeType': 'S',
                        'attributes': ['S'],
                        'link': 'S',
                        'children': [{
                                'word': 'the dog',
                                'nodeType': 'NP',
                                'attributes': ['NP'],
                                'link': 'NP',
                                'children': [{
                                        'word': 'the',
                                        'nodeType': 'D',
                                        'attributes': ['D'],
                                        'link': 'D'
                                        },
                                        {
                                        'word': 'dog',
                                        'nodeType': 'N',
                                        'attributes': ['N'],
                                        'link': 'N'}
                                        ]
                                },
                                {
                                'word': 'chased the cat',
                                'nodeType': 'VP',
                                'attributes': ['VP'],
                                'link': 'VP',
                                'children': [{
                                    'word': 'chased',
                                    'nodeType': 'V',
                                    'attributes': ['V'],
                                    'link': 'V'
                                    },
                                    {
                                    'word':
                                    'the cat',
                                    'nodeType': 'NP',
                                    'attributes': ['NP'],
                                    'link': 'NP',
                                    'children': [{
                                            'word': 'the',
                                            'nodeType': 'D',
                                            'attributes': ['D'],
                                            'link': 'D'
                                            },
                                            {
                                            'word': 'cat',
                                            'nodeType': 'N',
                                            'attributes': ['N'],
                                            'link': 'N'}
                                        ]
                                    }
                                ]
                            }
                        ]
                    }
                }
        # pylint: enable=bad-continuation
        assert correct_tree == hierplane_tree
Пример #17
0
    def setUp(self):
        super().setUp()

        archive = load_archive('tests/fixtures/bidaf/serialization/model.tar.gz')
        self.bidaf_predictor = Predictor.from_archive(archive, 'machine-comprehension')
Пример #18
0
 def _caching_prediction(model: Predictor, data: str) -> JsonDict:
     """
     Just a wrapper around ``model.predict_json`` that allows us to use a cache decorator.
     """
     return model.predict_json(json.loads(data))
Пример #19
0
import json


#Takes all sentences from the blind test jsonl and writes to a file blind_test.txt 
#and then calls pretrained model to get the Named entities and stores them in file in order 


f = open('NER_dev_txt','w')
from allennlp.service.predictors import Predictor
predictor = Predictor.from_path("https://s3-us-west-2.amazonaws.com/allennlp/models/ner-model-2018.04.26.tar.gz")
for line in open('NER_shared_dev','r'):
	line = line.strip()
	results = predictor.predict(sentence=line[0])
	for word, tag in zip(results["words"], results["tags"]):
		f.write(word+'\t'+tag+'\n')
	f.write('\n')
Пример #20
0
from allennlp.common.util import JsonDict
from allennlp.common.testing import AllenNlpTestCase
from allennlp.models.archival import load_archive
from allennlp.service.predictors import Predictor
from allennlp.service.server_flask import make_app
from allennlp.service.db import InMemoryDemoDatabase

TEST_ARCHIVE_FILES = {
        'machine-comprehension': 'tests/fixtures/bidaf/serialization/model.tar.gz',
        'semantic-role-labeling': 'tests/fixtures/srl/serialization/model.tar.gz',
        'textual-entailment': 'tests/fixtures/decomposable_attention/serialization/model.tar.gz'
}

PREDICTORS = {
        name: Predictor.from_archive(load_archive(archive_file),
                                     predictor_name=name)
        for name, archive_file in TEST_ARCHIVE_FILES.items()
}


class CountingPredictor(Predictor):
    """
    bogus predictor that just returns a copy of its inputs
    and also counts how many times it was called with a given input
    """
    # pylint: disable=abstract-method
    def __init__(self):                 # pylint: disable=super-init-not-called
        self.calls = defaultdict(int)

    def predict_json(self, inputs: JsonDict, cuda_device: int = -1) -> JsonDict:
        key = json.dumps(inputs)
Пример #21
0
from allennlp.service.predictors import Predictor
from allennlp.models import load_archive
from drqa import retriever

from doc.getDocuments import getDocsBatch, GoogleConfig

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('in_file', type=str)
    parser.add_argument('out_file', type=str)
    parser.add_argument('--config')
    parser.add_argument('--cuda_device', type=int, default=-1)

    args = parser.parse_args()

    with open(args.config) as f:
        config = json.load(f)

    ner_predictor = Predictor.from_path(
        "https://s3-us-west-2.amazonaws.com/allennlp/models/fine-grained-ner-model-elmo-2018.12.21.tar.gz"
    )

    google_config = GoogleConfig(**config['retrieval']['google'])
    ranker = retriever.get_class('tfidf')(
        tfidf_path=config['retrieval']['tfidf']['index'])

    with open(args.out_file, 'w') as outfile:
        for docs in getDocsBatch(args.in_file, google_config, ner_predictor,
                                 ranker):
            print(json.dumps(docs), file=outfile)
Пример #22
0
        'semantic-role-labeling': 'tests/fixtures/srl/model.tar.gz',
        'textual-entailment': 'tests/fixtures/decomposable_attention/model.tar.gz',
        'open-information-extraction': 'tests/fixtures/openie/model.tar.gz',
        'event2mind': 'tests/fixtures/event2mind/model.tar.gz'
}

PREDICTOR_NAMES = {
    'reading-comprehension': 'machine-comprehension',
        'semantic-role-labeling': 'semantic-role-labeling',
        'textual-entailment': 'textual-entailment',
        'open-information-extraction': 'open-information-extraction',
        'event2mind': 'event2mind'
}

PREDICTORS = {
        name: Predictor.from_archive(load_archive(archive_file),
                                     predictor_name=PREDICTOR_NAMES[name])
        for name, archive_file in TEST_ARCHIVE_FILES.items()
}

LIMITS = {
        'reading-comprehension': 311108,
        'semantic-role-labeling': 4590,
        'textual-entailment': 13129,
        'open-information-extraction': 19681,
        'event2mind': 11643
}


class CountingPredictor(Predictor):
    """
    bogus predictor that just returns a copy of its inputs
Пример #23
0
def get_predictor(args: argparse.Namespace) -> Predictor:
    archive = load_archive(args.archive_file)
    predictor = Predictor.from_archive(archive)
    return predictor
Пример #24
0
from allennlp.common.util import JsonDict
from allennlp.common.testing import AllenNlpTestCase
from allennlp.models.archival import load_archive
from allennlp.service.predictors import Predictor
from allennlp.service.server_flask import make_app
from allennlp.service.db import InMemoryDemoDatabase

TEST_ARCHIVE_FILES = {
        'machine-comprehension': AllenNlpTestCase.FIXTURES_ROOT / 'bidaf'/ 'serialization'/ 'model.tar.gz',
        'semantic-role-labeling': AllenNlpTestCase.FIXTURES_ROOT / 'srl'/ 'serialization'/ 'model.tar.gz',
        'textual-entailment': (AllenNlpTestCase.FIXTURES_ROOT/ 'decomposable_attention'/
                               'serialization'/ 'model.tar.gz')
}

PREDICTORS = {
        name: Predictor.from_archive(load_archive(archive_file),
                                     predictor_name=name)
        for name, archive_file in TEST_ARCHIVE_FILES.items()
}


class CountingPredictor(Predictor):
    """
    bogus predictor that just returns a copy of its inputs
    and also counts how many times it was called with a given input
    """
    # pylint: disable=abstract-method
    def __init__(self):                 # pylint: disable=super-init-not-called
        self.calls = defaultdict(int)

    def predict_json(self, inputs: JsonDict) -> JsonDict:
        key = json.dumps(inputs)
Пример #25
0
 def _caching_prediction(model: Predictor, data: str) -> JsonDict:
     """
     Just a wrapper around ``model.predict_json`` that allows us to use a cache decorator.
     """
     return model.predict_json(json.loads(data))
Пример #26
0
    def test_build_hierplane_tree(self):
        tree = Tree.fromstring("(S (NP (D the) (N dog)) (VP (V chased) (NP (D the) (N cat))))")
        archive = load_archive('tests/fixtures/constituency_parser/serialization/model.tar.gz')
        predictor = Predictor.from_archive(archive, 'constituency-parser')

        hierplane_tree = predictor._build_hierplane_tree(tree, 0, is_root=True)

        # pylint: disable=bad-continuation
        correct_tree = {
                'text': 'the dog chased the cat',
                'root': {
                        'word': 'the dog chased the cat',
                        'nodeType': 'S',
                        'attributes': ['S'],
                        'link': 'S',
                        'children': [{
                                'word': 'the dog',
                                'nodeType': 'NP',
                                'attributes': ['NP'],
                                'link': 'NP',
                                'children': [{
                                        'word': 'the',
                                        'nodeType': 'D',
                                        'attributes': ['D'],
                                        'link': 'D'
                                        },
                                        {
                                        'word': 'dog',
                                        'nodeType': 'N',
                                        'attributes': ['N'],
                                        'link': 'N'}
                                        ]
                                },
                                {
                                'word': 'chased the cat',
                                'nodeType': 'VP',
                                'attributes': ['VP'],
                                'link': 'VP',
                                'children': [{
                                    'word': 'chased',
                                    'nodeType': 'V',
                                    'attributes': ['V'],
                                    'link': 'V'
                                    },
                                    {
                                    'word':
                                    'the cat',
                                    'nodeType': 'NP',
                                    'attributes': ['NP'],
                                    'link': 'NP',
                                    'children': [{
                                            'word': 'the',
                                            'nodeType': 'D',
                                            'attributes': ['D'],
                                            'link': 'D'
                                            },
                                            {
                                            'word': 'cat',
                                            'nodeType': 'N',
                                            'attributes': ['N'],
                                            'link': 'N'}
                                        ]
                                    }
                                ]
                            }
                        ]
                    }
                }
        # pylint: enable=bad-continuation
        assert correct_tree == hierplane_tree
Пример #27
0
import sys
import json
#from label import get_batch_trees
from allennlp.models.archival import load_archive
from allennlp.service.predictors import Predictor

archive = load_archive("srl-model-2018.05.25.tar.gz", cuda_device=0)
srl = Predictor.from_archive(archive)
batch_size = 30
print ("Loading Done")

def get_srl(sentences, one_file, output):
	one_file["dlen"] = len(one_file["document"])
	sentences = [{"sentence": line} for line in sentences]
	srl_res = []; start_idx = 0
	while start_idx < len(sentences):
		batch_sentences = sentences[start_idx: min(start_idx + batch_size, len(sentences))]
		srl_res.extend(srl.predict_batch_json(batch_sentences))
		start_idx += batch_size
	if len(srl_res) > 1:
		one_file["srl_summary"] = srl_res[0]
		one_file["srl_document"] = srl_res[1:]
		output.write(json.dumps(one_file) + "\n")

if __name__ == '__main__':
	one_file = {}; one_file_sentences = []
	#input_filename = "./XSum.txt." + sys.argv[1]
	#xsum_trees_file = "./XSum.srl." + sys.argv[1]; existed_files = []
	input_filename = sys.argv[1]
	xsum_trees_file = sys.argv[1].split(".")[0] + ".srl"; existed_files = []