Exemplo n.º 1
0
def train():
    parser = argparse.ArgumentParser()
    parser.add_argument('model_dir', default='model dir')
    args = parser.parse_args()

    model_dir = args.model_dir
    hdf_dir = os.path.join(model_dir, "hdf5")
    os.makedirs(hdf_dir, exist_ok=True)

    bert_model_path = os.path.join(ROOT_DIR, 'BERT-baseline')
    data_path = os.path.join(model_dir, "feature.pkl")
    with open(data_path, 'rb') as fr:
        train_data, train_label, test_data, test_label = pickle.load(fr)
    print("load {}/{} train/dev items ".format(len(train_data),
                                               len(test_data)))

    bert_embed = BERTEmbedding(bert_model_path,
                               task=kashgari.LABELING,
                               sequence_length=50)
    model = KashModel(bert_embed)
    model.build_model(x_train=train_data,
                      y_train=train_label,
                      x_validate=test_data,
                      y_validate=test_label)

    from src.get_model_path import get_model_path
    model_path, init_epoch = get_model_path(hdf_dir)
    if init_epoch > 0:
        print("load epoch from {}".format(model_path))
        model.tf_model.load_weights(model_path)

    optimizer = RAdam(learning_rate=0.0001)
    model.compile_model(optimizer=optimizer)

    hdf5_path = os.path.join(hdf_dir,
                             "crf-{epoch:03d}-{val_accuracy:.3f}.hdf5")
    checkpoint = ModelCheckpoint(hdf5_path,
                                 monitor='val_accuracy',
                                 verbose=1,
                                 save_best_only=True,
                                 save_weights_only=False,
                                 mode='auto',
                                 period=1)
    tensorboard = TensorBoard(log_dir=os.path.join(model_dir, "log"))
    eval_callback = EvalCallBack(kash_model=model,
                                 valid_x=test_data,
                                 valid_y=test_label,
                                 step=1,
                                 log_path=os.path.join(model_dir, "acc.txt"))
    callbacks = [checkpoint, tensorboard, eval_callback]

    model.fit(train_data,
              train_label,
              x_validate=test_data,
              y_validate=test_label,
              epochs=100,
              batch_size=256,
              callbacks=callbacks)
    return
Exemplo n.º 2
0
class BertPolyPhone:
  """ 拼音预测主类"""
  def __init__(self):
    super().__init__()
    self.poly_dict = dict()
    poly_dict_path = "/data1/liufeng/synthesis/frontend/data/simple_poly_dict"
    for line in read_lines(poly_dict_path):
      line = line.replace(" ", "").replace("*", "")
      key = line.split(":")[0]
      value = line.split(":")[1].split(",")
      self.poly_dict[key] = value
    self.model, self.model_dir = None, None
    self.sess = None

  def inialize_model(self, bert_model_path, poly_model_path):
    print('=============init phone model=========================')
    print("bert model path:", bert_model_path)
    print("crf model path:", poly_model_path)
    # 需要训练数据的路径构建字典
    self.sess = tf.Session()
    set_session(self.sess)
    self.model_dir = os.path.dirname(os.path.dirname(poly_model_path))
    data_path = os.path.join(self.model_dir, "feature.pkl")

    train_data, train_label, test_data, test_label = \
        pickle.load(open(data_path, 'rb'))

    bert_embed = BERTEmbedding(bert_model_path, task=kashgari.LABELING,
                               sequence_length=50)
    self.model = BiLSTM_CRF_Model(bert_embed)

    self.model.build_model(x_train=train_data, y_train=train_label,
                           x_validate=test_data, y_validate=test_label)
    self.model.compile_model()
    self.model.tf_model.load_weights(poly_model_path)
    print('=============successful loaded=========================')

  def _lookup_dict(self, bert_result, pred_ph_pairs):
    """查字典的方法对拼音进行修正 """
    # todo: 如果词在词典中,不用bert的结果。
    bert_phone_result = []
    for index_c, (char, ph, _) in enumerate(pred_ph_pairs):
      if char in self.poly_dict.keys():
        # 如果bert预测结果不在多音字字典中,就是预测结果跑偏了
        if bert_result[index_c] not in self.poly_dict[char]:
          bert_phone_result.append((char, ph))
        else:
          bert_result[index_c] = split_phone_format(bert_result[index_c])
          bert_phone_result.append((char, bert_result[index_c]))
          if ph != bert_result[index_c]:
            print("using bert result {}:{} instead of {}".format(
              char, bert_result[index_c], ph))
      else:
        bert_phone_result.append((char, ph))
    return bert_phone_result

  def predict(self, sentence_list):
    """ 通过句子预测韵律,标点断开 """
    bert_input = []
    for sent in sentence_list:
      assert len(sent) < 50
      bert_input.append([c for c in sent])
    print("bert-input:", bert_input)
    prosody = self.model.predict(bert_input)
    return prosody

  def save_pb(self):
    self._write_dict()
    pb_dir = os.path.join(self.model_dir, "pb")
    os.makedirs(pb_dir, exist_ok=True)
    h5_to_pb(self.model.tf_model, pb_dir, self.sess, "model_phone.pb",
             ["output_phone"])
    return

  def _write_dict(self):
    label_path = os.path.join(self.model_dir, "pb/phone_idx2label.txt")
    with open(label_path, "w", encoding="utf-8") as fr:
      for key, value in self.model.embedding.label2idx.items():
        fr.write("{} {}\n".format(value, key))
    print("write {}".format(label_path))

    token_path = os.path.join(self.model_dir, "pb/phone_token2idx.txt")
    with open(token_path, "w", encoding="utf-8") as fr:
      for key, value in self.model.embedding.token2idx.items():
        if len(key) > 0:
          fr.write("{} {}\n".format(key, value))
    print("write {}".format(token_path))
    return

  def compute_embed(self, sentence_list):
    bert_input = [[c for c in sent] for sent in sentence_list]
    print("bert-input:", bert_input)
    import numpy as np
    tensor = self.model.embedding.process_x_dataset(bert_input)
    print("debug:", np.shape(tensor), tensor)
    res = self.model.tf_model.predict(tensor)
    import numpy as np
    print("debug:", np.shape(res), res[0][0: len(sentence_list[0]+1)])
    return tensor

  @staticmethod
  def _merge_eng_char(bert_phone_result, dict_phone_pairs):
    from src.utils import check_all_chinese
    index = 0
    new_bert_phone = []
    for word, _, _ in dict_phone_pairs:
      if (not check_all_chinese(word)) and len(word) > 1:
        new_bert_phone.append(bert_phone_result[index])
        index += len(word)
      else:
        new_bert_phone.append(bert_phone_result[index])
        index += 1
    return new_bert_phone

  def modify_result(self, bert_result, dict_phone_pairs):
    bert_result = self._merge_eng_char(bert_result, dict_phone_pairs)
    bert_phone_pairs = self._lookup_dict(bert_result, dict_phone_pairs)
    phone_pairs = bert_phone_pairs
    # phone_pairs = change_yi(phone_pairs)
    # phone_pairs = change_bu(phone_pairs)
    phone_pairs = sandhi(phone_pairs)
    bert_result = [ph for _, ph in phone_pairs]
    chars = "".join([c for c, _ in phone_pairs])
    bert_result = change_qingyin(bert_result, chars)
    return bert_result
