Ejemplo n.º 1
0
def main():
    data_path = './data/chatbot.txt'
    voc, pairs = loadPrepareData(data_path)

    # 把含有低频词的句子扔掉
    MIN_COUNT = Config.MIN_COUNT
    pairs = trimRareWords(voc, pairs, MIN_COUNT)

    training_batches = [batch2TrainData(voc, [random.choice(pairs) for _ in range(Config.batch_size)])
                        for _ in range(Config.total_step)]

    # 词嵌入部分
    embedding = nn.Embedding(voc.num_words, Config.hidden_size)

    # 定义编码解码器
    encoder = EncoderRNN(Config.hidden_size, embedding, Config.encoder_n_layers, Config.dropout)
    decoder = LuongAttnDecoderRNN(Config.attn_model, embedding, Config.hidden_size, voc.num_words, Config.decoder_n_layers, Config.dropout)

    # 定义优化器
    encoder_optimizer = optim.Adam(encoder.parameters(), lr=Config.learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=Config.learning_rate * Config.decoder_learning_ratio)

    start_iteration = 1
    save_every = 4000   # 多少步保存一次模型

    for iteration in range(start_iteration, Config.total_step + 1):
        training_batch = training_batches[iteration - 1]
        input_variable, lengths, target_variable, mask, max_target_len = training_batch

        start_time = time.time()
        # Run a training iteration with batch
        loss = train(input_variable, lengths, target_variable, mask, max_target_len, encoder,
                     decoder, embedding, encoder_optimizer, decoder_optimizer, Config.batch_size, Config.clip)

        time_str = datetime.datetime.now().isoformat()
        log_str = "time: {}, Iteration: {}; Percent complete: {:.1f}%; loss: {:.4f}, spend_time: {:6f}".format(time_str, iteration, iteration / Config.total_step * 100, loss, time.time() - start_time)
        rainbow(log_str)

        # Save checkpoint
        if iteration % save_every == 0:
            save_path = './save_model/'
            if not os.path.exists(save_path):
                os.makedirs(save_path)

            torch.save({
                'iteration': iteration,
                'encoder': encoder.state_dict(),
                'decoder': decoder.state_dict(),
                'en_opt': encoder_optimizer.state_dict(),
                'de_opt': decoder_optimizer.state_dict(),
                'loss': loss,
                'voc_dict': voc.__dict__,
                'embedding': embedding.state_dict()
            }, os.path.join(save_path, '{}_{}_model.tar'.format(iteration, 'checkpoint')))
Ejemplo n.º 2
0
def train():
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    optimizer = optim.Adam(model.parameters(), lr=Config.lr, weight_decay=Config.weight_decay)

    for epoch in range(Config.max_epoch):
        model.train()
        model.to(Config.device)
        for index, batch in enumerate(train_dataloader):
            optimizer.zero_grad()

            X = batch['x'].long().to(Config.device)    # torch.Size([4, 60])    (batch_size, max_len)
            y = batch['y'].long().to(Config.device)    # torch.Size([4, 60])    (batch_size, max_len)

            # CRF
            loss = model.log_likelihood(X, y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=10)
            optimizer.step()

            now_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
            o_str = 'time: {}, epoch: {}, step: {}, loss: {:6f}'.format(now_time, epoch, index, loss.item())
            rainbow(o_str)

        aver_loss = 0
        preds, labels = [], []
        for index, batch in enumerate(valid_dataloader):
            model.eval()
            val_x, val_y = batch['x'].long().to(Config.device), batch['y'].long().to(Config.device)
            predict = model(val_x)
            # CRF
            loss = model.log_likelihood(val_x, val_y)
            aver_loss += loss.item()
            # 统计非0的,也就是真实标签的长度
            leng = []
            for i in val_y.cpu():
                tmp = []
                for j in i:
                    if j.item() > 0:
                        tmp.append(j.item())
                leng.append(tmp)

            for index, i in enumerate(predict):
                preds += i[:len(leng[index])]

            for index, i in enumerate(val_y.tolist()):
                labels += i[:len(leng[index])]
        aver_loss /= (len(valid_dataloader) * 64)
        precision = precision_score(labels, preds, average='macro')
        recall = recall_score(labels, preds, average='macro')
        f1 = f1_score(labels, preds, average='macro')
        report = classification_report(labels, preds)
        print(report)
        torch.save(model.state_dict(), './save_model/bilstm_ner.bin')
