示例#1
0
class query_weight:
    def __init__(self, ckpt_num=156000, is_training=False):
        #init_log()
        batch_size = 1
        logging.info("Init query weight model ...")
        self.sp = Tokenizer()
        self.lm = language_model()
        self.xgb_model = xgb.Booster(model_file=conf.rank_model)
        tf.logging.set_verbosity(tf.logging.INFO)
        tf_float = tf.bfloat16 if FLAGS.use_bfloat16 else tf.float32
        self.input_ids = tf.placeholder(dtype=tf.int64,
                                        shape=[batch_size, FLAGS.seq_len],
                                        name="input_ids")
        self.segment_ids = tf.placeholder(dtype=tf.int32,
                                          shape=[batch_size, FLAGS.seq_len],
                                          name="segment_ids")
        self.input_mask = tf.placeholder(dtype=tf_float,
                                         shape=[batch_size, FLAGS.seq_len],
                                         name="input_mask")
        self.label_ids = tf.placeholder(dtype=tf.int64,
                                        shape=[batch_size],
                                        name="label_ids")
        inp = tf.transpose(self.input_ids, [1, 0])
        seg_id = tf.transpose(self.segment_ids, [1, 0])
        inp_mask = tf.transpose(self.input_mask, [1, 0])
        self.sess = tf.Session()
        xlnet_config = xlnet.XLNetConfig(json_path=FLAGS.model_config_path)
        run_config = xlnet.create_run_config(is_training, True, FLAGS)

        xlnet_model = xlnet.XLNetModel(xlnet_config=xlnet_config,
                                       run_config=run_config,
                                       input_ids=inp,
                                       seg_ids=seg_id,
                                       input_mask=inp_mask)
        self.output, self.attn_prob, self.attention_out = xlnet_model.output_encode, xlnet_model.attn_prob, xlnet_model.attention_out

        num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()])
        tf.logging.info('#params: {}'.format(num_params))
        xlnet_model.saver.restore(
            self.sess, FLAGS.init_checkpoint + "/model.ckpt-" + str(ckpt_num))
        #### load pretrained models
        # scaffold_fn = model_utils.init_from_checkpoint(FLAGS)
        logging.info("Init query weight model finished ...")

    def run(self, req_dict):
        result = None
        try:
            query = req_dict["request"]["p"]["query"]
            result = self.run_step(query)
        except Exception as e:
            logging.warning("run_error: %s" % traceback.format_exc())
        return result

    def run_step(self, text):
        cur_sent = preprocess_text(text.strip(), lower=FLAGS.uncased)
        tokens, ids = self.sp.encode_ids(cur_sent)
        sent_len, diff_len = len(ids) - 1, FLAGS.seq_len - len(ids)

        input_ids = ids + [SEP_ID] * (diff_len - 1) + [
            CLS_ID
        ]  #  cat_data = np.concatenate([inp, a_data, sep_array, b_data, sep_array, cls_array])
        input_tokens = tokens + ["<sep>"] * (diff_len - 1) + ["<cls>"]
        input_mask = [1] + [0] * sent_len + [1] * diff_len
        segment_ids = [0] * (sent_len + 1) + [
            2
        ] * diff_len  # seg_id = ([0] * (reuse_len + a_data.shape[0]) + [0] + [1] * b_data.shape[0] + [1] + [2])
        input_ids, input_tokens, input_mask, segment_ids = input_ids[:FLAGS.
                                                                     seq_len], input_tokens[:
                                                                                            FLAGS
                                                                                            .
                                                                                            seq_len], input_mask[:
                                                                                                                 FLAGS
                                                                                                                 .
                                                                                                                 seq_len], segment_ids[:
                                                                                                                                       FLAGS
                                                                                                                                       .
                                                                                                                                       seq_len]
        '''
       logging.info("text: %s, seg_text: %s" % (text, " ".join([str(x) for x in tokens])))
       logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
       logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
       logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
       '''
        il={'text':text,'seg_text':" ".join([str(x) for x in tokens]),'input_ids':" ".join([str(x) for x in input_ids]), \
           'input_mask':" ".join([str(x) for x in input_mask]),'segment_ids':" ".join([str(x) for x in segment_ids])}
        logging.info(json.dumps(il, ensure_ascii=False))

        feed_dict = {
            self.input_ids: [input_ids],
            self.segment_ids: [segment_ids],
            self.input_mask: [input_mask]
        }
        fetch = self.sess.run(
            [self.output, self.attn_prob, self.attention_out], feed_dict)
        out_encode, atten_prob = fetch[0], fetch[1]
        #weight0 = normalization(self.cal_weight(out_encode, input_tokens))
        weight_attn = normalization(self.weight_attenprob(atten_prob, tokens))
        weight_idf = normalization(self.sp.cal_weight_idf(tokens[1:]))
        weight_lm = normalization(self.lm.cal_weight_lm(tokens[1:]))
        weight_rule = self.merge_weight([(weight_attn, 0.5), (weight_idf, 0.5),
                                         (weight_lm, 0.5)])
        self.weight_attn, self.weight_idf, self.weight_lm = weight_attn, weight_idf, weight_lm
        sen2terms = [e for e in tokens[1:]]
        weightrank = self.rank_weight(sen2terms, weight_attn, weight_idf,
                                      weight_lm)
        weight_rank = normalization(weightrank)
        weight = self.merge_weight([(weight_rank, 0.7),
                                    (weight_rule, 0.0)])  # 0.6-0.4
        wl = {'weight_rank':' '.join([str(k)+':'+str(v) for k, v in weight_rank]),'weight_rule':' '.join([str(k)+':'+str(v) for k, v in weight_rule]), \
              'weight': ' '.join([str(k) + ':' + str(v) for k, v in weight])}
        logging.info(json.dumps(wl, ensure_ascii=False))
        return weight

    def rank_weight(self, sen2terms, weight_attn, weight_idf, weight_lm):
        tmp, score_sum = [], 1e-8
        for term in sen2terms:
            feature_vector, _ = get_feature(term, sen2terms, weight_attn,
                                            weight_idf, weight_lm)
            feature = np.array(feature_vector)
            feature_csr = sparse.csr_matrix(feature)
            input = DMatrix(feature_csr)
            score = self.xgb_model.predict(input)[0]
            prob = 1.0 / (1 + math.exp(-1 * score))
            tmp.append((term, prob))
            score_sum += prob
        res = [(k, round(v / score_sum, 3)) for k, v in tmp]
        return res

    def merge_weight(self, weight_tuple):
        weight, weight_sum = [], 1e-8
        for j in range(len(weight_tuple[0][0])):
            tmp = 0.0
            for i in range(len(weight_tuple)):
                (word, val), coef = weight_tuple[i][0][j], weight_tuple[i][1]
                tmp += val * coef
            weight.append((weight_tuple[0][0][j][0], tmp))
            weight_sum += tmp
        token_weight = [(k, round(v / weight_sum, 3)) for k, v in weight]
        return token_weight

    def weight_attenprob(self, attention_probs, input_tokens):
        weights = []
        (row, col, batch, dim) = attention_probs.shape
        for j in range(col):
            tmp = 0.0
            for i in range(row):
                if i == j: continue
                tmp += attention_probs[i][j][0][0]
            weights.append(tmp)
        token_weight = [(input_tokens[i], weights[i])
                        for i in range(min(len(input_tokens), len(weights)))
                        if input_tokens[i] not in special_words]
        token_weights = token_weight + [
            (input_tokens[i], 0.0)
            for i in range(len(token_weight) + 1, len(input_tokens))
        ]
        return token_weights

    def cal_weight(self, encode_vects, input_tokens):
        vects, vect = encode_vects[0], np.sum(encode_vects, axis=1)[0]
        token_weights = [(input_tokens[i], cal_sim(vect, vects[i]))
                         for i in range(len(vects))
                         if input_tokens[i] not in special_words]
        #token_weight = [(input_tokens[i], weight[i-1]) if input_tokens[i] not in special_words else (input_tokens[i], 0.0) for i in range(len(input_tokens))]
        return token_weights
