示例#1
0
 def test_truncation(self, lazy, use_contextualizer, include_raw_tokens,
                     max_instances):
     data_path = CustomTestCase.FIXTURES_ROOT / "data" / "syntactic_dependency" / "ptb.conllu"
     contextualizer_path = (CustomTestCase.FIXTURES_ROOT /
                            "contextualizers" / "precomputed_elmo" /
                            "elmo_layers_all.hdf5")
     # Set up contextualizer, if using.
     contextualizer = None
     if use_contextualizer:
         contextualizer = PrecomputedContextualizer(contextualizer_path)
     reader = SyntacticDependencyArcPredictionDatasetReader(
         negative_sampling_method="balanced",
         contextualizer=contextualizer,
         include_raw_tokens=include_raw_tokens,
         max_instances=max_instances,
         lazy=lazy)
     instances = list(reader.read(str(data_path)))
     num_total_instances = 2
     max_instances_to_num_instances = {
         int(1): 1,
         int(2): 2,
         0.5: int(num_total_instances * 0.5),
         0.75: int(num_total_instances * 0.75),
         1.0: num_total_instances
     }
     assert len(instances) == max_instances_to_num_instances[max_instances]
    def test_precomputed_contextualizer_all_elmo_layers_second_half(self):
        all_elmo_layers_path = self.model_paths / "elmo_layers_all.hdf5"
        num_sentences = 3

        # Test the first layer (index 0)
        for layer_num in [0, 1, 2]:
            all_elmo = PrecomputedContextualizer(all_elmo_layers_path,
                                                 layer_num=0)
            second_half_elmo = PrecomputedContextualizer(all_elmo_layers_path,
                                                         layer_num=0,
                                                         second_half_only=True)
            second_half_representations = second_half_elmo(
                [self.sentence_1, self.sentence_2, self.sentence_3])
            representations = all_elmo(
                [self.sentence_1, self.sentence_2, self.sentence_3])
            assert len(second_half_representations) == num_sentences
            assert len(representations) == num_sentences
            for second_half_repr, full_repr in zip(second_half_representations,
                                                   representations):
                assert_allclose(second_half_repr.cpu().numpy(),
                                full_repr[:, 512:].cpu().numpy(),
                                rtol=1e-5)
 def test_max_length(self, lazy, use_contextualizer, backward, max_length):
     # Set up contextualizer, if using.
     contextualizer = None
     if use_contextualizer:
         contextualizer = PrecomputedContextualizer(
             self.contextualizer_path)
     reader = LanguageModelingDatasetReader(max_length=max_length,
                                            backward=backward,
                                            contextualizer=contextualizer,
                                            lazy=lazy)
     instances = list(reader.read(str(self.data_path)))
     for instance in instances:
         fields = instance.fields
         assert len(
             [token.metadata
              for token in fields["raw_tokens"].field_list]) <= max_length
示例#4
0
 def test_truncation(self, lazy, use_contextualizer, max_instances):
     # Set up contextualizer, if using.
     contextualizer = None
     if use_contextualizer:
         contextualizer = PrecomputedContextualizer(
             self.contextualizer_path)
     reader = CoreferenceArcPredictionDatasetReader(
         contextualizer=contextualizer,
         max_instances=max_instances,
         lazy=lazy)
     instances = list(reader.read(str(self.data_path)))
     num_total_instances = 1
     max_instances_to_num_instances = {
         int(1): 1,
         int(2): 1,
         1.0: num_total_instances
     }
     assert len(instances) == max_instances_to_num_instances[max_instances]
 def test_truncation(self, lazy, use_contextualizer, max_instances):
     # Set up contextualizer, if using.
     contextualizer = None
     if use_contextualizer:
         contextualizer = PrecomputedContextualizer(self.contextualizer_path)
     reader = Conll2003NERDatasetReader(
         contextualizer=contextualizer,
         max_instances=max_instances,
         lazy=lazy)
     instances = list(reader.read(str(self.data_path)))
     num_total_instances = 3
     max_instances_to_num_instances = {
         int(1): 1,
         int(2): 2,
         0.5: int(num_total_instances * 0.5),
         0.75: int(num_total_instances * 0.75),
         1.0: num_total_instances}
     assert len(instances) == max_instances_to_num_instances[max_instances]
示例#6
0
 def test_truncation(self, lazy, use_contextualizer, directed,
                     include_raw_tokens, max_instances):
     # Set up contextualizer, if using.
     contextualizer = None
     if use_contextualizer:
         contextualizer = PrecomputedContextualizer(self.contextualizer_path)
     reader = SemanticDependencyArcClassificationDatasetReader(
         contextualizer=contextualizer,
         directed=directed,
         include_raw_tokens=include_raw_tokens,
         max_instances=max_instances,
         lazy=lazy)
     instances = list(reader.read(str(self.data_path)))
     num_total_instances = 2
     max_instances_to_num_instances = {
         int(1): 1,
         int(2): 2,
         0.5: int(num_total_instances * 0.5),
         0.75: int(num_total_instances * 0.75),
         1.0: num_total_instances}
     assert len(instances) == max_instances_to_num_instances[max_instances]
示例#7
0
 def test_truncation(self, lazy, use_contextualizer, include_raw_tokens,
                     mode, max_instances):
     # Set up contextualizer, if using.
     contextualizer = None
     if use_contextualizer:
         contextualizer = PrecomputedContextualizer(self.contextualizer_path)
     reader = AdpositionSupersenseTaggingDatasetReader(
         mode=mode,
         include_raw_tokens=include_raw_tokens,
         contextualizer=contextualizer,
         max_instances=max_instances,
         lazy=lazy)
     instances = list(reader.read(str(self.data_path)))
     num_total_instances = 3
     max_instances_to_num_instances = {
         int(1): 1,
         int(2): 2,
         0.5: int(num_total_instances * 0.5),
         0.75: int(num_total_instances * 0.75),
         1.0: num_total_instances}
     assert len(instances) == max_instances_to_num_instances[max_instances]
    def test_read_from_file(self, lazy, use_contextualizer):
        # Set up contextualizer, if using.
        contextualizer = None
        if use_contextualizer:
            contextualizer = PrecomputedContextualizer(self.contextualizer_path)
        reader = Conll2000ChunkingDatasetReader(contextualizer=contextualizer,
                                                lazy=lazy)
        instances = list(reader.read(str(self.data_path)))

        # First read instance
        instance = instances[0]
        fields = instance.fields
        assert [token.metadata for token in fields["raw_tokens"].field_list] == [
            'Confidence', 'in', 'the', 'pound', 'is', 'widely', 'expected', 'to', 'take', 'another',
            'sharp', 'dive', 'if', 'trade', 'figures', 'for', 'September', ',', 'due', 'for', 'release',
            'tomorrow', ',', 'fail', 'to', 'show', 'a', 'substantial', 'improvement', 'from', 'July', 'and',
            'August', "'s", 'near-record', 'deficits', '.']
        assert fields["labels"].labels == [
            'B-NP', 'B-PP', 'B-NP', 'I-NP', 'B-VP', 'I-VP', 'I-VP', 'I-VP', 'I-VP', 'B-NP', 'I-NP',
            'I-NP', 'B-SBAR', 'B-NP', 'I-NP', 'B-PP', 'B-NP', 'O', 'B-ADJP', 'B-PP', 'B-NP', 'B-NP', 'O',
            'B-VP', 'I-VP', 'I-VP', 'B-NP', 'I-NP', 'I-NP', 'B-PP', 'B-NP', 'I-NP', 'I-NP', 'B-NP', 'I-NP',
            'I-NP', 'O']
        if use_contextualizer:
            assert_allclose(
                fields["token_representations"].array[:, :2],
                np.array([[-1.258513, -1.0370141], [-0.05817706, -0.34756088],
                          [-0.06353955, -0.4938563], [-1.8024218, -1.0316596],
                          [-2.257492, -0.5222637], [-2.4755964, -0.24860916],
                          [-1.4937682, -1.3631285], [-1.5874765, 0.58332765],
                          [-0.6599875, -0.34025198], [-2.0129712, -1.7125161],
                          [-2.0061035, -2.0411587], [-2.111752, -0.17662084],
                          [-1.036485, -0.95351875], [-1.1027372, -0.8811481],
                          [-3.2971778, -0.80117923], [0.14612085, 0.2907345],
                          [-1.0681806, -0.11506036], [-0.89108264, -0.75120807],
                          [1.4598572, -1.5135024], [-0.19162387, -0.5925277],
                          [0.3152356, -0.67221195], [0.0429894, -0.3435017],
                          [-2.107685, 0.02174884], [-0.6821988, -1.6696682],
                          [-1.8384202, -0.22075021], [-1.033319, -1.1420834],
                          [-0.6265656, -0.8096429], [-1.0296414, -0.834536],
                          [-0.9962367, -0.09962708], [0.16024095, 0.43128008],
                          [-0.28929204, -1.4249148], [0.00278845, 0.6611263],
                          [0.50334555, -0.35937083], [1.147023, -0.6687972],
                          [0.77036375, -0.23009405], [-1.0500407, -0.02021815],
                          [-1.3865266, -0.85197794]]),
                rtol=1e-4)

        # Second read instance
        instance = instances[1]
        fields = instance.fields
        assert [token.metadata for token in fields["raw_tokens"].field_list] == [
            'Chancellor', 'of', 'the', 'Exchequer', 'Nigel', 'Lawson', "'s", 'restated', 'commitment',
            'to', 'a', 'firm', 'monetary', 'policy', 'has', 'helped', 'to', 'prevent', 'a', 'freefall',
            'in', 'sterling', 'over', 'the', 'past', 'week', '.']
        assert fields["labels"].labels == [
            'O', 'B-PP', 'B-NP', 'I-NP', 'B-NP', 'I-NP', 'B-NP', 'I-NP', 'I-NP', 'B-PP', 'B-NP', 'I-NP',
            'I-NP', 'I-NP', 'B-VP', 'I-VP', 'I-VP', 'I-VP', 'B-NP', 'I-NP', 'B-PP', 'B-NP', 'B-PP', 'B-NP',
            'I-NP', 'I-NP', 'O']
        if use_contextualizer:
            assert_allclose(
                fields["token_representations"].array[:, :2],
                np.array([[-2.4198189, -1.606727], [-1.2829566, -0.8627869],
                          [0.44851404, -0.8752346], [-1.2563871, -0.9349538],
                          [-2.3628764, 0.61601055], [-2.5294414, -0.8608694],
                          [-1.0940088, 0.36207741], [-1.3594072, -0.44920856],
                          [-2.1531758, -0.72061414], [-0.8710089, -0.01074989],
                          [1.1241767, 0.27293408], [-0.20302701, -0.3308825],
                          [-1.577058, -0.9223033], [-3.2015433, -1.4600563],
                          [-1.8444527, -0.3150784], [-1.4566939, -0.18184504],
                          [-2.097283, 0.02337693], [-1.4785317, 0.2928276],
                          [-0.47859374, -0.46162963], [-1.4853759, 0.30421454],
                          [0.25645372, -0.12769623], [-1.311865, -1.1461734],
                          [-0.75683033, -0.37533844], [-0.13498223, 1.1350582],
                          [0.3819366, 0.2941534], [-1.2304902, -0.67328024],
                          [-1.2757114, -0.43673947]]),
                rtol=1e-4)
