예제 #1
0
    def preProcessData(self):

        self.toi_box_batch, self.label_batch, _ = select_toi(self.toi_batch)

        self.toi_box_entity, _, self.space_list_entity = select_toi(
            self.entity_batch)
        self.word_batch_var = self.whetherUseGpu(
            Variable(torch.LongTensor(np.array(self.word_batch))),
            self.config.if_gpu)
        self.mask_batch_var = self.whetherUseGpu(
            generate_mask(self.word_batch_var.shape), self.config.if_gpu)
        self.char_batch_var = self.whetherUseGpu(
            Variable(torch.LongTensor(np.array(self.char_batch))),
            self.config.if_gpu)
        self.pos_tag_batch_var = self.whetherUseGpu(
            Variable(torch.LongTensor(np.array(self.pos_tag_batch))),
            self.config.if_gpu)
        self.gold_label_vec = self.whetherUseGpu(
            Variable(torch.LongTensor(np.hstack(self.label_batch))),
            self.config.if_gpu)
        self.entity_nested_depth = []
        for each_entiy_list in self.entity_batch:
            gt_layer = [-1] * len(each_entiy_list)
            for id in range(len(gt_layer)):
                if gt_layer[id] == -1:
                    dfs(id, each_entiy_list, gt_layer)
            self.entity_nested_depth.append(gt_layer)
        self.entity_nested_depth = np.hstack(self.entity_nested_depth)
        self.entity_nested_depth[np.where(
            self.entity_nested_depth >= self.config.nested_depth
        )] = self.config.nested_depth - 1
        self.entity_nested_depth = self.whetherUseGpu(
            Variable(torch.LongTensor(self.entity_nested_depth)),
            self.config.if_gpu)
예제 #2
0
    def launching_kosaraju(self):
        """
        Launches Kosaraju on the graph stored
        :return: an array of the scc contained in the graph
        """
        # we launch the first dfs on the graph and then we reverse the graph
        dfs_result = utils.full_dfs(self.graph)
        reverse_graph = utils.reverse_succ(self.graph)
        scc = []

        # while we haven't visited every node from the dfs result we launch a dfs on the inversed graph
        # starting from the node at the top of the stack (the dfs result) and we create a scc from all
        # nodes visited
        while (len(dfs_result) > 0):
            node = dfs_result.pop()
            (res, self.visited) = utils.dfs(reverse_graph, node, self.visited, [])
            current_scc = ""
            for i in res:
                current_scc += str(i) + " "
                # current_scc += chr(i + 97) # results as char from a-z
                if (node != i):
                    # complexity in O(n) could cost some performance loss for remove
                    # but both remove and pop give the same results on my small test cases
                    # and when plotting complexity both take the same time
                    dfs_result.remove(i)
                    # dfs_result.pop()

            scc.append(current_scc.rstrip())

        return scc
