def main():
    # Load Paddlehub pretrained model
    # 更多预训练模型 https://www.paddlepaddle.org.cn/hublist?filter=en_category&value=SemanticModel
    #model_name = "ernie_tiny"
    model_name = "chinese-roberta-wwm-ext-large"
    module = hub.Module(name=model_name)
    inputs, outputs, program = module.context(trainable=True,
                                              max_seq_len=args.max_seq_len)

    # Download dataset and use SequenceLabelReader to read dataset
    dataset = EEDataset(args.data_dir, schema_labels, model=args.do_model)
    reader = hub.reader.SequenceLabelReader(
        dataset=dataset,
        vocab_path=module.get_vocab_path(),
        max_seq_len=args.max_seq_len,
        sp_model_path=module.get_spm_path(),
        word_dict_path=module.get_word_dict_path())

    # Construct transfer learning network
    # Use "sequence_output" for token-level output.
    sequence_output = outputs["sequence_output"]

    # Setup feed list for data feeder
    # Must feed all the tensor of module need
    feed_list = [
        inputs["input_ids"].name, inputs["position_ids"].name,
        inputs["segment_ids"].name, inputs["input_mask"].name
    ]

    # Select a finetune strategy
    strategy = hub.AdamWeightDecayStrategy(
        warmup_proportion=args.warmup_proportion,
        weight_decay=args.weight_decay,
        learning_rate=args.learning_rate)

    # Setup runing config for PaddleHub Finetune API
    config = hub.RunConfig(eval_interval=args.eval_step,
                           save_ckpt_interval=args.model_save_step,
                           use_data_parallel=args.use_data_parallel,
                           use_cuda=args.use_gpu,
                           num_epoch=args.num_epoch,
                           batch_size=args.batch_size,
                           checkpoint_dir=args.checkpoint_dir,
                           strategy=strategy)

    # Define a sequence labeling finetune task by PaddleHub's API
    # If add crf, the network use crf as decoder
    seq_label_task = hub.SequenceLabelTask(data_reader=reader,
                                           feature=sequence_output,
                                           feed_list=feed_list,
                                           max_seq_len=args.max_seq_len,
                                           num_classes=dataset.num_labels,
                                           config=config,
                                           add_crf=args.add_crf)

    # Finetune and evaluate model by PaddleHub's API
    # will finish training, evaluation, testing, save model automatically
    if args.do_train:
        print("start finetune and eval process")
        seq_label_task.finetune_and_eval()

    if args.do_predict:
        print("start predict process")
        ret = []
        id2label = {val: key for key, val in reader.label_map.items()}
        input_data = [[d] for d in predict_data]
        run_states = seq_label_task.predict(data=input_data[1:])
        results = []
        for batch_states in run_states:
            batch_results = batch_states.run_results
            batch_infers = batch_results[0].reshape([-1]).astype(
                np.int32).tolist()
            seq_lens = batch_results[1].reshape([-1]).astype(np.int32).tolist()
            current_id = 0
            for length in seq_lens:
                seq_infers = batch_infers[current_id:current_id + length]
                seq_result = list(map(id2label.get, seq_infers[1:-1]))
                current_id += length if args.add_crf else args.max_seq_len
                results.append(seq_result)

        ret = []
        for sent, r_label in zip(predict_sents, results):
            sent["labels"] = r_label
            ret.append(json.dumps(sent, ensure_ascii=False))
        write_by_lines("{}.{}.pred".format(args.predict_data, args.do_model),
                       ret)
parser.add_argument("--batch_size", type=int, default=32, help="Total examples' number in batch for training.")
parser.add_argument("--add_crf", type=ast.literal_eval, default=True, help="add crf")
parser.add_argument("--checkpoint_dir", type=str, default=None, help="Directory to model checkpoint")
parser.add_argument("--use_data_parallel", type=ast.literal_eval, default=False, help="Whether use data parallel.")
args = parser.parse_args()
# yapf: enable.

# 先把数据处理好保存下来
train_data = data_process(args.train_data, args.do_model)  # 处理训练数据
dev_data = data_process(args.dev_data, args.do_model)  # 处理dev数据
test_data = data_process(args.test_data, args.do_model)
predict_sents, predict_data = data_process(args.predict_data,
                                           args.do_model,
                                           is_predict=True)

write_by_lines("{}/{}_train.tsv".format(args.data_dir, args.do_model),
               train_data)
write_by_lines("{}/{}_dev.tsv".format(args.data_dir, args.do_model), dev_data)
write_by_lines("{}/{}_test.tsv".format(args.data_dir, args.do_model),
               test_data)
write_by_lines("{}/{}_predict.tsv".format(args.data_dir, args.do_model),
               predict_data)

schema_labels = schema_process(args.schema_path, args.do_model)


class EEDataset(BaseNLPDataset):
    """EEDataset"""
    def __init__(self, data_dir, labels, model="trigger"):
        # 数据集存放位置
        super(EEDataset, self).__init__(
            base_path=data_dir,