Beispiel #1
0
class MultiVectorizer():

    reserved = ["<PAD>", "<UNK>"]
    embedding_matrix = None
    embedding_word_vector = {}
    glove = False

    def __init__(self, reserved=None, min_occur=1, use_bert=False, glove_path=None, tokenizer=None, embedding_size=300):

        self.mi_occur = min_occur
        self.embedding_size = embedding_size
        self.use_bert = use_bert

        self.nlp = spacy.load("en")
        if tokenizer is None:
            self.tokenizer = English().Defaults.create_tokenizer(self.nlp)
        else:
            self.tokenizer = tokenizer

        if glove_path is not None:
            self.load_glove(glove_path)
            self.glove = True

        if reserved is not None:
            self.vocabulary = Dictionary([self.reserved.extend(reserved)])
        else:
            self.vocabulary = Dictionary([self.reserved])

    def get_vocabulary_size(self):
        if not self.use_bert:
            return len(self.vocabulary.token2id.items())
        else:
            return len(self.tokenizer.vocab.keys())

    def load_glove(self, glove_file_path):
        f = open(glove_file_path, encoding="utf-8")
        for line in tqdm(f):
            value = line.split(" ")
            word = value[0]
            coef = np.array(value[1:], dtype='float32')
            self.embedding_word_vector[word] = coef
        f.close()

    def get_embedding_matrix(self):
        return self.embedding_matrix

    def is_word(self, string_value):
        if self.embedding_word_vector.get(string_value):
            return True

    def get_vocabulary(self):
        if not self.use_bert:
            return self.vocabulary
        else:
            return self.tokenizer.vocab

    def get_word_id(self, word):
        if not self.use_bert:
            return self.vocabulary.token2id[word]
        else:
            return self.tokenizer.vocab[word]


    def get_word_from_id(self, index):
        if not self.use_bert:
            return self.vocabulary.id2token[index]
        else:
            return self.tokenizer.inv_vocab[index]

    def fit_document(self, documents):
        document_tokens = []
        for document in documents:
            section_tokens = []
            for section in document:
                sentence_tokens = []
                for sentence in section:
                    tokens = self.tokenizer(sentence.lower())
                    word_str_tokens = list(map(convert_to_string, tokens))
                    sentence_tokens.append(word_str_tokens)
                    self.vocabulary.add_documents(sentence_tokens)
                section_tokens.append(sentence_tokens)
            document_tokens.append(section_tokens)
        return document_tokens

    def fit_bert_sentences(self, samples, remove_stop_words=True):
        output_tokens = []
        vocab = []
        stop_words = set(stopwords.words('english'))
        for sample in tqdm(samples):
            sentence_tokens = []
            for sentence in sample:
                tokens = self.tokenizer.tokenize(sentence.lower())
                tokens = [w for w in tokens if not w in stop_words]
                tokens = ["[CLS]"] + tokens + ["[SEP]"]
                sentence_tokens.append(tokens)
                vocab.append(tokens)
            output_tokens.append(sentence_tokens)
        #self.vocabulary.add_documents(vocab)
        return output_tokens

    def fit_samples_with_sentences(self, samples, remove_stop_words=True):
        output_tokens = []
        vocab = []
        for sample in tqdm(samples):
            sentence_tokens = []
            for sentence in sample:
                tokens = self.tokenizer(sentence.lower())
                if remove_stop_words:
                    tokens = [token for token in tokens if not token.is_stop]
                word_str_tokens = list(map(convert_to_string, tokens))
                sentence_tokens.append(word_str_tokens)
                vocab.append(word_str_tokens)
            output_tokens.append(sentence_tokens)
        self.vocabulary.add_documents(vocab)
        return output_tokens

    def fit(self, X, remove_stop_words=True, list_of_lists=False):
        if list_of_lists:
            if not self.use_bert:
                x_tokens = self.fit_samples_with_sentences(X,remove_stop_words=remove_stop_words) #self.fit_document(X)
            else:
                x_tokens = self.fit_bert_sentences(X, remove_stop_words=remove_stop_words)
        else:
            x_tokens = self.fit_text(X)

        self.vocabulary.filter_extremes(no_below=self.mi_occur, no_above=1.0, keep_tokens=self.reserved)
        unknown_words = []
        if self.glove:
            #spell = Spellchecker()
            print("Vocabulary Size:",self.get_vocabulary_size())
            self.embedding_matrix = np.zeros((self.get_vocabulary_size(), self.embedding_size))
            for word, i in tqdm(self.vocabulary.token2id.items()):
                if word == "<PAD>":
                    embedding_value = np.zeros((1, self.embedding_size))
                elif word == "<UNK>":
                    sd =  1/np.sqrt(self.embedding_size)
                    np.random.seed(seed=42)
                    embedding_value = np.random.normal(0, scale=sd, size=[1, self.embedding_size])
                else:
                    embedding_value = self.embedding_word_vector.get(word)
                    if embedding_value is None:
                        embedding_value = self.embedding_word_vector.get(self.correct_word(word))
                        if embedding_value is None:
                            unknown_words.append(word)
                            embedding_value = self.embedding_word_vector.get("<UNK>")

                if embedding_value is not None:
                    self.embedding_matrix[i] = embedding_value
        print("Number of unknown words:",len(unknown_words))
        unknown_words_df = pd.DataFrame()
        unknown_words_df["Unknown Words"] = unknown_words
        encoded_tokens = self.transform(x_tokens, list_of_lists=list_of_lists)
        return  encoded_tokens

    def fit_text(self, X, remove_stop_words=True):
        output_tokens = []
        for sample in tqdm(X):
            tokens = self.tokenizer(sample.lower())
            if remove_stop_words:
                tokens = [token for token in tokens if not token.is_stop]
            word_str_tokens = list(map(convert_to_string, tokens))
            output_tokens.append(word_str_tokens)
        self.vocabulary.add_documents(output_tokens)
        return output_tokens

    def correct_word(self, word):
        return word

    def transform(self, X, list_of_lists=False):
        if list_of_lists:
            if not self.use_bert:
                return self.transform_list_of_list(X)
            else:
                return self.transform_bert(X)
        else:
            return self.transform_text(X)

    def transform_list_of_list(self, samples):
        samples_tokens = []
        for sample in samples:
            encoded_tokens = self.transform_text(sample)
            samples_tokens.append(encoded_tokens)
        return samples_tokens

    def transform_document(self, documents):
        document_tokens = []
        for document in documents:
            section_tokens = []
            encoded_tokens = []
            for section in document:
                if type(section) == str:
                    encoded_tokens.append(section)
                    if len(encoded_tokens) == len(document):
                        section_tokens.append(encoded_tokens)
                        section_tokens = self.transform_text(section_tokens)
                else:
                    encoded_tokens = self.transform_text(section)
                    section_tokens.append(encoded_tokens)
            document_tokens.append(section_tokens)
        return document_tokens

    def transform_bert(self, samples):
        samples_tokens = []
        for sample in samples:
            encoded_sentences = []
            for sentence_tokens in sample:
                encoded_tokens = self.tokenizer.convert_tokens_to_ids(sentence_tokens)
                encoded_sentences.append(encoded_tokens)
            samples_tokens.append(encoded_sentences)
        return samples_tokens

    def transform_text(self, X):
        if hasattr(self, "limit"):
            return [[i if i < self.limit else self.reserved.index("<UNK>")
                     for i in self.vocabulary.doc2idx(x, unknown_word_index=self.reserved.index("<UNK>"))]
                    for x in X]
        else:
            return [self.vocabulary.doc2idx(x, unknown_word_index=self.reserved.index("<UNK>")) for x in X]

    def inverse_transform(self, X):
        return [[ self.vocabulary[i] for i in x ] for x in X]

    def save(self, file_path="./vecorizer.vec"):
        with open(file_path, "wb") as handle:
            pickle.dump(self, handle, protocol=pickle.HIGHEST_PROTOCOL)
        return file_path

    @classmethod
    def load(cls, file_path):
        with open(file_path, "rb") as handle:
            self = pickle.load(handle)
        return self
