class ModelServer:
    def __init__(self, param):

        self.model_path = os.path.abspath(param["model_path"])
        self.bert_config_file = os.path.abspath(param["bert_config_file"])
        bert_config = modeling.BertConfig.from_json_file(self.bert_config_file)
        self.fulltoken = FullTokenizer(os.path.abspath(param["vocab_file"]))
        self.vocab_dict = self.fulltoken.vocab

        target_start_ids = self.vocab_dict["[CLS]"]
        target_end_ids = self.vocab_dict["[SEP]"]

        num_gpus = len(os.environ["CUDA_VISIBLE_DEVICES"].split(','))
        tf.logging.info("num_gpus is {}".format(num_gpus))
        if param["use_mul_gpu"]:
            distribute = tf.contrib.distribute.MirroredStrategy(
                num_gpus=num_gpus)
        else:
            distribute = None
        run_config = tf.estimator.RunConfig(model_dir=os.path.abspath(
            self.model_path),
                                            save_summary_steps=200,
                                            keep_checkpoint_max=2,
                                            save_checkpoints_steps=3000,
                                            train_distribute=distribute,
                                            eval_distribute=distribute)
        self.input_max_seq_length = param["max_seq_length"]
        model_fn = model_fn_builder(
            bert_config,
            init_checkpoint=None,
            learning_rate=0.0001,
            num_train_steps=10000,
            num_warmup_steps=100,
            use_one_hot_embeddings=False,  # when use tpu ,it's True
            input_seq_length=param["max_seq_length"],
            target_seq_length=param["max_target_seq_length"],
            target_start_ids=target_start_ids,
            target_end_ids=target_end_ids,
            batch_size=param["batch_size"],
            mode_type=param["mode_type"])
        self.estimator = tf.estimator.Estimator(model_fn=model_fn,
                                                config=run_config)

    #input:[(str_mask_tokens,str_labels),list_str_mask_words]
    #label 0:Not mentioned,
    #   1:Negative,
    #   2:Neutral,
    #   3:Positive
    def predict(self, inputs, limitNum=3):
        predicts = []
        if not isinstance(inputs, list):
            inputs = [inputs]

        def token_input():
            for input in inputs:
                tokens = input[0]
                labels = [int(label) for label in input[1]][:20]
                mask_words = input[2]
                assert max(labels) < 4 and min(labels) >= 0
                tokens = self.fulltoken.tokenize(
                    tokens)[:self.input_max_seq_length - 2]

                def replace_Mask(tokens, mask_words):
                    mask_index = []
                    first_maskwords = [x[0] for x in mask_words]

                    for index, token in enumerate(tokens):
                        if token in first_maskwords:
                            for mask_words_x in mask_words:
                                if token == mask_words_x[0]:
                                    _token = "".join([
                                        _t.replace("#", '')
                                        for _t in tokens[index:index +
                                                         len(mask_words_x)]
                                    ])
                                    if _token == mask_words_x:
                                        for i in range(len(mask_words_x)):
                                            mask_index.append(index + i)
                                        mask_words = [
                                            x_ for x_ in mask_words
                                            if x_ != mask_words_x
                                        ]
                                        first_maskwords = [
                                            x[0] for x in mask_words
                                        ]
                        if len(mask_words) < 1:
                            break
                    for mask_index_ in mask_index:
                        tokens[mask_index_] = '[MASK]'
                    return tokens

                tokens = replace_Mask(tokens, mask_words)
                ids = self.fulltoken.convert_tokens_to_ids(['[CLS]'] + tokens +
                                                           ['[SEP]'])
                input_mask = [1] * len(ids)
                segment_ids = [0] * self.input_max_seq_length
                while len(ids) < self.input_max_seq_length:
                    ids.append(0)
                    input_mask.append(0)
                while len(labels) < 20:
                    labels.append(0)

                yield ([ids], [input_mask], [labels], [segment_ids])

        def input_fn():

            dataset = tf.data.Dataset.from_generator(
                token_input, (tf.int64, tf.int64, tf.int64, tf.int64),
                output_shapes=(tf.TensorShape([
                    None, self.input_max_seq_length
                ]), tf.TensorShape([None, self.input_max_seq_length]),
                               tf.TensorShape([None, 20]),
                               tf.TensorShape(
                                   [None, self.input_max_seq_length])))
            dataset = dataset.map(
                lambda ids, input_mask, labels, segment_ids: {
                    "sentiment_labels": labels,
                    "input_token_ids": ids,
                    "input_mask": input_mask,
                    "target_token_ids": tf.zeros_like([1, 1]),
                    "target_mask": tf.zeros_like([1, 1]),
                    "segment_ids": segment_ids
                })

            # (ids, input_mask, labels, segment_ids)=dataset
            # features={
            #     "sentiment_labels": labels,
            #     "input_token_ids": ids,
            #     "input_mask": input_mask,
            #     "target_token_ids": tf.zeros_like([1, 1]),
            #     "target_mask": tf.zeros_like([1, 1]),
            #     "segment_ids": segment_ids}
            #
            # return features

            return dataset

        result = self.estimator.predict(input_fn=input_fn)
        for prediction in result:
            sample_id = prediction['sample_id'][:, :limitNum].T.tolist()
            ans = []
            for sample_id_ in sample_id:
                token = self.fulltoken.convert_ids_to_tokens(sample_id_)
                ans.append("".join(token[:-1]))
            predicts.append(ans)
            input = prediction['inputs'].tolist()
            print(self.fulltoken.convert_ids_to_tokens(input))

        return predicts
