Пример #1
0
    def test_words_with_shared_prefix_should_retain_counts(self):
        print('  - %s' % inspect.stack()[0][3])
        sym_spell = SymSpell(16, 1, 3)
        sym_spell.create_dictionary_entry("pipe", 5)
        sym_spell.create_dictionary_entry("pips", 10)

        result = sym_spell.lookup("pipe", Verbosity.ALL, 1)
        self.assertEqual(2, len(result))
        self.assertEqual("pipe", result[0].term)
        self.assertEqual(5, result[0].count)
        self.assertEqual("pips", result[1].term)
        self.assertEqual(10, result[1].count)

        result = sym_spell.lookup("pips", Verbosity.ALL, 1)
        self.assertEqual(2, len(result))
        self.assertEqual("pips", result[0].term)
        self.assertEqual(10, result[0].count)
        self.assertEqual("pipe", result[1].term)
        self.assertEqual(5, result[1].count)

        result = sym_spell.lookup("pip", Verbosity.ALL, 1)
        self.assertEqual(2, len(result))
        self.assertEqual("pips", result[0].term)
        self.assertEqual(10, result[0].count)
        self.assertEqual("pipe", result[1].term)
        self.assertEqual(5, result[1].count)
Пример #2
0
def initializeSymspell():
    print("inside initializeSymspell()")
    symspell = SymSpell(max_dictionary_edit_distance=2, prefix_length=7)
    print("symspell created")
    resourceNames = [
        "symspellpy", "frequency_dictionary_en_82_765.txt",
        "frequency_bigramdictionary_en_243_342.txt"
    ]
    dictionaryPath = pkg_resources.resource_filename(resourceNames[0],
                                                     resourceNames[1])
    bigramPath = pkg_resources.resource_filename(resourceNames[0],
                                                 resourceNames[2])
    print("dictionaryPath created")
    symspell.load_dictionary(dictionaryPath, 0, 1)
    symspell.create_dictionary_entry(key='ap', count=500000000)
    print(list(islice(symspell.words.items(), 5)))
    print("symspell.load_ditionary() done")
    symspell.load_bigram_dictionary(bigramPath, 0, 1)
    print(list(islice(symspell.bigrams.items(), 5)))
    print("symspell.load_bigram_ditionary() done")

    # Create vocab
    vocab = set([w for w, f in symspell.words.items()])

    return symspell, vocab
Пример #3
0
class Autocorrect:
    def __init__(self, words=None, max_edit_distance=2):
        self._symspell = SymSpell()
        self._max_edit_distance = max_edit_distance
        if words is not None:
            self.add_words(words)

    def add_word(self, word):
        if word is not None:
            self._symspell.create_dictionary_entry(word, 1)

    def add_words(self, words):
        if words is not None:
            self._symspell.create_dictionary(words)

    def delete_word(self, word):
        if word is not None:
            self._symspell.delete_dictionary_entry(word)

    def correct(self, bad_word):
        return self._symspell.lookup(bad_word,
                                     Verbosity.TOP,
                                     max_edit_distance=self._max_edit_distance,
                                     include_unknown=True)[0].term

    def predictions(self, bad_word):
        return self._symspell.lookup(bad_word,
                                     Verbosity.CLOSEST,
                                     max_edit_distance=self._max_edit_distance,
                                     include_unknown=True)
Пример #4
0
    def _create_symspell_checker(self, language: AnyStr) -> SymSpell:
        """Private method to create a SymSpell instance for a given language

        Args:
            language: Language code in ISO 639-1 format

        Returns:
            SymSpell checker instance loaded with the language dictionary

        """
        start = perf_counter()
        logging.info(f"Loading spellchecker for language '{language}'...")
        symspell_checker = SymSpell(
            max_dictionary_edit_distance=self.edit_distance)
        frequency_dict_path = self.dictionary_folder_path + "/" + language + ".txt"
        symspell_checker.load_dictionary(frequency_dict_path,
                                         term_index=0,
                                         count_index=1,
                                         encoding="utf-8")
        if len(self.custom_vocabulary_set) != 0:
            for word in self.custom_vocabulary_set:
                symspell_checker.create_dictionary_entry(key=word, count=1)
        logging.info(
            f"Loading spellchecker for language '{language}': done in {perf_counter() - start:.2f} seconds"
        )
        return symspell_checker
