Exemplo n.º 1
0
def do_train(args):
    # Initialize the paddle execute enviroment
    paddle.enable_static()
    place = paddle.set_device(args.select_device)

    # Set the random seed
    set_seed(args.seed)

    # Define the input data in the static mode
    main_program = paddle.static.default_main_program()
    startup_program = paddle.static.default_startup_program()
    data_holders = create_data_holder(args)
    [
        input_ids, segment_ids, input_mask, masked_lm_positions,
        masked_lm_labels, next_sentence_labels, masked_lm_scale
    ] = data_holders

    # Define the model structure in static mode
    args.model_type = args.model_type.lower()
    model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
    config = model_class.pretrained_init_configuration[args.model_name_or_path]
    if config["vocab_size"] % 8 != 0:
        config["vocab_size"] += 8 - (config["vocab_size"] % 8)
    model = BertForPretraining(BertModel(**config))
    criterion = BertPretrainingCriterion(model.bert.config["vocab_size"])
    prediction_scores, seq_relationship_score = model(
        input_ids=input_ids,
        token_type_ids=segment_ids,
        attention_mask=input_mask,
        masked_positions=masked_lm_positions)
    loss = criterion(prediction_scores, seq_relationship_score,
                     masked_lm_labels, next_sentence_labels, masked_lm_scale)

    # Define the dynamic learing_reate scheduler and optimizer
    lr_scheduler = paddle.optimizer.lr.LambdaDecay(
        args.learning_rate,
        lambda current_step, num_warmup_steps=args.warmup_steps,
        num_training_steps=args.max_steps if args.max_steps > 0 else
        (len(train_data_loader) * args.num_train_epochs): float(
            current_step) / float(max(1, num_warmup_steps))
        if current_step < num_warmup_steps else max(
            0.0,
            float(num_training_steps - current_step) / float(
                max(1, num_training_steps - num_warmup_steps))))

    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler,
        epsilon=args.adam_epsilon,
        parameters=model.parameters(),
        weight_decay=args.weight_decay,
        apply_decay_param_fun=lambda x: x in [
            p.name for n, p in model.named_parameters()
            if not any(nd in n for nd in ["bias", "norm"])
        ])
    if args.use_amp:
        amp_list = paddle.fluid.contrib.mixed_precision.AutoMixedPrecisionLists(
            custom_white_list=['layer_norm', 'softmax', 'gelu'])
        optimizer = paddle.fluid.contrib.mixed_precision.decorate(
            optimizer,
            amp_list,
            init_loss_scaling=args.scale_loss,
            use_dynamic_loss_scaling=True)
    optimizer.minimize(loss)

    # Define the Executor for running the static model
    exe = paddle.static.Executor(place)
    exe.run(startup_program)
    state_dict = model.state_dict()

    # Use the state dict to update the parameter
    reset_state_dict = reset_program_state_dict(model, state_dict)
    paddle.static.set_program_state(main_program, reset_state_dict)
    # Construct the compiled program
    main_program = build_compiled_program(args, main_program, loss)
    global_step = 0
    tic_train = time.time()
    epoch = 0
    while True:
        files = [
            os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir)
            if os.path.isfile(os.path.join(args.input_dir, f)) and "training" in
            f
        ]
        files.sort()
        random.Random(args.seed + epoch).shuffle(files)

        for f_id in range(0, len(files)):
            train_data_loader, _ = create_pretraining_dataset(
                files[f_id], args.max_predictions_per_seq, args, data_holders)
            for step, batch in enumerate(train_data_loader):
                global_step += 1
                loss_return = exe.run(main_program,\
                    feed=batch,
                    fetch_list=[loss])
                # In the new 2.0 api, must call this function to change the learning_rate
                lr_scheduler.step()
                if global_step % args.logging_steps == 0:
                    time_cost = time.time() - tic_train
                    print(
                        "global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s, ips: %.2f sequences/s"
                        % (global_step, epoch, step, loss_return[0],
                           args.logging_steps / time_cost,
                           args.logging_steps * args.batch_size / time_cost))
                    tic_train = time.time()
                if global_step % args.save_steps == 0:
                    output_dir = os.path.join(args.output_dir,
                                              "model_%d" % global_step)
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    # TODO(fangzeyang): Udpate the save_params to paddle.static
                    paddle.fluid.io.save_params(exe, output_dir)
                    tokenizer.save_pretrained(output_dir)
                if global_step >= args.max_steps:
                    del train_data_loader
                    return
            del train_data_loader
        epoch += 1
