Exemplo n.º 1
0
def dump_outputs(src_ids,
                 gold_ids,
                 predicted_ids,
                 gold_tok_dist,
                 id2tok,
                 out_file,
                 pred_dists=None):
    out_hits = []
    preds_for_bleu = []
    golds_for_bleu = []
    srcs_for_bleu = []

    if pred_dists is None:
        pred_dists = [''] * len(src_ids)
    for src_seq, gold_seq, pred_seq, gold_dist, pred_dist in zip(
            src_ids, gold_ids, predicted_ids, gold_tok_dist, pred_dists):

        src_seq = [id2tok[x] for x in src_seq]
        gold_seq = [id2tok[x] for x in gold_seq]
        pred_seq = [id2tok[x] for x in pred_seq[1:]]  # ignore start token
        if '止' in gold_seq:
            gold_seq = gold_seq[:gold_seq.index('止')]
        if '止' in pred_seq:
            pred_seq = pred_seq[:pred_seq.index('止')]

        gold_replace = [
            chunk for tag, chunk in diff(src_seq, gold_seq) if tag == '+'
        ]
        pred_replace = [
            chunk for tag, chunk in diff(src_seq, pred_seq) if tag == '+'
        ]

        src_seq = ' '.join(src_seq).replace('[PAD]', '').strip()
        gold_seq = ' '.join(gold_seq).replace('[PAD]', '').strip()
        pred_seq = ' '.join(pred_seq).replace('[PAD]', '').strip()

        # try:
        if out_file is not None:
            print('#' * 80, file=out_file)
            print('IN SEQ: \t', src_seq.encode('utf-8'), file=out_file)
            print('GOLD SEQ: \t', gold_seq.encode('utf-8'), file=out_file)
            print('PRED SEQ:\t', pred_seq.encode('utf-8'), file=out_file)
            print('GOLD DIST: \t', list(gold_dist), file=out_file)
            print('PRED DIST: \t', list(pred_dist), file=out_file)
            print('GOLD TOK: \t', list(gold_replace), file=out_file)
            print('PRED TOK: \t', list(pred_replace), file=out_file)
        # except UnicodeEncodeError:
        #     pass

        if gold_seq == pred_seq:
            out_hits.append(1)
        else:
            out_hits.append(0)

        preds_for_bleu.append(pred_seq.split())
        golds_for_bleu.append(gold_seq.split())
        srcs_for_bleu.append(src_seq.split())

    return out_hits, preds_for_bleu, golds_for_bleu, srcs_for_bleu
Exemplo n.º 2
0
def should_keep(prev_raw, prev_tok, post_raw, post_tok, bleu, rev_id):
    global CTR_LOW_BLEU
    global CTR_LOW_LEVEN
    global CTR_TOO_MANY_1_TOKS
    global CTR_SPELLING
    global CTR_CHEMISTRY
    global CTR_ONLY_PUNC_CHANGED

    # KEEP -- exact match
    if bleu == 100 or prev_raw == post_raw:
        return True, None, '0', ['0' for _ in range(len(prev_tok.split()))]
    # clearly not a match
    if bleu < 15.0:
        CTR_LOW_BLEU += 1
        return False, None, None, None
    # too close
    if Levenshtein.distance(prev_tok, post_tok) < 4:
        CTR_LOW_LEVEN += 1
        return False, None, None, None

    tok_diff = diff(prev_tok.split(), post_tok.split())
    tok_labels = get_tok_labels(tok_diff)
    assert len(tok_labels) == len(prev_tok.split())

    changed_text = ''.join([''.join(chunk) for tag, chunk in tok_diff if tag != '='])
    if not re.search('[a-z]', changed_text):
        CTR_ONLY_PUNC_CHANGED += 1
        return False, None, None, None  

    # too dissimilar -- less than half of toks shared
    tok_nums = [int(x) for x in tok_labels]
    if ( sum(tok_nums) * 1.0 / len(tok_nums) ) > 0.5:
        CTR_TOO_MANY_1_TOKS += 1
        return False, None, None, None  

    # edit was just fixing a spelling error
    word_diff = diff(word_tokenize(prev_raw), word_tokenize(post_raw))
    if is_spelling_diff(word_diff):
        CTR_SPELLING += 1
        return False, None, None, None

    # some simple filtering to get out the chemistry "neutral" edits
    if ' molecules' in prev_raw or ' ions' in prev_raw or ' ionic' in prev_raw or ' atoms' in prev_raw:
        CTR_CHEMISTRY += 1
        return False, None, None, None


    single_word_edit = sum([len(chunk) for tag, chunk in word_diff if tag == '-']) == 1

    return True, single_word_edit, '1', tok_labels