Ejemplo n.º 3
0
        def train_step(x_batch, y_batch):
            """
            A single training step
            """
            feed_dict = {
                model.input_x: x_batch,
                model.input_y: y_batch,
                model.dropout_keep_prob: FLAGS.dropout_keep_prob
            }

            _, step, summaries, lr, loss, l2_loss, accuracy = sess.run([
                train_op, global_step, train_summary_op, dlearning_rate,
                model.loss, model.l2_losses, model.accuracy
            ], feed_dict)

            time_str = datetime.datetime.now().strftime("%H:%M:%S.%f")
            rainbow(
                "train set:*** {}: step {}, learning_rate {:5f}, loss {:g}, l2_loss {:g}, acc {:g}"
                .format(time_str, step, lr, loss, l2_loss, accuracy),
                time_tag=True)

            train_summary_writer.add_summary(summaries, step)

            return loss
Ejemplo n.º 4
0
def train():
    device = Config.device
    # 准备数据
    train_data, dev_data = build_dataset(Config)
    train_iter = DatasetIterater(train_data, Config)
    dev_iter = DatasetIterater(dev_data, Config)

    model = Model().to(device)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    # optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
    # 这里我们用bertAdam优化器
    optimizer = AdamW(
        optimizer_grouped_parameters,
        lr=Config.learning_rate,
        correct_bias=False)  # 要重现BertAdam特定的行为,请设置correct_bias = False
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0.05,
        num_training_steps=len(train_iter) *
        Config.num_epochs)  # PyTorch调度程序用法如下:

    model.to(device)
    model.train()

    best_loss = 100000.0
    for epoch in range(Config.num_epochs):
        print('Epoch [{}/{}]'.format(epoch + 1, Config.num_epochs))
        for step, batch in enumerate(train_iter):
            start_time = time.time()
            ids, input_ids, input_mask, start_positions, end_positions = \
                batch[0], batch[1], batch[2], batch[3], batch[4]
            input_ids, input_mask, start_positions, end_positions = \
                input_ids.to(device), input_mask.to(device), start_positions.to(device), end_positions.to(device)

            # print(input_ids.size())
            # print(input_mask.size())
            # print(start_positions.size())
            # print(end_positions.size())

            loss, _, _ = model(input_ids,
                               attention_mask=input_mask,
                               start_positions=start_positions,
                               end_positions=end_positions)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           max_grad_norm=20)
            optimizer.step()
            scheduler.step()

            time_str = datetime.datetime.now().isoformat()
            log_str = 'time:{}, epoch:{}, step:{}, loss:{:8f}, spend_time:{:6f}'.format(
                time_str, epoch, step, loss,
                time.time() - start_time)
            rainbow(log_str)

            train_loss.append(loss)

        if epoch % 1 == 0:
            eval_loss = valid(model, dev_iter)
            if eval_loss < best_loss:
                best_loss = eval_loss
                torch.save(model.state_dict(),
                           './save_model/' + 'best_model.bin')
                model.train()
Ejemplo n.º 5
0
            bert_encode = model(batch_data,
                                token_type_ids=None,
                                attention_mask=batch_masks,
                                labels=batch_tags)
            train_loss = model.loss_fn(bert_encode=bert_encode,
                                       tags=batch_tags,
                                       output_mask=batch_masks)
            train_loss.backward()
            # gradient clipping
            nn.utils.clip_grad_norm_(parameters=model.parameters(),
                                     max_norm=Config.clip_grad)
            # performs updates using calculated gradients
            optimizer.step()

            predicts = model.predict(bert_encode, batch_masks)
            label_ids = batch_tags.view(1, -1)
            label_ids = label_ids[label_ids != -1]
            label_ids = label_ids.cpu()
            train_acc, f1 = model.acc_f1(predicts, label_ids)
            s = "Epoch:{}, step:{}, loss:{:8f}, acc:{:5f}, f1:{:5f}, spend_time:{:6f}".format(
                epoch, i, train_loss, train_acc, f1,
                time.time() - start_time)
            rainbow(s)

        if epoch % 1 == 0:
            eval_f1 = evaluate(model, val_data)
            if eval_f1 > best_f1:
                best_f1 = eval_f1
                torch.save(model.state_dict(),
                           './save_model/' + 'best_model.bin')