예제 #3
0
    def to_batch(self, mode):
        word_dic = defaultdict(list)
        char_dic = defaultdict(list)
        pos_tag_dic = defaultdict(list)
        entity_dic = defaultdict(list)
        char_len_dic = defaultdict(list)
        toi_dic = defaultdict(list)
        toi_dic_layer0 = defaultdict(list)
        toi_dic_layer1 = defaultdict(list)
        origin_word = defaultdict(list)

        word_batches = []
        char_batches = []
        char_len_batches = []
        pos_tag_batches = []
        entity_batches = []
        toi_batches = []
        toi_batches_layer0 = []
        toi_batches_layer1 = []
        word_origin_batches = []
        len_entity = np.zeros((2, 200))

        for i, sent_info in enumerate(self.infos[mode]):
            entity_vec = [(e[0], e[1], self.label2id[e[2]])
                          for e in sent_info.entities]
            #entity_vec =  sorted(entity_vec,key=lambda x: (x[0], x[1]))
            word_vec = [self.word2id[w] for w in sent_info.words]
            word_num = len(word_vec)

            char_mat = [[self.char2id[c] for c in w] for w in sent_info.chars]
            char_len_vec = [len(w) for w in char_mat]
            pad_id = self.char2id["."]
            char_mat = [
                w + [pad_id] * (self.max_word_len - len(w)) for w in char_mat
            ]

            pos_tag_vec = [self.pos_tag2id[p] for p in sent_info.pos_tags]

            gt_layer = [-1] * len(entity_vec)
            for id in range(len(gt_layer)):
                if gt_layer[id] == -1:
                    dfs(id, entity_vec, gt_layer)

            # calculating the distribution of length in different layers
            for each in range(len(entity_vec)):
                if gt_layer[each] == 0:
                    len_entity[0][entity_vec[each][1] -
                                  entity_vec[each][0]] += 1
                else:
                    len_entity[1][entity_vec[each][1] -
                                  entity_vec[each][0]] += 1

            entity_layer0 = []
            entity_layer1 = []
            entity_different_layer = []

            for id in range(len(gt_layer)):
                if gt_layer[id] == 0:
                    entity_layer0.append(entity_vec[id])
                    entity_different_layer.append(entity_vec[id])
                else:
                    entity_layer1.append((entity_vec[id][0], entity_vec[id][1],
                                          entity_vec[id][2]))
                    entity_different_layer.append(
                        (entity_vec[id][0], entity_vec[id][1],
                         entity_vec[id][2]))

            tois = self.generate_toi(word_num, entity_vec, mode, -1)

            if mode == 'train':
                for each in entity_vec:
                    if not each in tois:
                        tois.append(each)
            tois = sorted(tois)

            word_dic[word_num].append(word_vec)
            char_dic[word_num].append(char_mat)
            char_len_dic[word_num].append(char_len_vec)
            pos_tag_dic[word_num].append(pos_tag_vec)
            entity_dic[word_num].append(entity_vec)
            toi_dic[word_num].append(tois)
            origin_word[word_num].append(sent_info.words)

        for each in len_entity:
            for i in range(len(each)):
                print(each[i], end=' ')
            print()
        for length in word_dic.keys():
            word_batch = [
                word_dic[length][i:i + self.config.batch_size] for i in range(
                    0, len(word_dic[length]), self.config.batch_size)
            ]
            char_batch = [
                char_dic[length][i:i + self.config.batch_size] for i in range(
                    0, len(char_dic[length]), self.config.batch_size)
            ]
            char_len_batch = [
                char_len_dic[length][i:i + self.config.batch_size] for i in
                range(0, len(char_len_dic[length]), self.config.batch_size)
            ]
            pos_tag_batch = [
                pos_tag_dic[length][i:i + self.config.batch_size] for i in
                range(0, len(pos_tag_dic[length]), self.config.batch_size)
            ]
            entity_batch = [
                entity_dic[length][i:i + self.config.batch_size] for i in
                range(0, len(entity_dic[length]), self.config.batch_size)
            ]
            toi_batch = [
                toi_dic[length][i:i + self.config.batch_size]
                for i in range(0, len(toi_dic[length]), self.config.batch_size)
            ]

            word_origin_batch = [
                origin_word[length][i:i + self.config.batch_size] for i in
                range(0, len(origin_word[length]), self.config.batch_size)
            ]

            word_batches.extend(word_batch)
            char_batches.extend(char_batch)
            char_len_batches.extend(char_len_batch)
            pos_tag_batches.extend(pos_tag_batch)
            entity_batches.extend(entity_batch)
            toi_batches.extend(toi_batch)
            word_origin_batches.extend(word_origin_batch)

        return (word_batches, char_batches, char_len_batches, pos_tag_batches,
                entity_batches, toi_batches, word_origin_batches)