示例#9
0
    def test_read_from_file(self, lazy, use_contextualizer, include_raw_tokens):
        # Set up contextualizer, if using.
        contextualizer = None
        if use_contextualizer:
            contextualizer = PrecomputedContextualizer(self.contextualizer_path)
        reader = SemanticDependencyArcClassificationDatasetReader(
            contextualizer=contextualizer,
            include_raw_tokens=include_raw_tokens,
            lazy=lazy)
        instances = list(reader.read(str(self.data_path)))

        first_sentence = ['Pierre', 'Vinken', ',', '61', 'years', 'old',
                          ',', 'will', 'join', 'the', 'board', 'as',
                          'a', 'nonexecutive', 'director', 'Nov.', '29', '.']
        # First read instance
        instance = instances[0]
        fields = instance.fields
        if include_raw_tokens:
            assert [token.metadata for token in fields["raw_tokens"].field_list] == first_sentence
        assert_allclose(fields["arc_indices"].array, np.array([
            [1, 0], [1, 5], [1, 8], [4, 3],
            [5, 4], [8, 11], [8, 16], [10, 8],
            [10, 9], [14, 11], [14, 12], [14, 13],
            [16, 15]]))
        assert fields["labels"].labels == [
            'compound', 'ARG1', 'ARG1', 'ARG1', 'measure', 'ARG1', 'loc',
            'ARG2', 'BV', 'ARG2', 'BV', 'ARG1', 'of']
        if use_contextualizer:
            assert fields["token_representations"].array.shape[0] == len(first_sentence)
            assert_allclose(
                fields["token_representations"].array[:, :2],
                np.array([[-2.2218986, -2.451917],
                          [-0.7551352, -0.7919447],
                          [-0.9525466, -0.985806],
                          [-1.4567664, 0.1637534],
                          [0.21862003, -0.00878072],
                          [-0.7341557, -0.57776076],
                          [-1.6816409, -1.2562131],
                          [-0.9079286, 0.15607932],
                          [-0.44011104, -0.3434037],
                          [0.56341827, -0.97181696],
                          [-0.7166206, -0.33435553],
                          [-0.14051008, -1.260754],
                          [0.42426592, -0.35762805],
                          [-1.0153385, -0.7949409],
                          [-0.7065723, 0.05164766],
                          [-0.11002721, -0.11039695],
                          [0.41112524, 0.27260625],
                          [-1.0369725, -0.6278316]]),
                rtol=1e-4)

        # Test the second sentence
        second_sentence = ['Mr.', 'Vinken', 'is', 'chairman', 'of',
                           'Elsevier', 'N.V.', ',', 'the', 'Dutch',
                           'publishing', 'group', '.']
        instance = instances[1]
        fields = instance.fields
        if include_raw_tokens:
            assert [token.metadata for token in fields["raw_tokens"].field_list] == second_sentence
        assert_allclose(fields["arc_indices"].array, np.array([
            [1, 0], [1, 2], [3, 2], [3, 4], [5, 4], [5, 6], [5, 11], [11, 8],
            [11, 9], [11, 10]]))
        assert fields["labels"].labels == [
            'compound', 'ARG1', 'ARG2', 'ARG1', 'ARG2', 'compound',
            'appos', 'BV', 'ARG1', 'compound']
        if use_contextualizer:
            assert fields["token_representations"].array.shape[0] == len(second_sentence)
            assert_allclose(
                fields["token_representations"].array[:, :2],
                np.array([[0.7069745, -0.5422047],
                          [-1.8885247, -1.4432149],
                          [-1.7570897, -1.1201282],
                          [-1.2288755, -0.8003752],
                          [-0.08672556, -0.99020493],
                          [-0.6416313, -1.147429],
                          [-0.7924329, 0.14809224],
                          [-1.0645872, -1.0505759],
                          [0.69725895, -0.8735154],
                          [0.27878952, -0.339666],
                          [0.20708983, -0.7103262],
                          [-1.1115363, -0.16295972],
                          [-1.3495405, -0.8656957]]),
                rtol=1e-4)
示例#10
0
    def test_read_from_file(self, lazy, use_contextualizer, ancestor):
        # Set up contextualizer, if using.
        contextualizer = None
        if use_contextualizer:
            contextualizer = PrecomputedContextualizer(
                self.contextualizer_path)
        reader = ConstituencyAncestorPredictionDatasetReader(
            ancestor=ancestor, contextualizer=contextualizer, lazy=lazy)
        instances = list(reader.read(str(self.data_path)))

        # First read instance
        instance = instances[0]
        fields = instance.fields
        assert [token.metadata
                for token in fields["raw_tokens"].field_list] == [
                    'In', 'an', 'Oct.', '19', 'review', 'of', '``', 'The',
                    'Misanthrope', "''", 'at', 'Chicago', "'s", 'Goodman',
                    'Theatre', '(', '``', 'Revitalized', 'Classics', 'Take',
                    'the', 'Stage', 'in', 'Windy', 'City', ',', "''",
                    'Leisure', '&', 'Arts', ')', ',', 'the', 'role', 'of',
                    'Celimene', ',', 'played', 'by', 'Kim', 'Cattrall', ',',
                    'was', 'mistakenly', 'attributed', 'to', 'Christina',
                    'Haag', '.'
                ]
        assert len([
            token.metadata for token in fields["raw_tokens"].field_list
        ]) == len(fields["labels"].labels)
        if ancestor == "parent":
            assert fields["labels"].labels == [
                'PP', 'NP', 'NP', 'NP', 'NP', 'PP', 'NP', 'NP', 'NP', 'NP',
                'PP', 'NP', 'NP', 'NP', 'NP', 'PRN', 'PRN', 'NP', 'NP', 'VP',
                'NP', 'NP', 'PP', 'NP', 'NP', 'PRN', 'PRN', 'NP', 'NP', 'NP',
                'PRN', 'S', 'NP', 'NP', 'PP', 'NP', 'NP', 'VP', 'PP', 'NP',
                'NP', 'NP', 'VP', 'ADVP', 'VP', 'PP', 'NP', 'NP', 'S'
            ]
        elif ancestor == "grandparent":
            assert fields["labels"].labels == [
                'S', 'NP', 'NP', 'NP', 'NP', 'NP', 'PP', 'NP', 'NP', 'PP',
                'NP', 'NP', 'NP', 'PP', 'PP', 'NP', 'NP', 'S', 'S', 'S', 'VP',
                'VP', 'VP', 'PP', 'PP', 'NP', 'NP', 'PRN', 'PRN', 'PRN', 'NP',
                'None', 'NP', 'NP', 'NP', 'PP', 'S', 'NP', 'VP', 'PP', 'PP',
                'S', 'S', 'VP', 'VP', 'VP', 'PP', 'PP', 'None'
            ]
        else:
            # ancestor is greatgrandparent
            assert fields["labels"].labels == [
                'None', 'PP', 'PP', 'PP', 'PP', 'PP', 'NP', 'PP', 'PP', 'NP',
                'PP', 'PP', "PP", 'NP', 'NP', 'PP', 'PP', 'PRN', 'PRN', 'PRN',
                'S', 'S', 'S', 'VP', 'VP', 'PP', "PP", 'NP', 'NP', 'NP', 'PP',
                'None', 'NP', 'NP', 'NP', 'NP', 'None', 'S', 'NP', 'VP', 'VP',
                'None', 'None', 'VP', 'S', 'VP', 'VP', 'VP', 'None'
            ]
        if use_contextualizer:
            assert_allclose(fields["token_representations"].array[:, :2],
                            np.array([[0.7541596, 0.36606207],
                                      [-0.3912218, 0.2728929],
                                      [0.4532569, 0.59446496],
                                      [-0.034773, 0.6178972],
                                      [0.05996126, -0.21075758],
                                      [-0.00675234, -0.19188942],
                                      [-0.25371405, -0.98044276],
                                      [0.55180097, -1.3375797],
                                      [-0.76439965, -0.8849516],
                                      [-0.1852389, -0.76670283],
                                      [-0.6538293, -2.109323],
                                      [0.11706313, -0.14159685],
                                      [-0.26565668, 0.08206904],
                                      [-1.0511935, -0.28469092],
                                      [0.22915375, 0.2485466],
                                      [1.4214072, 0.02810444],
                                      [0.7648947, -1.3637407],
                                      [-0.01231889, -0.02892348],
                                      [-0.1330762, 0.0219465],
                                      [0.8961761, -1.2976432],
                                      [0.83349395, -1.8242016],
                                      [0.15122458, -0.9597366],
                                      [0.7570322, -0.73728824],
                                      [-0.04838032, -0.8663991],
                                      [0.32632858, -0.5200325],
                                      [0.7823914, -1.020006],
                                      [0.5874542, -1.020459],
                                      [-0.4918128, -0.85094],
                                      [-0.24947, -0.20599724],
                                      [-1.4349735, 0.19630724],
                                      [-0.49690107, -0.58586204],
                                      [0.06130999, -0.14850587],
                                      [0.66610545, -0.06235093],
                                      [-0.29052478, 0.40215907],
                                      [0.24728307, 0.23677489],
                                      [-0.05339833, 0.22958362],
                                      [-0.44152835, -0.58153844],
                                      [0.4723678, -0.06656095],
                                      [0.32210657, -0.03144099],
                                      [0.6663985, 0.39230958],
                                      [0.57831913, 0.19480982],
                                      [-0.96823174, 0.00828598],
                                      [-0.7640736, 0.00441009],
                                      [-0.5589211, 0.17509514],
                                      [0.01523143, -0.7975017],
                                      [0.3268571, -0.1870772],
                                      [1.4704096, 0.8472788],
                                      [0.23348817, -0.48313117],
                                      [-0.57006484, -0.77375746]]),
                            rtol=1e-3)

        # Second read instance
        instance = instances[1]
        fields = instance.fields
        assert [token.metadata for token in fields["raw_tokens"].field_list
                ] == ['Ms.', 'Haag', 'plays', 'Elianti', '.']
        if ancestor == "parent":
            assert fields["labels"].labels == ['NP', 'NP', 'VP', 'NP', 'S']
        elif ancestor == "grandparent":
            assert fields["labels"].labels == ['S', 'S', 'S', 'VP', 'None']
        else:
            # ancestor is greatgrandparent
            assert fields["labels"].labels == [
                'None', 'None', 'None', 'S', 'None'
            ]
        if use_contextualizer:
            assert_allclose(fields["token_representations"].array[:, :2],
                            np.array([[0.6757653, -0.80925614],
                                      [-1.9424553, -1.0854281],
                                      [-0.09960067, 0.17525218],
                                      [0.09222834, -0.8534998],
                                      [-0.66507375, -0.5633631]]),
                            rtol=1e-3)