Exemplo n.º 3
0
def telegram_webhook():
    update = request.get_json()
    if "message" in update:
        m = update["message"]["text"]
        #ms = m.split(' ')
        chat = update["message"]["chat"]
        chat_id = chat["id"]
        skribajxo = Skribajxo.get()
        if m == '/g':
            bot.sendMessage(chat_id, skribajxo.enhavo)
        elif m == '/start':
            bot.sendMessage(170378225, u'یک گپ تازه با من شروع شد!')
            bot.sendMessage(chat_id, 'سلام.\n برای دریافت متن با آخرین تغییرات بزنید /g و متن را کپی کنید؛ حداکثر ۷ نویسه (حرف، فاصله و…) را تغییر دهید و بفرستید.')
        elif chat_id == 170378225 and  m == '/uzantoj':
            Uzantoj = Uzanto.select().order_by(Uzanto.id)
            uzantoj = ''
            for uzanto in Uzantoj:
                uzantoj += '>>> '+str(uzanto.tid)+': '+uzanto.nomo+':'+uzanto.familio+': @'+uzanto.uzantnomo+': '+str(uzanto.kontribuinta)+'\n'
            uzantoj += '---------\n'+str(Uzantoj.count())
            bot.sendMessage(170378225, uzantoj)
        else:
            try:
                uzanto = Uzanto.get(Uzanto.tid == chat_id)
            except Uzanto.DoesNotExist:
                try:
                    chat['username']
                except:
                    chat['username'] = ''
                try:
                    chat['first_name']
                except:
                    chat['first_name'] = ''
                try:
                    chat['last_name']
                except:
                    chat['last_name'] = ''
                uzanto = Uzanto.create(tid=chat['id'], uzantnomo=chat['username'], nomo=chat['first_name'], familio=chat['last_name'])
            if (datetime.datetime.now() - uzanto.lastaredakto).seconds >= datetime.timedelta(seconds=30).seconds:
                if ':\n' in m:
                    lasta_duponktoj = m.rindex(':\n')
                    m = m[lasta_duponktoj+2:]
                diferencoj = diff(skribajxo.enhavo, m)
                diferenco_nombro = 0
                for diferenco in diferencoj:
                    if diferenco[0] != '=':
                        diferenco_nombro += len(diferenco[1])
                #bot.sendMessage(chat_id, str(diferenco_nombro)+'\n'+str(diferencoj))
                if diferenco_nombro < 8:
                    skribajxo.enhavo = m
                    skribajxo.save()
                    uzanto.lastaredakto = datetime.datetime.now()
                    uzanto.kontribuinta += diferenco_nombro
                    uzanto.save()
                    bot.sendMessage(chat_id, 'تغییرات با موفقیت انجام شد!')
                else:
                    bot.sendMessage(chat_id, 'حداکثر ۷ نویسه (حرف، فاصله و…) باید تغییر کند! دقت کنید که نام‌تان در ابتدای پیام ارسالی نباشد. ممکن است متنی که دارید قدیمی باشد. برای دریافت متن با آخرین تغییرات بزنید /g')
            else:
                bot.sendMessage(chat_id, '{} ثانیهٔ دیگر منتظر بمانید!'.format(persi(30-(datetime.datetime.now() - uzanto.lastaredakto).seconds)))
                    
    return "OK"
Exemplo n.º 4
0
def diff_changed(old, new):
    '''
    Returns the differences basend on characters between two strings
    wrapped with DIFFON and DIFFOFF using `diff`.
    '''
    con = {'=': (lambda x: x),
           '+': (lambda x: DIFFON + x + DIFFOFF),
           '-': (lambda x: '')}
    return "".join([(con[a])("".join(b)) for a, b in diff(old, new)])
Exemplo n.º 5
0
def diff_changed(old, new):
    '''
    Returns the differences basend on characters between two strings
    wrapped with DIFFON and DIFFOFF using `diff`.
    '''
    con = {'=': (lambda x: x),
           '+': (lambda x: DIFFON + x + DIFFOFF),
           '-': (lambda x: '')}
    return "".join([(con[a])("".join(b)) for a, b in diff(old, new)])
Exemplo n.º 6
0
 def punct_diff(a, b):
     d = diff(a.split(), b.split())
     changed_text = ''.join(
         [''.join(chunk).strip() for tag, chunk in d if tag != '='])
     if not re.search('[a-z]', changed_text):
         return True
     elif re.sub(r'[^\w\s]', '', a) == re.sub(r'[^\w\s]', '', b):
         return True
     return False
Exemplo n.º 7
0
    def correct_url_list(self, url_list, parent_url):
        correct_url_list = []
        parent_parsed = urlparse(parent_url)
        for url in url_list:
            url_parsed = urlparse(url)
            dif = set([d[1] for d in diff(parent_parsed.netloc, url_parsed.netloc) if d[0] in ('+', '-')])
            if url_parsed.netloc == '':
                url = '{uri.scheme}://{uri.netloc}'.format(uri=parent_parsed) + "/" + url_parsed.path
                correct_url_list.append(url)

            elif url_parsed.netloc == parent_parsed.netloc or dif in ['www.']:
                correct_url_list.append(url)
        return correct_url_list
Exemplo n.º 8
0
def get_tok_labels(s1_toks, s2_toks):
    s_diff = diff(s1_toks, s2_toks)
    tok_labels = []
    for tag, chunk in s_diff:
        if tag == '=':
            tok_labels += ['0'] * len(chunk)
        elif tag == '-':
            tok_labels += ['1'] * len(chunk)
        else:
            pass
    assert len(tok_labels) == len(s1_toks)

    return tok_labels
Exemplo n.º 9
0
def word_diff(old, new):
    '''
    Returns the difference between the old and new strings based on words. Punctuation is not part of the word.

    Params:
        old the old string
        new the new string

    Returns:
        the output of `diff` on the two strings after splitting them
        on whitespace (a list of change instructions; see the docstring
        of `diff`)
    '''
    separator_pattern = '(\W+)'
    return diff(re.split(separator_pattern, old, flags=re.UNICODE), re.split(separator_pattern, new, flags=re.UNICODE))
Exemplo n.º 10
0
def word_diff(old, new):
    '''
    Returns the difference between the old and new strings based on words. Punctuation is not part of the word.

    Params:
        old the old string
        new the new string

    Returns:
        the output of `diff` on the two strings after splitting them
        on whitespace (a list of change instructions; see the docstring
        of `diff`)
    '''
    separator_pattern = '(\W+)';
    return diff(re.split(separator_pattern, old, flags=re.UNICODE), re.split(separator_pattern, new, flags=re.UNICODE))