Beispiel #2
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--data_dir",
                        default="./data/SST-2/",
                        type=str,
                        help="The input dir")

    parser.add_argument(
        "--bert_model_path",
        default="bert-base-uncased",
        type=str,
        help="bert model name or path. leave it bank if you are using Glove")

    parser.add_argument(
        "--output_dir",
        default="./output_SanText/QNLI/",
        type=str,
        help=
        "The output directory where the model predictions and checkpoints will be written.",
    )

    parser.add_argument(
        "--word_embedding_path",
        default='./data/glove.840B.300d.txt',
        type=str,
        help=
        "The pretrained word embedding path. leave it blank if you are using BERT",
    )

    parser.add_argument(
        "--word_embedding_size",
        default=300,
        type=int,
        help=
        "The pretrained word embedding size. leave it blank if you are using BERT",
    )

    parser.add_argument('--method',
                        choices=['SanText', 'SanText_plus'],
                        default='SanText_plus',
                        help='Sanitized method')

    parser.add_argument('--embedding_type',
                        choices=['glove', 'bert'],
                        default='glove',
                        help='embedding used for sanitization')

    parser.add_argument('--task',
                        choices=['CliniSTS', "SST-2", "QNLI"],
                        default='SST-2',
                        help='NLP eval tasks')

    parser.add_argument("--seed",
                        type=int,
                        default=42,
                        help="random seed for initialization")

    parser.add_argument("--epsilon",
                        type=float,
                        default=15,
                        help="privacy parameter epsilon")
    parser.add_argument(
        "--p",
        type=float,
        default=0.2,
        help="SanText+: probability of non-sensitive words to be sanitized")

    parser.add_argument(
        "--sensitive_word_percentage",
        type=float,
        default=0.5,
        help="SanText+: how many words are treated as sensitive")

    parser.add_argument("--threads",
                        type=int,
                        default=12,
                        help="number of processors")

    args = parser.parse_args()

    set_seed(args)

    logging.basicConfig(
        format="%(asctime)s -  %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )

    logger.info(
        "Running method: %s, task: %s,  epsilon = %s, random_seed: %d" %
        (args.method, args.task, args.epsilon, args.seed))

    if args.method == "SanText":
        args.sensitive_word_percentage = 1.0
        args.output_dir = os.path.join(args.output_dir,
                                       "eps_%.2f" % args.epsilon)
    else:
        args.output_dir = os.path.join(
            args.output_dir, "eps_%.2f" % args.epsilon,
            "sword_%.2f_p_%.2f" % (args.sensitive_word_percentage, args.p))

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    logger.info("Building Vocabulary...")

    if args.embedding_type == "glove":
        tokenizer = English()
        tokenizer_type = "word"
    else:
        tokenizer = BertTokenizer.from_pretrained(args.bert_model_path)
        tokenizer_type = "subword"
    if args.task == "SST-2":
        vocab = get_vocab_SST2(args.data_dir,
                               tokenizer,
                               tokenizer_type=tokenizer_type)
    elif args.task == "CliniSTS":
        vocab = get_vocab_CliniSTS(args.data_dir,
                                   tokenizer,
                                   tokenizer_type=tokenizer_type)
    elif args.task == "QNLI":
        vocab = get_vocab_QNLI(args.data_dir,
                               tokenizer,
                               tokenizer_type=tokenizer_type)
    else:
        raise NotImplementedError

    sensitive_word_count = int(args.sensitive_word_percentage * len(vocab))
    words = [key for key, _ in vocab.most_common()]
    sensitive_words = words[-sensitive_word_count - 1:]

    sensitive_words2id = {word: k for k, word in enumerate(sensitive_words)}
    logger.info("#Total Words: %d, #Sensitive Words: %d" %
                (len(words), len(sensitive_words2id)))

    sensitive_word_embed = []
    all_word_embed = []

    word2id = {}
    sword2id = {}
    sensitive_count = 0
    all_count = 0
    if args.embedding_type == "glove":
        num_lines = sum(1 for _ in open(args.word_embedding_path))
        logger.info("Loading Word Embedding File: %s" %
                    args.word_embedding_path)

        with open(args.word_embedding_path) as f:
            # Skip first line if of form count/dim.
            line = f.readline().rstrip().split(' ')
            if len(line) != 2:
                f.seek(0)
            for row in tqdm(f, total=num_lines - 1):
                content = row.rstrip().split(' ')
                cur_word = word_normalize(content[0])
                if cur_word in vocab and cur_word not in word2id:
                    word2id[cur_word] = all_count
                    all_count += 1
                    emb = [float(i) for i in content[1:]]
                    all_word_embed.append(emb)
                    if cur_word in sensitive_words2id:
                        sword2id[cur_word] = sensitive_count
                        sensitive_count += 1
                        sensitive_word_embed.append(emb)
                assert len(word2id) == len(all_word_embed)
                assert len(sword2id) == len(sensitive_word_embed)
            f.close()
    else:
        logger.info("Loading BERT Embedding File: %s" % args.bert_model_path)
        model = BertForMaskedLM.from_pretrained(args.bert_model_path)
        embedding_matrix = model.bert.embeddings.word_embeddings.weight.data.cpu(
        ).numpy()

        for cur_word in tokenizer.vocab:
            if cur_word in vocab and cur_word not in word2id:
                word2id[cur_word] = all_count
                emb = embedding_matrix[tokenizer.convert_tokens_to_ids(
                    cur_word)]
                all_word_embed.append(emb)
                all_count += 1

                if cur_word in sensitive_words2id:
                    sword2id[cur_word] = sensitive_count
                    sensitive_count += 1
                    sensitive_word_embed.append(emb)
            assert len(word2id) == len(all_word_embed)
            assert len(sword2id) == len(sensitive_word_embed)

    all_word_embed = np.array(all_word_embed, dtype='f')
    sensitive_word_embed = np.array(sensitive_word_embed, dtype='f')

    logger.info("All Word Embedding Matrix: %s" % str(all_word_embed.shape))
    logger.info("Sensitive Word Embedding Matrix: %s" %
                str(sensitive_word_embed.shape))

    logger.info("Calculating Prob Matrix for Exponential Mechanism...")
    prob_matrix = cal_probability(all_word_embed, sensitive_word_embed,
                                  args.epsilon)

    threads = min(args.threads, cpu_count())

    for file_name in ['train.tsv', 'dev.tsv']:
        data_file = os.path.join(args.data_dir, file_name)
        out_file = open(os.path.join(args.output_dir, file_name), 'w')
        logger.info("Processing file: %s. Will write to: %s" %
                    (data_file, os.path.join(args.output_dir, file_name)))

        num_lines = sum(1 for _ in open(data_file))
        with open(data_file, 'r') as rf:
            # header
            header = next(rf)
            out_file.write(header)
            labels = []
            docs = []
            if args.task == "SST-2":
                for line in tqdm(rf, total=num_lines - 1):
                    content = line.strip().split("\t")
                    text = content[0]
                    label = int(content[1])
                    if args.embedding_type == "glove":
                        doc = [token.text for token in tokenizer(text)]
                    else:
                        doc = tokenizer.tokenize(text)
                    docs.append(doc)
                    labels.append(label)
            elif args.task == "CliniSTS":
                for line in tqdm(rf, total=num_lines - 1):
                    content = line.strip().split("\t")
                    text1 = content[7]
                    text2 = content[8]
                    label = content[-1]
                    if args.embedding_type == "glove":
                        doc1 = [token.text for token in tokenizer(text1)]
                        doc2 = [token.text for token in tokenizer(text2)]
                    else:
                        doc1 = tokenizer.tokenize(text1)
                        doc2 = tokenizer.tokenize(text2)
                    docs.append(doc1)
                    docs.append(doc2)
                    labels.append(label)
            elif args.task == "QNLI":
                for line in tqdm(rf, total=num_lines - 1):
                    content = line.strip().split("\t")
                    text1 = content[1]
                    text2 = content[2]
                    label = content[-1]
                    if args.embedding_type == "glove":
                        doc1 = [token.text for token in tokenizer(text1)]
                        doc2 = [token.text for token in tokenizer(text2)]
                    else:
                        doc1 = tokenizer.tokenize(text1)
                        doc2 = tokenizer.tokenize(text2)

                    docs.append(doc1)
                    docs.append(doc2)
                    labels.append(label)

            rf.close()

        with Pool(threads,
                  initializer=SanText_plus_init,
                  initargs=(prob_matrix, word2id, sword2id, words, args.p,
                            tokenizer)) as p:
            annotate_ = partial(SanText_plus, )
            results = list(
                tqdm(
                    p.imap(annotate_, docs, chunksize=32),
                    total=len(docs),
                    desc="Sanitize docs using SanText",
                ))
            p.close()

        logger.info("Saving ...")

        if args.task == "SST-2":
            for i, predicted_text in enumerate(results):
                write_content = predicted_text + "\t" + str(labels[i]) + "\n"
                out_file.write(write_content)
        elif args.task == "CliniSTS":
            assert len(results) / 2 == len(labels)
            for i in range(len(labels)):
                predicted_text1 = results[i * 2]
                predicted_text2 = results[i * 2 + 1]
                write_content = str(
                    i
                ) + "\t" + "none\t" * 6 + predicted_text1 + "\t" + predicted_text2 + "\t" + str(
                    labels[i]) + "\n"
                out_file.write(write_content)
        elif args.task == "QNLI":
            assert len(results) / 2 == len(labels)
            for i in range(len(labels)):
                predicted_text1 = results[i * 2]
                predicted_text2 = results[i * 2 + 1]
                write_content = str(
                    i
                ) + "\t" + predicted_text1 + "\t" + predicted_text2 + "\t" + str(
                    labels[i]) + "\n"
                out_file.write(write_content)

        out_file.close()
Beispiel #3
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--data_dir",
                        default="./data/SST-2/",
                        type=str,
                        help="The input dir")

    parser.add_argument("--model_path",
                        default="bert-base-uncased",
                        type=str,
                        help="bert model_path")

    parser.add_argument(
        "--output_dir",
        default="./output_SanText/SST-2/",
        type=str,
        help=
        "The output directory where the model predictions and checkpoints will be written.",
    )

    parser.add_argument(
        "--word_embedding_path",
        default='./data/glove.840B.300d.txt',
        type=str,
        help="The pretrained word embedding path",
    )

    parser.add_argument(
        "--word_embedding_size",
        default=300,
        type=int,
        help="The pretrained word embedding size",
    )

    parser.add_argument('--method',
                        choices=['WarmUp', 'SanText'],
                        default='WarmUp',
                        help='Sanitized method')

    parser.add_argument('--embedding_type',
                        choices=['glove', 'bert'],
                        default='bert',
                        help='embedding used for sanitization')
    parser.add_argument(
        "--max_seq_length",
        default=64,
        type=int,
        help="Optional input sequence length after tokenization."
        "The training dataset will be truncated in block of this size for training.",
    )

    parser.add_argument(
        "--batch_size",
        default=256,
        type=int,
        help="batch size",
    )

    parser.add_argument('--task',
                        choices=['CliniSTS', "SST-2", "QNLI"],
                        default='SST-2',
                        help='Sanitized method')

    parser.add_argument("--seed",
                        type=int,
                        default=42,
                        help="random seed for initialization")

    parser.add_argument("--epsilon",
                        type=float,
                        default=10000.0,
                        help="privacy parameter epsilon")
    parser.add_argument(
        "--p",
        type=float,
        default=0.2,
        help="probability of non-sensitive words to be sanitized")

    parser.add_argument("--sensitive_word_percentage",
                        type=float,
                        default=0.5,
                        help="how many words are treated as sensitive")

    parser.add_argument("--threads",
                        type=int,
                        default=12,
                        help="number of processors")

    args = parser.parse_args()

    set_seed(args)

    logging.basicConfig(
        format="%(asctime)s -  %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )

    logger.info(
        "Running method: %s, task: %s,  epsilon = %s, random_seed: %d" %
        (args.method, args.task, args.epsilon, args.seed))

    if args.method == "WarmUp":
        args.sensitive_word_percentage = 1.0

    args.output_dir = os.path.join(
        args.output_dir, "eps_%.2f" % args.epsilon, "seed_" + str(args.seed),
        "sword_%.2f_p_%.2f" % (args.sensitive_word_percentage, args.p))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    logger.info("Building Vocabulary...")

    if args.embedding_type == "glove":
        tokenizer = English()
        tokenizer_type = "word"
    else:
        tokenizer = BertTokenizer.from_pretrained(args.model_path)
        tokenizer_type = "subword"
    if args.task == "SST-2":
        vocab = get_vocab_SST2(args.data_dir,
                               tokenizer,
                               tokenizer_type=tokenizer_type)
    elif args.task == "CliniSTS":
        vocab = get_vocab_CliniSTS(args.data_dir,
                                   tokenizer,
                                   tokenizer_type=tokenizer_type)
    elif args.task == "QNLI":
        vocab = get_vocab_QNLI(args.data_dir,
                               tokenizer,
                               tokenizer_type=tokenizer_type)
    else:
        raise NotImplementedError

    sensitive_word_count = int(args.sensitive_word_percentage * len(vocab))
    words = [key for key, _ in vocab.most_common()]
    sensitive_words = words[-sensitive_word_count - 1:]

    sensitive_words2id = {word: k for k, word in enumerate(sensitive_words)}
    logger.info("#Total Words: %d, #Sensitive Words: %d" %
                (len(words), len(sensitive_words2id)))

    sensitive_word_embed = []
    all_word_embed = []

    word2id = {}
    sword2id = {}
    sensitive_count = 0
    all_count = 0
    if args.embedding_type == "glove":
        num_lines = sum(1 for _ in open(args.word_embedding_path))
        logger.info("Loading Word Embedding File: %s" %
                    args.word_embedding_path)

        with open(args.word_embedding_path) as f:
            # Skip first line if of form count/dim.
            line = f.readline().rstrip().split(' ')
            if len(line) != 2:
                f.seek(0)
            for row in tqdm(f, total=num_lines - 1):
                content = row.rstrip().split(' ')
                cur_word = word_normalize(content[0])
                if cur_word in vocab and cur_word not in word2id:
                    word2id[cur_word] = all_count
                    all_count += 1
                    emb = [float(i) for i in content[1:]]
                    all_word_embed.append(emb)
                    if cur_word in sensitive_words2id:
                        sword2id[cur_word] = sensitive_count
                        sensitive_count += 1
                        sensitive_word_embed.append(emb)
                assert len(word2id) == len(all_word_embed)
                assert len(sword2id) == len(sensitive_word_embed)
            f.close()
    else:
        logger.info("Loading BERT Embedding File: %s" % args.model_path)
        model = BertForMaskedLM.from_pretrained(args.model_path)
        embedding_matrix = model.bert.embeddings.word_embeddings.weight.data.cpu(
        ).numpy()

        for cur_word in tokenizer.vocab:
            if cur_word in vocab and cur_word not in word2id:
                word2id[cur_word] = all_count
                emb = embedding_matrix[tokenizer.convert_tokens_to_ids(
                    cur_word)]
                all_word_embed.append(emb)
                all_count += 1

                if cur_word in sensitive_words2id:
                    sword2id[cur_word] = sensitive_count
                    sensitive_count += 1
                    sensitive_word_embed.append(emb)
            assert len(word2id) == len(all_word_embed)
            assert len(sword2id) == len(sensitive_word_embed)

    all_word_embed = np.array(all_word_embed, dtype='f')
    sensitive_word_embed = np.array(sensitive_word_embed, dtype='f')

    logger.info("All Word Embedding Matrix: %s" % str(all_word_embed.shape))
    logger.info("Sensitive Word Embedding Matrix: %s" %
                str(sensitive_word_embed.shape))

    logger.info("Calculating Prob Matrix for Exponential Mechanism...")
    prob_matrix = cal_probability(all_word_embed, sensitive_word_embed,
                                  args.epsilon)

    threads = min(args.threads, cpu_count())

    for file_name in ['dev.tsv']:
        data_file = os.path.join(args.data_dir, file_name)
        out_file = open(os.path.join(args.output_dir, file_name), 'w')
        logger.info("Processing file: %s. Will write to: %s" %
                    (data_file, os.path.join(args.output_dir, file_name)))

        num_lines = sum(1 for _ in open(data_file))
        with open(data_file, 'r') as rf:
            # header
            header = next(rf)
            out_file.write(header)
            labels = []
            docs = []
            if args.task == "SST-2":
                for line in tqdm(rf, total=num_lines - 1):
                    content = line.strip().split("\t")
                    text = content[0]
                    label = int(content[1])
                    if args.embedding_type == "glove":
                        doc = [token.text for token in tokenizer(text)]
                    else:
                        doc = tokenizer.tokenize(text)
                    docs.append(doc)
                    labels.append(label)
            elif args.task == "CliniSTS":
                for line in tqdm(rf, total=num_lines - 1):
                    content = line.strip().split("\t")
                    text1 = content[7]
                    text2 = content[8]
                    label = content[-1]
                    if args.embedding_type == "glove":
                        doc1 = [token.text for token in tokenizer(text1)]
                        doc2 = [token.text for token in tokenizer(text2)]
                    else:
                        doc1 = tokenizer.tokenize(text1)
                        doc2 = tokenizer.tokenize(text2)
                    docs.append(doc1)
                    docs.append(doc2)
                    labels.append(label)
            elif args.task == "QNLI":
                for line in tqdm(rf, total=num_lines - 1):
                    content = line.strip().split("\t")
                    text1 = content[1]
                    text2 = content[2]
                    label = content[-1]
                    if args.embedding_type == "glove":
                        doc1 = [token.text for token in tokenizer(text1)]
                        doc2 = [token.text for token in tokenizer(text2)]
                    else:
                        doc1 = tokenizer.tokenize(text1)
                        doc2 = tokenizer.tokenize(text2)

                    docs.append(doc1)
                    # docs.append(doc2)
                    labels.append(label)

            rf.close()

        with Pool(threads,
                  initializer=SanText_init,
                  initargs=(prob_matrix, word2id, sword2id, words, args.p,
                            tokenizer)) as p:
            annotate_ = partial(SanText, )
            results = list(
                tqdm(
                    p.imap(annotate_, docs, chunksize=32),
                    total=len(docs),
                    desc="Sanitize docs using SanText",
                ))
            p.close()

        logger.info("Saving ...")

        if args.max_seq_length <= 0:
            args.max_seq_length = (
                tokenizer.max_len_single_sentence
            )  # Our input block size will be the max possible for the model
        args.max_seq_length = min(args.max_seq_length,
                                  tokenizer.max_len_single_sentence)
        args.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        model.to(args.device)
        model = torch.nn.DataParallel(model)
        tokenized_new_docs = []
        labels = []
        for i, new_doc in enumerate(results):
            assert len(docs[i]) == len(new_doc)
            for j in range(len(new_doc)):
                tmp_doc = copy.deepcopy(new_doc)
                tmp_doc[j] = "[MASK]"
                tokenized_new_docs.append(
                    tokenizer.encode_plus(tmp_doc,
                                          padding="max_length",
                                          max_length=args.max_seq_length,
                                          truncation=True))
                labels.append(tokenizer.convert_tokens_to_ids(docs[i][j]))

        all_input_ids = torch.tensor(
            [doc.data['input_ids'] for doc in tokenized_new_docs],
            dtype=torch.long)
        all_token_type_ids = torch.tensor(
            [doc.data['token_type_ids'] for doc in tokenized_new_docs],
            dtype=torch.long)
        all_attention_mask = torch.tensor(
            [doc.data['attention_mask'] for doc in tokenized_new_docs],
            dtype=torch.long)
        all_labels = torch.tensor(labels, dtype=torch.long)
        dataset = TensorDataset(all_input_ids, all_token_type_ids,
                                all_attention_mask, all_labels)
        sampler = SequentialSampler(dataset)
        dataloader = DataLoader(dataset,
                                sampler=sampler,
                                batch_size=args.batch_size)

        intersect_num = 0
        total_num = 0
        for batch in tqdm(dataloader):
            batch = tuple(t.to(args.device) for t in batch)
            with torch.no_grad():
                inputs = {
                    "input_ids": batch[0],
                    "attention_mask": batch[1],
                    "token_type_ids": batch[2]
                }

                prediction = model(**inputs)[0]
                prediction = torch.argmax(prediction, dim=2)
                prediction = prediction[torch.where(batch[0] == 103)]
                ground_truths = batch[3]
                intersect_num += (prediction == ground_truths).sum()
                total_num += len(prediction)

        print(intersect_num.item())
        print(total_num)
        print(1.0 * intersect_num.item() / total_num)