Example #1
0
parser.add_argument("--use_data_parallel", type=ast.literal_eval, default=False, help="Whether use data parallel.")
args = parser.parse_args()
# yapf: enable.

if __name__ == '__main__':

    # Load Paddlehub ERNIE Tiny pretrained model
    module = hub.Module(name="ernie_tiny")
    inputs, outputs, program = module.context(
        trainable=True, max_seq_len=args.max_seq_len)

    # Use the appropriate tokenizer to preprocess the data set
    # For ernie_tiny, it will do word segmentation to get subword. More details: https://www.jiqizhixin.com/articles/2019-11-06-9
    if module.name == "ernie_tiny":
        tokenizer = hub.ErnieTinyTokenizer(
            vocab_file=module.get_vocab_path(),
            spm_path=module.get_spm_path(),
            word_dict_path=module.get_word_dict_path())
    else:
        tokenizer = hub.BertTokenizer(vocab_file=module.get_vocab_path())

    dataset = hub.dataset.ChnSentiCorp(
        tokenizer=tokenizer, max_seq_len=args.max_seq_len)

    # Construct transfer learning network
    # Use "pooled_output" for classification tasks on an entire sentence.
    # Use "sequence_output" for token-level output.
    pooled_output = outputs["pooled_output"]

    # Select fine-tune strategy, setup config and fine-tune
    strategy = hub.AdamWeightDecayStrategy(
        warmup_proportion=args.warmup_proportion,
def finetune(args):
    module = hub.Module(name="ernie", max_seq_len=args.max_seq_len)
    # Use the appropriate tokenizer to preprocess the data set
    # For ernie_tiny, it will do word segmentation to get subword. More details: https://www.jiqizhixin.com/articles/2019-11-06-9
    if module.name == "ernie_tiny":
        tokenizer = hub.ErnieTinyTokenizer(
            vocab_file=module.get_vocab_path(),
            spm_path=module.get_spm_path(),
            word_dict_path=module.get_word_dict_path(),
        )
    else:
        tokenizer = hub.BertTokenizer(vocab_file=module.get_vocab_path())
    dataset = hub.dataset.ChnSentiCorp(tokenizer=tokenizer,
                                       max_seq_len=args.max_seq_len)

    with fluid.dygraph.guard():
        tc = TransformerClassifier(num_classes=dataset.num_labels,
                                   transformer=module)
        adam = AdamOptimizer(learning_rate=1e-5,
                             parameter_list=tc.parameters())
        state_dict_path = os.path.join(args.checkpoint_dir,
                                       'dygraph_state_dict')
        if os.path.exists(state_dict_path + '.pdparams'):
            state_dict, _ = fluid.load_dygraph(state_dict_path)
            tc.load_dict(state_dict)

        loss_sum = acc_sum = cnt = 0
        for epoch in range(args.num_epoch):
            for batch_id, data in enumerate(
                    dataset.batch_records_generator(
                        phase="train",
                        batch_size=args.batch_size,
                        shuffle=True,
                        pad_to_batch_max_seq_len=False)):
                batch_size = len(data["input_ids"])
                input_ids = np.array(data["input_ids"]).astype(
                    np.int64).reshape([batch_size, -1, 1])
                position_ids = np.array(data["position_ids"]).astype(
                    np.int64).reshape([batch_size, -1, 1])
                segment_ids = np.array(data["segment_ids"]).astype(
                    np.int64).reshape([batch_size, -1, 1])
                input_mask = np.array(data["input_mask"]).astype(
                    np.float32).reshape([batch_size, -1, 1])
                labels = np.array(data["label"]).astype(np.int64).reshape(
                    [batch_size, 1])
                pred = tc(input_ids, position_ids, segment_ids, input_mask)

                acc = fluid.layers.accuracy(pred, to_variable(labels))
                loss = fluid.layers.cross_entropy(pred, to_variable(labels))
                avg_loss = fluid.layers.mean(loss)
                avg_loss.backward()
                adam.minimize(avg_loss)

                loss_sum += avg_loss.numpy() * labels.shape[0]
                acc_sum += acc.numpy() * labels.shape[0]
                cnt += labels.shape[0]
                if batch_id % args.log_interval == 0:
                    print('epoch {}: loss {}, acc {}'.format(
                        epoch, loss_sum / cnt, acc_sum / cnt))
                    loss_sum = acc_sum = cnt = 0

                if batch_id % args.save_interval == 0:
                    state_dict = tc.state_dict()
                    fluid.save_dygraph(state_dict, state_dict_path)