def train(
        batch_size=16,
        pretrain_model_path='',
        name='',
        model_type='mlp',
        after_bert_choice='last_cls',
        dim=1024,
        lr=1e-5,
        epoch=12,
        smoothing=0.05,
        sample=False,
        #open_ad='',
        dialog_name='xxx'):

    if not pretrain_model_path or not name:
        assert 1 == -1

    print('\n********** model type:', model_type, '**********')
    print('batch_size:', batch_size)

    # load dataset
    train_file = '/kaggle/input/dataset/my_train.csv'
    dev_file = '/kaggle/input/dataset/my_dev.csv'

    train_num = len(pd.read_csv(train_file).values.tolist())
    val_num = len(pd.read_csv(dev_file).values.tolist())
    print('train_num: %d, dev_num: %d' % (train_num, val_num))

    # 选择模型
    if model_type in ['siam', 'esim', 'sbert']:
        assert 1 == -1

    else:
        train_iter = MyDataset(file=train_file,
                               is_train=True,
                               sample=sample,
                               pretrain_model_path=pretrain_model_path)
        train_iter = get_dataloader(train_iter,
                                    batch_size,
                                    shuffle=True,
                                    drop_last=True)
        dev_iter = MyDataset(file=dev_file,
                             is_train=True,
                             sample=sample,
                             pretrain_model_path=pretrain_model_path)
        dev_iter = get_dataloader(dev_iter,
                                  batch_size,
                                  shuffle=False,
                                  drop_last=False)

        if model_type == 'mlp':
            model = MyModel(dim=dim,
                            pretrain_model_path=pretrain_model_path,
                            smoothing=smoothing,
                            after_bert_choice='last_cls')

        elif model_type == 'cnn':
            model = MyTextCNNModel(dim=dim,
                                   pretrain_model_path=pretrain_model_path,
                                   smoothing=smoothing)

        elif model_type == 'rcnn':
            model = MyRCNNModel(dim=dim,
                                pretrain_model_path=pretrain_model_path,
                                smoothing=smoothing)

    #模型加载到gpu
    model.to(device)
    model_param_num = 0

    ##### 3.24 muppti-gpu-training
    if n_gpu > 1:
        model = torch.nn.DataParallel(model)

    for p in model.parameters():
        if p.requires_grad:
            model_param_num += p.nelement()
    print('param_num:%d\n' % model_param_num)

    # 加入对抗训练,提升泛化能力;但是训练速度明显变慢 (插件式调用)
    # 3.12 change to FGM 更快!
    """
    if open_ad == 'fgm':
        fgm = FGM(model)
    elif open_ad == 'pgd':
        pgd = PGD(model)
        K = 3
    """
    # model-store-path
    #model_path = '/kaggle/output/' + name + '.pkl' # 输出模型默认存放在当前路径下
    output_dir = 'output'
    state = {}
    time0 = time.time()
    best_loss = 999
    early_stop = 0
    for e in range(epoch):
        print("*" * 100)
        print("Epoch:", e)
        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=lr,
                             warmup=0.05,
                             t_total=len(train_iter))  # 设置优化器
        train_loss = 0
        train_c = 0
        train_right_num = 0

        model.train()  # 将模型设置成训练模式(Sets the module in training mode)
        print('training..., %s, e:%d, lr:%7f' % (name, e, lr))
        for batch in tqdm(train_iter):  # 每一次返回 batch_size 条数据

            optimizer.zero_grad()  # 清空梯度
            batch = [b.to(device) for b in batch]  # cpu -> GPU

            # 正常训练
            labels = batch[-1].view(-1).cpu().numpy()
            loss, bert_enc = model(batch, task='train',
                                   epoch=epoch)  # 进行前向传播,真正开始训练;计算 loss
            right_num = count_right_num(bert_enc, labels)

            # multi-gpu training!
            if n_gpu > 1:
                loss = loss.mean()

            loss.backward()  # 反向传播计算参数的梯度

            #"""
            if open_ad == 'fgm':
                # 对抗训练
                fgm.attack()  # 在embedding上添加对抗扰动

                if model_type == 'multi-task':
                    loss_adv, _, _ = model(batch, task='train')
                else:
                    loss_adv, _ = model(batch, task='train')

                if n_gpu > 1:
                    loss_adv = loss_adv.mean()

                loss_adv.backward()  # 反向传播,并在正常的grad基础上,累加对抗训练的梯度
                fgm.restore()  # 恢复embedding参数

            elif open_ad == 'pgd':
                pgd.backup_grad()
                # 对抗训练
                for t in range(K):
                    pgd.attack(is_first_attack=(
                        t == 0
                    ))  # 在embedding上添加对抗扰动, first attack时备份param.data
                    if t != K - 1:
                        optimizer.zero_grad()
                    else:
                        pgd.restore_grad()

                    if model_type == 'multi-task':
                        loss_adv, _, _ = model(batch, task='train')
                    else:
                        loss_adv, _ = model(batch, task='train')

                    if n_gpu > 1:
                        loss_adv = loss_adv.mean()

                    loss_adv.backward()  # 反向传播,并在正常的grad基础上,累加对抗训练的梯度
                pgd.restore()  # 恢复embedding参数
            #"""
            optimizer.step()  # 更新参数

            train_loss += loss.item()  # loss 求和
            train_c += 1
            train_right_num += right_num

        val_loss = 0
        val_c = 0
        val_right_num = 0

        model.eval()
        print('eval...')
        with torch.no_grad():  # 不进行梯度的反向传播
            for batch in tqdm(dev_iter):  # 每一次返回 batch_size 条数据
                batch = [b.to(device) for b in batch]

                labels = batch[-1].view(-1).cpu().numpy()
                loss, bert_enc = model(batch, task='train',
                                       epoch=epoch)  # 进行前向传播,真正开始训练;计算 loss
                right_num = count_right_num(bert_enc, labels)

                if n_gpu > 1:
                    loss = loss.mean()

                val_c += 1
                val_loss += loss.item()
                val_right_num += right_num

        train_acc = train_right_num / train_num
        val_acc = val_right_num / val_num

        print('train_acc: %.4f, val_acc: %.4f' % (train_acc, val_acc))
        print('train_loss: %.4f, val_loss: %.4f, time: %d' %
              (train_loss / train_c, val_loss / val_c, time.time() - time0))

        if val_loss / val_c < best_loss:
            early_stop = 0
            best_loss = val_loss / val_c
            best_acc = val_acc

            # 3.24 update 多卡训练时模型保存避坑:
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)

            model_to_save = model.module if hasattr(model, 'module') else model
            state['model_state'] = model_to_save.state_dict()
            state['loss'] = val_loss / val_c
            state['acc'] = val_acc
            state['e'] = e
            state['time'] = time.time() - time0
            state['lr'] = lr

            output_model_file = os.path.join(output_dir, name + '.pkl')
            torch.save(state, output_model_file)
            #torch.save(state, model_path)

            best_epoch = e
            cost_time = time.time() - time0
            tmp_train_acc = train_acc
            best_model = model

        else:
            early_stop += 1
            if early_stop == 2:
                break

            model = best_model
            lr = lr * 0.5
        print("best_loss:", best_loss)

    # 3.12 add 打印显示最终的最优结果
    print('-' * 30)
    print('best_epoch:', best_epoch, 'best_loss:', best_loss, 'best_acc:',
          best_acc, 'reach time:', cost_time, '\n')

    # model-clean
    del model
    gc.collect()

    # 实验结果写入日志
    """