示例#11
0
    def test_read_from_file(self, lazy, use_contextualizer,
                            include_raw_tokens):
        # Set up contextualizer, if using.
        contextualizer = None
        if use_contextualizer:
            contextualizer = PrecomputedContextualizer(
                self.contextualizer_path)
        reader = SyntacticDependencyArcClassificationDatasetReader(
            contextualizer=contextualizer,
            include_raw_tokens=include_raw_tokens,
            lazy=lazy)
        instances = list(reader.read(str(self.data_path)))
        assert len(instances) == 2

        first_sentence = [
            'In', 'an', 'Oct.', '19', 'review', 'of', '``', 'The',
            'Misanthrope', "''", 'at', 'Chicago', "'s", 'Goodman', 'Theatre',
            '(', '``', 'Revitalized', 'Classics', 'Take', 'the', 'Stage', 'in',
            'Windy', 'City', ',', "''", 'Leisure', '&', 'Arts', ')', ',',
            'the', 'role', 'of', 'Celimene', ',', 'played', 'by', 'Kim',
            'Cattrall', ',', 'was', 'mistakenly', 'attributed', 'to',
            'Christina', 'Haag', '.'
        ]
        # First read instance
        instance = instances[0]
        fields = instance.fields
        if include_raw_tokens:
            assert [
                token.metadata for token in fields["raw_tokens"].field_list
            ] == first_sentence
        assert_allclose(
            fields["arc_indices"].array,
            np.array([
                (0, 4), (1, 4), (2, 4), (3, 4), (4, 44), (5, 8), (6, 8),
                (7, 8), (8, 4), (9, 8), (10, 14), (11, 14), (12, 11), (13, 14),
                (14, 8), (15, 19), (16, 19), (17, 18), (18, 19), (19, 4),
                (20, 21), (21, 19), (22, 24), (23, 24), (24, 19), (25, 19),
                (26, 19), (27, 19), (28, 27), (29, 27), (30, 19), (31, 44),
                (32, 33), (33, 44), (34, 35), (35, 33), (36, 33), (37, 33),
                (38, 40), (39, 40), (40, 37), (41, 33), (42, 44), (43, 44),
                (45, 47), (46, 47), (47, 44), (48, 44)
            ]))
        assert fields["labels"].labels == [
            'case', 'det', 'compound', 'nummod', 'nmod', 'case', 'punct',
            'det', 'nmod', 'punct', 'case', 'nmod:poss', 'case', 'compound',
            'nmod', 'punct', 'punct', 'amod', 'nsubj', 'dep', 'det', 'dobj',
            'case', 'compound', 'nmod', 'punct', 'punct', 'dep', 'cc', 'conj',
            'punct', 'punct', 'det', 'nsubjpass', 'case', 'nmod', 'punct',
            'acl', 'case', 'compound', 'nmod', 'punct', 'auxpass', 'advmod',
            'case', 'compound', 'nmod', 'punct'
        ]
        if use_contextualizer:
            assert fields["token_representations"].array.shape[0] == len(
                first_sentence)
            assert_allclose(fields["token_representations"].array[:, :2],
                            np.array([[0.7541596, 0.36606207],
                                      [-0.3912218, 0.2728929],
                                      [0.4532569, 0.59446496],
                                      [-0.034773, 0.6178972],
                                      [0.05996126, -0.21075758],
                                      [-0.00675234, -0.19188942],
                                      [-0.25371405, -0.98044276],
                                      [0.55180097, -1.3375797],
                                      [-0.76439965, -0.8849516],
                                      [-0.1852389, -0.76670283],
                                      [-0.6538293, -2.109323],
                                      [0.11706313, -0.14159685],
                                      [-0.26565668, 0.08206904],
                                      [-1.0511935, -0.28469092],
                                      [0.22915375, 0.2485466],
                                      [1.4214072, 0.02810444],
                                      [0.7648947, -1.3637407],
                                      [-0.01231889, -0.02892348],
                                      [-0.1330762, 0.0219465],
                                      [0.8961761, -1.2976432],
                                      [0.83349395, -1.8242016],
                                      [0.15122458, -0.9597366],
                                      [0.7570322, -0.73728824],
                                      [-0.04838032, -0.8663991],
                                      [0.32632858, -0.5200325],
                                      [0.7823914, -1.020006],
                                      [0.5874542, -1.020459],
                                      [-0.4918128, -0.85094],
                                      [-0.24947, -0.20599724],
                                      [-1.4349735, 0.19630724],
                                      [-0.49690107, -0.58586204],
                                      [0.06130999, -0.14850587],
                                      [0.66610545, -0.06235093],
                                      [-0.29052478, 0.40215907],
                                      [0.24728307, 0.23677489],
                                      [-0.05339833, 0.22958362],
                                      [-0.44152835, -0.58153844],
                                      [0.4723678, -0.06656095],
                                      [0.32210657, -0.03144099],
                                      [0.6663985, 0.39230958],
                                      [0.57831913, 0.19480982],
                                      [-0.96823174, 0.00828598],
                                      [-0.7640736, 0.00441009],
                                      [-0.5589211, 0.17509514],
                                      [0.01523143, -0.7975017],
                                      [0.3268571, -0.1870772],
                                      [1.4704096, 0.8472788],
                                      [0.23348817, -0.48313117],
                                      [-0.57006484, -0.77375746]]),
                            rtol=1e-3)

        # second read instance
        instance = instances[1]
        fields = instance.fields
        second_sentence = ['Ms.', 'Haag', 'plays', 'Elianti', '.']
        if include_raw_tokens:
            assert [
                token.metadata for token in fields["raw_tokens"].field_list
            ] == second_sentence
        assert_allclose(fields["arc_indices"].array,
                        np.array([[0, 1], [1, 2], [3, 2], [4, 2]]))
        assert fields["labels"].labels == [
            'compound', 'nsubj', 'dobj', 'punct'
        ]
        if use_contextualizer:
            assert fields["token_representations"].array.shape[0] == len(
                second_sentence)
            assert_allclose(fields["token_representations"].array[:, :2],
                            np.array([[0.6757653, -0.80925614],
                                      [-1.9424553, -1.0854281],
                                      [-0.09960067, 0.17525218],
                                      [0.09222834, -0.8534998],
                                      [-0.66507375, -0.5633631]]),
                            rtol=1e-3)