Пример #5
0
 def test_lookup_should_not_return_non_word_delete(self):
     print('  - %s' % inspect.stack()[0][3])
     sym_spell = SymSpell(16, 2, 7, 10)
     sym_spell.create_dictionary_entry("pawn", 10)
     result = sym_spell.lookup("paw", Verbosity.TOP, 0)
     self.assertEqual(0, len(result))
     result = sym_spell.lookup("awn", Verbosity.TOP, 0)
     self.assertEqual(0, len(result))
Пример #6
0
 def test_lookup_should_not_return_low_count_word_that_are_also_delete_word(
         self):
     print('  - %s' % inspect.stack()[0][3])
     sym_spell = SymSpell(16, 2, 7, 10)
     sym_spell.create_dictionary_entry("flame", 20)
     sym_spell.create_dictionary_entry("flam", 1)
     result = sym_spell.lookup("flam", Verbosity.TOP, 0)
     self.assertEqual(0, len(result))
Пример #7
0
 def test_lookup_should_find_exact_match(self):
     print('  - %s' % inspect.stack()[0][3])
     sym_spell = SymSpell()
     sym_spell.create_dictionary_entry("steama", 4)
     sym_spell.create_dictionary_entry("steamb", 6)
     sym_spell.create_dictionary_entry("steamc", 2)
     result = sym_spell.lookup("streama", Verbosity.TOP, 2)
     self.assertEqual(1, len(result))
     self.assertEqual("steama", result[0].term)
Пример #8
0
    def test_add_additional_counts_should_not_add_word_again(self):
        print('  - %s' % inspect.stack()[0][3])
        sym_spell = SymSpell()
        word = "hello"
        sym_spell.create_dictionary_entry(word, 11)
        self.assertEqual(1, sym_spell.word_count)

        sym_spell.create_dictionary_entry(word, 3)
        self.assertEqual(1, sym_spell.word_count)
Пример #9
0
 def test_lookup_should_return_most_frequent(self):
     print('  - %s' % inspect.stack()[0][3])
     sym_spell = SymSpell()
     sym_spell.create_dictionary_entry("steama", 4)
     sym_spell.create_dictionary_entry("steamb", 6)
     sym_spell.create_dictionary_entry("steamc", 2)
     result = sym_spell.lookup("stream", Verbosity.TOP, 2)
     self.assertEqual(1, len(result))
     self.assertEqual("steamb", result[0].term)
     self.assertEqual(6, result[0].count)
Пример #10
0
    def test_add_additional_counts_should_not_overflow(self):
        print('  - %s' % inspect.stack()[0][3])
        sym_spell = SymSpell()
        word = "hello"
        sym_spell.create_dictionary_entry(word, sys.maxsize - 10)
        result = sym_spell.lookup(word, Verbosity.TOP)
        count = result[0].count if len(result) == 1 else 0
        self.assertEqual(sys.maxsize - 10, count)

        sym_spell.create_dictionary_entry(word, 11)
        result = sym_spell.lookup(word, Verbosity.TOP)
        count = result[0].count if len(result) == 1 else 0
        self.assertEqual(sys.maxsize, count)
Пример #11
0
    def test_add_additional_counts_should_increase_count(self):
        print('  - %s' % inspect.stack()[0][3])
        sym_spell = SymSpell()
        word = "hello"
        sym_spell.create_dictionary_entry(word, 11)
        result = sym_spell.lookup(word, Verbosity.TOP)
        count = result[0].count if len(result) == 1 else 0
        self.assertEqual(11, count)

        sym_spell.create_dictionary_entry(word, 3)
        result = sym_spell.lookup(word, Verbosity.TOP)
        count = result[0].count if len(result) == 1 else 0
        self.assertEqual(11 + 3, count)
Пример #12
0
def get_pubmedc_spellcheck():
    global pubmed_dir,pubmedc_url,pubmedc_spell,did_dir
    pubmedc_spell = load_pkl(fdir=pubmed_dir, f='pubmedc_spell')


    if pubmedc_spell is None:
        pubmedc_spell = SymSpell(initial_capacity, max_edit_distance_dictionary, prefix_length)

        fs = requests.get(pubmedc_url).content.decode('utf-8').split()
        fs = [x for x in fs if '1-grams' in x]

        with open(os.path.join(did_dir,'medwords.pkl'),'rb') as f:
            medwords = pickle.load(f)

        for i,f in enumerate(fs):
            r=requests.get(f)
            # vocab = defaultdict(int)
            with gzip.open(BytesIO(r.content),'rb') as zfile:
                while zfile.readline():
                    tmp = zfile.readline().decode('utf-8').split('\t')
                    if tmp and len(tmp) == 4:
                        word,freq=tmp[0].lower(),int(tmp[2])
                        if '=' in word:
                            continue
                        if len(word) <3 or len(word) > 45:
                            continue
                        if freq < 2:
                            continue
                        if len(word) < 3:
                            continue
                        if word in medwords:
                            pubmedc_spell.create_dictionary_entry(word, freq)
                            print('added \'{}\''.format(word))
                    else:
                        print('skipped None word')
                        # if len(tmp) >= 2 and tmp not in glookup:
                        #     print('processing \'{}\''.format(tmp))
                        #     cur.append(tmp)
                print('added {}th file: '.format(i,f))
        # save_pkl(fdir=pubmed_dir,f='pubmedc_spell',obj=pubmedc_spell)

    return pubmedc_spell