예제 #4
0
    def preProcessData(self):

        self.toi_box_batch, self.label_batch, _ = select_toi(self.toi_batch)

        self.toi_box_entity, _, self.space_list_entity = select_toi(
            self.entity_batch)
        self.word_batch_var = self.whetherUseGpu(
            Variable(torch.LongTensor(np.array(self.word_batch))),
            self.config.if_gpu)
        self.mask_batch_var = self.whetherUseGpu(
            generate_mask(self.word_batch_var.shape), self.config.if_gpu)
        self.char_batch_var = self.whetherUseGpu(
            Variable(torch.LongTensor(np.array(self.char_batch))),
            self.config.if_gpu)
        self.pos_tag_batch_var = self.whetherUseGpu(
            Variable(torch.LongTensor(np.array(self.pos_tag_batch))),
            self.config.if_gpu)
        self.gold_label_vec = self.whetherUseGpu(
            Variable(torch.LongTensor(np.hstack(self.label_batch))),
            self.config.if_gpu)
        self.entity_nested_depth = []
        for each_entiy_list in self.entity_batch:
            gt_layer = [-1] * len(each_entiy_list)
            for id in range(len(gt_layer)):
                if gt_layer[id] == -1:
                    dfs(id, each_entiy_list, gt_layer)
            self.entity_nested_depth.append(gt_layer)
        self.entity_nested_depth = np.hstack(self.entity_nested_depth)
        self.entity_nested_depth[np.where(
            self.entity_nested_depth >= self.config.nested_depth
        )] = self.config.nested_depth - 1
        self.entity_nested_depth = self.whetherUseGpu(
            Variable(torch.LongTensor(self.entity_nested_depth)),
            self.config.if_gpu)

        if self.config.use_bert:
            # bert initiallization
            tokens_tensors = []
            for each in self.word_origin_batch:
                text = "[CLS] " + " ".join(each) + " [SEP]"

                if text not in self.text_to_bert:
                    tokens_length = []
                    text_subwords = self.tokenizer.tokenize(text)
                    st = 0
                    for each_word in each:
                        aim = self.tokenizer.tokenize(each_word)
                        tokens_length.append([st, st + len(aim)])
                        st = st + len(aim)
                    tokens_length = np.array(tokens_length)
                    sub_tokens = torch.tensor([
                        self.tokenizer.convert_tokens_to_ids(text_subwords)
                    ]).cuda()
                    text_embedding, _ = self.bertModel(sub_tokens)
                    text_embedding = torch.cat(text_embedding,
                                               dim=0).squeeze(1)
                    word_embedding = self.bertModel.embeddings.word_embeddings(
                        sub_tokens)
                    word_embedding = self.bertModel.embeddings.LayerNorm(
                        word_embedding)
                    text_embedding = torch.cat(
                        (text_embedding, word_embedding), dim=0)[:, 1:-1, :]
                    cumsum = torch.cat([
                        torch.zeros(text_embedding.size(0), 1,
                                    text_embedding.size(2)).cuda(),
                        torch.cumsum(text_embedding, 1)
                    ],
                                       dim=1)
                    boundary_len = Variable(
                        torch.FloatTensor(tokens_length[:, 1] -
                                          tokens_length[:, 0]),
                        requires_grad=False).cuda()
                    hidden_list = (
                        cumsum[:, tokens_length[:, 1], :] -
                        cumsum[:, tokens_length[:, 0], :]) / boundary_len.view(
                            1, boundary_len.size(0), 1)
                    self.text_to_bert_out[self.cnt] = hidden_list.cpu().numpy()
                    self.text_to_bert[text] = self.cnt
                    self.cnt += 1
                else:
                    hidden_list = torch.tensor(self.text_to_bert_out[
                        self.text_to_bert.get(text)]).cuda()
                tokens_tensors.append(hidden_list.unsqueeze(0))

            hidden_list = torch.cat(tokens_tensors, dim=0)

            if not self.config.fusion:
                hidden_list = hidden_list[:, -2:-1, :, :]
            else:
                pass
            self.hiddenList = hidden_list
        else:
            self.hiddenList = None