Ejemplo n.º 6
0
def train():
    # 1.数据集整理
    data = json.load(open(Config.train_data_path, 'r'))

    input_data = data['input_data']
    input_len = data['input_len']
    output_data = data['output_data']
    mask_data = data['mask']
    output_len = data['output_len']

    total_len = len(input_data)
    step = total_len // Config.batch_size

    # 词嵌入部分
    embedding = nn.Embedding(Config.vocab_size,
                             Config.hidden_size,
                             padding_idx=Config.PAD)

    # 2. 模型准备
    encoder = Encoder(embedding)
    attn_model = 'dot'
    decoder = Decoder(
        attn_model,
        embedding,
    )

    encoder_optimizer = torch.optim.Adam(encoder.parameters(),
                                         lr=Config.learning_rate)
    decoder_optimizer = torch.optim.Adam(decoder.parameters(),
                                         lr=Config.learning_rate)

    for epoch in range(Config.num_epochs):
        for i in range(step - 1):
            start_time = time.time()
            encoder_optimizer.zero_grad()
            decoder_optimizer.zero_grad()

            input_ids = torch.LongTensor(
                input_data[i * Config.batch_size:(i + 1) *
                           Config.batch_size]).to(Config.device)
            inp_len = torch.LongTensor(
                input_len[i * Config.batch_size:(i + 1) *
                          Config.batch_size]).to(Config.device)
            output_ids = torch.LongTensor(
                output_data[i * Config.batch_size:(i + 1) *
                            Config.batch_size]).to(Config.device)
            mask = torch.BoolTensor(mask_data[i * Config.batch_size:(i + 1) *
                                              Config.batch_size]).to(
                                                  Config.device)
            out_len = output_len[i * Config.batch_size:(i + 1) *
                                 Config.batch_size]

            max_ans_len = max(out_len)

            mask = mask.permute(1, 0)
            output_ids = output_ids.permute(1, 0)
            encoder_outputs, hidden = encoder(input_ids, inp_len)
            encoder_outputs = encoder_outputs.permute(1, 0, 2)
            decoder_hidden = hidden.unsqueeze(0)

            # 创建解码的初始输入 (为一个batch中的每条数创建SOS)
            decoder_input = torch.LongTensor(
                [[Config.SOS for _ in range(Config.batch_size)]])
            decoder_input = decoder_input.to(Config.device)

            # Determine if we are using teacher forcing this iteration
            teacher_forcing_ratio = 0.3
            use_teacher_forcing = True if random.random(
            ) < teacher_forcing_ratio else False

            loss = 0
            print_losses = []
            n_totals = 0
            if use_teacher_forcing:
                # 这种是解码的每步我们输入上一步的真实标签
                for t in range(max_ans_len):
                    decoder_output, decoder_hidden = decoder(
                        decoder_input, decoder_hidden, encoder_outputs)
                    # print(decoder_output.size())  # torch.Size([2, 2672])
                    # print(decoder_hidden.size())   # torch.Size([1, 2, 512])

                    decoder_input = output_ids[t].view(1, -1)
                    # 计算损失
                    mask_loss, nTotal = maskNLLLoss(decoder_output,
                                                    output_ids[t], mask[t])
                    # print('1', mask_loss)
                    loss += mask_loss
                    print_losses.append(mask_loss.item() * nTotal)
                    n_totals += nTotal
            else:
                # 这种是解码的每步输入是上一步的预测结果
                for t in range(max_ans_len):
                    decoder_output, decoder_hidden = decoder(
                        decoder_input, decoder_hidden, encoder_outputs)

                    _, topi = decoder_output.topk(1)
                    decoder_input = torch.LongTensor(
                        [[topi[i][0] for i in range(Config.batch_size)]])
                    decoder_input = decoder_input.to(Config.device)
                    # Calculate and accumulate loss
                    mask_loss, nTotal = maskNLLLoss(decoder_output,
                                                    output_ids[t], mask[t])
                    # print('2', mask_loss)
                    loss += mask_loss
                    print_losses.append(mask_loss.item() * nTotal)
                    n_totals += nTotal

            # Perform backpropatation
            loss.backward()

            # 梯度裁剪
            _ = nn.utils.clip_grad_norm_(encoder.parameters(), Config.clip)
            _ = nn.utils.clip_grad_norm_(decoder.parameters(), Config.clip)

            # Adjust model weights
            encoder_optimizer.step()
            decoder_optimizer.step()
            avg_loss = sum(print_losses) / n_totals

            time_str = datetime.datetime.now().isoformat()
            log_str = 'time:{}, epoch:{}, step:{}, loss:{:5f}, spend_time:{:6f}'.format(
                time_str, epoch, i, avg_loss,
                time.time() - start_time)
            rainbow(log_str)

        if epoch % 1 == 0:
            save_path = './save_model/'
            if not os.path.exists(save_path):
                os.makedirs(save_path)

            torch.save(
                {
                    'epoch': epoch,
                    'encoder': encoder.state_dict(),
                    'decoder': decoder.state_dict(),
                    'en_opt': encoder_optimizer.state_dict(),
                    'de_opt': decoder_optimizer.state_dict(),
                    'loss': avg_loss,
                    'embedding': embedding.state_dict()
                },
                os.path.join(
                    save_path,
                    'epoch{}_{}_model.tar'.format(epoch, 'checkpoint')))