Пример #13
0
    async def quote(self, message, args):
        msg = None
        if try_parse_int64(args[0]) is not None:
            msg_id = args[0]
            try:
                msg = await self.client.get_message(message.channel.id, msg_id)
            except Exception as exception:  # pylint: disable=W0703
                LOG.exception(exception)
        else:
            input_term = args[0]
            sym_spell = SymSpell()
            for term in input_term.split(" "):
                sym_spell.create_dictionary_entry(term, 1)
            target = sym_spell.lookup_compound(input_term, 2)[0].term
            iterator = message.channel.history(limit=100)
            for __ in range(100):
                try:
                    msg = await iterator.next()
                    suggestion = sym_spell.lookup_compound(msg.content, 2)[0]
                    if suggestion.term == target:
                        msg = await self.client.get_message(
                            message.channel.id, msg.id)
                        break
                except NoMoreItems:
                    msg = None
        if msg is not None:
            display_name = message.guild.get_member(int(
                msg["author"]["id"])).display_name
            time_str = (
                datetime.strptime(msg["timestamp"].split(".")[0],
                                  "%Y-%m-%dT%H:%M:%S") +
                timedelta(hours=TZ_OFFSET)).strftime("%Y-%m-%d %I:%M %p")
            quote_msg = "```{} - {} UTC+{}\n{}```".format(
                display_name, time_str, TZ_OFFSET, msg["content"])
        else:
            quote_msg = "Message not found!"

        await self.client.send_message(message.channel.id, quote_msg)
        await message.delete()
Пример #14
0
    def test_verbosity_should_control_lookup_results(self):
        print('  - %s' % inspect.stack()[0][3])
        sym_spell = SymSpell()
        sym_spell.create_dictionary_entry("steam", 1)
        sym_spell.create_dictionary_entry("steams", 2)
        sym_spell.create_dictionary_entry("steem", 3)

        result = sym_spell.lookup("steems", Verbosity.TOP, 2)
        self.assertEqual(1, len(result))
        result = sym_spell.lookup("steems", Verbosity.CLOSEST, 2)
        self.assertEqual(2, len(result))
        result = sym_spell.lookup("steems", Verbosity.ALL, 2)
        self.assertEqual(3, len(result))