예제 #5
0
def main(args):
    # Initial process
    args = vars(args)
    unicode_enc = args['unicode_enc']  # 选择编码方式
    mode = args['mode']  # 选择隐写算法
    block_size = args['block_size']  # 隐写参数batch_size
    temp = args['temp']  # 隐写参数TEMPERATURE,注意下文中最好不要新建temp变量
    precision = args['precision']  # 隐写参数
    topk = args['topk']  # 文本生成相关参数
    device = args['device']  # device,文本生成相关参数,选择GPU/CPU,默认'cuda'
    finish_sent = args['finish_sent']  # 隐写参数
    nucleus = args['nucleus']  # saac相关隐写参数
    delta = args['delta']  # saac相关隐写参数
    model_name = args['language_model']  # 文本生成模型
    context_file = args['context_file']  # 上下文文件的位置
    message_str = args['name']
    # sample_tokens = 100               # 测试用变量

    # PARAMETERS 默认第一次的隐写信息(人名)
    # message_str = "Chhenl"              # string to be hidden.

    # VALIDATE PARAMETERS 验证隐写算法
    if mode not in ['arithmetic', 'huffman', 'bins', 'saac']:
        raise NotImplementedError

    # 打印隐写信息(人名)
    print("Default plain_text is ", message_str)

    # 读取上下文
    f = open(context_file, 'r', encoding='utf-8')
    context = f.read()
    f.close()
    print("sample context is ",
          context)  # related to the text generation procedure.

    # 加载文本生成模型
    print("loading GPT-2 LM to GPU")
    enc, model = get_model(model_name=model_name)
    print("finish loading !")

    print("implication of {}".format(mode))

    # bins隐写算法的处理
    if mode == 'bins':
        bin2words, words2bin = get_bins(len(enc.encoder), block_size)

    # saac隐写算法的处理
    if delta and mode == "saac":
        nucleus = 2**(-1.0 * delta)

    # 以下注释都为旧调试过程中的注释
    # fix situation: directly encode the text.
    # print("directly encode the plain txt:\n", enc.encode(message_str))
    # print("Decode back:\n", enc.decode(enc.encode(message_str)))

    # can ensure the problem arise in the arithmetic_decode as well as the arithmetic_encode function.

    # ----------------------start test----------------------------
    # test_str = "hello world."
    # print("test_str = ", test_str)
    # out = enc.encode(test_str)
    # print("out = ", out)
    # decode_str = enc.decode(out)
    # print("decode_str = ", decode_str)
    # print("enc.encode(decode_str) = ", enc.encode(decode_str))
    # ----------------------stop test-----------------------------

    # Archive Basic Initialization----------------------------------
    # print("plain_text is {}".format(message_str))
    # unicode_enc = False
    # mode = 'huffman'
    # block_size = 3 # for huffman and bins
    # temp = 0.9 # for arithmetic
    # precision = 26 # for arithmetic
    # sample_tokens = 100 # for sample, delete sample
    # topk = 300
    # device = 'cuda'
    # finish_sent=False # whether or not to force finish sent. If so, stats displayed will be for non-finished sentence
    # nucleus = 0.95
    # Archive Basic Initialization----------------------------------

    first_flag = 1  # 对下文中默认处理的标志
    context_tokens = encode_context(context, enc)  # 对context进行语言模型相关的编码

    while (1):
        # ---此处在循环中,则会不断等待输入隐写信息(人名)--------------------------------------
        # ------------------------------------------------------------------------------------
        # list_for_bpw = [] # 用于计算Bits/word参数
        # list_for_DKL = [] # 用于计算KL参数
        # list_for_seq = [] # 用于标记

        if first_flag == 0:
            message_str = input("Please reenter a new plaintext:")
            # output_amount = len(message_str)

        # 得到对隐写信息(人名)的大小写集合
        message_str = message_str.upper()
        arr = list(message_str)
        generated_array = dfs(arr, 0, [])

        first_flag = 0
        covertext_list = []

        for temp_count in range(0, len(generated_array)):
            # First encode message to uniform bits, without any context
            # (not essential this is arithmetic vs ascii, but it's more efficient when the message is natural language)

            # if temp_count > 10:
            #     break                 # 测试时最好完成修正,此处限制输出10个COVERTEXT

            print("=" * 80)
            print("Altering the #{} msg_str:".format(temp_count), message_str)
            message_str = generated_array[temp_count]  # 选择一个隐写信息(比如 KiErAn)

            # 得到message。即上文所述的字节流
            if unicode_enc:
                ba = bitarray.bitarray()
                ba.frombytes(message_str.encode('utf-8'))
                message = ba.tolist()
            else:
                message_ctx = [enc.encoder['<|endoftext|>']]
                message_str += '<eos>'
                message = decode_arithmetic(model,
                                            enc,
                                            message_str,
                                            message_ctx,
                                            precision=40,
                                            topk=60000)

            # print("First encode the text to a bit sequence!")
            # print(message)  # the binary stream. text--arithmetic-->binary stream
            # print("the length is {}".format(len(message)))

            # Next encode bits into cover text, using arbitrary context

            # 下方完成隐写算法,使用不同隐写算法将字节流嵌入进生成文本中,得到out经过GPT2的解码器得到COVERTEXT
            Hq = 0
            if mode == 'arithmetic':
                out, nll, kl, words_per_bit, Hq = encode_arithmetic(
                    model,
                    enc,
                    message,
                    context_tokens,
                    temp=temp,
                    finish_sent=finish_sent,
                    precision=precision,
                    topk=topk)
            elif mode == 'huffman':
                out, nll, kl, words_per_bit = encode_huffman(
                    model,
                    enc,
                    message,
                    context_tokens,
                    block_size,
                    finish_sent=finish_sent)
            elif mode == 'bins':
                out, nll, kl, words_per_bit = encode_block(
                    model,
                    enc,
                    message,
                    context_tokens,
                    block_size,
                    bin2words,
                    words2bin,
                    finish_sent=finish_sent)
            elif mode == 'saac':
                out, nll, kl, words_per_bit, Hq, topk_list, case_studies = encode_saac(
                    model,
                    enc,
                    message,
                    context_tokens,
                    device=device,
                    temp=temp,
                    precision=precision,
                    topk=topk,
                    nucleus=nucleus)
            #     add thing contains device='cuda', temp=1.0, precision=26, topk=50, nucleus=0.95.
            covertext = enc.decode(out)
            covertext_list.append(covertext)  # 将所有COVERTEXT保存到一个结构中,可供调用

            # list_for_bpw.append(1/words_per_bit)      # 用于计算参数
            # list_for_DKL.append(kl)                   # 用于计算参数
            # list_for_seq.append(temp_count)
            # print("="*40 + " Encoding " + "="*40)

            # 打印结果,COVERTEXT,此处可以将covertext进行提取。
            print(
                '#{} generated covertext:\n'.format(temp_count), covertext
            )  # covertext. generated covertext that contains secret information.
            print(
                'ppl: %0.2f, kl: %0.3f, words/bit: %0.2f, bits/word: %0.2f, entropy: %.2f'
                % (math.exp(nll), kl, words_per_bit, 1 / words_per_bit,
                   Hq / 0.69315))

            # -----------------------------------------------------------------------------------
            # 以下为隐写提取过程, 选择不同的隐写算法对covertext进行提取,得到字节流 MESSAGE_REC
            # Decode binary message from bits using the same arbitrary context

            # 下方在编写时可能会使用到,这里先注释掉,接收人将自己的名字和covertext输入进行判定。
            # input_name = input("Please input ur name:")
            # input_covertext = input("Please input the covertext:")
            # covertext = input_covertext

            if mode == 'arithmetic':
                message_rec = decode_arithmetic(model,
                                                enc,
                                                covertext,
                                                context_tokens,
                                                temp=temp,
                                                precision=precision,
                                                topk=topk)
            elif mode == 'huffman':
                message_rec = decode_huffman(model, enc, covertext,
                                             context_tokens, block_size)
            elif mode == 'bins':
                message_rec = decode_block(model, enc, covertext,
                                           context_tokens, block_size,
                                           bin2words, words2bin)
            elif mode == 'saac':
                message_rec = decode_saac(model,
                                          enc,
                                          covertext,
                                          context_tokens,
                                          device=device,
                                          temp=temp,
                                          precision=precision,
                                          topk=topk,
                                          nucleus=nucleus)

            # print("="*40 + " Recovered Message " + "="*40)
            # print(message_rec)  # binary stream extracted from stego_text.
            # print("=" * 80)
            # Finally map message bits back to original text

            # 对字节流进行解码操作,最终得到的reconst变量即为最终隐写提取所得,正常使用应为人名。
            if unicode_enc:
                message_rec = [bool(item) for item in message_rec]
                ba = bitarray.bitarray(message_rec)
                reconst = ba.tobytes().decode('utf-8', 'ignore')
            else:
                reconst = encode_arithmetic(model,
                                            enc,
                                            message_rec,
                                            message_ctx,
                                            precision=40,
                                            topk=60000)
                # reconst = encode_arithmetic(model, enc, message_rec, message_ctx, temp=temp, precision=precision, topk=topk)
                # print("reconst[0] is", format(reconst[0]))
                reconst = enc.decode(reconst[0])
            print("The decode text is ")
            print(reconst[0:-5]
                  )  # Decoded text. message_rec --arithmetic decode--> reconst
