Exemple #1
0
def train():
    # 加载预训练bert
    model = BertForQuestionAnswering.from_pretrained(
        "bert-base-chinese",
        cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE),
                               'distributed_{}'.format(-1)))
    device = args.device
    model.to(device)

    # 准备 optimizer
    optimizer = get_bert_optimizer()

    # 准备数据
    data = Dureader()
    train_dataloader, dev_dataloader = data.train_iter, data.dev_iter

    best_loss = 100000.0
    model.train()
    for i in range(args.num_train_epochs):
        for step, batch in enumerate(tqdm(train_dataloader, desc="Epoch")):
            input_ids, input_mask, segment_ids, start_positions, end_positions = \
                                        batch.input_ids, batch.input_mask, batch.segment_ids, batch.start_position, batch.end_position
            input_ids, input_mask, segment_ids, start_positions, end_positions = \
                                        input_ids.to(device), input_mask.to(device), segment_ids.to(device), start_positions.to(device), end_positions.to(device)

            # 计算loss
            loss, _, _ = model(input_ids,
                               token_type_ids=segment_ids,
                               attention_mask=input_mask,
                               start_positions=start_positions,
                               end_positions=end_positions)
            loss = loss / args.gradient_accumulation_steps
            loss.backward()

            # 更新梯度
            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()

            # 验证
            if step % args.log_step == 4:
                eval_loss = evaluate.evaluate(model, dev_dataloader)
                if eval_loss < best_loss:
                    best_loss = eval_loss
                    torch.save(model.state_dict(),
                               './model_dir/' + "best_model")
                    model.train()
Exemple #2
0
def train():
    # 加载预训练bert
    model = BertForQuestionAnswering.from_pretrained(
        "bert-base-chinese",
        cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE),
                               'distributed_{}'.format(-1)))
    device = args.device
    model.to(device)

    # 准备 optimizer
    param_optimizer = list(model.named_parameters())
    param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         warmup=0.1,
                         t_total=args.num_train_optimization_steps)

    # 准备数据
    data = Dureader()
    train_dataloader, dev_dataloader = data.train_iter, data.dev_iter

    best_loss = 100000.0
    model.train()
    for i in range(args.num_train_epochs):
        for step, batch in enumerate(tqdm(train_dataloader, desc="Epoch")):
            input_ids, input_mask, segment_ids, start_positions, end_positions = \
                                        batch.input_ids, batch.input_mask, batch.segment_ids, batch.start_position, batch.end_position
            input_ids, input_mask, segment_ids, start_positions, end_positions = \
                                        input_ids.to(device), input_mask.to(device), segment_ids.to(device), start_positions.to(device), end_positions.to(device)

            # 计算loss
            loss, _, _ = model(input_ids,
                               token_type_ids=segment_ids,
                               attention_mask=input_mask,
                               start_positions=start_positions,
                               end_positions=end_positions)
            loss = loss / args.gradient_accumulation_steps
            loss.backward()

            # 更新梯度
            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()

            # 验证
            if step % args.log_step == 4:
                eval_loss = evaluate.evaluate(model, dev_dataloader)
                if eval_loss < best_loss:
                    best_loss = eval_loss
                    torch.save(model.state_dict(),
                               './model_dir/' + "best_model")
                    model.train()
Exemple #3
0
def train():
    # 加载预训练bert
    config = XLNetConfig.from_pretrained('xlnet_config.json')
    model = XLNetForQuestionAnswering.from_pretrained('xlnet_model.ckpt.index',
                                                      from_tf=True,
                                                      config=config)
    device = args.device
    model.to(device)

    # 准备 optimizer
    param_optimizer = list(model.named_parameters())
    param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    optimizer = adabound.AdaBound(optimizer_grouped_parameters,
                                  lr=1e-3,
                                  final_lr=0.1)
    # 准备数据
    data = Dureader()
    train_dataloader, dev_dataloader = data.train_iter, data.dev_iter

    best_loss = 100000.0
    model.train()
    for i in range(args.num_train_epochs):
        for step, batch in enumerate(tqdm(train_dataloader, desc="Epoch")):
            input_ids, input_mask, segment_ids, start_positions, end_positions = \
                                        batch.input_ids, batch.input_mask, batch.segment_ids, batch.start_position, batch.end_position
            input_ids, input_mask, segment_ids, start_positions, end_positions = \
                                        input_ids.to(device), input_mask.to(device), segment_ids.to(device), start_positions.to(device), end_positions.to(device)

            # 计算loss
            outputs = model(input_ids,
                            token_type_ids=segment_ids,
                            attention_mask=input_mask,
                            start_positions=start_positions,
                            end_positions=end_positions)
            loss = outputs[0]
            loss = loss / args.gradient_accumulation_steps
            loss.backward()

            # 更新梯度
            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()

            # 验证
            if step % args.log_step == 4:
                eval_loss = evaluate.evaluate(model, dev_dataloader)
                if eval_loss < best_loss:
                    best_loss = eval_loss
                    torch.save(model.state_dict(),
                               './model_dir/' + "best_model")
                    model.train()
