Пример #1
0
    def respond(self, req):
        # 对req进行初步处理
        inputs, dec_input = data_utils.preprocess_request(sentence=req, start_sign=self.start_sign,
                                                          end_sign=self.end_sign, token=self.token,
                                                          max_length=self.max_length)

        self.beam_search_container.init_all_inner_variables(inputs=inputs, dec_input=dec_input)
        inputs, dec_input = self.beam_search_container.expand_beam_size_inputs()
        for t in range(self.max_length):
            predictions = self._create_predictions(inputs, dec_input)
            self.beam_search_container.add(predictions=predictions, end_sign=self.token.get(self.end_sign))
            if self.beam_search_container.beam_size == 0:
                break

            inputs, dec_input = self.beam_search_container.expand_beam_size_inputs()
            dec_input = torch.unsqueeze(dec_input[:, -1], dim=-1)
        beam_search_result = self.beam_search_container.get_result(top_k=1)
        result = ''
        # 从容器中抽取序列,生成最终结果
        for i in range(len(beam_search_result)):
            temp = beam_search_result[i].numpy()
            text = data_utils.sequences_to_texts(temp, self.token)
            text[0] = text[0].replace(self.start_sign, '').replace(self.end_sign, '').replace(' ', '')
            result = '<' + text[0] + '>' + result
        return result
Пример #2
0
    def respond(self, req, input_dict_fn, target_dict_fn):
        # 对req进行初步处理
        input_token, target_token = _data.load_token_dict(
            input_dict_fn=input_dict_fn, target_dict_fn=target_dict_fn)
        inputs, dec_input = self._pre_treat_inputs(req, input_token,
                                                   target_token)
        self.beam_search_container.init_variables(inputs=inputs,
                                                  dec_input=dec_input)
        inputs, dec_input = self.beam_search_container.get_variables()
        for t in range(_config.max_length_tar):
            predictions = self._create_predictions(inputs, dec_input, t)
            self.beam_search_container.add(predictions=predictions,
                                           end_sign=target_token.get('end'))
            if self.beam_search_container.beam_size == 0:
                break

            inputs, dec_input = self.beam_search_container.get_variables()
        beam_search_result = self.beam_search_container.get_result()
        result = ''
        # 从容器中抽取序列,生成最终结果
        for i in range(len(beam_search_result)):
            temp = beam_search_result[i].numpy()
            text = _data.sequences_to_texts(temp, target_token)
            text[0] = text[0].replace('start',
                                      '').replace('end', '').replace(' ', '')
            result = '<' + text[0] + '>' + result
        return result
Пример #3
0
    def respond(self, req):
        # 对req进行初步处理
        inputs, dec_input = _data.preprocess_request(sentence=req,
                                                     token=self.token)

        self.beam_search_container.init_all_inner_variables(
            inputs=inputs, dec_input=dec_input)
        inputs, dec_input = self.beam_search_container.expand_beam_size_inputs(
        )
        for t in range(_config.max_length_tar):
            predictions = self._create_predictions(inputs, dec_input, t)
            self.beam_search_container.add(predictions=predictions,
                                           end_sign=self.token.get('end'))
            if self.beam_search_container.beam_size == 0:
                break

            inputs, dec_input = self.beam_search_container.expand_beam_size_inputs(
            )
        beam_search_result = self.beam_search_container.get_result(top_k=3)
        result = ''
        # 从容器中抽取序列,生成最终结果
        for i in range(len(beam_search_result)):
            temp = beam_search_result[i].numpy()
            text = _data.sequences_to_texts(temp, self.token)
            text[0] = text[0].replace(_config.start_sign,
                                      '').replace(_config.end_sign,
                                                  '').replace(' ', '')
            result = '<' + text[0] + '>' + result
        return result
Пример #4
0
def response(sentence):
    inputs = " ".join(jieba.cut("cls" + sentence + "sep"))
    config = get_config()
    tokenizer = data_utils.load_dict(config["gpt2_dict"])
    inputs = [tokenizer.get(i, 1) for i in inputs.split(' ')]
    # inputs = tf.keras.preprocessing.sequence.pad_sequences([inputs], maxlen=config["max_length"], padding="post")
    inputs = tf.convert_to_tensor(inputs)
    inputs = tf.cast(tf.expand_dims(inputs, axis=0), dtype=tf.int64)

    checkpoint_dir = config["checkpoint_dir"]
    model = gpt2.gpt2(vocab_size=config["vocab_size"],
                      num_layers=config["num_layers"],
                      units=config["units"],
                      deep=config["deep"],
                      num_heads=config["num_heads"],
                      dropout=config["dropout"])

    learning_rate = gpt2.CustomSchedule(config["deep"])
    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate,
                                         beta_1=0.9,
                                         beta_2=0.98,
                                         epsilon=1e-9)
    checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir, exist_ok=True)
    if os.listdir(checkpoint_dir):
        checkpoint.restore(
            tf.train.latest_checkpoint(checkpoint_dir)).expect_partial()

    result = []
    for _ in range(config["max_length"]):
        # print(inputs)
        # exit(0)
        predictions = model(inputs=inputs, training=False)
        predictions = tf.nn.softmax(predictions, axis=-1)
        predictions = predictions[:, -1:, :]
        predictions = tf.squeeze(predictions, axis=1)
        # print(predictions)
        # exit(0)
        pred = tf.argmax(input=predictions, axis=-1)
        print(inputs)
        print(pred)
        # exit(0)
        if pred.numpy()[0] == 2:
            break
        result.append(pred.numpy()[0])

        inputs = tf.concat([inputs, tf.expand_dims(pred, axis=0)], axis=-1)
        print(inputs)
    print(result)
    return data_utils.sequences_to_texts(result, tokenizer)
Пример #5
0
    def respond(self, req: str):
        """
        对外部聊天请求进行回复
        子类需要利用模型进行推断和搜索以产生回复。
        Args:
            req: 输入的语句
        Returns: 系统回复字符串
        """
        # 对req进行初步处理
        inputs, dec_input = data_utils.preprocess_request(
            sentence=req,
            token=self.token,
            max_length=self.max_length,
            start_sign=self.start_sign,
            end_sign=self.end_sign)
        self.beam_search_container.reset(inputs=inputs, dec_input=dec_input)
        inputs, dec_input = self.beam_search_container.get_search_inputs()

        for t in range(self.max_length):
            predictions = self._create_predictions(inputs, dec_input, t)
            self.beam_search_container.expand(predictions=predictions,
                                              end_sign=self.token.get(
                                                  self.end_sign))
            # 注意了,如果BeamSearch容器里的beam_size为0了,说明已经找到了相应数量的结果,直接跳出循环
            if self.beam_search_container.beam_size == 0:
                break

            inputs, dec_input = self.beam_search_container.get_search_inputs()

        beam_search_result = self.beam_search_container.get_result(top_k=3)
        result = ''
        # 从容器中抽取序列,生成最终结果
        for i in range(len(beam_search_result)):
            temp = beam_search_result[i].numpy()
            text = data_utils.sequences_to_texts(temp, self.token)
            text[0] = text[0].replace(self.start_sign,
                                      '').replace(self.end_sign,
                                                  '').replace(' ', '')
            result = '<' + text[0] + '>' + result
        return result