Exemplo n.º 2
0
def do_train(args):
    # Initialize the paddle and paddle fleet execute enviroment
    paddle.enable_static()
    place = paddle.set_device(args.select_device)
    fleet.init(is_collective=True)

    worker_num = fleet.worker_num()
    worker_index = fleet.worker_index()

    # Create the random seed for the worker
    set_seed(args.seed)
    worker_init = WorkerInitObj(args.seed + worker_index)

    # Define the input data in the static mode
    main_program = paddle.static.default_main_program()
    startup_program = paddle.static.default_startup_program()
    data_holders = create_data_holder(args)

    [
        input_ids, segment_ids, input_mask, masked_lm_positions,
        masked_lm_labels, next_sentence_labels, masked_lm_scale
    ] = data_holders

    # Define the model structure in static mode
    args.model_type = args.model_type.lower()
    model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
    config = model_class.pretrained_init_configuration[args.model_name_or_path]
    if config["vocab_size"] % 8 != 0:
        config["vocab_size"] += 8 - (config["vocab_size"] % 8)
    model = BertForPretraining(BertModel(**config))
    criterion = BertPretrainingCriterion(model.bert.config["vocab_size"])
    prediction_scores, seq_relationship_score = model(
        input_ids=input_ids,
        token_type_ids=segment_ids,
        attention_mask=input_mask,
        masked_positions=masked_lm_positions)
    loss = criterion(prediction_scores, seq_relationship_score,
                     masked_lm_labels, next_sentence_labels, masked_lm_scale)

    # Define the dynamic learing_reate scheduler and optimizer
    lr_scheduler = paddle.optimizer.lr.LambdaDecay(
        args.learning_rate,
        lambda current_step, num_warmup_steps=args.warmup_steps,
        num_training_steps=args.max_steps if args.max_steps > 0 else
        (len(train_data_loader) * args.num_train_epochs): float(
            current_step) / float(max(1, num_warmup_steps))
        if current_step < num_warmup_steps else max(
            0.0,
            float(num_training_steps - current_step) / float(
                max(1, num_training_steps - num_warmup_steps))))

    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler,
        epsilon=args.adam_epsilon,
        parameters=model.parameters(),
        weight_decay=args.weight_decay,
        apply_decay_param_fun=lambda x: x in [
            p.name for n, p in model.named_parameters()
            if not any(nd in n for nd in ["bias", "norm"])
        ],
        multi_precision=args.use_pure_fp16)
    if worker_num == 1 and args.use_amp:
        custom_black_list = (['lookup_table', 'lookup_table_v2']
                             if args.use_pure_fp16 else None)
        amp_list = paddle.static.amp.AutoMixedPrecisionLists(
            custom_white_list=['softmax', 'layer_norm', 'gelu'],
            custom_black_list=custom_black_list)
        optimizer = paddle.static.amp.decorate(
            optimizer,
            amp_list,
            init_loss_scaling=args.scale_loss,
            use_dynamic_loss_scaling=True,
            use_pure_fp16=args.use_pure_fp16)

    if worker_num > 1:
        # Use the fleet api to compile the distributed optimizer
        optimizer = dist_optimizer(args, optimizer)
    optimizer.minimize(loss)

    # Define the Executor for running the static model
    exe = paddle.static.Executor(place)
    exe.run(startup_program)
    state_dict = model.state_dict()

    # Use the state dict to update the parameter
    reset_state_dict = reset_program_state_dict(model, state_dict)
    paddle.static.set_program_state(main_program, reset_state_dict)
    if args.use_amp:
        optimizer.amp_init(place)

    if worker_num == 1:
        # Construct the compiled program
        main_program = build_compiled_program(main_program, loss)

    pool = ThreadPoolExecutor(1)
    global_step = 0
    tic_train = time.time()
    epoch = 0
    while True:
        files = [
            os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir)
            if os.path.isfile(os.path.join(args.input_dir, f))
            and "training" in f
        ]
        files.sort()
        num_files = len(files)
        random.Random(args.seed + epoch).shuffle(files)
        f_start_id = 0

        # Select one file for each worker and create the DataLoader for the file
        data_file = select_dataset_file_for_each_worker(
            files, f_start_id, worker_num, worker_index)
        train_data_loader, _ = create_pretraining_dataset(
            data_file, args.max_predictions_per_seq, args, data_holders,
            worker_init, paddle.static.cuda_places())

        for f_id in range(f_start_id + 1, len(files)):
            data_file = select_dataset_file_for_each_worker(
                files, f_id, worker_num, worker_index)
            dataset_future = pool.submit(create_pretraining_dataset, data_file,
                                         args.max_predictions_per_seq, args,
                                         data_holders, worker_init,
                                         paddle.static.cuda_places())

            train_reader_cost = 0.0
            train_run_cost = 0.0
            total_samples = 0
            reader_start = time.time()
            for step, batch in enumerate(train_data_loader):
                train_reader_cost += time.time() - reader_start
                global_step += 1
                train_start = time.time()
                loss_return = exe.run(main_program,
                                      feed=batch,
                                      fetch_list=[loss])
                train_run_cost += time.time() - train_start
                total_samples += args.batch_size
                # In the new 2.0 api, must call this function to change the learning_rate
                lr_scheduler.step()
                if global_step % args.logging_steps == 0:
                    print(
                        "tobal step: %d, epoch: %d, batch: %d, loss: %f, "
                        "avg_reader_cost: %.5f sec, avg_batch_cost: %.5f sec, avg_samples: %.5f, ips: %.5f sequences/sec"
                        % (global_step, epoch, step, loss_return[0],
                           train_reader_cost / args.logging_steps,
                           (train_reader_cost + train_run_cost) /
                           args.logging_steps,
                           total_samples / args.logging_steps, total_samples /
                           (train_reader_cost + train_run_cost)))
                    train_reader_cost = 0.0
                    train_run_cost = 0.0
                    total_samples = 0
                if global_step % args.save_steps == 0:
                    if worker_index == 0:
                        output_dir = os.path.join(args.output_dir,
                                                  "model_%d" % global_step)
                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)
                        # TODO(fangzeyang): Udpate the save_params to paddle.static
                        paddle.fluid.io.save_params(exe, output_dir)
                        tokenizer.save_pretrained(output_dir)
                if global_step >= args.max_steps:
                    reader_start = time.time()
                    del train_data_loader
                    return
                reader_start = time.time()
            del train_data_loader
            train_data_loader, data_file = dataset_future.result(timeout=None)
        epoch += 1