示例#12
0
    def test_read_from_file(self, lazy, use_contextualizer, include_raw_tokens, mode):
        # Set up contextualizer, if using.
        contextualizer = None
        if use_contextualizer:
            contextualizer = PrecomputedContextualizer(self.contextualizer_path)
        reader = AdpositionSupersenseTaggingDatasetReader(
            mode=mode,
            include_raw_tokens=include_raw_tokens,
            contextualizer=contextualizer,
            lazy=lazy)
        instances = list(reader.read(str(self.data_path)))
        assert len(instances) == 3

        # First read instance
        instance = instances[0]
        fields = instance.fields
        if include_raw_tokens:
            assert [token.metadata for token in fields["raw_tokens"].field_list] == [
                'Have', 'a', 'real', 'mechanic', 'check', 'before', 'you', 'buy', '!!!!']
        assert_allclose(fields["label_indices"].array, np.array([5]))
        if use_contextualizer:
            assert_allclose(
                fields["token_representations"].array[:, :2],
                np.array([[1.0734302, 0.01504201], [0.22717158, -0.07871069],
                          [-0.04494515, -1.5083733], [-0.99489564, -0.6427601],
                          [-0.6134137, 0.21780868], [-0.72287357, 0.00998633],
                          [-2.4904299, -0.49546975], [-1.2544577, -0.3230043],
                          [-0.33858764, -0.4852887]]),
                rtol=1e-4)
        assert fields["labels"].labels == ['p.Time']

        # Second read instance
        instance = instances[1]
        fields = instance.fields
        if include_raw_tokens:
            assert [token.metadata for token in fields["raw_tokens"].field_list] == [
                "Very", "good", "with", "my", "5", "year", "old", "daughter", "."]

        assert_allclose(fields["label_indices"].array, np.array([3]))
        if use_contextualizer:
            assert_allclose(
                fields["token_representations"].array[:, :2],
                np.array([[1.6298808, 1.1453593], [0.5952766, -0.9594625],
                          [0.97710425, -0.19498791], [0.01909786, -0.9163474],
                          [-0.06802255, -1.0863125], [-1.223998, 0.2686447],
                          [-0.3791673, -0.71468884], [-1.1185161, -1.2551097],
                          [-1.3264754, -0.55683744]]),
                rtol=1e-4)
        if mode == "role":
            assert fields["labels"].labels == ["p.SocialRel"]
        else:
            assert fields["labels"].labels == ["p.Gestalt"]

        # Third read instance
        instance = instances[2]
        fields = instance.fields
        if include_raw_tokens:
            assert [token.metadata for token in fields["raw_tokens"].field_list] == [
                'After', 'firing', 'this', 'company', 'my', 'next', 'pool',
                'service', 'found', 'the', 'filters', 'had', 'not', 'been',
                'cleaned', 'as', 'they', 'should', 'have', 'been', '.']

        assert_allclose(fields["label_indices"].array, np.array([0, 4, 15]))
        if use_contextualizer:
            assert_allclose(
                fields["token_representations"].array[:, :2],
                np.array([[0.11395034, -0.20784551], [0.15244134, -0.55084044],
                          [0.7266098, -0.6036074], [0.6412934, -0.3448509],
                          [0.57441425, -0.73139024], [0.5528518, -0.19321561],
                          [0.5668789, -0.20008], [0.5737846, -1.2053688],
                          [-0.3721336, -0.8618743], [0.59511614, -0.18732266],
                          [0.72423995, 0.4306308], [0.96764237, 0.21643513],
                          [-0.40797114, 0.67060745], [-0.04556704, 0.5140952],
                          [0.422831, 0.32669073], [0.6339446, -0.44046107],
                          [-0.19289528, -0.18465114], [0.09728494, -1.0248029],
                          [0.791354, -0.2504376], [0.7951995, -0.7192571],
                          [-0.345582, -0.8098198]]),
                rtol=1e-4)
        if mode == "role":
            assert fields["labels"].labels == ["p.Time", "p.OrgRole", "p.ComparisonRef"]
        else:
            assert fields["labels"].labels == ["p.Time", "p.Gestalt", "p.ComparisonRef"]
    def test_read_from_file(self, lazy, use_contextualizer):
        # Set up contextualizer, if using.
        contextualizer = None
        if use_contextualizer:
            contextualizer = PrecomputedContextualizer(
                self.contextualizer_path)
        reader = ConllXPOSDatasetReader(contextualizer=contextualizer,
                                        lazy=lazy)
        instances = list(reader.read(str(self.data_path)))

        # First read instance
        instance = instances[0]
        fields = instance.fields
        assert [token.metadata
                for token in fields["raw_tokens"].field_list] == [
                    'Pierre', 'Vinken', ',', '61', 'years', 'old', ',', 'will',
                    'join', 'the', 'board', 'as', 'a', 'nonexecutive',
                    'director', 'Nov.', '29', '.'
                ]
        assert fields["labels"].labels == [
            'NNP', 'NNP', ',', 'CD', 'NNS', 'JJ', ',', 'MD', 'VB', 'DT', 'NN',
            'IN', 'DT', 'JJ', 'NN', 'NNP', 'CD', '.'
        ]
        if use_contextualizer:
            assert_allclose(fields["token_representations"].array[:, :2],
                            np.array([[-2.2218986, -2.451917],
                                      [-0.7551352, -0.7919447],
                                      [-0.9525466, -0.985806],
                                      [-1.4567664, 0.1637534],
                                      [0.21862003, -0.00878072],
                                      [-0.7341557, -0.57776076],
                                      [-1.6816409, -1.2562131],
                                      [-0.9079286, 0.15607932],
                                      [-0.44011104, -0.3434037],
                                      [0.56341827, -0.97181696],
                                      [-0.7166206, -0.33435553],
                                      [-0.14051008, -1.260754],
                                      [0.42426592, -0.35762805],
                                      [-1.0153385, -0.7949409],
                                      [-0.7065723, 0.05164766],
                                      [-0.11002721, -0.11039695],
                                      [0.41112524, 0.27260625],
                                      [-1.0369725, -0.6278316]]),
                            rtol=1e-4)

        # Second read instance
        instance = instances[1]
        fields = instance.fields
        assert [token.metadata
                for token in fields["raw_tokens"].field_list] == [
                    'Mr.', 'Vinken', 'is', 'chairman', 'of', 'Elsevier',
                    'N.V.', ',', 'the', 'Dutch', 'publishing', 'group', '.'
                ]
        assert fields["labels"].labels == [
            'NNP', 'NNP', 'VBZ', 'NN', 'IN', 'NNP', 'NNP', ',', 'DT', 'NNP',
            'VBG', 'NN', '.'
        ]
        if use_contextualizer:
            assert_allclose(fields["token_representations"].array[:, :2],
                            np.array([[0.7069745, -0.5422047],
                                      [-1.8885247, -1.4432149],
                                      [-1.7570897, -1.1201282],
                                      [-1.2288755, -0.8003752],
                                      [-0.08672556, -0.99020493],
                                      [-0.6416313, -1.147429],
                                      [-0.7924329, 0.14809224],
                                      [-1.0645872, -1.0505759],
                                      [0.69725895, -0.8735154],
                                      [0.27878952, -0.339666],
                                      [0.20708983, -0.7103262],
                                      [-1.1115363, -0.16295972],
                                      [-1.3495405, -0.8656957]]),
                            rtol=1e-4)