Exemplo n.º 11
0
def list_inline_diff(oldlist, newlist, colors=None):
    if not colors:
        colors = init_colors(False)
    diff = simplediff.diff(oldlist, newlist)
    ret = []
    for change, value in diff:
        value = ' '.join(value)
        if change == '=':
            ret.append("'%s'" % value)
        elif change == '+':
            item = '{color_add}+{value}{color_default}'.format(value=value, **colors)
            ret.append(item)
        elif change == '-':
            item = '{color_remove}-{value}{color_default}'.format(value=value, **colors)
            ret.append(item)
    return '[%s]' % (', '.join(ret))
Exemplo n.º 12
0
def list_inline_diff(oldlist, newlist, colors=None):
    if not colors:
        colors = init_colors(False)
    diff = simplediff.diff(oldlist, newlist)
    ret = []
    for change, value in diff:
        value = ' '.join(value)
        if change == '=':
            ret.append("'%s'" % value)
        elif change == '+':
            item = '{color_add}+{value}{color_default}'.format(value=value, **colors)
            ret.append(item)
        elif change == '-':
            item = '{color_remove}-{value}{color_default}'.format(value=value, **colors)
            ret.append(item)
    return '[%s]' % (', '.join(ret))
Exemplo n.º 13
0
def split_with_diff(src_lines, tgt_lines):
    content = []
    src_attr = []
    tgt_attr = []

    for src, tgt in zip(src_lines, tgt_lines):
        sent_diff = diff(src, tgt)
        tok_collector = defaultdict(list)
        for source, chunk in sent_diff:
            tok_collector[source] += chunk

        content.append(tok_collector['='][:])
        src_attr.append(tok_collector['-'][:])
        tgt_attr.append(tok_collector['+'][:])

    return content[:], content[:], src_attr, tgt_attr
Exemplo n.º 14
0
def cal_idx(previous_sql, gold_sql):
    pre_cur_id = -1
    gold_match_id = []
    diff_ans = diff(previous_sql, gold_sql)
    for x, y in diff_ans:
        if x == '=':
            for _ in range(len(y)):
                pre_cur_id += 1
                gold_match_id.append(pre_cur_id)
        elif x == '+':
            for _ in range(len(y)):
                gold_match_id.append(-1)
        else:
            for _ in range(len(y)):
                pre_cur_id += 1
    assert len(gold_match_id) == len(gold_sql)
    return gold_match_id
Exemplo n.º 15
0
def worddiff_str(oldstr, newstr, colors=None):
    if not colors:
        colors = init_colors(False)
    diff = simplediff.diff(oldstr.split(' '), newstr.split(' '))
    ret = []
    for change, value in diff:
        value = ' '.join(value)
        if change == '=':
            ret.append(value)
        elif change == '+':
            item = '{color_add}{{+{value}+}}{color_default}'.format(value=value, **colors)
            ret.append(item)
        elif change == '-':
            item = '{color_remove}[-{value}-]{color_default}'.format(value=value, **colors)
            ret.append(item)
    whitespace_note = ''
    if oldstr != newstr and ' '.join(oldstr.split()) == ' '.join(newstr.split()):
        whitespace_note = ' (whitespace changed)'
    return '"%s"%s' % (' '.join(ret), whitespace_note)
Exemplo n.º 16
0
def worddiff_str(oldstr, newstr, colors=None):
    if not colors:
        colors = init_colors(False)
    diff = simplediff.diff(oldstr.split(' '), newstr.split(' '))
    ret = []
    for change, value in diff:
        value = ' '.join(value)
        if change == '=':
            ret.append(value)
        elif change == '+':
            item = '{color_add}{{+{value}+}}{color_default}'.format(value=value, **colors)
            ret.append(item)
        elif change == '-':
            item = '{color_remove}[-{value}-]{color_default}'.format(value=value, **colors)
            ret.append(item)
    whitespace_note = ''
    if oldstr != newstr and ' '.join(oldstr.split()) == ' '.join(newstr.split()):
        whitespace_note = ' (whitespace changed)'
    return '"%s"%s' % (' '.join(ret), whitespace_note)
Exemplo n.º 17
0
def is_spelling_diff(prev_sent, post_sent):
    d = diff(word_tokenize(prev_sent), word_tokenize(post_sent))

    # only look at the one-word diffs
    if sum([len(chunk) for tag, chunk in d if tag == '-']) > 1:
        return False

    sp = SpellChecker()
    for i, (tag, words) in enumerate(d):
        # is one-word spelling replacement
        if tag == '-' and \
            i+1 < len(d) - 1 and \
            len(words) == 1 and \
            d[i+1][0] == '+' and \
            not sp.correction(words[0]) == words[0] and \
            sp.correction(words[0]) in ' '.join(d[i+1][1]):

            return True

    return False
Exemplo n.º 18
0
    def test_character_diff(self):
        strings = TESTS['character']

        for check in strings:
            self.assertEqual(simplediff.diff(check['old'], check['new']),
                             check['diff'])