Exemplo n.º 3
0
def do_train(args):
    # Initialize the paddle and paddle fleet execute enviroment
    paddle.enable_static()
    place = paddle.set_device(args.device)
    fleet.init(is_collective=True)

    worker_num = fleet.worker_num()
    worker_index = fleet.worker_index()

    # Create the random seed for the worker
    set_seed(args.seed)
    worker_init = WorkerInitObj(args.seed + worker_index)

    # Define the input data in the static mode
    main_program = paddle.static.default_main_program()
    startup_program = paddle.static.default_startup_program()

    data_holders = create_data_holder(args)

    [
        input_ids, segment_ids, input_mask, masked_lm_positions,
        masked_lm_labels, next_sentence_labels, masked_lm_scale
    ] = data_holders

    # Define the model structure in static mode
    args.model_type = args.model_type.lower()
    model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
    config = model_class.pretrained_init_configuration[args.model_name_or_path]
    if config["vocab_size"] % 8 != 0:
        config["vocab_size"] += 8 - (config["vocab_size"] % 8)
    model = BertForPretraining(BertModel(**config))
    criterion = BertPretrainingCriterion(model.bert.config["vocab_size"])
    prediction_scores, seq_relationship_score = model(
        input_ids=input_ids,
        token_type_ids=segment_ids,
        attention_mask=input_mask,
        masked_positions=masked_lm_positions)
    loss = criterion(prediction_scores, seq_relationship_score,
                     masked_lm_labels, next_sentence_labels, masked_lm_scale)

    # Define the dynamic learing_reate scheduler and optimizer
    num_training_steps = args.max_steps if args.max_steps > 0 else len(
        train_data_loader) * args.num_train_epochs

    lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps,
                                         args.warmup_steps)

    # Generate parameter names needed to perform weight decay.
    # All bias and LayerNorm parameters are excluded.
    decay_params = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ]
    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler,
        epsilon=args.adam_epsilon,
        parameters=model.parameters(),
        weight_decay=args.weight_decay,
        apply_decay_param_fun=lambda x: x in decay_params,
        multi_precision=args.use_pure_fp16)

    # Use the fleet api to compile the distributed optimizer
    optimizer = dist_optimizer(args, optimizer)
    optimizer.minimize(loss)

    # Define the Executor for running the static model
    exe = paddle.static.Executor(place)
    exe.run(startup_program)
    state_dict = model.state_dict()

    # Use the state dict to update the parameter
    reset_state_dict = reset_program_state_dict(model, state_dict)
    paddle.static.set_program_state(main_program, reset_state_dict)
    if args.use_amp:
        optimizer.amp_init(place)

    pool = ThreadPoolExecutor(1)
    global_step = 0
    tic_train = time.time()
    epoch = 0
    while True:
        files = [
            os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir)
            if os.path.isfile(os.path.join(args.input_dir, f)) and "training" in
            f
        ]
        files.sort()
        num_files = len(files)
        random.Random(args.seed + epoch).shuffle(files)
        f_start_id = 0

        # Select one file for each worker and create the DataLoader for the file
        data_file = select_dataset_file_for_each_worker(
            files, f_start_id, worker_num, worker_index)
        train_data_loader, _ = create_pretraining_dataset(
            data_file, args.max_predictions_per_seq, args, data_holders,
            worker_init, paddle.static.cuda_places())

        for f_id in range(f_start_id + 1, len(files)):
            data_file = select_dataset_file_for_each_worker(
                files, f_id, worker_num, worker_index)
            dataset_future = pool.submit(create_pretraining_dataset, data_file,
                                         args.max_predictions_per_seq, args,
                                         data_holders, worker_init,
                                         paddle.static.cuda_places())

            train_cost_avg = TimeCostAverage()
            reader_cost_avg = TimeCostAverage()
            total_samples = 0
            batch_start = time.time()
            for step, batch in enumerate(train_data_loader):
                train_reader_cost = time.time() - batch_start
                reader_cost_avg.record(train_reader_cost)
                global_step += 1
                train_start = time.time()
                loss_return = exe.run(main_program,
                                      feed=batch,
                                      fetch_list=[loss])
                total_samples += args.batch_size
                # In the new 2.0 api, must call this function to change the learning_rate
                lr_scheduler.step()
                train_run_cost = time.time() - batch_start
                train_cost_avg.record(train_run_cost)

                # Profile for model benchmark
                if args.profiler_options is not None:
                    profiler.add_profiler_step(args.profiler_options)

                if global_step % args.logging_steps == 0:
                    print(
                        "tobal step: %d, epoch: %d, batch: %d, loss: %f, "
                        "avg_reader_cost: %.5f sec, avg_batch_cost: %.5f sec, avg_samples: %.5f, ips: %.5f sequences/sec"
                        % (global_step, epoch, step, loss_return[0],
                           reader_cost_avg.get_average(),
                           train_cost_avg.get_average(), total_samples /
                           args.logging_steps, args.batch_size / (
                               reader_cost_avg.get_average() +
                               train_cost_avg.get_average())))
                    total_samples = 0
                    train_cost_avg.reset()
                    reader_cost_avg.reset()
                if global_step % args.save_steps == 0:
                    if worker_index == 0:
                        output_dir = os.path.join(args.output_dir,
                                                  "model_%d" % global_step)
                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)
                        model.save_model_config(output_dir)
                        paddle.static.save(main_program,
                                           os.path.join(output_dir,
                                                        "model_state"))
                        tokenizer.save_pretrained(output_dir)
                if global_step >= args.max_steps:
                    reader_start = time.time()
                    del train_data_loader
                    return
                batch_start = time.time()
            del train_data_loader
            train_data_loader, data_file = dataset_future.result(timeout=None)
        epoch += 1
