def random_attack(text_ls, true_label, predictor, perturb_ratio, stop_words_set, word2idx, idx2word, cos_sim,
                  sim_predictor=None, import_score_threshold=-1., sim_score_threshold=0.5, sim_score_window=15,
                  synonym_num=50, batch_size=32):
    # first check the prediction of the original text
    orig_probs = predictor([text_ls]).squeeze()
    orig_label = torch.argmax(orig_probs)
    orig_prob = orig_probs.max()
    if true_label != orig_label:
        return '', 0, orig_label, orig_label, 0
    else:
        len_text = len(text_ls)
        if len_text < sim_score_window:
            sim_score_threshold = 0.1  # shut down the similarity thresholding function
        half_sim_score_window = (sim_score_window - 1) // 2
        num_queries = 1

        # get the pos and verb tense info
        pos_ls = criteria.get_pos(text_ls)

        # randomly get perturbed words
        perturb_idxes = random.sample(range(len_text), int(len_text * perturb_ratio))
        words_perturb = [(idx, text_ls[idx]) for idx in perturb_idxes]

        # find synonyms
        words_perturb_idx = [word2idx[word] for idx, word in words_perturb if word in word2idx]
        synonym_words, _ = pick_most_similar_words_batch(words_perturb_idx, cos_sim, idx2word, synonym_num, 0.5)
        synonyms_all = []
        for idx, word in words_perturb:
            if word in word2idx:
                synonyms = synonym_words.pop(0)
                if synonyms:
                    synonyms_all.append((idx, synonyms))

        # start replacing and attacking
        text_prime = text_ls[:]
        text_cache = text_prime[:]
        num_changed = 0
        for idx, synonyms in synonyms_all:
            new_texts = [text_prime[:idx] + [synonym] + text_prime[min(idx + 1, len_text):] for synonym in synonyms]
            new_probs = predictor(new_texts, batch_size=batch_size)

            # compute semantic similarity
            if idx >= half_sim_score_window and len_text - idx - 1 >= half_sim_score_window:
                text_range_min = idx - half_sim_score_window
                text_range_max = idx + half_sim_score_window + 1
            elif idx < half_sim_score_window and len_text - idx - 1 >= half_sim_score_window:
                text_range_min = 0
                text_range_max = sim_score_window
            elif idx >= half_sim_score_window and len_text - idx - 1 < half_sim_score_window:
                text_range_min = len_text - sim_score_window
                text_range_max = len_text
            else:
                text_range_min = 0
                text_range_max = len_text
            semantic_sims = \
            sim_predictor.semantic_sim([' '.join(text_cache[text_range_min:text_range_max])] * len(new_texts),
                                       list(map(lambda x: ' '.join(x[text_range_min:text_range_max]), new_texts)))[0]

            num_queries += len(new_texts)
            if len(new_probs.shape) < 2:
                new_probs = new_probs.unsqueeze(0)
            new_probs_mask = (orig_label != torch.argmax(new_probs, dim=-1)).data.cpu().numpy()
            # prevent bad synonyms
            new_probs_mask *= (semantic_sims >= sim_score_threshold)
            # prevent incompatible pos
            synonyms_pos_ls = [criteria.get_pos(new_text[max(idx - 4, 0):idx + 5])[min(4, idx)]
                               if len(new_text) > 10 else criteria.get_pos(new_text)[idx] for new_text in new_texts]
            pos_mask = np.array(criteria.pos_filter(pos_ls[idx], synonyms_pos_ls))
            new_probs_mask *= pos_mask

            if np.sum(new_probs_mask) > 0:
                text_prime[idx] = synonyms[(new_probs_mask * semantic_sims).argmax()]
                num_changed += 1
                break
            else:
                new_label_probs = new_probs[:, orig_label] + torch.from_numpy(
                        (semantic_sims < sim_score_threshold) + (1 - pos_mask).astype(float)).float().cuda()
                new_label_prob_min, new_label_prob_argmin = torch.min(new_label_probs, dim=-1)
                if new_label_prob_min < orig_prob:
                    text_prime[idx] = synonyms[new_label_prob_argmin]
                    num_changed += 1
            text_cache = text_prime[:]
        return ' '.join(text_prime), num_changed, orig_label, torch.argmax(predictor([text_prime])), num_queries