Пример #15
0
class HeuristicAbbreviationsCompounds(HeuristicPunctuation):
    """
    Generate an abbreviation version of an input string based on its compounds.

    The abbreviation heuristic is far more complex than the other heuristics.
    This is because I can not know if the given query mention is already an abbreviation or the reference entity.
    Example:    Given the mention "embedded subscriber identity module" and the reference entity "eSIM".
                My system does not know if the mention is already an abbreviation or not.
                So instead I assume two possible cases:
                1) The mention is already an abbreviation
                2) The mention is not an abbreviation yet

                For the first case we have to assume that the reference entity is NOT an abbreviation (otherwise a
                prior heuristic should have found a match already).
                The same is valid for the second case, we have to assume that the reference entity IS an abbreviation.
                In order to check both cases I need the reference entities in two versions: original (without any
                stemming, only punctuation removal) and abbreviation.

                This heuristic checks both cases and returns the better result of the two.
    """
    def __init__(self,
                 max_edit_distance_dictionary: int = 0,
                 prefix_length: int = 10,
                 count_threshold: int = 1,
                 compact_level: int = 5,
                 prob_threshold: float = 0.1):
        """
        :param prob_threshold: This threshold is used for the compound splitter. CharSplit returns a probability for
        how likely the compounds are for a word. All compounds that have a lower probability are discarded. Furthermore,
        this threshold is the condition to stop the recursive compound splitting.

        :param max_edit_distance_dictionary: A threshold to ensure that the matched abbreviation is as similar as
        possible to the reference entity.
        """
        super().__init__(max_edit_distance_dictionary, prefix_length,
                         count_threshold, compact_level)
        assert prob_threshold > 0
        self._prop_threshold = prob_threshold
        self.max_edit_distance_dictionary = max_edit_distance_dictionary

        # The symspell dictionary and mapping for the unrefactored entities (case 1)
        self._original_sym_speller = None
        self._original_rule_mapping = None

    def name(self):
        return "abbreviations"

    def initialize_sym_speller(self):
        super().initialize_sym_speller()
        self._original_sym_speller = SymSpell(
            self.max_edit_distance_dictionary, self.prefix_length,
            self.count_threshold, self.compact_level)
        self._original_rule_mapping = {}

    def _split(self, s: str) -> List[str]:
        return split_compounds(s, prop_threshold=self._prop_threshold)

    def _refactor(self, s: str) -> str:
        s = super()._refactor(s)
        compounds = self._split(s)

        # Create an abbreviation from the compounds. However, pure digit compounds are kept as is.
        abbreviation = "".join([
            compound if compound.isdigit() else compound[:1].capitalize()
            for compound in compounds
        ])

        return abbreviation

    def lookup(self,
               mention: str,
               original_mention: str = "") -> Tuple[List[SuggestItem], str]:
        """
        The abbreviation heuristic only uses the original mention instead of a previously refactored mention.
        This is because previous refactoring might affect the compound splitting.
        """
        original_mention = str(original_mention)

        # Case1: the mention is the abbreviation, the original entity is not known as abbreviation
        # I only do a punctuation refactoring on the mention and then match against the abbreviation refactored entities
        case1_refactored_mention = super()._refactor(original_mention)
        case1_suggestions = self._sym_speller.lookup(case1_refactored_mention,
                                                     Verbosity.CLOSEST)
        case1_distance = 99999
        for suggestion in case1_suggestions:
            suggestion.reference_entities = self._rule_mapping[suggestion.term]
            case1_distance = suggestion.distance

        # Case2: the mention is currently not an abbreviation but the original entity is only known as abbreviation
        # I do the abbreviation refactoring on the mention and then match against the punctuation refactored entities
        case2_refactored_mention = self._refactor(original_mention)
        case2_suggestions = self._original_sym_speller.lookup(
            case2_refactored_mention, Verbosity.CLOSEST)
        case2_distance = 99999
        for suggestion in case2_suggestions:
            suggestion.reference_entities = self._original_rule_mapping[
                suggestion.term]
            case2_distance = suggestion.distance

        # Only return the better result
        if case1_distance < case2_distance:
            return case1_suggestions, case1_refactored_mention
        else:
            return case2_suggestions, case2_refactored_mention

    def add_dictionary_entity(self,
                              entity: str,
                              original_entity: str = "") -> str:
        """
        This heuristic needs the refactored entity as well as the original entity in its dictionary because it is
        possible, that either the mention or the reference entity is the abbreviation.
        """
        original_entity = str(original_entity)

        # Add the entities in abbreviation form to the own sym spell dictionary
        abbreviation = self._refactor(original_entity)
        self._sym_speller.create_dictionary_entry(abbreviation, 1)
        if abbreviation not in self._rule_mapping:
            self._rule_mapping[abbreviation] = {original_entity}
        else:
            self._rule_mapping[abbreviation].update({original_entity})

        # Because the original entity could already be the abbreviation, we also want to save the original one.
        # Do the same for the mapping and apply the punctuation heuristic to the original entity.
        punctuation_refactored_entity = super()._refactor(original_entity)
        self._original_sym_speller.create_dictionary_entry(
            punctuation_refactored_entity, 1)
        if punctuation_refactored_entity not in self._original_rule_mapping:
            self._original_rule_mapping[punctuation_refactored_entity] = {
                original_entity
            }
        else:
            self._original_rule_mapping[punctuation_refactored_entity].update(
                {original_entity})

        return abbreviation