Exemple #4
0
def train():
    # 第一步加载预训练bert,用之前的数据把BERT在领域内先微调一下
    model = BertForQuestionAnswering.from_pretrained(
        './roberta_wwm_ext',
        cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE),
                               'distributed_{}'.format(-1)))

    # 第二步,在任务上用robustness的数据训练
    # output_model_file = "./model_dir/best_model"
    # output_config_file = "./model_dir/bert_configbase.json"
    #
    # config = BertConfig(output_config_file)
    # model = BertForQuestionAnswering(config)
    # # 针对多卡训练加载模型的方法:
    # state_dict = torch.load(output_model_file, map_location='cuda:0')
    # # 初始化一个空 dict
    # new_state_dict = OrderedDict()
    # # 修改 key,没有module字段则需要不上,如果有,则需要修改为 module.features
    # for k, v in state_dict.items():
    #     if 'module' not in k:
    #         k = k
    #     else:
    #         k = k.replace('module.', '')
    #     new_state_dict[k] = v
    # model.load_state_dict(new_state_dict)

    if len(device_ids) > 1:
        model = torch.nn.DataParallel(model, device_ids=device_ids)  # 声明所有可用设备
        model = model.cuda(device=device_ids[0])  # 模型放在主设备
    elif len(device_ids) == 1:
        model = model.cuda()

    # 准备 optimizer
    param_optimizer = list(model.named_parameters())
    param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         warmup=0.1,
                         t_total=args.num_train_optimization_steps)
    # optimizer = nn.DataParallel(optimizer, device_ids=device_ids)
    # 准备数据
    data = Dureader()
    train_dataloader, dev_dataloader = data.train_iter, data.dev_iter

    best_loss = 100000.0
    model.train()
    for i in range(args.num_train_epochs):
        main_losses, ide_losses = 0, 0
        for step, batch in enumerate(tqdm(train_dataloader, desc="Epoch")):
            input_ids, input_mask, input_ids_q, input_mask_q, \
            segment_ids, can_answer, start_positions, end_positions = \
                batch.input_ids, batch.input_mask, batch.input_ids_q, batch.input_mask_q,\
                batch.segment_ids, batch.can_answer,\
                batch.start_position, batch.end_position

            if len(device_ids) > 1:
                input_ids, input_mask, input_ids_q, input_mask_q, \
                segment_ids, can_answer, start_positions, end_positions = \
                    input_ids.cuda(device=device_ids[0]), input_mask.cuda(device=device_ids[0]), \
                    input_ids_q.cuda(device=device_ids[0]), input_mask_q.cuda(device=device_ids[0]), \
                    segment_ids.cuda(device=device_ids[0]), can_answer.cuda(device=device_ids[0]),\
                    start_positions.cuda(device=device_ids[0]), end_positions.cuda(device=device_ids[0])
            elif len(device_ids) == 1:
                input_ids, input_mask, input_ids_q, input_mask_q, \
                segment_ids, can_answer, start_positions, end_positions = \
                    input_ids.cuda(), input_mask.cuda(), input_ids_q.cuda(), input_mask_q.cuda(), \
                    segment_ids.cuda(), can_answer.cuda(), start_positions.cuda(), \
                    end_positions.cuda()
                # print("gpu nums is 1.")

            # 计算loss
            loss, main_loss, ide_loss, s, e = model(
                input_ids,
                input_ids_q,
                token_type_ids=segment_ids,
                attention_mask=input_mask,
                attention_mask_q=input_mask_q,
                can_answer=can_answer,
                start_positions=start_positions,
                end_positions=end_positions)
            main_losses += main_loss.mean().item()
            ide_losses += ide_loss.mean().item()
            if step % 100 == 0 and step:
                print(
                    'After {}, main_losses is {}, ide_losses is {},   ide_losses is dd'
                    .format(step, main_losses / step, ide_losses / step))
            elif step == 0:
                print(
                    'After {}, main_losses is {}, ide_losses is {},   ide_losses is dd'
                    .format(step, main_losses, ide_losses))
            # loss = loss / args.gradient_accumulation_steps
            # loss.backward()
            # if n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu.
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            loss.backward()

            # 更新梯度
            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()

            # 验证
            if step % args.log_step == 4:
                eval_loss = evaluate.evaluate(model, dev_dataloader,
                                              device_ids)
                if eval_loss < best_loss:
                    best_loss = eval_loss
                    if len(device_ids) > 1:
                        torch.save(model.module.state_dict(),
                                   './model_dir/' + "best_model_ablation-mult")
                    if len(device_ids) == 1:
                        torch.save(model.state_dict(),
                                   './model_dir/' + "best_model")
                model.train()