示例#2
0
def _create_data(idx, input_paths):
    # Load sentence-piece model
    #sp = spm.SentencePieceProcessor(); sp.Load(FLAGS.sp_path)
    sp = Tokenizer()

    input_shards = []
    total_line_cnt = 0
    for input_path in input_paths:
        input_data, sent_ids = [], []
        sent_id, line_cnt = True, 0
        tf.logging.info("Processing %s", input_path)
        for line in tf.gfile.Open(input_path):
            if line_cnt % 100000 == 0:
                tf.logging.info("Loading line %d", line_cnt)
            line_cnt += 1

            if not line.strip():
                if FLAGS.use_eod:
                    sent_id = not sent_id
                    cur_sent = [EOD_ID]
                else:
                    continue
            else:
                if FLAGS.from_raw_text:
                    cur_sent = preprocess_text(line.strip(),
                                               lower=FLAGS.uncased)
                    #cur_sent = encode_ids(sp, cur_sent)
                    _, cur_sent = sp.encode_ids(cur_sent)
                    #a=sp.encode_ids("java开发工程师")
                else:
                    cur_sent = list(map(int, line.strip().split()))

            input_data.extend(cur_sent)
            sent_ids.extend([sent_id] * len(cur_sent))
            sent_id = not sent_id

        tf.logging.info("Finish with line %d", line_cnt)
        if line_cnt == 0:
            continue

        input_data = np.array(input_data, dtype=np.int64)
        sent_ids = np.array(sent_ids, dtype=np.bool)

        total_line_cnt += line_cnt
        input_shards.append((input_data, sent_ids))

    tf.logging.info("[Task %d] Total number line: %d", idx, total_line_cnt)

    tfrecord_dir = os.path.join(FLAGS.save_dir, "tfrecords")

    filenames, num_batch = [], 0

    # Randomly shuffle input shards (with a fixed but distinct random seed)
    np.random.seed(100 * FLAGS.task + FLAGS.pass_id)

    perm_indices = np.random.permutation(len(input_shards))
    tf.logging.info("Using perm indices %s for pass %d", perm_indices.tolist(),
                    FLAGS.pass_id)

    input_data_list, sent_ids_list = [], []
    prev_sent_id = None
    for perm_idx in perm_indices:
        input_data, sent_ids = input_shards[perm_idx]
        # make sure the `send_ids[0] == not prev_sent_id`
        if prev_sent_id is not None and sent_ids[0] == prev_sent_id:
            sent_ids = np.logical_not(sent_ids)

        # append to temporary list
        input_data_list.append(input_data)
        sent_ids_list.append(sent_ids)

        # update `prev_sent_id`
        prev_sent_id = sent_ids[-1]

    input_data = np.concatenate(input_data_list)
    sent_ids = np.concatenate(sent_ids_list)

    file_name, cur_num_batch = create_tfrecords(
        save_dir=tfrecord_dir,
        basename="{}-{}-{}".format(FLAGS.split, idx, FLAGS.pass_id),
        data=[input_data, sent_ids],
        bsz_per_host=FLAGS.bsz_per_host,
        seq_len=FLAGS.seq_len,
        bi_data=FLAGS.bi_data,
        sp=sp,
    )

    filenames.append(file_name)
    num_batch += cur_num_batch

    record_info = {"filenames": filenames, "num_batch": num_batch}

    return record_info