Пример #16
0
class Heuristic(metaclass=ABCMeta):
    def __init__(self,
                 max_edit_distance_dictionary: int = 5,
                 prefix_length: int = 10,
                 count_threshold: int = 1,
                 compact_level: int = 5):
        """
        Note:  lower max_edit_distance and higher prefix_length=2*max_edit_distance == faster
        """
        self.max_edit_distance_dictionary = max_edit_distance_dictionary
        self.prefix_length = prefix_length
        self.count_threshold = count_threshold
        self.compact_level = compact_level

        self._sym_speller = None
        self._rule_mapping = {}

    @property
    @abstractmethod
    def name(self):
        pass

    @abstractmethod
    def _refactor(self, s: str) -> str:
        pass

    def initialize_sym_speller(self):
        self._sym_speller = SymSpell(self.max_edit_distance_dictionary,
                                     self.prefix_length, self.count_threshold,
                                     self.compact_level)
        self._rule_mapping = {}

    def lookup(self,
               mention: str,
               original_mention: str = "") -> Tuple[List[SuggestItem], str]:
        """
        The original mention is optional here because it is only necessary for very specific heuristics.
        In the general case, it will be ignored and a previously refactored mention will be further refactored here.
        """
        mention = str(mention)

        refactored_mention = self._refactor(mention)
        suggestions = self._sym_speller.lookup(refactored_mention,
                                               Verbosity.CLOSEST)

        # Look up the unrefactored entities for the suggestions and save them
        for suggestion in suggestions:
            suggestion.reference_entities = self._rule_mapping[suggestion.term]

        return suggestions, refactored_mention

    def add_dictionary_entity(self,
                              entity: str,
                              original_entity: str = "") -> str:
        """
        Adds an entity to the sym spell dictionary of this heuristic.
        The original entity is usually not necessary except for very specific heuristics that rely on the original,
        untouched entity. Otherwise, the previously refactored entitiy is used here.
        """
        entity = str(entity)

        # Apply the current heuristic
        refactored_entity = self._refactor(entity)

        # Save the refactored entity in the heuristic symspeller + the mapping to the untouched entity
        self._sym_speller.create_dictionary_entry(refactored_entity, 1)
        if refactored_entity not in self._rule_mapping:
            self._rule_mapping[refactored_entity] = {original_entity}
        else:
            self._rule_mapping[refactored_entity].update({original_entity})

        return refactored_entity
Пример #17
0
class Corrector(Dictionary):

	def __init__(self):
		self.verbosity = Verbosity.ALL
		self.symspell = SymSpell(max_dictionary_edit_distance=3, count_threshold=0)

	def load_symspell(self):
		for k,v in tqdm(self.uni_dict.items(), 'Loading symspell...'):
			self.symspell.create_dictionary_entry(key=k, count=v)

	def _gen_character_candidates(self, character):
		for line in self.diacritic:
			if character in line:
				return line
		return [character]

	def _gen_word_candidates(self, word):
		cands = [word]
		for i in range(len(word)):
			can_list = self._gen_character_candidates(word[i])
			cands += [word[:i] + c + rest
						for c in can_list
						for rest in self._gen_word_candidates(word[i+1:])]
		return cands

	def _gen_diacritic_candidates(self, term):
		if term == '<num>':
			return [term]
		cands = list(set(self._gen_word_candidates(term)))
		cands = sorted(cands, key=lambda x: -self._c1w(x))[:min(10, len(cands))]
		return cands

	def gen_states(self, obs):
		states = {}
		new_obs = []
		for ob in obs:
			states[ob] = []
			if len(ob)==1 and re.match(r"[^aáàảãạăằắẳẵặâầấẩẫậeèéẻẽẹêềếểễệiìíỉĩịoóòỏõọôồổỗộơờớởỡợuùúũụưứửữựyỳýỷỹỵ]", ob):
				continue
			new_obs += [ob]

			if len(ob)<=4: edit_distance=1
			elif len(ob)>4 and len(ob)<=10: edit_distance=2
			elif len(ob)>10: edit_distance=3

			states[ob] = sorted([c.term for c in self.symspell.lookup(
				phrase=ob,
				verbosity=self.verbosity,
				max_edit_distance=edit_distance,
				include_unknown=False,
				ignore_token=r'<num>'
			)],
			key=lambda x: -self._c1w(x))[:30]
		
		return new_obs, states
	
	def _trans(self, cur, prev, prev_prev=None):
		if prev_prev is None:
			return self.cpw(cur, prev)
		else:
			return self.cp3w(cur, prev, prev_prev)

	def _emiss(self, cur_observe, cur_state):
		return self.words_similarity(cur_observe, cur_state) +\
				self.pw(cur_state)

	def _viterbi_decoder(self, obs, states):
		V = [{}]
		path = {}

		for st in states[obs[0]]:
			V[0][st] = 1.0
			path[st] = [st]

		for i in range(1, len(obs)):
			V.append({})
			new_path = {}

			for st in states[obs[i]]:
				# if i==1:
				prob, state = max([
					(V[i-1][prev_st]*self._trans(st, prev_st)*self._emiss(obs[i], st), prev_st)
					for prev_st in states[obs[i-1]]
				])

				# else:
				# 	prob, state = max([
				# 		(V[i-1][prev_st]*self._trans(st, prev_st, prev_prev_st)*self._emiss(obs[i], st), prev_st)
				# 		for prev_st in states[obs[i-1]]
				# 		for prev_prev_st in states[obs[i-2]]
				# 	])

				V[i][st] = prob
				new_path[st] = path[state] + [st]

			path = new_path

		with open('data/viterbi_tracking.txt', 'w+', encoding='utf-8') as writer:
			# writer.write(json.dumps(path, indent=4, ensure_ascii=False) + '\n')
			writer.write(json.dumps(V, indent=4, ensure_ascii=False) + '\n')

		prob, state = max([(V[-1][st], st) for st in states[obs[-1]]])

		return {
			"prob": prob,
			"result": ' '.join(path[state])
		}

	def correct(self, text):
		# obs = ['<START>'] + text.split() + ['<END>']
		obs = text.split()
		obs, states = self.gen_states(obs)
		result = self._viterbi_decoder(obs, states)
		# return [" ".join(r[1:-1]) for r in result]
		return result