Exemple #5
0
def train():
    # 加载预训练bert
    # model = BertForQuestionAnswering.from_pretrained('./roberta_wwm_ext',
    #                 cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(-1)))

    # 加载训练好的模型
    output_model_file = "./model_dir_/best_base_qback"
    output_config_file = "./model_dir/bert_config.json"
    print('output_model_file is {}'.format(output_model_file))
    config = BertConfig(output_config_file)
    model = BertForQuestionAnswering(config)
    # 针对多卡训练加载模型的方法:
    state_dict = torch.load(output_model_file)
    # 初始化一个空 dict
    new_state_dict = OrderedDict()
    # 修改 key,没有module字段则需要不上,如果有,则需要修改为 module.features
    for k, v in state_dict.items():
        if 'module' not in k:
            k = k
        else:
            k = k.replace('module.', '')
        new_state_dict[k] = v
    model.load_state_dict(new_state_dict)
    #
    if len(device_ids) > 1:
        model = torch.nn.DataParallel(model, device_ids=device_ids)  # 声明所有可用设备
        model = model.cuda(device=device_ids[0])  # 模型放在主设备
    elif len(device_ids) == 1:
        # model.to(device)
        model.cuda()  # windows上使用

    # 准备 optimizer
    param_optimizer = list(model.named_parameters())
    param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
            ]
    optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=0.1, t_total=args.num_train_optimization_steps)
    # optimizer = nn.DataParallel(optimizer, device_ids=device_ids)
    # 准备数据
    data = Dureader()
    train_dataloader, dev_dataloader = data.train_iter, data.dev_iter

    best_loss = 100000.0
    model.train()
    for epoch in range(args.num_train_epochs):
        main_losses, ide_losses = 0, 0
        train_loss, train_loss_total = 0.0, 0.0
        n_words, n_words_total = 0, 0
        n_sents, n_sents_total = 0, 0
        for step, batch in enumerate(tqdm(train_dataloader, desc="Epoch")):
            input_ids, input_mask, input_ids_q, input_mask_q, answer_ids, answer_mask, \
            segment_ids, can_answer = \
                batch.input_ids, batch.input_mask, batch.input_ids_q, batch.input_mask_q,\
                batch.answer_ids, batch.answer_mask, batch.segment_ids, batch.can_answer

            answer_inputs = answer_ids[:, :-1]  # 去掉EOS
            answer_targets = answer_ids[:, 1:]  # 去掉BOS
            answer_len = answer_mask.sum(1) - 1

            # flag = torch.ones(4).cuda(device=device_ids[0])

            if len(device_ids) > 1:
                input_ids, input_mask, input_ids_q, input_mask_q, answer_inputs, answer_len, answer_targets,\
                segment_ids, can_answer = \
                    input_ids.cuda(device=device_ids[0]), input_mask.cuda(device=device_ids[0]), \
                    input_ids_q.cuda(device=device_ids[0]), input_mask_q.cuda(device=device_ids[0]), \
                    answer_inputs.cuda(device=device_ids[0]), answer_len.cuda(device=device_ids[0]), \
                    answer_targets.cuda(device=device_ids[0]), \
                    segment_ids.cuda(device=device_ids[0]), can_answer.cuda(device=device_ids[0])
            elif len(device_ids) == 1:
                # input_ids, input_mask, segment_ids, can_answer, start_positions, end_positions = \
                #     input_ids.to(device), input_mask.to(device), \
                #     segment_ids.to(device), can_answer.to(device), start_positions.to(device), \
                #     end_positions.to(device)
                # windows
                input_ids, input_mask, input_ids_q, input_mask_q, answer_inputs, answer_len, answer_targets, \
                segment_ids, can_answer = \
                    input_ids.cuda(), input_mask.cuda(), input_ids_q.cuda(), input_mask_q.cuda(), \
                    answer_inputs.cuda(), answer_len.cuda(), answer_targets.cuda(), \
                    segment_ids.cuda(), can_answer.cuda()
                # print("gpu nums is 1.")

            # 计算loss
            loss = model(input_ids, input_ids_q, token_type_ids=segment_ids,
                         attention_mask=input_mask, attention_mask_q=input_mask_q,
                         dec_inputs=answer_inputs, dec_inputs_len=answer_len,
                         dec_targets=answer_targets, can_answer=can_answer)
            # main_losses += main_loss.mean().item()
            # ide_losses += ide_loss.mean().item()
            train_loss_total += float(loss.item())
            n_words_total += torch.sum(answer_len)
            n_sents_total += answer_len.size(0)  # batch_size

            if step % args.display_freq == 0 and step:
                loss_int = (train_loss_total - train_loss)
                n_words_int = (n_words_total - n_words)
                loss_per_words = loss_int / n_words_int
                avg_loss = loss_per_words

                print('Epoch {0:<3}'.format(epoch),
                      'Step {0:<10}'.format(step),
                      'Avg_loss {0:<10.2f}'.format(avg_loss))
                train_loss, n_words, n_sents = (train_loss_total, n_words_total.item(), n_sents_total)
                # print('After {}, main_losses is {}, ide_losses is none,   ide_losses is dd'.format(step, loss))
            elif step == 0:
                loss_int = (train_loss_total - train_loss)
                n_words_int = (n_words_total - n_words)
                loss_per_words = loss_int / n_words_int
                avg_loss = loss_per_words

                print('Epoch {0:<3}'.format(epoch),
                      'Step {0:<10}'.format(step),
                      'Avg_loss {0:<10.2f}'.format(avg_loss))
                train_loss, n_words, n_sents = (train_loss_total, n_words_total.item(), n_sents_total)
                # print('After {}, main_losses is {}, ide_losses is none,   ide_losses is dd'.format(step, loss))
            # loss = loss / args.gradient_accumulation_steps
            # loss.backward()
            # if n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu.
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            loss.backward()

            # 更新梯度
            if (step+1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()

            # 验证
            if step % args.log_step == 4:
                eval_loss = evaluate.evaluate(model, dev_dataloader, device_ids)
                if eval_loss < best_loss:
                    best_loss = eval_loss
                    if len(device_ids) > 1:
                        torch.save(model.module.state_dict(), './model_dir/' + "best_base_backgate")
                    if len(device_ids) == 1:
                        torch.save(model.state_dict(), './model_dir_/' + "best_base_hype4_6")
                model.train()
Exemple #6
0
def train():
    print(args.config_name)
    # 加载预训练bert
    if args.state_dict is not None:
        model = BertForQuestionAnswering.from_pretrained(
            'bert-base-chinese',
            cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE),
                                   'distributed_{}'.format(-1)),
            state_dict=torch.load(args.state_dict),  #, map_location='cpu'),
            config_name=args.config_name)


