def dump_outputs(src_ids, gold_ids, predicted_ids, gold_tok_dist, id2tok, out_file, pred_dists=None): out_hits = [] preds_for_bleu = [] golds_for_bleu = [] srcs_for_bleu = [] if pred_dists is None: pred_dists = [''] * len(src_ids) for src_seq, gold_seq, pred_seq, gold_dist, pred_dist in zip( src_ids, gold_ids, predicted_ids, gold_tok_dist, pred_dists): src_seq = [id2tok[x] for x in src_seq] gold_seq = [id2tok[x] for x in gold_seq] pred_seq = [id2tok[x] for x in pred_seq[1:]] # ignore start token if '止' in gold_seq: gold_seq = gold_seq[:gold_seq.index('止')] if '止' in pred_seq: pred_seq = pred_seq[:pred_seq.index('止')] gold_replace = [ chunk for tag, chunk in diff(src_seq, gold_seq) if tag == '+' ] pred_replace = [ chunk for tag, chunk in diff(src_seq, pred_seq) if tag == '+' ] src_seq = ' '.join(src_seq).replace('[PAD]', '').strip() gold_seq = ' '.join(gold_seq).replace('[PAD]', '').strip() pred_seq = ' '.join(pred_seq).replace('[PAD]', '').strip() # try: if out_file is not None: print('#' * 80, file=out_file) print('IN SEQ: \t', src_seq.encode('utf-8'), file=out_file) print('GOLD SEQ: \t', gold_seq.encode('utf-8'), file=out_file) print('PRED SEQ:\t', pred_seq.encode('utf-8'), file=out_file) print('GOLD DIST: \t', list(gold_dist), file=out_file) print('PRED DIST: \t', list(pred_dist), file=out_file) print('GOLD TOK: \t', list(gold_replace), file=out_file) print('PRED TOK: \t', list(pred_replace), file=out_file) # except UnicodeEncodeError: # pass if gold_seq == pred_seq: out_hits.append(1) else: out_hits.append(0) preds_for_bleu.append(pred_seq.split()) golds_for_bleu.append(gold_seq.split()) srcs_for_bleu.append(src_seq.split()) return out_hits, preds_for_bleu, golds_for_bleu, srcs_for_bleu
def should_keep(prev_raw, prev_tok, post_raw, post_tok, bleu, rev_id): global CTR_LOW_BLEU global CTR_LOW_LEVEN global CTR_TOO_MANY_1_TOKS global CTR_SPELLING global CTR_CHEMISTRY global CTR_ONLY_PUNC_CHANGED # KEEP -- exact match if bleu == 100 or prev_raw == post_raw: return True, None, '0', ['0' for _ in range(len(prev_tok.split()))] # clearly not a match if bleu < 15.0: CTR_LOW_BLEU += 1 return False, None, None, None # too close if Levenshtein.distance(prev_tok, post_tok) < 4: CTR_LOW_LEVEN += 1 return False, None, None, None tok_diff = diff(prev_tok.split(), post_tok.split()) tok_labels = get_tok_labels(tok_diff) assert len(tok_labels) == len(prev_tok.split()) changed_text = ''.join([''.join(chunk) for tag, chunk in tok_diff if tag != '=']) if not re.search('[a-z]', changed_text): CTR_ONLY_PUNC_CHANGED += 1 return False, None, None, None # too dissimilar -- less than half of toks shared tok_nums = [int(x) for x in tok_labels] if ( sum(tok_nums) * 1.0 / len(tok_nums) ) > 0.5: CTR_TOO_MANY_1_TOKS += 1 return False, None, None, None # edit was just fixing a spelling error word_diff = diff(word_tokenize(prev_raw), word_tokenize(post_raw)) if is_spelling_diff(word_diff): CTR_SPELLING += 1 return False, None, None, None # some simple filtering to get out the chemistry "neutral" edits if ' molecules' in prev_raw or ' ions' in prev_raw or ' ionic' in prev_raw or ' atoms' in prev_raw: CTR_CHEMISTRY += 1 return False, None, None, None single_word_edit = sum([len(chunk) for tag, chunk in word_diff if tag == '-']) == 1 return True, single_word_edit, '1', tok_labels
def telegram_webhook(): update = request.get_json() if "message" in update: m = update["message"]["text"] #ms = m.split(' ') chat = update["message"]["chat"] chat_id = chat["id"] skribajxo = Skribajxo.get() if m == '/g': bot.sendMessage(chat_id, skribajxo.enhavo) elif m == '/start': bot.sendMessage(170378225, u'یک گپ تازه با من شروع شد!') bot.sendMessage(chat_id, 'سلام.\n برای دریافت متن با آخرین تغییرات بزنید /g و متن را کپی کنید؛ حداکثر ۷ نویسه (حرف، فاصله و…) را تغییر دهید و بفرستید.') elif chat_id == 170378225 and m == '/uzantoj': Uzantoj = Uzanto.select().order_by(Uzanto.id) uzantoj = '' for uzanto in Uzantoj: uzantoj += '>>> '+str(uzanto.tid)+': '+uzanto.nomo+':'+uzanto.familio+': @'+uzanto.uzantnomo+': '+str(uzanto.kontribuinta)+'\n' uzantoj += '---------\n'+str(Uzantoj.count()) bot.sendMessage(170378225, uzantoj) else: try: uzanto = Uzanto.get(Uzanto.tid == chat_id) except Uzanto.DoesNotExist: try: chat['username'] except: chat['username'] = '' try: chat['first_name'] except: chat['first_name'] = '' try: chat['last_name'] except: chat['last_name'] = '' uzanto = Uzanto.create(tid=chat['id'], uzantnomo=chat['username'], nomo=chat['first_name'], familio=chat['last_name']) if (datetime.datetime.now() - uzanto.lastaredakto).seconds >= datetime.timedelta(seconds=30).seconds: if ':\n' in m: lasta_duponktoj = m.rindex(':\n') m = m[lasta_duponktoj+2:] diferencoj = diff(skribajxo.enhavo, m) diferenco_nombro = 0 for diferenco in diferencoj: if diferenco[0] != '=': diferenco_nombro += len(diferenco[1]) #bot.sendMessage(chat_id, str(diferenco_nombro)+'\n'+str(diferencoj)) if diferenco_nombro < 8: skribajxo.enhavo = m skribajxo.save() uzanto.lastaredakto = datetime.datetime.now() uzanto.kontribuinta += diferenco_nombro uzanto.save() bot.sendMessage(chat_id, 'تغییرات با موفقیت انجام شد!') else: bot.sendMessage(chat_id, 'حداکثر ۷ نویسه (حرف، فاصله و…) باید تغییر کند! دقت کنید که نامتان در ابتدای پیام ارسالی نباشد. ممکن است متنی که دارید قدیمی باشد. برای دریافت متن با آخرین تغییرات بزنید /g') else: bot.sendMessage(chat_id, '{} ثانیهٔ دیگر منتظر بمانید!'.format(persi(30-(datetime.datetime.now() - uzanto.lastaredakto).seconds))) return "OK"
def diff_changed(old, new): ''' Returns the differences basend on characters between two strings wrapped with DIFFON and DIFFOFF using `diff`. ''' con = {'=': (lambda x: x), '+': (lambda x: DIFFON + x + DIFFOFF), '-': (lambda x: '')} return "".join([(con[a])("".join(b)) for a, b in diff(old, new)])
def punct_diff(a, b): d = diff(a.split(), b.split()) changed_text = ''.join( [''.join(chunk).strip() for tag, chunk in d if tag != '=']) if not re.search('[a-z]', changed_text): return True elif re.sub(r'[^\w\s]', '', a) == re.sub(r'[^\w\s]', '', b): return True return False
def correct_url_list(self, url_list, parent_url): correct_url_list = [] parent_parsed = urlparse(parent_url) for url in url_list: url_parsed = urlparse(url) dif = set([d[1] for d in diff(parent_parsed.netloc, url_parsed.netloc) if d[0] in ('+', '-')]) if url_parsed.netloc == '': url = '{uri.scheme}://{uri.netloc}'.format(uri=parent_parsed) + "/" + url_parsed.path correct_url_list.append(url) elif url_parsed.netloc == parent_parsed.netloc or dif in ['www.']: correct_url_list.append(url) return correct_url_list
def get_tok_labels(s1_toks, s2_toks): s_diff = diff(s1_toks, s2_toks) tok_labels = [] for tag, chunk in s_diff: if tag == '=': tok_labels += ['0'] * len(chunk) elif tag == '-': tok_labels += ['1'] * len(chunk) else: pass assert len(tok_labels) == len(s1_toks) return tok_labels
def word_diff(old, new): ''' Returns the difference between the old and new strings based on words. Punctuation is not part of the word. Params: old the old string new the new string Returns: the output of `diff` on the two strings after splitting them on whitespace (a list of change instructions; see the docstring of `diff`) ''' separator_pattern = '(\W+)' return diff(re.split(separator_pattern, old, flags=re.UNICODE), re.split(separator_pattern, new, flags=re.UNICODE))
def word_diff(old, new): ''' Returns the difference between the old and new strings based on words. Punctuation is not part of the word. Params: old the old string new the new string Returns: the output of `diff` on the two strings after splitting them on whitespace (a list of change instructions; see the docstring of `diff`) ''' separator_pattern = '(\W+)'; return diff(re.split(separator_pattern, old, flags=re.UNICODE), re.split(separator_pattern, new, flags=re.UNICODE))
def list_inline_diff(oldlist, newlist, colors=None): if not colors: colors = init_colors(False) diff = simplediff.diff(oldlist, newlist) ret = [] for change, value in diff: value = ' '.join(value) if change == '=': ret.append("'%s'" % value) elif change == '+': item = '{color_add}+{value}{color_default}'.format(value=value, **colors) ret.append(item) elif change == '-': item = '{color_remove}-{value}{color_default}'.format(value=value, **colors) ret.append(item) return '[%s]' % (', '.join(ret))
def split_with_diff(src_lines, tgt_lines): content = [] src_attr = [] tgt_attr = [] for src, tgt in zip(src_lines, tgt_lines): sent_diff = diff(src, tgt) tok_collector = defaultdict(list) for source, chunk in sent_diff: tok_collector[source] += chunk content.append(tok_collector['='][:]) src_attr.append(tok_collector['-'][:]) tgt_attr.append(tok_collector['+'][:]) return content[:], content[:], src_attr, tgt_attr
def cal_idx(previous_sql, gold_sql): pre_cur_id = -1 gold_match_id = [] diff_ans = diff(previous_sql, gold_sql) for x, y in diff_ans: if x == '=': for _ in range(len(y)): pre_cur_id += 1 gold_match_id.append(pre_cur_id) elif x == '+': for _ in range(len(y)): gold_match_id.append(-1) else: for _ in range(len(y)): pre_cur_id += 1 assert len(gold_match_id) == len(gold_sql) return gold_match_id
def worddiff_str(oldstr, newstr, colors=None): if not colors: colors = init_colors(False) diff = simplediff.diff(oldstr.split(' '), newstr.split(' ')) ret = [] for change, value in diff: value = ' '.join(value) if change == '=': ret.append(value) elif change == '+': item = '{color_add}{{+{value}+}}{color_default}'.format(value=value, **colors) ret.append(item) elif change == '-': item = '{color_remove}[-{value}-]{color_default}'.format(value=value, **colors) ret.append(item) whitespace_note = '' if oldstr != newstr and ' '.join(oldstr.split()) == ' '.join(newstr.split()): whitespace_note = ' (whitespace changed)' return '"%s"%s' % (' '.join(ret), whitespace_note)
def is_spelling_diff(prev_sent, post_sent): d = diff(word_tokenize(prev_sent), word_tokenize(post_sent)) # only look at the one-word diffs if sum([len(chunk) for tag, chunk in d if tag == '-']) > 1: return False sp = SpellChecker() for i, (tag, words) in enumerate(d): # is one-word spelling replacement if tag == '-' and \ i+1 < len(d) - 1 and \ len(words) == 1 and \ d[i+1][0] == '+' and \ not sp.correction(words[0]) == words[0] and \ sp.correction(words[0]) in ' '.join(d[i+1][1]): return True return False
def test_character_diff(self): strings = TESTS['character'] for check in strings: self.assertEqual(simplediff.diff(check['old'], check['new']), check['diff'])
def get_examples(data_path, tok2id, max_seq_len, noise=False, add_del_tok=False, categories_path=None): global REL2ID global POS2ID global EDIT_TYPE2ID global ARGS if ARGS.drop_words is not None: drop_set = set([l.strip() for l in open(ARGS.drop_words)]) else: drop_set = None def pad(id_arr, pad_idx): return id_arr + ([pad_idx] * (max_seq_len - len(id_arr))) skipped = 0 out = defaultdict(list) if categories_path is not None: category_fp = open(categories_path) next(category_fp) # ignore header revid2topic = { l.strip().split(',')[0]: [float(x) for x in l.strip().split(',')[1:]] for l in category_fp } for i, (line) in enumerate(tqdm(open(data_path))): parts = line.strip().split('\t') # if there pos/rel info if len(parts) == 7: [revid, pre, post, _, _, pos, rels] = parts # no pos/rel info elif len(parts) == 5: [revid, pre, post, _, _] = parts pos = ' '.join(['<UNK>'] * len(pre.strip().split())) rels = ' '.join(['<UNK>'] * len(pre.strip().split())) # broken line else: skipped += 1 continue # break up tokens tokens = pre.strip().split() post_tokens = post.strip().split() rels = rels.strip().split() pos = pos.strip().split() # get diff + binary diff masks tok_diff = diff(tokens, post_tokens) pre_tok_labels, post_tok_labels = get_tok_labels(tok_diff) # make sure everything lines up if len(tokens) != len(pre_tok_labels) \ or len(tokens) != len(rels) \ or len(tokens) != len(pos) \ or len(post_tokens) != len(post_tok_labels): skipped += 1 continue # leave room in the post for start/stop and possible category/class token if len(tokens) > max_seq_len - 1 or len(post_tokens) > max_seq_len - 1: skipped += 1 continue # category info if provided # TODO -- if provided but not in diyi's data, we fill with random...is that ok? if categories_path is not None and revid in revid2topic: categories = revid2topic[revid] else: categories = np.random.uniform( size=43) # 43 = number of categories categories = categories / sum(categories) # normalize if ARGS.category_input: category_id = np.argmax(categories) tokens = ['[unused%d]' % category_id] + tokens pre_tok_labels = [EDIT_TYPE2ID['mask']] + pre_tok_labels post_tok_labels = [EDIT_TYPE2ID['mask']] + post_tok_labels # add start + end symbols to post in/out post_input_tokens = ['行'] + post_tokens post_output_tokens = post_tokens + ['止'] # shuffle + convert to ids + pad try: if noise: pre_toks = noise_seq(tokens[:], drop_prob=ARGS.noise_prob, shuf_dist=ARGS.shuf_dist, drop_set=drop_set, keep_bigrams=ARGS.keep_bigrams) else: pre_toks = tokens pre_ids = pad([tok2id[x] for x in pre_toks], 0) post_in_ids = pad([tok2id[x] for x in post_input_tokens], 0) post_out_ids = pad([tok2id[x] for x in post_output_tokens], 0) pre_tok_label_ids = pad(pre_tok_labels, EDIT_TYPE2ID['mask']) post_tok_label_ids = pad(post_tok_labels, EDIT_TYPE2ID['mask']) rel_ids = pad([REL2ID.get(x, REL2ID['<UNK>']) for x in rels], 0) pos_ids = pad([POS2ID.get(x, POS2ID['<UNK>']) for x in pos], 0) except KeyError: # TODO F**K THIS ENCODING BUG!!! skipped += 1 continue input_mask = pad([0] * len(tokens), 1) pre_len = len(tokens) out['pre_ids'].append(pre_ids) out['pre_masks'].append(input_mask) out['pre_lens'].append(pre_len) out['post_in_ids'].append(post_in_ids) out['post_out_ids'].append(post_out_ids) out['pre_tok_label_ids'].append(pre_tok_label_ids) out['post_tok_label_ids'].append(post_tok_label_ids) out['rel_ids'].append(rel_ids) out['pos_ids'].append(pos_ids) out['categories'].append(categories) print('SKIPPED ', skipped) return out
def test_words_diff(self): strings = TESTS['words'] for check in strings: self.assertEqual(simplediff.diff(check['old'], check['new']), check['diff'])
def test_character_diff(self): strings = TESTS["character"] for check in strings: self.assertEqual(simplediff.diff(check["old"], check["new"]), check["diff"])
def test_words_diff(self): strings = TESTS["words"] for check in strings: self.assertEqual(simplediff.diff(check["old"], check["new"]), check["diff"])
new_element = new_tree.xpath( old_path ) # Check for a one-to-one path relationship. if len( new_element ) != 1: # If the list is empty, the element is missing. # Otherwise, the path matches more than one element, which means elements have been added. if len( new_element ) == 0: print bill_id, "* Element missing:", old_path else: print bill_id, old_path, old_element, new_element else: old_element_keys = sorted( old_element.keys() ) new_element_keys = sorted( new_element[0].keys() ) key_diff = diff( old_element_keys, new_element_keys ) # Check for attribute changes. for state, attributes in key_diff: if state != '=': print bill_id, "* Element has difference in attributes:", old_path # Check for missing or added attributes. if state == "-": for attribute in attributes: print bill_id, "** Attribute missing:", attribute + '="' + old_element.attrib[attribute] + '"' elif state == "+": for attribute in attributes: print bill_id, "** Attribute added:", attribute + '="' + new_element[0].attrib[attribute] + '"' else: # Check for attribute value changes.
def get_examples(data_path, tok2id, max_seq_len, noise=False, add_del_tok=False, categories_path=None): global REL2ID global POS2ID global EDIT_TYPE2ID global ARGS if ARGS.drop_words is not None: drop_set = set([l.strip() for l in open(ARGS.drop_words)]) else: drop_set = None # pad_idx: token-based element adding, just need the token def pad(id_arr, pad_idx): return id_arr + ([pad_idx] * (max_seq_len - len(id_arr))) skipped = 0 out = defaultdict(list) if categories_path is not None: category_fp = open(categories_path) next(category_fp) # ignore header ## previous # topic is a vector presentation with different semantic meanings revid2topic = { l.strip().split(',')[0]: [float(x) for x in l.strip().split(',')[1:]] for l in category_fp } # # ## Bing's opinion # TODO check the existing code, it should bypass the previous code # for l in category_fp: # revid2topic = { # l.strip().split(',')[0]: [float(x) for x in l.strip().split(',')[1:]] # } # add the encoding setting by bing , encoding="utf-8" for i, (line) in enumerate(tqdm(open(data_path, encoding="utf-8"))): # iterate over the whole lines to return a defaultdict called out parts = line.strip().split('\t') # if there pos/rel info if len(parts) == 7: [revid, pre, post, _, _, pos, rels] = parts # no pos/rel info elif len(parts) == 5: [revid, pre, post, _, _] = parts pos = ' '.join(['<UNK>'] * len(pre.strip().split())) rels = ' '.join(['<UNK>'] * len(pre.strip().split())) # broken line else: skipped += 1 continue # break up tokens # by default, we use the whitespace to split tokens = pre.strip().split() post_tokens = post.strip().split() rels = rels.strip().split() pos = pos.strip().split() # get diff + binary diff masks # print("tokens: ", tokens) # print("post_tokens: ", post_tokens) """" example setting a=['ch', '##lor', '##of', '##or', '##m'] b=['ch', '##lor', '##of', '##or', 'new', 'things'] diff(a,b) Out[19]: [('=', ['ch', '##lor', '##of', '##or']), ('-', ['##m']), ('+', ['new', 'things'])] """ tok_diff = diff(tokens, post_tokens) pre_tok_labels, post_tok_labels = get_tok_labels(tok_diff) # print("pre_tok_labels", pre_tok_labels) # print("post_tok_labels", post_tok_labels) # exit(0) # make sure everything lines up # double check the data data: bing --- they will only use the equal length token # by checing: the len(tokens) != len(pre_tok_labels) and len(post_tokens) != len(post_tok_labels) should be useless? # in my opinion if len(tokens) != len(pre_tok_labels) \ or len(tokens) != len(rels) \ or len(tokens) != len(pos) \ or len(post_tokens) != len(post_tok_labels): skipped += 1 continue # leave room in the post for start/stop and possible category/class token if len(tokens) > max_seq_len - 1 or len(post_tokens) > max_seq_len - 1: skipped += 1 continue # category info if provided # TODO -- if provided but not in diyi's data, we fill with random...is that ok? if categories_path is not None and revid in revid2topic: categories = revid2topic[revid] else: categories = np.random.uniform( size=43) # 43 = number of categories categories = categories / sum(categories) # normalize if ARGS.category_input: category_id = np.argmax(categories) tokens = ['[unused%d]' % category_id] + tokens pre_tok_labels = [EDIT_TYPE2ID['mask']] + pre_tok_labels post_tok_labels = [EDIT_TYPE2ID['mask']] + post_tok_labels # add start + end symbols to post in/out # any usage for this chinese charater? TODO check by bing post_input_tokens = ['行'] + post_tokens post_output_tokens = post_tokens + ['止'] # shuffle + convert to ids + pad try: # TODO check the possible usage of noise, in setting, we do not use by bing if noise: pre_toks = noise_seq(tokens[:], drop_prob=ARGS.noise_prob, shuf_dist=ARGS.shuf_dist, drop_set=drop_set, keep_bigrams=ARGS.keep_bigrams) else: pre_toks = tokens # bt default, max 80, we need to pad pre_ids = pad([tok2id[x] for x in pre_toks], 0) # not by zero, but [PAD] format # TODO: why post two tokens? by bing post_in_ids = pad([tok2id[x] for x in post_input_tokens], 0) post_out_ids = pad([tok2id[x] for x in post_output_tokens], 0) # label setting and processing pre_tok_label_ids = pad(pre_tok_labels, EDIT_TYPE2ID['mask']) post_tok_label_ids = pad(post_tok_labels, EDIT_TYPE2ID['mask']) # relation get setting: REL2ID.get(x, REL2ID['<UNK>'])--- if not found, return the UNK rel_ids = pad([REL2ID.get(x, REL2ID['<UNK>']) for x in rels], 0) pos_ids = pad([POS2ID.get(x, POS2ID['<UNK>']) for x in pos], 0) except KeyError: # TODO F**K THIS ENCODING BUG!!! skipped += 1 continue input_mask = pad([0] * len(tokens), 1) pre_len = len(tokens) # TODO: why we set so many masks for this processing? by bing out['pre_ids'].append(pre_ids) out['pre_masks'].append(input_mask) out['pre_lens'].append(pre_len) out['post_in_ids'].append(post_in_ids) out['post_out_ids'].append(post_out_ids) out['pre_tok_label_ids'].append(pre_tok_label_ids) out['post_tok_label_ids'].append(post_tok_label_ids) out['rel_ids'].append(rel_ids) out['pos_ids'].append(pos_ids) out['categories'].append(categories) print('SKIPPED ', skipped) return out
def word_diff(old, new): """ similar to simplediff.string_diff, but keeping newlines""" return simplediff.diff(old.split(' '), new.split(' '))
python test.py ../raw/tst.biased > ../raw/tst.tokbiased """ import sys from simplediff import diff i = 0 for l in open(sys.argv[1]): parts = l.strip().split('\t') if len(parts) != 7: continue pre_tok = parts[1].split() post_tok = parts[2].split() d = diff(pre_tok, post_tok) old = [x for x in d if x[0] == '-'] new = [x for x in d if x[0] == '+'] # print(diff(l1.strip().split(), l2.strip().split())) if len(old) == 1 and len(old[0][1]) == 1: if not new: print(l.strip()) i += 1 if len(new) == 1 and len(new[0][1]) == 1: print(l.strip()) i += 1
def should_keep(prev_raw, prev_tok, post_raw, post_tok, bleu, rev_id): global CTR_LOW_BLEU global CTR_LOW_LEVEN global CTR_TOO_MANY_1_TOKS global CTR_SPELLING global CTR_CHEMISTRY global CTR_ONLY_PUNC_CHANGED # KEEP -- exact match if bleu == 100 or prev_raw == post_raw: return True, None, [0 for _ in range(len(prev_tok.split()))] # clearly not a match if bleu < 15.0: CTR_LOW_BLEU += 1 return False, None, None # too close if Levenshtein.distance(prev_tok, post_tok) < 4: CTR_LOW_LEVEN += 1 return False, None, None tok_diff = diff(prev_tok.split(), post_tok.split()) tok_labels = get_tok_labels(tok_diff) assert len(tok_labels) == len(prev_tok.split()) changed_text = ''.join( [''.join(chunk) for tag, chunk in tok_diff if tag != '=']) if not re.search('[a-z]', changed_text): CTR_ONLY_PUNC_CHANGED += 1 return False, None, None # too dissimilar -- less than half of toks shared tok_nums = [int(x) for x in tok_labels] if (sum(tok_nums) * 1.0 / len(tok_nums)) > 0.5: CTR_TOO_MANY_1_TOKS += 1 return False, None, None # edit was just fixing a spelling error word_diff = diff(word_tokenize(prev_raw), word_tokenize(post_raw)) if is_spelling_diff(word_diff): CTR_SPELLING += 1 return False, None, None # some simple filtering to get out the chemistry "neutral" edits if ' molecules' in prev_raw or ' ions' in prev_raw or ' ionic' in prev_raw or ' atoms' in prev_raw: CTR_CHEMISTRY += 1 return False, None, None # # use enchant to make sure example has enough normal words # prev_words = prev_words.translate(str.maketrans('', '', string.punctuation)).split() # n_words = sum(1 if d.check(w) else 0 for w in pre_words) # if len(prev_words) == 0 or (float(n_words) / len(prev_words)) < 0.5: # return False, None, None # see if this is a "single word" edit, where a single word was replaced with 0+ words def is_single_word_edit(d): """ is this diff good for the final generation dataset """ pre_chunks = [chunk for tag, chunk in d if tag == '-'] post_chunks = [chunk for tag, chunk in d if tag == '+'] # a single word on the pre side if sum([len(chunk) for chunk in pre_chunks]) != 1: return False # 0 words in the post if len(post_chunks) == 0: return True # ensure 1 post chunk if len(post_chunks) > 1: return False # post language chunk is directly after the pre chunk prei = next((i for i, x in enumerate(d) if x[0] == '-')) if prei < len(d) - 1 and d[prei + 1][0] == '+': return True single_word_edit = is_single_word_edit(word_diff) return True, single_word_edit, tok_labels
def get_examples(dataset_params, data_path, tok2id, max_seq_len, noise=False, add_del_tok=False, categories_path=None, convert_to_tensors=True, no_bias_type_labels=False): global REL2ID global POS2ID global EDIT_TYPE2ID if dataset_params['drop_words'] is not None: drop_set = set([l.strip() for l in open(dataset_params['drop_words'])]) else: drop_set = None def pad(id_arr, pad_idx): return id_arr + ([pad_idx] * (max_seq_len - len(id_arr))) skipped = 0 out = defaultdict(list) if categories_path is not None: category_fp = open(categories_path) next(category_fp) # ignore header revid2topic = { l.strip().split(',')[0]: [float(x) for x in l.strip().split(',')[1:]] for l in category_fp } #NOTE: we require that the TSV begins with some header data_file = open(data_path) header_line = next(data_file).strip().split('\t') assert header_line[0] == 'id', "data file required to contain a header." for i, (line) in enumerate(tqdm(data_file)): parts = line.strip().split('\t') if no_bias_type_labels: if len(parts) == 5: [revid, pre, post, pos, rels] = parts elif len(parts) == 7: [revid, pre, post, _, _, pos, rels] = parts else: skipped += 1 continue else: if len(parts) == 7: [revid, pre, post, pos, rels, epistemological, framing] = parts elif len(parts) == 9: [revid, pre, post, _, _, pos, rels, epistemological, framing] = parts else: skipped += 1 continue # break up tokens tokens = pre.strip().split() post_tokens = post.strip().split() rels = rels.strip().split() pos = pos.strip().split() # get diff + binary diff masks tok_diff = diff(tokens, post_tokens) pre_tok_labels, post_tok_labels = get_tok_labels(tok_diff) # make sure everything lines up if len(tokens) != len(pre_tok_labels) \ or len(tokens) != len(rels) \ or len(tokens) != len(pos) \ or len(post_tokens) != len(post_tok_labels): skipped += 1 continue # leave room in the post for start/stop and possible category/class token if len(tokens) > max_seq_len - 1 or len(post_tokens) > max_seq_len - 1: skipped += 1 continue # category info if provided # TODO -- if provided but not in diyi's data, we fill with random...is that ok? if categories_path is not None and revid in revid2topic: categories = revid2topic[revid] else: categories = np.random.uniform(size=43) # 43 = number of categories categories = categories / sum(categories) # normalize if dataset_params['category_input']: category_id = np.argmax(categories) tokens = ['[unused%d]' % category_id] + tokens pre_tok_labels = [EDIT_TYPE2ID['mask']] + pre_tok_labels post_tok_labels = [EDIT_TYPE2ID['mask']] + post_tok_labels # add start + end symbols to post in/out post_input_tokens = ['行'] + post_tokens post_output_tokens = post_tokens + ['止'] # shuffle + convert to ids + pad try: if noise: pre_toks = noise_seq( tokens[:], drop_prob=dataset_params['noise_prob'], shuf_dist=dataset_params['shuf_dist'], drop_set=drop_set, keep_bigrams=dataset_params['keep_bigrams']) else: pre_toks = tokens pre_ids = pad([tok2id[x] for x in pre_toks], 0) post_in_ids = pad([tok2id[x] for x in post_input_tokens], 0) post_out_ids = pad([tok2id[x] for x in post_output_tokens], 0) pre_tok_label_ids = pad(pre_tok_labels, EDIT_TYPE2ID['mask']) if 1 not in pre_tok_label_ids: skipped += 1 continue post_tok_label_ids = pad(post_tok_labels, EDIT_TYPE2ID['mask']) rel_ids = pad([REL2ID.get(x, REL2ID['<UNK>']) for x in rels], 0) pos_ids = pad([POS2ID.get(x, POS2ID['<UNK>']) for x in pos], 0) except KeyError: skipped += 1 continue input_mask = pad([0] * len(tokens), 1) pre_len = len(tokens) # Adding label encodings --> 0: epistemological; 1:framing if not no_bias_type_labels: try: bias_label = 0 if int(epistemological) else 1 if bias_label == 1: assert int(framing), "Processing error: both epistemological and framing labels are true." except KeyError: continue out['pre_ids'].append(pre_ids) out['masks'].append(input_mask) out['pre_lens'].append(pre_len) out['post_in_ids'].append(post_in_ids) out['post_out_ids'].append(post_out_ids) out['pre_tok_label_ids'].append(pre_tok_label_ids) out['post_tok_label_ids'].append(post_tok_label_ids) out['rel_ids'].append(rel_ids) out['pos_ids'].append(pos_ids) out['categories'].append(categories) out['index'].append(i) # can do some hash thing if not no_bias_type_labels: out['bias_label'].append(bias_label) if convert_to_tensors: out['pre_ids'] = torch.tensor(out['pre_ids'], dtype=torch.long) out['masks'] = torch.tensor(out['masks'], dtype=torch.uint8) # byte for masked_fill() out['pre_lens'] = torch.tensor(out['pre_lens'], dtype=torch.long) out['post_in_ids'] = torch.tensor(out['post_in_ids'], dtype=torch.long) out['post_out_ids'] = torch.tensor(out['post_out_ids'], dtype=torch.long) out['pre_tok_label_ids'] = torch.tensor(out['pre_tok_label_ids'], dtype=torch.float) # for compartin to enrichment stuff out['post_tok_label_ids'] = torch.tensor(out['post_tok_label_ids'], dtype=torch.float) # for loss multiplying out['rel_ids'] = torch.tensor(out['rel_ids'], dtype=torch.long) out['pos_ids'] = torch.tensor(out['pos_ids'], dtype=torch.long) out['categories'] = torch.tensor(out['categories'], dtype=torch.float) out['index'] = torch.tensor(out['index'], dtype=torch.int32) if not no_bias_type_labels: out['bias_label'] = torch.tensor(out['bias_label'], dtype=torch.int16) return out
def is_single_word_diff(s1, s2): s1 = word_tokenize(s1.lower().strip()) s2 = word_tokenize(s2.lower().strip()) word_diff = diff(s1, s2) return sum([len(chunk) for tag, chunk in word_diff if tag == '-']) == 1
out_pre = open(out_prefix + '.pre', 'w') out_post = open(out_prefix + '.post', 'w') len_skip_del = 0 len_skip_nodel = 0 skip = 0 dels = 0 nondels = 0 filtered = 0 for pre_l, post_l in zip(open(pre_fp), open(post_fp)): pre_l = pre_l.strip().split() post_l = post_l.strip().split() d = diff(pre_l, post_l) old = [x for x in d if x[0] == '-'] new = [x for x in d if x[0] == '+'] if len(old) == 0: skip += 1 continue oldi = next((i for i, x in enumerate(d) if x[0] == '-')) if not new: if len(pre_l) > 100 or len(post_l) > 100: len_skip_del += 1 continue out_pre.write(' '.join(pre_l) + '\n') out_post.write(' '.join(post_l) + '\n')
def export_word_edit_matrix(context: List, current_sen: List, label_sen: List, super_mode: str = 'before', # if there requires multiple insert, we only # keep the longest one only_one_insert: bool = False): if isinstance(context, str): context_seq = list(context) current_seq = list(current_sen) label_seq = list(label_sen) else: context_seq = context current_seq = current_sen label_seq = label_sen applied_changes = diff(current_seq, label_seq) def sub_finder(cus_list, pattern, used_pos): find_indices = [] for i in range(len(cus_list)): if cus_list[i] == pattern[0] and \ cus_list[i:i + len(pattern)] == pattern \ and i not in used_pos: find_indices.append((i, i + len(pattern))) if len(find_indices) == 0: return 0, 0 else: return find_indices[-1] def cont_sub_finder(cus_list, pattern, used_pos): context_len = len(cus_list) pattern_len = len(pattern) for i in range(context_len): k = i j = 0 temp_indices = [] while j < pattern_len and k < context_len: if cus_list[k] == pattern[j][0] and \ cus_list[k:k + len(pattern[j])] == pattern[j] \ and k not in used_pos: temp_indices.append((k, k + len(pattern[j]))) j += 1 else: k += 1 if j == pattern_len: return zip(*temp_indices) else: return 0, 0 rm_range = None ret_ops = [] context_used_pos = [] current_used_pos = [] pointer = 0 for diff_sample in applied_changes: diff_op = diff_sample[0] diff_content = diff_sample[1] if diff_op == '-': if rm_range is not None: ret_ops.append(['remove', rm_range, []]) start, end = sub_finder(current_seq, diff_content, current_used_pos ) rm_range = [start, end] current_used_pos.extend(list(range(start, end))) elif diff_op == '+': start, end = sub_finder(context_seq, diff_content, context_used_pos) # cannot find the exact match substring, we should identify the snippets if start == 0 and end == 0: inner_diff = diff(diff_content, context_seq) overlap_content = [inner_diff_sample[1] for inner_diff_sample in inner_diff if inner_diff_sample[0] == '='] if len(overlap_content) > 0: # only take one insert if len(overlap_content) == 1 or only_one_insert: overlap_content = sorted(overlap_content, key=lambda x: len(x), reverse=True)[0] start, end = sub_finder(context_seq, overlap_content, context_used_pos) else: start_end_tuple = cont_sub_finder(context_seq, overlap_content, context_used_pos) # start is a list, end is also start, end = start_end_tuple else: start, end = 0, 0 if not (start == 0 and end == 0): if isinstance(start, int): add_ranges = [[start, end]] else: add_ranges = list(zip(start, end)) if rm_range is not None: for add_range in add_ranges: context_used_pos.extend(list(range(add_range[0], add_range[1]))) ret_ops.append(['replace', rm_range, add_range]) rm_range = None else: for add_range in add_ranges: if super_mode in ['before', 'both']: ret_ops.append(['before', [pointer, pointer], add_range]) if super_mode in ['after', 'both']: if pointer >= 1: ret_ops.append(['after', [pointer - 1, pointer - 1], add_range]) elif diff_op == '=': if rm_range is not None: ret_ops.append(['remove', rm_range, []]) start, end = sub_finder(current_seq, diff_content, current_used_pos ) current_used_pos.extend(list(range(start, end))) rm_range = None pointer = end return ret_ops