Пример #18
0
class SpellCheck:
    def __init__(self, progress, directory, countries_dict):
        self.progress = progress
        self.logger = logging.getLogger(__name__)
        self.spelling_update = Counter()
        self.directory = directory
        self.spell_path = os.path.join(self.directory, 'spelling.pkl')
        self.countries_dict = countries_dict
        self.sym_spell = SymSpell()

    def insert(self, name, iso):
        if 'gothland cemetery' not in name and name not in noise_words:
            name_tokens = name.split(' ')
            for word in name_tokens:
                key = f'{word}'
                if len(key) > 2:
                    self.spelling_update[key] += 1

    def write(self):
        # Create blank spelling dictionary
        path = os.path.join(self.directory, 'spelling.tmp')
        fl = open(path, 'w')
        fl.write('the,1\n')
        fl.close()
        success = self.sym_spell.create_dictionary(corpus=path)
        if not success:
            self.logger.error(f"error creating spelling dictionary")

        self.logger.info('Building Spelling Dictionary')

        # Add all words from geonames into spelling dictionary
        for key in self.spelling_update:
            self.sym_spell.create_dictionary_entry(
                key=key, count=self.spelling_update[key])

        self.logger.info('Writing Spelling Dictionary')
        self.sym_spell.save_pickle(self.spell_path)

    def read(self):
        success = False
        if os.path.exists(self.spell_path):
            self.logger.info(
                f'Loading Spelling Dictionary from {self.spell_path}')
            success = self.sym_spell.load_pickle(self.spell_path)
        else:
            self.logger.error(
                f"spelling dictionary not found: {self.spell_path}")

        if not success:
            self.logger.error(
                f"error loading spelling dictionary from {self.spell_path}")
        else:
            self.sym_spell.delete_dictionary_entry(key='gothland')

        size = len(self.sym_spell.words)
        self.logger.info(f"Spelling Dictionary contains {size} words")

    def lookup(self, input_term):
        #suggestions = [SymSpell.    SuggestItem]
        if '*' in input_term:
            return input_term
        res = ''
        if len(input_term) > 1:
            suggestions = self.sym_spell.lookup(input_term,
                                                Verbosity.CLOSEST,
                                                max_edit_distance=2,
                                                include_unknown=True)
            for idx, item in enumerate(suggestions):
                if idx > 3:
                    break
                #self.logger.debug(f'{item._term}')
                if item._term[0] == input_term[0]:
                    # Only accept results where first letter matches
                    res += item._term + ' '
            return res
        else:
            return input_term

    def lookup_compound(self, phrase):
        suggestions = self.sym_spell.lookup_compound(phrase=phrase,
                                                     max_edit_distance=2,
                                                     ignore_non_words=False)
        for item in suggestions:
            self.logger.debug(f'{item._term}')
        return suggestions[0]._term

    def fix_spelling(self, text):
        new_text = text
        if bool(re.search(r'\d', text)):
            # Has digits, just return text, no spellcheck
            pass
        elif 'st ' in text:
            # Spellcheck not handling St properly
            pass
        else:
            if len(text) > 0:
                new_text = self.lookup(text)
                self.logger.debug(f'Spell {text} -> {new_text}')

        return new_text.strip(' ')
