def main(): # Load Paddlehub pretrained model # 更多预训练模型 https://www.paddlepaddle.org.cn/hublist?filter=en_category&value=SemanticModel #model_name = "ernie_tiny" model_name = "chinese-roberta-wwm-ext-large" module = hub.Module(name=model_name) inputs, outputs, program = module.context(trainable=True, max_seq_len=args.max_seq_len) # Download dataset and use SequenceLabelReader to read dataset dataset = EEDataset(args.data_dir, schema_labels, model=args.do_model) reader = hub.reader.SequenceLabelReader( dataset=dataset, vocab_path=module.get_vocab_path(), max_seq_len=args.max_seq_len, sp_model_path=module.get_spm_path(), word_dict_path=module.get_word_dict_path()) # Construct transfer learning network # Use "sequence_output" for token-level output. sequence_output = outputs["sequence_output"] # Setup feed list for data feeder # Must feed all the tensor of module need feed_list = [ inputs["input_ids"].name, inputs["position_ids"].name, inputs["segment_ids"].name, inputs["input_mask"].name ] # Select a finetune strategy strategy = hub.AdamWeightDecayStrategy( warmup_proportion=args.warmup_proportion, weight_decay=args.weight_decay, learning_rate=args.learning_rate) # Setup runing config for PaddleHub Finetune API config = hub.RunConfig(eval_interval=args.eval_step, save_ckpt_interval=args.model_save_step, use_data_parallel=args.use_data_parallel, use_cuda=args.use_gpu, num_epoch=args.num_epoch, batch_size=args.batch_size, checkpoint_dir=args.checkpoint_dir, strategy=strategy) # Define a sequence labeling finetune task by PaddleHub's API # If add crf, the network use crf as decoder seq_label_task = hub.SequenceLabelTask(data_reader=reader, feature=sequence_output, feed_list=feed_list, max_seq_len=args.max_seq_len, num_classes=dataset.num_labels, config=config, add_crf=args.add_crf) # Finetune and evaluate model by PaddleHub's API # will finish training, evaluation, testing, save model automatically if args.do_train: print("start finetune and eval process") seq_label_task.finetune_and_eval() if args.do_predict: print("start predict process") ret = [] id2label = {val: key for key, val in reader.label_map.items()} input_data = [[d] for d in predict_data] run_states = seq_label_task.predict(data=input_data[1:]) results = [] for batch_states in run_states: batch_results = batch_states.run_results batch_infers = batch_results[0].reshape([-1]).astype( np.int32).tolist() seq_lens = batch_results[1].reshape([-1]).astype(np.int32).tolist() current_id = 0 for length in seq_lens: seq_infers = batch_infers[current_id:current_id + length] seq_result = list(map(id2label.get, seq_infers[1:-1])) current_id += length if args.add_crf else args.max_seq_len results.append(seq_result) ret = [] for sent, r_label in zip(predict_sents, results): sent["labels"] = r_label ret.append(json.dumps(sent, ensure_ascii=False)) write_by_lines("{}.{}.pred".format(args.predict_data, args.do_model), ret)
parser.add_argument("--batch_size", type=int, default=32, help="Total examples' number in batch for training.") parser.add_argument("--add_crf", type=ast.literal_eval, default=True, help="add crf") parser.add_argument("--checkpoint_dir", type=str, default=None, help="Directory to model checkpoint") parser.add_argument("--use_data_parallel", type=ast.literal_eval, default=False, help="Whether use data parallel.") args = parser.parse_args() # yapf: enable. # 先把数据处理好保存下来 train_data = data_process(args.train_data, args.do_model) # 处理训练数据 dev_data = data_process(args.dev_data, args.do_model) # 处理dev数据 test_data = data_process(args.test_data, args.do_model) predict_sents, predict_data = data_process(args.predict_data, args.do_model, is_predict=True) write_by_lines("{}/{}_train.tsv".format(args.data_dir, args.do_model), train_data) write_by_lines("{}/{}_dev.tsv".format(args.data_dir, args.do_model), dev_data) write_by_lines("{}/{}_test.tsv".format(args.data_dir, args.do_model), test_data) write_by_lines("{}/{}_predict.tsv".format(args.data_dir, args.do_model), predict_data) schema_labels = schema_process(args.schema_path, args.do_model) class EEDataset(BaseNLPDataset): """EEDataset""" def __init__(self, data_dir, labels, model="trigger"): # 数据集存放位置 super(EEDataset, self).__init__( base_path=data_dir,