Exemplo n.º 4
0
def do_train(args):
    paddle.enable_static() if not args.eager_run else None
    paddle.set_device("gpu" if args.n_gpu else "cpu")
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    set_seed(args)

    args.task_name = args.task_name.lower()
    dataset_class, metric_class = TASK_CLASSES[args.task_name]
    args.model_type = args.model_type.lower()
    model_class, tokenizer_class = MODEL_CLASSES[args.model_type]

    train_dataset, dev_dataset = dataset_class.get_datasets(["train", "dev"])
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)

    trans_func = partial(convert_example,
                         tokenizer=tokenizer,
                         label_list=train_dataset.get_labels(),
                         max_seq_length=args.max_seq_length)
    train_dataset = train_dataset.apply(trans_func, lazy=True)
    # train_batch_sampler = SamplerHelper(train_dataset).shuffle().batch(
    #     batch_size=args.batch_size).shard()
    train_batch_sampler = paddle.io.DistributedBatchSampler(
        # train_dataset, batch_size=args.batch_size, shuffle=True)
        train_dataset,
        batch_size=args.batch_size,
        shuffle=False)
    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=tokenizer.vocab[tokenizer.pad_token]),  # input
        Pad(axis=0, pad_val=tokenizer.vocab[tokenizer.pad_token]),  # segment
        Stack(),  # length
        Stack(dtype="int64"
              if train_dataset.get_labels() else "float32")  # label
    ): [data for i, data in enumerate(fn(samples)) if i != 2]
    train_data_loader = DataLoader(dataset=train_dataset,
                                   batch_sampler=train_batch_sampler,
                                   collate_fn=batchify_fn,
                                   num_workers=0,
                                   return_list=True)
    dev_dataset = dev_dataset.apply(trans_func, lazy=True)
    # dev_batch_sampler = SamplerHelper(dev_dataset).batch(
    #     batch_size=args.batch_size)
    dev_batch_sampler = paddle.io.BatchSampler(dev_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False)
    dev_data_loader = DataLoader(dataset=dev_dataset,
                                 batch_sampler=dev_batch_sampler,
                                 collate_fn=batchify_fn,
                                 num_workers=0,
                                 return_list=True)

    # model = model_class.from_pretrained(
    #     args.model_name_or_path,) num_classes=len(train_dataset.get_labels()))
    model = BertForPretraining(
        BertModel(**model_class.pretrained_init_configuration[
            args.model_name_or_path]))
    if paddle.distributed.get_world_size() > 1:
        model = paddle.DataParallel(model)

    num_training_steps = args.max_steps if args.max_steps > 0 else len(
        train_data_loader) * args.num_train_epochs

    lr_scheduler = LinearDecayWithWarmup(args.learning_rate,
                                         num_training_steps, args.warmup_steps)

    # Generate parameter names needed to perform weight decay.
    # All bias and LayerNorm parameters are excluded.
    decay_params = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ]
    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler,
        epsilon=args.adam_epsilon,
        parameters=model.parameters(),
        weight_decay=args.weight_decay,
        apply_decay_param_fun=lambda x: x in decay_params)

    loss_fct = paddle.nn.loss.CrossEntropyLoss() if train_dataset.get_labels(
    ) else paddle.nn.loss.MSELoss()

    metric = metric_class()

    ### TODO: use hapi
    # trainer = paddle.hapi.Model(model)
    # trainer.prepare(optimizer, loss_fct, paddle.metric.Accuracy())
    # trainer.fit(train_data_loader,
    #             dev_data_loader,
    #             log_freq=args.logging_steps,
    #             epochs=args.num_train_epochs,
    #             save_dir=args.output_dir)

    model.eval()
    param_names = list(model.state_dict().keys())
    import pickle
    with open(args.params_pd_path, "rb") as f:
        np_params = pickle.load(f)
    model.set_state_dict(dict(zip(param_names, np_params)))
    paddle.save(model.state_dict(), "%s.pdparams" % args.model_name_or_path)
    for data in train_data_loader():
        print(model(*data[:-1]))
        exit(0)

    global_step = 0
    tic_train = time.time()
    for epoch in range(args.num_train_epochs):
        for step, batch in enumerate(train_data_loader):
            input_ids, segment_ids, labels = batch
            logits = model(input_ids, segment_ids)
            loss = loss_fct(logits, labels)
            if global_step % args.logging_steps == 0:
                print(
                    "global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s"
                    % (global_step, epoch, step, loss, args.logging_steps /
                       (time.time() - tic_train)))
                tic_train = time.time()
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.clear_grad()
            if global_step % args.save_steps == 0:
                evaluate(model, loss_fct, metric, dev_data_loader)
                if (not args.n_gpu > 1) or paddle.distributed.get_rank() == 0:
                    paddle.save(
                        model.state_dict(),
                        os.path.join(args.output_dir,
                                     "model_%d.pdparams" % global_step))
            global_step += 1
Exemplo n.º 5
0
def do_train(args):
    paddle.set_device("gpu" if args.n_gpu else "cpu")
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    set_seed(args)
    worker_init = WorkerInitObj(args.seed + paddle.distributed.get_rank())

    args.model_type = args.model_type.lower()
    model_class, tokenizer_class = MODEL_CLASSES[args.model_type]

    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)

    model = BertForPretraining(
        BertModel(**model_class.pretrained_init_configuration[
            args.model_name_or_path]))
    criterion = BertPretrainingCriterion(
        getattr(model,
                BertForPretraining.base_model_prefix).config["vocab_size"])
    if paddle.distributed.get_world_size() > 1:
        model = paddle.DataParallel(model)

    # If use defalut last_epoch, lr of the first iteration is 0.
    # Use `last_epoch = 0` to be consistent with nv bert.
    lr_scheduler = paddle.optimizer.lr.LambdaDecay(
        args.learning_rate,
        lambda current_step, num_warmup_steps=args.warmup_steps,
        num_training_steps=args.max_steps if args.max_steps > 0 else
        (len(train_data_loader) * args.num_train_epochs): float(
            current_step) / float(max(1, num_warmup_steps))
        if current_step < num_warmup_steps else max(
            0.0,
            float(num_training_steps - current_step) / float(
                max(1, num_training_steps - num_warmup_steps))),
        last_epoch=0)

    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler,
        epsilon=args.adam_epsilon,
        parameters=model.parameters(),
        weight_decay=args.weight_decay,
        apply_decay_param_fun=lambda x: x in [
            p.name for n, p in model.named_parameters()
            if not any(nd in n for nd in ["bias", "norm"])
        ])

    pool = ThreadPoolExecutor(1)
    global_step = 0
    tic_train = time.time()
    for epoch in range(args.num_train_epochs):
        files = [
            os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir)
            if os.path.isfile(os.path.join(args.input_dir, f))
            and "training" in f
        ]
        files.sort()
        num_files = len(files)
        random.Random(args.seed + epoch).shuffle(files)
        f_start_id = 0

        shared_file_list = {}

        if paddle.distributed.get_world_size() > num_files:
            remainder = paddle.distributed.get_world_size() % num_files
            data_file = files[
                (f_start_id * paddle.distributed.get_world_size() +
                 paddle.distributed.get_rank() + remainder * f_start_id) %
                num_files]
        else:
            data_file = files[
                (f_start_id * paddle.distributed.get_world_size() +
                 paddle.distributed.get_rank()) % num_files]

        previous_file = data_file

        train_data_loader, _ = create_pretraining_dataset(
            data_file, args.max_predictions_per_seq, shared_file_list, args,
            worker_init)

        # TODO(guosheng): better way to process single file
        single_file = True if f_start_id + 1 == len(files) else False

        for f_id in range(f_start_id, len(files)):
            if not single_file and f_id == f_start_id:
                continue
            if paddle.distributed.get_world_size() > num_files:
                data_file = files[(f_id * paddle.distributed.get_world_size() +
                                   paddle.distributed.get_rank() +
                                   remainder * f_id) % num_files]
            else:
                data_file = files[(f_id * paddle.distributed.get_world_size() +
                                   paddle.distributed.get_rank()) % num_files]

            previous_file = data_file
            dataset_future = pool.submit(create_pretraining_dataset, data_file,
                                         args.max_predictions_per_seq,
                                         shared_file_list, args, worker_init)
            for step, batch in enumerate(train_data_loader):
                global_step += 1
                (input_ids, segment_ids, input_mask, masked_lm_positions,
                 masked_lm_labels, next_sentence_labels,
                 masked_lm_scale) = batch
                prediction_scores, seq_relationship_score = model(
                    input_ids=input_ids,
                    token_type_ids=segment_ids,
                    attention_mask=input_mask,
                    masked_positions=masked_lm_positions)
                loss = criterion(prediction_scores, seq_relationship_score,
                                 masked_lm_labels, next_sentence_labels,
                                 masked_lm_scale)
                if global_step % args.logging_steps == 0:
                    if (not args.n_gpu > 1
                        ) or paddle.distributed.get_rank() == 0:
                        logger.info(
                            "global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s"
                            % (global_step, epoch, step, loss,
                               args.logging_steps / (time.time() - tic_train)))
                    tic_train = time.time()
                loss.backward()
                optimizer.step()
                lr_scheduler.step()
                optimizer.clear_gradients()
                if global_step % args.save_steps == 0:
                    if (not args.n_gpu > 1
                        ) or paddle.distributed.get_rank() == 0:
                        output_dir = os.path.join(args.output_dir,
                                                  "model_%d" % global_step)
                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)
                        # need better way to get inner model of DataParallel
                        model_to_save = model._layers if isinstance(
                            model, paddle.DataParallel) else model
                        model_to_save.save_pretrained(output_dir)
                        tokenizer.save_pretrained(output_dir)
                        paddle.save(
                            optimizer.state_dict(),
                            os.path.join(output_dir, "model_state.pdopt"))
                if global_step >= args.max_steps:
                    del train_data_loader
                    return

            del train_data_loader
            train_data_loader, data_file = dataset_future.result(timeout=None)