def contextual_attack(text_ls, true_label, predictor, maskedLM_predictor , stop_words_set, word2idx, idx2word, cos_sim, sim_predictor=None,
           import_score_threshold=-1., sim_score_threshold=0.5, sim_score_window=15, synonym_num=50,
           batch_size=32):
    # first check the prediction of the original text
    orig_probs = predictor([text_ls]).squeeze()
    orig_label = torch.argmax(orig_probs)
    orig_prob = orig_probs.max()
    if true_label != orig_label:
        return '', 0, orig_label, orig_label, 0
    else:
        len_text = len(text_ls)
        if len_text < sim_score_window:
            sim_score_threshold = 0.1  # shut down the similarity thresholding function
        half_sim_score_window = (sim_score_window - 1) // 2
        num_queries = 1

        # get the pos and verb tense info
        pos_ls = criteria.get_pos(text_ls)

        # get importance score
        leave_1_texts = [text_ls[:ii] + ['<oov>'] + text_ls[min(ii + 1, len_text):] for ii in range(len_text)]
        leave_1_probs = predictor(leave_1_texts, batch_size=batch_size)
        num_queries += len(leave_1_texts)
        leave_1_probs_argmax = torch.argmax(leave_1_probs, dim=-1)
        import_scores = (orig_prob - leave_1_probs[:, orig_label] + (leave_1_probs_argmax != orig_label).float() * (
                    leave_1_probs.max(dim=-1)[0] - torch.index_select(orig_probs, 0,
                                                                      leave_1_probs_argmax))).data.cpu().numpy()
        
        # get words to perturb ranked by importance scorefor word in words_perturb
        words_perturb = []
        for idx, score in sorted(enumerate(import_scores), key=lambda x: x[1], reverse=True):
            try:
                if score > import_score_threshold and text_ls[idx] not in stop_words_set:
                    words_perturb.append((idx, text_ls[idx]))
            except:
                print(idx, len(text_ls), import_scores.shape, text_ls, len(leave_1_texts))
        #print("Generated words_perturb")
        # find synonyms
        new_texts=[]
        synonyms_all = []
        #print(' '.join(text_ls))
        for idx, word in words_perturb:
            synonyms=[]
            if idx >=127:
                continue
            new_texts.append(text_ls[:idx] + ['[MASK]'] + text_ls[min(idx + 1, len_text):])
            masked_lm_probs=maskedLM_predictor.text_pred(new_texts, batch_size=batch_size)
            #masked_lm_probs=masked_lm_probs.cpu().numpy()
            #print(np.shape(masked_lm_probs))
            #exit()
            values,indices = torch.topk(masked_lm_probs, 25, dim=-1) 
            tokens=maskedLM_predictor.convert_ids_to_tokens(indices.view(-1).cpu().numpy())
            tokens=np.reshape(tokens,(1,128,-1))
            #print(np.shape(tokens))
            #exit()
            #print(word+" "+str(idx))
            #print(' '.join(text_ls))
            for i in range(25):
                word=tokens[0][idx][i]
                if word in word2idx:
                    synonyms.append(word)
                #print(tokens[0][idx+1][i]+" ",end="")
            #print("\n")
            #for i in range(25):
            #    print(tokens[0][idx][i]+" ",end="")
            #print("\n")
            #for i in range(len(indices)):
            #    if indices[i] in idx2word:
            #        synonyms.append( idx2word[indices[i]])# if indices[i] in idx2word)
            if synonyms:
                synonyms_all.append((idx, synonyms))
            #exit()
        # words_perturb_idx = [word2idx[word] for idx, word in words_perturb if word in word2idx]
        # synonym_words, _ = pick_most_similar_words_batch(words_perturb_idx, cos_sim, idx2word, synonym_num, 0.5)
        
        # for idx, word in words_perturb:
        #     if word in word2idx:
        #         synonyms = synonym_words.pop(0)
        #         if synonyms:
        #             synonyms_all.append((idx, synonyms))

        # start replacing and attacking
        text_prime = text_ls[:]
        text_cache = text_prime[:]
        num_changed = 0
        #print("Generated Synonyms")
        for idx, synonyms in synonyms_all:
            new_texts = [text_prime[:idx] + [synonym] + text_prime[min(idx + 1, len_text):] for synonym in synonyms]
            #print(new_texts)
            if new_texts:
                new_probs = predictor(new_texts, batch_size=batch_size)
            else:
                continue
            # compute semantic similarity
            if idx >= half_sim_score_window and len_text - idx - 1 >= half_sim_score_window:
                text_range_min = idx - half_sim_score_window
                text_range_max = idx + half_sim_score_window + 1
            elif idx < half_sim_score_window and len_text - idx - 1 >= half_sim_score_window:
                text_range_min = 0
                text_range_max = sim_score_window
            elif idx >= half_sim_score_window and len_text - idx - 1 < half_sim_score_window:
                text_range_min = len_text - sim_score_window
                text_range_max = len_text
            else:
                text_range_min = 0
                text_range_max = len_text
            semantic_sims = \
            sim_predictor.semantic_sim([' '.join(text_cache[text_range_min:text_range_max])] * len(new_texts),
                                       list(map(lambda x: ' '.join(x[text_range_min:text_range_max]), new_texts)))[0]

            num_queries += len(new_texts)
            if len(new_probs.shape) < 2:
                new_probs = new_probs.unsqueeze(0)
            new_probs_mask = (orig_label != torch.argmax(new_probs, dim=-1)).data.cpu().numpy()
            # prevent bad synonyms
            new_probs_mask *= (semantic_sims >= sim_score_threshold)
            # prevent incompatible pos
            synonyms_pos_ls = [criteria.get_pos(new_text[max(idx - 4, 0):idx + 5])[min(4, idx)]
                               if len(new_text) > 10 else criteria.get_pos(new_text)[idx] for new_text in new_texts]
            pos_mask = np.array(criteria.pos_filter(pos_ls[idx], synonyms_pos_ls))
            new_probs_mask *= pos_mask

            if np.sum(new_probs_mask) > 0:
                text_prime[idx] = synonyms[(new_probs_mask * semantic_sims).argmax()]
                num_changed += 1
                break
            else:
                new_label_probs = new_probs[:, orig_label] + torch.from_numpy(
                        (semantic_sims < sim_score_threshold) + (1 - pos_mask).astype(float)).float().cuda()
                new_label_prob_min, new_label_prob_argmin = torch.min(new_label_probs, dim=-1)
                if new_label_prob_min < orig_prob:
                    text_prime[idx] = synonyms[new_label_prob_argmin]
                    num_changed += 1
            text_cache = text_prime[:]
        return ' '.join(text_prime), num_changed, orig_label, torch.argmax(predictor([text_prime])), num_queries