Exemplo n.º 19
0
def get_examples(data_path,
                 tok2id,
                 max_seq_len,
                 noise=False,
                 add_del_tok=False,
                 categories_path=None):
    global REL2ID
    global POS2ID
    global EDIT_TYPE2ID
    global ARGS

    if ARGS.drop_words is not None:
        drop_set = set([l.strip() for l in open(ARGS.drop_words)])
    else:
        drop_set = None

    def pad(id_arr, pad_idx):
        return id_arr + ([pad_idx] * (max_seq_len - len(id_arr)))

    skipped = 0
    out = defaultdict(list)
    if categories_path is not None:
        category_fp = open(categories_path)
        next(category_fp)  # ignore header
        revid2topic = {
            l.strip().split(',')[0]:
            [float(x) for x in l.strip().split(',')[1:]]
            for l in category_fp
        }
    for i, (line) in enumerate(tqdm(open(data_path))):
        parts = line.strip().split('\t')

        # if there pos/rel info
        if len(parts) == 7:
            [revid, pre, post, _, _, pos, rels] = parts
        # no pos/rel info
        elif len(parts) == 5:
            [revid, pre, post, _, _] = parts
            pos = ' '.join(['<UNK>'] * len(pre.strip().split()))
            rels = ' '.join(['<UNK>'] * len(pre.strip().split()))
        # broken line
        else:
            skipped += 1
            continue

        # break up tokens
        tokens = pre.strip().split()
        post_tokens = post.strip().split()
        rels = rels.strip().split()
        pos = pos.strip().split()

        # get diff + binary diff masks
        tok_diff = diff(tokens, post_tokens)
        pre_tok_labels, post_tok_labels = get_tok_labels(tok_diff)

        # make sure everything lines up
        if len(tokens) != len(pre_tok_labels) \
            or len(tokens) != len(rels) \
            or len(tokens) != len(pos) \
            or len(post_tokens) != len(post_tok_labels):
            skipped += 1
            continue

        # leave room in the post for start/stop and possible category/class token
        if len(tokens) > max_seq_len - 1 or len(post_tokens) > max_seq_len - 1:
            skipped += 1
            continue

        # category info if provided
        # TODO -- if provided but not in diyi's data, we fill with random...is that ok?
        if categories_path is not None and revid in revid2topic:
            categories = revid2topic[revid]
        else:
            categories = np.random.uniform(
                size=43)  # 43 = number of categories
            categories = categories / sum(categories)  # normalize

        if ARGS.category_input:
            category_id = np.argmax(categories)
            tokens = ['[unused%d]' % category_id] + tokens
            pre_tok_labels = [EDIT_TYPE2ID['mask']] + pre_tok_labels
            post_tok_labels = [EDIT_TYPE2ID['mask']] + post_tok_labels

        # add start + end symbols to post in/out
        post_input_tokens = ['行'] + post_tokens
        post_output_tokens = post_tokens + ['止']

        # shuffle + convert to ids + pad
        try:
            if noise:
                pre_toks = noise_seq(tokens[:],
                                     drop_prob=ARGS.noise_prob,
                                     shuf_dist=ARGS.shuf_dist,
                                     drop_set=drop_set,
                                     keep_bigrams=ARGS.keep_bigrams)
            else:
                pre_toks = tokens

            pre_ids = pad([tok2id[x] for x in pre_toks], 0)
            post_in_ids = pad([tok2id[x] for x in post_input_tokens], 0)
            post_out_ids = pad([tok2id[x] for x in post_output_tokens], 0)
            pre_tok_label_ids = pad(pre_tok_labels, EDIT_TYPE2ID['mask'])
            post_tok_label_ids = pad(post_tok_labels, EDIT_TYPE2ID['mask'])
            rel_ids = pad([REL2ID.get(x, REL2ID['<UNK>']) for x in rels], 0)
            pos_ids = pad([POS2ID.get(x, POS2ID['<UNK>']) for x in pos], 0)
        except KeyError:
            # TODO F**K THIS ENCODING BUG!!!
            skipped += 1
            continue

        input_mask = pad([0] * len(tokens), 1)
        pre_len = len(tokens)

        out['pre_ids'].append(pre_ids)
        out['pre_masks'].append(input_mask)
        out['pre_lens'].append(pre_len)
        out['post_in_ids'].append(post_in_ids)
        out['post_out_ids'].append(post_out_ids)
        out['pre_tok_label_ids'].append(pre_tok_label_ids)
        out['post_tok_label_ids'].append(post_tok_label_ids)
        out['rel_ids'].append(rel_ids)
        out['pos_ids'].append(pos_ids)
        out['categories'].append(categories)

    print('SKIPPED ', skipped)
    return out
Exemplo n.º 20
0
 def test_words_diff(self):
     strings = TESTS['words']
     for check in strings:
         self.assertEqual(simplediff.diff(check['old'], check['new']),
                          check['diff'])
Exemplo n.º 21
0
    def test_character_diff(self):
        strings = TESTS["character"]

        for check in strings:
            self.assertEqual(simplediff.diff(check["old"], check["new"]), check["diff"])
Exemplo n.º 22
0
 def test_words_diff(self):
     strings = TESTS["words"]
     for check in strings:
         self.assertEqual(simplediff.diff(check["old"], check["new"]), check["diff"])
Exemplo n.º 23
0
    def test_character_diff(self):
        strings = TESTS['character']

        for check in strings:
            self.assertEqual(simplediff.diff(check['old'], check['new']),
                             check['diff'])
Exemplo n.º 24
0
 def test_words_diff(self):
     strings = TESTS['words']
     for check in strings:
         self.assertEqual(simplediff.diff(check['old'], check['new']),
                          check['diff'])
Exemplo n.º 25
0
			new_element = new_tree.xpath( old_path )

			# Check for a one-to-one path relationship.
			if len( new_element ) != 1:
				# If the list is empty, the element is missing.
				# Otherwise, the path matches more than one element, which means elements have been added.
				if len( new_element ) == 0:
					print bill_id, "* Element missing:", old_path
				else:
					print bill_id, old_path, old_element, new_element
			else:
				old_element_keys = sorted( old_element.keys() )
				new_element_keys = sorted( new_element[0].keys() )

				key_diff = diff( old_element_keys, new_element_keys )

				# Check for attribute changes.
				for state, attributes in key_diff:
					if state != '=':
						print bill_id, "* Element has difference in attributes:", old_path

					# Check for missing or added attributes.
					if state == "-":
						for attribute in attributes:
							print bill_id, "** Attribute missing:", attribute + '="' + old_element.attrib[attribute] + '"'
					elif state == "+":
						for attribute in attributes:
							print bill_id, "** Attribute added:", attribute + '="' + new_element[0].attrib[attribute] + '"'
					else:
						# Check for attribute value changes.