예제 #6
0
def encrypt(unicode_enc, mode, block_size, temp, precision, topk, device,
            finish_sent, model_name, delta, context, message_str):
    print("loading GPT-2 LM to GPU")
    enc, model = get_model(model_name=model_name)
    print("finish loading !")

    print("implication of {}".format(mode))
    if mode == 'bins':
        bin2words, words2bin = get_bins(len(enc.encoder), block_size)

    if delta and mode == "saac":
        nucleus = 2**(-1.0 * delta)

    first_flag = 1
    context_tokens = encode_context(context, enc)
    while (1):
        sentence_assmble = []
        if first_flag == 0:
            message_str = input("Please reenter a new plaintext:")
            # output_amount = len(message_str)
        message_str = message_str.upper()
        arr = list(message_str)
        generated_array = dfs(arr, 0, [])
        first_flag = 0
        for temp_count in range(0, len(generated_array)):
            # First encode message to uniform bits, without any context
            # (not essential this is arithmetic vs ascii, but it's more efficient when the message is natural language)

            # if temp_count > 10: # protect from running too much times.
            #     break

            print("=" * 80)
            print("Altering the #{} msg_str:".format(temp_count), message_str)
            message_str = generated_array[temp_count]

            if unicode_enc:
                ba = bitarray.bitarray()
                ba.frombytes(message_str.encode('utf-8'))
                message = ba.tolist()
            else:
                message_ctx = [enc.encoder['<|endoftext|>']]
                message_str += '<eos>'
                message = decode_arithmetic(model,
                                            enc,
                                            message_str,
                                            message_ctx,
                                            precision=40,
                                            topk=60000)
                # message = decode_arithmetic(model, enc, message_str, message_ctx, precision=precision, topk=topk, temp=temp)

            Hq = 0
            if mode == 'arithmetic':
                out, nll, kl, words_per_bit, Hq = encode_arithmetic(
                    model,
                    enc,
                    message,
                    context_tokens,
                    temp=temp,
                    finish_sent=finish_sent,
                    precision=precision,
                    topk=topk)
            elif mode == 'huffman':
                out, nll, kl, words_per_bit = encode_huffman(
                    model,
                    enc,
                    message,
                    context_tokens,
                    block_size,
                    finish_sent=finish_sent)
            elif mode == 'bins':
                out, nll, kl, words_per_bit = encode_block(
                    model,
                    enc,
                    message,
                    context_tokens,
                    block_size,
                    bin2words,
                    words2bin,
                    finish_sent=finish_sent)
                words_per_bit = 1
            elif mode == 'saac':
                out, nll, kl, words_per_bit, Hq, topk_list, case_studies = encode_saac(
                    model,
                    enc,
                    message,
                    context_tokens,
                    device=device,
                    temp=temp,
                    precision=precision,
                    topk=topk,
                    nucleus=nucleus)
            #     add thing contains device='cuda', temp=1.0, precision=26, topk=50, nucleus=0.95.
            text = enc.decode(out)
            # print("="*40 + " Encoding " + "="*40)
            print(
                '#{} generated covertext:\n'.format(temp_count), text
            )  # covertext. generated text that contains secret information.
            # print('ppl: %0.2f, kl: %0.3f, words/bit: %0.2f, bits/word: %0.2f, entropy: %.2f' % (math.exp(nll), kl, words_per_bit, 1/words_per_bit, Hq/0.69315))
            sentence_assmble.append(text)
        dataframe = pd.DataFrame({'Sentences': sentence_assmble})
        dataframe.to_csv("User_{}_Name_{}_Amount_{}.csv".format(
            random.randint(1, 10000),
            message_str.upper()[0:-5], len(generated_array)),
                         index=False,
                         sep=',')
