def __init__(self, params):

        self.data_path = params.data_path
        self.params = params

        if params.IS_DEBUG:
            print('debug mode')
            # load data for debugging
            self.train = self.load_data(self.data_path +
                                        self.params.DATA_DEBUG)
            self.dev = self.load_data(self.data_path + self.params.DATA_DEBUG)
            self.test = self.load_data(self.data_path + self.params.DATA_DEBUG)

        else:
            # load data
            self.train = self.load_data(self.data_path +
                                        self.params.DATA_TRAIN)
            self.dev = self.load_data(self.data_path + self.params.DATA_DEV)
            self.test = self.load_data(self.data_path + self.params.DATA_TEST)

        # batcher for ELMo
        if self.params.USE_CHAR_ELMO:
            print('[INFO] character-level ELMo')
            self.batcher = Batcher(self.data_path + self.params.DIC, 50)
        else:
            print('[INFO] cached-token-level ELMo')
            self.batcher = TokenBatcher(self.data_path + self.params.DIC)

        self.dic_size = 0
        with open(self.data_path + self.params.DIC, 'r') as f:
            self.dic = f.readlines()
            self.dic = [x.strip() for x in self.dic]
            self.dic_size = len(self.dic)

        print('[completed] load data, dic_size: ', self.dic_size)
Beispiel #2
0
def load_ELMo_data(filename, seq_len, entity_len):
    vocab_file = "./ELMo_file/vocab.txt"
    batcher = TokenBatcher(vocab_file)
    entity_list, token_list, _ = read_data(filename)

    entity_id_list, token_id_list = [], []
    real_chars_list, seq_lens_list = [], []
    for index in range(len(token_list)):
        token_id_list.append(token_list[index][:seq_len])
        entity_id_list.append(entity_list[index][:entity_len])

        real_seq_len = min(len(token_list[index]), seq_len)
        tmp = [1] * real_seq_len
        [tmp.append(0) for _ in range(len(tmp), seq_len)]
        seq_lens_list.append(real_seq_len)
        real_chars_list.append(tmp)

    entity_pad = batcher.batch_sentences(entity_id_list)
    token_pad = batcher.batch_sentences(token_id_list)

    print("The shape of tokens after loading vocab:", token_pad.shape)

    # 按每条数据打包
    features = []
    for index in range(len(token_list)):
        curr_features = [
            entity_pad[index],
            token_pad[index],
            real_chars_list[index],
            seq_lens_list[index],
        ]
        features.append(curr_features)

    return np.array(features)
    def __init__(self):
        self.vocab_file = 'vocab_small.txt'
        # Location of pretrained LM.  Here we use the test fixtures.
        datadir = os.path.join('pretrained')
        options_file = os.path.join(
            datadir, 'elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json')
        weight_file = os.path.join(
            datadir, 'elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5')

        # Dump the token embeddings to a file. Run this once for your dataset.
        token_embedding_file = 'elmo_token_embeddings.hdf5'
        dump_token_embeddings(self.vocab_file, options_file, weight_file,
                              token_embedding_file)

        self.batcher = TokenBatcher(self.vocab_file)
        # Input placeholders to the biLM.
        self.context_token_ids = tf.placeholder('int32', shape=(None, None))
        # Build the biLM graph.
        bilm = BidirectionalLanguageModel(
            options_file,
            weight_file,
            use_character_inputs=False,
            embedding_weight_file=token_embedding_file)
        # Get ops to compute the LM embeddings.
        context_embeddings_op = bilm(self.context_token_ids)
        self.elmo_context_input = weight_layers('input',
                                                context_embeddings_op,
                                                l2_coef=0.0)
        self.elmo_context_output = weight_layers('output',
                                                 context_embeddings_op,
                                                 l2_coef=0.0)
Beispiel #4
0
 def build(self, options_file, weight_file, vocab_file, token_embedding_file):
     self._bilm = BidirectionalLanguageModel(
         options_file,
         weight_file,
         use_character_inputs=False,
         embedding_weight_file=token_embedding_file,
         max_batch_size = self.max_batch)
     self._token_batcher = TokenBatcher(vocab_file)
Beispiel #5
0
    def __init__(self, hparams):
        self.hparams = hparams
        self.vocab_path = self.hparams.word_vocab_path
        self.elmo_options_file = self.hparams.elmo_options_file
        self.elmo_weight_file = self.hparams.elmo_weight_file
        self.token_embedding_file = self.hparams.elmo_token_embedding_file

        self.batcher = TokenBatcher(self.vocab_path)
        if not os.path.exists(self.token_embedding_file):
            print("making dump token embeddings")
            self._make_dump_token_embeddings()
            print("finished making dump_token_embeddings")
Beispiel #6
0
    def __init__(self, config):
        self.lr = config["lr"]
        self.input_dropout = config["dropout"]
        self.lstm_dim = config["lstm_dim"]
        self.layer_type = config["layer_type"]
        self.use_attention = config["attention"]
        self.num_attention_heads = config['num_attention_heads']
        self.size_per_head = config['size_per_head']
        self.num_tags = 7
        self.char_dim = 300
        self.global_step = tf.Variable(0, trainable=False)
        self.best_dev_f1 = tf.Variable(0.0, trainable=False)
        self.initializer = initializers.xavier_initializer()

        # elmo
        self.batcher = TokenBatcher(config['vocab_file'])
        # Input placeholders to the biLM.
        self.context_token_ids = tf.placeholder('int32', shape=(None, None))
        # Build the biLM graph.
        self.bilm = BidirectionalLanguageModel(
            config['options_file'],
            config['weight_file'],
            use_character_inputs=False,
            embedding_weight_file=config['token_embedding_file'])
        self.context_embeddings_op = self.bilm(self.context_token_ids)
        self.elmo_context_input = weight_layers('input',
                                                self.context_embeddings_op,
                                                l2_coef=0.0)['weighted_op']

        # add placeholders for the model
        self.mask_inputs = tf.placeholder(dtype=tf.int32,
                                          shape=[None, None],
                                          name="ChatInputs")
        self.targets = tf.placeholder(dtype=tf.int32,
                                      shape=[None, None],
                                      name="Targets")

        # dropout keep prob
        self.dropout = tf.placeholder(dtype=tf.float32, name="Dropout")
        used = tf.sign(tf.abs(self.mask_inputs))
        length = tf.reduce_sum(used, reduction_indices=1)
        self.lengths = tf.cast(length, tf.int32)
        self.batch_size = tf.shape(self.mask_inputs)[0]
        self.num_steps = tf.shape(self.mask_inputs)[-1]

        self.logits = self.inference(self.elmo_context_input)
        # loss of the model
        self.loss = self.loss_layer(self.logits, self.lengths)
        self.train_op = self.train(self.loss)
        # saver of the model
        self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=5)
Beispiel #7
0
def evaluate(args):
    """
    evaluate the trained model on dev files
    """

    logger = logging.getLogger(args.algo)
    logger.info('Load data_set and vocab...')
    # with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
    #     vocab = pickle.load(fin)

    data_dir = '/home/home1/dmyan/codes/bilm-tf/bilm/data/'
    vocab_file = data_dir + 'vocab.txt'
    batcher = TokenBatcher(vocab_file)

    data = Dataset(test_files=args.test_files,
                   max_p_length=args.max_p_len,
                   max_q_length=args.max_q_len)
    logger.info('Converting text into ids...')
    data.convert_to_ids(batcher)
    logger.info('Initialize the model...')
    if args.algo.startswith("BIDAF"):
        model = BiDAFModel(args)
    elif args.algo.startswith("R-net"):
        model = RNETModel(args)
    model.restore(model_dir=args.model_dir + args.algo, model_prefix=args.algo)
    #logger.info("Load dev dataset...")
    #model.dev_content_answer(args.dev_files)
    logger.info('Testing the model...')
    eval_batches = data.get_batches("test", args.batch_size, 0, shuffle=False)
    eval_loss, bleu_rouge = model.evaluate(eval_batches,
                                           result_dir=args.result_dir,
                                           result_prefix="test.predicted")
    logger.info("Test loss {}".format(eval_loss))
    logger.info("Test result: {}".format(bleu_rouge))
    logger.info('Done with model Testing!')
Beispiel #8
0
def contextualize(sequences):
    batcher = TokenBatcher(vocab_file)

    with tf.Session() as sess:
        # It is necessary to initialize variables once before running inference.
        sess.run(tf.global_variables_initializer())

        # Create batches of data.
        context_ids = batcher.batch_sentences(sequences)

        # Compute ELMo representations (here for the input only, for simplicity).
        elmo_context_output_ = sess.run(
            [elmo_context_output['weighted_op']],
            feed_dict={context_token_ids: context_ids})
    # print(np.array(elmo_context_output_).shape)
    # print(elmo_context_output_) #contextualized embedding vector sequences
    return elmo_context_output_
Beispiel #9
0
            def elmo(reviews, inputData):
                """
                对每个输入的batcher都动态的生成词向量表示
                """
                # TokenBatcher是生成词表示的batch类
                batcher = TokenBatcher(config.vocabFile)
                with tf.Session() as sess:
                    sess.run(tf.global_variables_initializer())

                    # 生成batch数据
                    inputDataIndex = batcher.batch_sentences(reviews)
                    #print("inputDataIndex:{}".format(inputDataIndex))

                    # 计算ELMo的向量表示
                    elmoInputVec = sess.run(
                        [elmoInput["weighted_op"]],
                        feed_dict={inputData: inputDataIndex})
                    return elmoInputVec
Beispiel #10
0
        def elmo(reviews):
            """
            对每一个输入的batch都动态的生成词向量表示
            """

            #           tf.reset_default_graph()
            # TokenBatcher是生成词表示的batch类
            batcher = TokenBatcher(config.vocabFile)
            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())

                # 生成batch数据
                inputDataIndex = batcher.batch_sentences(reviews)

                # 计算ELMo的向量表示
                elmoInputVec = sess.run([elmoInput['weighted_op']],
                                        feed_dict={inputData: inputDataIndex})

            return elmoInputVec
class elmo():
    def __init__(self):
        self.vocab_file = 'vocab_small.txt'
        # Location of pretrained LM.  Here we use the test fixtures.
        datadir = os.path.join('pretrained')
        options_file = os.path.join(
            datadir, 'elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json')
        weight_file = os.path.join(
            datadir, 'elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5')

        # Dump the token embeddings to a file. Run this once for your dataset.
        token_embedding_file = 'elmo_token_embeddings.hdf5'
        dump_token_embeddings(self.vocab_file, options_file, weight_file,
                              token_embedding_file)

        self.batcher = TokenBatcher(self.vocab_file)
        # Input placeholders to the biLM.
        self.context_token_ids = tf.placeholder('int32', shape=(None, None))
        # Build the biLM graph.
        bilm = BidirectionalLanguageModel(
            options_file,
            weight_file,
            use_character_inputs=False,
            embedding_weight_file=token_embedding_file)
        # Get ops to compute the LM embeddings.
        context_embeddings_op = bilm(self.context_token_ids)
        self.elmo_context_input = weight_layers('input',
                                                context_embeddings_op,
                                                l2_coef=0.0)
        self.elmo_context_output = weight_layers('output',
                                                 context_embeddings_op,
                                                 l2_coef=0.0)

    def get_emb(self, tokenized_context):
        all_tokens = set(['<S>', '</S>'])
        for context_sentence in tokenized_context:
            for token in context_sentence:
                all_tokens.add(token)
        with open(self.vocab_file, 'w') as fout:
            fout.write('\n'.join(all_tokens))
        tf.reset_default_graph()
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            # Create batches of data.
            context_ids = self.batcher.batch_sentences(tokenized_context)
            # Input
            elmo_context_input_ = sess.run(
                [self.elmo_context_input['weighted_op']],
                feed_dict={self.context_token_ids: context_ids})
            # For output
            elmo_context_output_ = sess.run(
                [self.elmo_context_output['weighted_op']],
                feed_dict={self.context_token_ids: context_ids})
        return elmo_context_input_, elmo_context_output_