#        try:
#            model.load_state_dict(torch.load(output_model_file))
#        except:
#            model = nn.DataParallel(model)
#            model.load_state_dict(torch.load(output_model_file))
    else:
        model = BertForQuestionAnswering.from_pretrained(
            'bert-base-chinese',
            cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE),
                                   'distributed_{}'.format(-1)),
            config_name=args.config_name)

    # use multiple GPUs
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if torch.cuda.device_count() > 1:
        print("We have ", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)
    print('device: ', device)
    model.to(device)

    # 准备 optimizer
    param_optimizer = list(model.named_parameters())
    param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         warmup=0.1,
                         t_total=NUM_TRAIN_OPTIMIZATION_STEPS)

    # 准备数据
    data = Dureader(args.batch_size, args.dataset_path, TRAINSET_NAME,
                    DEVSET_NAME)
    train_dataloader, dev_dataloader = data.train_iter, data.dev_iter

    best_loss = 100000.0
    model.train()
    for i in range(args.epochs):
        print("==Epoch: ", i)
        for step, batch in enumerate(tqdm(train_dataloader, desc="Epoch")):
            input_ids, input_mask, segment_ids, start_positions, end_positions = \
                batch.input_ids, batch.input_mask, batch.segment_ids, batch.start_position, batch.end_position
            input_ids, input_mask, segment_ids, start_positions, end_positions = \
                input_ids.to(device), input_mask.to(device), segment_ids.to(device), start_positions.to(device), end_positions.to(device)

            # 计算loss
            loss, _, _ = model(input_ids,
                               token_type_ids=segment_ids,
                               attention_mask=input_mask,
                               start_positions=start_positions,
                               end_positions=end_positions)
            loss = loss / args.gradient_accumulation_steps
            loss.sum().backward()

            # 更新梯度
            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()

            # 验证
            if step % LOG_STEP == 4 and not args.no_eval:
                eval_loss = evaluate.evaluate(model, dev_dataloader, device,
                                              args.gradient_accumulation_steps)
                print('eval loss: ', eval_loss)
                if eval_loss < best_loss:
                    print("save the best model")
                    best_loss = eval_loss
                    torch.save(model.state_dict(), SAVE_MODEL_PATH)
                    model.train()
    # 最后一次验证
    eval_loss = evaluate.evaluate(model, dev_dataloader, device,
                                  args.gradient_accumulation_steps)
    print('final eval loss: ', eval_loss)
    if eval_loss < best_loss:
        print("save the best model")
        best_loss = eval_loss
        torch.save(model.state_dict(), SAVE_MODEL_PATH)
        model.train()

    print('training finished!')
