class QueryReformulation:
    def __init__(self, model_path=None, output_path=''):
        self.model = None
        self.model_name = None
        self.model_output = output_path + '/qr_{name}_model_[e{epoch}]_[p{precision}]_' \
                            + str(datetime.now().date()) + '.h5'
        if model_path:
            self.model = load_model(model_path)
            self.model.summary()

    def build_model(self, model_name, query_dim, terms_dim, output_dim,
                    word_embedding):

        self.model_name = model_name

        query_input = Input(shape=(query_dim, ), name='query_input')
        terms_input = Input(shape=(terms_dim, ), name='terms_input')

        if model_name == 'lstm':
            embedding_feature_block = Sequential(layers=[
                Embedding(word_embedding.vocabulary_size,
                          word_embedding.dimensions,
                          weights=[word_embedding.embedding_matrix],
                          trainable=True,
                          mask_zero=False),
                BatchNormalization(),
                LSTM(64, return_sequences=True)
            ])

        elif model_name == 'bilstm':
            embedding_feature_block = Sequential(layers=[
                Embedding(word_embedding.vocabulary_size,
                          word_embedding.dimensions,
                          weights=[word_embedding.embedding_matrix],
                          trainable=True,
                          mask_zero=False),
                BatchNormalization(),
                Bidirectional(LSTM(64, return_sequences=True))
            ])

        else:  # default cnn
            embedding_feature_block = Sequential(layers=[
                Embedding(word_embedding.vocabulary_size,
                          word_embedding.dimensions,
                          weights=[word_embedding.embedding_matrix],
                          trainable=True,
                          mask_zero=False),
                BatchNormalization(),
                Conv1D(filters=64, kernel_size=3, strides=1),
                MaxPooling1D(pool_size=3)
            ])

        # Features
        query_feature = embedding_feature_block(query_input)
        terms_feature = embedding_feature_block(terms_input)

        # Query-Terms alignment
        attention = Dot(axes=-1)([query_feature, terms_feature])
        softmax_attention = Lambda(lambda x: softmax(x, axis=1),
                                   output_shape=unchanged_shape)(attention)
        terms_aligned = Dot(axes=1)([softmax_attention, terms_feature])

        # Aligned features
        if model_name == 'lstm':
            flatten_layer = LSTM(128, return_sequences=False)(terms_aligned)

        elif model_name == 'bilstm':
            flatten_layer = Bidirectional(LSTM(
                128, return_sequences=False))(terms_aligned)

        else:  # default cnn
            merged_cnn = Conv1D(filters=128, kernel_size=3,
                                strides=1)(terms_aligned)
            merged_cnn = MaxPooling1D(pool_size=3)(merged_cnn)
            flatten_layer = Flatten()(merged_cnn)

        # Output
        dense = BatchNormalization()(flatten_layer)
        dense = Dense(64, activation='sigmoid')(dense)
        out = Dense(output_dim, activation='linear')(dense)

        self.model = Model(inputs=[query_input, terms_input], outputs=out)
        self.model.compile(optimizer='adam', loss=losses.mean_squared_error)
        self.model.summary()

    def train_model(self,
                    query_objs,
                    query_sequence,
                    terms_sequence,
                    candidate_terms,
                    epochs=20,
                    batch_size=4):
        best_precision = 0
        pool = Pool(batch_size)
        for e in range(epochs):
            print('Epochs: %3d/%d' % (e + 1, epochs))

            reward = np.zeros(shape=(len(query_objs)))
            precision = np.zeros(shape=(len(query_objs)))
            for i, query, q_seq, t_seq, terms in get_batch_data(
                    query_objs, query_sequence, terms_sequence,
                    candidate_terms, batch_size):
                print('  [%4d-%-4d/%d]' % (i, i + batch_size, len(query_objs)))

                weights = self.model.predict(x=[q_seq, t_seq])

                batch_reward_precision = pool.map(evaluate_reward_precision,
                                                  zip(weights, terms, query))
                batch_reward_precision = np.array(batch_reward_precision)

                batch_reward = 0.8 * np.asarray(
                    batch_reward_precision[:,
                                           0]) + 0.2 * reward[i:i + batch_size]

                self.model.train_on_batch(x=[q_seq, t_seq],
                                          y=weights,
                                          sample_weight=batch_reward)

                reward[i:i + batch_size] = batch_reward_precision[:, 0]
                precision[i:i + batch_size] = batch_reward_precision[:, 1]

            # Save model
            avg_precision = precision.mean()
            print('  Average precision %.5f on epoch %d, best precision %.5f' %
                  (avg_precision, e + 1, best_precision))
            if avg_precision > best_precision:
                best_precision = avg_precision
                self.model.save(filepath=self.model_output.format(
                    name=self.model_name,
                    epoch=e + 1,
                    precision=round(avg_precision, 4)))

        pool.close()
        pool.join()

    def test_model(self,
                   query_objs,
                   query_sequence,
                   terms_sequence,
                   candidate_terms,
                   batch_size=4):

        pool = Pool(batch_size)
        precision_recall = np.zeros(shape=(len(query_objs), 2))
        for i, query, q_seq, t_seq, terms in get_batch_data(
                query_objs, query_sequence, terms_sequence, candidate_terms,
                batch_size):
            print('[%4d-%-4d/%d]' % (i, i + batch_size, len(query_objs)))

            weights = self.model.predict(x=[q_seq, t_seq])

            batch_precision_recall = pool.map(evaluate_precision_recall,
                                              zip(weights, terms, query))

            precision_recall[i:i +
                             batch_size] = np.array(batch_precision_recall)

        pool.close()
        pool.join()

        return precision_recall.mean(axis=0)

    def reformulate_query(self,
                          query_sequence,
                          terms_sequence,
                          candidate_terms,
                          threshold=0.5):
        weights = self.model.predict(x=[[query_sequence], [terms_sequence]])
        reformulated_query = recreate_query(terms=candidate_terms,
                                            weights=weights[0],
                                            threshold=threshold)
        return reformulated_query