def dump_token_bilm_embeddings(vocab_file, dataset_file, options_file,
                               weight_file, embedding_weight_file, outfile):

    batcher = TokenBatcher(vocab_file)

    ids_placeholder = tf.placeholder('int32', shape=(None, None))

    model = BidirectionalLanguageModel(
        options_file,
        weight_file,
        use_character_inputs=False,
        embedding_weight_file=embedding_weight_file)
    ops = model(ids_placeholder)

    config = tf.ConfigProto(allow_soft_placement=True)
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        sentence_id = 0
        with open(dataset_file, 'r') as fin, \
                h5py.File(outfile, 'w') as fout:
            for line in fin:
                sentence = line.strip().split()
                token_ids = batcher.batch_sentences([sentence])
                embeddings = sess.run(ops['lm_embeddings'],
                                      feed_dict={ids_placeholder: token_ids})
                embedding = embeddings[0, :, :, :]
                ds = fout.create_dataset('{}'.format(sentence_id),
                                         embedding.shape,
                                         dtype='float32',
                                         data=embedding)
                # static_token_emb = embedding[0, :, :]
                # first_layer_emb = embedding[1, :, :]
                # final_layer_emb = embedding[2, :, :]
                # avg_emb = np.mean(embedding, axis=0)  # average embedding of the three layers
                sentence_id += 1
                if sentence_id % 500 == 0:
                    print('%.2f%% finished!' %
                          (sentence_id / float(EXAMPLE_COUNT) * 100))