Ejemplo n.º 7
0
def train():
    # 加载数据集
    dataset = DreamDataset()
    dataloader = DataLoader(dataset,
                            batch_size=Config.batch_size,
                            shuffle=True,
                            collate_fn=collate_fn)

    # 实例化模型
    word2idx = load_bert_vocab()
    bertconfig = BertConfig(vocab_size=len(word2idx))
    bert_model = Seq2SeqModel(config=bertconfig)
    # 加载预训练模型
    load_model(bert_model, Config.pretrain_model_path)
    bert_model.to(Config.device)

    # 声明需要优化的参数 并定义相关优化器
    optim_parameters = list(bert_model.parameters())
    optimizer = torch.optim.Adam(optim_parameters,
                                 lr=Config.learning_rate,
                                 weight_decay=1e-3)

    step = 0
    for epoch in range(Config.EPOCH):
        total_loss = 0
        i = 0
        for token_ids, token_type_ids, target_ids in dataloader:
            start_time = time.time()
            step += 1
            i += 1
            token_ids = token_ids.to(Config.device)
            token_type_ids = token_type_ids.to(Config.device)
            target_ids = target_ids.to(Config.device)
            # 因为传入了target标签,因此会计算loss并且返回
            predictions, loss = bert_model(token_ids,
                                           token_type_ids,
                                           labels=target_ids,
                                           device=Config.device)

            # 1. 清空之前梯度
            optimizer.zero_grad()
            # 2. 反向传播
            loss.backward()
            # 3. 梯度更新
            optimizer.step()

            time_str = datetime.datetime.now().isoformat()

            log_str = 'time:{}, epoch:{}, step:{}, loss:{:8f}, spend_time:{:6f}'.format(
                time_str, epoch, step, loss,
                time.time() - start_time)
            rainbow(log_str)
            # print('epoch:{}, step:{}, loss:{:6f}, spend_time:{}'.format(epoch, step, loss, time.time() - start_time))

            # 为计算当前epoch的平均loss
            total_loss += loss.item()

            if step % 30 == 0:
                torch.save(bert_model.state_dict(), './bert_dream.bin')

        print("当前epoch:{}, 平均损失:{}".format(epoch, total_loss / i))

        if epoch % 10 == 0:
            save_path = "./data/" + "pytorch_bert_gen_epoch{}.bin".format(
                str(epoch))
            torch.save(bert_model.state_dict(), save_path)
            print("{} saved!".format(save_path))