示例#14
0
    def test_read_from_file_balanced_negative_sampling(self, lazy,
                                                       use_contextualizer,
                                                       include_raw_tokens):
        # Set up contextualizer, if using.
        contextualizer = None
        if use_contextualizer:
            contextualizer = PrecomputedContextualizer(
                self.contextualizer_path)
        reader = SyntacticDependencyArcPredictionDatasetReader(
            negative_sampling_method="balanced",
            contextualizer=contextualizer,
            include_raw_tokens=include_raw_tokens,
            lazy=lazy)
        instances = list(reader.read(str(self.data_path)))
        first_sentence = [
            'In', 'an', 'Oct.', '19', 'review', 'of', '``', 'The',
            'Misanthrope', "''", 'at', 'Chicago', "'s", 'Goodman', 'Theatre',
            '(', '``', 'Revitalized', 'Classics', 'Take', 'the', 'Stage', 'in',
            'Windy', 'City', ',', "''", 'Leisure', '&', 'Arts', ')', ',',
            'the', 'role', 'of', 'Celimene', ',', 'played', 'by', 'Kim',
            'Cattrall', ',', 'was', 'mistakenly', 'attributed', 'to',
            'Christina', 'Haag', '.'
        ]
        # First read instance
        instance = instances[0]
        fields = instance.fields
        if include_raw_tokens:
            assert [
                token.metadata for token in fields["raw_tokens"].field_list
            ] == first_sentence
        assert_allclose(
            fields["arc_indices"].array,
            np.array([[0, 4], [0, 26], [1, 4], [1, 28], [2, 4], [2, 3], [3, 4],
                      [3, 18], [4, 44], [4, 33], [5, 8], [5, 33], [6, 8],
                      [6, 27], [7, 8],
                      [7, 21], [8, 4], [8, 32], [9, 8], [9, 24], [10, 14],
                      [10, 39], [11, 14], [11, 15], [12, 11], [12,
                                                               34], [13, 14],
                      [13, 8], [14, 8], [14, 20], [15, 19], [15, 8], [16, 19],
                      [16, 6], [17, 18], [17, 41], [18, 19], [18, 16], [19, 4],
                      [19, 36], [20, 21], [20, 47], [21, 19], [21,
                                                               40], [22, 24],
                      [22, 9], [23, 24], [23, 19], [24, 19], [24, 6], [25, 19],
                      [25, 48], [26, 19], [26, 4], [27, 19], [27,
                                                              45], [28, 27],
                      [28, 21], [29, 27], [29, 32], [30, 19], [30,
                                                               37], [31, 44],
                      [31, 6], [32, 33], [32, 22], [33, 44], [33,
                                                              27], [34, 35],
                      [34, 20], [35, 33], [35, 41], [36, 33], [36,
                                                               42], [37, 33],
                      [37, 13], [38, 40], [38, 35], [39, 40], [39,
                                                               30], [40, 37],
                      [40, 28], [41, 33], [41, 34], [42, 44], [42,
                                                               16], [43, 44],
                      [43, 3], [45, 47], [45, 35], [46, 47], [46, 0], [47, 44],
                      [47, 5], [48, 44], [48, 47]]))
        assert fields["labels"].labels == [
            '1', '0', '1', '0', '1', '0', '1', '0', '1', '0', '1', '0', '1',
            '0', '1', '0', '1', '0', '1', '0', '1', '0', '1', '0', '1', '0',
            '1', '0', '1', '0', '1', '0', '1', '0', '1', '0', '1', '0', '1',
            '0', '1', '0', '1', '0', '1', '0', '1', '0', '1', '0', '1', '0',
            '1', '0', '1', '0', '1', '0', '1', '0', '1', '0', '1', '0', '1',
            '0', '1', '0', '1', '0', '1', '0', '1', '0', '1', '0', '1', '0',
            '1', '0', '1', '0', '1', '0', '1', '0', '1', '0', '1', '0', '1',
            '0', '1', '0', '1', '0'
        ]
        if use_contextualizer:
            assert fields["token_representations"].array.shape[0] == len(
                first_sentence)
            assert_allclose(fields["token_representations"].array[:, :2],
                            np.array([[0.7541596, 0.36606207],
                                      [-0.3912218, 0.2728929],
                                      [0.4532569, 0.59446496],
                                      [-0.034773, 0.6178972],
                                      [0.05996126, -0.21075758],
                                      [-0.00675234, -0.19188942],
                                      [-0.25371405, -0.98044276],
                                      [0.55180097, -1.3375797],
                                      [-0.76439965, -0.8849516],
                                      [-0.1852389, -0.76670283],
                                      [-0.6538293, -2.109323],
                                      [0.11706313, -0.14159685],
                                      [-0.26565668, 0.08206904],
                                      [-1.0511935, -0.28469092],
                                      [0.22915375, 0.2485466],
                                      [1.4214072, 0.02810444],
                                      [0.7648947, -1.3637407],
                                      [-0.01231889, -0.02892348],
                                      [-0.1330762, 0.0219465],
                                      [0.8961761, -1.2976432],
                                      [0.83349395, -1.8242016],
                                      [0.15122458, -0.9597366],
                                      [0.7570322, -0.73728824],
                                      [-0.04838032, -0.8663991],
                                      [0.32632858, -0.5200325],
                                      [0.7823914, -1.020006],
                                      [0.5874542, -1.020459],
                                      [-0.4918128, -0.85094],
                                      [-0.24947, -0.20599724],
                                      [-1.4349735, 0.19630724],
                                      [-0.49690107, -0.58586204],
                                      [0.06130999, -0.14850587],
                                      [0.66610545, -0.06235093],
                                      [-0.29052478, 0.40215907],
                                      [0.24728307, 0.23677489],
                                      [-0.05339833, 0.22958362],
                                      [-0.44152835, -0.58153844],
                                      [0.4723678, -0.06656095],
                                      [0.32210657, -0.03144099],
                                      [0.6663985, 0.39230958],
                                      [0.57831913, 0.19480982],
                                      [-0.96823174, 0.00828598],
                                      [-0.7640736, 0.00441009],
                                      [-0.5589211, 0.17509514],
                                      [0.01523143, -0.7975017],
                                      [0.3268571, -0.1870772],
                                      [1.4704096, 0.8472788],
                                      [0.23348817, -0.48313117],
                                      [-0.57006484, -0.77375746]]),
                            rtol=1e-3)

        # Skip to next sentence
        second_sentence = ['Ms.', 'Haag', 'plays', 'Elianti', '.']
        instance = instances[1]
        fields = instance.fields
        if include_raw_tokens:
            assert [
                token.metadata for token in fields["raw_tokens"].field_list
            ] == second_sentence
        assert_allclose(
            fields["arc_indices"].array,
            np.array([[0, 1], [0, 3], [1, 2], [1, 4], [3, 2], [3, 4], [4, 2],
                      [4, 3]]))
        assert fields["labels"].labels == [
            '1', '0', '1', '0', '1', '0', '1', '0'
        ]
        if use_contextualizer:
            assert fields["token_representations"].array.shape[0] == len(
                second_sentence)
            assert_allclose(fields["token_representations"].array[:, :2],
                            np.array([[0.6757653, -0.80925614],
                                      [-1.9424553, -1.0854281],
                                      [-0.09960067, 0.17525218],
                                      [0.09222834, -0.8534998],
                                      [-0.66507375, -0.5633631]]),
                            rtol=1e-3)
    def test_read_from_file(self, lazy, use_contextualizer, backward):
        # Set up contextualizer, if using.
        contextualizer = None
        if use_contextualizer:
            contextualizer = PrecomputedContextualizer(
                self.contextualizer_path)
        reader = LanguageModelingDatasetReader(backward=backward,
                                               contextualizer=contextualizer,
                                               lazy=lazy)
        instances = list(reader.read(str(self.data_path)))

        # First read instance
        instance = instances[0]
        fields = instance.fields
        tokens = [
            'The', 'party', 'stopped', 'at', 'the', 'remains', 'of', 'a',
            '15th-century', 'khan', '--', 'a', 'roadside', 'inn', 'built',
            'by', 'the', 'Mameluks', ',', 'former', 'slave', 'soldiers',
            'turned', 'rulers', 'of', 'Egypt', 'and', 'Palestine', '.'
        ]
        if backward:
            assert [
                token.metadata for token in fields["raw_tokens"].field_list
            ] == tokens
            assert fields["labels"].labels == ["<S>"] + tokens[:-1]
        else:
            assert [
                token.metadata for token in fields["raw_tokens"].field_list
            ] == tokens
            assert fields["labels"].labels == tokens[1:] + ["</S>"]

        # Second read instance
        instance = instances[1]
        fields = instance.fields
        tokens = [
            'He', 'claimed', 'that', 'people', 'who', 'followed', 'his',
            'cultivation', 'formula', 'acquired', 'a', '"', 'third', 'eye',
            '"', 'that', 'allowed', 'them', 'to', 'peer', 'into', 'other',
            'dimensions', 'and', 'escape', 'the', 'molecular', 'world', '.'
        ]
        if backward:
            assert [
                token.metadata for token in fields["raw_tokens"].field_list
            ] == tokens
            assert fields["labels"].labels == ["<S>"] + tokens[:-1]
        else:
            assert [
                token.metadata for token in fields["raw_tokens"].field_list
            ] == tokens
            assert fields["labels"].labels == tokens[1:] + ["</S>"]

        # Third read instance
        instance = instances[2]
        fields = instance.fields
        tokens = [
            'If', 'they', 'were', 'losing', 'so', 'much', 'money', 'on', 'me',
            ',', 'why', 'would', 'they', 'send', 'me', 'a', 'new', 'credit',
            'card', ',', 'under', 'the', 'same', 'terms', ',', 'when', 'my',
            'old', 'one', 'expired', '?'
        ]
        if backward:
            assert [
                token.metadata for token in fields["raw_tokens"].field_list
            ] == tokens
            assert fields["labels"].labels == ["<S>"] + tokens[:-1]
        else:
            assert [
                token.metadata for token in fields["raw_tokens"].field_list
            ] == tokens
            assert fields["labels"].labels == tokens[1:] + ["</S>"]

        # Fourth read instance
        instance = instances[3]
        fields = instance.fields
        tokens = [
            'His', 'most', 'impressive', 'performance', 'was', 'his', 'last',
            '.'
        ]
        if backward:
            assert [
                token.metadata for token in fields["raw_tokens"].field_list
            ] == tokens
            assert fields["labels"].labels == ["<S>"] + tokens[:-1]
        else:
            assert [
                token.metadata for token in fields["raw_tokens"].field_list
            ] == tokens
            assert fields["labels"].labels == tokens[1:] + ["</S>"]
    def test_precomputed_contextualizer_all_elmo_layers(self):
        all_elmo_layers_path = self.model_paths / "elmo_layers_all.hdf5"
        rep_dim = 1024
        num_sentences = 3

        # Test the first layer (index 0)
        all_elmo_layers_0 = PrecomputedContextualizer(all_elmo_layers_path,
                                                      layer_num=0)
        representations = all_elmo_layers_0(
            [self.sentence_1, self.sentence_2, self.sentence_3])
        assert len(representations) == num_sentences

        first_sentence_representation = representations[0]
        seq_len = 16
        assert first_sentence_representation.size() == (seq_len, rep_dim)
        assert_allclose(first_sentence_representation[:, :1].cpu().numpy()[:4],
                        np.array([[-0.3288476], [-0.28436223], [0.9835328],
                                  [0.1915474]]),
                        rtol=1e-5)
        second_sentence_representation = representations[1]
        seq_len = 11
        assert second_sentence_representation.size() == (seq_len, rep_dim)
        assert_allclose(
            second_sentence_representation[:, :1].cpu().numpy()[:4],
            np.array([[-0.23547989], [-1.7968968], [-0.09795779],
                      [0.10400581]]),
            rtol=1e-5)
        third_sentence_representation = representations[2]
        seq_len = 11
        assert third_sentence_representation.size() == (seq_len, rep_dim)
        assert_allclose(third_sentence_representation[:, :1].cpu().numpy()[:4],
                        np.array([[0.7506348], [-0.09795779], [0.08865512],
                                  [0.6102083]]),
                        rtol=1e-5)

        # Test the second layer (index 1)
        all_elmo_layers_1 = PrecomputedContextualizer(all_elmo_layers_path,
                                                      layer_num=1)
        representations = all_elmo_layers_1(
            [self.sentence_1, self.sentence_2, self.sentence_3])
        assert len(representations) == num_sentences

        first_sentence_representation = representations[0]
        seq_len = 16
        assert first_sentence_representation.size() == (seq_len, rep_dim)
        assert_allclose(first_sentence_representation[:, :1].cpu().numpy()[:4],
                        np.array([[0.02916196], [-0.618347], [0.04200662],
                                  [-0.28494996]]),
                        rtol=1e-5)
        second_sentence_representation = representations[1]
        seq_len = 11
        assert second_sentence_representation.size() == (seq_len, rep_dim)
        assert_allclose(
            second_sentence_representation[:, :1].cpu().numpy()[:4],
            np.array([[0.04939255], [-0.08436887], [-0.10033038],
                      [0.23103642]]),
            rtol=1e-5)
        third_sentence_representation = representations[2]
        seq_len = 11
        assert third_sentence_representation.size() == (seq_len, rep_dim)
        assert_allclose(third_sentence_representation[:, :1].cpu().numpy()[:4],
                        np.array([[0.19448458], [-0.014540106], [0.23244698],
                                  [-1.1397098]]),
                        rtol=1e-5)

        # Test the third / last layer (index 2)
        all_elmo_layers_2 = PrecomputedContextualizer(all_elmo_layers_path)
        representations = all_elmo_layers_2(
            [self.sentence_1, self.sentence_2, self.sentence_3])
        assert len(representations) == num_sentences

        first_sentence_representation = representations[0]
        seq_len = 16
        assert first_sentence_representation.size() == (seq_len, rep_dim)
        assert_allclose(first_sentence_representation[:, :1].cpu().numpy()[:4],
                        np.array([[0.28029996], [-1.1247718], [-0.45496008],
                                  [-0.25592107]]),
                        rtol=1e-5)
        second_sentence_representation = representations[1]
        seq_len = 11
        assert second_sentence_representation.size() == (seq_len, rep_dim)
        assert_allclose(
            second_sentence_representation[:, :1].cpu().numpy()[:4],
            np.array([[-0.12891075], [-0.67801315], [0.021882683],
                      [0.03998524]]),
            rtol=1e-5)
        third_sentence_representation = representations[2]
        seq_len = 11
        assert third_sentence_representation.size() == (seq_len, rep_dim)
        assert_allclose(third_sentence_representation[:, :1].cpu().numpy()[:4],
                        np.array([[0.17843074], [0.49779615], [0.36996722],
                                  [-1.154212]]),
                        rtol=1e-5)
