def test_words_with_shared_prefix_should_retain_counts(self): print(' - %s' % inspect.stack()[0][3]) sym_spell = SymSpell(16, 1, 3) sym_spell.create_dictionary_entry("pipe", 5) sym_spell.create_dictionary_entry("pips", 10) result = sym_spell.lookup("pipe", Verbosity.ALL, 1) self.assertEqual(2, len(result)) self.assertEqual("pipe", result[0].term) self.assertEqual(5, result[0].count) self.assertEqual("pips", result[1].term) self.assertEqual(10, result[1].count) result = sym_spell.lookup("pips", Verbosity.ALL, 1) self.assertEqual(2, len(result)) self.assertEqual("pips", result[0].term) self.assertEqual(10, result[0].count) self.assertEqual("pipe", result[1].term) self.assertEqual(5, result[1].count) result = sym_spell.lookup("pip", Verbosity.ALL, 1) self.assertEqual(2, len(result)) self.assertEqual("pips", result[0].term) self.assertEqual(10, result[0].count) self.assertEqual("pipe", result[1].term) self.assertEqual(5, result[1].count)
def initializeSymspell(): print("inside initializeSymspell()") symspell = SymSpell(max_dictionary_edit_distance=2, prefix_length=7) print("symspell created") resourceNames = [ "symspellpy", "frequency_dictionary_en_82_765.txt", "frequency_bigramdictionary_en_243_342.txt" ] dictionaryPath = pkg_resources.resource_filename(resourceNames[0], resourceNames[1]) bigramPath = pkg_resources.resource_filename(resourceNames[0], resourceNames[2]) print("dictionaryPath created") symspell.load_dictionary(dictionaryPath, 0, 1) symspell.create_dictionary_entry(key='ap', count=500000000) print(list(islice(symspell.words.items(), 5))) print("symspell.load_ditionary() done") symspell.load_bigram_dictionary(bigramPath, 0, 1) print(list(islice(symspell.bigrams.items(), 5))) print("symspell.load_bigram_ditionary() done") # Create vocab vocab = set([w for w, f in symspell.words.items()]) return symspell, vocab
class Autocorrect: def __init__(self, words=None, max_edit_distance=2): self._symspell = SymSpell() self._max_edit_distance = max_edit_distance if words is not None: self.add_words(words) def add_word(self, word): if word is not None: self._symspell.create_dictionary_entry(word, 1) def add_words(self, words): if words is not None: self._symspell.create_dictionary(words) def delete_word(self, word): if word is not None: self._symspell.delete_dictionary_entry(word) def correct(self, bad_word): return self._symspell.lookup(bad_word, Verbosity.TOP, max_edit_distance=self._max_edit_distance, include_unknown=True)[0].term def predictions(self, bad_word): return self._symspell.lookup(bad_word, Verbosity.CLOSEST, max_edit_distance=self._max_edit_distance, include_unknown=True)
def _create_symspell_checker(self, language: AnyStr) -> SymSpell: """Private method to create a SymSpell instance for a given language Args: language: Language code in ISO 639-1 format Returns: SymSpell checker instance loaded with the language dictionary """ start = perf_counter() logging.info(f"Loading spellchecker for language '{language}'...") symspell_checker = SymSpell( max_dictionary_edit_distance=self.edit_distance) frequency_dict_path = self.dictionary_folder_path + "/" + language + ".txt" symspell_checker.load_dictionary(frequency_dict_path, term_index=0, count_index=1, encoding="utf-8") if len(self.custom_vocabulary_set) != 0: for word in self.custom_vocabulary_set: symspell_checker.create_dictionary_entry(key=word, count=1) logging.info( f"Loading spellchecker for language '{language}': done in {perf_counter() - start:.2f} seconds" ) return symspell_checker
def test_lookup_should_not_return_non_word_delete(self): print(' - %s' % inspect.stack()[0][3]) sym_spell = SymSpell(16, 2, 7, 10) sym_spell.create_dictionary_entry("pawn", 10) result = sym_spell.lookup("paw", Verbosity.TOP, 0) self.assertEqual(0, len(result)) result = sym_spell.lookup("awn", Verbosity.TOP, 0) self.assertEqual(0, len(result))
def test_lookup_should_not_return_low_count_word_that_are_also_delete_word( self): print(' - %s' % inspect.stack()[0][3]) sym_spell = SymSpell(16, 2, 7, 10) sym_spell.create_dictionary_entry("flame", 20) sym_spell.create_dictionary_entry("flam", 1) result = sym_spell.lookup("flam", Verbosity.TOP, 0) self.assertEqual(0, len(result))
def test_lookup_should_find_exact_match(self): print(' - %s' % inspect.stack()[0][3]) sym_spell = SymSpell() sym_spell.create_dictionary_entry("steama", 4) sym_spell.create_dictionary_entry("steamb", 6) sym_spell.create_dictionary_entry("steamc", 2) result = sym_spell.lookup("streama", Verbosity.TOP, 2) self.assertEqual(1, len(result)) self.assertEqual("steama", result[0].term)
def test_add_additional_counts_should_not_add_word_again(self): print(' - %s' % inspect.stack()[0][3]) sym_spell = SymSpell() word = "hello" sym_spell.create_dictionary_entry(word, 11) self.assertEqual(1, sym_spell.word_count) sym_spell.create_dictionary_entry(word, 3) self.assertEqual(1, sym_spell.word_count)
def test_lookup_should_return_most_frequent(self): print(' - %s' % inspect.stack()[0][3]) sym_spell = SymSpell() sym_spell.create_dictionary_entry("steama", 4) sym_spell.create_dictionary_entry("steamb", 6) sym_spell.create_dictionary_entry("steamc", 2) result = sym_spell.lookup("stream", Verbosity.TOP, 2) self.assertEqual(1, len(result)) self.assertEqual("steamb", result[0].term) self.assertEqual(6, result[0].count)
def test_add_additional_counts_should_not_overflow(self): print(' - %s' % inspect.stack()[0][3]) sym_spell = SymSpell() word = "hello" sym_spell.create_dictionary_entry(word, sys.maxsize - 10) result = sym_spell.lookup(word, Verbosity.TOP) count = result[0].count if len(result) == 1 else 0 self.assertEqual(sys.maxsize - 10, count) sym_spell.create_dictionary_entry(word, 11) result = sym_spell.lookup(word, Verbosity.TOP) count = result[0].count if len(result) == 1 else 0 self.assertEqual(sys.maxsize, count)
def test_add_additional_counts_should_increase_count(self): print(' - %s' % inspect.stack()[0][3]) sym_spell = SymSpell() word = "hello" sym_spell.create_dictionary_entry(word, 11) result = sym_spell.lookup(word, Verbosity.TOP) count = result[0].count if len(result) == 1 else 0 self.assertEqual(11, count) sym_spell.create_dictionary_entry(word, 3) result = sym_spell.lookup(word, Verbosity.TOP) count = result[0].count if len(result) == 1 else 0 self.assertEqual(11 + 3, count)
def get_pubmedc_spellcheck(): global pubmed_dir,pubmedc_url,pubmedc_spell,did_dir pubmedc_spell = load_pkl(fdir=pubmed_dir, f='pubmedc_spell') if pubmedc_spell is None: pubmedc_spell = SymSpell(initial_capacity, max_edit_distance_dictionary, prefix_length) fs = requests.get(pubmedc_url).content.decode('utf-8').split() fs = [x for x in fs if '1-grams' in x] with open(os.path.join(did_dir,'medwords.pkl'),'rb') as f: medwords = pickle.load(f) for i,f in enumerate(fs): r=requests.get(f) # vocab = defaultdict(int) with gzip.open(BytesIO(r.content),'rb') as zfile: while zfile.readline(): tmp = zfile.readline().decode('utf-8').split('\t') if tmp and len(tmp) == 4: word,freq=tmp[0].lower(),int(tmp[2]) if '=' in word: continue if len(word) <3 or len(word) > 45: continue if freq < 2: continue if len(word) < 3: continue if word in medwords: pubmedc_spell.create_dictionary_entry(word, freq) print('added \'{}\''.format(word)) else: print('skipped None word') # if len(tmp) >= 2 and tmp not in glookup: # print('processing \'{}\''.format(tmp)) # cur.append(tmp) print('added {}th file: '.format(i,f)) # save_pkl(fdir=pubmed_dir,f='pubmedc_spell',obj=pubmedc_spell) return pubmedc_spell
async def quote(self, message, args): msg = None if try_parse_int64(args[0]) is not None: msg_id = args[0] try: msg = await self.client.get_message(message.channel.id, msg_id) except Exception as exception: # pylint: disable=W0703 LOG.exception(exception) else: input_term = args[0] sym_spell = SymSpell() for term in input_term.split(" "): sym_spell.create_dictionary_entry(term, 1) target = sym_spell.lookup_compound(input_term, 2)[0].term iterator = message.channel.history(limit=100) for __ in range(100): try: msg = await iterator.next() suggestion = sym_spell.lookup_compound(msg.content, 2)[0] if suggestion.term == target: msg = await self.client.get_message( message.channel.id, msg.id) break except NoMoreItems: msg = None if msg is not None: display_name = message.guild.get_member(int( msg["author"]["id"])).display_name time_str = ( datetime.strptime(msg["timestamp"].split(".")[0], "%Y-%m-%dT%H:%M:%S") + timedelta(hours=TZ_OFFSET)).strftime("%Y-%m-%d %I:%M %p") quote_msg = "```{} - {} UTC+{}\n{}```".format( display_name, time_str, TZ_OFFSET, msg["content"]) else: quote_msg = "Message not found!" await self.client.send_message(message.channel.id, quote_msg) await message.delete()
def test_verbosity_should_control_lookup_results(self): print(' - %s' % inspect.stack()[0][3]) sym_spell = SymSpell() sym_spell.create_dictionary_entry("steam", 1) sym_spell.create_dictionary_entry("steams", 2) sym_spell.create_dictionary_entry("steem", 3) result = sym_spell.lookup("steems", Verbosity.TOP, 2) self.assertEqual(1, len(result)) result = sym_spell.lookup("steems", Verbosity.CLOSEST, 2) self.assertEqual(2, len(result)) result = sym_spell.lookup("steems", Verbosity.ALL, 2) self.assertEqual(3, len(result))
class HeuristicAbbreviationsCompounds(HeuristicPunctuation): """ Generate an abbreviation version of an input string based on its compounds. The abbreviation heuristic is far more complex than the other heuristics. This is because I can not know if the given query mention is already an abbreviation or the reference entity. Example: Given the mention "embedded subscriber identity module" and the reference entity "eSIM". My system does not know if the mention is already an abbreviation or not. So instead I assume two possible cases: 1) The mention is already an abbreviation 2) The mention is not an abbreviation yet For the first case we have to assume that the reference entity is NOT an abbreviation (otherwise a prior heuristic should have found a match already). The same is valid for the second case, we have to assume that the reference entity IS an abbreviation. In order to check both cases I need the reference entities in two versions: original (without any stemming, only punctuation removal) and abbreviation. This heuristic checks both cases and returns the better result of the two. """ def __init__(self, max_edit_distance_dictionary: int = 0, prefix_length: int = 10, count_threshold: int = 1, compact_level: int = 5, prob_threshold: float = 0.1): """ :param prob_threshold: This threshold is used for the compound splitter. CharSplit returns a probability for how likely the compounds are for a word. All compounds that have a lower probability are discarded. Furthermore, this threshold is the condition to stop the recursive compound splitting. :param max_edit_distance_dictionary: A threshold to ensure that the matched abbreviation is as similar as possible to the reference entity. """ super().__init__(max_edit_distance_dictionary, prefix_length, count_threshold, compact_level) assert prob_threshold > 0 self._prop_threshold = prob_threshold self.max_edit_distance_dictionary = max_edit_distance_dictionary # The symspell dictionary and mapping for the unrefactored entities (case 1) self._original_sym_speller = None self._original_rule_mapping = None def name(self): return "abbreviations" def initialize_sym_speller(self): super().initialize_sym_speller() self._original_sym_speller = SymSpell( self.max_edit_distance_dictionary, self.prefix_length, self.count_threshold, self.compact_level) self._original_rule_mapping = {} def _split(self, s: str) -> List[str]: return split_compounds(s, prop_threshold=self._prop_threshold) def _refactor(self, s: str) -> str: s = super()._refactor(s) compounds = self._split(s) # Create an abbreviation from the compounds. However, pure digit compounds are kept as is. abbreviation = "".join([ compound if compound.isdigit() else compound[:1].capitalize() for compound in compounds ]) return abbreviation def lookup(self, mention: str, original_mention: str = "") -> Tuple[List[SuggestItem], str]: """ The abbreviation heuristic only uses the original mention instead of a previously refactored mention. This is because previous refactoring might affect the compound splitting. """ original_mention = str(original_mention) # Case1: the mention is the abbreviation, the original entity is not known as abbreviation # I only do a punctuation refactoring on the mention and then match against the abbreviation refactored entities case1_refactored_mention = super()._refactor(original_mention) case1_suggestions = self._sym_speller.lookup(case1_refactored_mention, Verbosity.CLOSEST) case1_distance = 99999 for suggestion in case1_suggestions: suggestion.reference_entities = self._rule_mapping[suggestion.term] case1_distance = suggestion.distance # Case2: the mention is currently not an abbreviation but the original entity is only known as abbreviation # I do the abbreviation refactoring on the mention and then match against the punctuation refactored entities case2_refactored_mention = self._refactor(original_mention) case2_suggestions = self._original_sym_speller.lookup( case2_refactored_mention, Verbosity.CLOSEST) case2_distance = 99999 for suggestion in case2_suggestions: suggestion.reference_entities = self._original_rule_mapping[ suggestion.term] case2_distance = suggestion.distance # Only return the better result if case1_distance < case2_distance: return case1_suggestions, case1_refactored_mention else: return case2_suggestions, case2_refactored_mention def add_dictionary_entity(self, entity: str, original_entity: str = "") -> str: """ This heuristic needs the refactored entity as well as the original entity in its dictionary because it is possible, that either the mention or the reference entity is the abbreviation. """ original_entity = str(original_entity) # Add the entities in abbreviation form to the own sym spell dictionary abbreviation = self._refactor(original_entity) self._sym_speller.create_dictionary_entry(abbreviation, 1) if abbreviation not in self._rule_mapping: self._rule_mapping[abbreviation] = {original_entity} else: self._rule_mapping[abbreviation].update({original_entity}) # Because the original entity could already be the abbreviation, we also want to save the original one. # Do the same for the mapping and apply the punctuation heuristic to the original entity. punctuation_refactored_entity = super()._refactor(original_entity) self._original_sym_speller.create_dictionary_entry( punctuation_refactored_entity, 1) if punctuation_refactored_entity not in self._original_rule_mapping: self._original_rule_mapping[punctuation_refactored_entity] = { original_entity } else: self._original_rule_mapping[punctuation_refactored_entity].update( {original_entity}) return abbreviation
class Heuristic(metaclass=ABCMeta): def __init__(self, max_edit_distance_dictionary: int = 5, prefix_length: int = 10, count_threshold: int = 1, compact_level: int = 5): """ Note: lower max_edit_distance and higher prefix_length=2*max_edit_distance == faster """ self.max_edit_distance_dictionary = max_edit_distance_dictionary self.prefix_length = prefix_length self.count_threshold = count_threshold self.compact_level = compact_level self._sym_speller = None self._rule_mapping = {} @property @abstractmethod def name(self): pass @abstractmethod def _refactor(self, s: str) -> str: pass def initialize_sym_speller(self): self._sym_speller = SymSpell(self.max_edit_distance_dictionary, self.prefix_length, self.count_threshold, self.compact_level) self._rule_mapping = {} def lookup(self, mention: str, original_mention: str = "") -> Tuple[List[SuggestItem], str]: """ The original mention is optional here because it is only necessary for very specific heuristics. In the general case, it will be ignored and a previously refactored mention will be further refactored here. """ mention = str(mention) refactored_mention = self._refactor(mention) suggestions = self._sym_speller.lookup(refactored_mention, Verbosity.CLOSEST) # Look up the unrefactored entities for the suggestions and save them for suggestion in suggestions: suggestion.reference_entities = self._rule_mapping[suggestion.term] return suggestions, refactored_mention def add_dictionary_entity(self, entity: str, original_entity: str = "") -> str: """ Adds an entity to the sym spell dictionary of this heuristic. The original entity is usually not necessary except for very specific heuristics that rely on the original, untouched entity. Otherwise, the previously refactored entitiy is used here. """ entity = str(entity) # Apply the current heuristic refactored_entity = self._refactor(entity) # Save the refactored entity in the heuristic symspeller + the mapping to the untouched entity self._sym_speller.create_dictionary_entry(refactored_entity, 1) if refactored_entity not in self._rule_mapping: self._rule_mapping[refactored_entity] = {original_entity} else: self._rule_mapping[refactored_entity].update({original_entity}) return refactored_entity
class Corrector(Dictionary): def __init__(self): self.verbosity = Verbosity.ALL self.symspell = SymSpell(max_dictionary_edit_distance=3, count_threshold=0) def load_symspell(self): for k,v in tqdm(self.uni_dict.items(), 'Loading symspell...'): self.symspell.create_dictionary_entry(key=k, count=v) def _gen_character_candidates(self, character): for line in self.diacritic: if character in line: return line return [character] def _gen_word_candidates(self, word): cands = [word] for i in range(len(word)): can_list = self._gen_character_candidates(word[i]) cands += [word[:i] + c + rest for c in can_list for rest in self._gen_word_candidates(word[i+1:])] return cands def _gen_diacritic_candidates(self, term): if term == '<num>': return [term] cands = list(set(self._gen_word_candidates(term))) cands = sorted(cands, key=lambda x: -self._c1w(x))[:min(10, len(cands))] return cands def gen_states(self, obs): states = {} new_obs = [] for ob in obs: states[ob] = [] if len(ob)==1 and re.match(r"[^aáàảãạăằắẳẵặâầấẩẫậeèéẻẽẹêềếểễệiìíỉĩịoóòỏõọôồổỗộơờớởỡợuùúũụưứửữựyỳýỷỹỵ]", ob): continue new_obs += [ob] if len(ob)<=4: edit_distance=1 elif len(ob)>4 and len(ob)<=10: edit_distance=2 elif len(ob)>10: edit_distance=3 states[ob] = sorted([c.term for c in self.symspell.lookup( phrase=ob, verbosity=self.verbosity, max_edit_distance=edit_distance, include_unknown=False, ignore_token=r'<num>' )], key=lambda x: -self._c1w(x))[:30] return new_obs, states def _trans(self, cur, prev, prev_prev=None): if prev_prev is None: return self.cpw(cur, prev) else: return self.cp3w(cur, prev, prev_prev) def _emiss(self, cur_observe, cur_state): return self.words_similarity(cur_observe, cur_state) +\ self.pw(cur_state) def _viterbi_decoder(self, obs, states): V = [{}] path = {} for st in states[obs[0]]: V[0][st] = 1.0 path[st] = [st] for i in range(1, len(obs)): V.append({}) new_path = {} for st in states[obs[i]]: # if i==1: prob, state = max([ (V[i-1][prev_st]*self._trans(st, prev_st)*self._emiss(obs[i], st), prev_st) for prev_st in states[obs[i-1]] ]) # else: # prob, state = max([ # (V[i-1][prev_st]*self._trans(st, prev_st, prev_prev_st)*self._emiss(obs[i], st), prev_st) # for prev_st in states[obs[i-1]] # for prev_prev_st in states[obs[i-2]] # ]) V[i][st] = prob new_path[st] = path[state] + [st] path = new_path with open('data/viterbi_tracking.txt', 'w+', encoding='utf-8') as writer: # writer.write(json.dumps(path, indent=4, ensure_ascii=False) + '\n') writer.write(json.dumps(V, indent=4, ensure_ascii=False) + '\n') prob, state = max([(V[-1][st], st) for st in states[obs[-1]]]) return { "prob": prob, "result": ' '.join(path[state]) } def correct(self, text): # obs = ['<START>'] + text.split() + ['<END>'] obs = text.split() obs, states = self.gen_states(obs) result = self._viterbi_decoder(obs, states) # return [" ".join(r[1:-1]) for r in result] return result
class SpellCheck: def __init__(self, progress, directory, countries_dict): self.progress = progress self.logger = logging.getLogger(__name__) self.spelling_update = Counter() self.directory = directory self.spell_path = os.path.join(self.directory, 'spelling.pkl') self.countries_dict = countries_dict self.sym_spell = SymSpell() def insert(self, name, iso): if 'gothland cemetery' not in name and name not in noise_words: name_tokens = name.split(' ') for word in name_tokens: key = f'{word}' if len(key) > 2: self.spelling_update[key] += 1 def write(self): # Create blank spelling dictionary path = os.path.join(self.directory, 'spelling.tmp') fl = open(path, 'w') fl.write('the,1\n') fl.close() success = self.sym_spell.create_dictionary(corpus=path) if not success: self.logger.error(f"error creating spelling dictionary") self.logger.info('Building Spelling Dictionary') # Add all words from geonames into spelling dictionary for key in self.spelling_update: self.sym_spell.create_dictionary_entry( key=key, count=self.spelling_update[key]) self.logger.info('Writing Spelling Dictionary') self.sym_spell.save_pickle(self.spell_path) def read(self): success = False if os.path.exists(self.spell_path): self.logger.info( f'Loading Spelling Dictionary from {self.spell_path}') success = self.sym_spell.load_pickle(self.spell_path) else: self.logger.error( f"spelling dictionary not found: {self.spell_path}") if not success: self.logger.error( f"error loading spelling dictionary from {self.spell_path}") else: self.sym_spell.delete_dictionary_entry(key='gothland') size = len(self.sym_spell.words) self.logger.info(f"Spelling Dictionary contains {size} words") def lookup(self, input_term): #suggestions = [SymSpell. SuggestItem] if '*' in input_term: return input_term res = '' if len(input_term) > 1: suggestions = self.sym_spell.lookup(input_term, Verbosity.CLOSEST, max_edit_distance=2, include_unknown=True) for idx, item in enumerate(suggestions): if idx > 3: break #self.logger.debug(f'{item._term}') if item._term[0] == input_term[0]: # Only accept results where first letter matches res += item._term + ' ' return res else: return input_term def lookup_compound(self, phrase): suggestions = self.sym_spell.lookup_compound(phrase=phrase, max_edit_distance=2, ignore_non_words=False) for item in suggestions: self.logger.debug(f'{item._term}') return suggestions[0]._term def fix_spelling(self, text): new_text = text if bool(re.search(r'\d', text)): # Has digits, just return text, no spellcheck pass elif 'st ' in text: # Spellcheck not handling St properly pass else: if len(text) > 0: new_text = self.lookup(text) self.logger.debug(f'Spell {text} -> {new_text}') return new_text.strip(' ')
class NoisyChannelModel(): def __init__(self, lm, max_ed=4, prefix_length=7, l=1, channel_method_poisson=True, channel_prob_param=0.02): self.show_progress = False self.lm = lm self.l = l self.channel_method_poisson = channel_method_poisson self.channel_prob_param = channel_prob_param self.sym_spell = SymSpell(max_ed, prefix_length) if isinstance(self.lm, GPT2LMHeadModel): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.lm_sent_logscore = self.gpt2_sent_logscore self.beam_init = self.beam_GPT_init self.skipstart = 1 self.skipend = -1 self.update_sentence_history = self.updateGPT2history self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2') for subword in range(self.tokenizer.vocab_size): self.sym_spell.create_dictionary_entry(key=self.tokenizer.decode(subword), count=1) else: self.lm_sent_logscore = self.ngram_sent_logscore self.beam_init = self.beam_ngram_init self.skipstart = self.lm.order-1 self.skipend = None self.update_sentence_history = self.updatengramhistory self.tokenizer = ngramTokenizer(self.lm) for word in lm.vocab: self.sym_spell.create_dictionary_entry(key=word, count=self.lm.counts[word]) def GPTrun(self, indexed_tokens, past=None): tokens_tensor = torch.tensor([indexed_tokens]) tokens_tensor = tokens_tensor.to(self.device) with torch.no_grad(): return self.lm(tokens_tensor, past=past, labels=tokens_tensor) def gpt2_sent_logscore(self, sentence): loss, next_loss = self.sentence_history[sentence[:self.pos]] return loss + next_loss[sentence[self.pos]] def gpt2_nohist_sent_logscore(self, sentence): loss, prediction_scores, past = self.GPTrun(sentence) return np.array(-(loss.cpu()))/np.log(2) def updateGPT2history(self): if self.pos > 1: for sentence in tuple(self.suggestion_sentences): formed_sentence = sentence[:self.pos] loss, prediction_scores, past = self.GPTrun(formed_sentence) next_loss = prediction_scores[0, -1].cpu().detach().numpy() self.sentence_history[formed_sentence] = (np.array(-(loss.cpu()))/np.log(2), np.log2(softmax(next_loss))) else: formed_sentence = torch.tensor([self.tokenizer.bos_token_id]).to(self.device) prediction_scores, past = self.lm(formed_sentence) formed_sentence = tuple([formed_sentence.item()]) next_loss = prediction_scores[0].cpu().detach().numpy() loss = np.array(0) self.sentence_history[formed_sentence] = (loss, np.log2(softmax(next_loss))) def ngram_sent_logscore(self, sentence): qs = [] for ngram in ngrams(sentence, self.lm.order): q = (ngram[-1], ngram[:-1]) if q not in self.logscoredb: self.logscoredb[q] = self.lm.logscore(*q) qs += [q] return np.array([self.logscoredb[q] for q in qs]).sum() def updatengramhistory(self): return None def channel_probabilities(self): eds = np.array([candidate.distance for candidate in self.candidates]) logprobs = self.poisson_channel_model(eds) if self.channel_method_poisson else self.inv_prop_channel_model(eds) self.channel_logprobs = {candidate.term: logprob for candidate, logprob in zip(self.candidates, logprobs)} def poisson_channel_model(self, eds): for ed in eds: if ed not in self.poisson_probsdb: self.poisson_probsdb[ed] = np.log2(poisson.pmf(k=ed, mu=self.channel_prob_param)) return np.array([self.poisson_probsdb[ed] for ed in eds]) def inv_prop_channel_model(self, eds): inv_eds = np.reciprocal(eds.astype(float), where=eds!=0) inv_eds[inv_eds < 1e-100] = 0. probs = (1-self.channel_prob_param)/inv_eds.sum() * inv_eds return np.log2(np.where(probs == 0., self.channel_prob_param, probs)) def generate_suggestion_sentences(self): new_suggestion_sentences = {} self.update_sentence_history() for changed_word in tuple(self.channel_logprobs): if self.channel_logprobs[changed_word] != 0: for sentence in tuple(self.suggestion_sentences): new_sentence = list(sentence) new_sentence[self.pos] = changed_word new_sentence = tuple(new_sentence) new_suggestion_sentences[new_sentence] = self.lm_sent_logscore(new_sentence) * self.l + self.channel_logprobs[changed_word] self.suggestion_sentences.update(new_suggestion_sentences) def beam_all_init(self, input_sentence): self.logscoredb = {} self.poisson_probsdb = {} self.channel_logprobs = None self.suggestion_sentences = None self.candidates = None self.pos = 0 if self.channel_method_poisson: chan_prob = np.log2(poisson.pmf(k=0, mu=self.channel_prob_param)) else: chan_prob = np.log2(self.channel_prob_param) return self.beam_init(input_sentence, chan_prob) def beam_GPT_init(self, input_sentence, chan_prob): self.sentence_history = {} observed_sentence = tuple(self.tokenizer.encode(self.tokenizer.bos_token + input_sentence + self.tokenizer.eos_token)) self.suggestion_sentences = {observed_sentence: self.gpt2_nohist_sent_logscore(observed_sentence) * self.l + chan_prob} return observed_sentence def beam_ngram_init(self, input_sentence, chan_prob): observed_sentence = self.tokenizer.encode(input_sentence) self.suggestion_sentences = {observed_sentence: self.lm_sent_logscore(observed_sentence) * self.l + chan_prob} return observed_sentence def beam_search(self, input_sentence, beam_width=10, max_ed=3, candidates_cutoff=50): observed_sentence = self.beam_all_init(input_sentence) for e, observed_word in enumerate(observed_sentence[self.skipstart:self.skipend]): self.pos = e + self.skipstart lookup_word = self.tokenizer.decode(observed_word) if isinstance(self.lm, GPT2LMHeadModel) else observed_word if lookup_word == ' ': continue self.candidates = self.sym_spell.lookup(lookup_word, Verbosity.ALL, max_ed)[:candidates_cutoff] if isinstance(self.lm, GPT2LMHeadModel): for candidate in self.candidates: candidate.term = self.tokenizer.encode(candidate.term)[0] self.channel_probabilities() self.generate_suggestion_sentences() self.suggestion_sentences = dict(sorted(self.suggestion_sentences.items(), key = lambda kv:(kv[1], kv[0]), reverse=True)[:beam_width]) if isinstance(self.lm, GPT2LMHeadModel): return {self.tokenizer.decode(sentence)[13:-13]: np.power(2, self.suggestion_sentences[sentence]) for sentence in self.suggestion_sentences} else: return {self.tokenizer.decode(sentence): np.power(2, self.suggestion_sentences[sentence]) for sentence in self.suggestion_sentences} def beam_search_sentences(self, sentences): iterate = tqdm(sentences) if self.show_progress else sentences df = pd.DataFrame() for sent in iterate: corrections = self.beam_search(sent) df_sents = pd.DataFrame(corrections.keys()) df_probs = pd.DataFrame(corrections.values()) df_sents = df_sents.append(df_probs, ignore_index=True).transpose() df = df.append(df_sents, ignore_index=True) return df def viterbi(self, input_sentence, max_ed=3, candidates_cutoff=10): observed_sentence = self.tokenizer.encode(input_sentence) V = {(1, ('<s>','<s>')): 0} backpointer = {} candidate_words = [['<s>'], ['<s>']] gengram = lambda ngram, w: tuple([w] + list(ngram[:-1])) self.poisson_probsdb = {} for e, observed_word in enumerate(observed_sentence[self.skipstart:]): t = e + self.skipstart self.candidates = self.sym_spell.lookup(observed_word, Verbosity.ALL, max_ed, transfer_casing=True)[:candidates_cutoff] self.channel_probabilities() candidate_words += [[candidate.term for candidate in self.candidates]] for ngram in itertools.product(*candidate_words[t-self.lm.order+2:t+1]): options = [V[t-1, gengram(ngram, w)] + self.lm.logscore(ngram[-1], gengram(ngram, w)) + self.channel_logprobs[ngram[-1]] for w in candidate_words[t-self.lm.order+1]] best_option = np.argmax(options) V[t, ngram] = options[best_option] backpointer[t] = candidate_words[t-self.lm.order+1][best_option] return self.tokenizer.decode([token for _, token in backpointer.items()])