Beispiel #13
0
    def build(self, vocab_file, stop_word_file, synonym_file=None):
        # 1. build TokenBatcher
        self.token_batcher = TokenBatcher(vocab_file)
        self.word2id = self.token_batcher._lm_vocab._word_to_id
        self.id2word = self.token_batcher._lm_vocab._id_to_word
        # 2. if synonym_file is not None, populate synonyms (two directions).
        with open(synonym_file, "r") as f:
            for line in f:
                line = line.strip().split("\t")
                if (line[0] in self.word2id and line[2] in self.word2id):
                    id0 = self.word2id[line[0]]
                    id1 = self.word2id[line[2]]
                    if (id1 == id0):
                        continue
                    self.synonyms.setdefault(id0, set()).add(id1)
                    self.synonyms.setdefault(id1, set()).add(id0)

        # 3. if stop_word_file is not None, populate stop_word_ids
        with open(stop_word_file, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if line in self.word2id:
                    self.stop_word_ids.add(self.word2id[line])
Beispiel #14
0
    def __init__(self,
                 vocab_file,
                 max_seq_length,
                 max_token_length=None,
                 stroke_vocab_file=None,
                 tran2sim=False,
                 sim2tran=False):
        self.vocab_file = vocab_file
        self.max_seq_length = max_seq_length
        self.max_token_length = max_token_length

        max_seq_length = self.max_seq_length - 2  # 因會加 <bos> and <eos>,所以 -2
        self.token_batcher = TokenBatcher(self.vocab_file, max_seq_length)
        if max_token_length:
            self.batcher = Batcher(self.vocab_file, self.max_token_length,
                                   max_seq_length, stroke_vocab_file)

        self.convert_config = None
        if tran2sim and sim2tran:
            assert tran2sim != sim2tran
        elif tran2sim:
            self.convert_config = "t2s.json"
        elif sim2tran:
            self.convert_config = "s2t.json"
Beispiel #15
0
    def __init__(self, path=embedding_path, embedding_dim=512,
                 sentence_len=max_sentence_len, pair_mode=False):
        embeddings = dict()

        self.embedding_path = path
        self.embedding_dim = embedding_dim
        self.sentence_len = sentence_len
        self.pair_mode = pair_mode
        self.embedding_dict = embeddings

        g_elmo = tf.Graph()
        vocab_file = './bilmelmo/data/vocab.txt'
        options_file = './bilmelmo/try/options.json'
        weight_file = './bilmelmo/try/weights.hdf5'
        token_embedding_file = './bilmelmo/data/vocab_embedding.hdf5'

        with tf.Graph().as_default() as g_elmo:
            self.batcher = TokenBatcher(vocab_file)
            self.context_token_ids = tf.placeholder('int32', shape=(None, None))
            self.bilm = BidirectionalLanguageModel(
                options_file,
                weight_file,
                use_character_inputs=False,
                embedding_weight_file=token_embedding_file
            )

            self.context_embeddings_op = self.bilm(self.context_token_ids)
            self.elmo_context_input = weight_layers('input', self.context_embeddings_op, l2_coef=0.0)

            self.elmo_context_output = weight_layers(
                'output', self.context_embeddings_op, l2_coef=0.0
            )
            init = tf.global_variables_initializer()
        sess_elmo = tf.Session(graph=g_elmo)
        sess_elmo.run(init)
        self.sess_elmo = sess_elmo
Beispiel #16
0
class elmo_encoder(object):
    def __init__(self):
        self.max_batch = 120000
        print ("WARNING: Currently max_batch_size of elmo encoder is set to", self.max_batch)
        pass
    
    def build(self, options_file, weight_file, vocab_file, token_embedding_file):
        self._bilm = BidirectionalLanguageModel(
            options_file,
            weight_file,
            use_character_inputs=False,
            embedding_weight_file=token_embedding_file,
            max_batch_size = self.max_batch)
        self._token_batcher = TokenBatcher(vocab_file)
        #self.length = length
    
    # sentences has to list of word lists. [['You', 'see', '?'], ['That', 'is', 'very', 'interesting', '.']]
    def embed_sent_batch(self, sentences, length):
        sentences_tokenid = self._token_batcher.batch_sentences(sentences)
        # s_tokenid = s_tokenid[1:][:-1]
        tf.reset_default_graph()
        processed_sentences_tokenid = []
        length += 2 # Take into account <s> and </s>
        for s_tokenid in sentences_tokenid:
            if (len(s_tokenid) >= length):
                s_tokenid = s_tokenid[:length]
            else:
                s_tokenid = np.pad(s_tokenid, (0, length - len(s_tokenid)), 'constant', constant_values=(0))
            #s_tokenid = np.expand_dims(s_tokenid, axis=0)
            processed_sentences_tokenid.append(s_tokenid)
        batch_size = len(processed_sentences_tokenid)
        processed_sentences_tokenid = np.array(processed_sentences_tokenid)
        # tf
        with tf.device("/cpu:0"):
            context_token_ids = tf.placeholder('int32', shape=(batch_size, length))
            context_embeddings_op = self._bilm(context_token_ids)
            elmo_context_output = weight_layers('output', context_embeddings_op, l2_coef=0.0)['weighted_op']
            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            print ('++++++Check_point_1\n')
            with tf.Session(config=config) as sess:
                sess.run([tf.global_variables_initializer()])
                elmo_context_output_ = sess.run([elmo_context_output],feed_dict={context_token_ids: processed_sentences_tokenid})[0]
        #print (elmo_context_output_.shape)
        return elmo_context_output_
    def list_to_embeddings_with_dump(self,
                                     batch: List[List[str]],
                                     outfile_to_dump=None):
        """
        Parameters
        ----------
        batch : ``List[List[str]]``, required
            A list of tokenized sentences.

        """
        document_embeddings = []

        if batch == [[]]:
            raise ValueError('Batch should not be empty')
        else:

            if self.word_embedding_file is None:
                batcher = Batcher(self.voc_file_path, self.max_word_length)
            else:
                batcher = TokenBatcher(self.voc_file_path)
            config = tf.ConfigProto(allow_soft_placement=True)
            with tf.Session(config=config) as sess:
                sess.run(tf.global_variables_initializer())
                ids_list = batcher.batch_sentences(batch)
                with h5py.File(outfile_to_dump, 'w') as fout:
                    for i, ids in enumerate(tqdm(ids_list,
                                                 total=len(ids_list))):
                        _ops = sess.run(
                            self.ops, feed_dict={self.ids_placeholder: [ids]})
                        mask = _ops['mask']
                        lm_embeddings = _ops['lm_embeddings'][0, :]
                        token_embeddings = _ops['token_embeddings']
                        lengths = _ops['lengths']
                        length = int(mask.sum())
                        document_embeddings.append(lm_embeddings)
                        ds = fout.create_dataset('{}'.format(i),
                                                 lm_embeddings.shape,
                                                 dtype='float32',
                                                 data=lm_embeddings)
                document_embeddings = np.asarray(document_embeddings)
        return document_embeddings
Beispiel #18
0
def train(args):
    """
    trains the reading comprehension model
    """
    logger = logging.getLogger(args.algo)
    logger.info('Load data_set and vocab...')
    # with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
    #     vocab = pickle.load(fin)

    data_dir = '/home/home1/dmyan/codes/bilm-tf/bilm/data/'
    vocab_file = data_dir + 'vocab.txt'
    batcher = TokenBatcher(vocab_file)

    data = Dataset(train_files=args.train_files,
                   dev_files=args.dev_files,
                   max_p_length=args.max_p_len,
                   max_q_length=args.max_q_len)
    logger.info('Converting text into ids...')
    data.convert_to_ids(batcher)
    logger.info('Initialize the model...')
    if args.algo.startswith("BIDAF"):
        model = BiDAFModel(args)
    elif args.algo.startswith("R-net"):
        model = RNETModel(args)
    elif args.algo.startswith("QANET"):
        model = QANetModel(args)
    #model.restore(model_dir=args.model_dir, model_prefix=args.algo)
    logger.info("Load dev dataset...")
    model.dev_content_answer(args.dev_files)
    logger.info('Training the model...')
    model.train(data,
                args.epochs,
                args.batch_size,
                save_dir=args.model_dir + args.algo,
                save_prefix=args.algo,
                dropout_keep_prob=args.dropout_keep_prob)
    logger.info('Done with model training!')
Beispiel #19
0
id2char, char2id = json.load(open('./inputs/char2id.json', encoding='utf-8'))
id2bichar, bichar2id = json.load(
    open('./inputs/bichar2id.json', encoding='utf-8'))
id2BIO, BIO2id = json.load(open('./inputs/bio2id.json', encoding='utf-8'))

params = {
    'char2id_size': len(char2id),
    'epochs': 100,
    'early_stopping': 8,
    'bichar2id_size': len(bichar2id),
    'n_class_labels': len(BIO2id)
}

# train_data = train_data[:200]
vocab_file = './ELMo/DaGuanVocabForElmo.txt'
batcher = TokenBatcher(vocab_file)


def process_batch_data(batch_data, char2id, BIO2id, mode):

    new_batch_data = []
    elmo_text = []
    if mode == 'dev':
        for data in batch_data:
            dic = {}
            elmo_text.append([char for char in data['text']])
            text = [char2id.get(_char, 0)
                    for _char in data['text']]  #1,UNK,0 pad
            bichar = [bichar2id.get(_bichar, 0) for _bichar in data['bichar']]

            bio = [BIO2id.get(_bio) for _bio in data['bio']]
Beispiel #20
0
import data as dt

#split trainset and validationset
alen = len(dt.X)
val_ratio = 0.1
val_len = int(alen * val_ratio)

tokenized_sentences = dt.X[:-val_len]
y = dt.y[:-val_len]

tokenized_sentences_val = dt.X[-val_len:]
y_val = dt.y[-val_len:]

#Creat a TokenBatcher to map text to token ids.
batcher = TokenBatcher(vocab_file)

#input placeholder to the biLM
token_ids = tf.placeholder('int32', shape=(None, None))
y_label = tf.placeholder('float32', shape=(None, None, 17))

#Build the biLM graph
bilm = BidirectionalLanguageModel(options_file,
                                  weight_file,
                                  use_character_inputs=False,
                                  embedding_weight_file=token_embedding_file)

#Get ops to compute the LM embeddings
embeddings_op = bilm(token_ids)

#Get an op to compute ELMo(weighted average of the internal biLM layers)
Beispiel #21
0
#         feed_dict={context_token_ids: context_ids,
#                    question_token_ids: question_ids}
#     )
#
# print(elmo_context_input_.shape, elmo_question_input_.shape)
"==================="
tokenized_context = [
    ['这', '是', '什么'],
]

vocab_file = './data/vocab.txt'
options_file = './try/options.json'
weight_file = './try/weights.hdf5'
token_embedding_file = './data/vocab_embedding.hdf5'

batcher = TokenBatcher(vocab_file)
context_token_ids = tf.placeholder('int32', shape=(None, None))
bilm = BidirectionalLanguageModel(options_file,
                                  weight_file,
                                  use_character_inputs=False,
                                  embedding_weight_file=token_embedding_file)

context_embeddings_op = bilm(context_token_ids)
elmo_context_input = weight_layers('input', context_embeddings_op, l2_coef=0.0)

elmo_context_output = weight_layers('output',
                                    context_embeddings_op,
                                    l2_coef=0.0)
with tf.Session() as sess:
    # It is necessary to initialize variables once before running inference.
    sess.run(tf.global_variables_initializer())
Beispiel #22
0
datadir = os.path.join('tests', 'fixtures', 'model')
options_file = os.path.join(datadir, 'options.json')
weight_file = os.path.join(datadir, 'lm_weights.hdf5')

# Dump the token embeddings to a file. Run this once for your dataset.
token_embedding_file = 'elmo_token_embeddings.hdf5'
dump_token_embeddings(
    vocab_file, options_file, weight_file, token_embedding_file
)
tf.reset_default_graph()



## Now we can do inference.
# Create a TokenBatcher to map text to token ids.
batcher = TokenBatcher(vocab_file)

# Input placeholders to the biLM.
context_token_ids = tf.placeholder('int32', shape=(None, None))
question_token_ids = tf.placeholder('int32', shape=(None, None))

# Build the biLM graph.
bilm = BidirectionalLanguageModel(
    options_file,
    weight_file,
    use_character_inputs=False,
    embedding_weight_file=token_embedding_file
)

# Get ops to compute the LM embeddings.
context_embeddings_op = bilm(context_token_ids)
Beispiel #23
0
class Data(object):
    # member variables like dictionaries and lists goes here
    def __init__(self, length=0, use_synonym=False):
        self.para_tuples = [
        ]  # [(sent_id, sent_id, index_of_an_overlapping/synonym_token, index_of_an_overlapping/synonym_token), ... ]
        self.neg_tuples = [
        ]  # [(sent_id, sent_id, index_of_an_overlapping/synonym_token, index_of_an_overlapping/synonym_token), ... ]
        self.token_pair2neg_tuples = {
        }  # {(token_id\, token_id) : set([neg_tuple_id, ...])}
        self.id2sent = [
        ]  # a list of arrays, where each array is a list of token ids (which represent a sentence). # eventually, make this an numpy array
        self.sent2id = {}
        self.paraphrases = set(
            []
        )  # a set of {(sent_id, sent_id), ...} to quickly check whether two sentences are paraphrases or not.
        self.token2sents = {
        }  # reverse index of sentences given tokens. This is a map { token_id : set([(sent_id, index_of_the_token_in_the_sentence), ...]) }.
        self.synonyms = {}  # {token_id : set([token_id, ... ])}
        self.use_synonym = use_synonym
        self.stop_word_ids = set([])
        self.length = length
        # self.batch_sizeK = None # To be readed by tester

        # build token_batcher
        self.word2id = {}
        self.id2word = []

    def build(self, vocab_file, stop_word_file, synonym_file=None):
        # 1. build TokenBatcher
        self.token_batcher = TokenBatcher(vocab_file)
        self.word2id = self.token_batcher._lm_vocab._word_to_id
        self.id2word = self.token_batcher._lm_vocab._id_to_word
        # 2. if synonym_file is not None, populate synonyms (two directions).
        with open(synonym_file, "r") as f:
            for line in f:
                line = line.strip().split("\t")
                if (line[0] in self.word2id and line[2] in self.word2id):
                    id0 = self.word2id[line[0]]
                    id1 = self.word2id[line[2]]
                    if (id1 == id0):
                        continue
                    self.synonyms.setdefault(id0, set()).add(id1)
                    self.synonyms.setdefault(id1, set()).add(id0)

        # 3. if stop_word_file is not None, populate stop_word_ids
        with open(stop_word_file, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if line in self.word2id:
                    self.stop_word_ids.add(self.word2id[line])

    # The dataset is formatted as sentence\tsentence\tlabel
    def load_sentece_pairs(self, data_file_list, bad_words, data_type_list):
        # 1. populate sentence_tuples, update sentences (check stop_word_ids), paraphrases, token2sents.
        # 2. populate negative cases of sentence pairs into neg_tuples, and correspondingly update token2neg_tuples, sentences (check stop_word_ids), token2sents.
        s_len = []
        for data_file, data_type in zip(data_file_list, data_type_list):
            with open(data_file, "rt", encoding="utf-8") as f:
                count = 0
                for line in f:
                    count += 1
                    if (count >= 20000):
                        break
                    line = line.strip().split('\t')
                    label = line[0]
                    if (data_type == "mrpc"):
                        s1 = line[3].split()
                        s2 = line[4].split()
                    else:
                        s1 = line[1].split()
                        s2 = line[2].split()

                    exist_bad_word = False
                    for i in bad_words:
                        if (i in s1 or i in s2):
                            exist_bad_word = True
                    if (exist_bad_word == True):
                        continue

                    # s1_tokenid = self.token_batcher.batch_sentences([s1])[0][1:][:-1]
                    # s2_tokenid = self.token_batcher.batch_sentences([s2])[0][1:][:-1]

                    # 1
                    s1_tokenid = self.token_batcher.batch_sentences([s1])[0]
                    s2_tokenid = self.token_batcher.batch_sentences([s2])[0]

                    # zero-pad/ truncate sentences to self.length
                    #check if len(s1) > self.len
                    s_len.append(len(s1_tokenid))
                    s_len.append(len(s2_tokenid))
                    if (len(s1_tokenid) > self.length or len(s1_tokenid) < 3):
                        print(s1_tokenid, s1)
                        continue
                    if (len(s2_tokenid) > self.length or len(s2_tokenid) < 3):
                        print(s2_tokenid, s2)
                        continue

                    if len(s1_tokenid) > self.length:
                        s1_tokenid = s1_tokenid[:self.length]
                    else:
                        s1_tokenid = np.pad(s1_tokenid,
                                            (0, self.length - len(s1_tokenid)),
                                            'constant',
                                            constant_values=(0))
                    if len(s2_tokenid) > self.length:
                        s2_tokenid = s2_tokenid[:self.length]
                    else:
                        s2_tokenid = np.pad(s2_tokenid,
                                            (0, self.length - len(s2_tokenid)),
                                            'constant',
                                            constant_values=(0))

                    if not (tuple(s1_tokenid) in self.sent2id):
                        self.id2sent.append(s1_tokenid)
                        s1_id = len(self.id2sent) - 1
                        self.sent2id.update({tuple(s1_tokenid): s1_id})
                    else:
                        s1_id = self.sent2id[tuple(s1_tokenid)]
                    if not (tuple(s2_tokenid) in self.sent2id):
                        self.id2sent.append(s2_tokenid)
                        s2_id = len(self.id2sent) - 1
                        self.sent2id.update({tuple(s2_tokenid): s2_id})
                    else:
                        s2_id = self.sent2id[tuple(s2_tokenid)]

                    #update paraphrases, para_tuples, neg_tuples
                    overlap_index_pairs, synonym_index_pairs = self.overlap(
                        s1_tokenid, s2_tokenid)
                    # print(s1_tokenid)
                    # print(s2_tokenid)
                    # print("overlap", overlap_index_pairs)
                    # if synonym_index_pairs:
                    #     print("synonym_index_pairs", synonym_index_pairs)
                    total_index_pairs = overlap_index_pairs + synonym_index_pairs
                    if (label == "1"):
                        self.paraphrases.add((s1_id, s2_id))
                        self.paraphrases.add((s2_id, s1_id))
                        for p in total_index_pairs:
                            sent_tuple = (s1_id, s2_id, p[0], p[1])
                            self.para_tuples.append(sent_tuple)
                    else:
                        for p in total_index_pairs:
                            sent_tuple = (s1_id, s2_id, p[0], p[1])
                            self.neg_tuples.append(sent_tuple)
                            w1 = s1_tokenid[p[0]]
                            w2 = s2_tokenid[p[1]]
                            if w1 in self.stop_word_ids or w2 in self.stop_word_ids:
                                continue
                            self.token_pair2neg_tuples.setdefault(
                                (w1, w2), set()).add(len(self.neg_tuples) - 1)

                    # update token2sents
                    for index, token_id in enumerate(s1_tokenid):
                        if (token_id == 2 or token_id == 1):
                            continue
                        sid_index = (s1_id, index)
                        self.token2sents.setdefault(token_id,
                                                    set()).add(sid_index)
                    for index, token_id in enumerate(s2_tokenid):
                        if (token_id == 2 or token_id == 1):
                            continue
                        sid_index = (s2_id, index)
                        self.token2sents.setdefault(token_id,
                                                    set()).add(sid_index)
        self.neg_tuples, self.para_tuples, self.id2sent = np.array(
            self.neg_tuples), np.array(self.para_tuples), np.array(
                self.id2sent)
        s_len = np.array(s_len)
        print("s length", np.min(s_len), np.max(s_len), np.mean(s_len),
              np.median(s_len))

    def overlap(self, s1, s2):
        # check intersection
        s1_dict = dict((k, i) for i, k in enumerate(s1))
        s2_dict = dict((k, i) for i, k in enumerate(s2))
        word_pairs = []
        inter = set(s1_dict).intersection(set(s2_dict))
        if (1 in inter):
            inter.remove(1)
        if (2 in inter):
            inter.remove(2)
        if (0 in inter):
            inter.remove(0)
        inter.difference_update(self.stop_word_ids)
        # check digit
        for i in inter.copy():
            if (self.id2word[i].isdigit()):
                inter.remove(i)
            if (self.id2word[i].startswith('-')):
                inter.remove(i)
        for w in inter:
            w1_id = s1_dict[w]
            w2_id = s2_dict[w]
            word_pairs.append([w1_id, w2_id])

        synonym_pairs = []
        if self.use_synonym:
            for id in s1_dict.keys():
                if id in self.synonyms:
                    for s in self.synonyms[id]:
                        if s in s2_dict.keys():
                            synonym_pairs.append((s1_dict[id], s2_dict[s]))
            # print(synonym_pairs)
            for id in s2_dict.keys():
                if id in self.synonyms:
                    for s in self.synonyms[id]:
                        if s in s1_dict.keys():
                            synonym_pairs.append((s1_dict[s], s2_dict[id]))
            # print(synonym_pairs)
            # print("------")
        synonym_pairs = list(set(synonym_pairs))
        return word_pairs, synonym_pairs

    def corrupt(self, para_tuple, tar=None):
        # corrupt para tuple into a negative sample. Return (sent_id, sent_id, index_of_an_overlapping/synonym_token, index_of_an_overlapping/synonym_token) for a negative sample.
        if tar == None:
            tar = random.randint(0, 1)
        s1 = para_tuple[0]
        s1_index = para_tuple[2]
        s2 = para_tuple[1]
        s2_index = para_tuple[3]

        if (tar == 0):
            token = self.id2sent[s1][s1_index]
            sents_list = self.token2sents[token]

            if ((s1, s1_index) in sents_list):
                sents_list.remove((s1, s1_index))
            if ((s2, s2_index) in sents_list):
                sents_list.remove((s2, s2_index))
            if (len(sents_list) == 0):
                return random.choice(self.neg_tuples)
            else:
                corrupt_s = random.choice(list(sents_list))
            ind = 0
            while self.is_paraphrase(corrupt_s[0], s1):
                corrupt_s = random.choice(list(sents_list))
                ind += 1
                if ind > 10:
                    # print("ind", ind)
                    random.choice(self.neg_tuples)
                    break
            return (corrupt_s[0], s1, corrupt_s[1], s1_index)

        if (tar == 1):
            token = self.id2sent[s2][s2_index]
            sents_list = self.token2sents[token]

            if ((s1, s1_index) in sents_list):
                sents_list.remove((s1, s1_index))
            if ((s2, s2_index) in sents_list):
                sents_list.remove((s2, s2_index))
            if (len(sents_list) < 2):
                return random.choice(self.neg_tuples)
            else:
                corrupt_s = random.choice(list(sents_list))
            ind = 0
            while self.is_paraphrase(corrupt_s[0], s2):
                corrupt_s = random.choice(list(sents_list))
                ind += 1
                if ind > 10:
                    # print("ind", ind)
                    random.choice(self.neg_tuples)
                    break
            c_tuple = (corrupt_s[0], s2, corrupt_s[1], s2_index)
            return c_tuple

    def neg(self, para_tuple):
        s1 = para_tuple[0]
        s1_index = para_tuple[2]
        s2 = para_tuple[1]
        s2_index = para_tuple[3]
        s1_token = self.id2sent[s1][s1_index]
        s2_token = self.id2sent[s2][s2_index]
        if ((s1_token, s2_token) in self.token_pair2neg_tuples):
            neg_tuple_id = random.choice(
                list(self.token_pair2neg_tuples[(s1_token, s2_token)]))
            neg_tuple = self.neg_tuples[neg_tuple_id]
            return neg_tuple
        else:
            return None

    def corrupt_n(self, para_tuple, n=2):
        # in case we use logistic loss, use the corrupt function n times to generate and return n negative samples. Before each corruption, the random seed needs to be reset.
        corrupt_tuples = []
        for i in range(n):
            random.seed(datetime.now())
            corrupt_tuple = self.corrupt(para_tuple)
            if not corrupt_tuple:
                return None
            else:
                corrupt_tuples.append(corrupt_tuple)
        return corrupt_tuples

    def is_synonym(self, token_id1, token_id2):
        if (token_id1 in self.synonyms(token_id2)):
            return True
        else:
            return False

    def is_paraphrase(self, sent_id1, sent_id2):
        if ((sent_id1, sent_id2) in self.paraphrases):
            return True
        else:
            return False

    def save(self, filename):
        f = open(filename, 'wb')
        #self.desc_embed = self.desc_embed_padded = None
        pickle.dump(self.__dict__, f, pickle.HIGHEST_PROTOCOL)
        f.close()
        print("Save data object as", filename)

    def load(self, filename):
        f = open(filename, 'rb')
        tmp_dict = pickle.load(f)
        self.__dict__.update(tmp_dict)
        print("Loaded data object from", filename)
        print(
            "===============\nCaution: need to reload desc embeddings.\n====================="
        )
Beispiel #24
0
def elmo(reviews):    
    inputDataIndex=batcher.batch_sentences(reviews)
    elmoInputVec=sess.run([elmoInput['weighted_op']],feed_dict={inputData:inputDataIndex})
    return elmoInputVec

if __name__=='__main__':
    sess=tf.Session()
    with sess.as_default():
        model=ELMo(flags=config)

        abilm=BidirectionalLanguageModel(config.option_file,config.weight_file,use_character_inputs=False,embedding_weight_file=config.tokenEmbeddingFile)
        inputData=tf.placeholder('int32',shape=(None,None))
        inputEmbeddingsOp=abilm(inputData)
        elmoInput=weight_layers('input',inputEmbeddingsOp,l2_coef=0.0)
        batcher=TokenBatcher(config.vocab_file)
        sess.run(tf.global_variables_initializer())

        ckpt = tf.train.get_checkpoint_state(config.model_dir)
        if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
            print('Reloading model parameters..')
            model.saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            print('Created new model parameters..')
            sess.run(tf.global_variables_initializer())
    
        current_step = 0
        for e in range(epochs):
            print("----- Epoch {}/{} -----".format(e + 1, epochs))
            for batch in get_batch():
                loss,acc=model.train(sess, batch)
Beispiel #25
0
batchsize = 64

dump_token_embeddings(
    vocab_file, options_file, weight_file, token_embedding_file,
    gpu=gpu, batchsize=batchsize
)

###########################################
"""
Differences from usage of character-elmo are only simple two points:
1. use TokenBatcher(vocab_file) instead of Batcher(vocab_file)
2. add token_embedding_file and token batcher for Elmo instantiation
"""

# Create a TokenBatcher to map text to token ids.
batcher = TokenBatcher(vocab_file)  # REQUIRED

# Build the Elmo with biLM and weight layers.
elmo = Elmo(
    options_file,
    weight_file,
    token_embedding_file=token_embedding_file,  # REQUIRED
    token_batcher=batcher,  # REQUIRED
    num_output_representations=1,
    requires_grad=False,
    do_layer_norm=False,
    dropout=0.)

# Create batches of data.
context_token_ids = batcher.batch_sentences(
    tokenized_context, add_bos_eos=False)
import numpy as np

os.chdir('/local/datdb/MuhaoChelseaBiLmEncoder/model')
vocab_file = 'vocab.txt'
options_file = 'behm_32skip_2l.ckpt/options.json'
weight_file = 'behm_32skip_2l.hdf5'
token_embedding_file = 'vocab_embedding_32skip_2l.hdf5'

## COMMENT need to load own sequences
sequences = [['A', 'K', 'J', 'T', 'C', 'N'], ['C', 'A', 'D', 'A', 'A']]

## Serving contextualized embeddings of amino acids ================================

## Now we can do inference.
# Create a TokenBatcher to map text to token ids.
batcher = TokenBatcher(vocab_file)

# Input placeholders to the biLM.
context_token_ids = tf.placeholder('int32', shape=(None, None))

# Build the biLM graph.
bilm = BidirectionalLanguageModel(options_file,
                                  weight_file,
                                  use_character_inputs=False,
                                  embedding_weight_file=token_embedding_file)

# Get ops to compute the LM embeddings.
context_embeddings_op = bilm(context_token_ids)

elmo_context_top = weight_layers('output_top_only',
                                 context_embeddings_op,
Beispiel #27
0
class Tokenizer(object):
    def __init__(self,
                 vocab_file,
                 max_seq_length,
                 max_token_length=None,
                 stroke_vocab_file=None,
                 tran2sim=False,
                 sim2tran=False):
        self.vocab_file = vocab_file
        self.max_seq_length = max_seq_length
        self.max_token_length = max_token_length

        max_seq_length = self.max_seq_length - 2  # 因會加 <bos> and <eos>,所以 -2
        self.token_batcher = TokenBatcher(self.vocab_file, max_seq_length)
        if max_token_length:
            self.batcher = Batcher(self.vocab_file, self.max_token_length,
                                   max_seq_length, stroke_vocab_file)

        self.convert_config = None
        if tran2sim and sim2tran:
            assert tran2sim != sim2tran
        elif tran2sim:
            self.convert_config = "t2s.json"
        elif sim2tran:
            self.convert_config = "s2t.json"

    def convert(self, text):
        """
    未轉簡繁、轉簡體、轉繁體
    很慢,不建議使用
    """
        if self.convert_config is None:
            return text
        return opencc.convert(text, config=self.convert_config)

    def tokenize(self, text):
        """
    text to token, for example:
    text=‘Pretrained biLMs compute representations useful for NLP tasks.’
    token=['Pretrained', 'biLMs', 'compute', 'representations', 'useful', 'for', 'NLP', 'tasks', '.']
    """
        text = self.convert(text)
        text = tokenize_chinese_chars(text)
        text = text.strip()
        tokens = []
        for word in text.split():
            tokens.extend(self._run_split_on_punc(word))
        return tokens

    def convert_tokens_to_ids(self, tokens):
        return self.token_batcher.batch_sentences([tokens])[0]

    def convert_tokens_to_char_ids(self, tokens):
        """
    tokens: tokenize(text)
    return: shape [max_seq_length * max_token_length]
    """
        # char_ids [max_seq_length, max_token_length]
        char_ids = self.batcher.batch_sentences([tokens])[0]
        # flat_char_ids [max_seq_length * max_token_length]
        flat_char_ids = [
            char_id for sublist in char_ids for char_id in sublist
        ]
        return flat_char_ids

    def _is_punctuation(self, char):
        """Checks whether `chars` is a punctuation character."""
        cp = ord(char)
        # We treat all non-letter/number ASCII as punctuation.
        # Characters such as "^", "$", and "`" are not in the Unicode
        # Punctuation class but we treat them as punctuation anyways, for
        # consistency.
        if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64)
                or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
            return True
        cat = unicodedata.category(char)
        if cat.startswith("P"):
            return True
        return False

    def _run_split_on_punc(self, text):
        """Splits punctuation on a piece of text."""
        chars = list(text)
        i = 0
        start_new_word = True
        output = []
        while i < len(chars):
            char = chars[i]
            if self._is_punctuation(char):
                output.append([char])
                start_new_word = True
            else:
                if start_new_word:
                    output.append([])
                start_new_word = False
                output[-1].append(char)
            i += 1
        return ["".join(x) for x in output]
Beispiel #28
0
    def __init__(self,
                 elmo_vocab_file,
                 elmo_weight_file,
                 elmo_option_file,
                 use_character_elmo,
                 use_concat_p,
                 question_window,
                 utterance_cache_file='',
                 passage_cache_file='',
                 question_cache_file=''):
        self.logger = logging.getLogger("dial")
        self.utterance_cache = None
        self.passage_cache = None
        self.question_cache = None
        self.need_q_cache = (question_window > 1)
        self.need_p_cache = use_concat_p
        if os.path.exists(elmo_weight_file) and os.path.exists(
                elmo_option_file) and os.path.exists(elmo_vocab_file):
            # the vocab file exported from the corpus
            self.elmo_vocab_file = elmo_vocab_file
            # elmo weight file
            self.elmo_weight_file = elmo_weight_file
            # elmo option file
            self.elmo_option_file = elmo_option_file
            self.utterance_cache_file = utterance_cache_file
            self.passage_cache_file = passage_cache_file
            self.question_cache_file = question_cache_file
            self.use_character_elmo = use_character_elmo
            with open(self.elmo_option_file, 'r') as fin:
                options = json.load(fin)
            self.output_layers = options['lstm']['n_layers'] + 1
            self.output_dim = 2 * options['lstm']['projection_dim']
            self.logger.info("output_layers :{}, output_dim :{}".format(
                self.output_layers, self.output_dim))
            # by default, the bilm use the character_elmo
            if self.use_character_elmo:
                # max_num_char for characters for a token.
                self.elmo_max_num_char = options['char_cnn'][
                    'max_characters_per_token']
                # line 207 https://github.com/allenai/bilm-tf/blob/ebf52c6ec1012a3672247c2d14ff7bcad7fb812b/bilm/data.py
                # the mask for char id is 0
                self.PAD_TOKEN_CHAR_IDS = np.zeros((self.elmo_max_num_char),
                                                   dtype=np.int32).tolist()
                # use subword character first, which shows extra improvements beside the contextual information.
                self.elmo_char_batcher = Batcher(self.elmo_vocab_file,
                                                 self.elmo_max_num_char)
                # language mode with use_character_inputs = True
                self.elmo_bilm = BidirectionalLanguageModel(
                    self.elmo_option_file, self.elmo_weight_file)
            else:
                # use token batcher
                self.elmo_token_batcher = TokenBatcher(self.elmo_vocab_file)
                # use elmo_bilm with use_character_inputs = False
                self.elmo_bilm = BidirectionalLanguageModel(
                    self.elmo_option_file, self.elmo_weight_file)

            self.chk_load_utterance_cache()
            self.chk_load_passage_cache()
            self.chk_load_question_cache()
        else:
            self.logger.warn(
                "elmo_weight_file = {}, elmo_option_file={}, elmo_vocab_file={}"
                .format(elmo_weight_file, elmo_option_file, elmo_vocab_file))
Beispiel #29
0
class ELMo_Utils(object):
    """
    Impements Elmo functions used by downstream task
    Each tokenized sentence is a list of str, with a batch of sentences a list of tokenized sentences (List[List[str]]).

The Batcher packs these into a shape (n_sentences, max_sentence_length + 2, 50) numpy array of character ids, padding on the right with 0 ids for sentences less then the maximum length. The first and last tokens for each sentence are special begin and end of sentence ids added by the Batcher.

The input character id placeholder can be dimensioned (None, None, 50), with both the batch dimension (axis=0) and time dimension (axis=1) determined for each batch, up the the maximum batch size specified in the BidirectionalLanguageModel constructor.

After running inference with the batch, the return biLM embeddings are a numpy array with shape (n_sentences, 3, max_sentence_length, 1024), after removing the special begin/end tokens.
    """

    START_TOKEN = '<S>'
    END_TOKEN = '</S>'
    UNK_TOKEN = '<UNK>'
    PAD_SNT = '<S></S>'
    PAD_SNT_ID = 0

    def __init__(self,
                 elmo_vocab_file,
                 elmo_weight_file,
                 elmo_option_file,
                 use_character_elmo,
                 use_concat_p,
                 question_window,
                 utterance_cache_file='',
                 passage_cache_file='',
                 question_cache_file=''):
        self.logger = logging.getLogger("dial")
        self.utterance_cache = None
        self.passage_cache = None
        self.question_cache = None
        self.need_q_cache = (question_window > 1)
        self.need_p_cache = use_concat_p
        if os.path.exists(elmo_weight_file) and os.path.exists(
                elmo_option_file) and os.path.exists(elmo_vocab_file):
            # the vocab file exported from the corpus
            self.elmo_vocab_file = elmo_vocab_file
            # elmo weight file
            self.elmo_weight_file = elmo_weight_file
            # elmo option file
            self.elmo_option_file = elmo_option_file
            self.utterance_cache_file = utterance_cache_file
            self.passage_cache_file = passage_cache_file
            self.question_cache_file = question_cache_file
            self.use_character_elmo = use_character_elmo
            with open(self.elmo_option_file, 'r') as fin:
                options = json.load(fin)
            self.output_layers = options['lstm']['n_layers'] + 1
            self.output_dim = 2 * options['lstm']['projection_dim']
            self.logger.info("output_layers :{}, output_dim :{}".format(
                self.output_layers, self.output_dim))
            # by default, the bilm use the character_elmo
            if self.use_character_elmo:
                # max_num_char for characters for a token.
                self.elmo_max_num_char = options['char_cnn'][
                    'max_characters_per_token']
                # line 207 https://github.com/allenai/bilm-tf/blob/ebf52c6ec1012a3672247c2d14ff7bcad7fb812b/bilm/data.py
                # the mask for char id is 0
                self.PAD_TOKEN_CHAR_IDS = np.zeros((self.elmo_max_num_char),
                                                   dtype=np.int32).tolist()
                # use subword character first, which shows extra improvements beside the contextual information.
                self.elmo_char_batcher = Batcher(self.elmo_vocab_file,
                                                 self.elmo_max_num_char)
                # language mode with use_character_inputs = True
                self.elmo_bilm = BidirectionalLanguageModel(
                    self.elmo_option_file, self.elmo_weight_file)
            else:
                # use token batcher
                self.elmo_token_batcher = TokenBatcher(self.elmo_vocab_file)
                # use elmo_bilm with use_character_inputs = False
                self.elmo_bilm = BidirectionalLanguageModel(
                    self.elmo_option_file, self.elmo_weight_file)

            self.chk_load_utterance_cache()
            self.chk_load_passage_cache()
            self.chk_load_question_cache()
        else:
            self.logger.warn(
                "elmo_weight_file = {}, elmo_option_file={}, elmo_vocab_file={}"
                .format(elmo_weight_file, elmo_option_file, elmo_vocab_file))

    def chk_load_utterance_cache(self):
        if self.utterance_cache_file and os.path.exists(
                self.utterance_cache_file):
            self.utterance_cache = h5py.File(self.utterance_cache_file, 'r')
            #self.utterance_cache_in_mem = {}
            #self.utterance_cache_in_mem['lm_embeddings'] = self.load_h5(self.utterance_cache['lm_embeddings'])
            #self.utterance_cache_in_mem['lengths'] = self.load_h5_lengths(self.utterance_cache['lengths'])
            #self.utterance_cache_in_mem['mask'] = self.load_h5(self.utterance_cache['mask'])
            self.logger.info(
                "Utterance cache loaded from {}, size = {}".format(
                    self.utterance_cache_file,
                    len(self.utterance_cache['lm_embeddings'].keys())))
        else:
            self.utterance_cache = None

    def load_h5(self, h5group):
        x = []
        for index in range(len(h5group.keys())):
            # https://stackoverflow.com/questions/10274476/how-to-export-hdf5-file-to-numpy-using-h5py
            x.append(h5group['{}'.format(index)][...].tolist())
        return x

    def load_h5_lengths(self, h5group):
        x = []
        for index in range(len(h5group.keys())):
            x.extend(h5group['{}'.format(index)][...].tolist())
        return x

    def chk_load_passage_cache(self):
        if self.need_p_cache:
            if self.passage_cache_file and os.path.exists(
                    self.passage_cache_file):
                self.passage_cache = h5py.File(self.passage_cache_file, 'r')
                self.logger.info("Passage cache loaded from {}".format(
                    self.passage_cache_file))
            else:
                self.passage_cache = None
                self.logger.info(
                    "Passage cache needed from {}, it will build soon.".format(
                        self.passage_cache_file))
        else:
            self.passage_cache = None
            self.logger.info("Passage cache not needed")

    def chk_load_question_cache(self):
        if self.need_q_cache:
            if self.question_cache_file and os.path.exists(
                    self.question_cache_file):
                self.question_cache = h5py.File(self.question_cache_file, 'r')
                self.logger.info("Question cache loaded from {}".format(
                    self.question_cache_file))
            else:
                self.question_cache = None
                self.logger.info(
                    "Question cache needed from {}, it will build soon.".
                    format(self.question_cache_file))
        else:
            self.question_cache = None
            self.logger.info("Question cache not needed")

    def need_build_passage_cache(self):
        return self.need_p_cache and self.passage_cache_file != '' and self.passage_cache == None

    def need_build_question_cache(self):
        return self.need_q_cache and self.question_cache_file != '' and self.question_cache == None

    def cleanup(self):
        if self.utterance_cache:
            self.utterance_cache.close()
        if self.passage_cache:
            self.passage_cache.close()
        if self.question_cache:
            self.question_cache.close()
        self.logger.info("Clean up elmo cahce")

    def get_elmo_char_ids(self, sentences):
        '''
        Given a nested list of tokens(with start and end token), return the character ids
        Arguments:
            sentences: List[List[str]]

        Return: [sentence_num, token_num, max_char_num]
        '''
        return self.elmo_char_batcher.batch_sentences(sentences).tolist()

    def get_elmo_token_ids(self, sentences):
        '''
        Given a nested list of tokens(without start and end token), return the token ids

        Arguments:
           sentemces : List[List[str]]

        Return : [sentence_num, token_num, max_char_num]
        '''
        return self.elmo_token_batcher.batch_sentences(sentences).tolist()

    def get_elmo_emb_op(self, input_ids_place_holder):
        '''
        Given the input ids place holder, reutrn a ops for computing the language model
        {
         'lm_embeddings': embedding_op, (None, 3, None, 1024)
         'lengths': sequence_lengths_op, (None, )
         'mask': op to compute mask (None, None)
        }
        '''
        return self.elmo_bilm(input_ids_place_holder)

    def weight_layers(self,
                      name,
                      bilm_ops,
                      l2_coef=None,
                      use_top_only=False,
                      do_layer_norm=False):
        '''
        Weight the layers of a biLM with trainable scalar weights to compute ELMo representations.
        See more details on https://github.com/allenai/bilm-tf/blob/81a4b54937f4dfb93308f709c1cf34dbb37c553e/bilm/elmo.py
        {
           'weighted_op': op to compute weighted average for output,
           'regularization_op': op to compute regularization term
        }
        '''
        return weight_layers(name, bilm_ops, l2_coef, use_top_only,
                             do_layer_norm)

    @staticmethod
    def prepare_elmo_vocab_file(vocab, elmo_vocab_file):
        sorted_word = sorted(vocab.token_cnt,
                             key=vocab.token_cnt.get,
                             reverse=True)
        with open(elmo_vocab_file, 'w') as f:
            f.write('{}\n'.format(ELMo_Utils.START_TOKEN))
            f.write('{}\n'.format(ELMo_Utils.END_TOKEN))
            f.write('{}\n'.format(ELMo_Utils.UNK_TOKEN))
            for item in sorted_word:
                f.write('%s\n' % item)

    def build_elmo_char_cache(self, snt_dict_file, max_snt_length,
                              output_cache_file):
        """
        Go through all the snts in the dataset, save into the cache
        """
        self.logger.info(
            'Prepare ELMo character embeddings for {} with ELMo_Utils ...'.
            format(snt_dict_file))
        ids_placeholder = tf.placeholder('int32',
                                         shape=(None, max_snt_length,
                                                self.elmo_max_num_char))
        ops = self.elmo_bilm(ids_placeholder)
        config = tf.ConfigProto(allow_soft_placement=True)
        with tf.Session(config=config) as sess:
            sess.run(tf.global_variables_initializer())
            with open(snt_dict_file,
                      'r') as fin, h5py.File(output_cache_file, 'w') as fout:
                lm_embeddings_h5 = fout.create_group('lm_embeddings')
                lengths_h5 = fout.create_group('lengths')
                mask_h5 = fout.create_group('mask')
                batch_snts = []
                start_snt_id_in_batch = 0
                SNT_BATCH_SIZE = 10
                for line in tqdm(fin, total=get_num_lines(snt_dict_file)):
                    sentence = line.strip().split()
                    batch_snts.append(sentence)
                    length = len(batch_snts)
                    if length >= SNT_BATCH_SIZE:
                        start_snt_id_in_batch += self.consume_batch_snts(
                            sess, ids_placeholder, ops, batch_snts,
                            max_snt_length, start_snt_id_in_batch,
                            lm_embeddings_h5, lengths_h5, mask_h5)
                        batch_snts = []
                if len(batch_snts) > 0:
                    start_snt_id_in_batch += self.consume_batch_snts(
                        sess, ids_placeholder, ops, batch_snts, max_snt_length,
                        start_snt_id_in_batch, lm_embeddings_h5, lengths_h5,
                        mask_h5)
                    batch_snts = []
                self.logger.info(
                    "Finished ELMo embeddings for {} senencesm in {}".format(
                        start_snt_id_in_batch, output_cache_file))

    def consume_batch_snts(self, sess, ids_placeholder, ops, batch_snts,
                           max_snt_length, start_snt_id_in_batch,
                           lm_embeddings_h5, lengths_h5, mask_h5):
        char_ids = self.get_elmo_char_ids(batch_snts)
        char_ids = [(ids + [self.PAD_TOKEN_CHAR_IDS] *
                     (max_snt_length - len(ids)))[:max_snt_length]
                    for ids in char_ids]
        elmo_ops = sess.run(ops, feed_dict={ids_placeholder: char_ids})
        batch_size = len(batch_snts)
        for i in range(batch_size):
            sentence_id = start_snt_id_in_batch + i
            # self.logger.info("create lm for snt {}".format(sentence_id))
            lm_embeddings_h5.create_dataset(
                '{}'.format(sentence_id),
                elmo_ops['lm_embeddings'].shape[1:],
                dtype='float32',
                data=elmo_ops['lm_embeddings'][i, :, :, :],
                compression="gzip")
            lengths_h5.create_dataset('{}'.format(sentence_id), (1, ),
                                      dtype='int32',
                                      data=elmo_ops['lengths'][i])
            mask_h5.create_dataset('{}'.format(sentence_id),
                                   elmo_ops['mask'].shape[1:],
                                   dtype='int32',
                                   data=elmo_ops['mask'][i],
                                   compression="gzip")
        return batch_size

    # TODO for token level embedding.
    def build_elmo_token_cache(self, snt_dict_file, max_snt_length,
                               output_cache_file):
        pass

    def build_elmo_cache(self, snt_dict_file, max_snt_length,
                         output_cache_file):
        if self.use_character_elmo:
            self.build_elmo_char_cache(snt_dict_file, max_snt_length,
                                       output_cache_file)
        else:
            self.build_elmo_token_cache(snt_dict_file, max_snt_length,
                                        output_cache_file)

        self.logger.info(
            'Finished ELMo embeddings for utterance cache with ELMo_Utils')

    def build_elmo_cache_for_samples(self, dataset, max_p_len, max_q_len):
        if (not self.need_p_cache) and (not self.need_q_cache):
            self.logger.info(
                'No need for ELMo embeddings for concated passage and question with ELMo_Utils'
            )
        else:
            # build graph for getting forward elmo embedding.
            self.logger.info('Build ELMo embeddings for p = {}, q = {}'.format(
                self.need_p_cache, self.need_q_cache))
            self.build_pq_elmo_graph()
            if self.need_p_cache:
                p_out = h5py.File(self.passage_cache_file, 'w')
                p_lm_embeddings_h5 = p_out.create_group('lm_embeddings')
                p_lengths_h5 = p_out.create_group('lengths')
                p_mask_h5 = p_out.create_group('mask')

            if self.need_q_cache:
                q_out = h5py.File(self.question_cache_file, 'w')
                q_lm_embeddings_h5 = q_out.create_group('lm_embeddings')
                q_lengths_h5 = q_out.create_group('lengths')
                q_mask_h5 = q_out.create_group('mask')

            config = tf.ConfigProto(allow_soft_placement=True)
            with tf.Session(config=config) as sess:
                sess.run(tf.global_variables_initializer())
                for set_name in ['train', 'dev', 'test']:
                    for batch_data in tqdm(
                            dataset.gen_mini_batches(set_name,
                                                     20,
                                                     shuffle=False)):
                        samples = batch_data['raw_data']
                        # batch_data is filled with elmo feed_dict
                        self.run_pq_ops(sess, batch_data, max_p_len, max_q_len)
                        for i in range(len(samples)):
                            e_id = '{}'.format(samples[i]['example-id'])
                            try:
                                if self.need_p_cache:
                                    p_lm_embeddings_h5.create_dataset(
                                        e_id,
                                        p_ops['lm_embeddings'].shape[1:],
                                        dtype='float32',
                                        data=p_ops['lm_embeddings'][
                                            i, :, :, :],
                                        compression="gzip")
                                    p_lengths_h5.create_dataset(
                                        e_id, (1, ),
                                        dtype='int32',
                                        data=p_ops['lengths'][i])
                                    p_mask_h5.create_dataset(
                                        e_id,
                                        p_ops['mask'].shape[1:],
                                        dtype='int32',
                                        data=p_ops['mask'][i, :],
                                        compression="gzip")
                                if self.need_q_cache:
                                    q_lm_embeddings_h5.create_dataset(
                                        e_id,
                                        q_ops['lm_embeddings'].shape[1:],
                                        dtype='float32',
                                        data=q_ops['lm_embeddings'][
                                            i, :, :, :],
                                        compression="gzip")
                                    q_lengths_h5.create_dataset(
                                        e_id,
                                        (1, ),
                                        dtype='int32',
                                        data=q_ops['lengths'][i],
                                    )
                                    q_mask_h5.create_dataset(
                                        e_id,
                                        q_ops['mask'].shape[1:],
                                        dtype='int32',
                                        data=q_ops['mask'][i, :],
                                        compression="gzip")
                            except:
                                continue

        self.logger.info(
            'Finished ELMo embeddings for concated passage and question with ELMo_Utils'
        )

    def run_pq_ops(self, sess, batch_data, max_p_len, max_q_len):
        self._static_pq_padding(batch_data, max_p_len, max_q_len)

        if self.need_p_cache and self.need_q_cache:
            self.p_ops, self.q_ops = sess.run(
                [self.p_emb_elmo_op, self.q_emb_elmo_op],
                feed_dict={
                    self.elmo_p: batch_data['elmo_passage_char_ids'],
                    self.elmo_q: batch_data['elmo_question_char_ids']
                })
        elif self.need_p_cache:
            self.p_ops = sess.run(
                [self.p_emb_elmo_op],
                feed_dict={self.elmo_p: batch_data['elmo_passage_char_ids']})
        else:
            self.q_ops = sess.run([self.q_emb_elmo_op],
                                  feed_dict={
                                      self.elmo_q:
                                      batch_data['elmo_question_char_ids'],
                                  })

    def build_pq_elmo_graph(self):
        """
        Given the batch_data, this will seperately run tensorflow get the elmo embedding for each batch, which will be cached into file
        Especially , for sample level cache, please make sure that the first dimension for any tensor is batch_size
        """
        start_t = time.time()
        self.logger.info(
            "Start building elmo graph for concatenated p and q ...")
        self.add_elmo_placeholders()
        with tf.device('/device:GPU:0'):
            with tf.variable_scope("", reuse=tf.AUTO_REUSE):
                # get all elmo op with language mode
                # lm_embeddings : [batch_size, layers, max_length, hidden_dims * 2]
                # lengths : [batch_size]
                # mask : [batch_size, length]
                if self.need_p_cache:
                    self.p_emb_elmo_op = self.elmo_bilm(self.elmo_p)

                if self.need_q_cache:
                    # [batch_size, context_window, layers, max_u_length, hidden_dims * 2]
                    self.q_emb_elmo_op = self.elmo_bilm(self.elmo_q)

    def add_elmo_placeholders(self):
        """
        elmo for business, logic corresponding the specific application
        """
        # for ELMo with character embedding
        # elmo passage character ids for each token in each concatenated passage
        # [batch_size, passage_length, char_length]

        if self.need_p_cache:
            self.elmo_p = tf.placeholder(tf.int32,
                                         [None, None, self.elmo_max_num_char],
                                         'elmo_p')
        # elmo character ids for whole concatenated qustion
        # [batch_size, question_length, char_length]
        self.elmo_q = tf.placeholder(tf.int32,
                                     [None, None, self.elmo_max_num_char],
                                     'elmo_q')

    def _static_pq_padding(self, batch_data, max_p_len, max_q_len):
        """
        This is used for static padding, which is useful when the deep contextual embedding is saved with a mask of the whole static length.
        """
        # also padding elmo matrix
        # in elmo, the character ids after batch_sentences contains the start and end token, length for charids +2 while the final embedding not contains those special token.
        # For further compatibility, we still leave elmo length as different length.
        pad_q_len_elmo = 2 + max_q_len
        padding(batch_data, 'elmo_question_char_ids', pad_q_len_elmo,
                self.PAD_TOKEN_CHAR_IDS)

        if self.need_p_cache:
            pad_p_len_elmo = 2 + max_p_len
            padding(batch_data, 'elmo_passage_char_ids', pad_p_len_elmo,
                    self.PAD_TOKEN_CHAR_IDS)

    def _prepare_passage_elmo_feed_dict(self, sample, batch_data,
                                        context_window, token_key_to_use):
        """
        add elmo feed_dict for passage
        """
        e_id_str = '{}'.format(sample['example-id'])
        passage_utterance_tokens_elmo = []
        passage_utterance_length_elmo = []
        passage_tokens_elmo = [ELMo_Utils.START_TOKEN]
        passage_snt_ids = []
        pruned_context_utterances_elmo = sample['messages-so-far'][
            -context_window:]
        for i in range(context_window):
            if i >= len(pruned_context_utterances_elmo):
                current_utterance_tokens_elmo = [
                    ELMo_Utils.START_TOKEN, ELMo_Utils.END_TOKEN
                ]
                passage_snt_ids.append(ELMo_Utils.PAD_SNT_ID)
                passage_utterance_tokens_elmo.append(
                    current_utterance_tokens_elmo)
                passage_utterance_length_elmo.append(
                    len(current_utterance_tokens_elmo))
            else:
                utterance = pruned_context_utterances_elmo[i]
                if 'snt_id' in utterance:
                    passage_snt_ids.append(utterance['snt_id'])
                # split version of passages
                current_utterance_tokens_elmo = [ELMo_Utils.START_TOKEN]
                current_utterance_tokens_elmo.extend(
                    utterance[token_key_to_use])
                current_utterance_tokens_elmo.extend([ELMo_Utils.END_TOKEN])
                passage_utterance_tokens_elmo.append(
                    current_utterance_tokens_elmo)
                passage_utterance_length_elmo.append(
                    len(current_utterance_tokens_elmo))
                # concatenated version of passages
                # append passages utterance tokens
                passage_tokens_elmo.extend(utterance[token_key_to_use])

        passage_tokens_elmo.extend([ELMo_Utils.END_TOKEN])
        if self.need_build_passage_cache():
            # add into batch_data, no other batch data will data
            # [batch_size, passage_length, max_char_num]
            batch_data['elmo_passage_char_ids'].append(
                self.get_elmo_char_ids([passage_tokens_elmo])[0])
        else:
            #TODO add passage and question elmo retrieve here.
            if self.need_p_cache:
                self.assemble_elmo_batch_data('p', batch_data, e_id_str,
                                              self.passage_cache)
            for snt_id in passage_snt_ids:
                # self.assemble_elmo_with_snt_ids('pu', batch_data, snt_id)
                # self.assemble_elmo_batch_data_with_mem('pu', batch_data, snt_id, self.utterance_cache_in_mem)
                self.assemble_elmo_batch_data('pu', batch_data, snt_id,
                                              self.utterance_cache)

    def _prepare_question_elmo_feed_dict(self, sample, batch_data,
                                         question_window, token_key_to_use):
        """
        add question elmo feed_dict according the same style for adding regular question feed_dict
        """
        e_id_str = '{}'.format(sample['example-id'])
        # for each utterance in question
        question_utterance_tokens_elmo = []
        # for the concatenated question
        # for question utterance length
        question_utterance_length_elmo = []
        question_snt_ids = []
        # add start token, which is also in the vocabulary
        # in non-elmo, embedding, we wil add self.vocab.sos and self.vocab.eos in to the sentence,whic will be encoded by the downstream lstm. However, sos and eos are in capital case in the elmo. In fact, we must use Upper case here to get a emebdding from elmo abou it.
        question_tokens_elmo = [ELMo_Utils.START_TOKEN]
        pruned_question_utterance_elmo = sample['messages-so-far'][
            -question_window:]
        for i in range(question_window):
            if i >= len(pruned_question_utterance_elmo):
                current_utterance_tokens_elmo = [
                    ELMo_Utils.START_TOKEN, ELMo_Utils.END_TOKEN
                ]
                question_snt_ids.append(ELMo_Utils.PAD_SNT_ID)
                question_utterance_tokens_elmo.append(
                    current_utterance_tokens_elmo)
                question_utterance_length_elmo.append(
                    len(current_utterance_tokens_elmo))
            else:
                utterance = pruned_question_utterance_elmo[i]
                # split version of question
                if 'snt_id' in utterance:
                    question_snt_ids.append(utterance['snt_id'])
                current_utterance_tokens_elmo = [ELMo_Utils.START_TOKEN]
                current_utterance_tokens_elmo.extend(
                    utterance[token_key_to_use])
                current_utterance_tokens_elmo.extend([ELMo_Utils.END_TOKEN])
                # add each utterance token_ids into a parental list
                question_utterance_tokens_elmo.append(
                    current_utterance_tokens_elmo)
                question_utterance_length_elmo.append(
                    len(current_utterance_tokens_elmo))
                # concatenated version of question
                # append question utterance tokens
                question_tokens_elmo.extend(utterance[token_key_to_use])

        question_tokens_elmo.extend([ELMo_Utils.END_TOKEN])
        if question_window == 0:
            # if note use question, here it will make mistake,
            # bug here. make question at least = 1
            pass
        else:
            # add elmo question tokenids into batch_data
            if self.need_build_question_cache():
                # add into batch_data
                # [batch_size, question_length, max_char_num]
                batch_data['elmo_question_char_ids'].append(
                    self.get_elmo_char_ids([question_tokens_elmo])[0])
            else:
                # if question_window = 1, then juse use utterance cache
                if question_window == 1:
                    # self.assemble_elmo_with_snt_ids('q', batch_data, question_snt_ids[0])
                    # self.assemble_elmo_batch_data_with_mem('q', batch_data, question_snt_ids[0], self.utterance_cache_in_mem)
                    self.assemble_elmo_batch_data('q', batch_data,
                                                  question_snt_ids[0],
                                                  self.utterance_cache)
                else:
                    self.assemble_elmo_batch_data('q', batch_data, e_id_str,
                                                  self.question_cache)

    def _prepare_response_elmo_feed_dict(self, sample, batch_data,
                                         token_key_to_use):
        """
        add question elmo feed_dict according the same style for adding regular question feed_dict
        """
        if 'options-for-correct-answers':
            e_id_str = '{}'.format(sample['example-id'])
            utterance = sample['options-for-correct-answers'][0]
            # split version of question
            current_utterance_tokens_elmo = [ELMo_Utils.START_TOKEN]
            current_utterance_tokens_elmo.extend(utterance[token_key_to_use])
            current_utterance_tokens_elmo.extend([ELMo_Utils.END_TOKEN])
            if 'snt_id' in utterance:
                snt_id = utterance['snt_id']
                self.assemble_elmo_batch_data('r', batch_data, snt_id,
                                              self.utterance_cache)

    def init_elmo_batch_data_sntids(self, batch_data):
        if self.need_p_cache:
            # use elmo cache to retrieve batch_data
            batch_data['elmo_p_lm_embeddings'] = []
            batch_data['elmo_p_lengths'] = []
            batch_data['elmo_p_mask'] = []
        batch_data['elmo_pu_snt_ids'] = []
        batch_data['elmo_q_snt_ids'] = []
        batch_data['elmo_r_snt_ids'] = []

    def init_elmo_batch_data_emb(self, batch_data):
        if self.need_p_cache:
            # use elmo cache to retrieve batch_data
            batch_data['elmo_p_lm_embeddings'] = []
            batch_data['elmo_p_lengths'] = []
            batch_data['elmo_p_mask'] = []

        # for passage_utterance
        batch_data['elmo_pu_lm_embeddings'] = []
        batch_data['elmo_pu_lengths'] = []
        batch_data['elmo_pu_mask'] = []
        # for question
        batch_data['elmo_q_lm_embeddings'] = []
        batch_data['elmo_q_lengths'] = []
        batch_data['elmo_q_mask'] = []
        # for res
        batch_data['elmo_r_lm_embeddings'] = []
        batch_data['elmo_r_lengths'] = []
        batch_data['elmo_r_mask'] = []

    def add_elmo_placeholder_with_cache_sntids(self):
        """
        add placeholders for elmo ops, which will be used in the weight_layers
        """
        if self.need_p_cache:
            self.elmo_p_lm_embeddings = tf.placeholder(
                tf.float32, [None, self.output_layers, None, self.output_dim],
                name='elmp_p_lm_embeddings')
            self.elmo_p_lengths = tf.placeholder(tf.int32, [None],
                                                 name='elmo_p_lengths')
            self.elmo_p_mask = tf.placeholder(tf.int32, [None, None],
                                              name='elmo_p_mask')

        self.elmo_pu_snt_ids = tf.placeholder(tf.int32, [None],
                                              name='elmo_pu_snt_ids')
        self.elmo_q_snt_ids = tf.placeholder(tf.int32, [None],
                                             name='elmo_q_snt_ids')
        self.elmo_r_snt_ids = tf.placeholder(tf.int32, [None],
                                             name='elmo_r_snt_ids')

    def add_elmo_placeholder_with_cache_emb(self):
        """
        add placeholders for elmo ops, which will be used in the weight_layers
        """
        if self.need_p_cache:
            self.elmo_p_lm_embeddings = tf.placeholder(
                tf.float32, [None, self.output_layers, None, self.output_dim],
                name='elmp_p_lm_embeddings')
            self.elmo_p_lengths = tf.placeholder(tf.int32, [None],
                                                 name='elmo_p_lengths')
            self.elmo_p_mask = tf.placeholder(tf.int32, [None, None],
                                              name='elmo_p_mask')

        self.elmo_pu_lm_embeddings = tf.placeholder(
            tf.float32, [None, self.output_layers, None, self.output_dim],
            name='elmo_pu_lm_embeddings')
        self.elmo_pu_lengths = tf.placeholder(tf.int32, [None],
                                              name='elmo_pu_lengths')
        self.elmo_pu_mask = tf.placeholder(tf.int32, [None, None],
                                           name='elmo_pu_mask')
        self.elmo_q_lm_embeddings = tf.placeholder(
            tf.float32, [None, self.output_layers, None, self.output_dim],
            name='elmo_q_lm_embeddings')
        self.elmo_q_lengths = tf.placeholder(tf.int32, [None],
                                             name='elmo_q_lengths')
        self.elmo_q_mask = tf.placeholder(tf.int32, [None, None],
                                          name='elmo_q_mask')
        self.elmo_r_lm_embeddings = tf.placeholder(
            tf.float32, [None, self.output_layers, None, self.output_dim],
            name='elmo_r_lm_embeddings')
        self.elmo_r_lengths = tf.placeholder(tf.int32, [None],
                                             name='elmo_r_lengths')
        self.elmo_r_mask = tf.placeholder(tf.int32, [None, None],
                                          name='elmo_r_mask')

    def prepare_elmo_cache_feed_dict_sntids(self, feed_dict, batch):
        """
        consitently feed the batch_data, we prepared in the prepare_passage_elmo, question_elmo, answer_elmo
        """
        if self.need_p_cache:
            # for elmo_p
            feed_dict[
                self.elmo_p_lm_embeddings] = batch['elmo_p_lm_embeddings']
            feed_dict[self.elmo_p_lengths] = batch['elmo_p_lengths']
            feed_dict[self.elmo_p_mask] = batch['elmo_p_mask']

        # for elmo_q
        feed_dict[self.elmo_q_snt_ids] = batch['elmo_q_snt_ids']
        # for elmo_pu
        feed_dict[self.elmo_pu_snt_ids] = batch['elmo_pu_snt_ids']
        # for elmo_r
        feed_dict[self.elmo_r_snt_ids] = batch['elmo_r_snt_ids']

    def prepare_elmo_cache_feed_dict_emb(self, feed_dict, batch):
        """
        consitently feed the batch_data, we prepared in the prepare_passage_elmo, question_elmo, answer_elmo
        """
        if self.need_p_cache:
            # for elmo_p
            feed_dict[
                self.elmo_p_lm_embeddings] = batch['elmo_p_lm_embeddings']
            feed_dict[self.elmo_p_lengths] = batch['elmo_p_lengths']
            feed_dict[self.elmo_p_mask] = batch['elmo_p_mask']

        # for elmo_q
        feed_dict[self.elmo_q_lm_embeddings] = batch['elmo_q_lm_embeddings']
        feed_dict[self.elmo_q_lengths] = batch['elmo_q_lengths']
        feed_dict[self.elmo_q_mask] = batch['elmo_q_mask']

        # for elmo_pu
        feed_dict[self.elmo_pu_lm_embeddings] = batch['elmo_pu_lm_embeddings']
        feed_dict[self.elmo_pu_lengths] = batch['elmo_pu_lengths']
        feed_dict[self.elmo_pu_mask] = batch['elmo_pu_mask']

        # for elmo_r
        feed_dict[self.elmo_r_lm_embeddings] = batch['elmo_r_lm_embeddings']
        feed_dict[self.elmo_r_lengths] = batch['elmo_r_lengths']
        feed_dict[self.elmo_r_mask] = batch['elmo_r_mask']

    def elmo_embedding_layer_emb(self, elmo_emb_output):
        """
        elmo embedding layers, which will return embedding for p,q,a,pu,qu
        after projections, dim is elmo_emb_output
        if elmo_emb_output == self.output_dim, then no projection will be done
        """
        self.logger.info('build elmo embedding layer')
        if self.need_p_cache:
            p_emb_elmo_op = {
                'lm_embeddings': self.elmo_p_lm_embeddings,
                'lengths': self.elmo_p_lengths,
                'mask': self.elmo_p_mask
            }

        q_emb_elmo_op = {
            'lm_embeddings': self.elmo_q_lm_embeddings,
            'lengths': self.elmo_q_lengths,
            'mask': self.elmo_q_mask
        }

        pu_emb_elmo_op = {
            'lm_embeddings': self.elmo_pu_lm_embeddings,
            'lengths': self.elmo_pu_lengths,
            'mask': self.elmo_pu_mask
        }

        r_emb_elmo_op = {
            'lm_embeddings': self.elmo_r_lm_embeddings,
            'lengths': self.elmo_r_lengths,
            'mask': self.elmo_r_mask
        }

        with tf.device('/device:GPU:1'):
            with tf.variable_scope("", reuse=tf.AUTO_REUSE):
                if self.need_p_cache:
                    self.p_elmo_emb = self.weight_layers(
                        'input', p_emb_elmo_op, l2_coef=0.0)['weighted_op']
                self.q_elmo_emb = self.weight_layers(
                    'input', q_emb_elmo_op, l2_coef=0.0)['weighted_op']
                self.pu_elmo_emb = self.weight_layers(
                    'input', pu_emb_elmo_op, l2_coef=0.0)['weighted_op']
                self.r_elmo_emb = self.weight_layers(
                    'input', r_emb_elmo_op, l2_coef=0.0)['weighted_op']
                # do project from elmo embedding into 128 embedding to contact with word embedding.
                if elmo_emb_output == self.output_dim:
                    self.logger.info(
                        "Elmo_emb_output={} is just equal to the output_dim={}, no need to project with fully connected layers for passage and questions"
                        .format(elmo_emb_output, self.output_dim))
                else:
                    self.logger.info(
                        "Elmo_emb_output={}, output_dim={}, project with fully connected layers for question and passage"
                        .format(elmo_emb_output, self.output_dim))
                    if self.need_p_cache:
                        self.p_elmo_emb = tf.contrib.layers.fully_connected(
                            inputs=self.p_elmo_emb,
                            num_outputs=elmo_emb_output,
                            activation_fn=tf.nn.softmax)

                    self.q_elmo_emb = tf.contrib.layers.fully_connected(
                        inputs=self.q_elmo_emb,
                        num_outputs=elmo_emb_output,
                        activation_fn=tf.nn.softmax)
                    self.pu_elmo_emb = tf.contrib.layers.fully_connected(
                        inputs=self.pu_elmo_emb,
                        num_outputs=elmo_emb_output,
                        activation_fn=tf.nn.softmax)
                    self.r_elmo_emb = tf.contrib.layers.fully_connected(
                        inputs=self.r_elmo_emb,
                        num_outputs=elmo_emb_output,
                        activation_fn=tf.nn.softmax)

    def elmo_embedding_layer_sntids(self, elmo_emb_output):
        """
        elmo embedding layers, which will return embedding for p,q,a,pu,qu
        after projections, dim is elmo_emb_output
        if elmo_emb_output == self.output_dim, then no projection will be done
        """
        with tf.device('/cpu:0'), tf.variable_scope('elmo_embedding'):
            self.elmo_lm_embeddings_lookup = tf.get_variable(
                'lm_embeddings_lookup',
                shape=np.shape(self.utterance_cache_in_mem['lm_embeddings']),
                initializer=tf.constant_initializer(
                    self.utterance_cache_in_mem['lm_embeddings']),
                trainable=False)

            self.elmo_lengths_lookup = tf.get_variable(
                'lengths_lookup',
                shape=(np.shape(self.utterance_cache_in_mem['lengths'])),
                initializer=tf.constant_initializer(
                    self.utterance_cache_in_mem['lengths']),
                trainable=False)

            self.elmo_mask_lookup = tf.get_variable(
                'mask_lookup',
                shape=np.shape(self.utterance_cache_in_mem['mask']),
                initializer=tf.constant_initializer(
                    self.utterance_cache_in_mem['mask']),
                trainable=False)

        if self.need_p_cache:
            p_emb_elmo_op = {
                'lm_embeddings': self.elmo_p_embeddings,
                'lengths': self.elmo_p_lengths,
                'mask': self.elmo_p_mask
            }

        q_emb_elmo_op = {
            'lm_embeddings':
            tf.nn.embedding_lookup(self.elmo_lm_embeddings_lookup,
                                   self.elmo_q_snt_ids),
            'lengths':
            tf.nn.embedding_lookup(self.elmo_lengths_lookup,
                                   self.elmo_q_snt_ids),
            'mask':
            tf.nn.embedding_lookup(self.elmo_mask_lookup, self.elmo_q_snt_ids)
        }

        pu_emb_elmo_op = {
            'lm_embeddings':
            tf.nn.embedding_lookup(self.elmo_lm_embeddings_lookup,
                                   self.elmo_pu_snt_ids),
            'lengths':
            tf.nn.embedding_lookup(self.elmo_lengths_lookup,
                                   self.elmo_pu_snt_ids),
            'mask':
            tf.nn.embedding_lookup(self.elmo_mask_lookup, self.elmo_pu_snt_ids)
        }

        r_emb_elmo_op = {
            'lm_embeddings':
            tf.nn.embedding_lookup(self.elmo_lm_embeddings_lookup,
                                   self.elmo_r_snt_ids),
            'lengths':
            tf.nn.embedding_lookup(self.elmo_lengths_lookup,
                                   self.elmo_r_snt_ids),
            'mask':
            tf.nn.embedding_lookup(self.elmo_mask_lookup, self.elmo_r_snt_ids)
        }

        with tf.device('/device:GPU:1'):
            with tf.variable_scope("", reuse=tf.AUTO_REUSE):
                if self.need_p_cache:
                    self.p_elmo_emb = self.weight_layers(
                        'input', p_emb_elmo_op, l2_coef=0.0)['weighted_op']
                self.q_elmo_emb = self.weight_layers(
                    'input', q_emb_elmo_op, l2_coef=0.0)['weighted_op']
                self.pu_elmo_emb = self.weight_layers(
                    'input', pu_emb_elmo_op, l2_coef=0.0)['weighted_op']
                self.r_elmo_emb = self.weight_layers(
                    'input', r_emb_elmo_op, l2_coef=0.0)['weighted_op']
                # do project from elmo embedding into 128 embedding to contact with word embedding.
                if elmo_emb_output == self.output_dim:
                    self.logger.info(
                        "Elmo_emb_output={} is just equal to the output_dim={}, no need to project with fully connected layers for question and passage"
                        .format(elmo_emb_output, self.output_dim))
                else:
                    self.logger.info(
                        "Elmo_emb_output={}, output_dim={}, project with fully connected layers for question and passage"
                        .format(elmo_emb_output, self.output_dim))
                    if self.need_p_cache:
                        self.p_elmo_emb = tf.contrib.layers.fully_connected(
                            inputs=self.p_elmo_emb,
                            num_outputs=elmo_emb_output,
                            activation_fn=tf.nn.softmax)

                    self.q_elmo_emb = tf.contrib.layers.fully_connected(
                        inputs=self.q_elmo_emb,
                        num_outputs=elmo_emb_output,
                        activation_fn=tf.nn.softmax)
                    self.pu_elmo_emb = tf.contrib.layers.fully_connected(
                        inputs=self.pu_elmo_emb,
                        num_outputs=elmo_emb_output,
                        activation_fn=tf.nn.softmax)
                    self.r_elmo_emb = tf.contrib.layers.fully_connected(
                        inputs=self.r_elmo_emb,
                        num_outputs=elmo_emb_output,
                        activation_fn=tf.nn.softmax)

    def assemble_elmo_batch_data(self, name, batch_data, id_key, cache):
        lm_embeddings = cache['lm_embeddings']['{}'.format(id_key)][...]
        length = cache['lengths']['{}'.format(id_key)][0]
        mask = cache['mask']['{}'.format(id_key)][...]
        batch_data['elmo_{}_lm_embeddings'.format(name)].append(lm_embeddings)
        batch_data['elmo_{}_lengths'.format(name)].append(length)
        batch_data['elmo_{}_mask'.format(name)].append(mask)

    def assemble_elmo_batch_data_with_mem(self, name, batch_data, id_key,
                                          cache_in_mem):
        """
        id_key is int here, for the snt_id
        """
        lm_embeddings = cache_in_mem['lm_embeddings'][id_key]
        length = cache_in_mem['lengths'][id_key]
        mask = cache_in_mem['mask'][id_key]
        batch_data['elmo_{}_lm_embeddings'.format(name)].append(lm_embeddings)
        batch_data['elmo_{}_lengths'.format(name)].append(length)
        batch_data['elmo_{}_mask'.format(name)].append(mask)

    def assemble_elmo_with_snt_ids(self, name, batch_data, id_key):
        """
        id_key is int here, for the snt_id
        """
        batch_data['elmo_{}_snt_ids'.format(name)].append(id_key)
def dump_bilm_embeddings(vocab_file, options_file, weight_file,
                         embedding_file):

    batcher = TokenBatcher(vocab_file)

    ids_placeholder = tf.placeholder('int32', shape=(None, None))
    model = BidirectionalLanguageModel(options_file,
                                       weight_file,
                                       use_character_inputs=False,
                                       embedding_weight_file=embedding_file)
    ops = model(ids_placeholder)

    config = tf.ConfigProto(allow_soft_placement=True)
    # config.gpu_options.per_process_gpu_memory_fraction = 0.6
    config.gpu_options.allow_growth = True

    def dump_split_bl_embeddings(_sess, dataset_file, outfile_last,
                                 outfile_avg, outfile_avglasttwo):
        if outfile_avglasttwo:
            fout_avglasttwo = h5py.File(outfile_avglasttwo, 'w')
        if outfile_last:
            fout_last = h5py.File(outfile_last, 'w')
        if outfile_avg:
            fout_avg = h5py.File(outfile_avg, 'w')
        fin = open(dataset_file, 'r')
        try:
            sentence_id = 0
            for line in fin:
                sentence = json.loads(line.strip())['text']
                # sentence = line.strip().split()
                _ids = batcher.batch_sentences([sentence])
                embeddings = sess.run(ops['lm_embeddings'],
                                      feed_dict={ids_placeholder: _ids})
                embedding = embeddings[0, :, :, :]
                if outfile_last:
                    last_layer_emb = embedding[2, :, :]
                    last_ds = fout_last.create_dataset(
                        '{}'.format(sentence_id),
                        last_layer_emb.shape,
                        dtype='float32',
                        data=last_layer_emb)
                if outfile_avg:
                    avg_emb = np.mean(embedding, axis=0)
                    avg_ds = fout_avg.create_dataset('{}'.format(sentence_id),
                                                     avg_emb.shape,
                                                     dtype='float32',
                                                     data=avg_emb)
                if outfile_avglasttwo:
                    avg_lasttwo = np.mean(embedding[1:3, :, :], axis=0)
                    avglasttwo_ds = fout_avglasttwo.create_dataset(
                        '{}'.format(sentence_id),
                        avg_lasttwo.shape,
                        dtype='float32',
                        data=avg_lasttwo)
                sentence_id += 1
                if sentence_id % 500 == 0:
                    print('%.2f%% finished!' %
                          (sentence_id / float(EXAMPLE_COUNT) * 100))
        finally:
            fin.close()
            if outfile_avglasttwo:
                fout_avglasttwo.close()
            if outfile_last:
                fout_last.close()
            if outfile_avg:
                fout_avg.close()

    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        for _count, split in zip([1500, 17000, 3000],
                                 ['dev', 'train_full', 'test']):
            # split = 'dev'  # dev/train_full/test
            dataset_file = 'data/%s.json' % split
            output_file_last = os.path.join(
                ELMo_data, '%s_elmo_embeddings_last.hdf5' % split)
            output_file_avg = os.path.join(
                ELMo_data, '%s_elmo_embeddings_avg.hdf5' % split)
            output_file_avg_of_last_two = os.path.join(
                ELMo_data, '%s_elmo_embeddings_avg_of_last_two.hdf5' % split)
            EXAMPLE_COUNT = _count
            print('start to dump %s split...' % split)
            start = time.time()
            dump_split_bl_embeddings(sess, dataset_file, output_file_last,
                                     output_file_avg,
                                     output_file_avg_of_last_two)
            print('%.1f mins.' % ((time.time() - start) / 60.))
args = ap.parse_args()

os.environ[
    "CUDA_VISIBLE_DEVICES"] = args.GPU if args.GPU else '3'  #use GPU with ID=0
tfconfig = tf.ConfigProto()
tfconfig.gpu_options.per_process_gpu_memory_fraction = 0.8  # maximun alloc gpu50% of MEM
tfconfig.gpu_options.allow_growth = True  #allocate dynamically
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"  #屏蔽warning信息

config = Config()

if __name__ == '__main__':
    datas = json.load(open('../data/test2.json', encoding='utf-8'))
    ndatas = [line.split()[1:] for line in datas[:10]]

    batcher = TokenBatcher(config.vocab_file)  #生成词表示的batch类

    inputData = tf.placeholder('int32', shape=(None, None))

    abilm = BidirectionalLanguageModel(
        config.option_file,
        config.weight_file,
        use_character_inputs=False,
        embedding_weight_file=config.tokenEmbeddingFile)
    inputEmbeddingsOp = abilm(inputData)

    elmoInput = weight_layers('input', inputEmbeddingsOp, l2_coef=0.0)

    sess = tf.Session()
    with sess.as_default():
        sess.run(tf.global_variables_initializer())
    def list_to_lazy_embeddings_with_dump(self,
                                          batch: List[List[str]],
                                          outfile_to_dump=None,
                                          partition=20):
        """
        Parameters
        ----------
        batch : ``List[List[str]]``, required
            A list of tokenized sentences.

        """
        nothing = []
        if batch == [[]]:
            raise ValueError('Batch should not be empty')
        else:

            if self.word_embedding_file is None:
                batcher = Batcher(self.voc_file_path, self.max_word_length)
            else:
                batcher = TokenBatcher(self.voc_file_path)
            config = tf.ConfigProto(allow_soft_placement=True)
            num_of_total_tokens = len(batch)
            each_partition_size = math.ceil(num_of_total_tokens / partition)
            print('Parition Size:{}'.format(partition))
            for _pi in range(0, partition):
                document_embeddings = []
                with tf.Session(config=config) as sess:
                    sess.run(tf.global_variables_initializer())
                    _begin_index = _pi * each_partition_size
                    _end_index = _begin_index + each_partition_size
                    print(15 * '-')
                    print('Itration: {}, Data Range {} - {}'.format(
                        _pi + 1, _begin_index, _end_index))
                    for i, token in enumerate(
                            tqdm(batch[_begin_index:_end_index],
                                 total=len(batch[_begin_index:_end_index]))):
                        char_ids = batcher.batch_sentences([[token]])
                        _ops = sess.run(
                            self.ops,
                            feed_dict={self.ids_placeholder: char_ids})
                        mask = _ops['mask']
                        lm_embeddings = _ops['lm_embeddings']
                        token_embeddings = _ops['token_embeddings']
                        lengths = _ops['lengths']
                        length = int(mask.sum())

                        #### shape of new embeddings [1,3,1,1024] so swap axes
                        new_embedding = np.swapaxes(lm_embeddings, 1, 2)
                        ## Another method for moving the axis (swapping) is transposing the matrix
                        #new_embedding_ = lm_embeddings.transpose(0,2,1,3)

                        new_embedding = new_embedding.reshape(
                            (new_embedding.shape[2], new_embedding.shape[3]))

                        # ds = fout.create_dataset(
                        #     '{}'.format(i),
                        #     new_embedding.shape, dtype='float32',
                        #     data=new_embedding
                        # )

                        document_embeddings.append(new_embedding)
                document_embeddings = np.asarray(document_embeddings)
                with h5py.File(outfile_to_dump.replace('@@', str(_pi + 1)),
                               'w') as fout:
                    ds = fout.create_dataset('embeddings',
                                             document_embeddings.shape,
                                             dtype='float32',
                                             data=document_embeddings)

        return nothing