Example #3
0
def attack(fuzz_val,
           top_k_words,
           qrs,
           wts,
           sample_index,
           text_ls,
           true_label,
           predictor,
           stop_words_set,
           word2idx,
           idx2word,
           cos_sim,
           word_embedding,
           sim_predictor=None,
           import_score_threshold=-1.,
           sim_score_threshold=0.5,
           sim_score_window=15,
           synonym_num=50,
           batch_size=32):
    rows = []
    nlp = spacy.load('en_core_web_sm')
    masked_lang_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    changed_with = []
    doc = nlp(' '.join(text_ls))
    text = []
    for sent in doc.sents:
        for token in sent:
            text.append(token.text)
    tok_text = []
    for item in text:
        ap = item.find("'")
        if ap >= 0:
            tok_text.append(item[0:ap])
            tok_text.append("'")
            tok_text.append(item[ap + 1:len(item)])
        else:
            tok_text.append(item)
    text = []
    for item in tok_text:
        if len(item) > 0:
            text.append(item)

    text_ls = text[:]

    # first check the prediction of the original text
    orig_probs = predictor([text_ls]).squeeze()
    orig_label = torch.argmax(orig_probs)
    orig_prob = orig_probs.max()
    if true_label != orig_label:
        return '', 0, orig_label, orig_label, 0, [], []
    else:

        len_text = len(text_ls)
        if len_text < sim_score_window:
            sim_score_threshold = 0.1  # shut down the similarity thresholding function
        half_sim_score_window = (sim_score_window - 1) // 2
        num_queries = 1

        # get the pos and verb tense info
        pos_ls = criteria.get_pos(text_ls)
        # get importance score
        leave_1_texts = [
            text_ls[:ii] + ['<oov>'] + text_ls[min(ii + 1, len_text):]
            for ii in range(len_text)
        ]
        leave_1_probs = predictor(leave_1_texts, batch_size=batch_size)
        num_queries += len(leave_1_texts)
        leave_1_probs_argmax = torch.argmax(leave_1_probs, dim=-1)
        import_scores = (
            orig_prob - leave_1_probs[:, orig_label] +
            (leave_1_probs_argmax != orig_label).float() *
            (leave_1_probs.max(dim=-1)[0] - torch.index_select(
                orig_probs, 0, leave_1_probs_argmax))).data.cpu().numpy()

        # get words to perturb ranked by importance score for word in words_perturb
        words_perturb = []
        for idx, score in sorted(enumerate(import_scores),
                                 key=lambda x: x[1],
                                 reverse=True):
            try:
                if score > import_score_threshold and text_ls[
                        idx] not in stop_words_set and len(text_ls[idx]) > 2:
                    words_perturb.append((idx, score))
            except:
                print(idx, len(text_ls), import_scores.shape, text_ls,
                      len(leave_1_texts))
        #return '', 0, orig_label, orig_label, 0, [], words_perturb
        # find synonyms
        words_perturb_idx = [
            word2idx[word] for idx, word in words_perturb if word in word2idx
        ]
        #synonym_words, synonym_values, synonyms_dict = pick_most_similar_words_batch(words_perturb_idx, cos_sim, idx2word, synonym_num, -1.0)
        # start replacing and attacking
        text_prime = text_ls[:]
        sims = []
        text_cache = text_prime[:]
        num_changed = 0
        for idx, score in words_perturb:
            #print(text_ls[idx])
            text_range_min, text_range_max = calc_window(idx, 3, 10, len_text)

            sliced_text = text_prime[text_range_min:text_range_max]
            #print(sliced_text)
            new_index = idx - text_range_min
            #print(sliced_text[new_index])
            masked_idx = new_index

            tokens, words, position = gen.convert_sentence_to_token(
                ' '.join(sliced_text), 1000, tokenizer)
            assert len(words) == len(position)

            len_tokens = len(tokens)

            mask_position = position[masked_idx]

            if isinstance(mask_position, list):
                feature = gen.convert_whole_word_to_feature(
                    tokens, mask_position, 1000, tokenizer)
            else:
                feature = gen.convert_token_to_feature(tokens, mask_position,
                                                       1000, tokenizer)

            tokens_tensor = torch.tensor([feature.input_ids])
            token_type_ids = torch.tensor([feature.input_type_ids])
            attention_mask = torch.tensor([feature.input_mask])
            tokens_tensor = tokens_tensor.to('cuda')
            token_type_ids = token_type_ids.to('cuda')
            attention_mask = attention_mask.to('cuda')
            #new_probs = predictor(new_texts, batch_size=batch_size)
            masked_lang_model.to('cuda')
            masked_lang_model.eval()
            ps = PorterStemmer()

            with torch.no_grad():
                prediction_scores = masked_lang_model(tokens_tensor,
                                                      token_type_ids,
                                                      attention_mask)

            if isinstance(mask_position, list):
                predicted_top = prediction_scores[0, mask_position[0]].topk(50)
            else:
                predicted_top = prediction_scores[0, mask_position].topk(50)

            pre_tokens = tokenizer.convert_ids_to_tokens(
                predicted_top[1].cpu().numpy())
            synonyms_initial = gen.substitution_generation(
                words[masked_idx], pre_tokens, predicted_top[0].cpu().numpy(),
                ps, 50)
            new_texts = []
            avg = []
            synonyms = []
            assert words[masked_idx] == text_ls[idx]
            #print(synonyms)
            for candidate_word in synonyms_initial:
                if candidate_word in word_embedding and words[
                        masked_idx] in word_embedding:
                    candidate_similarity = calc_similarity(
                        word_embedding[words[masked_idx]],
                        word_embedding[candidate_word])
                    avg.append(candidate_similarity)
                    #print(words[masked_idx], candidate_similarity, candidate_word)
                    if candidate_similarity >= 0.2:
                        new_texts.append(text_prime[:idx] + [candidate_word] +
                                         text_prime[min(idx + 1, len_text):])
                        synonyms.append(candidate_word)
                else:
                    new_texts.append(text_prime[:idx] + [candidate_word] +
                                     text_prime[min(idx + 1, len_text):])
                    synonyms.append(candidate_word)
            #print(len(new_texts))
            if len(new_texts) == 0:
                continue

            text_range_min, text_range_max = calc_window(
                idx, half_sim_score_window, sim_score_window, len_text)
            semantic_sims = \
            sim_predictor.semantic_sim([' '.join(text_cache[text_range_min:text_range_max])] * len(new_texts),
                                       list(map(lambda x: ' '.join(x[text_range_min:text_range_max]), new_texts)))[0]
            sims.append(np.sum(semantic_sims) / len(semantic_sims))

            new_probs_mask = np.ones(
                len(new_texts)
            )  #(orig_label != torch.argmax(new_probs, dim=-1)).data.cpu().numpy()
            # prevent bad synonyms
            new_probs_mask *= (semantic_sims >= sim_score_threshold)
            # prevent incompatible pos
            synonyms_pos_ls = [
                criteria.get_pos(new_text[max(idx - 4, 0):idx +
                                          5])[min(4, idx)]
                if len(new_text) > 10 else criteria.get_pos(new_text)[idx]
                for new_text in new_texts
            ]
            pos_mask = np.array(
                criteria.pos_filter(pos_ls[idx], synonyms_pos_ls))
            new_probs_mask *= pos_mask
            new_vals = semantic_sims * new_probs_mask
            index = []
            mini = 2
            for i in range(len(new_vals)):
                if new_vals[i] > 0:
                    index.append((new_vals[i], i))
            if len(index) == 0:
                continue
            new_texts1 = [new_texts[ind] for val, ind in index]
            #print(len(new_texts1))
            num_queries += len(new_texts1)
            if num_queries > qrs:
                return '', 0, orig_label, orig_label, 0, [], []
            new_probs = predictor(new_texts1, batch_size=batch_size)
            if len(new_probs.shape) < 2:
                new_probs = new_probs.unsqueeze(0)
            pr = (orig_label != torch.argmax(new_probs,
                                             dim=-1)).data.cpu().numpy()
            if np.sum(pr) > 0:
                text_prime[idx] = synonyms[index[pr.argmax(
                )][1]]  #synonyms[(new_probs_mask * semantic_sims).argmax()]
                num_changed += 1
                break
            else:
                new_label_probs = new_probs[:, orig_label]
                new_label_prob_min, new_label_prob_argmin = torch.min(
                    new_label_probs, dim=-1)
                if new_label_prob_min < orig_prob:
                    text_prime[idx] = synonyms[index[new_label_prob_argmin][1]]
                    num_changed += 1
            text_cache = text_prime[:]

            if fuzz.token_set_ratio(' '.join(text_ls),
                                    ' '.join(text_cache)) < fuzz_val:
                return ' '.join(
                    text_prime), num_changed, orig_label, torch.argmax(
                        predictor([text_prime
                                   ])), num_queries, words_perturb, sims
        return ' '.join(text_prime), num_changed, orig_label, torch.argmax(
            predictor([text_prime])), num_queries, words_perturb, sims
