Пример #1
0
def _make_symbol_table_dict():
    symbol_table_dict = {}
    type_name1 = 'type_safety'
    symbol_table1 = symbol.SymbolTable()
    for s in 'abcdefghijklmnop':
        symbol_table1.insert(s)
    symbol_table1.freeze()
    type_name2 = 'fancy_type'
    symbol_table2 = symbol.SymbolTable()
    for s in '1234567890':
        symbol_table2.insert(s)
    symbol_table2.pad_to_vocab_size(20)
    symbol_table_dict[type_name1] = symbol_table1
    symbol_table_dict[type_name2] = symbol_table2
    return symbol_table_dict
Пример #2
0
 def test_padding(self):
     tab = symbol.SymbolTable()
     for s in 'abcdefg':
         tab.insert(s)
     self.assertTrue(tab.has_id('a'))
     self.assertEqual(tab.get_max_id(), 7)
     tab.pad_to_vocab_size(20)
     self.assertEqual(tab.get_max_id(), 20)
     tab.reset()
     for s in 'tuvwx':
         tab.insert(s)
     self.assertEqual(tab.get_max_id(), 20)
     self.assertTrue(tab.has_id('x'))
     self.assertFalse(tab.has_id('a'))
Пример #3
0
 def test_unk(self):
     tab = symbol.SymbolTable()
     for s in 'abcdefg':
         tab.insert(s)
     self.assertEqual(tab.get_max_id(), 7)
     self.assertEqual(tab.get_unk_id(), None)
     # freezing adds an UNK symbol
     tab.freeze()
     self.assertEqual(tab.get_unk_id(), 7)
     self.assertEqual(tab.get_max_id(), 8)
     # new strings are now UNK'd out
     self.assertEqual(tab.get_id('h'), tab.get_id('z'))
     # even if you insert them
     tab.insert('h')
     self.assertEqual(tab.get_max_id(), 8)
     self.assertEqual(tab.get_id('h'), tab.get_id('z'))
     self.assertEqual(tab.get_id('h'), 7)
Пример #4
0
    def _test_restrict_filtering(self, input_types, restrict_write,
                                 restrict_read, expected_output_types):
        """Test restrict behavior on types written and read through tempfile.

    Args:
      input_types: A list of the types to start with.
      restrict_write: A list of the restrict_to to use while writing.
      restrict_read: A list of the restrict_to to use while reading.
      expected_output_types: A list of the types to expect returned.
    """
        symbol_table_dict = {t: symbol.SymbolTable() for t in input_types}
        io.write_symbol_table_dict(self.filename,
                                   symbol_table_dict,
                                   restrict_to=[r for r in restrict_write])
        io_symbol_table_dict = io.read_symbol_table_dict(
            self.filename, restrict_to=[r for r in restrict_read])
        io_output_types = ''.join(sorted(io_symbol_table_dict.keys()))
        self.assertEqual(io_output_types, expected_output_types)
Пример #5
0
 def test_fixed_freeze_none(self):
     tab = symbol.SymbolTable()
     for s in 'abcdefg':
         tab.insert(s)
     tab.freeze(unknown_marker=None)
     self.assertEqual(tab.get_id('Z'), None)