def train():
    # s加载预训练bert
    model = BertForQuestionAnswering.from_pretrained(
        './roberta_wwm_ext',
        cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE),
                               'distributed_{}'.format(-1)))
    # output_model_file = "./model_dir_/baseextra"
    # output_config_file = "./model_dir/bert_config.json"
    #
    # config = BertConfig(output_config_file)
    # model = BertForQuestionAnswering(config)
    # # 针对多卡训练加载模型的方法:
    # state_dict = torch.load(output_model_file, map_location='cuda:0')
    # # 初始化一个空 dict
    # new_state_dict = OrderedDict()
    # # 修改 key,没有module字段则需要不上,如果有,则需要修改为 module.features
    # for k, v in state_dict.items():
    #     if 'module' not in k:
    #         k = k
    #     else:
    #         k = k.replace('module.', '')
    #     new_state_dict[k] = v
    # model.load_state_dict(new_state_dict)

    if len(device_ids) > 1:
        model = torch.nn.DataParallel(model, device_ids=device_ids)  # 声明所有可用设备
        model = model.cuda(device=device_ids[0])  # 模型放在主设备
    elif len(device_ids) == 1:
        model = model.cuda()

    # 准备 optimizer
    param_optimizer = list(model.named_parameters())
    param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         warmup=0.1,
                         t_total=args.num_train_optimization_steps)
    # optimizer = nn.DataParallel(optimizer, device_ids=device_ids)
    # 准备数据
    data = Dureader()
    train_dataloader, dev_dataloader = data.train_iter, data.dev_iter

    best_loss = 100000.0
    model.train()
    # 初始化
    fgm = FGM(model)
    for i in range(args.num_train_epochs):
        for step, batch in enumerate(tqdm(train_dataloader, desc="Epoch")):
            input_ids, input_mask, segment_ids, start_positions, end_positions = \
                batch.input_ids, batch.input_mask, batch.segment_ids, batch.start_position, batch.end_position
            if len(device_ids) > 1:
                input_ids, input_mask, segment_ids, start_positions, end_positions = \
                    input_ids.cuda(device=device_ids[0]), input_mask.cuda(device=device_ids[0]), \
                    segment_ids.cuda(device=device_ids[0]), start_positions.cuda(device=device_ids[0]), \
                    end_positions.cuda(device=device_ids[0])
            elif len(device_ids) == 1:
                input_ids, input_mask, segment_ids, start_positions, end_positions = \
                    input_ids.cuda(), input_mask.cuda(), \
                    segment_ids.cuda(), start_positions.cuda(), \
                    end_positions.cuda()
            # 正常训练
            # 计算loss
            loss, s, e = model(input_ids,
                               token_type_ids=segment_ids,
                               attention_mask=input_mask,
                               start_positions=start_positions,
                               end_positions=end_positions)

            # loss = loss / args.gradient_accumulation_steps
            # loss.backward()
            # if n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu.
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            loss.backward()

            # 对抗训练
            fgm.attack()  # 在embedding上添加对抗扰动
            loss_adv, s, e = model(input_ids,
                                   token_type_ids=segment_ids,
                                   attention_mask=input_mask,
                                   start_positions=start_positions,
                                   end_positions=end_positions)
            loss_adv.backward()  # 反向传播,并在正常的grad基础上,累加对抗训练的梯度
            fgm.restore()  # 恢复embedding参数
            # 更新梯度
            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()

            # 验证
            if step % args.log_step == 4:
                # torch.save(model.state_dict(), './model_dir/' + "best_model_baseextraadv_d")
                eval_loss = evaluate.evaluate(model, dev_dataloader,
                                              device_ids)
                if eval_loss < best_loss:
                    best_loss = eval_loss
                    if len(device_ids) > 1:
                        torch.save(model.module.state_dict(),
                                   './model_dir/' + "best_model_base")
                    if len(device_ids) == 1:
                        torch.save(model.state_dict(),
                                   './model_dir/' + "best_model_base")
                    model.train()