def test_annotation_num(self): data = TSVDataSource( SafeFileWrapper(tests_module.test_file("compositional_seq2seq_unit.tsv")), test_file=None, eval_file=None, field_names=["text", "seqlogical"], schema={"text": str, "seqlogical": str}, ) nbrz = AnnotationNumberizer() init = nbrz.initialize() init.send(None) # kick for row in data.train: init.send(row) init.close() # vocab = {'IN:GET_INFO_TRAFFIC': 0, 'SHIFT': 1, 'SL:LOCATION': 2, # 'REDUCE': 3, 'IN:GET_DIRECTIONS': 4, 'SL:DESTINATION': 5, 'SL:SOURCE': 6, # 'IN:GET_LOCATION_HOME': 7, 'SL:CONTACT': 8, 'SL:DATE_TIME_DEPARTURE': 9, # 'IN:UNSUPPORTED_NAVIGATION': 10, 'IN:GET_ESTIMATED_DURATION': 11, # 'IN:GET_LOCATION_WORK': 12, 'SL:PATH_AVOID': 13, 'IN:GET_DISTANCE': 14} self.assertEqual(15, len(nbrz.vocab)) self.assertEqual(1, nbrz.shift_idx) self.assertEqual(3, nbrz.reduce_idx) self.assertEqual([10], nbrz.ignore_subNTs_roots) self.assertEqual( [0, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], nbrz.valid_NT_idxs ) self.assertEqual([0, 4, 7, 10, 11, 12, 14], nbrz.valid_IN_idxs) self.assertEqual([2, 5, 6, 8, 9, 13], nbrz.valid_SL_idxs) for row, expected in zip(data.train, EXPECTED_ACTIONS): actions = nbrz.numberize(row) self.assertEqual(expected, actions)
class ModelInput(BaseModel.Config.ModelInput): tokens: TokenTensorizer.Config = TokenTensorizer.Config( column="tokenized_text") actions: AnnotationNumberizer.Config = AnnotationNumberizer.Config( )