def attack(
    text_ls,
    true_label,
    predictor,
    stop_words_set,
    word2idx,
    idx2word,
    cos_sim,
    sim_predictor=None,
    import_score_threshold=-1.0,
    sim_score_threshold=0.5,
    sim_score_window=15,
    synonym_num=50,
    batch_size=32,
):
    # first check the prediction of the original text
    orig_probs = predictor([text_ls]).squeeze()
    orig_label = torch.argmax(orig_probs)
    # orig_label = (
    #     torch.tensor(
    #         list(map(lambda x: 1.0 if x[0] > 0.5 else 0.0, orig_probs)),
    #         dtype=torch.long,
    #     )
    #     .cuda()
    #     .unsqueeze(-1)
    # )
    # orig_label = torch.tensor(
    #     1 if orig_probs.data >= 0.5 else 0, dtype=torch.long
    # ).cuda()
    orig_prob = orig_probs.max()
    if true_label != orig_label:
        return "", 0, orig_label, orig_label, 0
    else:
        len_text = len(text_ls)
        if len_text < sim_score_window:
            sim_score_threshold = 0.1  # shut down the similarity thresholding function
        half_sim_score_window = (sim_score_window - 1) // 2
        num_queries = 1

        # get the pos and verb tense info
        pos_ls = criteria.get_pos(text_ls)

        # get importance score
        leave_1_texts = [
            text_ls[:ii] + ["<oov>"] + text_ls[min(ii + 1, len_text):]
            for ii in range(len_text)
        ]
        leave_1_probs = predictor(leave_1_texts, batch_size=batch_size)
        num_queries += len(leave_1_texts)
        leave_1_probs_argmax = torch.argmax(leave_1_probs, dim=-1)
        # leave_1_probs_argmax = torch.tensor(
        #     1 if leave_1_probs.data >= 0.5 else 0, dtype=torch.long
        # ).cuda()
        # leave_1_probs_argmax = (
        #     torch.tensor(
        #         list(map(lambda x: 1 if x[0] > 0.5 else 0, leave_1_probs)),
        #         dtype=torch.long,
        #     )
        #     .cuda()
        #     .unsqueeze(-1)
        # )
        import_scores = (
            (orig_prob - leave_1_probs[:, orig_label] +
             (leave_1_probs_argmax != orig_label).float() *
             (leave_1_probs.max(dim=-1)[0] - torch.index_select(
                 orig_probs, 0, leave_1_probs_argmax))).data.cpu().numpy())

        # get words to perturb ranked by importance scorefor word in words_perturb
        words_perturb = []
        for idx, score in sorted(enumerate(import_scores),
                                 key=lambda x: x[1],
                                 reverse=True):
            try:
                if (score > import_score_threshold
                        and text_ls[idx] not in stop_words_set):
                    words_perturb.append((idx, text_ls[idx]))
            except:
                print(idx, len(text_ls), import_scores.shape, text_ls,
                      len(leave_1_texts))

        # find synonyms
        words_perturb_idx = [
            word2idx[word] for idx, word in words_perturb if word in word2idx
        ]
        synonym_words, _ = pick_most_similar_words_batch(
            words_perturb_idx, cos_sim, idx2word, synonym_num, 0.5)
        synonyms_all = []
        for idx, word in words_perturb:
            if word in word2idx:
                synonyms = synonym_words.pop(0)
                if synonyms:
                    synonyms_all.append((idx, synonyms))

        # start replacing and attacking
        text_prime = text_ls[:]
        text_cache = text_prime[:]
        num_changed = 0
        for idx, synonyms in synonyms_all:
            new_texts = [
                text_prime[:idx] + [synonym] +
                text_prime[min(idx + 1, len_text):] for synonym in synonyms
            ]
            new_probs = predictor(new_texts, batch_size=batch_size)

            # compute semantic similarity
            if (idx >= half_sim_score_window
                    and len_text - idx - 1 >= half_sim_score_window):
                text_range_min = idx - half_sim_score_window
                text_range_max = idx + half_sim_score_window + 1
            elif (idx < half_sim_score_window
                  and len_text - idx - 1 >= half_sim_score_window):
                text_range_min = 0
                text_range_max = sim_score_window
            elif (idx >= half_sim_score_window
                  and len_text - idx - 1 < half_sim_score_window):
                text_range_min = len_text - sim_score_window
                text_range_max = len_text
            else:
                text_range_min = 0
                text_range_max = len_text
            semantic_sims = sim_predictor.semantic_sim(
                [" ".join(text_cache[text_range_min:text_range_max])] *
                len(new_texts),
                list(
                    map(lambda x: " ".join(x[text_range_min:text_range_max]),
                        new_texts)),
            )[0]

            num_queries += len(new_texts)
            if len(new_probs.shape) < 2:
                new_probs = new_probs.unsqueeze(0)
            new_probs_mask = ((orig_label != torch.argmax(
                new_probs, dim=-1)).data.cpu().numpy())
            # prevent bad synonyms
            new_probs_mask *= semantic_sims >= sim_score_threshold
            # prevent incompatible pos
            synonyms_pos_ls = [
                criteria.get_pos(new_text[max(idx - 4, 0):idx +
                                          5])[min(4, idx)]
                if len(new_text) > 10 else criteria.get_pos(new_text)[idx]
                for new_text in new_texts
            ]
            pos_mask = np.array(
                criteria.pos_filter(pos_ls[idx], synonyms_pos_ls))
            new_probs_mask *= pos_mask

            if np.sum(new_probs_mask) > 0:
                text_prime[idx] = synonyms[(new_probs_mask *
                                            semantic_sims).argmax()]
                num_changed += 1
                break
            else:
                new_label_probs = (new_probs[:, orig_label] + torch.from_numpy(
                    (semantic_sims < sim_score_threshold) +
                    (1 - pos_mask).astype(float)).float().cuda())
                new_label_prob_min, new_label_prob_argmin = torch.min(
                    new_label_probs, dim=-1)
                if new_label_prob_min < orig_prob:
                    text_prime[idx] = synonyms[new_label_prob_argmin]
                    num_changed += 1
            text_cache = text_prime[:]
        return (
            " ".join(text_prime),
            num_changed,
            orig_label,
            torch.argmax(predictor([text_prime])),
            num_queries,
        )