Exemplo n.º 26
0
def get_examples(data_path,
                 tok2id,
                 max_seq_len,
                 noise=False,
                 add_del_tok=False,
                 categories_path=None):
    global REL2ID
    global POS2ID
    global EDIT_TYPE2ID
    global ARGS

    if ARGS.drop_words is not None:
        drop_set = set([l.strip() for l in open(ARGS.drop_words)])
    else:
        drop_set = None

    # pad_idx: token-based element adding, just need the token
    def pad(id_arr, pad_idx):
        return id_arr + ([pad_idx] * (max_seq_len - len(id_arr)))

    skipped = 0
    out = defaultdict(list)
    if categories_path is not None:
        category_fp = open(categories_path)
        next(category_fp)  # ignore header
        ## previous
        # topic is a vector presentation with different semantic meanings
        revid2topic = {
            l.strip().split(',')[0]:
            [float(x) for x in l.strip().split(',')[1:]]
            for l in category_fp
        }
        #
        # ## Bing's opinion
        # TODO check the existing code, it should bypass the previous code
        # for l in category_fp:
        #     revid2topic = {
        #         l.strip().split(',')[0]: [float(x) for x in l.strip().split(',')[1:]]
        #     }

    # add the encoding setting by bing , encoding="utf-8"
    for i, (line) in enumerate(tqdm(open(data_path, encoding="utf-8"))):
        # iterate over the whole lines to return a defaultdict called out
        parts = line.strip().split('\t')

        # if there pos/rel info
        if len(parts) == 7:
            [revid, pre, post, _, _, pos, rels] = parts
        # no pos/rel info
        elif len(parts) == 5:
            [revid, pre, post, _, _] = parts
            pos = ' '.join(['<UNK>'] * len(pre.strip().split()))
            rels = ' '.join(['<UNK>'] * len(pre.strip().split()))
        # broken line
        else:
            skipped += 1
            continue

        # break up tokens
        # by default, we use the whitespace to split
        tokens = pre.strip().split()
        post_tokens = post.strip().split()
        rels = rels.strip().split()
        pos = pos.strip().split()

        # get diff + binary diff masks
        # print("tokens: ", tokens)
        # print("post_tokens: ", post_tokens)
        """" example setting
        a=['ch', '##lor', '##of', '##or', '##m']
        b=['ch',  '##lor', '##of', '##or', 'new', 'things']
        diff(a,b)
        Out[19]:
        [('=', ['ch', '##lor', '##of', '##or']),
         ('-', ['##m']),
         ('+', ['new', 'things'])]
        """
        tok_diff = diff(tokens, post_tokens)
        pre_tok_labels, post_tok_labels = get_tok_labels(tok_diff)
        # print("pre_tok_labels", pre_tok_labels)
        # print("post_tok_labels", post_tok_labels)
        # exit(0)

        # make sure everything lines up
        # double check the data data: bing --- they will only use the equal length token
        # by checing: the len(tokens) != len(pre_tok_labels) and len(post_tokens) != len(post_tok_labels) should be useless?
        # in my opinion
        if len(tokens) != len(pre_tok_labels) \
                or len(tokens) != len(rels) \
                or len(tokens) != len(pos) \
                or len(post_tokens) != len(post_tok_labels):
            skipped += 1
            continue

        # leave room in the post for start/stop and possible category/class token
        if len(tokens) > max_seq_len - 1 or len(post_tokens) > max_seq_len - 1:
            skipped += 1
            continue

        # category info if provided
        # TODO -- if provided but not in diyi's data, we fill with random...is that ok?
        if categories_path is not None and revid in revid2topic:
            categories = revid2topic[revid]
        else:
            categories = np.random.uniform(
                size=43)  # 43 = number of categories
            categories = categories / sum(categories)  # normalize

        if ARGS.category_input:
            category_id = np.argmax(categories)
            tokens = ['[unused%d]' % category_id] + tokens
            pre_tok_labels = [EDIT_TYPE2ID['mask']] + pre_tok_labels
            post_tok_labels = [EDIT_TYPE2ID['mask']] + post_tok_labels

        # add start + end symbols to post in/out
        # any usage for this chinese charater? TODO check by bing
        post_input_tokens = ['行'] + post_tokens
        post_output_tokens = post_tokens + ['止']

        # shuffle + convert to ids + pad
        try:
            # TODO check the possible usage of noise, in setting, we do not use by bing
            if noise:
                pre_toks = noise_seq(tokens[:],
                                     drop_prob=ARGS.noise_prob,
                                     shuf_dist=ARGS.shuf_dist,
                                     drop_set=drop_set,
                                     keep_bigrams=ARGS.keep_bigrams)
            else:
                pre_toks = tokens

            # bt default, max 80, we need to pad
            pre_ids = pad([tok2id[x] for x in pre_toks],
                          0)  # not by zero, but [PAD] format
            # TODO: why post two tokens? by bing
            post_in_ids = pad([tok2id[x] for x in post_input_tokens], 0)
            post_out_ids = pad([tok2id[x] for x in post_output_tokens], 0)
            # label setting and processing
            pre_tok_label_ids = pad(pre_tok_labels, EDIT_TYPE2ID['mask'])
            post_tok_label_ids = pad(post_tok_labels, EDIT_TYPE2ID['mask'])
            # relation get setting: REL2ID.get(x, REL2ID['<UNK>'])--- if not found, return the UNK
            rel_ids = pad([REL2ID.get(x, REL2ID['<UNK>']) for x in rels], 0)
            pos_ids = pad([POS2ID.get(x, POS2ID['<UNK>']) for x in pos], 0)
        except KeyError:
            # TODO F**K THIS ENCODING BUG!!!
            skipped += 1
            continue

        input_mask = pad([0] * len(tokens), 1)
        pre_len = len(tokens)

        # TODO: why we set so many masks for this processing? by bing
        out['pre_ids'].append(pre_ids)
        out['pre_masks'].append(input_mask)
        out['pre_lens'].append(pre_len)
        out['post_in_ids'].append(post_in_ids)
        out['post_out_ids'].append(post_out_ids)
        out['pre_tok_label_ids'].append(pre_tok_label_ids)
        out['post_tok_label_ids'].append(post_tok_label_ids)
        out['rel_ids'].append(rel_ids)
        out['pos_ids'].append(pos_ids)
        out['categories'].append(categories)

    print('SKIPPED ', skipped)
    return out