Ejemplo n.º 8
0
def main():
    args = set_args()
    # 加载训练集
    with gzip.open(args.train_data_path, 'rb') as f:
        train_features = pickle.load(f)
    
    # 加载验证集
    with gzip.open(args.dev_data_path, 'rb') as f:
        eval_features = pickle.load(f)
    
    # 总共训练的步数
    num_train_steps = int(
        len(train_features) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
    
    # 模型
    model = Model()

    # 指定多gpu运行
    if torch.cuda.is_available():
        model.cuda()

    if torch.cuda.device_count() > 1:
        args.n_gpu = torch.cuda.device_count()
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        # 就这一行
        model = nn.DataParallel(model)

    tokenizer = BertTokenizer.from_pretrained(args.vocab_file)
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]

    warmup_steps = 0.05 * num_train_steps
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=1e-8)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_train_steps)

    best_loss = None
    global_step = 0

    # 开始训练
    print("***** Running training *****")
    print("  Num examples = {}".format(len(train_features)))
    print("  Batch size = {}".format(args.train_batch_size))
    all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
    all_label_ids = torch.tensor([f.label for f in train_features], dtype=torch.float32)

    train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)

    model.train()
    for epoch in range(args.num_train_epochs):
        train_dataloader = DataLoader(train_data, shuffle=True, batch_size=args.train_batch_size)
        for step, batch in enumerate(train_dataloader):
            start_time = time.time()
            if torch.cuda.is_available():
                batch = tuple(t.cuda() for t in batch)
            input_ids, input_mask, segment_ids, label = batch

            logits = model(input_ids=input_ids, attention_mask=input_mask, segment_ids=segment_ids, labels=label)
            loss = loss_fct(logits, label)
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            s = '****Epoch: {}, step: {}, loss: {:10f}, time_cost: {:10f}'.format(epoch, step, loss, time.time() - start_time)
            rainbow(s)
            loss.backward()
            # nn.utils.clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)   # 是否进行梯度裁剪

            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                global_step += 1
            # test_loss, test_acc = evaluate(epoch, eval_features, args, model)

        # 一轮跑完 进行eval
        test_loss, test_acc = evaluate(epoch, eval_features, args, model)
        model.train()
        if best_loss is None or best_loss > test_loss:
            best_loss = test_loss
            model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
            os.makedirs(args.save_model, exist_ok=True)

            output_model_file = os.path.join(args.save_model, "best_pytorch_model.bin")
            torch.save(model_to_save.state_dict(), output_model_file)

        # Save a trained model
        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(args.save_model, "epoch{}_ckpt.bin".format(epoch))
        torch.save(model_to_save.state_dict(), output_model_file)
Ejemplo n.º 9
0
                                            Config.batch_size]).to(
                                                Config.device)
            length = torch.LongTensor(lengths[i * Config.batch_size:(i + 1) *
                                              Config.batch_size]).to(
                                                  Config.device)
            # print(sentence_id.size())   # torch.Size([2, 225])
            # print(label.size())  # torch.Size([2])
            # print(length.size())  # torch.Size([2])
            logits = model(sentence_id)
            # print(logits.size())  # torch.Size([16, 2])

            pred = np.argmax(logits.cpu().data.numpy(), axis=1)
            acc = accuracy_score(label.cpu().data.numpy(), pred)

            loss = loss_func(logits, label)
            optimizer.zero_grad()  # clear gradients for this training step
            loss.backward()  # backpropagation, compute gradients
            optimizer.step()  # apply gradients
            out = 'Epoch:{}, steps:{}, loss:{:6f}, accuracy:{:6f}'.format(
                epoch, i, loss, acc)
            rainbow(out)
        if epoch % 2 == 0:
            acc_dev, loss_dev = evaluate(sentence_id_dev, labels_dev,
                                         lengths_dev, model)
            print("epoch:{}, dev_loss:{:6f}, dev_acc:{:6f}".format(
                epoch, loss_dev, acc_dev))
            if acc_dev > acc_best:
                acc_best = acc_dev
                torch.save(model.state_dict(), './best_model.bin')
                model.train()