Exemplo n.º 3
0
class BertProsody:
  """ 目前只支持长度50,输入字符数49 + 终结符 """
  def __init__(self):
    self.model, self.model_dir, self.model_path = None, None, None
    self.sess = None
    return

  def initial_model(self, bert_model_path, psd_model_path):
    print('=============init bert model=========================')
    print("bert model path:", bert_model_path)
    print("crf model path:", psd_model_path)
    self.sess = tf.Session()
    set_session(self.sess)
    self.model_dir = os.path.dirname(os.path.dirname(psd_model_path))
    self.model_path = psd_model_path
    data_path = os.path.join(self.model_dir, "feature_psd.pkl")
    train_data, train_label, test_data, test_label = \
        pickle.load(open(data_path, 'rb'))

    bert_embed = BERTEmbedding(bert_model_path, task=kashgari.LABELING,
                               sequence_length=50)
    self.model = BiLSTM_CRF_Model(bert_embed)
    self.model.build_model(x_train=train_data, y_train=train_label,
                           x_validate=test_data, y_validate=test_label)
    self.model.compile_model()
    self.model.tf_model.load_weights(psd_model_path)
    print('=============bert model loaded=========================')
    return

  def _write_dict(self):
    label_path = os.path.join(self.model_dir, "idx2label.txt")
    with open(label_path, "w", encoding="utf-8") as fr:
      for key, value in self.model.embedding.label2idx.items():
        fr.write("{} {}\n".format(value, key))

    token_path = os.path.join(self.model_dir, "token2idx.txt")
    with open(token_path, "w", encoding="utf-8") as fr:
      for key, value in self.model.embedding.token2idx.items():
        if len(key) > 0:
          fr.write("{} {}\n".format(key, value))

  def predict(self, sentence_list):
    """ 通过句子预测韵律,标点断开 """
    bert_input = []
    for sent in sentence_list:
      assert len(sent) < 50
      bert_input.append([c for c in sent])
    print("bert-input:", bert_input)
    prosody = self.model.predict(bert_input)
    return prosody

  def compute_embed(self, sentence_list):
    bert_input = [[c for c in sent] for sent in sentence_list]
    print("bert-input:", bert_input)
    tensor = self.model.embedding.process_x_dataset(bert_input)
    res = self.model.tf_model.predict(tensor)
    import numpy as np
    print("debug:", np.shape(res), res[0])
    return tensor

  def save_pb(self):
    self._write_dict()
    pb_dir = os.path.join(self.model_dir, "pb")
    os.makedirs(pb_dir, exist_ok=True)
    # [print(n.name) for n in tf.get_default_graph().as_graph_def().node]
    h5_to_pb(self.model.tf_model, pb_dir, self.sess, "model_psd.pb",
             ["output_psd"])
    return

  @staticmethod
  def change_by_rules(old_pairs):
    """ 强制规则:
    1. 逗号之前是#3,句号之前是#4
    2. 其他位置,#3 -> #2
    """
    new_pairs = []
    for i, (char, ph, psd) in enumerate(old_pairs[0:-1]):
      next_char, _, _ = old_pairs[i+1]
      if next_char == ",":
        new_pairs.append((char, ph, "3"))
      elif next_char in ["。", "?", "!"]:
        new_pairs.append((char, ph, "4"))
      else:
        if psd == "3":
          new_pairs.append((char, ph, "2"))
        else:
          new_pairs.append((char, ph, psd))
    new_pairs.append(old_pairs[-1])
    return new_pairs