示例#17
0
    def test_read_from_file(self, lazy, use_contextualizer,
                            include_raw_tokens):
        # Set up contextualizer, if using.
        contextualizer = None
        if use_contextualizer:
            contextualizer = PrecomputedContextualizer(
                self.contextualizer_path)
        reader = CoreferenceArcPredictionDatasetReader(
            contextualizer=contextualizer,
            include_raw_tokens=include_raw_tokens,
            lazy=lazy)
        instances = list(reader.read(str(self.data_path)))

        document = [
            'What', 'kind', 'of', 'memory', '?', 'We', 'respectfully',
            'invite', 'you', 'to', 'watch', 'a', 'special', 'edition', 'of',
            'Across', 'China', '.', 'WW', 'II', 'Landmarks', 'on', 'the',
            'Great', 'Earth', 'of', 'China', ':', 'Eternal', 'Memories', 'of',
            'Taihang', 'Mountain', 'Standing', 'tall', 'on', 'Taihang',
            'Mountain', 'is', 'the', 'Monument', 'to', 'the', 'Hundred',
            'Regiments', 'Offensive', '.', 'It', 'is', 'composed', 'of', 'a',
            'primary', 'stele', ',', 'secondary', 'steles', ',', 'a', 'huge',
            'round', 'sculpture', 'and', 'beacon', 'tower', ',', 'and', 'the',
            'Great', 'Wall', ',', 'among', 'other', 'things', '.', 'A',
            'primary', 'stele', ',', 'three', 'secondary', 'steles', ',',
            'and', 'two', 'inscribed', 'steles', '.', 'The', 'Hundred',
            'Regiments', 'Offensive', 'was', 'the', 'campaign', 'of', 'the',
            'largest', 'scale', 'launched', 'by', 'the', 'Eighth', 'Route',
            'Army', 'during', 'the', 'War', 'of', 'Resistance', 'against',
            'Japan', '.', 'This', 'campaign', 'broke', 'through', 'the',
            'Japanese', 'army', "'s", 'blockade', 'to', 'reach', 'base',
            'areas', 'behind', 'enemy', 'lines', ',', 'stirring', 'up',
            'anti-Japanese', 'spirit', 'throughout', 'the', 'nation', 'and',
            'influencing', 'the', 'situation', 'of', 'the', 'anti-fascist',
            'war', 'of', 'the', 'people', 'worldwide', '.', 'This', 'is',
            'Zhuanbi', 'Village', ',', 'Wuxiang', 'County', 'of', 'Shanxi',
            'Province', ',', 'where', 'the', 'Eighth', 'Route', 'Army', 'was',
            'headquartered', 'back', 'then', '.', 'On', 'a', 'wall', 'outside',
            'the', 'headquarters', 'we', 'found', 'a', 'map', '.', 'This',
            'map', 'was', 'the', 'Eighth', 'Route', 'Army', "'s", 'depiction',
            'of', 'the', 'Mediterranean', 'Sea', 'situation', 'at', 'that',
            'time', '.', 'This', 'map', 'reflected', 'the', 'European',
            'battlefield', 'situation', '.', 'In', '1940', ',', 'the',
            'German', 'army', 'invaded', 'and', 'occupied', 'Czechoslovakia',
            ',', 'Poland', ',', 'the', 'Netherlands', ',', 'Belgium', ',',
            'and', 'France', '.', 'It', 'was', 'during', 'this', 'year',
            'that', 'the', 'Japanese', 'army', 'developed', 'a', 'strategy',
            'to', 'rapidly', 'force', 'the', 'Chinese', 'people', 'into',
            'submission', 'by', 'the', 'end', 'of', '1940', '.', 'In', 'May',
            ',', 'the', 'Japanese', 'army', 'launched', '--', 'From', 'one',
            'side', ',', 'it', 'seized', 'an', 'important', 'city', 'in',
            'China', 'called', 'Yichang', '.', 'Um', ',', ',', 'uh', ',',
            'through', 'Yichang', ',', 'it', 'could', 'directly', 'reach',
            'Chongqing', '.', 'Ah', ',', 'that', 'threatened', 'Chongqing',
            '.', 'Then', 'they', 'would', ',', 'ah', ',', 'bomb', 'these',
            'large', 'rear', 'areas', 'such', 'as', 'Chongqing', '.', 'So',
            ',', 'along', 'with', 'the', 'coordinated', ',', 'er', ',',
            'economic', 'blockade', ',', 'military', 'offensives', ',', 'and',
            'strategic', 'bombings', ',', 'er', ',', 'a', 'simultaneous',
            'attack', 'was', 'launched', 'in', 'Hong', 'Kong', 'to', 'lure',
            'the', 'KMT', 'government', 'into', 'surrender', '.', 'The',
            'progress', 'of', 'this', 'coordinated', 'offensive', 'was',
            'already', 'very', 'entrenched', 'by', 'then', '.'
        ]

        if include_raw_tokens:
            for instance in instances:
                assert [
                    token.metadata
                    for token in instance.fields["raw_tokens"].field_list
                ] == document

        # First read instance
        instance = instances[0]
        fields = instance.fields
        if include_raw_tokens:
            assert [
                token.metadata for token in fields["raw_tokens"].field_list
            ] == document
        assert_allclose(
            fields["arc_indices"].array,
            np.array([(298, 285), (298, 288), (298, 267), (298, 288),
                      (293, 288), (293, 273)]))
        assert fields["labels"].labels == ['1', '0', '1', '0', '1', '0']
        if use_contextualizer:
            assert fields["token_representations"].array.shape[0] == len(
                document)
            assert_allclose(fields["token_representations"].array[:4, :2],
                            np.array([[-0.40419546, 0.18443017],
                                      [-0.4557378, -0.50057644],
                                      [0.10493508, -0.7943226],
                                      [-0.8075396, 0.87755275]]),
                            rtol=1e-4)
示例#18
0
    def test_read_from_file(self, lazy, use_contextualizer):
        # Set up contextualizer, if using.
        contextualizer = None
        if use_contextualizer:
            contextualizer = PrecomputedContextualizer(
                self.contextualizer_path)
        reader = ConjunctIdentificationDatasetReader(
            contextualizer=contextualizer, lazy=lazy)
        instances = list(reader.read(str(self.data_path)))
        # One instance is skipped because of nested coordination
        assert len(instances) == 2

        # First read instance
        instance = instances[0]
        fields = instance.fields
        assert [token.metadata
                for token in fields["raw_tokens"].field_list] == [
                    'They', 'shredded', 'it', 'simply', 'because', 'it',
                    'contained', 'financial', 'information', 'about', 'their',
                    'creditors', 'and', 'depositors', '.', "''"
                ]
        assert fields["labels"].labels == [
            'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B', 'O',
            'B', 'O', 'O'
        ]
        if use_contextualizer:
            assert_allclose(fields["token_representations"].array[:, :2],
                            np.array([[-2.047799, -0.3107947],
                                      [0.40388513, 0.15957603],
                                      [0.4006851, -0.1980469],
                                      [0.409753, -0.48708656],
                                      [0.65417755, -0.03706935],
                                      [-0.53143466, -1.057557],
                                      [0.7815078, -0.21813926],
                                      [-1.3369036, -0.77031285],
                                      [0.11985331, -0.39474356],
                                      [0.68627775, -0.72502434],
                                      [0.569624, -2.3243494],
                                      [-0.69559455, -1.248917],
                                      [0.2524291, -0.47938287],
                                      [0.2019696, -0.66839015],
                                      [-0.5914014, -0.8587656],
                                      [-0.521717, 0.04716678]]),
                            rtol=1e-4)

        # Second read instance
        instance = instances[1]
        fields = instance.fields
        assert [token.metadata
                for token in fields["raw_tokens"].field_list] == [
                    'Suppression', 'of', 'the', 'book', ',', 'Judge', 'Oakes',
                    'observed', ',', 'would', 'operate', 'as', 'a', 'prior',
                    'restraint', 'and', 'thus', 'involve', 'the', 'First',
                    'Amendment', '.'
                ]
        assert fields["labels"].labels == [
            'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B', 'I', 'I',
            'I', 'I', 'O', 'O', 'B', 'I', 'I', 'I', 'O'
        ]
        if use_contextualizer:
            assert_allclose(fields["token_representations"].array[:, :2],
                            np.array([[-0.00731754, 0.00914195],
                                      [0.6760086, -0.11198741],
                                      [0.8685149, 0.15874714],
                                      [-0.9228251, -1.3684492],
                                      [-0.17535079, 0.36266953],
                                      [-0.85589266, -1.4212742],
                                      [-1.8647766, -0.9377552],
                                      [-0.34395775, 0.18579313],
                                      [-1.6104316, 0.5044512],
                                      [-1.6913524, 0.5832756],
                                      [0.6513059, 1.1528094],
                                      [-0.24509574, 0.49362227],
                                      [-0.47929475, 0.6173321],
                                      [-0.431388, 0.15780556],
                                      [-1.4048593, 0.44075668],
                                      [-0.32530123, 0.23048985],
                                      [-0.23973304, 1.2190828],
                                      [0.4657239, 0.20590879],
                                      [0.16104633, 0.04873788],
                                      [0.8202704, -0.7126241],
                                      [-0.59338295, 1.2020597],
                                      [-0.5741635, -0.05905316]]),
                            rtol=1e-4)
    def test_read_from_file(self, lazy, use_contextualizer, label_encoding):
        # Set up contextualizer, if using.
        contextualizer = None
        if use_contextualizer:
            contextualizer = PrecomputedContextualizer(self.contextualizer_path)
        reader = Conll2003NERDatasetReader(label_encoding=label_encoding,
                                           contextualizer=contextualizer,
                                           lazy=lazy)
        instances = list(reader.read(str(self.data_path)))
        print(len(instances))
        # First read instance
        instance = instances[0]
        fields = instance.fields
        assert [token.metadata for token in fields["raw_tokens"].field_list] == [
            "EU", "rejects", "German", "call", "to", "boycott", "British", "lamb", "."]
        if label_encoding == "IOB1":
            assert fields["labels"].labels == [
                "I-ORG", "O", "I-MISC", "O", "O", "O", "I-MISC", "O", "O"]
        else:
            assert fields["labels"].labels == [
                "U-ORG", "O", "U-MISC", "O", "O", "O", "U-MISC", "O", "O"]
        if use_contextualizer:
            assert_allclose(
                fields["token_representations"].array[:, :2],
                np.array([[0.9611156, -0.63280857], [-0.5790216, -0.13829914],
                          [-0.35322708, -0.22807068], [-1.6707208, -1.1125797],
                          [-2.0587592, -1.5086308], [-1.3151755, -1.6046834],
                          [0.5072891, -1.5075727], [0.11287686, -1.2473724],
                          [-0.5029946, -1.4319026]]),
                rtol=1e-4)

        # Second read instance
        instance = instances[1]
        fields = instance.fields
        assert [token.metadata for token in fields["raw_tokens"].field_list] == [
            "Peter", "Blackburn"]
        if label_encoding == "IOB1":
            assert fields["labels"].labels == [
                "I-PER", "I-PER"]
        else:
            assert fields["labels"].labels == [
                "B-PER", "L-PER"]
        if use_contextualizer:
            assert_allclose(
                fields["token_representations"].array[:, :2],
                np.array([[-0.5346743, -1.1600235], [-0.7508778, -1.031188]]),
                rtol=1e-4)

        # Third read instance
        instance = instances[2]
        fields = instance.fields
        assert [token.metadata for token in fields["raw_tokens"].field_list] == [
            "BRUSSELS", "1996-08-22"]
        if label_encoding == "IOB1":
            assert fields["labels"].labels == [
                "I-LOC", "O"]
        else:
            assert fields["labels"].labels == [
                "U-LOC", "O"]
        if use_contextualizer:
            assert_allclose(
                fields["token_representations"].array[:, :2],
                np.array([[0.0324477, -0.06925768], [-0.47278678, 0.530316]]),
                rtol=1e-4)