Пример #19
0
class NoisyChannelModel():
    
    def __init__(self, lm, max_ed=4, prefix_length=7, l=1, channel_method_poisson=True, channel_prob_param=0.02):
        self.show_progress = False
        self.lm = lm
        self.l = l
        self.channel_method_poisson = channel_method_poisson
        self.channel_prob_param = channel_prob_param
        
        self.sym_spell = SymSpell(max_ed, prefix_length)
        
        if isinstance(self.lm, GPT2LMHeadModel):
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            self.lm_sent_logscore = self.gpt2_sent_logscore
            self.beam_init = self.beam_GPT_init
            self.skipstart = 1
            self.skipend = -1
            self.update_sentence_history = self.updateGPT2history
            self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
            for subword in range(self.tokenizer.vocab_size):
                self.sym_spell.create_dictionary_entry(key=self.tokenizer.decode(subword), count=1)
        else:
            self.lm_sent_logscore = self.ngram_sent_logscore
            self.beam_init = self.beam_ngram_init
            self.skipstart = self.lm.order-1
            self.skipend = None
            self.update_sentence_history = self.updatengramhistory
            self.tokenizer = ngramTokenizer(self.lm)
            for word in lm.vocab:
                self.sym_spell.create_dictionary_entry(key=word, count=self.lm.counts[word])
    
    
    def GPTrun(self, indexed_tokens, past=None):
        tokens_tensor = torch.tensor([indexed_tokens])
        tokens_tensor = tokens_tensor.to(self.device)
        with torch.no_grad():
            return self.lm(tokens_tensor, past=past, labels=tokens_tensor)
    
    
    def gpt2_sent_logscore(self, sentence):
        loss, next_loss = self.sentence_history[sentence[:self.pos]]
        return loss + next_loss[sentence[self.pos]]
    
    def gpt2_nohist_sent_logscore(self, sentence):
        loss, prediction_scores, past = self.GPTrun(sentence)
        return np.array(-(loss.cpu()))/np.log(2)
    
    
    def updateGPT2history(self):
        if self.pos > 1:
            for sentence in tuple(self.suggestion_sentences):
                formed_sentence = sentence[:self.pos]
                loss, prediction_scores, past = self.GPTrun(formed_sentence)
                next_loss = prediction_scores[0, -1].cpu().detach().numpy()
                self.sentence_history[formed_sentence] = (np.array(-(loss.cpu()))/np.log(2), np.log2(softmax(next_loss)))
        else:
            formed_sentence = torch.tensor([self.tokenizer.bos_token_id]).to(self.device)
            prediction_scores, past = self.lm(formed_sentence)
            formed_sentence = tuple([formed_sentence.item()])
            next_loss = prediction_scores[0].cpu().detach().numpy()
            loss = np.array(0)
            self.sentence_history[formed_sentence] = (loss, np.log2(softmax(next_loss)))
    
    
    def ngram_sent_logscore(self, sentence):
        qs = []
        for ngram in ngrams(sentence, self.lm.order):
            q = (ngram[-1], ngram[:-1])
            if q not in self.logscoredb:
                self.logscoredb[q] = self.lm.logscore(*q)
            qs += [q]
        return np.array([self.logscoredb[q] for q in qs]).sum()
    
    
    def updatengramhistory(self):
        return None
    
    
    def channel_probabilities(self):
        eds = np.array([candidate.distance for candidate in self.candidates])
        logprobs = self.poisson_channel_model(eds) if self.channel_method_poisson else self.inv_prop_channel_model(eds)
        self.channel_logprobs = {candidate.term: logprob for candidate, logprob in zip(self.candidates, logprobs)}
    
    
    def poisson_channel_model(self, eds):
        for ed in eds:
            if ed not in self.poisson_probsdb:
                self.poisson_probsdb[ed] = np.log2(poisson.pmf(k=ed, mu=self.channel_prob_param))
        return np.array([self.poisson_probsdb[ed] for ed in eds])
    
    
    def inv_prop_channel_model(self, eds):
        inv_eds = np.reciprocal(eds.astype(float), where=eds!=0)
        inv_eds[inv_eds < 1e-100] = 0.
        probs = (1-self.channel_prob_param)/inv_eds.sum() * inv_eds
        return np.log2(np.where(probs == 0., self.channel_prob_param, probs))
    
    
    def generate_suggestion_sentences(self):
        new_suggestion_sentences = {}
        self.update_sentence_history()
        for changed_word in tuple(self.channel_logprobs):
            if self.channel_logprobs[changed_word] != 0:
                for sentence in tuple(self.suggestion_sentences):
                    new_sentence = list(sentence)
                    new_sentence[self.pos] = changed_word
                    new_sentence = tuple(new_sentence)
                    new_suggestion_sentences[new_sentence] = self.lm_sent_logscore(new_sentence) * self.l + self.channel_logprobs[changed_word]
        self.suggestion_sentences.update(new_suggestion_sentences)
    
    
    def beam_all_init(self, input_sentence):
        self.logscoredb = {}
        self.poisson_probsdb = {}
        self.channel_logprobs = None
        self.suggestion_sentences = None
        self.candidates = None
        self.pos = 0
        if self.channel_method_poisson:
            chan_prob = np.log2(poisson.pmf(k=0, mu=self.channel_prob_param))
        else:
            chan_prob = np.log2(self.channel_prob_param)
        return self.beam_init(input_sentence, chan_prob)
    
    
    def beam_GPT_init(self, input_sentence, chan_prob):
        self.sentence_history = {}
        observed_sentence = tuple(self.tokenizer.encode(self.tokenizer.bos_token + input_sentence + self.tokenizer.eos_token))
        self.suggestion_sentences = {observed_sentence: self.gpt2_nohist_sent_logscore(observed_sentence) * self.l + chan_prob}
        return observed_sentence
    
    
    def beam_ngram_init(self, input_sentence, chan_prob):
        observed_sentence = self.tokenizer.encode(input_sentence)
        self.suggestion_sentences = {observed_sentence: self.lm_sent_logscore(observed_sentence) * self.l + chan_prob}
        return observed_sentence
    
    
    def beam_search(self, input_sentence, beam_width=10, max_ed=3, candidates_cutoff=50):
        observed_sentence = self.beam_all_init(input_sentence)
        for e, observed_word in enumerate(observed_sentence[self.skipstart:self.skipend]):
            self.pos = e + self.skipstart
            lookup_word = self.tokenizer.decode(observed_word) if isinstance(self.lm, GPT2LMHeadModel) else observed_word
            if lookup_word == ' ':
                continue
            self.candidates = self.sym_spell.lookup(lookup_word, Verbosity.ALL, max_ed)[:candidates_cutoff]
            if isinstance(self.lm, GPT2LMHeadModel):
                for candidate in self.candidates:
                    candidate.term = self.tokenizer.encode(candidate.term)[0]
            self.channel_probabilities()
            self.generate_suggestion_sentences()
            self.suggestion_sentences = dict(sorted(self.suggestion_sentences.items(), key = lambda kv:(kv[1], kv[0]), reverse=True)[:beam_width])
        if isinstance(self.lm, GPT2LMHeadModel):
            return {self.tokenizer.decode(sentence)[13:-13]: np.power(2, self.suggestion_sentences[sentence]) for sentence in self.suggestion_sentences}
        else:
            return {self.tokenizer.decode(sentence): np.power(2, self.suggestion_sentences[sentence]) for sentence in self.suggestion_sentences}
    
    
    def beam_search_sentences(self, sentences):
        iterate = tqdm(sentences) if self.show_progress else sentences
        df = pd.DataFrame()
        for sent in iterate:
            corrections = self.beam_search(sent)
            df_sents = pd.DataFrame(corrections.keys())
            df_probs = pd.DataFrame(corrections.values())
            df_sents = df_sents.append(df_probs, ignore_index=True).transpose()
            df = df.append(df_sents, ignore_index=True)
        return df
    
    
    def viterbi(self, input_sentence, max_ed=3, candidates_cutoff=10):
        observed_sentence = self.tokenizer.encode(input_sentence)
        V = {(1, ('<s>','<s>')): 0}
        backpointer = {}
        candidate_words = [['<s>'], ['<s>']]
        gengram = lambda ngram, w: tuple([w] + list(ngram[:-1]))
        self.poisson_probsdb = {}
        
        for e, observed_word in enumerate(observed_sentence[self.skipstart:]):
            t = e + self.skipstart
            self.candidates = self.sym_spell.lookup(observed_word, Verbosity.ALL, max_ed, transfer_casing=True)[:candidates_cutoff]
            self.channel_probabilities()
            candidate_words += [[candidate.term for candidate in self.candidates]]
            for ngram in itertools.product(*candidate_words[t-self.lm.order+2:t+1]):
                options = [V[t-1, gengram(ngram, w)] + self.lm.logscore(ngram[-1], gengram(ngram, w)) + self.channel_logprobs[ngram[-1]] for w in candidate_words[t-self.lm.order+1]]
                best_option = np.argmax(options)
                V[t, ngram] = options[best_option]
                backpointer[t] = candidate_words[t-self.lm.order+1][best_option]
        return self.tokenizer.decode([token for _, token in backpointer.items()])