def main(args):
    checkpoint_path = os.path.join(args.model_dir, "bert_model.ckpt")

    bert_config = BertConfig.from_json_file(os.path.join(args.model_dir, "bert_config.json"))
    bert_config.hidden_dropout_prob = 0.0
    bert_config.attention_probs_dropout_prob = 0.0

    batch_size = args.batch_size
    max_seq_len = args.max_seq_length
    top_k = args.top_k_answers
    tf_dtype = tf.float16 if args.precision == 'fp16' else tf.float32

    if args.effective_mode:
        # load transformer weights *before* building the computation graph
        weights_value = load_transformer_weights(checkpoint_path, bert_config, batch_size, max_seq_len, tf_dtype)

    # build model
    input_ids_placeholder = tf.placeholder(shape=[batch_size, max_seq_len], dtype=tf.int32, name="input_ids")
    input_mask_placeholder = tf.placeholder(shape=[batch_size, max_seq_len], dtype=tf.int32, name="input_mask")
    attention_mask_placeholder = tf.placeholder(shape=[batch_size, max_seq_len, max_seq_len], dtype=tf_dtype, name="attention_mask")
    input_embedding_placeholder = tf.placeholder(shape=[batch_size, max_seq_len, bert_config.hidden_size], dtype=tf_dtype, name="input_embedding")
    embedding_table_placeholder = tf.placeholder(shape=[bert_config.vocab_size, bert_config.hidden_size], dtype=tf_dtype, name="embedding_table")
    transformer_output_placeholder = tf.placeholder(shape=[batch_size, max_seq_len, bert_config.hidden_size], dtype=tf_dtype, name="transformer_output")

    embedding_layer = EmbeddingLayer(bert_config, tf_dtype, input_ids_placeholder)
    if args.effective_mode:
        effective_transformer_layer = EffectiveTransformerLayer(batch_size, max_seq_len, bert_config,
                                                                attention_mask_placeholder, input_mask_placeholder,
                                                                input_embedding_placeholder, weights_value)
    else:
        standard_transformer_layer = TransformerLayer(bert_config, input_embedding_placeholder, attention_mask_placeholder)
    output_layer = LanguageModelOutputLayer(bert_config, tf_dtype, transformer_output_placeholder, embedding_table_placeholder)

    config = tf.ConfigProto()
    config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
    with tf.Session(config=config) as sess:
        # restore embedding layer and output layer
        variables_to_restore = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        if tf_dtype == tf.float16:
            # https://stackoverflow.com/questions/42793027/quantize-tensorflow-graph-to-float16
            sess.run(tf.global_variables_initializer())
            for variable in variables_to_restore:
                var = tf.contrib.framework.load_variable(checkpoint_path, variable.op.name)
                if var.dtype == np.float32:
                    tf.add_to_collection('assignOps', variable.assign(tf.cast(var, tf.float16)))
                else:
                    tf.add_to_collection('assignOps', variable.assign(var))
            sess.run(tf.get_collection('assignOps'))
        else:
            saver = tf.train.Saver(variables_to_restore)
            saver.restore(sess, checkpoint_path)

        # process input data
        tokenizer = FullTokenizer(vocab_file=os.path.join(args.model_dir, 'vocab.txt'))
        input_ids, input_mask, input_text, to_predict = process_data(batch_size, max_seq_len, tokenizer)
        input_ids_tensor = tf.convert_to_tensor(input_ids, dtype=tf.int32)
        input_mask_tensor = tf.convert_to_tensor(input_mask, dtype=tf.int32)

        # predict
        begin = datetime.now()
        input_embedding, embedding_table = sess.run(
            [embedding_layer.get_embedding_output(), embedding_layer.get_embedding_table()],
            feed_dict={input_ids_placeholder: input_ids})
        attention_mask = sess.run(create_attention_mask_from_input_mask(input_ids_tensor, input_mask_tensor, tf_dtype))
        if args.effective_mode:
            transformer_output = sess.run(effective_transformer_layer.get_transformer_output(),
                                          feed_dict={input_embedding_placeholder: input_embedding,
                                                     attention_mask_placeholder: attention_mask,
                                                     input_mask_placeholder: input_mask})
        else:
            transformer_output = sess.run(standard_transformer_layer.get_transformer_output(),
                                          feed_dict={input_embedding_placeholder: input_embedding,
                                                     attention_mask_placeholder: attention_mask})
        probs = sess.run(output_layer.get_predict_probs(),
                         feed_dict={transformer_output_placeholder: transformer_output,
                                    embedding_table_placeholder: embedding_table})
        end = datetime.now()
        print("time cost: ", (end - begin).total_seconds(), "s")

        # choose top k answers
        top_ids = np.argsort(-probs, axis=2)[:, :, :top_k]

        batch_results = []
        for sid, blank_ids in enumerate(to_predict):
            sentence_results = []
            for cid in blank_ids:
                result = []
                for idx in top_ids[sid][cid]:
                    token = tokenizer.convert_ids_to_tokens([idx])[0]
                    result.append((token, probs[sid][cid][idx]))
                sentence_results.append(result)
            batch_results.append(sentence_results)

    for text, blank_ids, sentence_results in zip(input_text, to_predict, batch_results):
        print("Q:", text)
        for cid, result in zip(blank_ids, sentence_results):
            print("A:", result)