Exemplo n.º 6
0
def do_train(args):
    # Initialize the paddle and paddle fleet execute enviroment
    paddle.enable_static()
    place = paddle.set_device(args.select_device)
    fleet.init(is_collective=True)
    # paddle.distributed.init_parallel_env()

    worker_num = fleet.worker_num()
    worker_index = fleet.worker_index()

    # Create the random seed for the worker
    set_seed(args.seed)
    # worker_init = WorkerInitObj(args.seed + worker_index)
    worker_init = WorkerInitObj(args.seed)
    tracker = get_rng_state_tracker()
    tracker.add('global_seed', args.seed)
    tracker.add('local_seed', args.seed + worker_index + 2021)

    # Define the input data in the static mode
    main_program = paddle.static.default_main_program()
    startup_program = paddle.static.default_startup_program()
    data_holders = create_data_holder(args)

    [
        input_ids, segment_ids, input_mask, masked_lm_positions,
        masked_lm_labels, next_sentence_labels, masked_lm_scale
    ] = data_holders

    # Define the model structure in static mode
    args.model_type = args.model_type.lower()
    model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
    config = model_class.pretrained_init_configuration[args.model_name_or_path]
    if config["vocab_size"] % 8 != 0:
        config["vocab_size"] += 8 - (config["vocab_size"] % 8)
    config['num_partitions'] = args.num_partitions
    model = BertForPretraining(BertModel(**config), args.num_partitions)
    criterion = BertPretrainingCriterion(model.bert.config["vocab_size"])
    prediction_scores, seq_relationship_score = model(
        input_ids=input_ids,
        token_type_ids=segment_ids,
        attention_mask=input_mask,
        masked_positions=masked_lm_positions)
    loss = criterion(prediction_scores, seq_relationship_score,
                     masked_lm_labels, next_sentence_labels, masked_lm_scale)

    # Define the dynamic learing_reate scheduler and optimizer
    lr_scheduler = paddle.optimizer.lr.LambdaDecay(
        args.learning_rate,
        lambda current_step, num_warmup_steps=args.warmup_steps,
        num_training_steps=args.max_steps if args.max_steps > 0 else
        (len(train_data_loader) * args.num_train_epochs): float(
            current_step) / float(max(1, num_warmup_steps))
        if current_step < num_warmup_steps else max(
            0.0,
            float(num_training_steps - current_step) / float(
                max(1, num_training_steps - num_warmup_steps))))

    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler,
        epsilon=args.adam_epsilon,
        parameters=model.parameters(),
        weight_decay=args.weight_decay,
        apply_decay_param_fun=lambda x: x in [
            p.name for n, p in model.named_parameters()
            if not any(nd in n for nd in ["bias", "norm"])
        ])
    # if worker_num == 1 and args.use_amp:
    #     amp_list = paddle.fluid.contrib.mixed_precision.AutoMixedPrecisionLists(
    #         custom_white_list=['softmax', 'layer_norm', 'gelu'])
    #     optimizer = paddle.fluid.contrib.mixed_precision.decorate(
    #         optimizer,
    #         amp_list,
    #         init_loss_scaling=args.scale_loss,
    #         use_dynamic_loss_scaling=True)

    if fleet.worker_num() > 1:
        # Use the fleet api to compile the distributed optimizer
        optimizer = dist_optimizer(args, optimizer)
    optimizer.minimize(loss)

    # Define the Executor for running the static model
    exe = paddle.static.Executor(place)
    exe.run(startup_program)
    # state_dict = model.state_dict()

    # Use the state dict to update the parameter
    # reset_state_dict = reset_program_state_dict(model, state_dict)
    # paddle.static.set_program_state(main_program, reset_state_dict)

    # if worker_num == 1:
    #     # Construct the compiled program
    #     main_program = build_compiled_program(main_program, loss)
    main_program._graph = None

    if fleet.worker_index() == 0:
        with open('startup_%d' % fleet.worker_num(), 'w') as f:
            f.writelines(str(startup_program))
        with open('main_%d' % fleet.worker_num(), 'w') as f:
            f.writelines(str(main_program))
    pool = ThreadPoolExecutor(1)
    global_step = 0
    tic_train = time.time()
    epoch = 0
    while True:
        files = [
            os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir)
            if os.path.isfile(os.path.join(args.input_dir, f))
            and "training" in f
        ]
        files.sort()
        num_files = len(files)
        random.Random(args.seed + epoch).shuffle(files)
        f_start_id = 0

        # Select one file for each worker and create the DataLoader for the file
        data_file = select_dataset_file_for_each_worker(
            files, f_start_id, 1, 0)
        #files, f_start_id, worker_num, worker_index)
        train_data_loader, _ = create_pretraining_dataset(
            data_file, args.max_predictions_per_seq, args, data_holders,
            worker_init, paddle.static.cuda_places())

        for f_id in range(f_start_id + 1, len(files)):
            data_file = select_dataset_file_for_each_worker(files, f_id, 1, 0)
            # files, f_id, worker_num, worker_index)
            dataset_future = pool.submit(create_pretraining_dataset, data_file,
                                         args.max_predictions_per_seq, args,
                                         data_holders, worker_init,
                                         paddle.static.cuda_places())

            for step, batch in enumerate(train_data_loader):
                global_step += 1
                if step == 10 and worker_index == 0:
                    profiler.start_profiler("All")
                if step == 20 and worker_index == 0:
                    profiler.stop_profiler("total", "/tmp/profile")

                loss_return = exe.run(main_program,
                                      feed=batch,
                                      fetch_list=[loss])
                # In the new 2.0 api, must call this function to change the learning_rate
                lr_scheduler.step()
                if global_step % args.logging_steps == 0:
                    time_cost = time.time() - tic_train
                    print(
                        "global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s, ips: %.2f sequences/s"
                        % (global_step, epoch, step, loss_return[0],
                           args.logging_steps / time_cost,
                           args.logging_steps * args.batch_size / time_cost))
                    tic_train = time.time()
                if global_step % args.save_steps == 0:
                    if worker_index == 0:
                        output_dir = os.path.join(args.output_dir,
                                                  "model_%d" % global_step)
                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)
                        # TODO(fangzeyang): Udpate the save_params to paddle.static
                        paddle.fluid.io.save_params(exe, output_dir)
                        tokenizer.save_pretrained(output_dir)
                if global_step >= args.max_steps:
                    del train_data_loader
                    return
            del train_data_loader
            train_data_loader, data_file = dataset_future.result(timeout=None)
        epoch += 1