示例#20
0
    def test_read_from_file(self, lazy, use_contextualizer):
        # Set up contextualizer, if using.
        contextualizer = None
        if use_contextualizer:
            contextualizer = PrecomputedContextualizer(
                self.contextualizer_path)
        reader = SemanticTaggingDatasetReader(contextualizer=contextualizer,
                                              lazy=lazy)
        instances = list(reader.read(str(self.data_path)))

        # First read instance
        instance = instances[0]
        fields = instance.fields
        assert [token.metadata
                for token in fields["raw_tokens"].field_list] == [
                    "The", "report", "comes", "as", "Afghan", "officials",
                    "announced", "that", "the", "bodies", "of", "20",
                    "Taleban", "militants", "had", "been", "recovered", "from",
                    "Bermel", "district", "in", "Paktika", "province", ",",
                    "where", "NATO", "and", "Afghan", "forces", "recently",
                    "conducted", "a", "mission", "."
                ]
        assert fields["labels"].labels == [
            "DEF", "CON", "ENS", "SUB", "GPE", "CON", "PST", "SUB", "DEF",
            "CON", "REL", "DOM", "ORG", "CON", "EPT", "ETV", "EXV", "REL",
            "LOC", "CON", "REL", "LOC", "CON", "NIL", "PRO", "ORG", "AND",
            "GPE", "CON", "REL", "PST", "DIS", "CON", "NIL"
        ]
        if use_contextualizer:
            assert_allclose(fields["token_representations"].array[:, :2],
                            np.array([[0.2937507, -0.03459462],
                                      [-2.0376801, -1.6185987],
                                      [0.80633676, 0.25174493],
                                      [-0.31453115, -0.16706648],
                                      [0.2778436, 0.8754083],
                                      [-3.912532, 0.15716752],
                                      [0.03259511, 1.074891],
                                      [0.60919964, 0.28122807],
                                      [0.2766431, -0.57389474],
                                      [-1.5917854, 0.14402057],
                                      [0.46617347, 0.5476148],
                                      [-0.3859496, 0.55521],
                                      [-0.19902334, 0.51852816],
                                      [-0.49617743, 0.50021535],
                                      [0.89773405, 0.33418086],
                                      [-1.0823509, 0.8463002],
                                      [0.9214894, 0.17294498],
                                      [-0.98676234, 0.46858853],
                                      [-1.1950549, 1.0456221],
                                      [-0.06810452, 1.8754647],
                                      [-0.31319135, 0.5955827],
                                      [0.8572887, 0.9902405],
                                      [0.18385345, 0.88080823],
                                      [-0.2386447, 0.273946],
                                      [1.0159383, 0.2908004],
                                      [-0.84152496, -1.8987631],
                                      [0.6318563, -1.3307623],
                                      [0.77291626, -0.9464708],
                                      [-2.5105689, 0.05288363],
                                      [-1.8620715, 0.05540787],
                                      [0.8963124, 0.88138795],
                                      [1.0833803, 0.29445225],
                                      [-0.33804226, -0.5501779],
                                      [-0.80601907, -0.6653841]]),
                            rtol=1e-4)

        # Second read instance
        instance = instances[1]
        fields = instance.fields
        assert [token.metadata
                for token in fields["raw_tokens"].field_list] == [
                    "Turkish", "Prime", "Minister", "Tayyip", "Erdogan", ",",
                    "in", "London", "for", "talks", "with", "British", "Prime",
                    "Minister", "Tony", "Blair", ",", "said", "Wednesday",
                    "Ankara", "would", "sign", "the", "EU", "protocol", "soon",
                    "."
                ]
        assert fields["labels"].labels == [
            "GPE", "UNK", "UNK", "PER", "PER", "NIL", "REL", "LOC", "REL",
            "CON", "REL", "GPE", "UNK", "UNK", "PER", "PER", "NIL", "PST",
            "TIM", "TIM", "FUT", "EXS", "DEF", "ORG", "CON", "REL", "NIL"
        ]
        if use_contextualizer:
            assert_allclose(fields["token_representations"].array[:, :2],
                            np.array([[0.28560728, 0.34812376],
                                      [-1.7316533, -0.5265728],
                                      [-2.6642923, -0.9582914],
                                      [-1.6637948, -2.5388384],
                                      [-2.9503021, -0.74373335],
                                      [-3.1062536, -0.47450644],
                                      [-2.2821736, -0.08023855],
                                      [-1.9760342, -0.4066736],
                                      [-1.9215266, -0.81184065],
                                      [-2.2942708, -0.13005577],
                                      [-1.1666149, -0.82010025],
                                      [1.2843199, -0.04729652],
                                      [-0.35602665, -1.9205997],
                                      [0.1594456, -2.390737],
                                      [-1.0997499, -0.11030376],
                                      [-1.7266417, 0.01889065],
                                      [-2.9103873, -1.6603167],
                                      [-1.3453144, 0.0276348],
                                      [-1.5531495, 0.24530894],
                                      [-4.1084657, -0.24038172],
                                      [-3.6353674, -1.2928469],
                                      [-1.527199, 1.9692067],
                                      [-0.86209273, 1.5000844],
                                      [-1.3264929, 0.35947016],
                                      [-2.4620879, 1.5387912],
                                      [-1.9274603, 0.67314804],
                                      [-1.1620884, -0.63547856]]),
                            rtol=1e-4)
    def test_read_from_file(self, lazy, use_contextualizer,
                            include_raw_tokens):
        # Set up contextualizer, if using.
        contextualizer = None
        if use_contextualizer:
            contextualizer = PrecomputedContextualizer(
                self.contextualizer_path)
        reader = EventFactualityDatasetReader(
            include_raw_tokens=include_raw_tokens,
            contextualizer=contextualizer,
            lazy=lazy)
        instances = list(reader.read(str(self.data_path)))
        assert len(instances) == 15

        # First read instance
        instance = instances[0]
        fields = instance.fields
        if include_raw_tokens:
            assert [
                token.metadata for token in fields["raw_tokens"].field_list
            ] == [
                'Joe', 'asked', 'that', 'you', 'fax', 'the', 'revised', 'gtee',
                'wording', 'that', 'has', 'been', 'agreed', '(', 'I',
                'believe', 'it', 'was', 'our', 'agreeing', 'to', 'reduce',
                'the', 'claim', 'period', 'from', '15', 'days', 'down', 'to',
                '5', ')', 'and', 'the', 'new', 'l/c', 'wording', '(', 'drops',
                'the', '2', 'day', 'period', 'to', 'replace', 'an', 'l/c',
                'with', 'a', 'different', 'bank', 'if', 'the', 'first',
                'refuses', 'to', 'pay', ')', '.'
            ]
        assert_allclose(fields["label_indices"].array,
                        np.array([15, 19, 21, 38, 44, 54, 56, 1, 4, 6, 12]))
        if use_contextualizer:
            assert_allclose(fields["token_representations"].array[:, :2],
                            np.array([[-0.36685583, -1.0248482],
                                      [-0.08719254, -0.12769365],
                                      [0.64198303, 0.7540561],
                                      [-2.480215, 0.04852793],
                                      [0.05662279, 0.19068614],
                                      [0.8952136, 0.18563624],
                                      [0.61201894, -0.21791479],
                                      [0.16892922, -0.79595846],
                                      [-0.27208328, -0.13422441],
                                      [0.04730925, -0.43866983],
                                      [-0.18922694, 0.41402912],
                                      [-1.2735212, -0.7098247],
                                      [-0.35325307, -0.1886746],
                                      [0.24240366, -0.2627995],
                                      [-2.657272, -0.85991454],
                                      [-0.19721821, -0.28280562],
                                      [-1.2974384, -1.5685275],
                                      [-0.17114338, -1.3488747],
                                      [-0.14475444, -1.3091846],
                                      [-0.9352702, -0.42290983],
                                      [-1.9790481, -0.19222577],
                                      [-0.7576624, -1.3168397],
                                      [0.04005039, -0.9087254],
                                      [-1.1224419, -1.2120944],
                                      [-1.1654481, -1.2385485],
                                      [-0.53110546, -0.37541062],
                                      [-0.43803376, -0.5062414],
                                      [-1.0063732, -1.4231381],
                                      [-1.6299391, -0.08710647],
                                      [-0.4013245, 1.336797],
                                      [-0.31591064, 0.11186421],
                                      [-0.9240766, -0.19987631],
                                      [-0.91462064, -0.2551515],
                                      [0.48850712, -0.05782498],
                                      [0.26612586, -0.7230994],
                                      [-0.00594145, -1.11585],
                                      [-0.82872486, -0.6029454],
                                      [0.10594115, 0.6299722],
                                      [-0.23010078, 0.5210506],
                                      [0.57265085, -0.76853454],
                                      [-0.2151854, 0.1495785],
                                      [-0.5665817, 0.10349956],
                                      [-0.0619593, 0.15140474],
                                      [0.47662088, 0.9349986],
                                      [0.4795642, 0.4577945],
                                      [0.3688566, -0.06091809],
                                      [0.29802012, -0.25112373],
                                      [0.8288579, 0.28962702],
                                      [0.90991616, 0.24866864],
                                      [-0.2174969, -1.3967221],
                                      [-0.26998952, -1.2395245],
                                      [0.40867922, 0.41572857],
                                      [0.34937006, -0.21592987],
                                      [0.02204479, -1.1068783],
                                      [-0.81269974, -0.71383244],
                                      [-1.6719012, -0.24751332],
                                      [-0.7133447, -0.9015558],
                                      [-0.36663392, 0.00226176],
                                      [-0.66520894, 0.02220622]]),
                            rtol=1e-4)
        assert_allclose(
            fields["labels"].array,
            np.array([
                2.625, 2.625, -1.125, -1.125, -1.5, -2.25, 2.25, 3.0, -2.625,
                2.625, 3.0
            ]))
        assert len(fields["label_indices"].array) == len(
            fields["labels"].array)

        # Second read instance
        instance = instances[1]
        fields = instance.fields
        if include_raw_tokens:
            assert [
                token.metadata for token in fields["raw_tokens"].field_list
            ] == [
                'you', 'need', 'to', 'do', 'that', 'for', 'a', 'least', 'a',
                'week', ',', 'or', 'else', 'this', 'type', 'of', 'territorial',
                'fighting', 'will', 'happen', '.'
            ]

        assert_allclose(fields["label_indices"].array, np.array([19, 1, 3]))
        if use_contextualizer:
            assert_allclose(fields["token_representations"].array[:, :2],
                            np.array([[-1.8767732, 0.017495558],
                                      [0.6025951, -0.20260233],
                                      [-1.2275248, -0.09547685],
                                      [0.77881116, 0.58967566],
                                      [-0.48033506, -0.7382006],
                                      [0.81346595, -0.5705998],
                                      [0.4814891, 0.13833025],
                                      [0.29612598, 0.4745674],
                                      [1.047016, -0.10455979],
                                      [-0.42458856, -1.1668162],
                                      [-0.12459692, 1.0916736],
                                      [0.28291142, 0.17336448],
                                      [-0.08204004, 0.6720216],
                                      [-0.55279577, -0.3378092],
                                      [0.046703815, 0.0627833],
                                      [-0.17136925, 0.07279006],
                                      [-0.61967653, -0.36650854],
                                      [0.22994132, -0.17982215],
                                      [-0.039243788, -0.19590409],
                                      [1.0741227, -0.46452063],
                                      [-0.99690104, -0.20734516]]),
                            rtol=1e-4)
        assert_allclose(fields["labels"].array, np.array([-2.25, -2.25,
                                                          -2.25]))
        assert len(fields["label_indices"].array) == len(
            fields["labels"].array)

        # Third read instance
        instance = instances[2]
        fields = instance.fields
        if include_raw_tokens:
            assert [
                token.metadata for token in fields["raw_tokens"].field_list
            ] == [
                'The', 'police', 'commander', 'of', 'Ninevah', 'Province',
                'announced', 'that', 'bombings', 'had', 'declined', '80',
                'percent', 'in', 'Mosul', ',', 'whereas', 'there', 'had',
                'been', 'a', 'big', 'jump', 'in', 'the', 'number', 'of',
                'kidnappings', '.'
            ]

        assert_allclose(fields["label_indices"].array, np.array([6, 10]))
        if use_contextualizer:
            assert_allclose(fields["token_representations"].array[:, :2],
                            np.array([[0.30556452, -0.03631544],
                                      [-0.408598, -0.06455576],
                                      [-1.5492078, -0.54190993],
                                      [-0.2587847, -0.2953575],
                                      [-0.73995805, 0.4683098],
                                      [-0.7363724, 0.79993117],
                                      [-0.16775642, 1.7011331],
                                      [0.41263822, 1.7853746],
                                      [0.20824504, 0.94526154],
                                      [0.38856286, 0.55201274],
                                      [0.21549016, 0.29676253],
                                      [0.19024657, 1.6858654],
                                      [-0.3601446, 0.9940252],
                                      [0.06638061, 2.2731574],
                                      [-0.83813465, 1.996573],
                                      [-0.2564547, 1.3016648],
                                      [1.2408254, 1.2657689],
                                      [1.2441401, 0.26394492],
                                      [1.2946486, 0.4354594],
                                      [1.5132289, -0.28065175],
                                      [1.3383818, 0.99084306],
                                      [1.04397, -0.52631915],
                                      [1.026963, 0.8950106],
                                      [1.1683758, 0.3674168],
                                      [1.568187, 0.60855913],
                                      [0.00261295, 1.0362052],
                                      [1.0013494, 1.1375219],
                                      [0.46779868, 0.85086995],
                                      [-0.23202378, -0.3398294]]),
                            rtol=1e-4)
        assert_allclose(fields["labels"].array, np.array([3.0, 3.0]))
    def test_read_from_file(self, lazy, use_contextualizer):
        # Set up contextualizer, if using.
        contextualizer = None
        if use_contextualizer:
            contextualizer = PrecomputedContextualizer(
                self.contextualizer_path)
        reader = ConllUPOSDatasetReader(contextualizer=contextualizer,
                                        lazy=lazy)
        instances = list(reader.read(str(self.data_path)))

        # First read instance
        instance = instances[0]
        fields = instance.fields
        assert [token.metadata
                for token in fields["raw_tokens"].field_list] == [
                    'Al', '-', 'Zaman', ':', 'American', 'forces', 'killed',
                    'Shaikh', 'Abdullah', 'al', '-', 'Ani', ',', 'the',
                    'preacher', 'at', 'the', 'mosque', 'in', 'the', 'town',
                    'of', 'Qaim', ',', 'near', 'the', 'Syrian', 'border', '.'
                ]
        assert fields["labels"].labels == [
            'PROPN', 'PUNCT', 'PROPN', 'PUNCT', 'ADJ', 'NOUN', 'VERB', 'PROPN',
            'PROPN', 'PROPN', 'PUNCT', 'PROPN', 'PUNCT', 'DET', 'NOUN', 'ADP',
            'DET', 'NOUN', 'ADP', 'DET', 'NOUN', 'ADP', 'PROPN', 'PUNCT',
            'ADP', 'DET', 'ADJ', 'NOUN', 'PUNCT'
        ]
        if use_contextualizer:
            assert_allclose(fields["token_representations"].array[:, :2],
                            np.array([[0.43633628, -0.5755847],
                                      [-0.2244201, -0.3955103],
                                      [-1.8495967, -1.6728945],
                                      [-1.0596983, -0.10573974],
                                      [-0.15140322, -0.7195155],
                                      [-2.3639536, -0.42766416],
                                      [-0.3464077, -0.6743664],
                                      [-0.5407328, -0.9869094],
                                      [-1.2095747, 0.8123201],
                                      [0.46097872, 0.8609313],
                                      [-0.46175557, 0.42401582],
                                      [-0.42247432, -0.91118157],
                                      [-0.41762316, -0.5272959],
                                      [0.69995964, -0.16589859],
                                      [-1.4730558, -0.23568547],
                                      [-0.30440047, -0.8264297],
                                      [-0.40472034, -0.15715468],
                                      [-1.3681564, -0.08945632],
                                      [-0.6464306, 0.52979404],
                                      [-0.35902542, 0.8537967],
                                      [-2.1601028, 1.0484889],
                                      [-0.42148307, 0.11593458],
                                      [-0.81707406, 0.47127616],
                                      [-0.8185376, -0.20927876],
                                      [-1.4944136, 0.2279036],
                                      [-1.244726, 0.27427846],
                                      [-1.366718, 0.9977276],
                                      [-1.0117195, 0.27465925],
                                      [-0.6697843, -0.24481633]]),
                            rtol=1e-4)

        # Second read instance
        instance = instances[1]
        fields = instance.fields
        assert [token.metadata
                for token in fields["raw_tokens"].field_list] == [
                    '[', 'This', 'killing', 'of', 'a', 'respected', 'cleric',
                    'will', 'be', 'causing', 'us', 'trouble', 'for', 'years',
                    'to', 'come', '.', ']'
                ]
        assert fields["labels"].labels == [
            'PUNCT', 'DET', 'NOUN', 'ADP', 'DET', 'ADJ', 'NOUN', 'AUX', 'AUX',
            'VERB', 'PRON', 'NOUN', 'ADP', 'NOUN', 'PART', 'VERB', 'PUNCT',
            'PUNCT'
        ]
        if use_contextualizer:
            assert_allclose(fields["token_representations"].array[:, :2],
                            np.array([[-0.21313506, -0.9986056],
                                      [-0.9670943, -1.293689],
                                      [-0.9337523, -0.2829439],
                                      [-0.14427447, -1.3481213],
                                      [1.0426146, -1.2611127],
                                      [-0.03402041, -0.90879065],
                                      [-2.1094723, -0.65833807],
                                      [-2.52652, 0.05855975],
                                      [-1.5565295, -0.62821376],
                                      [-1.016165, -0.6203798],
                                      [-0.5337064, -1.0520142],
                                      [-1.2524656, -1.2280166],
                                      [0.05167481, -0.63919723],
                                      [-1.9454485, -1.7038071],
                                      [0.24676055, 1.0511997],
                                      [-1.4455109, -2.3033257],
                                      [-2.0335193, -1.3011322],
                                      [-0.9321909, -0.09861001]]),
                            rtol=1e-4)