Ejemplo n.º 10
0
def train(args, train_features, model, tokenizer, eval_features,
          teacher_model):
    all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                 dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                  dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                   dtype=torch.long)
    all_label_ids = torch.tensor([f.label_id for f in train_features],
                                 dtype=torch.long)

    train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                               all_label_ids)
    train_dataloader = DataLoader(train_data,
                                  shuffle=True,
                                  batch_size=args.train_batch_size)

    t_total = len(train_dataloader
                  ) // args.gradient_accumulation_steps * args.num_train_epochs

    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]

    warmup_steps = 0.05 * t_total
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=1e-8)
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                num_warmup_steps=warmup_steps,
                                                num_training_steps=t_total)

    global_step = 0
    tr_loss = 0.0
    model.zero_grad()
    set_seed(args)
    loss_mse = MSELoss()
    for epoch in range(args.num_train_epochs):
        model.train()
        for step, batch in enumerate(train_dataloader):
            batch = tuple(t.cuda() for t in batch)
            teacher_inputs = {
                'input_ids': batch[0],
                'attention_mask': batch[1],
                'labels': batch[3],
                'segment_ids': batch[2]
            }
            inputs = {
                'input_ids': batch[0],
                'attention_mask': batch[1],
                'labels': batch[3],
                'token_type_ids': batch[2]
            }

            with torch.no_grad():
                teacher_logits, layer_13_output = teacher_model(
                    **teacher_inputs)

            start_time = time.time()
            # 先对高度进行缩减  再对宽度进行缩减

            # accumulate grads for all sub-networks
            for depth_mult in sorted(args.depth_mult_list, reverse=True):
                model.apply(lambda m: setattr(m, 'depth_mult', depth_mult))
                n_layers = model.config.num_hidden_layers

                depth = round(depth_mult * n_layers)
                kept_layers_index = []
                for i in range(depth):
                    kept_layers_index.append(math.floor(i / depth_mult))
                kept_layers_index.append(n_layers)
                s = ''
                width_idx = 0
                for width_mult in sorted(args.width_mult_list, reverse=True):
                    model.apply(lambda m: setattr(m, 'width_mult', width_mult))

                    loss, student_logit, student_reps, _, _ = model(**inputs)
                    logit_loss = soft_cross_entropy(student_logit,
                                                    teacher_logits.detach())
                    # loss = args.width_lambda1 * logit_loss + args.width_lambda2 * rep_loss
                    # loss = logit_loss   # 这里只加入蒸馏最终的损失  不加入层损失
                    # for student_rep, teacher_rep in zip(student_reps, list(layer_13_output[i] for i in kept_layers_index)):
                    #     print('------------------------------------')
                    #     print(student_reps)
                    #     print('*************************************')
                    #     print(teacher_rep)
                    #     print('------------------------------------')
                    #     # print(student_reps.size())
                    #     # print(teacher_rep.size())
                    #     exit()

                    #     tmp_loss = loss_mse(student_reps, teacher_rep.detach())
                    #     rep_loss += tmp_loss
                    # loss = logit_loss + rep_loss
                    loss = logit_loss

                    s += 'width={}: {}, '.format(width_mult, loss)
                    if args.n_gpu > 1:
                        loss = loss.mean()
                    if args.gradient_accumulation_steps > 1:
                        loss = loss / args.gradient_accumulation_steps
                    loss.backward()

                t = 'tailoring*****epoch:{}, step:{}, depth:{}, {}time:{}'.format(
                    epoch, step, len(kept_layers_index), s,
                    time.time() - start_time)
                rainbow(t)
        exit()
        # clip the accumulated grad from all widths
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        tr_loss += loss.item()
        if (step + 1) % args.gradient_accumulation_steps == 0:
            optimizer.step()
            scheduler.step()  # Update learning rate schedule
            model.zero_grad()
            global_step += 1
            # evaluate
            current_best = 0
            if global_step > 0 and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                acc = []
                for depth_mult in sorted(args.depth_mult_list, reverse=True):
                    model.apply(lambda m: setattr(m, 'depth_mult', depth_mult))
                    for width_mult in sorted(args.width_mult_list,
                                             reverse=True):
                        model.apply(
                            lambda m: setattr(m, 'width_mult', width_mult))
                        eval_loss, eval_accuracy = evaluate(
                            args, model, tokenizer, eval_features)
                        acc.append(eval_accuracy)
                    if sum(acc) > current_best:
                        current_best = sum(acc)
                        os.makedirs(args.save_student_model, exist_ok=True)
                        print('Saving model checkpoint to %s' %
                              (args.save_teacher_model))
                        model_to_save = model.modules if hasattr(
                            model, 'module') else model
                        model_to_save.save_pretrained(args.save_student_model)
                        torch.save(
                            args,
                            os.path.join(args.save_student_model,
                                         'student_model.bin'))
                        model_to_save.config.to_json_file(
                            os.path.join(args.save_student_model,
                                         'config.json'))
                        tokenizer.save_vocabulary(args.save_student_model)