def construct_keras_api_model(embedding_weights):
    # input_no_time_no_repeat = Input(shape=max_len, dtype='int32')
    # embedded_no_time_no_repeat = Embedding(
    #     creative_id_window,embedding_size,weights=[embedding_weights],trainable=False
    # )(input_no_time_no_repeat)
    # ==================================================================================
    Input_fix_creative_id = Input(shape=(math.ceil(time_id_max / period_days) *
                                         period_length),
                                  dtype='int32',
                                  name='input_fix_creative_id')
    Embedded_fix_creative_id = Embedding(
        creative_id_window,
        embedding_size,
        weights=[embedding_weights],
        trainable=False)(Input_fix_creative_id)
    # ==================================================================================
    # input_no_time_with_repeat = Input(shape=max_len, dtype='int32')
    # embedded_no_time_with_repeat = Embedding(creative_id_window,embedding_size,weights=[embedding_weights],trainable=False)(input_no_time_with_repeat)

    # ----------------------------------------------------------------------
    GM_x = keras.layers.GlobalMaxPooling1D()(Embedded_fix_creative_id)
    GM_x = Dropout(0.5)(GM_x)
    GM_x = Dense(embedding_size // 2, kernel_regularizer=l2(0.001))(GM_x)
    GM_x = BatchNormalization()(GM_x)
    GM_x = Activation('relu')(GM_x)
    GM_x = Dropout(0.5)(GM_x)
    GM_x = Dense(embedding_size // 4, kernel_regularizer=l2(0.001))(GM_x)
    GM_x = BatchNormalization()(GM_x)
    GM_x = Activation('relu')(GM_x)
    GM_x = Dense(1, 'sigmoid')(GM_x)

    # ----------------------------------------------------------------------
    GA_x = GlobalAveragePooling1D()(Embedded_fix_creative_id)
    GA_x = Dropout(0.5)(GA_x)
    GA_x = Dense(embedding_size // 2, kernel_regularizer=l2(0.001))(GA_x)
    GA_x = BatchNormalization()(GA_x)
    GA_x = Activation('relu')(GA_x)
    GA_x = Dropout(0.5)(GA_x)
    GA_x = Dense(embedding_size // 4, kernel_regularizer=l2(0.001))(GA_x)
    GA_x = BatchNormalization()(GA_x)
    GA_x = Activation('relu')(GA_x)
    GA_x = Dense(1, 'sigmoid')(GA_x)

    # ==================================================================================
    Conv_creative_id = Conv1D(embedding_size, 15, 5,
                              activation='relu')(Embedded_fix_creative_id)
    # ----------------------------------------------------------------------
    Conv_GM_x = MaxPooling1D(7)(Conv_creative_id)
    Conv_GM_x = Conv1D(embedding_size, 2, 1, activation='relu')(Conv_GM_x)
    Conv_GM_x = GlobalMaxPooling1D()(Conv_GM_x)
    Conv_GM_x = Dropout(0.5)(Conv_GM_x)
    Conv_GM_x = Dense(embedding_size // 2,
                      kernel_regularizer=l2(0.001))(Conv_GM_x)
    Conv_GM_x = BatchNormalization()(Conv_GM_x)
    Conv_GM_x = Activation('relu')(Conv_GM_x)
    Conv_GM_x = Dropout(0.5)(Conv_GM_x)
    Conv_GM_x = Dense(embedding_size // 4,
                      kernel_regularizer=l2(0.001))(Conv_GM_x)
    Conv_GM_x = BatchNormalization()(Conv_GM_x)
    Conv_GM_x = Activation('relu')(Conv_GM_x)
    Conv_GM_x = Dense(1, 'sigmoid')(Conv_GM_x)

    # ----------------------------------------------------------------------
    Conv_GA_x = AveragePooling1D(7)(Conv_creative_id)
    Conv_GA_x = Conv1D(embedding_size, 2, 1, activation='relu')(Conv_GA_x)
    Conv_GA_x = GlobalAveragePooling1D()(Conv_GA_x)
    Conv_GA_x = Dropout(0.5)(Conv_GA_x)
    Conv_GA_x = Dense(embedding_size // 2,
                      kernel_regularizer=l2(0.001))(Conv_GA_x)
    Conv_GA_x = BatchNormalization()(Conv_GA_x)
    Conv_GA_x = Activation('relu')(Conv_GA_x)
    Conv_GA_x = Dropout(0.5)(Conv_GA_x)
    Conv_GA_x = Dense(embedding_size // 4,
                      kernel_regularizer=l2(0.001))(Conv_GA_x)
    Conv_GA_x = BatchNormalization()(Conv_GA_x)
    Conv_GA_x = Activation('relu')(Conv_GA_x)
    Conv_GA_x = Dense(1, 'sigmoid')(Conv_GA_x)

    # ----------------------------------------------------------------------
    LSTM_x = Conv1D(embedding_size, 14, 7, activation='relu')(Conv_creative_id)
    LSTM_x = LSTM(embedding_size, return_sequences=True)(LSTM_x)
    LSTM_x = LSTM(embedding_size, return_sequences=True)(LSTM_x)
    LSTM_x = LSTM(embedding_size)(LSTM_x)
    LSTM_x = Dropout(0.5)(LSTM_x)
    LSTM_x = Dense(embedding_size // 2, kernel_regularizer=l2(0.001))(LSTM_x)
    LSTM_x = BatchNormalization()(LSTM_x)
    LSTM_x = Activation('relu')(LSTM_x)
    LSTM_x = Dropout(0.5)(LSTM_x)
    LSTM_x = Dense(embedding_size // 4, kernel_regularizer=l2(0.001))(LSTM_x)
    LSTM_x = BatchNormalization()(LSTM_x)
    LSTM_x = Activation('relu')(LSTM_x)
    LSTM_x = Dense(1, 'sigmoid')(LSTM_x)

    # ----------------------------------------------------------------------
    concatenated = concatenate([
        GM_x,
        GA_x,
        Conv_GM_x,
        Conv_GA_x,
        LSTM_x,
    ],
                               axis=-1)
    output_tensor = Dense(1, 'sigmoid')(concatenated)

    keras_api_model = Model(
        [
            # input_no_time_no_repeat,
            Input_fix_creative_id,
            # input_no_time_with_repeat,
        ],
        output_tensor)
    keras_api_model.summary()
    plot_model(keras_api_model, to_file='model/keras_api_word2vec_model.png')
    print('-' * 5 + ' ' * 3 + "编译模型" + ' ' * 3 + '-' * 5)
    keras_api_model.compile(optimizer=optimizers.RMSprop(lr=RMSProp_lr),
                            loss=losses.binary_crossentropy,
                            metrics=[metrics.binary_accuracy])
    return keras_api_model