コード例 #1
0
    def test_annotation_num(self):
        data = TSVDataSource(
            SafeFileWrapper(tests_module.test_file("compositional_seq2seq_unit.tsv")),
            test_file=None,
            eval_file=None,
            field_names=["text", "seqlogical"],
            schema={"text": str, "seqlogical": str},
        )
        nbrz = AnnotationNumberizer()
        init = nbrz.initialize()
        init.send(None)  # kick
        for row in data.train:
            init.send(row)
        init.close()

        # vocab = {'IN:GET_INFO_TRAFFIC': 0, 'SHIFT': 1, 'SL:LOCATION': 2,
        # 'REDUCE': 3, 'IN:GET_DIRECTIONS': 4, 'SL:DESTINATION': 5, 'SL:SOURCE': 6,
        # 'IN:GET_LOCATION_HOME': 7, 'SL:CONTACT': 8, 'SL:DATE_TIME_DEPARTURE': 9,
        # 'IN:UNSUPPORTED_NAVIGATION': 10, 'IN:GET_ESTIMATED_DURATION': 11,
        # 'IN:GET_LOCATION_WORK': 12, 'SL:PATH_AVOID': 13, 'IN:GET_DISTANCE': 14}

        self.assertEqual(15, len(nbrz.vocab))
        self.assertEqual(1, nbrz.shift_idx)
        self.assertEqual(3, nbrz.reduce_idx)
        self.assertEqual([10], nbrz.ignore_subNTs_roots)
        self.assertEqual(
            [0, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], nbrz.valid_NT_idxs
        )
        self.assertEqual([0, 4, 7, 10, 11, 12, 14], nbrz.valid_IN_idxs)
        self.assertEqual([2, 5, 6, 8, 9, 13], nbrz.valid_SL_idxs)

        for row, expected in zip(data.train, EXPECTED_ACTIONS):
            actions = nbrz.numberize(row)
            self.assertEqual(expected, actions)
コード例 #2
0
 class ModelInput(BaseModel.Config.ModelInput):
     tokens: TokenTensorizer.Config = TokenTensorizer.Config(
         column="tokenized_text")
     actions: AnnotationNumberizer.Config = AnnotationNumberizer.Config(
     )