def _make_symbol_table_dict(): symbol_table_dict = {} type_name1 = 'type_safety' symbol_table1 = symbol.SymbolTable() for s in 'abcdefghijklmnop': symbol_table1.insert(s) symbol_table1.freeze() type_name2 = 'fancy_type' symbol_table2 = symbol.SymbolTable() for s in '1234567890': symbol_table2.insert(s) symbol_table2.pad_to_vocab_size(20) symbol_table_dict[type_name1] = symbol_table1 symbol_table_dict[type_name2] = symbol_table2 return symbol_table_dict
def test_padding(self): tab = symbol.SymbolTable() for s in 'abcdefg': tab.insert(s) self.assertTrue(tab.has_id('a')) self.assertEqual(tab.get_max_id(), 7) tab.pad_to_vocab_size(20) self.assertEqual(tab.get_max_id(), 20) tab.reset() for s in 'tuvwx': tab.insert(s) self.assertEqual(tab.get_max_id(), 20) self.assertTrue(tab.has_id('x')) self.assertFalse(tab.has_id('a'))
def test_unk(self): tab = symbol.SymbolTable() for s in 'abcdefg': tab.insert(s) self.assertEqual(tab.get_max_id(), 7) self.assertEqual(tab.get_unk_id(), None) # freezing adds an UNK symbol tab.freeze() self.assertEqual(tab.get_unk_id(), 7) self.assertEqual(tab.get_max_id(), 8) # new strings are now UNK'd out self.assertEqual(tab.get_id('h'), tab.get_id('z')) # even if you insert them tab.insert('h') self.assertEqual(tab.get_max_id(), 8) self.assertEqual(tab.get_id('h'), tab.get_id('z')) self.assertEqual(tab.get_id('h'), 7)
def _test_restrict_filtering(self, input_types, restrict_write, restrict_read, expected_output_types): """Test restrict behavior on types written and read through tempfile. Args: input_types: A list of the types to start with. restrict_write: A list of the restrict_to to use while writing. restrict_read: A list of the restrict_to to use while reading. expected_output_types: A list of the types to expect returned. """ symbol_table_dict = {t: symbol.SymbolTable() for t in input_types} io.write_symbol_table_dict(self.filename, symbol_table_dict, restrict_to=[r for r in restrict_write]) io_symbol_table_dict = io.read_symbol_table_dict( self.filename, restrict_to=[r for r in restrict_read]) io_output_types = ''.join(sorted(io_symbol_table_dict.keys())) self.assertEqual(io_output_types, expected_output_types)
def test_fixed_freeze_none(self): tab = symbol.SymbolTable() for s in 'abcdefg': tab.insert(s) tab.freeze(unknown_marker=None) self.assertEqual(tab.get_id('Z'), None)