Example #5
0
def text_fooler(text_ls,
                true_label,
                model,
                stop_words_set,
                word2idx,
                idx2word,
                cos_sim,
                sim_predictor=None,
                import_score_threshold=-1.,
                sim_score_threshold=0.7,
                sim_score_window=15,
                synonym_num=50,
                batch_size=32):
    adversaries = []
    # first check the prediction of the original text#
    ref_ans, stud_ans = text_ls
    stud_ans = list_to_string(stud_ans).split(" ")
    orig_logits = predict(model, ref_ans, stud_ans, true_label)
    orig_probs = F.softmax(orig_logits, dim=0)
    orig_label = torch.argmax(orig_probs).item()
    orig_prob = orig_probs.max().item()
    if true_label != orig_label:
        return '', 0, orig_label, orig_label, 0
    else:
        len_text = len(stud_ans)
        if len_text < sim_score_window:
            sim_score_threshold = 0.1  # shut down the similarity thresholding function
        half_sim_score_window = (sim_score_window - 1) // 2
        num_queries = 1

        # get the pos and verb tense info
        pos_ls = criteria.get_pos(stud_ans)

        # get importance score
        leave_1_texts = [
            stud_ans[:ii] + ['[UNK]'] + stud_ans[min(ii + 1, len_text):]
            for ii in range(len_text)
        ]
        leave_1_probs = []
        num_queries += len(leave_1_texts)

        for new_ans in leave_1_texts:
            new_logits = predict(model, ref_ans, new_ans, true_label)
            new_probs = F.softmax(new_logits, dim=0)
            leave_1_probs.append(new_probs)
        leave_1_probs = torch.stack(leave_1_probs)
        leave_1_probs_argmax = torch.argmax(leave_1_probs, dim=-1)

        import_scores = (
            orig_prob - leave_1_probs[:, orig_label] +
            (leave_1_probs_argmax != orig_label).float() *
            (leave_1_probs.max(dim=-1)[0] - torch.index_select(
                orig_probs, 0, leave_1_probs_argmax))).data.cpu().numpy()

        # get words to perturb ranked by importance score for word in words_perturb
        words_perturb = []
        for idx, score in sorted(enumerate(import_scores),
                                 key=lambda x: x[1],
                                 reverse=True):
            try:
                if score > import_score_threshold and stud_ans[
                        idx] not in stop_words_set:
                    words_perturb.append((idx, stud_ans[idx]))
            except:
                print(idx, len(stud_ans), import_scores.shape, stud_ans,
                      len(leave_1_texts))

        # find synonyms
        words_perturb_idx = [
            word2idx[word] for idx, word in words_perturb if word in word2idx
        ]
        synonym_words, _ = pick_most_similar_words_batch(
            words_perturb_idx, cos_sim, idx2word, synonym_num, 0.5)

        synonyms_all = []
        for idx, word in words_perturb:
            if word in word2idx:
                synonyms = synonym_words.pop(0)
                if synonyms:
                    synonyms_all.append((idx, synonyms))

        # start replacing and attacking
        text_prime = stud_ans[:]
        text_cache = text_prime[:]
        num_changed = 0
        for idx, synonyms in synonyms_all:
            new_texts = [
                text_prime[:idx] + [synonym] +
                text_prime[min(idx + 1, len_text):] for synonym in synonyms
            ]
            new_probs = []
            new_labels = []
            for syn_text in new_texts:
                syn_logits = predict(model, ref_ans, syn_text, true_label)
                new_probs.append(F.softmax(syn_logits, dim=0))

            new_probs = torch.stack(new_probs)

            # compute semantic similarity
            if idx >= half_sim_score_window and len_text - idx - 1 >= half_sim_score_window:
                text_range_min = idx - half_sim_score_window
                text_range_max = idx + half_sim_score_window + 1
            elif idx < half_sim_score_window <= len_text - idx - 1:
                text_range_min = 0
                text_range_max = sim_score_window
            elif idx >= half_sim_score_window > len_text - idx - 1:
                text_range_min = len_text - sim_score_window
                text_range_max = len_text
            else:
                text_range_min = 0
                text_range_max = len_text
            semantic_sims = \
                sim_predictor.semantic_sim([' '.join(text_cache[text_range_min:text_range_max])] * len(new_texts),
                                           list(map(lambda x: ' '.join(x[text_range_min:text_range_max]), new_texts)))[
                    0]

            num_queries += len(new_texts)

            if len(new_probs.shape) < 2:
                new_probs = new_probs.unsqueeze(0)
            new_probs_mask = (2 == torch.argmax(new_probs,
                                                dim=-1)).data.cpu().numpy()
            # prevent bad synonyms
            new_probs_mask *= (semantic_sims >= sim_score_threshold)
            # prevent incompatible pos (maybe not)

            synonyms_pos_ls = [
                criteria.get_pos(new_text[max(idx - 4, 0):idx +
                                          5])[min(4, idx)]
                if len(new_text) > 10 else criteria.get_pos(new_text)[idx]
                for new_text in new_texts
            ]

            pos_mask = np.array(
                criteria.pos_filter(pos_ls[idx], synonyms_pos_ls))
            # Uncomment to inverse mask and only allow candidates where POS is not the same
            # pos_mask = np.invert(pos_mask)
            new_probs_mask *= pos_mask

            if np.sum(new_probs_mask) > 0:
                text_prime[idx] = synonyms[(new_probs_mask *
                                            semantic_sims).argmax()]
                num_changed += 1
                adversaries.append(tuple(text_prime))
                break
            """
            else:
                new_label_probs = new_probs[:, orig_label] + torch.from_numpy(
                    (semantic_sims < sim_score_threshold) + (1 - pos_mask).astype(float)).float()
                new_label_prob_min, new_label_prob_argmin = torch.min(new_label_probs, dim=-1)
                if new_label_prob_min < orig_prob:
                    text_prime[idx] = synonyms[new_label_prob_argmin]
                    num_changed += 1
            text_cache = text_prime[:]
            adversaries.append(text_cache)
            #new_labels.append()
            """
        # Combine adversaries with new labels
        result = set(i for i in adversaries if i[0] != stud_ans[:])
        return num_changed, num_queries, result