Exemplo n.º 27
0
def word_diff(old, new):
    """ similar to simplediff.string_diff, but keeping newlines"""
    return simplediff.diff(old.split(' '), new.split(' '))
Exemplo n.º 28
0
python test.py ../raw/tst.biased > ../raw/tst.tokbiased
"""

import sys

from simplediff import diff

i = 0
for l in open(sys.argv[1]):
    parts = l.strip().split('\t')

    if len(parts) != 7:
        continue

    pre_tok = parts[1].split()
    post_tok = parts[2].split()
    d = diff(pre_tok, post_tok)

    old = [x for x in d if x[0] == '-']
    new = [x for x in d if x[0] == '+']

    #    print(diff(l1.strip().split(), l2.strip().split()))
    if len(old) == 1 and len(old[0][1]) == 1:
        if not new:
            print(l.strip())
            i += 1
        if len(new) == 1 and len(new[0][1]) == 1:
            print(l.strip())
            i += 1
Exemplo n.º 29
0
def should_keep(prev_raw, prev_tok, post_raw, post_tok, bleu, rev_id):
    global CTR_LOW_BLEU
    global CTR_LOW_LEVEN
    global CTR_TOO_MANY_1_TOKS
    global CTR_SPELLING
    global CTR_CHEMISTRY
    global CTR_ONLY_PUNC_CHANGED

    # KEEP -- exact match
    if bleu == 100 or prev_raw == post_raw:
        return True, None, [0 for _ in range(len(prev_tok.split()))]

    # clearly not a match
    if bleu < 15.0:
        CTR_LOW_BLEU += 1
        return False, None, None
    # too close
    if Levenshtein.distance(prev_tok, post_tok) < 4:
        CTR_LOW_LEVEN += 1
        return False, None, None

    tok_diff = diff(prev_tok.split(), post_tok.split())
    tok_labels = get_tok_labels(tok_diff)
    assert len(tok_labels) == len(prev_tok.split())

    changed_text = ''.join(
        [''.join(chunk) for tag, chunk in tok_diff if tag != '='])
    if not re.search('[a-z]', changed_text):
        CTR_ONLY_PUNC_CHANGED += 1
        return False, None, None

    # too dissimilar -- less than half of toks shared
    tok_nums = [int(x) for x in tok_labels]
    if (sum(tok_nums) * 1.0 / len(tok_nums)) > 0.5:
        CTR_TOO_MANY_1_TOKS += 1
        return False, None, None

    # edit was just fixing a spelling error
    word_diff = diff(word_tokenize(prev_raw), word_tokenize(post_raw))
    if is_spelling_diff(word_diff):
        CTR_SPELLING += 1
        return False, None, None

    # some simple filtering to get out the chemistry "neutral" edits
    if ' molecules' in prev_raw or ' ions' in prev_raw or ' ionic' in prev_raw or ' atoms' in prev_raw:
        CTR_CHEMISTRY += 1
        return False, None, None

    # # use enchant to make sure example has enough normal words
    # prev_words = prev_words.translate(str.maketrans('', '', string.punctuation)).split()
    # n_words = sum(1 if d.check(w) else 0 for w in pre_words)
    # if len(prev_words) == 0 or (float(n_words) / len(prev_words)) < 0.5:
    #     return False, None, None

    # see if this is a "single word" edit, where a single word was replaced with 0+ words
    def is_single_word_edit(d):
        """ is this diff good for the final generation dataset """
        pre_chunks = [chunk for tag, chunk in d if tag == '-']
        post_chunks = [chunk for tag, chunk in d if tag == '+']
        # a single word on the pre side
        if sum([len(chunk) for chunk in pre_chunks]) != 1:
            return False
        # 0 words in the post
        if len(post_chunks) == 0:
            return True
        # ensure 1 post chunk
        if len(post_chunks) > 1:
            return False
        # post language chunk is directly after the pre chunk
        prei = next((i for i, x in enumerate(d) if x[0] == '-'))
        if prei < len(d) - 1 and d[prei + 1][0] == '+':
            return True

    single_word_edit = is_single_word_edit(word_diff)

    return True, single_word_edit, tok_labels
Exemplo n.º 30
0
def get_examples(dataset_params, data_path, tok2id, max_seq_len, noise=False,
                 add_del_tok=False, categories_path=None, convert_to_tensors=True,
                 no_bias_type_labels=False):

    global REL2ID
    global POS2ID
    global EDIT_TYPE2ID

    if dataset_params['drop_words'] is not None:
        drop_set = set([l.strip() for l in
                   open(dataset_params['drop_words'])])
    else:
        drop_set = None

    def pad(id_arr, pad_idx):
        return id_arr + ([pad_idx] * (max_seq_len - len(id_arr)))

    skipped = 0
    out = defaultdict(list)
    if categories_path is not None:
        category_fp = open(categories_path)
        next(category_fp) # ignore header
        revid2topic = {
            l.strip().split(',')[0]: [float(x) for x in l.strip().split(',')[1:]]
            for l in category_fp
        }

    #NOTE: we require that the TSV begins with some header
    data_file = open(data_path)
    header_line = next(data_file).strip().split('\t')
    assert header_line[0] == 'id', "data file required to contain a header."

    for i, (line) in enumerate(tqdm(data_file)):

        parts = line.strip().split('\t')

        if no_bias_type_labels:
            if len(parts) == 5:
                [revid, pre, post, pos, rels] = parts
            elif len(parts) == 7:
                [revid, pre, post, _, _, pos, rels] = parts
            else:
                skipped += 1
                continue
        else:
            if len(parts) == 7:
                [revid, pre, post, pos, rels, epistemological, framing] = parts
            elif len(parts) == 9:
                [revid, pre, post, _, _, pos, rels, epistemological, framing] = parts
            else:
                skipped += 1
                continue

        # break up tokens
        tokens = pre.strip().split()
        post_tokens = post.strip().split()
        rels = rels.strip().split()
        pos = pos.strip().split()

        # get diff + binary diff masks
        tok_diff = diff(tokens, post_tokens)
        pre_tok_labels, post_tok_labels = get_tok_labels(tok_diff)

        # make sure everything lines up
        if len(tokens) != len(pre_tok_labels) \
            or len(tokens) != len(rels) \
            or len(tokens) != len(pos) \
            or len(post_tokens) != len(post_tok_labels):
            skipped += 1
            continue

        # leave room in the post for start/stop and possible category/class token
        if len(tokens) > max_seq_len - 1 or len(post_tokens) > max_seq_len - 1:
            skipped += 1
            continue

        # category info if provided
        # TODO -- if provided but not in diyi's data, we fill with random...is that ok?
        if categories_path is not None and revid in revid2topic:
            categories = revid2topic[revid]
        else:
            categories = np.random.uniform(size=43)   # 43 = number of categories
            categories = categories / sum(categories) # normalize

        if dataset_params['category_input']:
            category_id = np.argmax(categories)
            tokens = ['[unused%d]' % category_id] + tokens
            pre_tok_labels = [EDIT_TYPE2ID['mask']] + pre_tok_labels
            post_tok_labels = [EDIT_TYPE2ID['mask']] + post_tok_labels

        # add start + end symbols to post in/out
        post_input_tokens = ['行'] + post_tokens
        post_output_tokens = post_tokens + ['止']

        # shuffle + convert to ids + pad
        try:
            if noise:
                pre_toks = noise_seq(
                    tokens[:],
                    drop_prob=dataset_params['noise_prob'],
                    shuf_dist=dataset_params['shuf_dist'],
                    drop_set=drop_set,
                    keep_bigrams=dataset_params['keep_bigrams'])
            else:
                pre_toks = tokens

            pre_ids = pad([tok2id[x] for x in pre_toks], 0)
            post_in_ids = pad([tok2id[x] for x in post_input_tokens], 0)
            post_out_ids = pad([tok2id[x] for x in post_output_tokens], 0)
            pre_tok_label_ids = pad(pre_tok_labels, EDIT_TYPE2ID['mask'])
            if 1 not in pre_tok_label_ids:
                skipped += 1
                continue

            post_tok_label_ids = pad(post_tok_labels, EDIT_TYPE2ID['mask'])
            rel_ids = pad([REL2ID.get(x, REL2ID['<UNK>']) for x in rels], 0)
            pos_ids = pad([POS2ID.get(x, POS2ID['<UNK>']) for x in pos], 0)
        except KeyError:
            skipped += 1
            continue

        input_mask = pad([0] * len(tokens), 1)
        pre_len = len(tokens)

        # Adding label encodings --> 0: epistemological; 1:framing
        if not no_bias_type_labels:
            try:
                bias_label = 0 if int(epistemological) else 1
                if bias_label == 1:
                    assert int(framing), "Processing error: both epistemological and framing labels are true."
            except KeyError:
                continue

        out['pre_ids'].append(pre_ids)
        out['masks'].append(input_mask)
        out['pre_lens'].append(pre_len)
        out['post_in_ids'].append(post_in_ids)
        out['post_out_ids'].append(post_out_ids)
        out['pre_tok_label_ids'].append(pre_tok_label_ids)
        out['post_tok_label_ids'].append(post_tok_label_ids)
        out['rel_ids'].append(rel_ids)
        out['pos_ids'].append(pos_ids)
        out['categories'].append(categories)
        out['index'].append(i) # can do some hash thing
        if not no_bias_type_labels:
            out['bias_label'].append(bias_label)

    if convert_to_tensors:
        out['pre_ids'] = torch.tensor(out['pre_ids'], dtype=torch.long)
        out['masks'] = torch.tensor(out['masks'], dtype=torch.uint8) # byte for masked_fill()
        out['pre_lens'] = torch.tensor(out['pre_lens'], dtype=torch.long)
        out['post_in_ids'] = torch.tensor(out['post_in_ids'], dtype=torch.long)
        out['post_out_ids'] = torch.tensor(out['post_out_ids'], dtype=torch.long)
        out['pre_tok_label_ids'] = torch.tensor(out['pre_tok_label_ids'], dtype=torch.float)  # for compartin to enrichment stuff
        out['post_tok_label_ids'] = torch.tensor(out['post_tok_label_ids'], dtype=torch.float)  # for loss multiplying
        out['rel_ids'] = torch.tensor(out['rel_ids'], dtype=torch.long)
        out['pos_ids'] = torch.tensor(out['pos_ids'], dtype=torch.long)
        out['categories'] = torch.tensor(out['categories'], dtype=torch.float)
        out['index'] = torch.tensor(out['index'], dtype=torch.int32)
        if not no_bias_type_labels:
            out['bias_label'] = torch.tensor(out['bias_label'], dtype=torch.int16)


    return out
Exemplo n.º 31
0
def is_single_word_diff(s1, s2):
    s1 = word_tokenize(s1.lower().strip())
    s2 = word_tokenize(s2.lower().strip())

    word_diff = diff(s1, s2)
    return sum([len(chunk) for tag, chunk in word_diff if tag == '-']) == 1
Exemplo n.º 32
0
out_pre = open(out_prefix + '.pre', 'w')
out_post = open(out_prefix + '.post', 'w')

len_skip_del = 0
len_skip_nodel = 0
skip = 0
dels = 0
nondels = 0
filtered = 0

for pre_l, post_l in zip(open(pre_fp), open(post_fp)):
    pre_l = pre_l.strip().split()
    post_l = post_l.strip().split()

    d = diff(pre_l, post_l)

    old = [x for x in d if x[0] == '-']
    new = [x for x in d if x[0] == '+']

    if len(old) == 0:
        skip += 1
        continue
    oldi = next((i for i, x in enumerate(d) if x[0] == '-'))

    if not new:
        if len(pre_l) > 100 or len(post_l) > 100:
            len_skip_del += 1
            continue
        out_pre.write(' '.join(pre_l) + '\n')
        out_post.write(' '.join(post_l) + '\n')
Exemplo n.º 33
0
def export_word_edit_matrix(context: List,
                            current_sen: List,
                            label_sen: List,
                            super_mode: str = 'before',
                            # if there requires multiple insert, we only
                            # keep the longest one
                            only_one_insert: bool = False):
    if isinstance(context, str):
        context_seq = list(context)
        current_seq = list(current_sen)
        label_seq = list(label_sen)
    else:
        context_seq = context
        current_seq = current_sen
        label_seq = label_sen
    applied_changes = diff(current_seq, label_seq)

    def sub_finder(cus_list, pattern, used_pos):
        find_indices = []
        for i in range(len(cus_list)):
            if cus_list[i] == pattern[0] and \
                    cus_list[i:i + len(pattern)] == pattern \
                    and i not in used_pos:
                find_indices.append((i, i + len(pattern)))
        if len(find_indices) == 0:
            return 0, 0
        else:
            return find_indices[-1]

    def cont_sub_finder(cus_list, pattern, used_pos):
        context_len = len(cus_list)
        pattern_len = len(pattern)
        for i in range(context_len):
            k = i
            j = 0
            temp_indices = []
            while j < pattern_len and k < context_len:
                if cus_list[k] == pattern[j][0] and \
                        cus_list[k:k + len(pattern[j])] == pattern[j] \
                        and k not in used_pos:
                    temp_indices.append((k, k + len(pattern[j])))
                    j += 1
                else:
                    k += 1
            if j == pattern_len:
                return zip(*temp_indices)
        else:
            return 0, 0

    rm_range = None
    ret_ops = []
    context_used_pos = []
    current_used_pos = []
    pointer = 0
    for diff_sample in applied_changes:
        diff_op = diff_sample[0]
        diff_content = diff_sample[1]
        if diff_op == '-':
            if rm_range is not None:
                ret_ops.append(['remove', rm_range, []])
            start, end = sub_finder(current_seq, diff_content, current_used_pos
                                    )
            rm_range = [start, end]
            current_used_pos.extend(list(range(start, end)))
        elif diff_op == '+':
            start, end = sub_finder(context_seq, diff_content, context_used_pos)
            # cannot find the exact match substring, we should identify the snippets
            if start == 0 and end == 0:
                inner_diff = diff(diff_content, context_seq)
                overlap_content = [inner_diff_sample[1] for
                                   inner_diff_sample in inner_diff if inner_diff_sample[0] == '=']
                if len(overlap_content) > 0:
                    # only take one insert
                    if len(overlap_content) == 1 or only_one_insert:
                        overlap_content = sorted(overlap_content, key=lambda x: len(x), reverse=True)[0]
                        start, end = sub_finder(context_seq, overlap_content,
                                                context_used_pos)
                    else:
                        start_end_tuple = cont_sub_finder(context_seq, overlap_content, context_used_pos)
                        # start is a list, end is also
                        start, end = start_end_tuple
                else:
                    start, end = 0, 0
            if not (start == 0 and end == 0):
                if isinstance(start, int):
                    add_ranges = [[start, end]]
                else:
                    add_ranges = list(zip(start, end))

                if rm_range is not None:
                    for add_range in add_ranges:
                        context_used_pos.extend(list(range(add_range[0], add_range[1])))
                        ret_ops.append(['replace', rm_range, add_range])
                    rm_range = None
                else:
                    for add_range in add_ranges:
                        if super_mode in ['before', 'both']:
                            ret_ops.append(['before', [pointer, pointer], add_range])
                        if super_mode in ['after', 'both']:
                            if pointer >= 1:
                                ret_ops.append(['after', [pointer - 1, pointer - 1], add_range])
        elif diff_op == '=':
            if rm_range is not None:
                ret_ops.append(['remove', rm_range, []])
            start, end = sub_finder(current_seq, diff_content, current_used_pos
                                    )
            current_used_pos.extend(list(range(start, end)))
            rm_range = None
            pointer = end
    return ret_ops