Exemplo n.º 7
0
def do_train(args):
    # Initialize the paddle and paddle fleet execute enviroment
    paddle.enable_static()
    place = paddle.CUDAPlace(int(os.environ.get('FLAGS_selected_gpus', 0)))
    fleet.init(is_collective=True)

    # Create the random seed for the worker
    set_seed(args.seed)
    worker_init = WorkerInitObj(args.seed + fleet.worker_index())

    # Define the input data in the static mode
    data_holders = create_data_holder(args)

    [
        input_ids, segment_ids, input_mask, masked_lm_positions,
        masked_lm_labels, next_sentence_labels, masked_lm_scale
    ] = data_holders

    # Define the model structure in static mode
    args.model_type = args.model_type.lower()
    model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
    model = BertForPretraining(
        BertModel(**model_class.pretrained_init_configuration[
            args.model_name_or_path]))
    criterion = BertPretrainingCriterion(model.bert.config["vocab_size"])
    prediction_scores, seq_relationship_score = model(
        input_ids=input_ids,
        token_type_ids=segment_ids,
        attention_mask=input_mask,
        masked_positions=masked_lm_positions)
    loss = criterion(prediction_scores, seq_relationship_score,
                     masked_lm_labels, next_sentence_labels, masked_lm_scale)

    num_training_steps = args.max_steps if args.max_steps > 0 else len(
        train_data_loader) * args.num_train_epochs
    # Define the dynamic learing_reate scheduler and optimizer
    lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps,
                                         args.warmup_steps)

    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler,
        epsilon=args.adam_epsilon,
        parameters=model.parameters(),
        weight_decay=args.weight_decay,
        apply_decay_param_fun=lambda x: x in [
            p.name for n, p in model.named_parameters()
            if not any(nd in n for nd in ["bias", "norm"])
        ])

    # Use the fleet api to compile the distributed optimizer
    strategy = fleet.DistributedStrategy()
    optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
    optimizer.minimize(loss)

    # Define the Executor for running the static model
    exe = paddle.static.Executor(place)
    exe.run(paddle.static.default_startup_program())
    state_dict = model.state_dict()

    # Use the state dict to update the parameter
    reset_state_dict = reset_program_state_dict(model, state_dict)
    paddle.static.set_program_state(paddle.static.default_main_program(),
                                    reset_state_dict)

    pool = ThreadPoolExecutor(1)
    global_step = 0
    tic_train = time.time()
    worker_num = fleet.worker_num()
    worker_index = fleet.worker_index()
    epoch = 0
    while True:
        files = [
            os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir)
            if os.path.isfile(os.path.join(args.input_dir, f)) and "training" in
            f
        ]
        files.sort()
        num_files = len(files)
        random.Random(args.seed + epoch).shuffle(files)
        f_start_id = 0

        # Select one file for each worker and create the DataLoader for the file
        data_file = select_dataset_file_for_each_worker(
            files, f_start_id, worker_num, worker_index)
        train_data_loader, _ = create_pretraining_dataset(
            data_file, args.max_predictions_per_seq, args, data_holders,
            worker_init, paddle.static.cuda_places())

        for f_id in range(f_start_id + 1, len(files)):
            data_file = select_dataset_file_for_each_worker(
                files, f_id, worker_num, worker_index)
            dataset_future = pool.submit(create_pretraining_dataset, data_file,
                                         args.max_predictions_per_seq, args,
                                         data_holders, worker_init,
                                         paddle.static.cuda_places())

            for step, batch in enumerate(train_data_loader):
                global_step += 1
                loss_return = exe.run(paddle.static.default_main_program(),\
                    feed=batch,
                    fetch_list=[loss])
                # In the new 2.0 api, must call this function to change the learning_rate
                lr_scheduler.step()
                if global_step % args.logging_steps == 0:
                    time_cost = time.time() - tic_train
                    print(
                        "global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s, ips :%.2f sequences/s"
                        % (global_step, epoch, step, loss_return[0],
                           args.logging_steps / time_cost,
                           args.logging_steps * args.batch_size / time_cost))
                    tic_train = time.time()
                if global_step % args.save_steps == 0:
                    if worker_index == 0:
                        output_dir = os.path.join(args.output_dir,
                                                  "model_%d" % global_step)
                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)
                        # TODO(fangzeyang): Udpate the save_params to paddle.static
                        paddle.fluid.io.save_params(exe, output_dir)
                        tokenizer.save_pretrained(output_dir)
                if global_step >= args.max_steps:
                    del train_data_loader
                    return
            del train_data_loader
            train_data_loader, data_file = dataset_future.result(timeout=None)
        epoch += 1