예제 #7
0
def embed(unicode_enc=False, mode='saac', block_size=1, temp=0.9, precision=26, topk=300, device='cuda',
          finish_sent=False, nucleus=0.95, delta=0.01, model_name='gpt2',
          context_file='D:/OneDrive - whu.edu.cn/桌面/NeuralSteganography-master1/context.txt', name='Gogo'):
    # Example: embed(mode='saac', name='Chhenl', temp=0.9)
    # covertext_list保存生成文本的列表(生成10个)

    temp = float(temp)
    message_str = name
    # VALIDATE PARAMETERS 验证隐写算法
    if mode not in ['arithmetic', 'huffman', 'bins', 'saac']:
        raise NotImplementedError

    # 打印隐写信息(人名)
    print("Plain_text is ", message_str)

    # 读取上下文
    f = open(context_file, 'r', encoding='utf-8')
    context = f.read()
    f.close()
    print("sample context is ", context)  # related to the text generation procedure.

    # 加载文本生成模型
    print("loading GPT-2 LM to GPU")
    enc, model = get_model(model_name=model_name)
    print("finish loading !")

    print("implication of {}".format(mode))

    # bins隐写算法的处理
    if mode == 'bins':
        bin2words, words2bin = get_bins(len(enc.encoder), block_size)

    # saac隐写算法的处理
    if delta and mode == "saac":
        nucleus = 2 ** (-1.0 * delta)

    # first_flag = 1  # 对下文中默认处理的标志
    context_tokens = encode_context(context, enc)  # 对context进行语言模型相关的编码

    # 得到对隐写信息(人名)的大小写集合
    message_str = message_str.upper()
    arr = list(message_str)
    generated_array = dfs(arr, 0, [])

    # first_flag = 0
    covertext_list = []

    for temp_count in range(0, len(generated_array)):
        # First encode message to uniform bits, without any context
        # (not essential this is arithmetic vs ascii, but it's more efficient when the message is natural language)

        if temp_count > 1:
            break                 # 测试时最好完成修正,此处限制输出10个COVERTEXT

        print("=" * 80)
        print("Altering the #{} msg_str:".format(temp_count), message_str)
        message_str = generated_array[temp_count]  # 选择一个隐写信息(比如 KiErAn)

        # 得到message。即上文所述的字节流
        if unicode_enc:
            ba = bitarray.bitarray()
            ba.frombytes(message_str.encode('utf-8'))
            message = ba.tolist()
        else:
            message_ctx = [enc.encoder['<|endoftext|>']]
            message_str += '<eos>'
            message = decode_arithmetic(model, enc, message_str, message_ctx, precision=40, topk=60000)

        # Next encode bits into cover text, using arbitrary context

        # 下方完成隐写算法,使用不同隐写算法将字节流嵌入进生成文本中,得到out经过GPT2的解码器得到COVERTEXT
        Hq = 0
        if mode == 'arithmetic':
            out, nll, kl, words_per_bit, Hq = encode_arithmetic(model, enc, message, context_tokens, temp=temp,
                                                                finish_sent=finish_sent, precision=precision,
                                                                topk=topk)
        elif mode == 'huffman':
            out, nll, kl, words_per_bit = encode_huffman(model, enc, message, context_tokens, block_size,
                                                         finish_sent=finish_sent)
        elif mode == 'bins':
            out, nll, kl, words_per_bit = encode_block(model, enc, message, context_tokens, block_size, bin2words,
                                                       words2bin, finish_sent=finish_sent)
        elif mode == 'saac':
            out, nll, kl, words_per_bit, Hq, topk_list, case_studies = encode_saac(model, enc, message,
                                                                                   context_tokens, device=device,
                                                                                   temp=temp, precision=precision,
                                                                                   topk=topk, nucleus=nucleus)
        covertext = enc.decode(out)
        covertext_list.append(covertext)  # 将所有COVERTEXT保存到一个结构中,可供调用

        # 打印结果,COVERTEXT,此处可以将covertext进行提取。
        print('#{} generated covertext:\n'.format(temp_count),
              covertext)  # covertext. generated covertext that contains secret information.

    return covertext_list