Exemplo n.º 8
0
def do_train(args):
    # Initialize the paddle execute enviroment
    paddle.enable_static()
    place = paddle.set_device(args.device)

    # Set the random seed
    set_seed(args.seed)

    # Define the input data in the static mode
    main_program = paddle.static.default_main_program()
    startup_program = paddle.static.default_startup_program()
    data_holders = create_data_holder(args)
    [
        input_ids, segment_ids, input_mask, masked_lm_positions,
        masked_lm_labels, next_sentence_labels, masked_lm_scale
    ] = data_holders

    # Define the model structure in static mode
    args.model_type = args.model_type.lower()
    model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
    config = model_class.pretrained_init_configuration[args.model_name_or_path]
    if config["vocab_size"] % 8 != 0:
        config["vocab_size"] += 8 - (config["vocab_size"] % 8)
    model = BertForPretraining(BertModel(**config))
    criterion = BertPretrainingCriterion(model.bert.config["vocab_size"])
    prediction_scores, seq_relationship_score = model(
        input_ids=input_ids,
        token_type_ids=segment_ids,
        attention_mask=input_mask,
        masked_positions=masked_lm_positions)
    loss = criterion(prediction_scores, seq_relationship_score,
                     masked_lm_labels, next_sentence_labels, masked_lm_scale)

    # Define the dynamic learing_reate scheduler and optimizer
    num_training_steps = args.max_steps if args.max_steps > 0 else len(
        train_data_loader) * args.num_train_epochs

    lr_scheduler = LinearDecayWithWarmup(args.learning_rate,
                                         num_training_steps, args.warmup_steps)

    # Generate parameter names needed to perform weight decay.
    # All bias and LayerNorm parameters are excluded.
    decay_params = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ]
    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler,
        epsilon=args.adam_epsilon,
        parameters=model.parameters(),
        weight_decay=args.weight_decay,
        apply_decay_param_fun=lambda x: x in decay_params,
        multi_precision=False)
    if args.use_amp:
        custom_black_list = (['lookup_table', 'lookup_table_v2']
                             if args.use_pure_fp16 else None)
        amp_list = paddle.static.amp.AutoMixedPrecisionLists(
            custom_white_list=['layer_norm', 'softmax', 'gelu'],
            custom_black_list=custom_black_list)
        optimizer = paddle.static.amp.decorate(
            optimizer,
            amp_list,
            init_loss_scaling=args.scale_loss,
            use_dynamic_loss_scaling=True,
            use_pure_fp16=args.use_pure_fp16)
    optimizer.minimize(loss)

    # Define the Executor for running the static model
    exe = paddle.static.Executor(place)
    exe.run(startup_program)
    state_dict = model.state_dict()

    # Use the state dict to update the parameter
    reset_state_dict = reset_program_state_dict(model, state_dict)
    paddle.static.set_program_state(main_program, reset_state_dict)
    if args.use_amp:
        optimizer.amp_init(place)
    # Construct the compiled program
    main_program = build_compiled_program(args, main_program, loss)
    global_step = 0
    tic_train = time.time()
    epoch = 0
    while True:
        files = [
            os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir)
            if os.path.isfile(os.path.join(args.input_dir, f))
            and "training" in f
        ]
        files.sort()
        random.Random(args.seed + epoch).shuffle(files)

        for f_id in range(0, len(files)):
            train_data_loader, _ = create_pretraining_dataset(
                files[f_id], args.max_predictions_per_seq, args, data_holders)
            train_reader_cost = 0.0
            train_run_cost = 0.0
            total_samples = 0
            reader_start = time.time()
            for step, batch in enumerate(train_data_loader):
                train_reader_cost += time.time() - reader_start
                global_step += 1
                train_start = time.time()
                loss_return = exe.run(main_program,\
                    feed=batch,
                    fetch_list=[loss])
                train_run_cost += time.time() - train_start
                total_samples += args.batch_size
                # In the new 2.0 api, must call this function to change the learning_rate
                lr_scheduler.step()
                if global_step % args.logging_steps == 0:
                    print(
                        "global step: %d, epoch: %d, batch: %d, loss: %f, "
                        "avg_reader_cost: %.5f sec, avg_batch_cost: %.5f sec, avg_samples: %.5f, ips: %.5f sequences/sec"
                        % (global_step, epoch, step, loss_return[0],
                           train_reader_cost / args.logging_steps,
                           (train_reader_cost + train_run_cost) /
                           args.logging_steps,
                           total_samples / args.logging_steps, total_samples /
                           (train_reader_cost + train_run_cost)))
                    train_reader_cost = 0.0
                    train_run_cost = 0.0
                    total_samples = 0
                if global_step % args.save_steps == 0:
                    output_dir = os.path.join(args.output_dir,
                                              "model_%d" % global_step)
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    # TODO(fangzeyang): Udpate the save_params to paddle.static
                    paddle.fluid.io.save_params(exe, output_dir)
                    tokenizer.save_pretrained(output_dir)
                if global_step >= args.max_steps:
                    reader_start = time.time()
                    del train_data_loader
                    return
                reader_start = time.time()
            del train_data_loader
        epoch += 1
Exemplo n.º 9
0
def do_train(args):
    paddle.set_device(args.device)
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    set_seed(args)

    args.task_name = args.task_name.lower()
    metric_class = METRIC_CLASSES[args.task_name]
    args.model_type = args.model_type.lower()
    model_class, tokenizer_class = MODEL_CLASSES[args.model_type]

    train_ds = load_dataset('glue', args.task_name, splits="train")
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)

    trans_func = partial(convert_example,
                         tokenizer=tokenizer,
                         label_list=train_ds.label_list,
                         max_seq_length=args.max_seq_length)
    train_ds = train_ds.map(trans_func, lazy=True)
    train_batch_sampler = paddle.io.DistributedBatchSampler(
        train_ds, batch_size=args.batch_size,
        shuffle=False)  # for same data when converting
    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # input
        Pad(axis=0, pad_val=tokenizer.pad_token_type_id),  # segment
        Stack(dtype="int64" if train_ds.label_list else "float32")  # label
    ): fn(samples)
    train_data_loader = DataLoader(dataset=train_ds,
                                   batch_sampler=train_batch_sampler,
                                   collate_fn=batchify_fn,
                                   num_workers=0,
                                   return_list=True)
    if args.task_name == "mnli":
        dev_ds_matched, dev_ds_mismatched = load_dataset(
            'glue', args.task_name, splits=["dev_matched", "dev_mismatched"])

        dev_ds_matched = dev_ds_matched.map(trans_func, lazy=True)
        dev_ds_mismatched = dev_ds_mismatched.map(trans_func, lazy=True)
        dev_batch_sampler_matched = paddle.io.BatchSampler(
            dev_ds_matched, batch_size=args.batch_size, shuffle=False)
        dev_data_loader_matched = DataLoader(
            dataset=dev_ds_matched,
            batch_sampler=dev_batch_sampler_matched,
            collate_fn=batchify_fn,
            num_workers=0,
            return_list=True)
        dev_batch_sampler_mismatched = paddle.io.BatchSampler(
            dev_ds_mismatched, batch_size=args.batch_size, shuffle=False)
        dev_data_loader_mismatched = DataLoader(
            dataset=dev_ds_mismatched,
            batch_sampler=dev_batch_sampler_mismatched,
            collate_fn=batchify_fn,
            num_workers=0,
            return_list=True)
    else:
        dev_ds = load_dataset('glue', args.task_name, splits='dev')
        dev_ds = dev_ds.map(trans_func, lazy=True)
        dev_batch_sampler = paddle.io.BatchSampler(dev_ds,
                                                   batch_size=args.batch_size,
                                                   shuffle=False)
        dev_data_loader = DataLoader(dataset=dev_ds,
                                     batch_sampler=dev_batch_sampler,
                                     collate_fn=batchify_fn,
                                     num_workers=0,
                                     return_list=True)

    num_classes = 1 if train_ds.label_list == None else len(
        train_ds.label_list)
    # model = model_class.from_pretrained(
    #     args.model_name_or_path, num_classes=num_classes)
    model = BertForPretraining(
        BertModel(**model_class.pretrained_init_configuration[
            args.model_name_or_path]))
    if paddle.distributed.get_world_size() > 1:
        model = paddle.DataParallel(model)

    num_training_steps = args.max_steps if args.max_steps > 0 else (
        len(train_data_loader) * args.num_train_epochs)
    warmup = args.warmup_steps if args.warmup_steps > 0 else args.warmup_proportion

    lr_scheduler = LinearDecayWithWarmup(args.learning_rate,
                                         num_training_steps, warmup)

    # Generate parameter names needed to perform weight decay.
    # All bias and LayerNorm parameters are excluded.
    decay_params = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ]
    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler,
        beta1=0.9,
        beta2=0.999,
        epsilon=args.adam_epsilon,
        parameters=model.parameters(),
        weight_decay=args.weight_decay,
        apply_decay_param_fun=lambda x: x in decay_params)

    loss_fct = paddle.nn.loss.CrossEntropyLoss(
    ) if train_ds.label_list else paddle.nn.loss.MSELoss()

    metric = metric_class()
    if args.use_amp:
        scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss)

    # load converted model and run once to compare
    model.eval()
    param_names = list(model.state_dict().keys())
    import pickle
    with open(args.params_pd_path, "rb") as f:
        np_params = pickle.load(f)
    model.set_state_dict(dict(zip(param_names, np_params)))
    paddle.save(model.state_dict(), "%s.pdparams" % args.model_name_or_path)
    for data in train_data_loader():
        print(model(*data[:-1]))
        exit(0)

    global_step = 0
    tic_train = time.time()
    for epoch in range(args.num_train_epochs):
        for step, batch in enumerate(train_data_loader):
            global_step += 1

            input_ids, segment_ids, labels = batch
            with paddle.amp.auto_cast(
                    args.use_amp,
                    custom_white_list=["layer_norm", "softmax", "gelu"]):
                logits = model(input_ids, segment_ids)
                loss = loss_fct(logits, labels)
            if args.use_amp:
                scaler.scale(loss).backward()
                scaler.minimize(optimizer, loss)
            else:
                loss.backward()
                optimizer.step()
            lr_scheduler.step()
            optimizer.clear_grad()
            if global_step % args.logging_steps == 0:
                print(
                    "global step %d/%d, epoch: %d, batch: %d, rank_id: %s, loss: %f, lr: %.10f, speed: %.4f step/s"
                    % (global_step, num_training_steps, epoch, step,
                       paddle.distributed.get_rank(), loss, optimizer.get_lr(),
                       args.logging_steps / (time.time() - tic_train)))
                tic_train = time.time()
            if global_step % args.save_steps == 0 or global_step == num_training_steps:
                tic_eval = time.time()
                if args.task_name == "mnli":
                    evaluate(model, loss_fct, metric, dev_data_loader_matched)
                    evaluate(model, loss_fct, metric,
                             dev_data_loader_mismatched)
                    print("eval done total : %s s" % (time.time() - tic_eval))
                else:
                    evaluate(model, loss_fct, metric, dev_data_loader)
                    print("eval done total : %s s" % (time.time() - tic_eval))
                if paddle.distributed.get_rank() == 0:
                    output_dir = os.path.join(
                        args.output_dir, "%s_ft_model_%d.pdparams" %
                        (args.task_name, global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    # Need better way to get inner model of DataParallel
                    model_to_save = model._layers if isinstance(
                        model, paddle.DataParallel) else model
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)